diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index 7201e1835..36b5782a5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -41,6 +41,7 @@ import dev.langchain4j.agent.tool.ToolMemoryId; import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.model.output.structured.Description; import io.quarkiverse.langchain4j.runtime.ToolsRecorder; import io.quarkiverse.langchain4j.runtime.prompt.Mappable; import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker; @@ -82,6 +83,16 @@ public class ToolProcessor { Object.class); private static final Logger log = Logger.getLogger(ToolProcessor.class); + public static final DotName OPTIONAL = DotName.createSimple("java.util.Optional"); + public static final DotName OPTIONAL_INT = DotName.createSimple("java.util.OptionalInt"); + public static final DotName OPTIONAL_LONG = DotName.createSimple("java.util.OptionalLong"); + public static final DotName OPTIONAL_DOUBLE = DotName.createSimple("java.util.OptionalDouble"); + + private static final DotName DATE = DotName.createSimple("java.util.Date"); + private static final DotName LOCAL_DATE = DotName.createSimple("java.time.LocalDate"); + private static final DotName LOCAL_DATE_TIME = DotName.createSimple("java.time.LocalDateTime"); + private static final DotName OFFSET_DATE_TIME = DotName.createSimple("java.time.OffsetDateTime"); + @BuildStep public void telemetry(Capabilities capabilities, BuildProducer additionalBeanProducer) { var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER); @@ -452,7 +463,15 @@ private Iterable toJsonSchemaProperties(Type type, IndexView || DotNames.BIG_DECIMAL.equals(typeName)) { return removeNulls(NUMBER, description); } + if (LOCAL_DATE_TIME.equals(typeName) || OFFSET_DATE_TIME.equals(typeName)) { + return removeNulls(JsonSchemaProperty.from("type", "string"), JsonSchemaProperty.from("format", "date-time"), + description); + } + if (DATE.equals(typeName) || LOCAL_DATE.equals(typeName)) { + return removeNulls(JsonSchemaProperty.from("type", "string"), JsonSchemaProperty.from("format", "date"), + description); + } // TODO something else? if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals(typeName) || DotNames.SET.equals(typeName)) { ParameterizedType parameterizedType = type.kind() == Type.Kind.PARAMETERIZED_TYPE ? type.asParameterizedType() @@ -487,17 +506,35 @@ private Iterable toJsonSchemaProperties(Type type, IndexView ClassInfo classInfo = index.getClassByName(type.name()); List required = new ArrayList<>(); + if (classInfo != null) { for (FieldInfo field : classInfo.fields()) { String fieldName = field.name(); + Type fieldType = field.type(); + + boolean isOptional = isJavaOptionalType(fieldType); + if (isOptional) { + fieldType = unwrapOptionalType(fieldType); + } - Iterable fieldSchema = toJsonSchemaProperties(field.type(), index, null); + Iterable fieldSchema = toJsonSchemaProperties(fieldType, index, null); Map fieldDescription = new HashMap<>(); for (JsonSchemaProperty fieldProperty : fieldSchema) { fieldDescription.put(fieldProperty.key(), fieldProperty.value()); } + if (field.hasAnnotation(Description.class)) { + AnnotationInstance descriptionAnnotation = field.annotation(Description.class); + if (descriptionAnnotation != null && descriptionAnnotation.value() != null) { + String[] descriptionValue = descriptionAnnotation.value().asStringArray(); + fieldDescription.put("description", String.join(",", descriptionValue)); + } + } + if (!isOptional) { + required.add(fieldName); + } + properties.put(fieldName, fieldDescription); } } @@ -509,10 +546,39 @@ private Iterable toJsonSchemaProperties(Type type, IndexView throw new IllegalArgumentException("Unsupported type: " + type); } + private boolean isJavaOptionalType(Type type) { + DotName typeName = type.name(); + return typeName.equals(DotName.createSimple("java.util.Optional")) + || typeName.equals(DotName.createSimple("java.util.OptionalInt")) + || typeName.equals(DotName.createSimple("java.util.OptionalLong")) + || typeName.equals(DotName.createSimple("java.util.OptionalDouble")); + } + + private Type unwrapOptionalType(Type optionalType) { + if (optionalType.kind() == Type.Kind.PARAMETERIZED_TYPE) { + ParameterizedType parameterizedType = optionalType.asParameterizedType(); + return parameterizedType.arguments().get(0); + } + return optionalType; + } + private boolean isComplexType(Type type) { return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE; } + private boolean isOptionalField(FieldInfo field, IndexView index) { + Type fieldType = field.type(); + DotName fieldTypeName = fieldType.name(); + + if (OPTIONAL.equals(fieldTypeName) || OPTIONAL_INT.equals(fieldTypeName) || OPTIONAL_LONG.equals(fieldTypeName) + || OPTIONAL_DOUBLE.equals(fieldTypeName)) { + return true; + } + + return false; + + } + private Iterable removeNulls(JsonSchemaProperty... properties) { return stream(properties) .filter(Objects::nonNull) diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusJsonCodecFactory.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusJsonCodecFactory.java index 00f3f80f7..4914a3adc 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusJsonCodecFactory.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusJsonCodecFactory.java @@ -4,6 +4,8 @@ import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -15,6 +17,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.json.JsonReadFeature; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.databind.PropertyNamingStrategies; @@ -27,11 +30,11 @@ public class QuarkusJsonCodecFactory implements JsonCodecFactory { @Override - public Json.JsonCodec create() { + public Codec create() { return new Codec(); } - private static class Codec implements Json.JsonCodec { + public static class Codec implements Json.JsonCodec { private static final Pattern sanitizePattern = Pattern.compile("(?s)\\{.*\\}|\\[.*\\]"); @@ -60,6 +63,31 @@ public T fromJson(String json, Class type) { } } + public T fromJson(String json, Type type) { + try { + String sanitizedJson = sanitize(json, type.getClass()); + JavaType javaType = ObjectMapperHolder.MAPPER.getTypeFactory().constructType(type); + return ObjectMapperHolder.MAPPER.readValue(sanitizedJson, javaType); + } catch (JsonProcessingException e) { + if (e instanceof JsonParseException && isEnumType(type)) { + // this is the case where LangChain4j simply passes the string value of the enum to Json.fromJson() + // and Jackson does not handle it + if (type instanceof ParameterizedType) { + Class enumClass = (Class) ((ParameterizedType) type).getRawType(); + return (T) Enum.valueOf(enumClass, json); + } else { + + return (T) Enum.valueOf((Class) type, json); + } + } + throw new UncheckedIOException(e); + } + } + + private boolean isEnumType(Type type) { + return type instanceof Class && ((Class) type).isEnum(); + } + private String sanitize(String original, Class type) { if (String.class.equals(type)) { return original; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java index 33f977ff8..1c4f5ccd6 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java @@ -1,21 +1,222 @@ package io.quarkiverse.langchain4j.runtime; -import static dev.langchain4j.service.TypeUtils.getRawClass; +import java.lang.reflect.*; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; -import java.lang.reflect.Type; +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.structured.Description; +import dev.langchain4j.service.Result; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.TypeUtils; import dev.langchain4j.service.output.ServiceOutputParser; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; import io.smallrye.mutiny.Multi; public class QuarkusServiceOutputParser extends ServiceOutputParser { + private static final Pattern JSON_BLOCK_PATTERN = Pattern.compile("(?s)\\{.*\\}|\\[.*\\]"); @Override public String outputFormatInstructions(Type returnType) { - Class rawClass = getRawClass(returnType); - if (Multi.class.equals(rawClass)) { - // when Multi is used as the return type, Multi is the only supported type, thus we don't need want any formatting instructions - return ""; + boolean isOptional = isJavaOptional(returnType); + Type actualType = isOptional ? unwrapOptionalType(returnType) : returnType; + + Class rawClass = getRawClass(actualType); + + if (rawClass != String.class && rawClass != AiMessage.class && rawClass != TokenStream.class + && rawClass != ChatMessage.class + && rawClass != Response.class && !Multi.class.equals(rawClass)) { + try { + var schema = this.toJsonSchema(returnType); + return "You must answer strictly with json according to the following json schema format. Use description metadata to fill data properly: " + + schema; + } catch (Exception e) { + return ""; + } + } + + return ""; + } + + public Object parse(Response response, Type returnType) { + QuarkusJsonCodecFactory factory = new QuarkusJsonCodecFactory(); + var codec = factory.create(); + + if (TypeUtils.typeHasRawClass(returnType, Result.class)) { + returnType = TypeUtils.resolveFirstGenericParameterClass(returnType); + } + + Class rawReturnClass = TypeUtils.getRawClass(returnType); + + if (rawReturnClass == Response.class) { + return response; + } else { + AiMessage aiMessage = response.content(); + if (rawReturnClass == AiMessage.class || rawReturnClass == ChatMessage.class) { + return aiMessage; + } else { + String text = aiMessage.text(); + if (rawReturnClass == String.class) { + return text; + } else { + try { + return codec.fromJson(text, returnType); + } catch (Exception var10) { + String jsonBlock = this.extractJsonBlock(text); + return codec.fromJson(jsonBlock, returnType); + } + } + } + } + } + + private String extractJsonBlock(String text) { + Matcher matcher = JSON_BLOCK_PATTERN.matcher(text); + return matcher.find() ? matcher.group() : text; + } + + public String toJsonSchema(Type type) throws Exception { + Map schema = new HashMap<>(); + boolean isOptional = isJavaOptional(type); + Type actualType = isOptional ? unwrapOptionalType(type) : type; + + Class rawClass = getRawClass(actualType); + + if (type instanceof WildcardType wildcardType) { + Type boundType = wildcardType.getUpperBounds().length > 0 ? wildcardType.getUpperBounds()[0] + : wildcardType.getLowerBounds()[0]; + return toJsonSchema(boundType); + } + + if (rawClass == String.class || rawClass == Character.class) { + schema.put("type", "string"); + } else if (rawClass == Boolean.class || rawClass == boolean.class) { + schema.put("type", "boolean"); + } else if (Number.class.isAssignableFrom(rawClass) || rawClass.isPrimitive()) { + schema.put("type", (rawClass == double.class || rawClass == float.class) ? "number" : "integer"); + } else if (Collection.class.isAssignableFrom(rawClass) || rawClass.isArray()) { + schema.put("type", "array"); + + Type elementType = getElementType(type); + Map itemsSchema = toJsonSchemaMap(elementType); + schema.put("items", itemsSchema); + } else if (rawClass == LocalDate.class || rawClass == Date.class) { + schema.put("type", "string"); + schema.put("format", "date"); + } else if (rawClass == LocalDateTime.class || rawClass == OffsetDateTime.class) { + schema.put("type", "string"); + schema.put("format", "date-time"); + } else if (rawClass.isEnum()) { + schema.put("type", "string"); + schema.put("enum", getEnumConstants(rawClass)); + } else { + schema.put("type", "object"); + Map properties = new HashMap<>(); + + List required = new ArrayList<>(); + for (Field field : rawClass.getDeclaredFields()) { + try { + field.setAccessible(true); + Type fieldType = field.getGenericType(); + + // Check if the field is Optional and unwrap it if necessary + boolean fieldIsOptional = isJavaOptional(fieldType); + Type fieldActualType = fieldIsOptional ? unwrapOptionalType(fieldType) : fieldType; + + Map fieldSchema = toJsonSchemaMap(fieldActualType); + properties.put(field.getName(), fieldSchema); + + if (field.isAnnotationPresent(Description.class)) { + Description description = field.getAnnotation(Description.class); + fieldSchema.put("description", String.join(",", description.value())); + } + + // Only add to required if it is not Optional + if (!fieldIsOptional) { + required.add(field.getName()); + } else { + fieldSchema.put("nullable", true); // Mark as nullable in the JSON schema + } + + } catch (Exception e) { + + } + + } + schema.put("properties", properties); + if (!required.isEmpty()) { + schema.put("required", required); + } + } + if (isOptional) { + schema.put("nullable", true); + } + ObjectMapper mapper = new ObjectMapper(); + return mapper.writeValueAsString(schema); // Convert the schema map to a JSON string + } + + private boolean isJavaOptional(Type type) { + if (type instanceof ParameterizedType) { + Type rawType = ((ParameterizedType) type).getRawType(); + return rawType == Optional.class || rawType == OptionalInt.class || rawType == OptionalLong.class + || rawType == OptionalDouble.class; + } + return false; + } + + private Type unwrapOptionalType(Type optionalType) { + if (optionalType instanceof ParameterizedType) { + return ((ParameterizedType) optionalType).getActualTypeArguments()[0]; + } + return optionalType; + } + + private Class getRawClass(Type type) { + if (type instanceof Class) { + return (Class) type; + } else if (type instanceof ParameterizedType) { + return (Class) ((ParameterizedType) type).getRawType(); + } else if (type instanceof GenericArrayType) { + Type componentType = ((GenericArrayType) type).getGenericComponentType(); + return Array.newInstance(getRawClass(componentType), 0).getClass(); + } else if (type instanceof WildcardType) { + Type boundType = ((WildcardType) type).getUpperBounds().length > 0 ? ((WildcardType) type).getUpperBounds()[0] + : ((WildcardType) type).getLowerBounds()[0]; + return getRawClass(boundType); + } + throw new IllegalArgumentException("Unsupported type: " + type); + } + + private Type getElementType(Type type) { + if (type instanceof ParameterizedType) { + return ((ParameterizedType) type).getActualTypeArguments()[0]; + } else if (type instanceof GenericArrayType) { + return ((GenericArrayType) type).getGenericComponentType(); + } else if (type instanceof Class && ((Class) type).isArray()) { + return ((Class) type).getComponentType(); + } + return Object.class; // Fallback for cases where element type cannot be determined + } + + private Map toJsonSchemaMap(Type type) throws Exception { + String jsonSchema = toJsonSchema(type); + ObjectMapper mapper = new ObjectMapper(); + return mapper.readValue(jsonSchema, Map.class); + } + + private List getEnumConstants(Class enumClass) { + List constants = new ArrayList<>(); + for (Object constant : enumClass.getEnumConstants()) { + constants.add(constant.toString()); } - return super.outputFormatInstructions(returnType); + return constants; } } diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java index cdec3a989..522c486a5 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java @@ -2,6 +2,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import jakarta.annotation.PreDestroy; @@ -12,10 +13,13 @@ import org.jboss.resteasy.reactive.RestQuery; +import com.fasterxml.jackson.annotation.JsonProperty; + import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.output.structured.Description; import io.quarkiverse.langchain4j.RegisterAiService; @Path("assistant-with-tool") @@ -28,14 +32,25 @@ public AssistantWithToolsResource(Assistant assistant) { } public static class TestData { + @Description("Foo description for structured output") + @JsonProperty("foo") String foo; + + @Description("Foo description for structured output") + @JsonProperty("bar") Integer bar; - Double baz; + + @Description("Foo description for structured output") + @JsonProperty("baz") + Optional baz; + + public TestData() { + } TestData(String foo, Integer bar, Double baz) { this.foo = foo; this.bar = bar; - this.baz = baz; + this.baz = Optional.of(baz); } } @@ -48,6 +63,7 @@ public String get(@RestQuery String message) { public interface Assistant { String chat(String userMessage); + } @Singleton diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java new file mode 100644 index 000000000..d798fbbfd --- /dev/null +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java @@ -0,0 +1,85 @@ +package org.acme.example.openai.aiservices; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; + +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; + +import org.jboss.resteasy.reactive.RestQuery; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.structured.Description; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; + +@Path("collection-entity-mapping") +public class EntityMappedResource { + + private final EntityMappedDescriber describer; + + public EntityMappedResource(EntityMappedDescriber describer) { + this.describer = describer; + } + + public static class TestData { + @Description("Foo description for structured output") + @JsonProperty("foo") + String foo; + + @Description("Foo description for structured output") + @JsonProperty("bar") + Integer bar; + + @Description("Foo description for structured output") + @JsonProperty("baz") + Optional baz; + + public TestData() { + } + + TestData(String foo, Integer bar, Double baz) { + this.foo = foo; + this.bar = bar; + this.baz = Optional.of(baz); + } + } + + public static class MirrorModelSupplier implements Supplier { + @Override + public ChatLanguageModel get() { + return (messages) -> new Response<>(new AiMessage(""" + [ + { + "foo": "asd", + "bar": 1, + "baz": 2.0 + } + ] + """)); + } + } + + @POST + @Path("generateMapped") + public List generateMapped(@RestQuery String message) { + List inputs = new ArrayList<>(); + inputs.add(new TestData(message, 100, 100.0)); + + var test = describer.describeMapped(inputs); + return test; + } + + @RegisterAiService(chatLanguageModelSupplier = MirrorModelSupplier.class) + public interface EntityMappedDescriber { + + @UserMessage("This is a describer returning a collection of mapped entities") + List describeMapped(List inputs); + } +} diff --git a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithEntityMappingTest.java b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithEntityMappingTest.java new file mode 100644 index 000000000..5058b935d --- /dev/null +++ b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithEntityMappingTest.java @@ -0,0 +1,34 @@ +package org.acme.example.openai.aiservices; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.Matchers.*; + +import java.net.URL; + +import org.junit.jupiter.api.Test; + +import io.quarkus.test.common.http.TestHTTPEndpoint; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class AssistantResourceWithEntityMappingTest { + + @TestHTTPEndpoint(EntityMappedResource.class) + @TestHTTPResource + URL url; + + @Test + public void getMany() { + given() + .baseUri(url.toString() + "/generateMapped") + .queryParam("message", "This is a test") + .post() + .then() + .statusCode(200) + .body("$", hasSize(1)) // Ensure that the response is an array with exactly one item + .body("[0].foo", equalTo("asd")) // Check that foo is set correctly + .body("[0].bar", equalTo(1)) // Check that bar is 100 + .body("[0].baz", equalTo(2.0F)); + } +} diff --git a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java index 5017569db..524df4e25 100644 --- a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java +++ b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java @@ -28,4 +28,5 @@ public void get() { .statusCode(200) .body(containsString("MockGPT")); } + }