diff options
Diffstat (limited to 'pw_stream/mpsc_stream_test.cc')
-rw-r--r-- | pw_stream/mpsc_stream_test.cc | 601 |
1 files changed, 601 insertions, 0 deletions
diff --git a/pw_stream/mpsc_stream_test.cc b/pw_stream/mpsc_stream_test.cc new file mode 100644 index 000000000..1a08bed2f --- /dev/null +++ b/pw_stream/mpsc_stream_test.cc @@ -0,0 +1,601 @@ +// Copyright 2023 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_stream/mpsc_stream.h" + +#include "gtest/gtest.h" +#include "pw_containers/vector.h" +#include "pw_fuzzer/fuzztest.h" +#include "pw_random/xor_shift.h" +#include "pw_thread/test_thread_context.h" +#include "pw_thread/thread.h" + +namespace pw::stream { +namespace { + +using namespace std::chrono_literals; +using namespace pw::fuzzer; + +//////////////////////////////////////////////////////////////////////////////// +// Test fixtures. + +/// Capacity in bytes for data buffers. +constexpr size_t kBufSize = 512; + +/// Fills a byte span with random data. +void Fill(std::byte* buf, size_t len) { + ByteSpan data(buf, len); + random::XorShiftStarRng64 rng(1); + rng.Get(data); +} + +/// FNV-1a offset basis. +constexpr uint64_t kOffsetBasis = 0xcbf29ce484222325ULL; + +/// FNV-1a prime value. +constexpr uint64_t kPrimeValue = 0x100000001b3ULL; + +/// Quick implementation of public-domain Fowler-Noll-Vo hashing algorithm. +/// +/// This is used in the tests below to verify equality of two sequences of bytes +/// that are too large to compare directly. +/// +/// See http://www.isthe.com/chongo/tech/comp/fnv/index.html +void fnv1a(ConstByteSpan bytes, uint64_t& hash) { + for (const auto& b : bytes) { + hash = (hash ^ static_cast<uint8_t>(b)) * kPrimeValue; + } +} + +/// MpscStream test context that uses a generic reader. +/// +/// This struct associates a reader and writer with their parameters and return +/// values. This is useful for communicating with threads spawned to call a +/// blocking method. +struct MpscTestContext { + MpscWriter writer; + MpscReader reader; + + ConstByteSpan data; + std::byte write_buffer[kBufSize]; + uint64_t write_hash = kOffsetBasis; + Status write_status; + + ByteSpan destination; + std::byte read_buffer[kBufSize]; + Result<ByteSpan> read_result; + uint64_t read_hash = kOffsetBasis; + size_t total_read = 0; + + MpscTestContext() { + data = ConstByteSpan(write_buffer); + destination = ByteSpan(read_buffer); + } + + void Connect() { CreateMpscStream(reader, writer); } + + // Fills a byte span with random data. + void Fill() { pw::stream::Fill(write_buffer, sizeof(write_buffer)); } + + // Writes data using the writer. + void Write() { + fnv1a(data, write_hash); + write_status = writer.Write(data); + } + + // Writes data repeatedly up to the writer's limit. + void WriteAll() { + size_t limit = writer.ConservativeWriteLimit(); + ASSERT_NE(limit, 0U); + ASSERT_NE(limit, Stream::kUnlimited); + while (limit != 0) { + if (limit < kBufSize) { + data = data.subspan(0, limit); + } + Fill(); + Write(); + if (!write_status.ok()) { + break; + } + limit = writer.ConservativeWriteLimit(); + } + } + + // Reads data using the reader. + void Read() { + read_result = reader.Read(destination); + if (read_result.ok()) { + fnv1a(*read_result, write_hash); + total_read += read_result->size(); + } + } + + // Run the given function on a dedicated thread. + using ThreadBody = Function<void(MpscTestContext* ctx)>; + void Spawn(ThreadBody func) { + body_ = std::move(func); + thread_ = thread::Thread( + context_.options(), + [](void* arg) { + auto* base = static_cast<MpscTestContext*>(arg); + base->body_(base); + }, + this); + } + + // Waits for the spawned thread to complete. + void Join() { thread_.join(); } + + private: + thread::Thread thread_; + thread::test::TestThreadContext context_; + ThreadBody body_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Unit tests. + +TEST(MpscStreamTest, CopyWriters) { + MpscTestContext ctx; + ctx.Connect(); + EXPECT_TRUE(ctx.reader.connected()); + EXPECT_TRUE(ctx.writer.connected()); + + MpscWriter writer2(ctx.writer); + EXPECT_TRUE(ctx.reader.connected()); + EXPECT_TRUE(ctx.writer.connected()); + EXPECT_TRUE(writer2.connected()); + + MpscWriter writer3 = writer2; + EXPECT_TRUE(ctx.reader.connected()); + EXPECT_TRUE(ctx.writer.connected()); + EXPECT_TRUE(writer2.connected()); + EXPECT_TRUE(writer3.connected()); + + ctx.writer.Close(); + writer2.Close(); + EXPECT_TRUE(ctx.reader.connected()); + EXPECT_FALSE(ctx.writer.connected()); + EXPECT_FALSE(writer2.connected()); + EXPECT_TRUE(writer3.connected()); +} + +TEST(MpscStreamTest, MoveWriters) { + MpscTestContext ctx; + ctx.Connect(); + EXPECT_TRUE(ctx.reader.connected()); + EXPECT_TRUE(ctx.writer.connected()); + + MpscWriter writer2(std::move(ctx.writer)); + EXPECT_TRUE(ctx.reader.connected()); + EXPECT_TRUE(writer2.connected()); + + MpscWriter writer3 = std::move(writer2); + EXPECT_TRUE(ctx.reader.connected()); + EXPECT_TRUE(writer3.connected()); + + // Only writer3 should be connected. + writer3.Close(); + EXPECT_FALSE(writer3.connected()); + EXPECT_FALSE(ctx.reader.connected()); +} + +TEST(MpscStreamTest, ReadFailsIfDisconnected) { + MpscTestContext ctx; + ctx.Connect(); + + ctx.writer.Close(); + ctx.Read(); + EXPECT_EQ(ctx.read_result.status(), Status::OutOfRange()); +} + +TEST(MpscStreamTest, ReadBlocksWhenEmpty) { + MpscTestContext ctx; + ctx.Connect(); + ctx.reader.SetTimeout(10ms); + + auto start = chrono::SystemClock::now(); + ctx.Read(); + auto elapsed = chrono::SystemClock::now() - start; + + EXPECT_EQ(ctx.read_result.status(), Status::ResourceExhausted()); + EXPECT_GE(elapsed, 10ms); +} + +TEST(MpscStreamTest, ReadReturnsAfterReaderClose) { + MpscTestContext ctx; + ctx.Connect(); + + ctx.Spawn([](MpscTestContext* inner) { inner->Read(); }); + ctx.reader.Close(); + ctx.Join(); + + EXPECT_EQ(ctx.read_result.status(), Status::OutOfRange()); +} + +TEST(MpscStreamTest, WriteBlocksUntilTimeout) { + MpscTestContext ctx; + ctx.Connect(); + ctx.writer.SetTimeout(10ms); + ctx.Fill(); + + auto start = chrono::SystemClock::now(); + ctx.Write(); + auto elapsed = chrono::SystemClock::now() - start; + + EXPECT_EQ(ctx.write_status, Status::ResourceExhausted()); + EXPECT_GE(elapsed, 10ms); +} + +TEST(MpscStreamTest, WriteReturnsAfterClose) { + MpscTestContext ctx; + ctx.Connect(); + + ctx.Fill(); + ctx.Spawn([](MpscTestContext* inner) { inner->Write(); }); + ctx.reader.Close(); + ctx.Join(); + + EXPECT_EQ(ctx.write_status, Status::OutOfRange()); +} + +void VerifyRoundtripImpl(const Vector<std::byte>& data, ByteSpan buffer) { + MpscTestContext ctx; + ctx.Connect(); + + ctx.reader.SetBuffer(buffer); + ctx.data = ConstByteSpan(data.data(), data.size()); + ctx.Spawn([](MpscTestContext* inner) { inner->Write(); }); + size_t offset = 0; + while (offset < data.size()) { + ctx.Read(); + ASSERT_EQ(ctx.read_result.status(), OkStatus()); + size_t num_read = ctx.read_result->size(); + EXPECT_EQ(memcmp(ctx.read_buffer, &data[offset], num_read), 0); + offset += num_read; + } + ctx.Join(); +} + +template <size_t kCapacity> +void FillAndVerifyRoundtripImpl(ByteSpan buffer) { + Vector<std::byte, kCapacity> data; + Fill(data.data(), data.size()); + VerifyRoundtripImpl(data, buffer); +} + +TEST(MpscStreamTest, VerifyRoundtripWithoutBufferSmall) { + FillAndVerifyRoundtripImpl<kBufSize / 2>(ByteSpan()); +} + +TEST(MpscStreamTest, VerifyRoundtripWithoutBufferLarge) { + FillAndVerifyRoundtripImpl<kBufSize * 2>(ByteSpan()); +} + +void VerifyRoundtripWithoutBuffer(const Vector<std::byte>& data) { + VerifyRoundtripImpl(data, ByteSpan()); +} +FUZZ_TEST(MpscStreamTest, VerifyRoundtripWithoutBuffer) + .WithDomains(VectorOf<kBufSize * 2>(Arbitrary<std::byte>()).WithMinSize(1)); + +TEST(MpscStreamTest, VerifyRoundtripWithBufferSmall) { + std::byte buffer[kBufSize]; + FillAndVerifyRoundtripImpl<kBufSize / 2>(buffer); +} + +TEST(MpscStreamTest, VerifyRoundtripWithBufferLarge) { + std::byte buffer[kBufSize]; + FillAndVerifyRoundtripImpl<kBufSize * 2>(buffer); +} + +void VerifyRoundtripWithBuffer(const Vector<std::byte>& data) { + std::byte buffer[kBufSize]; + VerifyRoundtripImpl(data, buffer); +} +FUZZ_TEST(MpscStreamTest, VerifyRoundtripWithBuffer) + .WithDomains(VectorOf<kBufSize * 2>(Arbitrary<std::byte>()).WithMinSize(1)); + +TEST(MpscStreamTest, CanRetryAfterPartialWrite) { + constexpr size_t kChunk = kBufSize - 4; + MpscTestContext ctx; + ctx.Connect(); + ctx.writer.SetTimeout(10ms); + ByteSpan destination = ctx.destination; + + ctx.Spawn([](MpscTestContext* inner) { + inner->Fill(); + inner->Write(); + }); + ctx.destination = destination.subspan(0, kChunk); + ctx.Read(); + ctx.Join(); + EXPECT_EQ(ctx.read_result.status(), OkStatus()); + EXPECT_EQ(ctx.read_result->size(), kChunk); + EXPECT_EQ(ctx.write_status, Status::ResourceExhausted()); + EXPECT_EQ(ctx.writer.last_write(), kChunk); + + ctx.Spawn([](MpscTestContext* inner) { + inner->data = inner->data.subspan(kChunk); + inner->Write(); + }); + ctx.destination = destination.subspan(kChunk); + ctx.Read(); + ctx.Join(); + EXPECT_EQ(ctx.read_result.status(), OkStatus()); + EXPECT_EQ(ctx.read_result->size(), 4U); + EXPECT_EQ(ctx.write_status, OkStatus()); + EXPECT_EQ(ctx.writer.last_write(), 4U); + + EXPECT_EQ(memcmp(ctx.write_buffer, ctx.read_buffer, kBufSize), 0); +} + +TEST(MpscStreamTest, CannotReadAfterReaderClose) { + MpscTestContext ctx; + ctx.Connect(); + ctx.reader.Close(); + ctx.Read(); + EXPECT_EQ(ctx.read_result.status(), Status::OutOfRange()); +} + +TEST(MpscStreamTest, CanReadAfterWriterCloses) { + MpscTestContext ctx; + ctx.Connect(); + std::byte buffer[kBufSize]; + ctx.reader.SetBuffer(buffer); + ctx.Fill(); + ctx.Write(); + EXPECT_EQ(ctx.write_status, OkStatus()); + ctx.writer.Close(); + + ctx.Read(); + ASSERT_EQ(ctx.read_result.status(), OkStatus()); + ASSERT_EQ(ctx.read_result->size(), kBufSize); + EXPECT_EQ(memcmp(ctx.write_buffer, ctx.read_buffer, kBufSize), 0); +} + +TEST(MpscStreamTest, CannotWriteAfterWriterClose) { + MpscTestContext ctx; + ctx.Connect(); + ctx.Fill(); + ctx.writer.Close(); + ctx.Write(); + EXPECT_EQ(ctx.write_status, Status::OutOfRange()); +} + +TEST(MpscStreamTest, CannotWriteAfterReaderClose) { + MpscTestContext ctx; + ctx.Connect(); + ctx.Fill(); + ctx.reader.Close(); + ctx.Write(); + EXPECT_EQ(ctx.write_status, Status::OutOfRange()); +} + +TEST(MpscStreamTest, MultipleWriters) { + MpscTestContext ctx1; + ctx1.Connect(); + Vector<std::byte, kBufSize + 1> data1(kBufSize + 1, std::byte(1)); + ctx1.data = ByteSpan(data1.data(), data1.size()); + + MpscTestContext ctx2; + ctx2.writer = ctx1.writer; + Vector<std::byte, kBufSize / 2> data2(kBufSize / 2, std::byte(2)); + ctx2.data = ByteSpan(data2.data(), data2.size()); + + MpscTestContext ctx3; + ctx3.writer = ctx1.writer; + Vector<std::byte, kBufSize * 3> data3(kBufSize * 3, std::byte(3)); + ctx3.data = ByteSpan(data3.data(), data3.size()); + + // Start all threads. + ctx1.Spawn([](MpscTestContext* ctx) { ctx->Write(); }); + ctx2.Spawn([](MpscTestContext* ctx) { ctx->Write(); }); + ctx3.Spawn([](MpscTestContext* ctx) { ctx->Write(); }); + + // The loop below keeps track of how many contiguous values are read, in order + // to verify that writes are not split or interleaved. + size_t expected[4] = {0, data1.size(), data2.size(), data3.size()}; + size_t actual[4] = {0}; + + size_t total_read = 0; + auto current = std::byte(0); + size_t num_current = 0; + while (total_read < data1.size() + data2.size() + data3.size()) { + ctx1.Read(); + if (!ctx1.read_result.ok()) { + break; + } + size_t num_read = ctx1.read_result->size(); + for (size_t i = 0; i < num_read; ++i) { + if (current == ctx1.read_buffer[i]) { + ++num_current; + continue; + } + actual[size_t(current)] = num_current; + current = ctx1.read_buffer[i]; + num_current = 1; + } + actual[size_t(current)] = num_current; + total_read += num_read; + } + ctx1.reader.Close(); + ctx1.Join(); + ctx2.Join(); + ctx3.Join(); + ASSERT_EQ(ctx1.read_result.status(), OkStatus()); + for (size_t i = 0; i < 4; ++i) { + EXPECT_EQ(actual[i], expected[i]); + } +} + +TEST(MpscStreamTest, GetAndSetLimits) { + MpscReader reader; + EXPECT_EQ(reader.ConservativeReadLimit(), 0U); + + MpscWriter writer; + EXPECT_EQ(writer.ConservativeWriteLimit(), 0U); + + CreateMpscStream(reader, writer); + EXPECT_EQ(reader.ConservativeReadLimit(), Stream::kUnlimited); + EXPECT_EQ(writer.ConservativeWriteLimit(), Stream::kUnlimited); + + writer.SetLimit(10); + EXPECT_EQ(reader.ConservativeReadLimit(), 10U); + EXPECT_EQ(writer.ConservativeWriteLimit(), 10U); + + writer.Close(); + EXPECT_EQ(reader.ConservativeReadLimit(), 0U); + EXPECT_EQ(writer.ConservativeWriteLimit(), 0U); +} + +TEST(MpscStreamTest, ReaderAggregatesLimit) { + MpscTestContext ctx; + ctx.Connect(); + ctx.writer.SetLimit(10); + + MpscWriter writer2 = ctx.writer; + writer2.SetLimit(20); + + EXPECT_EQ(ctx.reader.ConservativeReadLimit(), 30U); + + ctx.writer.SetLimit(Stream::kUnlimited); + EXPECT_EQ(ctx.reader.ConservativeReadLimit(), Stream::kUnlimited); + + writer2.SetLimit(40); + EXPECT_EQ(ctx.reader.ConservativeReadLimit(), Stream::kUnlimited); + + ctx.writer.SetLimit(0); + EXPECT_EQ(ctx.reader.ConservativeReadLimit(), 40U); +} + +TEST(MpscStreamTest, ReadingUpdatesLimit) { + MpscTestContext ctx; + ctx.Connect(); + + constexpr size_t kChunk = kBufSize - 4; + std::byte buffer[kBufSize]; + ctx.reader.SetBuffer(buffer); + ctx.Fill(); + ctx.writer.SetLimit(kBufSize); + ctx.Write(); + EXPECT_EQ(ctx.write_status, OkStatus()); + + ctx.destination = ByteSpan(ctx.read_buffer, kChunk); + ctx.Read(); + EXPECT_EQ(ctx.read_result.status(), OkStatus()); + EXPECT_EQ(ctx.read_result->size(), kChunk); + EXPECT_EQ(ctx.reader.ConservativeReadLimit(), kBufSize - kChunk); +} + +TEST(MpscStreamTest, CannotWriteMoreThanLimit) { + MpscTestContext ctx; + ctx.Connect(); + + std::byte buffer[kBufSize]; + ctx.reader.SetBuffer(buffer); + ctx.writer.SetLimit(kBufSize - 1); + ctx.Fill(); + ctx.Write(); + EXPECT_EQ(ctx.write_status, Status::ResourceExhausted()); +} + +TEST(MpscStreamTest, WritersCanCloseAutomatically) { + MpscTestContext ctx1; + ctx1.Connect(); + Vector<std::byte, kBufSize + 1> data1(kBufSize + 1, std::byte(1)); + ctx1.writer.SetLimit(data1.size()); + ctx1.data = ByteSpan(data1.data(), data1.size()); + + MpscTestContext ctx2; + ctx2.writer = ctx1.writer; + Vector<std::byte, kBufSize / 2> data2(kBufSize / 2, std::byte(2)); + ctx2.writer.SetLimit(data2.size()); + ctx2.data = ByteSpan(data2.data(), data2.size()); + + // Start all threads. + EXPECT_TRUE(ctx1.reader.connected()); + EXPECT_TRUE(ctx1.writer.connected()); + EXPECT_TRUE(ctx2.writer.connected()); + + ctx1.Spawn([](MpscTestContext* ctx) { ctx->Write(); }); + ctx2.Spawn([](MpscTestContext* ctx) { ctx->Write(); }); + + size_t total = 0; + while (ctx1.reader.ConservativeReadLimit() != 0) { + ctx1.Read(); + EXPECT_EQ(ctx1.read_result.status(), OkStatus()); + if (!ctx1.read_result.ok()) { + ctx1.reader.Close(); + break; + } + total += ctx1.read_result->size(); + } + EXPECT_EQ(total, data1.size() + data2.size()); + ctx1.Join(); + ctx2.Join(); + EXPECT_FALSE(ctx1.reader.connected()); + EXPECT_FALSE(ctx1.writer.connected()); + EXPECT_FALSE(ctx2.writer.connected()); +} + +TEST(MpscStreamTest, ReadAllWithoutBuffer) { + MpscTestContext ctx; + Status status = ctx.reader.ReadAll([](ConstByteSpan) { return OkStatus(); }); + EXPECT_EQ(status, Status::FailedPrecondition()); +} + +TEST(MpscStreamTest, ReadAll) { + MpscTestContext ctx; + ctx.Connect(); + + std::byte buffer[kBufSize]; + ctx.reader.SetBuffer(buffer); + ctx.writer.SetLimit(kBufSize * 100); + ctx.Spawn([](MpscTestContext* inner) { inner->WriteAll(); }); + + Status status = ctx.reader.ReadAll([&ctx](ConstByteSpan data) { + ctx.total_read += data.size(); + fnv1a(data, ctx.read_hash); + return OkStatus(); + }); + ctx.Join(); + + EXPECT_EQ(status, OkStatus()); + EXPECT_FALSE(ctx.reader.connected()); + EXPECT_EQ(ctx.total_read, kBufSize * 100); + EXPECT_EQ(ctx.read_hash, ctx.write_hash); +} + +TEST(MpscStreamTest, BufferedMpscReader) { + BufferedMpscReader<kBufSize> reader; + MpscWriter writer; + CreateMpscStream(reader, writer); + + // `kBufSize` writes of 1 byte each should fit without blocking. + for (size_t i = 0; i < kBufSize; ++i) { + std::byte b{static_cast<uint8_t>(i)}; + EXPECT_EQ(writer.Write(ConstByteSpan(&b, 1)), OkStatus()); + } + + std::byte rx_buffer[kBufSize]; + auto result = reader.Read(ByteSpan(rx_buffer)); + ASSERT_EQ(result.status(), OkStatus()); + ASSERT_EQ(result->size(), kBufSize); + for (size_t i = 0; i < kBufSize; ++i) { + EXPECT_EQ(rx_buffer[i], std::byte(i)); + } +} + +} // namespace +} // namespace pw::stream |