diff options
author | Armando Montanez <amontanez@google.com> | 2023-11-15 19:04:41 +0000 |
---|---|---|
committer | CQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com> | 2023-11-15 19:04:41 +0000 |
commit | 3587ce2d020e9336a440146430727fdb01d9d60f (patch) | |
tree | dcee2fd881f5b3cbdd0003a1e3a1dfa8f3350543 | |
parent | 3c3dc003812d05bb31a47b93ff1df5666ad3db7f (diff) | |
download | pigweed-3587ce2d020e9336a440146430727fdb01d9d60f.tar.gz |
Revert "pw_rpc_transport: Close sockets when stopping"
This reverts commit 50d8c114135f26bd613baa99615e685a869b50b2.
Reason for revert: b/309680612 races manifesting on macOS.
Original change's description:
> pw_rpc_transport: Close sockets when stopping
>
> Also, when closing sockets, disconnect the underlying sockets by using
> the socket shutdown API. This unblocks socket recv and accept calls.
>
> Bug: 309680612
> Test: Verified socket unit tests pass.
> Test: See details in testing done comment in the code review.
> Change-Id: If9e54bdb88fa36a3f831d32b4549d5a222fd3df8
> Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/179591
> Reviewed-by: Carlos Chinchilla <cachinchilla@google.com>
> Commit-Queue: Erik Staats <estaats@google.com>
> Presubmit-Verified: CQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com>
# Not skipping CQ checks because original CL landed > 1 day ago.
Bug: 309680612
Change-Id: Ie8f0bc3d6665c9c11c680a964bb6842b5d570c43
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/181130
Pigweed-Auto-Submit: Rob Mohr <mohrr@google.com>
Commit-Queue: Rob Mohr <mohrr@google.com>
Reviewed-by: Rob Mohr <mohrr@google.com>
Reviewed-by: Armando Montanez <amontanez@google.com>
Reviewed-by: Erik Staats <estaats@google.com>
-rw-r--r-- | pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h | 15 | ||||
-rw-r--r-- | pw_rpc_transport/rpc_integration_test.cc | 6 | ||||
-rw-r--r-- | pw_rpc_transport/socket_rpc_transport_test.cc | 16 | ||||
-rw-r--r-- | pw_stream/BUILD.bazel | 1 | ||||
-rw-r--r-- | pw_stream/BUILD.gn | 5 | ||||
-rw-r--r-- | pw_stream/CMakeLists.txt | 1 | ||||
-rw-r--r-- | pw_stream/public/pw_stream/socket_stream.h | 38 | ||||
-rw-r--r-- | pw_stream/socket_stream.cc | 97 |
8 files changed, 43 insertions, 136 deletions
diff --git a/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h b/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h index 3d4121471..e3a84c5ea 100644 --- a/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h +++ b/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h @@ -106,10 +106,6 @@ class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore { while (!stopped_) { const auto read_status = ReadData(); - // Break if ReadData was cancelled after the transport was stopped. - if (stopped_) { - break; - } if (!read_status.ok()) { internal::LogSocketReadError(read_status); } @@ -126,11 +122,7 @@ class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore { } } - void Stop() { - stopped_ = true; - socket_stream_.Close(); - server_socket_.Close(); - } + void Stop() { stopped_ = true; } private: enum class ClientServerRole { kClient, kServer }; @@ -164,11 +156,6 @@ class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore { NotifyReady(); Result<stream::SocketStream> stream = server_socket_.Accept(); - // If Accept was cancelled due to stopping the transport, return without - // error. - if (stopped_) { - return OkStatus(); - } if (!stream.ok()) { internal::LogSocketAcceptError(stream.status()); return stream.status(); diff --git a/pw_rpc_transport/rpc_integration_test.cc b/pw_rpc_transport/rpc_integration_test.cc index 633ff848f..c48672d3e 100644 --- a/pw_rpc_transport/rpc_integration_test.cc +++ b/pw_rpc_transport/rpc_integration_test.cc @@ -120,6 +120,12 @@ TEST(RpcIntegrationTest, SocketTransport) { a.transport.Stop(); b.transport.Stop(); + // Unblock socket transports by sending terminator packets. + const std::array<std::byte, 1> terminator_bytes{std::byte{0x42}}; + RpcFrame terminator{.header = {}, .payload = terminator_bytes}; + EXPECT_EQ(a.transport.Send(terminator), OkStatus()); + EXPECT_EQ(b.transport.Send(terminator), OkStatus()); + a_local_egress_thread.join(); b_local_egress_thread.join(); a_transport_thread.join(); diff --git a/pw_rpc_transport/socket_rpc_transport_test.cc b/pw_rpc_transport/socket_rpc_transport_test.cc index 8f128726e..12c0d845c 100644 --- a/pw_rpc_transport/socket_rpc_transport_test.cc +++ b/pw_rpc_transport/socket_rpc_transport_test.cc @@ -115,12 +115,18 @@ class SocketSender { } } + // stream::SocketStream doesn't support read timeouts so we have to + // unblock socket reads by sending more data after the transport is stopped. + pw::Status Terminate() { return transport_.Send(terminator_); } + private: SocketRpcTransport<kReadBufferSize>& transport_; std::vector<std::byte> sent_; std::array<std::byte, 256> data_{}; std::uniform_int_distribution<size_t> offset_dist_{0, 255}; std::uniform_int_distribution<size_t> size_dist_{1, kMaxWriteSize}; + std::array<std::byte, 1> terminator_bytes_{std::byte{0x42}}; + RpcFrame terminator_{.header = {}, .payload = terminator_bytes_}; }; class SocketSenderThreadCore : public SocketSender, public thread::ThreadCore { @@ -176,6 +182,10 @@ TEST(SocketRpcTransportTest, SendAndReceiveFramesOverSocketConnection) { server.Stop(); client.Stop(); + // Unblock socket reads to propagate the stop signal. + EXPECT_EQ(server_sender.Terminate(), OkStatus()); + EXPECT_EQ(client_sender.Terminate(), OkStatus()); + server_thread.join(); client_thread.join(); @@ -232,6 +242,7 @@ TEST(SocketRpcTransportTest, ServerReconnects) { // Stop the client but not the server: we're re-using the same server // with a new client below. client.Stop(); + EXPECT_EQ(server_sender.Terminate(), OkStatus()); client_thread.join(); } @@ -256,11 +267,13 @@ TEST(SocketRpcTransportTest, ServerReconnects) { std::back_inserter(received)); client.Stop(); + EXPECT_EQ(server_sender.Terminate(), OkStatus()); client_thread.join(); // This time stop the server as well. SocketSender client_sender(client); server.Stop(); + EXPECT_EQ(client_sender.Terminate(), OkStatus()); server_thread.join(); } @@ -309,6 +322,7 @@ TEST(SocketRpcTransportTest, ClientReconnects) { server1_sent.end(), std::back_inserter(sent_by_server)); + EXPECT_EQ(client_sender.Terminate(), OkStatus()); server_thread.join(); server = nullptr; @@ -331,9 +345,11 @@ TEST(SocketRpcTransportTest, ClientReconnects) { server2_sent.end(), std::back_inserter(sent_by_server)); + EXPECT_EQ(client_sender.Terminate(), OkStatus()); server_thread.join(); client.Stop(); + EXPECT_EQ(server2_sender.Terminate(), OkStatus()); client_thread.join(); server = nullptr; diff --git a/pw_stream/BUILD.bazel b/pw_stream/BUILD.bazel index 733878ad0..1d40bfd84 100644 --- a/pw_stream/BUILD.bazel +++ b/pw_stream/BUILD.bazel @@ -52,7 +52,6 @@ pw_cc_library( ":pw_stream", "//pw_log", "//pw_string", - "//pw_sync:mutex", "//pw_sys_io", ], ) diff --git a/pw_stream/BUILD.gn b/pw_stream/BUILD.gn index 45357a0ec..a55fb91a0 100644 --- a/pw_stream/BUILD.gn +++ b/pw_stream/BUILD.gn @@ -48,10 +48,7 @@ pw_source_set("pw_stream") { pw_source_set("socket_stream") { public_configs = [ ":public_include_path" ] - public_deps = [ - ":pw_stream", - "$dir_pw_sync:mutex", - ] + public_deps = [ ":pw_stream" ] deps = [ dir_pw_assert, dir_pw_log, diff --git a/pw_stream/CMakeLists.txt b/pw_stream/CMakeLists.txt index 623ac4720..b04b440af 100644 --- a/pw_stream/CMakeLists.txt +++ b/pw_stream/CMakeLists.txt @@ -40,7 +40,6 @@ pw_add_library(pw_stream.socket_stream STATIC public PUBLIC_DEPS pw_stream - pw_sync.mutex SOURCES socket_stream.cc PRIVATE_DEPS diff --git a/pw_stream/public/pw_stream/socket_stream.h b/pw_stream/public/pw_stream/socket_stream.h index e80739ea0..a9b7b16a6 100644 --- a/pw_stream/public/pw_stream/socket_stream.h +++ b/pw_stream/public/pw_stream/socket_stream.h @@ -18,32 +18,24 @@ #include "pw_result/result.h" #include "pw_span/span.h" #include "pw_stream/stream.h" -#include "pw_sync/mutex.h" namespace pw::stream { class SocketStream : public NonSeekableReaderWriter { public: - SocketStream() = default; + constexpr SocketStream() = default; // Construct a SocketStream directly from a file descriptor. - explicit SocketStream(int connection_fd) : connection_fd_(connection_fd) { - // Take ownership of the connection fd by this object. - TakeConnectionFd(); - } + explicit SocketStream(int connection_fd) : connection_fd_(connection_fd) {} // SocketStream objects are moveable but not copyable. SocketStream& operator=(SocketStream&& other) { connection_fd_ = other.connection_fd_; other.connection_fd_ = kInvalidFd; - connection_fd_own_count_ = other.connection_fd_own_count_; - other.connection_fd_own_count_ = 0; return *this; } SocketStream(SocketStream&& other) noexcept : connection_fd_(other.connection_fd_) { other.connection_fd_ = kInvalidFd; - connection_fd_own_count_ = other.connection_fd_own_count_; - other.connection_fd_own_count_ = 0; } SocketStream(const SocketStream&) = delete; SocketStream& operator=(const SocketStream&) = delete; @@ -74,20 +66,7 @@ class SocketStream : public NonSeekableReaderWriter { StatusWithSize DoRead(ByteSpan dest) override; - // Take ownership of the connection fd. There may be multiple owners. Each - // time TakeConnectionFd is called, ReleaseConnectionFd must be called to - // release ownership, even if the connection fd is invalid. - // - // Returns the connection fd. - int TakeConnectionFd(); - - // Release ownership of the connection fd. If no owners remain, close and - // clear the connection fd. - void ReleaseConnectionFd(); - - sync::Mutex connection_fd_mutex_; int connection_fd_ = kInvalidFd; - int connection_fd_own_count_ = 0; }; /// `ServerSocket` wraps a POSIX-style server socket, producing a `SocketStream` @@ -121,21 +100,8 @@ class ServerSocket { private: static constexpr int kInvalidFd = -1; - // Take ownership of the socket fd. There may be multiple owners. Each time - // TakeSocketFd is called, ReleaseReleaseFd must be called to release - // ownership, even if the socket fd is invalid. - // - // Returns the socket fd. - int TakeSocketFd(); - - // Release ownership of the socket fd. If no owners remain, close and clear - // the socket fd. - void ReleaseSocketFd(); - uint16_t port_ = -1; - sync::Mutex socket_fd_mutex_; int socket_fd_ = kInvalidFd; - int socket_fd_own_count_ = 0; }; } // namespace pw::stream diff --git a/pw_stream/socket_stream.cc b/pw_stream/socket_stream.cc index 7ac374c22..b3125439c 100644 --- a/pw_stream/socket_stream.cc +++ b/pw_stream/socket_stream.cc @@ -96,8 +96,6 @@ Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) { for (rp = res; rp != nullptr; rp = rp->ai_next) { connection_fd_ = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (connection_fd_ != kInvalidFd) { - // Take ownership of the connection fd by this object. - TakeConnectionFd(); break; } } @@ -110,8 +108,8 @@ Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) { ConfigureSocket(connection_fd_); if (connect(connection_fd_, rp->ai_addr, rp->ai_addrlen) == -1) { - // Release ownership of the connection fd by this object. - ReleaseConnectionFd(); + close(connection_fd_); + connection_fd_ = kInvalidFd; PW_LOG_ERROR( "Failed to connect to %s:%d: %s", host, port, std::strerror(errno)); freeaddrinfo(res); @@ -123,15 +121,10 @@ Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) { } void SocketStream::Close() { - int fd = TakeConnectionFd(); - if (fd != kInvalidFd) { - // Shutdown the connection to cancel any blocking calls. - shutdown(fd, SHUT_RDWR); - - // Release ownership of the connection fd by this object. - ReleaseConnectionFd(); + if (connection_fd_ != kInvalidFd) { + close(connection_fd_); + connection_fd_ = kInvalidFd; } - ReleaseConnectionFd(); } Status SocketStream::DoWrite(span<const std::byte> data) { @@ -142,16 +135,10 @@ Status SocketStream::DoWrite(span<const std::byte> data) { send_flags |= MSG_NOSIGNAL; #endif // defined(__linux__) - int fd = TakeConnectionFd(); - if (fd == kInvalidFd) { - ReleaseConnectionFd(); - return Status::Unknown(); - } - ssize_t bytes_sent = send(fd, + ssize_t bytes_sent = send(connection_fd_, reinterpret_cast<const char*>(data.data()), data.size_bytes(), send_flags); - ReleaseConnectionFd(); if (bytes_sent < 0 || static_cast<size_t>(bytes_sent) != data.size()) { if (errno == EPIPE) { @@ -166,14 +153,10 @@ Status SocketStream::DoWrite(span<const std::byte> data) { } StatusWithSize SocketStream::DoRead(ByteSpan dest) { - int fd = TakeConnectionFd(); - if (fd == kInvalidFd) { - ReleaseConnectionFd(); - return StatusWithSize::Unknown(); - } - ssize_t bytes_rcvd = - recv(fd, reinterpret_cast<char*>(dest.data()), dest.size_bytes(), 0); - ReleaseConnectionFd(); + ssize_t bytes_rcvd = recv(connection_fd_, + reinterpret_cast<char*>(dest.data()), + dest.size_bytes(), + 0); if (bytes_rcvd == 0) { // Remote peer has closed the connection. Close(); @@ -191,23 +174,6 @@ StatusWithSize SocketStream::DoRead(ByteSpan dest) { return StatusWithSize(bytes_rcvd); } -int SocketStream::TakeConnectionFd() { - std::lock_guard lock(connection_fd_mutex_); - int fd = connection_fd_; - ++connection_fd_own_count_; - return fd; -} - -void SocketStream::ReleaseConnectionFd() { - std::lock_guard lock(connection_fd_mutex_); - int fd = connection_fd_; - --connection_fd_own_count_; - if ((connection_fd_own_count_ <= 0) && (fd != kInvalidFd)) { - connection_fd_ = kInvalidFd; - close(fd); - } -} - // Listen for connections on the given port. // If port is 0, a random unused port is chosen and can be retrieved with // port(). @@ -216,8 +182,6 @@ Status ServerSocket::Listen(uint16_t port) { if (socket_fd_ == kInvalidFd) { return Status::Unknown(); } - // Take ownership of the socket fd by this object. - TakeSocketFd(); // Allow binding to an address that may still be in use by a closed socket. constexpr int value = 1; @@ -248,8 +212,7 @@ Status ServerSocket::Listen(uint16_t port) { if (getsockname(socket_fd_, reinterpret_cast<sockaddr*>(&addr), &addr_len) < 0 || static_cast<size_t>(addr_len) > sizeof(addr)) { - // Release ownership of the socket fd by this object. - ReleaseSocketFd(); + close(socket_fd_); return Status::Unknown(); } @@ -264,49 +227,23 @@ Result<SocketStream> ServerSocket::Accept() { struct sockaddr_in6 sockaddr_client_ = {}; socklen_t len = sizeof(sockaddr_client_); - int fd = TakeSocketFd(); - if (fd == kInvalidFd) { - ReleaseSocketFd(); - return Status::Unknown(); - } int connection_fd = - accept(fd, reinterpret_cast<sockaddr*>(&sockaddr_client_), &len); - ReleaseSocketFd(); + accept(socket_fd_, reinterpret_cast<sockaddr*>(&sockaddr_client_), &len); if (connection_fd == kInvalidFd) { return Status::Unknown(); } ConfigureSocket(connection_fd); - return SocketStream(connection_fd); + SocketStream client_stream; + client_stream.connection_fd_ = connection_fd; + return client_stream; } // Close the server socket, preventing further connections. void ServerSocket::Close() { - int fd = TakeSocketFd(); - if (fd != kInvalidFd) { - // Shutdown the connection to cancel any blocking calls. - shutdown(fd, SHUT_RDWR); - - // Release ownership of the socket fd by this object. - ReleaseSocketFd(); - } - ReleaseSocketFd(); -} - -int ServerSocket::TakeSocketFd() { - std::lock_guard lock(socket_fd_mutex_); - int fd = socket_fd_; - ++socket_fd_own_count_; - return fd; -} - -void ServerSocket::ReleaseSocketFd() { - std::lock_guard lock(socket_fd_mutex_); - int fd = socket_fd_; - --socket_fd_own_count_; - if ((socket_fd_own_count_ <= 0) && (fd != kInvalidFd)) { + if (socket_fd_ != kInvalidFd) { + close(socket_fd_); socket_fd_ = kInvalidFd; - close(fd); } } |