aboutsummaryrefslogtreecommitdiff
path: root/pw_tokenizer/py/pw_tokenizer/detokenize.py
blob: c777252da028c79872789853f94835dce999f171 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
#!/usr/bin/env python3
# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
r"""Decodes and detokenizes strings from binary or Base64 input.

The main class provided by this module is the Detokenize class. To use it,
construct it with the path to an ELF or CSV database, a tokens.Database,
or a file object for an ELF file or CSV. Then, call the detokenize method with
encoded messages, one at a time. The detokenize method returns a
DetokenizedString object with the result.

For example::

  from pw_tokenizer import detokenize

  detok = detokenize.Detokenizer('path/to/firmware/image.elf')
  print(detok.detokenize(b'\x12\x34\x56\x78\x03hi!'))

This module also provides a command line interface for decoding and detokenizing
messages from a file or stdin.
"""

import argparse
import base64
import binascii
from concurrent.futures import Executor, ThreadPoolExecutor
import enum
import io
import logging
import os
from pathlib import Path
import re
import string
import struct
import sys
import threading
import time
from typing import (
    AnyStr,
    BinaryIO,
    Callable,
    Dict,
    List,
    Iterable,
    Iterator,
    Match,
    NamedTuple,
    Optional,
    Pattern,
    Tuple,
    Union,
)

try:
    from pw_tokenizer import database, decode, encode, tokens
except ImportError:
    # Append this path to the module search path to allow running this module
    # without installing the pw_tokenizer package.
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from pw_tokenizer import database, decode, encode, tokens

_LOG = logging.getLogger('pw_tokenizer')

ENCODED_TOKEN = struct.Struct('<I')
_BASE64_CHARS = string.ascii_letters + string.digits + '+/-_='
DEFAULT_RECURSION = 9
NESTED_TOKEN_PREFIX = encode.NESTED_TOKEN_PREFIX.encode()
NESTED_TOKEN_BASE_PREFIX = encode.NESTED_TOKEN_BASE_PREFIX.encode()

_BASE8_TOKEN_REGEX = rb'(?P<base8>[0-7]{11})'
_BASE10_TOKEN_REGEX = rb'(?P<base10>[0-9]{10})'
_BASE16_TOKEN_REGEX = rb'(?P<base16>[A-Fa-f0-9]{8})'
_BASE64_TOKEN_REGEX = (
    rb'(?P<base64>'
    # Tokenized Base64 contains 0 or more blocks of four Base64 chars.
    rb'(?:[A-Za-z0-9+/\-_]{4})*'
    # The last block of 4 chars may have one or two padding chars (=).
    rb'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'
    rb')'
)
_NESTED_TOKEN_FORMATS = (
    _BASE8_TOKEN_REGEX,
    _BASE10_TOKEN_REGEX,
    _BASE16_TOKEN_REGEX,
    _BASE64_TOKEN_REGEX,
)

_RawIo = Union[io.RawIOBase, BinaryIO]
_RawIoOrBytes = Union[_RawIo, bytes]


def _token_regex(prefix: bytes) -> Pattern[bytes]:
    """Returns a regular expression for prefixed tokenized strings."""
    return re.compile(
        # Tokenized strings start with the prefix character ($).
        re.escape(prefix)
        # Optional; no base specifier defaults to BASE64.
        # Hash (#) with no number specified defaults to Base-16.
        + rb'(?P<basespec>(?P<base>[0-9]*)?'
        + NESTED_TOKEN_BASE_PREFIX
        + rb')?'
        # Match one of the following token formats.
        + rb'('
        + rb'|'.join(_NESTED_TOKEN_FORMATS)
        + rb')'
    )


