From 0a6df41d1ceb8e5ceddb3097690ae1d0e3ea7e68 Mon Sep 17 00:00:00 2001 From: AD Date: Fri, 20 Jun 2025 20:59:02 +0530 Subject: [PATCH] Add type annotations to snowflake.connector.connect --- src/snowflake/connector/__init__.py | 19 ++++++++++++++++--- test/unit/test_type_check.py | 13 +++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 test/unit/test_type_check.py diff --git a/src/snowflake/connector/__init__.py b/src/snowflake/connector/__init__.py index 41b5288ac7..ce570d3072 100644 --- a/src/snowflake/connector/__init__.py +++ b/src/snowflake/connector/__init__.py @@ -45,13 +45,26 @@ from .log_configuration import EasyLoggingConfigPython from .version import VERSION +from typing import TypeVar, ParamSpec, Unpack + +P = ParamSpec("P") +T = TypeVar("T", bound=SnowflakeConnection) + logging.getLogger(__name__).addHandler(NullHandler()) setup_external_libraries() - @wraps(SnowflakeConnection.__init__) -def Connect(**kwargs) -> SnowflakeConnection: - return SnowflakeConnection(**kwargs) +def connect( + __cls: type[T] = SnowflakeConnection, + /, + *args: P.args, + **kwargs: Unpack[P.kwargs] +) -> T: + return __cls(*args, **kwargs) + +# @wraps(SnowflakeConnection.__init__) +# def Connect(**kwargs) -> SnowflakeConnection: +# return SnowflakeConnection(**kwargs) connect = Connect diff --git a/test/unit/test_type_check.py b/test/unit/test_type_check.py new file mode 100644 index 0000000000..84f7976505 --- /dev/null +++ b/test/unit/test_type_check.py @@ -0,0 +1,13 @@ +import snowflake.connector as conn + +c = conn.connect( + user="user", + password="pass", + account="account" +) + +invalid = conn.connect( + user="user", + password=123, + account="account" +)