Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1825621 OAuth code flow support #2135

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- Updated the log level for cursor's chunk rowcount from INFO to DEBUG.
- Added a feature to verify if the connection is still good enough to send queries over.
- Added support for base64-encoded DER private key strings in the `private_key` authentication type.
- Added support for OAuth authorization code flow.

- v3.12.4(December 3,2024)
- Fixed a bug where multipart uploads to Azure would be missing their MD5 hashes.
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/connector/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .idtoken import AuthByIdToken
from .keypair import AuthByKeyPair
from .oauth import AuthByOAuth
from .oauth_code import AuthByOauthCode
from .okta import AuthByOkta
from .pat import AuthByPAT
from .usrpwdmfa import AuthByUsrPwdMfa
Expand All @@ -20,6 +21,7 @@
AuthByDefault,
AuthByKeyPair,
AuthByOAuth,
AuthByOauthCode,
AuthByOkta,
AuthByUsrPwdMfa,
AuthByWebBrowser,
Expand All @@ -34,6 +36,7 @@
"AuthByKeyPair",
"AuthByPAT",
"AuthByOAuth",
"AuthByOauthCode",
"AuthByOkta",
"AuthByUsrPwdMfa",
"AuthByWebBrowser",
Expand Down
128 changes: 128 additions & 0 deletions src/snowflake/connector/auth/_http_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

import logging
import os
import select
import socket
import time

from ..compat import IS_WINDOWS
from ..errorcode import ER_NO_HOSTNAME_FOUND
from ..errors import OperationalError

logger = logging.getLogger(__name__)


class AuthHttpServer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we have this to avoid taking additional dependencies in the driver?

(As we move to introduce additional forms of CSP-specific authentication I suspect we'll want to introduce dependencies on the various CSP libraries, which is in part why I ask.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I copied this code all from externalbrowser auth. So it's tried and tested given that I have made no mistakes. I didn't feel like I needed to introduce any dependencies for this OAuth implementation and I'm not sure I'd have time to really investigate different options anymore. I'd like to leave this up to the new owners of the library if you don't mind

"""Simple HTTP server to receive callbacks through for auth purposes."""

def __init__(
self,
hostname: str = "localhost",
sfc-gh-eworoshow marked this conversation as resolved.
Show resolved Hide resolved
buf_size: int = 16384,
) -> None:
self.buf_size = buf_size
self._socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true":
if IS_WINDOWS:
logger.warning(
"Configuration SNOWFLAKE_AUTH_SOCKET_REUSE_PORT is not available in Windows. Ignoring."
)
else:
self._socket_connection.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEPORT, 1
)

try:
self._socket_connection.bind(
(
os.getenv("SF_AUTH_SOCKET_ADDR", hostname),
int(os.getenv("SF_AUTH_SOCKET_PORT", 0)),
)
)
except socket.gaierror as ex:
if ex.args[0] == socket.EAI_NONAME and hostname == "localhost":
raise OperationalError(
msg="localhost is not found. Ensure /etc/hosts has "
sfc-gh-eworoshow marked this conversation as resolved.
Show resolved Hide resolved
"localhost entry.",
errno=ER_NO_HOSTNAME_FOUND,
)
raise
try:
self._socket_connection.listen(0) # no backlog
self.port = self._socket_connection.getsockname()[1]
except Exception:
self._socket_connection.close()

def receive_block(
self,
max_attempts: int = 15,
) -> tuple[list[str], socket.socket]:
"""Receive a message with a maximum attempt count, blocking."""
socket_client = None
while True:
try:
attempts = 0
raw_data = bytearray()

msg_dont_wait = (
os.getenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", "false").lower()
== "true"
)
if IS_WINDOWS:
if msg_dont_wait:
logger.warning(
"Configuration SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT is not available in Windows. Ignoring."
)
msg_dont_wait = False

# when running in a containerized environment, socket_client.recv ocassionally returns an empty byte array
# an immediate successive call to socket_client.recv gets the actual data
while len(raw_data) == 0 and attempts < max_attempts:
attempts += 1
read_sockets, _write_sockets, _exception_sockets = select.select(
[self._socket_connection], [], []
)

if read_sockets[0] is not None:
# Receive the data in small chunks and retransmit it
socket_client, _ = self._socket_connection.accept()

try:
if msg_dont_wait:
# WSL containerized environment sometimes causes socket_client.recv to hang indefinetly
# To avoid this, passing the socket.MSG_DONTWAIT flag which raises BlockingIOError if
# operation would block
logger.debug(
"Calling socket_client.recv with MSG_DONTWAIT flag due to SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT env var"
)
raw_data = socket_client.recv(
BUF_SIZE, socket.MSG_DONTWAIT
)
else:
raw_data = socket_client.recv(self.buf_size)

except BlockingIOError:
logger.debug(
"BlockingIOError raised from socket.recv while attempting to retrieve callback request"
)
if attempts < max_attempts:
sleep_time = 0.25
logger.debug(
f"Waiting {sleep_time} seconds before trying again"
)
time.sleep(sleep_time)
else:
logger.debug("Exceeded retry count")

assert socket_client is not None
return raw_data.decode("utf-8").split("\r\n"), socket_client
except Exception:
if socket_client is not None:
socket_client.shutdown(socket.SHUT_RDWR)
socket_client.close()
Loading
Loading