From f5be1d061215318a2ed88333a146335fc703b6d7 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Wed, 10 Sep 2025 17:37:51 -0600 Subject: [PATCH] feat: support dtype(ir.Value) and ir.Value[] This allows for pleasant UX such as - ibis.dtype(ibis.ir.StringValue) - ibis.dtype(ibis.ir.IntegerValue["!uint32"]) ```python class Users(ibis.Table): name: ibis.ir.StringColumn age: ibis.ir.IntegerColumn["uint8"] def my_api(users: Users): # Runtime coercion of types works as expected! users = ibis.cast(users, Users) # IDE type hints work too! reveal_type(users.age) # Its an IntegerColumn! ``` I also adjusted the ibis.dtype() function so that the nullable kwarg defaults to None, which means "Don't mess with the nullability of something that is already a dtype". I can put this into a separate PR if desired, but I think this is a good change to make anyways. The current behavior of `dtype(nonnullable_dtype, nullable=True)` returning the original input, still nonnullable, is a footgun. --- ibis/expr/datatypes/core.py | 57 ++++++++++++++++------- ibis/expr/datatypes/tests/test_core.py | 63 ++++++++++++++++++++++++++ ibis/expr/tests/test_schema.py | 11 +++++ ibis/expr/types/arrays.py | 2 + ibis/expr/types/binary.py | 3 ++ ibis/expr/types/generic.py | 45 ++++++++++++++++-- ibis/expr/types/geospatial.py | 15 +++--- ibis/expr/types/inet.py | 5 +- ibis/expr/types/json.py | 2 + ibis/expr/types/logical.py | 3 ++ ibis/expr/types/maps.py | 3 ++ ibis/expr/types/numeric.py | 10 +++- ibis/expr/types/strings.py | 3 ++ ibis/expr/types/structs.py | 4 +- ibis/expr/types/temporal.py | 8 ++++ ibis/expr/types/uuid.py | 3 +- ibis/tests/expr/test_literal.py | 5 ++ ibis/tests/expr/test_table.py | 10 ++++ 18 files changed, 221 insertions(+), 31 deletions(-) diff --git a/ibis/expr/datatypes/core.py b/ibis/expr/datatypes/core.py index 4264b00168f6..20597d72670b 100644 --- a/ibis/expr/datatypes/core.py +++ b/ibis/expr/datatypes/core.py @@ -11,6 +11,7 @@ TYPE_CHECKING, Annotated, Any, + ClassVar, Generic, Literal, NamedTuple, @@ -41,52 +42,58 @@ @overload -def dtype(value: type[int] | Literal["int"], nullable: bool = True) -> Int64: ... +def dtype(value: type[int] | Literal["int"], nullable: bool | None = None) -> Int64: ... @overload def dtype( - value: type[str] | Literal["str", "string"], nullable: bool = True + value: type[str] | Literal["str", "string"], nullable: bool | None = None ) -> String: ... @overload def dtype( - value: type[bool] | Literal["bool", "boolean"], nullable: bool = True + value: type[bool] | Literal["bool", "boolean"], nullable: bool | None = None ) -> Boolean: ... @overload -def dtype(value: type[bytes] | Literal["bytes"], nullable: bool = True) -> Binary: ... +def dtype( + value: type[bytes] | Literal["bytes"], nullable: bool | None = None +) -> Binary: ... @overload -def dtype(value: type[Real] | Literal["float"], nullable: bool = True) -> Float64: ... +def dtype( + value: type[Real] | Literal["float"], nullable: bool | None = None +) -> Float64: ... @overload def dtype( - value: type[pydecimal.Decimal] | Literal["decimal"], nullable: bool = True + value: type[pydecimal.Decimal] | Literal["decimal"], nullable: bool | None = None ) -> Decimal: ... @overload def dtype( - value: type[pydatetime.datetime] | Literal["timestamp"], nullable: bool = True + value: type[pydatetime.datetime] | Literal["timestamp"], + nullable: bool | None = None, ) -> Timestamp: ... @overload def dtype( - value: type[pydatetime.date] | Literal["date"], nullable: bool = True + value: type[pydatetime.date] | Literal["date"], nullable: bool | None = None ) -> Date: ... @overload def dtype( - value: type[pydatetime.time] | Literal["time"], nullable: bool = True + value: type[pydatetime.time] | Literal["time"], nullable: bool | None = None ) -> Time: ... @overload def dtype( - value: type[pydatetime.timedelta] | Literal["interval"], nullable: bool = True + value: type[pydatetime.timedelta] | Literal["interval"], + nullable: bool | None = None, ) -> Interval: ... @overload def dtype( - value: type[pyuuid.UUID] | Literal["uuid"], nullable: bool = True + value: type[pyuuid.UUID] | Literal["uuid"], nullable: bool | None = None ) -> UUID: ... @overload def dtype( value: DataType | str | np.dtype | ExtensionDtype | pl.DataType | pa.DataType, - nullable: bool = True, + nullable: bool | None = None, ) -> DataType: ... @lazy_singledispatch -def dtype(value, nullable=True) -> DataType: +def dtype(value, nullable: bool | None = None) -> DataType: """Create a DataType object. Parameters @@ -96,7 +103,11 @@ def dtype(value, nullable=True) -> DataType: strings, python type annotations, numpy dtypes, pandas dtypes, and pyarrow types. nullable - Whether the type should be nullable. Defaults to True. + Whether the type should be nullable. By default: + + - if the value is already a DataType, its nullability is preserved. + - If the value is not a DataType, it is treated as nullable. + If `value` is a string prefixed by "!", the type is always non-nullable. Examples @@ -124,13 +135,19 @@ def dtype(value, nullable=True) -> DataType: """ if isinstance(value, DataType): - return value + if nullable is None: + return value + return value.copy(nullable=nullable) + elif getattr(value, "__dtype__", None) is not None: + return dtype(value.__dtype__, nullable=nullable) else: + if nullable is None: + nullable = True return DataType.from_typehint(value, nullable) @dtype.register(str) -def from_string(value, nullable: bool = True): +def from_string(value, nullable=True): return DataType.from_string(value, nullable) @@ -290,6 +307,14 @@ def from_typehint(cls, typ, nullable=True) -> Self: elif issubclass(typ, pyuuid.UUID): return UUID(nullable=nullable) elif annots := get_type_hints(typ): + from ibis.expr import types as ir + + if issubclass(typ, ir.Table): + annots = { + k: v + for k, v in annots.items() + if k not in get_type_hints(ir.Table) + } return Struct(toolz.valmap(dtype, annots), nullable=nullable) else: raise TypeError( diff --git a/ibis/expr/datatypes/tests/test_core.py b/ibis/expr/datatypes/tests/test_core.py index 900f93065002..cb1168a09c27 100644 --- a/ibis/expr/datatypes/tests/test_core.py +++ b/ibis/expr/datatypes/tests/test_core.py @@ -7,10 +7,12 @@ from dataclasses import dataclass from typing import Annotated, NamedTuple, Optional, Union +import parsy import pytest from pytest import param import ibis.expr.datatypes as dt +import ibis.expr.types as ir from ibis.common.annotations import ValidationError from ibis.common.patterns import As, Attrs, NoMatch, Pattern from ibis.common.temporal import TimestampUnit, TimeUnit @@ -360,6 +362,67 @@ def test_dtype_from_newer_typehints(hint, expected): assert dt.dtype(hint) == expected +def test_dtype_from_string_expr_class(): + assert dt.dtype(ir.StringValue) == dt.String(nullable=True) + assert dt.dtype(ir.StringColumn) == dt.String(nullable=True) + assert dt.dtype(ir.StringScalar) == dt.String(nullable=True) + + assert dt.dtype(ir.StringValue["string"]) == dt.String(nullable=True) + assert ir.StringValue["!string"].__dtype__ == dt.String(nullable=False) + assert dt.dtype(ir.StringValue["!string"]) == dt.String(nullable=False) + + len_string = dt.String(nullable=False, length=10) + assert dt.dtype(ir.StringValue[len_string]) == len_string + + with pytest.raises(TypeError): + ir.StringValue["int64"] + with pytest.raises(parsy.ParseError): + ir.StringValue["bogus"] + + +def test_dtype_from_integer_expr_class(): + assert dt.dtype(ir.IntegerValue) == dt.Int64(nullable=True) + assert dt.dtype(ir.IntegerValue["int64"]) == dt.Int64(nullable=True) + assert dt.dtype(ir.IntegerValue["!int64"]) == dt.Int64(nullable=False) + assert dt.dtype(ir.IntegerValue["!int64"]) == dt.Int64(nullable=False) + + with pytest.raises(TypeError): + ir.IntegerValue["float64"] + with pytest.raises(parsy.ParseError): + ir.IntegerValue["bogus"] + + +def test_dtype_from_struct_expr_class(): + with pytest.raises(TypeError): + dt.dtype(ir.StructValue) + assert dt.dtype(ir.StructValue["struct"]) == dt.Struct( + {"a": dt.string, "b": dt.Int64(nullable=False)} + ) + + +def test_dtype_from_struct_subclass(): + class MyStruct(ir.StructValue): + a: ir.StringValue + b: ir.IntegerValue["!int64"] + + expected = dt.Struct({"a": dt.string, "b": dt.Int64(nullable=False)}) + actual = dt.dtype(MyStruct) + assert actual == expected + + +def test_dtype_from_abstract_expr_class_fails(): + with pytest.raises(TypeError): + dt.dtype(ir.Value) + with pytest.raises(TypeError): + dt.dtype(ir.Column) + with pytest.raises(TypeError): + dt.dtype(ir.Scalar) + with pytest.raises(TypeError): + dt.dtype(ir.ArrayValue) + with pytest.raises(TypeError): + dt.dtype(ir.MapValue) + + def test_dtype_from_invalid_python_value(): msg = "Cannot construct an ibis datatype from python value `1.0`" with pytest.raises(TypeError, match=msg): diff --git a/ibis/expr/tests/test_schema.py b/ibis/expr/tests/test_schema.py index 69761874ac21..2b94ad4a9d1a 100644 --- a/ibis/expr/tests/test_schema.py +++ b/ibis/expr/tests/test_schema.py @@ -9,6 +9,7 @@ import ibis.expr.datatypes as dt import ibis.expr.schema as sch +import ibis.expr.types as ir from ibis.common.exceptions import IntegrityError from ibis.common.grounds import Annotable from ibis.common.patterns import CoercedTo @@ -602,3 +603,13 @@ def test_schema_from_sqlglot(): ) assert ibis_schema == expected + + +def test_schema_from_Table_subclass(): + class MyTable(ir.Table): + a: ir.StringValue + b: ir.IntegerValue["!int64"] + + expected = sch.Schema({"a": dt.string, "b": dt.Int64(nullable=False)}) + actual = sch.schema(MyTable) + assert actual == expected diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index 84f9fb796f93..9d5a376775b7 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -41,6 +41,8 @@ class ArrayValue(Value): └──────────────────────┘ """ + __dtype_supertype__ = dt.Array + def length(self) -> ir.IntegerValue: """Compute the length of an array. diff --git a/ibis/expr/types/binary.py b/ibis/expr/types/binary.py index f07aad33cf5e..22ff8f5016f9 100644 --- a/ibis/expr/types/binary.py +++ b/ibis/expr/types/binary.py @@ -7,12 +7,15 @@ from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.expr.types.generic import Column, Scalar, Value @public class BinaryValue(Value): + __dtype__ = dt.binary + def hashbytes( self, how: Literal["md5", "sha1", "sha256", "sha512"] = "sha256", / ) -> ir.BinaryValue: diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 63b1c264623d..b7302a916e64 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Union, overload from public import public @@ -10,6 +10,7 @@ import ibis.expr.builders as bl import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.common.bases import AbstractMeta from ibis.common.deferred import Deferred, _, deferrable from ibis.common.grounds import Singleton from ibis.expr.rewrites import rewrite_window_input @@ -37,10 +38,46 @@ _SENTINEL = object() +class ValueMeta(AbstractMeta): + def __new__( + metacls: type, + clsname: str, + bases: tuple[type, ...], + dct: dict[str, Any], + **kwargs: Any, + ) -> type: + def __class_getitem__(cls: Value, item: type | str) -> type: + dtype_supertype = None + if cls.__dtype__ is not None: + dtype_supertype = cls.__dtype__.__class__ + if cls.__dtype_supertype__ is not None: + dtype_supertype = cls.__dtype_supertype__ + if dtype_supertype is None: + raise TypeError(f"{cls.__name__} does not support type parameters") + dtype_obj = dt.dtype(item) + if not isinstance(dtype_obj, dtype_supertype): + raise TypeError( + f"invalid type parameter {item!r} for {cls.__name__}, " + f"expected a subtype of {dtype_supertype.__name__}" + ) + + class ParameterizedValue(cls): + __dtype__ = dtype_obj + + return ParameterizedValue + + new_dct = {**dct, "__class_getitem__": classmethod(__class_getitem__)} + new_type = super().__new__(metacls, clsname, bases, new_dct, **kwargs) + return new_type + + @public -class Value(Expr): +class Value(Expr, metaclass=ValueMeta): """Base class for a data generating expression having a known type.""" + __dtype__: ClassVar[Union[dt.DataType, None]] = None + __dtype_supertype__: ClassVar[Union[type[dt.DataType], None]] = None + def name(self, name: str, /) -> Value: """Rename an expression to `name`. @@ -2964,7 +3001,7 @@ def to_list(self, **kwargs) -> list: @public class UnknownValue(Value): - pass + __dtype__ = dt.unknown @public @@ -2979,7 +3016,7 @@ class UnknownColumn(Column): @public class NullValue(Value): - pass + __dtype__ = dt.null @public diff --git a/ibis/expr/types/geospatial.py b/ibis/expr/types/geospatial.py index c8093d093dbb..113a6999ad15 100644 --- a/ibis/expr/types/geospatial.py +++ b/ibis/expr/types/geospatial.py @@ -4,6 +4,7 @@ from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.expr.types.numeric import NumericColumn, NumericScalar, NumericValue @@ -13,6 +14,8 @@ @public class GeoSpatialValue(NumericValue): + __dtype__ = dt.geometry + def area(self) -> ir.FloatingValue: """Compute the area of a geospatial value. @@ -1666,7 +1669,7 @@ def unary_union( @public class PointValue(GeoSpatialValue): - pass + __dtype__ = dt.point @public @@ -1681,7 +1684,7 @@ class PointColumn(GeoSpatialColumn, PointValue): @public class LineStringValue(GeoSpatialValue): - pass + __dtype__ = dt.linestring @public @@ -1696,7 +1699,7 @@ class LineStringColumn(GeoSpatialColumn, LineStringValue): @public class PolygonValue(GeoSpatialValue): - pass + __dtype__ = dt.polygon @public @@ -1711,7 +1714,7 @@ class PolygonColumn(GeoSpatialColumn, PolygonValue): @public class MultiLineStringValue(GeoSpatialValue): - pass + __dtype__ = dt.multilinestring @public @@ -1726,7 +1729,7 @@ class MultiLineStringColumn(GeoSpatialColumn, MultiLineStringValue): @public class MultiPointValue(GeoSpatialValue): - pass + __dtype__ = dt.multipoint @public @@ -1741,7 +1744,7 @@ class MultiPointColumn(GeoSpatialColumn, MultiPointValue): @public class MultiPolygonValue(GeoSpatialValue): - pass + __dtype__ = dt.multipolygon @public diff --git a/ibis/expr/types/inet.py b/ibis/expr/types/inet.py index 6869487bd4e3..1cdd7508d6f5 100644 --- a/ibis/expr/types/inet.py +++ b/ibis/expr/types/inet.py @@ -2,12 +2,13 @@ from public import public +from ibis.expr import datatypes as dt from ibis.expr.types.generic import Column, Scalar, Value @public class MACADDRValue(Value): - pass + __dtype__ = dt.macaddr @public @@ -22,7 +23,7 @@ class MACADDRColumn(Column, MACADDRValue): @public class INETValue(Value): - pass + __dtype__ = dt.inet @public diff --git a/ibis/expr/types/json.py b/ibis/expr/types/json.py index d99c04af474f..c98531795590 100644 --- a/ibis/expr/types/json.py +++ b/ibis/expr/types/json.py @@ -86,6 +86,8 @@ class JSONValue(Value): └─────────────────────┘ """ + __dtype__ = dt.json + def __getitem__( self, key: str | int | ir.StringValue | ir.IntegerValue ) -> JSONValue: diff --git a/ibis/expr/types/logical.py b/ibis/expr/types/logical.py index d727438e1d11..3fe1d92b00c9 100644 --- a/ibis/expr/types/logical.py +++ b/ibis/expr/types/logical.py @@ -5,6 +5,7 @@ from public import public import ibis +import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util from ibis.expr.types.core import _binop @@ -16,6 +17,8 @@ @public class BooleanValue(NumericValue): + __dtype__ = dt.boolean + def ifelse(self, true_expr: ir.Value, false_expr: ir.Value, /) -> ir.Value: """Construct a ternary conditional expression. diff --git a/ibis/expr/types/maps.py b/ibis/expr/types/maps.py index a49c1c6e087f..aa01a980a007 100644 --- a/ibis/expr/types/maps.py +++ b/ibis/expr/types/maps.py @@ -4,6 +4,7 @@ from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.deferred import deferrable from ibis.expr.types.generic import Column, Scalar, Value @@ -80,6 +81,8 @@ class MapValue(Value): └───────────────────┘ """ + __dtype_supertype__ = dt.Map + def get(self, key: ir.Value, default: ir.Value | None = None, /) -> ir.Value: """Return the value for `key` from `expr`. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 1f78a07d8b6c..77c46eef9303 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -6,6 +6,7 @@ from public import public import ibis +import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.exceptions import IbisTypeError from ibis.expr.types.core import _binop @@ -1468,6 +1469,9 @@ def approx_quantile( @public class IntegerValue(NumericValue): + __dtype__ = dt.int64 + __dtype_supertype__ = dt.Integer + def as_timestamp(self, unit: Literal["s", "ms", "us"], /) -> ir.TimestampValue: """Convert an integral UNIX timestamp to a timestamp expression. @@ -1693,6 +1697,9 @@ def bit_xor(self, *, where: ir.BooleanValue | None = None) -> IntegerScalar: @public class FloatingValue(NumericValue): + __dtype__ = dt.float64 + __dtype_supertype__ = dt.Floating + def isnan(self) -> ir.BooleanValue: """Return whether the value is NaN. Does NOT detect `NULL` and `inf` values. @@ -1772,7 +1779,8 @@ class FloatingColumn(NumericColumn, FloatingValue): @public class DecimalValue(NumericValue): - pass + __dtype_supertype__ = dt.Decimal + __dtype__ = dt.Decimal(38, 10) @public diff --git a/ibis/expr/types/strings.py b/ibis/expr/types/strings.py index eac19d9252d6..9afc8c35dd8f 100644 --- a/ibis/expr/types/strings.py +++ b/ibis/expr/types/strings.py @@ -8,6 +8,7 @@ import ibis.expr.operations as ops from ibis import util +from ibis.expr import datatypes as dt from ibis.expr.types.core import _binop from ibis.expr.types.generic import Column, Scalar, Value @@ -19,6 +20,8 @@ @public class StringValue(Value): + __dtype__ = dt.string + def __getitem__(self, key: slice | int | ir.IntegerScalar) -> StringValue: """Index or slice a string expression. diff --git a/ibis/expr/types/structs.py b/ibis/expr/types/structs.py index 7639a38f0516..f37bda7fc139 100644 --- a/ibis/expr/types/structs.py +++ b/ibis/expr/types/structs.py @@ -6,6 +6,7 @@ from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.deferred import deferrable from ibis.common.exceptions import IbisError @@ -14,7 +15,6 @@ if TYPE_CHECKING: from collections.abc import Iterable, Mapping, Sequence - import ibis.expr.datatypes as dt import ibis.expr.types as ir @@ -140,6 +140,8 @@ class StructValue(Value): └───────┘ """ + __dtype_supertype__ = dt.Struct + def __dir__(self): out = set(dir(type(self))) out.update( diff --git a/ibis/expr/types/temporal.py b/ibis/expr/types/temporal.py index 181ffd7ec4ef..c6251344b217 100644 --- a/ibis/expr/types/temporal.py +++ b/ibis/expr/types/temporal.py @@ -329,6 +329,8 @@ def between( class TimeValue(_TimeComponentMixin, Value): """A time of day from 0:00:00 to 23:59:59.999999999.""" + __dtype__ = dt.time + def strftime(self, format_str: str, /) -> ir.StringValue: """Format a time according to `format_str`. @@ -505,6 +507,8 @@ class TimeColumn(Column, TimeValue): class DateValue(Value, _DateComponentMixin): """A date (without time), eg 2024-12-31.""" + __dtype__ = dt.date + def strftime(self, format_str: str, /) -> ir.StringValue: """Format a date according to `format_str`. @@ -785,6 +789,8 @@ class DateColumn(Column, DateValue): class TimestampValue(_DateComponentMixin, _TimeComponentMixin, Value): """A date and time, eg 2024-12-31 23:59:59.999999.""" + __dtype__ = dt.timestamp + def strftime(self, format_str: str, /) -> ir.StringValue: """Format a timestamp according to `format_str`. @@ -1255,6 +1261,8 @@ class IntervalValue(Value): which results in a new timestamp expression. """ + __dtype__ = dt.Interval + def as_unit(self, target_unit: str, /) -> IntervalValue: """Convert this interval to units of `target_unit`.""" # TODO(kszucs): should use a separate operation for unit conversion diff --git a/ibis/expr/types/uuid.py b/ibis/expr/types/uuid.py index 772e74da5186..a1f9c241c970 100644 --- a/ibis/expr/types/uuid.py +++ b/ibis/expr/types/uuid.py @@ -2,12 +2,13 @@ from public import public +from ibis.expr import datatypes as dt from ibis.expr.types.generic import Column, Scalar, Value @public class UUIDValue(Value): - pass + __dtype__ = dt.uuid @public diff --git a/ibis/tests/expr/test_literal.py b/ibis/tests/expr/test_literal.py index 4da79536aaf4..8c44107c840c 100644 --- a/ibis/tests/expr/test_literal.py +++ b/ibis/tests/expr/test_literal.py @@ -175,3 +175,8 @@ def test_deferred(table): deferred = ibis.literal(expr, type=dtype) result = deferred.resolve(table) assert result.op().value == "g" + + +def test_literal_from_expr_for_type(): + expr = ibis.literal(1, ibis.ir.IntegerValue["!int32"]) + assert expr.type() == dt.Int32(nullable=False) diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 80321d98c273..5a9325112ac2 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -1879,6 +1879,16 @@ def test_cast(): ) +def test_cast_subclass(): + class Users(ibis.Table): + name: ibis.ir.StringColumn + age: ibis.ir.IntegerColumn["uint8"] + + inp = ibis.table(dict(name="string", age="int64")) + result = inp.cast(Users) + assert result.schema() == ibis.schema(dict(name="string", age="uint8")) + + def test_pivot_longer(): diamonds = ibis.table( {