21
21
22
22
import logging
23
23
from datetime import datetime , date
24
+ from types import ModuleType
24
25
25
26
from sqlalchemy import types as sqltypes
26
27
from sqlalchemy .engine import default , reflection
@@ -207,6 +208,12 @@ def initialize(self, connection):
207
208
self .default_schema_name = \
208
209
self ._get_default_schema_name (connection )
209
210
211
+ def set_isolation_level (self , dbapi_connection , level ):
212
+ """
213
+ For CrateDB, this is implemented as a noop.
214
+ """
215
+ pass
216
+
210
217
def do_rollback (self , connection ):
211
218
# if any exception is raised by the dbapi, sqlalchemy by default
212
219
# attempts to do a rollback crate doesn't support rollbacks.
@@ -225,7 +232,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
225
232
use_ssl = asbool (kwargs .pop ("ssl" , False ))
226
233
if use_ssl :
227
234
servers = ["https://" + server for server in servers ]
228
- return self .dbapi .connect (servers = servers , ** kwargs )
235
+
236
+ is_module = isinstance (self .dbapi , ModuleType )
237
+ if is_module :
238
+ driver_name = self .dbapi .__name__
239
+ else :
240
+ driver_name = self .dbapi .__class__ .__name__
241
+ if driver_name == "crate.client" :
242
+ if "database" in kwargs :
243
+ del kwargs ["database" ]
244
+ return self .dbapi .connect (servers = servers , ** kwargs )
245
+ elif driver_name in ["psycopg" , "PsycopgAdaptDBAPI" , "AsyncAdapt_asyncpg_dbapi" ]:
246
+ return self .dbapi .connect (host = host , port = port , ** kwargs )
247
+ else :
248
+ raise ValueError (f"Unknown driver variant: { driver_name } " )
249
+
229
250
return self .dbapi .connect (** kwargs )
230
251
231
252
def _get_default_schema_name (self , connection ):
@@ -271,11 +292,11 @@ def get_schema_names(self, connection, **kw):
271
292
def get_table_names (self , connection , schema = None , ** kw ):
272
293
if schema is None :
273
294
schema = self ._get_effective_schema_name (connection )
274
- cursor = connection .exec_driver_sql (
295
+ cursor = connection .exec_driver_sql (self . _format_query (
275
296
"SELECT table_name FROM information_schema.tables "
276
297
"WHERE {0} = ? "
277
298
"AND table_type = 'BASE TABLE' "
278
- "ORDER BY table_name ASC, {0} ASC" .format (self .schema_column ),
299
+ "ORDER BY table_name ASC, {0} ASC" ) .format (self .schema_column ),
279
300
(schema or self .default_schema_name , )
280
301
)
281
302
return [row [0 ] for row in cursor .fetchall ()]
@@ -297,7 +318,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
297
318
"AND column_name !~ ?" \
298
319
.format (self .schema_column )
299
320
cursor = connection .exec_driver_sql (
300
- query ,
321
+ self . _format_query ( query ) ,
301
322
(table_name ,
302
323
schema or self .default_schema_name ,
303
324
r"(.*)\[\'(.*)\'\]" ) # regex to filter subscript
@@ -336,7 +357,7 @@ def result_fun(result):
336
357
return set (rows [0 ] if rows else [])
337
358
338
359
pk_result = engine .exec_driver_sql (
339
- query ,
360
+ self . _format_query ( query ) ,
340
361
(table_name , schema or self .default_schema_name )
341
362
)
342
363
pks = result_fun (pk_result )
@@ -377,6 +398,17 @@ def has_ilike_operator(self):
377
398
server_version_info = self .server_version_info
378
399
return server_version_info is not None and server_version_info >= (4 , 1 , 0 )
379
400
401
+ def _format_query (self , query ):
402
+ """
403
+ When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
404
+ the paramstyle is not `qmark`, but `pyformat`.
405
+
406
+ TODO: Review: Is it legit and sane? Are there alternatives?
407
+ """
408
+ if self .paramstyle == "pyformat" :
409
+ query = query .replace ("= ?" , "= %s" ).replace ("!~ ?" , "!~ %s" )
410
+ return query
411
+
380
412
381
413
class DateTrunc (functions .GenericFunction ):
382
414
name = "date_trunc"
0 commit comments