class DetokenizedString:
    """A detokenized string, with all results if there are collisions."""

    def __init__(
        self,
        token: Optional[int],
        format_string_entries: Iterable[tuple],
        encoded_message: bytes,
        show_errors: bool = False,
        recursive_detokenize: Optional[Callable[[str], str]] = None,
    ):
        self.token = token
        self.encoded_message = encoded_message
        self._show_errors = show_errors

        self.successes: List[decode.FormattedString] = []
        self.failures: List[decode.FormattedString] = []

        decode_attempts: List[Tuple[Tuple, decode.FormattedString]] = []

        for entry, fmt in format_string_entries:
            result = fmt.format(
                encoded_message[ENCODED_TOKEN.size :], show_errors
            )
            if recursive_detokenize:
                result = decode.FormattedString(
                    recursive_detokenize(result.value),
                    result.args,
                    result.remaining,
                )
            decode_attempts.append((result.score(entry.date_removed), result))

        # Sort the attempts by the score so the most likely results are first.
        decode_attempts.sort(key=lambda value: value[0], reverse=True)

        # Split out the successesful decodes from the failures.
        for score, result in decode_attempts:
            if score[0]:
                self.successes.append(result)
            else:
                self.failures.append(result)

    def ok(self) -> bool:
        """True if exactly one string decoded the arguments successfully."""
        return len(self.successes) == 1

    def matches(self) -> List[decode.FormattedString]:
        """Returns the strings that matched the token, best matches first."""
        return self.successes + self.failures

    def best_result(self) -> Optional[decode.FormattedString]:
        """Returns the string and args for the most likely decoded string."""
        for string_and_args in self.matches():
            return string_and_args

        return None

    def error_message(self) -> str:
        """If detokenization failed, returns a descriptive message."""
        if self.ok():
            return ''

        if not self.matches():
            if self.token is None:
                return 'missing token'

            return 'unknown token {:08x}'.format(self.token)

        if len(self.matches()) == 1:
            return 'decoding failed for {!r}'.format(self.matches()[0].value)

        return '{} matches'.format(len(self.matches()))

    def __str__(self) -> str:
        """Returns the string for the most likely result."""
        result = self.best_result()
        if result:
            return result[0]

        if self._show_errors:
            return '<[ERROR: {}|{!r}]>'.format(
                self.error_message(), self.encoded_message
            )

        # Display the string as prefixed Base64 if it cannot be decoded.
        return encode.prefixed_base64(self.encoded_message)

    def __repr__(self) -> str:
        if self.ok():
            message = repr(str(self))
        else:
            message = 'ERROR: {}|{!r}'.format(
                self.error_message(), self.encoded_message
            )

        return '{}({})'.format(type(self).__name__, message)


class _TokenizedFormatString(NamedTuple):
    entry: tokens.TokenizedStringEntry
    format: decode.FormatString


