Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/docs/asciidoc/usage.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ include::{test-examples}/RpcApi.java[tag=rpc-client-request]
The `RpcClient#publish(Message)` method returns a `CompletableFuture<Message>` that holds the reply message.
It is then possible to wait for the reply asynchronously or synchronously.

The RPC server has the following behavior:

* when receiving a message request, it calls the processing logic (handler), extracts the correlation ID, calls a reply post-processor if defined, and sends the reply message.
* if all these operations succeed, the server accepts the request message (settles it with the `ACCEPTED` outcome).
* if any of these operations throws an exception, the server discards the request message (the message is removed from the request queue and is https://www.rabbitmq.com/client-libraries/amqp-client-libraries#message-processing-result-outcome[dead-lettered] if configured).

The RPC server uses the following defaults:

* it uses the _request_ https://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-properties[`message-id` property] for the correlation ID.
Expand Down
27 changes: 18 additions & 9 deletions src/main/java/com/rabbitmq/client/amqp/impl/AmqpRpcServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,24 @@ public Message message(byte[] body) {
.queue(builder.requestQueue())
.messageHandler(
(ctx, msg) -> {
ctx.accept();
Message reply = handler.handle(context, msg);
if (reply != null && msg.replyTo() != null) {
reply.to(msg.replyTo());
}
Object correlationId = correlationIdExtractor.apply(msg);
reply = replyPostProcessor.apply(reply, correlationId);
if (reply != null && reply.to() != null) {
sendReply(reply);
Object correlationId = null;
try {
Message reply = handler.handle(context, msg);
if (reply != null && msg.replyTo() != null) {
reply.to(msg.replyTo());
}
correlationId = correlationIdExtractor.apply(msg);
reply = replyPostProcessor.apply(reply, correlationId);
if (reply != null && reply.to() != null) {
sendReply(reply);
}
ctx.accept();
} catch (Exception e) {
LOGGER.info(
"Error while processing RPC request (correlation ID {}): {}",
correlationId,
e.getMessage());
ctx.discard();
}
})
.build();
Expand Down
117 changes: 91 additions & 26 deletions src/test/java/com/rabbitmq/client/amqp/impl/RpcTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
// info@rabbitmq.com.
package com.rabbitmq.client.amqp.impl;

import static com.rabbitmq.client.amqp.impl.TestUtils.waitAtMost;
import static com.rabbitmq.client.amqp.Management.ExchangeType.FANOUT;
import static com.rabbitmq.client.amqp.impl.Assertions.assertThat;
import static com.rabbitmq.client.amqp.impl.TestUtils.*;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.time.Duration.ofMillis;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.fail;

import com.rabbitmq.client.amqp.*;
import com.rabbitmq.client.amqp.impl.TestUtils.Sync;
import java.time.Duration;
import java.util.Random;
import java.util.UUID;
Expand Down Expand Up @@ -75,7 +80,7 @@ void rpcWithDefaults() {
serverConnection.rpcServerBuilder().requestQueue(requestQueue).handler(HANDLER).build();

int requestCount = 100;
CountDownLatch latch = new CountDownLatch(requestCount);
Sync sync = sync(requestCount);
IntStream.range(0, requestCount)
.forEach(
ignored ->
Expand All @@ -86,10 +91,10 @@ void rpcWithDefaults() {
rpcClient.publish(rpcClient.message(request.getBytes(UTF_8)));
Message response = responseFuture.get(10, TimeUnit.SECONDS);
assertThat(response.body()).asString(UTF_8).isEqualTo(process(request));
latch.countDown();
sync.down();
return null;
}));
Assertions.assertThat(latch).completes();
assertThat(sync).completes();
}
}

Expand Down Expand Up @@ -136,7 +141,7 @@ void rpcWithCustomSettings() {
.build();

int requestCount = 100;
CountDownLatch latch = new CountDownLatch(requestCount);
Sync sync = sync(requestCount);
IntStream.range(0, requestCount)
.forEach(
ignored ->
Expand All @@ -146,11 +151,13 @@ void rpcWithCustomSettings() {
CompletableFuture<Message> responseFuture =
rpcClient.publish(rpcClient.message(request.getBytes(UTF_8)));
Message response = responseFuture.get(10, TimeUnit.SECONDS);
assertThat(response.body()).asString(UTF_8).isEqualTo(process(request));
latch.countDown();
org.assertj.core.api.Assertions.assertThat(response.body())
.asString(UTF_8)
.isEqualTo(process(request));
sync.down();
return null;
}));
Assertions.assertThat(latch).completes();
assertThat(sync).completes();
}
}

