aboutsummaryrefslogtreecommitdiff
path: root/include/perfetto/base/unix_socket.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/perfetto/base/unix_socket.h')
-rw-r--r--include/perfetto/base/unix_socket.h332
1 files changed, 332 insertions, 0 deletions
diff --git a/include/perfetto/base/unix_socket.h b/include/perfetto/base/unix_socket.h
new file mode 100644
index 000000000..f2b8003b7
--- /dev/null
+++ b/include/perfetto/base/unix_socket.h
@@ -0,0 +1,332 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INCLUDE_PERFETTO_BASE_UNIX_SOCKET_H_
+#define INCLUDE_PERFETTO_BASE_UNIX_SOCKET_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+
+#include <memory>
+#include <string>
+
+#include "perfetto/base/logging.h"
+#include "perfetto/base/scoped_file.h"
+#include "perfetto/base/utils.h"
+#include "perfetto/base/weak_ptr.h"
+
+struct msghdr;
+
+namespace perfetto {
+namespace base {
+
+class TaskRunner;
+
+// Use arbitrarily high values to avoid that some code accidentally ends up
+// assuming that these enum values match the sysroot's SOCK_xxx defines rather
+// than using GetUnixSockType().
+enum class SockType { kStream = 100, kDgram, kSeqPacket };
+
+// UnixSocketRaw is a basic wrapper around UNIX sockets. It exposes wrapper
+// methods that take care of most common pitfalls (e.g., marking fd as
+// O_CLOEXEC, avoiding SIGPIPE, properly handling partial writes). It is used as
+// a building block for the more sophisticated UnixSocket class.
+class UnixSocketRaw {
+ public:
+ // Creates a new unconnected unix socket.
+ static UnixSocketRaw CreateMayFail(SockType t) { return UnixSocketRaw(t); }
+
+ // Crates a pair of connected sockets.
+ static std::pair<UnixSocketRaw, UnixSocketRaw> CreatePair(SockType);
+
+ // Creates an uninitialized unix socket.
+ UnixSocketRaw();
+
+ // Creates a unix socket adopting an existing file descriptor. This is
+ // typically used to inherit fds from init via environment variables.
+ UnixSocketRaw(ScopedFile, SockType);
+
+ ~UnixSocketRaw() = default;
+ UnixSocketRaw(UnixSocketRaw&&) noexcept = default;
+ UnixSocketRaw& operator=(UnixSocketRaw&&) = default;
+
+ bool Bind(const std::string& socket_name);
+ bool Listen();
+ bool Connect(const std::string& socket_name);
+ bool SetTxTimeout(uint32_t timeout_ms);
+ bool SetRxTimeout(uint32_t timeout_ms);
+ void Shutdown();
+ void SetBlocking(bool);
+ bool IsBlocking() const;
+ void RetainOnExec();
+ SockType type() const { return type_; }
+ int fd() const { return *fd_; }
+ explicit operator bool() const { return !!fd_; }
+
+ ScopedFile ReleaseFd() { return std::move(fd_); }
+
+ ssize_t Send(const void* msg,
+ size_t len,
+ const int* send_fds = nullptr,
+ size_t num_fds = 0);
+
+ // Re-enter sendmsg until all the data has been sent or an error occurs.
+ // TODO(fmayer): Figure out how to do timeouts here for heapprofd.
+ ssize_t SendMsgAll(struct msghdr* msg);
+
+ ssize_t Receive(void* msg,
+ size_t len,
+ ScopedFile* fd_vec = nullptr,
+ size_t max_files = 0);
+
+ // Exposed for testing only.
+ // Update msghdr so subsequent sendmsg will send data that remains after n
+ // bytes have already been sent.
+ static void ShiftMsgHdr(size_t n, struct msghdr* msg);
+
+ private:
+ explicit UnixSocketRaw(SockType);
+
+ UnixSocketRaw(const UnixSocketRaw&) = delete;
+ UnixSocketRaw& operator=(const UnixSocketRaw&) = delete;
+
+ ScopedFile fd_;
+ SockType type_{SockType::kStream};
+};
+
+// A non-blocking UNIX domain socket. Allows also to transfer file descriptors.
+// None of the methods in this class are blocking.
+// The main design goal is making strong guarantees on the EventListener
+// callbacks, in order to avoid ending in some undefined state.
+// In case of any error it will aggressively just shut down the socket and
+// notify the failure with OnConnect(false) or OnDisconnect() depending on the
+// state of the socket (see below).
+// EventListener callbacks stop happening as soon as the instance is destroyed.
+//
+// Lifecycle of a client socket:
+//
+// Connect()
+// |
+// +------------------+------------------+
+// | (success) | (failure or Shutdown())
+// V V
+// OnConnect(true) OnConnect(false)
+// |
+// V
+// OnDataAvailable()
+// |
+// V
+// OnDisconnect() (failure or shutdown)
+//
+//
+// Lifecycle of a server socket:
+//
+// Listen() --> returns false in case of errors.
+// |
+// V
+// OnNewIncomingConnection(new_socket)
+//
+// (|new_socket| inherits the same EventListener)
+// |
+// V
+// OnDataAvailable()
+// | (failure or Shutdown())
+// V
+// OnDisconnect()
+class UnixSocket {
+ public:
+ class EventListener {
+ public:
+ virtual ~EventListener();
+
+ // After Listen().
+ virtual void OnNewIncomingConnection(
+ UnixSocket* self,
+ std::unique_ptr<UnixSocket> new_connection);
+
+ // After Connect(), whether successful or not.
+ virtual void OnConnect(UnixSocket* self, bool connected);
+
+ // After a successful Connect() or OnNewIncomingConnection(). Either the
+ // other endpoint did disconnect or some other error happened.
+ virtual void OnDisconnect(UnixSocket* self);
+
+ // Whenever there is data available to Receive(). Note that spurious FD
+ // watch events are possible, so it is possible that Receive() soon after
+ // OnDataAvailable() returns 0 (just ignore those).
+ virtual void OnDataAvailable(UnixSocket* self);
+ };
+
+ enum class State {
+ kDisconnected = 0, // Failed connection, peer disconnection or Shutdown().
+ kConnecting, // Soon after Connect(), before it either succeeds or fails.
+ kConnected, // After a successful Connect().
+ kListening // After Listen(), until Shutdown().
+ };
+
+ enum class BlockingMode { kNonBlocking, kBlocking };
+
+ // Creates a Unix domain socket and starts listening. If |socket_name|
+ // starts with a '@', an abstract socket will be created (Linux/Android only).
+ // Returns always an instance. In case of failure (e.g., another socket
+ // with the same name is already listening) the returned socket will have
+ // is_listening() == false and last_error() will contain the failure reason.
+ static std::unique_ptr<UnixSocket> Listen(const std::string& socket_name,
+ EventListener*,
+ TaskRunner*,
+ SockType = SockType::kStream);
+
+ // Attaches to a pre-existing socket. The socket must have been created in
+ // SOCK_STREAM mode and the caller must have called bind() on it.
+ static std::unique_ptr<UnixSocket> Listen(ScopedFile,
+ EventListener*,
+ TaskRunner*,
+ SockType = SockType::kStream);
+
+ // Creates a Unix domain socket and connects to the listening endpoint.
+ // Returns always an instance. EventListener::OnConnect(bool success) will
+ // be called always, whether the connection succeeded or not.
+ static std::unique_ptr<UnixSocket> Connect(const std::string& socket_name,
+ EventListener*,
+ TaskRunner*,
+ SockType = SockType::kStream);
+
+ // Constructs a UnixSocket using the given connected socket.
+ static std::unique_ptr<UnixSocket> AdoptConnected(
+ ScopedFile fd,
+ EventListener* event_listener,
+ TaskRunner* task_runner,
+ SockType sock_type);
+
+ UnixSocket(const UnixSocket&) = delete;
+ UnixSocket& operator=(const UnixSocket&) = delete;
+ // Cannot be easily moved because of tasks from the FileDescriptorWatch.
+ UnixSocket(UnixSocket&&) = delete;
+ UnixSocket& operator=(UnixSocket&&) = delete;
+
+ // This class gives the hard guarantee that no callback is called on the
+ // passed EventListener immediately after the object has been destroyed.
+ // Any queued callback will be silently dropped.
+ ~UnixSocket();
+
+ // Shuts down the current connection, if any. If the socket was Listen()-ing,
+ // stops listening. The socket goes back to kNotInitialized state, so it can
+ // be reused with Listen() or Connect().
+ void Shutdown(bool notify);
+
+ // Returns true is the message was queued, false if there was no space in the
+ // output buffer, in which case the client should retry or give up.
+ // If any other error happens the socket will be shutdown and
+ // EventListener::OnDisconnect() will be called.
+ // If the socket is not connected, Send() will just return false.
+ // Does not append a null string terminator to msg in any case.
+ //
+ // DO NOT PASS kNonBlocking, it is broken.
+ bool Send(const void* msg,
+ size_t len,
+ const int* send_fds,
+ size_t num_fds,
+ BlockingMode blocking = BlockingMode::kNonBlocking);
+
+ inline bool Send(const void* msg,
+ size_t len,
+ int send_fd = -1,
+ BlockingMode blocking = BlockingMode::kNonBlocking) {
+ if (send_fd != -1)
+ return Send(msg, len, &send_fd, 1, blocking);
+ return Send(msg, len, nullptr, 0, blocking);
+ }
+
+ inline bool Send(const std::string& msg,
+ BlockingMode blocking = BlockingMode::kNonBlocking) {
+ return Send(msg.c_str(), msg.size() + 1, -1, blocking);
+ }
+
+ // Returns the number of bytes (<= |len|) written in |msg| or 0 if there
+ // is no data in the buffer to read or an error occurs (in which case a
+ // EventListener::OnDisconnect() will follow).
+ // If the ScopedFile pointer is not null and a FD is received, it moves the
+ // received FD into that. If a FD is received but the ScopedFile pointer is
+ // null, the FD will be automatically closed.
+ size_t Receive(void* msg, size_t len, ScopedFile*, size_t max_files = 1);
+
+ inline size_t Receive(void* msg, size_t len) {
+ return Receive(msg, len, nullptr, 0);
+ }
+
+ // Only for tests. This is slower than Receive() as it requires a heap
+ // allocation and a copy for the std::string. Guarantees that the returned
+ // string is null terminated even if the underlying message sent by the peer
+ // is not.
+ std::string ReceiveString(size_t max_length = 1024);
+
+ bool is_connected() const { return state_ == State::kConnected; }
+ bool is_listening() const { return state_ == State::kListening; }
+ int fd() const { return sock_raw_.fd(); }
+ int last_error() const { return last_error_; }
+
+ // User ID of the peer, as returned by the kernel. If the client disconnects
+ // and the socket goes into the kDisconnected state, it retains the uid of
+ // the last peer.
+ uid_t peer_uid() const {
+ PERFETTO_DCHECK(!is_listening() && peer_uid_ != kInvalidUid);
+ return peer_uid_;
+ }
+
+#if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
+ PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
+ // Process ID of the peer, as returned by the kernel. If the client
+ // disconnects and the socket goes into the kDisconnected state, it
+ // retains the pid of the last peer.
+ //
+ // This is only available on Linux / Android.
+ pid_t peer_pid() const {
+ PERFETTO_DCHECK(!is_listening() && peer_pid_ != kInvalidPid);
+ return peer_pid_;
+ }
+#endif
+
+ // This makes the UnixSocket unusable.
+ UnixSocketRaw ReleaseSocket();
+
+ private:
+ UnixSocket(EventListener*, TaskRunner*, SockType);
+ UnixSocket(EventListener*, TaskRunner*, ScopedFile, State, SockType);
+
+ // Called once by the corresponding public static factory methods.
+ void DoConnect(const std::string& socket_name);
+ void ReadPeerCredentials();
+
+ void OnEvent();
+ void NotifyConnectionState(bool success);
+
+ UnixSocketRaw sock_raw_;
+ State state_ = State::kDisconnected;
+ int last_error_ = 0;
+ uid_t peer_uid_ = kInvalidUid;
+#if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
+ PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
+ pid_t peer_pid_ = kInvalidPid;
+#endif
+ EventListener* const event_listener_;
+ TaskRunner* const task_runner_;
+ WeakPtrFactory<UnixSocket> weak_ptr_factory_; // Keep last.
+};
+
+} // namespace base
+} // namespace perfetto
+
+#endif // INCLUDE_PERFETTO_BASE_UNIX_SOCKET_H_