From e5b96c64d4de7cac5b74be0eae78ad775fb3c259 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 18 Aug 2025 15:26:40 +0100 Subject: [PATCH 1/5] wip remove node pinning logic --- .../xpack/inference/InferencePlugin.java | 17 -- .../action/BaseTransportInferenceAction.java | 84 +------- .../action/TransportInferenceAction.java | 3 - ...sportUnifiedCompletionInferenceAction.java | 2 - ...nceAPIClusterAwareRateLimitingFeature.java | 28 --- ...ceServiceNodeLocalRateLimitCalculator.java | 197 ------------------ .../InferenceServiceRateLimitCalculator.java | 18 -- .../NoopNodeLocalRateLimitCalculator.java | 27 --- .../http/sender/RequestExecutorService.java | 3 +- 9 files changed, 2 insertions(+), 377 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 8380ac1d87c37..ec391efbbf90e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -81,9 +81,6 @@ import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; -import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; -import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.common.NoopNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; @@ -160,7 +157,6 @@ import static java.util.Collections.singletonList; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; -import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG; public class InferencePlugin extends Plugin implements @@ -382,19 +378,6 @@ public Collection createComponents(PluginServices services) { new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager) ); components.add(inferenceStatsBinding); - - // Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting, - // if the rate limiting feature flags are enabled, otherwise provide noop implementation - InferenceServiceRateLimitCalculator calculator; - if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG) { - calculator = new InferenceServiceNodeLocalRateLimitCalculator(services.clusterService(), serviceRegistry); - } else { - calculator = new NoopNodeLocalRateLimitCalculator(); - } - - // Add binding for interface -> implementation - components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator)); - components.add( new InferenceEndpointRegistry( services.clusterService(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 8e34cafa3e878..8ffda54ddc085 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -16,7 +16,6 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Nullable; @@ -31,25 +30,19 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.TransportException; -import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; -import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; -import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Random; -import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -64,10 +57,6 @@ /** * Base class for transport actions that handle inference requests. - * Works in conjunction with {@link InferenceServiceNodeLocalRateLimitCalculator} to - * route requests to specific nodes, iff they support "node-local" rate limiting, which is described in detail - * in {@link InferenceServiceNodeLocalRateLimitCalculator}. - * * @param The specific type of inference request being handled */ public abstract class BaseTransportInferenceAction extends HandledTransportAction< @@ -82,7 +71,6 @@ public abstract class BaseTransportInferenceAction requestReader, - InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, NodeClient nodeClient, ThreadPool threadPool ) { @@ -108,7 +95,6 @@ public BaseTransportInferenceAction( this.serviceRegistry = serviceRegistry; this.inferenceStats = inferenceStats; this.streamingTaskManager = streamingTaskManager; - this.inferenceServiceRateLimitCalculator = inferenceServiceNodeLocalRateLimitCalculator; this.nodeClient = nodeClient; this.threadPool = threadPool; this.transportService = transportService; @@ -161,15 +147,8 @@ protected void doExecute(Task task, Request request, ActionListener { try { inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); @@ -195,73 +174,12 @@ private void validateRequest(Request request, Model model) { validationHelper(() -> isInvalidTaskTypeForInferenceEndpoint(request, model), () -> createInvalidTaskTypeException(request, model)); } - private NodeRoutingDecision determineRouting(String serviceName, Request request, TaskType modelTaskType, String localNodeId) { - // Rerouting not supported or request was already rerouted - if (inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceName, modelTaskType) == false - || request.hasBeenRerouted()) { - return NodeRoutingDecision.handleLocally(); - } - - var rateLimitAssignment = inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceName, modelTaskType); - - // No assignment yet - if (rateLimitAssignment == null) { - return NodeRoutingDecision.handleLocally(); - } - - var responsibleNodes = rateLimitAssignment.responsibleNodes(); - - // Empty assignment - if (responsibleNodes == null || responsibleNodes.isEmpty()) { - return NodeRoutingDecision.handleLocally(); - } - - var nodeToHandleRequest = responsibleNodes.get(random.nextInt(responsibleNodes.size())); - - // The drawn node is the current node - if (nodeToHandleRequest.getId().equals(localNodeId)) { - return NodeRoutingDecision.handleLocally(); - } - - // Reroute request - return NodeRoutingDecision.routeTo(nodeToHandleRequest); - } - private static void validationHelper(Supplier validationFailure, Supplier exceptionCreator) { if (validationFailure.get()) { throw exceptionCreator.get(); } } - private void rerouteRequest(Request request, ActionListener listener, DiscoveryNode nodeToHandleRequest) { - transportService.sendRequest( - nodeToHandleRequest, - InferenceAction.NAME, - request, - new TransportResponseHandler() { - @Override - public Executor executor() { - return threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); - } - - @Override - public void handleResponse(InferenceAction.Response response) { - listener.onResponse(response); - } - - @Override - public void handleException(TransportException exp) { - listener.onFailure(exp); - } - - @Override - public InferenceAction.Response read(StreamInput in) throws IOException { - return new InferenceAction.Response(in); - } - } - ); - } - private void recordRequestDurationMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { Map metricAttributes = new HashMap<>(); metricAttributes.putAll(modelAttributes(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index f0fb0ec82757a..8d7a37ca52ea7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -22,7 +22,6 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; public class TransportInferenceAction extends BaseTransportInferenceAction { @@ -36,7 +35,6 @@ public TransportInferenceAction( InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, - InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, NodeClient nodeClient, ThreadPool threadPool ) { @@ -50,7 +48,6 @@ public TransportInferenceAction( inferenceStats, streamingTaskManager, InferenceAction.Request::new, - inferenceServiceNodeLocalRateLimitCalculator, nodeClient, threadPool ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 4fe5dd3a55a12..5c1c78fc8c9df 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -43,7 +43,6 @@ public TransportUnifiedCompletionInferenceAction( InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, - InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, NodeClient nodeClient, ThreadPool threadPool ) { @@ -57,7 +56,6 @@ public TransportUnifiedCompletionInferenceAction( inferenceStats, streamingTaskManager, UnifiedCompletionAction.Request::new, - inferenceServiceNodeLocalRateLimitCalculator, nodeClient, threadPool ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java deleted file mode 100644 index 518bd4ea85d9e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.common; - -import org.elasticsearch.common.util.FeatureFlag; -import org.elasticsearch.xpack.inference.InferencePlugin; - -/** - * Cluster aware rate limiting feature flag. When the feature is complete and fully rolled out, this flag will be removed. - * Enable feature via JVM option: `-Des.inference_cluster_aware_rate_limiting_feature_flag_enabled=true`. - * - * This controls, whether {@link InferenceServiceNodeLocalRateLimitCalculator} gets instantiated and - * added as injectable {@link InferencePlugin} component. - */ -public class InferenceAPIClusterAwareRateLimitingFeature { - - public static final boolean INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG = new FeatureFlag( - "inference_cluster_aware_rate_limiting" - ).isEnabled(); - - private InferenceAPIClusterAwareRateLimitingFeature() {} - -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java deleted file mode 100644 index 4778e4cc6d30c..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.common; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.cluster.ClusterChangedEvent; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.injection.guice.Inject; -import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; -import org.elasticsearch.xpack.inference.action.BaseTransportInferenceAction; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.services.SenderService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; - -/** - * Note: {@link InferenceAPIClusterAwareRateLimitingFeature} needs to be enabled for this class to get - * instantiated inside {@link org.elasticsearch.xpack.inference.InferencePlugin} and be available via dependency injection. - * - * Calculates and manages node-local rate limits for inference services based on changes in the cluster topology. - * This calculator calculates a "node-local" rate-limit, which essentially divides the rate limit for a service/task type - * through the number of nodes, which got assigned to this service/task type pair. Without this calculator the rate limit stored - * in the inference endpoint configuration would get effectively multiplied by the number of nodes in a cluster (assuming a ~ uniform - * distribution of requests to the nodes in the cluster). - * - * The calculator works in conjunction with several other components: - * - {@link BaseTransportInferenceAction} - Uses the calculator to determine, whether to reroute a request or not - * - {@link BaseInferenceActionRequest} - Tracks, if the request (an instance of a subclass of {@link BaseInferenceActionRequest}) - * already got re-routed at least once - * - {@link HttpRequestSender} - Provides original rate limits that this calculator divides through the number of nodes - * responsible for a service/task type - */ -public class InferenceServiceNodeLocalRateLimitCalculator implements InferenceServiceRateLimitCalculator { - - public static final Integer DEFAULT_MAX_NODES_PER_GROUPING = 3; - - /** - * Configuration mapping services to their task type rate limiting settings. - * Each service can have multiple configs defining: - * - Which task types support request re-routing and "node-local" rate limit calculation - * - How many nodes should handle requests for each task type, based on cluster size (dynamically calculated or statically provided) - **/ - static final Map> SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS = Map.of( - ElasticInferenceService.NAME, - // TODO: should probably be a map/set - List.of(new NodeLocalRateLimitConfig(TaskType.SPARSE_EMBEDDING, (numNodesInCluster) -> DEFAULT_MAX_NODES_PER_GROUPING)) - ); - - record NodeLocalRateLimitConfig(TaskType taskType, MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy) {} - - @FunctionalInterface - private interface MaxNodesPerGroupingStrategy { - - Integer calculate(Integer numberOfNodesInCluster); - - } - - private static final Logger logger = LogManager.getLogger(InferenceServiceNodeLocalRateLimitCalculator.class); - - private final InferenceServiceRegistry serviceRegistry; - - private final ConcurrentHashMap> serviceAssignments; - - @Inject - public InferenceServiceNodeLocalRateLimitCalculator(ClusterService clusterService, InferenceServiceRegistry serviceRegistry) { - clusterService.addListener(this); - this.serviceRegistry = serviceRegistry; - this.serviceAssignments = new ConcurrentHashMap<>(); - } - - @Override - public void clusterChanged(ClusterChangedEvent event) { - boolean clusterTopologyChanged = event.nodesChanged(); - - // TODO: feature flag per node? We should not reroute to nodes not having eis and/or the inference plugin enabled - // Every node should land on the same grouping by calculation, so no need to put anything into the cluster state - if (clusterTopologyChanged) { - updateAssignments(event); - } - } - - public boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType) { - return SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.getOrDefault(serviceName, Collections.emptyList()) - .stream() - .anyMatch(rateLimitConfig -> taskType.equals(rateLimitConfig.taskType)); - } - - public RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType) { - var assignmentsPerTaskType = serviceAssignments.get(service); - - if (assignmentsPerTaskType == null) { - return null; - } - - return assignmentsPerTaskType.get(taskType); - } - - /** - * Updates instances of {@link RateLimitAssignment} for each service and task type when the cluster topology changes. - * For each service and supported task type, calculates which nodes should handle requests - * and what their local rate limits should be per inference endpoint. - */ - private void updateAssignments(ClusterChangedEvent event) { - var newClusterState = event.state(); - var nodes = newClusterState.nodes().getAllNodes(); - - // Sort nodes by id (every node lands on the same result) - var sortedNodes = nodes.stream().sorted(Comparator.comparing(DiscoveryNode::getId)).toList(); - - // Sort inference services by name (every node lands on the same result) - var sortedServices = new ArrayList<>(serviceRegistry.getServices().values()); - sortedServices.sort(Comparator.comparing(InferenceService::name)); - - for (String serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) { - Optional service = serviceRegistry.getService(serviceName); - - if (service.isPresent()) { - var inferenceService = service.get(); - - for (NodeLocalRateLimitConfig rateLimitConfig : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName)) { - Map perTaskTypeAssignments = new HashMap<>(); - TaskType taskType = rateLimitConfig.taskType(); - - // Calculate node assignments needed for re-routing - var assignedNodes = calculateServiceAssignment(rateLimitConfig.maxNodesPerGroupingStrategy(), sortedNodes); - - // Update rate limits to be "node-local" - var numAssignedNodes = assignedNodes.size(); - updateRateLimits(inferenceService, numAssignedNodes); - - perTaskTypeAssignments.put(taskType, new RateLimitAssignment(assignedNodes)); - serviceAssignments.put(serviceName, perTaskTypeAssignments); - } - } else { - logger.warn( - "Service [{}] is configured for node-local rate limiting but was not found in the service registry", - serviceName - ); - } - } - } - - private List calculateServiceAssignment( - MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy, - List sortedNodes - ) { - int numberOfNodes = sortedNodes.size(); - int nodesPerGrouping = Math.min(numberOfNodes, maxNodesPerGroupingStrategy.calculate(numberOfNodes)); - - List assignedNodes = new ArrayList<>(); - - // TODO: here we can probably be smarter: if |num nodes in cluster| > |num nodes per task types| - // -> make sure a service provider is not assigned the same nodes for all task types; only relevant as soon as we support more task - // types - for (int j = 0; j < nodesPerGrouping; j++) { - var assignedNode = sortedNodes.get(j % numberOfNodes); - assignedNodes.add(assignedNode); - } - - return assignedNodes; - } - - private void updateRateLimits(InferenceService service, int responsibleNodes) { - if ((service instanceof SenderService) == false) { - return; - } - - SenderService senderService = (SenderService) service; - Sender sender = senderService.getSender(); - // TODO: this needs to take in service and task type as soon as multiple services/task types are supported - sender.updateRateLimitDivisor(responsibleNodes); - } - - InferenceServiceRegistry serviceRegistry() { - return serviceRegistry; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java deleted file mode 100644 index e05637f629ec6..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.common; - -import org.elasticsearch.cluster.ClusterStateListener; -import org.elasticsearch.inference.TaskType; - -public interface InferenceServiceRateLimitCalculator extends ClusterStateListener { - - boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType); - - RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType); -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java deleted file mode 100644 index a07217d9e9af7..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.common; - -import org.elasticsearch.cluster.ClusterChangedEvent; -import org.elasticsearch.inference.TaskType; - -public class NoopNodeLocalRateLimitCalculator implements InferenceServiceRateLimitCalculator { - - @Override - public void clusterChanged(ClusterChangedEvent event) { - // Do nothing - } - - public boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType) { - return false; - } - - public RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType) { - return null; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index 81d249add0262..c9f4cb9353bdf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -19,7 +19,6 @@ import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue; -import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.RateLimiter; import org.elasticsearch.xpack.inference.external.http.RequestExecutor; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; @@ -411,7 +410,7 @@ public void init() { * * @param divisor - divisor to divide the initial requests per time unit by */ - public synchronized void updateTokensPerTimeUnit(Integer divisor) { + private synchronized void updateTokensPerTimeUnit(Integer divisor) { double updatedTokensPerTimeUnit = (double) originalRequestsPerTimeUnit / divisor; rateLimiter.setRate(ACCUMULATED_TOKENS_LIMIT, updatedTokensPerTimeUnit, rateLimitSettings.timeUnit()); } From 1d86ce831a754ba8a8dd67cd88acecf1a6968656 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Fri, 12 Sep 2025 13:09:21 +0100 Subject: [PATCH 2/5] the rest --- ...sportUnifiedCompletionInferenceAction.java | 1 - .../http/sender/RequestExecutorService.java | 6 - .../BaseTransportInferenceActionTestCase.java | 5 - .../action/TransportInferenceActionTests.java | 151 ----------- ...TransportUnifiedCompletionActionTests.java | 3 - ...viceNodeLocalRateLimitCalculatorTests.java | 246 ------------------ 6 files changed, 412 deletions(-) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 5c1c78fc8c9df..4000b2e175c52 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -27,7 +27,6 @@ import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.util.concurrent.Flow; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index c9f4cb9353bdf..3e2d0c5d0171e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -97,8 +97,6 @@ interface RateLimiterCreator { private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1); private final ConcurrentMap rateLimitGroupings = new ConcurrentHashMap<>(); - // TODO: add one atomic integer (number of nodes); also explain the assumption and why this works - // TODO: document that this impacts chat completion (and increase the default rate limit) private final AtomicInteger rateLimitDivisor = new AtomicInteger(1); private final ThreadPool threadPool; private final CountDownLatch startupLatch; @@ -404,10 +402,6 @@ public void init() { } /** - * This method is solely called by {@link InferenceServiceNodeLocalRateLimitCalculator} to update - * rate limits, so they're "node-local". - * The general idea is described in {@link InferenceServiceNodeLocalRateLimitCalculator} in more detail. - * * @param divisor - divisor to divide the initial requests per time unit by */ private synchronized void updateTokensPerTimeUnit(Integer divisor) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 47053b7cbe5eb..da2e8589c4525 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -32,7 +32,6 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -68,7 +67,6 @@ public abstract class BaseTransportInferenceActionTestCase createAction( InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, - InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, NodeClient nodeClient, ThreadPool threadPool ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index dd0a1b952233b..f63f322c91576 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -9,34 +9,17 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.TransportException; -import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.common.RateLimitAssignment; import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; -import java.util.List; - -import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.assertArg; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase { @@ -53,7 +36,6 @@ protected BaseTransportInferenceAction createAction( InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, - InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, NodeClient nodeClient, ThreadPool threadPool ) { @@ -65,7 +47,6 @@ protected BaseTransportInferenceAction createAction( serviceRegistry, inferenceStats, streamingTaskManager, - inferenceServiceNodeLocalRateLimitCalculator, nodeClient, threadPool ); @@ -75,136 +56,4 @@ protected BaseTransportInferenceAction createAction( protected InferenceAction.Request createRequest() { return mock(InferenceAction.Request.class); } - - public void testNoRerouting_WhenTaskTypeNotSupported() { - TaskType unsupportedTaskType = TaskType.COMPLETION; - mockService(listener -> listener.onResponse(mock())); - - when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false); - - var listener = doExecute(unsupportedTaskType); - - verify(listener).onResponse(any()); - // Verify request was handled locally (not rerouted using TransportService) - verify(transportService, never()).sendRequest(any(), any(), any(), any()); - // Verify request metric attributes were recorded on the node performing inference - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); - })); - } - - public void testNoRerouting_WhenNoGroupingCalculatedYet() { - mockService(listener -> listener.onResponse(mock())); - - when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); - when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null); - - var listener = doExecute(taskType); - - verify(listener).onResponse(any()); - // Verify request was handled locally (not rerouted using TransportService) - verify(transportService, never()).sendRequest(any(), any(), any(), any()); - // Verify request metric attributes were recorded on the node performing inference - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); - })); - } - - public void testNoRerouting_WhenEmptyNodeList() { - mockService(listener -> listener.onResponse(mock())); - - when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); - when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn( - new RateLimitAssignment(List.of()) - ); - - var listener = doExecute(taskType); - - verify(listener).onResponse(any()); - // Verify request was handled locally (not rerouted using TransportService) - verify(transportService, never()).sendRequest(any(), any(), any(), any()); - // Verify request metric attributes were recorded on the node performing inference - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); - })); - } - - public void testRerouting_ToOtherNode() { - DiscoveryNode otherNode = mock(DiscoveryNode.class); - when(otherNode.getId()).thenReturn("other-node"); - - // The local node is different to the "other-node" responsible for serviceId - when(nodeClient.getLocalNodeId()).thenReturn("local-node"); - when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); - // Requests for serviceId are always routed to "other-node" - var assignment = new RateLimitAssignment(List.of(otherNode)); - when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment); - - mockService(listener -> listener.onResponse(mock())); - var listener = doExecute(taskType); - - // Verify request was rerouted - verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any()); - // Verify local execution didn't happen - verify(listener, never()).onResponse(any()); - // Verify that request metric attributes were NOT recorded on the node rerouting the request to another node - verify(inferenceStats.inferenceDuration(), never()).record(anyLong(), any()); - } - - public void testRerouting_ToLocalNode_WithoutGoingThroughTransportLayerAgain() { - DiscoveryNode localNode = mock(DiscoveryNode.class); - String localNodeId = "local-node"; - when(localNode.getId()).thenReturn(localNodeId); - - // The local node is the only one responsible for serviceId - when(nodeClient.getLocalNodeId()).thenReturn(localNodeId); - when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); - var assignment = new RateLimitAssignment(List.of(localNode)); - when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment); - - mockService(listener -> listener.onResponse(mock())); - var listener = doExecute(taskType); - - verify(listener).onResponse(any()); - // Verify request was handled locally (not rerouted using TransportService) - verify(transportService, never()).sendRequest(any(), any(), any(), any()); - // Verify request metric attributes were recorded on the node performing inference - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); - })); - } - - public void testRerouting_HandlesTransportException_FromOtherNode() { - DiscoveryNode otherNode = mock(DiscoveryNode.class); - when(otherNode.getId()).thenReturn("other-node"); - - when(nodeClient.getLocalNodeId()).thenReturn("local-node"); - when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true); - var assignment = new RateLimitAssignment(List.of(otherNode)); - when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment); - - mockService(listener -> listener.onResponse(mock())); - - TransportException expectedException = new TransportException("Failed to route"); - doAnswer(invocation -> { - TransportResponseHandler handler = invocation.getArgument(3); - handler.handleException(expectedException); - return null; - }).when(transportService).sendRequest(any(), any(), any(), any()); - - var listener = doExecute(taskType); - - // Verify request was rerouted - verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any()); - // Verify local execution didn't happen - verify(listener, never()).onResponse(any()); - // Verify exception was propagated from "other-node" to "local-node" - verify(listener).onFailure(same(expectedException)); - // Verify that request metric attributes were NOT recorded on the node rerouting the request to another node - verify(inferenceStats.inferenceDuration(), never()).record(anyLong(), any()); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index 0b05509acaf8b..45d2addaa7075 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.util.Optional; @@ -49,7 +48,6 @@ protected BaseTransportInferenceAction createAc InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, - InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator, NodeClient nodeClient, ThreadPool threadPool ) { @@ -61,7 +59,6 @@ protected BaseTransportInferenceAction createAc serviceRegistry, inferenceStats, streamingTaskManager, - inferenceServiceRateLimitCalculator, nodeClient, threadPool ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java deleted file mode 100644 index 0914a081acf07..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java +++ /dev/null @@ -1,246 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.common; - -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.InternalTestCluster; -import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.services.SenderService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; - -import java.util.Arrays; -import java.util.Collection; -import java.util.Set; -import java.util.concurrent.TimeUnit; - -import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.DEFAULT_MAX_NODES_PER_GROUPING; -import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; - -@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 0) -public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase { - - private static final Integer RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS = 15; - - public void setUp() throws Exception { - super.setUp(); - assumeTrue( - "If inference_cluster_aware_rate_limiting_feature_flag_enabled=false we'll fallback to " - + "NoopNodeLocalRateLimitCalculator, which shouldn't be tested by this class.", - InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG - ); - } - - public void testInitialClusterGrouping_Correct() throws Exception { - // Start with 2-5 nodes - var numNodes = randomIntBetween(2, 5); - var nodeNames = internalCluster().startNodes(numNodes); - ensureStableCluster(numNodes); - - var firstCalculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst()); - waitForRateLimitingAssignments(firstCalculator); - - RateLimitAssignment firstAssignment = firstCalculator.getRateLimitAssignment( - ElasticInferenceService.NAME, - TaskType.SPARSE_EMBEDDING - ); - - // Verify that all other nodes land on the same assignment - for (String nodeName : nodeNames.subList(1, nodeNames.size())) { - var calculator = getCalculatorInstance(internalCluster(), nodeName); - waitForRateLimitingAssignments(calculator); - var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING); - assertEquals(firstAssignment, currentAssignment); - } - } - - public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws Exception { - // Start with 3-5 nodes - var numNodes = randomIntBetween(3, 5); - var nodeNames = internalCluster().startNodes(numNodes); - ensureStableCluster(numNodes); - - var nodeLeftInCluster = nodeNames.getFirst(); - var currentNumberOfNodes = numNodes; - - // Stop all nodes except one - for (String nodeName : nodeNames) { - if (nodeName.equals(nodeLeftInCluster)) { - continue; - } - internalCluster().stopNode(nodeName); - currentNumberOfNodes--; - ensureStableCluster(currentNumberOfNodes); - } - - var calculator = getCalculatorInstance(internalCluster(), nodeLeftInCluster); - waitForRateLimitingAssignments(calculator); - - Set supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet(); - - // Check assignments for each supported service - for (var service : supportedServices) { - var assignment = calculator.getRateLimitAssignment(service, TaskType.SPARSE_EMBEDDING); - - assertNotNull(assignment); - // Should have exactly one responsible node - assertEquals(1, assignment.responsibleNodes().size()); - // That node should be our remaining node - assertEquals(nodeLeftInCluster, assignment.responsibleNodes().get(0).getName()); - } - } - - public void testGrouping_RespectsMaxNodesPerGroupingLimit() throws Exception { - // Start with more nodes possible per grouping - var numNodes = DEFAULT_MAX_NODES_PER_GROUPING + randomIntBetween(1, 3); - var nodeNames = internalCluster().startNodes(numNodes); - ensureStableCluster(numNodes); - - var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst()); - waitForRateLimitingAssignments(calculator); - - Set supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet(); - - for (var service : supportedServices) { - var assignment = calculator.getRateLimitAssignment(service, TaskType.SPARSE_EMBEDDING); - - assertNotNull(assignment); - assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size())); - } - } - - public void testInitialRateLimitsCalculation_Correct() throws Exception { - // Start with max nodes per grouping (=3) - int numNodes = DEFAULT_MAX_NODES_PER_GROUPING; - var nodeNames = internalCluster().startNodes(numNodes); - ensureStableCluster(numNodes); - - var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst()); - waitForRateLimitingAssignments(calculator); - - Set supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet(); - - for (var serviceName : supportedServices) { - try (var serviceRegistry = calculator.serviceRegistry()) { - var serviceOptional = serviceRegistry.getService(serviceName); - assertTrue(serviceOptional.isPresent()); - var service = serviceOptional.get(); - - if ((service instanceof SenderService senderService)) { - var sender = senderService.getSender(); - if (sender instanceof HttpRequestSender) { - var assignment = calculator.getRateLimitAssignment(service.name(), TaskType.SPARSE_EMBEDDING); - - assertNotNull(assignment); - assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size())); - } - } - } - - } - } - - public void testRateLimits_Decrease_OnNodeJoin() throws Exception { - // Start with 2 nodes - var initialNodes = 2; - var nodeNames = internalCluster().startNodes(initialNodes); - ensureStableCluster(initialNodes); - - var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst()); - waitForRateLimitingAssignments(calculator); - - for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) { - var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName); - for (var config : configs) { - // Get initial assignments and rate limits - var initialAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType()); - assertEquals(2, initialAssignment.responsibleNodes().size()); - - // Add a new node - internalCluster().startNode(); - ensureStableCluster(initialNodes + 1); - waitForRateLimitingAssignments(calculator); - - // Get updated assignments - var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType()); - - // Verify number of responsible nodes increased - assertEquals(3, updatedAssignment.responsibleNodes().size()); - } - } - } - - public void testRateLimits_Increase_OnNodeLeave() throws Exception { - // Start with max nodes per grouping (=3) - int numNodes = DEFAULT_MAX_NODES_PER_GROUPING; - var nodeNames = internalCluster().startNodes(numNodes); - ensureStableCluster(numNodes); - - var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst()); - waitForRateLimitingAssignments(calculator); - - for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) { - var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName); - for (var config : configs) { - // Get initial assignments and rate limits - var initialAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType()); - assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(initialAssignment.responsibleNodes().size())); - - // Remove a node - var nodeToRemove = nodeNames.get(numNodes - 1); - internalCluster().stopNode(nodeToRemove); - ensureStableCluster(numNodes - 1); - waitForRateLimitingAssignments(calculator); - - // Get updated assignments - var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType()); - - // Verify number of responsible nodes decreased - assertThat(2, equalTo(updatedAssignment.responsibleNodes().size())); - } - } - } - - @Override - protected Collection> nodePlugins() { - return Arrays.asList(LocalStateInferencePlugin.class); - } - - private InferenceServiceNodeLocalRateLimitCalculator getCalculatorInstance(InternalTestCluster internalTestCluster, String nodeName) { - InferenceServiceRateLimitCalculator calculatorInstance = internalTestCluster.getInstance( - InferenceServiceRateLimitCalculator.class, - nodeName - ); - assertThat( - "[" - + InferenceServiceNodeLocalRateLimitCalculatorTests.class.getName() - + "] should use [" - + InferenceServiceNodeLocalRateLimitCalculator.class.getName() - + "] as implementation for [" - + InferenceServiceRateLimitCalculator.class.getName() - + "]. Provided implementation was [" - + calculatorInstance.getClass().getName() - + "].", - calculatorInstance, - instanceOf(InferenceServiceNodeLocalRateLimitCalculator.class) - ); - return (InferenceServiceNodeLocalRateLimitCalculator) calculatorInstance; - } - - private void waitForRateLimitingAssignments(InferenceServiceNodeLocalRateLimitCalculator calculator) throws Exception { - assertBusy(() -> { - var assignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING); - assertNotNull(assignment); - assertFalse(assignment.responsibleNodes().isEmpty()); - }, RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS, TimeUnit.SECONDS); - } -} From 93f0f19e1d53187ecd091af0712f7199550128e5 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 17 Sep 2025 09:52:24 +0100 Subject: [PATCH 3/5] Remove rate limit divisor --- .../external/http/RequestExecutor.java | 2 -- .../http/sender/HttpRequestSender.java | 4 ---- .../http/sender/RequestExecutorService.java | 23 ------------------- .../external/http/sender/Sender.java | 2 -- .../client/AmazonBedrockRequestSender.java | 5 ---- .../AmazonBedrockMockRequestSender.java | 5 ---- 6 files changed, 41 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java index 6c7c6e0d114c7..63c042ce8a623 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java @@ -21,8 +21,6 @@ public interface RequestExecutor { void shutdown(); - void updateRateLimitDivisor(int newDivisor); - boolean isShutdown(); boolean isTerminated(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java index f870f997153a4..0c551f67cc531 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java @@ -113,10 +113,6 @@ public void start() { } } - public void updateRateLimitDivisor(int rateLimitDivisor) { - service.updateRateLimitDivisor(rateLimitDivisor); - } - private void waitForStartToComplete() { try { if (startCompleted.await(START_COMPLETED_WAIT_TIME.getSeconds(), TimeUnit.SECONDS) == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index 3e2d0c5d0171e..7138cd30aa4d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -162,19 +162,6 @@ public int queueSize() { return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); } - @Override - public void updateRateLimitDivisor(int numResponsibleNodes) { - // in the unlikely case where we get an invalid value, we'll just ignore it - if (numResponsibleNodes <= 0) { - return; - } - - rateLimitDivisor.set(numResponsibleNodes); - for (var rateLimitingEndpointHandler : rateLimitGroupings.values()) { - rateLimitingEndpointHandler.updateTokensPerTimeUnit(rateLimitDivisor.get()); - } - } - /** * Begin servicing tasks. *

