aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Sanche <sanche@google.com>2023-11-17 10:29:55 -0800
committerGitHub <noreply@github.com>2023-11-17 10:29:55 -0800
commitfc12b40bfc6e0c4bb313196e2e3a9c9374ce1c45 (patch)
tree16019fb3b32cf9e6350ad3a442a19ed934636041
parent448923acf277a70e8704c949311bf4feaef8cab6 (diff)
downloadpython-api-core-fc12b40bfc6e0c4bb313196e2e3a9c9374ce1c45.tar.gz
feat: add type annotations to wrapped grpc calls (#554)
* add types to grpc call wrappers * fixed tests * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * changed type * changed async types * added tests * fixed lint issues * Update tests/asyncio/test_grpc_helpers_async.py Co-authored-by: Anthonios Partheniou <partheniou@google.com> * turned GrpcStream into a type alias * added test for GrpcStream * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * added comment * reordered types * changed type var to P --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com> Co-authored-by: Anthonios Partheniou <partheniou@google.com>
-rw-r--r--google/api_core/grpc_helpers.py14
-rw-r--r--google/api_core/grpc_helpers_async.py30
-rw-r--r--tests/asyncio/test_grpc_helpers_async.py22
-rw-r--r--tests/unit/test_grpc_helpers.py17
4 files changed, 70 insertions, 13 deletions
diff --git a/google/api_core/grpc_helpers.py b/google/api_core/grpc_helpers.py
index f52e180..793c884 100644
--- a/google/api_core/grpc_helpers.py
+++ b/google/api_core/grpc_helpers.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Helpers for :mod:`grpc`."""
+from typing import Generic, TypeVar, Iterator
import collections
import functools
@@ -54,6 +55,9 @@ _STREAM_WRAP_CLASSES = (grpc.UnaryStreamMultiCallable, grpc.StreamStreamMultiCal
_LOGGER = logging.getLogger(__name__)
+# denotes the proto response type for grpc calls
+P = TypeVar("P")
+
def _patch_callable_name(callable_):
"""Fix-up gRPC callable attributes.
@@ -79,7 +83,7 @@ def _wrap_unary_errors(callable_):
return error_remapped_callable
-class _StreamingResponseIterator(grpc.Call):
+class _StreamingResponseIterator(Generic[P], grpc.Call):
def __init__(self, wrapped, prefetch_first_result=True):
self._wrapped = wrapped
@@ -97,11 +101,11 @@ class _StreamingResponseIterator(grpc.Call):
# ignore stop iteration at this time. This should be handled outside of retry.
pass
- def __iter__(self):
+ def __iter__(self) -> Iterator[P]:
"""This iterator is also an iterable that returns itself."""
return self
- def __next__(self):
+ def __next__(self) -> P:
"""Get the next response from the stream.
Returns:
@@ -144,6 +148,10 @@ class _StreamingResponseIterator(grpc.Call):
return self._wrapped.trailing_metadata()
+# public type alias denoting the return type of streaming gapic calls
+GrpcStream = _StreamingResponseIterator[P]
+
+
def _wrap_stream_errors(callable_):
"""Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
diff --git a/google/api_core/grpc_helpers_async.py b/google/api_core/grpc_helpers_async.py
index d1f69d9..5685e6f 100644
--- a/google/api_core/grpc_helpers_async.py
+++ b/google/api_core/grpc_helpers_async.py
@@ -21,11 +21,15 @@ functions. This module is implementing the same surface with AsyncIO semantics.
import asyncio
import functools
+from typing import Generic, Iterator, AsyncGenerator, TypeVar
+
import grpc
from grpc import aio
from google.api_core import exceptions, grpc_helpers
+# denotes the proto response type for grpc calls
+P = TypeVar("P")
# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform
# automatic patching for us. But that means the overhead of creating an
@@ -75,8 +79,8 @@ class _WrappedCall(aio.Call):
raise exceptions.from_grpc_error(rpc_error) from rpc_error
-class _WrappedUnaryResponseMixin(_WrappedCall):
- def __await__(self):
+class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall):
+ def __await__(self) -> Iterator[P]:
try:
response = yield from self._call.__await__()
return response
@@ -84,17 +88,17 @@ class _WrappedUnaryResponseMixin(_WrappedCall):
raise exceptions.from_grpc_error(rpc_error) from rpc_error
-class _WrappedStreamResponseMixin(_WrappedCall):
+class _WrappedStreamResponseMixin(Generic[P], _WrappedCall):
def __init__(self):
self._wrapped_async_generator = None
- async def read(self):
+ async def read(self) -> P:
try:
return await self._call.read()
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error
- async def _wrapped_aiter(self):
+ async def _wrapped_aiter(self) -> AsyncGenerator[P, None]:
try:
# NOTE(lidiz) coverage doesn't understand the exception raised from
# __anext__ method. It is covered by test case:
@@ -104,7 +108,7 @@ class _WrappedStreamResponseMixin(_WrappedCall):
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error
- def __aiter__(self):
+ def __aiter__(self) -> AsyncGenerator[P, None]:
if not self._wrapped_async_generator:
self._wrapped_async_generator = self._wrapped_aiter()
return self._wrapped_async_generator
@@ -127,26 +131,32 @@ class _WrappedStreamRequestMixin(_WrappedCall):
# NOTE(lidiz) Implementing each individual class separately, so we don't
# expose any API that should not be seen. E.g., __aiter__ in unary-unary
# RPC, or __await__ in stream-stream RPC.
-class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin, aio.UnaryUnaryCall):
+class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall):
"""Wrapped UnaryUnaryCall to map exceptions."""
-class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin, aio.UnaryStreamCall):
+class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall):
"""Wrapped UnaryStreamCall to map exceptions."""
class _WrappedStreamUnaryCall(
- _WrappedUnaryResponseMixin, _WrappedStreamRequestMixin, aio.StreamUnaryCall
+ _WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall
):
"""Wrapped StreamUnaryCall to map exceptions."""
class _WrappedStreamStreamCall(
- _WrappedStreamRequestMixin, _WrappedStreamResponseMixin, aio.StreamStreamCall
+ _WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall
):
"""Wrapped StreamStreamCall to map exceptions."""
+# public type alias denoting the return type of async streaming gapic calls
+GrpcAsyncStream = _WrappedStreamResponseMixin[P]
+# public type alias denoting the return type of unary gapic calls
+AwaitableGrpcCall = _WrappedUnaryResponseMixin[P]
+
+
def _wrap_unary_errors(callable_):
"""Map errors for Unary-Unary async callables."""
grpc_helpers._patch_callable_name(callable_)
diff --git a/tests/asyncio/test_grpc_helpers_async.py b/tests/asyncio/test_grpc_helpers_async.py
index 95242f6..67c9b33 100644
--- a/tests/asyncio/test_grpc_helpers_async.py
+++ b/tests/asyncio/test_grpc_helpers_async.py
@@ -266,6 +266,28 @@ def test_wrap_errors_non_streaming(wrap_unary_errors):
wrap_unary_errors.assert_called_once_with(callable_)
+def test_grpc_async_stream():
+ """
+ GrpcAsyncStream type should be both an AsyncIterator and a grpc.aio.Call.
+ """
+ instance = grpc_helpers_async.GrpcAsyncStream[int]()
+ assert isinstance(instance, grpc.aio.Call)
+ # should implement __aiter__ and __anext__
+ assert hasattr(instance, "__aiter__")
+ it = instance.__aiter__()
+ assert hasattr(it, "__anext__")
+
+
+def test_awaitable_grpc_call():
+ """
+ AwaitableGrpcCall type should be an Awaitable and a grpc.aio.Call.
+ """
+ instance = grpc_helpers_async.AwaitableGrpcCall[int]()
+ assert isinstance(instance, grpc.aio.Call)
+ # should implement __await__
+ assert hasattr(instance, "__await__")
+
+
@mock.patch("google.api_core.grpc_helpers_async._wrap_stream_errors")
def test_wrap_errors_streaming(wrap_stream_errors):
callable_ = mock.create_autospec(aio.UnaryStreamMultiCallable)
diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py
index 4eccbca..58a6a32 100644
--- a/tests/unit/test_grpc_helpers.py
+++ b/tests/unit/test_grpc_helpers.py
@@ -195,6 +195,23 @@ class Test_StreamingResponseIterator:
wrapped.trailing_metadata.assert_called_once_with()
+class TestGrpcStream(Test_StreamingResponseIterator):
+ @staticmethod
+ def _make_one(wrapped, **kw):
+ return grpc_helpers.GrpcStream(wrapped, **kw)
+
+ def test_grpc_stream_attributes(self):
+ """
+ Should be both a grpc.Call and an iterable
+ """
+ call = self._make_one(None)
+ assert isinstance(call, grpc.Call)
+ # should implement __iter__
+ assert hasattr(call, "__iter__")
+ it = call.__iter__()
+ assert hasattr(it, "__next__")
+
+
def test_wrap_stream_okay():
expected_responses = [1, 2, 3]
callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses))