aboutsummaryrefslogtreecommitdiff
path: root/src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.cc
blob: 7a86440cd7efa2eb2180978967b3484af300f511 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
//
//
// Copyright 2018 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//

#include <grpc/support/port_platform.h>

#include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h"

#include <string.h>

#include <memory>
#include <utility>

#include <grpc/support/alloc.h>
#include <grpc/support/log.h>

#include "src/core/tsi/alts/crypt/gsec.h"
#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_integrity_only_record_protocol.h"
#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_privacy_integrity_record_protocol.h"
#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol.h"
#include "src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.h"
#include "src/core/tsi/transport_security_grpc.h"

constexpr size_t kMinFrameLength = 1024;
constexpr size_t kDefaultFrameLength = 16 * 1024;
constexpr size_t kMaxFrameLength = 16 * 1024 * 1024;

///
/// Main struct for alts_zero_copy_grpc_protector.
/// We choose to have two alts_grpc_record_protocol objects and two sets of
/// slice buffers: one for protect and the other for unprotect, so that protect
/// and unprotect can be executed in parallel. Implementations of this object
/// must be thread compatible.
///
typedef struct alts_zero_copy_grpc_protector {
  tsi_zero_copy_grpc_protector base;
  alts_grpc_record_protocol* record_protocol;
  alts_grpc_record_protocol* unrecord_protocol;
  size_t max_protected_frame_size;
  size_t max_unprotected_data_size;
  grpc_slice_buffer unprotected_staging_sb;
  grpc_slice_buffer protected_sb;
  grpc_slice_buffer protected_staging_sb;
  uint32_t parsed_frame_size;
} alts_zero_copy_grpc_protector;

///
/// Given a slice buffer, parses the first 4 bytes little-endian unsigned frame
/// size and returns the total frame size including the frame field. Caller
/// needs to make sure the input slice buffer has at least 4 bytes. Returns true
/// on success and false on failure.
///
static bool read_frame_size(const grpc_slice_buffer* sb,
                            uint32_t* total_frame_size) {
  if (sb == nullptr || sb->length < kZeroCopyFrameLengthFieldSize) {
    return false;
  }
  uint8_t frame_size_buffer[kZeroCopyFrameLengthFieldSize];
  uint8_t* buf = frame_size_buffer;
  // Copies the first 4 bytes to a temporary buffer.
  size_t remaining = kZeroCopyFrameLengthFieldSize;
  for (size_t i = 0; i < sb->count; i++) {
    size_t slice_length = GRPC_SLICE_LENGTH(sb->slices[i]);
    if (remaining <= slice_length) {
      memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), remaining);
      remaining = 0;
      break;
    } else {
      memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), slice_length);
      buf += slice_length;
      remaining -= slice_length;
    }
  }
  GPR_ASSERT(remaining == 0);
  // Gets little-endian frame size.
  uint32_t frame_size = (static_cast<uint32_t>(frame_size_buffer[3]) << 24) |
                        (static_cast<uint32_t>(frame_size_buffer[2]) << 16) |
                        (static_cast<uint32_t>(frame_size_buffer[1]) << 8) |
                        static_cast<uint32_t>(frame_size_buffer[0]);
  if (frame_size > kMaxFrameLength) {
    gpr_log(GPR_ERROR, "Frame size is larger than maximum frame size");
    return false;
  }
  // Returns frame size including frame length field.
  *total_frame_size =
      static_cast<uint32_t>(frame_size + kZeroCopyFrameLengthFieldSize);
  return true;
}

