aboutsummaryrefslogtreecommitdiff
path: root/src/core/lib/promise/party.h
blob: 206cbd5af80d60d9600d8e734b1f333c4562459a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef GRPC_SRC_CORE_LIB_PROMISE_PARTY_H
#define GRPC_SRC_CORE_LIB_PROMISE_PARTY_H

#include <grpc/support/port_platform.h>

#include <stddef.h>
#include <stdint.h>

#include <atomic>
#include <string>
#include <utility>

#include "absl/base/thread_annotations.h"
#include "absl/strings/string_view.h"

#include <grpc/event_engine/event_engine.h>
#include <grpc/support/log.h>

#include "src/core/lib/debug/trace.h"
#include "src/core/lib/gprpp/construct_destruct.h"
#include "src/core/lib/gprpp/crash.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/sync.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/detail/promise_factory.h"
#include "src/core/lib/promise/trace.h"
#include "src/core/lib/resource_quota/arena.h"

// Two implementations of party synchronization are provided: one using a single
// atomic, the other using a mutex and a set of state variables.
// Originally the atomic implementation was implemented, but we found some race
// conditions on Arm that were not reported by our default TSAN implementation.
// The mutex implementation was added to see if it would fix the problem, and
// it did. Later we found the race condition, so there's no known reason to use
// the mutex version - however we keep it around as a just in case measure.
// There's a thought of fuzzing the two implementations against each other as
// a correctness check of both, but that's not implemented yet.

#define GRPC_PARTY_SYNC_USING_ATOMICS
// #define GRPC_PARTY_SYNC_USING_MUTEX

#if defined(GRPC_PARTY_SYNC_USING_ATOMICS) +    \
        defined(GRPC_PARTY_SYNC_USING_MUTEX) != \
    1
#error Must define a party sync mechanism
#endif

namespace grpc_core {

namespace party_detail {

// Number of bits reserved for wakeups gives us the maximum number of
// participants.
static constexpr size_t kMaxParticipants = 16;

}  // namespace party_detail

class PartySyncUsingAtomics {
 public:
  explicit PartySyncUsingAtomics(size_t initial_refs)
      : state_(kOneRef * initial_refs) {}

  void IncrementRefCount() {
    state_.fetch_add(kOneRef, std::memory_order_relaxed);
  }
  GRPC_MUST_USE_RESULT bool RefIfNonZero();
  // Returns true if the ref count is now zero and the caller should call
  // PartyOver
  GRPC_MUST_USE_RESULT bool Unref() {
    uint64_t prev_state = state_.fetch_sub(kOneRef, std::memory_order_acq_rel);
    if ((prev_state & kRefMask) == kOneRef) {
      return UnreffedLast();
    }
    return false;
  }
  void ForceImmediateRepoll(WakeupMask mask) {
    // Or in the bit for the currently polling participant.
    // Will be grabbed next round to force a repoll of this promise.
    state_.fetch_or(mask, std::memory_order_relaxed);
  }

  // Run the update loop: poll_one_participant is called with an integral index
  // for the participant that should be polled. It should return true if the
  // participant completed and should be removed from the allocated set.
  template <typename F>
  GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) {
    uint64_t prev_state;
    do {
      // Grab the current state, and clear the wakeup bits & add flag.
      prev_state = state_.fetch_and(kRefMask | kLocked | kAllocatedMask,
                                    std::memory_order_acquire);
      GPR_ASSERT(prev_state & kLocked);
      if (prev_state & kDestroying) return true;
      // From the previous state, extract which participants we're to wakeup.
      uint64_t wakeups = prev_state & kWakeupMask;
      // Now update prev_state to be what we want the CAS to see below.
      prev_state &= kRefMask | kLocked | kAllocatedMask;
      // For each wakeup bit...
      for (size_t i = 0; wakeups != 0; i++, wakeups >>= 1) {
        // If the bit is not set, skip.
        if ((wakeups & 1) == 0) continue;
        if (poll_one_participant(i)) {
          const uint64_t allocated_bit = (1u << i << kAllocatedShift);
          prev_state &= ~allocated_bit;
          state_.fetch_and(~allocated_bit, std::memory_order_release);
        }
      }
      // Try to CAS the state we expected to have (with no wakeups or adds)
      // back to unlocked (by masking in only the ref mask - sans locked bit).
      // If this succeeds then no wakeups were added, no adds were added, and we
      // have successfully unlocked.
      // Otherwise, we need to loop again.
      // Note that if an owning waker is created or the weak cas spuriously
      // fails we will also loop again, but in that case see no wakeups or adds
      // and so will get back here fairly quickly.
      // TODO(ctiller): consider mitigations for the accidental wakeup on owning
      // waker creation case -- I currently expect this will be more expensive
      // than this quick loop.
    } while (!state_.compare_exchange_weak(
        prev_state, (prev_state & (kRefMask | kAllocatedMask)),
        std::memory_order_acq_rel, std::memory_order_acquire));
    return false;
  }

