Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion python/pyhive/sqlalchemy_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql import compiler, bindparam
from sqlalchemy.sql.compiler import SQLCompiler

from pyhive import presto
Expand Down Expand Up @@ -204,12 +204,44 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
else:
return []

def _get_default_schema_name(self, connection):
#'SELECT CURRENT_SCHEMA()'
return super()._get_default_schema_name(connection)

def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
# N.B. This is incorrect, if no schema is provided, the current/default schema should be used
# with a call to an overridden self._get_default_schema_name(connection), but I could not
# see how to implement that as there is no CURRENT_SCHEMA function
# default_schema = self._get_default_schema_name(connection)

if schema:
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
return [row.Table for row in connection.execute(text(query))]

def get_view_names(self, connection, schema=None, **kw):
if schema:
view_name_query = """
SELECT table_name
FROM information_schema.views
WHERE table_schema = :schema
"""
query = text(view_name_query).bindparams(
bindparam("schema", type_=types.Unicode)
)
else:
# N.B. This is incorrect, if no schema is provided, the current/default schema should be used
# with a call to self._get_default_schema_name(connection), but I could not see how to implement that
# default_schema = self._get_default_schema_name(connection)
view_name_query = """
SELECT table_name
FROM information_schema.views
"""
query = text(view_name_query)

result = connection.execute(query, dict(schema=schema))
return [row[0] for row in result]

def do_rollback(self, dbapi_connection):
# No transactions for Presto
pass
Expand Down
29 changes: 28 additions & 1 deletion python/pyhive/tests/test_sqlalchemy_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,31 @@ def test_hash_table(self, engine, connection):
self.assertFalse(insp.has_table("THIS_TABLE_DOSE_not_exist"))
else:
self.assertFalse(Table('THIS_TABLE_DOSE_NOT_EXIST', MetaData(bind=engine)).exists())
self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', MetaData(bind=engine)).exists())
self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', MetaData(bind=engine)).exists())

@with_engine_connection
def test_reflect_table_names(self, engine, connection):
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
if sqlalchemy_version >= 1.4:
insp = sqlalchemy.inspect(engine)
table_names = insp.get_table_names()
self.assertIn("one_row", table_names)
self.assertIn("one_row_complex", table_names)
self.assertIn("many_rows", table_names)
self.assertNotIn("THIS_TABLE_DOES_not_exist", table_names)
else:
self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
self.assertTrue(Table('one_row_complex', MetaData(bind=engine)).exists())
self.assertTrue(Table('many_rows', MetaData(bind=engine)).exists())
self.assertFalse(Table('THIS_TABLE_DOES_not_exist', MetaData(bind=engine)).exists())

@with_engine_connection
def test_reflect_view_names(self, engine, connection):
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
if sqlalchemy_version >= 1.4:
insp = sqlalchemy.inspect(engine)
view_names = insp.get_view_names()
self.assertNotIn("one_row", view_names)
self.assertNotIn("one_row_complex", view_names)
self.assertNotIn("many_rows", view_names)
self.assertNotIn("THIS_TABLE_DOES_not_exist", view_names)