Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: default ports for SSH tunnel #32403

Merged
merged 2 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion superset/commands/database/ssh_tunnel/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
from superset.utils.ssh_tunnel import get_default_port

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,7 +73,9 @@ def validate(self) -> None:
"private_key_password"
)
url = make_url_safe(self._database.sqlalchemy_uri)
if not url.port:
backend = url.get_backend_name()
port = url.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
Expand Down
5 changes: 4 additions & 1 deletion superset/commands/database/ssh_tunnel/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.utils.decorators import on_error, transaction
from superset.utils.ssh_tunnel import get_default_port

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,5 +76,7 @@ def validate(self) -> None:
raise SSHTunnelInvalidError(
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
)
if not url.port:
backend = url.get_backend_name()
port = url.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
5 changes: 4 additions & 1 deletion superset/extensions/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ def create_tunnel(
ssh_tunnel: "SSHTunnel",
sqlalchemy_database_uri: str,
) -> sshtunnel.SSHTunnelForwarder:
from superset.utils.ssh_tunnel import get_default_port

url = make_url_safe(sqlalchemy_database_uri)
backend = url.get_backend_name()
params = {
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
"ssh_username": ssh_tunnel.username,
"remote_bind_address": (url.host, url.port),
"remote_bind_address": (url.host, url.port or get_default_port(backend)),
"local_bind_address": (self.local_bind_address,),
"debug_level": logging.getLogger("flask_appbuilder").level,
}
Expand Down
14 changes: 14 additions & 0 deletions superset/utils/ssh_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from superset.constants import PASSWORD_MASK
from superset.databases.ssh_tunnel.models import SSHTunnel

DEFAULT_PORTS: dict[str, int] = {
"postgresql": 5432,
"mysql": 3306,
"oracle": 1521,
"mssql": 1433,
}


def mask_password_info(ssh_tunnel: dict[str, Any]) -> dict[str, Any]:
if ssh_tunnel.pop("password", None) is not None:
Expand All @@ -41,3 +48,10 @@ def unmask_password_info(
if ssh_tunnel.get("private_key_password") == PASSWORD_MASK:
ssh_tunnel["private_key_password"] = model.private_key_password
return ssh_tunnel


def get_default_port(backend: str) -> int | None:
"""
Get the default port for the given backend.
"""
return DEFAULT_PORTS.get(backend)
121 changes: 115 additions & 6 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,18 +345,18 @@ def test_create_database_with_ssh_tunnel(
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_missing_port_raises_error(
def test_create_database_with_ssh_tunnel_no_port(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
Database API: Test create with SSH Tunnel
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
Expand All @@ -369,13 +369,58 @@ def test_create_database_with_missing_port_raises_error(
"username": "foo",
"password": "bar",
}

database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}

uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
assert response.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX" # noqa: S105
assert model_ssh_tunnel.database_id == response.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()

@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_ssh_tunnel_no_port_no_default(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return

modified_sqlalchemy_uri = "weird+db://foo:bar@localhost/test-db"

ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
Expand Down Expand Up @@ -459,7 +504,71 @@ def test_update_database_with_ssh_tunnel(
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_with_missing_port_raises_error(
def test_update_database_with_ssh_tunnel_no_port(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test update Database with SSH Tunnel
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return

modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"

ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}

uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201

uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200

model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
assert model_ssh_tunnel.database_id == response_update.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()

@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_no_port_no_default(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
Expand All @@ -477,7 +586,7 @@ def test_update_database_with_missing_port_raises_error(
if example_db.backend == "sqlite":
return

modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
modified_sqlalchemy_uri = "weird+db://foo:bar@localhost/test-db"

ssh_tunnel_properties = {
"server_address": "123.132.123.1",
Expand Down
57 changes: 51 additions & 6 deletions tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,24 @@


import pytest
from sqlalchemy.orm.session import Session

from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)


def test_create_ssh_tunnel_command() -> None:
def test_create_ssh_tunnel_command(session: Session) -> None:
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database

engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member

database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
Expand All @@ -49,12 +53,15 @@ def test_create_ssh_tunnel_command() -> None:
assert isinstance(result, SSHTunnel)


def test_create_ssh_tunnel_command_invalid_params() -> None:
def test_create_ssh_tunnel_command_invalid_params(session: Session) -> None:
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.models.core import Database

engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member

database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
Expand All @@ -76,12 +83,19 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")


def test_create_ssh_tunnel_command_no_port() -> None:
def test_create_ssh_tunnel_command_no_port(session: Session) -> None:
"""
Test that SSH Tunnel can be created without explicit port but with a default one.
"""
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database

engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member

database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost/db",
)
Expand All @@ -94,6 +108,37 @@ def test_create_ssh_tunnel_command_no_port() -> None:
"password": "bar",
}

result = CreateSSHTunnelCommand(database, properties).run()

assert result is not None
assert isinstance(result, SSHTunnel)


def test_create_ssh_tunnel_command_no_port_no_default(session: Session) -> None:
"""
Test that error is raised when creating SSH Tunnel without explicit/default ports.
"""
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.models.core import Database

engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member

database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="weird+db://u:p@localhost/db",
)

properties = {
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}

command = CreateSSHTunnelCommand(database, properties)

with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,37 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
"session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
)
def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
"""
Test that SSH Tunnel can be updated without explicit port but with a default one.
"""
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel

result = DatabaseDAO.get_ssh_tunnel(1)

assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address

update_payload = {"server_address": "Test2"}
UpdateSSHTunnelCommand(1, update_payload).run()

result = DatabaseDAO.get_ssh_tunnel(1)

assert result
assert isinstance(result, SSHTunnel)
assert "Test2" == result.server_address


@pytest.mark.parametrize(
"session_with_data", ["weird+db://u:p@localhost/testdb"], indirect=True
)
def test_update_shh_tunnel_no_port_no_default(session_with_data: Session) -> None:
"""
Test that error is raised when updating SSH Tunnel without explicit/default ports.
"""
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
Expand Down
Loading