Skip to content

Commit 9776ec5

Browse files
authored
refactor protocol model extraction to only check extra parameters (#436)
1 parent a0bf68c commit 9776ec5

File tree

3 files changed

+211
-26
lines changed

3 files changed

+211
-26
lines changed

ragna/core/_components.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,15 @@
44
import enum
55
import functools
66
import inspect
7-
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union
7+
from typing import (
8+
AsyncIterable,
9+
AsyncIterator,
10+
Iterator,
11+
Optional,
12+
Type,
13+
Union,
14+
get_type_hints,
15+
)
816

917
import pydantic
1018
import pydantic.utils
@@ -42,30 +50,47 @@ def __repr__(self) -> str:
4250
def _protocol_models(
4351
cls,
4452
) -> dict[tuple[Type[Component], str], Type[pydantic.BaseModel]]:
53+
# This method dynamically builds a pydantic.BaseModel for the extra parameters
54+
# of each method that is listed in the __ragna_protocol_methods__ class
55+
# variable. These models are used by ragna.core.Chat._unpack_chat_params to
56+
# validate and distribute the **params passed by the user.
57+
58+
# Walk up the MRO until we find the __ragna_protocol_methods__ variable. This is
59+
# the indicator that we found the protocol class. We use this as a reference for
60+
# which params of a protocol method are part of the protocol (think positional
61+
# parameters) and which are requested by the concrete class (think keyword
62+
# parameters).
4563
protocol_cls, protocol_methods = next(
4664
(cls_, cls_.__ragna_protocol_methods__) # type: ignore[attr-defined]
4765
for cls_ in cls.__mro__
4866
if "__ragna_protocol_methods__" in cls_.__dict__
4967
)
5068
models = {}
5169
for method_name in protocol_methods:
70+
num_protocol_params = len(
71+
inspect.signature(getattr(protocol_cls, method_name)).parameters
72+
)
5273
method = getattr(cls, method_name)
53-
concrete_params = inspect.signature(method).parameters
54-
protocol_params = inspect.signature(
55-
getattr(protocol_cls, method_name)
56-
).parameters
57-
extra_param_names = concrete_params.keys() - protocol_params.keys()
58-
59-
models[(cls, method_name)] = pydantic.create_model( # type: ignore[call-overload]
74+
params = iter(inspect.signature(method).parameters.values())
75+
annotations = get_type_hints(method)
76+
# Skip over the protocol parameters in order for the model below to only
77+
# comprise concrete parameters.
78+
for _ in range(num_protocol_params):
79+
next(params)
80+
81+
models[(cls, method_name)] = pydantic.create_model(
82+
# type: ignore[call-overload]
6083
f"{cls.__name__}.{method_name}",
6184
**{
62-
(param := concrete_params[param_name]).name: (
63-
param.annotation,
64-
param.default
65-
if param.default is not inspect.Parameter.empty
66-
else ...,
85+
param.name: (
86+
annotations[param.name],
87+
(
88+
param.default
89+
if param.default is not inspect.Parameter.empty
90+
else ...
91+
),
6792
)
68-
for param_name in extra_param_names
93+
for param in params
6994
},
7095
)
7196
return models

ragna/core/_rag.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import datetime
45
import inspect
6+
import itertools
57
import uuid
8+
from collections import defaultdict
69
from typing import (
710
Any,
811
AsyncIterator,
@@ -19,6 +22,7 @@
1922
)
2023

2124
import pydantic
25+
import pydantic_core
2226
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool
2327

2428
from ._components import Assistant, Component, Message, MessageRole, SourceStorage
@@ -251,27 +255,120 @@ def _parse_documents(self, documents: Iterable[Any]) -> list[Document]:
251255
def _unpack_chat_params(
252256
self, params: dict[str, Any]
253257
) -> dict[Callable, dict[str, Any]]:
258+
# This method does two things:
259+
# 1. Validate the **params against the signatures of the protocol methods of the
260+
# used components. This makes sure that
261+
# - No parameter is passed that isn't used by at least one component
262+
# - No parameter is missing that is needed by at least one component
263+
# - No parameter is passed in the wrong type
264+
# 2. Prepare the distribution of the parameters to the protocol method that
265+
# requested them. The actual distribution happens in self._run and
266+
# self._run_gen, but is only a dictionary lookup by then.
254267
component_models = {
255268
getattr(component, name): model
256269
for component in [self.source_storage, self.assistant]
257270
for (_, name), model in component._protocol_models().items()
258271
}
259272

260273
ChatModel = merge_models(
261-
str(self.params["chat_id"]),
274+
f"{self.__module__}.{type(self).__name__}-{self.params['chat_id']}",
262275
SpecialChatParams,
263276
*component_models.values(),
264277
config=pydantic.ConfigDict(extra="forbid"),
265278
)
266279

267-
chat_params = ChatModel.model_validate(params, strict=True).model_dump(
268-
exclude_none=True
269-
)
280+
with self._format_validation_error(ChatModel):
281+
chat_model = ChatModel.model_validate(params, strict=True)
282+
283+
chat_params = chat_model.model_dump(exclude_none=True)
270284
return {
271285
fn: model(**chat_params).model_dump()
272286
for fn, model in component_models.items()
273287
}
274288

289+
@contextlib.contextmanager
290+
def _format_validation_error(
291+
self, model_cls: type[pydantic.BaseModel]
292+
) -> Iterator[None]:
293+
try:
294+
yield
295+
except pydantic.ValidationError as validation_error:
296+
errors = defaultdict(list)
297+
for error in validation_error.errors():
298+
errors[error["type"]].append(error)
299+
300+
parts = [
301+
f"Validating the Chat parameters resulted in {validation_error.error_count()} errors:",
302+
"",
303+
]
304+
305+
def format_error(
306+
error: pydantic_core.ErrorDetails,
307+
*,
308+
annotation: bool = False,
309+
value: bool = False,
310+
) -> str:
311+
param = cast(str, error["loc"][0])
312+
313+
formatted_error = f"- {param}"
314+
if annotation:
315+
annotation_ = cast(
316+
type, model_cls.model_fields[param].annotation
317+
).__name__
318+
formatted_error += f": {annotation_}"
319+
320+
if value:
321+
value_ = error["input"]
322+
formatted_error += (
323+
f" = {value_!r}" if annotation else f"={value_!r}"
324+
)
325+
326+
return formatted_error
327+
328+
if "extra_forbidden" in errors:
329+
parts.extend(
330+
[
331+
"The following parameters are unknown:",
332+
"",
333+
*[
334+
format_error(error, value=True)
335+
for error in errors["extra_forbidden"]
336+
],
337+
"",
338+
]
339+
)
340+
341+
if "missing" in errors:
342+
parts.extend(
343+
[
344+
"The following parameters are missing:",
345+
"",
346+
*[
347+
format_error(error, annotation=True)
348+
for error in errors["missing"]
349+
],
350+
"",
351+
]
352+
)
353+
354+
type_errors = ["string_type", "int_type", "float_type", "bool_type"]
355+
if any(type_error in errors for type_error in type_errors):
356+
parts.extend(
357+
[
358+
"The following parameters have the wrong type:",
359+
"",
360+
*[
361+
format_error(error, annotation=True, value=True)
362+
for error in itertools.chain.from_iterable(
363+
errors[type_error] for type_error in type_errors
364+
)
365+
],
366+
"",
367+
]
368+
)
369+
370+
raise RagnaException("\n".join(parts))
371+
275372
async def _run(
276373
self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any
277374
) -> T:

tests/core/test_rag.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import pydantic
21
import pytest
32

43
from ragna import Rag, assistants, source_storages
5-
from ragna.core import LocalDocument
4+
from ragna.core import Assistant, LocalDocument, RagnaException
65

76

87
@pytest.fixture()
@@ -14,20 +13,84 @@ def demo_document(tmp_path, request):
1413

1514

1615
class TestChat:
17-
def chat(self, documents, **params):
16+
def chat(
17+
self,
18+
documents,
19+
source_storage=source_storages.RagnaDemoSourceStorage,
20+
assistant=assistants.RagnaDemoAssistant,
21+
**params,
22+
):
1823
return Rag().chat(
1924
documents=documents,
20-
source_storage=source_storages.RagnaDemoSourceStorage,
21-
assistant=assistants.RagnaDemoAssistant,
25+
source_storage=source_storage,
26+
assistant=assistant,
2227
**params,
2328
)
2429

25-
def test_extra_params(self, demo_document):
26-
with pytest.raises(pydantic.ValidationError, match="not_supported_parameter"):
30+
def test_params_validation_unknown(self, demo_document):
31+
params = {
32+
"bool_param": False,
33+
"int_param": 1,
34+
"float_param": 0.5,
35+
"string_param": "arbitrary_value",
36+
}
37+
with pytest.raises(RagnaException, match="unknown") as exc_info:
38+
self.chat(documents=[demo_document], **params)
39+
40+
msg = str(exc_info.value)
41+
for param, value in params.items():
42+
assert f"{param}={value!r}" in msg
43+
44+
def test_params_validation_missing(self, demo_document):
45+
class ValidationAssistant(Assistant):
46+
def answer(
47+
self,
48+
prompt,
49+
sources,
50+
bool_param: bool,
51+
int_param: int,
52+
float_param: float,
53+
string_param: str,
54+
):
55+
pass
56+
57+
with pytest.raises(RagnaException, match="missing") as exc_info:
58+
self.chat(documents=[demo_document], assistant=ValidationAssistant)
59+
60+
msg = str(exc_info.value)
61+
for param, annotation in ValidationAssistant.answer.__annotations__.items():
62+
assert f"{param}: {annotation.__name__}" in msg
63+
64+
def test_params_validation_wrong_type(self, demo_document):
65+
class ValidationAssistant(Assistant):
66+
def answer(
67+
self,
68+
prompt,
69+
sources,
70+
bool_param: bool,
71+
int_param: int,
72+
float_param: float,
73+
string_param: str,
74+
):
75+
pass
76+
77+
params = {
78+
"bool_param": 1,
79+
"int_param": 0.5,
80+
"float_param": "arbitrary_value",
81+
"string_param": False,
82+
}
83+
84+
with pytest.raises(RagnaException, match="wrong type") as exc_info:
2785
self.chat(
28-
documents=[demo_document], not_supported_parameter="arbitrary_value"
86+
documents=[demo_document], assistant=ValidationAssistant, **params
2987
)
3088

89+
msg = str(exc_info.value)
90+
for param, value in params.items():
91+
annotation = ValidationAssistant.answer.__annotations__[param]
92+
assert f"{param}: {annotation.__name__} = {value!r}" in msg
93+
3194
def test_document_path(self, demo_document):
3295
chat = self.chat(documents=[demo_document.path])
3396

0 commit comments

Comments
 (0)