aboutsummaryrefslogtreecommitdiff
path: root/pw_rpc/public/pw_rpc/internal/fake_channel_output.h
blob: b60f5c59818c2ded2f71e775f040e0c7c3b73443 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
// Copyright 2022 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
#pragma once

#include <cstddef>
#include <iterator>
#include <limits>
#include <mutex>

#include "pw_bytes/span.h"
#include "pw_containers/vector.h"
#include "pw_function/function.h"
#include "pw_rpc/channel.h"
#include "pw_rpc/internal/lock.h"
#include "pw_rpc/internal/method_info.h"
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/method_type.h"
#include "pw_rpc/payloads_view.h"
#include "pw_sync/lock_annotations.h"

namespace pw::rpc {
namespace internal {

// Forward declare for a friend statement.
template <class, size_t, size_t, size_t>
class ForwardingChannelOutput;

}  // namespace internal
}  // namespace pw::rpc

namespace pw::rpc {

class FakeServer;

namespace internal::test {

// A ChannelOutput implementation that stores outgoing packets.
class FakeChannelOutput : public ChannelOutput {
 public:
  FakeChannelOutput(const FakeChannelOutput&) = delete;
  FakeChannelOutput(FakeChannelOutput&&) = delete;

  FakeChannelOutput& operator=(const FakeChannelOutput&) = delete;
  FakeChannelOutput& operator=(FakeChannelOutput&&) = delete;

  Status last_status() const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    PW_ASSERT(total_response_packets_ > 0);
    return packets_.back().status();
  }

  // Returns a view of the payloads seen for this RPC.
  //
  // !!! WARNING !!!
  //
  // Access to the FakeChannelOutput through the PayloadsView is NOT
  // synchronized! The PayloadsView is immediately invalidated if any thread
  // accesses the FakeChannelOutput.
  template <auto kMethod>
  PayloadsView payloads(uint32_t channel_id = Channel::kUnassignedChannelId)
      const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return PayloadsView(packets_,
                        MethodInfo<kMethod>::kType,
                        channel_id,
                        MethodInfo<kMethod>::kServiceId,
                        MethodInfo<kMethod>::kMethodId);
  }

