From 0ca086fa1ec00ecaa282f25d8b733ce3a9b909d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arnaud=20Cogolu=C3=A8gnes?= <514737+acogoluegnes@users.noreply.github.com> Date: Mon, 10 Mar 2025 14:57:01 +0100 Subject: [PATCH] Make RPC server accept or discard request message Accept it if successfully processed and replied to, discard it if failure somewhere. Fixes #163 --- src/docs/asciidoc/usage.adoc | 6 + .../client/amqp/impl/AmqpRpcServer.java | 27 ++-- .../rabbitmq/client/amqp/impl/RpcTest.java | 117 ++++++++++++++---- 3 files changed, 115 insertions(+), 35 deletions(-) diff --git a/src/docs/asciidoc/usage.adoc b/src/docs/asciidoc/usage.adoc index c0b65958d..b1f46b9a8 100644 --- a/src/docs/asciidoc/usage.adoc +++ b/src/docs/asciidoc/usage.adoc @@ -220,6 +220,12 @@ include::{test-examples}/RpcApi.java[tag=rpc-client-request] The `RpcClient#publish(Message)` method returns a `CompletableFuture` that holds the reply message. It is then possible to wait for the reply asynchronously or synchronously. +The RPC server has the following behavior: + +* when receiving a message request, it calls the processing logic (handler), extracts the correlation ID, calls a reply post-processor if defined, and sends the reply message. +* if all these operations succeed, the server accepts the request message (settles it with the `ACCEPTED` outcome). +* if any of these operations throws an exception, the server discards the request message (the message is removed from the request queue and is https://www.rabbitmq.com/client-libraries/amqp-client-libraries#message-processing-result-outcome[dead-lettered] if configured). + The RPC server uses the following defaults: * it uses the _request_ https://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-properties[`message-id` property] for the correlation ID. diff --git a/src/main/java/com/rabbitmq/client/amqp/impl/AmqpRpcServer.java b/src/main/java/com/rabbitmq/client/amqp/impl/AmqpRpcServer.java index 3ce2346c1..1cda9b6f8 100644 --- a/src/main/java/com/rabbitmq/client/amqp/impl/AmqpRpcServer.java +++ b/src/main/java/com/rabbitmq/client/amqp/impl/AmqpRpcServer.java @@ -89,15 +89,24 @@ public Message message(byte[] body) { .queue(builder.requestQueue()) .messageHandler( (ctx, msg) -> { - ctx.accept(); - Message reply = handler.handle(context, msg); - if (reply != null && msg.replyTo() != null) { - reply.to(msg.replyTo()); - } - Object correlationId = correlationIdExtractor.apply(msg); - reply = replyPostProcessor.apply(reply, correlationId); - if (reply != null && reply.to() != null) { - sendReply(reply); + Object correlationId = null; + try { + Message reply = handler.handle(context, msg); + if (reply != null && msg.replyTo() != null) { + reply.to(msg.replyTo()); + } + correlationId = correlationIdExtractor.apply(msg); + reply = replyPostProcessor.apply(reply, correlationId); + if (reply != null && reply.to() != null) { + sendReply(reply); + } + ctx.accept(); + } catch (Exception e) { + LOGGER.info( + "Error while processing RPC request (correlation ID {}): {}", + correlationId, + e.getMessage()); + ctx.discard(); } }) .build(); diff --git a/src/test/java/com/rabbitmq/client/amqp/impl/RpcTest.java b/src/test/java/com/rabbitmq/client/amqp/impl/RpcTest.java index f3f658102..720ab6e39 100644 --- a/src/test/java/com/rabbitmq/client/amqp/impl/RpcTest.java +++ b/src/test/java/com/rabbitmq/client/amqp/impl/RpcTest.java @@ -17,13 +17,18 @@ // info@rabbitmq.com. package com.rabbitmq.client.amqp.impl; -import static com.rabbitmq.client.amqp.impl.TestUtils.waitAtMost; +import static com.rabbitmq.client.amqp.Management.ExchangeType.FANOUT; +import static com.rabbitmq.client.amqp.impl.Assertions.assertThat; +import static com.rabbitmq.client.amqp.impl.TestUtils.*; import static java.nio.charset.StandardCharsets.UTF_8; import static java.time.Duration.ofMillis; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.fail; import com.rabbitmq.client.amqp.*; +import com.rabbitmq.client.amqp.impl.TestUtils.Sync; import java.time.Duration; import java.util.Random; import java.util.UUID; @@ -75,7 +80,7 @@ void rpcWithDefaults() { serverConnection.rpcServerBuilder().requestQueue(requestQueue).handler(HANDLER).build(); int requestCount = 100; - CountDownLatch latch = new CountDownLatch(requestCount); + Sync sync = sync(requestCount); IntStream.range(0, requestCount) .forEach( ignored -> @@ -86,10 +91,10 @@ void rpcWithDefaults() { rpcClient.publish(rpcClient.message(request.getBytes(UTF_8))); Message response = responseFuture.get(10, TimeUnit.SECONDS); assertThat(response.body()).asString(UTF_8).isEqualTo(process(request)); - latch.countDown(); + sync.down(); return null; })); - Assertions.assertThat(latch).completes(); + assertThat(sync).completes(); } } @@ -136,7 +141,7 @@ void rpcWithCustomSettings() { .build(); int requestCount = 100; - CountDownLatch latch = new CountDownLatch(requestCount); + Sync sync = sync(requestCount); IntStream.range(0, requestCount) .forEach( ignored -> @@ -146,11 +151,13 @@ void rpcWithCustomSettings() { CompletableFuture responseFuture = rpcClient.publish(rpcClient.message(request.getBytes(UTF_8))); Message response = responseFuture.get(10, TimeUnit.SECONDS); - assertThat(response.body()).asString(UTF_8).isEqualTo(process(request)); - latch.countDown(); + org.assertj.core.api.Assertions.assertThat(response.body()) + .asString(UTF_8) + .isEqualTo(process(request)); + sync.down(); return null; })); - Assertions.assertThat(latch).completes(); + assertThat(sync).completes(); } } @@ -184,7 +191,7 @@ void rpcUseCorrelationIdRequestProperty() { .build(); int requestCount = 100; - CountDownLatch latch = new CountDownLatch(requestCount); + Sync sync = sync(requestCount); IntStream.range(0, requestCount) .forEach( ignored -> @@ -195,10 +202,10 @@ void rpcUseCorrelationIdRequestProperty() { rpcClient.publish(rpcClient.message(request.getBytes(UTF_8))); Message response = responseFuture.get(10, TimeUnit.SECONDS); assertThat(response.body()).asString(UTF_8).isEqualTo(process(request)); - latch.countDown(); + sync.down(); return null; })); - Assertions.assertThat(latch).completes(); + assertThat(sync).completes(); } } @@ -207,16 +214,16 @@ void rpcUseCorrelationIdRequestProperty() { void rpcShouldRecoverAfterConnectionIsClosed(boolean isolateResources) throws ExecutionException, InterruptedException, TimeoutException { String clientConnectionName = UUID.randomUUID().toString(); - CountDownLatch clientConnectionLatch = new CountDownLatch(1); + Sync clientConnectionSync = sync(); String serverConnectionName = UUID.randomUUID().toString(); - CountDownLatch serverConnectionLatch = new CountDownLatch(1); + Sync serverConnectionSync = sync(); BackOffDelayPolicy backOffDelayPolicy = BackOffDelayPolicy.fixed(ofMillis(100)); Connection serverConnection = connectionBuilder() .name(serverConnectionName) .isolateResources(isolateResources) - .listeners(recoveredListener(serverConnectionLatch)) + .listeners(recoveredListener(serverConnectionSync)) .recovery() .backOffDelayPolicy(backOffDelayPolicy) .connectionBuilder() @@ -226,7 +233,7 @@ void rpcShouldRecoverAfterConnectionIsClosed(boolean isolateResources) connectionBuilder() .name(clientConnectionName) .isolateResources(isolateResources) - .listeners(recoveredListener(clientConnectionLatch)) + .listeners(recoveredListener(clientConnectionSync)) .recovery() .backOffDelayPolicy(backOffDelayPolicy) .connectionBuilder() @@ -254,7 +261,7 @@ void rpcShouldRecoverAfterConnectionIsClosed(boolean isolateResources) } catch (AmqpException e) { // OK } - Assertions.assertThat(clientConnectionLatch).completes(); + assertThat(clientConnectionSync).completes(); requestBody = request(UUID.randomUUID().toString()); response = rpcClient.publish(rpcClient.message(requestBody).messageId(UUID.randomUUID())); assertThat(response.get(10, TimeUnit.SECONDS).body()).isEqualTo(process(requestBody)); @@ -263,7 +270,7 @@ void rpcShouldRecoverAfterConnectionIsClosed(boolean isolateResources) requestBody = request(UUID.randomUUID().toString()); response = rpcClient.publish(rpcClient.message(requestBody).messageId(UUID.randomUUID())); assertThat(response.get(10, TimeUnit.SECONDS).body()).isEqualTo(process(requestBody)); - Assertions.assertThat(serverConnectionLatch).completes(); + assertThat(serverConnectionSync).completes(); requestBody = request(UUID.randomUUID().toString()); response = rpcClient.publish(rpcClient.message(requestBody).messageId(UUID.randomUUID())); assertThat(response.get(10, TimeUnit.SECONDS).body()).isEqualTo(process(requestBody)); @@ -308,7 +315,7 @@ void poisonRequestsShouldTimeout() { int requestCount = 100; AtomicInteger expectedPoisonCount = new AtomicInteger(); AtomicInteger timedOutRequestCount = new AtomicInteger(); - CountDownLatch latch = new CountDownLatch(requestCount); + Sync sync = sync(requestCount); Random random = new Random(); IntStream.range(0, requestCount) .forEach( @@ -330,18 +337,18 @@ void poisonRequestsShouldTimeout() { if (ex != null) { timedOutRequestCount.incrementAndGet(); } - latch.countDown(); + sync.down(); return null; }); }); }); - Assertions.assertThat(latch).completes(); + assertThat(sync).completes(); assertThat(timedOutRequestCount).hasPositiveValue().hasValue(expectedPoisonCount.get()); } } @Test - void outstandingRequestsShouldCompleteExceptionallyOnRpcClientClosing() throws Exception { + void outstandingRequestsShouldCompleteExceptionallyOnRpcClientClosing() { try (Connection clientConnection = environment.connectionBuilder().build(); Connection serverConnection = environment.connectionBuilder().build()) { @@ -375,7 +382,7 @@ void outstandingRequestsShouldCompleteExceptionallyOnRpcClientClosing() throws E AtomicInteger timedOutRequestCount = new AtomicInteger(); AtomicInteger completedRequestCount = new AtomicInteger(); Random random = new Random(); - CountDownLatch allRequestSubmitted = new CountDownLatch(requestCount); + Sync allRequestSubmitted = sync(requestCount); IntStream.range(0, requestCount) .forEach( ignored -> { @@ -401,9 +408,9 @@ void outstandingRequestsShouldCompleteExceptionallyOnRpcClientClosing() throws E return null; }); }); - allRequestSubmitted.countDown(); + allRequestSubmitted.down(); }); - Assertions.assertThat(allRequestSubmitted).completes(); + assertThat(allRequestSubmitted).completes(); waitAtMost(() -> completedRequestCount.get() == requestCount - expectedPoisonCount.get()); assertThat(timedOutRequestCount).hasValue(0); rpcClient.close(); @@ -411,6 +418,64 @@ void outstandingRequestsShouldCompleteExceptionallyOnRpcClientClosing() throws E } } + @Test + void errorDuringProcessingShouldDiscardMessageAndDeadLetterIfSet(TestInfo info) + throws ExecutionException, InterruptedException, TimeoutException { + try (Connection clientConnection = environment.connectionBuilder().build(); + Connection serverConnection = environment.connectionBuilder().build()) { + + String dlx = name(info); + String dlq = name(info); + Management management = serverConnection.management(); + management.exchange(dlx).type(FANOUT).autoDelete(true).declare(); + management.queue(dlq).exclusive(true).declare(); + management.binding().sourceExchange(dlx).destinationQueue(dlq).bind(); + + String requestQueue = + management.queue().exclusive(true).deadLetterExchange(dlx).declare().name(); + + Duration requestTimeout = Duration.ofSeconds(1); + RpcClient rpcClient = + clientConnection + .rpcClientBuilder() + .requestTimeout(requestTimeout) + .requestAddress() + .queue(requestQueue) + .rpcClient() + .build(); + + serverConnection + .rpcServerBuilder() + .requestQueue(requestQueue) + .handler( + (ctx, request) -> { + String body = new String(request.body(), UTF_8); + if (body.contains("poison")) { + throw new RuntimeException("Poison message"); + } + return HANDLER.handle(ctx, request); + }) + .build(); + + String request = UUID.randomUUID().toString(); + CompletableFuture responseFuture = + rpcClient.publish(rpcClient.message(request.getBytes(UTF_8))); + Message response = responseFuture.get(10, TimeUnit.SECONDS); + assertThat(response.body()).asString(UTF_8).isEqualTo(process(request)); + + assertThat(management.queueInfo(dlq)).isEmpty(); + + request = "poison"; + CompletableFuture poisonFuture = + rpcClient.publish(rpcClient.message(request.getBytes(UTF_8))); + waitAtMost(() -> management.queueInfo(dlq).messageCount() == 1); + assertThatThrownBy( + () -> poisonFuture.get(requestTimeout.multipliedBy(3).toMillis(), MILLISECONDS)) + .isInstanceOf(ExecutionException.class) + .hasCauseInstanceOf(AmqpException.class); + } + } + private static AmqpConnectionBuilder connectionBuilder() { return (AmqpConnectionBuilder) environment.connectionBuilder(); } @@ -427,11 +492,11 @@ private static byte[] process(byte[] in) { return process(new String(in, UTF_8)).getBytes(UTF_8); } - private static Resource.StateListener recoveredListener(CountDownLatch latch) { + private static Resource.StateListener recoveredListener(Sync sync) { return context -> { if (context.previousState() == Resource.State.RECOVERING && context.currentState() == Resource.State.OPEN) { - latch.countDown(); + sync.down(); } }; }