Skip to content

Commit b0d5f12

Browse files
Merge pull request #3 from Cloudrisk/develop
allow ..WithMeta data types to be initialized with objects of the bas…
2 parents 05a841b + ce66557 commit b0d5f12

File tree

2 files changed

+93
-55
lines changed

2 files changed

+93
-55
lines changed

src/rune/runtime/metadata.py

Lines changed: 70 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from decimal import Decimal
77
from typing import Any, Never, get_args
88
import datetime
9-
from typing_extensions import Self
9+
from typing_extensions import Self, Tuple
1010
from pydantic import (PlainSerializer, PlainValidator, WrapValidator,
1111
WrapSerializer)
1212
# from rune.runtime.object_registry import get_object
@@ -404,6 +404,17 @@ def validator(cls, allowed_meta: tuple[str] | tuple[Never, ...] = tuple()):
404404

405405
class BasicTypeMetaDataMixin(BaseMetaDataMixin):
406406
'''holds the metadata associated with an instance'''
407+
_INPUT_TYPES: Any | Tuple[Any, ...] = str # to be overridden by subclasses
408+
_OUTPUT_TYPE: Any = str # to be overridden by subclasses
409+
_JSON_OUTPUT = str | dict
410+
411+
@classmethod
412+
def _check_type(cls, value):
413+
if not isinstance(value, cls._INPUT_TYPES):
414+
raise ValueError(f'{cls.__name__} can be instantiated only with '
415+
f'one of the following type(s): {cls._INPUT_TYPES},'
416+
f' however the value is of type {type(value)}')
417+
407418
@classmethod
408419
def serialise(cls, obj, base_type) -> dict:
409420
'''used as serialisation method with pydantic'''
@@ -431,7 +442,7 @@ def deserialize(cls, obj, handler, base_types, allowed_meta: set[str]):
431442
@lru_cache
432443
def serializer(cls):
433444
'''should return the validator for the specific class'''
434-
ser_fn = partial(cls.serialise, base_type=str)
445+
ser_fn = partial(cls.serialise, base_type=cls._OUTPUT_TYPE)
435446
return PlainSerializer(ser_fn, return_type=dict)
436447

437448
@classmethod
@@ -440,49 +451,62 @@ def validator(cls, allowed_meta: tuple[str]):
440451
'''default validator for the specific class'''
441452
allowed = set(allowed_meta)
442453
return WrapValidator(partial(cls.deserialize,
443-
base_types=str,
454+
base_types=cls._INPUT_TYPES,
444455
allowed_meta=allowed),
445-
json_schema_input_type=str | dict)
456+
json_schema_input_type=cls._JSON_OUTPUT)
446457

447458

448459
class DateWithMeta(datetime.date, BasicTypeMetaDataMixin):
449460
'''date with metadata'''
461+
_INPUT_TYPES = (datetime.date, str)
462+
450463
def __new__(cls, value, **kwds): # pylint: disable=signature-differs
451-
ymd = datetime.date.fromisoformat(value).timetuple()[:3]
464+
cls._check_type(value)
465+
if isinstance(value, str):
466+
value = datetime.date.fromisoformat(value)
467+
ymd = value.timetuple()[:3]
452468
obj = datetime.date.__new__(cls, *ymd)
453469
obj.set_meta(check_allowed=False, **kwds)
454470
return obj
455471

456472

457473
class TimeWithMeta(datetime.time, BasicTypeMetaDataMixin):
458474
'''annotated time'''
475+
_INPUT_TYPES = (datetime.time, str)
476+
459477
def __new__(cls, value, **kwds): # pylint: disable=signature-differs
460-
aux = datetime.time.fromisoformat(value)
478+
cls._check_type(value)
479+
if isinstance(value, str):
480+
value = datetime.time.fromisoformat(value)
461481
obj = datetime.time.__new__(cls,
462-
aux.hour,
463-
aux.minute,
464-
aux.second,
465-
aux.microsecond,
466-
aux.tzinfo,
467-
fold=aux.fold)
482+
value.hour,
483+
value.minute,
484+
value.second,
485+
value.microsecond,
486+
value.tzinfo,
487+
fold=value.fold)
468488
obj.set_meta(check_allowed=False, **kwds)
469489
return obj
470490

471491

472492
class DateTimeWithMeta(datetime.datetime, BasicTypeMetaDataMixin):
473493
'''annotated datetime'''
494+
_INPUT_TYPES = (datetime.datetime, str)
495+
474496
def __new__(cls, value, **kwds): # pylint: disable=signature-differs
475-
aux = datetime.datetime.fromisoformat(value)
497+
cls._check_type(value)
498+
if isinstance(value, str):
499+
value = datetime.datetime.fromisoformat(value)
476500
obj = datetime.datetime.__new__(cls,
477-
aux.year,
478-
aux.month,
479-
aux.day,
480-
aux.hour,
481-
aux.minute,
482-
aux.second,
483-
aux.microsecond,
484-
aux.tzinfo,
485-
fold=aux.fold)
501+
value.year,
502+
value.month,
503+
value.day,
504+
value.hour,
505+
value.minute,
506+
value.second,
507+
value.microsecond,
508+
value.tzinfo,
509+
fold=value.fold)
486510
obj.set_meta(check_allowed=False, **kwds)
487511
return obj
488512