Expand Down Expand Up @@ -184,7 +191,7 @@ void rpcUseCorrelationIdRequestProperty() {
.build();

int requestCount = 100;
CountDownLatch latch = new CountDownLatch(requestCount);
Sync sync = sync(requestCount);
IntStream.range(0, requestCount)
.forEach(
ignored ->
Expand All @@ -195,10 +202,10 @@ void rpcUseCorrelationIdRequestProperty() {
rpcClient.publish(rpcClient.message(request.getBytes(UTF_8)));
Message response = responseFuture.get(10, TimeUnit.SECONDS);
assertThat(response.body()).asString(UTF_8).isEqualTo(process(request));
latch.countDown();
sync.down();
return null;
}));
Assertions.assertThat(latch).completes();
assertThat(sync).completes();
}
}

Expand All @@ -207,16 +214,16 @@ void rpcUseCorrelationIdRequestProperty() {
void rpcShouldRecoverAfterConnectionIsClosed(boolean isolateResources)
throws ExecutionException, InterruptedException, TimeoutException {
String clientConnectionName = UUID.randomUUID().toString();
CountDownLatch clientConnectionLatch = new CountDownLatch(1);
Sync clientConnectionSync = sync();
String serverConnectionName = UUID.randomUUID().toString();
CountDownLatch serverConnectionLatch = new CountDownLatch(1);
Sync serverConnectionSync = sync();

BackOffDelayPolicy backOffDelayPolicy = BackOffDelayPolicy.fixed(ofMillis(100));
Connection serverConnection =
connectionBuilder()
.name(serverConnectionName)
.isolateResources(isolateResources)
.listeners(recoveredListener(serverConnectionLatch))
.listeners(recoveredListener(serverConnectionSync))
.recovery()
.backOffDelayPolicy(backOffDelayPolicy)
.connectionBuilder()
Expand All @@ -226,7 +233,7 @@ void rpcShouldRecoverAfterConnectionIsClosed(boolean isolateResources)
connectionBuilder()
.name(clientConnectionName)
.isolateResources(isolateResources)
.listeners(recoveredListener(clientConnectionLatch))
.listeners(recoveredListener(clientConnectionSync))
.recovery()
.backOffDelayPolicy(backOffDelayPolicy)
.connectionBuilder()
Expand Down Expand Up @@ -254,7 +261,7 @@ void rpcShouldRecoverAfterConnectionIsClosed(boolean isolateResources)
} catch (AmqpException e) {
// OK
}
Assertions.assertThat(clientConnectionLatch).completes();
assertThat(clientConnectionSync).completes();
requestBody = request(UUID.randomUUID().toString());
response = rpcClient.publish(rpcClient.message(requestBody).messageId(UUID.randomUUID()));
assertThat(response.get(10, TimeUnit.SECONDS).body()).isEqualTo(process(requestBody));
Expand All @@ -263,7 +270,7 @@ void rpcShouldRecoverAfterConnectionIsClosed(boolean isolateResources)
requestBody = request(UUID.randomUUID().toString());
response = rpcClient.publish(rpcClient.message(requestBody).messageId(UUID.randomUUID()));
assertThat(response.get(10, TimeUnit.SECONDS).body()).isEqualTo(process(requestBody));
Assertions.assertThat(serverConnectionLatch).completes();
assertThat(serverConnectionSync).completes();
requestBody = request(UUID.randomUUID().toString());
response = rpcClient.publish(rpcClient.message(requestBody).messageId(UUID.randomUUID()));
assertThat(response.get(10, TimeUnit.SECONDS).body()).isEqualTo(process(requestBody));
Expand Down Expand Up @@ -308,7 +315,7 @@ void poisonRequestsShouldTimeout() {
int requestCount = 100;
AtomicInteger expectedPoisonCount = new AtomicInteger();
AtomicInteger timedOutRequestCount = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(requestCount);
Sync sync = sync(requestCount);
Random random = new Random();
IntStream.range(0, requestCount)
.forEach(
Expand All @@ -330,18 +337,18 @@ void poisonRequestsShouldTimeout() {
if (ex != null) {
timedOutRequestCount.incrementAndGet();
}
latch.countDown();
sync.down();
return null;
});
});
});
Assertions.assertThat(latch).completes();
assertThat(sync).completes();
assertThat(timedOutRequestCount).hasPositiveValue().hasValue(expectedPoisonCount.get());
}
}

