diff options
Diffstat (limited to 'pw_transfer/py/pw_transfer/client.py')
-rw-r--r-- | pw_transfer/py/pw_transfer/client.py | 245 |
1 files changed, 166 insertions, 79 deletions
diff --git a/pw_transfer/py/pw_transfer/client.py b/pw_transfer/py/pw_transfer/client.py index b97f2cef1..85bfc4575 100644 --- a/pw_transfer/py/pw_transfer/client.py +++ b/pw_transfer/py/pw_transfer/client.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Pigweed Authors +# Copyright 2022 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of @@ -21,9 +21,20 @@ from typing import Any, Dict, Optional, Union from pw_rpc.callback_client import BidirectionalStreamingCall from pw_status import Status -from pw_transfer.transfer import (ProgressCallback, ReadTransfer, Transfer, - WriteTransfer) -from pw_transfer.transfer_pb2 import Chunk +from pw_transfer.transfer import ( + ProgressCallback, + ProtocolVersion, + ReadTransfer, + Transfer, + WriteTransfer, +) +from pw_transfer.chunk import Chunk + +try: + from pw_transfer import transfer_pb2 +except ImportError: + # For the bazel build, which puts generated protos in a different location. + from pigweed.pw_transfer import transfer_pb2 # type: ignore _LOG = logging.getLogger(__package__) @@ -40,12 +51,17 @@ class Manager: # pylint: disable=too-many-instance-attributes When created, a Manager starts a separate thread in which transfer communications and events are handled. """ - def __init__(self, - rpc_transfer_service, - *, - default_response_timeout_s: float = 2.0, - initial_response_timeout_s: float = 4.0, - max_retries: int = 3): + + def __init__( + self, + rpc_transfer_service, + *, + default_response_timeout_s: float = 2.0, + initial_response_timeout_s: float = 4.0, + max_retries: int = 3, + max_lifetime_retries: int = 1500, + default_protocol_version=ProtocolVersion.LATEST, + ): """Initializes a Manager on top of a TransferService. Args: @@ -53,14 +69,18 @@ class Manager: # pylint: disable=too-many-instance-attributes default_response_timeout_s: max time to wait between receiving packets initial_response_timeout_s: timeout for the first packet; may be longer to account for transfer handler initialization - max_retires: number of times to retry after a timeout + max_retires: number of times to retry a single package after a timeout + max_lifetime_retires: Cumulative maximum number of times to retry over + the course of the transfer before giving up. """ self._service: Any = rpc_transfer_service self._default_response_timeout_s = default_response_timeout_s self._initial_response_timeout_s = initial_response_timeout_s self.max_retries = max_retries + self.max_lifetime_retries = max_lifetime_retries + self._default_protocol_version = default_protocol_version - # Ongoing transfers in the service by ID. + # Ongoing transfers in the service by resource ID. self._read_transfers: _TransferDict = {} self._write_transfers: _TransferDict = {} @@ -70,6 +90,8 @@ class Manager: # pylint: disable=too-many-instance-attributes self._write_stream: Optional[BidirectionalStreamingCall] = None self._loop = asyncio.new_event_loop() + # Set the event loop for the current thread. + asyncio.set_event_loop(self._loop) # Queues are used for communication between the Manager context and the # dedicated asyncio transfer thread. @@ -78,8 +100,9 @@ class Manager: # pylint: disable=too-many-instance-attributes self._write_chunk_queue: asyncio.Queue = asyncio.Queue() self._quit_event = asyncio.Event() - self._thread = threading.Thread(target=self._start_event_loop_thread, - daemon=True) + self._thread = threading.Thread( + target=self._start_event_loop_thread, daemon=True + ) self._thread.start() @@ -90,42 +113,63 @@ class Manager: # pylint: disable=too-many-instance-attributes self._loop.call_soon_threadsafe(self._quit_event.set) self._thread.join() - def read(self, - transfer_id: int, - progress_callback: ProgressCallback = None) -> bytes: + def read( + self, + resource_id: int, + progress_callback: Optional[ProgressCallback] = None, + protocol_version: Optional[ProtocolVersion] = None, + ) -> bytes: """Receives ("downloads") data from the server. + Args: + resource_id: ID of the resource from which to read. + progress_callback: Optional callback periodically invoked throughout + the transfer with the transfer state. Can be used to provide user- + facing status updates such as progress bars. + Raises: Error: the transfer failed to complete """ - if transfer_id in self._read_transfers: - raise ValueError(f'Read transfer {transfer_id} already exists') - - transfer = ReadTransfer(transfer_id, - self._send_read_chunk, - self._end_read_transfer, - self._default_response_timeout_s, - self._initial_response_timeout_s, - self.max_retries, - progress_callback=progress_callback) + if resource_id in self._read_transfers: + raise ValueError( + f'Read transfer for resource {resource_id} already exists' + ) + + if protocol_version is None: + protocol_version = self._default_protocol_version + + transfer = ReadTransfer( + resource_id, + self._send_read_chunk, + self._end_read_transfer, + self._default_response_timeout_s, + self._initial_response_timeout_s, + self.max_retries, + self.max_lifetime_retries, + protocol_version, + progress_callback=progress_callback, + ) self._start_read_transfer(transfer) transfer.done.wait() if not transfer.status.ok(): - raise Error(transfer.id, transfer.status) + raise Error(transfer.resource_id, transfer.status) return transfer.data - def write(self, - transfer_id: int, - data: Union[bytes, str], - progress_callback: ProgressCallback = None) -> None: + def write( + self, + resource_id: int, + data: Union[bytes, str], + progress_callback: Optional[ProgressCallback] = None, + protocol_version: Optional[ProtocolVersion] = None, + ) -> None: """Transmits ("uploads") data to the server. Args: - transfer_id: ID of the write transfer + resource_id: ID of the resource to which to write. data: Data to send to the server. progress_callback: Optional callback periodically invoked throughout the transfer with the transfer state. Can be used to provide user- @@ -138,31 +182,40 @@ class Manager: # pylint: disable=too-many-instance-attributes if isinstance(data, str): data = data.encode() - if transfer_id in self._write_transfers: - raise ValueError(f'Write transfer {transfer_id} already exists') - - transfer = WriteTransfer(transfer_id, - data, - self._send_write_chunk, - self._end_write_transfer, - self._default_response_timeout_s, - self._initial_response_timeout_s, - self.max_retries, - progress_callback=progress_callback) + if resource_id in self._write_transfers: + raise ValueError( + f'Write transfer for resource {resource_id} already exists' + ) + + if protocol_version is None: + protocol_version = self._default_protocol_version + + transfer = WriteTransfer( + resource_id, + data, + self._send_write_chunk, + self._end_write_transfer, + self._default_response_timeout_s, + self._initial_response_timeout_s, + self.max_retries, + self.max_lifetime_retries, + protocol_version, + progress_callback=progress_callback, + ) self._start_write_transfer(transfer) transfer.done.wait() if not transfer.status.ok(): - raise Error(transfer.id, transfer.status) + raise Error(transfer.resource_id, transfer.status) def _send_read_chunk(self, chunk: Chunk) -> None: assert self._read_stream is not None - self._read_stream.send(chunk) + self._read_stream.send(chunk.to_message()) def _send_write_chunk(self, chunk: Chunk) -> None: assert self._write_stream is not None - self._write_stream.send(chunk) + self._write_stream.send(chunk.to_message()) def _start_event_loop_thread(self): """Entry point for event loop thread that starts an asyncio context.""" @@ -189,7 +242,8 @@ class Manager: # pylint: disable=too-many-instance-attributes # Perform a select(2)-like wait for one of several events to occur. done, _ = await asyncio.wait( (exit_thread, new_transfer, read_chunk, write_chunk), - return_when=asyncio.FIRST_COMPLETED) + return_when=asyncio.FIRST_COMPLETED, + ) if exit_thread in done: break @@ -197,26 +251,35 @@ class Manager: # pylint: disable=too-many-instance-attributes if new_transfer in done: await new_transfer.result().begin() new_transfer = self._loop.create_task( - self._new_transfer_queue.get()) + self._new_transfer_queue.get() + ) if read_chunk in done: self._loop.create_task( - self._handle_chunk(self._read_transfers, - read_chunk.result())) + self._handle_chunk( + self._read_transfers, read_chunk.result() + ) + ) read_chunk = self._loop.create_task( - self._read_chunk_queue.get()) + self._read_chunk_queue.get() + ) if write_chunk in done: self._loop.create_task( - self._handle_chunk(self._write_transfers, - write_chunk.result())) + self._handle_chunk( + self._write_transfers, write_chunk.result() + ) + ) write_chunk = self._loop.create_task( - self._write_chunk_queue.get()) + self._write_chunk_queue.get() + ) self._loop.stop() @staticmethod - async def _handle_chunk(transfers: _TransferDict, chunk: Chunk) -> None: + async def _handle_chunk( + transfers: _TransferDict, message: transfer_pb2.Chunk + ) -> None: """Processes an incoming chunk from a stream. The chunk is dispatched to an active transfer based on its ID. If the @@ -224,12 +287,23 @@ class Manager: # pylint: disable=too-many-instance-attributes is invoked. """ + chunk = Chunk.from_message(message) + + # Find a transfer for the chunk in the list of active transfers. try: - transfer = transfers[chunk.transfer_id] - except KeyError: + if chunk.resource_id is not None: + # Prioritize a resource_id if one is set. + transfer = transfers[chunk.resource_id] + else: + # Otherwise, match against either resource or session ID. + transfer = next( + t for t in transfers.values() if t.id == chunk.id() + ) + except (KeyError, StopIteration): _LOG.error( 'TransferManager received chunk for unknown transfer %d', - chunk.transfer_id) + chunk.id(), + ) # TODO(frolv): What should be done here, if anything? return @@ -238,8 +312,10 @@ class Manager: # pylint: disable=too-many-instance-attributes def _open_read_stream(self) -> None: self._read_stream = self._service.Read.invoke( lambda _, chunk: self._loop.call_soon_threadsafe( - self._read_chunk_queue.put_nowait, chunk), - on_error=lambda _, status: self._on_read_error(status)) + self._read_chunk_queue.put_nowait, chunk + ), + on_error=lambda _, status: self._on_read_error(status), + ) def _on_read_error(self, status: Status) -> None: """Callback for an RPC error in the read stream.""" @@ -265,8 +341,10 @@ class Manager: # pylint: disable=too-many-instance-attributes def _open_write_stream(self) -> None: self._write_stream = self._service.Write.invoke( lambda _, chunk: self._loop.call_soon_threadsafe( - self._write_chunk_queue.put_nowait, chunk), - on_error=lambda _, status: self._on_write_error(status)) + self._write_chunk_queue.put_nowait, chunk + ), + on_error=lambda _, status: self._on_write_error(status), + ) def _on_write_error(self, status: Status) -> None: """Callback for an RPC error in the write stream.""" @@ -292,22 +370,26 @@ class Manager: # pylint: disable=too-many-instance-attributes def _start_read_transfer(self, transfer: Transfer) -> None: """Begins a new read transfer, opening the stream if it isn't.""" - self._read_transfers[transfer.id] = transfer + self._read_transfers[transfer.resource_id] = transfer if not self._read_stream: self._open_read_stream() _LOG.debug('Starting new read transfer %d', transfer.id) - self._loop.call_soon_threadsafe(self._new_transfer_queue.put_nowait, - transfer) + self._loop.call_soon_threadsafe( + self._new_transfer_queue.put_nowait, transfer + ) def _end_read_transfer(self, transfer: Transfer) -> None: """Completes a read transfer.""" - del self._read_transfers[transfer.id] + del self._read_transfers[transfer.resource_id] if not transfer.status.ok(): - _LOG.error('Read transfer %d terminated with status %s', - transfer.id, transfer.status) + _LOG.error( + 'Read transfer %d terminated with status %s', + transfer.id, + transfer.status, + ) # TODO(frolv): This doesn't seem to work. Investigate why. # If no more transfers are using the read stream, close it. @@ -318,22 +400,26 @@ class Manager: # pylint: disable=too-many-instance-attributes def _start_write_transfer(self, transfer: Transfer) -> None: """Begins a new write transfer, opening the stream if it isn't.""" - self._write_transfers[transfer.id] = transfer + self._write_transfers[transfer.resource_id] = transfer if not self._write_stream: self._open_write_stream() _LOG.debug('Starting new write transfer %d', transfer.id) - self._loop.call_soon_threadsafe(self._new_transfer_queue.put_nowait, - transfer) + self._loop.call_soon_threadsafe( + self._new_transfer_queue.put_nowait, transfer + ) def _end_write_transfer(self, transfer: Transfer) -> None: """Completes a write transfer.""" - del self._write_transfers[transfer.id] + del self._write_transfers[transfer.resource_id] if not transfer.status.ok(): - _LOG.error('Write transfer %d terminated with status %s', - transfer.id, transfer.status) + _LOG.error( + 'Write transfer %d terminated with status %s', + transfer.id, + transfer.status, + ) # TODO(frolv): This doesn't seem to work. Investigate why. # If no more transfers are using the write stream, close it. @@ -345,9 +431,10 @@ class Manager: # pylint: disable=too-many-instance-attributes class Error(Exception): """Exception raised when a transfer fails. - Stores the ID of the failed transfer and the error that occurred. + Stores the ID of the failed transfer resource and the error that occurred. """ - def __init__(self, transfer_id: int, status: Status): - super().__init__(f'Transfer {transfer_id} failed with status {status}') - self.transfer_id = transfer_id + + def __init__(self, resource_id: int, status: Status): + super().__init__(f'Transfer {resource_id} failed with status {status}') + self.resource_id = resource_id self.status = status |