Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Builder;

import java.util.List;
import java.util.UUID;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record TraceThreadBatchUpdate(
@Schema(description = "List of thread model IDs to update", requiredMode = Schema.RequiredMode.REQUIRED) @NotNull(message = "Thread model IDs are required") @Size(min = 1, max = 1000, message = "Thread model IDs must contain between 1 and 1000 items") List<@NotNull UUID> threadModelIds,

@Schema(description = "Update to apply to all threads", requiredMode = Schema.RequiredMode.REQUIRED) @NotNull(message = "Update is required") @Valid TraceThreadUpdate update) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.comet.opik.api.TraceSearchStreamRequest;
import com.comet.opik.api.TraceThread;
import com.comet.opik.api.TraceThreadBatchIdentifier;
import com.comet.opik.api.TraceThreadBatchUpdate;
import com.comet.opik.api.TraceThreadIdentifier;
import com.comet.opik.api.TraceThreadSearchStreamRequest;
import com.comet.opik.api.TraceThreadUpdate;
Expand Down Expand Up @@ -778,6 +779,30 @@ public Response updateThread(@PathParam("threadModelId") UUID threadModelId,
return Response.noContent().build();
}

@PATCH
@Path("/threads/batch")
@Operation(operationId = "batchUpdateThreads", summary = "Batch update threads", description = "Batch update threads", responses = {
@ApiResponse(responseCode = "204", description = "No Content")})
@RateLimited
public Response batchUpdateThreads(
@RequestBody(content = @Content(schema = @Schema(implementation = TraceThreadBatchUpdate.class))) @NotNull @Valid TraceThreadBatchUpdate batchUpdate) {

String workspaceId = requestContext.get().getWorkspaceId();
String userName = requestContext.get().getUserName();

log.info("Batch updating threads on workspace_id: '{}', user: '{}', thread_count: '{}'",
workspaceId, userName, batchUpdate.threadModelIds().size());

traceThreadService.batchUpdate(batchUpdate.threadModelIds(), batchUpdate.update())
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();

log.info("Successfully batch updated threads on workspace_id: '{}', thread_count: '{}'",
workspaceId, batchUpdate.threadModelIds().size());

return Response.noContent().build();
}

@PUT
@Path("/threads/feedback-scores")
@Operation(operationId = "scoreBatchOfThreads", summary = "Batch feedback scoring for threads", description = "Batch feedback scoring for threads", responses = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ Flux<ProjectWithPendingClosureTraceThreads> findProjectsWithPendingClosureThread

Mono<Void> updateThread(UUID threadModelId, UUID projectId, TraceThreadUpdate threadUpdate);

Mono<Void> batchUpdateThreads(List<UUID> threadModelIds, TraceThreadUpdate threadUpdate);

Mono<Long> setScoredAt(UUID projectId, List<String> threadIds, Instant scoredAt);

Flux<List<TraceThreadModel>> streamPendingClosureThreads(UUID projectId, Instant lastUpdatedAt);
Expand Down Expand Up @@ -119,6 +121,29 @@ INSERT INTO trace_threads (
;
""";

private static final String BATCH_UPDATE_THREADS_SQL = """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this query and all the related methods in this DAO, the service etc. There's a minimal difference, which is allowing multiple ids instead of one.

Let's just update the existing query and code to allow that:

  1. Update the original query where clause to id IN :ids.
  2. Update the methods to accept a Set of UUIDs instead of single one.
    2.1 You can keep the old methods, but overload them to just pass a set with a single UUID.

INSERT INTO trace_threads (
workspace_id, project_id, thread_id, id, status, tags, created_by, last_updated_by, created_at, last_updated_at, sampling_per_rule, scored_at
)
SELECT
workspace_id,
project_id,
thread_id,
id,
status,
<if(tags)> :tags <else> tags <endif> as tags,
created_by,
:user_name as last_updated_by,
created_at,
now64(6),
sampling_per_rule,
scored_at
FROM trace_threads final
WHERE workspace_id = :workspace_id
AND id IN :ids
;
""";

private static final String FIND_THREADS_BY_PROJECT_SQL = """
SELECT *
FROM trace_threads
Expand Down Expand Up @@ -510,6 +535,44 @@ public Mono<Void> updateThread(@NonNull UUID threadModelId, @NonNull UUID projec
.ifPresent(tags -> statement.bind("tags", tags.toArray(String[]::new)));

return makeMonoContextAware(bindUserNameAndWorkspaceContext(statement))
.flatMap(result -> Mono.from(result.getRowsUpdated()))
.then();
});
}

