3
3
from contextlib import contextmanager
4
4
from dataclasses import dataclass
5
5
from itertools import chain , repeat
6
- from typing import Callable , Dict , Mapping , Optional
6
+ from typing import Any , Callable , Dict , Mapping , Optional , Tuple
7
7
8
+ import agate
8
9
import dbt .exceptions
9
10
import pyodbc
10
11
from azure .core .credentials import AccessToken
17
18
)
18
19
from dbt .adapters .base import Credentials
19
20
from dbt .adapters .sql import SQLConnectionManager
20
- from dbt .contracts .connection import AdapterResponse
21
+ from dbt .clients .agate_helper import empty_table
22
+ from dbt .contracts .connection import AdapterResponse , Connection , ConnectionState
21
23
from dbt .events import AdapterLogger
22
24
23
25
from dbt .adapters .sqlserver import __version__
@@ -46,6 +48,7 @@ class SQLServerCredentials(Credentials):
46
48
authentication : Optional [str ] = "sql"
47
49
encrypt : Optional [bool ] = False
48
50
trust_cert : Optional [bool ] = False
51
+ retries : int = 1
49
52
50
53
_ALIASES = {
51
54
"user" : "UID" ,
@@ -287,7 +290,6 @@ def exception_handler(self, sql):
287
290
self .release ()
288
291
except pyodbc .Error :
289
292
logger .debug ("Failed to release connection!" )
290
- pass
291
293
292
294
raise dbt .exceptions .DatabaseException (str (e ).strip ()) from e
293
295
@@ -304,69 +306,73 @@ def exception_handler(self, sql):
304
306
raise dbt .exceptions .RuntimeException (e )
305
307
306
308
@classmethod
307
- def open (cls , connection ) :
309
+ def open (cls , connection : Connection ) -> Connection :
308
310
309
- if connection .state == "open" :
311
+ if connection .state == ConnectionState . OPEN :
310
312
logger .debug ("Connection is already open, skipping open." )
311
313
return connection
312
314
313
- credentials = connection .credentials
315
+ credentials = cls . get_credentials ( connection .credentials )
314
316
315
- try :
316
- con_str = []
317
- con_str .append (f"DRIVER={{{ credentials .driver } }}" )
318
-
319
- if "\\ " in credentials .host :
317
+ con_str = [f"DRIVER={{{ credentials .driver } }}" ]
320
318
321
- # If there is a backslash \ in the host name, the host is a
322
- # SQL Server named instance. In this case then port number has to be omitted.
323
- con_str .append (f"SERVER={ credentials .host } " )
324
- else :
325
- con_str .append (f"SERVER={ credentials .host } ,{ credentials .port } " )
319
+ if "\\ " in credentials .host :
326
320
327
- con_str .append (f"Database={ credentials .database } " )
321
+ # If there is a backslash \ in the host name, the host is a
322
+ # SQL Server named instance. In this case then port number has to be omitted.
323
+ con_str .append (f"SERVER={ credentials .host } " )
324
+ else :
325
+ con_str .append (f"SERVER={ credentials .host } ,{ credentials .port } " )
328
326
329
- type_auth = getattr ( credentials , "authentication" , "sql " )
327
+ con_str . append ( f"Database= { credentials . database } " )
330
328
331
- if "ActiveDirectory" in type_auth :
332
- con_str .append (f"Authentication={ credentials .authentication } " )
329
+ type_auth = getattr (credentials , "authentication" , "sql" )
333
330
334
- if type_auth == "ActiveDirectoryPassword" :
335
- con_str .append (f"UID={{{ credentials .UID } }}" )
336
- con_str .append (f"PWD={{{ credentials .PWD } }}" )
337
- elif type_auth == "ActiveDirectoryInteractive" :
338
- con_str .append (f"UID={{{ credentials .UID } }}" )
331
+ if "ActiveDirectory" in type_auth :
332
+ con_str .append (f"Authentication={ credentials .authentication } " )
339
333
340
- elif getattr (credentials , "windows_login" , False ):
341
- con_str .append ("trusted_connection=yes" )
342
- elif type_auth == "sql" :
334
+ if type_auth == "ActiveDirectoryPassword" :
343
335
con_str .append (f"UID={{{ credentials .UID } }}" )
344
336
con_str .append (f"PWD={{{ credentials .PWD } }}" )
337
+ elif type_auth == "ActiveDirectoryInteractive" :
338
+ con_str .append (f"UID={{{ credentials .UID } }}" )
339
+
340
+ elif getattr (credentials , "windows_login" , False ):
341
+ con_str .append ("trusted_connection=yes" )
342
+ elif type_auth == "sql" :
343
+ con_str .append (f"UID={{{ credentials .UID } }}" )
344
+ con_str .append (f"PWD={{{ credentials .PWD } }}" )
345
+
346
+ # still confused whether to use "Yes", "yes", "True", or "true"
347
+ # to learn more visit
348
+ # https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15
349
+ if getattr (credentials , "encrypt" , False ) is True :
350
+ con_str .append ("Encrypt=Yes" )
351
+ if getattr (credentials , "trust_cert" , False ) is True :
352
+ con_str .append ("TrustServerCertificate=Yes" )
345
353
346
- # still confused whether to use "Yes", "yes", "True", or "true"
347
- # to learn more visit
348
- # https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15
349
- if getattr (credentials , "encrypt" , False ) is True :
350
- con_str .append ("Encrypt=Yes" )
351
- if getattr (credentials , "trust_cert" , False ) is True :
352
- con_str .append ("TrustServerCertificate=Yes" )
354
+ plugin_version = __version__ .version
355
+ application_name = f"dbt-{ credentials .type } /{ plugin_version } "
356
+ con_str .append (f"Application Name={ application_name } " )
353
357
354
- plugin_version = __version__ .version
355
- application_name = f"dbt-{ credentials .type } /{ plugin_version } "
356
- con_str .append (f"Application Name={ application_name } " )
358
+ con_str_concat = ";" .join (con_str )
357
359
358
- con_str_concat = ";" .join (con_str )
360
+ index = []
361
+ for i , elem in enumerate (con_str ):
362
+ if "pwd=" in elem .lower ():
363
+ index .append (i )
359
364
360
- index = []
361
- for i , elem in enumerate (con_str ):
362
- if "pwd=" in elem .lower ():
363
- index .append (i )
365
+ if len (index ) != 0 :
366
+ con_str [index [0 ]] = "PWD=***"
364
367
365
- if len (index ) != 0 :
366
- con_str [index [0 ]] = "PWD=***"
368
+ con_str_display = ";" .join (con_str )
367
369
368
- con_str_display = ";" .join (con_str )
370
+ retryable_exceptions = [ # https://github.com/mkleehammer/pyodbc/wiki/Exceptions
371
+ pyodbc .InternalError , # not used according to docs, but defined in PEP-249
372
+ pyodbc .OperationalError ,
373
+ ]
369
374
375
+ def connect ():
370
376
logger .debug (f"Using connection string: { con_str_display } " )
371
377
372
378
attrs_before = get_pyodbc_attrs_before (credentials )
@@ -375,24 +381,19 @@ def open(cls, connection):
375
381
attrs_before = attrs_before ,
376
382
autocommit = True ,
377
383
)
378
-
379
- connection .state = "open"
380
- connection .handle = handle
381
384
logger .debug (f"Connected to db: { credentials .database } " )
385
+ return handle
386
+
387
+ return cls .retry_connection (
388
+ connection ,
389
+ connect = connect ,
390
+ logger = logger ,
391
+ retry_limit = credentials .retries ,
392
+ retryable_exceptions = retryable_exceptions ,
393
+ )
382
394
383
- except pyodbc .Error as e :
384
- logger .debug (f"Could not connect to db: { e } " )
385
-
386
- connection .handle = None
387
- connection .state = "fail"
388
-
389
- raise dbt .exceptions .FailedToConnectException (str (e ))
390
-
391
- return connection
392
-
393
- def cancel (self , connection ):
395
+ def cancel (self , connection : Connection ):
394
396
logger .debug ("Cancel query" )
395
- pass
396
397
397
398
def add_begin_query (self ):
398
399
# return self.add_query('BEGIN TRANSACTION', auto_begin=False)
@@ -402,7 +403,13 @@ def add_commit_query(self):
402
403
# return self.add_query('COMMIT TRANSACTION', auto_begin=False)
403
404
pass
404
405
405
- def add_query (self , sql , auto_begin = True , bindings = None , abridge_sql_log = False ):
406
+ def add_query (
407
+ self ,
408
+ sql : str ,
409
+ auto_begin : bool = True ,
410
+ bindings : Optional [Any ] = None ,
411
+ abridge_sql_log : bool = False ,
412
+ ) -> Tuple [Connection , Any ]:
406
413
407
414
connection = self .get_thread_connection ()
408
415
@@ -435,11 +442,11 @@ def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
435
442
return connection , cursor
436
443
437
444
@classmethod
438
- def get_credentials (cls , credentials ) :
445
+ def get_credentials (cls , credentials : SQLServerCredentials ) -> SQLServerCredentials :
439
446
return credentials
440
447
441
448
@classmethod
442
- def get_response (cls , cursor ) -> AdapterResponse :
449
+ def get_response (cls , cursor : Any ) -> AdapterResponse :
443
450
# message = str(cursor.statusmessage)
444
451
message = "OK"
445
452
rows = cursor .rowcount
@@ -456,7 +463,9 @@ def get_response(cls, cursor) -> AdapterResponse:
456
463
rows_affected = rows ,
457
464
)
458
465
459
- def execute (self , sql , auto_begin = True , fetch = False ):
466
+ def execute (
467
+ self , sql : str , auto_begin : bool = True , fetch : bool = False
468
+ ) -> Tuple [AdapterResponse , agate .Table ]:
460
469
_ , cursor = self .add_query (sql , auto_begin )
461
470
response = self .get_response (cursor )
462
471
if fetch :
@@ -466,7 +475,7 @@ def execute(self, sql, auto_begin=True, fetch=False):
466
475
break
467
476
table = self .get_result_from_cursor (cursor )
468
477
else :
469
- table = dbt . clients . agate_helper . empty_table ()
478
+ table = empty_table ()
470
479
# Step through all result sets so we process all errors
471
480
while cursor .nextset ():
472
481
pass
0 commit comments