Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction;
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.NoopNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
Expand Down Expand Up @@ -160,7 +157,6 @@

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;

public class InferencePlugin extends Plugin
implements
Expand Down Expand Up @@ -382,19 +378,6 @@ public Collection<?> createComponents(PluginServices services) {
new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager)
);
components.add(inferenceStatsBinding);

// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
// if the rate limiting feature flags are enabled, otherwise provide noop implementation
InferenceServiceRateLimitCalculator calculator;
if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG) {
calculator = new InferenceServiceNodeLocalRateLimitCalculator(services.clusterService(), serviceRegistry);
} else {
calculator = new NoopNodeLocalRateLimitCalculator();
}

// Add binding for interface -> implementation
components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));

components.add(
new InferenceEndpointRegistry(
services.clusterService(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Nullable;
Expand All @@ -31,25 +30,19 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand All @@ -64,10 +57,6 @@

/**
* Base class for transport actions that handle inference requests.
* Works in conjunction with {@link InferenceServiceNodeLocalRateLimitCalculator} to
* route requests to specific nodes, iff they support "node-local" rate limiting, which is described in detail
* in {@link InferenceServiceNodeLocalRateLimitCalculator}.
*
* @param <Request> The specific type of inference request being handled
*/
public abstract class BaseTransportInferenceAction<Request extends BaseInferenceActionRequest> extends HandledTransportAction<
Expand All @@ -82,7 +71,6 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
private final InferenceServiceRegistry serviceRegistry;
private final InferenceStats inferenceStats;
private final StreamingTaskManager streamingTaskManager;
private final InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
private final NodeClient nodeClient;
private final ThreadPool threadPool;
private final TransportService transportService;
Expand All @@ -98,7 +86,6 @@ public BaseTransportInferenceAction(
InferenceStats inferenceStats,
StreamingTaskManager streamingTaskManager,
Writeable.Reader<Request> requestReader,
InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
NodeClient nodeClient,
ThreadPool threadPool
) {
Expand All @@ -108,7 +95,6 @@ public BaseTransportInferenceAction(
this.serviceRegistry = serviceRegistry;
this.inferenceStats = inferenceStats;
this.streamingTaskManager = streamingTaskManager;
this.inferenceServiceRateLimitCalculator = inferenceServiceNodeLocalRateLimitCalculator;
this.nodeClient = nodeClient;
this.threadPool = threadPool;
this.transportService = transportService;
Expand Down Expand Up @@ -161,15 +147,8 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct

var service = serviceRegistry.getService(serviceName).get();
var localNodeId = nodeClient.getLocalNodeId();
var routingDecision = determineRouting(serviceName, request, model.getTaskType(), localNodeId);
inferOnServiceWithMetrics(model, request, service, timer, localNodeId, listener);

if (routingDecision.currentNodeShouldHandleRequest()) {
inferOnServiceWithMetrics(model, request, service, timer, localNodeId, listener);
} else {
// Reroute request
request.setHasBeenRerouted(true);
rerouteRequest(request, listener, routingDecision.targetNode);
}
}, e -> {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
Expand All @@ -195,73 +174,12 @@ private void validateRequest(Request request, Model model) {
validationHelper(() -> isInvalidTaskTypeForInferenceEndpoint(request, model), () -> createInvalidTaskTypeException(request, model));
}

private NodeRoutingDecision determineRouting(String serviceName, Request request, TaskType modelTaskType, String localNodeId) {
// Rerouting not supported or request was already rerouted
if (inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceName, modelTaskType) == false
|| request.hasBeenRerouted()) {
return NodeRoutingDecision.handleLocally();
}

var rateLimitAssignment = inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceName, modelTaskType);

// No assignment yet
if (rateLimitAssignment == null) {
return NodeRoutingDecision.handleLocally();
}

var responsibleNodes = rateLimitAssignment.responsibleNodes();

// Empty assignment
if (responsibleNodes == null || responsibleNodes.isEmpty()) {
return NodeRoutingDecision.handleLocally();
}

var nodeToHandleRequest = responsibleNodes.get(random.nextInt(responsibleNodes.size()));

// The drawn node is the current node
if (nodeToHandleRequest.getId().equals(localNodeId)) {
return NodeRoutingDecision.handleLocally();
}

// Reroute request
return NodeRoutingDecision.routeTo(nodeToHandleRequest);
}

private static void validationHelper(Supplier<Boolean> validationFailure, Supplier<ElasticsearchStatusException> exceptionCreator) {
if (validationFailure.get()) {
throw exceptionCreator.get();
}
}

private void rerouteRequest(Request request, ActionListener<InferenceAction.Response> listener, DiscoveryNode nodeToHandleRequest) {
transportService.sendRequest(
nodeToHandleRequest,
InferenceAction.NAME,
request,
new TransportResponseHandler<InferenceAction.Response>() {
@Override
public Executor executor() {
return threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
}

@Override
public void handleResponse(InferenceAction.Response response) {
listener.onResponse(response);
}

@Override
public void handleException(TransportException exp) {
listener.onFailure(exp);
}

@Override
public InferenceAction.Response read(StreamInput in) throws IOException {
return new InferenceAction.Response(in);
}
}
);
}

private void recordRequestDurationMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
Map<String, Object> metricAttributes = new HashMap<>();
metricAttributes.putAll(modelAttributes(model));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;

public class TransportInferenceAction extends BaseTransportInferenceAction<InferenceAction.Request> {
Expand All @@ -36,7 +35,6 @@ public TransportInferenceAction(
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats,
StreamingTaskManager streamingTaskManager,
InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
NodeClient nodeClient,
ThreadPool threadPool
) {
Expand All @@ -50,7 +48,6 @@ public TransportInferenceAction(
inferenceStats,
streamingTaskManager,
InferenceAction.Request::new,
inferenceServiceNodeLocalRateLimitCalculator,
nodeClient,
threadPool
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;

import java.util.concurrent.Flow;
Expand All @@ -43,7 +42,6 @@ public TransportUnifiedCompletionInferenceAction(
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats,
StreamingTaskManager streamingTaskManager,
InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
NodeClient nodeClient,
ThreadPool threadPool
) {
Expand All @@ -57,7 +55,6 @@ public TransportUnifiedCompletionInferenceAction(
inferenceStats,
streamingTaskManager,
UnifiedCompletionAction.Request::new,
inferenceServiceNodeLocalRateLimitCalculator,
nodeClient,
threadPool
);
Expand Down

This file was deleted.

Loading