Skip to content

Commit c963611

Browse files
authored
Merge pull request #5 from Mirascope/live-stream2
Live stream 2: initial set up for LLM as a judge
2 parents 8566f48 + 4469a19 commit c963611

File tree

3 files changed

+92
-9
lines changed

3 files changed

+92
-9
lines changed

chat_with_docs/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def get_docs(user_message: str, k: int = 3) -> list[str]:
9292

9393

9494
@lilypad.trace(versioning="automatic")
95-
def bot_response(user_message: str) -> str:
95+
async def bot_response(user_message: str) -> str:
9696
docs = get_docs(user_message)
9797
return bot_response_with_docs(user_message, docs).content
9898

evaluate.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Evaluation script for processing query files."""
33

44
import argparse
5+
import asyncio
56
from pathlib import Path
67

78
from pydantic import BaseModel
@@ -41,7 +42,7 @@ def load_queries(queries_dir: str) -> list[Query]:
4142
return queries
4243

4344

44-
def evaluate_queries(queries_dir: str) -> None:
45+
async def evaluate_queries(queries_dir: str) -> list[str]:
4546
"""Load queries and process each one through bot_response.
4647
4748
Args:
@@ -50,15 +51,13 @@ def evaluate_queries(queries_dir: str) -> None:
5051
try:
5152
queries = load_queries(queries_dir)
5253
print(f"Loaded {len(queries)} queries from {queries_dir}")
53-
54-
for query in queries:
55-
print(f"\n--- Processing Query ID: {query.id} ---")
56-
57-
bot_response(query.content)
54+
tasks = [bot_response(query.content) for query in queries]
55+
results = await asyncio.gather(*tasks)
56+
return results
5857

5958
except Exception as e:
6059
print(f"Error during evaluation: {e}")
61-
return
60+
raise
6261

6362

6463
def main() -> None:
@@ -67,7 +66,7 @@ def main() -> None:
6766
parser.add_argument("queries_dir", help="Directory containing markdown query files")
6867

6968
args = parser.parse_args()
70-
evaluate_queries(args.queries_dir)
69+
asyncio.run(evaluate_queries(args.queries_dir))
7170

7271

7372
if __name__ == "__main__":

judge.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import argparse
2+
import asyncio
3+
import json
4+
import os
5+
from typing import Annotated, cast
6+
7+
import lilypad
8+
from dotenv import load_dotenv
9+
from lilypad.generated.types.span_more_details import SpanMoreDetails
10+
from lilypad.generated.types.span_public import SpanPublic
11+
from mirascope import llm, prompt_template
12+
from mirascope.core import FromCallArgs
13+
from pydantic import BaseModel, Field
14+
15+
load_dotenv()
16+
lilypad.configure()
17+
18+
19+
class JudgeResults(BaseModel):
20+
question_md: Annotated[str, FromCallArgs()]
21+
answer_md: Annotated[str, FromCallArgs()]
22+
docs: Annotated[list[str], FromCallArgs()]
23+
reasoning: str = Field(description="The reasoning process of the judge")
24+
is_correct: bool = Field(description="Whether the answer is correct")
25+
26+
27+
@lilypad.trace(versioning="automatic")
28+
@llm.call(provider="openai", model="gpt-4o-mini", response_model=JudgeResults)
29+
@prompt_template("""SYSTEM: You are a helpful assistant that judges the correctness of an answer to a question.
30+
31+
<criteria>
32+
- Correct answers should be based on provided documents.
33+
</criteria>
34+
35+
<question>
36+
{question_md}
37+
</question>
38+
<answer>
39+
{answer_md}
40+
</answer>
41+
<docs>
42+
{docs}
43+
</docs>
44+
""")
45+
async def judge(question_md: str, answer_md: str, docs: list[str]): ... # noqa: ANN201
46+
47+
48+
async def get_child_span_by_name(
49+
trace: SpanPublic, name: str, client: lilypad.AsyncLilypad
50+
) -> SpanMoreDetails:
51+
span = next(
52+
t for t in trace.child_spans if getattr(t.function, "name", None) == name
53+
)
54+
return await client.spans.get(span.uuid_)
55+
56+
57+
async def main() -> None:
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument(
60+
"--function_uuid", type=str, default="dd75bac5-bde0-4e55-9bf6-7c9437e8b3a2"
61+
)
62+
parser.add_argument("--trace_id", type=str, required=True)
63+
args = parser.parse_args()
64+
65+
client = lilypad.AsyncLilypad(
66+
base_url="https://lilypad-api.mirascope.com/v0",
67+
api_key=os.environ["LILYPAD_API_KEY"],
68+
)
69+
traces = await client.projects.functions.spans.list_paginated(
70+
os.environ["LILYPAD_PROJECT_ID"], args.function_uuid
71+
)
72+
trace = next(t for t in traces.items if t.trace_id == args.trace_id)
73+
docs_span = await get_child_span_by_name(trace, "get_docs", client)
74+
docs = json.loads(docs_span.output or "[]")
75+
annotation = (trace.annotations or [])[0]
76+
question_md: str = cast(str, (annotation.span.arg_values or {})["user_message"])
77+
answer_md: str = annotation.span.output or ""
78+
79+
result = await judge(question_md, answer_md, docs)
80+
print(result.reasoning, "\n", f"Is correct: {result.is_correct}") # noqa: T201
81+
82+
83+
if __name__ == "__main__":
84+
asyncio.run(main())

0 commit comments

Comments
 (0)