diff --git a/example.py b/example.py index 929af56..ae5525a 100644 --- a/example.py +++ b/example.py @@ -24,7 +24,10 @@ }, ) -class Base(DeclarativeBase): pass + +class Base(DeclarativeBase): + pass + class Person(Base): __tablename__ = "people" diff --git a/pyproject.toml b/pyproject.toml index 5d7bf33..551cbe7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,7 @@ [tool.isort] profile = "black" + +[tool.pytest.ini_options] +pythonpath = [ + "src" +] diff --git a/setup.py b/setup.py index a3ec982..83474e9 100644 --- a/setup.py +++ b/setup.py @@ -15,16 +15,14 @@ entry_points={ "sqlalchemy.dialects": [ "rockset_sqlalchemy = rockset_sqlalchemy.sqlalchemy:RocksetDialect", - "rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect" + "rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect", ] }, - install_requires=[ - "rockset>=1.0.0", - "sqlalchemy>=1.4.0" - ], + install_requires=["rockset>=1.0.0", "sqlalchemy>=1.4.0"], + tests_require=["pytest>=8.2.1"], classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", ], ) diff --git a/src/rockset_sqlalchemy/cursor.py b/src/rockset_sqlalchemy/cursor.py index 9e6ecf1..59f2450 100644 --- a/src/rockset_sqlalchemy/cursor.py +++ b/src/rockset_sqlalchemy/cursor.py @@ -162,12 +162,13 @@ def description(self): return None desc = [] - for field_name, field_value in self._response.results[0].items(): - name, type_ = field_name, Cursor.__convert_to_rockset_type(field_value) - null_ok = name != "_id" and "__id" not in name + if len(self._response.results) > 0: + for field_name, field_value in self._response.results[0].items(): + name, type_ = field_name, Cursor.__convert_to_rockset_type(field_value) + null_ok = name != "_id" and "__id" not in name - # name, type_code, display_size, internal_size, precision, scale, null_ok - desc.append((name, type_, None, None, None, None, null_ok)) + # name, type_code, display_size, internal_size, precision, scale, null_ok + desc.append((name, type_, None, None, None, None, null_ok)) return desc def __iter__(self): diff --git a/src/rockset_sqlalchemy/sqlalchemy/dialect.py b/src/rockset_sqlalchemy/sqlalchemy/dialect.py index 85da2bd..84ab034 100644 --- a/src/rockset_sqlalchemy/sqlalchemy/dialect.py +++ b/src/rockset_sqlalchemy/sqlalchemy/dialect.py @@ -1,4 +1,6 @@ -from sqlalchemy import exc, types, util +from typing import List, Any, Sequence + +from sqlalchemy import exc, types, util, Engine from sqlalchemy.engine import default, reflection from sqlalchemy.sql import compiler @@ -80,6 +82,7 @@ def get_table_names(self, connection, schema=None, **kw): return [w["name"] for w in tables] def _get_table_columns(self, connection, table_name, schema): + """Gets table columns based on a retrieved row from the collection""" schema = self.identifier_preparer.quote_identifier(schema) table_name = self.identifier_preparer.quote_identifier(table_name) @@ -87,9 +90,7 @@ def _get_table_columns(self, connection, table_name, schema): # This assumes the whole collection has a fixed schema of course. q = f"SELECT * FROM {schema}.{table_name} LIMIT 1" try: - cursor = connection.connect().connection.cursor() - cursor.execute(q) - fields = cursor.description + fields = self._exec_query_description(connection, q) if not fields: # Return a fake schema if the collection is empty. return [("null", "null")] @@ -115,10 +116,72 @@ def _get_table_columns(self, connection, table_name, schema): raise e return columns + def _validate_query(self, connection: Engine, query: str): + import rockset.models + + query_request_sql = rockset.models.QueryRequestSql(query=query) + # raises rockset.exceptions.BadRequestException if DESCRIBE is invalid on this collection e.g. rollups + connection.connect().connection._client.Queries.validate(sql=query_request_sql) + + def _exec_query(self, connection: Engine, query: str) -> Sequence[Any]: + cursor = connection.connect().connection.cursor() + cursor.execute(query) + return cursor.fetchall() + + def _exec_query_description(self, connection: Engine, query: str) -> Sequence[Any]: + cursor = connection.connect().connection.cursor() + cursor.execute(query) + return cursor.description + + def _get_table_columns_describe(self, connection, table_name, schema, **kw): + """Gets table columns based on the query DESCRIBE SomeCollection""" + schema = self.identifier_preparer.quote_identifier(schema) + table_name = self.identifier_preparer.quote_identifier(table_name) + + max_field_depth = kw["max_field_depth"] if "max_field_depth" in kw else 1 + if not isinstance(max_field_depth, int): + raise ValueError("Query option max_field_depth, must be of type 'int'") + + q = f"DESCRIBE {table_name} OPTION(max_field_depth={max_field_depth})" + self._validate_query(connection, q) + + try: + results = self._exec_query(connection, q) + columns = list() + for result in results: + field_name = ".".join(result[0]) + field_type = result[1] + if field_type not in type_map.keys(): + raise exc.SQLAlchemyError( + "Query returned unsupported type {} in field {} in table {}.{}".format( + field_type, field_name, schema, table_name + ) + ) + nullable = ( + False if result[0][0] == "_id" else True + ) # _id is the only field that's not nullable + columns.append( + { + "name": field_name, + "type": type_map[result[1]], + "nullable": nullable, + "default": None, + } + ) + except Exception as e: + # TODO: more graceful handling of exceptions. + raise e + return columns + @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): if schema is None: schema = RocksetDialect.default_schema_name + try: + return self._get_table_columns_describe(connection, table_name, schema) + except Exception as e: + # likely a rollup collection, so revert to old behavior + pass return self._get_table_columns(connection, table_name, schema) @reflection.cache diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/sqlalchemy/__init__.py b/tests/unittests/sqlalchemy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/sqlalchemy/conftest.py b/tests/unittests/sqlalchemy/conftest.py new file mode 100644 index 0000000..b40d21a --- /dev/null +++ b/tests/unittests/sqlalchemy/conftest.py @@ -0,0 +1,20 @@ +import os + +import pytest +from sqlalchemy import create_engine + +# unittests are local so ROCKSET_API_KEY can be set to anything +host = "https://api.usw2a1.rockset.com" +api_key = "abc123" + + +@pytest.fixture +def engine(): + + return create_engine( + "rockset://", + connect_args={ + "api_server": host, + "api_key": api_key, + }, + ) diff --git a/tests/unittests/sqlalchemy/test_dialect.py b/tests/unittests/sqlalchemy/test_dialect.py new file mode 100644 index 0000000..f06cb3e --- /dev/null +++ b/tests/unittests/sqlalchemy/test_dialect.py @@ -0,0 +1,87 @@ +from unittest import mock + +import pytest +import rockset + +from rockset_sqlalchemy.sqlalchemy import dialect +from rockset_sqlalchemy.sqlalchemy.types import Timestamp, String, Object, NullType + + +@pytest.fixture +def columns_from_describe_results(): + """Schema results based on Person in example.py""" + return [ + (["_event_time"], "timestamp", 8, 8), + (["_id"], "string", 8, 8), + (["info"], "object", 8, 8), + (["name"], "string", 8, 8), + ] + + +@pytest.fixture +def columns_from_select_row_results(): + """Schema results based on Person in example.py""" + return [ + ("_id", "string", None, None, None, None, False), + ("_event_time", "string", None, None, None, None, True), + ("_meta", "null", None, None, None, None, True), + ("name", "string", None, None, None, None, True), + ("info", "object", None, None, None, None, True), + ] + + +def test_get_columns_with_describe_returns_expected_structure( + engine, columns_from_describe_results +): + with mock.patch( + "rockset_sqlalchemy.cursor.Cursor.execute_query", mock.Mock() + ) as _, mock.patch( + "rockset_sqlalchemy.sqlalchemy.dialect.RocksetDialect._validate_query", + mock.Mock(), + ) as _, mock.patch( + "rockset_sqlalchemy.sqlalchemy.dialect.RocksetDialect._exec_query", + mock.Mock(return_value=columns_from_describe_results), + ) as _: + expected_results = [ + { + "name": "_event_time", + "type": Timestamp, + "nullable": True, + "default": None, + }, + {"name": "_id", "type": String, "nullable": False, "default": None}, + {"name": "info", "type": Object, "nullable": True, "default": None}, + {"name": "name", "type": String, "nullable": True, "default": None}, + ] + + rs_dialect = dialect.RocksetDialect() + columns = rs_dialect.get_columns(engine, "people") + assert expected_results == columns + + +def test_get_columns_falls_back_on_rockset_exception_to_select_one_row( + engine, columns_from_select_row_results +): + with mock.patch( + "rockset_sqlalchemy.sqlalchemy.dialect.RocksetDialect._get_table_columns_describe", + mock.Mock(side_effect=rockset.exceptions.RocksetException()), + ) as _, mock.patch( + "rockset_sqlalchemy.sqlalchemy.dialect.RocksetDialect._exec_query_description", + mock.Mock(return_value=columns_from_select_row_results), + ) as _: + expected_results = [ + {"name": "_id", "type": String, "nullable": False, "default": None}, + {"name": "_event_time", "type": String, "nullable": True, "default": None}, + { + "name": "_meta", + "type": NullType, + "nullable": True, + "default": None, + }, + {"name": "name", "type": String, "nullable": True, "default": None}, + {"name": "info", "type": Object, "nullable": True, "default": None}, + ] + + rs_dialect = dialect.RocksetDialect() + columns = rs_dialect.get_columns(engine, "people") + assert expected_results == columns