Skip to content

Commit 1133517

Browse files
committed
Merge branch 'next'
2 parents adc02ff + d055bb2 commit 1133517

File tree

7 files changed

+337
-25
lines changed

7 files changed

+337
-25
lines changed

src/langdiff/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
String,
66
Field,
77
Parser,
8+
StreamingValue,
9+
PydanticType,
810
)
911
from .tracker import (
1012
ChangeTracker,
@@ -23,6 +25,8 @@
2325
"String",
2426
"Field",
2527
"Parser",
28+
"StreamingValue",
29+
"PydanticType",
2630
# tracker
2731
"ChangeTracker",
2832
"JSONPatchChangeTracker",

src/langdiff/parser/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
Object,
55
String,
66
Field,
7+
StreamingValue,
8+
PydanticType,
79
)
810
from .parser import Parser
911

@@ -14,4 +16,6 @@
1416
"String",
1517
"Field",
1618
"Parser",
19+
"StreamingValue",
20+
"PydanticType",
1721
]

src/langdiff/parser/decoder.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import functools
2+
from typing import Callable, Any
3+
4+
from pydantic import TypeAdapter
5+
6+
7+
def _build_type_adapter_cache(max_size: int):
8+
return functools.lru_cache(maxsize=max_size)(lambda t: TypeAdapter(t))
9+
10+
11+
_DEFAULT_CACHE_SIZE = 100
12+
_CACHE = _build_type_adapter_cache(_DEFAULT_CACHE_SIZE)
13+
14+
15+
def set_type_adapter_cache_size(max_size: int):
16+
"""
17+
Set the maximum size of the TypeAdapter cache.
18+
19+
Args:
20+
max_size (int): The new maximum size for the cache.
21+
"""
22+
global _CACHE
23+
_CACHE = _build_type_adapter_cache(max_size)
24+
25+
26+
def get_cached_type_adapter(key: type) -> TypeAdapter:
27+
"""
28+
Get a cached TypeAdapter for the given key type.
29+
30+
Args:
31+
key (type): The type for which to get the TypeAdapter.
32+
33+
Returns:
34+
TypeAdapter: The cached TypeAdapter for the given type.
35+
"""
36+
return _CACHE(key)
37+
38+
39+
def get_decoder(type_hint: Any) -> Callable | None:
40+
# fast path for common types
41+
if type_hint is str or type_hint is int or type_hint is float or type_hint is bool:
42+
return None
43+
44+
return get_cached_type_adapter(type_hint).validate_python

src/langdiff/parser/model.py

Lines changed: 76 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,37 @@
11
import typing
2-
from typing import Generic, Callable, Any, TypeVar
2+
from typing import Generic, Callable, Any, TypeVar, Annotated
33

44
import pydantic
55
from pydantic import BaseModel
66

7+
from langdiff.parser.decoder import get_decoder
8+
79
T = TypeVar("T")
810

911
Field = pydantic.Field
1012

1113

14+
class PydanticType:
15+
"""A hint that specifies the Pydantic type to use when converting to Pydantic models.
16+
17+
This is used with typing.Annotated to provide custom type hints for Pydantic model derivation.
18+
19+
Example:
20+
class Item(Object):
21+
field: Annotated[String, PydanticType(UUID)]
22+
23+
When Item.to_pydantic() is called, the generated field will have type UUID instead of str.
24+
"""
25+
26+
def __init__(self, pydantic_type: Any):
27+
"""Initialize with the desired Pydantic type.
28+
29+
Args:
30+
pydantic_type: The type to use in the generated Pydantic model
31+
"""
32+
self.pydantic_type = pydantic_type
33+
34+
1235
class StreamingValue(Generic[T]):
1336
"""A generic base class for a value that is streamed incrementally.
1437
@@ -65,12 +88,17 @@ def __init__(self):
6588
for key, type_hint in type(self).__annotations__.items():
6689
self._keys.append(key)
6790

68-
# handle StreamingList[T], CompleteValue[T]
69-
if hasattr(type_hint, "__origin__"):
70-
item_cls = typing.get_args(type_hint)[0]
71-
setattr(self, key, type_hint.__origin__(item_cls))
91+
# Extract base type from Annotated[T, PydanticType(...), ...]
92+
base_type = type_hint
93+
if typing.get_origin(type_hint) is Annotated:
94+
base_type = typing.get_args(type_hint)[0]
95+
96+
# handle List[T], Atom[T]
97+
if hasattr(base_type, "__origin__"):
98+
item_cls = typing.get_args(base_type)[0]
99+
setattr(self, key, base_type.__origin__(item_cls))
72100
else:
73-
setattr(self, key, type_hint())
101+
setattr(self, key, base_type())
74102

75103
def on_update(self, func: Callable[[dict], Any]):
76104
"""Register a callback that is called whenever the object is updated."""
@@ -121,7 +149,7 @@ def to_pydantic(cls) -> type[BaseModel]:
121149
model = getattr(cls, "_pydantic_model", None)
122150
if model is not None: # use cached model if available
123151
return model
124-
fields = {}
152+
fields: dict[str, Any] = {}
125153
for name, type_hint in cls.__annotations__.items():
126154
type_hint = unwrap_raw_type(type_hint)
127155
field = getattr(cls, name, None)
@@ -130,15 +158,15 @@ def to_pydantic(cls) -> type[BaseModel]:
130158
else:
131159
fields[name] = type_hint
132160
model = pydantic.create_model(cls.__name__, **fields, __doc__=cls.__doc__)
133-
cls._pydantic_model = model
161+
setattr(cls, "_pydantic_model", model)
134162
return model
135163

136164

137165
class List(Generic[T], StreamingValue[list]):
138166
"""Represents a JSON array that is streamed.
139167
140168
This class can handle a list of items that are themselves `StreamingValue`s
141-
(like `StreamingObject` or `StreamingString`) or complete values. It provides
169+
(like `langdiff.Object` or `langdiff.String`) or complete values. It provides
142170
an `on_append` callback that is fired when a new item is added to the list.
143171
"""
144172

@@ -154,9 +182,7 @@ def __init__(self, item_cls: type[T]):
154182
self._value = []
155183
self._item_cls = item_cls
156184
self._item_streaming = issubclass(item_cls, StreamingValue)
157-
self._decode = (
158-
item_cls.model_validate if issubclass(item_cls, BaseModel) else None
159-
)
185+
self._decode = get_decoder(item_cls) if not self._item_streaming else None
160186
self._streaming_values = []
161187
self._on_append_funcs = []
162188

@@ -270,7 +296,7 @@ def update(self, value: str | None):
270296
else:
271297
if value is None or not value.startswith(self._value):
272298
raise ValueError(
273-
"StreamingString can only be updated with a continuation of the current value."
299+
"langdiff.String can only be updated with a continuation of the current value."
274300
)
275301
if len(value) == len(self._value):
276302
return
@@ -290,18 +316,16 @@ class Atom(Generic[T], StreamingValue[T]):
290316
291317
This is useful for types like numbers, booleans, or even entire objects/lists
292318
that are not streamed part-by-part but are present completely once available.
293-
The `on_complete` callback is triggered when the parent `StreamingObject` or
294-
`StreamingList` determines that this value is complete.
319+
The `on_complete` callback is triggered when the parent `langdiff.Object` or
320+
`langdiff.List` determines that this value is complete.
295321
"""
296322