class Detokenizer:
    """Main detokenization class; detokenizes strings and caches results."""

    def __init__(self, *token_database_or_elf, show_errors: bool = False):
        """Decodes and detokenizes binary messages.

        Args:
          *token_database_or_elf: a path or file object for an ELF or CSV
              database, a tokens.Database, or an elf_reader.Elf
          show_errors: if True, an error message is used in place of the %
              conversion specifier when an argument fails to decode
        """
        self.show_errors = show_errors

        self._database_lock = threading.Lock()

        # Cache FormatStrings for faster lookup & formatting.
        self._cache: Dict[int, List[_TokenizedFormatString]] = {}

        self._initialize_database(token_database_or_elf)

    def _initialize_database(self, token_sources: Iterable) -> None:
        with self._database_lock:
            self.database = database.load_token_database(*token_sources)
            self._cache.clear()

    def lookup(self, token: int) -> List[_TokenizedFormatString]:
        """Returns (TokenizedStringEntry, FormatString) list for matches."""
        with self._database_lock:
            try:
                return self._cache[token]
            except KeyError:
                format_strings = [
                    _TokenizedFormatString(
                        entry, decode.FormatString(str(entry))
                    )
                    for entry in self.database.token_to_entries[token]
                ]
                self._cache[token] = format_strings
                return format_strings

    def detokenize(
        self,
        encoded_message: bytes,
        prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX,
        recursion: int = DEFAULT_RECURSION,
    ) -> DetokenizedString:
        """Decodes and detokenizes a message as a DetokenizedString."""
        if not encoded_message:
            return DetokenizedString(
                None, (), encoded_message, self.show_errors
            )

        # Pad messages smaller than ENCODED_TOKEN.size with zeroes to support
        # tokens smaller than a uint32. Messages with arguments must always use
        # a full 32-bit token.
        missing_token_bytes = ENCODED_TOKEN.size - len(encoded_message)
        if missing_token_bytes > 0:
            encoded_message += b'\0' * missing_token_bytes

        (token,) = ENCODED_TOKEN.unpack_from(encoded_message)

        recursive_detokenize = None
        if recursion > 0:
            recursive_detokenize = self._detokenize_nested_callback(
                prefix, recursion
            )

        return DetokenizedString(
            token,
            self.lookup(token),
            encoded_message,
            self.show_errors,
            recursive_detokenize,
        )

    def detokenize_text(
        self,
        data: AnyStr,
        prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX,
        recursion: int = DEFAULT_RECURSION,
    ) -> AnyStr:
        """Decodes and replaces prefixed Base64 messages in the provided data.

        Args:
          data: the binary data to decode
          prefix: one-character byte string that signals the start of a message
          recursion: how many levels to recursively decode

        Returns:
          copy of the data with all recognized tokens decoded
        """
        return self._detokenize_nested_callback(prefix, recursion)(data)

    # TODO(gschen): remove unnecessary function
    def detokenize_base64(
        self,
        data: AnyStr,
        prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX,
        recursion: int = DEFAULT_RECURSION,
    ) -> AnyStr:
        """Alias of detokenize_text for backwards compatibility."""
        return self.detokenize_text(data, prefix, recursion)

    def detokenize_text_to_file(
        self,
        data: AnyStr,
        output: BinaryIO,
        prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX,
        recursion: int = DEFAULT_RECURSION,
    ) -> None:
        """Decodes prefixed Base64 messages in data; decodes to output file."""
        output.write(self._detokenize_nested(data, prefix, recursion))

    # TODO(gschen): remove unnecessary function
    def detokenize_base64_to_file(
        self,
        data: AnyStr,
        output: BinaryIO,
        prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX,
        recursion: int = DEFAULT_RECURSION,
    ) -> None:
        """Alias of detokenize_text_to_file for backwards compatibility."""
        self.detokenize_text_to_file(data, output, prefix, recursion)

    def detokenize_text_live(
        self,
        input_file: _RawIo,
        output: BinaryIO,
        prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX,
        recursion: int = DEFAULT_RECURSION,
    ) -> None:
        """Reads chars one-at-a-time, decoding messages; SLOW for big files."""

        def transform(data: bytes) -> bytes:
            return self._detokenize_nested(data.decode(), prefix, recursion)

        for message in NestedMessageParser(prefix, _BASE64_CHARS).transform_io(
            input_file, transform
        ):
            output.write(message)

            # Flush each line to prevent delays when piping between processes.
            if b'\n' in message:
                output.flush()

    # TODO(gschen): remove unnecessary function
    def detokenize_base64_live(
        self,
        input_file: _RawIo,
        output: BinaryIO,
        prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX,
        recursion: int = DEFAULT_RECURSION,
    ) -> None:
        """Alias of detokenize_text_live for backwards compatibility."""
        self.detokenize_text_live(input_file, output, prefix, recursion)

    def _detokenize_nested_callback(
        self,
        prefix: Union[str, bytes],
        recursion: int,
    ) -> Callable[[AnyStr], AnyStr]:
        """Returns a function that replaces all tokens for a given string."""

        def detokenize(message: AnyStr) -> AnyStr:
            result = self._detokenize_nested(message, prefix, recursion)
            return result.decode() if isinstance(message, str) else result

        return detokenize

    def _detokenize_nested(
        self,
        message: Union[str, bytes],
        prefix: Union[str, bytes],
        recursion: int,
    ) -> bytes:
        """Returns the message with recognized tokens replaced.

        Message data is internally handled as bytes regardless of input message
        type and returns the result as bytes.
        """
        # A unified format across the token types is required for regex
        # consistency.
        message = message.encode() if isinstance(message, str) else message
        prefix = prefix.encode() if isinstance(prefix, str) else prefix

        if not self.database:
            return message

        result = message
        for _ in range(recursion - 1):
            result = _token_regex(prefix).sub(self._detokenize_scan, result)

            if result == message:
                return result
        return result

    def _detokenize_scan(self, match: Match[bytes]) -> bytes:
        """Decodes prefixed tokens for one of multiple formats."""
        basespec = match.group('basespec')
        base = match.group('base')

        if not basespec or (base == b'64'):
            return self._detokenize_once_base64(match)

        if not base:
            base = b'16'

        return self._detokenize_once(match, base)

    def _detokenize_once(
        self,
        match: Match[bytes],
        base: bytes,
    ) -> bytes:
        """Performs lookup on a plain token"""
        original = match.group(0)
        token = match.group('base' + base.decode())
        if not token:
            return original

        token = int(token, int(base))
        entries = self.database.token_to_entries[token]

        if len(entries) == 1:
            return str(entries[0]).encode()

        # TODO(gschen): improve token collision reporting

        return original

    def _detokenize_once_base64(
        self,
        match: Match[bytes],
    ) -> bytes:
        """Performs lookup on a Base64 token"""
        original = match.group(0)

        try:
            encoded_token = match.group('base64')
            if not encoded_token:
                return original

            detokenized_string = self.detokenize(
                base64.b64decode(encoded_token, validate=True), recursion=0
            )

            if detokenized_string.matches():
                return str(detokenized_string).encode()

        except binascii.Error:
            pass

        return original