  // Add new participants to the party. Returns true if the caller should run
  // the party. store is called with an array of indices of the new
  // participants. Adds a ref that should be dropped by the caller after
  // RunParty has been called (if that was required).
  template <typename F>
  GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) {
    uint64_t state = state_.load(std::memory_order_acquire);
    uint64_t allocated;

    size_t slots[party_detail::kMaxParticipants];

    // Find slots for each new participant, ordering them from lowest available
    // slot upwards to ensure the same poll ordering as presentation ordering to
    // this function.
    WakeupMask wakeup_mask;
    do {
      wakeup_mask = 0;
      allocated = (state & kAllocatedMask) >> kAllocatedShift;
      size_t n = 0;
      for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants;
           bit++) {
        if (allocated & (1 << bit)) continue;
        wakeup_mask |= (1 << bit);
        slots[n++] = bit;
        allocated |= 1 << bit;
      }
      GPR_ASSERT(n == count);
      // Try to allocate this slot and take a ref (atomically).
      // Ref needs to be taken because once we store the participant it could be
      // spuriously woken up and unref the party.
    } while (!state_.compare_exchange_weak(
        state, (state | (allocated << kAllocatedShift)) + kOneRef,
        std::memory_order_acq_rel, std::memory_order_acquire));

    store(slots);

    // Now we need to wake up the party.
    state = state_.fetch_or(wakeup_mask | kLocked, std::memory_order_release);

    // If the party was already locked, we're done.
    return ((state & kLocked) == 0);
  }

  // Schedule a wakeup for the given participant.
  // Returns true if the caller should run the party.
  GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask);

 private:
  bool UnreffedLast();

  // State bits:
  // The atomic state_ field is composed of the following:
  //   - 24 bits for ref counts
  //     1 is owned by the party prior to Orphan()
  //     All others are owned by owning wakers
  //   - 1 bit to indicate whether the party is locked
  //     The first thread to set this owns the party until it is unlocked
  //     That thread will run the main loop until no further work needs to
  //     be done.
  //   - 1 bit to indicate whether there are participants waiting to be
  //   added
  //   - 16 bits, one per participant, indicating which participants have
  //   been
  //     woken up and should be polled next time the main loop runs.

  // clang-format off
  // Bits used to store 16 bits of wakeups
  static constexpr uint64_t kWakeupMask    = 0x0000'0000'0000'ffff;
  // Bits used to store 16 bits of allocated participant slots.
  static constexpr uint64_t kAllocatedMask = 0x0000'0000'ffff'0000;
  // Bit indicating destruction has begun (refs went to zero)
  static constexpr uint64_t kDestroying    = 0x0000'0001'0000'0000;
  // Bit indicating locked or not
  static constexpr uint64_t kLocked        = 0x0000'0008'0000'0000;
  // Bits used to store 24 bits of ref counts
  static constexpr uint64_t kRefMask       = 0xffff'ff00'0000'0000;
  // clang-format on

  // Shift to get from a participant mask to an allocated mask.
  static constexpr size_t kAllocatedShift = 16;
  // How far to shift to get the refcount
  static constexpr size_t kRefShift = 40;
  // One ref count
  static constexpr uint64_t kOneRef = 1ull << kRefShift;

  std::atomic<uint64_t> state_;
};

class PartySyncUsingMutex {
 public:
  explicit PartySyncUsingMutex(size_t initial_refs) : refs_(initial_refs) {}

