diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java index ac23731c38b84..1038308bc6bf3 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java @@ -121,7 +121,8 @@ protected void doExecute(Task task, OpenPointInTimeRequest request, ActionListen .source(new SearchSourceBuilder().query(request.indexFilter())); searchRequest.setMaxConcurrentShardRequests(request.maxConcurrentShardRequests()); searchRequest.setCcsMinimizeRoundtrips(false); - transportSearchAction.executeRequest((SearchTask) task, searchRequest, listener.map(r -> { + + transportSearchAction.executeOpenPit((SearchTask) task, searchRequest, listener.map(r -> { assert r.pointInTimeId() != null : r; return new OpenPointInTimeResponse( r.pointInTimeId(), diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 154e2fcc8d539..8c5e861c199b6 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -167,7 +167,7 @@ public class TransportSearchAction extends HandledTransportAction listener) { - executeRequest((SearchTask) task, searchRequest, new SearchResponseActionListener(listener), AsyncSearchActionProvider::new); + executeRequest((SearchTask) task, searchRequest, listener, AsyncSearchActionProvider::new, true); } - void executeRequest( + void executeOpenPit( SearchTask task, SearchRequest original, - ActionListener listener, + ActionListener originalListener, Function, SearchPhaseProvider> searchPhaseProvider + ) { + executeRequest(task, original, originalListener, searchPhaseProvider, false); + } + + private void executeRequest( + SearchTask task, + SearchRequest original, + ActionListener originalListener, + Function, SearchPhaseProvider> searchPhaseProvider, + boolean collectSearchTelemetry ) { final long relativeStartNanos = System.nanoTime(); final SearchTimeProvider timeProvider = new SearchTimeProvider( @@ -372,48 +382,93 @@ void executeRequest( frozenIndexCheck(resolvedIndices); } - ActionListener rewriteListener = listener.delegateFailureAndWrap((delegate, rewritten) -> { + final SearchSourceBuilder source = original.source(); + if (shouldOpenPIT(source)) { + // disabling shard reordering for request + original.setPreFilterShardSize(Integer.MAX_VALUE); + openPIT( + client, + original, + searchService.getDefaultKeepAliveInMillis(), + originalListener.delegateFailureAndWrap((delegate, resp) -> { + // We set the keep alive to -1 to indicate that we don't need the pit id in the response. + // This is needed since we delete the pit prior to sending the response so the id doesn't exist anymore. + source.pointInTimeBuilder(new PointInTimeBuilder(resp.getPointInTimeId()).setKeepAlive(TimeValue.MINUS_ONE)); + var pitListener = new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + // we need to close the PIT first so we delay the release of the response to after the closing + response.incRef(); + closePIT( + client, + original.source().pointInTimeBuilder(), + () -> ActionListener.respondAndRelease(delegate, response) + ); + } + + @Override + public void onFailure(Exception e) { + closePIT(client, original.source().pointInTimeBuilder(), () -> delegate.onFailure(e)); + } + }; + executeRequest(task, original, pitListener, searchPhaseProvider, true); + }) + ); + return; + } + + ActionListener rewriteListener = originalListener.delegateFailureAndWrap((delegate, rewritten) -> { if (ccsCheckCompatibility) { checkCCSVersionCompatibility(rewritten); } - if (resolvedIndices.getRemoteClusterIndices().isEmpty()) { - executeLocalSearch( - task, - timeProvider, - rewritten, - resolvedIndices, - projectState, - SearchResponse.Clusters.EMPTY, - searchPhaseProvider.apply(delegate) - ); - } else { - if (delegate instanceof TelemetryListener tl) { - tl.setRemotes(resolvedIndices.getRemoteClusterIndices().size()); + final ActionListener searchResponseActionListener; + if (collectSearchTelemetry) { + if (collectCCSTelemetry == false || resolvedIndices.getRemoteClusterIndices().isEmpty()) { + searchResponseActionListener = new SearchTelemetryListener(delegate, searchResponseMetrics); + } else { + CCSUsage.Builder usageBuilder = new CCSUsage.Builder(); + usageBuilder.setRemotesCount(resolvedIndices.getRemoteClusterIndices().size()); + usageBuilder.setClientFromTask(task); if (task.isAsync()) { - tl.setFeature(CCSUsageTelemetry.ASYNC_FEATURE); + usageBuilder.setFeature(CCSUsageTelemetry.ASYNC_FEATURE); } if (original.pointInTimeBuilder() != null) { - tl.setFeature(CCSUsageTelemetry.PIT_FEATURE); + usageBuilder.setFeature(CCSUsageTelemetry.PIT_FEATURE); } - tl.setClient(task); // Check if any of the index patterns are wildcard patterns var localIndices = resolvedIndices.getLocalIndices(); if (localIndices != null && Arrays.stream(localIndices.indices()).anyMatch(Regex::isSimpleMatchPattern)) { - tl.setFeature(CCSUsageTelemetry.WILDCARD_FEATURE); + usageBuilder.setFeature(CCSUsageTelemetry.WILDCARD_FEATURE); } if (resolvedIndices.getRemoteClusterIndices() .values() .stream() .anyMatch(indices -> Arrays.stream(indices.indices()).anyMatch(Regex::isSimpleMatchPattern))) { - tl.setFeature(CCSUsageTelemetry.WILDCARD_FEATURE); + usageBuilder.setFeature(CCSUsageTelemetry.WILDCARD_FEATURE); + } + if (shouldMinimizeRoundtrips(rewritten)) { + usageBuilder.setFeature(CCSUsageTelemetry.MRT_FEATURE); } + searchResponseActionListener = new SearchTelemetryListener(delegate, searchResponseMetrics, usageService, usageBuilder); } + } else { + searchResponseActionListener = delegate; + } + + if (resolvedIndices.getRemoteClusterIndices().isEmpty()) { + executeLocalSearch( + task, + timeProvider, + rewritten, + resolvedIndices, + projectState, + SearchResponse.Clusters.EMPTY, + searchPhaseProvider.apply(searchResponseActionListener) + ); + } else { final TaskId parentTaskId = task.taskInfo(clusterService.localNode().getId(), false).taskId(); if (shouldMinimizeRoundtrips(rewritten)) { - if (delegate instanceof TelemetryListener tl) { - tl.setFeature(CCSUsageTelemetry.MRT_FEATURE); - } final AggregationReduceContext.Builder aggregationReduceContextBuilder = rewritten.source() != null && rewritten.source().aggregations() != null ? searchService.aggReduceContextBuilder(task::isCancelled, rewritten.source().aggregations()) @@ -439,7 +494,7 @@ void executeRequest( aggregationReduceContextBuilder, remoteClusterService, threadPool, - delegate, + searchResponseActionListener, (r, l) -> executeLocalSearch( task, timeProvider, @@ -473,7 +528,7 @@ void executeRequest( clusters, timeProvider, transportService, - delegate.delegateFailureAndWrap((finalDelegate, searchShardsResponses) -> { + searchResponseActionListener.delegateFailureAndWrap((finalDelegate, searchShardsResponses) -> { final BiFunction clusterNodeLookup = getRemoteClusterNodeLookup( searchShardsResponses ); @@ -517,49 +572,20 @@ void executeRequest( } }); - final SearchSourceBuilder source = original.source(); final boolean isExplain = source != null && source.explain() != null && source.explain(); - if (shouldOpenPIT(source)) { - // disabling shard reordering for request - original.setPreFilterShardSize(Integer.MAX_VALUE); - openPIT(client, original, searchService.getDefaultKeepAliveInMillis(), listener.delegateFailureAndWrap((delegate, resp) -> { - // We set the keep alive to -1 to indicate that we don't need the pit id in the response. - // This is needed since we delete the pit prior to sending the response so the id doesn't exist anymore. - source.pointInTimeBuilder(new PointInTimeBuilder(resp.getPointInTimeId()).setKeepAlive(TimeValue.MINUS_ONE)); - var pitListener = new SearchResponseActionListener(delegate) { - @Override - public void onResponse(SearchResponse response) { - // we need to close the PIT first so we delay the release of the response to after the closing - response.incRef(); - closePIT( - client, - original.source().pointInTimeBuilder(), - () -> ActionListener.respondAndRelease(delegate, response) - ); - } - - @Override - public void onFailure(Exception e) { - closePIT(client, original.source().pointInTimeBuilder(), () -> delegate.onFailure(e)); - } - }; - executeRequest(task, original, pitListener, searchPhaseProvider); - })); - } else { - Rewriteable.rewriteAndFetch( - original, - searchService.getRewriteContext( - timeProvider::absoluteStartMillis, - clusterState.getMinTransportVersion(), - original.getLocalClusterAlias(), - resolvedIndices, - original.pointInTimeBuilder(), - shouldMinimizeRoundtrips(original), - isExplain - ), - rewriteListener - ); - } + Rewriteable.rewriteAndFetch( + original, + searchService.getRewriteContext( + timeProvider::absoluteStartMillis, + clusterState.getMinTransportVersion(), + original.getLocalClusterAlias(), + resolvedIndices, + original.pointInTimeBuilder(), + shouldMinimizeRoundtrips(original), + isExplain + ), + rewriteListener + ); } /** @@ -2001,49 +2027,34 @@ static String[] ignoreBlockedIndices(ProjectState projectState, String[] concret .toArray(String[]::new); } return concreteIndices; - } - - private interface TelemetryListener { - void setRemotes(int count); - - void setFeature(String feature); - void setClient(Task task); } - private class SearchResponseActionListener extends DelegatingActionListener - implements - TelemetryListener { + private static class SearchTelemetryListener extends DelegatingActionListener { private final CCSUsage.Builder usageBuilder; - - SearchResponseActionListener(ActionListener listener) { + private final SearchResponseMetrics searchResponseMetrics; + private final UsageService usageService; + private final boolean collectCCSTelemetry; + + SearchTelemetryListener( + ActionListener listener, + SearchResponseMetrics searchResponseMetrics, + UsageService usageService, + CCSUsage.Builder usageBuilder + ) { super(listener); - if (listener instanceof SearchResponseActionListener srListener) { - usageBuilder = srListener.usageBuilder; - } else { - usageBuilder = new CCSUsage.Builder(); - } - } - - /** - * Should we collect telemetry for this search? - */ - private boolean collectTelemetry() { - return collectTelemetry && usageBuilder.getRemotesCount() > 0; - } - - public void setRemotes(int count) { - usageBuilder.setRemotesCount(count); - } - - @Override - public void setFeature(String feature) { - usageBuilder.setFeature(feature); + this.searchResponseMetrics = searchResponseMetrics; + this.collectCCSTelemetry = true; + this.usageService = usageService; + this.usageBuilder = usageBuilder; } - @Override - public void setClient(Task task) { - usageBuilder.setClientFromTask(task); + SearchTelemetryListener(ActionListener listener, SearchResponseMetrics searchResponseMetrics) { + super(listener); + this.searchResponseMetrics = searchResponseMetrics; + this.collectCCSTelemetry = false; + this.usageService = null; + this.usageBuilder = null; } @Override @@ -2069,7 +2080,7 @@ public void onResponse(SearchResponse searchResponse) { } searchResponseMetrics.incrementResponseCount(responseCountTotalStatus); - if (collectTelemetry()) { + if (collectCCSTelemetry) { extractCCSTelemetry(searchResponse); recordTelemetry(); } @@ -2084,7 +2095,7 @@ public void onResponse(SearchResponse searchResponse) { @Override public void onFailure(Exception e) { searchResponseMetrics.incrementResponseCount(SearchResponseMetrics.ResponseCountTotalStatus.FAILURE); - if (collectTelemetry()) { + if (collectCCSTelemetry) { usageBuilder.setFailure(e); recordTelemetry(); } @@ -2109,8 +2120,6 @@ private void extractCCSTelemetry(SearchResponse searchResponse) { usageBuilder.perClusterUsage(clusterAlias, cluster.getTook()); } } - } - } } diff --git a/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/SearchTookTimeTelemetryTests.java b/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/SearchTookTimeTelemetryTests.java index 2c000314c7dad..3d9797a5f5f48 100644 --- a/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/SearchTookTimeTelemetryTests.java +++ b/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/SearchTookTimeTelemetryTests.java @@ -16,12 +16,17 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.rescore.QueryRescorerBuilder; +import org.elasticsearch.search.retriever.RescorerRetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.telemetry.Measurement; import org.elasticsearch.telemetry.TestTelemetryPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -85,6 +90,29 @@ public void testSimpleQuery() { assertEquals(searchResponse.getTook().millis(), measurements.getFirst().getLong()); } + public void testCompoundRetriever() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.retriever( + new RescorerRetrieverBuilder( + new StandardRetrieverBuilder(new MatchAllQueryBuilder()), + List.of(new QueryRescorerBuilder(new MatchAllQueryBuilder())) + ) + ); + SearchResponse searchResponse = client().prepareSearch(indexName).setSource(searchSourceBuilder).get(); + try { + assertNoFailures(searchResponse); + assertSearchHits(searchResponse, "1", "2"); + } finally { + searchResponse.decRef(); + } + + List measurements = getTestTelemetryPlugin().getLongHistogramMeasurement(TOOK_DURATION_TOTAL_HISTOGRAM_NAME); + // compound retriever does its own search as an async action, whose took time is recorded separately + assertEquals(2, measurements.size()); + assertThat(measurements.getFirst().getLong(), Matchers.lessThan(searchResponse.getTook().millis())); + assertEquals(searchResponse.getTook().millis(), measurements.getLast().getLong()); + } + public void testMultiSearch() { MultiSearchRequestBuilder multiSearchRequestBuilder = client().prepareMultiSearch(); int numSearchRequests = randomIntBetween(3, 10);