diff --git a/python/pyhive/sqlalchemy_presto.py b/python/pyhive/sqlalchemy_presto.py index 33a41bae3e7..f5a256fb8d7 100644 --- a/python/pyhive/sqlalchemy_presto.py +++ b/python/pyhive/sqlalchemy_presto.py @@ -23,7 +23,7 @@ from sqlalchemy.dialects import mysql mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, bindparam from sqlalchemy.sql.compiler import SQLCompiler from pyhive import presto @@ -204,12 +204,45 @@ def get_indexes(self, connection, table_name, schema=None, **kw): else: return [] + def _get_default_schema_name(self, connection): + #'SELECT CURRENT_SCHEMA()' + return super()._get_default_schema_name(connection) + def get_table_names(self, connection, schema=None, **kw): query = 'SHOW TABLES' + # N.B. This is incorrect, if no schema is provided, the current/default schema should be used + # with a call to an overridden self._get_default_schema_name(connection), but I could not + # see how to implement that as there is no CURRENT_SCHEMA function + # default_schema = self._get_default_schema_name(connection) + if schema: query += ' FROM ' + self.identifier_preparer.quote_identifier(schema) return [row.Table for row in connection.execute(text(query))] + def get_view_names(self, connection, schema=None, **kw): + if schema: + view_name_query = """ + SELECT table_name + FROM information_schema.views + WHERE table_schema = :schema + """ + query = text(view_name_query).bindparams( + bindparam("schema", type_=types.Unicode) + ) + else: + # N.B. This is incorrect, if no schema is provided, the current/default schema should + # be used with a call to self._get_default_schema_name(connection), but I could not + # see how to implement that + # default_schema = self._get_default_schema_name(connection) + view_name_query = """ + SELECT table_name + FROM information_schema.views + """ + query = text(view_name_query) + + result = connection.execute(query, dict(schema=schema)) + return [row[0] for row in result] + def do_rollback(self, dbapi_connection): # No transactions for Presto pass diff --git a/python/pyhive/tests/test_sqlalchemy_presto.py b/python/pyhive/tests/test_sqlalchemy_presto.py index 336dd12e243..e8b04ea249c 100644 --- a/python/pyhive/tests/test_sqlalchemy_presto.py +++ b/python/pyhive/tests/test_sqlalchemy_presto.py @@ -102,4 +102,31 @@ def test_hash_table(self, engine, connection): self.assertFalse(insp.has_table("THIS_TABLE_DOSE_not_exist")) else: self.assertFalse(Table('THIS_TABLE_DOSE_NOT_EXIST', MetaData(bind=engine)).exists()) - self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', MetaData(bind=engine)).exists()) \ No newline at end of file + self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', MetaData(bind=engine)).exists()) + + @with_engine_connection + def test_reflect_table_names(self, engine, connection): + sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) + if sqlalchemy_version >= 1.4: + insp = sqlalchemy.inspect(engine) + table_names = insp.get_table_names() + self.assertIn("one_row", table_names) + self.assertIn("one_row_complex", table_names) + self.assertIn("many_rows", table_names) + self.assertNotIn("THIS_TABLE_DOES_not_exist", table_names) + else: + self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) + self.assertTrue(Table('one_row_complex', MetaData(bind=engine)).exists()) + self.assertTrue(Table('many_rows', MetaData(bind=engine)).exists()) + self.assertFalse(Table('THIS_TABLE_DOES_not_exist', MetaData(bind=engine)).exists()) + + @with_engine_connection + def test_reflect_view_names(self, engine, connection): + sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) + if sqlalchemy_version >= 1.4: + insp = sqlalchemy.inspect(engine) + view_names = insp.get_view_names() + self.assertNotIn("one_row", view_names) + self.assertNotIn("one_row_complex", view_names) + self.assertNotIn("many_rows", view_names) + self.assertNotIn("THIS_TABLE_DOES_not_exist", view_names)