diff --git a/labgrid/resource/common.py b/labgrid/resource/common.py index d2b190750..51e15e36f 100644 --- a/labgrid/resource/common.py +++ b/labgrid/resource/common.py @@ -86,18 +86,21 @@ class NetworkResource(Resource): Args: host (str): remote host the resource is available on + sshpassword (str): remote host ssh password """ host = attr.ib(validator=attr.validators.instance_of(str)) + sshpassword = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(str)), kw_only=True) @property def command_prefix(self): host = self.host + sshpassword = self.sshpassword if hasattr(self, 'extra'): if self.extra.get('proxy_required'): host = self.extra.get('proxy') - conn = sshmanager.get(host) + conn = sshmanager.get(host, sshpassword) prefix = conn.get_prefix() return prefix + ['--'] diff --git a/labgrid/util/ssh.py b/labgrid/util/ssh.py index 62bbf4cbb..b33871bdd 100644 --- a/labgrid/util/ssh.py +++ b/labgrid/util/ssh.py @@ -4,9 +4,11 @@ import shutil import subprocess import os +import stat +import shlex from select import select from functools import wraps -from typing import Dict +from typing import Dict, Union import attr from ..driver.exception import ExecutionError @@ -35,18 +37,19 @@ def __attrs_post_init__(self): self.logger = logging.getLogger(f"{self}") atexit.register(self.close_all) - def get(self, host: str): + def get(self, host: str, sshpassword: Union[str, None] = None): """Retrieve or create a new connection to a given host Arguments: host (str): host to retrieve the connection for + sshpassword (str): remote host ssh password Returns: :obj:`SSHConnection`: the SSHConnection for the host""" instance = self._connections.get(host) if instance is None: self.logger.debug("Creating SSHConnection for %s", host) - instance = SSHConnection(host) + instance = SSHConnection(host, sshpassword=sshpassword) instance.connect() self._connections[host] = instance return instance @@ -130,6 +133,7 @@ class SSHConnection: A public identity infrastructure is assumed, no extra username or passwords are supported.""" host = attr.ib(validator=attr.validators.instance_of(str)) + sshpassword = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(str)), kw_only=True) _connected = attr.ib( default=False, init=False, validator=attr.validators.instance_of(bool) ) @@ -150,15 +154,20 @@ def __attrs_post_init__(self): @staticmethod def _get_ssh_base_args(): - return ["-x", "-o", "LogLevel=ERROR", "-o", "PasswordAuthentication=no"] + return ["-x", "-o", "LogLevel=ERROR"] def _get_ssh_control_args(self): + args = [] if self._socket: - return [ + args += [ "-o", "ControlMaster=no", "-o", f"ControlPath={self._socket}", ] - return [] + if not self.sshpassword: + args += [ + "-o", "PasswordAuthentication=no" + ] + return args def _get_ssh_args(self): args = SSHConnection._get_ssh_base_args() @@ -445,11 +454,25 @@ def _start_own_master(self): self._logger.debug("Master Start command: %s", " ".join(args)) assert self._master is None + + env = os.environ.copy() + pass_file = '' + if self.sshpassword: + fd, pass_file = tempfile.mkstemp() + os.fchmod(fd, stat.S_IRWXU) + #with openssh>=8.4 SSH_ASKPASS_REQUIRE can be used to force SSH_ASK_PASS + #openssh<8.4 requires the DISPLAY var and a detached process with start_new_session=True + env = {'SSH_ASKPASS': pass_file, 'DISPLAY':'', 'SSH_ASKPASS_REQUIRE':'force'} + with open(fd, 'w') as f: + f.write("#!/bin/sh\necho " + shlex.quote(self.sshpassword)) + self._master = subprocess.Popen( args, + env=env, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + start_new_session=True ) try: @@ -464,6 +487,9 @@ def _start_own_master(self): raise ExecutionError( f"failed to connect (timeout) to {self.host} with args {args}, process killed, got {stdout},{stderr}" # pylint: disable=line-too-long ) + finally: + if self.sshpassword and os.path.exists(pass_file): + os.remove(pass_file) if not os.path.exists(control): raise ExecutionError(f"no control socket to {self.host}") diff --git a/tests/test_util.py b/tests/test_util.py index ab83a8a5c..ecf59f9f8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -301,6 +301,17 @@ def test_proxymanager_local_forced_proxy(target): assert (host, port) != (nhost, nport) +@pytest.mark.localsshmanager +def test_remote_networkresource(target, tmpdir): + name = "test" + host = "localhost" + sshpassword = "foo" + res = NetworkResource(target, name, host, sshpassword=sshpassword) + + assert res.name == name + assert res.host == host + assert res.sshpassword == sshpassword + @pytest.mark.localsshmanager def test_remote_managedfile(target, tmpdir): import hashlib