diff options
author | zxzxwu <92432172+zxzxwu@users.noreply.github.com> | 2023-09-21 16:09:36 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-21 16:09:36 +0800 |
commit | d290df4aa92957aa09edb03210f29bb222527748 (patch) | |
tree | 1f6c7f4c1eba46721a60720ddbb692ad52fd9e2f | |
parent | 67418e649a44cb2f34fe685a88f7351a5bcc8558 (diff) | |
parent | e559744f3205eb73649b91e3b25559922c329aa8 (diff) | |
download | bumble-d290df4aa92957aa09edb03210f29bb222527748.tar.gz |
Merge pull request #278 from zxzxwu/gatt2
Typing GATT
-rw-r--r-- | bumble/att.py | 128 | ||||
-rw-r--r-- | bumble/gatt.py | 42 | ||||
-rw-r--r-- | bumble/gatt_client.py | 99 | ||||
-rw-r--r-- | bumble/gatt_server.py | 106 | ||||
-rw-r--r-- | tests/gatt_test.py | 8 |
5 files changed, 245 insertions, 138 deletions
diff --git a/bumble/att.py b/bumble/att.py index 55ae8a5..db8d2ba 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -23,13 +23,14 @@ # Imports # ----------------------------------------------------------------------------- from __future__ import annotations +import enum import functools import struct from pyee import EventEmitter -from typing import Dict, Type, TYPE_CHECKING +from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING -from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError -from bumble.hci import HCI_Object, key_with_value, HCI_Constant +from bumble.core import UUID, name_or_number, ProtocolError +from bumble.hci import HCI_Object, key_with_value from bumble.colors import color if TYPE_CHECKING: @@ -182,6 +183,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731 # pylint: enable=line-too-long # pylint: disable=invalid-name + # ----------------------------------------------------------------------------- # Exceptions # ----------------------------------------------------------------------------- @@ -209,7 +211,7 @@ class ATT_PDU: pdu_classes: Dict[int, Type[ATT_PDU]] = {} op_code = 0 - name = None + name: str @staticmethod def from_bytes(pdu): @@ -720,47 +722,67 @@ class ATT_Handle_Value_Confirmation(ATT_PDU): # ----------------------------------------------------------------------------- -class Attribute(EventEmitter): - # Permission flags - READABLE = 0x01 - WRITEABLE = 0x02 - READ_REQUIRES_ENCRYPTION = 0x04 - WRITE_REQUIRES_ENCRYPTION = 0x08 - READ_REQUIRES_AUTHENTICATION = 0x10 - WRITE_REQUIRES_AUTHENTICATION = 0x20 - READ_REQUIRES_AUTHORIZATION = 0x40 - WRITE_REQUIRES_AUTHORIZATION = 0x80 - - PERMISSION_NAMES = { - READABLE: 'READABLE', - WRITEABLE: 'WRITEABLE', - READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION', - WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION', - READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION', - WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION', - READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION', - WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION', - } +class ConnectionValue(Protocol): + def read(self, connection) -> bytes: + ... - @staticmethod - def string_to_permissions(permissions_str: str): - try: - return functools.reduce( - lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y), - permissions_str.split(","), - 0, - ) - except TypeError as exc: - raise TypeError( - f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}" - ) from exc + def write(self, connection, value: bytes) -> None: + ... - def __init__(self, attribute_type, permissions, value=b''): + +# ----------------------------------------------------------------------------- +class Attribute(EventEmitter): + class Permissions(enum.IntFlag): + READABLE = 0x01 + WRITEABLE = 0x02 + READ_REQUIRES_ENCRYPTION = 0x04 + WRITE_REQUIRES_ENCRYPTION = 0x08 + READ_REQUIRES_AUTHENTICATION = 0x10 + WRITE_REQUIRES_AUTHENTICATION = 0x20 + READ_REQUIRES_AUTHORIZATION = 0x40 + WRITE_REQUIRES_AUTHORIZATION = 0x80 + + @classmethod + def from_string(cls, permissions_str: str) -> Attribute.Permissions: + try: + return functools.reduce( + lambda x, y: x | Attribute.Permissions[y], + permissions_str.replace('|', ',').split(","), + Attribute.Permissions(0), + ) + except TypeError as exc: + # The check for `p.name is not None` here is needed because for InFlag + # enums, the .name property can be None, when the enum value is 0, + # so the type hint for .name is Optional[str]. + enum_list: List[str] = [p.name for p in cls if p.name is not None] + enum_list_str = ",".join(enum_list) + raise TypeError( + f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}" + ) from exc + + # Permission flags(legacy-use only) + READABLE = Permissions.READABLE + WRITEABLE = Permissions.WRITEABLE + READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION + WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION + READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION + WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION + READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION + WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION + + value: Union[str, bytes, ConnectionValue] + + def __init__( + self, + attribute_type: Union[str, bytes, UUID], + permissions: Union[str, Attribute.Permissions], + value: Union[str, bytes, ConnectionValue] = b'', + ) -> None: EventEmitter.__init__(self) self.handle = 0 self.end_group_handle = 0 if isinstance(permissions, str): - self.permissions = self.string_to_permissions(permissions) + self.permissions = Attribute.Permissions.from_string(permissions) else: self.permissions = permissions @@ -778,22 +800,26 @@ class Attribute(EventEmitter): else: self.value = value - def encode_value(self, value): + def encode_value(self, value: Any) -> bytes: return value - def decode_value(self, value_bytes): + def decode_value(self, value_bytes: bytes) -> Any: return value_bytes - def read_value(self, connection: Connection): + def read_value(self, connection: Optional[Connection]) -> bytes: if ( - self.permissions & self.READ_REQUIRES_ENCRYPTION - ) and not connection.encryption: + (self.permissions & self.READ_REQUIRES_ENCRYPTION) + and connection is not None + and not connection.encryption + ): raise ATT_Error( error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle ) if ( - self.permissions & self.READ_REQUIRES_AUTHENTICATION - ) and not connection.authenticated: + (self.permissions & self.READ_REQUIRES_AUTHENTICATION) + and connection is not None + and not connection.authenticated + ): raise ATT_Error( error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle ) @@ -803,9 +829,9 @@ class Attribute(EventEmitter): error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle ) - if read := getattr(self.value, 'read', None): + if hasattr(self.value, 'read'): try: - value = read(connection) # pylint: disable=not-callable + value = self.value.read(connection) except ATT_Error as error: raise ATT_Error( error_code=error.error_code, att_handle=self.handle @@ -815,7 +841,7 @@ class Attribute(EventEmitter): return self.encode_value(value) - def write_value(self, connection: Connection, value_bytes): + def write_value(self, connection: Connection, value_bytes: bytes) -> None: if ( self.permissions & self.WRITE_REQUIRES_ENCRYPTION ) and not connection.encryption: @@ -836,9 +862,9 @@ class Attribute(EventEmitter): value = self.decode_value(value_bytes) - if write := getattr(self.value, 'write', None): + if hasattr(self.value, 'write'): try: - write(connection, value) # pylint: disable=not-callable + self.value.write(connection, value) # pylint: disable=not-callable except ATT_Error as error: raise ATT_Error( error_code=error.error_code, att_handle=self.handle diff --git a/bumble/gatt.py b/bumble/gatt.py index 067f31d..fe3e85c 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -28,7 +28,7 @@ import enum import functools import logging import struct -from typing import Optional, Sequence, List +from typing import Optional, Sequence, Iterable, List, Union from .colors import color from .core import UUID, get_dict_key_by_value @@ -187,7 +187,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi # ----------------------------------------------------------------------------- -def show_services(services): +def show_services(services: Iterable[Service]) -> None: for service in services: print(color(str(service), 'cyan')) @@ -210,11 +210,11 @@ class Service(Attribute): def __init__( self, - uuid, + uuid: Union[str, UUID], characteristics: List[Characteristic], primary=True, included_services: List[Service] = [], - ): + ) -> None: # Convert the uuid to a UUID object if it isn't already if isinstance(uuid, str): uuid = UUID(uuid) @@ -239,7 +239,7 @@ class Service(Attribute): """ return None - def __str__(self): + def __str__(self) -> str: return ( f'Service(handle=0x{self.handle:04X}, ' f'end=0x{self.end_group_handle:04X}, ' @@ -255,9 +255,11 @@ class TemplateService(Service): to expose their UUID as a class property ''' - UUID: Optional[UUID] = None + UUID: UUID - def __init__(self, characteristics, primary=True): + def __init__( + self, characteristics: List[Characteristic], primary: bool = True + ) -> None: super().__init__(self.UUID, characteristics, primary) @@ -269,7 +271,7 @@ class IncludedServiceDeclaration(Attribute): service: Service - def __init__(self, service): + def __init__(self, service: Service) -> None: declaration_bytes = struct.pack( '<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes() ) @@ -278,7 +280,7 @@ class IncludedServiceDeclaration(Attribute): ) self.service = service - def __str__(self): + def __str__(self) -> str: return ( f'IncludedServiceDefinition(handle=0x{self.handle:04X}, ' f'group_starting_handle=0x{self.service.handle:04X}, ' @@ -326,7 +328,7 @@ class Characteristic(Attribute): f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}" ) - def __str__(self): + def __str__(self) -> str: # NOTE: we override this method to offer a consistent result between python # versions: the value returned by IntFlag.__str__() changed in version 11. return '|'.join( @@ -348,10 +350,10 @@ class Characteristic(Attribute): def __init__( self, - uuid, + uuid: Union[str, bytes, UUID], properties: Characteristic.Properties, - permissions, - value=b'', + permissions: Union[str, Attribute.Permissions], + value: Union[str, bytes, CharacteristicValue] = b'', descriptors: Sequence[Descriptor] = (), ): super().__init__(uuid, permissions, value) @@ -369,7 +371,7 @@ class Characteristic(Attribute): def has_properties(self, properties: Characteristic.Properties) -> bool: return self.properties & properties == properties - def __str__(self): + def __str__(self) -> str: return ( f'Characteristic(handle=0x{self.handle:04X}, ' f'end=0x{self.end_group_handle:04X}, ' @@ -386,7 +388,7 @@ class CharacteristicDeclaration(Attribute): characteristic: Characteristic - def __init__(self, characteristic, value_handle): + def __init__(self, characteristic: Characteristic, value_handle: int) -> None: declaration_bytes = ( struct.pack('<BH', characteristic.properties, value_handle) + characteristic.uuid.to_pdu_bytes() @@ -397,7 +399,7 @@ class CharacteristicDeclaration(Attribute): self.value_handle = value_handle self.characteristic = characteristic - def __str__(self): + def __str__(self) -> str: return ( f'CharacteristicDeclaration(handle=0x{self.handle:04X}, ' f'value_handle=0x{self.value_handle:04X}, ' @@ -520,7 +522,7 @@ class CharacteristicAdapter: return self.wrapped_characteristic.unsubscribe(subscriber) - def __str__(self): + def __str__(self) -> str: wrapped = str(self.wrapped_characteristic) return f'{self.__class__.__name__}({wrapped})' @@ -600,10 +602,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter): Adapter that converts strings to/from bytes using UTF-8 encoding ''' - def encode_value(self, value): + def encode_value(self, value: str) -> bytes: return value.encode('utf-8') - def decode_value(self, value): + def decode_value(self, value: bytes) -> str: return value.decode('utf-8') @@ -613,7 +615,7 @@ class Descriptor(Attribute): See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations ''' - def __str__(self): + def __str__(self) -> str: return ( f'Descriptor(handle=0x{self.handle:04X}, ' f'type={self.type}, ' diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index a33039e..e3b8bb2 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -28,7 +28,18 @@ import asyncio import logging import struct from datetime import datetime -from typing import List, Optional, Dict, Tuple, Callable, Union, Any +from typing import ( + List, + Optional, + Dict, + Tuple, + Callable, + Union, + Any, + Iterable, + Type, + TYPE_CHECKING, +) from pyee import EventEmitter @@ -66,8 +77,12 @@ from .gatt import ( GATT_INCLUDE_ATTRIBUTE_TYPE, Characteristic, ClientCharacteristicConfigurationBits, + TemplateService, ) +if TYPE_CHECKING: + from bumble.device import Connection + # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -78,16 +93,16 @@ logger = logging.getLogger(__name__) # Proxies # ----------------------------------------------------------------------------- class AttributeProxy(EventEmitter): - client: Client - - def __init__(self, client, handle, end_group_handle, attribute_type): + def __init__( + self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID + ) -> None: EventEmitter.__init__(self) self.client = client self.handle = handle self.end_group_handle = end_group_handle self.type = attribute_type - async def read_value(self, no_long_read=False): + async def read_value(self, no_long_read: bool = False) -> bytes: return self.decode_value( await self.client.read_value(self.handle, no_long_read) ) @@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter): self.handle, self.encode_value(value), with_response ) - def encode_value(self, value): + def encode_value(self, value: Any) -> bytes: return value - def decode_value(self, value_bytes): + def decode_value(self, value_bytes: bytes) -> Any: return value_bytes - def __str__(self): + def __str__(self) -> str: return f'Attribute(handle=0x{self.handle:04X}, type={self.type})' @@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy): def get_characteristics_by_uuid(self, uuid): return self.client.get_characteristics_by_uuid(uuid, self) - def __str__(self): + def __str__(self) -> str: return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})' class CharacteristicProxy(AttributeProxy): properties: Characteristic.Properties descriptors: List[DescriptorProxy] - subscribers: Dict[Any, Callable] + subscribers: Dict[Any, Callable[[bytes], Any]] def __init__( self, @@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy): return await self.client.discover_descriptors(self) async def subscribe( - self, subscriber: Optional[Callable] = None, prefer_notify=True + self, + subscriber: Optional[Callable[[bytes], Any]] = None, + prefer_notify: bool = True, ): if subscriber is not None: if subscriber in self.subscribers: @@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy): return await self.client.unsubscribe(self, subscriber) - def __str__(self): + def __str__(self) -> str: return ( f'Characteristic(handle=0x{self.handle:04X}, ' f'uuid={self.uuid}, ' @@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy): def __init__(self, client, handle, descriptor_type): super().__init__(client, handle, 0, descriptor_type) - def __str__(self): + def __str__(self) -> str: return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})' @@ -216,8 +233,10 @@ class ProfileServiceProxy: Base class for profile-specific service proxies ''' + SERVICE_CLASS: Type[TemplateService] + @classmethod - def from_client(cls, client): + def from_client(cls, client: Client) -> ProfileServiceProxy: return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) @@ -227,8 +246,12 @@ class ProfileServiceProxy: class Client: services: List[ServiceProxy] cached_values: Dict[int, Tuple[datetime, bytes]] + notification_subscribers: Dict[int, Callable[[bytes], Any]] + indication_subscribers: Dict[int, Callable[[bytes], Any]] + pending_response: Optional[asyncio.futures.Future[ATT_PDU]] + pending_request: Optional[ATT_PDU] - def __init__(self, connection): + def __init__(self, connection: Connection) -> None: self.connection = connection self.mtu_exchange_done = False self.request_semaphore = asyncio.Semaphore(1) @@ -241,16 +264,16 @@ class Client: self.services = [] self.cached_values = {} - def send_gatt_pdu(self, pdu): + def send_gatt_pdu(self, pdu: bytes) -> None: self.connection.send_l2cap_pdu(ATT_CID, pdu) - async def send_command(self, command): + async def send_command(self, command: ATT_PDU) -> None: logger.debug( f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' ) self.send_gatt_pdu(command.to_bytes()) - async def send_request(self, request): + async def send_request(self, request: ATT_PDU): logger.debug( f'GATT Request from client: [0x{self.connection.handle:04X}] {request}' ) @@ -279,14 +302,14 @@ class Client: return response - def send_confirmation(self, confirmation): + def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None: logger.debug( f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' f'{confirmation}' ) self.send_gatt_pdu(confirmation.to_bytes()) - async def request_mtu(self, mtu): + async def request_mtu(self, mtu: int) -> int: # Check the range if mtu < ATT_DEFAULT_MTU: raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}') @@ -313,10 +336,12 @@ class Client: return self.connection.att_mtu - def get_services_by_uuid(self, uuid): + def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]: return [service for service in self.services if service.uuid == uuid] - def get_characteristics_by_uuid(self, uuid, service=None): + def get_characteristics_by_uuid( + self, uuid: UUID, service: Optional[ServiceProxy] = None + ) -> List[CharacteristicProxy]: services = [service] if service else self.services return [ c @@ -363,7 +388,7 @@ class Client: if not already_known: self.services.append(service) - async def discover_services(self, uuids=None) -> List[ServiceProxy]: + async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]: ''' See Vol 3, Part G - 4.4.1 Discover All Primary Services ''' @@ -435,7 +460,7 @@ class Client: return services - async def discover_service(self, uuid): + async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]: ''' See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID ''' @@ -468,7 +493,7 @@ class Client: f'{HCI_Constant.error_name(response.error_code)}' ) # TODO raise appropriate exception - return + return [] break for attribute_handle, end_group_handle in response.handles_information: @@ -480,7 +505,7 @@ class Client: logger.warning( f'bogus handle values: {attribute_handle} {end_group_handle}' ) - return + return [] # Create a service proxy for this service service = ServiceProxy( @@ -721,7 +746,7 @@ class Client: return descriptors - async def discover_attributes(self): + async def discover_attributes(self) -> List[AttributeProxy]: ''' Discover all attributes, regardless of type ''' @@ -844,7 +869,9 @@ class Client: # No more subscribers left await self.write_value(cccd, b'\x00\x00', with_response=True) - async def read_value(self, attribute, no_long_read=False): + async def read_value( + self, attribute: Union[int, AttributeProxy], no_long_read: bool = False + ) -> Any: ''' See Vol 3, Part G - 4.8.1 Read Characteristic Value @@ -905,7 +932,9 @@ class Client: # Return the value as bytes return attribute_value - async def read_characteristics_by_uuid(self, uuid, service): + async def read_characteristics_by_uuid( + self, uuid: UUID, service: Optional[ServiceProxy] + ) -> List[bytes]: ''' See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID ''' @@ -960,7 +989,12 @@ class Client: return characteristics_values - async def write_value(self, attribute, value, with_response=False): + async def write_value( + self, + attribute: Union[int, AttributeProxy], + value: bytes, + with_response: bool = False, + ) -> None: ''' See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic Value @@ -990,7 +1024,7 @@ class Client: ) ) - def on_gatt_pdu(self, att_pdu): + def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None: logger.debug( f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' ) @@ -1013,6 +1047,7 @@ class Client: return # Return the response to the coroutine that is waiting for it + assert self.pending_response is not None self.pending_response.set_result(att_pdu) else: handler_name = f'on_{att_pdu.name.lower()}' @@ -1060,7 +1095,7 @@ class Client: # Confirm that we received the indication self.send_confirmation(ATT_Handle_Value_Confirmation()) - def cache_value(self, attribute_handle: int, value: bytes): + def cache_value(self, attribute_handle: int, value: bytes) -> None: self.cached_values[attribute_handle] = ( datetime.now(), value, diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index 3624905..cdf1b5e 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -23,11 +23,12 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging from collections import defaultdict import struct -from typing import List, Tuple, Optional, TypeVar, Type +from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING from pyee import EventEmitter from .colors import color @@ -42,6 +43,7 @@ from .att import ( ATT_INVALID_OFFSET_ERROR, ATT_REQUEST_NOT_SUPPORTED_ERROR, ATT_REQUESTS, + ATT_PDU, ATT_UNLIKELY_ERROR_ERROR, ATT_UNSUPPORTED_GROUP_TYPE_ERROR, ATT_Error, @@ -73,6 +75,8 @@ from .gatt import ( Service, ) +if TYPE_CHECKING: + from bumble.device import Device, Connection # ----------------------------------------------------------------------------- # Logging @@ -91,8 +95,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517 # ----------------------------------------------------------------------------- class Server(EventEmitter): attributes: List[Attribute] + services: List[Service] + attributes_by_handle: Dict[int, Attribute] + subscribers: Dict[int, Dict[int, bytes]] + indication_semaphores: defaultdict[int, asyncio.Semaphore] + pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]] - def __init__(self, device): + def __init__(self, device: Device) -> None: super().__init__() self.device = device self.services = [] @@ -107,16 +116,16 @@ class Server(EventEmitter): self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1)) self.pending_confirmations = defaultdict(lambda: None) - def __str__(self): + def __str__(self) -> str: return "\n".join(map(str, self.attributes)) - def send_gatt_pdu(self, connection_handle, pdu): + def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None: self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu) - def next_handle(self): + def next_handle(self) -> int: return 1 + len(self.attributes) - def get_advertising_service_data(self): + def get_advertising_service_data(self) -> Dict[Attribute, bytes]: return { attribute: data for attribute in self.attributes @@ -124,7 +133,7 @@ class Server(EventEmitter): and (data := attribute.get_advertising_data()) } - def get_attribute(self, handle): + def get_attribute(self, handle: int) -> Optional[Attribute]: attribute = self.attributes_by_handle.get(handle) if attribute: return attribute @@ -173,12 +182,17 @@ class Server(EventEmitter): return next( ( - (attribute, self.get_attribute(attribute.characteristic.handle)) + ( + attribute, + self.get_attribute(attribute.characteristic.handle), + ) # type: ignore for attribute in map( self.get_attribute, range(service_handle.handle, service_handle.end_group_handle + 1), ) - if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE + if attribute is not None + and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE + and isinstance(attribute, CharacteristicDeclaration) and attribute.characteristic.uuid == characteristic_uuid ), None, @@ -197,7 +211,7 @@ class Server(EventEmitter): return next( ( - attribute + attribute # type: ignore for attribute in map( self.get_attribute, range( @@ -205,12 +219,12 @@ class Server(EventEmitter): characteristic_value.end_group_handle + 1, ), ) - if attribute.type == descriptor_uuid + if attribute is not None and attribute.type == descriptor_uuid ), None, ) - def add_attribute(self, attribute): + def add_attribute(self, attribute: Attribute) -> None: # Assign a handle to this attribute attribute.handle = self.next_handle() attribute.end_group_handle = ( @@ -220,7 +234,7 @@ class Server(EventEmitter): # Add this attribute to the list self.attributes.append(attribute) - def add_service(self, service: Service): + def add_service(self, service: Service) -> None: # Add the service attribute to the DB self.add_attribute(service) @@ -285,11 +299,13 @@ class Server(EventEmitter): service.end_group_handle = self.attributes[-1].handle self.services.append(service) - def add_services(self, services): + def add_services(self, services: Iterable[Service]) -> None: for service in services: self.add_service(service) - def read_cccd(self, connection, characteristic): + def read_cccd( + self, connection: Optional[Connection], characteristic: Characteristic + ) -> bytes: if connection is None: return bytes([0, 0]) @@ -300,7 +316,12 @@ class Server(EventEmitter): return cccd or bytes([0, 0]) - def write_cccd(self, connection, characteristic, value): + def write_cccd( + self, + connection: Connection, + characteristic: Characteristic, + value: bytes, + ) -> None: logger.debug( f'Subscription update for connection=0x{connection.handle:04X}, ' f'handle=0x{characteristic.handle:04X}: {value.hex()}' @@ -327,13 +348,19 @@ class Server(EventEmitter): indicate_enabled, ) - def send_response(self, connection, response): + def send_response(self, connection: Connection, response: ATT_PDU) -> None: logger.debug( f'GATT Response from server: [0x{connection.handle:04X}] {response}' ) self.send_gatt_pdu(connection.handle, response.to_bytes()) - async def notify_subscriber(self, connection, attribute, value=None, force=False): + async def notify_subscriber( + self, + connection: Connection, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ) -> None: # Check if there's a subscriber if not force: subscribers = self.subscribers.get(connection.handle) @@ -370,7 +397,13 @@ class Server(EventEmitter): ) self.send_gatt_pdu(connection.handle, bytes(notification)) - async def indicate_subscriber(self, connection, attribute, value=None, force=False): + async def indicate_subscriber( + self, + connection: Connection, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ) -> None: # Check if there's a subscriber if not force: subscribers = self.subscribers.get(connection.handle) @@ -411,15 +444,13 @@ class Server(EventEmitter): assert self.pending_confirmations[connection.handle] is None # Create a future value to hold the eventual response - self.pending_confirmations[ + pending_confirmation = self.pending_confirmations[ connection.handle ] = asyncio.get_running_loop().create_future() try: self.send_gatt_pdu(connection.handle, indication.to_bytes()) - await asyncio.wait_for( - self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT - ) + await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT) except asyncio.TimeoutError as error: logger.warning(color('!!! GATT Indicate timeout', 'red')) raise TimeoutError(f'GATT timeout for {indication.name}') from error @@ -427,8 +458,12 @@ class Server(EventEmitter): self.pending_confirmations[connection.handle] = None async def notify_or_indicate_subscribers( - self, indicate, attribute, value=None, force=False - ): + self, + indicate: bool, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ) -> None: # Get all the connections for which there's at least one subscription connections = [ connection @@ -450,13 +485,23 @@ class Server(EventEmitter): ] ) - async def notify_subscribers(self, attribute, value=None, force=False): + async def notify_subscribers( + self, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ): return await self.notify_or_indicate_subscribers(False, attribute, value, force) - async def indicate_subscribers(self, attribute, value=None, force=False): + async def indicate_subscribers( + self, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ): return await self.notify_or_indicate_subscribers(True, attribute, value, force) - def on_disconnection(self, connection): + def on_disconnection(self, connection: Connection) -> None: if connection.handle in self.subscribers: del self.subscribers[connection.handle] if connection.handle in self.indication_semaphores: @@ -464,7 +509,7 @@ class Server(EventEmitter): if connection.handle in self.pending_confirmations: del self.pending_confirmations[connection.handle] - def on_gatt_pdu(self, connection, att_pdu): + def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None: logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}') handler_name = f'on_{att_pdu.name.lower()}' handler = getattr(self, handler_name, None) @@ -506,7 +551,7 @@ class Server(EventEmitter): ####################################################### # ATT handlers ####################################################### - def on_att_request(self, connection, pdu): + def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None: ''' Handler for requests without a more specific handler ''' @@ -679,7 +724,6 @@ class Server(EventEmitter): and attribute.handle <= request.ending_handle and pdu_space_available ): - try: attribute_value = attribute.read_value(connection) except ATT_Error as error: diff --git a/tests/gatt_test.py b/tests/gatt_test.py index dd0277e..d9f6d60 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -891,10 +891,10 @@ async def async_main(): # ----------------------------------------------------------------------------- -def test_attribute_string_to_permissions(): - assert Attribute.string_to_permissions('READABLE') == 1 - assert Attribute.string_to_permissions('WRITEABLE') == 2 - assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3 +def test_permissions_from_string(): + assert Attribute.Permissions.from_string('READABLE') == 1 + assert Attribute.Permissions.from_string('WRITEABLE') == 2 + assert Attribute.Permissions.from_string('READABLE,WRITEABLE') == 3 # ----------------------------------------------------------------------------- |