  void IncrementRefCount() { refs_.Ref(); }
  GRPC_MUST_USE_RESULT bool RefIfNonZero() { return refs_.RefIfNonZero(); }
  GRPC_MUST_USE_RESULT bool Unref() { return refs_.Unref(); }
  void ForceImmediateRepoll(WakeupMask mask) {
    MutexLock lock(&mu_);
    wakeups_ |= mask;
  }
  template <typename F>
  GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) {
    WakeupMask freed = 0;
    while (true) {
      ReleasableMutexLock lock(&mu_);
      GPR_ASSERT(locked_);
      allocated_ &= ~std::exchange(freed, 0);
      auto wakeup = std::exchange(wakeups_, 0);
      if (wakeup == 0) {
        locked_ = false;
        return false;
      }
      lock.Release();
      for (size_t i = 0; wakeup != 0; i++, wakeup >>= 1) {
        if ((wakeup & 1) == 0) continue;
        if (poll_one_participant(i)) freed |= 1 << i;
      }
    }
  }

  template <typename F>
  GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) {
    IncrementRefCount();
    MutexLock lock(&mu_);
    size_t slots[party_detail::kMaxParticipants];
    WakeupMask wakeup_mask = 0;
    size_t n = 0;
    for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants;
         bit++) {
      if (allocated_ & (1 << bit)) continue;
      slots[n++] = bit;
      wakeup_mask |= 1 << bit;
      allocated_ |= 1 << bit;
    }
    GPR_ASSERT(n == count);
    store(slots);
    wakeups_ |= wakeup_mask;
    return !std::exchange(locked_, true);
  }

  GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask);

 private:
  RefCount refs_;
  Mutex mu_;
  WakeupMask allocated_ ABSL_GUARDED_BY(mu_) = 0;
  WakeupMask wakeups_ ABSL_GUARDED_BY(mu_) = 0;
  bool locked_ ABSL_GUARDED_BY(mu_) = false;
};

// A Party is an Activity with multiple participant promises.
class Party : public Activity, private Wakeable {
 private:
  // Non-owning wakeup handle.
  class Handle;

  // One participant in the party.
  class Participant {
   public:
    explicit Participant(absl::string_view name) : name_(name) {}
    // Poll the participant. Return true if complete.
    // Participant should take care of its own deallocation in this case.
    virtual bool Poll() = 0;

    // Destroy the participant before finishing.
    virtual void Destroy() = 0;

    // Return a Handle instance for this participant.
    Wakeable* MakeNonOwningWakeable(Party* party);

    absl::string_view name() const { return name_; }

   protected:
    ~Participant();

   private:
    Handle* handle_ = nullptr;
    absl::string_view name_;
  };

 public:
  Party(const Party&) = delete;
  Party& operator=(const Party&) = delete;

  // Spawn one promise into the party.
  // The promise will be polled until it is resolved, or until the party is shut
  // down.
  // The on_complete callback will be called with the result of the promise if
  // it completes.
  // A maximum of sixteen promises can be spawned onto a party.
  template <typename Factory, typename OnComplete>
  void Spawn(absl::string_view name, Factory promise_factory,
             OnComplete on_complete);

  void Orphan() final { Crash("unused"); }

