diff --git a/chat_with_docs/main.py b/chat_with_docs/main.py
index 71d0cbe..af3336d 100644
--- a/chat_with_docs/main.py
+++ b/chat_with_docs/main.py
@@ -92,7 +92,7 @@ def get_docs(user_message: str, k: int = 3) -> list[str]:
@lilypad.trace(versioning="automatic")
-def bot_response(user_message: str) -> str:
+async def bot_response(user_message: str) -> str:
docs = get_docs(user_message)
return bot_response_with_docs(user_message, docs).content
diff --git a/evaluate.py b/evaluate.py
index f2ad4ac..6301098 100644
--- a/evaluate.py
+++ b/evaluate.py
@@ -2,6 +2,7 @@
"""Evaluation script for processing query files."""
import argparse
+import asyncio
from pathlib import Path
from pydantic import BaseModel
@@ -41,7 +42,7 @@ def load_queries(queries_dir: str) -> list[Query]:
return queries
-def evaluate_queries(queries_dir: str) -> None:
+async def evaluate_queries(queries_dir: str) -> list[str]:
"""Load queries and process each one through bot_response.
Args:
@@ -50,15 +51,13 @@ def evaluate_queries(queries_dir: str) -> None:
try:
queries = load_queries(queries_dir)
print(f"Loaded {len(queries)} queries from {queries_dir}")
-
- for query in queries:
- print(f"\n--- Processing Query ID: {query.id} ---")
-
- bot_response(query.content)
+ tasks = [bot_response(query.content) for query in queries]
+ results = await asyncio.gather(*tasks)
+ return results
except Exception as e:
print(f"Error during evaluation: {e}")
- return
+ raise
def main() -> None:
@@ -67,7 +66,7 @@ def main() -> None:
parser.add_argument("queries_dir", help="Directory containing markdown query files")
args = parser.parse_args()
- evaluate_queries(args.queries_dir)
+ asyncio.run(evaluate_queries(args.queries_dir))
if __name__ == "__main__":
diff --git a/judge.py b/judge.py
new file mode 100644
index 0000000..919cd2a
--- /dev/null
+++ b/judge.py
@@ -0,0 +1,84 @@
+import argparse
+import asyncio
+import json
+import os
+from typing import Annotated, cast
+
+import lilypad
+from dotenv import load_dotenv
+from lilypad.generated.types.span_more_details import SpanMoreDetails
+from lilypad.generated.types.span_public import SpanPublic
+from mirascope import llm, prompt_template
+from mirascope.core import FromCallArgs
+from pydantic import BaseModel, Field
+
+load_dotenv()
+lilypad.configure()
+
+
+class JudgeResults(BaseModel):
+ question_md: Annotated[str, FromCallArgs()]
+ answer_md: Annotated[str, FromCallArgs()]
+ docs: Annotated[list[str], FromCallArgs()]
+ reasoning: str = Field(description="The reasoning process of the judge")
+ is_correct: bool = Field(description="Whether the answer is correct")
+
+
+@lilypad.trace(versioning="automatic")
+@llm.call(provider="openai", model="gpt-4o-mini", response_model=JudgeResults)
+@prompt_template("""SYSTEM: You are a helpful assistant that judges the correctness of an answer to a question.
+
+
+- Correct answers should be based on provided documents.
+
+
+
+{question_md}
+
+
+{answer_md}
+
+
+{docs}
+
+""")
+async def judge(question_md: str, answer_md: str, docs: list[str]): ... # noqa: ANN201
+
+
+async def get_child_span_by_name(
+ trace: SpanPublic, name: str, client: lilypad.AsyncLilypad
+) -> SpanMoreDetails:
+ span = next(
+ t for t in trace.child_spans if getattr(t.function, "name", None) == name
+ )
+ return await client.spans.get(span.uuid_)
+
+
+async def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--function_uuid", type=str, default="dd75bac5-bde0-4e55-9bf6-7c9437e8b3a2"
+ )
+ parser.add_argument("--trace_id", type=str, required=True)
+ args = parser.parse_args()
+
+ client = lilypad.AsyncLilypad(
+ base_url="https://lilypad-api.mirascope.com/v0",
+ api_key=os.environ["LILYPAD_API_KEY"],
+ )
+ traces = await client.projects.functions.spans.list_paginated(
+ os.environ["LILYPAD_PROJECT_ID"], args.function_uuid
+ )
+ trace = next(t for t in traces.items if t.trace_id == args.trace_id)
+ docs_span = await get_child_span_by_name(trace, "get_docs", client)
+ docs = json.loads(docs_span.output or "[]")
+ annotation = (trace.annotations or [])[0]
+ question_md: str = cast(str, (annotation.span.arg_values or {})["user_message"])
+ answer_md: str = annotation.span.output or ""
+
+ result = await judge(question_md, answer_md, docs)
+ print(result.reasoning, "\n", f"Is correct: {result.is_correct}") # noqa: T201
+
+
+if __name__ == "__main__":
+ asyncio.run(main())