Skip to content

Commit 77951bc

Browse files
committed
Handle credentials refresh error
[#167029587]
1 parent 7b50067 commit 77951bc

File tree

3 files changed

+190
-10
lines changed

3 files changed

+190
-10
lines changed

src/main/java/com/rabbitmq/client/impl/AMQConnection.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,7 @@ public AMQConnection(ConnectionParams params, FrameHandler frameHandler, Metrics
243243

244244
this.credentialsRefreshService = params.getCredentialsRefreshService();
245245

246-
this._channel0 = new AMQChannel(this, 0) {
247-
@Override public boolean processAsync(Command c) throws IOException {
248-
return getConnection().processControlCommand(c);
249-
}
250-
};
246+
this._channel0 = createChannel0();
251247

252248
this._channelManager = null;
253249

@@ -262,6 +258,14 @@ public AMQConnection(ConnectionParams params, FrameHandler frameHandler, Metrics
262258
this.workPoolTimeout = params.getWorkPoolTimeout();
263259
}
264260

261+
AMQChannel createChannel0() {
262+
return new AMQChannel(this, 0) {
263+
@Override public boolean processAsync(Command c) throws IOException {
264+
return getConnection().processControlCommand(c);
265+
}
266+
};
267+
}
268+
265269
private void initializeConsumerWorkService() {
266270
this._workService = new ConsumerWorkService(consumerWorkServiceExecutor, threadFactory, workPoolTimeout, shutdownTimeout);
267271
}
@@ -438,7 +442,12 @@ public void start()
438442
AMQImpl.Connection.UpdateSecret updateSecret = new AMQImpl.Connection.UpdateSecret(
439443
LongStringHelper.asLongString(refreshedPassword), "Refresh scheduled by client"
440444
);
441-
_channel0.rpc(updateSecret);
445+
try {
446+
_channel0.rpc(updateSecret);
447+
} catch (ShutdownSignalException e) {
448+
LOGGER.warn("Error while trying to update secret: {}. Connection has been closed.", e.getMessage());
449+
return false;
450+
}
442451
return true;
443452
});
444453

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// Copyright (c) 2019 Pivotal Software, Inc. All rights reserved.
2+
//
3+
// This software, the RabbitMQ Java client library, is triple-licensed under the
4+
// Mozilla Public License 1.1 ("MPL"), the GNU General Public License version 2
5+
// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see
6+
// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL,
7+
// please see LICENSE-APACHE2.
8+
//
9+
// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
10+
// either express or implied. See the LICENSE file for specific language governing
11+
// rights and limitations of this software.
12+
//
13+
// If you have any questions regarding licensing, please contact us at
14+
// info@rabbitmq.com.
15+
16+
package com.rabbitmq.client.impl;
17+
18+
import com.rabbitmq.client.Method;
19+
import com.rabbitmq.client.*;
20+
import com.rabbitmq.client.test.TestUtils;
21+
import org.junit.ClassRule;
22+
import org.junit.Test;
23+
import org.junit.rules.TestRule;
24+
import org.junit.runner.RunWith;
25+
import org.mockito.Mock;
26+
import org.mockito.junit.MockitoJUnitRunner;
27+
28+
import java.io.IOException;
29+
import java.time.Duration;
30+
import java.util.UUID;
31+
import java.util.concurrent.Callable;
32+
import java.util.concurrent.CountDownLatch;
33+
import java.util.concurrent.TimeUnit;
34+
import java.util.concurrent.atomic.AtomicReference;
35+
36+
import static org.assertj.core.api.Assertions.assertThat;
37+
import static org.mockito.ArgumentMatchers.any;
38+
import static org.mockito.ArgumentMatchers.eq;
39+
import static org.mockito.Mockito.*;
40+
41+
@RunWith(MockitoJUnitRunner.class)
42+
public class AMQConnectionRefreshCredentialsTest {
43+
44+
@ClassRule
45+
public static TestRule brokerVersionTestRule = TestUtils.atLeast38();
46+
47+
@Mock
48+
CredentialsProvider credentialsProvider;
49+
50+
@Mock
51+
CredentialsRefreshService refreshService;
52+
53+
private static ConnectionFactory connectionFactoryThatSendsGarbageAfterUpdateSecret() {
54+
ConnectionFactory cf = new ConnectionFactory() {
55+
@Override
56+
protected AMQConnection createConnection(ConnectionParams params, FrameHandler frameHandler, MetricsCollector metricsCollector) {
57+
return new AMQConnection(params, frameHandler, metricsCollector) {
58+
59+
@Override
60+
AMQChannel createChannel0() {
61+
return new AMQChannel(this, 0) {
62+
@Override
63+
public boolean processAsync(Command c) throws IOException {
64+
return getConnection().processControlCommand(c);
65+
}
66+
67+
@Override
68+
public AMQCommand rpc(Method m) throws IOException, ShutdownSignalException {
69+
if (m instanceof AMQImpl.Connection.UpdateSecret) {
70+
super.rpc(m);
71+
return super.rpc(new AMQImpl.Connection.UpdateSecret(LongStringHelper.asLongString(""), "Refresh scheduled by client") {
72+
@Override
73+
public int protocolMethodId() {
74+
return 255;
75+
}
76+
});
77+
} else {
78+
return super.rpc(m);
79+
}
80+
81+
}
82+
};
83+
84+
}
85+
};
86+
}
87+
};
88+
cf.setAutomaticRecoveryEnabled(false);
89+
if (TestUtils.USE_NIO) {
90+
cf.useNio();
91+
}
92+
return cf;
93+
}
94+
95+
@Test
96+
@SuppressWarnings("unchecked")
97+
public void connectionIsUnregisteredFromRefreshServiceWhenClosed() throws Exception {
98+
when(credentialsProvider.getUsername()).thenReturn("guest");
99+
when(credentialsProvider.getPassword()).thenReturn("guest");
100+
when(credentialsProvider.getTimeBeforeExpiration()).thenReturn(Duration.ofSeconds(10));
101+
102+
ConnectionFactory cf = TestUtils.connectionFactory();
103+
cf.setCredentialsProvider(credentialsProvider);
104+
105+
String registrationId = UUID.randomUUID().toString();
106+
CountDownLatch unregisteredLatch = new CountDownLatch(1);
107+
108+
AtomicReference<Callable<Boolean>> refreshTokenCallable = new AtomicReference<>();
109+
when(refreshService.register(eq(credentialsProvider), any(Callable.class))).thenAnswer(invocation -> {
110+
refreshTokenCallable.set(invocation.getArgument(1));
111+
return registrationId;
112+
});
113+
doAnswer(invocation -> {
114+
unregisteredLatch.countDown();
115+
return null;
116+
}).when(refreshService).unregister(credentialsProvider, registrationId);
117+
118+
cf.setCredentialsRefreshService(refreshService);
119+
120+
verify(refreshService, never()).register(any(CredentialsProvider.class), any(Callable.class));
121+
try (Connection c = cf.newConnection()) {
122+
verify(refreshService, times(1)).register(eq(credentialsProvider), any(Callable.class));
123+
Channel ch = c.createChannel();
124+
String queue = ch.queueDeclare().getQueue();
125+
TestUtils.sendAndConsumeMessage("", queue, queue, c);
126+
verify(refreshService, never()).unregister(any(CredentialsProvider.class), anyString());
127+
// calling refresh
128+
assertThat(refreshTokenCallable.get().call()).isTrue();
129+
}
130+
verify(refreshService, times(1)).register(eq(credentialsProvider), any(Callable.class));
131+
assertThat(unregisteredLatch.await(5, TimeUnit.SECONDS)).isTrue();
132+
verify(refreshService, times(1)).unregister(credentialsProvider, registrationId);
133+
}
134+
135+
@Test
136+
@SuppressWarnings("unchecked")
137+
public void connectionIsUnregisteredFromRefreshServiceIfUpdateSecretFails() throws Exception {
138+
when(credentialsProvider.getUsername()).thenReturn("guest");
139+
when(credentialsProvider.getPassword()).thenReturn("guest");
140+
when(credentialsProvider.getTimeBeforeExpiration()).thenReturn(Duration.ofSeconds(10));
141+
142+
ConnectionFactory cf = connectionFactoryThatSendsGarbageAfterUpdateSecret();
143+
cf.setCredentialsProvider(credentialsProvider);
144+
145+
String registrationId = UUID.randomUUID().toString();
146+
CountDownLatch unregisteredLatch = new CountDownLatch(1);
147+
AtomicReference<Callable<Boolean>> refreshTokenCallable = new AtomicReference<>();
148+
when(refreshService.register(eq(credentialsProvider), any(Callable.class))).thenAnswer(invocation -> {
149+
refreshTokenCallable.set(invocation.getArgument(1));
150+
return registrationId;
151+
});
152+
doAnswer(invocation -> {
153+
unregisteredLatch.countDown();
154+
return null;
155+
}).when(refreshService).unregister(credentialsProvider, registrationId);
156+
157+
cf.setCredentialsRefreshService(refreshService);
158+
159+
Connection c = cf.newConnection();
160+
verify(refreshService, times(1)).register(eq(credentialsProvider), any(Callable.class));
161+
Channel ch = c.createChannel();
162+
String queue = ch.queueDeclare().getQueue();
163+
TestUtils.sendAndConsumeMessage("", queue, queue, c);
164+
verify(refreshService, never()).unregister(any(CredentialsProvider.class), anyString());
165+
166+
verify(refreshService, never()).unregister(any(CredentialsProvider.class), anyString());
167+
// calling refresh, this sends garbage and should make the broker close the connection
168+
assertThat(refreshTokenCallable.get().call()).isFalse();
169+
assertThat(unregisteredLatch.await(5, TimeUnit.SECONDS)).isTrue();
170+
verify(refreshService, times(1)).unregister(credentialsProvider, registrationId);
171+
assertThat(c.isOpen()).isFalse();
172+
}
173+
}

src/test/java/com/rabbitmq/client/test/ClientTests.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717
package com.rabbitmq.client.test;
1818

1919
import com.rabbitmq.client.JacksonJsonRpcTest;
20-
import com.rabbitmq.client.impl.DefaultCredentialsRefreshServiceTest;
21-
import com.rabbitmq.client.impl.OAuth2ClientCredentialsGrantCredentialsProviderTest;
22-
import com.rabbitmq.client.impl.RefreshProtectedCredentialsProviderTest;
23-
import com.rabbitmq.client.impl.ValueWriterTest;
20+
import com.rabbitmq.client.impl.*;
2421
import com.rabbitmq.utility.IntAllocatorTests;
2522

2623
import org.junit.runner.RunWith;
@@ -78,6 +75,7 @@
7875
DefaultCredentialsRefreshServiceTest.class,
7976
OAuth2ClientCredentialsGrantCredentialsProviderTest.class,
8077
RefreshCredentialsTest.class,
78+
AMQConnectionRefreshCredentialsTest.class,
8179
ValueWriterTest.class
8280
})
8381
public class ClientTests {

0 commit comments

Comments
 (0)