aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzxzxwu <92432172+zxzxwu@users.noreply.github.com>2023-09-21 16:09:36 +0800
committerGitHub <noreply@github.com>2023-09-21 16:09:36 +0800
commitd290df4aa92957aa09edb03210f29bb222527748 (patch)
tree1f6c7f4c1eba46721a60720ddbb692ad52fd9e2f
parent67418e649a44cb2f34fe685a88f7351a5bcc8558 (diff)
parente559744f3205eb73649b91e3b25559922c329aa8 (diff)
downloadbumble-d290df4aa92957aa09edb03210f29bb222527748.tar.gz
Merge pull request #278 from zxzxwu/gatt2
Typing GATT
-rw-r--r--bumble/att.py128
-rw-r--r--bumble/gatt.py42
-rw-r--r--bumble/gatt_client.py99
-rw-r--r--bumble/gatt_server.py106
-rw-r--r--tests/gatt_test.py8
5 files changed, 245 insertions, 138 deletions
diff --git a/bumble/att.py b/bumble/att.py
index 55ae8a5..db8d2ba 100644
--- a/bumble/att.py
+++ b/bumble/att.py
@@ -23,13 +23,14 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
+import enum
import functools
import struct
from pyee import EventEmitter
-from typing import Dict, Type, TYPE_CHECKING
+from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
-from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
-from bumble.hci import HCI_Object, key_with_value, HCI_Constant
+from bumble.core import UUID, name_or_number, ProtocolError
+from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color
if TYPE_CHECKING:
@@ -182,6 +183,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# pylint: enable=line-too-long
# pylint: disable=invalid-name
+
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
@@ -209,7 +211,7 @@ class ATT_PDU:
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
op_code = 0
- name = None
+ name: str
@staticmethod
def from_bytes(pdu):
@@ -720,47 +722,67 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# -----------------------------------------------------------------------------
-class Attribute(EventEmitter):
- # Permission flags
- READABLE = 0x01
- WRITEABLE = 0x02
- READ_REQUIRES_ENCRYPTION = 0x04
- WRITE_REQUIRES_ENCRYPTION = 0x08
- READ_REQUIRES_AUTHENTICATION = 0x10
- WRITE_REQUIRES_AUTHENTICATION = 0x20
- READ_REQUIRES_AUTHORIZATION = 0x40
- WRITE_REQUIRES_AUTHORIZATION = 0x80
-
- PERMISSION_NAMES = {
- READABLE: 'READABLE',
- WRITEABLE: 'WRITEABLE',
- READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
- WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
- READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
- WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
- READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
- WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
- }
+class ConnectionValue(Protocol):
+ def read(self, connection) -> bytes:
+ ...
- @staticmethod
- def string_to_permissions(permissions_str: str):
- try:
- return functools.reduce(
- lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
- permissions_str.split(","),
- 0,
- )
- except TypeError as exc:
- raise TypeError(
- f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
- ) from exc
+ def write(self, connection, value: bytes) -> None:
+ ...
- def __init__(self, attribute_type, permissions, value=b''):
+
+# -----------------------------------------------------------------------------
+class Attribute(EventEmitter):
+ class Permissions(enum.IntFlag):
+ READABLE = 0x01
+ WRITEABLE = 0x02
+ READ_REQUIRES_ENCRYPTION = 0x04
+ WRITE_REQUIRES_ENCRYPTION = 0x08
+ READ_REQUIRES_AUTHENTICATION = 0x10
+ WRITE_REQUIRES_AUTHENTICATION = 0x20
+ READ_REQUIRES_AUTHORIZATION = 0x40
+ WRITE_REQUIRES_AUTHORIZATION = 0x80
+
+ @classmethod
+ def from_string(cls, permissions_str: str) -> Attribute.Permissions:
+ try:
+ return functools.reduce(
+ lambda x, y: x | Attribute.Permissions[y],
+ permissions_str.replace('|', ',').split(","),
+ Attribute.Permissions(0),
+ )
+ except TypeError as exc:
+ # The check for `p.name is not None` here is needed because for InFlag
+ # enums, the .name property can be None, when the enum value is 0,
+ # so the type hint for .name is Optional[str].
+ enum_list: List[str] = [p.name for p in cls if p.name is not None]
+ enum_list_str = ",".join(enum_list)
+ raise TypeError(
+ f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
+ ) from exc
+
+ # Permission flags(legacy-use only)
+ READABLE = Permissions.READABLE
+ WRITEABLE = Permissions.WRITEABLE
+ READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
+ WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
+ READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
+ WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
+ READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
+ WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
+
+ value: Union[str, bytes, ConnectionValue]
+
+ def __init__(
+ self,
+ attribute_type: Union[str, bytes, UUID],
+ permissions: Union[str, Attribute.Permissions],
+ value: Union[str, bytes, ConnectionValue] = b'',
+ ) -> None:
EventEmitter.__init__(self)
self.handle = 0
self.end_group_handle = 0
if isinstance(permissions, str):
- self.permissions = self.string_to_permissions(permissions)
+ self.permissions = Attribute.Permissions.from_string(permissions)
else:
self.permissions = permissions
@@ -778,22 +800,26 @@ class Attribute(EventEmitter):
else:
self.value = value
- def encode_value(self, value):
+ def encode_value(self, value: Any) -> bytes:
return value
- def decode_value(self, value_bytes):
+ def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
- def read_value(self, connection: Connection):
+ def read_value(self, connection: Optional[Connection]) -> bytes:
if (
- self.permissions & self.READ_REQUIRES_ENCRYPTION
- ) and not connection.encryption:
+ (self.permissions & self.READ_REQUIRES_ENCRYPTION)
+ and connection is not None
+ and not connection.encryption
+ ):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
- self.permissions & self.READ_REQUIRES_AUTHENTICATION
- ) and not connection.authenticated:
+ (self.permissions & self.READ_REQUIRES_AUTHENTICATION)
+ and connection is not None
+ and not connection.authenticated
+ ):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
@@ -803,9 +829,9 @@ class Attribute(EventEmitter):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
- if read := getattr(self.value, 'read', None):
+ if hasattr(self.value, 'read'):
try:
- value = read(connection) # pylint: disable=not-callable
+ value = self.value.read(connection)
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
@@ -815,7 +841,7 @@ class Attribute(EventEmitter):
return self.encode_value(value)
- def write_value(self, connection: Connection, value_bytes):
+ def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
@@ -836,9 +862,9 @@ class Attribute(EventEmitter):
value = self.decode_value(value_bytes)
- if write := getattr(self.value, 'write', None):
+ if hasattr(self.value, 'write'):
try:
- write(connection, value) # pylint: disable=not-callable
+ self.value.write(connection, value) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
diff --git a/bumble/gatt.py b/bumble/gatt.py
index 067f31d..fe3e85c 100644
--- a/bumble/gatt.py
+++ b/bumble/gatt.py
@@ -28,7 +28,7 @@ import enum
import functools
import logging
import struct
-from typing import Optional, Sequence, List
+from typing import Optional, Sequence, Iterable, List, Union
from .colors import color
from .core import UUID, get_dict_key_by_value
@@ -187,7 +187,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
# -----------------------------------------------------------------------------
-def show_services(services):
+def show_services(services: Iterable[Service]) -> None:
for service in services:
print(color(str(service), 'cyan'))
@@ -210,11 +210,11 @@ class Service(Attribute):
def __init__(
self,
- uuid,
+ uuid: Union[str, UUID],
characteristics: List[Characteristic],
primary=True,
included_services: List[Service] = [],
- ):
+ ) -> None:
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
uuid = UUID(uuid)
@@ -239,7 +239,7 @@ class Service(Attribute):
"""
return None
- def __str__(self):
+ def __str__(self) -> str:
return (
f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
@@ -255,9 +255,11 @@ class TemplateService(Service):
to expose their UUID as a class property
'''
- UUID: Optional[UUID] = None
+ UUID: UUID
- def __init__(self, characteristics, primary=True):
+ def __init__(
+ self, characteristics: List[Characteristic], primary: bool = True
+ ) -> None:
super().__init__(self.UUID, characteristics, primary)
@@ -269,7 +271,7 @@ class IncludedServiceDeclaration(Attribute):
service: Service
- def __init__(self, service):
+ def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
)
@@ -278,7 +280,7 @@ class IncludedServiceDeclaration(Attribute):
)
self.service = service
- def __str__(self):
+ def __str__(self) -> str:
return (
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
f'group_starting_handle=0x{self.service.handle:04X}, '
@@ -326,7 +328,7 @@ class Characteristic(Attribute):
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
)
- def __str__(self):
+ def __str__(self) -> str:
# NOTE: we override this method to offer a consistent result between python
# versions: the value returned by IntFlag.__str__() changed in version 11.
return '|'.join(
@@ -348,10 +350,10 @@ class Characteristic(Attribute):
def __init__(
self,
- uuid,
+ uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties,
- permissions,
- value=b'',
+ permissions: Union[str, Attribute.Permissions],
+ value: Union[str, bytes, CharacteristicValue] = b'',
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value)
@@ -369,7 +371,7 @@ class Characteristic(Attribute):
def has_properties(self, properties: Characteristic.Properties) -> bool:
return self.properties & properties == properties
- def __str__(self):
+ def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
@@ -386,7 +388,7 @@ class CharacteristicDeclaration(Attribute):
characteristic: Characteristic
- def __init__(self, characteristic, value_handle):
+ def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
@@ -397,7 +399,7 @@ class CharacteristicDeclaration(Attribute):
self.value_handle = value_handle
self.characteristic = characteristic
- def __str__(self):
+ def __str__(self) -> str:
return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
@@ -520,7 +522,7 @@ class CharacteristicAdapter:
return self.wrapped_characteristic.unsubscribe(subscriber)
- def __str__(self):
+ def __str__(self) -> str:
wrapped = str(self.wrapped_characteristic)
return f'{self.__class__.__name__}({wrapped})'
@@ -600,10 +602,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
- def encode_value(self, value):
+ def encode_value(self, value: str) -> bytes:
return value.encode('utf-8')
- def decode_value(self, value):
+ def decode_value(self, value: bytes) -> str:
return value.decode('utf-8')
@@ -613,7 +615,7 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
'''
- def __str__(self):
+ def __str__(self) -> str:
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py
index a33039e..e3b8bb2 100644
--- a/bumble/gatt_client.py
+++ b/bumble/gatt_client.py
@@ -28,7 +28,18 @@ import asyncio
import logging
import struct
from datetime import datetime
-from typing import List, Optional, Dict, Tuple, Callable, Union, Any
+from typing import (
+ List,
+ Optional,
+ Dict,
+ Tuple,
+ Callable,
+ Union,
+ Any,
+ Iterable,
+ Type,
+ TYPE_CHECKING,
+)
from pyee import EventEmitter
@@ -66,8 +77,12 @@ from .gatt import (
GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
+ TemplateService,
)
+if TYPE_CHECKING:
+ from bumble.device import Connection
+
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -78,16 +93,16 @@ logger = logging.getLogger(__name__)
# Proxies
# -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter):
- client: Client
-
- def __init__(self, client, handle, end_group_handle, attribute_type):
+ def __init__(
+ self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
+ ) -> None:
EventEmitter.__init__(self)
self.client = client
self.handle = handle
self.end_group_handle = end_group_handle
self.type = attribute_type
- async def read_value(self, no_long_read=False):
+ async def read_value(self, no_long_read: bool = False) -> bytes:
return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
@@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter):
self.handle, self.encode_value(value), with_response
)
- def encode_value(self, value):
+ def encode_value(self, value: Any) -> bytes:
return value
- def decode_value(self, value_bytes):
+ def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
- def __str__(self):
+ def __str__(self) -> str:
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
@@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy):
def get_characteristics_by_uuid(self, uuid):
return self.client.get_characteristics_by_uuid(uuid, self)
- def __str__(self):
+ def __str__(self) -> str:
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
class CharacteristicProxy(AttributeProxy):
properties: Characteristic.Properties
descriptors: List[DescriptorProxy]
- subscribers: Dict[Any, Callable]
+ subscribers: Dict[Any, Callable[[bytes], Any]]
def __init__(
self,
@@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.discover_descriptors(self)
async def subscribe(
- self, subscriber: Optional[Callable] = None, prefer_notify=True
+ self,
+ subscriber: Optional[Callable[[bytes], Any]] = None,
+ prefer_notify: bool = True,
):
if subscriber is not None:
if subscriber in self.subscribers:
@@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.unsubscribe(self, subscriber)
- def __str__(self):
+ def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, '
@@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy):
def __init__(self, client, handle, descriptor_type):
super().__init__(client, handle, 0, descriptor_type)
- def __str__(self):
+ def __str__(self) -> str:
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
@@ -216,8 +233,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies
'''
+ SERVICE_CLASS: Type[TemplateService]
+
@classmethod
- def from_client(cls, client):
+ def from_client(cls, client: Client) -> ProfileServiceProxy:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -227,8 +246,12 @@ class ProfileServiceProxy:
class Client:
services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]]
+ notification_subscribers: Dict[int, Callable[[bytes], Any]]
+ indication_subscribers: Dict[int, Callable[[bytes], Any]]
+ pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
+ pending_request: Optional[ATT_PDU]
- def __init__(self, connection):
+ def __init__(self, connection: Connection) -> None:
self.connection = connection
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
@@ -241,16 +264,16 @@ class Client:
self.services = []
self.cached_values = {}
- def send_gatt_pdu(self, pdu):
+ def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu)
- async def send_command(self, command):
+ async def send_command(self, command: ATT_PDU) -> None:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes())
- async def send_request(self, request):
+ async def send_request(self, request: ATT_PDU):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
@@ -279,14 +302,14 @@ class Client:
return response
- def send_confirmation(self, confirmation):
+ def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes())
- async def request_mtu(self, mtu):
+ async def request_mtu(self, mtu: int) -> int:
# Check the range
if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
@@ -313,10 +336,12 @@ class Client:
return self.connection.att_mtu
- def get_services_by_uuid(self, uuid):
+ def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
- def get_characteristics_by_uuid(self, uuid, service=None):
+ def get_characteristics_by_uuid(
+ self, uuid: UUID, service: Optional[ServiceProxy] = None
+ ) -> List[CharacteristicProxy]:
services = [service] if service else self.services
return [
c
@@ -363,7 +388,7 @@ class Client:
if not already_known:
self.services.append(service)
- async def discover_services(self, uuids=None) -> List[ServiceProxy]:
+ async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -435,7 +460,7 @@ class Client:
return services
- async def discover_service(self, uuid):
+ async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
'''
@@ -468,7 +493,7 @@ class Client:
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
- return
+ return []
break
for attribute_handle, end_group_handle in response.handles_information:
@@ -480,7 +505,7 @@ class Client:
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
- return
+ return []
# Create a service proxy for this service
service = ServiceProxy(
@@ -721,7 +746,7 @@ class Client:
return descriptors
- async def discover_attributes(self):
+ async def discover_attributes(self) -> List[AttributeProxy]:
'''
Discover all attributes, regardless of type
'''
@@ -844,7 +869,9 @@ class Client:
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
- async def read_value(self, attribute, no_long_read=False):
+ async def read_value(
+ self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
+ ) -> Any:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -905,7 +932,9 @@ class Client:
# Return the value as bytes
return attribute_value
- async def read_characteristics_by_uuid(self, uuid, service):
+ async def read_characteristics_by_uuid(
+ self, uuid: UUID, service: Optional[ServiceProxy]
+ ) -> List[bytes]:
'''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
'''
@@ -960,7 +989,12 @@ class Client:
return characteristics_values
- async def write_value(self, attribute, value, with_response=False):
+ async def write_value(
+ self,
+ attribute: Union[int, AttributeProxy],
+ value: bytes,
+ with_response: bool = False,
+ ) -> None:
'''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value
@@ -990,7 +1024,7 @@ class Client:
)
)
- def on_gatt_pdu(self, att_pdu):
+ def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
@@ -1013,6 +1047,7 @@ class Client:
return
# Return the response to the coroutine that is waiting for it
+ assert self.pending_response is not None
self.pending_response.set_result(att_pdu)
else:
handler_name = f'on_{att_pdu.name.lower()}'
@@ -1060,7 +1095,7 @@ class Client:
# Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation())
- def cache_value(self, attribute_handle: int, value: bytes):
+ def cache_value(self, attribute_handle: int, value: bytes) -> None:
self.cached_values[attribute_handle] = (
datetime.now(),
value,
diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py
index 3624905..cdf1b5e 100644
--- a/bumble/gatt_server.py
+++ b/bumble/gatt_server.py
@@ -23,11 +23,12 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
+from __future__ import annotations
import asyncio
import logging
from collections import defaultdict
import struct
-from typing import List, Tuple, Optional, TypeVar, Type
+from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter
from .colors import color
@@ -42,6 +43,7 @@ from .att import (
ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS,
+ ATT_PDU,
ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error,
@@ -73,6 +75,8 @@ from .gatt import (
Service,
)
+if TYPE_CHECKING:
+ from bumble.device import Device, Connection
# -----------------------------------------------------------------------------
# Logging
@@ -91,8 +95,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
class Server(EventEmitter):
attributes: List[Attribute]
+ services: List[Service]
+ attributes_by_handle: Dict[int, Attribute]
+ subscribers: Dict[int, Dict[int, bytes]]
+ indication_semaphores: defaultdict[int, asyncio.Semaphore]
+ pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
- def __init__(self, device):
+ def __init__(self, device: Device) -> None:
super().__init__()
self.device = device
self.services = []
@@ -107,16 +116,16 @@ class Server(EventEmitter):
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
- def __str__(self):
+ def __str__(self) -> str:
return "\n".join(map(str, self.attributes))
- def send_gatt_pdu(self, connection_handle, pdu):
+ def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
- def next_handle(self):
+ def next_handle(self) -> int:
return 1 + len(self.attributes)
- def get_advertising_service_data(self):
+ def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
return {
attribute: data
for attribute in self.attributes
@@ -124,7 +133,7 @@ class Server(EventEmitter):
and (data := attribute.get_advertising_data())
}
- def get_attribute(self, handle):
+ def get_attribute(self, handle: int) -> Optional[Attribute]:
attribute = self.attributes_by_handle.get(handle)
if attribute:
return attribute
@@ -173,12 +182,17 @@ class Server(EventEmitter):
return next(
(
- (attribute, self.get_attribute(attribute.characteristic.handle))
+ (
+ attribute,
+ self.get_attribute(attribute.characteristic.handle),
+ ) # type: ignore
for attribute in map(
self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1),
)
- if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
+ if attribute is not None
+ and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
+ and isinstance(attribute, CharacteristicDeclaration)
and attribute.characteristic.uuid == characteristic_uuid
),
None,
@@ -197,7 +211,7 @@ class Server(EventEmitter):
return next(
(
- attribute
+ attribute # type: ignore
for attribute in map(
self.get_attribute,
range(
@@ -205,12 +219,12 @@ class Server(EventEmitter):
characteristic_value.end_group_handle + 1,
),
)
- if attribute.type == descriptor_uuid
+ if attribute is not None and attribute.type == descriptor_uuid
),
None,
)
- def add_attribute(self, attribute):
+ def add_attribute(self, attribute: Attribute) -> None:
# Assign a handle to this attribute
attribute.handle = self.next_handle()
attribute.end_group_handle = (
@@ -220,7 +234,7 @@ class Server(EventEmitter):
# Add this attribute to the list
self.attributes.append(attribute)
- def add_service(self, service: Service):
+ def add_service(self, service: Service) -> None:
# Add the service attribute to the DB
self.add_attribute(service)
@@ -285,11 +299,13 @@ class Server(EventEmitter):
service.end_group_handle = self.attributes[-1].handle
self.services.append(service)
- def add_services(self, services):
+ def add_services(self, services: Iterable[Service]) -> None:
for service in services:
self.add_service(service)
- def read_cccd(self, connection, characteristic):
+ def read_cccd(
+ self, connection: Optional[Connection], characteristic: Characteristic
+ ) -> bytes:
if connection is None:
return bytes([0, 0])
@@ -300,7 +316,12 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
- def write_cccd(self, connection, characteristic, value):
+ def write_cccd(
+ self,
+ connection: Connection,
+ characteristic: Characteristic,
+ value: bytes,
+ ) -> None:
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
@@ -327,13 +348,19 @@ class Server(EventEmitter):
indicate_enabled,
)
- def send_response(self, connection, response):
+ def send_response(self, connection: Connection, response: ATT_PDU) -> None:
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, response.to_bytes())
- async def notify_subscriber(self, connection, attribute, value=None, force=False):
+ async def notify_subscriber(
+ self,
+ connection: Connection,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -370,7 +397,13 @@ class Server(EventEmitter):
)
self.send_gatt_pdu(connection.handle, bytes(notification))
- async def indicate_subscriber(self, connection, attribute, value=None, force=False):
+ async def indicate_subscriber(
+ self,
+ connection: Connection,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -411,15 +444,13 @@ class Server(EventEmitter):
assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response
- self.pending_confirmations[
+ pending_confirmation = self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
try:
self.send_gatt_pdu(connection.handle, indication.to_bytes())
- await asyncio.wait_for(
- self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
- )
+ await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error
@@ -427,8 +458,12 @@ class Server(EventEmitter):
self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers(
- self, indicate, attribute, value=None, force=False
- ):
+ self,
+ indicate: bool,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ) -> None:
# Get all the connections for which there's at least one subscription
connections = [
connection
@@ -450,13 +485,23 @@ class Server(EventEmitter):
]
)
- async def notify_subscribers(self, attribute, value=None, force=False):
+ async def notify_subscribers(
+ self,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
- async def indicate_subscribers(self, attribute, value=None, force=False):
+ async def indicate_subscribers(
+ self,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
- def on_disconnection(self, connection):
+ def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
@@ -464,7 +509,7 @@ class Server(EventEmitter):
if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle]
- def on_gatt_pdu(self, connection, att_pdu):
+ def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None)
@@ -506,7 +551,7 @@ class Server(EventEmitter):
#######################################################
# ATT handlers
#######################################################
- def on_att_request(self, connection, pdu):
+ def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
'''
Handler for requests without a more specific handler
'''
@@ -679,7 +724,6 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle
and pdu_space_available
):
-
try:
attribute_value = attribute.read_value(connection)
except ATT_Error as error:
diff --git a/tests/gatt_test.py b/tests/gatt_test.py
index dd0277e..d9f6d60 100644
--- a/tests/gatt_test.py
+++ b/tests/gatt_test.py
@@ -891,10 +891,10 @@ async def async_main():
# -----------------------------------------------------------------------------
-def test_attribute_string_to_permissions():
- assert Attribute.string_to_permissions('READABLE') == 1
- assert Attribute.string_to_permissions('WRITEABLE') == 2
- assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3
+def test_permissions_from_string():
+ assert Attribute.Permissions.from_string('READABLE') == 1
+ assert Attribute.Permissions.from_string('WRITEABLE') == 2
+ assert Attribute.Permissions.from_string('READABLE,WRITEABLE') == 3
# -----------------------------------------------------------------------------