1
1
import typing
2
- from typing import Generic , Callable , Any , TypeVar
2
+ from typing import Generic , Callable , Any , TypeVar , Annotated
3
3
4
4
import pydantic
5
5
from pydantic import BaseModel
6
6
7
+ from langdiff .parser .decoder import get_decoder
8
+
7
9
T = TypeVar ("T" )
8
10
9
11
Field = pydantic .Field
10
12
11
13
14
+ class PydanticType :
15
+ """A hint that specifies the Pydantic type to use when converting to Pydantic models.
16
+
17
+ This is used with typing.Annotated to provide custom type hints for Pydantic model derivation.
18
+
19
+ Example:
20
+ class Item(Object):
21
+ field: Annotated[String, PydanticType(UUID)]
22
+
23
+ When Item.to_pydantic() is called, the generated field will have type UUID instead of str.
24
+ """
25
+
26
+ def __init__ (self , pydantic_type : Any ):
27
+ """Initialize with the desired Pydantic type.
28
+
29
+ Args:
30
+ pydantic_type: The type to use in the generated Pydantic model
31
+ """
32
+ self .pydantic_type = pydantic_type
33
+
34
+
12
35
class StreamingValue (Generic [T ]):
13
36
"""A generic base class for a value that is streamed incrementally.
14
37
@@ -65,12 +88,17 @@ def __init__(self):
65
88
for key , type_hint in type (self ).__annotations__ .items ():
66
89
self ._keys .append (key )
67
90
68
- # handle StreamingList[T], CompleteValue[T]
69
- if hasattr (type_hint , "__origin__" ):
70
- item_cls = typing .get_args (type_hint )[0 ]
71
- setattr (self , key , type_hint .__origin__ (item_cls ))
91
+ # Extract base type from Annotated[T, PydanticType(...), ...]
92
+ base_type = type_hint
93
+ if typing .get_origin (type_hint ) is Annotated :
94
+ base_type = typing .get_args (type_hint )[0 ]
95
+
96
+ # handle List[T], Atom[T]
97
+ if hasattr (base_type , "__origin__" ):
98
+ item_cls = typing .get_args (base_type )[0 ]
99
+ setattr (self , key , base_type .__origin__ (item_cls ))
72
100
else :
73
- setattr (self , key , type_hint ())
101
+ setattr (self , key , base_type ())
74
102
75
103
def on_update (self , func : Callable [[dict ], Any ]):
76
104
"""Register a callback that is called whenever the object is updated."""
@@ -121,7 +149,7 @@ def to_pydantic(cls) -> type[BaseModel]:
121
149
model = getattr (cls , "_pydantic_model" , None )
122
150
if model is not None : # use cached model if available
123
151
return model
124
- fields = {}
152
+ fields : dict [ str , Any ] = {}
125
153
for name , type_hint in cls .__annotations__ .items ():
126
154
type_hint = unwrap_raw_type (type_hint )
127
155
field = getattr (cls , name , None )
@@ -130,15 +158,15 @@ def to_pydantic(cls) -> type[BaseModel]:
130
158
else :
131
159
fields [name ] = type_hint
132
160
model = pydantic .create_model (cls .__name__ , ** fields , __doc__ = cls .__doc__ )
133
- cls . _pydantic_model = model
161
+ setattr ( cls , " _pydantic_model" , model )
134
162
return model
135
163
136
164
137
165
class List (Generic [T ], StreamingValue [list ]):
138
166
"""Represents a JSON array that is streamed.
139
167
140
168
This class can handle a list of items that are themselves `StreamingValue`s
141
- (like `StreamingObject ` or `StreamingString `) or complete values. It provides
169
+ (like `langdiff.Object ` or `langdiff.String `) or complete values. It provides
142
170
an `on_append` callback that is fired when a new item is added to the list.
143
171
"""
144
172
@@ -154,9 +182,7 @@ def __init__(self, item_cls: type[T]):
154
182
self ._value = []
155
183
self ._item_cls = item_cls
156
184
self ._item_streaming = issubclass (item_cls , StreamingValue )
157
- self ._decode = (
158
- item_cls .model_validate if issubclass (item_cls , BaseModel ) else None
159
- )
185
+ self ._decode = get_decoder (item_cls ) if not self ._item_streaming else None
160
186
self ._streaming_values = []
161
187
self ._on_append_funcs = []
162
188
@@ -270,7 +296,7 @@ def update(self, value: str | None):
270
296
else :
271
297
if value is None or not value .startswith (self ._value ):
272
298
raise ValueError (
273
- "StreamingString can only be updated with a continuation of the current value."
299
+ "langdiff.String can only be updated with a continuation of the current value."
274
300
)
275
301
if len (value ) == len (self ._value ):
276
302
return
@@ -290,18 +316,16 @@ class Atom(Generic[T], StreamingValue[T]):
290
316
291
317
This is useful for types like numbers, booleans, or even entire objects/lists
292
318
that are not streamed part-by-part but are present completely once available.
293
- The `on_complete` callback is triggered when the parent `StreamingObject ` or
294
- `StreamingList ` determines that this value is complete.
319
+ The `on_complete` callback is triggered when the parent `langdiff.Object ` or
320
+ `langdiff.List ` determines that this value is complete.
295
321
"""
296
322
297
323
_value : T | None
298
324
299
325
def __init__ (self , item_cls : type [T ]):
300
326
super ().__init__ ()
301
327
self ._value = None
302
- self ._decode = (
303
- item_cls .model_validate if issubclass (item_cls , BaseModel ) else None
304
- )
328
+ self ._decode = get_decoder (item_cls )
305
329
306
330
def update (self , value : T ):
307
331
self ._trigger_start ()
@@ -320,23 +344,53 @@ def value(self) -> T | None:
320
344
return self ._value
321
345
322
346
323
- def unwrap_raw_type (type_hint : Any ) -> type :
347
+ def _extract_pydantic_hint (type_hint : Any ) -> type | None :
348
+ """Extract PydanticType from Annotated type if present."""
349
+ if typing .get_origin (type_hint ) is Annotated :
350
+ args = typing .get_args (type_hint )
351
+ if len (args ) >= 2 :
352
+ # Look for PydanticType in the metadata
353
+ for metadata in args [1 :]:
354
+ if isinstance (metadata , PydanticType ):
355
+ return metadata .pydantic_type
356
+ return None
357
+
358
+
359
+ def unwrap_raw_type (type_hint : Any ):
324
360
# Possible types:
361
+ # - Annotated[T, PydanticType(U)] => U (custom Pydantic type)
325
362
# - Atom[T] => T
326
363
# - List[T] => list[unwrap(T)]
327
364
# - String => str
328
- # - T extends StreamableModel => T.to_pydantic()
365
+ # - T extends Object => T.to_pydantic()
366
+
367
+ # First check for PydanticType in Annotated types
368
+ pydantic_hint = _extract_pydantic_hint (type_hint )
369
+ if pydantic_hint is not None :
370
+ return pydantic_hint
371
+
372
+ # Handle Annotated[T, ...] by extracting the base type
373
+ if typing .get_origin (type_hint ) is Annotated :
374
+ type_hint = typing .get_args (type_hint )[0 ]
375
+
329
376
if hasattr (type_hint , "__origin__" ):
330
377
origin = type_hint .__origin__
331
378
if origin is Atom :
332
379
return typing .get_args (type_hint )[0 ]
333
380
elif origin is List :
334
381
item_type = typing .get_args (type_hint )[0 ]
335
- return list [unwrap_raw_type (item_type )]
382
+ return list [unwrap_raw_type (item_type )] # type: ignore[misc]
336
383
elif type_hint is String :
337
384
return str
338
385
elif issubclass (type_hint , Object ):
339
386
return type_hint .to_pydantic ()
387
+ elif issubclass (type_hint , StreamingValue ):
388
+ to_pydantic = getattr (type_hint , "to_pydantic" , None )
389
+ if to_pydantic is None or not callable (to_pydantic ):
390
+ raise ValueError (
391
+ f"Custom StreamingValue type { type_hint } must implement to_pydantic() method."
392
+ )
393
+ return to_pydantic ()
340
394
elif (
341
395
type_hint is str
342
396
or type_hint is int
@@ -346,5 +400,5 @@ def unwrap_raw_type(type_hint: Any) -> type:
346
400
):
347
401
return type_hint
348
402
raise ValueError (
349
- f"Unsupported type hint: { type_hint } . Expected Atom, List, String, or StreamableModel subclass."
403
+ f"Unsupported type hint: { type_hint } . Expected LangDiff Atom, List, String, or Object subclass."
350
404
)
0 commit comments