diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 10114e429cc7a..09759cadaea70 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -324,6 +324,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_156_0_00); public static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = def(9_157_0_00); public static final TransportVersion INDEX_SOURCE = def(9_158_0_00); + public static final TransportVersion ESQL_CHAT_COMPLETION_SUPPORT = def(9_159_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 1f51055094c92..c8d2de3d1ad7c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -623,7 +623,7 @@ private LogicalPlan resolveCompletion(Completion p, List childrenOutp prompt = prompt.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput)); } - return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField); + return new Completion(p.source(), p.child(), p.inferenceId(), p.taskType(), prompt, targetField); } private LogicalPlan resolveMvExpand(MvExpand p, List childrenOutput) { @@ -1349,20 +1349,7 @@ private LogicalPlan resolveInferencePlan(InferencePlan plan, AnalyzerContext return plan.withInferenceResolutionError(inferenceId, error); } - if (resolvedInference.taskType() != plan.taskType()) { - String error = "cannot use inference endpoint [" - + inferenceId - + "] with task type [" - + resolvedInference.taskType() - + "] within a " - + plan.nodeName() - + " command. Only inference endpoints with the task type [" - + plan.taskType() - + "] are supported."; - return plan.withInferenceResolutionError(inferenceId, error); - } - - return plan; + return plan.withResolvedInference(resolvedInference); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java new file mode 100644 index 0000000000000..20b1bc24143da --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.bulk; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; + +import java.util.Objects; + +public sealed interface BulkInferenceRequestItem permits + BulkInferenceRequestItem.AbstractBulkInferenceRequestItem { + + TaskType taskType(); + + T inferenceRequest(); + + BulkInferenceRequestItem withSeqNo(long seqNo); + + Long seqNo(); + + static InferenceRequestItem from(InferenceAction.Request request) { + return new InferenceRequestItem(request); + } + + static ChatCompletionRequestItem from(UnifiedCompletionAction.Request request) { + return new ChatCompletionRequestItem(request); + } + + abstract sealed class AbstractBulkInferenceRequestItem implements BulkInferenceRequestItem + permits InferenceRequestItem, ChatCompletionRequestItem { + private final T request; + private final Long seqNo; + + protected AbstractBulkInferenceRequestItem(T request) { + this(request, null); + } + + protected AbstractBulkInferenceRequestItem(T request, Long seqNo) { + this.request = request; + this.seqNo = seqNo; + } + + @Override + public T inferenceRequest() { + return request; + } + + @Override + public Long seqNo() { + return seqNo; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + AbstractBulkInferenceRequestItem that = (AbstractBulkInferenceRequestItem) o; + return Objects.equals(request, that.request) && Objects.equals(seqNo, that.seqNo); + } + + @Override + public int hashCode() { + return Objects.hash(request, seqNo); + } + + @Override + public TaskType taskType() { + return request.getTaskType(); + } + } + + final class InferenceRequestItem extends AbstractBulkInferenceRequestItem { + private InferenceRequestItem(InferenceAction.Request request) { + this(request, null); + } + + private InferenceRequestItem(InferenceAction.Request request, Long seqNo) { + super(request, seqNo); + } + + @Override + public InferenceRequestItem withSeqNo(long seqNo) { + return new InferenceRequestItem(inferenceRequest(), seqNo); + } + } + + final class ChatCompletionRequestItem extends AbstractBulkInferenceRequestItem { + + private ChatCompletionRequestItem(UnifiedCompletionAction.Request request) { + this(request, null); + } + + private ChatCompletionRequestItem(UnifiedCompletionAction.Request request, Long seqNo) { + super(request, seqNo); + } + + @Override + public TaskType taskType() { + return TaskType.CHAT_COMPLETION; + } + + @Override + public ChatCompletionRequestItem withSeqNo(long seqNo) { + return new ChatCompletionRequestItem(inferenceRequest(), seqNo); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java index 7327b182d0b6c..03739af92f7ff 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java @@ -8,17 +8,15 @@ package org.elasticsearch.xpack.esql.inference.bulk; import org.elasticsearch.core.Releasable; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.util.Iterator; -public interface BulkInferenceRequestIterator extends Iterator, Releasable { +public interface BulkInferenceRequestIterator extends Iterator>, Releasable { /** * Returns an estimate of the number of requests that will be produced. * - *

This is typically used to pre-allocate buffers or output to th appropriate size.

+ *

This is typically used to pre-allocate buffers or output to the appropriate size.

*/ int estimatedSize(); - } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java index 203a3031bcad4..4d35411e3d3bb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java @@ -10,8 +10,10 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import java.util.ArrayList; import java.util.List; @@ -175,12 +177,12 @@ private class BulkInferenceRequest { * to the request iterator. *

* - * @return A BulkRequestItem if a request and permit are available, null otherwise + * @return A BulkInferenceRequestItem if a request and permit are available, null otherwise */ - private BulkRequestItem pollPendingRequest() { + private BulkInferenceRequestItem pollPendingRequest() { synchronized (requests) { if (requests.hasNext()) { - return new BulkRequestItem(executionState.generateSeqNo(), requests.next()); + return requests.next().withSeqNo(executionState.generateSeqNo()); } } @@ -226,7 +228,7 @@ private void executePendingRequests(int recursionDepth) { } return; } else { - BulkRequestItem bulkRequestItem = pollPendingRequest(); + BulkInferenceRequestItem bulkRequestItem = pollPendingRequest(); if (bulkRequestItem == null) { // No more requests available @@ -234,14 +236,14 @@ private void executePendingRequests(int recursionDepth) { permits.release(); // Check if another bulk request is pending for execution. - BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); + BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll(); - while (nexBulkRequest == this) { - nexBulkRequest = pendingBulkRequests.poll(); + while (nextBulkRequest == this) { + nextBulkRequest = pendingBulkRequests.poll(); } - if (nexBulkRequest != null) { - executor.execute(nexBulkRequest::executePendingRequests); + if (nextBulkRequest != null) { + executor.execute(nextBulkRequest::executePendingRequests); } return; @@ -275,9 +277,9 @@ private void executePendingRequests(int recursionDepth) { // Response has already been sent // No need to continue processing this bulk. // Check if another bulk request is pending for execution. - BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); - if (nexBulkRequest != null) { - executor.execute(nexBulkRequest::executePendingRequests); + BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll(); + if (nextBulkRequest != null) { + executor.execute(nextBulkRequest::executePendingRequests); } return; } @@ -298,19 +300,26 @@ private void executePendingRequests(int recursionDepth) { ); // Handle null requests (edge case in some iterators) - if (bulkRequestItem.request() == null) { + if (bulkRequestItem.inferenceRequest() == null) { inferenceResponseListener.onResponse(null); return; } // Execute the inference request with proper origin context - executeAsyncWithOrigin( - client, - INFERENCE_ORIGIN, - InferenceAction.INSTANCE, - bulkRequestItem.request(), - inferenceResponseListener - ); + if (bulkRequestItem.taskType() == TaskType.CHAT_COMPLETION) { + handleStreamingRequest( + (UnifiedCompletionAction.Request) bulkRequestItem.inferenceRequest(), + inferenceResponseListener + ); + } else { + executeAsyncWithOrigin( + client, + INFERENCE_ORIGIN, + InferenceAction.INSTANCE, + bulkRequestItem.inferenceRequest(), + inferenceResponseListener + ); + } } } } catch (Exception e) { @@ -318,6 +327,30 @@ private void executePendingRequests(int recursionDepth) { } } + /** + * Handles streaming inference requests for chat completion tasks. + *

+ * This method executes UnifiedCompletionAction requests and sets up proper streaming + * response handling through the BulkInferenceStreamingHandler. The streaming handler + * manages the asynchronous stream processing and ensures responses are properly + * delivered to the completion listener. + *

+ * + * @param request The UnifiedCompletionAction request to execute + * @param listener The listener to receive the final aggregated response + */ + private void handleStreamingRequest(UnifiedCompletionAction.Request request, ActionListener listener) { + executeAsyncWithOrigin( + client, + INFERENCE_ORIGIN, + UnifiedCompletionAction.INSTANCE, + request, + listener.delegateFailureAndWrap((l, inferenceResponse) -> { + inferenceResponse.publisher().subscribe(new BulkInferenceStreamingHandler(l)); + }) + ); + } + /** * Processes and delivers buffered responses in order, ensuring proper sequencing. *

@@ -360,20 +393,6 @@ private void onBulkCompletion() { } } - /** - * Encapsulates an inference request with its associated sequence number. - *

- * The sequence number is used for ordering responses and tracking completion - * in the bulk execution state. - *

- * - * @param seqNo Unique sequence number for this request in the bulk operation - * @param request The actual inference request to execute - */ - private record BulkRequestItem(long seqNo, InferenceAction.Request request) { - - } - public static Factory factory(Client client) { return inferenceRunnerConfig -> new BulkInferenceRunner(client, inferenceRunnerConfig.maxOutstandingBulkRequests()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java new file mode 100644 index 0000000000000..ab693a0787a52 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.bulk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; + +import java.util.List; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Handles streaming inference responses for chat completion requests in bulk inference operations. + *

+ * This class implements the Reactive Streams {@link Flow.Subscriber} interface to process + * streaming inference results from chat completion services. It accumulates content from + * streaming chunks and delivers the final aggregated response to the completion listener. + *

+ * Ultimately, it constructs a {@link ChatCompletionResults} object containing the + * complete content from all streaming chunks received during the inference operation. + * By doing so, the result of chat_completion requests is made available is the same as the legacu completion + * and can be consumed in the same way. + *

+ */ +class BulkInferenceStreamingHandler implements Flow.Subscriber { + + /** + * Flag to track whether this streaming session has completed to prevent duplicate processing. + */ + private final AtomicBoolean isLastPart = new AtomicBoolean(false); + + /** + * The subscription handle for controlling the flow of streaming data. + */ + private Flow.Subscription subscription; + + /** + * Buffer for accumulating content from streaming chunks into the final response. + */ + private final StringBuilder resultBuilder = new StringBuilder(); + + /** + * Listener to receive the final aggregated inference response. + */ + private final ActionListener inferenceResponseListener; + + /** + * Creates a new streaming handler for processing inference chat_completion responses. + * + * @param inferenceResponseListener The listener that will receive the final aggregated response + * once all streaming chunks have been processed + */ + BulkInferenceStreamingHandler(ActionListener inferenceResponseListener) { + this.inferenceResponseListener = inferenceResponseListener; + } + + /** + * Called when the streaming publisher is ready to start sending data. + *

+ * This method establishes the subscription and requests the first chunk of data. + * If the streaming session has already completed, it cancels the subscription + * to prevent resource leaks. + *

+ * + * @param subscription The subscription handle for controlling data flow + */ + @Override + public void onSubscribe(Flow.Subscription subscription) { + if (isLastPart.get() == false) { + this.subscription = subscription; + subscription.request(1); + } else { + subscription.cancel(); + } + } + + /** + * Processes each streaming chunk as it arrives from the inference service. + *

+ * This method extracts content from streaming chat completion chunks and accumulates + * it in the result builder. It handles the specific structure of streaming unified + * chat completion results, extracting text content from delta objects within choices. + *

+ *

+ * After processing each chunk, it requests the next chunk from the subscription + * to continue the streaming process. + *

+ * + * @param item The streaming result item containing chunk data from the inference service + */ + @Override + public void onNext(InferenceServiceResults.Result item) { + if (isLastPart.get() == false) { + if (item instanceof StreamingUnifiedChatCompletionResults.Results streamingChunkResults) { + for (var chunk : streamingChunkResults.chunks()) { + for (var choice : chunk.choices()) { + if (choice.delta() != null && choice.delta().content() != null) { + resultBuilder.append(choice.delta().content()); + } + } + } + subscription.request(1); + } else { + // Handle unexpected result types by requesting the next item + subscription.request(1); + } + } + } + + /** + * Called when an error occurs during streaming processing. + *

+ * This method ensures that errors are properly propagated to the inference listener + * and that the streaming session is marked as completed to prevent further processing. + *

+ * + * @param throwable The error that occurred during streaming + */ + @Override + public void onError(Throwable throwable) { + if (isLastPart.compareAndSet(false, true)) { + inferenceResponseListener.onFailure(new RuntimeException("Streaming inference failed", throwable)); + } + } + + /** + * Called when the streaming process completes successfully. + *

+ * This method finalizes the streaming process by creating a complete inference response + * from the accumulated content and delivering it to the listener. It constructs a + * {@link ChatCompletionResults} object containing the aggregated content from all + * streaming chunks. + *

+ */ + @Override + public void onComplete() { + if (isLastPart.compareAndSet(false, true)) { + // Create the final aggregated response from accumulated content + String finalContent = resultBuilder.toString(); + ChatCompletionResults.Result completionResult = new ChatCompletionResults.Result(finalContent); + ChatCompletionResults chatResults = new ChatCompletionResults(List.of(completionResult)); + InferenceAction.Response response = new InferenceAction.Response(chatResults); + + // Deliver the final response to the listener + inferenceResponseListener.onResponse(response); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java new file mode 100644 index 0000000000000..89dc03f86b953 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.completion; + +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; +import org.elasticsearch.xpack.esql.inference.InferenceService; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; + +/** + * {@link ChatCompletionOperator} is an {@link InferenceOperator} that performs inference using chat_completion inference endpoints. + * It evaluates a prompt expression for each input row, constructs inference requests, and emits the model responses as output. + */ +public class ChatCompletionOperator extends InferenceOperator { + + private final ExpressionEvaluator promptEvaluator; + + public ChatCompletionOperator( + DriverContext driverContext, + BulkInferenceRunner bulkInferenceRunner, + String inferenceId, + ExpressionEvaluator promptEvaluator, + int maxOutstandingPages + ) { + super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages); + this.promptEvaluator = promptEvaluator; + } + + @Override + protected void doClose() { + Releasables.close(promptEvaluator); + } + + @Override + public String toString() { + return "ChatCompletionOperator[inference_id=[" + inferenceId() + "]]"; + } + + /** + * Constructs the chat completion inference requests iterator for the given input page by evaluating the prompt expression. + * + * @param inputPage The input data page. + */ + @Override + protected BulkInferenceRequestIterator requests(Page inputPage) { + return new ChatCompletionOperatorRequestIterator((BytesRefBlock) promptEvaluator.eval(inputPage), inferenceId()); + } + + /** + * Creates a new {@link CompletionOperatorOutputBuilder} to collect and emit the chat completion results. + * + * @param input The input page for which results will be constructed. + */ + @Override + protected CompletionOperatorOutputBuilder outputBuilder(Page input) { + BytesRefBlock.Builder outputBlockBuilder = blockFactory().newBytesRefBlockBuilder(input.getPositionCount()); + return new CompletionOperatorOutputBuilder(outputBlockBuilder, input); + } + + /** + * Factory for creating {@link ChatCompletionOperator} instances. + */ + public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory) + implements + OperatorFactory { + @Override + public String describe() { + return "ChatCompletionOperator[inference_id=[" + inferenceId + "]]"; + } + + @Override + public Operator get(DriverContext driverContext) { + return new ChatCompletionOperator( + driverContext, + inferenceService.bulkInferenceRunner(), + inferenceId, + promptEvaluatorFactory.get(driverContext), + BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests() + ); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java new file mode 100644 index 0000000000000..7d0e0e3c6edd4 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.completion; + +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestItem; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; + +import java.util.List; +import java.util.NoSuchElementException; + +/** + * This iterator reads prompts from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances + * of type {@link TaskType#CHAT_COMPLETION}. + */ +public class ChatCompletionOperatorRequestIterator implements BulkInferenceRequestIterator { + + private final PromptReader promptReader; + private final String inferenceId; + private final int size; + private int currentPos = 0; + + /** + * Constructs a new iterator from the given block of prompts. + * + * @param promptBlock The input block containing prompts. + * @param inferenceId The ID of the inference model to invoke. + */ + public ChatCompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) { + this.promptReader = new PromptReader(promptBlock); + this.size = promptBlock.getPositionCount(); + this.inferenceId = inferenceId; + } + + @Override + public boolean hasNext() { + return currentPos < size; + } + + @Override + public BulkInferenceRequestItem.ChatCompletionRequestItem next() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + + UnifiedCompletionAction.Request inferenceRequest = inferenceRequest(promptReader.readPrompt(currentPos++)); + return BulkInferenceRequestItem.from(inferenceRequest); + } + + /** + * Wraps a single prompt string into an {@link UnifiedCompletionRequest}. + */ + private UnifiedCompletionAction.Request inferenceRequest(String prompt) { + if (prompt == null) { + return null; + } + + return new UnifiedCompletionAction.Request( + inferenceId, + TaskType.CHAT_COMPLETION, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(prompt), "user", null, null)) + ), + TimeValue.THIRTY_SECONDS + ); + } + + @Override + public int estimatedSize() { + return promptReader.estimatedSize(); + } + + @Override + public void close() { + Releasables.close(promptReader); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java index 65b560f3cf9ce..023c4670b6bc0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java @@ -20,7 +20,7 @@ import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; /** - * {@link CompletionOperator} is an {@link InferenceOperator} that performs inference using prompt-based model (e.g., text completion). + * {@link CompletionOperator} is an {@link InferenceOperator} that performs inference using completion inference endpoints. * It evaluates a prompt expression for each input row, constructs inference requests, and emits the model responses as output. */ public class CompletionOperator extends InferenceOperator { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java index f526cd9edb077..dc9922b3ecd55 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java @@ -7,12 +7,11 @@ package org.elasticsearch.xpack.esql.inference.completion; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestItem; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; import java.util.List; @@ -47,12 +46,12 @@ public boolean hasNext() { } @Override - public InferenceAction.Request next() { + public BulkInferenceRequestItem next() { if (hasNext() == false) { throw new NoSuchElementException(); } - return inferenceRequest(promptReader.readPrompt(currentPos++)); + return BulkInferenceRequestItem.from(inferenceRequest(promptReader.readPrompt(currentPos++))); } /** @@ -75,53 +74,4 @@ public int estimatedSize() { public void close() { Releasables.close(promptReader); } - - /** - * Helper class that reads prompts from a {@link BytesRefBlock}. - */ - private static class PromptReader implements Releasable { - private final BytesRefBlock promptBlock; - private final StringBuilder strBuilder = new StringBuilder(); - private BytesRef readBuffer = new BytesRef(); - - private PromptReader(BytesRefBlock promptBlock) { - this.promptBlock = promptBlock; - } - - /** - * Reads the prompt string at the given position.. - * - * @param pos the position index in the block - */ - public String readPrompt(int pos) { - if (promptBlock.isNull(pos)) { - return null; - } - - strBuilder.setLength(0); - - for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) { - readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer); - strBuilder.append(readBuffer.utf8ToString()); - if (valueIndex != promptBlock.getValueCount(pos) - 1) { - strBuilder.append("\n"); - } - } - - return strBuilder.toString(); - } - - /** - * Returns the total number of positions (prompts) in the block. - */ - public int estimatedSize() { - return promptBlock.getPositionCount(); - } - - @Override - public void close() { - promptBlock.allowPassingToDifferentDriver(); - Releasables.close(promptBlock); - } - } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java new file mode 100644 index 0000000000000..58c509fd6f021 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.completion; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +public class PromptReader implements Releasable { + private final BytesRefBlock promptBlock; + private final StringBuilder strBuilder = new StringBuilder(); + private BytesRef readBuffer = new BytesRef(); + + public PromptReader(BytesRefBlock promptBlock) { + this.promptBlock = promptBlock; + } + + /** + * Reads the prompt string at the given position. + * + * @param pos the position index in the block + */ + public String readPrompt(int pos) { + if (promptBlock.isNull(pos)) { + return null; + } + strBuilder.setLength(0); + for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) { + readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer); + strBuilder.append(readBuffer.utf8ToString()); + if (valueIndex != promptBlock.getValueCount(pos) - 1) { + strBuilder.append("\n"); + } + } + return strBuilder.toString(); + } + + /** + * Returns the total number of positions (prompts) in the block. + */ + public int estimatedSize() { + return promptBlock.getPositionCount(); + } + + @Override + public void close() { + promptBlock.allowPassingToDifferentDriver(); + Releasables.close(promptBlock); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java index 4b1cfe5870ad7..fe33c799aaec8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java @@ -13,6 +13,7 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestItem; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; import java.util.ArrayList; @@ -46,7 +47,7 @@ public boolean hasNext() { } @Override - public InferenceAction.Request next() { + public BulkInferenceRequestItem next() { if (hasNext() == false) { throw new NoSuchElementException(); } @@ -59,7 +60,7 @@ public InferenceAction.Request next() { if (inputBlock.isNull(startIndex)) { remainingPositions -= 1; - return null; + return BulkInferenceRequestItem.from((InferenceAction.Request) null); } for (int i = 0; i < maxInputSize; i++) { @@ -73,7 +74,7 @@ public InferenceAction.Request next() { } remainingPositions -= inputs.size(); - return inferenceRequest(inputs); + return BulkInferenceRequestItem.from(inferenceRequest(inputs)); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java index 191664bea9a81..12950dcafbf6c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.plan.logical.inference; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -22,6 +23,7 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.inference.ResolvedInference; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -37,6 +39,8 @@ public class Completion extends InferencePlan implements TelemetryAw public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion"; + public static final List SUPPORTED_TASK_TYPES = List.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( LogicalPlan.class, "Completion", @@ -47,11 +51,18 @@ public class Completion extends InferencePlan implements TelemetryAw private List lazyOutput; public Completion(Source source, LogicalPlan p, Expression prompt, Attribute targetField) { - this(source, p, Literal.keyword(Source.EMPTY, DEFAULT_OUTPUT_FIELD_NAME), prompt, targetField); - } - - public Completion(Source source, LogicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) { - super(source, child, inferenceId); + this(source, p, Literal.NULL, null, prompt, targetField); + } + + public Completion( + Source source, + LogicalPlan child, + Expression inferenceId, + TaskType taskType, + Expression prompt, + Attribute targetField + ) { + super(source, child, inferenceId, taskType); this.prompt = prompt; this.targetField = targetField; } @@ -61,6 +72,9 @@ public Completion(StreamInput in) throws IOException { Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(LogicalPlan.class), in.readNamedWriteable(Expression.class), + in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT) + ? in.readOptional(input -> TaskType.fromString(input.readString())) + : TaskType.COMPLETION, in.readNamedWriteable(Expression.class), in.readNamedWriteable(Attribute.class) ); @@ -69,6 +83,9 @@ public Completion(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)) { + out.writeOptional((output, taskType) -> output.writeString(taskType.toString()), taskType()); + } out.writeNamedWriteable(prompt); out.writeNamedWriteable(targetField); } @@ -87,17 +104,17 @@ public Completion withInferenceId(Expression newInferenceId) { return this; } - return new Completion(source(), child(), newInferenceId, prompt, targetField); + return new Completion(source(), child(), newInferenceId, taskType(), prompt, targetField); } @Override - public Completion replaceChild(LogicalPlan newChild) { - return new Completion(source(), newChild, inferenceId(), prompt, targetField); + public List supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; } @Override - public TaskType taskType() { - return TaskType.COMPLETION; + public Completion replaceChild(LogicalPlan newChild) { + return new Completion(source(), newChild, inferenceId(), taskType(), prompt, targetField); } @Override @@ -122,7 +139,7 @@ public List generatedAttributes() { @Override public Completion withGeneratedNames(List newNames) { checkNumberOfNewNames(newNames); - return new Completion(source(), child(), inferenceId(), prompt, this.renameTargetField(newNames.get(0))); + return new Completion(source(), child(), inferenceId(), taskType(), prompt, this.renameTargetField(newNames.get(0))); } private Attribute renameTargetField(String newName) { @@ -133,6 +150,24 @@ private Attribute renameTargetField(String newName) { return targetField.withName(newName).withId(new NameId()); } + @Override + public Completion withResolvedInference(ResolvedInference resolvedInference) { + Completion completion = super.withResolvedInference(resolvedInference); + + if (completion.inferenceId().resolved()) { + return new Completion( + source(), + child(), + completion.inferenceId(), + resolvedInference.taskType(), + completion.prompt(), + completion.targetField() + ); + } + + return completion; + } + @Override protected AttributeSet computeReferences() { return prompt.references(); @@ -152,7 +187,7 @@ public void postAnalysisVerification(Failures failures) { @Override protected NodeInfo info() { - return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField); + return NodeInfo.create(this, Completion::new, child(), inferenceId(), taskType(), prompt, targetField); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java index 633ed74d8addb..70769cd7e0140 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.inference.ResolvedInference; import org.elasticsearch.xpack.esql.plan.GeneratingPlan; import org.elasticsearch.xpack.esql.plan.logical.ExecutesOn; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -32,10 +33,12 @@ public abstract class InferencePlan> ex public static final List VALID_INFERENCE_OPTION_NAMES = List.of(INFERENCE_ID_OPTION_NAME); private final Expression inferenceId; + private final TaskType taskType; - protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId) { + protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId, TaskType taskType) { super(source, child); this.inferenceId = inferenceId; + this.taskType = taskType; } @Override @@ -60,18 +63,40 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; if (super.equals(o) == false) return false; InferencePlan other = (InferencePlan) o; - return Objects.equals(inferenceId(), other.inferenceId()); + return Objects.equals(inferenceId(), other.inferenceId()) && taskType == other.taskType; } @Override public int hashCode() { - return Objects.hash(super.hashCode(), inferenceId()); + return Objects.hash(super.hashCode(), inferenceId(), taskType); } - public abstract TaskType taskType(); + public TaskType taskType() { + return taskType; + } public abstract PlanType withInferenceId(Expression newInferenceId); + public abstract List supportedTaskTypes(); + + @SuppressWarnings("unchecked") + public PlanType withResolvedInference(ResolvedInference resolvedInference) { + if (supportedTaskTypes().stream().noneMatch(resolvedInference.taskType()::equals)) { + String error = "cannot use inference endpoint [" + + resolvedInference.inferenceId() + + "] with task type [" + + resolvedInference.taskType() + + "] within a " + + nodeName() + + " command. Only inference endpoints with the task type " + + supportedTaskTypes() + + " are supported."; + return withInferenceResolutionError(resolvedInference.inferenceId(), error); + } + + return (PlanType) this; + } + public PlanType withInferenceResolutionError(String inferenceId, String error) { return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error)); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java index 6f86138397fa6..25894c75f2289 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java @@ -59,7 +59,7 @@ public Rerank( List rerankFields, Attribute scoreAttribute ) { - super(source, child, inferenceId); + super(source, child, inferenceId, TaskType.RERANK); this.queryText = queryText; this.rerankFields = rerankFields; this.scoreAttribute = scoreAttribute; @@ -96,11 +96,6 @@ public Attribute scoreAttribute() { return scoreAttribute; } - @Override - public TaskType taskType() { - return TaskType.RERANK; - } - @Override public Rerank withInferenceId(Expression newInferenceId) { if (inferenceId().equals(newInferenceId)) { @@ -109,6 +104,11 @@ public Rerank withInferenceId(Expression newInferenceId) { return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute); } + @Override + public List supportedTaskTypes() { + return List.of(TaskType.RERANK); + } + public Rerank withRerankFields(List newRerankFields) { if (rerankFields.equals(newRerankFields)) { return this; @@ -163,7 +163,7 @@ public static AttributeSet computeReferences(List fields) { } public boolean isValidRerankField(Alias rerankField) { - // Only supportinng the following datatypes for now: text, numeric and boolean + // Only supporting the following datatypes for now: text, numeric and boolean return DataType.isString(rerankField.dataType()) || rerankField.dataType() == DataType.BOOLEAN || rerankField.dataType().isNumeric(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java index 80887ad08fe69..fc8d09b83566e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.esql.plan.physical.inference; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -37,8 +39,15 @@ public class CompletionExec extends InferenceExec { private final Attribute targetField; private List lazyOutput; - public CompletionExec(Source source, PhysicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) { - super(source, child, inferenceId); + public CompletionExec( + Source source, + PhysicalPlan child, + Expression inferenceId, + TaskType taskType, + Expression prompt, + Attribute targetField + ) { + super(source, child, inferenceId, taskType); this.prompt = prompt; this.targetField = targetField; } @@ -48,6 +57,9 @@ public CompletionExec(StreamInput in) throws IOException { Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteable(Expression.class), + in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT) + ? TaskType.fromString(in.readString()) + : TaskType.COMPLETION, in.readNamedWriteable(Expression.class), in.readNamedWriteable(Attribute.class) ); @@ -61,6 +73,9 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)) { + out.writeString(taskType().toString()); + } out.writeNamedWriteable(prompt); out.writeNamedWriteable(targetField); } @@ -75,12 +90,12 @@ public Attribute targetField() { @Override protected NodeInfo info() { - return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), prompt, targetField); + return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), taskType(), prompt, targetField); } @Override public UnaryExec replaceChild(PhysicalPlan newChild) { - return new CompletionExec(source(), newChild, inferenceId(), prompt, targetField); + return new CompletionExec(source(), newChild, inferenceId(), taskType(), prompt, targetField); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java index d60a5ecccc384..f973fcd58e346 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.plan.physical.inference; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -18,16 +19,22 @@ public abstract class InferenceExec extends UnaryExec { private final Expression inferenceId; + private final TaskType taskType; - protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId) { + protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId, TaskType taskType) { super(source, child); - this.inferenceId = inferenceId; + this.inferenceId = Objects.requireNonNull(inferenceId); + this.taskType = Objects.requireNonNull(taskType); } public Expression inferenceId() { return inferenceId; } + public TaskType taskType() { + return taskType; + } + @Override public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); @@ -41,11 +48,11 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; if (super.equals(o) == false) return false; InferenceExec that = (InferenceExec) o; - return inferenceId.equals(that.inferenceId); + return inferenceId.equals(that.inferenceId) && taskType == that.taskType; } @Override public int hashCode() { - return Objects.hash(super.hashCode(), inferenceId()); + return Objects.hash(super.hashCode(), inferenceId(), taskType); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java index ad852d0ac20db..28cb18b17ecdc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; @@ -48,7 +49,7 @@ public RerankExec( List rerankFields, Attribute scoreAttribute ) { - super(source, child, inferenceId); + super(source, child, inferenceId, TaskType.RERANK); this.queryText = queryText; this.rerankFields = rerankFields; this.scoreAttribute = scoreAttribute; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 878d223535df5..29fee28801aa8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -89,6 +89,7 @@ import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.inference.InferenceService; import org.elasticsearch.xpack.esql.inference.XContentRowEncoder; +import org.elasticsearch.xpack.esql.inference.completion.ChatCompletionOperator; import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator; import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; @@ -323,7 +324,12 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti source.layout ); - return source.with(new CompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory), outputLayout); + OperatorFactory operatorFactory = switch (completion.taskType()) { + case CHAT_COMPLETION -> new ChatCompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory); + default -> new CompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory); + }; + + return source.with(operatorFactory, outputLayout); } private PhysicalOperation planFuseScoreEvalExec(FuseScoreEvalExec fuse, LocalExecutionPlannerContext context) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java index aabb18326fe11..05cee4b6a4adf 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java @@ -103,7 +103,14 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) { } if (p instanceof Completion completion) { - return new CompletionExec(completion.source(), child, completion.inferenceId(), completion.prompt(), completion.targetField()); + return new CompletionExec( + completion.source(), + child, + completion.inferenceId(), + completion.taskType(), + completion.prompt(), + completion.targetField() + ); } if (p instanceof Enrich enrich) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 417eb0f1a7834..cf89671a2543e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -4026,7 +4026,7 @@ public void testResolveCompletionInferenceIdInvalidTaskType() { "mapping-books.json", new QueryParams(), "cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command." - + " Only inference endpoints with the task type [completion] are supported" + + " Only inference endpoints with the task type [completion, chat_completion] are supported" ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java index dedbf895860b9..02eec173eb3cd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java @@ -193,7 +193,10 @@ private BulkInferenceRunnerConfig randomBulkExecutionConfig() { } private BulkInferenceRequestIterator requestIterator(List requests) { - final Iterator delegate = requests.iterator(); + final Iterator> delegate = requests.stream() + .map(BulkInferenceRequestItem::from) + .toList() + .iterator(); BulkInferenceRequestIterator iterator = mock(BulkInferenceRequestIterator.class); doAnswer(i -> delegate.hasNext()).when(iterator).hasNext(); doAnswer(i -> delegate.next()).when(iterator).next(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java index 86592256d26bc..63b9a7677911c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java @@ -32,7 +32,7 @@ private void assertIterate(int size) throws Exception { BytesRef scratch = new BytesRef(); for (int currentPos = 0; requestIterator.hasNext(); currentPos++) { - InferenceAction.Request request = requestIterator.next(); + InferenceAction.Request request = requestIterator.next().inferenceRequest(); assertThat(request.getInferenceEntityId(), equalTo(inferenceId)); scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch); assertThat(request.getInput().getFirst(), equalTo(scratch.utf8ToString())); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java index 72397efcf1be3..010a1ed35816a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java @@ -37,7 +37,7 @@ private void assertIterate(int size, int batchSize) throws Exception { BytesRef scratch = new BytesRef(); for (int currentPos = 0; requestIterator.hasNext();) { - InferenceAction.Request request = requestIterator.next(); + InferenceAction.Request request = requestIterator.next().inferenceRequest(); assertThat(request.getInferenceEntityId(), equalTo(inferenceId)); assertThat(request.getQuery(), equalTo(queryText)); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 1a181fe805e81..8a912f3d7c11c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -5488,6 +5488,7 @@ record PushdownShadowingGeneratingPlanTestCase( EMPTY, plan, randomLiteral(TEXT), + randomFrom(Completion.SUPPORTED_TASK_TYPES), new Concat(EMPTY, randomLiteral(TEXT), List.of(attr)), new ReferenceAttribute(EMPTY, "y", KEYWORD) ), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java index 11c64a82e3f57..988718ff365c2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java @@ -303,6 +303,7 @@ public void testPushDownFilterPastCompletion() { EMPTY, new Filter(EMPTY, relation, new And(EMPTY, conditionA, conditionB)), completion.inferenceId(), + completion.taskType(), completion.prompt(), completion.targetField() ), @@ -350,6 +351,7 @@ private static Completion completion(LogicalPlan child) { EMPTY, child, randomLiteral(DataType.KEYWORD), + randomFrom(Completion.SUPPORTED_TASK_TYPES), randomLiteral(randomBoolean() ? DataType.TEXT : DataType.KEYWORD), referenceAttribute(randomIdentifier(), DataType.KEYWORD) ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java index b1626e4b77ce8..973b770bedac3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java @@ -77,7 +77,14 @@ public void checkOptimizedPlan(LogicalPlan basePlan, LogicalPlan optimizedPlan) ), new PushDownLimitTestCase<>( Completion.class, - (plan, attr) -> new Completion(EMPTY, plan, randomLiteral(KEYWORD), randomLiteral(KEYWORD), attr), + (plan, attr) -> new Completion( + EMPTY, + plan, + randomLiteral(KEYWORD), + randomFrom(Completion.SUPPORTED_TASK_TYPES), + randomLiteral(KEYWORD), + attr + ), (basePlan, optimizedPlan) -> { assertEquals(basePlan.source(), optimizedPlan.source()); assertEquals(basePlan.inferenceId(), optimizedPlan.inferenceId()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java index e9810454224aa..732335f958fc4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.plan.logical.inference; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -21,7 +22,14 @@ public class CompletionSerializationTests extends AbstractLogicalPlanSerializati @Override protected Completion createTestInstance() { - return new Completion(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute()); + return new Completion( + randomSource(), + randomChild(0), + randomInferenceId(), + randomTaskTypeOrNull(), + randomPrompt(), + randomAttribute() + ); } @Override @@ -30,14 +38,16 @@ protected Completion mutateInstance(Completion instance) throws IOException { Expression inferenceId = instance.inferenceId(); Expression prompt = instance.prompt(); Attribute targetField = instance.targetField(); + TaskType taskType = instance.taskType(); - switch (between(0, 3)) { + switch (between(0, 4)) { case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId); - case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt); - case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute); + case 2 -> taskType = randomValueOtherThan(taskType, this::randomTaskTypeOrNull); + case 3 -> prompt = randomValueOtherThan(prompt, this::randomPrompt); + case 4 -> targetField = randomValueOtherThan(targetField, this::randomAttribute); } - return new Completion(instance.source(), child, inferenceId, prompt, targetField); + return new Completion(instance.source(), child, inferenceId, taskType, prompt, targetField); } private Literal randomInferenceId() { @@ -51,4 +61,12 @@ private Expression randomPrompt() { private Attribute randomAttribute() { return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean()); } + + private TaskType randomTaskType() { + return randomFrom(Completion.SUPPORTED_TASK_TYPES); + } + + private TaskType randomTaskTypeOrNull() { + return randomBoolean() ? randomTaskType() : null; + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java index 9fd41a2432462..486ae323c982c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java @@ -7,11 +7,13 @@ package org.elasticsearch.xpack.esql.plan.physical.inference; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -20,7 +22,14 @@ public class CompletionExecSerializationTests extends AbstractPhysicalPlanSerializationTests { @Override protected CompletionExec createTestInstance() { - return new CompletionExec(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute()); + return new CompletionExec( + randomSource(), + randomChild(0), + randomInferenceId(), + randomFrom(Completion.SUPPORTED_TASK_TYPES), + randomPrompt(), + randomAttribute() + ); } @Override @@ -29,14 +38,16 @@ protected CompletionExec mutateInstance(CompletionExec instance) throws IOExcept Expression inferenceId = instance.inferenceId(); Expression prompt = instance.prompt(); Attribute targetField = instance.targetField(); + TaskType taskType = instance.taskType(); - switch (between(0, 3)) { + switch (between(0, 4)) { case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId); - case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt); - case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute); + case 2 -> taskType = randomValueOtherThan(taskType, () -> randomFrom(Completion.SUPPORTED_TASK_TYPES)); + case 3 -> prompt = randomValueOtherThan(prompt, this::randomPrompt); + case 4 -> targetField = randomValueOtherThan(targetField, this::randomAttribute); } - return new CompletionExec(instance.source(), child, inferenceId, prompt, targetField); + return new CompletionExec(instance.source(), child, inferenceId, taskType, prompt, targetField); } private Literal randomInferenceId() {