aboutsummaryrefslogtreecommitdiff
path: root/pw_unit_test/py/pw_unit_test/serial_test_runner.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_unit_test/py/pw_unit_test/serial_test_runner.py')
-rw-r--r--pw_unit_test/py/pw_unit_test/serial_test_runner.py196
1 files changed, 196 insertions, 0 deletions
diff --git a/pw_unit_test/py/pw_unit_test/serial_test_runner.py b/pw_unit_test/py/pw_unit_test/serial_test_runner.py
new file mode 100644
index 000000000..cdb879ba7
--- /dev/null
+++ b/pw_unit_test/py/pw_unit_test/serial_test_runner.py
@@ -0,0 +1,196 @@
+#!/usr/bin/env python3
+# Copyright 2023 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""This library facilitates automating unit tests on devices with serial ports.
+
+This library assumes that the on-device test runner will emit the test results
+as plain-text over a serial port, and tests will be triggered by a pre-defined
+input (DEFAULT_TEST_START_CHARACTER) over the same serial port that results
+are emitted from.
+"""
+
+import abc
+import logging
+from pathlib import Path
+
+import serial # type: ignore
+
+
+_LOG = logging.getLogger("serial_test_runner")
+
+# Verification of test pass/failure depends on these strings. If the formatting
+# or output of the simple_printing_event_handler changes, this may need to be
+# updated.
+_TESTS_STARTING_STRING = b'[==========] Running all tests.'
+_TESTS_DONE_STRING = b'[==========] Done running all tests.'
+_TEST_FAILURE_STRING = b'[ FAILED ]'
+
+# Character used to trigger test start.
+DEFAULT_TEST_START_CHARACTER = ' '.encode('utf-8')
+
+
+class FlashingFailure(Exception):
+ """A simple exception to be raised when flashing fails."""
+
+
+class TestingFailure(Exception):
+ """A simple exception to be raised when a testing step fails."""
+
+
+class DeviceNotFound(Exception):
+ """A simple exception to be raised when unable to connect to a device."""
+
+
+class SerialTestingDevice(abc.ABC):
+ """A device that supports automated testing via parsing serial output."""
+
+ @abc.abstractmethod
+ def load_binary(self, binary: Path) -> None:
+ """Flashes the specified binary to this device.
+
+ Raises:
+ DeviceNotFound: This device is no longer available.
+ FlashingFailure: The binary could not be flashed.
+ """
+
+ @abc.abstractmethod
+ def serial_port(self) -> str:
+ """Returns the name of the com port this device is enumerated on.
+
+ Raises:
+ DeviceNotFound: This device is no longer available.
+ """
+
+ @abc.abstractmethod
+ def baud_rate(self) -> int:
+ """Returns the baud rate to use when connecting to this device.
+
+ Raises:
+ DeviceNotFound: This device is no longer available.
+ """
+
+
+def _log_subprocess_output(level, output: bytes, logger: logging.Logger):
+ """Logs subprocess output line-by-line."""
+
+ lines = output.decode('utf-8', errors='replace').splitlines()
+ for line in lines:
+ logger.log(level, line)
+
+
+def trigger_test_run(
+ port: str,
+ baud_rate: int,
+ test_timeout: float,
+ trigger_data: bytes = DEFAULT_TEST_START_CHARACTER,
+) -> bytes:
+ """Triggers a test run, and returns captured test results."""
+
+ serial_data = bytearray()
+ device = serial.Serial(baudrate=baud_rate, port=port, timeout=test_timeout)
+ if not device.is_open:
+ raise TestingFailure('Failed to open device')
+
+ # Flush input buffer and trigger the test start.
+ device.reset_input_buffer()
+ device.write(trigger_data)
+
+ # Block and wait for the first byte.
+ serial_data += device.read()
+ if not serial_data:
+ raise TestingFailure('Device not producing output')
+
+ # Read with a reasonable timeout until we stop getting characters.
+ while True:
+ bytes_read = device.readline()
+ if not bytes_read:
+ break
+ serial_data += bytes_read
+ if serial_data.rfind(_TESTS_DONE_STRING) != -1:
+ # Set to much more aggressive timeout since the last one or two
+ # lines should print out immediately. (one line if all fails or all
+ # passes, two lines if mixed.)
+ device.timeout = 0.01
+
+ # Remove carriage returns.
+ serial_data = serial_data.replace(b"\r", b"")
+
+ # Try to trim captured results to only contain most recent test run.
+ test_start_index = serial_data.rfind(_TESTS_STARTING_STRING)
+ return (
+ serial_data
+ if test_start_index == -1
+ else serial_data[test_start_index:]
+ )
+
+
+def handle_test_results(
+ test_output: bytes, logger: logging.Logger = _LOG
+) -> None:
+ """Parses test output to determine whether tests passed or failed.
+
+ Raises:
+ TestingFailure if any tests fail or if test results are incomplete.
+ """
+
+ if test_output.find(_TESTS_STARTING_STRING) == -1:
+ raise TestingFailure('Failed to find test start')
+
+ if test_output.rfind(_TESTS_DONE_STRING) == -1:
+ _log_subprocess_output(logging.INFO, test_output, logger)
+ raise TestingFailure('Tests did not complete')
+
+ if test_output.rfind(_TEST_FAILURE_STRING) != -1:
+ _log_subprocess_output(logging.INFO, test_output, logger)
+ raise TestingFailure('Test suite had one or more failures')
+
+ _log_subprocess_output(logging.DEBUG, test_output, logger)
+
+ logger.info('Test passed!')
+
+
+def run_device_test(
+ device: SerialTestingDevice,
+ binary: Path,
+ test_timeout: float,
+ logger: logging.Logger = _LOG,
+) -> bool:
+ """Runs tests on a device.
+
+ When a unit test run fails, results will be logged as an error.
+
+ Args:
+ device: The device to run tests on.
+ binary: The binary containing tests that will be flashed to the device.
+ test_timeout: If the device stops producing output longer than this
+ timeout, the test will be considered stuck and the test will be aborted.
+
+ Returns:
+ True if all tests passed.
+ """
+
+ logger.info('Flashing binary to device')
+ device.load_binary(binary)
+ try:
+ logger.info('Running test')
+ test_output = trigger_test_run(
+ device.serial_port(), device.baud_rate(), test_timeout
+ )
+ if test_output:
+ handle_test_results(test_output, logger)
+ except TestingFailure as err:
+ logger.error(err)
+ return False
+
+ return True