diff options
Diffstat (limited to 'py/pica/pica/__init__.py')
-rw-r--r-- | py/pica/pica/__init__.py | 67 |
1 files changed, 60 insertions, 7 deletions
diff --git a/py/pica/pica/__init__.py b/py/pica/pica/__init__.py index ba2b155..423824f 100644 --- a/py/pica/pica/__init__.py +++ b/py/pica/pica/__init__.py @@ -1,6 +1,10 @@ import asyncio -from typing import Union +from typing import Union, Type, TypeVar + from .packets import uci +from .packets.uci import CommonPacketHeader, ControlPacketHeader, DataPacketHeader + +UciPacket = TypeVar("UciPacket", uci.DataPacket, uci.ControlPacket) class Host: @@ -49,14 +53,24 @@ class Host: while True: # Read the common packet header. header_bytes = await self._read_exact(4) - header = uci.ControlPacketHeader.parse_all(header_bytes) + common_header: CommonPacketHeader = uci.CommonPacketHeader.parse_all( + header_bytes[0:1] + ) # type: ignore + + if common_header.mt == uci.MessageType.DATA: + # Read the packet payload. + data_header: DataPacketHeader = uci.DataPacketHeader.parse_all(header_bytes) # type: ignore + payload_bytes = await self._read_exact(data_header.payload_length) + + else: + # Read the packet payload. + control_header: ControlPacketHeader = uci.ControlPacketHeader.parse_all(header_bytes) # type: ignore + payload_bytes = await self._read_exact(control_header.payload_length) - # Read the packet payload. - payload_bytes = await self._read_exact(header.payload_length) complete_packet_bytes += payload_bytes # Check the Packet Boundary Flag. - match header.pbf: + match common_header.pbf: case uci.PacketBoundaryFlag.COMPLETE: return header_bytes + complete_packet_bytes case uci.PacketBoundaryFlag.NOT_COMPLETE: @@ -68,13 +82,20 @@ class Host: try: while True: packet = await self._read_packet() - await self.control_queue.put(packet) + header: CommonPacketHeader = uci.CommonPacketHeader.parse_all(packet[0:1]) # type: ignore + if header.mt == uci.MessageType.DATA: + await self.data_queue.put(packet) + else: + await self.control_queue.put(packet) except Exception as exn: print(f"reader task closed") async def _recv_control(self) -> bytes: return await self.control_queue.get() + async def _recv_data(self) -> bytes: + return await self.data_queue.get() + def send_control(self, packet: uci.ControlPacket): # TODO packet fragmentation. packet = bytearray(packet.serialize()) @@ -90,7 +111,9 @@ class Host: self.writer.write(packet) async def expect_control( - self, expected: Union[type, uci.ControlPacket], timeout: float = 1.0 + self, + expected: Union[Type[uci.ControlPacket], uci.ControlPacket], + timeout: float = 1.0, ) -> uci.ControlPacket: """Wait for a control packet being sent from the controller. @@ -116,3 +139,33 @@ class Host: ) return received + + async def expect_data( + self, + expected: Union[Type[uci.DataPacket], uci.DataPacket], + timeout: float = 1.0, + ) -> uci.DataPacket: + """Wait for a data packet being sent from the controller. + + Raises ValueError if the packet is not well formatted. + Raises ValueError if the packet does not match the expected type or value. + Raises TimeoutError if no packet is received after `timeout` seconds. + Returns the received packet on success. + """ + + packet = await asyncio.wait_for(self._recv_data(), timeout=timeout) + received = uci.DataPacket.parse_all(packet) + + if isinstance(expected, type) and not isinstance(received, expected): + raise ValueError( + f"received unexpected packet {received.__class__.__name__}," + + f" expected {expected.__name__}" + ) + + if isinstance(expected, uci.DataPacket) and received != expected: + raise ValueError( + f"received unexpected packet {received.__class__.__name__}," + + f" expected {expected.__class__.__name__}" + ) + + return received |