From 6bcc25df509830169e8d7b997d4728992f54662f Mon Sep 17 00:00:00 2001 From: Viet Nguyen Duc Date: Tue, 5 Aug 2025 08:59:36 +0700 Subject: [PATCH] [grid] External datastore Redis-backed for Session Queue Signed-off-by: Viet Nguyen Duc --- .../distributor/local/LocalDistributor.java | 66 +- .../config/NewSessionQueueOptions.java | 10 +- .../grid/sessionqueue/redis/BUILD.bazel | 30 + .../redis/RedisBackedSessionQueue.java | 851 ++++++++++++++++++ .../remote/RemoteNewSessionQueue.java | 54 +- .../grid/sessionqueue/redis/BUILD.bazel | 22 + .../redis/RedisBackedSessionQueueTest.java | 216 +++++ 7 files changed, 1230 insertions(+), 19 deletions(-) create mode 100644 java/src/org/openqa/selenium/grid/sessionqueue/redis/BUILD.bazel create mode 100644 java/src/org/openqa/selenium/grid/sessionqueue/redis/RedisBackedSessionQueue.java create mode 100644 java/test/org/openqa/selenium/grid/sessionqueue/redis/BUILD.bazel create mode 100644 java/test/org/openqa/selenium/grid/sessionqueue/redis/RedisBackedSessionQueueTest.java diff --git a/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java b/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java index 835150eb94f01..ef93904b43f0d 100644 --- a/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java +++ b/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java @@ -38,6 +38,7 @@ import java.io.UncheckedIOException; import java.net.URI; import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -248,13 +249,15 @@ public LocalDistributor( this.healthcheckInterval.toMillis(), TimeUnit.MILLISECONDS); - // if sessionRequestRetryInterval is 0, we will schedule session creation every 10 millis + // Default to 100ms if no interval is specified (was 10ms) long period = - sessionRequestRetryInterval.isZero() ? 10 : sessionRequestRetryInterval.toMillis(); - newSessionService.scheduleAtFixedRate( + sessionRequestRetryInterval.isZero() ? 100 : sessionRequestRetryInterval.toMillis(); + + // Use scheduleWithFixedDelay instead of scheduleAtFixedRate to prevent task pileup + newSessionService.scheduleWithFixedDelay( GuardedRunnable.guard(newSessionRunnable), - sessionRequestRetryInterval.toMillis(), - period, + period, // Initial delay + period, // Subsequent delays TimeUnit.MILLISECONDS); new JMXHelper().register(this); @@ -771,12 +774,47 @@ public void close() { } private class NewSessionRunnable implements Runnable { + private long backoffMs = 100; // Start with 100ms backoff + private static final long MAX_BACKOFF_MS = 5000; // Max 5 seconds backoff + private static final long MIN_BACKOFF_MS = 100; // Min 100ms backoff + private Instant lastNodeAvailableCheck = Instant.MIN; + private boolean hadNodesLastCheck = false; @Override public void run() { Set inQueue; boolean pollQueue; + // Check if we have any available nodes + boolean hasNodes = !getAvailableNodes().isEmpty(); + + // If we had nodes before but don't now, or vice versa, reset the backoff + if (hasNodes != hadNodesLastCheck + || Duration.between(lastNodeAvailableCheck, Instant.now()).toMillis() > 5000) { + backoffMs = MIN_BACKOFF_MS; + } + + hadNodesLastCheck = hasNodes; + lastNodeAvailableCheck = Instant.now(); + + // If no nodes available, apply backoff before proceeding + if (!hasNodes) { + try { + // Add some jitter to prevent thundering herd + long jitter = (long) (Math.random() * backoffMs * 0.1); // Up to 10% jitter + Thread.sleep(backoffMs + jitter); + + // Double the backoff for next time, up to the max + backoffMs = Math.min(backoffMs * 2, MAX_BACKOFF_MS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + } else { + // Reset backoff when we have nodes + backoffMs = MIN_BACKOFF_MS; + } + if (rejectUnsupportedCaps) { inQueue = sessionQueue.getQueueContents().stream() @@ -801,9 +839,21 @@ public void run() { Collectors.groupingBy(ImmutableCapabilities::copyOf, Collectors.counting())); if (!stereotypes.isEmpty()) { - List matchingRequests = sessionQueue.getNextAvailable(stereotypes); - matchingRequests.forEach( - req -> sessionCreatorExecutor.execute(() -> handleNewSessionRequest(req))); + try { + List matchingRequests = sessionQueue.getNextAvailable(stereotypes); + if (!matchingRequests.isEmpty()) { + // Process requests in batch + matchingRequests.forEach( + req -> sessionCreatorExecutor.execute(() -> handleNewSessionRequest(req))); + } else if (backoffMs < MAX_BACKOFF_MS) { + // If we didn't get any requests, increase backoff slightly + backoffMs = Math.min((long) (backoffMs * 1.5), MAX_BACKOFF_MS); + } + } catch (Exception e) { + LOG.log(Level.SEVERE, "Error processing session requests", e); + // On error, back off more aggressively + backoffMs = Math.min(backoffMs * 2, MAX_BACKOFF_MS); + } } } diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/config/NewSessionQueueOptions.java b/java/src/org/openqa/selenium/grid/sessionqueue/config/NewSessionQueueOptions.java index ef6e74e2736c5..8e2c8f27eeddb 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/config/NewSessionQueueOptions.java +++ b/java/src/org/openqa/selenium/grid/sessionqueue/config/NewSessionQueueOptions.java @@ -50,6 +50,12 @@ public NewSessionQueueOptions(Config config) { public URI getSessionQueueUri() { + BaseServerOptions serverOptions = new BaseServerOptions(config); + String scheme = + config + .get(SESSION_QUEUE_SECTION, "scheme") + .orElse((serverOptions.isSecure() || serverOptions.isSelfSigned()) ? "https" : "http"); + Optional host = config .get(SESSION_QUEUE_SECTION, "host") @@ -72,8 +78,6 @@ public URI getSessionQueueUri() { return host.get(); } - BaseServerOptions serverOptions = new BaseServerOptions(config); - String schema = (serverOptions.isSecure() || serverOptions.isSelfSigned()) ? "https" : "http"; Optional port = config.getInt(SESSION_QUEUE_SECTION, "port"); Optional hostname = config.get(SESSION_QUEUE_SECTION, "hostname"); @@ -82,7 +86,7 @@ public URI getSessionQueueUri() { } try { - return new URI(schema, null, hostname.get(), port.get(), "", null, null); + return new URI(scheme, null, hostname.get(), port.get(), "", null, null); } catch (URISyntaxException e) { throw new ConfigException( "Session queue server uri configured through host (%s) and port (%d) is not a valid URI", diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/redis/BUILD.bazel b/java/src/org/openqa/selenium/grid/sessionqueue/redis/BUILD.bazel new file mode 100644 index 0000000000000..d87d5854c5e15 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionqueue/redis/BUILD.bazel @@ -0,0 +1,30 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") +load("//java:defs.bzl", "java_export") +load("//java:version.bzl", "SE_VERSION") + +java_export( + name = "redis", + srcs = glob(["*.java"]), + maven_coordinates = "org.seleniumhq.selenium:selenium-session-queue-redis:%s" % SE_VERSION, + pom_template = "//java/src/org/openqa/selenium:template-pom", + tags = [ + "release-artifact", + ], + visibility = [ + "//visibility:public", + ], + exports = [ + "//java/src/org/openqa/selenium/grid", + ], + deps = [ + "//java:auto-service", + "//java/src/org/openqa/selenium/grid", + "//java/src/org/openqa/selenium/json", + "//java/src/org/openqa/selenium/redis", + "//java/src/org/openqa/selenium/remote", + artifact("com.beust:jcommander"), + artifact("com.google.guava:guava"), + artifact("io.lettuce:lettuce-core"), + artifact("org.redisson:redisson"), + ], +) diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/redis/RedisBackedSessionQueue.java b/java/src/org/openqa/selenium/grid/sessionqueue/redis/RedisBackedSessionQueue.java new file mode 100644 index 0000000000000..76a47332639d1 --- /dev/null +++ b/java/src/org/openqa/selenium/grid/sessionqueue/redis/RedisBackedSessionQueue.java @@ -0,0 +1,851 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.grid.sessionqueue.redis; + +import static java.net.HttpURLConnection.HTTP_INTERNAL_ERROR; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.openqa.selenium.concurrent.ExecutorServices.shutdownGracefully; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import java.io.Closeable; +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Predicate; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import org.openqa.selenium.Capabilities; +import org.openqa.selenium.SessionNotCreatedException; +import org.openqa.selenium.concurrent.GuardedRunnable; +import org.openqa.selenium.grid.config.Config; +import org.openqa.selenium.grid.data.CreateSessionResponse; +import org.openqa.selenium.grid.data.RequestId; +import org.openqa.selenium.grid.data.SessionRequest; +import org.openqa.selenium.grid.data.SessionRequestCapability; +import org.openqa.selenium.grid.data.SlotMatcher; +import org.openqa.selenium.grid.data.TraceSessionRequest; +import org.openqa.selenium.grid.distributor.config.DistributorOptions; +import org.openqa.selenium.grid.jmx.JMXHelper; +import org.openqa.selenium.grid.jmx.ManagedAttribute; +import org.openqa.selenium.grid.jmx.ManagedService; +import org.openqa.selenium.grid.log.LoggingOptions; +import org.openqa.selenium.grid.security.Secret; +import org.openqa.selenium.grid.security.SecretOptions; +import org.openqa.selenium.grid.sessionqueue.NewSessionQueue; +import org.openqa.selenium.grid.sessionqueue.config.NewSessionQueueOptions; +import org.openqa.selenium.internal.Either; +import org.openqa.selenium.internal.Require; +import org.openqa.selenium.json.Json; +import org.openqa.selenium.redis.GridRedisClient; +import org.openqa.selenium.remote.http.Contents; +import org.openqa.selenium.remote.http.HttpResponse; +import org.openqa.selenium.remote.tracing.Span; +import org.openqa.selenium.remote.tracing.TraceContext; +import org.openqa.selenium.remote.tracing.Tracer; + +/** + * A Redis-backed implementation of the list of new session requests. + * + *

