Skip to content

Commit 282d60e

Browse files
committed
Fix - Presto SQLAlchemy dialect did not implement get_view_names
1 parent ea75fa8 commit 282d60e

File tree

2 files changed

+66
-5
lines changed

2 files changed

+66
-5
lines changed

python/pyhive/sqlalchemy_presto.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sqlalchemy.dialects import mysql
2424
mysql_tinyinteger = mysql.base.MSTinyInteger
2525
from sqlalchemy.engine import default
26-
from sqlalchemy.sql import compiler
26+
from sqlalchemy.sql import compiler, bindparam
2727
from sqlalchemy.sql.compiler import SQLCompiler
2828

2929
from pyhive import presto
@@ -205,10 +205,44 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
205205
return []
206206

207207
def get_table_names(self, connection, schema=None, **kw):
208-
query = 'SHOW TABLES'
209208
if schema:
210-
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
211-
return [row.Table for row in connection.execute(text(query))]
209+
table_name_query = """
210+
SELECT table_name FROM information_schema.tables
211+
WHERE table_type = 'TABLE'
212+
AND table_schema = :schema
213+
"""
214+
query = text(table_name_query).bindparams(
215+
bindparam("schema", type_=types.Unicode)
216+
)
217+
else:
218+
table_name_query = """
219+
SELECT table_name FROM information_schema.tables
220+
WHERE table_type = 'VIEW'
221+
"""
222+
query = text(table_name_query)
223+
224+
result = connection.execute(query, dict(schema=schema))
225+
return [row[0] for row in result]
226+
227+
def get_view_names(self, connection, schema=None, **kw):
228+
if schema:
229+
view_name_query = """
230+
SELECT table_name FROM information_schema.tables
231+
WHERE table_type = 'VIEW'
232+
AND table_schema = :schema
233+
"""
234+
query = text(view_name_query).bindparams(
235+
bindparam("schema", type_=types.Unicode)
236+
)
237+
else:
238+
view_name_query = """
239+
SELECT table_name FROM information_schema.tables
240+
WHERE table_type = 'VIEW'
241+
"""
242+
query = text(view_name_query)
243+
244+
result = connection.execute(query, dict(schema=schema))
245+
return [row[0] for row in result]
212246

213247
def do_rollback(self, dbapi_connection):
214248
# No transactions for Presto

python/pyhive/tests/test_sqlalchemy_presto.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,31 @@ def test_hash_table(self, engine, connection):
102102
self.assertFalse(insp.has_table("THIS_TABLE_DOSE_not_exist"))
103103
else:
104104
self.assertFalse(Table('THIS_TABLE_DOSE_NOT_EXIST', MetaData(bind=engine)).exists())
105-
self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', MetaData(bind=engine)).exists())
105+
self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', MetaData(bind=engine)).exists())
106+
107+
@with_engine_connection
108+
def test_reflect_table_names(self, engine, connection):
109+
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
110+
if sqlalchemy_version >= 1.4:
111+
insp = sqlalchemy.inspect(engine)
112+
table_names = insp.get_table_names()
113+
self.assertIn("one_row", table_names)
114+
self.assertIn("one_row_complex", table_names)
115+
self.assertIn("many_rows", table_names)
116+
self.assertNotIn("THIS_TABLE_DOES_not_exist", table_names)
117+
else:
118+
self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
119+
self.assertTrue(Table('one_row_complex', MetaData(bind=engine)).exists())
120+
self.assertTrue(Table('many_rows', MetaData(bind=engine)).exists())
121+
self.assertFalse(Table('THIS_TABLE_DOES_not_exist', MetaData(bind=engine)).exists())
122+
123+
@with_engine_connection
124+
def test_reflect_view_names(self, engine, connection):
125+
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
126+
if sqlalchemy_version >= 1.4:
127+
insp = sqlalchemy.inspect(engine)
128+
view_names = insp.get_view_names()
129+
self.assertNotIn("one_row", view_names)
130+
self.assertNotIn("one_row_complex", view_names)
131+
self.assertNotIn("many_rows", view_names)
132+
self.assertNotIn("THIS_TABLE_DOES_not_exist", view_names)

0 commit comments

Comments
 (0)