Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ static TransportVersion def(int id) {
public static final TransportVersion MAX_HEAP_SIZE_PER_NODE_IN_CLUSTER_INFO = def(9_159_0_00);
public static final TransportVersion TIMESERIES_DEFAULT_LIMIT = def(9_160_0_00);
public static final TransportVersion INFERENCE_API_OPENAI_HEADERS = def(9_161_0_00);
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED = def(9_162_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ public static Map<String, Object> modelAttributes(Model model) {
return modelAttributesMap;
}

public static Map<String, Object> routingAttributes(boolean hasBeenRerouted, String nodeIdHandlingRequest) {
return Map.of("rerouted", hasBeenRerouted, "node_id", nodeIdHandlingRequest);
}

public static Map<String, Object> modelAttributes(UnparsedModel model) {
return Map.of("service", model.service(), "task_type", model.taskType().toString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
*/
public abstract class BaseInferenceActionRequest extends LegacyActionRequest {

private boolean hasBeenRerouted;

private final InferenceContext context;

public BaseInferenceActionRequest(InferenceContext context) {
Expand All @@ -34,12 +32,9 @@ public BaseInferenceActionRequest(InferenceContext context) {

public BaseInferenceActionRequest(StreamInput in) throws IOException {
super(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
this.hasBeenRerouted = in.readBoolean();
} else {
// For backwards compatibility, we treat all inference requests coming from ES nodes having
// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
this.hasBeenRerouted = true;
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)
&& in.getTransportVersion().before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED)) {
in.readBoolean();
}

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)
Expand All @@ -56,23 +51,16 @@ public BaseInferenceActionRequest(StreamInput in) throws IOException {

public abstract String getInferenceEntityId();

public void setHasBeenRerouted(boolean hasBeenRerouted) {
this.hasBeenRerouted = hasBeenRerouted;
}

public boolean hasBeenRerouted() {
return hasBeenRerouted;
}

public InferenceContext getContext() {
return context;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
out.writeBoolean(hasBeenRerouted);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)
&& out.getTransportVersion().before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED)) {
out.writeBoolean(true);
}

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)
Expand All @@ -86,11 +74,11 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BaseInferenceActionRequest that = (BaseInferenceActionRequest) o;
return hasBeenRerouted == that.hasBeenRerouted && Objects.equals(context, that.context);
return Objects.equals(context, that.context);
}

@Override
public int hashCode() {
return Objects.hash(hasBeenRerouted, context);
return Objects.hash(context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -691,13 +691,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
mutated = instance;
}

// We always assume that a request has been rerouted, if it came from a node without adaptive rate limiting
if (version.before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
mutated.setHasBeenRerouted(true);
} else {
mutated.setHasBeenRerouted(instance.hasBeenRerouted());
}

return mutated;
}

Expand Down Expand Up @@ -749,7 +742,7 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn
assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
}

public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException {
public void testWriteTo_ForHasBeenReroutedChanges() throws IOException {
var instance = new InferenceAction.Request(
TaskType.TEXT_EMBEDDING,
"model",
Expand All @@ -763,15 +756,39 @@ public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeen
false
);

InferenceAction.Request deserializedInstance = copyWriteable(
instance,
getNamedWriteableRegistry(),
instanceReader(),
TransportVersions.V_8_13_0
);
{
// From a version before the rerouting logic was added
InferenceAction.Request deserializedInstance = copyWriteable(
instance,
getNamedWriteableRegistry(),
instanceReader(),
TransportVersions.V_8_17_0
);

assertEquals(instance, deserializedInstance);
}
{
// From a version with rerouting
InferenceAction.Request deserializedInstance = copyWriteable(
instance,
getNamedWriteableRegistry(),
instanceReader(),
TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING
);

// Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version
assertTrue(deserializedInstance.hasBeenRerouted());
assertEquals(instance, deserializedInstance);
}
{
// From a version with rerouting removed
InferenceAction.Request deserializedInstance = copyWriteable(
instance,
getNamedWriteableRegistry(),
instanceReader(),
TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED
);

assertEquals(instance, deserializedInstance);
}
}

public void testWriteTo_WhenVersionIsBeforeInferenceContext_ShouldSetContextToEmptyContext() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,6 @@ public void testValidation_ReturnsNull_When_TaskType_IsAny() {
assertNull(request.validate());
}

public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException {
var instance = new UnifiedCompletionAction.Request(
"model",
TaskType.ANY,
UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())),
TimeValue.timeValueSeconds(10)
);

UnifiedCompletionAction.Request deserializedInstance = copyWriteable(
instance,
getNamedWriteableRegistry(),
instanceReader(),
TransportVersions.ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION
);

// Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version
assertTrue(deserializedInstance.hasBeenRerouted());
}

public void testWriteTo_WhenVersionIsBeforeInferenceContext_ShouldSetContextToEmptyContext() throws IOException {
var instance = new UnifiedCompletionAction.Request(
"model",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction;
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.NoopNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
Expand Down Expand Up @@ -160,7 +157,6 @@

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;

public class InferencePlugin extends Plugin
implements
Expand Down Expand Up @@ -382,19 +378,6 @@ public Collection<?> createComponents(PluginServices services) {
new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager)
);
components.add(inferenceStatsBinding);

// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
// if the rate limiting feature flags are enabled, otherwise provide noop implementation
InferenceServiceRateLimitCalculator calculator;
if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG) {
calculator = new InferenceServiceNodeLocalRateLimitCalculator(services.clusterService(), serviceRegistry);
} else {
calculator = new NoopNodeLocalRateLimitCalculator();
}

// Add binding for interface -> implementation
components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));

components.add(
new InferenceEndpointRegistry(
services.clusterService(),
Expand Down
Loading