///
/// Creates an alts_grpc_record_protocol object, given key, key size, and flags
/// to indicate whether the record_protocol object uses the rekeying AEAD,
/// whether the object is for client or server, whether the object is for
/// integrity-only or privacy-integrity mode, and whether the object is used
/// for protect or unprotect.
///
static tsi_result create_alts_grpc_record_protocol(
    std::unique_ptr<grpc_core::GsecKeyInterface> key, bool is_client,
    bool is_integrity_only, bool is_protect, bool enable_extra_copy,
    alts_grpc_record_protocol** record_protocol) {
  if (key == nullptr || record_protocol == nullptr) {
    return TSI_INVALID_ARGUMENT;
  }
  grpc_status_code status;
  gsec_aead_crypter* crypter = nullptr;
  char* error_details = nullptr;
  bool is_rekey = key->IsRekey();
  status = gsec_aes_gcm_aead_crypter_create(std::move(key), kAesGcmNonceLength,
                                            kAesGcmTagLength, &crypter,
                                            &error_details);
  if (status != GRPC_STATUS_OK) {
    gpr_log(GPR_ERROR, "Failed to create AEAD crypter, %s", error_details);
    gpr_free(error_details);
    return TSI_INTERNAL_ERROR;
  }
  size_t overflow_limit = is_rekey ? kAltsRecordProtocolRekeyFrameLimit
                                   : kAltsRecordProtocolFrameLimit;
  // Creates alts_grpc_record_protocol with AEAD crypter ownership transferred.
  //
  tsi_result result = is_integrity_only
                          ? alts_grpc_integrity_only_record_protocol_create(
                                crypter, overflow_limit, is_client, is_protect,
                                enable_extra_copy, record_protocol)
                          : alts_grpc_privacy_integrity_record_protocol_create(
                                crypter, overflow_limit, is_client, is_protect,
                                record_protocol);
  if (result != TSI_OK) {
    gsec_aead_crypter_destroy(crypter);
    return result;
  }
  return TSI_OK;
}

// --- tsi_zero_copy_grpc_protector methods implementation. ---

static tsi_result alts_zero_copy_grpc_protector_protect(
    tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* unprotected_slices,
    grpc_slice_buffer* protected_slices) {
  if (self == nullptr || unprotected_slices == nullptr ||
      protected_slices == nullptr) {
    gpr_log(GPR_ERROR, "Invalid nullptr arguments to zero-copy grpc protect.");
    return TSI_INVALID_ARGUMENT;
  }
  alts_zero_copy_grpc_protector* protector =
      reinterpret_cast<alts_zero_copy_grpc_protector*>(self);
  // Calls alts_grpc_record_protocol protect repeatly.
  while (unprotected_slices->length > protector->max_unprotected_data_size) {
    grpc_slice_buffer_move_first(unprotected_slices,
                                 protector->max_unprotected_data_size,
                                 &protector->unprotected_staging_sb);
    tsi_result status = alts_grpc_record_protocol_protect(
        protector->record_protocol, &protector->unprotected_staging_sb,
        protected_slices);
    if (status != TSI_OK) {
      return status;
    }
  }
  return alts_grpc_record_protocol_protect(
      protector->record_protocol, unprotected_slices, protected_slices);
}

static tsi_result alts_zero_copy_grpc_protector_unprotect(
    tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* protected_slices,
    grpc_slice_buffer* unprotected_slices, int* min_progress_size) {
  if (self == nullptr || unprotected_slices == nullptr ||
      protected_slices == nullptr) {
    gpr_log(GPR_ERROR,
            "Invalid nullptr arguments to zero-copy grpc unprotect.");
    return TSI_INVALID_ARGUMENT;
  }
  alts_zero_copy_grpc_protector* protector =
      reinterpret_cast<alts_zero_copy_grpc_protector*>(self);
  grpc_slice_buffer_move_into(protected_slices, &protector->protected_sb);
  // Keep unprotecting each frame if possible.
  while (protector->protected_sb.length >= kZeroCopyFrameLengthFieldSize) {
    if (protector->parsed_frame_size == 0) {
      // We have not parsed frame size yet. Parses frame size.
      if (!read_frame_size(&protector->protected_sb,
                           &protector->parsed_frame_size)) {
        grpc_slice_buffer_reset_and_unref(&protector->protected_sb);
        return TSI_DATA_CORRUPTED;
      }
    }
    if (protector->protected_sb.length < protector->parsed_frame_size) break;
    // At this point, protected_sb contains at least one frame of data.
    tsi_result status;
    if (protector->protected_sb.length == protector->parsed_frame_size) {
      status = alts_grpc_record_protocol_unprotect(protector->unrecord_protocol,
                                                   &protector->protected_sb,
                                                   unprotected_slices);
    } else {
      grpc_slice_buffer_move_first(&protector->protected_sb,
                                   protector->parsed_frame_size,
                                   &protector->protected_staging_sb);
      status = alts_grpc_record_protocol_unprotect(
          protector->unrecord_protocol, &protector->protected_staging_sb,
          unprotected_slices);
    }
    protector->parsed_frame_size = 0;
    if (status != TSI_OK) {
      grpc_slice_buffer_reset_and_unref(&protector->protected_sb);
      return status;
    }
  }
  if (min_progress_size != nullptr) {
    if (protector->parsed_frame_size > kZeroCopyFrameLengthFieldSize) {
      *min_progress_size =
          protector->parsed_frame_size - protector->protected_sb.length;
    } else {
      *min_progress_size = 1;
    }
  }
  return TSI_OK;
}

