Skip to content

Commit

Permalink
Support base64-encoded DER private keys (#2134)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-eworoshow authored Jan 13, 2025
1 parent 55f831e commit 4b3ded1
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 5 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- Added a feature to limit the sizes of IO-bound ThreadPoolExecutors during PUT and GET commands.
- Adding support for the new PAT authentication method.
- Updated README.md to include instructions on how to verify package signatures using `cosign`.
- Updated the log level for cursor's chunk rowcount from INFO to DEBUG
- Updated the log level for cursor's chunk rowcount from INFO to DEBUG.
- Added a feature to verify if the connection is still good enough to send queries over.
- Added support for base64-encoded DER private key strings in the `private_key` authentication type.

- v3.12.4(December 3,2024)
- Fixed a bug where multipart uploads to Azure would be missing their MD5 hashes.
Expand Down
15 changes: 13 additions & 2 deletions src/snowflake/connector/auth/keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class AuthByKeyPair(AuthByPlugin):

def __init__(
self,
private_key: bytes | RSAPrivateKey,
private_key: bytes | str | RSAPrivateKey,
lifetime_in_seconds: int = LIFETIME,
**kwargs,
) -> None:
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
).total_seconds()
)

self._private_key: bytes | RSAPrivateKey | None = private_key
self._private_key: bytes | str | RSAPrivateKey | None = private_key
self._jwt_token = ""
self._jwt_token_exp = 0
self._lifetime = timedelta(
Expand Down Expand Up @@ -105,6 +105,17 @@ def prepare(

now = datetime.now(timezone.utc).replace(tzinfo=None)

if isinstance(self._private_key, str):
try:
self._private_key = base64.b64decode(self._private_key)
except Exception as e:
raise ProgrammingError(
msg=f"Failed to decode private key: {e}\nPlease provide a valid "
"unencrypted rsa private key in base64-encoded DER format as a "
"str object",
errno=ER_INVALID_PRIVATE_KEY,
)

if isinstance(self._private_key, bytes):
try:
private_key = load_der_private_key(
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _get_private_bytes_from_file(
"backoff_policy": (DEFAULT_BACKOFF_POLICY, Callable),
"passcode_in_password": (False, bool), # Snowflake MFA
"passcode": (None, (type(None), str)), # Snowflake MFA
"private_key": (None, (type(None), bytes, RSAPrivateKey)),
"private_key": (None, (type(None), bytes, str, RSAPrivateKey)),
"private_key_file": (None, (type(None), str)),
"private_key_file_pwd": (None, (type(None), str, bytes)),
"token": (None, (type(None), str)), # OAuth/JWT/PAT Token
Expand Down
6 changes: 6 additions & 0 deletions test/integ/test_key_pair_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import base64
import uuid
from datetime import datetime, timedelta, timezone
from os import path
Expand Down Expand Up @@ -126,6 +127,11 @@ def fin():
with snowflake.connector.connect(**db_config) as _:
pass

# Ensure the base64-encoded version also works
db_config["private_key"] = base64.b64encode(private_key_der)
with snowflake.connector.connect(**db_config) as _:
pass


@pytest.mark.skipolddriver
def test_multiple_key_pair(is_public_test, request, conn_cnx, db_parameters):
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_auth_keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_auth_keypair_bad_type():
class Bad:
pass

for bad_private_key in ("abcd", 1234, Bad()):
for bad_private_key in (1234, Bad()):
auth_instance = AuthByKeyPair(private_key=bad_private_key)
with raises(TypeError) as ex:
auth_instance.prepare(account=account, user=user)
Expand Down

0 comments on commit 4b3ded1

Please sign in to comment.