Skip to content

Commit

Permalink
[dagster-ssh] Update to Pythonic resources (#15180)
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow authored Nov 10, 2023
1 parent 525c3cc commit 8978f18
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 81 deletions.
175 changes: 108 additions & 67 deletions python_modules/libraries/dagster-ssh/dagster_ssh/resources.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import getpass
import logging
import os
from io import StringIO
from typing import Optional

import paramiko
from dagster import (
BoolSource,
Field,
Field as DagsterField,
IntSource,
StringSource,
_check as check,
resource,
)
from dagster._config.pythonic_config import ConfigurableResource
from dagster._core.definitions.resource_definition import dagster_maintained_resource
from dagster._core.execution.context.init import InitResourceContext
from dagster._utils import mkdir_p
from dagster._utils.merger import merge_dicts
from paramiko.client import SSHClient
from paramiko.config import SSH_PORT
from pydantic import (
Field,
PrivateAttr,
)
from sshtunnel import SSHTunnelForwarder


Expand All @@ -29,70 +37,99 @@ def key_from_str(key_str):
return result


class SSHResource:
class SSHResource(ConfigurableResource):
"""Resource for ssh remote execution using Paramiko.
ref: https://github.com/paramiko/paramiko
"""

def __init__(
self,
remote_host,
remote_port,
username=None,
password=None,
key_file=None,
key_string=None,
timeout=10,
keepalive_interval=30,
compress=True,
no_host_key_check=True,
allow_host_key_change=False,
logger=None,
):
self.remote_host = check.str_param(remote_host, "remote_host")
self.remote_port = check.opt_int_param(remote_port, "remote_port")
self.username = check.opt_str_param(username, "username")
self.password = check.opt_str_param(password, "password")
self.key_file = check.opt_str_param(key_file, "key_file")
self.timeout = check.opt_int_param(timeout, "timeout")
self.keepalive_interval = check.opt_int_param(keepalive_interval, "keepalive_interval")
self.compress = check.opt_bool_param(compress, "compress")
self.no_host_key_check = check.opt_bool_param(no_host_key_check, "no_host_key_check")
self.log = logger

self.host_proxy = None
remote_host: str = Field(description="Remote host to connect to")
remote_port: Optional[int] = Field(default=None, description="Port of remote host to connect")
username: Optional[str] = Field(default=None, description="Username to connect to remote host")
password: Optional[str] = Field(
default=None, description="Password of the username to connect to remote host"
)
key_file: Optional[str] = Field(
default=None, description="Key file to use to connect to remote host"
)
key_string: Optional[str] = Field(
default=None, description="Key string to use to connect to remote host"
)
timeout: int = Field(
default=10, description="Timeout for the attempt to connect to remote host"
)
keepalive_interval: int = Field(
default=30,
description="Send a keepalive packet to remote host every keepalive_interval seconds",
)
compress: bool = Field(default=True, description="Compress the transport stream")
no_host_key_check: bool = Field(
default=True,
description=(
"If True, the host key will not be verified. This is unsafe and not recommended"
),
)
allow_host_key_change: bool = Field(
default=False,
description="If True, allow connecting to hosts whose host key has changed",
)

_logger: Optional[logging.Logger] = PrivateAttr(default=None)
_host_proxy: Optional[paramiko.ProxyCommand] = PrivateAttr(default=None)
_key_obj: Optional[paramiko.RSAKey] = PrivateAttr(default=None)

def set_logger(self, logger: logging.Logger) -> None:
self._logger = logger

def setup_for_execution(self, context: InitResourceContext) -> None:
self._logger = context.log
self._host_proxy = None

# Create RSAKey object from private key string
self.key_obj = key_from_str(key_string) if key_string is not None else None
self._key_obj = key_from_str(self.key_string) if self.key_string is not None else None

# Auto detecting username values from system
if not self.username:
logger.debug(
"username to ssh to host: %s is not specified. Using system's default provided by"
" getpass.getuser()" % self.remote_host
)
if self._logger:
self._logger.debug(
"username to ssh to host: %s is not specified. Using system's default provided"
" by getpass.getuser()" % self.remote_host
)
self.username = getpass.getuser()

user_ssh_config_filename = os.path.expanduser("~/.ssh/config")
if os.path.isfile(user_ssh_config_filename):
ssh_conf = paramiko.SSHConfig()
ssh_conf.parse(open(user_ssh_config_filename, encoding="utf8"))
host_info = ssh_conf.lookup(self.remote_host)
if host_info and host_info.get("proxycommand"):
self.host_proxy = paramiko.ProxyCommand(host_info.get("proxycommand"))

proxy_command = host_info.get("proxycommand")
if host_info and proxy_command:
self._host_proxy = paramiko.ProxyCommand(proxy_command)

if not (self.password or self.key_file):
if host_info and host_info.get("identityfile"):
self.key_file = host_info.get("identityfile")[0]
identify_file = host_info.get("identityfile")
if host_info and identify_file:
self.key_file = identify_file[0]

@property
def log(self) -> logging.Logger:
return check.not_none(self._logger)

def get_connection(self):
def get_connection(self) -> SSHClient:
"""Opens a SSH connection to the remote host.
:rtype: paramiko.client.SSHClient
"""
client = paramiko.SSHClient()
client.load_system_host_keys()

if not self.allow_host_key_change:
self.log.warning(
"Remote Identification Change is not verified. This won't protect against "
"Man-In-The-Middle attacks"
)
client.load_system_host_keys()
if self.no_host_key_check:
self.log.warning(
"No Host Key Verification. This won't protect against Man-In-The-Middle attacks"
Expand All @@ -106,31 +143,33 @@ def get_connection(self):
username=self.username,
password=self.password,
key_filename=self.key_file,
pkey=self.key_obj,
pkey=self._key_obj,
timeout=self.timeout,
compress=self.compress,
port=self.remote_port,
sock=self.host_proxy,
port=self.remote_port, # type: ignore
sock=self._host_proxy, # type: ignore
look_for_keys=False,
)
else:
client.connect(
hostname=self.remote_host,
username=self.username,
key_filename=self.key_file,
pkey=self.key_obj,
pkey=self._key_obj,
timeout=self.timeout,
compress=self.compress,
port=self.remote_port,
sock=self.host_proxy,
port=self.remote_port, # type: ignore
sock=self._host_proxy, # type: ignore
)

if self.keepalive_interval:
client.get_transport().set_keepalive(self.keepalive_interval)
client.get_transport().set_keepalive(self.keepalive_interval) # type: ignore

return client

def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
def get_tunnel(
self, remote_port, remote_host="localhost", local_port=None
) -> SSHTunnelForwarder:
check.int_param(remote_port, "remote_port")
check.str_param(remote_host, "remote_host")
check.opt_int_param(local_port, "local_port")
Expand All @@ -141,7 +180,11 @@ def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
local_bind_address = ("localhost",)

# Will prefer key string if specified, otherwise use the key file
pkey = self.key_obj if self.key_obj else self.key_file
if self._key_obj and self.key_file:
self.log.warning(
"SSHResource: key_string and key_file both specified as config. Using key_string."
)
pkey = self._key_obj if self._key_obj else self.key_file

if self.password and self.password.strip():
client = SSHTunnelForwarder(
Expand All @@ -150,22 +193,22 @@ def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
ssh_username=self.username,
ssh_password=self.password,
ssh_pkey=pkey,
ssh_proxy=self.host_proxy,
ssh_proxy=self._host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
logger=self.log,
logger=self._logger,
)
else:
client = SSHTunnelForwarder(
self.remote_host,
ssh_port=self.remote_port,
ssh_username=self.username,
ssh_pkey=pkey,
ssh_proxy=self.host_proxy,
ssh_proxy=self._host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
host_pkey_directories=[],
logger=self.log,
logger=self._logger,
)

return client
Expand Down Expand Up @@ -203,53 +246,51 @@ def sftp_put(self, remote_filepath, local_filepath, confirm=True):
@dagster_maintained_resource
@resource(
config_schema={
"remote_host": Field(
"remote_host": DagsterField(
StringSource, description="remote host to connect to", is_required=True
),
"remote_port": Field(
"remote_port": DagsterField(
IntSource,
description="port of remote host to connect (Default is paramiko SSH_PORT)",
is_required=False,
default_value=SSH_PORT,
),
"username": Field(
"username": DagsterField(
StringSource, description="username to connect to the remote_host", is_required=False
),
"password": Field(
"password": DagsterField(
StringSource,
description="password of the username to connect to the remote_host",
is_required=False,
),
"key_file": Field(
"key_file": DagsterField(
StringSource,
description="key file to use to connect to the remote_host.",
is_required=False,
),
"key_string": Field(
"key_string": DagsterField(
StringSource,
description="key string to use to connect to remote_host",
is_required=False,
),
"timeout": Field(
"timeout": DagsterField(
IntSource,
description="timeout for the attempt to connect to the remote_host.",
is_required=False,
default_value=10,
),
"keepalive_interval": Field(
"keepalive_interval": DagsterField(
IntSource,
description="send a keepalive packet to remote host every keepalive_interval seconds",
is_required=False,
default_value=30,
),
"compress": Field(BoolSource, is_required=False, default_value=True),
"no_host_key_check": Field(BoolSource, is_required=False, default_value=True),
"allow_host_key_change": Field(
"compress": DagsterField(BoolSource, is_required=False, default_value=True),
"no_host_key_check": DagsterField(BoolSource, is_required=False, default_value=True),
"allow_host_key_change": DagsterField(
BoolSource, description="[Deprecated]", is_required=False, default_value=False
),
}
)
def ssh_resource(init_context):
args = init_context.resource_config
args = merge_dicts(init_context.resource_config, {"logger": init_context.log})
return SSHResource(**args)
return SSHResource.from_resource_context(init_context)
Loading

0 comments on commit 8978f18

Please sign in to comment.