diff --git a/duckdb/__init__.py b/duckdb/__init__.py index c3ec0610..b5e994fa 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -18,6 +18,49 @@ def version(): "functional" ]) +class DBAPITypeObject: + def __init__(self, types: list[typing.DuckDBPyType]) -> None: + self.types = types + + def __eq__(self, other): + if isinstance(other, typing.DuckDBPyType): + return other in self.types + return False + + def __repr__(self): + return f"" + +# Define the standard DBAPI sentinels +STRING = DBAPITypeObject([typing.VARCHAR]) +NUMBER = DBAPITypeObject([ + typing.TINYINT, + typing.UTINYINT, + typing.SMALLINT, + typing.USMALLINT, + typing.INTEGER, + typing.UINTEGER, + typing.BIGINT, + typing.UBIGINT, + typing.HUGEINT, + typing.UHUGEINT, + typing.DuckDBPyType("BIGNUM"), + typing.DuckDBPyType("DECIMAL"), + typing.FLOAT, + typing.DOUBLE +]) +DATETIME = DBAPITypeObject([ + typing.DATE, + typing.TIME, + typing.TIME_TZ, + typing.TIMESTAMP, + typing.TIMESTAMP_TZ, + typing.TIMESTAMP_NS, + typing.TIMESTAMP_MS, + typing.TIMESTAMP_S +]) +BINARY = DBAPITypeObject([typing.BLOB]) +ROWID = None + # Classes from _duckdb import ( DuckDBPyRelation, diff --git a/src/duckdb_py/pyresult.cpp b/src/duckdb_py/pyresult.cpp index 5997d57b..a2607a12 100644 --- a/src/duckdb_py/pyresult.cpp +++ b/src/duckdb_py/pyresult.cpp @@ -514,72 +514,12 @@ py::object DuckDBPyResult::FetchArrowCapsule(idx_t rows_per_batch) { return py::capsule(stream, "arrow_array_stream", ArrowArrayStreamPyCapsuleDestructor); } -py::str GetTypeToPython(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - return py::str("bool"); - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: { - return py::str("NUMBER"); - } - case LogicalTypeId::VARCHAR: { - if (type.HasAlias() && type.GetAlias() == "JSON") { - return py::str("JSON"); - } else { - return py::str("STRING"); - } - } - case LogicalTypeId::BLOB: - case LogicalTypeId::BIT: - return py::str("BINARY"); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_SEC: { - return py::str("DATETIME"); - } - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: { - return py::str("Time"); - } - case LogicalTypeId::DATE: { - return py::str("Date"); - } - case LogicalTypeId::STRUCT: - case LogicalTypeId::MAP: - return py::str("dict"); - case LogicalTypeId::LIST: { - return py::str("list"); - } - case LogicalTypeId::INTERVAL: { - return py::str("TIMEDELTA"); - } - case LogicalTypeId::UUID: { - return py::str("UUID"); - } - default: - return py::str(type.ToString()); - } -} - py::list DuckDBPyResult::GetDescription(const vector &names, const vector &types) { py::list desc; for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { auto py_name = py::str(names[col_idx]); - auto py_type = GetTypeToPython(types[col_idx]); + auto py_type = DuckDBPyType(types[col_idx]); desc.append(py::make_tuple(py_name, py_type, py::none(), py::none(), py::none(), py::none(), py::none())); } return desc; diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index de03fa7d..009e3dab 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -326,8 +326,8 @@ void DuckDBPyType::Initialize(py::handle &m) { auto type_module = py::class_>(m, "DuckDBPyType", py::module_local()); type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); - type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other")); - type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other")); + type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), py::is_operator()); + type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { @@ -347,7 +347,7 @@ void DuckDBPyType::Initialize(py::handle &m) { return make_shared_ptr(ltype); })); type_module.def("__getattr__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); - type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); + type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name"), py::is_operator()); py::implicitly_convertible(); py::implicitly_convertible(); diff --git a/tests/fast/api/test_3728.py b/tests/fast/api/test_3728.py index da0d2015..2df3c156 100644 --- a/tests/fast/api/test_3728.py +++ b/tests/fast/api/test_3728.py @@ -14,6 +14,6 @@ def test_3728_describe_enum(self, duckdb_cursor): # This fails with "RuntimeError: Not implemented Error: unsupported type: mood" assert cursor.table("person").execute().description == [ - ('name', 'STRING', None, None, None, None, None), + ('name', 'VARCHAR', None, None, None, None, None), ('current_mood', "ENUM('sad', 'ok', 'happy')", None, None, None, None, None), ] diff --git a/tests/fast/api/test_dbapi10.py b/tests/fast/api/test_dbapi10.py index aec4b24a..1fbde602 100644 --- a/tests/fast/api/test_dbapi10.py +++ b/tests/fast/api/test_dbapi10.py @@ -1,19 +1,20 @@ # cursor description from datetime import datetime, date from pytest import mark +import duckdb class TestCursorDescription(object): @mark.parametrize( "query,column_name,string_type,real_type", [ - ["SELECT * FROM integers", "i", "NUMBER", int], - ["SELECT * FROM timestamps", "t", "DATETIME", datetime], - ["SELECT DATE '1992-09-20' AS date_col;", "date_col", "Date", date], - ["SELECT '\\xAA'::BLOB AS blob_col;", "blob_col", "BINARY", bytes], - ["SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", "struct_col", "dict", dict], - ["SELECT [1, 2, 3] AS list_col", "list_col", "list", list], - ["SELECT 'Frank' AS str_col", "str_col", "STRING", str], + ["SELECT * FROM integers", "i", "INTEGER", int], + ["SELECT * FROM timestamps", "t", "TIMESTAMP", datetime], + ["SELECT DATE '1992-09-20' AS date_col;", "date_col", "DATE", date], + ["SELECT '\\xAA'::BLOB AS blob_col;", "blob_col", "BLOB", bytes], + ["SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", "struct_col", "STRUCT(x INTEGER, y INTEGER, z INTEGER)", dict], + ["SELECT [1, 2, 3] AS list_col", "list_col", "INTEGER[]", list], + ["SELECT 'Frank' AS str_col", "str_col", "VARCHAR", str], ["SELECT [1, 2, 3]::JSON AS json_col", "json_col", "JSON", str], ["SELECT union_value(tag := 1) AS union_col", "union_col", "UNION(tag INTEGER)", int], ], @@ -23,6 +24,24 @@ def test_description(self, query, column_name, string_type, real_type, duckdb_cu assert duckdb_cursor.description == [(column_name, string_type, None, None, None, None, None)] assert isinstance(duckdb_cursor.fetchone()[0], real_type) + def test_description_comparisons(self): + duckdb.execute("select 42 a, 'test' b, true c") + types = [x[1] for x in duckdb.description()] + + STRING = duckdb.STRING + NUMBER = duckdb.NUMBER + DATETIME = duckdb.DATETIME + + assert(types[1] == STRING) + assert(STRING == types[1]) + assert(types[0] != STRING) + assert((types[1] != STRING) == False) + assert((STRING != types[1]) == False) + + assert(types[1] in [STRING]) + assert(types[1] in [STRING, NUMBER]) + assert(types[1] not in [NUMBER, DATETIME]) + def test_none_description(self, duckdb_empty_cursor): assert duckdb_empty_cursor.description is None diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index c9f46021..4cb565c1 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -1,4 +1,5 @@ import duckdb +import duckdb.typing import pytest from conftest import NumpyPandas, ArrowPandas @@ -113,7 +114,7 @@ def test_readonly_properties(self): duckdb.execute("select 42") description = duckdb.description() rowcount = duckdb.rowcount() - assert description == [('42', 'NUMBER', None, None, None, None, None)] + assert description == [('42', 'INTEGER', None, None, None, None, None)] assert rowcount == -1 def test_execute(self): @@ -349,9 +350,6 @@ def test_view(self): assert [([0, 1, 2, 3, 4],)] == duckdb.view("vw").fetchall() duckdb.execute("drop view vw") - def test_description(self): - assert None != duckdb.description - def test_close(self): assert None != duckdb.close diff --git a/tests/fast/relational_api/test_rapi_description.py b/tests/fast/relational_api/test_rapi_description.py index 8738b30a..01c8a460 100644 --- a/tests/fast/relational_api/test_rapi_description.py +++ b/tests/fast/relational_api/test_rapi_description.py @@ -9,7 +9,8 @@ def test_rapi_description(self, duckdb_cursor): names = [x[0] for x in desc] types = [x[1] for x in desc] assert names == ['a', 'b'] - assert types == ['NUMBER', 'NUMBER'] + assert types == ['INTEGER', 'BIGINT'] + assert (all([x == duckdb.NUMBER for x in types])) def test_rapi_describe(self, duckdb_cursor): np = pytest.importorskip("numpy") diff --git a/tests/fast/test_map.py b/tests/fast/test_map.py index 894f1050..4dbd1a36 100644 --- a/tests/fast/test_map.py +++ b/tests/fast/test_map.py @@ -154,9 +154,9 @@ def cast_to_string(df): con = duckdb.connect() rel = con.sql('select i from range (10) tbl(i)') - assert rel.types[0] == int + assert rel.types[0] == duckdb.NUMBER mapped_rel = rel.map(cast_to_string, schema={'i': str}) - assert mapped_rel.types[0] == str + assert mapped_rel.types[0] == duckdb.STRING def test_explicit_schema_returntype_mismatch(self): def does_nothing(df): diff --git a/tests/fast/test_result.py b/tests/fast/test_result.py index 34c8e187..af68e268 100644 --- a/tests/fast/test_result.py +++ b/tests/fast/test_result.py @@ -31,9 +31,9 @@ def test_result_describe_types(self, duckdb_cursor): rel = connection.table("test") res = rel.execute() assert res.description == [ - ('i', 'bool', None, None, None, None, None), - ('j', 'Time', None, None, None, None, None), - ('k', 'STRING', None, None, None, None, None), + ('i', 'BOOLEAN', None, None, None, None, None), + ('j', 'TIME', None, None, None, None, None), + ('k', 'VARCHAR', None, None, None, None, None), ] def test_result_timestamps(self, duckdb_cursor): @@ -64,7 +64,7 @@ def test_result_interval(self): rel = connection.table("intervals") res = rel.execute() - assert res.description == [('ivals', 'TIMEDELTA', None, None, None, None, None)] + assert res.description == [('ivals', 'INTERVAL', None, None, None, None, None)] assert res.fetchall() == [ (datetime.timedelta(days=1.0),), (datetime.timedelta(seconds=2.0),),