diff --git a/docs/changelog/134673.yaml b/docs/changelog/134673.yaml new file mode 100644 index 0000000000000..4b9c22d958d4a --- /dev/null +++ b/docs/changelog/134673.yaml @@ -0,0 +1,6 @@ +pr: 134673 +summary: Gracefully shutdown model deployment when node is removed from assignment + routing +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index 83db62e2da4de..846fb7a530283 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -523,7 +523,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) { if (task == null) { logger.debug( () -> format( - "[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exit", + "[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exist", deploymentId, currentNode ) @@ -547,7 +547,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) { routingStateListener ); - stopDeploymentAfterCompletingPendingWorkAsync(task, notifyDeploymentOfStopped); + stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped); } private ActionListener updateRoutingStateToStoppedListener( @@ -573,11 +573,18 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode) // This model is not routed to the current node at all TrainedModelDeploymentTask task = deploymentIdToTask.remove(deploymentId); if (task == null) { + logger.debug( + () -> format( + "[%s] Unable to stop unreferenced deployment for node %s because task does not exist", + deploymentId, + currentNode + ) + ); return; } logger.debug(() -> format("[%s] Stopping unreferenced deployment for node %s", deploymentId, currentNode)); - stopDeploymentAsync( + stopDeploymentAfterCompletingPendingWorkAsync( task, NODE_NO_LONGER_REFERENCED, ActionListener.wrap( @@ -614,8 +621,12 @@ private void stopDeploymentHelper( }); } - private void stopDeploymentAfterCompletingPendingWorkAsync(TrainedModelDeploymentTask task, ActionListener listener) { - stopDeploymentHelper(task, NODE_IS_SHUTTING_DOWN, deploymentManager::stopAfterCompletingPendingWork, listener); + private void stopDeploymentAfterCompletingPendingWorkAsync( + TrainedModelDeploymentTask task, + String reason, + ActionListener listener + ) { + stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener); } private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignments) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java index 9fbc2b43f1137..7aa4faa6459dc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java @@ -380,8 +380,7 @@ public void testClusterChangedWithResetMode() throws InterruptedException { verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } - public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork() - throws InterruptedException { + public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork() throws Exception { final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build(); String modelOne = "model-1"; @@ -430,9 +429,11 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop fail("Failed waiting for the stop process call to complete"); } - verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture()); - assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne)); - assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne)); + assertBusy(() -> { + verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture()); + assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne)); + assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne)); + }); verify(trainedModelAssignmentService, times(1)).updateModelAssignmentState( any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class), any() @@ -440,7 +441,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } - public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() { + public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() { final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); String node2 = "test-node-2"; final DiscoveryNodes nodes = DiscoveryNodes.builder() @@ -488,7 +489,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherA verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } - public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() { + public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() { final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build(); String modelOne = "model-1"; @@ -529,7 +530,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlready verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } - public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() { + public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() { final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build(); String modelOne = "model-1"; @@ -571,7 +572,46 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStarti verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } - public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException { + public void testClusterChanged_WhenNodeDoesNotExistInAssignmentRoutingTable_DoesGracefullyStopTheDeployment() throws Exception { + final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); + final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build(); + String modelOne = "model-1"; + String deploymentOne = "deployment-1"; + + var taskParams = newParams(deploymentOne, modelOne); + + ClusterChangedEvent event = new ClusterChangedEvent( + "testClusterChanged", + ClusterState.builder(new ClusterName("testClusterChanged")) + .nodes(nodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAssignmentMetadata.NAME, + TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment(deploymentOne, TrainedModelAssignment.Builder.empty(taskParams, null)) + .build() + ) + .build() + ) + .build(), + ClusterState.EMPTY_STATE + ); + + trainedModelAssignmentNodeService.prepareModelToLoad(taskParams); + trainedModelAssignmentNodeService.clusterChanged(event); + + assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any())); + // This still shouldn't trigger a cluster state update because the routing entry wasn't in the table so we won't add a new routing + // entry for stopping + verify(trainedModelAssignmentService, never()).updateModelAssignmentState( + any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class), + any() + ); + verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); + } + + public void testClusterChanged_WhenAssignmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException { final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build(); String modelOne = "model-1"; @@ -603,7 +643,6 @@ public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded ClusterState.EMPTY_STATE ); - // trainedModelAssignmentNodeService.prepareModelToLoad(taskParams); trainedModelAssignmentNodeService.clusterChanged(event); loadQueuedModels(trainedModelAssignmentNodeService); @@ -724,7 +763,9 @@ public void testClusterChanged() throws Exception { assertBusy(() -> { ArgumentCaptor stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); - verify(deploymentManager, times(1)).stopDeployment(stoppedTaskCapture.capture()); + // deployment-2 was originally started on node NODE_ID but in the latest cluster event it is no longer on that node so we will + // gracefully stop it + verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture()); assertThat(stoppedTaskCapture.getAllValues().get(0).getDeploymentId(), equalTo(deploymentTwo)); }); ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);