| 
1 | 1 | from __future__ import annotations  | 
2 | 2 | 
 
  | 
 | 3 | +import contextlib  | 
3 | 4 | import datetime  | 
4 | 5 | import inspect  | 
 | 6 | +import itertools  | 
5 | 7 | import uuid  | 
 | 8 | +from collections import defaultdict  | 
6 | 9 | from typing import (  | 
7 | 10 |     Any,  | 
8 | 11 |     AsyncIterator,  | 
 | 
19 | 22 | )  | 
20 | 23 | 
 
  | 
21 | 24 | import pydantic  | 
 | 25 | +import pydantic_core  | 
22 | 26 | from starlette.concurrency import iterate_in_threadpool, run_in_threadpool  | 
23 | 27 | 
 
  | 
24 | 28 | from ._components import Assistant, Component, Message, MessageRole, SourceStorage  | 
@@ -251,27 +255,120 @@ def _parse_documents(self, documents: Iterable[Any]) -> list[Document]:  | 
251 | 255 |     def _unpack_chat_params(  | 
252 | 256 |         self, params: dict[str, Any]  | 
253 | 257 |     ) -> dict[Callable, dict[str, Any]]:  | 
 | 258 | +        # This method does two things:  | 
 | 259 | +        # 1. Validate the **params against the signatures of the protocol methods of the  | 
 | 260 | +        #    used components. This makes sure that  | 
 | 261 | +        #    - No parameter is passed that isn't used by at least one component  | 
 | 262 | +        #    - No parameter is missing that is needed by at least one component  | 
 | 263 | +        #    - No parameter is passed in the wrong type  | 
 | 264 | +        # 2. Prepare the distribution of the parameters to the protocol method that  | 
 | 265 | +        #    requested them. The actual distribution happens in self._run and  | 
 | 266 | +        #    self._run_gen, but is only a dictionary lookup by then.  | 
254 | 267 |         component_models = {  | 
255 | 268 |             getattr(component, name): model  | 
256 | 269 |             for component in [self.source_storage, self.assistant]  | 
257 | 270 |             for (_, name), model in component._protocol_models().items()  | 
258 | 271 |         }  | 
259 | 272 | 
 
  | 
260 | 273 |         ChatModel = merge_models(  | 
261 |  | -            str(self.params["chat_id"]),  | 
 | 274 | +            f"{self.__module__}.{type(self).__name__}-{self.params['chat_id']}",  | 
262 | 275 |             SpecialChatParams,  | 
263 | 276 |             *component_models.values(),  | 
264 | 277 |             config=pydantic.ConfigDict(extra="forbid"),  | 
265 | 278 |         )  | 
266 | 279 | 
 
  | 
267 |  | -        chat_params = ChatModel.model_validate(params, strict=True).model_dump(  | 
268 |  | -            exclude_none=True  | 
269 |  | -        )  | 
 | 280 | +        with self._format_validation_error(ChatModel):  | 
 | 281 | +            chat_model = ChatModel.model_validate(params, strict=True)  | 
 | 282 | + | 
 | 283 | +        chat_params = chat_model.model_dump(exclude_none=True)  | 
270 | 284 |         return {  | 
271 | 285 |             fn: model(**chat_params).model_dump()  | 
272 | 286 |             for fn, model in component_models.items()  | 
273 | 287 |         }  | 
274 | 288 | 
 
  | 
 | 289 | +    @contextlib.contextmanager  | 
 | 290 | +    def _format_validation_error(  | 
 | 291 | +        self, model_cls: type[pydantic.BaseModel]  | 
 | 292 | +    ) -> Iterator[None]:  | 
 | 293 | +        try:  | 
 | 294 | +            yield  | 
 | 295 | +        except pydantic.ValidationError as validation_error:  | 
 | 296 | +            errors = defaultdict(list)  | 
 | 297 | +            for error in validation_error.errors():  | 
 | 298 | +                errors[error["type"]].append(error)  | 
 | 299 | + | 
 | 300 | +            parts = [  | 
 | 301 | +                f"Validating the Chat parameters resulted in {validation_error.error_count()} errors:",  | 
 | 302 | +                "",  | 
 | 303 | +            ]  | 
 | 304 | + | 
 | 305 | +            def format_error(  | 
 | 306 | +                error: pydantic_core.ErrorDetails,  | 
 | 307 | +                *,  | 
 | 308 | +                annotation: bool = False,  | 
 | 309 | +                value: bool = False,  | 
 | 310 | +            ) -> str:  | 
 | 311 | +                param = cast(str, error["loc"][0])  | 
 | 312 | + | 
 | 313 | +                formatted_error = f"- {param}"  | 
 | 314 | +                if annotation:  | 
 | 315 | +                    annotation_ = cast(  | 
 | 316 | +                        type, model_cls.model_fields[param].annotation  | 
 | 317 | +                    ).__name__  | 
 | 318 | +                    formatted_error += f": {annotation_}"  | 
 | 319 | + | 
 | 320 | +                if value:  | 
 | 321 | +                    value_ = error["input"]  | 
 | 322 | +                    formatted_error += (  | 
 | 323 | +                        f" = {value_!r}" if annotation else f"={value_!r}"  | 
 | 324 | +                    )  | 
 | 325 | + | 
 | 326 | +                return formatted_error  | 
 | 327 | + | 
 | 328 | +            if "extra_forbidden" in errors:  | 
 | 329 | +                parts.extend(  | 
 | 330 | +                    [  | 
 | 331 | +                        "The following parameters are unknown:",  | 
 | 332 | +                        "",  | 
 | 333 | +                        *[  | 
 | 334 | +                            format_error(error, value=True)  | 
 | 335 | +                            for error in errors["extra_forbidden"]  | 
 | 336 | +                        ],  | 
 | 337 | +                        "",  | 
 | 338 | +                    ]  | 
 | 339 | +                )  | 
 | 340 | + | 
 | 341 | +            if "missing" in errors:  | 
 | 342 | +                parts.extend(  | 
 | 343 | +                    [  | 
 | 344 | +                        "The following parameters are missing:",  | 
 | 345 | +                        "",  | 
 | 346 | +                        *[  | 
 | 347 | +                            format_error(error, annotation=True)  | 
 | 348 | +                            for error in errors["missing"]  | 
 | 349 | +                        ],  | 
 | 350 | +                        "",  | 
 | 351 | +                    ]  | 
 | 352 | +                )  | 
 | 353 | + | 
 | 354 | +            type_errors = ["string_type", "int_type", "float_type", "bool_type"]  | 
 | 355 | +            if any(type_error in errors for type_error in type_errors):  | 
 | 356 | +                parts.extend(  | 
 | 357 | +                    [  | 
 | 358 | +                        "The following parameters have the wrong type:",  | 
 | 359 | +                        "",  | 
 | 360 | +                        *[  | 
 | 361 | +                            format_error(error, annotation=True, value=True)  | 
 | 362 | +                            for error in itertools.chain.from_iterable(  | 
 | 363 | +                                errors[type_error] for type_error in type_errors  | 
 | 364 | +                            )  | 
 | 365 | +                        ],  | 
 | 366 | +                        "",  | 
 | 367 | +                    ]  | 
 | 368 | +                )  | 
 | 369 | + | 
 | 370 | +            raise RagnaException("\n".join(parts))  | 
 | 371 | + | 
275 | 372 |     async def _run(  | 
276 | 373 |         self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any  | 
277 | 374 |     ) -> T:  | 
 | 
0 commit comments