diff --git a/chat_with_docs/main.py b/chat_with_docs/main.py index 71d0cbe..af3336d 100644 --- a/chat_with_docs/main.py +++ b/chat_with_docs/main.py @@ -92,7 +92,7 @@ def get_docs(user_message: str, k: int = 3) -> list[str]: @lilypad.trace(versioning="automatic") -def bot_response(user_message: str) -> str: +async def bot_response(user_message: str) -> str: docs = get_docs(user_message) return bot_response_with_docs(user_message, docs).content diff --git a/evaluate.py b/evaluate.py index f2ad4ac..6301098 100644 --- a/evaluate.py +++ b/evaluate.py @@ -2,6 +2,7 @@ """Evaluation script for processing query files.""" import argparse +import asyncio from pathlib import Path from pydantic import BaseModel @@ -41,7 +42,7 @@ def load_queries(queries_dir: str) -> list[Query]: return queries -def evaluate_queries(queries_dir: str) -> None: +async def evaluate_queries(queries_dir: str) -> list[str]: """Load queries and process each one through bot_response. Args: @@ -50,15 +51,13 @@ def evaluate_queries(queries_dir: str) -> None: try: queries = load_queries(queries_dir) print(f"Loaded {len(queries)} queries from {queries_dir}") - - for query in queries: - print(f"\n--- Processing Query ID: {query.id} ---") - - bot_response(query.content) + tasks = [bot_response(query.content) for query in queries] + results = await asyncio.gather(*tasks) + return results except Exception as e: print(f"Error during evaluation: {e}") - return + raise def main() -> None: @@ -67,7 +66,7 @@ def main() -> None: parser.add_argument("queries_dir", help="Directory containing markdown query files") args = parser.parse_args() - evaluate_queries(args.queries_dir) + asyncio.run(evaluate_queries(args.queries_dir)) if __name__ == "__main__": diff --git a/judge.py b/judge.py new file mode 100644 index 0000000..919cd2a --- /dev/null +++ b/judge.py @@ -0,0 +1,84 @@ +import argparse +import asyncio +import json +import os +from typing import Annotated, cast + +import lilypad +from dotenv import load_dotenv +from lilypad.generated.types.span_more_details import SpanMoreDetails +from lilypad.generated.types.span_public import SpanPublic +from mirascope import llm, prompt_template +from mirascope.core import FromCallArgs +from pydantic import BaseModel, Field + +load_dotenv() +lilypad.configure() + + +class JudgeResults(BaseModel): + question_md: Annotated[str, FromCallArgs()] + answer_md: Annotated[str, FromCallArgs()] + docs: Annotated[list[str], FromCallArgs()] + reasoning: str = Field(description="The reasoning process of the judge") + is_correct: bool = Field(description="Whether the answer is correct") + + +@lilypad.trace(versioning="automatic") +@llm.call(provider="openai", model="gpt-4o-mini", response_model=JudgeResults) +@prompt_template("""SYSTEM: You are a helpful assistant that judges the correctness of an answer to a question. + + +- Correct answers should be based on provided documents. + + + +{question_md} + + +{answer_md} + + +{docs} + +""") +async def judge(question_md: str, answer_md: str, docs: list[str]): ... # noqa: ANN201 + + +async def get_child_span_by_name( + trace: SpanPublic, name: str, client: lilypad.AsyncLilypad +) -> SpanMoreDetails: + span = next( + t for t in trace.child_spans if getattr(t.function, "name", None) == name + ) + return await client.spans.get(span.uuid_) + + +async def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--function_uuid", type=str, default="dd75bac5-bde0-4e55-9bf6-7c9437e8b3a2" + ) + parser.add_argument("--trace_id", type=str, required=True) + args = parser.parse_args() + + client = lilypad.AsyncLilypad( + base_url="https://lilypad-api.mirascope.com/v0", + api_key=os.environ["LILYPAD_API_KEY"], + ) + traces = await client.projects.functions.spans.list_paginated( + os.environ["LILYPAD_PROJECT_ID"], args.function_uuid + ) + trace = next(t for t in traces.items if t.trace_id == args.trace_id) + docs_span = await get_child_span_by_name(trace, "get_docs", client) + docs = json.loads(docs_span.output or "[]") + annotation = (trace.annotations or [])[0] + question_md: str = cast(str, (annotation.span.arg_values or {})["user_message"]) + answer_md: str = annotation.span.output or "" + + result = await judge(question_md, answer_md, docs) + print(result.reasoning, "\n", f"Is correct: {result.is_correct}") # noqa: T201 + + +if __name__ == "__main__": + asyncio.run(main())