aboutsummaryrefslogtreecommitdiff
path: root/include/perfetto/base/unix_socket.h
blob: f2b8003b791baebac2d02f5dbd944bba5ee08cd5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef INCLUDE_PERFETTO_BASE_UNIX_SOCKET_H_
#define INCLUDE_PERFETTO_BASE_UNIX_SOCKET_H_

#include <stdint.h>
#include <sys/types.h>

#include <memory>
#include <string>

#include "perfetto/base/logging.h"
#include "perfetto/base/scoped_file.h"
#include "perfetto/base/utils.h"
#include "perfetto/base/weak_ptr.h"

struct msghdr;

namespace perfetto {
namespace base {

class TaskRunner;

// Use arbitrarily high values to avoid that some code accidentally ends up
// assuming that these enum values match the sysroot's SOCK_xxx defines rather
// than using GetUnixSockType().
enum class SockType { kStream = 100, kDgram, kSeqPacket };

// UnixSocketRaw is a basic wrapper around UNIX sockets. It exposes wrapper
// methods that take care of most common pitfalls (e.g., marking fd as
// O_CLOEXEC, avoiding SIGPIPE, properly handling partial writes). It is used as
// a building block for the more sophisticated UnixSocket class.
class UnixSocketRaw {
 public:
  // Creates a new unconnected unix socket.
  static UnixSocketRaw CreateMayFail(SockType t) { return UnixSocketRaw(t); }

  // Crates a pair of connected sockets.
  static std::pair<UnixSocketRaw, UnixSocketRaw> CreatePair(SockType);

  // Creates an uninitialized unix socket.
  UnixSocketRaw();

  // Creates a unix socket adopting an existing file descriptor. This is
  // typically used to inherit fds from init via environment variables.
  UnixSocketRaw(ScopedFile, SockType);

  ~UnixSocketRaw() = default;
  UnixSocketRaw(UnixSocketRaw&&) noexcept = default;
  UnixSocketRaw& operator=(UnixSocketRaw&&) = default;

  bool Bind(const std::string& socket_name);
  bool Listen();
  bool Connect(const std::string& socket_name);
  bool SetTxTimeout(uint32_t timeout_ms);
  bool SetRxTimeout(uint32_t timeout_ms);
  void Shutdown();
  void SetBlocking(bool);
  bool IsBlocking() const;
  void RetainOnExec();
  SockType type() const { return type_; }
  int fd() const { return *fd_; }
  explicit operator bool() const { return !!fd_; }

  ScopedFile ReleaseFd() { return std::move(fd_); }

  ssize_t Send(const void* msg,
               size_t len,
               const int* send_fds = nullptr,
               size_t num_fds = 0);

  // Re-enter sendmsg until all the data has been sent or an error occurs.
  // TODO(fmayer): Figure out how to do timeouts here for heapprofd.
  ssize_t SendMsgAll(struct msghdr* msg);

  ssize_t Receive(void* msg,
                  size_t len,
                  ScopedFile* fd_vec = nullptr,
                  size_t max_files = 0);

  // Exposed for testing only.
  // Update msghdr so subsequent sendmsg will send data that remains after n
  // bytes have already been sent.
  static void ShiftMsgHdr(size_t n, struct msghdr* msg);

 private:
  explicit UnixSocketRaw(SockType);

  UnixSocketRaw(const UnixSocketRaw&) = delete;
  UnixSocketRaw& operator=(const UnixSocketRaw&) = delete;

