From 2334739ed2c09f01cfd00ce86a48129c534b1cdf Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 26 Feb 2025 17:07:42 -0500 Subject: [PATCH] feature: default ports for SSH tunnel --- superset/commands/database/ssh_tunnel/create.py | 5 ++++- superset/commands/database/ssh_tunnel/update.py | 5 ++++- superset/utils/ssh_tunnel.py | 14 ++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index 89e607ba67aed..9e9161ea51f0c 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -33,6 +33,7 @@ from superset.extensions import event_logger from superset.models.core import Database from superset.utils.decorators import on_error, transaction +from superset.utils.ssh_tunnel import get_default_port logger = logging.getLogger(__name__) @@ -72,7 +73,9 @@ def validate(self) -> None: "private_key_password" ) url = make_url_safe(self._database.sqlalchemy_uri) - if not url.port: + backend = url.get_backend_name() + port = url.port or get_default_port(backend) + if not port: raise SSHTunnelDatabasePortError() if not server_address: exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) diff --git a/superset/commands/database/ssh_tunnel/update.py b/superset/commands/database/ssh_tunnel/update.py index b2fa416bd597e..763d36e89a0b3 100644 --- a/superset/commands/database/ssh_tunnel/update.py +++ b/superset/commands/database/ssh_tunnel/update.py @@ -32,6 +32,7 @@ from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.utils.decorators import on_error, transaction +from superset.utils.ssh_tunnel import get_default_port logger = logging.getLogger(__name__) @@ -75,5 +76,7 @@ def validate(self) -> None: raise SSHTunnelInvalidError( exceptions=[SSHTunnelRequiredFieldValidationError("private_key")] ) - if not url.port: + backend = url.get_backend_name() + port = url.port or get_default_port(backend) + if not port: raise SSHTunnelDatabasePortError() diff --git a/superset/utils/ssh_tunnel.py b/superset/utils/ssh_tunnel.py index 8421350f8c140..1471d54f4b124 100644 --- a/superset/utils/ssh_tunnel.py +++ b/superset/utils/ssh_tunnel.py @@ -20,6 +20,13 @@ from superset.constants import PASSWORD_MASK from superset.databases.ssh_tunnel.models import SSHTunnel +DEFAULT_PORTS: dict[str, int] = { + "postgresql": 5432, + "mysql": 3306, + "oracle": 1521, + "mssql": 1433, +} + def mask_password_info(ssh_tunnel: dict[str, Any]) -> dict[str, Any]: if ssh_tunnel.pop("password", None) is not None: @@ -41,3 +48,10 @@ def unmask_password_info( if ssh_tunnel.get("private_key_password") == PASSWORD_MASK: ssh_tunnel["private_key_password"] = model.private_key_password return ssh_tunnel + + +def get_default_port(backend: str) -> int | None: + """ + Get the default port for the given backend. + """ + return DEFAULT_PORTS.get(backend)