Skip to content

Commit

Permalink
resource: add ability to provide password for NetworkResource's
Browse files Browse the repository at this point in the history
Signed-off-by: Felix Zwettler <[email protected]>
  • Loading branch information
flxzt committed Jun 18, 2024
1 parent 41a19d4 commit 5a8a805
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
5 changes: 4 additions & 1 deletion labgrid/resource/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,21 @@ class NetworkResource(Resource):
Args:
host (str): remote host the resource is available on
sshpassword (str): remote host ssh password
"""
host = attr.ib(validator=attr.validators.instance_of(str))
sshpassword = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(str)), kw_only=True)

@property
def command_prefix(self):
host = self.host
sshpassword = self.sshpassword

if hasattr(self, 'extra'):
if self.extra.get('proxy_required'):
host = self.extra.get('proxy')

conn = sshmanager.get(host)
conn = sshmanager.get(host, sshpassword)
prefix = conn.get_prefix()

return prefix + ['--']
Expand Down
36 changes: 31 additions & 5 deletions labgrid/util/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import shutil
import subprocess
import os
import stat
import shlex
from select import select
from functools import wraps
from typing import Dict
Expand Down Expand Up @@ -35,18 +37,19 @@ def __attrs_post_init__(self):
self.logger = logging.getLogger(f"{self}")
atexit.register(self.close_all)

def get(self, host: str):
def get(self, host: str, sshpassword: str | None = None):
"""Retrieve or create a new connection to a given host
Arguments:
host (str): host to retrieve the connection for
sshpassword (str): remote host ssh password
Returns:
:obj:`SSHConnection`: the SSHConnection for the host"""
instance = self._connections.get(host)
if instance is None:
self.logger.debug("Creating SSHConnection for %s", host)
instance = SSHConnection(host)
instance = SSHConnection(host, sshpassword=sshpassword)
instance.connect()
self._connections[host] = instance
return instance
Expand Down Expand Up @@ -130,6 +133,7 @@ class SSHConnection:
A public identity infrastructure is assumed, no extra username or passwords
are supported."""
host = attr.ib(validator=attr.validators.instance_of(str))
sshpassword = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(str)), kw_only=True)
_connected = attr.ib(
default=False, init=False, validator=attr.validators.instance_of(bool)
)
Expand All @@ -150,15 +154,20 @@ def __attrs_post_init__(self):

@staticmethod
def _get_ssh_base_args():
return ["-x", "-o", "LogLevel=ERROR", "-o", "PasswordAuthentication=no"]
return ["-x", "-o", "LogLevel=ERROR"]

def _get_ssh_control_args(self):
args = []
if self._socket:
return [
args += [
"-o", "ControlMaster=no",
"-o", f"ControlPath={self._socket}",
]
return []
if not self.sshpassword:
args += [
"-o", "PasswordAuthentication=no"
]
return args

def _get_ssh_args(self):
args = SSHConnection._get_ssh_base_args()
Expand Down Expand Up @@ -445,11 +454,25 @@ def _start_own_master(self):

self._logger.debug("Master Start command: %s", " ".join(args))
assert self._master is None

env = os.environ.copy()
pass_file = ''
if self.sshpassword:
fd, pass_file = tempfile.mkstemp()
os.fchmod(fd, stat.S_IRWXU)
#with openssh>=8.4 SSH_ASKPASS_REQUIRE can be used to force SSH_ASK_PASS
#openssh<8.4 requires the DISPLAY var and a detached process with start_new_session=True
env = {'SSH_ASKPASS': pass_file, 'DISPLAY':'', 'SSH_ASKPASS_REQUIRE':'force'}
with open(fd, 'w') as f:
f.write("#!/bin/sh\necho " + shlex.quote(self.sshpassword))

self._master = subprocess.Popen(
args,
env=env,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True
)

try:
Expand All @@ -464,6 +487,9 @@ def _start_own_master(self):
raise ExecutionError(
f"failed to connect (timeout) to {self.host} with args {args}, process killed, got {stdout},{stderr}" # pylint: disable=line-too-long
)
finally:
if self.sshpassword and os.path.exists(pass_file):
os.remove(pass_file)

if not os.path.exists(control):
raise ExecutionError(f"no control socket to {self.host}")
Expand Down
11 changes: 11 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,17 @@ def test_proxymanager_local_forced_proxy(target):

assert (host, port) != (nhost, nport)

@pytest.mark.localsshmanager
def test_remote_networkresource(target, tmpdir):
name = "test"
host = "localhost"
sshpassword = "foo"
res = NetworkResource(target, name, host, sshpassword=sshpassword)

assert res.name == name
assert res.host == host
assert res.sshpassword == sshpassword

@pytest.mark.localsshmanager
def test_remote_managedfile(target, tmpdir):
import hashlib
Expand Down

0 comments on commit 5a8a805

Please sign in to comment.