From 32fbf691767e2410a94147c0ca73c8ac2adba894 Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Thu, 16 Sep 2021 16:53:34 -0500 Subject: [PATCH 01/11] timeout: Allow integer timeouts Allowing the class to take an integer for a timeout instead of a strict float seems more convenient to users than forcing it at a higher level, or breaking usage by changing other APIs to require floats (i.e. run_check()) Signed-off-by: Joshua Watt --- labgrid/util/timeout.py | 2 +- tests/test_timeout.py | 25 +++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/labgrid/util/timeout.py b/labgrid/util/timeout.py index d1d792efa..2ebc167d9 100644 --- a/labgrid/util/timeout.py +++ b/labgrid/util/timeout.py @@ -7,7 +7,7 @@ class Timeout: """Reperents a timeout (as a deadline)""" timeout = attr.ib( - default=120.0, validator=attr.validators.instance_of(float) + default=120.0, validator=attr.validators.instance_of((float, int)) ) def __attrs_post_init__(self): diff --git a/tests/test_timeout.py b/tests/test_timeout.py index 101a88f36..8c2b6c0bb 100644 --- a/tests/test_timeout.py +++ b/tests/test_timeout.py @@ -9,12 +9,17 @@ def test_create(self): assert (isinstance(t, Timeout)) t = Timeout(5.0) assert (isinstance(t, Timeout)) + t = Timeout(5) + assert (isinstance(t, Timeout)) + with pytest.raises(TypeError): - t = Timeout(10) + t = Timeout('123') with pytest.raises(ValueError): t = Timeout(-1.0) + with pytest.raises(ValueError): + t = Timeout(-1) - def test_expire(self, mocker): + def test_expire_float(self, mocker): m = mocker.patch('time.monotonic') m.return_value = 0.0 @@ -29,3 +34,19 @@ def test_expire(self, mocker): m.return_value += 3.0 assert t.expired assert t.remaining == 0.0 + + def test_expire_int(self, mocker): + m = mocker.patch('time.monotonic') + m.return_value = 0.0 + + t = Timeout(5) + assert not t.expired + assert t.remaining == 5.0 + + m.return_value += 3.0 + assert not t.expired + assert t.remaining == 2.0 + + m.return_value += 3.0 + assert t.expired + assert t.remaining == 0.0 From 6e7a1e91b2e5e3caf7a00482a912ee2bac5fb24e Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Thu, 16 Sep 2021 10:31:05 -0500 Subject: [PATCH 02/11] commandprotocol: Add Process Protocol Extends the command protocol to support "background" processes. Signed-off-by: Joshua Watt --- labgrid/exceptions.py | 9 ++++ labgrid/protocol/__init__.py | 2 +- labgrid/protocol/commandprotocol.py | 79 +++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 1 deletion(-) diff --git a/labgrid/exceptions.py b/labgrid/exceptions.py index df49be3dd..bc45f3ff1 100644 --- a/labgrid/exceptions.py +++ b/labgrid/exceptions.py @@ -41,3 +41,12 @@ class NoStrategyFoundError(NoSupplierFoundError): @attr.s(eq=False) class RegistrationError(Exception): msg = attr.ib(validator=attr.validators.instance_of(str)) + + +@attr.s(eq=False) +class CommandProcessBusy(Exception): + """ + This exception is raised if it is not possible to execute multiple + CommandProcessProtocol at the same time. + """ + pass diff --git a/labgrid/protocol/__init__.py b/labgrid/protocol/__init__.py index 0ac225622..613ac2265 100644 --- a/labgrid/protocol/__init__.py +++ b/labgrid/protocol/__init__.py @@ -1,5 +1,5 @@ from .bootstrapprotocol import BootstrapProtocol -from .commandprotocol import CommandProtocol +from .commandprotocol import CommandProtocol, CommandProcessProtocol from .consoleprotocol import ConsoleProtocol from .linuxbootprotocol import LinuxBootProtocol from .powerprotocol import PowerProtocol diff --git a/labgrid/protocol/commandprotocol.py b/labgrid/protocol/commandprotocol.py index dff1daaf3..92e599fae 100644 --- a/labgrid/protocol/commandprotocol.py +++ b/labgrid/protocol/commandprotocol.py @@ -1,6 +1,75 @@ import abc +class CommandProcessProtocol(abc.ABC): + """Abstract class for a running command""" + + def read(self, size=1, timeout=-1): + """ + Reads up to size bytes from the remote process. If no output is + available to read, will block up to timeout seconds waiting for output, + then return any available (up to size bytes). If no output occurs + before the timeout, a TIMEOUT exception is raised. + + If there is no more output because the process has exited, raises EOF + + If timeout is -1, the default timeout value will be used. + + This operates the same as read_nonblocking() from pexpect; You may want + to look at the operations from the ReadMixIn class for more convenient + methods of reading data from remote processes + """ + raise NotImplementedError + + @abc.abstractmethod + def write(self, data): + """ + Write data to the process + """ + raise NotImplementedError + + @abc.abstractmethod + def poll(self): + """ + Check if the process is alive. If the process is alive, None will be + returned. If the process exited normally, the return code will be + returned. If the process died from a signal, the negative value of the + signal number will be returned. + """ + raise NotImplementedError + + @abc.abstractmethod + def stop(self): + """ + Stops the child process + """ + raise NotImplementedError + + @abc.abstractmethod + def expect(self, pattern, *, timeout=-1): + """ + Seeks through the process output until a pattern is matched. See + pexpect.spawn.expect() + """ + raise NotImplementedError + + @abc.abstractmethod + def wait(self): + """ + Wait for the process to exit. Note that this may wait forever if the + process generates output that is not read + """ + raise NotImplementedError + + @abc.abstractmethod + def sendcontrol(self, char): + """ + Helper method that wraps write() with mnemonic access for sending + control character to the child + """ + raise NotImplementedError + + class CommandProtocol(abc.ABC): """Abstract class for the CommandProtocol""" @@ -18,6 +87,16 @@ def run_check(self, command: str): """ raise NotImplementedError + @abc.abstractmethod + def start_process(self, cmd: str): + """ + Start a new command process. Returns CommandProcessProtocol. + + If another command process is already running and the driver doesn't + support multiple at the same time, CommandProcessBusy will be raised + """ + raise NotImplementedError + @abc.abstractmethod def get_status(self): """ From cb2d152c0d149ab2ddc6197ca056ca050353c934 Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Fri, 5 Nov 2021 11:05:17 -0500 Subject: [PATCH 03/11] util: Add readmixin class Adds a read mix in class that implements read helper functions to deal with reading from sources that will return only buffered data, instead of trying harder to read up to the user requested size. These functions are modeled after the functions provided by Rust Signed-off-by: Joshua Watt --- labgrid/util/readmixin.py | 81 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 labgrid/util/readmixin.py diff --git a/labgrid/util/readmixin.py b/labgrid/util/readmixin.py new file mode 100644 index 000000000..b01f60bc3 --- /dev/null +++ b/labgrid/util/readmixin.py @@ -0,0 +1,81 @@ +from pexpect import EOF, TIMEOUT +from .timeout import Timeout + +class ReadMixIn: + """ + This class make it more convenient to deal with reading from devices. A + typical read() will either return immediately if there is buffered data, or + wait up to some timeout for data to appear, then return whatever happens to + be buffered at that time. In both of these cases, less data that requested + may be returned before the timeout expires, even if no EOF is encountered. + + The function in this class help deal with sources like this by trying harder + to read data from the device. + + Inspired by the function of the same name present in Rust + """ + def read(self, size, timeout): + """ + Stub for mixin. Must be implemented by subclass. + """ + raise NotImplementedError + + def read_full(self, size=-1, *, timeout=30): + """ + Reads bytes until either size bytes have been read, timeout seconds + have elapsed, or EOF is encountered. Returns as many bytes as were read + until that happens. + + If size is -1, read as much data as possible until either the timeout + or EOF is encountered. + """ + t = Timeout(timeout) + buf = b"" + + while not t.expired: + read_size = size - len(buf) if size >= 0 else 64 + if read_size <= 0: + break + + try: + buf += self.read(read_size, t.remaining) + except EOF: + break + except TIMEOUT: + pass + + return buf + + def read_to_end(self, *, timeout=30): + """ + Read until EOF is encountered and return the resulting data. If the + timeout expires before EOF, a TIMEOUT error is raised + + If an exception is raised, any data read is lost + """ + t = Timeout(timeout) + buf = b"" + + while True: + try: + buf += self.read(64, t.remaining) + except EOF: + break + + return buf + + def read_exact(self, size, *, timeout=30): + """ + Read exactly size bytes. If the timeout elapses before size bytes are + read, a TIMEOUT error is raised. If EOF is encountered before size + bytes are read, an EOF error is raised + + If an exception is raised, any data read is lost + """ + t = Timeout(timeout) + buf = b"" + + while len(buf) < size: + buf += self.read(size - len(buf), t.remaining) + + return buf From b08a32865ee93c895a9067d4658d87a7fd4425cd Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Thu, 16 Sep 2021 10:32:22 -0500 Subject: [PATCH 04/11] marker: Add ConsoleMarkerProcess helper class Many different drivers need to implement "marker" style command processes where the command output is terminated with a random marker then the command status code and a prompt. Instead of making each driver implement this logic over and over again, implement a common class that handles this type of command output. Signed-off-by: Joshua Watt --- labgrid/driver/consoleexpectmixin.py | 8 ++ labgrid/util/__init__.py | 2 +- labgrid/util/marker.py | 169 +++++++++++++++++++++++ tests/test_marker.py | 198 +++++++++++++++++++++++++++ 4 files changed, 376 insertions(+), 1 deletion(-) create mode 100644 tests/test_marker.py diff --git a/labgrid/driver/consoleexpectmixin.py b/labgrid/driver/consoleexpectmixin.py index 9bb1cc15e..0534d2384 100644 --- a/labgrid/driver/consoleexpectmixin.py +++ b/labgrid/driver/consoleexpectmixin.py @@ -53,6 +53,14 @@ def sendline(self, line): def sendcontrol(self, char): self._expect.sendcontrol(char) + @Driver.check_active + def expect_no_step(self, pattern, timeout=-1): + """ + Expect a pattern without logging the output. + """ + index = self._expect.expect(pattern, timeout=timeout) + return index, self._expect.before, self._expect.match, self._expect.after + @Driver.check_active @step(args=['pattern'], result=True) def expect(self, pattern, timeout=-1): diff --git a/labgrid/util/__init__.py b/labgrid/util/__init__.py index 042162c72..f5e4994ef 100644 --- a/labgrid/util/__init__.py +++ b/labgrid/util/__init__.py @@ -2,7 +2,7 @@ from .dict import diff_dict, flat_dict, filter_dict, find_dict from .expect import PtxExpect from .timeout import Timeout -from .marker import gen_marker +from .marker import gen_marker, ConsoleMarkerProcess from .yaml import load, dump from .ssh import sshmanager from .helper import get_free_port, get_user, re_vt100 diff --git a/labgrid/util/marker.py b/labgrid/util/marker.py index f3b1f42d3..453bc7bf6 100644 --- a/labgrid/util/marker.py +++ b/labgrid/util/marker.py @@ -1,5 +1,12 @@ import random import string +import re + +from pexpect import EOF, TIMEOUT, spawn + +from ..protocol import CommandProcessProtocol +from ..step import step +from ..util.readmixin import ReadMixIn # Remove RID to avoid markers containing substrings like ERROR, FAIL, WARN, INFO or DEBUG @@ -7,3 +14,165 @@ def gen_marker(): return ''.join(random.choice(MARKER_POOL) for i in range(10)) + + +class ConsoleProcessExpect(spawn): + """labgrid Wrapper of the pexpect module. + + This class provides pexpect functionality for the ConsoleMarkerProcess + classes. This allows users use the expect API on the output of a process + without needing to worry about if the console is using markers (e.g. + ShellDriver, UBootDriver). This means that these drivers behave the same as + drivers that directly use spawn (e.g. SSHDriver), even though they are + using markers. + """ + + def __init__(self, process, timeout): + "Initializes a pexpect spawn instanse with required configuration" + super().__init__(None, timeout=timeout) + self.process = process + + def send(self, s): + self.process.write(s) + + def read_nonblocking(self, size=1, timeout=-1): + return self.process.read(size, timeout) + + def sendcontrol(self, char): + self.process.sendcontrol(char) + + +class ConsoleMarkerProcess(CommandProcessProtocol, ReadMixIn): + def __init__(self, console, marker, prompt, *, encoding="utf-8", timeout=30.0, on_exit=None): + self._console = console + self._alive = True + self.exitcode = None + self._on_exit = on_exit + + # Build up the Regex to capture output from the command. The regex will + # only capture a single character from the output, and only if the + # current buffer doesn't start with any prefix of the output buffer + # (using negative lookahead). Partial prefixes are anchored to then end + # of the string so as soon as it is clear that the current first + # character of the output buffer can't possibly be part of the output + # marker, it will be consumed. + partials = [] + p = "" + for m in marker[:-1]: + p = p + m + partials.append(r"{}$".format(p)) + + # Add the complete marker, which doesn't need to be anchored to the end + # of the string to prevent a character from being consumed. + partials.append(marker) + partials.reverse() + + self._output_re = re.compile( + r"^{}.".format("".join(r"(?!{})".format(p) for p in partials)).encode( + encoding + ), + re.DOTALL, + ) + + self._prompt_re = re.compile(prompt.encode(encoding)) + + self._eof_re = re.compile( + r"^{marker}\s+(\d+)\s+.*{prompt}".format( + marker=marker, prompt=prompt + ).encode(encoding) + ) + + self._expect = ConsoleProcessExpect(self, timeout=timeout) + + def _handle_eof(self, code): + self.exitcode = code + self._alive = False + if self._on_exit: + self._on_exit(self) + + def read(self, size=1, timeout=-1): + # Wait up to timeout for the first byte of data + buf = self._read_byte(timeout) + + # Read as much remaining data as possible without blocking + while size < 0 or len(buf) < size: + try: + buf = buf + self._read_byte(0) + except (TIMEOUT, EOF): + break + + return buf + + def _read_byte(self, timeout): + if not self._alive: + raise EOF("ConsoleMarkerProcess end-of-file") + + index, _, match, after = self._console.expect_no_step( + [self._output_re, self._eof_re, TIMEOUT], + timeout=timeout, + ) + if index == 0: + return after + elif index == 1: + self._handle_eof(int(match.group(1))) + raise EOF("ConsoleMarkerProcess end-of-file") + elif index == 2: + raise TIMEOUT("ConsoleMarkerProcess timeout") + + @step(args=["data"]) + def write(self, data): + if self._alive: + return self._console.write(data) + return 0 + + @step(result=True) + def poll(self): + if not self._alive: + return self.exitcode + + index, _, match, _ = self._console.expect([self._eof_re, TIMEOUT], timeout=1) + + if index == 0: + self._handle_eof(int(match.group(1))) + return self.exitcode + + return None + + @step(result=True) + def stop(self): + if self._alive: + self._console.sendcontrol("c") + # Not all shells will emit the marker and exit code on interrupt. + # Check for a bare prompt in addition to EOF for these shells + index, _, _, _ = self.expect([EOF, self._prompt_re], timeout=60) + if index == 1: + # If a prompt was seen with no marker, emulate the exit code + # for "died by SIGINT" (130) + self._handle_eof(130) + + @step(args=["pattern", "timeout"], result=True) + def expect(self, pattern, *, timeout=-1): + return ( + self._expect.expect(pattern, timeout=timeout), + self._expect.before, + self._expect.match, + self._expect.after, + ) + + @step(result=True) + def wait(self): + while True: + index, _, _, _ = self.expect([EOF, TIMEOUT]) + if index == 0: + break + + @step(args=["char"]) + def sendcontrol(self, char): + self._console.sendcontrol(char) + + def __enter__(self): + return self + + def __exit__(self, typ, value, traceback): + self.stop() + return False diff --git a/tests/test_marker.py b/tests/test_marker.py new file mode 100644 index 000000000..827d15658 --- /dev/null +++ b/tests/test_marker.py @@ -0,0 +1,198 @@ +import pytest +import attr +import logging +from pexpect import TIMEOUT, EOF + +from labgrid.factory import target_factory +from labgrid.util import ConsoleMarkerProcess +from labgrid.protocol import ConsoleProtocol +from labgrid.driver import Driver +from labgrid.driver.consoleexpectmixin import ConsoleExpectMixin + + +@target_factory.reg_driver +@attr.s(eq=False) +class EchoConsoleDriver(ConsoleExpectMixin, Driver, ConsoleProtocol): + prompt = attr.ib(validator=attr.validators.instance_of(str)) + marker = attr.ib(default="ABC", validator=attr.validators.instance_of(str)) + txdelay = attr.ib(default=0.0, validator=attr.validators.instance_of(float)) + timeout = attr.ib(default=1.0, validator=attr.validators.instance_of(float)) + + def __attrs_post_init__(self): + super().__attrs_post_init__() + self.logger = logging.getLogger("{}({})".format(self, self.target)) + self.buffer = b"" + + def _read(self, size=-1, timeout=0.0, max_size=None): + if not self.buffer: + raise TIMEOUT("Timeout reading data") + elif size < 0 or size >= len(self.buffer): + data = self.buffer + self.buffer = b"" + else: + data = self.buffer[:size] + self.buffer = self.buffer[size:] + print("READ: %r %r" % (data, self.buffer)) + return data + + def _write(self, data, *_): + if b"\x03" in data: + a, b = data.split(b"\x03", 1) + self.buffer = self.buffer + a + self._end_process_string(-1) + b + else: + self.buffer = self.buffer + data + print("WRITE: %r" % self.buffer) + + def _end_process_string(self, retcode): + return "{}\n{}\n{}".format(self.marker, retcode, self.prompt).encode("utf-8") + + def open(self): + pass + + def close(self): + pass + + def end_process(self, retcode): + self.buffer = self.buffer + self._end_process_string(retcode) + + +@pytest.fixture(scope="function") +def console(target): + d = EchoConsoleDriver(target, "console", prompt="PROMPT>") + target.activate(d) + return d + + +@pytest.fixture(scope="function") +def process(console): + return ConsoleMarkerProcess(console, console.marker, console.prompt, timeout=0.1) + + +def test_create(console): + ConsoleMarkerProcess(console, console.marker, console.prompt) + + +def test_partial_read(console, process): + process.write(b"Hello World") + console.end_process(0) + + assert process.read_full(6, timeout=0.1) == b"Hello " + assert process.poll() is None + + assert process.read_full(5) == b"World" + assert process.poll() == 0 + + assert process.read_full(1, timeout=0.1) == b"" + with pytest.raises(EOF): + process.read(1) + + +def test_read_timeout(console, process): + process.write(b"Hello World") + + assert process.read_full(100, timeout=0.1) == b"Hello World" + + process.read_full(100, timeout=0.1) == b"" + + with pytest.raises(TIMEOUT): + process.read(100) + + process.write(b"Hello World") + + with pytest.raises(TIMEOUT): + process.expect("Never found", timeout=0.1) + + +def test_expect(console, process): + process.write(b"Hello World") + console.end_process(0) + + index, before, match, after = process.expect([b"Hello", "World"], timeout=0.1) + assert index == 0 + assert before == b"" + assert match.group(0) == b"Hello" + assert after == b"Hello" + + index, before, match, after = process.expect([b"Hello", "World"], timeout=0.1) + assert index == 1 + assert before == b" " + assert match.group(0) == b"World" + assert after == b"World" + + with pytest.raises(EOF): + process.expect([b"Hello", "World"], timeout=0.1) + + index, before, match, after = process.expect([r"Hello", r"World", EOF], timeout=0.1) + assert index == 2 + assert before == b"" + + +def test_marker(console, process): + for i in range(len(console.marker) - 1): + partial_marker = console.marker[: i + 1].encode("utf-8") + process.write(partial_marker) + # Any match on a partial marker should not consume any output until we + # are certain it's not actually a marker + assert process.read_full(100, timeout=0.1) == b"" + + with pytest.raises(TIMEOUT): + process.read(100) + + with pytest.raises(TIMEOUT): + process.expect(r".+", timeout=0.1) + + assert process.expect([r".+", TIMEOUT], timeout=0.1) == (1, b"", TIMEOUT, TIMEOUT) + + # Write the partial marker again, which means the original partial can + # be returned since it is guaranteed to not be a match for the complete + # marker + process.write(partial_marker) + assert process.read_full(100, timeout=0.1) == partial_marker + + # Do it again for read_nonblocking + process.write(partial_marker) + assert process.read(100) == partial_marker + + # Do it again for expect + process.write(partial_marker) + index, before, match, after = process.expect([r".+"], timeout=0.1) + assert index == 0 + assert before == b"" + assert match.group(0) == partial_marker + assert after == partial_marker + + # Write a space, which makes the last partial no longer match the marker + process.write(b" ") + assert process.read_full(100, timeout=0.1) == partial_marker + b" " + + +def test_read_to_end(console, process): + process.write(b"Hello World") + console.end_process(0) + + process.read_to_end() == b"Hello World" + process.read_to_end() == b"" + + +def test_read_to_end_timeout(console, process): + process.write(b"Hello World") + with pytest.raises(TIMEOUT): + process.read_to_end(timeout=0.1) + + console.end_process(0) + assert process.read_to_end() == b"" + + +def test_read_exact(console, process): + process.write(b"Hello World") + assert process.read_exact(6) == b"Hello " + process.write(b" Exact") + assert process.read_exact(11) == b"World Exact" + + with pytest.raises(TIMEOUT): + process.read_exact(1) + + console.end_process(0) + + with pytest.raises(EOF): + process.read_exact(1) From 37c63a5fd8e1df86318ccc1629d2a627a0938f43 Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Thu, 16 Sep 2021 10:34:56 -0500 Subject: [PATCH 05/11] bareboxdriver: Implement start_process() Implements the start_process() API for "background" processes using the ConsoleMarkerProcess helper class. Rework the _run() API to use a background process behind the scenes Signed-off-by: Joshua Watt --- labgrid/driver/bareboxdriver.py | 68 ++++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/labgrid/driver/bareboxdriver.py b/labgrid/driver/bareboxdriver.py index d87b5e0aa..61ccc78c3 100644 --- a/labgrid/driver/bareboxdriver.py +++ b/labgrid/driver/bareboxdriver.py @@ -1,12 +1,14 @@ import shlex +from contextlib import contextmanager import attr from pexpect import TIMEOUT +from ..exceptions import CommandProcessBusy from ..factory import target_factory from ..protocol import CommandProtocol, ConsoleProtocol, LinuxBootProtocol from ..step import step -from ..util import gen_marker, Timeout, re_vt100 +from ..util import gen_marker, Timeout, re_vt100, ConsoleMarkerProcess from .common import Driver from .commandmixin import CommandMixin @@ -44,6 +46,7 @@ def __attrs_post_init__(self): self._status = 0 # barebox' default log level, used as fallback if no log level can be saved self.saved_log_level = 7 + self._process = None def on_activate(self): """Activate the BareboxDriver @@ -59,6 +62,34 @@ def on_deactivate(self): Simply sets the internal status to 0 """ self._status = 0 + assert not self._process, "Deactivating while a command process is running is not allowed" + + @contextmanager + def _start_process(self, cmd: str, *, adjust_log_level: bool = True): + if self._process is not None: + raise CommandProcessBusy() + + # FIXME: use codec, decodeerrors + marker = gen_marker() + # hide marker from expect + hidden_marker = f'"{marker[:4]}""{marker[4:]}"' + # generate command with marker and log level adjustment + cmp_command = f'echo -o /cmd {shlex.quote(cmd)}; echo {hidden_marker};' + if self.saved_log_level and adjust_log_level: + cmp_command += f' global.loglevel={self.saved_log_level};' + cmp_command += f' sh /cmd; echo {hidden_marker} $?;' + if self.saved_log_level and adjust_log_level: + cmp_command += ' global.loglevel=0;' + + self.console.sendline(cmp_command) + self.console.expect(marker) + + with ConsoleMarkerProcess(self.console, marker, self.prompt) as p: + self._process = p + try: + yield p + finally: + self._process = None @Driver.check_active @step(args=['cmd']) @@ -76,32 +107,23 @@ def _run(self, cmd: str, *, timeout: int = 30, adjust_log_level: bool = True, co Returns: Tuple[List[str],List[str], int]: if successful, None otherwise """ - # FIXME: use codec, decodeerrors - marker = gen_marker() - # hide marker from expect - hidden_marker = f'"{marker[:4]}""{marker[4:]}"' - # generate command with marker and log level adjustment - cmp_command = f'echo -o /cmd {shlex.quote(cmd)}; echo {hidden_marker};' - if self.saved_log_level and adjust_log_level: - cmp_command += f' global.loglevel={self.saved_log_level};' - cmp_command += f' sh /cmd; echo {hidden_marker} $?;' - if self.saved_log_level and adjust_log_level: - cmp_command += ' global.loglevel=0;' - if self._status == 1: - self.console.sendline(cmp_command) - _, _, match, _ = self.console.expect( - rf'{marker}(.*){marker}\s+(\d+)\s+.*{self.prompt}', - timeout=timeout) - # Remove VT100 Codes and split by newline - data = re_vt100.sub('', match.group(1).decode('utf-8')).split('\r\n')[1:-1] - self.logger.debug("Received Data: %s", data) - # Get exit code - exitcode = int(match.group(2)) - return (data, [], exitcode) + with self._start_process(cmd, adjust_log_level=adjust_log_level) as p: + output = p.read_to_end(timeout=timeout) + # Remove VT100 Codes and split by newline + data = re_vt100.sub('', output.decode('utf-8')).split('\r\n')[1:-1] + self.logger.debug("Received Data: %s", data) + return (data, [], p.exitcode) return None + @Driver.check_active + @step(args=['cmd']) + @contextmanager + def start_process(self, cmd: str): + with self._start_process(cmd) as p: + yield p + @Driver.check_active @step() def reset(self): From 5583db268412703378dafd14da0c37bbda285adb Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Thu, 16 Sep 2021 10:34:56 -0500 Subject: [PATCH 06/11] shelldriver: Implement start_process() Implements the start_process() API for "background" processes using the ConsoleMarkerProcess helper class. Rework the _run() API to use a background process behind the scenes Signed-off-by: Joshua Watt [rouven.czerwinski@linaro.org: adjusted to UBoot timestamp removal] Signed-off-by: Rouven Czerwinski --- labgrid/driver/shelldriver.py | 61 ++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/labgrid/driver/shelldriver.py b/labgrid/driver/shelldriver.py index 7d476f536..a72aa7a0d 100644 --- a/labgrid/driver/shelldriver.py +++ b/labgrid/driver/shelldriver.py @@ -5,15 +5,17 @@ import re import shlex import ipaddress +from contextlib import contextmanager import attr from pexpect import TIMEOUT import xmodem +from ..exceptions import CommandProcessBusy from ..factory import target_factory from ..protocol import CommandProtocol, ConsoleProtocol, FileTransferProtocol from ..step import step -from ..util import gen_marker, Timeout, re_vt100 +from ..util import gen_marker, Timeout, re_vt100, ConsoleMarkerProcess from .commandmixin import CommandMixin from .common import Driver from .exception import ExecutionError @@ -61,6 +63,7 @@ def __attrs_post_init__(self): self._xmodem_cached_rx_cmd = "" self._xmodem_cached_sx_cmd = "" + self._process = None def on_activate(self): if self._status == 0: @@ -76,38 +79,58 @@ def on_activate(self): def on_deactivate(self): self._status = 0 + assert not self._process, "Deactivating while a command process is running is not allowed" - def _run(self, cmd, *, timeout=30.0, codec="utf-8", decodeerrors="strict"): - """ - Runs the specified cmd on the shell and returns the output. + @contextmanager + def _start_process(self, cmd: str): + if self._process is not None: + raise CommandProcessBusy() - Arguments: - cmd - cmd to run on the shell - """ # FIXME: Handle pexpect Timeout self._check_prompt() marker = gen_marker() + # hide marker from expect cmp_command = f'''MARKER='{marker[:4]}''{marker[4:]}' run {shlex.quote(cmd)}''' self.console.sendline(cmp_command) - _, _, match, _ = self.console.expect( - rf'{marker}(.*){marker}\s+(\d+)\s+{self.prompt}', - timeout=timeout - ) - # Remove VT100 Codes, split by newline and remove surrounding newline - data = re_vt100.sub('', match.group(1).decode(codec, decodeerrors)).split('\r\n') - if data and not data[-1]: - del data[-1] - self.logger.debug("Received Data: %s", data) - # Get exit code - exitcode = int(match.group(2)) - return (data, [], exitcode) + self.console.expect(marker) + with ConsoleMarkerProcess(self.console, marker, self.prompt) as p: + self._process = p + try: + yield p + finally: + self._process = None + + def _run(self, cmd, *, timeout=30.0, codec="utf-8", decodeerrors="strict"): + """ + Runs the specified cmd on the shell and returns the output. + + Arguments: + cmd - cmd to run on the shell + """ + with self._start_process(cmd) as p: + output = p.read_to_end(timeout=timeout) + + # Remove VT100 Codes, split by newline and remove surrounding newline + data = re_vt100.sub('', output.decode(codec, decodeerrors)).split('\r\n') + if data and not data[-1]: + del data[-1] + + self.logger.debug("Received Data: %s", data) + return (data, [], p.exitcode) @Driver.check_active @step(args=['cmd'], result=True) def run(self, cmd, timeout=30.0, codec="utf-8", decodeerrors="strict"): return self._run(cmd, timeout=timeout, codec=codec, decodeerrors=decodeerrors) + @Driver.check_active + @step(args=['cmd']) + @contextmanager + def start_process(self, cmd: str): + with self._start_process(cmd) as p: + yield p + @step() def _await_login(self): """Awaits the login prompt and logs the user in""" From fdd14bf853cfdf44180a7237a80b8d3f816fbad1 Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Thu, 16 Sep 2021 10:34:56 -0500 Subject: [PATCH 07/11] ubootdriver: Implement start_process() Implements the start_process() API for "background" processes using the ConsoleMarkerProcess helper class. Rework the _run() API to use a background process behind the scenes Signed-off-by: Joshua Watt --- labgrid/driver/ubootdriver.py | 60 ++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/labgrid/driver/ubootdriver.py b/labgrid/driver/ubootdriver.py index 8cef4c0ec..0acfa7bbe 100644 --- a/labgrid/driver/ubootdriver.py +++ b/labgrid/driver/ubootdriver.py @@ -1,11 +1,13 @@ """The U-Boot Module contains the UBootDriver""" import re +from contextlib import contextmanager import attr from pexpect import TIMEOUT +from ..exceptions import CommandProcessBusy from ..factory import target_factory from ..protocol import CommandProtocol, ConsoleProtocol, LinuxBootProtocol -from ..util import gen_marker, re_vt100, Timeout +from ..util import ConsoleMarkerProcess, gen_marker, re_vt100, Timeout from ..step import step from .common import Driver from .commandmixin import CommandMixin @@ -51,6 +53,7 @@ class UBootDriver(CommandMixin, Driver, CommandProtocol, LinuxBootProtocol): def __attrs_post_init__(self): super().__attrs_post_init__() self._status = 0 + self._process = None if self.boot_expression: import warnings @@ -70,32 +73,38 @@ def on_deactivate(self): Simply sets the internal status to 0 """ self._status = 0 + assert not self._process, "Deactivating while a command process is running is not allowed" + + @contextmanager + def _start_process(self, cmd: str): + if self._process is not None: + raise CommandProcessBusy() - def _run(self, cmd: str, *, timeout: int = 30, codec: str = "utf-8", decodeerrors: str = "strict"): # pylint: disable=unused-argument,line-too-long # TODO: use codec, decodeerrors # TODO: Shell Escaping for the U-Boot Shell marker = gen_marker() - cmp_command = f"""echo '{marker[:4]}''{marker[4:]}'; {cmd}; echo "$?"; echo '{marker[:4]}''{marker[4:]}';""" # pylint: disable=line-too-long - if self._status == 1: - self.console.sendline(cmp_command) - _, before, _, _ = self.console.expect(self.prompt, timeout=timeout) - # Remove VT100 Codes and split by newline - data = re_vt100.sub( - '', before.decode('utf-8'), count=1000000 - ).replace("\r", "").split("\n") - - # Strip possible U-Boot timestamps from the line - if self.strip_timestamp: - data = [re_uboot_timestamp.sub('', line) for line in data] - - self.logger.debug("Received Data: %s", data) - # Remove first element, the invoked cmd - data = data[data.index(marker) + 1:] - data = data[:data.index(marker)] - exitcode = int(data[-1]) - del data[-1] - return (data, [], exitcode) + with ConsoleMarkerProcess(self.console, marker, self.prompt) as p: + self._process = p + try: + yield p + finally: + self._process = None + def _run(self, cmd: str, *, timeout: int = 30, codec: str = "utf-8", decodeerrors: str = "strict"): # pylint: disable=unused-argument,line-too-long + if self._status == 1: + with self._start_process(cmd) as p: + output = p.read_to_end(timeout=timeout) + # Remove VT100 Codes and split by newline + data = re_vt100.sub( + '', output.decode('utf-8'), count=1000000 + ).replace("\r", "").split("\n") + + # Strip possible U-Boot timestamps from the line + if self.strip_timestamp: + data = [re_uboot_timestamp.sub('', line) for line in data] + + self.logger.debug("Received Data: %s", data) + return (data, [], p.exitcode) return None @Driver.check_active @@ -113,6 +122,13 @@ def run(self, cmd, timeout=30): """ return self._run(cmd, timeout=timeout) + @Driver.check_active + @step(args=['cmd'], result=True) + @contextmanager + def start_process(self, cmd: str): + with self._start_process(cmd) as p: + yield p + def get_status(self): """Retrieve status of the UBootDriver. 0 means inactive, 1 means active. From 0247155ab3cdb0252891fd5e7a17efcd92cb561d Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Thu, 16 Sep 2021 10:38:21 -0500 Subject: [PATCH 08/11] sshdriver: Implement start_process() Implements the start_process() API for SSH connections. The SSHDriverProcess is a wrapper around pexpect.spawn() Signed-off-by: Joshua Watt --- labgrid/driver/sshdriver.py | 85 +++++++++++++++++++++++++++++++++- tests/test_sshdriver.py | 91 +++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 2 deletions(-) diff --git a/labgrid/driver/sshdriver.py b/labgrid/driver/sshdriver.py index fdee16b72..040279b9b 100644 --- a/labgrid/driver/sshdriver.py +++ b/labgrid/driver/sshdriver.py @@ -9,11 +9,13 @@ import tempfile import time from functools import cached_property +from contextlib import contextmanager import attr +import pexpect from ..factory import target_factory -from ..protocol import CommandProtocol, FileTransferProtocol +from ..protocol import CommandProtocol, CommandProcessProtocol, FileTransferProtocol from .commandmixin import CommandMixin from .common import Driver from ..step import step @@ -22,6 +24,56 @@ from ..util.proxy import proxymanager from ..util.timeout import Timeout from ..util.ssh import get_ssh_connect_timeout +from ..util.readmixin import ReadMixIn + + +class SSHDriverProcess(CommandProcessProtocol, ReadMixIn): + def __init__(self, sub): + self._sub = sub + + @property + def exitcode(self): + if self._sub.isalive(): + return None + + if self._sub.exitstatus is None: + return -self._sub.signalstatus + return self._sub.exitstatus + + def read(self, size=1, timeout=-1): + return self._sub.read_nonblocking(size, timeout) + + @step(args=['data']) + def write(self, data): + self._sub.write(data) + + @step(result=True) + def poll(self): + return self.exitcode + + @step(result=True) + def stop(self): + self._sub.close(True) + + @step(args=['pattern', 'timeout'], result=True) + def expect(self, pattern, *, timeout=-1): + index = self._sub.expect(pattern, timeout=timeout) + return index, self._sub.before, self._sub.match, self._sub.after + + @step(result=True) + def wait(self): + return self._sub.wait() + + @step(args=['char']) + def sendcontrol(self, char): + self._sub.sendcontrol(char) + + def __enter__(self): + return self + + def __exit__(self, typ, value, traceback): + self.stop() + return False @target_factory.reg_driver @@ -45,6 +97,7 @@ def __attrs_post_init__(self): self._scp = self._get_tool("scp") self._sshfs = self._get_tool("sshfs") self._rsync = self._get_tool("rsync") + self._processes = [] def _get_tool(self, name): if self.target.env: @@ -78,6 +131,7 @@ def on_activate(self): self._start_keepalive() def on_deactivate(self): + assert not self._processes, "Deactivating while a command process is running is not allowed" try: self._stop_keepalive() finally: @@ -239,6 +293,33 @@ def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): stderr.pop() return (stdout, stderr, sub.returncode) + @Driver.check_active + @step(args=['cmd']) + @contextmanager + def start_process(self, cmd: str): + if not self._check_keepalive(): + raise ExecutionError("Keepalive no longer running") + + complete_cmd = ["ssh", "-o", "LogLevel=QUIET", "-x", *self.ssh_prefix, + "-p", str(self.networkservice.port), "-l", self.networkservice.username, + self.networkservice.address, "-T", "--", '/bin/sh -c {}'.format(shlex.quote(cmd)), + ] + self.logger.debug("Sending command: %s", complete_cmd) + + try: + sub = pexpect.spawn(complete_cmd[0], complete_cmd[1:]) + except: + raise ExecutionError( + "error executing command: {}".format(complete_cmd) + ) + + with SSHDriverProcess(sub) as p: + self._processes.append(p) + try: + yield p + finally: + self._processes.remove(p) + def interact(self, cmd=None): assert cmd is None or isinstance(cmd, list) @@ -374,7 +455,7 @@ def scp(self, *, src, dst): "-o", f"ControlPath={self.control.replace('%', '%%')}", src, dst, ] - + if self.explicit_sftp_mode and self._scp_supports_explicit_sftp_mode(): complete_cmd.insert(1, "-s") if self.explicit_scp_mode and self._scp_supports_explicit_scp_mode(): diff --git a/tests/test_sshdriver.py b/tests/test_sshdriver.py index 4c233a834..f6128581c 100644 --- a/tests/test_sshdriver.py +++ b/tests/test_sshdriver.py @@ -1,5 +1,6 @@ import pytest import socket +from pexpect import TIMEOUT, EOF from labgrid import Environment from labgrid.driver import SSHDriver, ExecutionError @@ -218,3 +219,93 @@ def test_unix_socket_forward(ssh_localhost, tmpdir): send_socket.send(test_string.encode("utf-8")) assert client_socket.recv(16).decode("utf-8") == test_string + +@pytest.mark.sshusername +def test_start_process_simple(ssh_localhost): + with ssh_localhost.start_process("echo Hello World") as p: + assert p.read_full(6, timeout=10.0) == b"Hello " + assert p.read(7) == b"World\r\n" + assert p.read_full(1, timeout=10.0) == b"" + assert p.read_full(timeout=10.0) == b"" + with pytest.raises(EOF): + p.read(10) + + with ssh_localhost.start_process("echo Hello World") as p: + p.expect("Hello World") + p.expect(EOF) + with pytest.raises(EOF): + p.expect(r".") + + +@pytest.mark.sshusername +def test_start_process_timeout(ssh_localhost): + with ssh_localhost.start_process("cat") as p: + with pytest.raises(TIMEOUT): + p.read(100, timeout=5) == b"" + + with pytest.raises(TIMEOUT): + p.expect("Never found", timeout=5) + +@pytest.mark.sshusername +def test_start_process_stream(ssh_localhost): + with ssh_localhost.start_process("cat") as p: + p.write(b"Hello World\n") + data = p.read_full(timeout=10.0) + data = data.decode("utf-8").splitlines() + # Two lines are expected; one for the echoed input and one for the output + assert data == ["Hello World"] * 2 + + assert p.poll() is None + p.sendcontrol('d') + p.expect(EOF) + + assert p.poll() == 0 + + with ssh_localhost.start_process("cat") as p: + p.write(b"ABCDEF\n") + + p.expect(b"ABCDEF", timeout=10.0) + p.expect(b"ABCDEF", timeout=10.0) + + assert p.poll() is None + p.stop() + + code = p.poll() + assert code is not None + assert code != 0 + + +@pytest.mark.sshusername +def test_start_process_read_to_end(ssh_localhost): + with ssh_localhost.start_process("echo Hello World") as p: + p.read_to_end() == b"Hello World" + p.read_to_end() == b"" + + +@pytest.mark.sshusername +def test_start_process_read_to_end_timeout(ssh_localhost): + with ssh_localhost.start_process("cat") as p: + p.write(b"Hello World\n") + with pytest.raises(TIMEOUT): + p.read_to_end(timeout=5) + + p.sendcontrol('d') + assert p.read_to_end() == b"" + + +@pytest.mark.sshusername +def test_start_process_read_exact(ssh_localhost): + with ssh_localhost.start_process("cat") as p: + p.write(b"Hello World\n") + assert p.read_exact(6) == b"Hello " + assert p.read_exact(5) == b"World" + + # Discard the remaining data (echoed data) + p.read_full(-1, timeout=1) + + with pytest.raises(TIMEOUT): + p.read_exact(1, timeout=5) + p.sendcontrol('d') + + with pytest.raises(EOF): + p.read_exact(1, timeout=5) From ebf3e7c75b56b74a51c3bda26ffa432e8e6c5b13 Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Mon, 20 Sep 2021 08:18:09 -0500 Subject: [PATCH 09/11] Update documentation for asynchronous processes Signed-off-by: Joshua Watt --- CHANGES.rst | 2 ++ doc/development.rst | 9 --------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 4bf84ef97..ad1715b41 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -592,6 +592,8 @@ New Features in 0.4.0 a jumper. - AndroidFastbootDriver now supports booting/flashing images preconfigured in the environment configuration. +- CommandProtocol now supports a start_process() method to run a process on the + target asynchronously. Bug fixes in 0.4.0 ~~~~~~~~~~~~~~~~~~ diff --git a/doc/development.rst b/doc/development.rst index 4283912dc..d08d38e8b 100644 --- a/doc/development.rst +++ b/doc/development.rst @@ -783,12 +783,3 @@ By writing these events to a file (or sqlite database) as a trace, we can collect data over multiple runs for later analysis. This would become more useful by passing recognized events (stack traces, crashes, ...) and benchmark results via the Step infrastructure. - -CommandProtocol Support for Background Processes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Currently the CommandProtocol does not support long running -processes well. -An implementation should start a new process, -return a handle and forbid running other processes in the foreground. -The handle can be used to retrieve output from a command. From b6da387b3b0e5eecb643827b67de2395f954ce0d Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Thu, 3 Jul 2025 13:48:38 -0600 Subject: [PATCH 10/11] sshdriver: Run start_process() commands in a PTY Instructs ssh to open a pseudo terminal when running start_process() commands. This allows signals to propagate to the remote process (instead of being intercepted by the ssh client process), emulating the start_process() signal behavior on other drivers. Signed-off-by: Joshua Watt --- labgrid/driver/sshdriver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/labgrid/driver/sshdriver.py b/labgrid/driver/sshdriver.py index 040279b9b..855d1ff4a 100644 --- a/labgrid/driver/sshdriver.py +++ b/labgrid/driver/sshdriver.py @@ -302,7 +302,7 @@ def start_process(self, cmd: str): complete_cmd = ["ssh", "-o", "LogLevel=QUIET", "-x", *self.ssh_prefix, "-p", str(self.networkservice.port), "-l", self.networkservice.username, - self.networkservice.address, "-T", "--", '/bin/sh -c {}'.format(shlex.quote(cmd)), + self.networkservice.address, "-tt", "--", '/bin/sh -c {}'.format(shlex.quote(cmd)), ] self.logger.debug("Sending command: %s", complete_cmd) From 84402d4a76c76ea73f86d7ff8bdcca052edf67c6 Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Thu, 3 Jul 2025 13:48:38 -0600 Subject: [PATCH 11/11] sshdriver: Disable input echo Disables echo in the shell when using start_process() with the ssh driver. This prevents duplicate output so it behaves as the other drivers. Signed-off-by: Joshua Watt --- labgrid/driver/sshdriver.py | 2 ++ tests/test_sshdriver.py | 5 ++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/labgrid/driver/sshdriver.py b/labgrid/driver/sshdriver.py index 855d1ff4a..d895e8b8c 100644 --- a/labgrid/driver/sshdriver.py +++ b/labgrid/driver/sshdriver.py @@ -300,6 +300,7 @@ def start_process(self, cmd: str): if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") + cmd = f"stty -echo; {cmd}" # Disable input echo from ssh complete_cmd = ["ssh", "-o", "LogLevel=QUIET", "-x", *self.ssh_prefix, "-p", str(self.networkservice.port), "-l", self.networkservice.username, self.networkservice.address, "-tt", "--", '/bin/sh -c {}'.format(shlex.quote(cmd)), @@ -308,6 +309,7 @@ def start_process(self, cmd: str): try: sub = pexpect.spawn(complete_cmd[0], complete_cmd[1:]) + sub.setecho(False) # Disable input echo from pexpect except: raise ExecutionError( "error executing command: {}".format(complete_cmd) diff --git a/tests/test_sshdriver.py b/tests/test_sshdriver.py index f6128581c..dfaf47ac7 100644 --- a/tests/test_sshdriver.py +++ b/tests/test_sshdriver.py @@ -252,8 +252,8 @@ def test_start_process_stream(ssh_localhost): p.write(b"Hello World\n") data = p.read_full(timeout=10.0) data = data.decode("utf-8").splitlines() - # Two lines are expected; one for the echoed input and one for the output - assert data == ["Hello World"] * 2 + # Only one hello world expected. Input echo is disabled + assert data == ["Hello World"] assert p.poll() is None p.sendcontrol('d') @@ -264,7 +264,6 @@ def test_start_process_stream(ssh_localhost): with ssh_localhost.start_process("cat") as p: p.write(b"ABCDEF\n") - p.expect(b"ABCDEF", timeout=10.0) p.expect(b"ABCDEF", timeout=10.0) assert p.poll() is None