aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArmando Montanez <amontanez@google.com>2023-11-15 19:04:41 +0000
committerCQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com>2023-11-15 19:04:41 +0000
commit3587ce2d020e9336a440146430727fdb01d9d60f (patch)
treedcee2fd881f5b3cbdd0003a1e3a1dfa8f3350543
parent3c3dc003812d05bb31a47b93ff1df5666ad3db7f (diff)
downloadpigweed-3587ce2d020e9336a440146430727fdb01d9d60f.tar.gz
Revert "pw_rpc_transport: Close sockets when stopping"
This reverts commit 50d8c114135f26bd613baa99615e685a869b50b2. Reason for revert: b/309680612 races manifesting on macOS. Original change's description: > pw_rpc_transport: Close sockets when stopping > > Also, when closing sockets, disconnect the underlying sockets by using > the socket shutdown API. This unblocks socket recv and accept calls. > > Bug: 309680612 > Test: Verified socket unit tests pass. > Test: See details in testing done comment in the code review. > Change-Id: If9e54bdb88fa36a3f831d32b4549d5a222fd3df8 > Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/179591 > Reviewed-by: Carlos Chinchilla <cachinchilla@google.com> > Commit-Queue: Erik Staats <estaats@google.com> > Presubmit-Verified: CQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com> # Not skipping CQ checks because original CL landed > 1 day ago. Bug: 309680612 Change-Id: Ie8f0bc3d6665c9c11c680a964bb6842b5d570c43 Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/181130 Pigweed-Auto-Submit: Rob Mohr <mohrr@google.com> Commit-Queue: Rob Mohr <mohrr@google.com> Reviewed-by: Rob Mohr <mohrr@google.com> Reviewed-by: Armando Montanez <amontanez@google.com> Reviewed-by: Erik Staats <estaats@google.com>
-rw-r--r--pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h15
-rw-r--r--pw_rpc_transport/rpc_integration_test.cc6
-rw-r--r--pw_rpc_transport/socket_rpc_transport_test.cc16
-rw-r--r--pw_stream/BUILD.bazel1
-rw-r--r--pw_stream/BUILD.gn5
-rw-r--r--pw_stream/CMakeLists.txt1
-rw-r--r--pw_stream/public/pw_stream/socket_stream.h38
-rw-r--r--pw_stream/socket_stream.cc97
8 files changed, 43 insertions, 136 deletions
diff --git a/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h b/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h
index 3d4121471..e3a84c5ea 100644
--- a/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h
+++ b/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h
@@ -106,10 +106,6 @@ class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore {
while (!stopped_) {
const auto read_status = ReadData();
- // Break if ReadData was cancelled after the transport was stopped.
- if (stopped_) {
- break;
- }
if (!read_status.ok()) {
internal::LogSocketReadError(read_status);
}
@@ -126,11 +122,7 @@ class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore {
}
}
- void Stop() {
- stopped_ = true;
- socket_stream_.Close();
- server_socket_.Close();
- }
+ void Stop() { stopped_ = true; }
private:
enum class ClientServerRole { kClient, kServer };
@@ -164,11 +156,6 @@ class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore {
NotifyReady();
Result<stream::SocketStream> stream = server_socket_.Accept();
- // If Accept was cancelled due to stopping the transport, return without
- // error.
- if (stopped_) {
- return OkStatus();
- }
if (!stream.ok()) {
internal::LogSocketAcceptError(stream.status());
return stream.status();
diff --git a/pw_rpc_transport/rpc_integration_test.cc b/pw_rpc_transport/rpc_integration_test.cc
index 633ff848f..c48672d3e 100644
--- a/pw_rpc_transport/rpc_integration_test.cc
+++ b/pw_rpc_transport/rpc_integration_test.cc
@@ -120,6 +120,12 @@ TEST(RpcIntegrationTest, SocketTransport) {
a.transport.Stop();
b.transport.Stop();
+ // Unblock socket transports by sending terminator packets.
+ const std::array<std::byte, 1> terminator_bytes{std::byte{0x42}};
+ RpcFrame terminator{.header = {}, .payload = terminator_bytes};
+ EXPECT_EQ(a.transport.Send(terminator), OkStatus());
+ EXPECT_EQ(b.transport.Send(terminator), OkStatus());
+
a_local_egress_thread.join();
b_local_egress_thread.join();
a_transport_thread.join();
diff --git a/pw_rpc_transport/socket_rpc_transport_test.cc b/pw_rpc_transport/socket_rpc_transport_test.cc
index 8f128726e..12c0d845c 100644
--- a/pw_rpc_transport/socket_rpc_transport_test.cc
+++ b/pw_rpc_transport/socket_rpc_transport_test.cc
@@ -115,12 +115,18 @@ class SocketSender {
}
}
+ // stream::SocketStream doesn't support read timeouts so we have to
+ // unblock socket reads by sending more data after the transport is stopped.
+ pw::Status Terminate() { return transport_.Send(terminator_); }
+
private:
SocketRpcTransport<kReadBufferSize>& transport_;
std::vector<std::byte> sent_;
std::array<std::byte, 256> data_{};
std::uniform_int_distribution<size_t> offset_dist_{0, 255};
std::uniform_int_distribution<size_t> size_dist_{1, kMaxWriteSize};
+ std::array<std::byte, 1> terminator_bytes_{std::byte{0x42}};
+ RpcFrame terminator_{.header = {}, .payload = terminator_bytes_};
};
class SocketSenderThreadCore : public SocketSender, public thread::ThreadCore {
@@ -176,6 +182,10 @@ TEST(SocketRpcTransportTest, SendAndReceiveFramesOverSocketConnection) {
server.Stop();
client.Stop();
+ // Unblock socket reads to propagate the stop signal.
+ EXPECT_EQ(server_sender.Terminate(), OkStatus());
+ EXPECT_EQ(client_sender.Terminate(), OkStatus());
+
server_thread.join();
client_thread.join();
@@ -232,6 +242,7 @@ TEST(SocketRpcTransportTest, ServerReconnects) {
// Stop the client but not the server: we're re-using the same server
// with a new client below.
client.Stop();
+ EXPECT_EQ(server_sender.Terminate(), OkStatus());
client_thread.join();
}
@@ -256,11 +267,13 @@ TEST(SocketRpcTransportTest, ServerReconnects) {
std::back_inserter(received));
client.Stop();
+ EXPECT_EQ(server_sender.Terminate(), OkStatus());
client_thread.join();
// This time stop the server as well.
SocketSender client_sender(client);
server.Stop();
+ EXPECT_EQ(client_sender.Terminate(), OkStatus());
server_thread.join();
}
@@ -309,6 +322,7 @@ TEST(SocketRpcTransportTest, ClientReconnects) {
server1_sent.end(),
std::back_inserter(sent_by_server));
+ EXPECT_EQ(client_sender.Terminate(), OkStatus());
server_thread.join();
server = nullptr;
@@ -331,9 +345,11 @@ TEST(SocketRpcTransportTest, ClientReconnects) {
server2_sent.end(),
std::back_inserter(sent_by_server));
+ EXPECT_EQ(client_sender.Terminate(), OkStatus());
server_thread.join();
client.Stop();
+ EXPECT_EQ(server2_sender.Terminate(), OkStatus());
client_thread.join();
server = nullptr;
diff --git a/pw_stream/BUILD.bazel b/pw_stream/BUILD.bazel
index 733878ad0..1d40bfd84 100644
--- a/pw_stream/BUILD.bazel
+++ b/pw_stream/BUILD.bazel
@@ -52,7 +52,6 @@ pw_cc_library(
":pw_stream",
"//pw_log",
"//pw_string",
- "//pw_sync:mutex",
"//pw_sys_io",
],
)
diff --git a/pw_stream/BUILD.gn b/pw_stream/BUILD.gn
index 45357a0ec..a55fb91a0 100644
--- a/pw_stream/BUILD.gn
+++ b/pw_stream/BUILD.gn
@@ -48,10 +48,7 @@ pw_source_set("pw_stream") {
pw_source_set("socket_stream") {
public_configs = [ ":public_include_path" ]
- public_deps = [
- ":pw_stream",
- "$dir_pw_sync:mutex",
- ]
+ public_deps = [ ":pw_stream" ]
deps = [
dir_pw_assert,
dir_pw_log,
diff --git a/pw_stream/CMakeLists.txt b/pw_stream/CMakeLists.txt
index 623ac4720..b04b440af 100644
--- a/pw_stream/CMakeLists.txt
+++ b/pw_stream/CMakeLists.txt
@@ -40,7 +40,6 @@ pw_add_library(pw_stream.socket_stream STATIC
public
PUBLIC_DEPS
pw_stream
- pw_sync.mutex
SOURCES
socket_stream.cc
PRIVATE_DEPS
diff --git a/pw_stream/public/pw_stream/socket_stream.h b/pw_stream/public/pw_stream/socket_stream.h
index e80739ea0..a9b7b16a6 100644
--- a/pw_stream/public/pw_stream/socket_stream.h
+++ b/pw_stream/public/pw_stream/socket_stream.h
@@ -18,32 +18,24 @@
#include "pw_result/result.h"
#include "pw_span/span.h"
#include "pw_stream/stream.h"
-#include "pw_sync/mutex.h"
namespace pw::stream {
class SocketStream : public NonSeekableReaderWriter {
public:
- SocketStream() = default;
+ constexpr SocketStream() = default;
// Construct a SocketStream directly from a file descriptor.
- explicit SocketStream(int connection_fd) : connection_fd_(connection_fd) {
- // Take ownership of the connection fd by this object.
- TakeConnectionFd();
- }
+ explicit SocketStream(int connection_fd) : connection_fd_(connection_fd) {}
// SocketStream objects are moveable but not copyable.
SocketStream& operator=(SocketStream&& other) {
connection_fd_ = other.connection_fd_;
other.connection_fd_ = kInvalidFd;
- connection_fd_own_count_ = other.connection_fd_own_count_;
- other.connection_fd_own_count_ = 0;
return *this;
}
SocketStream(SocketStream&& other) noexcept
: connection_fd_(other.connection_fd_) {
other.connection_fd_ = kInvalidFd;
- connection_fd_own_count_ = other.connection_fd_own_count_;
- other.connection_fd_own_count_ = 0;
}
SocketStream(const SocketStream&) = delete;
SocketStream& operator=(const SocketStream&) = delete;
@@ -74,20 +66,7 @@ class SocketStream : public NonSeekableReaderWriter {
StatusWithSize DoRead(ByteSpan dest) override;
- // Take ownership of the connection fd. There may be multiple owners. Each
- // time TakeConnectionFd is called, ReleaseConnectionFd must be called to
- // release ownership, even if the connection fd is invalid.
- //
- // Returns the connection fd.
- int TakeConnectionFd();
-
- // Release ownership of the connection fd. If no owners remain, close and
- // clear the connection fd.
- void ReleaseConnectionFd();
-
- sync::Mutex connection_fd_mutex_;
int connection_fd_ = kInvalidFd;
- int connection_fd_own_count_ = 0;
};
/// `ServerSocket` wraps a POSIX-style server socket, producing a `SocketStream`
@@ -121,21 +100,8 @@ class ServerSocket {
private:
static constexpr int kInvalidFd = -1;
- // Take ownership of the socket fd. There may be multiple owners. Each time
- // TakeSocketFd is called, ReleaseReleaseFd must be called to release
- // ownership, even if the socket fd is invalid.
- //
- // Returns the socket fd.
- int TakeSocketFd();
-
- // Release ownership of the socket fd. If no owners remain, close and clear
- // the socket fd.
- void ReleaseSocketFd();
-
uint16_t port_ = -1;
- sync::Mutex socket_fd_mutex_;
int socket_fd_ = kInvalidFd;
- int socket_fd_own_count_ = 0;
};
} // namespace pw::stream
diff --git a/pw_stream/socket_stream.cc b/pw_stream/socket_stream.cc
index 7ac374c22..b3125439c 100644
--- a/pw_stream/socket_stream.cc
+++ b/pw_stream/socket_stream.cc
@@ -96,8 +96,6 @@ Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) {
for (rp = res; rp != nullptr; rp = rp->ai_next) {
connection_fd_ = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (connection_fd_ != kInvalidFd) {
- // Take ownership of the connection fd by this object.
- TakeConnectionFd();
break;
}
}
@@ -110,8 +108,8 @@ Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) {
ConfigureSocket(connection_fd_);
if (connect(connection_fd_, rp->ai_addr, rp->ai_addrlen) == -1) {
- // Release ownership of the connection fd by this object.
- ReleaseConnectionFd();
+ close(connection_fd_);
+ connection_fd_ = kInvalidFd;
PW_LOG_ERROR(
"Failed to connect to %s:%d: %s", host, port, std::strerror(errno));
freeaddrinfo(res);
@@ -123,15 +121,10 @@ Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) {
}
void SocketStream::Close() {
- int fd = TakeConnectionFd();
- if (fd != kInvalidFd) {
- // Shutdown the connection to cancel any blocking calls.
- shutdown(fd, SHUT_RDWR);
-
- // Release ownership of the connection fd by this object.
- ReleaseConnectionFd();
+ if (connection_fd_ != kInvalidFd) {
+ close(connection_fd_);
+ connection_fd_ = kInvalidFd;
}
- ReleaseConnectionFd();
}
Status SocketStream::DoWrite(span<const std::byte> data) {
@@ -142,16 +135,10 @@ Status SocketStream::DoWrite(span<const std::byte> data) {
send_flags |= MSG_NOSIGNAL;
#endif // defined(__linux__)
- int fd = TakeConnectionFd();
- if (fd == kInvalidFd) {
- ReleaseConnectionFd();
- return Status::Unknown();
- }
- ssize_t bytes_sent = send(fd,
+ ssize_t bytes_sent = send(connection_fd_,
reinterpret_cast<const char*>(data.data()),
data.size_bytes(),
send_flags);
- ReleaseConnectionFd();
if (bytes_sent < 0 || static_cast<size_t>(bytes_sent) != data.size()) {
if (errno == EPIPE) {
@@ -166,14 +153,10 @@ Status SocketStream::DoWrite(span<const std::byte> data) {
}
StatusWithSize SocketStream::DoRead(ByteSpan dest) {
- int fd = TakeConnectionFd();
- if (fd == kInvalidFd) {
- ReleaseConnectionFd();
- return StatusWithSize::Unknown();
- }
- ssize_t bytes_rcvd =
- recv(fd, reinterpret_cast<char*>(dest.data()), dest.size_bytes(), 0);
- ReleaseConnectionFd();
+ ssize_t bytes_rcvd = recv(connection_fd_,
+ reinterpret_cast<char*>(dest.data()),
+ dest.size_bytes(),
+ 0);
if (bytes_rcvd == 0) {
// Remote peer has closed the connection.
Close();
@@ -191,23 +174,6 @@ StatusWithSize SocketStream::DoRead(ByteSpan dest) {
return StatusWithSize(bytes_rcvd);
}
-int SocketStream::TakeConnectionFd() {
- std::lock_guard lock(connection_fd_mutex_);
- int fd = connection_fd_;
- ++connection_fd_own_count_;
- return fd;
-}
-
-void SocketStream::ReleaseConnectionFd() {
- std::lock_guard lock(connection_fd_mutex_);
- int fd = connection_fd_;
- --connection_fd_own_count_;
- if ((connection_fd_own_count_ <= 0) && (fd != kInvalidFd)) {
- connection_fd_ = kInvalidFd;
- close(fd);
- }
-}
-
// Listen for connections on the given port.
// If port is 0, a random unused port is chosen and can be retrieved with
// port().
@@ -216,8 +182,6 @@ Status ServerSocket::Listen(uint16_t port) {
if (socket_fd_ == kInvalidFd) {
return Status::Unknown();
}
- // Take ownership of the socket fd by this object.
- TakeSocketFd();
// Allow binding to an address that may still be in use by a closed socket.
constexpr int value = 1;
@@ -248,8 +212,7 @@ Status ServerSocket::Listen(uint16_t port) {
if (getsockname(socket_fd_, reinterpret_cast<sockaddr*>(&addr), &addr_len) <
0 ||
static_cast<size_t>(addr_len) > sizeof(addr)) {
- // Release ownership of the socket fd by this object.
- ReleaseSocketFd();
+ close(socket_fd_);
return Status::Unknown();
}
@@ -264,49 +227,23 @@ Result<SocketStream> ServerSocket::Accept() {
struct sockaddr_in6 sockaddr_client_ = {};
socklen_t len = sizeof(sockaddr_client_);
- int fd = TakeSocketFd();
- if (fd == kInvalidFd) {
- ReleaseSocketFd();
- return Status::Unknown();
- }
int connection_fd =
- accept(fd, reinterpret_cast<sockaddr*>(&sockaddr_client_), &len);
- ReleaseSocketFd();
+ accept(socket_fd_, reinterpret_cast<sockaddr*>(&sockaddr_client_), &len);
if (connection_fd == kInvalidFd) {
return Status::Unknown();
}
ConfigureSocket(connection_fd);
- return SocketStream(connection_fd);
+ SocketStream client_stream;
+ client_stream.connection_fd_ = connection_fd;
+ return client_stream;
}
// Close the server socket, preventing further connections.
void ServerSocket::Close() {
- int fd = TakeSocketFd();
- if (fd != kInvalidFd) {
- // Shutdown the connection to cancel any blocking calls.
- shutdown(fd, SHUT_RDWR);
-
- // Release ownership of the socket fd by this object.
- ReleaseSocketFd();
- }
- ReleaseSocketFd();
-}
-
-int ServerSocket::TakeSocketFd() {
- std::lock_guard lock(socket_fd_mutex_);
- int fd = socket_fd_;
- ++socket_fd_own_count_;
- return fd;
-}
-
-void ServerSocket::ReleaseSocketFd() {
- std::lock_guard lock(socket_fd_mutex_);
- int fd = socket_fd_;
- --socket_fd_own_count_;
- if ((socket_fd_own_count_ <= 0) && (fd != kInvalidFd)) {
+ if (socket_fd_ != kInvalidFd) {
+ close(socket_fd_);
socket_fd_ = kInvalidFd;
- close(fd);
}
}