Skip to content

Commit

Permalink
Make ssh_tunnel context manager handle nested calls.
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-blanchard committed Apr 13, 2016
1 parent 6be150e commit e0ef07a
Showing 1 changed file with 36 additions and 34 deletions.
70 changes: 36 additions & 34 deletions streamparse/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import subprocess
import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from glob import glob
from os.path import join
Expand Down Expand Up @@ -45,7 +46,7 @@ def _port_in_use(port, server_type="tcp"):
return False


_active_tunnel = None
_active_tunnels = defaultdict(int)


@contextmanager
Expand All @@ -58,44 +59,45 @@ def ssh_tunnel(env_config, local_port=6627, remote_port=None, quiet=False):
if remote_port is None:
remote_port = nimbus_port
if env_config.get('use_ssh_for_nimbus', True):
global _active_tunnel
if _port_in_use(local_port):
if local_port == _active_tunnel:
yield 'localhost', local_port
need_setup = True
while _port_in_use(local_port):
if local_port in _active_tunnels:
active_remote_port = _active_tunnels[local_port]
if active_remote_port == remote_port:
need_setup = False
break
local_port += 1

if need_setup:
user = env_config.get("user")
if user:
user_at_host = "{user}@{host}".format(user=user, host=host)
else:
raise IOError("Local port: {} already in use, unable to open "
"ssh tunnel to {}:{}.".format(local_port,
host,
remote_port))

user = env_config.get("user")
if user:
user_at_host = "{user}@{host}".format(user=user, host=host)
else:
user_at_host = host # Rely on SSH default or config to connect.

ssh_cmd = ["ssh",
"-NL",
"{local}:localhost:{remote}".format(local=local_port,
remote=remote_port),
user_at_host]
ssh_proc = subprocess.Popen(ssh_cmd, shell=False)
# Validate that the tunnel is actually running before yielding
while not _port_in_use(local_port):
# Periodically check to see if the ssh command failed and returned a
# value, then raise an Exception
if ssh_proc.poll() is not None:
raise IOError('Unable to open ssh tunnel via: "{}"'
.format(" ".join(ssh_cmd)))
time.sleep(0.2)
try:
user_at_host = host # Rely on SSH default or config to connect.

ssh_cmd = ["ssh",
"-NL",
"{local}:localhost:{remote}".format(local=local_port,
remote=remote_port),
user_at_host]
ssh_proc = subprocess.Popen(ssh_cmd, shell=False)
# Validate that the tunnel is actually running before yielding
while not _port_in_use(local_port):
# Periodically check to see if the ssh command failed and returned a
# value, then raise an Exception
if ssh_proc.poll() is not None:
raise IOError('Unable to open ssh tunnel via: "{}"'
.format(" ".join(ssh_cmd)))
time.sleep(0.2)
if not quiet:
print("ssh tunnel to Nimbus {}:{} established."
.format(host, remote_port))
_active_tunnel = local_port
yield 'localhost', local_port
finally:
_active_tunnels[local_port] = remote_port
yield 'localhost', local_port
# Clean up after we exit context
if need_setup:
ssh_proc.kill()
del _active_tunnels[local_port]
# Do nothing if we're not supposed to use ssh
else:
yield host, remote_port
Expand Down

0 comments on commit e0ef07a

Please sign in to comment.