diff options
Diffstat (limited to 'pw_tokenizer/py/detokenize_test.py')
-rwxr-xr-x | pw_tokenizer/py/detokenize_test.py | 345 |
1 files changed, 315 insertions, 30 deletions
diff --git a/pw_tokenizer/py/detokenize_test.py b/pw_tokenizer/py/detokenize_test.py index df710c7e9..36bb1fa6e 100755 --- a/pw_tokenizer/py/detokenize_test.py +++ b/pw_tokenizer/py/detokenize_test.py @@ -15,12 +15,15 @@ """Tests for detokenize.""" import base64 +import concurrent import datetime as dt +import functools import io import os from pathlib import Path import struct import tempfile +from typing import Any, Callable, NamedTuple, Tuple import unittest from unittest import mock @@ -451,6 +454,35 @@ class DetokenizeWithCollisions(unittest.TestCase): self.assertIn('#0 -1', repr(unambiguous)) +class ManualPoolExecutor(concurrent.futures.Executor): + """A stubbed pool executor that captures the most recent work request + and holds it until the public process method is manually called.""" + + def __init__(self): + super().__init__() + self._func = None + + # pylint: disable=arguments-differ + def submit(self, func, *args, **kwargs): + """Submits work to the pool, stashing the partial for later use.""" + self._func = functools.partial(func, *args, **kwargs) + + def process(self): + """Processes the latest func submitted to the pool.""" + if self._func is not None: + self._func() + self._func = None + + +class InlinePoolExecutor(concurrent.futures.Executor): + """A stubbed pool executor that runs work immediately, inline.""" + + # pylint: disable=arguments-differ + def submit(self, func, *args, **kwargs): + """Submits work to the pool, stashing the partial for later use.""" + func(*args, **kwargs) + + @mock.patch('os.path.getmtime') class AutoUpdatingDetokenizerTest(unittest.TestCase): """Tests the AutoUpdatingDetokenizer class.""" @@ -478,18 +510,79 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): try: file.close() + pool = ManualPoolExecutor() detok = detokenize.AutoUpdatingDetokenizer( - file.name, min_poll_period_s=0 + file.name, min_poll_period_s=0, pool=pool ) self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) with open(file.name, 'wb') as fd: tokens.write_binary(db, fd) + # After the change but before the pool runs in another thread, + # the token should not exist. + self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + + # After the pool is allowed to process, it should. + pool.process() self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok()) finally: os.unlink(file.name) + def test_update_with_directory(self, mock_getmtime): + """Tests the update command with a directory format database.""" + db = database.load_token_database( + io.BytesIO(ELF_WITH_TOKENIZER_SECTIONS) + ) + self.assertEqual(len(db), TOKENS_IN_ELF) + + the_time = [100] + + def move_back_time_if_file_exists(path): + if os.path.exists(path): + the_time[0] -= 1 + return the_time[0] + + raise FileNotFoundError + + mock_getmtime.side_effect = move_back_time_if_file_exists + + with tempfile.TemporaryDirectory() as dbdir: + with tempfile.NamedTemporaryFile( + 'wb', delete=False, suffix='.pw_tokenizer.csv', dir=dbdir + ) as matching_suffix_file, tempfile.NamedTemporaryFile( + 'wb', delete=False, suffix='.not.right', dir=dbdir + ) as mismatched_suffix_file: + try: + matching_suffix_file.close() + mismatched_suffix_file.close() + + pool = ManualPoolExecutor() + detok = detokenize.AutoUpdatingDetokenizer( + dbdir, min_poll_period_s=0, pool=pool + ) + self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + + with open(mismatched_suffix_file.name, 'wb') as fd: + tokens.write_csv(db, fd) + pool.process() + self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + + with open(matching_suffix_file.name, 'wb') as fd: + tokens.write_csv(db, fd) + + # After the change but before the pool runs in another + # thread, the token should not exist. + self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + pool.process() + + # After the pool is allowed to process, it should. + self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + finally: + os.unlink(mismatched_suffix_file.name) + os.unlink(matching_suffix_file.name) + os.rmdir(dbdir) + # The database stays around if the file is deleted. self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok()) @@ -507,7 +600,7 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): file.close() detok = detokenize.AutoUpdatingDetokenizer( - file.name, min_poll_period_s=0 + file.name, min_poll_period_s=0, pool=InlinePoolExecutor() ) self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok()) @@ -527,7 +620,9 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): def test_token_domain_in_str(self, _) -> None: """Tests a str containing a domain""" detok = detokenize.AutoUpdatingDetokenizer( - f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*', min_poll_period_s=0 + f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*', + min_poll_period_s=0, + pool=InlinePoolExecutor(), ) self.assertEqual( len(detok.database), TOKENS_IN_ELF_WITH_TOKENIZER_SECTIONS @@ -536,7 +631,9 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): def test_token_domain_in_path(self, _) -> None: """Tests a Path() containing a domain""" detok = detokenize.AutoUpdatingDetokenizer( - Path(f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*'), min_poll_period_s=0 + Path(f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*'), + min_poll_period_s=0, + pool=InlinePoolExecutor(), ) self.assertEqual( len(detok.database), TOKENS_IN_ELF_WITH_TOKENIZER_SECTIONS @@ -545,14 +642,18 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): def test_token_no_domain_in_str(self, _) -> None: """Tests a str without a domain""" detok = detokenize.AutoUpdatingDetokenizer( - str(ELF_WITH_TOKENIZER_SECTIONS_PATH), min_poll_period_s=0 + str(ELF_WITH_TOKENIZER_SECTIONS_PATH), + min_poll_period_s=0, + pool=InlinePoolExecutor(), ) self.assertEqual(len(detok.database), TOKENS_IN_ELF) def test_token_no_domain_in_path(self, _) -> None: """Tests a Path() without a domain""" detok = detokenize.AutoUpdatingDetokenizer( - ELF_WITH_TOKENIZER_SECTIONS_PATH, min_poll_period_s=0 + ELF_WITH_TOKENIZER_SECTIONS_PATH, + min_poll_period_s=0, + pool=InlinePoolExecutor(), ) self.assertEqual(len(detok.database), TOKENS_IN_ELF) @@ -561,39 +662,173 @@ def _next_char(message: bytes) -> bytes: return bytes(b + 1 for b in message) -class PrefixedMessageDecoderTest(unittest.TestCase): - def setUp(self): - super().setUp() - self.decode = detokenize.PrefixedMessageDecoder('$', 'abcdefg') +class NestedMessageParserTest(unittest.TestCase): + """Tests parsing prefixed messages.""" + + class _Case(NamedTuple): + data: bytes + expected: bytes + title: str + transform: Callable[[bytes], bytes] = _next_char + + TRANSFORM_TEST_CASES = ( + _Case(b'$abcd', b'%bcde', 'single message'), + _Case( + b'$$WHAT?$abc$WHY? is this $ok $', + b'%%WHAT?%bcd%WHY? is this %ok %', + 'message and non-message', + ), + _Case(b'$1$', b'%1%', 'empty message'), + _Case(b'$abc$defgh', b'%bcd%efghh', 'sequential message'), + _Case( + b'w$abcx$defygh$$abz', + b'w$ABCx$DEFygh$$ABz', + 'interspersed start/end non-message', + bytes.upper, + ), + _Case( + b'$abcx$defygh$$ab', + b'$ABCx$DEFygh$$AB', + 'interspersed start/end message ', + bytes.upper, + ), + ) + + def setUp(self) -> None: + self.decoder = detokenize.NestedMessageParser('$', 'abcdefg') + + def test_transform_io(self) -> None: + for data, expected, title, transform in self.TRANSFORM_TEST_CASES: + self.assertEqual( + expected, + b''.join( + self.decoder.transform_io(io.BytesIO(data), transform) + ), + f'{title}: {data!r}', + ) + + def test_transform_bytes_with_flush(self) -> None: + for data, expected, title, transform in self.TRANSFORM_TEST_CASES: + self.assertEqual( + expected, + self.decoder.transform(data, transform, flush=True), + f'{title}: {data!r}', + ) + + def test_transform_bytes_sequential(self) -> None: + transform = lambda message: message.upper().replace(b'$', b'*') - def test_transform_single_message(self): + self.assertEqual(self.decoder.transform(b'abc$abcd', transform), b'abc') + self.assertEqual(self.decoder.transform(b'$', transform), b'*ABCD') + self.assertEqual(self.decoder.transform(b'$b', transform), b'*') + self.assertEqual(self.decoder.transform(b'', transform), b'') + self.assertEqual(self.decoder.transform(b' ', transform), b'*B ') + self.assertEqual(self.decoder.transform(b'hello', transform), b'hello') + self.assertEqual(self.decoder.transform(b'?? $ab', transform), b'?? ') self.assertEqual( - b'%bcde', - b''.join(self.decode.transform(io.BytesIO(b'$abcd'), _next_char)), + self.decoder.transform(b'123$ab4$56$a', transform), b'*AB123*AB4*56' ) + self.assertEqual( + self.decoder.transform(b'bc', transform, flush=True), b'*ABC' + ) + + MESSAGES_TEST: Any = ( + (b'123$abc456$a', (False, b'123'), (True, b'$abc'), (False, b'456')), + (b'7$abcd', (True, b'$a'), (False, b'7')), + (b'e',), + (b'',), + (b'$', (True, b'$abcde')), + (b'$', (True, b'$')), + (b'$a$b$c', (True, b'$'), (True, b'$a'), (True, b'$b')), + (b'1', (True, b'$c'), (False, b'1')), + (b'',), + (b'?', (False, b'?')), + (b'!@', (False, b'!@')), + (b'%^&', (False, b'%^&')), + ) - def test_transform_message_amidst_other_only_affects_message(self): + def test_read_messages(self) -> None: + for step in self.MESSAGES_TEST: + data: bytes = step[0] + pieces: Tuple[Tuple[bool, bytes], ...] = step[1:] + self.assertEqual(tuple(self.decoder.read_messages(data)), pieces) + + def test_read_messages_flush(self) -> None: self.assertEqual( - b'%%WHAT?%bcd%WHY? is this %ok %', - b''.join( - self.decode.transform( - io.BytesIO(b'$$WHAT?$abc$WHY? is this $ok $'), _next_char - ) - ), + list(self.decoder.read_messages(b'123$a')), [(False, b'123')] ) + self.assertEqual(list(self.decoder.read_messages(b'b')), []) + self.assertEqual( + list(self.decoder.read_messages(b'', flush=True)), [(True, b'$ab')] + ) + + def test_read_messages_io(self) -> None: + # Rework the read_messages test data for stream input. + data = io.BytesIO(b''.join(step[0] for step in self.MESSAGES_TEST)) + expected_pieces = sum((step[1:] for step in self.MESSAGES_TEST), ()) + + result = self.decoder.read_messages_io(data) + for expected_is_message, expected_data in expected_pieces: + if expected_is_message: + is_message, piece = next(result) + self.assertTrue(is_message) + self.assertEqual(expected_data, piece) + else: # the IO version yields non-messages byte by byte + for byte in expected_data: + is_message, piece = next(result) + self.assertFalse(is_message) + self.assertEqual(bytes([byte]), piece) - def test_transform_empty_message(self): + +class DetokenizeNested(unittest.TestCase): + """Tests detokenizing nested tokens""" + + def test_nested_hashed_arg(self): + detok = detokenize.Detokenizer( + tokens.Database( + [ + tokens.TokenizedStringEntry(0xA, 'tokenized argument'), + tokens.TokenizedStringEntry( + 2, + 'This is a ' + '$#%08x', + ), + ] + ) + ) self.assertEqual( - b'%1%', - b''.join(self.decode.transform(io.BytesIO(b'$1$'), _next_char)), + str(detok.detokenize(b'\x02\0\0\0\x14')), + 'This is a tokenized argument', ) - def test_transform_sequential_messages(self): + def test_nested_base64_arg(self): + detok = detokenize.Detokenizer( + tokens.Database( + [ + tokens.TokenizedStringEntry(1, 'base64 argument'), + tokens.TokenizedStringEntry(2, 'This is a %s'), + ] + ) + ) self.assertEqual( - b'%bcd%efghh', - b''.join( - self.decode.transform(io.BytesIO(b'$abc$defgh'), _next_char) - ), + str(detok.detokenize(b'\x02\0\0\0\x09$AQAAAA==')), # token for 1 + 'This is a base64 argument', + ) + + def test_deeply_nested_arg(self): + detok = detokenize.Detokenizer( + tokens.Database( + [ + tokens.TokenizedStringEntry(1, '$10#0000000005'), + tokens.TokenizedStringEntry(2, 'This is a $#%08x'), + tokens.TokenizedStringEntry(3, 'deeply nested argument'), + tokens.TokenizedStringEntry(4, '$AQAAAA=='), + tokens.TokenizedStringEntry(5, '$AwAAAA=='), + ] + ) + ) + self.assertEqual( + str(detok.detokenize(b'\x02\0\0\0\x08')), # token for 4 + 'This is a deeply nested argument', ) @@ -627,6 +862,10 @@ class DetokenizeBase64(unittest.TestCase): (JELLO + b'$a' + JELLO + b'bcd', b'Jello, world!$aJello, world!bcd'), (b'$3141', b'$3141'), (JELLO + b'$3141', b'Jello, world!$3141'), + ( + JELLO + b'$a' + JELLO + b'b' + JELLO + b'c', + b'Jello, world!$aJello, world!bJello, world!c', + ), (RECURSION, b'The secret message is "Jello, world!"'), ( RECURSION_2, @@ -650,7 +889,7 @@ class DetokenizeBase64(unittest.TestCase): output = io.BytesIO() self.detok.detokenize_base64_live(io.BytesIO(data), output, '$') - self.assertEqual(expected, output.getvalue()) + self.assertEqual(expected, output.getvalue(), f'Input: {data!r}') def test_detokenize_base64_to_file(self): for data, expected in self.TEST_CASES: @@ -670,6 +909,52 @@ class DetokenizeBase64(unittest.TestCase): ) +class DetokenizeInfiniteRecursion(unittest.TestCase): + """Tests that infinite Base64 token recursion resolves.""" + + def setUp(self): + super().setUp() + self.detok = detokenize.Detokenizer( + tokens.Database( + [ + tokens.TokenizedStringEntry(0, '$AAAAAA=='), # token for 0 + tokens.TokenizedStringEntry(1, '$AgAAAA=='), # token for 2 + tokens.TokenizedStringEntry(2, '$#00000003'), # token for 3 + tokens.TokenizedStringEntry(3, '$AgAAAA=='), # token for 2 + ] + ) + ) + + def test_detokenize_self_recursion(self): + for depth in range(5): + self.assertEqual( + self.detok.detokenize_text( + b'This one is deep: $AAAAAA==', recursion=depth + ), + b'This one is deep: $AAAAAA==', + ) + + def test_detokenize_self_recursion_default(self): + self.assertEqual( + self.detok.detokenize_text( + b'This one is deep: $AAAAAA==', + ), + b'This one is deep: $AAAAAA==', + ) + + def test_detokenize_cyclic_recursion_even(self): + self.assertEqual( + self.detok.detokenize_text(b'I said "$AQAAAA=="', recursion=6), + b'I said "$AgAAAA=="', + ) + + def test_detokenize_cyclic_recursion_odd(self): + self.assertEqual( + self.detok.detokenize_text(b'I said "$AQAAAA=="', recursion=7), + b'I said "$#00000003"', + ) + + class DetokenizeBase64InfiniteRecursion(unittest.TestCase): """Tests that infinite Bas64 token recursion resolves.""" @@ -697,7 +982,7 @@ class DetokenizeBase64InfiniteRecursion(unittest.TestCase): def test_detokenize_self_recursion_default(self): self.assertEqual( - self.detok.detokenize_base64(b'This one is deep: $AAAAAA=='), + self.detok.detokenize_base64(b'This one is deep: $64#AAAAAA=='), b'This one is deep: $AAAAAA==', ) |