diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/TraceThreadBatchUpdate.java b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceThreadBatchUpdate.java new file mode 100644 index 00000000000..c3fb502387f --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceThreadBatchUpdate.java @@ -0,0 +1,22 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; +import lombok.Builder; + +import java.util.List; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record TraceThreadBatchUpdate( + @Schema(description = "List of thread model IDs to update", requiredMode = Schema.RequiredMode.REQUIRED) @NotNull(message = "Thread model IDs are required") @Size(min = 1, max = 1000, message = "Thread model IDs must contain between 1 and 1000 items") List<@NotNull UUID> threadModelIds, + + @Schema(description = "Update to apply to all threads", requiredMode = Schema.RequiredMode.REQUIRED) @NotNull(message = "Update is required") @Valid TraceThreadUpdate update) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java index 06003bacea4..c59bf5381d9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java @@ -18,6 +18,7 @@ import com.comet.opik.api.TraceSearchStreamRequest; import com.comet.opik.api.TraceThread; import com.comet.opik.api.TraceThreadBatchIdentifier; +import com.comet.opik.api.TraceThreadBatchUpdate; import com.comet.opik.api.TraceThreadIdentifier; import com.comet.opik.api.TraceThreadSearchStreamRequest; import com.comet.opik.api.TraceThreadUpdate; @@ -778,6 +779,30 @@ public Response updateThread(@PathParam("threadModelId") UUID threadModelId, return Response.noContent().build(); } + @PATCH + @Path("/threads/batch") + @Operation(operationId = "batchUpdateThreads", summary = "Batch update threads", description = "Batch update threads", responses = { + @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited + public Response batchUpdateThreads( + @RequestBody(content = @Content(schema = @Schema(implementation = TraceThreadBatchUpdate.class))) @NotNull @Valid TraceThreadBatchUpdate batchUpdate) { + + String workspaceId = requestContext.get().getWorkspaceId(); + String userName = requestContext.get().getUserName(); + + log.info("Batch updating threads on workspace_id: '{}', user: '{}', thread_count: '{}'", + workspaceId, userName, batchUpdate.threadModelIds().size()); + + traceThreadService.batchUpdate(batchUpdate.threadModelIds(), batchUpdate.update()) + .contextWrite(ctx -> setRequestContext(ctx, requestContext)) + .block(); + + log.info("Successfully batch updated threads on workspace_id: '{}', thread_count: '{}'", + workspaceId, batchUpdate.threadModelIds().size()); + + return Response.noContent().build(); + } + @PUT @Path("/threads/feedback-scores") @Operation(operationId = "scoreBatchOfThreads", summary = "Batch feedback scoring for threads", description = "Batch feedback scoring for threads", responses = { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadDAO.java index 5f69bc96c86..d8b14bc79e5 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadDAO.java @@ -61,6 +61,8 @@ Flux findProjectsWithPendingClosureThread Mono updateThread(UUID threadModelId, UUID projectId, TraceThreadUpdate threadUpdate); + Mono batchUpdateThreads(List threadModelIds, TraceThreadUpdate threadUpdate); + Mono setScoredAt(UUID projectId, List threadIds, Instant scoredAt); Flux> streamPendingClosureThreads(UUID projectId, Instant lastUpdatedAt); @@ -119,6 +121,29 @@ INSERT INTO trace_threads ( ; """; + private static final String BATCH_UPDATE_THREADS_SQL = """ + INSERT INTO trace_threads ( + workspace_id, project_id, thread_id, id, status, tags, created_by, last_updated_by, created_at, last_updated_at, sampling_per_rule, scored_at + ) + SELECT + workspace_id, + project_id, + thread_id, + id, + status, + :tags tags as tags, + created_by, + :user_name as last_updated_by, + created_at, + now64(6), + sampling_per_rule, + scored_at + FROM trace_threads final + WHERE workspace_id = :workspace_id + AND id IN :ids + ; + """; + private static final String FIND_THREADS_BY_PROJECT_SQL = """ SELECT * FROM trace_threads @@ -510,6 +535,44 @@ public Mono updateThread(@NonNull UUID threadModelId, @NonNull UUID projec .ifPresent(tags -> statement.bind("tags", tags.toArray(String[]::new))); return makeMonoContextAware(bindUserNameAndWorkspaceContext(statement)) + .flatMap(result -> Mono.from(result.getRowsUpdated())) + .then(); + }); + } + + @Override + public Mono batchUpdateThreads(@NonNull List threadModelIds, + @NonNull TraceThreadUpdate threadUpdate) { + if (CollectionUtils.isEmpty(threadModelIds)) { + log.warn("batchUpdateThreads called with empty threadModelIds list"); + return Mono.empty(); + } + + log.info("Starting batch update for threads, thread_count: '{}'", threadModelIds.size()); + + return asyncTemplate.nonTransaction(connection -> { + + var template = new ST(BATCH_UPDATE_THREADS_SQL); + + Optional.ofNullable(threadUpdate.tags()) + .ifPresent(tags -> template.add("tags", tags.toString())); + + String renderedSql = template.render(); + log.debug("Rendered SQL: '{}'", renderedSql); + + var statement = connection.createStatement(renderedSql) + .bind("ids", threadModelIds.toArray(UUID[]::new)); + + Optional.ofNullable(threadUpdate.tags()) + .ifPresent(tags -> statement.bind("tags", tags.toArray(String[]::new))); + + return makeMonoContextAware(bindUserNameAndWorkspaceContext(statement)) + .flatMap(result -> Mono.from(result.getRowsUpdated())) + .doOnSuccess( + rowsUpdated -> log.info("Batch update executed, rows_updated: '{}', thread_count: '{}'", + rowsUpdated, threadModelIds.size())) + .doOnError(ex -> log.error("Error executing batch update, thread_count: '{}'", + threadModelIds.size(), ex)) .then(); }); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadService.java index d36dd08b775..b047af84ddd 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadService.java @@ -68,6 +68,8 @@ Mono processProjectWithTraceThreadsPendingClosure(UUID projectId, Instant Mono update(UUID threadModelId, TraceThreadUpdate threadUpdate); + Mono batchUpdate(List threadModelIds, TraceThreadUpdate threadUpdate); + Mono setScoredAt(UUID projectId, List threadIds, Instant scoredAt); } @@ -152,6 +154,22 @@ public Mono update(@NonNull UUID threadModelId, @NonNull TraceThreadUpdate traceThreadIdModel.projectId(), threadUpdate)); } + @Override + public Mono batchUpdate(@NonNull List threadModelIds, @NonNull TraceThreadUpdate threadUpdate) { + if (CollectionUtils.isEmpty(threadModelIds)) { + log.info("No thread model IDs provided for batch update. Skipping update."); + return Mono.empty(); + } + + log.info("Batch updating threads, thread_count: '{}'", threadModelIds.size()); + + return traceThreadDAO.batchUpdateThreads(threadModelIds, threadUpdate) + .doOnSuccess(__ -> log.info("Successfully batch updated threads, thread_count: '{}'", + threadModelIds.size())) + .doOnError(ex -> log.error("Error batch updating threads, thread_count: '{}'", + threadModelIds.size(), ex)); + } + @Override public Mono setScoredAt(@NonNull UUID projectId, @NonNull List threadIds, @NonNull Instant scoredAt) { if (threadIds.isEmpty()) { diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java index 5004d076a3d..fc315b834c2 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java @@ -13,6 +13,7 @@ import com.comet.opik.api.TraceSearchStreamRequest; import com.comet.opik.api.TraceThread; import com.comet.opik.api.TraceThreadBatchIdentifier; +import com.comet.opik.api.TraceThreadBatchUpdate; import com.comet.opik.api.TraceThreadIdentifier; import com.comet.opik.api.TraceThreadSearchStreamRequest; import com.comet.opik.api.TraceThreadUpdate; @@ -651,6 +652,22 @@ public void updateThread(TraceThreadUpdate threadUpdate, UUID threadModelId, Str } } + public void batchUpdateThreads(TraceThreadBatchUpdate batchUpdate, String apiKey, String workspaceName, + int expectedStatus) { + + try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) + .path("threads") + .path("batch") + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .method(HttpMethod.PATCH, Entity.json(batchUpdate))) { + + assertThat(response.getStatus()).isEqualTo(expectedStatus); + } + } + public Response callBatchCreateTracesWithCookie(List traces, String sessionToken, String workspaceName) { return client.target(RESOURCE_PATH.formatted(baseURI)) .path("batch") diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java index 82dee61aec0..294c3a5af19 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java @@ -21,6 +21,7 @@ import com.comet.opik.api.TraceSearchStreamRequest; import com.comet.opik.api.TraceThread; import com.comet.opik.api.TraceThread.TraceThreadPage; +import com.comet.opik.api.TraceThreadBatchUpdate; import com.comet.opik.api.TraceThreadIdentifier; import com.comet.opik.api.TraceThreadStatus; import com.comet.opik.api.TraceThreadUpdate; @@ -10923,6 +10924,84 @@ void createAndUpdateThread() { .build()), List.of(actualThread)); } + + @Test + void batchUpdateThreads() { + // Create multiple threads + var thread1 = createThread(); + var thread2 = createThread(); + var thread3 = createThread(); + + // Check that they don't have tags + assertThat(thread1.tags()).isNull(); + assertThat(thread2.tags()).isNull(); + assertThat(thread3.tags()).isNull(); + + // Batch update threads + var tags = Set.of("tag1", "tag2", "tag3"); + var update = TraceThreadUpdate.builder().tags(tags).build(); + var batchUpdate = TraceThreadBatchUpdate.builder() + .threadModelIds(List.of(thread1.threadModelId(), thread2.threadModelId(), thread3.threadModelId())) + .update(update) + .build(); + + traceResourceClient.batchUpdateThreads(batchUpdate, API_KEY, TEST_WORKSPACE, 204); + + // Check that batch update was applied to all threads + var actualThread1 = traceResourceClient.getTraceThread(thread1.id(), thread1.projectId(), API_KEY, + TEST_WORKSPACE); + var actualThread2 = traceResourceClient.getTraceThread(thread2.id(), thread2.projectId(), API_KEY, + TEST_WORKSPACE); + var actualThread3 = traceResourceClient.getTraceThread(thread3.id(), thread3.projectId(), API_KEY, + TEST_WORKSPACE); + + TraceAssertions.assertThreads(List.of(thread1.toBuilder().tags(tags).build()), List.of(actualThread1)); + TraceAssertions.assertThreads(List.of(thread2.toBuilder().tags(tags).build()), List.of(actualThread2)); + TraceAssertions.assertThreads(List.of(thread3.toBuilder().tags(tags).build()), List.of(actualThread3)); + } + + @Test + void batchUpdateThreadsWithEmptyListShouldSucceed() { + var update = TraceThreadUpdate.builder().tags(Set.of("tag1")).build(); + var batchUpdate = TraceThreadBatchUpdate.builder() + .threadModelIds(List.of()) + .update(update) + .build(); + + traceResourceClient.batchUpdateThreads(batchUpdate, API_KEY, TEST_WORKSPACE, 422); + } + + @Test + void batchUpdateThreadsWithExistingTags() { + // Create threads with existing tags + var thread1 = createThread(); + var thread2 = createThread(); + + // Add initial tags + var initialTags = Set.of("existing1", "existing2"); + var initialUpdate = TraceThreadUpdate.builder().tags(initialTags).build(); + traceResourceClient.updateThread(initialUpdate, thread1.threadModelId(), API_KEY, TEST_WORKSPACE, 204); + traceResourceClient.updateThread(initialUpdate, thread2.threadModelId(), API_KEY, TEST_WORKSPACE, 204); + + // Batch update with new tags + var newTags = Set.of("new1", "new2"); + var update = TraceThreadUpdate.builder().tags(newTags).build(); + var batchUpdate = TraceThreadBatchUpdate.builder() + .threadModelIds(List.of(thread1.threadModelId(), thread2.threadModelId())) + .update(update) + .build(); + + traceResourceClient.batchUpdateThreads(batchUpdate, API_KEY, TEST_WORKSPACE, 204); + + // Check that tags were replaced (not appended) + var actualThread1 = traceResourceClient.getTraceThread(thread1.id(), thread1.projectId(), API_KEY, + TEST_WORKSPACE); + var actualThread2 = traceResourceClient.getTraceThread(thread2.id(), thread2.projectId(), API_KEY, + TEST_WORKSPACE); + + assertThat(actualThread1.tags()).isEqualTo(newTags); + assertThat(actualThread2.tags()).isEqualTo(newTags); + } } private TraceThread createThread() { diff --git a/apps/opik-frontend/src/api/traces/useThreadBatchUpdateMutation.ts b/apps/opik-frontend/src/api/traces/useThreadBatchUpdateMutation.ts new file mode 100644 index 00000000000..ed24c63738a --- /dev/null +++ b/apps/opik-frontend/src/api/traces/useThreadBatchUpdateMutation.ts @@ -0,0 +1,56 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { AxiosError } from "axios"; +import get from "lodash/get"; + +import api, { TRACES_REST_ENDPOINT } from "@/api/api"; +import { useToast } from "@/components/ui/use-toast"; + +type UseThreadBatchUpdateMutationParams = { + threadModelIds: string[]; + tags: string[]; +}; + +const useThreadBatchUpdateMutation = () => { + const queryClient = useQueryClient(); + const { toast } = useToast(); + + return useMutation({ + mutationFn: async ({ + threadModelIds, + tags, + }: UseThreadBatchUpdateMutationParams) => { + const { data } = await api.patch(`${TRACES_REST_ENDPOINT}threads/batch`, { + thread_model_ids: threadModelIds, + update: { + tags, + }, + }); + return data; + }, + onError: (error: AxiosError) => { + const message = get( + error, + ["response", "data", "message"], + error.message, + ); + + toast({ + title: "Error", + description: `Failed to update threads: ${message}`, + variant: "destructive", + }); + }, + onSuccess: () => { + toast({ + title: "Threads updated", + description: "Tags have been successfully updated", + }); + + return queryClient.invalidateQueries({ + queryKey: [TRACES_REST_ENDPOINT, "threads"], + }); + }, + }); +}; + +export default useThreadBatchUpdateMutation; diff --git a/apps/opik-frontend/src/components/pages-shared/traces/AddTagDialog/AddTagDialog.tsx b/apps/opik-frontend/src/components/pages-shared/traces/AddTagDialog/AddTagDialog.tsx index e2b77e50d80..5dfc825bba8 100644 --- a/apps/opik-frontend/src/components/pages-shared/traces/AddTagDialog/AddTagDialog.tsx +++ b/apps/opik-frontend/src/components/pages-shared/traces/AddTagDialog/AddTagDialog.tsx @@ -1,5 +1,5 @@ import React, { useState } from "react"; -import { Trace, Span } from "@/types/traces"; +import { Trace, Span, Thread } from "@/types/traces"; import { TRACE_DATA_TYPE } from "@/hooks/useTracesOrSpansList"; import { Dialog, @@ -13,10 +13,11 @@ import { Input } from "@/components/ui/input"; import { useToast } from "@/components/ui/use-toast"; import useTraceUpdateMutation from "@/api/traces/useTraceUpdateMutation"; import useSpanUpdateMutation from "@/api/traces/useSpanUpdateMutation"; +import useThreadBatchUpdateMutation from "@/api/traces/useThreadBatchUpdateMutation"; import useAppStore from "@/store/AppStore"; type AddTagDialogProps = { - rows: Array; + rows: Array; open: boolean | number; setOpen: (open: boolean | number) => void; projectId: string; @@ -37,6 +38,7 @@ const AddTagDialog: React.FunctionComponent = ({ const [newTag, setNewTag] = useState(""); const traceUpdateMutation = useTraceUpdateMutation(); const spanUpdateMutation = useSpanUpdateMutation(); + const threadBatchUpdateMutation = useThreadBatchUpdateMutation(); const MAX_ENTITIES = 10; const handleClose = () => { @@ -47,6 +49,39 @@ const AddTagDialog: React.FunctionComponent = ({ const handleAddTag = () => { if (!newTag) return; + if (type === TRACE_DATA_TYPE.threads) { + const threadRows = rows as Thread[]; + const threadModelIds = threadRows.map((row) => row.id); + + const allTags = new Set(); + threadRows.forEach((row) => { + (row.tags || []).forEach((tag) => allTags.add(tag)); + }); + allTags.add(newTag); + + threadBatchUpdateMutation + .mutateAsync({ + threadModelIds, + tags: Array.from(allTags), + }) + .then(() => { + toast({ + title: "Success", + description: `Tag "${newTag}" added to ${rows.length} selected threads`, + }); + + if (onSuccess) { + onSuccess(); + } + + handleClose(); + }) + .catch(() => { + // Error handling is already done by the mutation hook + }); + return; + } + const promises: Promise[] = []; rows.forEach((row) => { @@ -114,7 +149,11 @@ const AddTagDialog: React.FunctionComponent = ({ Add tag to {rows.length}{" "} - {type === TRACE_DATA_TYPE.traces ? "traces" : "spans"} + {type === TRACE_DATA_TYPE.traces + ? "traces" + : type === TRACE_DATA_TYPE.threads + ? "threads" + : "spans"} {rows.length > MAX_ENTITIES && ( diff --git a/apps/opik-frontend/src/components/pages/TracesPage/ThreadsTab/ThreadsActionsPanel.tsx b/apps/opik-frontend/src/components/pages/TracesPage/ThreadsTab/ThreadsActionsPanel.tsx index a32c873e355..93205494c15 100644 --- a/apps/opik-frontend/src/components/pages/TracesPage/ThreadsTab/ThreadsActionsPanel.tsx +++ b/apps/opik-frontend/src/components/pages/TracesPage/ThreadsTab/ThreadsActionsPanel.tsx @@ -1,5 +1,5 @@ import React, { useCallback, useRef, useState } from "react"; -import { Trash } from "lucide-react"; +import { Tag, Trash } from "lucide-react"; import get from "lodash/get"; import first from "lodash/first"; import slugify from "slugify"; @@ -11,7 +11,9 @@ import ConfirmDialog from "@/components/shared/ConfirmDialog/ConfirmDialog"; import TooltipWrapper from "@/components/shared/TooltipWrapper/TooltipWrapper"; import ExportToButton from "@/components/shared/ExportToButton/ExportToButton"; import AddToDropdown from "@/components/pages-shared/traces/AddToDropdown/AddToDropdown"; +import AddTagDialog from "@/components/pages-shared/traces/AddTagDialog/AddTagDialog"; import { COLUMN_FEEDBACK_SCORES_ID } from "@/types/shared"; +import { TRACE_DATA_TYPE } from "@/hooks/useTracesOrSpansList"; type ThreadsActionsPanelProps = { getDataForExport: () => Promise; @@ -19,6 +21,7 @@ type ThreadsActionsPanelProps = { columnsToExport: string[]; projectName: string; projectId: string; + onClearSelection?: () => void; }; const ThreadsActionsPanel: React.FunctionComponent< @@ -29,6 +32,7 @@ const ThreadsActionsPanel: React.FunctionComponent< columnsToExport, projectName, projectId, + onClearSelection, }) => { const resetKeyRef = useRef(0); const [open, setOpen] = useState(false); @@ -89,12 +93,35 @@ const ThreadsActionsPanel: React.FunctionComponent< confirmText="Delete threads" confirmButtonVariant="destructive" /> + + + + = ({ })); }, [feedbackScoresNames]); + const onClearSelection = useCallback(() => { + setRowSelection({}); + }, []); + const scoresColumnsData = useMemo(() => { // Always include "User feedback" column, even if it has no data const userFeedbackColumn: DynamicColumn = { @@ -599,6 +603,7 @@ export const ThreadsTab: React.FC = ({ getDataForExport={getDataForExport} selectedRows={selectedRows} columnsToExport={columnsToExport} + onClearSelection={onClearSelection} /> diff --git a/apps/opik-frontend/src/hooks/useTracesOrSpansList.ts b/apps/opik-frontend/src/hooks/useTracesOrSpansList.ts index 1afa5bedbde..fc972ae1438 100644 --- a/apps/opik-frontend/src/hooks/useTracesOrSpansList.ts +++ b/apps/opik-frontend/src/hooks/useTracesOrSpansList.ts @@ -16,6 +16,7 @@ import { Sorting } from "@/types/sorting"; export enum TRACE_DATA_TYPE { traces = "traces", spans = "spans", + threads = "threads", } type UseTracesOrSpansListParams = {