Skip to content

ShellDriver: add optional arg dest_authorized_keys #1631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,8 @@ Arguments:
Can be an empty string.
- keyfile (str): optional, keyfile to upload after login, making the
`SSHDriver`_ usable
- dest_authorized_keys (str): optional, default="~/.ssh/authorized_keys",
filename of the authorized_keys file
- login_timeout (int, default=60): timeout for login prompt detection in
seconds
- await_login_timeout (int, default=2): time in seconds of silence that needs
Expand Down
40 changes: 23 additions & 17 deletions labgrid/driver/shelldriver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=unused-argument
"""The ShellDriver provides the CommandProtocol, ConsoleProtocol and
InfoProtocol on top of a SerialPort."""
import os
import io
import re
import shlex
Expand Down Expand Up @@ -34,6 +35,7 @@ class ShellDriver(CommandMixin, Driver, CommandProtocol, FileTransferProtocol):
username (str): username to login with
password (str): password to login with
keyfile (str): keyfile to bind mount over users authorized keys
dest_authorized_keys (str): optional, default="~/.ssh/authorized_keys", filename of the authorized_keys file
login_timeout (int): optional, timeout for login prompt detection
console_ready (regex): optional, pattern used by the kernel to inform the user that a
console can be activated by pressing enter.
Expand All @@ -49,6 +51,7 @@ class ShellDriver(CommandMixin, Driver, CommandProtocol, FileTransferProtocol):
username = attr.ib(validator=attr.validators.instance_of(str))
password = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(str)))
keyfile = attr.ib(default="", validator=attr.validators.instance_of(str))
dest_authorized_keys = attr.ib(default="~/.ssh/authorized_keys", validator=attr.validators.instance_of(str))
login_timeout = attr.ib(default=60, validator=attr.validators.instance_of(int))
console_ready = attr.ib(default="", validator=attr.validators.instance_of(str))
await_login_timeout = attr.ib(default=2, validator=attr.validators.instance_of(int))
Expand All @@ -72,7 +75,7 @@ def on_activate(self):
if self.target.env:
keyfile_path = self.target.env.config.resolve_path(self.keyfile)

self._put_ssh_key(keyfile_path)
self._put_ssh_key(keyfile_path, self.dest_authorized_keys)

def on_deactivate(self):
self._status = 0
Expand Down Expand Up @@ -210,8 +213,9 @@ def _inject_run(self):
)
self.console.expect(self.prompt)

@step(args=['keyfile_path'])
def _put_ssh_key(self, keyfile_path):
@step(args=['keyfile_path', 'dest_authorized_keys'])
def _put_ssh_key(self, keyfile_path, dest_authorized_keys):
dest_authorized_keys_dir = os.path.dirname(dest_authorized_keys)
"""Upload an SSH Key to a target"""
regex = re.compile(
r"""ssh-(rsa|ed25519)
Expand All @@ -229,7 +233,7 @@ def _put_ssh_key(self, keyfile_path):
f"Could not parse SSH-Key from file: {keyfile}"
)
self.logger.debug("Read Key: %s", new_key)
auth_keys, _, read_keys = self._run("cat ~/.ssh/authorized_keys")
auth_keys, _, read_keys = self._run(f"""cat {self.dest_authorized_keys}""")
self.logger.debug("Exitcode trying to read keys: %s, keys: %s", read_keys, auth_keys)
result = []
_, _, test_write = self._run("touch ~/.test")
Expand All @@ -251,34 +255,36 @@ def _put_ssh_key(self, keyfile_path):

if test_write == 0 and read_keys == 0:
self.logger.debug("Key not on target and writeable, concatenating...")
self._run_check(f'echo "{keyline}" >> ~/.ssh/authorized_keys')
self._run_check(f"""echo "{keyline}" >> {dest_authorized_keys}""")
self._run_check("rm ~/.test")
return

if test_write == 0:
self.logger.debug("Key not on target, testing for .ssh directory")
_, _, ssh_dir = self._run("[ -d ~/.ssh/ ]")
_, _, ssh_dir = self._run(f"""[ -d {dest_authorized_keys_dir} ]""")
if ssh_dir != 0:
self.logger.debug("~/.ssh did not exist, creating")
self._run("mkdir ~/.ssh/")
self._run_check("chmod 700 ~/.ssh/")
self.logger.debug("Creating ~/.ssh/authorized_keys")
self._run_check(f'echo "{keyline}" > ~/.ssh/authorized_keys')
self.logger.debug(f"""{dest_authorized_keys_dir} did not exist, creating""")
self._run(f"""mkdir -p {dest_authorized_keys_dir}""")
self._run_check(f"""chmod 700 {dest_authorized_keys_dir}""")
self.logger.debug(f"""Creating {dest_authorized_keys}""")
self._run_check(f"""echo "{keyline}" > {dest_authorized_keys}""")
self._run_check("rm ~/.test")
return

self.logger.debug("Key not on target and not writeable, using bind mount...")
self._run_check('mkdir -m 700 /tmp/labgrid-ssh/')
self._run("cp -a ~/.ssh/* /tmp/labgrid-ssh/")
self._run_check(f'echo "{keyline}" >> /tmp/labgrid-ssh/authorized_keys')
self._run(f"""cp -a {dest_authorized_keys_dir}/* /tmp/labgrid-ssh/""")
self._run_check(f"""echo "{keyline}" >> /tmp/labgrid-ssh/authorized_keys""")
self._run_check('chmod 600 /tmp/labgrid-ssh/authorized_keys')
out, err, exitcode = self._run('mount --bind /tmp/labgrid-ssh/ ~/.ssh/')
out, err, exitcode = self._run(f"""mount --bind /tmp/labgrid-ssh/ {dest_authorized_keys_dir}""")
if exitcode != 0:
self.logger.warning("Could not bind mount ~/.ssh directory: %s %s", out, err)
self.logger.warning(f"""Could not bind mount {dest_authorized_keys_dir} directory: {out} {err}""")

@Driver.check_active
def put_ssh_key(self, keyfile_path):
self._put_ssh_key(keyfile_path)
def put_ssh_key(self, keyfile_path, dest_authorized_keys = None):
if dest_authorized_keys is None:
dest_authorized_keys = self.dest_authorized_keys
self._put_ssh_key(keyfile_path, dest_authorized_keys)

def _xmodem_getc(self, size, timeout=10):
""" called by the xmodem.XMODEM instance to read protocol data from the console """
Expand Down
Loading