aboutsummaryrefslogtreecommitdiff
path: root/py/pica/pica/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'py/pica/pica/__init__.py')
-rw-r--r--py/pica/pica/__init__.py67
1 files changed, 60 insertions, 7 deletions
diff --git a/py/pica/pica/__init__.py b/py/pica/pica/__init__.py
index ba2b155..423824f 100644
--- a/py/pica/pica/__init__.py
+++ b/py/pica/pica/__init__.py
@@ -1,6 +1,10 @@
import asyncio
-from typing import Union
+from typing import Union, Type, TypeVar
+
from .packets import uci
+from .packets.uci import CommonPacketHeader, ControlPacketHeader, DataPacketHeader
+
+UciPacket = TypeVar("UciPacket", uci.DataPacket, uci.ControlPacket)
class Host:
@@ -49,14 +53,24 @@ class Host:
while True:
# Read the common packet header.
header_bytes = await self._read_exact(4)
- header = uci.ControlPacketHeader.parse_all(header_bytes)
+ common_header: CommonPacketHeader = uci.CommonPacketHeader.parse_all(
+ header_bytes[0:1]
+ ) # type: ignore
+
+ if common_header.mt == uci.MessageType.DATA:
+ # Read the packet payload.
+ data_header: DataPacketHeader = uci.DataPacketHeader.parse_all(header_bytes) # type: ignore
+ payload_bytes = await self._read_exact(data_header.payload_length)
+
+ else:
+ # Read the packet payload.
+ control_header: ControlPacketHeader = uci.ControlPacketHeader.parse_all(header_bytes) # type: ignore
+ payload_bytes = await self._read_exact(control_header.payload_length)
- # Read the packet payload.
- payload_bytes = await self._read_exact(header.payload_length)
complete_packet_bytes += payload_bytes
# Check the Packet Boundary Flag.
- match header.pbf:
+ match common_header.pbf:
case uci.PacketBoundaryFlag.COMPLETE:
return header_bytes + complete_packet_bytes
case uci.PacketBoundaryFlag.NOT_COMPLETE:
@@ -68,13 +82,20 @@ class Host:
try:
while True:
packet = await self._read_packet()
- await self.control_queue.put(packet)
+ header: CommonPacketHeader = uci.CommonPacketHeader.parse_all(packet[0:1]) # type: ignore
+ if header.mt == uci.MessageType.DATA:
+ await self.data_queue.put(packet)
+ else:
+ await self.control_queue.put(packet)
except Exception as exn:
print(f"reader task closed")
async def _recv_control(self) -> bytes:
return await self.control_queue.get()
+ async def _recv_data(self) -> bytes:
+ return await self.data_queue.get()
+
def send_control(self, packet: uci.ControlPacket):
# TODO packet fragmentation.
packet = bytearray(packet.serialize())
@@ -90,7 +111,9 @@ class Host:
self.writer.write(packet)
async def expect_control(
- self, expected: Union[type, uci.ControlPacket], timeout: float = 1.0
+ self,
+ expected: Union[Type[uci.ControlPacket], uci.ControlPacket],
+ timeout: float = 1.0,
) -> uci.ControlPacket:
"""Wait for a control packet being sent from the controller.
@@ -116,3 +139,33 @@ class Host:
)
return received
+
+ async def expect_data(
+ self,
+ expected: Union[Type[uci.DataPacket], uci.DataPacket],
+ timeout: float = 1.0,
+ ) -> uci.DataPacket:
+ """Wait for a data packet being sent from the controller.
+
+ Raises ValueError if the packet is not well formatted.
+ Raises ValueError if the packet does not match the expected type or value.
+ Raises TimeoutError if no packet is received after `timeout` seconds.
+ Returns the received packet on success.
+ """
+
+ packet = await asyncio.wait_for(self._recv_data(), timeout=timeout)
+ received = uci.DataPacket.parse_all(packet)
+
+ if isinstance(expected, type) and not isinstance(received, expected):
+ raise ValueError(
+ f"received unexpected packet {received.__class__.__name__},"
+ + f" expected {expected.__name__}"
+ )
+
+ if isinstance(expected, uci.DataPacket) and received != expected:
+ raise ValueError(
+ f"received unexpected packet {received.__class__.__name__},"
+ + f" expected {expected.__class__.__name__}"
+ )
+
+ return received