diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index 10114e429cc7a..09759cadaea70 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -324,6 +324,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_156_0_00);
public static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = def(9_157_0_00);
public static final TransportVersion INDEX_SOURCE = def(9_158_0_00);
+ public static final TransportVersion ESQL_CHAT_COMPLETION_SUPPORT = def(9_159_0_00);
/*
* STOP! READ THIS FIRST! No, really,
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
index 1f51055094c92..c8d2de3d1ad7c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
@@ -623,7 +623,7 @@ private LogicalPlan resolveCompletion(Completion p, List childrenOutp
prompt = prompt.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
}
- return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField);
+ return new Completion(p.source(), p.child(), p.inferenceId(), p.taskType(), prompt, targetField);
}
private LogicalPlan resolveMvExpand(MvExpand p, List childrenOutput) {
@@ -1349,20 +1349,7 @@ private LogicalPlan resolveInferencePlan(InferencePlan> plan, AnalyzerContext
return plan.withInferenceResolutionError(inferenceId, error);
}
- if (resolvedInference.taskType() != plan.taskType()) {
- String error = "cannot use inference endpoint ["
- + inferenceId
- + "] with task type ["
- + resolvedInference.taskType()
- + "] within a "
- + plan.nodeName()
- + " command. Only inference endpoints with the task type ["
- + plan.taskType()
- + "] are supported.";
- return plan.withInferenceResolutionError(inferenceId, error);
- }
-
- return plan;
+ return plan.withResolvedInference(resolvedInference);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java
new file mode 100644
index 0000000000000..20b1bc24143da
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.bulk;
+
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
+
+import java.util.Objects;
+
+public sealed interface BulkInferenceRequestItem permits
+ BulkInferenceRequestItem.AbstractBulkInferenceRequestItem {
+
+ TaskType taskType();
+
+ T inferenceRequest();
+
+ BulkInferenceRequestItem withSeqNo(long seqNo);
+
+ Long seqNo();
+
+ static InferenceRequestItem from(InferenceAction.Request request) {
+ return new InferenceRequestItem(request);
+ }
+
+ static ChatCompletionRequestItem from(UnifiedCompletionAction.Request request) {
+ return new ChatCompletionRequestItem(request);
+ }
+
+ abstract sealed class AbstractBulkInferenceRequestItem implements BulkInferenceRequestItem
+ permits InferenceRequestItem, ChatCompletionRequestItem {
+ private final T request;
+ private final Long seqNo;
+
+ protected AbstractBulkInferenceRequestItem(T request) {
+ this(request, null);
+ }
+
+ protected AbstractBulkInferenceRequestItem(T request, Long seqNo) {
+ this.request = request;
+ this.seqNo = seqNo;
+ }
+
+ @Override
+ public T inferenceRequest() {
+ return request;
+ }
+
+ @Override
+ public Long seqNo() {
+ return seqNo;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ AbstractBulkInferenceRequestItem> that = (AbstractBulkInferenceRequestItem>) o;
+ return Objects.equals(request, that.request) && Objects.equals(seqNo, that.seqNo);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(request, seqNo);
+ }
+
+ @Override
+ public TaskType taskType() {
+ return request.getTaskType();
+ }
+ }
+
+ final class InferenceRequestItem extends AbstractBulkInferenceRequestItem {
+ private InferenceRequestItem(InferenceAction.Request request) {
+ this(request, null);
+ }
+
+ private InferenceRequestItem(InferenceAction.Request request, Long seqNo) {
+ super(request, seqNo);
+ }
+
+ @Override
+ public InferenceRequestItem withSeqNo(long seqNo) {
+ return new InferenceRequestItem(inferenceRequest(), seqNo);
+ }
+ }
+
+ final class ChatCompletionRequestItem extends AbstractBulkInferenceRequestItem {
+
+ private ChatCompletionRequestItem(UnifiedCompletionAction.Request request) {
+ this(request, null);
+ }
+
+ private ChatCompletionRequestItem(UnifiedCompletionAction.Request request, Long seqNo) {
+ super(request, seqNo);
+ }
+
+ @Override
+ public TaskType taskType() {
+ return TaskType.CHAT_COMPLETION;
+ }
+
+ @Override
+ public ChatCompletionRequestItem withSeqNo(long seqNo) {
+ return new ChatCompletionRequestItem(inferenceRequest(), seqNo);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java
index 7327b182d0b6c..03739af92f7ff 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java
@@ -8,17 +8,15 @@
package org.elasticsearch.xpack.esql.inference.bulk;
import org.elasticsearch.core.Releasable;
-import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import java.util.Iterator;
-public interface BulkInferenceRequestIterator extends Iterator, Releasable {
+public interface BulkInferenceRequestIterator extends Iterator>, Releasable {
/**
* Returns an estimate of the number of requests that will be produced.
*
- * This is typically used to pre-allocate buffers or output to th appropriate size.
+ * This is typically used to pre-allocate buffers or output to the appropriate size.
*/
int estimatedSize();
-
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java
index 203a3031bcad4..4d35411e3d3bb 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java
@@ -10,8 +10,10 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
+import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
import java.util.ArrayList;
import java.util.List;
@@ -175,12 +177,12 @@ private class BulkInferenceRequest {
* to the request iterator.
*
*
- * @return A BulkRequestItem if a request and permit are available, null otherwise
+ * @return A BulkInferenceRequestItem if a request and permit are available, null otherwise
*/
- private BulkRequestItem pollPendingRequest() {
+ private BulkInferenceRequestItem> pollPendingRequest() {
synchronized (requests) {
if (requests.hasNext()) {
- return new BulkRequestItem(executionState.generateSeqNo(), requests.next());
+ return requests.next().withSeqNo(executionState.generateSeqNo());
}
}
@@ -226,7 +228,7 @@ private void executePendingRequests(int recursionDepth) {
}
return;
} else {
- BulkRequestItem bulkRequestItem = pollPendingRequest();
+ BulkInferenceRequestItem> bulkRequestItem = pollPendingRequest();
if (bulkRequestItem == null) {
// No more requests available
@@ -234,14 +236,14 @@ private void executePendingRequests(int recursionDepth) {
permits.release();
// Check if another bulk request is pending for execution.
- BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
+ BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll();
- while (nexBulkRequest == this) {
- nexBulkRequest = pendingBulkRequests.poll();
+ while (nextBulkRequest == this) {
+ nextBulkRequest = pendingBulkRequests.poll();
}
- if (nexBulkRequest != null) {
- executor.execute(nexBulkRequest::executePendingRequests);
+ if (nextBulkRequest != null) {
+ executor.execute(nextBulkRequest::executePendingRequests);
}
return;
@@ -275,9 +277,9 @@ private void executePendingRequests(int recursionDepth) {
// Response has already been sent
// No need to continue processing this bulk.
// Check if another bulk request is pending for execution.
- BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
- if (nexBulkRequest != null) {
- executor.execute(nexBulkRequest::executePendingRequests);
+ BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll();
+ if (nextBulkRequest != null) {
+ executor.execute(nextBulkRequest::executePendingRequests);
}
return;
}
@@ -298,19 +300,26 @@ private void executePendingRequests(int recursionDepth) {
);
// Handle null requests (edge case in some iterators)
- if (bulkRequestItem.request() == null) {
+ if (bulkRequestItem.inferenceRequest() == null) {
inferenceResponseListener.onResponse(null);
return;
}
// Execute the inference request with proper origin context
- executeAsyncWithOrigin(
- client,
- INFERENCE_ORIGIN,
- InferenceAction.INSTANCE,
- bulkRequestItem.request(),
- inferenceResponseListener
- );
+ if (bulkRequestItem.taskType() == TaskType.CHAT_COMPLETION) {
+ handleStreamingRequest(
+ (UnifiedCompletionAction.Request) bulkRequestItem.inferenceRequest(),
+ inferenceResponseListener
+ );
+ } else {
+ executeAsyncWithOrigin(
+ client,
+ INFERENCE_ORIGIN,
+ InferenceAction.INSTANCE,
+ bulkRequestItem.inferenceRequest(),
+ inferenceResponseListener
+ );
+ }
}
}
} catch (Exception e) {
@@ -318,6 +327,30 @@ private void executePendingRequests(int recursionDepth) {
}
}
+ /**
+ * Handles streaming inference requests for chat completion tasks.
+ *
+ * This method executes UnifiedCompletionAction requests and sets up proper streaming
+ * response handling through the BulkInferenceStreamingHandler. The streaming handler
+ * manages the asynchronous stream processing and ensures responses are properly
+ * delivered to the completion listener.
+ *
+ *
+ * @param request The UnifiedCompletionAction request to execute
+ * @param listener The listener to receive the final aggregated response
+ */
+ private void handleStreamingRequest(UnifiedCompletionAction.Request request, ActionListener listener) {
+ executeAsyncWithOrigin(
+ client,
+ INFERENCE_ORIGIN,
+ UnifiedCompletionAction.INSTANCE,
+ request,
+ listener.delegateFailureAndWrap((l, inferenceResponse) -> {
+ inferenceResponse.publisher().subscribe(new BulkInferenceStreamingHandler(l));
+ })
+ );
+ }
+
/**
* Processes and delivers buffered responses in order, ensuring proper sequencing.
*
@@ -360,20 +393,6 @@ private void onBulkCompletion() {
}
}
- /**
- * Encapsulates an inference request with its associated sequence number.
- *
- * The sequence number is used for ordering responses and tracking completion
- * in the bulk execution state.
- *
- *
- * @param seqNo Unique sequence number for this request in the bulk operation
- * @param request The actual inference request to execute
- */
- private record BulkRequestItem(long seqNo, InferenceAction.Request request) {
-
- }
-
public static Factory factory(Client client) {
return inferenceRunnerConfig -> new BulkInferenceRunner(client, inferenceRunnerConfig.maxOutstandingBulkRequests());
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java
new file mode 100644
index 0000000000000..ab693a0787a52
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java
@@ -0,0 +1,156 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.bulk;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
+
+import java.util.List;
+import java.util.concurrent.Flow;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Handles streaming inference responses for chat completion requests in bulk inference operations.
+ *
+ * This class implements the Reactive Streams {@link Flow.Subscriber} interface to process
+ * streaming inference results from chat completion services. It accumulates content from
+ * streaming chunks and delivers the final aggregated response to the completion listener.
+ *
+ * Ultimately, it constructs a {@link ChatCompletionResults} object containing the
+ * complete content from all streaming chunks received during the inference operation.
+ * By doing so, the result of chat_completion requests is made available is the same as the legacu completion
+ * and can be consumed in the same way.
+ *
+ */
+class BulkInferenceStreamingHandler implements Flow.Subscriber {
+
+ /**
+ * Flag to track whether this streaming session has completed to prevent duplicate processing.
+ */
+ private final AtomicBoolean isLastPart = new AtomicBoolean(false);
+
+ /**
+ * The subscription handle for controlling the flow of streaming data.
+ */
+ private Flow.Subscription subscription;
+
+ /**
+ * Buffer for accumulating content from streaming chunks into the final response.
+ */
+ private final StringBuilder resultBuilder = new StringBuilder();
+
+ /**
+ * Listener to receive the final aggregated inference response.
+ */
+ private final ActionListener inferenceResponseListener;
+
+ /**
+ * Creates a new streaming handler for processing inference chat_completion responses.
+ *
+ * @param inferenceResponseListener The listener that will receive the final aggregated response
+ * once all streaming chunks have been processed
+ */
+ BulkInferenceStreamingHandler(ActionListener inferenceResponseListener) {
+ this.inferenceResponseListener = inferenceResponseListener;
+ }
+
+ /**
+ * Called when the streaming publisher is ready to start sending data.
+ *
+ * This method establishes the subscription and requests the first chunk of data.
+ * If the streaming session has already completed, it cancels the subscription
+ * to prevent resource leaks.
+ *
+ *
+ * @param subscription The subscription handle for controlling data flow
+ */
+ @Override
+ public void onSubscribe(Flow.Subscription subscription) {
+ if (isLastPart.get() == false) {
+ this.subscription = subscription;
+ subscription.request(1);
+ } else {
+ subscription.cancel();
+ }
+ }
+
+ /**
+ * Processes each streaming chunk as it arrives from the inference service.
+ *
+ * This method extracts content from streaming chat completion chunks and accumulates
+ * it in the result builder. It handles the specific structure of streaming unified
+ * chat completion results, extracting text content from delta objects within choices.
+ *
+ *
+ * After processing each chunk, it requests the next chunk from the subscription
+ * to continue the streaming process.
+ *
+ *
+ * @param item The streaming result item containing chunk data from the inference service
+ */
+ @Override
+ public void onNext(InferenceServiceResults.Result item) {
+ if (isLastPart.get() == false) {
+ if (item instanceof StreamingUnifiedChatCompletionResults.Results streamingChunkResults) {
+ for (var chunk : streamingChunkResults.chunks()) {
+ for (var choice : chunk.choices()) {
+ if (choice.delta() != null && choice.delta().content() != null) {
+ resultBuilder.append(choice.delta().content());
+ }
+ }
+ }
+ subscription.request(1);
+ } else {
+ // Handle unexpected result types by requesting the next item
+ subscription.request(1);
+ }
+ }
+ }
+
+ /**
+ * Called when an error occurs during streaming processing.
+ *
+ * This method ensures that errors are properly propagated to the inference listener
+ * and that the streaming session is marked as completed to prevent further processing.
+ *
+ *
+ * @param throwable The error that occurred during streaming
+ */
+ @Override
+ public void onError(Throwable throwable) {
+ if (isLastPart.compareAndSet(false, true)) {
+ inferenceResponseListener.onFailure(new RuntimeException("Streaming inference failed", throwable));
+ }
+ }
+
+ /**
+ * Called when the streaming process completes successfully.
+ *
+ * This method finalizes the streaming process by creating a complete inference response
+ * from the accumulated content and delivering it to the listener. It constructs a
+ * {@link ChatCompletionResults} object containing the aggregated content from all
+ * streaming chunks.
+ *
+ */
+ @Override
+ public void onComplete() {
+ if (isLastPart.compareAndSet(false, true)) {
+ // Create the final aggregated response from accumulated content
+ String finalContent = resultBuilder.toString();
+ ChatCompletionResults.Result completionResult = new ChatCompletionResults.Result(finalContent);
+ ChatCompletionResults chatResults = new ChatCompletionResults(List.of(completionResult));
+ InferenceAction.Response response = new InferenceAction.Response(chatResults);
+
+ // Deliver the final response to the listener
+ inferenceResponseListener.onResponse(response);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java
new file mode 100644
index 0000000000000..89dc03f86b953
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.completion;
+
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
+import org.elasticsearch.compute.operator.Operator;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.xpack.esql.inference.InferenceOperator;
+import org.elasticsearch.xpack.esql.inference.InferenceService;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
+
+/**
+ * {@link ChatCompletionOperator} is an {@link InferenceOperator} that performs inference using chat_completion inference endpoints.
+ * It evaluates a prompt expression for each input row, constructs inference requests, and emits the model responses as output.
+ */
+public class ChatCompletionOperator extends InferenceOperator {
+
+ private final ExpressionEvaluator promptEvaluator;
+
+ public ChatCompletionOperator(
+ DriverContext driverContext,
+ BulkInferenceRunner bulkInferenceRunner,
+ String inferenceId,
+ ExpressionEvaluator promptEvaluator,
+ int maxOutstandingPages
+ ) {
+ super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages);
+ this.promptEvaluator = promptEvaluator;
+ }
+
+ @Override
+ protected void doClose() {
+ Releasables.close(promptEvaluator);
+ }
+
+ @Override
+ public String toString() {
+ return "ChatCompletionOperator[inference_id=[" + inferenceId() + "]]";
+ }
+
+ /**
+ * Constructs the chat completion inference requests iterator for the given input page by evaluating the prompt expression.
+ *
+ * @param inputPage The input data page.
+ */
+ @Override
+ protected BulkInferenceRequestIterator requests(Page inputPage) {
+ return new ChatCompletionOperatorRequestIterator((BytesRefBlock) promptEvaluator.eval(inputPage), inferenceId());
+ }
+
+ /**
+ * Creates a new {@link CompletionOperatorOutputBuilder} to collect and emit the chat completion results.
+ *
+ * @param input The input page for which results will be constructed.
+ */
+ @Override
+ protected CompletionOperatorOutputBuilder outputBuilder(Page input) {
+ BytesRefBlock.Builder outputBlockBuilder = blockFactory().newBytesRefBlockBuilder(input.getPositionCount());
+ return new CompletionOperatorOutputBuilder(outputBlockBuilder, input);
+ }
+
+ /**
+ * Factory for creating {@link ChatCompletionOperator} instances.
+ */
+ public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory)
+ implements
+ OperatorFactory {
+ @Override
+ public String describe() {
+ return "ChatCompletionOperator[inference_id=[" + inferenceId + "]]";
+ }
+
+ @Override
+ public Operator get(DriverContext driverContext) {
+ return new ChatCompletionOperator(
+ driverContext,
+ inferenceService.bulkInferenceRunner(),
+ inferenceId,
+ promptEvaluatorFactory.get(driverContext),
+ BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests()
+ );
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java
new file mode 100644
index 0000000000000..7d0e0e3c6edd4
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.completion;
+
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestItem;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
+
+import java.util.List;
+import java.util.NoSuchElementException;
+
+/**
+ * This iterator reads prompts from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances
+ * of type {@link TaskType#CHAT_COMPLETION}.
+ */
+public class ChatCompletionOperatorRequestIterator implements BulkInferenceRequestIterator {
+
+ private final PromptReader promptReader;
+ private final String inferenceId;
+ private final int size;
+ private int currentPos = 0;
+
+ /**
+ * Constructs a new iterator from the given block of prompts.
+ *
+ * @param promptBlock The input block containing prompts.
+ * @param inferenceId The ID of the inference model to invoke.
+ */
+ public ChatCompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
+ this.promptReader = new PromptReader(promptBlock);
+ this.size = promptBlock.getPositionCount();
+ this.inferenceId = inferenceId;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return currentPos < size;
+ }
+
+ @Override
+ public BulkInferenceRequestItem.ChatCompletionRequestItem next() {
+ if (hasNext() == false) {
+ throw new NoSuchElementException();
+ }
+
+ UnifiedCompletionAction.Request inferenceRequest = inferenceRequest(promptReader.readPrompt(currentPos++));
+ return BulkInferenceRequestItem.from(inferenceRequest);
+ }
+
+ /**
+ * Wraps a single prompt string into an {@link UnifiedCompletionRequest}.
+ */
+ private UnifiedCompletionAction.Request inferenceRequest(String prompt) {
+ if (prompt == null) {
+ return null;
+ }
+
+ return new UnifiedCompletionAction.Request(
+ inferenceId,
+ TaskType.CHAT_COMPLETION,
+ UnifiedCompletionRequest.of(
+ List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(prompt), "user", null, null))
+ ),
+ TimeValue.THIRTY_SECONDS
+ );
+ }
+
+ @Override
+ public int estimatedSize() {
+ return promptReader.estimatedSize();
+ }
+
+ @Override
+ public void close() {
+ Releasables.close(promptReader);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java
index 65b560f3cf9ce..023c4670b6bc0 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java
@@ -20,7 +20,7 @@
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
/**
- * {@link CompletionOperator} is an {@link InferenceOperator} that performs inference using prompt-based model (e.g., text completion).
+ * {@link CompletionOperator} is an {@link InferenceOperator} that performs inference using completion inference endpoints.
* It evaluates a prompt expression for each input row, constructs inference requests, and emits the model responses as output.
*/
public class CompletionOperator extends InferenceOperator {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java
index f526cd9edb077..dc9922b3ecd55 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java
@@ -7,12 +7,11 @@
package org.elasticsearch.xpack.esql.inference.completion;
-import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.BytesRefBlock;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestItem;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
import java.util.List;
@@ -47,12 +46,12 @@ public boolean hasNext() {
}
@Override
- public InferenceAction.Request next() {
+ public BulkInferenceRequestItem next() {
if (hasNext() == false) {
throw new NoSuchElementException();
}
- return inferenceRequest(promptReader.readPrompt(currentPos++));
+ return BulkInferenceRequestItem.from(inferenceRequest(promptReader.readPrompt(currentPos++)));
}
/**
@@ -75,53 +74,4 @@ public int estimatedSize() {
public void close() {
Releasables.close(promptReader);
}
-
- /**
- * Helper class that reads prompts from a {@link BytesRefBlock}.
- */
- private static class PromptReader implements Releasable {
- private final BytesRefBlock promptBlock;
- private final StringBuilder strBuilder = new StringBuilder();
- private BytesRef readBuffer = new BytesRef();
-
- private PromptReader(BytesRefBlock promptBlock) {
- this.promptBlock = promptBlock;
- }
-
- /**
- * Reads the prompt string at the given position..
- *
- * @param pos the position index in the block
- */
- public String readPrompt(int pos) {
- if (promptBlock.isNull(pos)) {
- return null;
- }
-
- strBuilder.setLength(0);
-
- for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
- readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
- strBuilder.append(readBuffer.utf8ToString());
- if (valueIndex != promptBlock.getValueCount(pos) - 1) {
- strBuilder.append("\n");
- }
- }
-
- return strBuilder.toString();
- }
-
- /**
- * Returns the total number of positions (prompts) in the block.
- */
- public int estimatedSize() {
- return promptBlock.getPositionCount();
- }
-
- @Override
- public void close() {
- promptBlock.allowPassingToDifferentDriver();
- Releasables.close(promptBlock);
- }
- }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java
new file mode 100644
index 0000000000000..58c509fd6f021
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.completion;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+
+public class PromptReader implements Releasable {
+ private final BytesRefBlock promptBlock;
+ private final StringBuilder strBuilder = new StringBuilder();
+ private BytesRef readBuffer = new BytesRef();
+
+ public PromptReader(BytesRefBlock promptBlock) {
+ this.promptBlock = promptBlock;
+ }
+
+ /**
+ * Reads the prompt string at the given position.
+ *
+ * @param pos the position index in the block
+ */
+ public String readPrompt(int pos) {
+ if (promptBlock.isNull(pos)) {
+ return null;
+ }
+ strBuilder.setLength(0);
+ for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
+ readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
+ strBuilder.append(readBuffer.utf8ToString());
+ if (valueIndex != promptBlock.getValueCount(pos) - 1) {
+ strBuilder.append("\n");
+ }
+ }
+ return strBuilder.toString();
+ }
+
+ /**
+ * Returns the total number of positions (prompts) in the block.
+ */
+ public int estimatedSize() {
+ return promptBlock.getPositionCount();
+ }
+
+ @Override
+ public void close() {
+ promptBlock.allowPassingToDifferentDriver();
+ Releasables.close(promptBlock);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java
index 4b1cfe5870ad7..fe33c799aaec8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java
@@ -13,6 +13,7 @@
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestItem;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
import java.util.ArrayList;
@@ -46,7 +47,7 @@ public boolean hasNext() {
}
@Override
- public InferenceAction.Request next() {
+ public BulkInferenceRequestItem next() {
if (hasNext() == false) {
throw new NoSuchElementException();
}
@@ -59,7 +60,7 @@ public InferenceAction.Request next() {
if (inputBlock.isNull(startIndex)) {
remainingPositions -= 1;
- return null;
+ return BulkInferenceRequestItem.from((InferenceAction.Request) null);
}
for (int i = 0; i < maxInputSize; i++) {
@@ -73,7 +74,7 @@ public InferenceAction.Request next() {
}
remainingPositions -= inputs.size();
- return inferenceRequest(inputs);
+ return BulkInferenceRequestItem.from(inferenceRequest(inputs));
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
index 191664bea9a81..12950dcafbf6c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.esql.plan.logical.inference;
+import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -22,6 +23,7 @@
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.inference.ResolvedInference;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
@@ -37,6 +39,8 @@ public class Completion extends InferencePlan implements TelemetryAw
public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion";
+ public static final List SUPPORTED_TASK_TYPES = List.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
LogicalPlan.class,
"Completion",
@@ -47,11 +51,18 @@ public class Completion extends InferencePlan implements TelemetryAw
private List lazyOutput;
public Completion(Source source, LogicalPlan p, Expression prompt, Attribute targetField) {
- this(source, p, Literal.keyword(Source.EMPTY, DEFAULT_OUTPUT_FIELD_NAME), prompt, targetField);
- }
-
- public Completion(Source source, LogicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
- super(source, child, inferenceId);
+ this(source, p, Literal.NULL, null, prompt, targetField);
+ }
+
+ public Completion(
+ Source source,
+ LogicalPlan child,
+ Expression inferenceId,
+ TaskType taskType,
+ Expression prompt,
+ Attribute targetField
+ ) {
+ super(source, child, inferenceId, taskType);
this.prompt = prompt;
this.targetField = targetField;
}
@@ -61,6 +72,9 @@ public Completion(StreamInput in) throws IOException {
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(LogicalPlan.class),
in.readNamedWriteable(Expression.class),
+ in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)
+ ? in.readOptional(input -> TaskType.fromString(input.readString()))
+ : TaskType.COMPLETION,
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Attribute.class)
);
@@ -69,6 +83,9 @@ public Completion(StreamInput in) throws IOException {
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
+ if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)) {
+ out.writeOptional((output, taskType) -> output.writeString(taskType.toString()), taskType());
+ }
out.writeNamedWriteable(prompt);
out.writeNamedWriteable(targetField);
}
@@ -87,17 +104,17 @@ public Completion withInferenceId(Expression newInferenceId) {
return this;
}
- return new Completion(source(), child(), newInferenceId, prompt, targetField);
+ return new Completion(source(), child(), newInferenceId, taskType(), prompt, targetField);
}
@Override
- public Completion replaceChild(LogicalPlan newChild) {
- return new Completion(source(), newChild, inferenceId(), prompt, targetField);
+ public List supportedTaskTypes() {
+ return SUPPORTED_TASK_TYPES;
}
@Override
- public TaskType taskType() {
- return TaskType.COMPLETION;
+ public Completion replaceChild(LogicalPlan newChild) {
+ return new Completion(source(), newChild, inferenceId(), taskType(), prompt, targetField);
}
@Override
@@ -122,7 +139,7 @@ public List generatedAttributes() {
@Override
public Completion withGeneratedNames(List newNames) {
checkNumberOfNewNames(newNames);
- return new Completion(source(), child(), inferenceId(), prompt, this.renameTargetField(newNames.get(0)));
+ return new Completion(source(), child(), inferenceId(), taskType(), prompt, this.renameTargetField(newNames.get(0)));
}
private Attribute renameTargetField(String newName) {
@@ -133,6 +150,24 @@ private Attribute renameTargetField(String newName) {
return targetField.withName(newName).withId(new NameId());
}
+ @Override
+ public Completion withResolvedInference(ResolvedInference resolvedInference) {
+ Completion completion = super.withResolvedInference(resolvedInference);
+
+ if (completion.inferenceId().resolved()) {
+ return new Completion(
+ source(),
+ child(),
+ completion.inferenceId(),
+ resolvedInference.taskType(),
+ completion.prompt(),
+ completion.targetField()
+ );
+ }
+
+ return completion;
+ }
+
@Override
protected AttributeSet computeReferences() {
return prompt.references();
@@ -152,7 +187,7 @@ public void postAnalysisVerification(Failures failures) {
@Override
protected NodeInfo extends LogicalPlan> info() {
- return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField);
+ return NodeInfo.create(this, Completion::new, child(), inferenceId(), taskType(), prompt, targetField);
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
index 633ed74d8addb..70769cd7e0140 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
@@ -12,6 +12,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.inference.ResolvedInference;
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
import org.elasticsearch.xpack.esql.plan.logical.ExecutesOn;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
@@ -32,10 +33,12 @@ public abstract class InferencePlan> ex
public static final List VALID_INFERENCE_OPTION_NAMES = List.of(INFERENCE_ID_OPTION_NAME);
private final Expression inferenceId;
+ private final TaskType taskType;
- protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId) {
+ protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId, TaskType taskType) {
super(source, child);
this.inferenceId = inferenceId;
+ this.taskType = taskType;
}
@Override
@@ -60,18 +63,40 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
InferencePlan> other = (InferencePlan>) o;
- return Objects.equals(inferenceId(), other.inferenceId());
+ return Objects.equals(inferenceId(), other.inferenceId()) && taskType == other.taskType;
}
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), inferenceId());
+ return Objects.hash(super.hashCode(), inferenceId(), taskType);
}
- public abstract TaskType taskType();
+ public TaskType taskType() {
+ return taskType;
+ }
public abstract PlanType withInferenceId(Expression newInferenceId);
+ public abstract List supportedTaskTypes();
+
+ @SuppressWarnings("unchecked")
+ public PlanType withResolvedInference(ResolvedInference resolvedInference) {
+ if (supportedTaskTypes().stream().noneMatch(resolvedInference.taskType()::equals)) {
+ String error = "cannot use inference endpoint ["
+ + resolvedInference.inferenceId()
+ + "] with task type ["
+ + resolvedInference.taskType()
+ + "] within a "
+ + nodeName()
+ + " command. Only inference endpoints with the task type "
+ + supportedTaskTypes()
+ + " are supported.";
+ return withInferenceResolutionError(resolvedInference.inferenceId(), error);
+ }
+
+ return (PlanType) this;
+ }
+
public PlanType withInferenceResolutionError(String inferenceId, String error) {
return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java
index 6f86138397fa6..25894c75f2289 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java
@@ -59,7 +59,7 @@ public Rerank(
List rerankFields,
Attribute scoreAttribute
) {
- super(source, child, inferenceId);
+ super(source, child, inferenceId, TaskType.RERANK);
this.queryText = queryText;
this.rerankFields = rerankFields;
this.scoreAttribute = scoreAttribute;
@@ -96,11 +96,6 @@ public Attribute scoreAttribute() {
return scoreAttribute;
}
- @Override
- public TaskType taskType() {
- return TaskType.RERANK;
- }
-
@Override
public Rerank withInferenceId(Expression newInferenceId) {
if (inferenceId().equals(newInferenceId)) {
@@ -109,6 +104,11 @@ public Rerank withInferenceId(Expression newInferenceId) {
return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute);
}
+ @Override
+ public List supportedTaskTypes() {
+ return List.of(TaskType.RERANK);
+ }
+
public Rerank withRerankFields(List newRerankFields) {
if (rerankFields.equals(newRerankFields)) {
return this;
@@ -163,7 +163,7 @@ public static AttributeSet computeReferences(List fields) {
}
public boolean isValidRerankField(Alias rerankField) {
- // Only supportinng the following datatypes for now: text, numeric and boolean
+ // Only supporting the following datatypes for now: text, numeric and boolean
return DataType.isString(rerankField.dataType())
|| rerankField.dataType() == DataType.BOOLEAN
|| rerankField.dataType().isNumeric();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java
index 80887ad08fe69..fc8d09b83566e 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java
@@ -7,9 +7,11 @@
package org.elasticsearch.xpack.esql.plan.physical.inference;
+import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
@@ -37,8 +39,15 @@ public class CompletionExec extends InferenceExec {
private final Attribute targetField;
private List lazyOutput;
- public CompletionExec(Source source, PhysicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
- super(source, child, inferenceId);
+ public CompletionExec(
+ Source source,
+ PhysicalPlan child,
+ Expression inferenceId,
+ TaskType taskType,
+ Expression prompt,
+ Attribute targetField
+ ) {
+ super(source, child, inferenceId, taskType);
this.prompt = prompt;
this.targetField = targetField;
}
@@ -48,6 +57,9 @@ public CompletionExec(StreamInput in) throws IOException {
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(PhysicalPlan.class),
in.readNamedWriteable(Expression.class),
+ in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)
+ ? TaskType.fromString(in.readString())
+ : TaskType.COMPLETION,
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Attribute.class)
);
@@ -61,6 +73,9 @@ public String getWriteableName() {
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
+ if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)) {
+ out.writeString(taskType().toString());
+ }
out.writeNamedWriteable(prompt);
out.writeNamedWriteable(targetField);
}
@@ -75,12 +90,12 @@ public Attribute targetField() {
@Override
protected NodeInfo extends PhysicalPlan> info() {
- return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), prompt, targetField);
+ return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), taskType(), prompt, targetField);
}
@Override
public UnaryExec replaceChild(PhysicalPlan newChild) {
- return new CompletionExec(source(), newChild, inferenceId(), prompt, targetField);
+ return new CompletionExec(source(), newChild, inferenceId(), taskType(), prompt, targetField);
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
index d60a5ecccc384..f973fcd58e346 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.plan.physical.inference;
import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
@@ -18,16 +19,22 @@
public abstract class InferenceExec extends UnaryExec {
private final Expression inferenceId;
+ private final TaskType taskType;
- protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId) {
+ protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId, TaskType taskType) {
super(source, child);
- this.inferenceId = inferenceId;
+ this.inferenceId = Objects.requireNonNull(inferenceId);
+ this.taskType = Objects.requireNonNull(taskType);
}
public Expression inferenceId() {
return inferenceId;
}
+ public TaskType taskType() {
+ return taskType;
+ }
+
@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
@@ -41,11 +48,11 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
InferenceExec that = (InferenceExec) o;
- return inferenceId.equals(that.inferenceId);
+ return inferenceId.equals(that.inferenceId) && taskType == that.taskType;
}
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), inferenceId());
+ return Objects.hash(super.hashCode(), inferenceId(), taskType);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java
index ad852d0ac20db..28cb18b17ecdc 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java
@@ -10,6 +10,7 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
@@ -48,7 +49,7 @@ public RerankExec(
List rerankFields,
Attribute scoreAttribute
) {
- super(source, child, inferenceId);
+ super(source, child, inferenceId, TaskType.RERANK);
this.queryText = queryText;
this.rerankFields = rerankFields;
this.scoreAttribute = scoreAttribute;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
index 878d223535df5..29fee28801aa8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
@@ -89,6 +89,7 @@
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.inference.InferenceService;
import org.elasticsearch.xpack.esql.inference.XContentRowEncoder;
+import org.elasticsearch.xpack.esql.inference.completion.ChatCompletionOperator;
import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator;
import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
@@ -323,7 +324,12 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti
source.layout
);
- return source.with(new CompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory), outputLayout);
+ OperatorFactory operatorFactory = switch (completion.taskType()) {
+ case CHAT_COMPLETION -> new ChatCompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory);
+ default -> new CompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory);
+ };
+
+ return source.with(operatorFactory, outputLayout);
}
private PhysicalOperation planFuseScoreEvalExec(FuseScoreEvalExec fuse, LocalExecutionPlannerContext context) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
index aabb18326fe11..05cee4b6a4adf 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
@@ -103,7 +103,14 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) {
}
if (p instanceof Completion completion) {
- return new CompletionExec(completion.source(), child, completion.inferenceId(), completion.prompt(), completion.targetField());
+ return new CompletionExec(
+ completion.source(),
+ child,
+ completion.inferenceId(),
+ completion.taskType(),
+ completion.prompt(),
+ completion.targetField()
+ );
}
if (p instanceof Enrich enrich) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index 417eb0f1a7834..cf89671a2543e 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -4026,7 +4026,7 @@ public void testResolveCompletionInferenceIdInvalidTaskType() {
"mapping-books.json",
new QueryParams(),
"cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command."
- + " Only inference endpoints with the task type [completion] are supported"
+ + " Only inference endpoints with the task type [completion, chat_completion] are supported"
);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java
index dedbf895860b9..02eec173eb3cd 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java
@@ -193,7 +193,10 @@ private BulkInferenceRunnerConfig randomBulkExecutionConfig() {
}
private BulkInferenceRequestIterator requestIterator(List requests) {
- final Iterator delegate = requests.iterator();
+ final Iterator> delegate = requests.stream()
+ .map(BulkInferenceRequestItem::from)
+ .toList()
+ .iterator();
BulkInferenceRequestIterator iterator = mock(BulkInferenceRequestIterator.class);
doAnswer(i -> delegate.hasNext()).when(iterator).hasNext();
doAnswer(i -> delegate.next()).when(iterator).next();
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java
index 86592256d26bc..63b9a7677911c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java
@@ -32,7 +32,7 @@ private void assertIterate(int size) throws Exception {
BytesRef scratch = new BytesRef();
for (int currentPos = 0; requestIterator.hasNext(); currentPos++) {
- InferenceAction.Request request = requestIterator.next();
+ InferenceAction.Request request = requestIterator.next().inferenceRequest();
assertThat(request.getInferenceEntityId(), equalTo(inferenceId));
scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch);
assertThat(request.getInput().getFirst(), equalTo(scratch.utf8ToString()));
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java
index 72397efcf1be3..010a1ed35816a 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java
@@ -37,7 +37,7 @@ private void assertIterate(int size, int batchSize) throws Exception {
BytesRef scratch = new BytesRef();
for (int currentPos = 0; requestIterator.hasNext();) {
- InferenceAction.Request request = requestIterator.next();
+ InferenceAction.Request request = requestIterator.next().inferenceRequest();
assertThat(request.getInferenceEntityId(), equalTo(inferenceId));
assertThat(request.getQuery(), equalTo(queryText));
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
index 1a181fe805e81..8a912f3d7c11c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
@@ -5488,6 +5488,7 @@ record PushdownShadowingGeneratingPlanTestCase(
EMPTY,
plan,
randomLiteral(TEXT),
+ randomFrom(Completion.SUPPORTED_TASK_TYPES),
new Concat(EMPTY, randomLiteral(TEXT), List.of(attr)),
new ReferenceAttribute(EMPTY, "y", KEYWORD)
),
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java
index 11c64a82e3f57..988718ff365c2 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java
@@ -303,6 +303,7 @@ public void testPushDownFilterPastCompletion() {
EMPTY,
new Filter(EMPTY, relation, new And(EMPTY, conditionA, conditionB)),
completion.inferenceId(),
+ completion.taskType(),
completion.prompt(),
completion.targetField()
),
@@ -350,6 +351,7 @@ private static Completion completion(LogicalPlan child) {
EMPTY,
child,
randomLiteral(DataType.KEYWORD),
+ randomFrom(Completion.SUPPORTED_TASK_TYPES),
randomLiteral(randomBoolean() ? DataType.TEXT : DataType.KEYWORD),
referenceAttribute(randomIdentifier(), DataType.KEYWORD)
);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java
index b1626e4b77ce8..973b770bedac3 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java
@@ -77,7 +77,14 @@ public void checkOptimizedPlan(LogicalPlan basePlan, LogicalPlan optimizedPlan)
),
new PushDownLimitTestCase<>(
Completion.class,
- (plan, attr) -> new Completion(EMPTY, plan, randomLiteral(KEYWORD), randomLiteral(KEYWORD), attr),
+ (plan, attr) -> new Completion(
+ EMPTY,
+ plan,
+ randomLiteral(KEYWORD),
+ randomFrom(Completion.SUPPORTED_TASK_TYPES),
+ randomLiteral(KEYWORD),
+ attr
+ ),
(basePlan, optimizedPlan) -> {
assertEquals(basePlan.source(), optimizedPlan.source());
assertEquals(basePlan.inferenceId(), optimizedPlan.inferenceId());
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java
index e9810454224aa..732335f958fc4 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.esql.plan.logical.inference;
+import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
@@ -21,7 +22,14 @@ public class CompletionSerializationTests extends AbstractLogicalPlanSerializati
@Override
protected Completion createTestInstance() {
- return new Completion(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute());
+ return new Completion(
+ randomSource(),
+ randomChild(0),
+ randomInferenceId(),
+ randomTaskTypeOrNull(),
+ randomPrompt(),
+ randomAttribute()
+ );
}
@Override
@@ -30,14 +38,16 @@ protected Completion mutateInstance(Completion instance) throws IOException {
Expression inferenceId = instance.inferenceId();
Expression prompt = instance.prompt();
Attribute targetField = instance.targetField();
+ TaskType taskType = instance.taskType();
- switch (between(0, 3)) {
+ switch (between(0, 4)) {
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId);
- case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
- case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
+ case 2 -> taskType = randomValueOtherThan(taskType, this::randomTaskTypeOrNull);
+ case 3 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
+ case 4 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
}
- return new Completion(instance.source(), child, inferenceId, prompt, targetField);
+ return new Completion(instance.source(), child, inferenceId, taskType, prompt, targetField);
}
private Literal randomInferenceId() {
@@ -51,4 +61,12 @@ private Expression randomPrompt() {
private Attribute randomAttribute() {
return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
}
+
+ private TaskType randomTaskType() {
+ return randomFrom(Completion.SUPPORTED_TASK_TYPES);
+ }
+
+ private TaskType randomTaskTypeOrNull() {
+ return randomBoolean() ? randomTaskType() : null;
+ }
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java
index 9fd41a2432462..486ae323c982c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java
@@ -7,11 +7,13 @@
package org.elasticsearch.xpack.esql.plan.physical.inference;
+import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
@@ -20,7 +22,14 @@
public class CompletionExecSerializationTests extends AbstractPhysicalPlanSerializationTests {
@Override
protected CompletionExec createTestInstance() {
- return new CompletionExec(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute());
+ return new CompletionExec(
+ randomSource(),
+ randomChild(0),
+ randomInferenceId(),
+ randomFrom(Completion.SUPPORTED_TASK_TYPES),
+ randomPrompt(),
+ randomAttribute()
+ );
}
@Override
@@ -29,14 +38,16 @@ protected CompletionExec mutateInstance(CompletionExec instance) throws IOExcept
Expression inferenceId = instance.inferenceId();
Expression prompt = instance.prompt();
Attribute targetField = instance.targetField();
+ TaskType taskType = instance.taskType();
- switch (between(0, 3)) {
+ switch (between(0, 4)) {
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId);
- case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
- case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
+ case 2 -> taskType = randomValueOtherThan(taskType, () -> randomFrom(Completion.SUPPORTED_TASK_TYPES));
+ case 3 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
+ case 4 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
}
- return new CompletionExec(instance.source(), child, inferenceId, prompt, targetField);
+ return new CompletionExec(instance.source(), child, inferenceId, taskType, prompt, targetField);
}
private Literal randomInferenceId() {