Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Generic,
Literal,
NamedTuple,
Expand Down Expand Up @@ -41,52 +42,58 @@


@overload
def dtype(value: type[int] | Literal["int"], nullable: bool = True) -> Int64: ...
def dtype(value: type[int] | Literal["int"], nullable: bool | None = None) -> Int64: ...
@overload
def dtype(
value: type[str] | Literal["str", "string"], nullable: bool = True
value: type[str] | Literal["str", "string"], nullable: bool | None = None
) -> String: ...
@overload
def dtype(
value: type[bool] | Literal["bool", "boolean"], nullable: bool = True
value: type[bool] | Literal["bool", "boolean"], nullable: bool | None = None
) -> Boolean: ...
@overload
def dtype(value: type[bytes] | Literal["bytes"], nullable: bool = True) -> Binary: ...
def dtype(
value: type[bytes] | Literal["bytes"], nullable: bool | None = None
) -> Binary: ...
@overload
def dtype(value: type[Real] | Literal["float"], nullable: bool = True) -> Float64: ...
def dtype(
value: type[Real] | Literal["float"], nullable: bool | None = None
) -> Float64: ...
@overload
def dtype(
value: type[pydecimal.Decimal] | Literal["decimal"], nullable: bool = True
value: type[pydecimal.Decimal] | Literal["decimal"], nullable: bool | None = None
) -> Decimal: ...
@overload
def dtype(
value: type[pydatetime.datetime] | Literal["timestamp"], nullable: bool = True
value: type[pydatetime.datetime] | Literal["timestamp"],
nullable: bool | None = None,
) -> Timestamp: ...
@overload
def dtype(
value: type[pydatetime.date] | Literal["date"], nullable: bool = True
value: type[pydatetime.date] | Literal["date"], nullable: bool | None = None
) -> Date: ...
@overload
def dtype(
value: type[pydatetime.time] | Literal["time"], nullable: bool = True
value: type[pydatetime.time] | Literal["time"], nullable: bool | None = None
) -> Time: ...
@overload
def dtype(
value: type[pydatetime.timedelta] | Literal["interval"], nullable: bool = True
value: type[pydatetime.timedelta] | Literal["interval"],
nullable: bool | None = None,
) -> Interval: ...
@overload
def dtype(
value: type[pyuuid.UUID] | Literal["uuid"], nullable: bool = True
value: type[pyuuid.UUID] | Literal["uuid"], nullable: bool | None = None
) -> UUID: ...
@overload
def dtype(
value: DataType | str | np.dtype | ExtensionDtype | pl.DataType | pa.DataType,
nullable: bool = True,
nullable: bool | None = None,
) -> DataType: ...


@lazy_singledispatch
def dtype(value, nullable=True) -> DataType:
def dtype(value, nullable: bool | None = None) -> DataType:
"""Create a DataType object.

Parameters
Expand All @@ -96,7 +103,11 @@ def dtype(value, nullable=True) -> DataType:
strings, python type annotations, numpy dtypes, pandas dtypes, and
pyarrow types.
nullable
Whether the type should be nullable. Defaults to True.
Whether the type should be nullable. By default:

- if the value is already a DataType, its nullability is preserved.
- If the value is not a DataType, it is treated as nullable.

If `value` is a string prefixed by "!", the type is always non-nullable.

Examples
Expand Down Expand Up @@ -124,13 +135,19 @@ def dtype(value, nullable=True) -> DataType:

"""
if isinstance(value, DataType):
return value
if nullable is None:
return value
return value.copy(nullable=nullable)
elif getattr(value, "__dtype__", None) is not None:
return dtype(value.__dtype__, nullable=nullable)
else:
if nullable is None:
nullable = True
return DataType.from_typehint(value, nullable)


@dtype.register(str)
def from_string(value, nullable: bool = True):
def from_string(value, nullable=True):
return DataType.from_string(value, nullable)


Expand Down Expand Up @@ -290,6 +307,14 @@ def from_typehint(cls, typ, nullable=True) -> Self:
elif issubclass(typ, pyuuid.UUID):
return UUID(nullable=nullable)
elif annots := get_type_hints(typ):
from ibis.expr import types as ir

