Skip to content

Commit

Permalink
feature: default ports for SSH tunnel
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Feb 26, 2025
1 parent 0042955 commit 2334739
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
5 changes: 4 additions & 1 deletion superset/commands/database/ssh_tunnel/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
from superset.utils.ssh_tunnel import get_default_port

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,7 +73,9 @@ def validate(self) -> None:
"private_key_password"
)
url = make_url_safe(self._database.sqlalchemy_uri)
if not url.port:
backend = url.get_backend_name()
port = url.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
Expand Down
5 changes: 4 additions & 1 deletion superset/commands/database/ssh_tunnel/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.utils.decorators import on_error, transaction
from superset.utils.ssh_tunnel import get_default_port

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,5 +76,7 @@ def validate(self) -> None:
raise SSHTunnelInvalidError(
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
)
if not url.port:
backend = url.get_backend_name()
port = url.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
14 changes: 14 additions & 0 deletions superset/utils/ssh_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from superset.constants import PASSWORD_MASK
from superset.databases.ssh_tunnel.models import SSHTunnel

DEFAULT_PORTS: dict[str, int] = {
"postgresql": 5432,
"mysql": 3306,
"oracle": 1521,
"mssql": 1433,
}


def mask_password_info(ssh_tunnel: dict[str, Any]) -> dict[str, Any]:
if ssh_tunnel.pop("password", None) is not None:
Expand All @@ -41,3 +48,10 @@ def unmask_password_info(
if ssh_tunnel.get("private_key_password") == PASSWORD_MASK:
ssh_tunnel["private_key_password"] = model.private_key_password
return ssh_tunnel


def get_default_port(backend: str) -> int | None:
"""
Get the default port for the given backend.
"""
return DEFAULT_PORTS.get(backend)

0 comments on commit 2334739

Please sign in to comment.