diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java index 5dc597d37c..9eab72ff57 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; +import static org.opensearch.ml.common.CommonValue.VERSION_3_3_0; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import java.io.IOException; @@ -51,12 +52,14 @@ public class MLAgentUpdateInput implements ToXContentObject, Writeable { public static final String MEMORY_TYPE_FIELD = "type"; public static final String MEMORY_SESSION_ID_FIELD = "session_id"; public static final String MEMORY_WINDOW_SIZE_FIELD = "window_size"; + public static final String TYPE_FIELD = "type"; public static final String APP_TYPE_FIELD = "app_type"; public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; @Getter private String agentId; private String name; + private String type; private String description; private String llmModelId; private Map llmParameters; @@ -73,6 +76,7 @@ public class MLAgentUpdateInput implements ToXContentObject, Writeable { public MLAgentUpdateInput( String agentId, String name, + String type, String description, String llmModelId, Map llmParameters, @@ -87,6 +91,7 @@ public MLAgentUpdateInput( ) { this.agentId = agentId; this.name = name; + this.type = type; this.description = description; this.llmModelId = llmModelId; this.llmParameters = llmParameters; @@ -105,6 +110,7 @@ public MLAgentUpdateInput(StreamInput in) throws IOException { Version streamInputVersion = in.getVersion(); agentId = in.readString(); name = in.readOptionalString(); + type = streamInputVersion.onOrAfter(VERSION_3_3_0) ? in.readOptionalString() : null; description = in.readOptionalString(); llmModelId = in.readOptionalString(); if (in.readBoolean()) { @@ -135,6 +141,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (name != null) { builder.field(AGENT_NAME_FIELD, name); } + if (type != null) { + builder.field(TYPE_FIELD, type); + } if (description != null) { builder.field(DESCRIPTION_FIELD, description); } @@ -185,6 +194,9 @@ public void writeTo(StreamOutput out) throws IOException { Version streamOutputVersion = out.getVersion(); out.writeString(agentId); out.writeOptionalString(name); + if (streamOutputVersion.onOrAfter(VERSION_3_3_0)) { + out.writeOptionalString(type); + } out.writeOptionalString(description); out.writeOptionalString(llmModelId); if (llmParameters != null && !llmParameters.isEmpty()) { @@ -221,6 +233,7 @@ public void writeTo(StreamOutput out) throws IOException { public static MLAgentUpdateInput parse(XContentParser parser) throws IOException { String agentId = null; String name = null; + String type = null; String description = null; String llmModelId = null; Map llmParameters = null; @@ -244,6 +257,9 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException case AGENT_NAME_FIELD: name = parser.text(); break; + case TYPE_FIELD: + type = parser.text(); + break; case DESCRIPTION_FIELD: description = parser.text(); break; @@ -314,6 +330,7 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException return new MLAgentUpdateInput( agentId, name, + type, description, llmModelId, llmParameters, @@ -329,6 +346,9 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException } public MLAgent toMLAgent(MLAgent originalAgent) { + if (type != null && !type.equals(originalAgent.getType())) { + throw new IllegalArgumentException("Agent type cannot be updated"); + } LLMSpec finalLlm; if (llmModelId == null && (llmParameters == null || llmParameters.isEmpty())) { finalLlm = originalAgent.getLlm(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java index d7a143fa57..efcc6602a1 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java @@ -392,6 +392,7 @@ public void testParseWithAllFields() throws Exception { { "agent_id": "test-agent-id", "name": "test-agent", + "type": "flow", "description": "test description", "llm": { "model_id": "test-model-id", @@ -423,6 +424,7 @@ public void testParseWithAllFields() throws Exception { """; testParseFromJsonString(inputStr, parsedInput -> { assertEquals("test-agent", parsedInput.getName()); + assertEquals("flow", parsedInput.getType()); assertEquals("test description", parsedInput.getDescription()); assertEquals("test-model-id", parsedInput.getLlmModelId()); assertEquals(1, parsedInput.getTools().size()); @@ -959,6 +961,41 @@ public void testCombinedLLMAndMemoryPartialUpdates() { assertEquals(Integer.valueOf(10), updatedAgent.getMemory().getWindowSize()); // Updated } + @Test + public void testAgentTypeValidation() { + MLAgent originalAgent = MLAgent.builder().type(MLAgentType.FLOW.name()).name("Test Agent").build(); + + // Same type should be allowed + MLAgentUpdateInput sameTypeInput = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .type(MLAgentType.FLOW.name()) + .name("Updated Name") + .build(); + + MLAgent updatedAgent = sameTypeInput.toMLAgent(originalAgent); + assertEquals(MLAgentType.FLOW.name(), updatedAgent.getType()); + assertEquals("Updated Name", updatedAgent.getName()); + + // Different type should throw error + MLAgentUpdateInput differentTypeInput = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .type(MLAgentType.CONVERSATIONAL.name()) + .name("Updated Name") + .build(); + + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { differentTypeInput.toMLAgent(originalAgent); }); + assertEquals("Agent type cannot be updated", e.getMessage()); + + // No type provided should work (original type) + MLAgentUpdateInput noTypeInput = MLAgentUpdateInput.builder().agentId("test-agent-id").name("Updated Name").build(); + + MLAgent originalAgentType = noTypeInput.toMLAgent(originalAgent); + assertEquals(MLAgentType.FLOW.name(), originalAgentType.getType()); + assertEquals("Updated Name", originalAgentType.getName()); + } + @Test public void testStreamInputOutputWithVersion() throws IOException { MLAgentUpdateInput input = MLAgentUpdateInput