@@ -393,22 +380,12 @@ static class RateLimitingEndpointHandler { rateLimitSettings.requestsPerTimeUnit(), rateLimitSettings.timeUnit() ); - - this.updateTokensPerTimeUnit(rateLimitDivisor); } public void init() { requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange); } - /** - * @param divisor - divisor to divide the initial requests per time unit by - */ - private synchronized void updateTokensPerTimeUnit(Integer divisor) { - double updatedTokensPerTimeUnit = (double) originalRequestsPerTimeUnit / divisor; - rateLimiter.setRate(ACCUMULATED_TOKENS_LIMIT, updatedTokensPerTimeUnit, rateLimitSettings.timeUnit()); - } - public String id() { return id; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java index fed92263f9999..3975a554586b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java @@ -27,8 +27,6 @@ void send( ActionListener listener ); - void updateRateLimitDivisor(int rateLimitDivisor); - void sendWithoutQueuing( Logger logger, Request request, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java index ccc52087288d0..3f3b0db571bae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java @@ -97,11 +97,6 @@ protected AmazonBedrockRequestSender( this.startCompleted = Objects.requireNonNull(startCompleted); } - @Override - public void updateRateLimitDivisor(int rateLimitDivisor) { - executorService.updateRateLimitDivisor(rateLimitDivisor); - } - @Override public void start() { if (started.compareAndSet(false, true)) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java index 797d50878a0b7..ac82367af6865 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java @@ -70,11 +70,6 @@ public void start() { // do nothing } - @Override - public void updateRateLimitDivisor(int rateLimitDivisor) { - // do nothing - } - @Override public void send( RequestManager requestCreator, From ab7e442d00d18882411e56cd48e8948bbe3a80da Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 17 Sep 2025 09:52:38 +0100 Subject: [PATCH 4/5] remove routing stats --- .../org/elasticsearch/TransportVersions.java | 1 + .../inference/telemetry/InferenceStats.java | 4 -- .../action/BaseInferenceActionRequest.java | 28 +++-------- .../action/InferenceActionRequestTests.java | 49 +++++++++++++------ .../UnifiedCompletionActionRequestTests.java | 19 ------- .../action/BaseTransportInferenceAction.java | 3 -- .../BaseTransportInferenceActionTestCase.java | 18 ------- 7 files changed, 42 insertions(+), 80 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 36a6e27dd2a3f..69d30b04d9120 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -325,6 +325,7 @@ static TransportVersion def(int id) { public static final TransportVersion MAX_HEAP_SIZE_PER_NODE_IN_CLUSTER_INFO = def(9_159_0_00); public static final TransportVersion TIMESERIES_DEFAULT_LIMIT = def(9_160_0_00); public static final TransportVersion INFERENCE_API_OPENAI_HEADERS = def(9_161_0_00); + public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED = def(9_162_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java b/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java index e73b1ad9c5ff6..53e1a103d7722 100644 --- a/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java +++ b/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java @@ -61,10 +61,6 @@ public static Map modelAttributes(Model model) { return modelAttributesMap; } - public static Map routingAttributes(boolean hasBeenRerouted, String nodeIdHandlingRequest) { - return Map.of("rerouted", hasBeenRerouted, "node_id", nodeIdHandlingRequest); - } - public static Map modelAttributes(UnparsedModel model) { return Map.of("service", model.service(), "task_type", model.taskType().toString()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java index 799f75cfd0527..48c9b235769e9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java @@ -23,8 +23,6 @@ */ public abstract class BaseInferenceActionRequest extends LegacyActionRequest { - private boolean hasBeenRerouted; - private final InferenceContext context; public BaseInferenceActionRequest(InferenceContext context) { @@ -34,12 +32,9 @@ public BaseInferenceActionRequest(InferenceContext context) { public BaseInferenceActionRequest(StreamInput in) throws IOException { super(in); - if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { - this.hasBeenRerouted = in.readBoolean(); - } else { - // For backwards compatibility, we treat all inference requests coming from ES nodes having - // a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior. - this.hasBeenRerouted = true; + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING) && + in.getTransportVersion().before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED)) { + in.readBoolean(); } if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT) @@ -56,14 +51,6 @@ public BaseInferenceActionRequest(StreamInput in) throws IOException { public abstract String getInferenceEntityId(); - public void setHasBeenRerouted(boolean hasBeenRerouted) { - this.hasBeenRerouted = hasBeenRerouted; - } - - public boolean hasBeenRerouted() { - return hasBeenRerouted; - } - public InferenceContext getContext() { return context; } @@ -71,8 +58,9 @@ public InferenceContext getContext() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { - out.writeBoolean(hasBeenRerouted); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING) + && out.getTransportVersion().before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED)) { + out.writeBoolean(true); } if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT) @@ -86,11 +74,11 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; BaseInferenceActionRequest that = (BaseInferenceActionRequest) o; - return hasBeenRerouted == that.hasBeenRerouted && Objects.equals(context, that.context); + return Objects.equals(context, that.context); } @Override public int hashCode() { - return Objects.hash(hasBeenRerouted, context); + return Objects.hash(context); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index 2e2b9bf9b0d23..e6d9a97e0b8c1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -691,13 +691,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque mutated = instance; } - // We always assume that a request has been rerouted, if it came from a node without adaptive rate limiting - if (version.before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { - mutated.setHasBeenRerouted(true); - } else { - mutated.setHasBeenRerouted(instance.hasBeenRerouted()); - } - return mutated; } @@ -749,7 +742,7 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED)); } - public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException { + public void testWriteTo_ForHasBeenReroutedChanges() throws IOException { var instance = new InferenceAction.Request( TaskType.TEXT_EMBEDDING, "model", @@ -763,15 +756,39 @@ public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeen false ); - InferenceAction.Request deserializedInstance = copyWriteable( - instance, - getNamedWriteableRegistry(), - instanceReader(), - TransportVersions.V_8_13_0 - ); + { + // From a version before the rerouting logic was added + InferenceAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), + TransportVersions.V_8_17_0 + ); + + assertEquals(instance, deserializedInstance); + } + { + // From a version with rerouting + InferenceAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), + TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING + ); - // Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version - assertTrue(deserializedInstance.hasBeenRerouted()); + assertEquals(instance, deserializedInstance); + } + { + // From a version with rerouting removed + InferenceAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), + TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED + ); + + assertEquals(instance, deserializedInstance); + } } public void testWriteTo_WhenVersionIsBeforeInferenceContext_ShouldSetContextToEmptyContext() throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java index d050c2fe57bd8..ac5afeb73eb53 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -68,25 +68,6 @@ public void testValidation_ReturnsNull_When_TaskType_IsAny() { assertNull(request.validate()); } - public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException { - var instance = new UnifiedCompletionAction.Request( - "model", - TaskType.ANY, - UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), - TimeValue.timeValueSeconds(10) - ); - - UnifiedCompletionAction.Request deserializedInstance = copyWriteable( - instance, - getNamedWriteableRegistry(), - instanceReader(), - TransportVersions.ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION - ); - - // Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version - assertTrue(deserializedInstance.hasBeenRerouted()); - } - public void testWriteTo_WhenVersionIsBeforeInferenceContext_ShouldSetContextToEmptyContext() throws IOException { var instance = new UnifiedCompletionAction.Request( "model", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 8ffda54ddc085..e943e03da07a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -52,7 +52,6 @@ import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes; import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes; -import static org.elasticsearch.inference.telemetry.InferenceStats.routingAttributes; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; /** @@ -272,7 +271,6 @@ protected Flow.Publisher streamErrorHandler(Flow.Publisher upstream) { private void recordRequestCountMetrics(Model model, Request request, String localNodeId) { Map requestCountAttributes = new HashMap<>(); requestCountAttributes.putAll(modelAttributes(model)); - requestCountAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId)); inferenceStats.requestCount().incrementBy(1, requestCountAttributes); } @@ -286,7 +284,6 @@ private void recordRequestDurationMetrics( ) { Map metricAttributes = new HashMap<>(); metricAttributes.putAll(modelAndResponseAttributes(model, unwrapCause(t))); - metricAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId)); inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index da2e8589c4525..cdde8c64eb537 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -64,7 +64,6 @@ public abstract class BaseTransportInferenceActionTestCase createAction( @@ -237,8 +235,6 @@ public void testMetricsAfterInferError() { assertThat(attributes.get("model_id"), nullValue()); assertThat(attributes.get("status_code"), nullValue()); assertThat(attributes.get("error.type"), is(expectedError)); - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); })); } @@ -261,8 +257,6 @@ public void testMetricsAfterStreamUnsupported() { assertThat(attributes.get("model_id"), nullValue()); assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); assertThat(attributes.get("error.type"), is(expectedError)); - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); })); } @@ -278,8 +272,6 @@ public void testMetricsAfterInferSuccess() { assertThat(attributes.get("model_id"), nullValue()); assertThat(attributes.get("status_code"), is(200)); assertThat(attributes.get("error.type"), nullValue()); - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); })); } @@ -291,8 +283,6 @@ public void testMetricsAfterStreamInferSuccess() { assertThat(attributes.get("model_id"), nullValue()); assertThat(attributes.get("status_code"), is(200)); assertThat(attributes.get("error.type"), nullValue()); - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); })); } @@ -306,8 +296,6 @@ public void testMetricsAfterStreamInferFailure() { assertThat(attributes.get("model_id"), nullValue()); assertThat(attributes.get("status_code"), nullValue()); assertThat(attributes.get("error.type"), is(expectedError)); - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); })); } @@ -341,8 +329,6 @@ public void onComplete() { assertThat(attributes.get("model_id"), nullValue()); assertThat(attributes.get("status_code"), is(200)); assertThat(attributes.get("error.type"), nullValue()); - assertThat(attributes.get("rerouted"), is(Boolean.FALSE)); - assertThat(attributes.get("node_id"), is(localNodeId)); })); } @@ -445,8 +431,4 @@ protected Model mockModel(TaskType expectedTaskType) { protected void mockValidLicenseState() { when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true); } - - private void mockNodeClient() { - when(nodeClient.getLocalNodeId()).thenReturn(localNodeId); - } } From 25dd25a78a1de3b8c7c6a4d23087902bcf904d90 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 17 Sep 2025 08:59:39 +0000 Subject: [PATCH 5/5] [CI] Auto commit changes from spotless --- .../core/inference/action/BaseInferenceActionRequest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java index 48c9b235769e9..7b663d5731dd2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java @@ -32,8 +32,8 @@ public BaseInferenceActionRequest(InferenceContext context) { public BaseInferenceActionRequest(StreamInput in) throws IOException { super(in); - if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING) && - in.getTransportVersion().before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING) + && in.getTransportVersion().before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED)) { in.readBoolean(); }