Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
},
)

class Base(DeclarativeBase): pass

class Base(DeclarativeBase):
pass


class Person(Base):
__tablename__ = "people"
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
[tool.isort]
profile = "black"

[tool.pytest.ini_options]
pythonpath = [
"src"
]
14 changes: 6 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@
entry_points={
"sqlalchemy.dialects": [
"rockset_sqlalchemy = rockset_sqlalchemy.sqlalchemy:RocksetDialect",
"rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect"
"rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The version needs to be bumped, right?

]
},
install_requires=[
"rockset>=1.0.0",
"sqlalchemy>=1.4.0"
],
install_requires=["rockset>=1.0.0", "sqlalchemy>=1.4.0"],
tests_require=["pytest>=8.2.1"],
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
)
11 changes: 6 additions & 5 deletions src/rockset_sqlalchemy/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,13 @@ def description(self):
return None

desc = []
for field_name, field_value in self._response.results[0].items():
name, type_ = field_name, Cursor.__convert_to_rockset_type(field_value)
null_ok = name != "_id" and "__id" not in name
if len(self._response.results) > 0:
for field_name, field_value in self._response.results[0].items():
name, type_ = field_name, Cursor.__convert_to_rockset_type(field_value)
null_ok = name != "_id" and "__id" not in name

# name, type_code, display_size, internal_size, precision, scale, null_ok
desc.append((name, type_, None, None, None, None, null_ok))
# name, type_code, display_size, internal_size, precision, scale, null_ok
desc.append((name, type_, None, None, None, None, null_ok))
return desc

def __iter__(self):
Expand Down
71 changes: 67 additions & 4 deletions src/rockset_sqlalchemy/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from sqlalchemy import exc, types, util
from typing import List, Any, Sequence

from sqlalchemy import exc, types, util, Engine
from sqlalchemy.engine import default, reflection
from sqlalchemy.sql import compiler

Expand Down Expand Up @@ -80,16 +82,15 @@ def get_table_names(self, connection, schema=None, **kw):
return [w["name"] for w in tables]

def _get_table_columns(self, connection, table_name, schema):
"""Gets table columns based on a retrieved row from the collection"""
schema = self.identifier_preparer.quote_identifier(schema)
table_name = self.identifier_preparer.quote_identifier(table_name)

# Get a single row and determine the schema from that.
# This assumes the whole collection has a fixed schema of course.
q = f"SELECT * FROM {schema}.{table_name} LIMIT 1"
try:
cursor = connection.connect().connection.cursor()
cursor.execute(q)
fields = cursor.description
fields = self._exec_query_description(connection, q)
if not fields:
# Return a fake schema if the collection is empty.
return [("null", "null")]
Expand All @@ -115,10 +116,72 @@ def _get_table_columns(self, connection, table_name, schema):
raise e
return columns

def _validate_query(self, connection: Engine, query: str):
import rockset.models

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to have this import at the top of the file?


query_request_sql = rockset.models.QueryRequestSql(query=query)
# raises rockset.exceptions.BadRequestException if DESCRIBE is invalid on this collection e.g. rollups
connection.connect().connection._client.Queries.validate(sql=query_request_sql)

def _exec_query(self, connection: Engine, query: str) -> Sequence[Any]:
cursor = connection.connect().connection.cursor()
cursor.execute(query)
return cursor.fetchall()

def _exec_query_description(self, connection: Engine, query: str) -> Sequence[Any]:
cursor = connection.connect().connection.cursor()
cursor.execute(query)
return cursor.description

def _get_table_columns_describe(self, connection, table_name, schema, **kw):
"""Gets table columns based on the query DESCRIBE SomeCollection"""
schema = self.identifier_preparer.quote_identifier(schema)
table_name = self.identifier_preparer.quote_identifier(table_name)

max_field_depth = kw["max_field_depth"] if "max_field_depth" in kw else 1
if not isinstance(max_field_depth, int):
raise ValueError("Query option max_field_depth, must be of type 'int'")

q = f"DESCRIBE {table_name} OPTION(max_field_depth={max_field_depth})"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema needs to be in the DESCRIBE query, no?

self._validate_query(connection, q)

