Skip to content
6 changes: 6 additions & 0 deletions docs/changelog/134673.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 134673
summary: Gracefully shutdown model deployment when node is removed from assignment
routing
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
routingStateListener
);

stopDeploymentAfterCompletingPendingWorkAsync(task, notifyDeploymentOfStopped);
stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped);
}

private ActionListener<Void> updateRoutingStateToStoppedListener(
Expand All @@ -573,11 +573,18 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode)
// This model is not routed to the current node at all
TrainedModelDeploymentTask task = deploymentIdToTask.remove(deploymentId);
if (task == null) {
logger.debug(
() -> format(
"[%s] Unable to stop unreferenced deployment for node %s because task does not exit",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be "task does not exist"?

deploymentId,
currentNode
)
);
return;
}

logger.debug(() -> format("[%s] Stopping unreferenced deployment for node %s", deploymentId, currentNode));
stopDeploymentAsync(
stopDeploymentAfterCompletingPendingWorkAsync(
task,
NODE_NO_LONGER_REFERENCED,
ActionListener.wrap(
Expand Down Expand Up @@ -614,8 +621,12 @@ private void stopDeploymentHelper(
});
}

private void stopDeploymentAfterCompletingPendingWorkAsync(TrainedModelDeploymentTask task, ActionListener<Void> listener) {
stopDeploymentHelper(task, NODE_IS_SHUTTING_DOWN, deploymentManager::stopAfterCompletingPendingWork, listener);
private void stopDeploymentAfterCompletingPendingWorkAsync(
TrainedModelDeploymentTask task,
String reason,
ActionListener<Void> listener
) {
stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener);
}

private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignments) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ public void testClusterChangedWithResetMode() throws InterruptedException {
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork()
throws InterruptedException {
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork() throws Exception {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
Expand Down Expand Up @@ -430,17 +429,19 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop
fail("Failed waiting for the stop process call to complete");
}

verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
assertBusy(() -> {
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
});
verify(trainedModelAssignmentService, times(1)).updateModelAssignmentState(
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
any()
);
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
String node2 = "test-node-2";
final DiscoveryNodes nodes = DiscoveryNodes.builder()
Expand Down Expand Up @@ -488,7 +489,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherA
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
Expand Down Expand Up @@ -529,7 +530,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlready
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
Expand Down Expand Up @@ -571,7 +572,46 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStarti
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
public void testClusterChanged_WhenNodeDoesNotExistInAssignmentRoutingTable_DoesGracefullyStopTheDeployment() throws Exception {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
String deploymentOne = "deployment-1";

var taskParams = newParams(deploymentOne, modelOne);

ClusterChangedEvent event = new ClusterChangedEvent(
"testClusterChanged",
ClusterState.builder(new ClusterName("testClusterChanged"))
.nodes(nodes)
.metadata(
Metadata.builder()
.putCustom(
TrainedModelAssignmentMetadata.NAME,
TrainedModelAssignmentMetadata.Builder.empty()
.addNewAssignment(deploymentOne, TrainedModelAssignment.Builder.empty(taskParams, null))
.build()
)
.build()
)
.build(),
ClusterState.EMPTY_STATE
);

trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
trainedModelAssignmentNodeService.clusterChanged(event);

assertBusy(() -> { verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any()); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, are the curly brackets needed around the lambda here?

// This still shouldn't trigger a cluster state update because the routing entry wasn't in the table so we won't add a new routing
// entry for stopping
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
any()
);
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssignmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
Expand Down Expand Up @@ -603,7 +643,6 @@ public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded
ClusterState.EMPTY_STATE
);

// trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
trainedModelAssignmentNodeService.clusterChanged(event);
loadQueuedModels(trainedModelAssignmentNodeService);

Expand Down Expand Up @@ -724,7 +763,9 @@ public void testClusterChanged() throws Exception {

assertBusy(() -> {
ArgumentCaptor<TrainedModelDeploymentTask> stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
verify(deploymentManager, times(1)).stopDeployment(stoppedTaskCapture.capture());
// deployment-2 was originally started on node NODE_ID but in the latest cluster event it is no longer on that node so we will
// gracefully stop it
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture());
assertThat(stoppedTaskCapture.getAllValues().get(0).getDeploymentId(), equalTo(deploymentTwo));
});
ArgumentCaptor<TrainedModelDeploymentTask> startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
Expand Down