static void alts_zero_copy_grpc_protector_destroy(
    tsi_zero_copy_grpc_protector* self) {
  if (self == nullptr) {
    return;
  }
  alts_zero_copy_grpc_protector* protector =
      reinterpret_cast<alts_zero_copy_grpc_protector*>(self);
  alts_grpc_record_protocol_destroy(protector->record_protocol);
  alts_grpc_record_protocol_destroy(protector->unrecord_protocol);
  grpc_slice_buffer_destroy(&protector->unprotected_staging_sb);
  grpc_slice_buffer_destroy(&protector->protected_sb);
  grpc_slice_buffer_destroy(&protector->protected_staging_sb);
  gpr_free(protector);
}

static tsi_result alts_zero_copy_grpc_protector_max_frame_size(
    tsi_zero_copy_grpc_protector* self, size_t* max_frame_size) {
  if (self == nullptr || max_frame_size == nullptr) return TSI_INVALID_ARGUMENT;
  alts_zero_copy_grpc_protector* protector =
      reinterpret_cast<alts_zero_copy_grpc_protector*>(self);
  *max_frame_size = protector->max_protected_frame_size;
  return TSI_OK;
}

static const tsi_zero_copy_grpc_protector_vtable
    alts_zero_copy_grpc_protector_vtable = {
        alts_zero_copy_grpc_protector_protect,
        alts_zero_copy_grpc_protector_unprotect,
        alts_zero_copy_grpc_protector_destroy,
        alts_zero_copy_grpc_protector_max_frame_size};

tsi_result alts_zero_copy_grpc_protector_create(
    const grpc_core::GsecKeyFactoryInterface& key_factory, bool is_client,
    bool is_integrity_only, bool enable_extra_copy,
    size_t* max_protected_frame_size,
    tsi_zero_copy_grpc_protector** protector) {
  if (protector == nullptr) {
    gpr_log(
        GPR_ERROR,
        "Invalid nullptr arguments to alts_zero_copy_grpc_protector create.");
    return TSI_INVALID_ARGUMENT;
  }
  // Creates alts_zero_copy_protector.
  alts_zero_copy_grpc_protector* impl =
      static_cast<alts_zero_copy_grpc_protector*>(
          gpr_zalloc(sizeof(alts_zero_copy_grpc_protector)));
  // Creates alts_grpc_record_protocol objects.
  tsi_result status = create_alts_grpc_record_protocol(
      key_factory.Create(), is_client, is_integrity_only,
      /*is_protect=*/true, enable_extra_copy, &impl->record_protocol);
  if (status == TSI_OK) {
    status = create_alts_grpc_record_protocol(
        key_factory.Create(), is_client, is_integrity_only,
        /*is_protect=*/false, enable_extra_copy, &impl->unrecord_protocol);
    if (status == TSI_OK) {
      // Sets maximum frame size.
      size_t max_protected_frame_size_to_set = kDefaultFrameLength;
      if (max_protected_frame_size != nullptr) {
        *max_protected_frame_size =
            std::min(*max_protected_frame_size, kMaxFrameLength);
        *max_protected_frame_size =
            std::max(*max_protected_frame_size, kMinFrameLength);
        max_protected_frame_size_to_set = *max_protected_frame_size;
      }
      impl->max_protected_frame_size = max_protected_frame_size_to_set;
      impl->max_unprotected_data_size =
          alts_grpc_record_protocol_max_unprotected_data_size(
              impl->record_protocol, max_protected_frame_size_to_set);
      GPR_ASSERT(impl->max_unprotected_data_size > 0);
      // Allocates internal slice buffers.
      grpc_slice_buffer_init(&impl->unprotected_staging_sb);
      grpc_slice_buffer_init(&impl->protected_sb);
      grpc_slice_buffer_init(&impl->protected_staging_sb);
      impl->parsed_frame_size = 0;
      impl->base.vtable = &alts_zero_copy_grpc_protector_vtable;
      *protector = &impl->base;
      return TSI_OK;
    }
  }

  // Cleanup if create failed.
  alts_grpc_record_protocol_destroy(impl->record_protocol);
  alts_grpc_record_protocol_destroy(impl->unrecord_protocol);
  gpr_free(impl);
  return TSI_INTERNAL_ERROR;
}