diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 3092f944..5f0fdff0 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for the `--draft` option when deploying content, this allows to deploy a new bundle for the content without exposing it as a the activated one. +- Improved support for Posit Connect deployments + hosted in Snowpark Container Services. ### Fixed diff --git a/docs/overrides/partials/header.html b/docs/overrides/partials/header.html index bbead205..90bd21de 100644 --- a/docs/overrides/partials/header.html +++ b/docs/overrides/partials/header.html @@ -80,7 +80,7 @@ {% endif %}
diff --git a/rsconnect/api.py b/rsconnect/api.py index a095d273..c8dac161 100644 --- a/rsconnect/api.py +++ b/rsconnect/api.py @@ -76,7 +76,7 @@ TaskStatusV1, UserRecord, ) -from .snowflake import generate_jwt, get_connection_parameters +from .snowflake import generate_jwt, get_parameters from .timeouts import get_task_timeout, get_task_timeout_help_message if TYPE_CHECKING: @@ -260,27 +260,53 @@ def __init__( self.bootstrap_jwt = None def token_endpoint(self) -> str: - params = get_connection_parameters(self.snowflake_connection_name) + params = get_parameters(self.snowflake_connection_name) if params is None: raise RSConnectException("No Snowflake connection found.") return "https://{}.snowflakecomputing.com/".format(params["account"]) - def fmt_payload(self) -> str: - params = get_connection_parameters(self.snowflake_connection_name) + def fmt_payload(self): + params = get_parameters(self.snowflake_connection_name) if params is None: raise RSConnectException("No Snowflake connection found.") - spcs_url = urlparse(self.url) - scope = "session:role:{} {}".format(params["role"], spcs_url.netloc) - jwt = generate_jwt(self.snowflake_connection_name) - grant_type = "urn:ietf:params:oauth:grant-type:jwt-bearer" - - payload = {"scope": scope, "assertion": jwt, "grant_type": grant_type} - payload = urlencode(payload) - return payload + authenticator = params.get("authenticator") + if authenticator == "SNOWFLAKE_JWT": + spcs_url = urlparse(self.url) + scope = ( + "session:role:{} {}".format(params["role"], spcs_url.netloc) if params.get("role") else spcs_url.netloc + ) + jwt = generate_jwt(self.snowflake_connection_name) + grant_type = "urn:ietf:params:oauth:grant-type:jwt-bearer" + + payload = {"scope": scope, "assertion": jwt, "grant_type": grant_type} + payload = urlencode(payload) + return { + "body": payload, + "headers": {"Content-Type": "application/x-www-form-urlencoded"}, + "path": "/oauth/token", + } + elif authenticator == "oauth": + payload = { + "data": { + "AUTHENTICATOR": "OAUTH", + "TOKEN": params["token"], + } + } + return { + "body": payload, + "headers": { + "Content-Type": "application/json", + "Authorization": "Bearer %s" % params["token"], + "X-Snowflake-Authorization-Token-Type": "OAUTH", + }, + "path": "/session/v1/login-request", + } + else: + raise NotImplementedError("Unsupported authenticator for SPCS Connect: %s" % authenticator) def exchange_token(self) -> str: try: @@ -288,12 +314,8 @@ def exchange_token(self) -> str: payload = self.fmt_payload() response = server.request( - method="POST", - path="/oauth/token", - body=payload, - headers={"Content-Type": "application/x-www-form-urlencoded"}, + method="POST", **payload # type: ignore[arg-type] # fmt_payload returns a dict with body and headers ) - response = cast(HTTPResponse, response) # borrowed from AbstractRemoteServer.handle_bad_response @@ -313,10 +335,24 @@ def exchange_token(self) -> str: if not response.response_body: raise RSConnectException("Token exchange returned empty response") - # Ensure we return a string + # Ensure response body is decoded to string on the object if isinstance(response.response_body, bytes): - return response.response_body.decode("utf-8") - return response.response_body + response.response_body = response.response_body.decode("utf-8") + + # Try to parse as JSON first + try: + import json + + json_data = json.loads(response.response_body) + # If it's JSON, extract the token from data.token + if isinstance(json_data, dict) and "data" in json_data and "token" in json_data["data"]: + return json_data["data"]["token"] + else: + # JSON format doesn't match expected structure, return raw response + return response.response_body + except (json.JSONDecodeError, ValueError): + # Not JSON, return the raw response body + return response.response_body except RSConnectException as e: raise RSConnectException(f"Failed to exchange Snowflake token: {str(e)}") from e diff --git a/rsconnect/snowflake.py b/rsconnect/snowflake.py index 0701d735..8fbf9d7d 100644 --- a/rsconnect/snowflake.py +++ b/rsconnect/snowflake.py @@ -40,27 +40,43 @@ def list_connections() -> List[Dict[str, Any]]: raise RSConnectException("Could not list snowflake connections.") -def get_connection_parameters(name: Optional[str] = None) -> Optional[Dict[str, Any]]: +def get_parameters(name: Optional[str] = None) -> Dict[str, Any]: + """Get Snowflake connection parameters. + Args: + name: The name of the connection to retrieve. If None, returns the default connection. + + Returns: + A dictionary of connection parameters. + """ + try: + from snowflake.connector.config_manager import CONFIG_MANAGER + except ImportError: + raise RSConnectException("snowflake-cli is not installed.") + try: + connections = CONFIG_MANAGER["connections"] + if not isinstance(connections, dict): + raise TypeError("connections is not a dictionary") + + if name is None: + def_connection_name = CONFIG_MANAGER["default_connection_name"] + if not isinstance(def_connection_name, str): + raise TypeError("default_connection_name is not a string") + params = connections[def_connection_name] + else: + params = connections[name] - connection_list = list_connections() - # return parameters for default connection if configured - # otherwise return named connection + if not isinstance(params, dict): + raise TypeError("connection parameters is not a dictionary") - if not connection_list: - raise RSConnectException("No Snowflake connections found.") + return {str(k): v for k, v in params.items()} - try: - if not name: - return next((x["parameters"] for x in connection_list if x.get("is_default")), None) - else: - return next((x["parameters"] for x in connection_list if x.get("connection_name") == name)) - except StopIteration: - raise RSConnectException(f"No Snowflake connection found with name '{name}'.") + except (KeyError, AttributeError) as e: + raise RSConnectException(f"Could not get Snowflake connection: {e}") def generate_jwt(name: Optional[str] = None) -> str: - _ = get_connection_parameters(name) + _ = get_parameters(name) connection_name = "" if name is None else name try: diff --git a/tests/test_api.py b/tests/test_api.py index 910efb09..1ea3bbd9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -526,41 +526,48 @@ def test_token_endpoint(self, mock_token_endpoint): endpoint = server.token_endpoint() assert endpoint == "https://example.snowflakecomputing.com/" - @patch("rsconnect.api.get_connection_parameters") - def test_token_endpoint_with_account(self, mock_get_connection_parameters): + @patch("rsconnect.api.get_parameters") + def test_token_endpoint_with_account(self, mock_get_parameters): server = SPCSConnectServer("https://spcs.example.com", "example_connection") - mock_get_connection_parameters.return_value = {"account": "test_account"} + mock_get_parameters.return_value = {"account": "test_account"} endpoint = server.token_endpoint() assert endpoint == "https://test_account.snowflakecomputing.com/" - mock_get_connection_parameters.assert_called_once_with("example_connection") + mock_get_parameters.assert_called_once_with("example_connection") - @patch("rsconnect.api.get_connection_parameters") - def test_token_endpoint_with_none_params(self, mock_get_connection_parameters): + @patch("rsconnect.api.get_parameters") + def test_token_endpoint_with_none_params(self, mock_get_parameters): server = SPCSConnectServer("https://spcs.example.com", "example_connection") - mock_get_connection_parameters.return_value = None + mock_get_parameters.return_value = None with pytest.raises(RSConnectException, match="No Snowflake connection found."): server.token_endpoint() - @patch("rsconnect.api.get_connection_parameters") - def test_fmt_payload(self, mock_get_connection_parameters): + @patch("rsconnect.api.get_parameters") + def test_fmt_payload(self, mock_get_parameters): server = SPCSConnectServer("https://spcs.example.com", "example_connection") - mock_get_connection_parameters.return_value = {"account": "test_account", "role": "test_role"} + mock_get_parameters.return_value = { + "account": "test_account", + "role": "test_role", + "authenticator": "SNOWFLAKE_JWT", + } with patch("rsconnect.api.generate_jwt") as mock_generate_jwt: mock_generate_jwt.return_value = "mocked_jwt" payload = server.fmt_payload() - assert "scope=session%3Arole%3Atest_role+spcs.example.com" in payload - assert "assertion=mocked_jwt" in payload - assert "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer" in payload + assert ( + payload["body"] + == "scope=session%3Arole%3Atest_role+spcs.example.com&assertion=mocked_jwt&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer" # noqa + ) + assert payload["headers"] == {"Content-Type": "application/x-www-form-urlencoded"} + assert payload["path"] == "/oauth/token" - mock_get_connection_parameters.assert_called_once_with("example_connection") + mock_get_parameters.assert_called_once_with("example_connection") mock_generate_jwt.assert_called_once_with("example_connection") - @patch("rsconnect.api.get_connection_parameters") - def test_fmt_payload_with_none_params(self, mock_get_connection_parameters): + @patch("rsconnect.api.get_parameters") + def test_fmt_payload_with_none_params(self, mock_get_parameters): server = SPCSConnectServer("https://spcs.example.com", "example_connection") - mock_get_connection_parameters.return_value = None + mock_get_parameters.return_value = None with pytest.raises(RSConnectException, match="No Snowflake connection found."): server.fmt_payload() @@ -579,7 +586,11 @@ def test_exchange_token_success(self, mock_fmt_payload, mock_token_endpoint, moc # Mock the token endpoint and payload mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/" - mock_fmt_payload.return_value = "mocked_payload" + mock_fmt_payload.return_value = { + "body": "mocked_payload_body", + "headers": {"Content-Type": "application/x-www-form-urlencoded"}, + "path": "/oauth/token", + } # Call the method result = server.exchange_token() @@ -589,9 +600,9 @@ def test_exchange_token_success(self, mock_fmt_payload, mock_token_endpoint, moc mock_http_server.assert_called_once_with(url="https://example.snowflakecomputing.com/") mock_server_instance.request.assert_called_once_with( method="POST", - path="/oauth/token", - body="mocked_payload", + body="mocked_payload_body", headers={"Content-Type": "application/x-www-form-urlencoded"}, + path="/oauth/token", ) @patch("rsconnect.api.HTTPServer") @@ -610,7 +621,11 @@ def test_exchange_token_error_status(self, mock_fmt_payload, mock_token_endpoint # Mock the token endpoint and payload mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/" - mock_fmt_payload.return_value = "mocked_payload" + mock_fmt_payload.return_value = { + "body": "mocked_payload_body", + "headers": {"Content-Type": "application/x-www-form-urlencoded"}, + "path": "/oauth/token", + } # Call the method and verify it raises the expected exception with pytest.raises(RSConnectException, match="Failed to exchange Snowflake token"): @@ -631,7 +646,11 @@ def test_exchange_token_empty_response(self, mock_fmt_payload, mock_token_endpoi # Mock the token endpoint and payload mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/" - mock_fmt_payload.return_value = "mocked_payload" + mock_fmt_payload.return_value = { + "body": "mocked_payload_body", + "headers": {"Content-Type": "application/x-www-form-urlencoded"}, + "path": "/oauth/token", + } # Call the method and verify it raises the expected exception with pytest.raises( diff --git a/tests/test_snowflake.py b/tests/test_snowflake.py index 2f62fcf7..763eeecf 100644 --- a/tests/test_snowflake.py +++ b/tests/test_snowflake.py @@ -11,7 +11,7 @@ from rsconnect.snowflake import ( ensure_snow_installed, generate_jwt, - get_connection_parameters, + get_parameters, list_connections, ) @@ -187,49 +187,93 @@ def mock_snow(*args): assert connections[1]["is_default"] is True -def test_get_connection_noname_default(monkeypatch: MonkeyPatch): - # Test that get_connection_parameters() returns parameters from +def test_get_parameters_noname_default(monkeypatch: MonkeyPatch): + # Test that get_parameters() returns parameters from # the default connection when no name is provided - monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: SAMPLE_CONNECTIONS) - monkeypatch.setattr("rsconnect.snowflake.ensure_snow_installed", lambda: None) + mock_config_manager = { + "default_connection_name": "prod", + "connections": {"prod": {"account": "example-prod-acct", "role": "DEVELOPER"}}, + } - connection = get_connection_parameters() + # Mock the import inside get_parameters + def mock_import(name, *args, **kwargs): + if name == "snowflake.connector.config_manager": + # Create a mock module with CONFIG_MANAGER + mock_module = type("mock_module", (), {}) + mock_module.CONFIG_MANAGER = mock_config_manager + return mock_module + return original_import(name, *args, **kwargs) - assert connection["account"] == "example-prod-acct" - assert connection["role"] == "DEVELOPER" + monkeypatch.setattr("builtins.__import__", mock_import) + params = get_parameters() -def test_get_connection_named(monkeypatch: MonkeyPatch): - # Test that get_connection_parameters() returns the specified connection when a name is provided + assert params["account"] == "example-prod-acct" + assert params["role"] == "DEVELOPER" - monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: SAMPLE_CONNECTIONS) - monkeypatch.setattr("rsconnect.snowflake.ensure_snow_installed", lambda: None) - connection = get_connection_parameters("dev") +def test_get_parameters_named(monkeypatch: MonkeyPatch): + # Test that get_parameters() returns the specified connection when a name is provided + + mock_config_manager = {"connections": {"dev": {"account": "example-dev-acct", "role": "ACCOUNTADMIN"}}} + + # Mock the import inside get_parameters + def mock_import(name, *args, **kwargs): + if name == "snowflake.connector.config_manager": + # Create a mock module with CONFIG_MANAGER + mock_module = type("mock_module", (), {}) + mock_module.CONFIG_MANAGER = mock_config_manager + return mock_module + return original_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", mock_import) + + params = get_parameters("dev") # Should return the connection with the specified name - assert connection["account"] == "example-dev-acct" - assert connection["role"] == "ACCOUNTADMIN" + assert params["account"] == "example-dev-acct" + assert params["role"] == "ACCOUNTADMIN" + + +def test_get_parameters_errs_if_none(monkeypatch: MonkeyPatch): + # Test that get_parameters() raises an exception when no matching connection is found + # Test with invalid default connection + mock_config_manager = {"default_connection_name": "non_existent", "connections": {}} -def test_get_connection_errs_if_none(monkeypatch: MonkeyPatch): - # Test that get_connection_parameters() raises an exception when no matching connection is found + # Mock the import inside get_parameters + def mock_import(name, *args, **kwargs): + if name == "snowflake.connector.config_manager": + # Create a mock module with CONFIG_MANAGER + mock_module = type("mock_module", (), {}) + mock_module.CONFIG_MANAGER = mock_config_manager + return mock_module + return original_import(name, *args, **kwargs) - # Test with empty connections list - monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: []) - monkeypatch.setattr("rsconnect.snowflake.ensure_snow_installed", lambda: None) + monkeypatch.setattr("builtins.__import__", mock_import) with pytest.raises(RSConnectException) as excinfo: - get_connection_parameters() - assert "No Snowflake connections found" in str(excinfo.value) + get_parameters() + assert "Could not get Snowflake connection" in str(excinfo.value) # Test with connections but non-existent name - monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: SAMPLE_CONNECTIONS) + mock_config_manager = {"connections": {"prod": {"account": "example-prod-acct"}}} + + # Update the mock with new config + def mock_import(name, *args, **kwargs): + if name == "snowflake.connector.config_manager": + # Create a mock module with CONFIG_MANAGER + mock_module = type("mock_module", (), {}) + mock_module.CONFIG_MANAGER = mock_config_manager + return mock_module + return original_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", mock_import) with pytest.raises(RSConnectException) as excinfo: - get_connection_parameters("nexiste") - assert "No Snowflake connection found with name 'nexiste'" in str(excinfo.value) + get_parameters("nexiste") + assert "Could not get Snowflake connection" in str(excinfo.value) def test_generate_jwt(monkeypatch: MonkeyPatch): @@ -263,7 +307,9 @@ def mock_snow(*args): ) monkeypatch.setattr("rsconnect.snowflake.snow", mock_snow) - monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: SAMPLE_CONNECTIONS) + + # Mock get_parameters to return empty dict (we just need it not to fail) + monkeypatch.setattr("rsconnect.snowflake.get_parameters", lambda name=None: {}) # Case 1: Test with default connection (no name parameter) jwt = generate_jwt() @@ -274,9 +320,16 @@ def mock_snow(*args): assert jwt == "header.payload.signature" # Case 3: Test with an invalid connection name + def mock_get_parameters_with_error(name=None): + if name == "nexiste": + raise RSConnectException(f"Could not get Snowflake connection: Key '{name}' does not exist.") + return {} + + monkeypatch.setattr("rsconnect.snowflake.get_parameters", mock_get_parameters_with_error) + with pytest.raises(RSConnectException) as excinfo: generate_jwt("nexiste") - assert "No Snowflake connection found with name 'nexiste'" in str(excinfo.value) + assert "Could not get Snowflake connection" in str(excinfo.value) def test_generate_jwt_command_failure(monkeypatch: MonkeyPatch): @@ -288,7 +341,7 @@ def mock_snow(*args): ) monkeypatch.setattr("rsconnect.snowflake.snow", mock_snow) - monkeypatch.setattr("rsconnect.snowflake.get_connection_parameters", lambda name=None: {}) + monkeypatch.setattr("rsconnect.snowflake.get_parameters", lambda name=None: {}) with pytest.raises(RSConnectException) as excinfo: generate_jwt() @@ -306,7 +359,7 @@ def mock_snow(*args): return MockProcessInvalidJSON() monkeypatch.setattr("rsconnect.snowflake.snow", mock_snow) - monkeypatch.setattr("rsconnect.snowflake.get_connection_parameters", lambda name=None: {}) + monkeypatch.setattr("rsconnect.snowflake.get_parameters", lambda name=None: {}) with pytest.raises(RSConnectException) as excinfo: generate_jwt() @@ -324,7 +377,7 @@ def mock_snow(*args): return MockProcessNoMessage() monkeypatch.setattr("rsconnect.snowflake.snow", mock_snow) - monkeypatch.setattr("rsconnect.snowflake.get_connection_parameters", lambda name=None: {}) + monkeypatch.setattr("rsconnect.snowflake.get_parameters", lambda name=None: {}) with pytest.raises(RSConnectException) as excinfo: generate_jwt()