  ScopedFile fd_;
  SockType type_{SockType::kStream};
};

// A non-blocking UNIX domain socket. Allows also to transfer file descriptors.
// None of the methods in this class are blocking.
// The main design goal is making strong guarantees on the EventListener
// callbacks, in order to avoid ending in some undefined state.
// In case of any error it will aggressively just shut down the socket and
// notify the failure with OnConnect(false) or OnDisconnect() depending on the
// state of the socket (see below).
// EventListener callbacks stop happening as soon as the instance is destroyed.
//
// Lifecycle of a client socket:
//
//                           Connect()
//                               |
//            +------------------+------------------+
//            | (success)                           | (failure or Shutdown())
//            V                                     V
//     OnConnect(true)                         OnConnect(false)
//            |
//            V
//    OnDataAvailable()
//            |
//            V
//     OnDisconnect()  (failure or shutdown)
//
//
// Lifecycle of a server socket:
//
//                          Listen()  --> returns false in case of errors.
//                             |
//                             V
//              OnNewIncomingConnection(new_socket)
//
//          (|new_socket| inherits the same EventListener)
//                             |
//                             V
//                     OnDataAvailable()
//                             | (failure or Shutdown())
//                             V
//                       OnDisconnect()
class UnixSocket {
 public:
  class EventListener {
   public:
    virtual ~EventListener();

    // After Listen().
    virtual void OnNewIncomingConnection(
        UnixSocket* self,
        std::unique_ptr<UnixSocket> new_connection);

    // After Connect(), whether successful or not.
    virtual void OnConnect(UnixSocket* self, bool connected);

    // After a successful Connect() or OnNewIncomingConnection(). Either the
    // other endpoint did disconnect or some other error happened.
    virtual void OnDisconnect(UnixSocket* self);

    // Whenever there is data available to Receive(). Note that spurious FD
    // watch events are possible, so it is possible that Receive() soon after
    // OnDataAvailable() returns 0 (just ignore those).
    virtual void OnDataAvailable(UnixSocket* self);
  };

  enum class State {
    kDisconnected = 0,  // Failed connection, peer disconnection or Shutdown().
    kConnecting,  // Soon after Connect(), before it either succeeds or fails.
    kConnected,   // After a successful Connect().
    kListening    // After Listen(), until Shutdown().
  };

  enum class BlockingMode { kNonBlocking, kBlocking };

  // Creates a Unix domain socket and starts listening. If |socket_name|
  // starts with a '@', an abstract socket will be created (Linux/Android only).
  // Returns always an instance. In case of failure (e.g., another socket
  // with the same name is  already listening) the returned socket will have
  // is_listening() == false and last_error() will contain the failure reason.
  static std::unique_ptr<UnixSocket> Listen(const std::string& socket_name,
                                            EventListener*,
                                            TaskRunner*,
                                            SockType = SockType::kStream);

  // Attaches to a pre-existing socket. The socket must have been created in
  // SOCK_STREAM mode and the caller must have called bind() on it.
  static std::unique_ptr<UnixSocket> Listen(ScopedFile,
                                            EventListener*,
                                            TaskRunner*,
                                            SockType = SockType::kStream);

  // Creates a Unix domain socket and connects to the listening endpoint.
  // Returns always an instance. EventListener::OnConnect(bool success) will
  // be called always, whether the connection succeeded or not.
  static std::unique_ptr<UnixSocket> Connect(const std::string& socket_name,
                                             EventListener*,
                                             TaskRunner*,
                                             SockType = SockType::kStream);

  // Constructs a UnixSocket using the given connected socket.
  static std::unique_ptr<UnixSocket> AdoptConnected(
      ScopedFile fd,
      EventListener* event_listener,
      TaskRunner* task_runner,
      SockType sock_type);

  UnixSocket(const UnixSocket&) = delete;
  UnixSocket& operator=(const UnixSocket&) = delete;
  // Cannot be easily moved because of tasks from the FileDescriptorWatch.
  UnixSocket(UnixSocket&&) = delete;
  UnixSocket& operator=(UnixSocket&&) = delete;

  // This class gives the hard guarantee that no callback is called on the
  // passed EventListener immediately after the object has been destroyed.
  // Any queued callback will be silently dropped.
  ~UnixSocket();

  // Shuts down the current connection, if any. If the socket was Listen()-ing,
  // stops listening. The socket goes back to kNotInitialized state, so it can
  // be reused with Listen() or Connect().
  void Shutdown(bool notify);

