aboutsummaryrefslogtreecommitdiff
path: root/pw_tokenizer/py/detokenize_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_tokenizer/py/detokenize_test.py')
-rwxr-xr-xpw_tokenizer/py/detokenize_test.py345
1 files changed, 315 insertions, 30 deletions
diff --git a/pw_tokenizer/py/detokenize_test.py b/pw_tokenizer/py/detokenize_test.py
index df710c7e9..36bb1fa6e 100755
--- a/pw_tokenizer/py/detokenize_test.py
+++ b/pw_tokenizer/py/detokenize_test.py
@@ -15,12 +15,15 @@
"""Tests for detokenize."""
import base64
+import concurrent
import datetime as dt
+import functools
import io
import os
from pathlib import Path
import struct
import tempfile
+from typing import Any, Callable, NamedTuple, Tuple
import unittest
from unittest import mock
@@ -451,6 +454,35 @@ class DetokenizeWithCollisions(unittest.TestCase):
self.assertIn('#0 -1', repr(unambiguous))
+class ManualPoolExecutor(concurrent.futures.Executor):
+ """A stubbed pool executor that captures the most recent work request
+ and holds it until the public process method is manually called."""
+
+ def __init__(self):
+ super().__init__()
+ self._func = None
+
+ # pylint: disable=arguments-differ
+ def submit(self, func, *args, **kwargs):
+ """Submits work to the pool, stashing the partial for later use."""
+ self._func = functools.partial(func, *args, **kwargs)
+
+ def process(self):
+ """Processes the latest func submitted to the pool."""
+ if self._func is not None:
+ self._func()
+ self._func = None
+
+
+class InlinePoolExecutor(concurrent.futures.Executor):
+ """A stubbed pool executor that runs work immediately, inline."""
+
+ # pylint: disable=arguments-differ
+ def submit(self, func, *args, **kwargs):
+ """Submits work to the pool, stashing the partial for later use."""
+ func(*args, **kwargs)
+
+
@mock.patch('os.path.getmtime')
class AutoUpdatingDetokenizerTest(unittest.TestCase):
"""Tests the AutoUpdatingDetokenizer class."""
@@ -478,18 +510,79 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase):
try:
file.close()
+ pool = ManualPoolExecutor()
detok = detokenize.AutoUpdatingDetokenizer(
- file.name, min_poll_period_s=0
+ file.name, min_poll_period_s=0, pool=pool
)
self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok())
with open(file.name, 'wb') as fd:
tokens.write_binary(db, fd)
+ # After the change but before the pool runs in another thread,
+ # the token should not exist.
+ self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok())
+
+ # After the pool is allowed to process, it should.
+ pool.process()
self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok())
finally:
os.unlink(file.name)
+ def test_update_with_directory(self, mock_getmtime):
+ """Tests the update command with a directory format database."""
+ db = database.load_token_database(
+ io.BytesIO(ELF_WITH_TOKENIZER_SECTIONS)
+ )
+ self.assertEqual(len(db), TOKENS_IN_ELF)
+
+ the_time = [100]
+
+ def move_back_time_if_file_exists(path):
+ if os.path.exists(path):
+ the_time[0] -= 1
+ return the_time[0]
+
+ raise FileNotFoundError
+
+ mock_getmtime.side_effect = move_back_time_if_file_exists
+
+ with tempfile.TemporaryDirectory() as dbdir:
+ with tempfile.NamedTemporaryFile(
+ 'wb', delete=False, suffix='.pw_tokenizer.csv', dir=dbdir
+ ) as matching_suffix_file, tempfile.NamedTemporaryFile(
+ 'wb', delete=False, suffix='.not.right', dir=dbdir
+ ) as mismatched_suffix_file:
+ try:
+ matching_suffix_file.close()
+ mismatched_suffix_file.close()
+
+ pool = ManualPoolExecutor()
+ detok = detokenize.AutoUpdatingDetokenizer(
+ dbdir, min_poll_period_s=0, pool=pool
+ )
+ self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok())
+
+ with open(mismatched_suffix_file.name, 'wb') as fd:
+ tokens.write_csv(db, fd)
+ pool.process()
+ self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok())
+
+ with open(matching_suffix_file.name, 'wb') as fd:
+ tokens.write_csv(db, fd)
+
+ # After the change but before the pool runs in another
+ # thread, the token should not exist.
+ self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok())
+ pool.process()
+
+ # After the pool is allowed to process, it should.
+ self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok())
+ finally:
+ os.unlink(mismatched_suffix_file.name)
+ os.unlink(matching_suffix_file.name)
+ os.rmdir(dbdir)
+
# The database stays around if the file is deleted.
self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok())
@@ -507,7 +600,7 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase):
file.close()
detok = detokenize.AutoUpdatingDetokenizer(
- file.name, min_poll_period_s=0
+ file.name, min_poll_period_s=0, pool=InlinePoolExecutor()
)
self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok())
@@ -527,7 +620,9 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase):
def test_token_domain_in_str(self, _) -> None:
"""Tests a str containing a domain"""
detok = detokenize.AutoUpdatingDetokenizer(
- f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*', min_poll_period_s=0
+ f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*',
+ min_poll_period_s=0,
+ pool=InlinePoolExecutor(),
)
self.assertEqual(
len(detok.database), TOKENS_IN_ELF_WITH_TOKENIZER_SECTIONS
@@ -536,7 +631,9 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase):
def test_token_domain_in_path(self, _) -> None:
"""Tests a Path() containing a domain"""
detok = detokenize.AutoUpdatingDetokenizer(
- Path(f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*'), min_poll_period_s=0
+ Path(f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*'),
+ min_poll_period_s=0,
+ pool=InlinePoolExecutor(),
)
self.assertEqual(
len(detok.database), TOKENS_IN_ELF_WITH_TOKENIZER_SECTIONS
@@ -545,14 +642,18 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase):
def test_token_no_domain_in_str(self, _) -> None:
"""Tests a str without a domain"""
detok = detokenize.AutoUpdatingDetokenizer(
- str(ELF_WITH_TOKENIZER_SECTIONS_PATH), min_poll_period_s=0
+ str(ELF_WITH_TOKENIZER_SECTIONS_PATH),
+ min_poll_period_s=0,
+ pool=InlinePoolExecutor(),
)
self.assertEqual(len(detok.database), TOKENS_IN_ELF)
def test_token_no_domain_in_path(self, _) -> None:
"""Tests a Path() without a domain"""
detok = detokenize.AutoUpdatingDetokenizer(
- ELF_WITH_TOKENIZER_SECTIONS_PATH, min_poll_period_s=0
+ ELF_WITH_TOKENIZER_SECTIONS_PATH,
+ min_poll_period_s=0,
+ pool=InlinePoolExecutor(),
)
self.assertEqual(len(detok.database), TOKENS_IN_ELF)
@@ -561,39 +662,173 @@ def _next_char(message: bytes) -> bytes:
return bytes(b + 1 for b in message)
-class PrefixedMessageDecoderTest(unittest.TestCase):
- def setUp(self):
- super().setUp()
- self.decode = detokenize.PrefixedMessageDecoder('$', 'abcdefg')
+class NestedMessageParserTest(unittest.TestCase):
+ """Tests parsing prefixed messages."""
+
+ class _Case(NamedTuple):
+ data: bytes
+ expected: bytes
+ title: str
+ transform: Callable[[bytes], bytes] = _next_char
+
+ TRANSFORM_TEST_CASES = (
+ _Case(b'$abcd', b'%bcde', 'single message'),
+ _Case(
+ b'$$WHAT?$abc$WHY? is this $ok $',
+ b'%%WHAT?%bcd%WHY? is this %ok %',
+ 'message and non-message',
+ ),
+ _Case(b'$1$', b'%1%', 'empty message'),
+ _Case(b'$abc$defgh', b'%bcd%efghh', 'sequential message'),
+ _Case(
+ b'w$abcx$defygh$$abz',
+ b'w$ABCx$DEFygh$$ABz',
+ 'interspersed start/end non-message',
+ bytes.upper,
+ ),
+ _Case(
+ b'$abcx$defygh$$ab',
+ b'$ABCx$DEFygh$$AB',
+ 'interspersed start/end message ',
+ bytes.upper,
+ ),
+ )
+
+ def setUp(self) -> None:
+ self.decoder = detokenize.NestedMessageParser('$', 'abcdefg')
+
+ def test_transform_io(self) -> None:
+ for data, expected, title, transform in self.TRANSFORM_TEST_CASES:
+ self.assertEqual(
+ expected,
+ b''.join(
+ self.decoder.transform_io(io.BytesIO(data), transform)
+ ),
+ f'{title}: {data!r}',
+ )
+
+ def test_transform_bytes_with_flush(self) -> None:
+ for data, expected, title, transform in self.TRANSFORM_TEST_CASES:
+ self.assertEqual(
+ expected,
+ self.decoder.transform(data, transform, flush=True),
+ f'{title}: {data!r}',
+ )
+
+ def test_transform_bytes_sequential(self) -> None:
+ transform = lambda message: message.upper().replace(b'$', b'*')
- def test_transform_single_message(self):
+ self.assertEqual(self.decoder.transform(b'abc$abcd', transform), b'abc')
+ self.assertEqual(self.decoder.transform(b'$', transform), b'*ABCD')
+ self.assertEqual(self.decoder.transform(b'$b', transform), b'*')
+ self.assertEqual(self.decoder.transform(b'', transform), b'')
+ self.assertEqual(self.decoder.transform(b' ', transform), b'*B ')
+ self.assertEqual(self.decoder.transform(b'hello', transform), b'hello')
+ self.assertEqual(self.decoder.transform(b'?? $ab', transform), b'?? ')
self.assertEqual(
- b'%bcde',
- b''.join(self.decode.transform(io.BytesIO(b'$abcd'), _next_char)),
+ self.decoder.transform(b'123$ab4$56$a', transform), b'*AB123*AB4*56'
)
+ self.assertEqual(
+ self.decoder.transform(b'bc', transform, flush=True), b'*ABC'
+ )
+
+ MESSAGES_TEST: Any = (
+ (b'123$abc456$a', (False, b'123'), (True, b'$abc'), (False, b'456')),
+ (b'7$abcd', (True, b'$a'), (False, b'7')),
+ (b'e',),
+ (b'',),
+ (b'$', (True, b'$abcde')),
+ (b'$', (True, b'$')),
+ (b'$a$b$c', (True, b'$'), (True, b'$a'), (True, b'$b')),
+ (b'1', (True, b'$c'), (False, b'1')),
+ (b'',),
+ (b'?', (False, b'?')),
+ (b'!@', (False, b'!@')),
+ (b'%^&', (False, b'%^&')),
+ )
- def test_transform_message_amidst_other_only_affects_message(self):
+ def test_read_messages(self) -> None:
+ for step in self.MESSAGES_TEST:
+ data: bytes = step[0]
+ pieces: Tuple[Tuple[bool, bytes], ...] = step[1:]
+ self.assertEqual(tuple(self.decoder.read_messages(data)), pieces)
+
+ def test_read_messages_flush(self) -> None:
self.assertEqual(
- b'%%WHAT?%bcd%WHY? is this %ok %',
- b''.join(
- self.decode.transform(
- io.BytesIO(b'$$WHAT?$abc$WHY? is this $ok $'), _next_char
- )
- ),
+ list(self.decoder.read_messages(b'123$a')), [(False, b'123')]
)
+ self.assertEqual(list(self.decoder.read_messages(b'b')), [])
+ self.assertEqual(
+ list(self.decoder.read_messages(b'', flush=True)), [(True, b'$ab')]
+ )
+
+ def test_read_messages_io(self) -> None:
+ # Rework the read_messages test data for stream input.
+ data = io.BytesIO(b''.join(step[0] for step in self.MESSAGES_TEST))
+ expected_pieces = sum((step[1:] for step in self.MESSAGES_TEST), ())
+
+ result = self.decoder.read_messages_io(data)
+ for expected_is_message, expected_data in expected_pieces:
+ if expected_is_message:
+ is_message, piece = next(result)
+ self.assertTrue(is_message)
+ self.assertEqual(expected_data, piece)
+ else: # the IO version yields non-messages byte by byte
+ for byte in expected_data:
+ is_message, piece = next(result)
+ self.assertFalse(is_message)
+ self.assertEqual(bytes([byte]), piece)
- def test_transform_empty_message(self):
+
+class DetokenizeNested(unittest.TestCase):
+ """Tests detokenizing nested tokens"""
+
+ def test_nested_hashed_arg(self):
+ detok = detokenize.Detokenizer(
+ tokens.Database(
+ [
+ tokens.TokenizedStringEntry(0xA, 'tokenized argument'),
+ tokens.TokenizedStringEntry(
+ 2,
+ 'This is a ' + '$#%08x',
+ ),
+ ]
+ )
+ )
self.assertEqual(
- b'%1%',
- b''.join(self.decode.transform(io.BytesIO(b'$1$'), _next_char)),
+ str(detok.detokenize(b'\x02\0\0\0\x14')),
+ 'This is a tokenized argument',
)
- def test_transform_sequential_messages(self):
+ def test_nested_base64_arg(self):
+ detok = detokenize.Detokenizer(
+ tokens.Database(
+ [
+ tokens.TokenizedStringEntry(1, 'base64 argument'),
+ tokens.TokenizedStringEntry(2, 'This is a %s'),
+ ]
+ )
+ )
self.assertEqual(
- b'%bcd%efghh',
- b''.join(
- self.decode.transform(io.BytesIO(b'$abc$defgh'), _next_char)
- ),
+ str(detok.detokenize(b'\x02\0\0\0\x09$AQAAAA==')), # token for 1
+ 'This is a base64 argument',
+ )
+
+ def test_deeply_nested_arg(self):
+ detok = detokenize.Detokenizer(
+ tokens.Database(
+ [
+ tokens.TokenizedStringEntry(1, '$10#0000000005'),
+ tokens.TokenizedStringEntry(2, 'This is a $#%08x'),
+ tokens.TokenizedStringEntry(3, 'deeply nested argument'),
+ tokens.TokenizedStringEntry(4, '$AQAAAA=='),
+ tokens.TokenizedStringEntry(5, '$AwAAAA=='),
+ ]
+ )
+ )
+ self.assertEqual(
+ str(detok.detokenize(b'\x02\0\0\0\x08')), # token for 4
+ 'This is a deeply nested argument',
)
@@ -627,6 +862,10 @@ class DetokenizeBase64(unittest.TestCase):
(JELLO + b'$a' + JELLO + b'bcd', b'Jello, world!$aJello, world!bcd'),
(b'$3141', b'$3141'),
(JELLO + b'$3141', b'Jello, world!$3141'),
+ (
+ JELLO + b'$a' + JELLO + b'b' + JELLO + b'c',
+ b'Jello, world!$aJello, world!bJello, world!c',
+ ),
(RECURSION, b'The secret message is "Jello, world!"'),
(
RECURSION_2,
@@ -650,7 +889,7 @@ class DetokenizeBase64(unittest.TestCase):
output = io.BytesIO()
self.detok.detokenize_base64_live(io.BytesIO(data), output, '$')
- self.assertEqual(expected, output.getvalue())
+ self.assertEqual(expected, output.getvalue(), f'Input: {data!r}')
def test_detokenize_base64_to_file(self):
for data, expected in self.TEST_CASES:
@@ -670,6 +909,52 @@ class DetokenizeBase64(unittest.TestCase):
)
+class DetokenizeInfiniteRecursion(unittest.TestCase):
+ """Tests that infinite Base64 token recursion resolves."""
+
+ def setUp(self):
+ super().setUp()
+ self.detok = detokenize.Detokenizer(
+ tokens.Database(
+ [
+ tokens.TokenizedStringEntry(0, '$AAAAAA=='), # token for 0
+ tokens.TokenizedStringEntry(1, '$AgAAAA=='), # token for 2
+ tokens.TokenizedStringEntry(2, '$#00000003'), # token for 3
+ tokens.TokenizedStringEntry(3, '$AgAAAA=='), # token for 2
+ ]
+ )
+ )
+
+ def test_detokenize_self_recursion(self):
+ for depth in range(5):
+ self.assertEqual(
+ self.detok.detokenize_text(
+ b'This one is deep: $AAAAAA==', recursion=depth
+ ),
+ b'This one is deep: $AAAAAA==',
+ )
+
+ def test_detokenize_self_recursion_default(self):
+ self.assertEqual(
+ self.detok.detokenize_text(
+ b'This one is deep: $AAAAAA==',
+ ),
+ b'This one is deep: $AAAAAA==',
+ )
+
+ def test_detokenize_cyclic_recursion_even(self):
+ self.assertEqual(
+ self.detok.detokenize_text(b'I said "$AQAAAA=="', recursion=6),
+ b'I said "$AgAAAA=="',
+ )
+
+ def test_detokenize_cyclic_recursion_odd(self):
+ self.assertEqual(
+ self.detok.detokenize_text(b'I said "$AQAAAA=="', recursion=7),
+ b'I said "$#00000003"',
+ )
+
+
class DetokenizeBase64InfiniteRecursion(unittest.TestCase):
"""Tests that infinite Bas64 token recursion resolves."""
@@ -697,7 +982,7 @@ class DetokenizeBase64InfiniteRecursion(unittest.TestCase):
def test_detokenize_self_recursion_default(self):
self.assertEqual(
- self.detok.detokenize_base64(b'This one is deep: $AAAAAA=='),
+ self.detok.detokenize_base64(b'This one is deep: $64#AAAAAA=='),
b'This one is deep: $AAAAAA==',
)