if issubclass(typ, ir.Table):
annots = {
k: v
for k, v in annots.items()
if k not in get_type_hints(ir.Table)
}
return Struct(toolz.valmap(dtype, annots), nullable=nullable)
else:
raise TypeError(
Expand Down
63 changes: 63 additions & 0 deletions ibis/expr/datatypes/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from dataclasses import dataclass
from typing import Annotated, NamedTuple, Optional, Union

import parsy
import pytest
from pytest import param

import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis.common.annotations import ValidationError
from ibis.common.patterns import As, Attrs, NoMatch, Pattern
from ibis.common.temporal import TimestampUnit, TimeUnit
Expand Down Expand Up @@ -360,6 +362,67 @@ def test_dtype_from_newer_typehints(hint, expected):
assert dt.dtype(hint) == expected


def test_dtype_from_string_expr_class():
assert dt.dtype(ir.StringValue) == dt.String(nullable=True)
assert dt.dtype(ir.StringColumn) == dt.String(nullable=True)
assert dt.dtype(ir.StringScalar) == dt.String(nullable=True)

assert dt.dtype(ir.StringValue["string"]) == dt.String(nullable=True)
assert ir.StringValue["!string"].__dtype__ == dt.String(nullable=False)
assert dt.dtype(ir.StringValue["!string"]) == dt.String(nullable=False)

len_string = dt.String(nullable=False, length=10)
assert dt.dtype(ir.StringValue[len_string]) == len_string

with pytest.raises(TypeError):
ir.StringValue["int64"]
with pytest.raises(parsy.ParseError):
ir.StringValue["bogus"]


def test_dtype_from_integer_expr_class():
assert dt.dtype(ir.IntegerValue) == dt.Int64(nullable=True)
assert dt.dtype(ir.IntegerValue["int64"]) == dt.Int64(nullable=True)
assert dt.dtype(ir.IntegerValue["!int64"]) == dt.Int64(nullable=False)
assert dt.dtype(ir.IntegerValue["!int64"]) == dt.Int64(nullable=False)

with pytest.raises(TypeError):
ir.IntegerValue["float64"]
with pytest.raises(parsy.ParseError):
ir.IntegerValue["bogus"]


def test_dtype_from_struct_expr_class():
with pytest.raises(TypeError):
dt.dtype(ir.StructValue)
assert dt.dtype(ir.StructValue["struct<a: string, b: !int64>"]) == dt.Struct(
{"a": dt.string, "b": dt.Int64(nullable=False)}
)


def test_dtype_from_struct_subclass():
class MyStruct(ir.StructValue):
a: ir.StringValue
b: ir.IntegerValue["!int64"]

expected = dt.Struct({"a": dt.string, "b": dt.Int64(nullable=False)})
actual = dt.dtype(MyStruct)
assert actual == expected


def test_dtype_from_abstract_expr_class_fails():
with pytest.raises(TypeError):
dt.dtype(ir.Value)
with pytest.raises(TypeError):
dt.dtype(ir.Column)
with pytest.raises(TypeError):
dt.dtype(ir.Scalar)
with pytest.raises(TypeError):
dt.dtype(ir.ArrayValue)
with pytest.raises(TypeError):
dt.dtype(ir.MapValue)


def test_dtype_from_invalid_python_value():
msg = "Cannot construct an ibis datatype from python value `1.0`"
with pytest.raises(TypeError, match=msg):
Expand Down
11 changes: 11 additions & 0 deletions ibis/expr/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.common.exceptions import IntegrityError
from ibis.common.grounds import Annotable
from ibis.common.patterns import CoercedTo
Expand Down Expand Up @@ -602,3 +603,13 @@ def test_schema_from_sqlglot():
)

assert ibis_schema == expected


def test_schema_from_Table_subclass():
class MyTable(ir.Table):
a: ir.StringValue
b: ir.IntegerValue["!int64"]

expected = sch.Schema({"a": dt.string, "b": dt.Int64(nullable=False)})
actual = sch.schema(MyTable)
assert actual == expected
2 changes: 2 additions & 0 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class ArrayValue(Value):
└──────────────────────┘
"""

__dtype_supertype__ = dt.Array

def length(self) -> ir.IntegerValue:
"""Compute the length of an array.

Expand Down
3 changes: 3 additions & 0 deletions ibis/expr/types/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

from public import public

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.expr.types.generic import Column, Scalar, Value


@public
class BinaryValue(Value):
__dtype__ = dt.binary

def hashbytes(
self, how: Literal["md5", "sha1", "sha256", "sha512"] = "sha256", /
) -> ir.BinaryValue:
Expand Down
45 changes: 41 additions & 4 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, Literal, overload
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Union, overload

from public import public

Expand All @@ -10,6 +10,7 @@
import ibis.expr.builders as bl
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.bases import AbstractMeta
from ibis.common.deferred import Deferred, _, deferrable
from ibis.common.grounds import Singleton
from ibis.expr.rewrites import rewrite_window_input
Expand Down Expand Up @@ -37,10 +38,46 @@
_SENTINEL = object()


class ValueMeta(AbstractMeta):
def __new__(
metacls: type,
clsname: str,
bases: tuple[type, ...],
dct: dict[str, Any],
**kwargs: Any,
) -> type:
def __class_getitem__(cls: Value, item: type | str) -> type:
dtype_supertype = None
if cls.__dtype__ is not None:
dtype_supertype = cls.__dtype__.__class__
if cls.__dtype_supertype__ is not None:
dtype_supertype = cls.__dtype_supertype__
if dtype_supertype is None:
raise TypeError(f"{cls.__name__} does not support type parameters")
dtype_obj = dt.dtype(item)
if not isinstance(dtype_obj, dtype_supertype):
raise TypeError(
f"invalid type parameter {item!r} for {cls.__name__}, "
f"expected a subtype of {dtype_supertype.__name__}"
)

class ParameterizedValue(cls):
__dtype__ = dtype_obj

return ParameterizedValue

new_dct = {**dct, "__class_getitem__": classmethod(__class_getitem__)}
new_type = super().__new__(metacls, clsname, bases, new_dct, **kwargs)
return new_type


@public
class Value(Expr):
class Value(Expr, metaclass=ValueMeta):
"""Base class for a data generating expression having a known type."""

__dtype__: ClassVar[Union[dt.DataType, None]] = None
__dtype_supertype__: ClassVar[Union[type[dt.DataType], None]] = None

def name(self, name: str, /) -> Value:
"""Rename an expression to `name`.

Expand Down Expand Up @@ -2964,7 +3001,7 @@ def to_list(self, **kwargs) -> list:

@public
class UnknownValue(Value):
pass
__dtype__ = dt.unknown


@public
Expand All @@ -2979,7 +3016,7 @@ class UnknownColumn(Column):

@public
class NullValue(Value):
pass
__dtype__ = dt.null


@public
Expand Down
Loading
Loading