_PathOrStr = Union[Path, str]


# TODO: b/265334753 - Reuse this function in database.py:LoadTokenDatabases
def _parse_domain(path: _PathOrStr) -> Tuple[Path, Optional[Pattern[str]]]:
    """Extracts an optional domain regex pattern suffix from a path"""

    if isinstance(path, Path):
        path = str(path)

    delimiters = path.count('#')

    if delimiters == 0:
        return Path(path), None

    if delimiters == 1:
        path, domain = path.split('#')
        return Path(path), re.compile(domain)

    raise ValueError(
        f'Too many # delimiters. Expected 0 or 1, found {delimiters}'
    )


class AutoUpdatingDetokenizer(Detokenizer):
    """Loads and updates a detokenizer from database paths."""

    class _DatabasePath:
        """Tracks the modified time of a path or file object."""

        def __init__(self, path: _PathOrStr) -> None:
            self.path, self.domain = _parse_domain(path)
            self._modified_time: Optional[float] = self._last_modified_time()

        def updated(self) -> bool:
            """True if the path has been updated since the last call."""
            modified_time = self._last_modified_time()
            if modified_time is None or modified_time == self._modified_time:
                return False

            self._modified_time = modified_time
            return True

        def _last_modified_time(self) -> Optional[float]:
            if self.path.is_dir():
                mtime = -1.0
                for child in self.path.glob(tokens.DIR_DB_GLOB):
                    mtime = max(mtime, os.path.getmtime(child))
                return mtime if mtime >= 0 else None

            try:
                return os.path.getmtime(self.path)
            except FileNotFoundError:
                return None

        def load(self) -> tokens.Database:
            try:
                if self.domain is not None:
                    return database.load_token_database(
                        self.path, domain=self.domain
                    )
                return database.load_token_database(self.path)
            except FileNotFoundError:
                return database.load_token_database()

    def __init__(
        self,
        *paths_or_files: _PathOrStr,
        min_poll_period_s: float = 1.0,
        pool: Executor = ThreadPoolExecutor(max_workers=1),
    ) -> None:
        self.paths = tuple(self._DatabasePath(path) for path in paths_or_files)
        self.min_poll_period_s = min_poll_period_s
        self._last_checked_time: float = time.time()
        # Thread pool to use for loading the databases. Limit to a single
        # worker since this is low volume and not time critical.
        self._pool = pool
        super().__init__(*(path.load() for path in self.paths))

    def __del__(self) -> None:
        self._pool.shutdown(wait=False)

    def _reload_paths(self) -> None:
        self._initialize_database([path.load() for path in self.paths])

    def _reload_if_changed(self) -> None:
        if time.time() - self._last_checked_time >= self.min_poll_period_s:
            self._last_checked_time = time.time()

            if any(path.updated() for path in self.paths):
                _LOG.info('Changes detected; reloading token database')
                self._pool.submit(self._reload_paths)

    def lookup(self, token: int) -> List[_TokenizedFormatString]:
        self._reload_if_changed()
        return super().lookup(token)