@Override
public Mono<Void> batchUpdateThreads(@NonNull List<UUID> threadModelIds,
@NonNull TraceThreadUpdate threadUpdate) {
if (CollectionUtils.isEmpty(threadModelIds)) {
log.warn("batchUpdateThreads called with empty threadModelIds list");
return Mono.empty();
}

log.info("Starting batch update for threads, thread_count: '{}'", threadModelIds.size());

return asyncTemplate.nonTransaction(connection -> {

var template = new ST(BATCH_UPDATE_THREADS_SQL);

Optional.ofNullable(threadUpdate.tags())
.ifPresent(tags -> template.add("tags", tags.toString()));

String renderedSql = template.render();
log.debug("Rendered SQL: '{}'", renderedSql);

var statement = connection.createStatement(renderedSql)
.bind("ids", threadModelIds.toArray(UUID[]::new));

Optional.ofNullable(threadUpdate.tags())
.ifPresent(tags -> statement.bind("tags", tags.toArray(String[]::new)));

return makeMonoContextAware(bindUserNameAndWorkspaceContext(statement))
.flatMap(result -> Mono.from(result.getRowsUpdated()))
.doOnSuccess(
rowsUpdated -> log.info("Batch update executed, rows_updated: '{}', thread_count: '{}'",
rowsUpdated, threadModelIds.size()))
.doOnError(ex -> log.error("Error executing batch update, thread_count: '{}'",
threadModelIds.size(), ex))
.then();
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Mono<Void> processProjectWithTraceThreadsPendingClosure(UUID projectId, Instant

Mono<Void> update(UUID threadModelId, TraceThreadUpdate threadUpdate);

Mono<Void> batchUpdate(List<UUID> threadModelIds, TraceThreadUpdate threadUpdate);

Mono<Void> setScoredAt(UUID projectId, List<String> threadIds, Instant scoredAt);
}

Expand Down Expand Up @@ -152,6 +154,22 @@ public Mono<Void> update(@NonNull UUID threadModelId, @NonNull TraceThreadUpdate
traceThreadIdModel.projectId(), threadUpdate));
}

@Override
public Mono<Void> batchUpdate(@NonNull List<UUID> threadModelIds, @NonNull TraceThreadUpdate threadUpdate) {
if (CollectionUtils.isEmpty(threadModelIds)) {
log.info("No thread model IDs provided for batch update. Skipping update.");
return Mono.empty();
}

log.info("Batch updating threads, thread_count: '{}'", threadModelIds.size());

return traceThreadDAO.batchUpdateThreads(threadModelIds, threadUpdate)
.doOnSuccess(__ -> log.info("Successfully batch updated threads, thread_count: '{}'",
threadModelIds.size()))
.doOnError(ex -> log.error("Error batch updating threads, thread_count: '{}'",
threadModelIds.size(), ex));
}

@Override
public Mono<Void> setScoredAt(@NonNull UUID projectId, @NonNull List<String> threadIds, @NonNull Instant scoredAt) {
if (threadIds.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.comet.opik.api.TraceSearchStreamRequest;
import com.comet.opik.api.TraceThread;
import com.comet.opik.api.TraceThreadBatchIdentifier;
import com.comet.opik.api.TraceThreadBatchUpdate;
import com.comet.opik.api.TraceThreadIdentifier;
import com.comet.opik.api.TraceThreadSearchStreamRequest;
import com.comet.opik.api.TraceThreadUpdate;
Expand Down Expand Up @@ -651,6 +652,22 @@ public void updateThread(TraceThreadUpdate threadUpdate, UUID threadModelId, Str
}
}

public void batchUpdateThreads(TraceThreadBatchUpdate batchUpdate, String apiKey, String workspaceName,
int expectedStatus) {

try (var response = client.target(RESOURCE_PATH.formatted(baseURI))
.path("threads")
.path("batch")
.request()
.accept(MediaType.APPLICATION_JSON_TYPE)
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(WORKSPACE_HEADER, workspaceName)
.method(HttpMethod.PATCH, Entity.json(batchUpdate))) {

assertThat(response.getStatus()).isEqualTo(expectedStatus);
}
}

public Response callBatchCreateTracesWithCookie(List<Trace> traces, String sessionToken, String workspaceName) {
return client.target(RESOURCE_PATH.formatted(baseURI))
.path("batch")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.comet.opik.api.TraceSearchStreamRequest;
import com.comet.opik.api.TraceThread;
import com.comet.opik.api.TraceThread.TraceThreadPage;
import com.comet.opik.api.TraceThreadBatchUpdate;
import com.comet.opik.api.TraceThreadIdentifier;
import com.comet.opik.api.TraceThreadStatus;
import com.comet.opik.api.TraceThreadUpdate;
Expand Down Expand Up @@ -10923,6 +10924,84 @@ void createAndUpdateThread() {
.build()), List.of(actualThread));

}

@Test
void batchUpdateThreads() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create a new separated subclass for these new tests. It's a different endpoint, so we usually follow that convention.

// Create multiple threads
var thread1 = createThread();
var thread2 = createThread();
var thread3 = createThread();
Comment on lines +10931 to +10933
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker: I'd find or create a flavour of these createThread methods to do it in batches. Test would be faster and more maintainable.


// Check that they don't have tags
assertThat(thread1.tags()).isNull();
assertThat(thread2.tags()).isNull();
assertThat(thread3.tags()).isNull();
Comment on lines +10935 to +10938
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: better keep a list of threads and iterate on all of them to assert this.


// Batch update threads
var tags = Set.of("tag1", "tag2", "tag3");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use podam factory to just generate a random TraceThreadUpdate.

var update = TraceThreadUpdate.builder().tags(tags).build();
var batchUpdate = TraceThreadBatchUpdate.builder()
.threadModelIds(List.of(thread1.threadModelId(), thread2.threadModelId(), thread3.threadModelId()))
.update(update)
.build();

traceResourceClient.batchUpdateThreads(batchUpdate, API_KEY, TEST_WORKSPACE, 204);

// Check that batch update was applied to all threads
var actualThread1 = traceResourceClient.getTraceThread(thread1.id(), thread1.projectId(), API_KEY,
TEST_WORKSPACE);
var actualThread2 = traceResourceClient.getTraceThread(thread2.id(), thread2.projectId(), API_KEY,
TEST_WORKSPACE);
var actualThread3 = traceResourceClient.getTraceThread(thread3.id(), thread3.projectId(), API_KEY,
TEST_WORKSPACE);
Comment on lines +10951 to +10956
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, please search all threads in one call if possible.


TraceAssertions.assertThreads(List.of(thread1.toBuilder().tags(tags).build()), List.of(actualThread1));
TraceAssertions.assertThreads(List.of(thread2.toBuilder().tags(tags).build()), List.of(actualThread2));
TraceAssertions.assertThreads(List.of(thread3.toBuilder().tags(tags).build()), List.of(actualThread3));
Comment on lines +10958 to +10960
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use one assertion, but passing the list with all threads instead of one by one.

}

