Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chat_with_docs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_docs(user_message: str, k: int = 3) -> list[str]:


@lilypad.trace(versioning="automatic")
def bot_response(user_message: str) -> str:
async def bot_response(user_message: str) -> str:
docs = get_docs(user_message)
return bot_response_with_docs(user_message, docs).content

Expand Down
15 changes: 7 additions & 8 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Evaluation script for processing query files."""

import argparse
import asyncio
from pathlib import Path

from pydantic import BaseModel
Expand Down Expand Up @@ -41,7 +42,7 @@ def load_queries(queries_dir: str) -> list[Query]:
return queries


def evaluate_queries(queries_dir: str) -> None:
async def evaluate_queries(queries_dir: str) -> list[str]:
"""Load queries and process each one through bot_response.

Args:
Expand All @@ -50,15 +51,13 @@ def evaluate_queries(queries_dir: str) -> None:
try:
queries = load_queries(queries_dir)
print(f"Loaded {len(queries)} queries from {queries_dir}")

for query in queries:
print(f"\n--- Processing Query ID: {query.id} ---")

bot_response(query.content)
tasks = [bot_response(query.content) for query in queries]
results = await asyncio.gather(*tasks)
return results

except Exception as e:
print(f"Error during evaluation: {e}")
return
raise


def main() -> None:
Expand All @@ -67,7 +66,7 @@ def main() -> None:
parser.add_argument("queries_dir", help="Directory containing markdown query files")

args = parser.parse_args()
evaluate_queries(args.queries_dir)
asyncio.run(evaluate_queries(args.queries_dir))


if __name__ == "__main__":
Expand Down
84 changes: 84 additions & 0 deletions judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
import asyncio
import json
import os
from typing import Annotated, cast

import lilypad
from dotenv import load_dotenv
from lilypad.generated.types.span_more_details import SpanMoreDetails
from lilypad.generated.types.span_public import SpanPublic
from mirascope import llm, prompt_template
from mirascope.core import FromCallArgs
from pydantic import BaseModel, Field

load_dotenv()
lilypad.configure()


class JudgeResults(BaseModel):
question_md: Annotated[str, FromCallArgs()]
answer_md: Annotated[str, FromCallArgs()]
docs: Annotated[list[str], FromCallArgs()]
reasoning: str = Field(description="The reasoning process of the judge")
is_correct: bool = Field(description="Whether the answer is correct")


@lilypad.trace(versioning="automatic")
@llm.call(provider="openai", model="gpt-4o-mini", response_model=JudgeResults)
@prompt_template("""SYSTEM: You are a helpful assistant that judges the correctness of an answer to a question.

<criteria>
- Correct answers should be based on provided documents.
</criteria>

<question>
{question_md}
</question>
<answer>
{answer_md}
</answer>
<docs>
{docs}
</docs>
""")
async def judge(question_md: str, answer_md: str, docs: list[str]): ... # noqa: ANN201


async def get_child_span_by_name(
trace: SpanPublic, name: str, client: lilypad.AsyncLilypad
) -> SpanMoreDetails:
span = next(
t for t in trace.child_spans if getattr(t.function, "name", None) == name
)
return await client.spans.get(span.uuid_)


async def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--function_uuid", type=str, default="dd75bac5-bde0-4e55-9bf6-7c9437e8b3a2"
)
parser.add_argument("--trace_id", type=str, required=True)
args = parser.parse_args()

client = lilypad.AsyncLilypad(
base_url="https://lilypad-api.mirascope.com/v0",
api_key=os.environ["LILYPAD_API_KEY"],
)
traces = await client.projects.functions.spans.list_paginated(
os.environ["LILYPAD_PROJECT_ID"], args.function_uuid
)
trace = next(t for t in traces.items if t.trace_id == args.trace_id)
docs_span = await get_child_span_by_name(trace, "get_docs", client)
docs = json.loads(docs_span.output or "[]")
annotation = (trace.annotations or [])[0]
question_md: str = cast(str, (annotation.span.arg_values or {})["user_message"])
answer_md: str = annotation.span.output or ""

result = await judge(question_md, answer_md, docs)
print(result.reasoning, "\n", f"Is correct: {result.is_correct}") # noqa: T201


if __name__ == "__main__":
asyncio.run(main())