aboutsummaryrefslogtreecommitdiff
path: root/pw_rpc/py/pw_rpc/callback_client/impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_rpc/py/pw_rpc/callback_client/impl.py')
-rw-r--r--pw_rpc/py/pw_rpc/callback_client/impl.py40
1 files changed, 33 insertions, 7 deletions
diff --git a/pw_rpc/py/pw_rpc/callback_client/impl.py b/pw_rpc/py/pw_rpc/callback_client/impl.py
index 474757373..df9129861 100644
--- a/pw_rpc/py/pw_rpc/callback_client/impl.py
+++ b/pw_rpc/py/pw_rpc/callback_client/impl.py
@@ -18,6 +18,7 @@ import logging
import textwrap
from typing import Any, Callable, Dict, Iterable, Optional, Type
+from dataclasses import dataclass
from pw_status import Status
from google.protobuf.message import Message
@@ -44,6 +45,15 @@ from pw_rpc.callback_client.call import (
_LOG = logging.getLogger(__package__)
+@dataclass(eq=True, frozen=True)
+class CallInfo:
+ method: Method
+
+ @property
+ def service(self) -> Service:
+ return self.method.service
+
+
class _MethodClient:
"""A method that can be invoked for a particular channel."""
@@ -57,20 +67,21 @@ class _MethodClient:
) -> None:
self._impl = client_impl
self._rpcs = rpcs
- self._rpc = PendingRpc(channel, method.service, method)
+ self._channel = channel
+ self._method = method
self.default_timeout_s: Optional[float] = default_timeout_s
@property
def channel(self) -> Channel:
- return self._rpc.channel
+ return self._channel
@property
def method(self) -> Method:
- return self._rpc.method
+ return self._method
@property
def service(self) -> Service:
- return self._rpc.service
+ return self._method.service
@property
def request(self) -> type:
@@ -118,8 +129,17 @@ class _MethodClient:
if timeout_s is UseDefault.VALUE:
timeout_s = self.default_timeout_s
+ if self._impl.on_call_hook:
+ self._impl.on_call_hook(CallInfo(self._method))
+
+ rpc = PendingRpc(
+ self._channel,
+ self.service,
+ self.method,
+ self._rpcs.allocate_call_id(),
+ )
call = call_type(
- self._rpcs, self._rpc, timeout_s, on_next, on_completed, on_error
+ self._rpcs, rpc, timeout_s, on_next, on_completed, on_error
)
call._invoke(request, ignore_errors) # pylint: disable=protected-access
return call
@@ -378,18 +398,24 @@ asynchronously using the invoke method.
class Impl(client.ClientImpl):
- """Callback-based ClientImpl, for use with pw_rpc.Client."""
+ """Callback-based ClientImpl, for use with pw_rpc.Client.
+
+ Args:
+ on_call_hook: A callable object to handle RPC method calls.
+ If hook is set, it will be called before RPC execution.
+ """
def __init__(
self,
default_unary_timeout_s: Optional[float] = None,
default_stream_timeout_s: Optional[float] = None,
+ on_call_hook: Optional[Callable[[CallInfo], Any]] = None,
cancel_duplicate_calls: Optional[bool] = True,
) -> None:
super().__init__()
self._default_unary_timeout_s = default_unary_timeout_s
self._default_stream_timeout_s = default_stream_timeout_s
-
+ self.on_call_hook = on_call_hook
# Temporary workaround for clients that rely on mulitple in-flight
# instances of an RPC on the same channel, which is not supported.
# TODO(hepler): Remove this option when clients have updated.