class NestedMessageParser:
    """Parses nested tokenized messages from a byte stream or string."""

    class _State(enum.Enum):
        MESSAGE = 1
        NON_MESSAGE = 2

    def __init__(
        self,
        prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX,
        chars: Union[str, bytes] = _BASE64_CHARS,
    ) -> None:
        """Initializes a parser.

        Args:
            prefix: one character that signifies the start of a message (``$``).
            chars: characters allowed in a message
        """
        self._prefix = ord(prefix)

        if isinstance(chars, str):
            chars = chars.encode()

        # Store the valid message bytes as a set of byte values.
        self._message_bytes = frozenset(chars)

        if len(prefix) != 1 or self._prefix in self._message_bytes:
            raise ValueError(
                f'Invalid prefix {prefix!r}: the prefix must be a single '
                'character that is not a valid message character.'
            )

        self._buffer = bytearray()
        self._state: NestedMessageParser._State = self._State.NON_MESSAGE

    def read_messages_io(
        self, binary_io: _RawIo
    ) -> Iterator[Tuple[bool, bytes]]:
        """Reads prefixed messages from a byte stream (BinaryIO object).

        Reads until EOF. If the stream is nonblocking (``read(1)`` returns
        ``None``), then this function returns and may be called again with the
        same IO object to continue parsing. Partial messages are preserved
        between calls.

        Yields:
            ``(is_message, contents)`` chunks.
        """
        # The read may block indefinitely, depending on the IO object.
        while (read_byte := binary_io.read(1)) != b'':
            # Handle non-blocking IO by returning when no bytes are available.
            if read_byte is None:
                return

            for byte in read_byte:
                yield from self._handle_byte(byte)

            if self._state is self._State.NON_MESSAGE:  # yield non-message byte
                yield from self._flush()

        yield from self._flush()  # Always flush after EOF
        self._state = self._State.NON_MESSAGE

    def read_messages(
        self, chunk: bytes, *, flush: bool = False
    ) -> Iterator[Tuple[bool, bytes]]:
        """Reads prefixed messages from a byte string.

        This function may be called repeatedly with chunks of a stream. Partial
        messages are preserved between calls, unless ``flush=True``.

        Args:
            chunk: byte string that may contain nested messagses
            flush: whether to flush any incomplete messages after processing
                this chunk

        Yields:
            ``(is_message, contents)`` chunks.
        """
        for byte in chunk:
            yield from self._handle_byte(byte)

        if flush or self._state is self._State.NON_MESSAGE:
            yield from self._flush()

    def _handle_byte(self, byte: int) -> Iterator[Tuple[bool, bytes]]:
        if self._state is self._State.MESSAGE:
            if byte not in self._message_bytes:
                yield from self._flush()
                if byte != self._prefix:
                    self._state = self._State.NON_MESSAGE
        elif self._state is self._State.NON_MESSAGE:
            if byte == self._prefix:
                yield from self._flush()
                self._state = self._State.MESSAGE
        else:
            raise NotImplementedError(f'Unsupported state: {self._state}')

        self._buffer.append(byte)

    def _flush(self) -> Iterator[Tuple[bool, bytes]]:
        data = bytes(self._buffer)
        self._buffer.clear()
        if data:
            yield self._state is self._State.MESSAGE, data

    def transform_io(
        self,
        binary_io: _RawIo,
        transform: Callable[[bytes], bytes],
    ) -> Iterator[bytes]:
        """Yields the file with a transformation applied to the messages."""
        for is_message, chunk in self.read_messages_io(binary_io):
            yield transform(chunk) if is_message else chunk

    def transform(
        self,
        chunk: bytes,
        transform: Callable[[bytes], bytes],
        *,
        flush: bool = False,
    ) -> bytes:
        """Yields the chunk with a transformation applied to the messages.

        Partial messages are preserved between calls unless ``flush=True``.
        """
        return b''.join(
            transform(data) if is_message else data
            for is_message, data in self.read_messages(chunk, flush=flush)
        )