@Test
void batchUpdateThreadsWithEmptyListShouldSucceed() {
var update = TraceThreadUpdate.builder().tags(Set.of("tag1")).build();
var batchUpdate = TraceThreadBatchUpdate.builder()
.threadModelIds(List.of())
.update(update)
.build();

traceResourceClient.batchUpdateThreads(batchUpdate, API_KEY, TEST_WORKSPACE, 422);
}

@Test
void batchUpdateThreadsWithExistingTags() {
Comment on lines +10974 to +10975
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a parameterised tests of the first one, basically:

  1. Initial value would be null or random.
  2. It's set to some other expected random value.

// Create threads with existing tags
var thread1 = createThread();
var thread2 = createThread();

// Add initial tags
var initialTags = Set.of("existing1", "existing2");
var initialUpdate = TraceThreadUpdate.builder().tags(initialTags).build();
traceResourceClient.updateThread(initialUpdate, thread1.threadModelId(), API_KEY, TEST_WORKSPACE, 204);
traceResourceClient.updateThread(initialUpdate, thread2.threadModelId(), API_KEY, TEST_WORKSPACE, 204);

// Batch update with new tags
var newTags = Set.of("new1", "new2");
var update = TraceThreadUpdate.builder().tags(newTags).build();
var batchUpdate = TraceThreadBatchUpdate.builder()
.threadModelIds(List.of(thread1.threadModelId(), thread2.threadModelId()))
.update(update)
.build();

traceResourceClient.batchUpdateThreads(batchUpdate, API_KEY, TEST_WORKSPACE, 204);

// Check that tags were replaced (not appended)
var actualThread1 = traceResourceClient.getTraceThread(thread1.id(), thread1.projectId(), API_KEY,
TEST_WORKSPACE);
var actualThread2 = traceResourceClient.getTraceThread(thread2.id(), thread2.projectId(), API_KEY,
TEST_WORKSPACE);

assertThat(actualThread1.tags()).isEqualTo(newTags);
assertThat(actualThread2.tags()).isEqualTo(newTags);
}
}

private TraceThread createThread() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { AxiosError } from "axios";
import get from "lodash/get";

import api, { TRACES_REST_ENDPOINT } from "@/api/api";
import { useToast } from "@/components/ui/use-toast";

type UseThreadBatchUpdateMutationParams = {
threadModelIds: string[];
tags: string[];
};

const useThreadBatchUpdateMutation = () => {
const queryClient = useQueryClient();
const { toast } = useToast();

return useMutation({
mutationFn: async ({
threadModelIds,
tags,
}: UseThreadBatchUpdateMutationParams) => {
const { data } = await api.patch(`${TRACES_REST_ENDPOINT}threads/batch`, {
thread_model_ids: threadModelIds,
update: {
tags,
},
});
return data;
},
onError: (error: AxiosError) => {
const message = get(
error,
["response", "data", "message"],
error.message,
);

toast({
title: "Error",
description: `Failed to update threads: ${message}`,
variant: "destructive",
});
},
onSuccess: () => {
toast({
title: "Threads updated",
description: "Tags have been successfully updated",
});
Comment on lines +44 to +47
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate success toast: The mutation's onSuccess displays a generic toast, but AddTagDialog already shows a more specific toast (Tag "${newTag}" added to ${rows.length} selected threads). Remove this duplicate toast from the mutation hook to avoid showing two toasts for the same action.

Copilot generated this review using guidance from repository custom instructions.

return queryClient.invalidateQueries({
queryKey: [TRACES_REST_ENDPOINT, "threads"],
});
},
});
};

export default useThreadBatchUpdateMutation;
Loading