|
| 1 | +import argparse |
| 2 | +import asyncio |
| 3 | +import json |
| 4 | +import os |
| 5 | +from typing import Annotated, cast |
| 6 | + |
| 7 | +import lilypad |
| 8 | +from dotenv import load_dotenv |
| 9 | +from lilypad.generated.types.span_more_details import SpanMoreDetails |
| 10 | +from lilypad.generated.types.span_public import SpanPublic |
| 11 | +from mirascope import llm, prompt_template |
| 12 | +from mirascope.core import FromCallArgs |
| 13 | +from pydantic import BaseModel, Field |
| 14 | + |
| 15 | +load_dotenv() |
| 16 | +lilypad.configure() |
| 17 | + |
| 18 | + |
| 19 | +class JudgeResults(BaseModel): |
| 20 | + question_md: Annotated[str, FromCallArgs()] |
| 21 | + answer_md: Annotated[str, FromCallArgs()] |
| 22 | + docs: Annotated[list[str], FromCallArgs()] |
| 23 | + reasoning: str = Field(description="The reasoning process of the judge") |
| 24 | + is_correct: bool = Field(description="Whether the answer is correct") |
| 25 | + |
| 26 | + |
| 27 | +@lilypad.trace(versioning="automatic") |
| 28 | +@llm.call(provider="openai", model="gpt-4o-mini", response_model=JudgeResults) |
| 29 | +@prompt_template("""SYSTEM: You are a helpful assistant that judges the correctness of an answer to a question. |
| 30 | +
|
| 31 | +<criteria> |
| 32 | +- Correct answers should be based on provided documents. |
| 33 | +</criteria> |
| 34 | +
|
| 35 | +<question> |
| 36 | +{question_md} |
| 37 | +</question> |
| 38 | +<answer> |
| 39 | +{answer_md} |
| 40 | +</answer> |
| 41 | +<docs> |
| 42 | +{docs} |
| 43 | +</docs> |
| 44 | +""") |
| 45 | +async def judge(question_md: str, answer_md: str, docs: list[str]): ... # noqa: ANN201 |
| 46 | + |
| 47 | + |
| 48 | +async def get_child_span_by_name( |
| 49 | + trace: SpanPublic, name: str, client: lilypad.AsyncLilypad |
| 50 | +) -> SpanMoreDetails: |
| 51 | + span = next( |
| 52 | + t for t in trace.child_spans if getattr(t.function, "name", None) == name |
| 53 | + ) |
| 54 | + return await client.spans.get(span.uuid_) |
| 55 | + |
| 56 | + |
| 57 | +async def main() -> None: |
| 58 | + parser = argparse.ArgumentParser() |
| 59 | + parser.add_argument( |
| 60 | + "--function_uuid", type=str, default="dd75bac5-bde0-4e55-9bf6-7c9437e8b3a2" |
| 61 | + ) |
| 62 | + parser.add_argument("--trace_id", type=str, required=True) |
| 63 | + args = parser.parse_args() |
| 64 | + |
| 65 | + client = lilypad.AsyncLilypad( |
| 66 | + base_url="https://lilypad-api.mirascope.com/v0", |
| 67 | + api_key=os.environ["LILYPAD_API_KEY"], |
| 68 | + ) |
| 69 | + traces = await client.projects.functions.spans.list_paginated( |
| 70 | + os.environ["LILYPAD_PROJECT_ID"], args.function_uuid |
| 71 | + ) |
| 72 | + trace = next(t for t in traces.items if t.trace_id == args.trace_id) |
| 73 | + docs_span = await get_child_span_by_name(trace, "get_docs", client) |
| 74 | + docs = json.loads(docs_span.output or "[]") |
| 75 | + annotation = (trace.annotations or [])[0] |
| 76 | + question_md: str = cast(str, (annotation.span.arg_values or {})["user_message"]) |
| 77 | + answer_md: str = annotation.span.output or "" |
| 78 | + |
| 79 | + result = await judge(question_md, answer_md, docs) |
| 80 | + print(result.reasoning, "\n", f"Is correct: {result.is_correct}") # noqa: T201 |
| 81 | + |
| 82 | + |
| 83 | +if __name__ == "__main__": |
| 84 | + asyncio.run(main()) |
0 commit comments