# TODO(hepler): Remove this unnecessary function.
def detokenize_base64(
    detokenizer: Detokenizer,
    data: bytes,
    prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX,
    recursion: int = DEFAULT_RECURSION,
) -> bytes:
    """Alias for detokenizer.detokenize_base64 for backwards compatibility.

    This function is deprecated; do not call it.
    """
    return detokenizer.detokenize_base64(data, prefix, recursion)


def _follow_and_detokenize_file(
    detokenizer: Detokenizer,
    file: BinaryIO,
    output: BinaryIO,
    prefix: Union[str, bytes],
    poll_period_s: float = 0.01,
) -> None:
    """Polls a file to detokenize it and any appended data."""

    try:
        while True:
            data = file.read()
            if data:
                detokenizer.detokenize_base64_to_file(data, output, prefix)
                output.flush()
            else:
                time.sleep(poll_period_s)
    except KeyboardInterrupt:
        pass


def _handle_base64(
    databases,
    input_file: BinaryIO,
    output: BinaryIO,
    prefix: str,
    show_errors: bool,
    follow: bool,
) -> None:
    """Handles the base64 command line option."""
    # argparse.FileType doesn't correctly handle - for binary files.
    if input_file is sys.stdin:
        input_file = sys.stdin.buffer

    if output is sys.stdout:
        output = sys.stdout.buffer

    detokenizer = Detokenizer(
        tokens.Database.merged(*databases), show_errors=show_errors
    )

    if follow:
        _follow_and_detokenize_file(detokenizer, input_file, output, prefix)
    elif input_file.seekable():
        # Process seekable files all at once, which is MUCH faster.
        detokenizer.detokenize_base64_to_file(input_file.read(), output, prefix)
    else:
        # For non-seekable inputs (e.g. pipes), read one character at a time.
        detokenizer.detokenize_base64_live(input_file, output, prefix)


def _parse_args() -> argparse.Namespace:
    """Parses and return command line arguments."""

    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.set_defaults(handler=lambda **_: parser.print_help())

    subparsers = parser.add_subparsers(help='Encoding of the input.')

    base64_help = 'Detokenize Base64-encoded data from a file or stdin.'
    subparser = subparsers.add_parser(
        'base64',
        description=base64_help,
        parents=[database.token_databases_parser()],
        help=base64_help,
    )
    subparser.set_defaults(handler=_handle_base64)
    subparser.add_argument(
        '-i',
        '--input',
        dest='input_file',
        type=argparse.FileType('rb'),
        default=sys.stdin.buffer,
        help='The file from which to read; provide - or omit for stdin.',
    )
    subparser.add_argument(
        '-f',
        '--follow',
        action='store_true',
        help=(
            'Detokenize data appended to input_file as it grows; similar to '
            'tail -f.'
        ),
    )
    subparser.add_argument(
        '-o',
        '--output',
        type=argparse.FileType('wb'),
        default=sys.stdout.buffer,
        help=(
            'The file to which to write the output; '
            'provide - or omit for stdout.'
        ),
    )
    subparser.add_argument(
        '-p',
        '--prefix',
        default=NESTED_TOKEN_PREFIX,
        help=(
            'The one-character prefix that signals the start of a '
            'nested tokenized message. (default: $)'
        ),
    )
    subparser.add_argument(
        '-s',
        '--show_errors',
        action='store_true',
        help=(
            'Show error messages instead of conversion specifiers when '
            'arguments cannot be decoded.'
        ),
    )

    return parser.parse_args()


def main() -> int:
    args = _parse_args()

    handler = args.handler
    del args.handler

    handler(**vars(args))
    return 0


if __name__ == '__main__':
    if sys.version_info[0] < 3:
        sys.exit('ERROR: The detokenizer command line tools require Python 3.')
    sys.exit(main())