try:
results = self._exec_query(connection, q)
columns = list()
for result in results:
field_name = ".".join(result[0])
field_type = result[1]
if field_type not in type_map.keys():
raise exc.SQLAlchemyError(
"Query returned unsupported type {} in field {} in table {}.{}".format(
field_type, field_name, schema, table_name
)
)
nullable = (
False if result[0][0] == "_id" else True
) # _id is the only field that's not nullable
columns.append(
{
"name": field_name,
"type": type_map[result[1]],
"nullable": nullable,
"default": None,
}
)
except Exception as e:
# TODO: more graceful handling of exceptions.
raise e
return columns

@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
if schema is None:
schema = RocksetDialect.default_schema_name
try:
return self._get_table_columns_describe(connection, table_name, schema)
except Exception as e:
# likely a rollup collection, so revert to old behavior
pass
return self._get_table_columns(connection, table_name, schema)

@reflection.cache
Expand Down
File renamed without changes.
Empty file added tests/unittests/__init__.py
Empty file.
Empty file.
20 changes: 20 additions & 0 deletions tests/unittests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import pytest
from sqlalchemy import create_engine

# unittests are local so ROCKSET_API_KEY can be set to anything
host = "https://api.usw2a1.rockset.com"
api_key = "abc123"


@pytest.fixture
def engine():

return create_engine(
"rockset://",
connect_args={
"api_server": host,
"api_key": api_key,
},
)
87 changes: 87 additions & 0 deletions tests/unittests/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from unittest import mock

import pytest
import rockset

from rockset_sqlalchemy.sqlalchemy import dialect
from rockset_sqlalchemy.sqlalchemy.types import Timestamp, String, Object, NullType


@pytest.fixture
def columns_from_describe_results():
"""Schema results based on Person in example.py"""
return [
(["_event_time"], "timestamp", 8, 8),
(["_id"], "string", 8, 8),
(["info"], "object", 8, 8),
(["name"], "string", 8, 8),
]


@pytest.fixture
def columns_from_select_row_results():
"""Schema results based on Person in example.py"""
return [
("_id", "string", None, None, None, None, False),
("_event_time", "string", None, None, None, None, True),
("_meta", "null", None, None, None, None, True),
("name", "string", None, None, None, None, True),
("info", "object", None, None, None, None, True),
]


def test_get_columns_with_describe_returns_expected_structure(
engine, columns_from_describe_results
):
with mock.patch(
"rockset_sqlalchemy.cursor.Cursor.execute_query", mock.Mock()
) as _, mock.patch(
"rockset_sqlalchemy.sqlalchemy.dialect.RocksetDialect._validate_query",
mock.Mock(),
) as _, mock.patch(
"rockset_sqlalchemy.sqlalchemy.dialect.RocksetDialect._exec_query",
mock.Mock(return_value=columns_from_describe_results),
) as _:
expected_results = [
{
"name": "_event_time",
"type": Timestamp,
"nullable": True,
"default": None,
},
{"name": "_id", "type": String, "nullable": False, "default": None},
{"name": "info", "type": Object, "nullable": True, "default": None},
{"name": "name", "type": String, "nullable": True, "default": None},
]

rs_dialect = dialect.RocksetDialect()
columns = rs_dialect.get_columns(engine, "people")
assert expected_results == columns


def test_get_columns_falls_back_on_rockset_exception_to_select_one_row(
engine, columns_from_select_row_results
):
with mock.patch(
"rockset_sqlalchemy.sqlalchemy.dialect.RocksetDialect._get_table_columns_describe",
mock.Mock(side_effect=rockset.exceptions.RocksetException()),
) as _, mock.patch(
"rockset_sqlalchemy.sqlalchemy.dialect.RocksetDialect._exec_query_description",
mock.Mock(return_value=columns_from_select_row_results),
) as _:
expected_results = [
{"name": "_id", "type": String, "nullable": False, "default": None},
{"name": "_event_time", "type": String, "nullable": True, "default": None},
{
"name": "_meta",
"type": NullType,
"nullable": True,
"default": None,
},
{"name": "name", "type": String, "nullable": True, "default": None},
{"name": "info", "type": Object, "nullable": True, "default": None},
]

rs_dialect = dialect.RocksetDialect()
columns = rs_dialect.get_columns(engine, "people")
assert expected_results == columns