diff --git a/CHANGES.rst b/CHANGES.rst index d068962c8..521e00214 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -668,6 +668,8 @@ New Features in 0.4.0 a jumper. - AndroidFastbootDriver now supports booting/flashing images preconfigured in the environment configuration. +- CommandProtocol now supports a start_process() method to run a process on the + target asynchronously. Bug fixes in 0.4.0 ~~~~~~~~~~~~~~~~~~ diff --git a/doc/development.rst b/doc/development.rst index 4283912dc..d08d38e8b 100644 --- a/doc/development.rst +++ b/doc/development.rst @@ -783,12 +783,3 @@ By writing these events to a file (or sqlite database) as a trace, we can collect data over multiple runs for later analysis. This would become more useful by passing recognized events (stack traces, crashes, ...) and benchmark results via the Step infrastructure. - -CommandProtocol Support for Background Processes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Currently the CommandProtocol does not support long running -processes well. -An implementation should start a new process, -return a handle and forbid running other processes in the foreground. -The handle can be used to retrieve output from a command. diff --git a/labgrid/driver/bareboxdriver.py b/labgrid/driver/bareboxdriver.py index d87b5e0aa..61ccc78c3 100644 --- a/labgrid/driver/bareboxdriver.py +++ b/labgrid/driver/bareboxdriver.py @@ -1,12 +1,14 @@ import shlex +from contextlib import contextmanager import attr from pexpect import TIMEOUT +from ..exceptions import CommandProcessBusy from ..factory import target_factory from ..protocol import CommandProtocol, ConsoleProtocol, LinuxBootProtocol from ..step import step -from ..util import gen_marker, Timeout, re_vt100 +from ..util import gen_marker, Timeout, re_vt100, ConsoleMarkerProcess from .common import Driver from .commandmixin import CommandMixin @@ -44,6 +46,7 @@ def __attrs_post_init__(self): self._status = 0 # barebox' default log level, used as fallback if no log level can be saved self.saved_log_level = 7 + self._process = None def on_activate(self): """Activate the BareboxDriver @@ -59,6 +62,34 @@ def on_deactivate(self): Simply sets the internal status to 0 """ self._status = 0 + assert not self._process, "Deactivating while a command process is running is not allowed" + + @contextmanager + def _start_process(self, cmd: str, *, adjust_log_level: bool = True): + if self._process is not None: + raise CommandProcessBusy() + + # FIXME: use codec, decodeerrors + marker = gen_marker() + # hide marker from expect + hidden_marker = f'"{marker[:4]}""{marker[4:]}"' + # generate command with marker and log level adjustment + cmp_command = f'echo -o /cmd {shlex.quote(cmd)}; echo {hidden_marker};' + if self.saved_log_level and adjust_log_level: + cmp_command += f' global.loglevel={self.saved_log_level};' + cmp_command += f' sh /cmd; echo {hidden_marker} $?;' + if self.saved_log_level and adjust_log_level: + cmp_command += ' global.loglevel=0;' + + self.console.sendline(cmp_command) + self.console.expect(marker) + + with ConsoleMarkerProcess(self.console, marker, self.prompt) as p: + self._process = p + try: + yield p + finally: + self._process = None @Driver.check_active @step(args=['cmd']) @@ -76,32 +107,23 @@ def _run(self, cmd: str, *, timeout: int = 30, adjust_log_level: bool = True, co Returns: Tuple[List[str],List[str], int]: if successful, None otherwise """ - # FIXME: use codec, decodeerrors - marker = gen_marker() - # hide marker from expect - hidden_marker = f'"{marker[:4]}""{marker[4:]}"' - # generate command with marker and log level adjustment - cmp_command = f'echo -o /cmd {shlex.quote(cmd)}; echo {hidden_marker};' - if self.saved_log_level and adjust_log_level: - cmp_command += f' global.loglevel={self.saved_log_level};' - cmp_command += f' sh /cmd; echo {hidden_marker} $?;' - if self.saved_log_level and adjust_log_level: - cmp_command += ' global.loglevel=0;' - if self._status == 1: - self.console.sendline(cmp_command) - _, _, match, _ = self.console.expect( - rf'{marker}(.*){marker}\s+(\d+)\s+.*{self.prompt}', - timeout=timeout) - # Remove VT100 Codes and split by newline - data = re_vt100.sub('', match.group(1).decode('utf-8')).split('\r\n')[1:-1] - self.logger.debug("Received Data: %s", data) - # Get exit code - exitcode = int(match.group(2)) - return (data, [], exitcode) + with self._start_process(cmd, adjust_log_level=adjust_log_level) as p: + output = p.read_to_end(timeout=timeout) + # Remove VT100 Codes and split by newline + data = re_vt100.sub('', output.decode('utf-8')).split('\r\n')[1:-1] + self.logger.debug("Received Data: %s", data) + return (data, [], p.exitcode) return None + @Driver.check_active + @step(args=['cmd']) + @contextmanager + def start_process(self, cmd: str): + with self._start_process(cmd) as p: + yield p + @Driver.check_active @step() def reset(self): diff --git a/labgrid/driver/consoleexpectmixin.py b/labgrid/driver/consoleexpectmixin.py index cc11ac6bd..b05f44cd8 100644 --- a/labgrid/driver/consoleexpectmixin.py +++ b/labgrid/driver/consoleexpectmixin.py @@ -55,6 +55,14 @@ def sendline(self, line): def sendcontrol(self, char): self._expect.sendcontrol(char) + @Driver.check_active + def expect_no_step(self, pattern, timeout=-1): + """ + Expect a pattern without logging the output. + """ + index = self._expect.expect(pattern, timeout=timeout) + return index, self._expect.before, self._expect.match, self._expect.after + @Driver.check_active @step(args=['pattern'], result=True) def expect(self, pattern, timeout=-1): diff --git a/labgrid/driver/shelldriver.py b/labgrid/driver/shelldriver.py index 7d476f536..a72aa7a0d 100644 --- a/labgrid/driver/shelldriver.py +++ b/labgrid/driver/shelldriver.py @@ -5,15 +5,17 @@ import re import shlex import ipaddress +from contextlib import contextmanager import attr from pexpect import TIMEOUT import xmodem +from ..exceptions import CommandProcessBusy from ..factory import target_factory from ..protocol import CommandProtocol, ConsoleProtocol, FileTransferProtocol from ..step import step -from ..util import gen_marker, Timeout, re_vt100 +from ..util import gen_marker, Timeout, re_vt100, ConsoleMarkerProcess from .commandmixin import CommandMixin from .common import Driver from .exception import ExecutionError @@ -61,6 +63,7 @@ def __attrs_post_init__(self): self._xmodem_cached_rx_cmd = "" self._xmodem_cached_sx_cmd = "" + self._process = None def on_activate(self): if self._status == 0: @@ -76,38 +79,58 @@ def on_activate(self): def on_deactivate(self): self._status = 0 + assert not self._process, "Deactivating while a command process is running is not allowed" - def _run(self, cmd, *, timeout=30.0, codec="utf-8", decodeerrors="strict"): - """ - Runs the specified cmd on the shell and returns the output. + @contextmanager + def _start_process(self, cmd: str): + if self._process is not None: + raise CommandProcessBusy() - Arguments: - cmd - cmd to run on the shell - """ # FIXME: Handle pexpect Timeout self._check_prompt() marker = gen_marker() + # hide marker from expect cmp_command = f'''MARKER='{marker[:4]}''{marker[4:]}' run {shlex.quote(cmd)}''' self.console.sendline(cmp_command) - _, _, match, _ = self.console.expect( - rf'{marker}(.*){marker}\s+(\d+)\s+{self.prompt}', - timeout=timeout - ) - # Remove VT100 Codes, split by newline and remove surrounding newline - data = re_vt100.sub('', match.group(1).decode(codec, decodeerrors)).split('\r\n') - if data and not data[-1]: - del data[-1] - self.logger.debug("Received Data: %s", data) - # Get exit code - exitcode = int(match.group(2)) - return (data, [], exitcode) + self.console.expect(marker) + with ConsoleMarkerProcess(self.console, marker, self.prompt) as p: + self._process = p + try: + yield p + finally: + self._process = None + + def _run(self, cmd, *, timeout=30.0, codec="utf-8", decodeerrors="strict"): + """ + Runs the specified cmd on the shell and returns the output. + + Arguments: + cmd - cmd to run on the shell + """ + with self._start_process(cmd) as p: + output = p.read_to_end(timeout=timeout) + + # Remove VT100 Codes, split by newline and remove surrounding newline + data = re_vt100.sub('', output.decode(codec, decodeerrors)).split('\r\n') + if data and not data[-1]: + del data[-1] + + self.logger.debug("Received Data: %s", data) + return (data, [], p.exitcode) @Driver.check_active @step(args=['cmd'], result=True) def run(self, cmd, timeout=30.0, codec="utf-8", decodeerrors="strict"): return self._run(cmd, timeout=timeout, codec=codec, decodeerrors=decodeerrors) + @Driver.check_active + @step(args=['cmd']) + @contextmanager + def start_process(self, cmd: str): + with self._start_process(cmd) as p: + yield p + @step() def _await_login(self): """Awaits the login prompt and logs the user in""" diff --git a/labgrid/driver/sshdriver.py b/labgrid/driver/sshdriver.py index 110a4f707..b65ff4256 100644 --- a/labgrid/driver/sshdriver.py +++ b/labgrid/driver/sshdriver.py @@ -9,11 +9,13 @@ import tempfile import time from functools import cached_property +from contextlib import contextmanager import attr +import pexpect from ..factory import target_factory -from ..protocol import CommandProtocol, FileTransferProtocol +from ..protocol import CommandProtocol, CommandProcessProtocol, FileTransferProtocol from .commandmixin import CommandMixin from .common import Driver from ..step import step @@ -22,6 +24,56 @@ from ..util.proxy import proxymanager from ..util.timeout import Timeout from ..util.ssh import get_ssh_connect_timeout +from ..util.readmixin import ReadMixIn + + +class SSHDriverProcess(CommandProcessProtocol, ReadMixIn): + def __init__(self, sub): + self._sub = sub + + @property + def exitcode(self): + if self._sub.isalive(): + return None + + if self._sub.exitstatus is None: + return -self._sub.signalstatus + return self._sub.exitstatus + + def read(self, size=1, timeout=-1): + return self._sub.read_nonblocking(size, timeout) + + @step(args=['data']) + def write(self, data): + self._sub.write(data) + + @step(result=True) + def poll(self): + return self.exitcode + + @step(result=True) + def stop(self): + self._sub.close(True) + + @step(args=['pattern', 'timeout'], result=True) + def expect(self, pattern, *, timeout=-1): + index = self._sub.expect(pattern, timeout=timeout) + return index, self._sub.before, self._sub.match, self._sub.after + + @step(result=True) + def wait(self): + return self._sub.wait() + + @step(args=['char']) + def sendcontrol(self, char): + self._sub.sendcontrol(char) + + def __enter__(self): + return self + + def __exit__(self, typ, value, traceback): + self.stop() + return False @target_factory.reg_driver @@ -45,6 +97,7 @@ def __attrs_post_init__(self): self._scp = self._get_tool("scp") self._sshfs = self._get_tool("sshfs") self._rsync = self._get_tool("rsync") + self._processes = [] def _get_tool(self, name): if self.target.env: @@ -80,6 +133,7 @@ def on_activate(self): self._start_keepalive() def on_deactivate(self): + assert not self._processes, "Deactivating while a command process is running is not allowed" try: self._stop_keepalive() finally: @@ -242,6 +296,35 @@ def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): stderr.pop() return (stdout, stderr, sub.returncode) + @Driver.check_active + @step(args=['cmd']) + @contextmanager + def start_process(self, cmd: str): + if not self._check_keepalive(): + raise ExecutionError("Keepalive no longer running") + + cmd = f"stty -echo; {cmd}" # Disable input echo from ssh + complete_cmd = ["ssh", "-o", "LogLevel=QUIET", "-x", *self.ssh_prefix, + "-p", str(self.networkservice.port), "-l", self.networkservice.username, + self.networkservice.address, "-tt", "--", '/bin/sh -c {}'.format(shlex.quote(cmd)), + ] + self.logger.debug("Sending command: %s", complete_cmd) + + try: + sub = pexpect.spawn(complete_cmd[0], complete_cmd[1:]) + sub.setecho(False) # Disable input echo from pexpect + except: + raise ExecutionError( + "error executing command: {}".format(complete_cmd) + ) + + with SSHDriverProcess(sub) as p: + self._processes.append(p) + try: + yield p + finally: + self._processes.remove(p) + def interact(self, cmd=None): assert cmd is None or isinstance(cmd, list) @@ -377,7 +460,7 @@ def scp(self, *, src, dst): "-o", f"ControlPath={self.control.replace('%', '%%')}", src, dst, ] - + if self.explicit_sftp_mode and self._scp_supports_explicit_sftp_mode(): complete_cmd.insert(1, "-s") if self.explicit_scp_mode and self._scp_supports_explicit_scp_mode(): diff --git a/labgrid/driver/ubootdriver.py b/labgrid/driver/ubootdriver.py index 8cef4c0ec..0acfa7bbe 100644 --- a/labgrid/driver/ubootdriver.py +++ b/labgrid/driver/ubootdriver.py @@ -1,11 +1,13 @@ """The U-Boot Module contains the UBootDriver""" import re +from contextlib import contextmanager import attr from pexpect import TIMEOUT +from ..exceptions import CommandProcessBusy from ..factory import target_factory from ..protocol import CommandProtocol, ConsoleProtocol, LinuxBootProtocol -from ..util import gen_marker, re_vt100, Timeout +from ..util import ConsoleMarkerProcess, gen_marker, re_vt100, Timeout from ..step import step from .common import Driver from .commandmixin import CommandMixin @@ -51,6 +53,7 @@ class UBootDriver(CommandMixin, Driver, CommandProtocol, LinuxBootProtocol): def __attrs_post_init__(self): super().__attrs_post_init__() self._status = 0 + self._process = None if self.boot_expression: import warnings @@ -70,32 +73,38 @@ def on_deactivate(self): Simply sets the internal status to 0 """ self._status = 0 + assert not self._process, "Deactivating while a command process is running is not allowed" + + @contextmanager + def _start_process(self, cmd: str): + if self._process is not None: + raise CommandProcessBusy() - def _run(self, cmd: str, *, timeout: int = 30, codec: str = "utf-8", decodeerrors: str = "strict"): # pylint: disable=unused-argument,line-too-long # TODO: use codec, decodeerrors # TODO: Shell Escaping for the U-Boot Shell marker = gen_marker() - cmp_command = f"""echo '{marker[:4]}''{marker[4:]}'; {cmd}; echo "$?"; echo '{marker[:4]}''{marker[4:]}';""" # pylint: disable=line-too-long - if self._status == 1: - self.console.sendline(cmp_command) - _, before, _, _ = self.console.expect(self.prompt, timeout=timeout) - # Remove VT100 Codes and split by newline - data = re_vt100.sub( - '', before.decode('utf-8'), count=1000000 - ).replace("\r", "").split("\n") - - # Strip possible U-Boot timestamps from the line - if self.strip_timestamp: - data = [re_uboot_timestamp.sub('', line) for line in data] - - self.logger.debug("Received Data: %s", data) - # Remove first element, the invoked cmd - data = data[data.index(marker) + 1:] - data = data[:data.index(marker)] - exitcode = int(data[-1]) - del data[-1] - return (data, [], exitcode) + with ConsoleMarkerProcess(self.console, marker, self.prompt) as p: + self._process = p + try: + yield p + finally: + self._process = None + def _run(self, cmd: str, *, timeout: int = 30, codec: str = "utf-8", decodeerrors: str = "strict"): # pylint: disable=unused-argument,line-too-long + if self._status == 1: + with self._start_process(cmd) as p: + output = p.read_to_end(timeout=timeout) + # Remove VT100 Codes and split by newline + data = re_vt100.sub( + '', output.decode('utf-8'), count=1000000 + ).replace("\r", "").split("\n") + + # Strip possible U-Boot timestamps from the line + if self.strip_timestamp: + data = [re_uboot_timestamp.sub('', line) for line in data] + + self.logger.debug("Received Data: %s", data) + return (data, [], p.exitcode) return None @Driver.check_active @@ -113,6 +122,13 @@ def run(self, cmd, timeout=30): """ return self._run(cmd, timeout=timeout) + @Driver.check_active + @step(args=['cmd'], result=True) + @contextmanager + def start_process(self, cmd: str): + with self._start_process(cmd) as p: + yield p + def get_status(self): """Retrieve status of the UBootDriver. 0 means inactive, 1 means active. diff --git a/labgrid/exceptions.py b/labgrid/exceptions.py index df49be3dd..bc45f3ff1 100644 --- a/labgrid/exceptions.py +++ b/labgrid/exceptions.py @@ -41,3 +41,12 @@ class NoStrategyFoundError(NoSupplierFoundError): @attr.s(eq=False) class RegistrationError(Exception): msg = attr.ib(validator=attr.validators.instance_of(str)) + + +@attr.s(eq=False) +class CommandProcessBusy(Exception): + """ + This exception is raised if it is not possible to execute multiple + CommandProcessProtocol at the same time. + """ + pass diff --git a/labgrid/protocol/__init__.py b/labgrid/protocol/__init__.py index 0ac225622..613ac2265 100644 --- a/labgrid/protocol/__init__.py +++ b/labgrid/protocol/__init__.py @@ -1,5 +1,5 @@ from .bootstrapprotocol import BootstrapProtocol -from .commandprotocol import CommandProtocol +from .commandprotocol import CommandProtocol, CommandProcessProtocol from .consoleprotocol import ConsoleProtocol from .linuxbootprotocol import LinuxBootProtocol from .powerprotocol import PowerProtocol diff --git a/labgrid/protocol/commandprotocol.py b/labgrid/protocol/commandprotocol.py index dff1daaf3..92e599fae 100644 --- a/labgrid/protocol/commandprotocol.py +++ b/labgrid/protocol/commandprotocol.py @@ -1,6 +1,75 @@ import abc +class CommandProcessProtocol(abc.ABC): + """Abstract class for a running command""" + + def read(self, size=1, timeout=-1): + """ + Reads up to size bytes from the remote process. If no output is + available to read, will block up to timeout seconds waiting for output, + then return any available (up to size bytes). If no output occurs + before the timeout, a TIMEOUT exception is raised. + + If there is no more output because the process has exited, raises EOF + + If timeout is -1, the default timeout value will be used. + + This operates the same as read_nonblocking() from pexpect; You may want + to look at the operations from the ReadMixIn class for more convenient + methods of reading data from remote processes + """ + raise NotImplementedError + + @abc.abstractmethod + def write(self, data): + """ + Write data to the process + """ + raise NotImplementedError + + @abc.abstractmethod + def poll(self): + """ + Check if the process is alive. If the process is alive, None will be + returned. If the process exited normally, the return code will be + returned. If the process died from a signal, the negative value of the + signal number will be returned. + """ + raise NotImplementedError + + @abc.abstractmethod + def stop(self): + """ + Stops the child process + """ + raise NotImplementedError + + @abc.abstractmethod + def expect(self, pattern, *, timeout=-1): + """ + Seeks through the process output until a pattern is matched. See + pexpect.spawn.expect() + """ + raise NotImplementedError + + @abc.abstractmethod + def wait(self): + """ + Wait for the process to exit. Note that this may wait forever if the + process generates output that is not read + """ + raise NotImplementedError + + @abc.abstractmethod + def sendcontrol(self, char): + """ + Helper method that wraps write() with mnemonic access for sending + control character to the child + """ + raise NotImplementedError + + class CommandProtocol(abc.ABC): """Abstract class for the CommandProtocol""" @@ -18,6 +87,16 @@ def run_check(self, command: str): """ raise NotImplementedError + @abc.abstractmethod + def start_process(self, cmd: str): + """ + Start a new command process. Returns CommandProcessProtocol. + + If another command process is already running and the driver doesn't + support multiple at the same time, CommandProcessBusy will be raised + """ + raise NotImplementedError + @abc.abstractmethod def get_status(self): """ diff --git a/labgrid/util/__init__.py b/labgrid/util/__init__.py index 042162c72..f5e4994ef 100644 --- a/labgrid/util/__init__.py +++ b/labgrid/util/__init__.py @@ -2,7 +2,7 @@ from .dict import diff_dict, flat_dict, filter_dict, find_dict from .expect import PtxExpect from .timeout import Timeout -from .marker import gen_marker +from .marker import gen_marker, ConsoleMarkerProcess from .yaml import load, dump from .ssh import sshmanager from .helper import get_free_port, get_user, re_vt100 diff --git a/labgrid/util/marker.py b/labgrid/util/marker.py index f3b1f42d3..453bc7bf6 100644 --- a/labgrid/util/marker.py +++ b/labgrid/util/marker.py @@ -1,5 +1,12 @@ import random import string +import re + +from pexpect import EOF, TIMEOUT, spawn + +from ..protocol import CommandProcessProtocol +from ..step import step +from ..util.readmixin import ReadMixIn # Remove RID to avoid markers containing substrings like ERROR, FAIL, WARN, INFO or DEBUG @@ -7,3 +14,165 @@ def gen_marker(): return ''.join(random.choice(MARKER_POOL) for i in range(10)) + + +class ConsoleProcessExpect(spawn): + """labgrid Wrapper of the pexpect module. + + This class provides pexpect functionality for the ConsoleMarkerProcess + classes. This allows users use the expect API on the output of a process + without needing to worry about if the console is using markers (e.g. + ShellDriver, UBootDriver). This means that these drivers behave the same as + drivers that directly use spawn (e.g. SSHDriver), even though they are + using markers. + """ + + def __init__(self, process, timeout): + "Initializes a pexpect spawn instanse with required configuration" + super().__init__(None, timeout=timeout) + self.process = process + + def send(self, s): + self.process.write(s) + + def read_nonblocking(self, size=1, timeout=-1): + return self.process.read(size, timeout) + + def sendcontrol(self, char): + self.process.sendcontrol(char) + + +class ConsoleMarkerProcess(CommandProcessProtocol, ReadMixIn): + def __init__(self, console, marker, prompt, *, encoding="utf-8", timeout=30.0, on_exit=None): + self._console = console + self._alive = True + self.exitcode = None + self._on_exit = on_exit + + # Build up the Regex to capture output from the command. The regex will + # only capture a single character from the output, and only if the + # current buffer doesn't start with any prefix of the output buffer + # (using negative lookahead). Partial prefixes are anchored to then end + # of the string so as soon as it is clear that the current first + # character of the output buffer can't possibly be part of the output + # marker, it will be consumed. + partials = [] + p = "" + for m in marker[:-1]: + p = p + m + partials.append(r"{}$".format(p)) + + # Add the complete marker, which doesn't need to be anchored to the end + # of the string to prevent a character from being consumed. + partials.append(marker) + partials.reverse() + + self._output_re = re.compile( + r"^{}.".format("".join(r"(?!{})".format(p) for p in partials)).encode( + encoding + ), + re.DOTALL, + ) + + self._prompt_re = re.compile(prompt.encode(encoding)) + + self._eof_re = re.compile( + r"^{marker}\s+(\d+)\s+.*{prompt}".format( + marker=marker, prompt=prompt + ).encode(encoding) + ) + + self._expect = ConsoleProcessExpect(self, timeout=timeout) + + def _handle_eof(self, code): + self.exitcode = code + self._alive = False + if self._on_exit: + self._on_exit(self) + + def read(self, size=1, timeout=-1): + # Wait up to timeout for the first byte of data + buf = self._read_byte(timeout) + + # Read as much remaining data as possible without blocking + while size < 0 or len(buf) < size: + try: + buf = buf + self._read_byte(0) + except (TIMEOUT, EOF): + break + + return buf + + def _read_byte(self, timeout): + if not self._alive: + raise EOF("ConsoleMarkerProcess end-of-file") + + index, _, match, after = self._console.expect_no_step( + [self._output_re, self._eof_re, TIMEOUT], + timeout=timeout, + ) + if index == 0: + return after + elif index == 1: + self._handle_eof(int(match.group(1))) + raise EOF("ConsoleMarkerProcess end-of-file") + elif index == 2: + raise TIMEOUT("ConsoleMarkerProcess timeout") + + @step(args=["data"]) + def write(self, data): + if self._alive: + return self._console.write(data) + return 0 + + @step(result=True) + def poll(self): + if not self._alive: + return self.exitcode + + index, _, match, _ = self._console.expect([self._eof_re, TIMEOUT], timeout=1) + + if index == 0: + self._handle_eof(int(match.group(1))) + return self.exitcode + + return None + + @step(result=True) + def stop(self): + if self._alive: + self._console.sendcontrol("c") + # Not all shells will emit the marker and exit code on interrupt. + # Check for a bare prompt in addition to EOF for these shells + index, _, _, _ = self.expect([EOF, self._prompt_re], timeout=60) + if index == 1: + # If a prompt was seen with no marker, emulate the exit code + # for "died by SIGINT" (130) + self._handle_eof(130) + + @step(args=["pattern", "timeout"], result=True) + def expect(self, pattern, *, timeout=-1): + return ( + self._expect.expect(pattern, timeout=timeout), + self._expect.before, + self._expect.match, + self._expect.after, + ) + + @step(result=True) + def wait(self): + while True: + index, _, _, _ = self.expect([EOF, TIMEOUT]) + if index == 0: + break + + @step(args=["char"]) + def sendcontrol(self, char): + self._console.sendcontrol(char) + + def __enter__(self): + return self + + def __exit__(self, typ, value, traceback): + self.stop() + return False diff --git a/labgrid/util/readmixin.py b/labgrid/util/readmixin.py new file mode 100644 index 000000000..b01f60bc3 --- /dev/null +++ b/labgrid/util/readmixin.py @@ -0,0 +1,81 @@ +from pexpect import EOF, TIMEOUT +from .timeout import Timeout + +class ReadMixIn: + """ + This class make it more convenient to deal with reading from devices. A + typical read() will either return immediately if there is buffered data, or + wait up to some timeout for data to appear, then return whatever happens to + be buffered at that time. In both of these cases, less data that requested + may be returned before the timeout expires, even if no EOF is encountered. + + The function in this class help deal with sources like this by trying harder + to read data from the device. + + Inspired by the function of the same name present in Rust + """ + def read(self, size, timeout): + """ + Stub for mixin. Must be implemented by subclass. + """ + raise NotImplementedError + + def read_full(self, size=-1, *, timeout=30): + """ + Reads bytes until either size bytes have been read, timeout seconds + have elapsed, or EOF is encountered. Returns as many bytes as were read + until that happens. + + If size is -1, read as much data as possible until either the timeout + or EOF is encountered. + """ + t = Timeout(timeout) + buf = b"" + + while not t.expired: + read_size = size - len(buf) if size >= 0 else 64 + if read_size <= 0: + break + + try: + buf += self.read(read_size, t.remaining) + except EOF: + break + except TIMEOUT: + pass + + return buf + + def read_to_end(self, *, timeout=30): + """ + Read until EOF is encountered and return the resulting data. If the + timeout expires before EOF, a TIMEOUT error is raised + + If an exception is raised, any data read is lost + """ + t = Timeout(timeout) + buf = b"" + + while True: + try: + buf += self.read(64, t.remaining) + except EOF: + break + + return buf + + def read_exact(self, size, *, timeout=30): + """ + Read exactly size bytes. If the timeout elapses before size bytes are + read, a TIMEOUT error is raised. If EOF is encountered before size + bytes are read, an EOF error is raised + + If an exception is raised, any data read is lost + """ + t = Timeout(timeout) + buf = b"" + + while len(buf) < size: + buf += self.read(size - len(buf), t.remaining) + + return buf diff --git a/labgrid/util/timeout.py b/labgrid/util/timeout.py index d1d792efa..2ebc167d9 100644 --- a/labgrid/util/timeout.py +++ b/labgrid/util/timeout.py @@ -7,7 +7,7 @@ class Timeout: """Reperents a timeout (as a deadline)""" timeout = attr.ib( - default=120.0, validator=attr.validators.instance_of(float) + default=120.0, validator=attr.validators.instance_of((float, int)) ) def __attrs_post_init__(self): diff --git a/tests/test_marker.py b/tests/test_marker.py new file mode 100644 index 000000000..827d15658 --- /dev/null +++ b/tests/test_marker.py @@ -0,0 +1,198 @@ +import pytest +import attr +import logging +from pexpect import TIMEOUT, EOF + +from labgrid.factory import target_factory +from labgrid.util import ConsoleMarkerProcess +from labgrid.protocol import ConsoleProtocol +from labgrid.driver import Driver +from labgrid.driver.consoleexpectmixin import ConsoleExpectMixin + + +@target_factory.reg_driver +@attr.s(eq=False) +class EchoConsoleDriver(ConsoleExpectMixin, Driver, ConsoleProtocol): + prompt = attr.ib(validator=attr.validators.instance_of(str)) + marker = attr.ib(default="ABC", validator=attr.validators.instance_of(str)) + txdelay = attr.ib(default=0.0, validator=attr.validators.instance_of(float)) + timeout = attr.ib(default=1.0, validator=attr.validators.instance_of(float)) + + def __attrs_post_init__(self): + super().__attrs_post_init__() + self.logger = logging.getLogger("{}({})".format(self, self.target)) + self.buffer = b"" + + def _read(self, size=-1, timeout=0.0, max_size=None): + if not self.buffer: + raise TIMEOUT("Timeout reading data") + elif size < 0 or size >= len(self.buffer): + data = self.buffer + self.buffer = b"" + else: + data = self.buffer[:size] + self.buffer = self.buffer[size:] + print("READ: %r %r" % (data, self.buffer)) + return data + + def _write(self, data, *_): + if b"\x03" in data: + a, b = data.split(b"\x03", 1) + self.buffer = self.buffer + a + self._end_process_string(-1) + b + else: + self.buffer = self.buffer + data + print("WRITE: %r" % self.buffer) + + def _end_process_string(self, retcode): + return "{}\n{}\n{}".format(self.marker, retcode, self.prompt).encode("utf-8") + + def open(self): + pass + + def close(self): + pass + + def end_process(self, retcode): + self.buffer = self.buffer + self._end_process_string(retcode) + + +@pytest.fixture(scope="function") +def console(target): + d = EchoConsoleDriver(target, "console", prompt="PROMPT>") + target.activate(d) + return d + + +@pytest.fixture(scope="function") +def process(console): + return ConsoleMarkerProcess(console, console.marker, console.prompt, timeout=0.1) + + +def test_create(console): + ConsoleMarkerProcess(console, console.marker, console.prompt) + + +def test_partial_read(console, process): + process.write(b"Hello World") + console.end_process(0) + + assert process.read_full(6, timeout=0.1) == b"Hello " + assert process.poll() is None + + assert process.read_full(5) == b"World" + assert process.poll() == 0 + + assert process.read_full(1, timeout=0.1) == b"" + with pytest.raises(EOF): + process.read(1) + + +def test_read_timeout(console, process): + process.write(b"Hello World") + + assert process.read_full(100, timeout=0.1) == b"Hello World" + + process.read_full(100, timeout=0.1) == b"" + + with pytest.raises(TIMEOUT): + process.read(100) + + process.write(b"Hello World") + + with pytest.raises(TIMEOUT): + process.expect("Never found", timeout=0.1) + + +def test_expect(console, process): + process.write(b"Hello World") + console.end_process(0) + + index, before, match, after = process.expect([b"Hello", "World"], timeout=0.1) + assert index == 0 + assert before == b"" + assert match.group(0) == b"Hello" + assert after == b"Hello" + + index, before, match, after = process.expect([b"Hello", "World"], timeout=0.1) + assert index == 1 + assert before == b" " + assert match.group(0) == b"World" + assert after == b"World" + + with pytest.raises(EOF): + process.expect([b"Hello", "World"], timeout=0.1) + + index, before, match, after = process.expect([r"Hello", r"World", EOF], timeout=0.1) + assert index == 2 + assert before == b"" + + +def test_marker(console, process): + for i in range(len(console.marker) - 1): + partial_marker = console.marker[: i + 1].encode("utf-8") + process.write(partial_marker) + # Any match on a partial marker should not consume any output until we + # are certain it's not actually a marker + assert process.read_full(100, timeout=0.1) == b"" + + with pytest.raises(TIMEOUT): + process.read(100) + + with pytest.raises(TIMEOUT): + process.expect(r".+", timeout=0.1) + + assert process.expect([r".+", TIMEOUT], timeout=0.1) == (1, b"", TIMEOUT, TIMEOUT) + + # Write the partial marker again, which means the original partial can + # be returned since it is guaranteed to not be a match for the complete + # marker + process.write(partial_marker) + assert process.read_full(100, timeout=0.1) == partial_marker + + # Do it again for read_nonblocking + process.write(partial_marker) + assert process.read(100) == partial_marker + + # Do it again for expect + process.write(partial_marker) + index, before, match, after = process.expect([r".+"], timeout=0.1) + assert index == 0 + assert before == b"" + assert match.group(0) == partial_marker + assert after == partial_marker + + # Write a space, which makes the last partial no longer match the marker + process.write(b" ") + assert process.read_full(100, timeout=0.1) == partial_marker + b" " + + +def test_read_to_end(console, process): + process.write(b"Hello World") + console.end_process(0) + + process.read_to_end() == b"Hello World" + process.read_to_end() == b"" + + +def test_read_to_end_timeout(console, process): + process.write(b"Hello World") + with pytest.raises(TIMEOUT): + process.read_to_end(timeout=0.1) + + console.end_process(0) + assert process.read_to_end() == b"" + + +def test_read_exact(console, process): + process.write(b"Hello World") + assert process.read_exact(6) == b"Hello " + process.write(b" Exact") + assert process.read_exact(11) == b"World Exact" + + with pytest.raises(TIMEOUT): + process.read_exact(1) + + console.end_process(0) + + with pytest.raises(EOF): + process.read_exact(1) diff --git a/tests/test_sshdriver.py b/tests/test_sshdriver.py index 4c233a834..dfaf47ac7 100644 --- a/tests/test_sshdriver.py +++ b/tests/test_sshdriver.py @@ -1,5 +1,6 @@ import pytest import socket +from pexpect import TIMEOUT, EOF from labgrid import Environment from labgrid.driver import SSHDriver, ExecutionError @@ -218,3 +219,92 @@ def test_unix_socket_forward(ssh_localhost, tmpdir): send_socket.send(test_string.encode("utf-8")) assert client_socket.recv(16).decode("utf-8") == test_string + +@pytest.mark.sshusername +def test_start_process_simple(ssh_localhost): + with ssh_localhost.start_process("echo Hello World") as p: + assert p.read_full(6, timeout=10.0) == b"Hello " + assert p.read(7) == b"World\r\n" + assert p.read_full(1, timeout=10.0) == b"" + assert p.read_full(timeout=10.0) == b"" + with pytest.raises(EOF): + p.read(10) + + with ssh_localhost.start_process("echo Hello World") as p: + p.expect("Hello World") + p.expect(EOF) + with pytest.raises(EOF): + p.expect(r".") + + +@pytest.mark.sshusername +def test_start_process_timeout(ssh_localhost): + with ssh_localhost.start_process("cat") as p: + with pytest.raises(TIMEOUT): + p.read(100, timeout=5) == b"" + + with pytest.raises(TIMEOUT): + p.expect("Never found", timeout=5) + +@pytest.mark.sshusername +def test_start_process_stream(ssh_localhost): + with ssh_localhost.start_process("cat") as p: + p.write(b"Hello World\n") + data = p.read_full(timeout=10.0) + data = data.decode("utf-8").splitlines() + # Only one hello world expected. Input echo is disabled + assert data == ["Hello World"] + + assert p.poll() is None + p.sendcontrol('d') + p.expect(EOF) + + assert p.poll() == 0 + + with ssh_localhost.start_process("cat") as p: + p.write(b"ABCDEF\n") + + p.expect(b"ABCDEF", timeout=10.0) + + assert p.poll() is None + p.stop() + + code = p.poll() + assert code is not None + assert code != 0 + + +@pytest.mark.sshusername +def test_start_process_read_to_end(ssh_localhost): + with ssh_localhost.start_process("echo Hello World") as p: + p.read_to_end() == b"Hello World" + p.read_to_end() == b"" + + +@pytest.mark.sshusername +def test_start_process_read_to_end_timeout(ssh_localhost): + with ssh_localhost.start_process("cat") as p: + p.write(b"Hello World\n") + with pytest.raises(TIMEOUT): + p.read_to_end(timeout=5) + + p.sendcontrol('d') + assert p.read_to_end() == b"" + + +@pytest.mark.sshusername +def test_start_process_read_exact(ssh_localhost): + with ssh_localhost.start_process("cat") as p: + p.write(b"Hello World\n") + assert p.read_exact(6) == b"Hello " + assert p.read_exact(5) == b"World" + + # Discard the remaining data (echoed data) + p.read_full(-1, timeout=1) + + with pytest.raises(TIMEOUT): + p.read_exact(1, timeout=5) + p.sendcontrol('d') + + with pytest.raises(EOF): + p.read_exact(1, timeout=5) diff --git a/tests/test_timeout.py b/tests/test_timeout.py index 101a88f36..8c2b6c0bb 100644 --- a/tests/test_timeout.py +++ b/tests/test_timeout.py @@ -9,12 +9,17 @@ def test_create(self): assert (isinstance(t, Timeout)) t = Timeout(5.0) assert (isinstance(t, Timeout)) + t = Timeout(5) + assert (isinstance(t, Timeout)) + with pytest.raises(TypeError): - t = Timeout(10) + t = Timeout('123') with pytest.raises(ValueError): t = Timeout(-1.0) + with pytest.raises(ValueError): + t = Timeout(-1) - def test_expire(self, mocker): + def test_expire_float(self, mocker): m = mocker.patch('time.monotonic') m.return_value = 0.0 @@ -29,3 +34,19 @@ def test_expire(self, mocker): m.return_value += 3.0 assert t.expired assert t.remaining == 0.0 + + def test_expire_int(self, mocker): + m = mocker.patch('time.monotonic') + m.return_value = 0.0 + + t = Timeout(5) + assert not t.expired + assert t.remaining == 5.0 + + m.return_value += 3.0 + assert not t.expired + assert t.remaining == 2.0 + + m.return_value += 3.0 + assert t.expired + assert t.remaining == 0.0