  // Activity implementation: not allowed to be overridden by derived types.
  void ForceImmediateRepoll(WakeupMask mask) final;
  WakeupMask CurrentParticipant() const final {
    GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling);
    return 1u << currently_polling_;
  }
  Waker MakeOwningWaker() final;
  Waker MakeNonOwningWaker() final;
  std::string ActivityDebugTag(WakeupMask wakeup_mask) const final;

  void IncrementRefCount() { sync_.IncrementRefCount(); }
  void Unref() {
    if (sync_.Unref()) PartyIsOver();
  }
  RefCountedPtr<Party> Ref() {
    IncrementRefCount();
    return RefCountedPtr<Party>(this);
  }

  Arena* arena() const { return arena_; }

  class BulkSpawner {
   public:
    explicit BulkSpawner(Party* party) : party_(party) {}
    ~BulkSpawner() {
      party_->AddParticipants(participants_, num_participants_);
    }

    template <typename Factory, typename OnComplete>
    void Spawn(absl::string_view name, Factory promise_factory,
               OnComplete on_complete);

   private:
    Party* const party_;
    size_t num_participants_ = 0;
    Participant* participants_[party_detail::kMaxParticipants];
  };

 protected:
  explicit Party(Arena* arena, size_t initial_refs)
      : sync_(initial_refs), arena_(arena) {}
  ~Party() override;

  // Main run loop. Must be locked.
  // Polls participants and drains the add queue until there is no work left to
  // be done.
  // Derived types will likely want to override this to set up their
  // contexts before polling.
  // Should not be called by derived types except as a tail call to the base
  // class RunParty when overriding this method to add custom context.
  // Returns true if the party is over.
  GRPC_MUST_USE_RESULT virtual bool RunParty();

  bool RefIfNonZero() { return sync_.RefIfNonZero(); }

  // Destroy any remaining participants.
  // Should be called by derived types in response to PartyOver.
  // Needs to have normal context setup before calling.
  void CancelRemainingParticipants();

 private:
  // Concrete implementation of a participant for some promise & oncomplete
  // type.
  template <typename SuppliedFactory, typename OnComplete>
  class ParticipantImpl final : public Participant {
    using Factory = promise_detail::OncePromiseFactory<void, SuppliedFactory>;
    using Promise = typename Factory::Promise;

   public:
    ParticipantImpl(absl::string_view name, SuppliedFactory promise_factory,
                    OnComplete on_complete)
        : Participant(name), on_complete_(std::move(on_complete)) {
      Construct(&factory_, std::move(promise_factory));
    }
    ~ParticipantImpl() {
      if (!started_) {
        Destruct(&factory_);
      } else {
        Destruct(&promise_);
      }
    }

    bool Poll() override {
      if (!started_) {
        auto p = factory_.Make();
        Destruct(&factory_);
        Construct(&promise_, std::move(p));
        started_ = true;
      }
      auto p = promise_();
      if (auto* r = p.value_if_ready()) {
        on_complete_(std::move(*r));
        GetContext<Arena>()->DeletePooled(this);
        return true;
      }
      return false;
    }

    void Destroy() override { GetContext<Arena>()->DeletePooled(this); }

   private:
    union {
      GPR_NO_UNIQUE_ADDRESS Factory factory_;
      GPR_NO_UNIQUE_ADDRESS Promise promise_;
    };
    GPR_NO_UNIQUE_ADDRESS OnComplete on_complete_;
    bool started_ = false;
  };

  // Notification that the party has finished and this instance can be deleted.
  // Derived types should arrange to call CancelRemainingParticipants during
  // this sequence.
  virtual void PartyOver() = 0;

  // Run the locked part of the party until it is unlocked.
  void RunLocked();
  // Called in response to Unref() hitting zero - ultimately calls PartyOver,
  // but needs to set some stuff up.
  // Here so it gets compiled out of line.
  void PartyIsOver();

  // Wakeable implementation
  void Wakeup(WakeupMask wakeup_mask) final;
  void WakeupAsync(WakeupMask wakeup_mask) final;
  void Drop(WakeupMask wakeup_mask) final;

  // Add a participant (backs Spawn, after type erasure to ParticipantFactory).
  void AddParticipants(Participant** participant, size_t count);

  virtual grpc_event_engine::experimental::EventEngine* event_engine()
      const = 0;

  // Sentinal value for currently_polling_ when no participant is being polled.
  static constexpr uint8_t kNotPolling = 255;

#ifdef GRPC_PARTY_SYNC_USING_ATOMICS
  PartySyncUsingAtomics sync_;
#elif defined(GRPC_PARTY_SYNC_USING_MUTEX)
  PartySyncUsingMutex sync_;
#else
#error No synchronization method defined
#endif

  Arena* const arena_;
  uint8_t currently_polling_ = kNotPolling;
  // All current participants, using a tagged format.
  // If the lower bit is unset, then this is a Participant*.
  // If the lower bit is set, then this is a ParticipantFactory*.
  std::atomic<Participant*> participants_[party_detail::kMaxParticipants] = {};
};

template <typename Factory, typename OnComplete>
void Party::BulkSpawner::Spawn(absl::string_view name, Factory promise_factory,
                               OnComplete on_complete) {
  if (grpc_trace_promise_primitives.enabled()) {
    gpr_log(GPR_DEBUG, "%s[bulk_spawn] On %p queue %s",
            party_->DebugTag().c_str(), this, std::string(name).c_str());
  }
  participants_[num_participants_++] =
      party_->arena_->NewPooled<ParticipantImpl<Factory, OnComplete>>(
          name, std::move(promise_factory), std::move(on_complete));
}

template <typename Factory, typename OnComplete>
void Party::Spawn(absl::string_view name, Factory promise_factory,
                  OnComplete on_complete) {
  BulkSpawner(this).Spawn(name, std::move(promise_factory),
                          std::move(on_complete));
}

}  // namespace grpc_core

#endif  // GRPC_SRC_CORE_LIB_PROMISE_PARTY_H