aboutsummaryrefslogtreecommitdiff
path: root/pw_hdlc/py/rpc_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_hdlc/py/rpc_test.py')
-rwxr-xr-xpw_hdlc/py/rpc_test.py294
1 files changed, 294 insertions, 0 deletions
diff --git a/pw_hdlc/py/rpc_test.py b/pw_hdlc/py/rpc_test.py
new file mode 100755
index 000000000..3ed388205
--- /dev/null
+++ b/pw_hdlc/py/rpc_test.py
@@ -0,0 +1,294 @@
+#!/usr/bin/env python3
+# Copyright 2023 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""device module unit tests"""
+
+from contextlib import contextmanager
+import logging
+import queue
+import threading
+import time
+import unittest
+
+from pw_hdlc.rpc import RpcClient, HdlcRpcClient, CancellableReader
+
+
+class QueueFile:
+ """A fake file object backed by a queue for testing."""
+
+ EOF = object()
+
+ def __init__(self):
+ # Operator puts; consumer gets
+ self._q = queue.Queue()
+
+ # Consumer side access only!
+ self._readbuf = b''
+ self._eof = False
+
+ ###############
+ # Consumer side
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc_info):
+ self.close()
+
+ def _read_from_buf(self, size: int) -> bytes:
+ data = self._readbuf[:size]
+ self._readbuf = self._readbuf[size:]
+ return data
+
+ def read(self, size: int = 1) -> bytes:
+ """Reads data from the queue"""
+ # First try to get buffered data
+ data = self._read_from_buf(size)
+ assert len(data) <= size
+ size -= len(data)
+
+ # if size == 0:
+ if data:
+ return data
+
+ # No more data in the buffer
+ assert not self._readbuf
+
+ if self._eof:
+ return data # may be empty
+
+ # Not enough in the buffer; block on the queue
+ item = self._q.get()
+
+ # NOTE: We can't call Queue.task_done() here because the reader hasn't
+ # actually *acted* on the read item yet.
+
+ # Queued data
+ if isinstance(item, bytes):
+ self._readbuf = item
+ return self._read_from_buf(size)
+
+ # Queued exception
+ if isinstance(item, Exception):
+ raise item
+
+ # Report EOF
+ if item is self.EOF:
+ self._eof = True
+ return data # may be empty
+
+ raise Exception('unexpected item type')
+
+ def write(self, data: bytes) -> None:
+ pass
+
+ #####################
+ # Weird middle ground
+
+ # It is a violation of most file-like object APIs for one thread to call
+ # close() while another thread is calling read(). The behavior is
+ # undefined.
+ #
+ # - On Linux, close() may wake up a select(), leaving the caller with a bad
+ # file descriptor (which could get reused!)
+ # - Or the read() could continue to block indefinitely.
+ #
+ # We choose to cause a subsequent/parallel read to receive an exception.
+ def close(self) -> None:
+ self.cause_read_exc(Exception('closed'))
+
+ ###############
+ # Operator side
+
+ def put_read_data(self, data: bytes) -> None:
+ self._q.put(data)
+
+ def cause_read_exc(self, exc: Exception) -> None:
+ self._q.put(exc)
+
+ def set_read_eof(self) -> None:
+ self._q.put(self.EOF)
+
+ def wait_for_drain(self, timeout=None) -> None:
+ """Wait for the queue to drain (be fully consumed).
+
+ Args:
+ timeout: The maximum time (in seconds) to wait, or wait forever
+ if None.
+
+ Raises:
+ TimeoutError: If timeout is given and has elapsed.
+ """
+ # It would be great to use Queue.join() here, but that requires the
+ # consumer to call Queue.task_done(), and we can't do that because
+ # the consumer of read() doesn't know anything about it.
+ # Instead, we poll. ¯\_(ツ)_/¯
+ start_time = time.time()
+ while not self._q.empty():
+ if timeout is not None:
+ elapsed = time.time() - start_time
+ if elapsed > timeout:
+ raise TimeoutError(f"Queue not empty after {elapsed} sec")
+ time.sleep(0.1)
+
+
+class QueueFileTest(unittest.TestCase):
+ """Test the QueueFile class"""
+
+ def test_read_data(self) -> None:
+ file = QueueFile()
+ file.put_read_data(b'hello')
+ self.assertEqual(file.read(5), b'hello')
+
+ def test_read_data_multi_read(self) -> None:
+ file = QueueFile()
+ file.put_read_data(b'helloworld')
+ self.assertEqual(file.read(5), b'hello')
+ self.assertEqual(file.read(5), b'world')
+
+ def test_read_data_multi_put(self) -> None:
+ file = QueueFile()
+ file.put_read_data(b'hello')
+ file.put_read_data(b'world')
+ self.assertEqual(file.read(5), b'hello')
+ self.assertEqual(file.read(5), b'world')
+
+ def test_read_eof(self) -> None:
+ file = QueueFile()
+ file.set_read_eof()
+ result = file.read(5)
+ self.assertEqual(result, b'')
+
+ def test_read_exception(self) -> None:
+ file = QueueFile()
+ message = 'test exception'
+ file.cause_read_exc(ValueError(message))
+ with self.assertRaisesRegex(ValueError, message):
+ file.read(5)
+
+ def test_wait_for_drain_works(self) -> None:
+ file = QueueFile()
+ file.put_read_data(b'hello')
+ file.read()
+ try:
+ # Timeout is arbitrary; will return immediately.
+ file.wait_for_drain(0.1)
+ except TimeoutError:
+ self.fail("wait_for_drain raised TimeoutError")
+
+ def test_wait_for_drain_raises(self) -> None:
+ file = QueueFile()
+ file.put_read_data(b'hello')
+ # don't read
+ with self.assertRaises(TimeoutError):
+ # Timeout is arbitrary; it will raise no matter what.
+ file.wait_for_drain(0.1)
+
+
+class Sentinel:
+ def __repr__(self):
+ return 'Sentinel'
+
+
+class _QueueReader(CancellableReader):
+ def cancel_read(self) -> None:
+ self._base_obj.close()
+
+
+def _get_client(file) -> RpcClient:
+ return HdlcRpcClient(
+ _QueueReader(file),
+ paths_or_modules=[],
+ channels=[],
+ )
+
+
+# This should take <10ms but we'll wait up to 1000x longer.
+_QUEUE_DRAIN_TIMEOUT = 10.0
+
+
+class HdlcRpcClientTest(unittest.TestCase):
+ """Tests the pw_hdlc.rpc.HdlcRpcClient class."""
+
+ # NOTE: There is no test here for stream EOF because Serial.read()
+ # can return an empty result if configured with timeout != None.
+ # The reader thread will continue in this case.
+
+ def test_clean_close_after_stream_close(self) -> None:
+ """Assert RpcClient closes cleanly when stream closes."""
+ # See b/293595266.
+ file = QueueFile()
+
+ with self.assert_no_hdlc_rpc_error_logs():
+ with file:
+ with _get_client(file):
+ # We want to make sure the reader thread is blocked on
+ # read() and doesn't exit immediately.
+ file.put_read_data(b'')
+ file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT)
+
+ # RpcClient.__exit__ calls stop() on the reader thread, but
+ # it is blocked on file.read().
+
+ # QueueFile.close() is called, triggering an exception in the
+ # blocking read() (by implementation choice). The reader should
+ # handle it by *not* logging it and exiting immediately.
+
+ self.assert_no_background_threads_running()
+
+ def test_device_handles_read_exception(self) -> None:
+ """Assert RpcClient closes cleanly when read raises an exception."""
+ # See b/293595266.
+ file = QueueFile()
+
+ logger = logging.getLogger('pw_hdlc.rpc')
+ test_exc = Exception('boom')
+ with self.assertLogs(logger, level=logging.ERROR) as ctx:
+ with _get_client(file):
+ # Cause read() to raise an exception. The reader should
+ # handle it by logging it and exiting immediately.
+ file.cause_read_exc(test_exc)
+ file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT)
+
+ # Assert one exception was raised
+ self.assertEqual(len(ctx.records), 1)
+ rec = ctx.records[0]
+ self.assertIsNotNone(rec.exc_info)
+ assert rec.exc_info is not None # for mypy
+ self.assertEqual(rec.exc_info[1], test_exc)
+
+ self.assert_no_background_threads_running()
+
+ @contextmanager
+ def assert_no_hdlc_rpc_error_logs(self):
+ logger = logging.getLogger('pw_hdlc.rpc')
+ sentinel = Sentinel()
+ with self.assertLogs(logger, level=logging.ERROR) as ctx:
+ # TODO: b/294861320 - use assertNoLogs() in Python 3.10+
+ # We actually want to assert there are no errors, but
+ # TestCase.assertNoLogs() is not available until Python 3.10.
+ # So we log one error to keep the test from failing and manually
+ # inspect the list of captured records.
+ logger.error(sentinel)
+
+ yield ctx
+
+ self.assertEqual([record.msg for record in ctx.records], [sentinel])
+
+ def assert_no_background_threads_running(self):
+ self.assertEqual(threading.enumerate(), [threading.current_thread()])
+
+
+if __name__ == '__main__':
+ unittest.main()