@@ -500,54 +524,45 @@ def __new__(cls, value, **kwds):
500524

501525
class IntWithMeta(int, BasicTypeMetaDataMixin):
502526
'''annotated integer'''
527+
_INPUT_TYPES = int
528+
_OUTPUT_TYPE = int
529+
_JSON_OUTPUT = int | dict
530+
503531
def __new__(cls, value, **kwds):
504532
obj = int.__new__(cls, value)
505533
obj.set_meta(check_allowed=False, **kwds)
506534
return obj
507535

508-
@classmethod
509-
@lru_cache
510-
def serializer(cls):
511-
'''should return the validator for the specific class'''
512-
ser_fn = partial(cls.serialise, base_type=int)
513-
return PlainSerializer(ser_fn, return_type=dict)
514-
515-
@classmethod
516-
@lru_cache
517-
def validator(cls, allowed_meta: tuple[str]):
518-
'''default validator for the specific class'''
519-
allowed = set(allowed_meta)
520-
return WrapValidator(partial(cls.deserialize,
521-
base_types=int,
522-
allowed_meta=allowed),
523-
json_schema_input_type=int | dict)
524-
525536

526537
class NumberWithMeta(Decimal, BasicTypeMetaDataMixin):
527538
'''annotated number'''
539+
_INPUT_TYPES = (Decimal, float, int, str)
540+
_OUTPUT_TYPE = Decimal
541+
_JSON_OUTPUT = float | int | str | dict
542+
528543
def __new__(cls, value, **kwds):
529544
# NOTE: it could be necessary to convert the value to str if it is a
530545
# float
531546
obj = Decimal.__new__(cls, value)
532547
obj.set_meta(check_allowed=False, **kwds)
533548
return obj
534549

535-
@classmethod
536-
@lru_cache
537-
def serializer(cls):
538-
'''should return the validator for the specific class'''
539-
ser_fn = partial(cls.serialise, base_type=Decimal)
540-
return PlainSerializer(ser_fn, return_type=dict)
541-
542-
@classmethod
543-
@lru_cache
544-
def validator(cls, allowed_meta: tuple[str]):
545-
'''default validator for the specific class'''
546-
allowed = set(allowed_meta)
547-
return WrapValidator(partial(cls.deserialize,
548-
base_types=(Decimal, float, int, str),
549-
allowed_meta=allowed),
550-
json_schema_input_type=float | int | str | dict)
550+
# @classmethod
551+
# @lru_cache
552+
# def serializer(cls):
553+
# '''should return the validator for the specific class'''
554+
# ser_fn = partial(cls.serialise, base_type=Decimal)
555+
# return PlainSerializer(ser_fn, return_type=dict)
556+
557+
# @classmethod
558+
# @lru_cache
559+
# def validator(cls, allowed_meta: tuple[str]):
560+
# '''default validator for the specific class'''
561+
# allowed = set(allowed_meta)
562+
# return WrapValidator(partial(cls.deserialize,
563+
# base_types=(Decimal, float, int, str),
564+
# allowed_meta=allowed),
565+
# json_schema_input_type=float | int | str | dict)
551566

552567

553568
class _EnumWrapperDefaultVal(Enum):

test/test_basic_types_with_meta.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,29 @@ def test_dump_annotated_date_simple():
216216
assert json_str == '{"date":{"@data":"2024-10-10"}}'
217217

218218

219+
def test_dump_annotated_date_date():
220+
'''test the annotated string'''
221+
model = AnnotatedDateModel(date=date(2024, 10, 10))
222+
json_str = model.model_dump_json(exclude_unset=True)
223+
assert json_str == '{"date":{"@data":"2024-10-10"}}'
224+
225+
model = AnnotatedDateModel(date=DateWithMeta(date(2024, 10, 10)))
226+
json_str = model.model_dump_json(exclude_unset=True)
227+
assert json_str == '{"date":{"@data":"2024-10-10"}}'
228+
229+
230+
def test_annotated_date_fail():
231+
'''test instantiation failure with an incorrect type'''
232+
with pytest.raises(AttributeError):
233+
AnnotatedDateModel(date=10)
234+
235+
236+
def test_date_with_meta_fail():
237+
'''test instantiation failure with an incorrect type'''
238+
with pytest.raises(ValueError):
239+
DateWithMeta(10)
240+
241+
219242
def test_load_annotated_date_scheme():
220243
'''test the loading of annotated with a scheme strings'''
221244
scheme_json = '{"date":{"@data":"2024-10-10","@scheme":"http://fpml.org"}}'

0 commit comments

Comments
 (0)