diff options
Diffstat (limited to 'pw_rpc/raw/client_test.cc')
-rw-r--r-- | pw_rpc/raw/client_test.cc | 182 |
1 files changed, 89 insertions, 93 deletions
diff --git a/pw_rpc/raw/client_test.cc b/pw_rpc/raw/client_test.cc index a7009d649..b78fa2307 100644 --- a/pw_rpc/raw/client_test.cc +++ b/pw_rpc/raw/client_test.cc @@ -19,6 +19,7 @@ #include "gtest/gtest.h" #include "pw_rpc/internal/client_call.h" #include "pw_rpc/internal/packet.h" +#include "pw_rpc/raw/client_reader_writer.h" #include "pw_rpc/raw/client_testing.h" namespace pw::rpc { @@ -42,113 +43,103 @@ struct internal::MethodInfo<BidirectionalStreamMethod> { namespace { -template <auto kMethod, typename Call, typename Context> -Call StartCall(Context& context, - std::optional<uint32_t> channel_id = std::nullopt) - PW_LOCKS_EXCLUDED(internal::rpc_lock()) { - internal::rpc_lock().lock(); - Call call(static_cast<internal::Endpoint&>(context.client()).ClaimLocked(), - channel_id.value_or(context.channel().id()), - internal::MethodInfo<kMethod>::kServiceId, - internal::MethodInfo<kMethod>::kMethodId, - internal::MethodInfo<kMethod>::kType); - call.SendInitialClientRequest({}); - // As in the real implementations, immediately clean up aborted calls. - static_cast<internal::Endpoint&>(context.client()).CleanUpCalls(); - return call; -} - -class TestStreamCall : public internal::StreamResponseClientCall { - public: - TestStreamCall(internal::LockedEndpoint& client, - uint32_t channel_id, - uint32_t service_id, - uint32_t method_id, - MethodType type) - PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock()) - : StreamResponseClientCall( - client, - channel_id, - service_id, - method_id, - internal::CallProperties( - type, internal::kClientCall, internal::kRawProto)), - payload(nullptr) { - set_on_next_locked([this](ConstByteSpan string) { +// Captures payload from on_next and statuses from on_error and on_completed. +// Payloads are assumed to be null-terminated strings. +template <typename CallType> +struct CallContext { + auto OnNext() { + return [this](ConstByteSpan string) { payload = reinterpret_cast<const char*>(string.data()); - }); - set_on_completed_locked([this](Status status) { completed = status; }); - set_on_error_locked([this](Status status) { error = status; }); + }; } - const char* payload; - std::optional<Status> completed; - std::optional<Status> error; -}; - -class TestUnaryCall : public internal::UnaryResponseClientCall { - public: - TestUnaryCall() = default; - - TestUnaryCall(internal::LockedEndpoint& client, - uint32_t channel_id, - uint32_t service_id, - uint32_t method_id, - MethodType type) - PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock()) - : UnaryResponseClientCall( - client, - channel_id, - service_id, - method_id, - internal::CallProperties( - type, internal::kClientCall, internal::kRawProto)), - payload(nullptr) { - set_on_completed_locked([this](ConstByteSpan string, Status status) { + auto UnaryOnCompleted() { + return [this](ConstByteSpan string, Status status) { payload = reinterpret_cast<const char*>(string.data()); completed = status; - }); - set_on_error_locked([this](Status status) { error = status; }); + }; + } + + auto StreamOnCompleted() { + return [this](Status status) { completed = status; }; } - using Call::set_on_error; - using UnaryResponseClientCall::set_on_completed; + auto OnError() { + return [this](Status status) { error = status; }; + } + + CallType call; const char* payload; std::optional<Status> completed; std::optional<Status> error; }; +template <auto kMethod, typename Context> +CallContext<RawUnaryReceiver> StartUnaryCall( + Context& context, std::optional<uint32_t> channel_id = std::nullopt) + PW_LOCKS_EXCLUDED(internal::rpc_lock()) { + CallContext<RawUnaryReceiver> call_context; + call_context.call = + internal::UnaryResponseClientCall::Start<RawUnaryReceiver>( + context.client(), + channel_id.value_or(context.channel().id()), + internal::MethodInfo<kMethod>::kServiceId, + internal::MethodInfo<kMethod>::kMethodId, + call_context.UnaryOnCompleted(), + call_context.OnError(), + {}); + return call_context; +} + +template <auto kMethod, typename Context> +CallContext<RawClientReaderWriter> StartStreamCall( + Context& context, std::optional<uint32_t> channel_id = std::nullopt) + PW_LOCKS_EXCLUDED(internal::rpc_lock()) { + CallContext<RawClientReaderWriter> call_context; + call_context.call = + internal::StreamResponseClientCall::Start<RawClientReaderWriter>( + context.client(), + channel_id.value_or(context.channel().id()), + internal::MethodInfo<kMethod>::kServiceId, + internal::MethodInfo<kMethod>::kMethodId, + call_context.OnNext(), + call_context.StreamOnCompleted(), + call_context.OnError(), + {}); + return call_context; +} + TEST(Client, ProcessPacket_InvokesUnaryCallbacks) { RawClientTestContext context; - TestUnaryCall call = StartCall<UnaryMethod, TestUnaryCall>(context); + CallContext call_context = StartUnaryCall<UnaryMethod>(context); - ASSERT_NE(call.completed, OkStatus()); + ASSERT_NE(call_context.completed, OkStatus()); context.server().SendResponse<UnaryMethod>(as_bytes(span("you nary?!?")), OkStatus()); - ASSERT_NE(call.payload, nullptr); - EXPECT_STREQ(call.payload, "you nary?!?"); - EXPECT_EQ(call.completed, OkStatus()); - EXPECT_FALSE(call.active()); + ASSERT_NE(call_context.payload, nullptr); + EXPECT_STREQ(call_context.payload, "you nary?!?"); + EXPECT_EQ(call_context.completed, OkStatus()); + EXPECT_FALSE(call_context.call.active()); } TEST(Client, ProcessPacket_NoCallbackSet) { RawClientTestContext context; - TestUnaryCall call = StartCall<UnaryMethod, TestUnaryCall>(context); - call.set_on_completed(nullptr); + CallContext call_context = StartUnaryCall<UnaryMethod>(context); + call_context.call.set_on_completed(nullptr); - ASSERT_NE(call.completed, OkStatus()); + ASSERT_NE(call_context.completed, OkStatus()); context.server().SendResponse<UnaryMethod>(as_bytes(span("you nary?!?")), OkStatus()); - EXPECT_FALSE(call.active()); + EXPECT_FALSE(call_context.call.active()); } TEST(Client, ProcessPacket_InvokesStreamCallbacks) { RawClientTestContext context; - auto call = StartCall<BidirectionalStreamMethod, TestStreamCall>(context); + auto call = StartStreamCall<BidirectionalStreamMethod>(context); context.server().SendServerStream<BidirectionalStreamMethod>( as_bytes(span("<=>"))); @@ -163,7 +154,7 @@ TEST(Client, ProcessPacket_InvokesStreamCallbacks) { TEST(Client, ProcessPacket_UnassignedChannelId_ReturnsDataLoss) { RawClientTestContext context; - auto call = StartCall<BidirectionalStreamMethod, TestStreamCall>(context); + auto call_cts = StartStreamCall<BidirectionalStreamMethod>(context); std::byte encoded[64]; Result<span<const std::byte>> result = @@ -180,7 +171,7 @@ TEST(Client, ProcessPacket_UnassignedChannelId_ReturnsDataLoss) { TEST(Client, ProcessPacket_InvokesErrorCallback) { RawClientTestContext context; - auto call = StartCall<BidirectionalStreamMethod, TestStreamCall>(context); + auto call = StartStreamCall<BidirectionalStreamMethod>(context); context.server().SendServerError<BidirectionalStreamMethod>( Status::Aborted()); @@ -247,9 +238,9 @@ TEST(Client, CloseChannel_UnknownChannel) { TEST(Client, CloseChannel_CallsErrorCallback) { RawClientTestContext ctx; - TestUnaryCall call = StartCall<UnaryMethod, TestUnaryCall>(ctx); + CallContext call_ctx = StartUnaryCall<UnaryMethod>(ctx); - ASSERT_NE(call.completed, OkStatus()); + ASSERT_NE(call_ctx.completed, OkStatus()); ASSERT_EQ(1u, static_cast<internal::Endpoint&>(ctx.client()).active_call_count()); @@ -257,25 +248,25 @@ TEST(Client, CloseChannel_CallsErrorCallback) { EXPECT_EQ(0u, static_cast<internal::Endpoint&>(ctx.client()).active_call_count()); - ASSERT_EQ(call.error, Status::Aborted()); // set by the on_error callback + ASSERT_EQ(call_ctx.error, Status::Aborted()); // set by the on_error callback } TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForCallOnClosedChannel) { struct { RawClientTestContext<> ctx; - TestUnaryCall call; + CallContext<RawUnaryReceiver> call_ctx; } context; - context.call = StartCall<UnaryMethod, TestUnaryCall>(context.ctx); - context.call.set_on_error([&context](Status error) { - context.call = StartCall<UnaryMethod, TestUnaryCall>(context.ctx, 1); - context.call.error = error; + context.call_ctx = StartUnaryCall<UnaryMethod>(context.ctx); + context.call_ctx.call.set_on_error([&context](Status error) { + context.call_ctx = StartUnaryCall<UnaryMethod>(context.ctx, 1); + context.call_ctx.error = error; }); EXPECT_EQ(OkStatus(), context.ctx.client().CloseChannel(1)); - EXPECT_EQ(context.call.error, Status::Aborted()); + EXPECT_EQ(context.call_ctx.error, Status::Aborted()); - EXPECT_FALSE(context.call.active()); + EXPECT_FALSE(context.call_ctx.call.active()); EXPECT_EQ(0u, static_cast<internal::Endpoint&>(context.ctx.client()) .active_call_count()); @@ -293,7 +284,12 @@ TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForActiveCall) { Channel& channel() { return channels_[0]; } Client& client() { return client_; } - TestUnaryCall& call() { return call_; } + CallContext<RawUnaryReceiver>& call_ctx() { return call_context_; } + RawUnaryReceiver& call() { return call_context_.call; } + + void StartCall(uint32_t channel_id) { + call_context_ = StartUnaryCall<UnaryMethod>(*this, channel_id); + } private: RawFakeChannelOutput<10, 256> channel_output_; @@ -302,17 +298,17 @@ TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForActiveCall) { std::byte packet_buffer[64]; FakeServer fake_server_; - TestUnaryCall call_; + CallContext<RawUnaryReceiver> call_context_; } context; - context.call() = StartCall<UnaryMethod, TestUnaryCall>(context, 1); + context.StartCall(1); context.call().set_on_error([&context](Status error) { - context.call() = StartCall<UnaryMethod, TestUnaryCall>(context, 2); - context.call().error = error; + context.StartCall(2); + context.call_ctx().error = error; }); EXPECT_EQ(OkStatus(), context.client().CloseChannel(1)); - EXPECT_EQ(context.call().error, Status::Aborted()); + EXPECT_EQ(context.call_ctx().error, Status::Aborted()); EXPECT_TRUE(context.call().active()); EXPECT_EQ( |