@Test
void outstandingRequestsShouldCompleteExceptionallyOnRpcClientClosing() throws Exception {
void outstandingRequestsShouldCompleteExceptionallyOnRpcClientClosing() {
try (Connection clientConnection = environment.connectionBuilder().build();
Connection serverConnection = environment.connectionBuilder().build()) {

Expand Down Expand Up @@ -375,7 +382,7 @@ void outstandingRequestsShouldCompleteExceptionallyOnRpcClientClosing() throws E
AtomicInteger timedOutRequestCount = new AtomicInteger();
AtomicInteger completedRequestCount = new AtomicInteger();
Random random = new Random();
CountDownLatch allRequestSubmitted = new CountDownLatch(requestCount);
Sync allRequestSubmitted = sync(requestCount);
IntStream.range(0, requestCount)
.forEach(
ignored -> {
Expand All @@ -401,16 +408,74 @@ void outstandingRequestsShouldCompleteExceptionallyOnRpcClientClosing() throws E
return null;
});
});
allRequestSubmitted.countDown();
allRequestSubmitted.down();
});
Assertions.assertThat(allRequestSubmitted).completes();
assertThat(allRequestSubmitted).completes();
waitAtMost(() -> completedRequestCount.get() == requestCount - expectedPoisonCount.get());
assertThat(timedOutRequestCount).hasValue(0);
rpcClient.close();
assertThat(timedOutRequestCount).hasPositiveValue().hasValue(expectedPoisonCount.get());
}
}

@Test
void errorDuringProcessingShouldDiscardMessageAndDeadLetterIfSet(TestInfo info)
throws ExecutionException, InterruptedException, TimeoutException {
try (Connection clientConnection = environment.connectionBuilder().build();
Connection serverConnection = environment.connectionBuilder().build()) {

String dlx = name(info);
String dlq = name(info);
Management management = serverConnection.management();
management.exchange(dlx).type(FANOUT).autoDelete(true).declare();
management.queue(dlq).exclusive(true).declare();
management.binding().sourceExchange(dlx).destinationQueue(dlq).bind();

String requestQueue =
management.queue().exclusive(true).deadLetterExchange(dlx).declare().name();

Duration requestTimeout = Duration.ofSeconds(1);
RpcClient rpcClient =
clientConnection
.rpcClientBuilder()
.requestTimeout(requestTimeout)
.requestAddress()
.queue(requestQueue)
.rpcClient()
.build();

serverConnection
.rpcServerBuilder()
.requestQueue(requestQueue)
.handler(
(ctx, request) -> {
String body = new String(request.body(), UTF_8);
if (body.contains("poison")) {
throw new RuntimeException("Poison message");
}
return HANDLER.handle(ctx, request);
})
.build();

String request = UUID.randomUUID().toString();
CompletableFuture<Message> responseFuture =
rpcClient.publish(rpcClient.message(request.getBytes(UTF_8)));
Message response = responseFuture.get(10, TimeUnit.SECONDS);
assertThat(response.body()).asString(UTF_8).isEqualTo(process(request));

assertThat(management.queueInfo(dlq)).isEmpty();

request = "poison";
CompletableFuture<Message> poisonFuture =
rpcClient.publish(rpcClient.message(request.getBytes(UTF_8)));
waitAtMost(() -> management.queueInfo(dlq).messageCount() == 1);
assertThatThrownBy(
() -> poisonFuture.get(requestTimeout.multipliedBy(3).toMillis(), MILLISECONDS))
.isInstanceOf(ExecutionException.class)
.hasCauseInstanceOf(AmqpException.class);
}
}

private static AmqpConnectionBuilder connectionBuilder() {
return (AmqpConnectionBuilder) environment.connectionBuilder();
}
Expand All @@ -427,11 +492,11 @@ private static byte[] process(byte[] in) {
return process(new String(in, UTF_8)).getBytes(UTF_8);
}

private static Resource.StateListener recoveredListener(CountDownLatch latch) {
private static Resource.StateListener recoveredListener(Sync sync) {
return context -> {
if (context.previousState() == Resource.State.RECOVERING
&& context.currentState() == Resource.State.OPEN) {
latch.countDown();
sync.down();
}
};
}
Expand Down