aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Wu <joshwu@google.com>2023-09-19 15:26:10 +0800
committerLucas Abel <22837557+uael@users.noreply.github.com>2023-09-20 23:13:08 +0200
commit2491b686fa4d1ab8a6f8f1bc3ecd358736e5cb9c (patch)
tree29a6b6262996f70278c2bdc875f961601f90f440
parentefd02b2f3e9092f13f03035917566784d84e39fb (diff)
downloadbumble-2491b686fa4d1ab8a6f8f1bc3ecd358736e5cb9c.tar.gz
Handle SMP_Security_Request
-rw-r--r--bumble/pandora/security.py60
-rw-r--r--bumble/smp.py27
2 files changed, 64 insertions, 23 deletions
diff --git a/bumble/pandora/security.py b/bumble/pandora/security.py
index 96fce85..99695d9 100644
--- a/bumble/pandora/security.py
+++ b/bumble/pandora/security.py
@@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
+import contextlib
import grpc
import logging
@@ -27,8 +28,8 @@ from bumble.core import (
)
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
+from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
-from contextlib import suppress
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
@@ -294,23 +295,35 @@ class SecurityService(SecurityServicer):
try:
self.log.debug('Pair...')
- if (
- connection.transport == BT_LE_TRANSPORT
- and connection.role == BT_PERIPHERAL_ROLE
- ):
- wait_for_security: asyncio.Future[
- bool
- ] = asyncio.get_running_loop().create_future()
- connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
- connection.on("pairing_failure", wait_for_security.set_exception)
+ security_result = asyncio.get_running_loop().create_future()
- connection.request_pairing()
+ with contextlib.closing(EventWatcher()) as watcher:
- await wait_for_security
- else:
- await connection.pair()
+ @watcher.on(connection, 'pairing')
+ def on_pairing(*_: Any) -> None:
+ security_result.set_result('success')
- self.log.debug('Paired')
+ @watcher.on(connection, 'pairing_failure')
+ def on_pairing_failure(*_: Any) -> None:
+ security_result.set_result('pairing_failure')
+
+ @watcher.on(connection, 'disconnection')
+ def on_disconnection(*_: Any) -> None:
+ security_result.set_result('connection_died')
+
+ if (
+ connection.transport == BT_LE_TRANSPORT
+ and connection.role == BT_PERIPHERAL_ROLE
+ ):
+ connection.request_pairing()
+ else:
+ await connection.pair()
+
+ result = await security_result
+
+ self.log.debug(f'Pairing session complete, status={result}')
+ if result != 'success':
+ return SecureResponse(**{result: empty_pb2.Empty()})
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -369,6 +382,7 @@ class SecurityService(SecurityServicer):
str
] = asyncio.get_running_loop().create_future()
authenticate_task: Optional[asyncio.Future[None]] = None
+ pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None:
assert connection
@@ -415,6 +429,10 @@ class SecurityService(SecurityServicer):
if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate())
+ def pair(*_: Any) -> None:
+ if self.need_pairing(connection, level):
+ pair_task = asyncio.create_task(connection.pair())
+
listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
@@ -425,6 +443,7 @@ class SecurityService(SecurityServicer):
'connection_encryption_change': on_encryption_change,
'classic_pairing': try_set_success,
'classic_pairing_failure': set_failure('pairing_failure'),
+ 'security_request': pair,
}
# register event handlers
@@ -452,6 +471,15 @@ class SecurityService(SecurityServicer):
pass
self.log.debug('Authenticated')
+ # wait for `pair` to finish if any
+ if pair_task is not None:
+ self.log.debug('Wait for authentication...')
+ try:
+ await pair_task # type: ignore
+ except:
+ pass
+ self.log.debug('paired')
+
return WaitSecurityResponse(**kwargs)
def reached_security_level(
@@ -523,7 +551,7 @@ class SecurityStorageService(SecurityStorageServicer):
self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None:
- with suppress(KeyError):
+ with contextlib.suppress(KeyError):
await self.device.keystore.delete(str(address))
return empty_pb2.Empty()
diff --git a/bumble/smp.py b/bumble/smp.py
index 55b8359..f8bba40 100644
--- a/bumble/smp.py
+++ b/bumble/smp.py
@@ -37,6 +37,7 @@ from typing import (
Optional,
Tuple,
Type,
+ cast,
)
from pyee import EventEmitter
@@ -1771,7 +1772,26 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes())
+ def on_smp_security_request_command(
+ self, connection: Connection, request: SMP_Security_Request_Command
+ ) -> None:
+ connection.emit('security_request', request.auth_req)
+
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
+ # Parse the L2CAP payload into an SMP Command object
+ command = SMP_Command.from_bytes(pdu)
+ logger.debug(
+ f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
+ f'{connection.peer_address}: {command}'
+ )
+
+ # Security request is more than just pairing, so let applications handle them
+ if command.code == SMP_SECURITY_REQUEST_COMMAND:
+ self.on_smp_security_request_command(
+ connection, cast(SMP_Security_Request_Command, command)
+ )
+ return
+
# Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE:
@@ -1782,13 +1802,6 @@ class Manager(EventEmitter):
)
self.sessions[connection.handle] = session
- # Parse the L2CAP payload into an SMP Command object
- command = SMP_Command.from_bytes(pdu)
- logger.debug(
- f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
- f'{connection.peer_address}: {command}'
- )
-
# Delegate the handling of the command to the session
session.on_smp_command(command)