  // Returns true is the message was queued, false if there was no space in the
  // output buffer, in which case the client should retry or give up.
  // If any other error happens the socket will be shutdown and
  // EventListener::OnDisconnect() will be called.
  // If the socket is not connected, Send() will just return false.
  // Does not append a null string terminator to msg in any case.
  //
  // DO NOT PASS kNonBlocking, it is broken.
  bool Send(const void* msg,
            size_t len,
            const int* send_fds,
            size_t num_fds,
            BlockingMode blocking = BlockingMode::kNonBlocking);

  inline bool Send(const void* msg,
                   size_t len,
                   int send_fd = -1,
                   BlockingMode blocking = BlockingMode::kNonBlocking) {
    if (send_fd != -1)
      return Send(msg, len, &send_fd, 1, blocking);
    return Send(msg, len, nullptr, 0, blocking);
  }

  inline bool Send(const std::string& msg,
                   BlockingMode blocking = BlockingMode::kNonBlocking) {
    return Send(msg.c_str(), msg.size() + 1, -1, blocking);
  }

  // Returns the number of bytes (<= |len|) written in |msg| or 0 if there
  // is no data in the buffer to read or an error occurs (in which case a
  // EventListener::OnDisconnect() will follow).
  // If the ScopedFile pointer is not null and a FD is received, it moves the
  // received FD into that. If a FD is received but the ScopedFile pointer is
  // null, the FD will be automatically closed.
  size_t Receive(void* msg, size_t len, ScopedFile*, size_t max_files = 1);

  inline size_t Receive(void* msg, size_t len) {
    return Receive(msg, len, nullptr, 0);
  }

  // Only for tests. This is slower than Receive() as it requires a heap
  // allocation and a copy for the std::string. Guarantees that the returned
  // string is null terminated even if the underlying message sent by the peer
  // is not.
  std::string ReceiveString(size_t max_length = 1024);

  bool is_connected() const { return state_ == State::kConnected; }
  bool is_listening() const { return state_ == State::kListening; }
  int fd() const { return sock_raw_.fd(); }
  int last_error() const { return last_error_; }

  // User ID of the peer, as returned by the kernel. If the client disconnects
  // and the socket goes into the kDisconnected state, it retains the uid of
  // the last peer.
  uid_t peer_uid() const {
    PERFETTO_DCHECK(!is_listening() && peer_uid_ != kInvalidUid);
    return peer_uid_;
  }

#if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
    PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
  // Process ID of the peer, as returned by the kernel. If the client
  // disconnects and the socket goes into the kDisconnected state, it
  // retains the pid of the last peer.
  //
  // This is only available on Linux / Android.
  pid_t peer_pid() const {
    PERFETTO_DCHECK(!is_listening() && peer_pid_ != kInvalidPid);
    return peer_pid_;
  }
#endif

  // This makes the UnixSocket unusable.
  UnixSocketRaw ReleaseSocket();

 private:
  UnixSocket(EventListener*, TaskRunner*, SockType);
  UnixSocket(EventListener*, TaskRunner*, ScopedFile, State, SockType);

  // Called once by the corresponding public static factory methods.
  void DoConnect(const std::string& socket_name);
  void ReadPeerCredentials();

  void OnEvent();
  void NotifyConnectionState(bool success);

  UnixSocketRaw sock_raw_;
  State state_ = State::kDisconnected;
  int last_error_ = 0;
  uid_t peer_uid_ = kInvalidUid;
#if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
    PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
  pid_t peer_pid_ = kInvalidPid;
#endif
  EventListener* const event_listener_;
  TaskRunner* const task_runner_;
  WeakPtrFactory<UnixSocket> weak_ptr_factory_;  // Keep last.
};

}  // namespace base
}  // namespace perfetto

#endif  // INCLUDE_PERFETTO_BASE_UNIX_SOCKET_H_