  PayloadsView payloads(MethodType type,
                        uint32_t channel_id,
                        uint32_t service_id,
                        uint32_t method_id) const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return PayloadsView(packets_, type, channel_id, service_id, method_id);
  }

  // Returns a number of the payloads seen for this RPC.
  template <auto kMethod>
  size_t total_payloads(uint32_t channel_id = Channel::kUnassignedChannelId)
      const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return PayloadsView(packets_,
                        MethodInfo<kMethod>::kType,
                        channel_id,
                        MethodInfo<kMethod>::kServiceId,
                        MethodInfo<kMethod>::kMethodId)
        .size();
  }

  // Returns a number of the payloads seen for this RPC.
  size_t total_payloads(MethodType type,
                        uint32_t channel_id,
                        uint32_t service_id,
                        uint32_t method_id) const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return PayloadsView(packets_, type, channel_id, service_id, method_id)
        .size();
  }

  // Returns a view of the final statuses seen for this RPC. Only relevant for
  // checking packets sent by a server.
  //
  // !!! WARNING !!!
  //
  // Access to the FakeChannelOutput through the StatusView is NOT
  // synchronized! The StatusView is immediately invalidated if any thread
  // accesses the FakeChannelOutput.
  template <auto kMethod>
  StatusView completions(uint32_t channel_id = Channel::kUnassignedChannelId)
      const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return StatusView(packets_,
                      internal::pwpb::PacketType::RESPONSE,
                      internal::pwpb::PacketType::RESPONSE,
                      channel_id,
                      MethodInfo<kMethod>::kServiceId,
                      MethodInfo<kMethod>::kMethodId);
  }

  // Returns a view of the pw_rpc server or client errors seen for this RPC.
  //
  // !!! WARNING !!!
  //
  // Access to the FakeChannelOutput through the StatusView is NOT
  // synchronized! The StatusView is immediately invalidated if any thread
  // accesses the FakeChannelOutput.
  template <auto kMethod>
  StatusView errors(uint32_t channel_id = Channel::kUnassignedChannelId) const
      PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return StatusView(packets_,
                      internal::pwpb::PacketType::CLIENT_ERROR,
                      internal::pwpb::PacketType::SERVER_ERROR,
                      channel_id,
                      MethodInfo<kMethod>::kServiceId,
                      MethodInfo<kMethod>::kMethodId);
  }

  // Returns a view of the client stream end packets seen for this RPC. Only
  // relevant for checking packets sent by a client.
  template <auto kMethod>
  size_t client_stream_end_packets(
      uint32_t channel_id = Channel::kUnassignedChannelId) const
      PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return internal::test::PacketsView(
               packets_,
               internal::test::PacketFilter(
                   internal::pwpb::PacketType::CLIENT_STREAM_END,
                   internal::pwpb::PacketType::CLIENT_STREAM_END,
                   channel_id,
                   MethodInfo<kMethod>::kServiceId,
                   MethodInfo<kMethod>::kMethodId))
        .size();
  }

  // The maximum number of packets this FakeChannelOutput can store. Attempting
  // to store more packets than this is an error.
  size_t max_packets() const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return packets_.max_size();
  }

  // The total number of packets that have been sent.
  size_t total_packets() const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return packets_.size();
  }

  // Set to true if a RESPONSE packet is seen.
  bool done() const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    return total_response_packets_ > 0;
  }

  // Clears and resets the FakeChannelOutput.
  void clear() PW_LOCKS_EXCLUDED(mutex_);

  // Returns `status` for all future Send calls. Enables packet processing if
  // `status` is OK.
  void set_send_status(Status status) PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    send_status_ = status;
    return_after_packet_count_ = status.ok() ? -1 : 0;
  }

  // Returns `status` once after the specified positive number of packets.
  void set_send_status(Status status, int return_after_packet_count)
      PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    PW_ASSERT(!status.ok());
    PW_ASSERT(return_after_packet_count > 0);
    send_status_ = status;
    return_after_packet_count_ = return_after_packet_count;
  }

  // Logs which packets have been sent for debugging purposes.
  void LogPackets() const PW_LOCKS_EXCLUDED(mutex_);

  // Processes buffer according to packet type and `return_after_packet_count_`
  // value as follows:
  // When positive, returns `send_status_` once,
  // When equals 0, returns `send_status_` in all future calls,
  // When negative, ignores `send_status_` processes buffer.
  Status Send(ConstByteSpan buffer) final PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    const Status status = HandlePacket(buffer);
    if (on_send_ != nullptr) {
      on_send_(buffer, status);
    }
    return status;
  }

  // Gives access to the last received internal::Packet. This is hidden by the
  // raw/Nanopb implementations, since it gives access to an internal class.
  const Packet& last_packet() const PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    PW_ASSERT(!packets_.empty());
    return packets_.back();
  }

  // The on_send callback is called every time Send() is called. It is passed
  // the contents of the packet and the status to be returned from Send().
  //
  // DANGER: Do NOT call any FakeChannelOutput functions or functions that call
  // FakeChannelOutput functions. That will result in infinite recursion or
  // deadlocks.
  void set_on_send(Function<void(ConstByteSpan, Status)>&& on_send)
      PW_LOCKS_EXCLUDED(mutex_) {
    std::lock_guard lock(mutex_);
    on_send_ = std::move(on_send);
  }

 protected:
  FakeChannelOutput(Vector<Packet>& packets, Vector<std::byte>& payloads)
      : ChannelOutput("pw::rpc::internal::test::FakeChannelOutput"),
        packets_(packets),
        payloads_(payloads) {}

  const Vector<Packet>& packets() const PW_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
    return packets_;
  }

  RpcLock& mutex() const { return mutex_; }

 private:
  friend class rpc::FakeServer;
  template <class, size_t, size_t, size_t>
  friend class internal::ForwardingChannelOutput;

  Status HandlePacket(ConstByteSpan buffer) PW_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
  void CopyPayloadToBuffer(Packet& packet) PW_EXCLUSIVE_LOCKS_REQUIRED(mutex_);

  int return_after_packet_count_ PW_GUARDED_BY(mutex_) = -1;
  unsigned total_response_packets_ PW_GUARDED_BY(mutex_) = 0;

  Vector<Packet>& packets_ PW_GUARDED_BY(mutex_);
  Vector<std::byte>& payloads_ PW_GUARDED_BY(mutex_);
  Status send_status_ PW_GUARDED_BY(mutex_) = OkStatus();
  Function<void(ConstByteSpan, Status)> on_send_ PW_GUARDED_BY(mutex_);

  mutable RpcLock mutex_;
};

// Adds the packet output buffer to a FakeChannelOutput.
template <size_t kMaxPackets, size_t kPayloadsBufferSizeBytes>
class FakeChannelOutputBuffer : public FakeChannelOutput {
 protected:
  FakeChannelOutputBuffer()
      : FakeChannelOutput(packets_array_, payloads_array_), payloads_array_{} {}

  Vector<std::byte, kPayloadsBufferSizeBytes> payloads_array_;
  Vector<Packet, kMaxPackets> packets_array_;
};

}  // namespace internal::test
}  // namespace pw::rpc