aboutsummaryrefslogtreecommitdiff
path: root/pw_rpc/raw/client_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'pw_rpc/raw/client_test.cc')
-rw-r--r--pw_rpc/raw/client_test.cc182
1 files changed, 89 insertions, 93 deletions
diff --git a/pw_rpc/raw/client_test.cc b/pw_rpc/raw/client_test.cc
index a7009d649..b78fa2307 100644
--- a/pw_rpc/raw/client_test.cc
+++ b/pw_rpc/raw/client_test.cc
@@ -19,6 +19,7 @@
#include "gtest/gtest.h"
#include "pw_rpc/internal/client_call.h"
#include "pw_rpc/internal/packet.h"
+#include "pw_rpc/raw/client_reader_writer.h"
#include "pw_rpc/raw/client_testing.h"
namespace pw::rpc {
@@ -42,113 +43,103 @@ struct internal::MethodInfo<BidirectionalStreamMethod> {
namespace {
-template <auto kMethod, typename Call, typename Context>
-Call StartCall(Context& context,
- std::optional<uint32_t> channel_id = std::nullopt)
- PW_LOCKS_EXCLUDED(internal::rpc_lock()) {
- internal::rpc_lock().lock();
- Call call(static_cast<internal::Endpoint&>(context.client()).ClaimLocked(),
- channel_id.value_or(context.channel().id()),
- internal::MethodInfo<kMethod>::kServiceId,
- internal::MethodInfo<kMethod>::kMethodId,
- internal::MethodInfo<kMethod>::kType);
- call.SendInitialClientRequest({});
- // As in the real implementations, immediately clean up aborted calls.
- static_cast<internal::Endpoint&>(context.client()).CleanUpCalls();
- return call;
-}
-
-class TestStreamCall : public internal::StreamResponseClientCall {
- public:
- TestStreamCall(internal::LockedEndpoint& client,
- uint32_t channel_id,
- uint32_t service_id,
- uint32_t method_id,
- MethodType type)
- PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
- : StreamResponseClientCall(
- client,
- channel_id,
- service_id,
- method_id,
- internal::CallProperties(
- type, internal::kClientCall, internal::kRawProto)),
- payload(nullptr) {
- set_on_next_locked([this](ConstByteSpan string) {
+// Captures payload from on_next and statuses from on_error and on_completed.
+// Payloads are assumed to be null-terminated strings.
+template <typename CallType>
+struct CallContext {
+ auto OnNext() {
+ return [this](ConstByteSpan string) {
payload = reinterpret_cast<const char*>(string.data());
- });
- set_on_completed_locked([this](Status status) { completed = status; });
- set_on_error_locked([this](Status status) { error = status; });
+ };
}
- const char* payload;
- std::optional<Status> completed;
- std::optional<Status> error;
-};
-
-class TestUnaryCall : public internal::UnaryResponseClientCall {
- public:
- TestUnaryCall() = default;
-
- TestUnaryCall(internal::LockedEndpoint& client,
- uint32_t channel_id,
- uint32_t service_id,
- uint32_t method_id,
- MethodType type)
- PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock())
- : UnaryResponseClientCall(
- client,
- channel_id,
- service_id,
- method_id,
- internal::CallProperties(
- type, internal::kClientCall, internal::kRawProto)),
- payload(nullptr) {
- set_on_completed_locked([this](ConstByteSpan string, Status status) {
+ auto UnaryOnCompleted() {
+ return [this](ConstByteSpan string, Status status) {
payload = reinterpret_cast<const char*>(string.data());
completed = status;
- });
- set_on_error_locked([this](Status status) { error = status; });
+ };
+ }
+
+ auto StreamOnCompleted() {
+ return [this](Status status) { completed = status; };
}
- using Call::set_on_error;
- using UnaryResponseClientCall::set_on_completed;
+ auto OnError() {
+ return [this](Status status) { error = status; };
+ }
+
+ CallType call;
const char* payload;
std::optional<Status> completed;
std::optional<Status> error;
};
+template <auto kMethod, typename Context>
+CallContext<RawUnaryReceiver> StartUnaryCall(
+ Context& context, std::optional<uint32_t> channel_id = std::nullopt)
+ PW_LOCKS_EXCLUDED(internal::rpc_lock()) {
+ CallContext<RawUnaryReceiver> call_context;
+ call_context.call =
+ internal::UnaryResponseClientCall::Start<RawUnaryReceiver>(
+ context.client(),
+ channel_id.value_or(context.channel().id()),
+ internal::MethodInfo<kMethod>::kServiceId,
+ internal::MethodInfo<kMethod>::kMethodId,
+ call_context.UnaryOnCompleted(),
+ call_context.OnError(),
+ {});
+ return call_context;
+}
+
+template <auto kMethod, typename Context>
+CallContext<RawClientReaderWriter> StartStreamCall(
+ Context& context, std::optional<uint32_t> channel_id = std::nullopt)
+ PW_LOCKS_EXCLUDED(internal::rpc_lock()) {
+ CallContext<RawClientReaderWriter> call_context;
+ call_context.call =
+ internal::StreamResponseClientCall::Start<RawClientReaderWriter>(
+ context.client(),
+ channel_id.value_or(context.channel().id()),
+ internal::MethodInfo<kMethod>::kServiceId,
+ internal::MethodInfo<kMethod>::kMethodId,
+ call_context.OnNext(),
+ call_context.StreamOnCompleted(),
+ call_context.OnError(),
+ {});
+ return call_context;
+}
+
TEST(Client, ProcessPacket_InvokesUnaryCallbacks) {
RawClientTestContext context;
- TestUnaryCall call = StartCall<UnaryMethod, TestUnaryCall>(context);
+ CallContext call_context = StartUnaryCall<UnaryMethod>(context);
- ASSERT_NE(call.completed, OkStatus());
+ ASSERT_NE(call_context.completed, OkStatus());
context.server().SendResponse<UnaryMethod>(as_bytes(span("you nary?!?")),
OkStatus());
- ASSERT_NE(call.payload, nullptr);
- EXPECT_STREQ(call.payload, "you nary?!?");
- EXPECT_EQ(call.completed, OkStatus());
- EXPECT_FALSE(call.active());
+ ASSERT_NE(call_context.payload, nullptr);
+ EXPECT_STREQ(call_context.payload, "you nary?!?");
+ EXPECT_EQ(call_context.completed, OkStatus());
+ EXPECT_FALSE(call_context.call.active());
}
TEST(Client, ProcessPacket_NoCallbackSet) {
RawClientTestContext context;
- TestUnaryCall call = StartCall<UnaryMethod, TestUnaryCall>(context);
- call.set_on_completed(nullptr);
+ CallContext call_context = StartUnaryCall<UnaryMethod>(context);
+ call_context.call.set_on_completed(nullptr);
- ASSERT_NE(call.completed, OkStatus());
+ ASSERT_NE(call_context.completed, OkStatus());
context.server().SendResponse<UnaryMethod>(as_bytes(span("you nary?!?")),
OkStatus());
- EXPECT_FALSE(call.active());
+ EXPECT_FALSE(call_context.call.active());
}
TEST(Client, ProcessPacket_InvokesStreamCallbacks) {
RawClientTestContext context;
- auto call = StartCall<BidirectionalStreamMethod, TestStreamCall>(context);
+ auto call = StartStreamCall<BidirectionalStreamMethod>(context);
context.server().SendServerStream<BidirectionalStreamMethod>(
as_bytes(span("<=>")));
@@ -163,7 +154,7 @@ TEST(Client, ProcessPacket_InvokesStreamCallbacks) {
TEST(Client, ProcessPacket_UnassignedChannelId_ReturnsDataLoss) {
RawClientTestContext context;
- auto call = StartCall<BidirectionalStreamMethod, TestStreamCall>(context);
+ auto call_cts = StartStreamCall<BidirectionalStreamMethod>(context);
std::byte encoded[64];
Result<span<const std::byte>> result =
@@ -180,7 +171,7 @@ TEST(Client, ProcessPacket_UnassignedChannelId_ReturnsDataLoss) {
TEST(Client, ProcessPacket_InvokesErrorCallback) {
RawClientTestContext context;
- auto call = StartCall<BidirectionalStreamMethod, TestStreamCall>(context);
+ auto call = StartStreamCall<BidirectionalStreamMethod>(context);
context.server().SendServerError<BidirectionalStreamMethod>(
Status::Aborted());
@@ -247,9 +238,9 @@ TEST(Client, CloseChannel_UnknownChannel) {
TEST(Client, CloseChannel_CallsErrorCallback) {
RawClientTestContext ctx;
- TestUnaryCall call = StartCall<UnaryMethod, TestUnaryCall>(ctx);
+ CallContext call_ctx = StartUnaryCall<UnaryMethod>(ctx);
- ASSERT_NE(call.completed, OkStatus());
+ ASSERT_NE(call_ctx.completed, OkStatus());
ASSERT_EQ(1u,
static_cast<internal::Endpoint&>(ctx.client()).active_call_count());
@@ -257,25 +248,25 @@ TEST(Client, CloseChannel_CallsErrorCallback) {
EXPECT_EQ(0u,
static_cast<internal::Endpoint&>(ctx.client()).active_call_count());
- ASSERT_EQ(call.error, Status::Aborted()); // set by the on_error callback
+ ASSERT_EQ(call_ctx.error, Status::Aborted()); // set by the on_error callback
}
TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForCallOnClosedChannel) {
struct {
RawClientTestContext<> ctx;
- TestUnaryCall call;
+ CallContext<RawUnaryReceiver> call_ctx;
} context;
- context.call = StartCall<UnaryMethod, TestUnaryCall>(context.ctx);
- context.call.set_on_error([&context](Status error) {
- context.call = StartCall<UnaryMethod, TestUnaryCall>(context.ctx, 1);
- context.call.error = error;
+ context.call_ctx = StartUnaryCall<UnaryMethod>(context.ctx);
+ context.call_ctx.call.set_on_error([&context](Status error) {
+ context.call_ctx = StartUnaryCall<UnaryMethod>(context.ctx, 1);
+ context.call_ctx.error = error;
});
EXPECT_EQ(OkStatus(), context.ctx.client().CloseChannel(1));
- EXPECT_EQ(context.call.error, Status::Aborted());
+ EXPECT_EQ(context.call_ctx.error, Status::Aborted());
- EXPECT_FALSE(context.call.active());
+ EXPECT_FALSE(context.call_ctx.call.active());
EXPECT_EQ(0u,
static_cast<internal::Endpoint&>(context.ctx.client())
.active_call_count());
@@ -293,7 +284,12 @@ TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForActiveCall) {
Channel& channel() { return channels_[0]; }
Client& client() { return client_; }
- TestUnaryCall& call() { return call_; }
+ CallContext<RawUnaryReceiver>& call_ctx() { return call_context_; }
+ RawUnaryReceiver& call() { return call_context_.call; }
+
+ void StartCall(uint32_t channel_id) {
+ call_context_ = StartUnaryCall<UnaryMethod>(*this, channel_id);
+ }
private:
RawFakeChannelOutput<10, 256> channel_output_;
@@ -302,17 +298,17 @@ TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForActiveCall) {
std::byte packet_buffer[64];
FakeServer fake_server_;
- TestUnaryCall call_;
+ CallContext<RawUnaryReceiver> call_context_;
} context;
- context.call() = StartCall<UnaryMethod, TestUnaryCall>(context, 1);
+ context.StartCall(1);
context.call().set_on_error([&context](Status error) {
- context.call() = StartCall<UnaryMethod, TestUnaryCall>(context, 2);
- context.call().error = error;
+ context.StartCall(2);
+ context.call_ctx().error = error;
});
EXPECT_EQ(OkStatus(), context.client().CloseChannel(1));
- EXPECT_EQ(context.call().error, Status::Aborted());
+ EXPECT_EQ(context.call_ctx().error, Status::Aborted());
EXPECT_TRUE(context.call().active());
EXPECT_EQ(