297323
_value: T | None
298324

299325
def __init__(self, item_cls: type[T]):
300326
super().__init__()
301327
self._value = None
302-
self._decode = (
303-
item_cls.model_validate if issubclass(item_cls, BaseModel) else None
304-
)
328+
self._decode = get_decoder(item_cls)
305329

306330
def update(self, value: T):
307331
self._trigger_start()
@@ -320,23 +344,53 @@ def value(self) -> T | None:
320344
return self._value
321345

322346

323-
def unwrap_raw_type(type_hint: Any) -> type:
347+
def _extract_pydantic_hint(type_hint: Any) -> type | None:
348+
"""Extract PydanticType from Annotated type if present."""
349+
if typing.get_origin(type_hint) is Annotated:
350+
args = typing.get_args(type_hint)
351+
if len(args) >= 2:
352+
# Look for PydanticType in the metadata
353+
for metadata in args[1:]:
354+
if isinstance(metadata, PydanticType):
355+
return metadata.pydantic_type
356+
return None
357+
358+
359+
def unwrap_raw_type(type_hint: Any):
324360
# Possible types:
361+
# - Annotated[T, PydanticType(U)] => U (custom Pydantic type)
325362
# - Atom[T] => T
326363
# - List[T] => list[unwrap(T)]
327364
# - String => str
328-
# - T extends StreamableModel => T.to_pydantic()
365+
# - T extends Object => T.to_pydantic()
366+
367+
# First check for PydanticType in Annotated types
368+
pydantic_hint = _extract_pydantic_hint(type_hint)
369+
if pydantic_hint is not None:
370+
return pydantic_hint
371+
372+
# Handle Annotated[T, ...] by extracting the base type
373+
if typing.get_origin(type_hint) is Annotated:
374+
type_hint = typing.get_args(type_hint)[0]
375+
329376
if hasattr(type_hint, "__origin__"):
330377
origin = type_hint.__origin__
331378
if origin is Atom:
332379
return typing.get_args(type_hint)[0]
333380
elif origin is List:
334381
item_type = typing.get_args(type_hint)[0]
335-
return list[unwrap_raw_type(item_type)]
382+
return list[unwrap_raw_type(item_type)] # type: ignore[misc]
336383
elif type_hint is String:
337384
return str
338385
elif issubclass(type_hint, Object):
339386
return type_hint.to_pydantic()
387+
elif issubclass(type_hint, StreamingValue):
388+
to_pydantic = getattr(type_hint, "to_pydantic", None)
389+
if to_pydantic is None or not callable(to_pydantic):
390+
raise ValueError(
391+
f"Custom StreamingValue type {type_hint} must implement to_pydantic() method."
392+
)
393+
return to_pydantic()
340394
elif (
341395
type_hint is str
342396
or type_hint is int
@@ -346,5 +400,5 @@ def unwrap_raw_type(type_hint: Any) -> type:
346400
):
347401
return type_hint
348402
raise ValueError(
349-
f"Unsupported type hint: {type_hint}. Expected Atom, List, String, or StreamableModel subclass."
403+
f"Unsupported type hint: {type_hint}. Expected LangDiff Atom, List, String, or Object subclass."
350404
)

src/langdiff/tracker/changes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import TypedDict, Any, NotRequired
22

3-
import jsonpatch
3+
import jsonpatch # type: ignore[import-untyped]
44

55
__all__ = [
66
"Operation",

src/langdiff/tracker/impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any
22

3-
from jsonpointer import JsonPointer
3+
from jsonpointer import JsonPointer # type: ignore[import-untyped]
44
from pydantic import BaseModel
55

66
from .change_tracker import ChangeTracker, TrackedObject, TrackedList, TrackedDict, Path

0 commit comments

Comments
 (0)