Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_156_0_00);
public static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = def(9_157_0_00);
public static final TransportVersion INDEX_SOURCE = def(9_158_0_00);
public static final TransportVersion ESQL_CHAT_COMPLETION_SUPPORT = def(9_159_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ private LogicalPlan resolveCompletion(Completion p, List<Attribute> childrenOutp
prompt = prompt.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
}

return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField);
return new Completion(p.source(), p.child(), p.inferenceId(), p.taskType(), prompt, targetField);
}

private LogicalPlan resolveMvExpand(MvExpand p, List<Attribute> childrenOutput) {
Expand Down Expand Up @@ -1349,20 +1349,7 @@ private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext
return plan.withInferenceResolutionError(inferenceId, error);
}

if (resolvedInference.taskType() != plan.taskType()) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ plan.nodeName()
+ " command. Only inference endpoints with the task type ["
+ plan.taskType()
+ "] are supported.";
return plan.withInferenceResolutionError(inferenceId, error);
}

return plan;
return plan.withResolvedInference(resolvedInference);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.inference.bulk;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;

import java.util.Objects;

public sealed interface BulkInferenceRequestItem<T extends BaseInferenceActionRequest> permits
BulkInferenceRequestItem.AbstractBulkInferenceRequestItem {

TaskType taskType();

T inferenceRequest();

BulkInferenceRequestItem<T> withSeqNo(long seqNo);

Long seqNo();

static InferenceRequestItem from(InferenceAction.Request request) {
return new InferenceRequestItem(request);
}

static ChatCompletionRequestItem from(UnifiedCompletionAction.Request request) {
return new ChatCompletionRequestItem(request);
}

abstract sealed class AbstractBulkInferenceRequestItem<T extends BaseInferenceActionRequest> implements BulkInferenceRequestItem<T>
permits InferenceRequestItem, ChatCompletionRequestItem {
private final T request;
private final Long seqNo;

protected AbstractBulkInferenceRequestItem(T request) {
this(request, null);
}

protected AbstractBulkInferenceRequestItem(T request, Long seqNo) {
this.request = request;
this.seqNo = seqNo;
}

@Override
public T inferenceRequest() {
return request;
}

@Override
public Long seqNo() {
return seqNo;
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
AbstractBulkInferenceRequestItem<?> that = (AbstractBulkInferenceRequestItem<?>) o;
return Objects.equals(request, that.request) && Objects.equals(seqNo, that.seqNo);
}

@Override
public int hashCode() {
return Objects.hash(request, seqNo);
}

@Override
public TaskType taskType() {
return request.getTaskType();
}
}

final class InferenceRequestItem extends AbstractBulkInferenceRequestItem<InferenceAction.Request> {
private InferenceRequestItem(InferenceAction.Request request) {
this(request, null);
}

private InferenceRequestItem(InferenceAction.Request request, Long seqNo) {
super(request, seqNo);
}

@Override
public InferenceRequestItem withSeqNo(long seqNo) {
return new InferenceRequestItem(inferenceRequest(), seqNo);
}
}

final class ChatCompletionRequestItem extends AbstractBulkInferenceRequestItem<UnifiedCompletionAction.Request> {

private ChatCompletionRequestItem(UnifiedCompletionAction.Request request) {
this(request, null);
}

private ChatCompletionRequestItem(UnifiedCompletionAction.Request request, Long seqNo) {
super(request, seqNo);
}

@Override
public TaskType taskType() {
return TaskType.CHAT_COMPLETION;
}

@Override
public ChatCompletionRequestItem withSeqNo(long seqNo) {
return new ChatCompletionRequestItem(inferenceRequest(), seqNo);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@
package org.elasticsearch.xpack.esql.inference.bulk;

import org.elasticsearch.core.Releasable;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;

import java.util.Iterator;

public interface BulkInferenceRequestIterator extends Iterator<InferenceAction.Request>, Releasable {
public interface BulkInferenceRequestIterator extends Iterator<BulkInferenceRequestItem<?>>, Releasable {

/**
* Returns an estimate of the number of requests that will be produced.
*
* <p>This is typically used to pre-allocate buffers or output to th appropriate size.</p>
* <p>This is typically used to pre-allocate buffers or output to the appropriate size.</p>
*/
int estimatedSize();

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -175,12 +177,12 @@ private class BulkInferenceRequest {
* to the request iterator.
* </p>
*
* @return A BulkRequestItem if a request and permit are available, null otherwise
* @return A BulkInferenceRequestItem if a request and permit are available, null otherwise
*/
private BulkRequestItem pollPendingRequest() {
private BulkInferenceRequestItem<?> pollPendingRequest() {
synchronized (requests) {
if (requests.hasNext()) {
return new BulkRequestItem(executionState.generateSeqNo(), requests.next());
return requests.next().withSeqNo(executionState.generateSeqNo());
}
}

Expand Down Expand Up @@ -226,22 +228,22 @@ private void executePendingRequests(int recursionDepth) {
}
return;
} else {
BulkRequestItem bulkRequestItem = pollPendingRequest();
BulkInferenceRequestItem<?> bulkRequestItem = pollPendingRequest();

if (bulkRequestItem == null) {
// No more requests available
// Release the permit we didn't used and stop processing
permits.release();

// Check if another bulk request is pending for execution.
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll();

while (nexBulkRequest == this) {
nexBulkRequest = pendingBulkRequests.poll();
while (nextBulkRequest == this) {
nextBulkRequest = pendingBulkRequests.poll();
}

if (nexBulkRequest != null) {
executor.execute(nexBulkRequest::executePendingRequests);
if (nextBulkRequest != null) {
executor.execute(nextBulkRequest::executePendingRequests);
}

return;
Expand Down Expand Up @@ -275,9 +277,9 @@ private void executePendingRequests(int recursionDepth) {
// Response has already been sent
// No need to continue processing this bulk.
// Check if another bulk request is pending for execution.
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
if (nexBulkRequest != null) {
executor.execute(nexBulkRequest::executePendingRequests);
BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll();
if (nextBulkRequest != null) {
executor.execute(nextBulkRequest::executePendingRequests);
}
return;
}
Expand All @@ -298,26 +300,57 @@ private void executePendingRequests(int recursionDepth) {
);

// Handle null requests (edge case in some iterators)
if (bulkRequestItem.request() == null) {
if (bulkRequestItem.inferenceRequest() == null) {
inferenceResponseListener.onResponse(null);
return;
}

// Execute the inference request with proper origin context
executeAsyncWithOrigin(
client,
INFERENCE_ORIGIN,
InferenceAction.INSTANCE,
bulkRequestItem.request(),
inferenceResponseListener
);
if (bulkRequestItem.taskType() == TaskType.CHAT_COMPLETION) {
handleStreamingRequest(
(UnifiedCompletionAction.Request) bulkRequestItem.inferenceRequest(),
inferenceResponseListener
);
} else {
executeAsyncWithOrigin(
client,
INFERENCE_ORIGIN,
InferenceAction.INSTANCE,
bulkRequestItem.inferenceRequest(),
inferenceResponseListener
);
}
}
}
} catch (Exception e) {
executionState.addFailure(e);
}
}

/**
* Handles streaming inference requests for chat completion tasks.
* <p>
* This method executes UnifiedCompletionAction requests and sets up proper streaming
* response handling through the BulkInferenceStreamingHandler. The streaming handler
* manages the asynchronous stream processing and ensures responses are properly
* delivered to the completion listener.
* </p>
*
* @param request The UnifiedCompletionAction request to execute
* @param listener The listener to receive the final aggregated response
*/
private void handleStreamingRequest(UnifiedCompletionAction.Request request, ActionListener<InferenceAction.Response> listener) {
executeAsyncWithOrigin(
client,
INFERENCE_ORIGIN,
UnifiedCompletionAction.INSTANCE,
request,
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
inferenceResponse.publisher().subscribe(new BulkInferenceStreamingHandler(l));
})
);
}

/**
* Processes and delivers buffered responses in order, ensuring proper sequencing.
* <p>
Expand Down Expand Up @@ -360,20 +393,6 @@ private void onBulkCompletion() {
}
}

/**
* Encapsulates an inference request with its associated sequence number.
* <p>
* The sequence number is used for ordering responses and tracking completion
* in the bulk execution state.
* </p>
*
* @param seqNo Unique sequence number for this request in the bulk operation
* @param request The actual inference request to execute
*/
private record BulkRequestItem(long seqNo, InferenceAction.Request request) {

}

public static Factory factory(Client client) {
return inferenceRunnerConfig -> new BulkInferenceRunner(client, inferenceRunnerConfig.maxOutstandingBulkRequests());
}
Expand Down
Loading