diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java index 137060595d..70f70c45c1 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java @@ -96,6 +96,7 @@ public final class GrpcCallContext implements ApiCallContext { private final ImmutableMap> extraHeaders; private final ApiCallContextOptions options; private final EndpointContext endpointContext; + private final boolean isDirectPath; /** Returns an empty instance with a null channel and default {@link CallOptions}. */ public static GrpcCallContext createDefault() { @@ -111,7 +112,8 @@ public static GrpcCallContext createDefault() { ApiCallContextOptions.getDefaultOptions(), null, null, - null); + null, + false); } /** Returns an instance with the given channel and {@link CallOptions}. */ @@ -128,7 +130,8 @@ public static GrpcCallContext of(Channel channel, CallOptions callOptions) { ApiCallContextOptions.getDefaultOptions(), null, null, - null); + null, + false); } private GrpcCallContext( @@ -143,10 +146,14 @@ private GrpcCallContext( ApiCallContextOptions options, @Nullable RetrySettings retrySettings, @Nullable Set retryableCodes, - @Nullable EndpointContext endpointContext) { + @Nullable EndpointContext endpointContext, + boolean isDirectPath) { this.channel = channel; this.credentials = credentials; - this.callOptions = Preconditions.checkNotNull(callOptions); + Preconditions.checkNotNull(callOptions); + // CallCredentials is stripped from CallOptions because CallCredentials are attached + // to ChannelCredentials in DirectPath flows. Adding it again would duplicate the headers. + this.callOptions = isDirectPath ? callOptions.withCallCredentials(null) : callOptions; this.timeout = timeout; this.streamWaitTimeout = streamWaitTimeout; this.streamIdleTimeout = streamIdleTimeout; @@ -159,6 +166,7 @@ private GrpcCallContext( // a valid EndpointContext with user configurations after the client has been initialized. this.endpointContext = endpointContext == null ? EndpointContext.getDefaultInstance() : endpointContext; + this.isDirectPath = isDirectPath; } /** @@ -199,7 +207,8 @@ public GrpcCallContext withCredentials(Credentials newCredentials) { options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } @Override @@ -210,7 +219,20 @@ public GrpcCallContext withTransportChannel(TransportChannel inputChannel) { "Expected GrpcTransportChannel, got " + inputChannel.getClass().getName()); } GrpcTransportChannel transportChannel = (GrpcTransportChannel) inputChannel; - return withChannel(transportChannel.getChannel()); + return new GrpcCallContext( + transportChannel.getChannel(), + credentials, + callOptions, + timeout, + streamWaitTimeout, + streamIdleTimeout, + channelAffinity, + extraHeaders, + options, + retrySettings, + retryableCodes, + endpointContext, + transportChannel.isDirectPath()); } @Override @@ -228,7 +250,8 @@ public GrpcCallContext withEndpointContext(EndpointContext endpointContext) { options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } /** This method is obsolete. Use {@link #withTimeoutDuration(java.time.Duration)} instead. */ @@ -262,7 +285,8 @@ public GrpcCallContext withTimeoutDuration(@Nullable java.time.Duration timeout) options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } /** This method is obsolete. Use {@link #getTimeoutDuration()} instead. */ @@ -310,7 +334,8 @@ public GrpcCallContext withStreamWaitTimeoutDuration( options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } /** @@ -344,7 +369,8 @@ public GrpcCallContext withStreamIdleTimeoutDuration( options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } @BetaApi("The surface for channel affinity is not stable yet and may change in the future.") @@ -361,7 +387,8 @@ public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) { options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } @BetaApi("The surface for extra headers is not stable yet and may change in the future.") @@ -382,7 +409,8 @@ public GrpcCallContext withExtraHeaders(Map> extraHeaders) options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } @Override @@ -404,7 +432,8 @@ public GrpcCallContext withRetrySettings(RetrySettings retrySettings) { options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } @Override @@ -426,7 +455,8 @@ public GrpcCallContext withRetryableCodes(Set retryableCodes) { options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } @Override @@ -456,6 +486,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newDeadline = callOptions.getDeadline(); } + boolean newIsDirectPath = grpcCallContext.isDirectPath; + CallCredentials newCallCredentials = grpcCallContext.callOptions.getCredentials(); if (newCallCredentials == null) { newCallCredentials = callOptions.getCredentials(); @@ -525,7 +557,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newOptions, newRetrySettings, newRetryableCodes, - endpointContext); + endpointContext, + newIsDirectPath); } /** The {@link Channel} set on this context. */ @@ -588,7 +621,11 @@ public Map> getExtraHeaders() { return extraHeaders; } - /** Returns a new instance with the channel set to the given channel. */ + /** + * This method is obsolete. Use {@link #withTransportChannel()} instead. Returns a new instance + * with the channel set to the given channel. + */ + @ObsoleteApi("Use withTransportChannel() instead") public GrpcCallContext withChannel(Channel newChannel) { return new GrpcCallContext( newChannel, @@ -602,7 +639,8 @@ public GrpcCallContext withChannel(Channel newChannel) { options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } /** Returns a new instance with the call options set to the given call options. */ @@ -619,7 +657,8 @@ public GrpcCallContext withCallOptions(CallOptions newCallOptions) { options, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } public GrpcCallContext withRequestParamsDynamicHeaderOption(String requestParams) { @@ -663,7 +702,8 @@ public GrpcCallContext withOption(Key key, T value) { newOptions, retrySettings, retryableCodes, - endpointContext); + endpointContext, + isDirectPath); } /** {@inheritDoc} */ diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java index 63de03a88a..5f30ed58c9 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java @@ -42,11 +42,14 @@ import com.google.api.gax.rpc.testing.FakeTransportChannel; import com.google.api.gax.tracing.ApiTracer; import com.google.auth.Credentials; +import com.google.auth.oauth2.GoogleCredentials; import com.google.common.collect.ImmutableMap; import com.google.common.truth.Truth; +import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.ManagedChannel; import io.grpc.Metadata.Key; +import io.grpc.auth.MoreCallCredentials; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -100,6 +103,26 @@ void testWithTransportChannelWrongType() { } } + @Test + void testWithTransportChannelIsDirectPath() { + ManagedChannel channel = Mockito.mock(ManagedChannel.class); + Credentials credentials = Mockito.mock(GoogleCredentials.class); + GrpcCallContext context = GrpcCallContext.createDefault().withCredentials(credentials); + assertNotNull(context.getCallOptions().getCredentials()); + context = + context.withTransportChannel( + GrpcTransportChannel.newBuilder() + .setDirectPath(true) + .setManagedChannel(channel) + .build()); + assertNull(context.getCallOptions().getCredentials()); + + // Call credentials from the call options will be stripped. + context.withCallOptions( + CallOptions.DEFAULT.withCallCredentials(MoreCallCredentials.from(credentials))); + assertNull(context.getCallOptions().getCredentials()); + } + @Test void testMergeWrongType() { try { @@ -320,6 +343,25 @@ void testMergeWithCustomCallOptions() { .isEqualTo(ctx2.getCallOptions().getOption(key)); } + @Test + void testMergeWithIsDirectPath() { + ManagedChannel channel = Mockito.mock(ManagedChannel.class); + CallCredentials callCredentials = Mockito.mock(CallCredentials.class); + GrpcCallContext ctx1 = + GrpcCallContext.createDefault() + .withCallOptions(CallOptions.DEFAULT.withCallCredentials(callCredentials)); + GrpcCallContext ctx2 = + GrpcCallContext.createDefault() + .withTransportChannel( + GrpcTransportChannel.newBuilder() + .setDirectPath(true) + .setManagedChannel(channel) + .build()); + + GrpcCallContext merged = (GrpcCallContext) ctx1.merge(ctx2); + assertNull(merged.getCallOptions().getCredentials()); + } + @Test void testWithExtraHeaders() { Map> extraHeaders =