diff options
Diffstat (limited to 'pw_hdlc/py/rpc_test.py')
-rwxr-xr-x | pw_hdlc/py/rpc_test.py | 294 |
1 files changed, 294 insertions, 0 deletions
diff --git a/pw_hdlc/py/rpc_test.py b/pw_hdlc/py/rpc_test.py new file mode 100755 index 000000000..3ed388205 --- /dev/null +++ b/pw_hdlc/py/rpc_test.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# Copyright 2023 The Pigweed Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +"""device module unit tests""" + +from contextlib import contextmanager +import logging +import queue +import threading +import time +import unittest + +from pw_hdlc.rpc import RpcClient, HdlcRpcClient, CancellableReader + + +class QueueFile: + """A fake file object backed by a queue for testing.""" + + EOF = object() + + def __init__(self): + # Operator puts; consumer gets + self._q = queue.Queue() + + # Consumer side access only! + self._readbuf = b'' + self._eof = False + + ############### + # Consumer side + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self.close() + + def _read_from_buf(self, size: int) -> bytes: + data = self._readbuf[:size] + self._readbuf = self._readbuf[size:] + return data + + def read(self, size: int = 1) -> bytes: + """Reads data from the queue""" + # First try to get buffered data + data = self._read_from_buf(size) + assert len(data) <= size + size -= len(data) + + # if size == 0: + if data: + return data + + # No more data in the buffer + assert not self._readbuf + + if self._eof: + return data # may be empty + + # Not enough in the buffer; block on the queue + item = self._q.get() + + # NOTE: We can't call Queue.task_done() here because the reader hasn't + # actually *acted* on the read item yet. + + # Queued data + if isinstance(item, bytes): + self._readbuf = item + return self._read_from_buf(size) + + # Queued exception + if isinstance(item, Exception): + raise item + + # Report EOF + if item is self.EOF: + self._eof = True + return data # may be empty + + raise Exception('unexpected item type') + + def write(self, data: bytes) -> None: + pass + + ##################### + # Weird middle ground + + # It is a violation of most file-like object APIs for one thread to call + # close() while another thread is calling read(). The behavior is + # undefined. + # + # - On Linux, close() may wake up a select(), leaving the caller with a bad + # file descriptor (which could get reused!) + # - Or the read() could continue to block indefinitely. + # + # We choose to cause a subsequent/parallel read to receive an exception. + def close(self) -> None: + self.cause_read_exc(Exception('closed')) + + ############### + # Operator side + + def put_read_data(self, data: bytes) -> None: + self._q.put(data) + + def cause_read_exc(self, exc: Exception) -> None: + self._q.put(exc) + + def set_read_eof(self) -> None: + self._q.put(self.EOF) + + def wait_for_drain(self, timeout=None) -> None: + """Wait for the queue to drain (be fully consumed). + + Args: + timeout: The maximum time (in seconds) to wait, or wait forever + if None. + + Raises: + TimeoutError: If timeout is given and has elapsed. + """ + # It would be great to use Queue.join() here, but that requires the + # consumer to call Queue.task_done(), and we can't do that because + # the consumer of read() doesn't know anything about it. + # Instead, we poll. ¯\_(ツ)_/¯ + start_time = time.time() + while not self._q.empty(): + if timeout is not None: + elapsed = time.time() - start_time + if elapsed > timeout: + raise TimeoutError(f"Queue not empty after {elapsed} sec") + time.sleep(0.1) + + +class QueueFileTest(unittest.TestCase): + """Test the QueueFile class""" + + def test_read_data(self) -> None: + file = QueueFile() + file.put_read_data(b'hello') + self.assertEqual(file.read(5), b'hello') + + def test_read_data_multi_read(self) -> None: + file = QueueFile() + file.put_read_data(b'helloworld') + self.assertEqual(file.read(5), b'hello') + self.assertEqual(file.read(5), b'world') + + def test_read_data_multi_put(self) -> None: + file = QueueFile() + file.put_read_data(b'hello') + file.put_read_data(b'world') + self.assertEqual(file.read(5), b'hello') + self.assertEqual(file.read(5), b'world') + + def test_read_eof(self) -> None: + file = QueueFile() + file.set_read_eof() + result = file.read(5) + self.assertEqual(result, b'') + + def test_read_exception(self) -> None: + file = QueueFile() + message = 'test exception' + file.cause_read_exc(ValueError(message)) + with self.assertRaisesRegex(ValueError, message): + file.read(5) + + def test_wait_for_drain_works(self) -> None: + file = QueueFile() + file.put_read_data(b'hello') + file.read() + try: + # Timeout is arbitrary; will return immediately. + file.wait_for_drain(0.1) + except TimeoutError: + self.fail("wait_for_drain raised TimeoutError") + + def test_wait_for_drain_raises(self) -> None: + file = QueueFile() + file.put_read_data(b'hello') + # don't read + with self.assertRaises(TimeoutError): + # Timeout is arbitrary; it will raise no matter what. + file.wait_for_drain(0.1) + + +class Sentinel: + def __repr__(self): + return 'Sentinel' + + +class _QueueReader(CancellableReader): + def cancel_read(self) -> None: + self._base_obj.close() + + +def _get_client(file) -> RpcClient: + return HdlcRpcClient( + _QueueReader(file), + paths_or_modules=[], + channels=[], + ) + + +# This should take <10ms but we'll wait up to 1000x longer. +_QUEUE_DRAIN_TIMEOUT = 10.0 + + +class HdlcRpcClientTest(unittest.TestCase): + """Tests the pw_hdlc.rpc.HdlcRpcClient class.""" + + # NOTE: There is no test here for stream EOF because Serial.read() + # can return an empty result if configured with timeout != None. + # The reader thread will continue in this case. + + def test_clean_close_after_stream_close(self) -> None: + """Assert RpcClient closes cleanly when stream closes.""" + # See b/293595266. + file = QueueFile() + + with self.assert_no_hdlc_rpc_error_logs(): + with file: + with _get_client(file): + # We want to make sure the reader thread is blocked on + # read() and doesn't exit immediately. + file.put_read_data(b'') + file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT) + + # RpcClient.__exit__ calls stop() on the reader thread, but + # it is blocked on file.read(). + + # QueueFile.close() is called, triggering an exception in the + # blocking read() (by implementation choice). The reader should + # handle it by *not* logging it and exiting immediately. + + self.assert_no_background_threads_running() + + def test_device_handles_read_exception(self) -> None: + """Assert RpcClient closes cleanly when read raises an exception.""" + # See b/293595266. + file = QueueFile() + + logger = logging.getLogger('pw_hdlc.rpc') + test_exc = Exception('boom') + with self.assertLogs(logger, level=logging.ERROR) as ctx: + with _get_client(file): + # Cause read() to raise an exception. The reader should + # handle it by logging it and exiting immediately. + file.cause_read_exc(test_exc) + file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT) + + # Assert one exception was raised + self.assertEqual(len(ctx.records), 1) + rec = ctx.records[0] + self.assertIsNotNone(rec.exc_info) + assert rec.exc_info is not None # for mypy + self.assertEqual(rec.exc_info[1], test_exc) + + self.assert_no_background_threads_running() + + @contextmanager + def assert_no_hdlc_rpc_error_logs(self): + logger = logging.getLogger('pw_hdlc.rpc') + sentinel = Sentinel() + with self.assertLogs(logger, level=logging.ERROR) as ctx: + # TODO: b/294861320 - use assertNoLogs() in Python 3.10+ + # We actually want to assert there are no errors, but + # TestCase.assertNoLogs() is not available until Python 3.10. + # So we log one error to keep the test from failing and manually + # inspect the list of captured records. + logger.error(sentinel) + + yield ctx + + self.assertEqual([record.msg for record in ctx.records], [sentinel]) + + def assert_no_background_threads_running(self): + self.assertEqual(threading.enumerate(), [threading.current_thread()]) + + +if __name__ == '__main__': + unittest.main() |