The lifecycle of a request can be described as: + * + *

    + *
  1. User adds an item on to the queue using {@link #addToQueue(SessionRequest)}. This will + * block until the request completes in some way. + *
  2. If the session request is completed, then {@link #complete(RequestId, Either)} must be + * called. This will ensure that {@link #addToQueue(SessionRequest)} returns. + *
  3. If the request cannot be handled right now, call {@link #retryAddToQueue(SessionRequest)} + * to return the session request to the front of the queue. + *
+ * + *

There is a background thread that will reap {@link SessionRequest}s that have timed out. This + * means that a request can either complete by a listener calling {@link #complete(RequestId, + * Either)} directly, or by being reaped by the thread. + * + *

Redis persistence ensures that session requests survive restarts and can be processed by + * multiple Grid instances. + */ +@ManagedService( + objectName = "org.seleniumhq.grid:type=SessionQueue,name=RedisBackedSessionQueue", + description = "Redis backed session queue") +public class RedisBackedSessionQueue extends NewSessionQueue implements Closeable { + + private static final Logger LOG = Logger.getLogger(RedisBackedSessionQueue.class.getName()); + private static final String NAME = "Redis Backed New Session Queue"; + private static final Json JSON = new Json(); + + // Redis keys + private static final String QUEUE_KEY = "selenium:session:queue"; + private static final String REQUEST_KEY_PREFIX = "selenium:session:request:"; + private static final String DATA_KEY_PREFIX = "selenium:session:data:"; + private static final String CONTEXT_KEY_PREFIX = "selenium:session:context:"; + + private final SlotMatcher slotMatcher; + private final GridRedisClient redisClient; + private final URI sessionQueueUri; + private final Duration requestTimeout; + private final Duration maximumResponseDelay; + private final int batchSize; + + // In-memory state management (mirrors LocalNewSessionQueue) + private final Map requests; + private final Map contexts; + private final Deque queue; + private final ReadWriteLock lock = new ReentrantReadWriteLock(); + + private final ScheduledExecutorService service = + Executors.newSingleThreadScheduledExecutor( + r -> { + Thread thread = new Thread(r); + thread.setDaemon(true); + thread.setName(NAME); + return thread; + }); + + public RedisBackedSessionQueue( + Tracer tracer, + SlotMatcher slotMatcher, + URI sessionQueueUri, + Duration requestTimeoutCheck, + Duration requestTimeout, + Duration maximumResponseDelay, + Secret registrationSecret, + int batchSize) { + super(tracer, registrationSecret); + + this.slotMatcher = Require.nonNull("Slot matcher", slotMatcher); + this.sessionQueueUri = Require.nonNull("Redis URI", sessionQueueUri); + this.redisClient = new GridRedisClient(sessionQueueUri); + + Require.nonNegative("Retry period", requestTimeoutCheck); + this.requestTimeout = Require.positive("Request timeout", requestTimeout); + this.maximumResponseDelay = Require.positive("Maximum response delay", maximumResponseDelay); + this.batchSize = Require.positive("Batch size", batchSize); + + this.requests = new ConcurrentHashMap<>(); + this.queue = new ConcurrentLinkedDeque<>(); + this.contexts = new ConcurrentHashMap<>(); + + // Restore state from Redis on startup + restoreStateFromRedis(); + + service.scheduleAtFixedRate( + GuardedRunnable.guard(this::timeoutSessions), + requestTimeoutCheck.toMillis(), + requestTimeoutCheck.toMillis(), + MILLISECONDS); + + new JMXHelper().register(this); + } + + public static NewSessionQueue create(Config config) { + LoggingOptions loggingOptions = new LoggingOptions(config); + Tracer tracer = loggingOptions.getTracer(); + + NewSessionQueueOptions newSessionQueueOptions = new NewSessionQueueOptions(config); + SecretOptions secretOptions = new SecretOptions(config); + + // Use the factory to create a SlotMatcher to avoid circular dependencies + SlotMatcher slotMatcher = new DistributorOptions(config).getSlotMatcher(); + + return new RedisBackedSessionQueue( + tracer, + slotMatcher, + newSessionQueueOptions.getSessionQueueUri(), + newSessionQueueOptions.getSessionRequestTimeoutPeriod(), + newSessionQueueOptions.getSessionRequestTimeout(), + newSessionQueueOptions.getMaximumResponseDelay(), + secretOptions.getRegistrationSecret(), + newSessionQueueOptions.getBatchSize()); + } + + /** Restores in-memory state from Redis on startup */ + private void restoreStateFromRedis() { + try { + String queueData = redisClient.get(QUEUE_KEY); + LOG.info( + "[RedisBackedSessionQueue.restoreStateFromRedis] Raw queue data from Redis: [" + + queueData + + "]"); + if (queueData != null && !queueData.isEmpty()) { + String[] requestIdArray = queueData.split(","); + for (String requestIdStr : requestIdArray) { + if (requestIdStr.trim().isEmpty()) continue; + try { + RequestId requestId = new RequestId(java.util.UUID.fromString(requestIdStr.trim())); + String requestKey = REQUEST_KEY_PREFIX + requestIdStr.trim(); + String dataKey = DATA_KEY_PREFIX + requestIdStr.trim(); + String requestJson = redisClient.get(requestKey); + String dataJson = redisClient.get(dataKey); + LOG.info( + "[restoreStateFromRedis] Read from Redis: requestKey=" + + requestKey + + ", requestJson=" + + requestJson); + LOG.info( + "[restoreStateFromRedis] Read from Redis: dataKey=" + + dataKey + + ", dataJson=" + + dataJson); + if (requestJson != null && dataJson != null) { + SessionRequest request = JSON.toType(requestJson, SessionRequest.class); + Data data = JSON.toType(dataJson, Data.class); // Deserialize full Data object + requests.put(requestId, data); + queue.add(request); + if (isTimedOut(Instant.now(), data)) { + failDueToTimeout(requestId); + } + } + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to restore request from Redis: " + requestIdStr, e); + cleanupRedisRequest(requestIdStr.trim()); + } + } + } + LOG.info("Restored " + requests.size() + " session requests from Redis"); + } catch (Exception e) { + LOG.log(Level.SEVERE, "Failed to restore state from Redis", e); + } + } + + /** Persists request state to Redis asynchronously */ + private void persistToRedis(SessionRequest request, Data data) { + String requestKey = REQUEST_KEY_PREFIX + request.getRequestId(); + String dataKey = DATA_KEY_PREFIX + request.getRequestId(); + String requestJson = JSON.toJson(request); + String dataJson = JSON.toJson(data); + + // Log what we're about to write + LOG.info( + "[persistToRedis] Writing to Redis: requestKey=" + + requestKey + + ", requestJson=" + + requestJson); + LOG.info("[persistToRedis] Writing to Redis: dataKey=" + dataKey + ", dataJson=" + dataJson); + + // Use mset to write both keys at once + Map keyValues = new HashMap<>(); + keyValues.put(requestKey, requestJson); + keyValues.put(dataKey, dataJson); + redisClient.mset(keyValues); + + // Update the queue key + StringBuilder queueBuilder = new StringBuilder(); + for (SessionRequest req : queue) { + queueBuilder.append(req.getRequestId().toString()).append(","); + } + String queueString = + queueBuilder.length() > 0 ? queueBuilder.substring(0, queueBuilder.length() - 1) : ""; + + LOG.info( + "[persistToRedis] Writing to Redis: QUEUE_KEY=" + + QUEUE_KEY + + ", queueString=" + + queueString); + + // Update queue using mset + Map queueUpdate = new HashMap<>(); + queueUpdate.put(QUEUE_KEY, queueString); + redisClient.mset(queueUpdate); + } + + /** Removes request from Redis asynchronously */ + private void removeFromRedis(RequestId requestId) { + service.execute( + () -> { + try { + String requestIdStr = requestId.toString(); + cleanupRedisRequest(requestIdStr); + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to remove request from Redis: " + requestId, e); + } + }); + } + + private void cleanupRedisRequest(String requestIdStr) { + try { + String requestKey = REQUEST_KEY_PREFIX + requestIdStr; + String dataKey = DATA_KEY_PREFIX + requestIdStr; + String contextKey = CONTEXT_KEY_PREFIX + requestIdStr; + + // Remove all associated data + redisClient.del(requestKey, dataKey, contextKey); + + // Remove from queue list + String currentQueue = redisClient.get(QUEUE_KEY); + if (currentQueue != null && !currentQueue.isEmpty()) { + String[] requestIds = currentQueue.split(","); + StringBuilder newQueue = new StringBuilder(); + for (String id : requestIds) { + if (!id.trim().equals(requestIdStr)) { + if (newQueue.length() > 0) { + newQueue.append(","); + } + newQueue.append(id.trim()); + } + } + Map queueUpdate = new HashMap<>(); + queueUpdate.put(QUEUE_KEY, newQueue.toString()); + redisClient.mset(queueUpdate); + int queueCount = newQueue.toString().isEmpty() ? 0 : newQueue.toString().split(",").length; + LOG.info( + "[RedisBackedSessionQueue.cleanupRedisRequest] Queue after removal: [" + + newQueue + + "] (" + + queueCount + + " requests)"); + } + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to cleanup Redis request: " + requestIdStr, e); + } + } + + private void timeoutSessions() { + Instant now = Instant.now(); + + Lock readLock = lock.readLock(); + readLock.lock(); + Set ids; + try { + ids = + requests.entrySet().stream() + .filter( + entry -> + queue.stream() + .anyMatch( + sessionRequest -> + sessionRequest.getRequestId().equals(entry.getKey()))) + .filter(entry -> isTimedOut(now, entry.getValue())) + .map(Map.Entry::getKey) + .collect(HashSet::new, Set::add, Set::addAll); + } finally { + readLock.unlock(); + } + ids.forEach(this::failDueToTimeout); + } + + private boolean isTimedOut(Instant now, Data data) { + return data.endTime.isBefore(now); + } + + @Override + public boolean peekEmpty() { + Lock readLock = lock.readLock(); + readLock.lock(); + try { + return requests.isEmpty() && queue.isEmpty(); + } finally { + readLock.unlock(); + } + } + + @Override + public HttpResponse addToQueue(SessionRequest request) { + Require.nonNull("New session request", request); + Require.nonNull("Request id", request.getRequestId()); + + TraceContext context = TraceSessionRequest.extract(tracer, request); + try (Span ignored = context.createSpan("sessionqueue.add_to_queue")) { + contexts.put(request.getRequestId(), context); + Data data = injectIntoQueue(request); + + if (isTimedOut(Instant.now(), data)) { + failDueToTimeout(request.getRequestId()); + } + + Either result; + try { + + boolean sessionCreated = data.latch.await(requestTimeout.toMillis(), MILLISECONDS); + + if (sessionCreated) { + result = data.getResult(); + } else { + result = Either.left(new SessionNotCreatedException("New session request timed out")); + } + } catch (InterruptedException e) { + // the client will never see the session, ensure the session is disposed + data.cancel(); + Thread.currentThread().interrupt(); + result = + Either.left(new SessionNotCreatedException("Interrupted when creating the session", e)); + } catch (RuntimeException e) { + // the client will never see the session, ensure the session is disposed + data.cancel(); + result = + Either.left( + new SessionNotCreatedException("An error occurred creating the session", e)); + } + + Lock writeLock = this.lock.writeLock(); + if (!writeLock.tryLock()) { + writeLock.lock(); + } + try { + requests.remove(request.getRequestId()); + queue.remove(request); + contexts.remove(request.getRequestId()); + } finally { + writeLock.unlock(); + } + + // Clean up from Redis + removeFromRedis(request.getRequestId()); + + HttpResponse res = new HttpResponse(); + if (result.isRight()) { + res.setContent(Contents.bytes(result.right().getDownstreamEncodedResponse())); + } else { + res.setStatus(HTTP_INTERNAL_ERROR) + .setContent( + Contents.asJson( + ImmutableMap.of( + "value", + ImmutableMap.of( + "error", "session not created", + "message", result.left().getMessage(), + "stacktrace", result.left().getStackTrace())))); + } + + return res; + } + } + + @VisibleForTesting + Data injectIntoQueue(SessionRequest request) { + Require.nonNull("Session request", request); + + Data data = new Data(request.getEnqueued(), requestTimeout); + + Lock writeLock = lock.writeLock(); + if (!writeLock.tryLock()) { + writeLock.lock(); + } + try { + requests.put(request.getRequestId(), data); + queue.addLast(request); + } finally { + writeLock.unlock(); + } + + // Persist to Redis asynchronously + persistToRedis(request, data); + + return data; + } + + @Override + public boolean retryAddToQueue(SessionRequest request) { + Require.nonNull("New session request", request); + + boolean added; + TraceContext context = + contexts.getOrDefault(request.getRequestId(), tracer.getCurrentContext()); + try (Span ignored = context.createSpan("sessionqueue.retry")) { + Lock writeLock = lock.writeLock(); + if (!writeLock.tryLock()) { + writeLock.lock(); + } + try { + if (!requests.containsKey(request.getRequestId())) { + return false; + } + Data data = requests.get(request.getRequestId()); + if (isTimedOut(Instant.now(), data)) { + // as we try to re-add a session request that has already expired, force session timeout + failDueToTimeout(request.getRequestId()); + // return true to avoid handleNewSessionRequest to call 'complete' an other time + return true; + } else if (data.isCanceled()) { + failDueToCanceled(request.getRequestId()); + // return true to avoid handleNewSessionRequest to call 'complete' an other time + return true; + } + + if (queue.contains(request)) { + // No need to re-add this + return true; + } else { + added = queue.offerFirst(request); + } + } finally { + writeLock.unlock(); + } + + return added; + } + } + + @Override + public Optional remove(RequestId reqId) { + Require.nonNull("Request ID", reqId); + + Lock writeLock = lock.writeLock(); + if (!writeLock.tryLock()) { + writeLock.lock(); + } + try { + Iterator iterator = queue.iterator(); + while (iterator.hasNext()) { + SessionRequest req = iterator.next(); + if (reqId.equals(req.getRequestId())) { + iterator.remove(); + // Remove from Redis + removeFromRedis(reqId); + return Optional.of(req); + } + } + return Optional.empty(); + } finally { + writeLock.unlock(); + } + } + + private volatile long lastNonEmptyQueueTime = System.currentTimeMillis(); + private volatile boolean wasEmpty = false; + private static final long MAX_BACKOFF_MS = 1000; // Maximum 1 second backoff + private static final long MIN_BACKOFF_MS = 10; // Minimum 10ms backoff + + @Override + public List getNextAvailable(Map stereotypes) { + Require.nonNull("Stereotypes", stereotypes); + + // Convert maximumResponseDelay to milliseconds for easier comparison + long maxDelayMs = maximumResponseDelay.toMillis(); + long startTime = System.currentTimeMillis(); + long backoffMs = MIN_BACKOFF_MS; + + // delay the response to avoid heavy polling via http + while (maxDelayMs > System.currentTimeMillis() - startTime) { + boolean isEmpty = true; + + // Check queue status with read lock + Lock readLock = lock.readLock(); + readLock.lock(); + try { + isEmpty = queue.isEmpty(); + if (!isEmpty) { + lastNonEmptyQueueTime = System.currentTimeMillis(); + wasEmpty = false; + break; // Exit loop if we found requests to process + } + } finally { + readLock.unlock(); + } + + // If queue is empty, use backoff with jitter + if (isEmpty) { + long timeSinceLastNonEmpty = System.currentTimeMillis() - lastNonEmptyQueueTime; + + // If queue has been empty for a while, increase backoff + if (wasEmpty && timeSinceLastNonEmpty > 100) { + backoffMs = Math.min(backoffMs * 2, MAX_BACKOFF_MS); + } else { + backoffMs = MIN_BACKOFF_MS; + } + + // Don't sleep longer than the remaining delay time + long remainingDelay = maxDelayMs - (System.currentTimeMillis() - startTime); + if (remainingDelay <= 0) { + break; + } + + long sleepTime = Math.min(backoffMs, remainingDelay); + if (sleepTime <= 0) { + break; + } + + // Add jitter to prevent thundering herd (up to 10% of sleep time) + long jitter = (long) (Math.random() * sleepTime * 0.1); + wasEmpty = true; + + try { + Thread.sleep(sleepTime + jitter); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + break; + } + } + } + + Predicate matchesStereotype = + caps -> + stereotypes.entrySet().stream() + .filter(entry -> entry.getValue() > 0) + .anyMatch( + entry -> { + boolean matches = slotMatcher.matches(entry.getKey(), caps); + if (matches) { + Long value = entry.getValue(); + entry.setValue(value - 1); + } + return matches; + }); + + Lock writeLock = lock.writeLock(); + if (!writeLock.tryLock()) { + writeLock.lock(); + } + try { + List availableRequests = + queue.stream() + .filter(req -> req.getDesiredCapabilities().stream().anyMatch(matchesStereotype)) + .limit(batchSize) + .collect(Collectors.toList()); + + availableRequests.removeIf( + (req) -> { + Data data = this.requests.get(req.getRequestId()); + + if (data.isCanceled()) { + failDueToCanceled(req.getRequestId()); + return true; + } + + this.remove(req.getRequestId()); + return false; + }); + + return availableRequests; + } finally { + writeLock.unlock(); + } + } + + /** Returns true if the session is still valid (not timed out and not canceled) */ + @Override + public boolean complete( + RequestId reqId, Either result) { + Require.nonNull("New session request", reqId); + Require.nonNull("Result", result); + TraceContext context = contexts.getOrDefault(reqId, tracer.getCurrentContext()); + try (Span ignored = context.createSpan("sessionqueue.completed")) { + Data data; + Lock writeLock = lock.writeLock(); + if (!writeLock.tryLock()) { + writeLock.lock(); + } + try { + data = requests.remove(reqId); + queue.removeIf(req -> reqId.equals(req.getRequestId())); + contexts.remove(reqId); + } finally { + writeLock.unlock(); + } + + if (data == null) { + return false; + } + + return data.setResult(result); + } + } + + @Override + public int clearQueue() { + Lock writeLock = lock.writeLock(); + if (!writeLock.tryLock()) { + writeLock.lock(); + } + + try { + int size = queue.size(); + queue.clear(); + requests.forEach( + (reqId, data) -> + data.setResult( + Either.left(new SessionNotCreatedException("Request queue was cleared")))); + requests.clear(); + // Do not clear contexts for strict alignment + // Clear Redis asynchronously + service.execute( + () -> { + try { + redisClient.del(QUEUE_KEY); + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to clear Redis queue", e); + } + }); + + return size; + } finally { + writeLock.unlock(); + } + } + + @Override + public List getQueueContents() { + Lock readLock = lock.readLock(); + readLock.lock(); + + try { + return queue.stream() + .map( + req -> new SessionRequestCapability(req.getRequestId(), req.getDesiredCapabilities())) + .collect(Collectors.toList()); + } finally { + readLock.unlock(); + } + } + + @ManagedAttribute(name = "NewSessionQueueSize") + public int getQueueSize() { + return queue.size(); + } + + @ManagedAttribute(name = "SessionQueueUri") + public String getSessionQueueUri() { + return sessionQueueUri.toString(); + } + + @Override + public boolean isReady() { + try { + return redisClient.isOpen(); + } catch (Exception e) { + return false; + } + } + + @Override + public void close() { + shutdownGracefully(NAME, service); + try { + redisClient.close(); + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to close Redis connection", e); + } + } + + private void failDueToTimeout(RequestId reqId) { + complete(reqId, Either.left(new SessionNotCreatedException("Timed out creating session"))); + } + + private void failDueToCanceled(RequestId reqId) { + // this error should never reach the client, as this is a client initiated state + complete(reqId, Either.left(new SessionNotCreatedException("Client has gone away"))); + } + + private static class Data { + public Instant endTime; + private final CountDownLatch latch = new CountDownLatch(1); + private Either result; + private boolean complete; + private boolean canceled; + + // No-arg constructor for JSON deserialization + public Data() { + this.endTime = Instant.now(); + this.complete = false; + this.canceled = false; + this.result = Either.left(new SessionNotCreatedException("Session not created")); + } + + public Data(Instant enqueued, Duration requestTimeout) { + this.endTime = Instant.now().plus(requestTimeout); + this.result = Either.left(new SessionNotCreatedException("Session not created")); + } + + // Constructor for JSON deserialization + public Data(Instant endTime, boolean complete, boolean canceled) { + this.endTime = endTime; + this.complete = complete; + this.canceled = canceled; + this.result = Either.left(new SessionNotCreatedException("Session not created")); + } + + // Add a constructor for full deserialization + public Data( + Instant endTime, + boolean complete, + boolean canceled, + Either result) { + this.endTime = endTime; + this.complete = complete; + this.canceled = canceled; + this.result = result; + } + + public synchronized Either getResult() { + return result; + } + + public synchronized void cancel() { + canceled = true; + } + + public synchronized boolean isCanceled() { + return canceled; + } + + public synchronized boolean setResult( + Either result) { + if (complete || canceled) { + return false; + } + this.result = result; + complete = true; + latch.countDown(); + return true; + } + + // Remove static from fromJson method and make it an instance method + public Data fromJson( + Instant endTime, boolean complete, boolean canceled, Map resultMap) { + Either result; + if (resultMap.containsKey("right")) { + CreateSessionResponse resp = (CreateSessionResponse) resultMap.get("right"); + result = Either.right(resp); + } else { + SessionNotCreatedException ex = (SessionNotCreatedException) resultMap.get("left"); + result = Either.left(ex); + } + return new Data(endTime, complete, canceled, result); + } + + // Setters for JSON deserialization + public void setEndTime(Instant endTime) { + this.endTime = endTime; + } + + public void setComplete(boolean complete) { + this.complete = complete; + } + + public void setCanceled(boolean canceled) { + this.canceled = canceled; + } + } +} diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/remote/RemoteNewSessionQueue.java b/java/src/org/openqa/selenium/grid/sessionqueue/remote/RemoteNewSessionQueue.java index 7b5fc00c3820f..fba0c168c4fce 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/remote/RemoteNewSessionQueue.java +++ b/java/src/org/openqa/selenium/grid/sessionqueue/remote/RemoteNewSessionQueue.java @@ -68,6 +68,10 @@ public class RemoteNewSessionQueue extends NewSessionQueue { private static final Json JSON = new Json(); private final HttpClient client; private final Filter addSecret; + private volatile long backoffMs = 100; // Start with 100ms backoff + private static final long MAX_BACKOFF_MS = 5000; // Max 5 seconds backoff + private static final long MIN_BACKOFF_MS = 100; // Min 100ms backoff + private volatile long lastRequestTime = 0; public RemoteNewSessionQueue(Tracer tracer, HttpClient client, Secret registrationSecret) { super(tracer, registrationSecret); @@ -146,17 +150,51 @@ public Optional remove(RequestId reqId) { public List getNextAvailable(Map stereotypes) { Require.nonNull("Stereotypes", stereotypes); - Map stereotypeJson = new HashMap<>(); - stereotypes.forEach((k, v) -> stereotypeJson.put(JSON.toJson(k), v)); + // Apply backoff if needed + long now = System.currentTimeMillis(); + long timeSinceLastRequest = now - lastRequestTime; + + if (timeSinceLastRequest < backoffMs) { + long sleepTime = backoffMs - timeSinceLastRequest; + try { + // Add some jitter to prevent thundering herd + long jitter = (long) (Math.random() * sleepTime * 0.1); // Up to 10% jitter + Thread.sleep(sleepTime + jitter); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return List.of(); + } + } - HttpRequest upstream = - new HttpRequest(POST, "/se/grid/newsessionqueue/session/next") - .setContent(Contents.asJson(stereotypeJson)); + try { + Map stereotypeJson = new HashMap<>(); + stereotypes.forEach((k, v) -> stereotypeJson.put(JSON.toJson(k), v)); - HttpTracing.inject(tracer, tracer.getCurrentContext(), upstream); - HttpResponse response = client.with(addSecret).execute(upstream); + HttpRequest upstream = + new HttpRequest(POST, "/se/grid/newsessionqueue/session/next") + .setContent(Contents.asJson(stereotypeJson)); - return Values.get(response, SESSION_REQUEST_TYPE); + HttpTracing.inject(tracer, tracer.getCurrentContext(), upstream); + HttpResponse response = client.with(addSecret).execute(upstream); + + List result = Values.get(response, SESSION_REQUEST_TYPE); + + // If we got results, reduce backoff. Otherwise, increase it. + if (result == null || result.isEmpty()) { + backoffMs = Math.min((long) (backoffMs * 1.5), MAX_BACKOFF_MS); + } else { + backoffMs = Math.max(MIN_BACKOFF_MS, backoffMs / 2); + } + + return result != null ? result : List.of(); + + } catch (Exception e) { + // On error, increase backoff more aggressively + backoffMs = Math.min(backoffMs * 2, MAX_BACKOFF_MS); + throw e; + } finally { + lastRequestTime = System.currentTimeMillis(); + } } @Override diff --git a/java/test/org/openqa/selenium/grid/sessionqueue/redis/BUILD.bazel b/java/test/org/openqa/selenium/grid/sessionqueue/redis/BUILD.bazel new file mode 100644 index 0000000000000..10adbef298ba9 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/sessionqueue/redis/BUILD.bazel @@ -0,0 +1,22 @@ +load("@rules_jvm_external//:defs.bzl", "artifact") +load("//java:defs.bzl", "JUNIT5_DEPS", "java_test_suite") + +java_test_suite( + name = "MediumTests", + size = "medium", + srcs = glob(["*Test.java"]), + deps = [ + "//java/src/org/openqa/selenium/events/local", + "//java/src/org/openqa/selenium/grid/sessionqueue/redis", + "//java/src/org/openqa/selenium/json", + "//java/src/org/openqa/selenium/redis", + "//java/src/org/openqa/selenium/remote", + "//java/test/org/openqa/selenium/remote/tracing:tracing-support", + "//java/test/org/openqa/selenium/testing:test-base", + artifact("io.lettuce:lettuce-core"), + artifact("io.opentelemetry:opentelemetry-api"), + artifact("org.junit.jupiter:junit-jupiter-api"), + artifact("org.assertj:assertj-core"), + artifact("org.mockito:mockito-core"), + ] + JUNIT5_DEPS, +) diff --git a/java/test/org/openqa/selenium/grid/sessionqueue/redis/RedisBackedSessionQueueTest.java b/java/test/org/openqa/selenium/grid/sessionqueue/redis/RedisBackedSessionQueueTest.java new file mode 100644 index 0000000000000..81a2acb57d574 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/sessionqueue/redis/RedisBackedSessionQueueTest.java @@ -0,0 +1,216 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.grid.sessionqueue.redis; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.openqa.selenium.ImmutableCapabilities; +import org.openqa.selenium.SessionNotCreatedException; +import org.openqa.selenium.grid.data.CreateSessionResponse; +import org.openqa.selenium.grid.data.RequestId; +import org.openqa.selenium.grid.data.Session; +import org.openqa.selenium.grid.data.SessionId; +import org.openqa.selenium.grid.data.SessionRequest; +import org.openqa.selenium.grid.data.SessionRequestCapability; +import org.openqa.selenium.grid.security.Secret; +import org.openqa.selenium.internal.Either; +import org.openqa.selenium.remote.http.Contents; +import org.openqa.selenium.remote.http.HttpMethod; +import org.openqa.selenium.remote.http.HttpRequest; +import org.openqa.selenium.remote.http.HttpResponse; +import org.openqa.selenium.remote.tracing.DefaultTestTracer; +import org.openqa.selenium.remote.tracing.Tracer; + +class RedisBackedSessionQueueTest { + + private static final Tracer tracer = DefaultTestTracer.createTracer(); + private static final Secret secret = new Secret("test-secret"); + private static final URI redisUri = URI.create("redis://localhost:6379"); + private static final Duration REQUEST_TIMEOUT_CHECK = Duration.ofMillis(50); + private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(1); + private static final Duration MAX_RESPONSE_DELAY = Duration.ofSeconds(2); + private static final int BATCH_SIZE = 3; + + private RedisBackedSessionQueue queue; + + @BeforeEach + void setUp() { + queue = + new RedisBackedSessionQueue( + tracer, + secret, + redisUri, + REQUEST_TIMEOUT_CHECK, + REQUEST_TIMEOUT, + MAX_RESPONSE_DELAY, + BATCH_SIZE); + } + + @AfterEach + void tearDown() { + queue.clearQueue(); + } + + @Test + void shouldThrowIllegalArgumentExceptionIfRedisUriIsNull() { + assertThatThrownBy( + () -> + new RedisBackedSessionQueue( + tracer, + secret, + null, + REQUEST_TIMEOUT_CHECK, + REQUEST_TIMEOUT, + MAX_RESPONSE_DELAY, + BATCH_SIZE)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void shouldThrowIllegalArgumentExceptionIfTracerIsNull() { + assertThatThrownBy( + () -> + new RedisBackedSessionQueue( + null, + secret, + redisUri, + REQUEST_TIMEOUT_CHECK, + REQUEST_TIMEOUT, + MAX_RESPONSE_DELAY, + BATCH_SIZE)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void shouldThrowIllegalArgumentExceptionIfSecretIsNull() { + assertThatThrownBy( + () -> + new RedisBackedSessionQueue( + tracer, + null, + redisUri, + REQUEST_TIMEOUT_CHECK, + REQUEST_TIMEOUT, + MAX_RESPONSE_DELAY, + BATCH_SIZE)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void canAddSessionRequestToQueue() { + RequestId requestId = new RequestId(UUID.randomUUID()); + SessionRequest request = createSessionRequest(requestId); + + HttpResponse response = queue.addToQueue(request); + + assertThat(response.getStatus()).isEqualTo(200); + } + + @Test + void canRemoveSessionRequestFromQueue() { + RequestId requestId = new RequestId(UUID.randomUUID()); + SessionRequest originalRequest = createSessionRequest(requestId); + + queue.addToQueue(originalRequest); + + Optional removed = queue.remove(requestId); + + assertThat(removed).isPresent(); + assertThat(removed.get().getRequestId()).isEqualTo(requestId); + } + + @Test + void getNextAvailableShouldReturnOldestRequest() { + RequestId requestId = new RequestId(UUID.randomUUID()); + SessionRequest originalRequest = createSessionRequest(requestId); + + queue.addToQueue(originalRequest); + + List next = queue.getNextAvailable(Map.of()); + + assertThat(next).hasSize(1); + assertThat(next.get(0).getRequestId()).isEqualTo(requestId); + } + + @Test + void completeShouldReturnTrueAndCleanupRequestData() { + RequestId requestId = new RequestId(UUID.randomUUID()); + Session dummySession = + new Session( + new SessionId("dummy"), + "dummy-uri", + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now()); + CreateSessionResponse response = new CreateSessionResponse(dummySession, new byte[0]); + Either result = Either.right(response); + + queue.addToQueue(createSessionRequest(requestId)); + + boolean completed = queue.complete(requestId, result); + + assertThat(completed).isTrue(); + } + + @Test + void clearQueueShouldRemoveAllRequests() { + RequestId requestId1 = new RequestId(UUID.randomUUID()); + RequestId requestId2 = new RequestId(UUID.randomUUID()); + + queue.addToQueue(createSessionRequest(requestId1)); + queue.addToQueue(createSessionRequest(requestId2)); + + int cleared = queue.clearQueue(); + + assertThat(cleared).isEqualTo(2); + } + + @Test + void getQueueContentsShouldReturnAllRequests() { + RequestId requestId1 = new RequestId(UUID.randomUUID()); + RequestId requestId2 = new RequestId(UUID.randomUUID()); + SessionRequest request1 = createSessionRequest(requestId1); + SessionRequest request2 = createSessionRequest(requestId2); + + queue.addToQueue(request1); + queue.addToQueue(request2); + + List contents = queue.getQueueContents(); + + assertThat(contents).hasSize(2); + assertThat(contents.get(0).getRequestId()).isEqualTo(requestId1); + assertThat(contents.get(1).getRequestId()).isEqualTo(requestId2); + } + + private SessionRequest createSessionRequest(RequestId requestId) { + HttpRequest httpRequest = new HttpRequest(HttpMethod.POST, "/session"); + httpRequest.setContent(Contents.utf8String("{\"capabilities\":{\"browserName\":\"chrome\"}}")); + return new SessionRequest(requestId, httpRequest, Instant.now()); + } +}