aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorqiaoli <qiaoli@google.com>2023-05-18 19:31:22 +0000
committerqiaoli <qiaoli@google.com>2023-05-18 19:31:32 +0000
commitff80e6501f1733ebdc17626f82f8410cd5b986e5 (patch)
tree2fa20e4647f3b103a78b39688d56458b8454f2b7
parentd8aa2b4d14b31ffcf024535bc463b932dc5e2dc1 (diff)
downloadfederated-compute-ff80e6501f1733ebdc17626f82f8410cd5b986e5.tar.gz
Add initial federatedcompute code
Test: TH Bug: 242229007 Change-Id: I0b491758abe69ec502cc1276a131623afed3bf71
-rw-r--r--CONTRIBUTING7
-rw-r--r--GETTING_STARTED.md81
-rw-r--r--LICENSE202
-rw-r--r--METADATA12
-rw-r--r--MODULE_LICENSE_APACHE20
-rw-r--r--OWNERS3
-rw-r--r--README.md64
-rw-r--r--WORKSPACE287
-rw-r--r--fcp/BUILD21
-rw-r--r--fcp/TEST_MAPPING7
-rw-r--r--fcp/aggregation/BUILD27
-rw-r--r--fcp/aggregation/core/BUILD278
-rw-r--r--fcp/aggregation/core/agg_vector.h77
-rw-r--r--fcp/aggregation/core/agg_vector_aggregator.h150
-rw-r--r--fcp/aggregation/core/agg_vector_aggregator_test.cc174
-rw-r--r--fcp/aggregation/core/agg_vector_iterator.h113
-rw-r--r--fcp/aggregation/core/agg_vector_test.cc81
-rw-r--r--fcp/aggregation/core/aggregator.h74
-rw-r--r--fcp/aggregation/core/composite_key_combiner.cc283
-rw-r--r--fcp/aggregation/core/composite_key_combiner.h129
-rw-r--r--fcp/aggregation/core/composite_key_combiner_test.cc257
-rw-r--r--fcp/aggregation/core/datatype.h128
-rw-r--r--fcp/aggregation/core/federated_sum.cc88
-rw-r--r--fcp/aggregation/core/federated_sum_bench.cc60
-rw-r--r--fcp/aggregation/core/federated_sum_test.cc106
-rw-r--r--fcp/aggregation/core/input_tensor_list.cc98
-rw-r--r--fcp/aggregation/core/input_tensor_list.h97
-rw-r--r--fcp/aggregation/core/input_tensor_list_test.cc366
-rw-r--r--fcp/aggregation/core/mutable_vector_data.h47
-rw-r--r--fcp/aggregation/core/mutable_vector_data_test.cc38
-rw-r--r--fcp/aggregation/core/one_dim_grouping_aggregator.h209
-rw-r--r--fcp/aggregation/core/one_dim_grouping_aggregator_test.cc582
-rw-r--r--fcp/aggregation/core/tensor.cc257
-rw-r--r--fcp/aggregation/core/tensor.h126
-rw-r--r--fcp/aggregation/core/tensor.proto95
-rw-r--r--fcp/aggregation/core/tensor_aggregator.cc46
-rw-r--r--fcp/aggregation/core/tensor_aggregator.h70
-rw-r--r--fcp/aggregation/core/tensor_aggregator_factory.h47
-rw-r--r--fcp/aggregation/core/tensor_aggregator_registry.cc115
-rw-r--r--fcp/aggregation/core/tensor_aggregator_registry.h54
-rw-r--r--fcp/aggregation/core/tensor_aggregator_registry_test.cc49
-rw-r--r--fcp/aggregation/core/tensor_data.cc47
-rw-r--r--fcp/aggregation/core/tensor_data.h81
-rw-r--r--fcp/aggregation/core/tensor_data_test.cc71
-rw-r--r--fcp/aggregation/core/tensor_shape.cc65
-rw-r--r--fcp/aggregation/core/tensor_shape.h85
-rw-r--r--fcp/aggregation/core/tensor_shape_test.cc60
-rw-r--r--fcp/aggregation/core/tensor_spec.h48
-rw-r--r--fcp/aggregation/core/tensor_test.cc200
-rw-r--r--fcp/aggregation/core/vector_string_data.h53
-rw-r--r--fcp/aggregation/core/vector_string_data_test.cc39
-rw-r--r--fcp/aggregation/protocol/BUILD118
-rw-r--r--fcp/aggregation/protocol/aggregation_protocol.h169
-rw-r--r--fcp/aggregation/protocol/aggregation_protocol_messages.proto119
-rw-r--r--fcp/aggregation/protocol/checkpoint_builder.h54
-rw-r--r--fcp/aggregation/protocol/checkpoint_parser.h53
-rw-r--r--fcp/aggregation/protocol/configuration.proto68
-rw-r--r--fcp/aggregation/protocol/python/BUILD22
-rw-r--r--fcp/aggregation/protocol/python/aggregation_protocol.cc100
-rw-r--r--fcp/aggregation/protocol/resource_resolver.h42
-rw-r--r--fcp/aggregation/protocol/simple_aggregation/BUILD83
-rw-r--r--fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.cc558
-rw-r--r--fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.h217
-rw-r--r--fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol_test.cc972
-rw-r--r--fcp/aggregation/protocol/testing/BUILD19
-rw-r--r--fcp/aggregation/protocol/testing/test_callback.h42
-rw-r--r--fcp/aggregation/tensorflow/BUILD191
-rw-r--r--fcp/aggregation/tensorflow/checkpoint_reader.cc92
-rw-r--r--fcp/aggregation/tensorflow/checkpoint_reader.h66
-rw-r--r--fcp/aggregation/tensorflow/checkpoint_reader_test.cc83
-rw-r--r--fcp/aggregation/tensorflow/checkpoint_writer.cc93
-rw-r--r--fcp/aggregation/tensorflow/checkpoint_writer.h60
-rw-r--r--fcp/aggregation/tensorflow/checkpoint_writer_test.cc82
-rw-r--r--fcp/aggregation/tensorflow/converters.cc137
-rw-r--r--fcp/aggregation/tensorflow/converters.h58
-rw-r--r--fcp/aggregation/tensorflow/converters_test.cc152
-rw-r--r--fcp/aggregation/tensorflow/python/BUILD36
-rw-r--r--fcp/aggregation/tensorflow/python/aggregation_protocols.cc67
-rw-r--r--fcp/aggregation/tensorflow/python/aggregation_protocols_test.py119
-rw-r--r--fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.cc103
-rw-r--r--fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h35
-rw-r--r--fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory_test.cc116
-rw-r--r--fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.cc100
-rw-r--r--fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.h38
-rw-r--r--fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory_test.cc63
-rw-r--r--fcp/aggregation/testing/BUILD61
-rw-r--r--fcp/aggregation/testing/test_data.h36
-rw-r--r--fcp/aggregation/testing/testing.cc95
-rw-r--r--fcp/aggregation/testing/testing.h169
-rw-r--r--fcp/artifact_building/BUILD261
-rw-r--r--fcp/artifact_building/artifact_constants.py34
-rw-r--r--fcp/artifact_building/checkpoint_type.py29
-rw-r--r--fcp/artifact_building/checkpoint_utils.py520
-rw-r--r--fcp/artifact_building/checkpoint_utils_test.py364
-rw-r--r--fcp/artifact_building/data_spec.py150
-rw-r--r--fcp/artifact_building/data_spec_test.py71
-rw-r--r--fcp/artifact_building/federated_compute_plan_builder.py1802
-rw-r--r--fcp/artifact_building/graph_helpers.py659
-rw-r--r--fcp/artifact_building/graph_helpers_test.py409
-rw-r--r--fcp/artifact_building/plan_utils.py161
-rw-r--r--fcp/artifact_building/plan_utils_test.py252
-rw-r--r--fcp/artifact_building/proto_helpers.py129
-rw-r--r--fcp/artifact_building/proto_helpers_test.py184
-rw-r--r--fcp/artifact_building/tensor_utils.py153
-rw-r--r--fcp/artifact_building/tensor_utils_test.py157
-rw-r--r--fcp/artifact_building/test_utils.py40
-rw-r--r--fcp/artifact_building/test_utils_test.py44
-rw-r--r--fcp/artifact_building/type_checks.py103
-rw-r--r--fcp/artifact_building/type_checks_test.py109
-rw-r--r--fcp/artifact_building/variable_helpers.py460
-rw-r--r--fcp/artifact_building/variable_helpers_test.py328
-rw-r--r--fcp/base/BUILD556
-rw-r--r--fcp/base/base_name.cc41
-rw-r--r--fcp/base/base_name.h45
-rw-r--r--fcp/base/bounds.h180
-rw-r--r--fcp/base/bounds_test.cc316
-rw-r--r--fcp/base/clock.cc225
-rw-r--r--fcp/base/clock.h107
-rw-r--r--fcp/base/error.h53
-rw-r--r--fcp/base/future.cc15
-rw-r--r--fcp/base/future.h290
-rw-r--r--fcp/base/future_test.cc201
-rw-r--r--fcp/base/golden_file.bzl65
-rw-r--r--fcp/base/match.h292
-rw-r--r--fcp/base/match_test.cc260
-rw-r--r--fcp/base/meta.h444
-rw-r--r--fcp/base/meta_test.cc397
-rw-r--r--fcp/base/monitoring.cc131
-rw-r--r--fcp/base/monitoring.h587
-rw-r--r--fcp/base/monitoring_test.cc269
-rw-r--r--fcp/base/move_to_lambda.h90
-rw-r--r--fcp/base/move_to_lambda_test.cc56
-rw-r--r--fcp/base/new.h13
-rw-r--r--fcp/base/platform.cc131
-rw-r--r--fcp/base/platform.h76
-rw-r--r--fcp/base/platform_test.cc89
-rw-r--r--fcp/base/process_unique_id.cc30
-rw-r--r--fcp/base/process_unique_id.h39
-rw-r--r--fcp/base/process_unique_id_test.cc50
-rw-r--r--fcp/base/random_token.cc66
-rw-r--r--fcp/base/random_token.h105
-rw-r--r--fcp/base/random_token_test.cc84
-rw-r--r--fcp/base/realtime_clock_test.cc88
-rw-r--r--fcp/base/reentrancy_guard.h72
-rw-r--r--fcp/base/reentrancy_guard_test.cc86
-rw-r--r--fcp/base/result.cc74
-rw-r--r--fcp/base/result.h401
-rw-r--r--fcp/base/result_test.cc317
-rw-r--r--fcp/base/scheduler.cc205
-rw-r--r--fcp/base/scheduler.h110
-rw-r--r--fcp/base/scheduler_test.cc128
-rw-r--r--fcp/base/simulated_clock.cc66
-rw-r--r--fcp/base/simulated_clock.h63
-rw-r--r--fcp/base/simulated_clock_test.cc193
-rw-r--r--fcp/base/source_location.h57
-rw-r--r--fcp/base/source_location_test.cc36
-rw-r--r--fcp/base/status_converters.cc94
-rw-r--r--fcp/base/status_converters.h50
-rw-r--r--fcp/base/string_stream.cc54
-rw-r--r--fcp/base/string_stream.h41
-rw-r--r--fcp/base/string_stream_test.cc20
-rw-r--r--fcp/base/time_util.cc67
-rw-r--r--fcp/base/time_util.h56
-rw-r--r--fcp/base/time_util_test.cc60
-rw-r--r--fcp/base/tracing_schema.fbs55
-rw-r--r--fcp/base/tracing_schema.h134
-rw-r--r--fcp/base/tracing_schema_generated.h586
-rw-r--r--fcp/base/unique_value.h123
-rw-r--r--fcp/base/unique_value_test.cc167
-rw-r--r--fcp/base/wall_clock_stopwatch.cc105
-rw-r--r--fcp/base/wall_clock_stopwatch.h89
-rw-r--r--fcp/base/wall_clock_stopwatch_test.cc251
-rw-r--r--fcp/client/BUILD773
-rw-r--r--fcp/client/README.md17
-rw-r--r--fcp/client/cache/BUILD124
-rw-r--r--fcp/client/cache/cache_manifest.proto41
-rw-r--r--fcp/client/cache/file_backed_resource_cache.cc500
-rw-r--r--fcp/client/cache/file_backed_resource_cache.h147
-rw-r--r--fcp/client/cache/file_backed_resource_cache_test.cc569
-rw-r--r--fcp/client/cache/resource_cache.h73
-rw-r--r--fcp/client/cache/temp_files.cc129
-rw-r--r--fcp/client/cache/temp_files.h79
-rw-r--r--fcp/client/cache/temp_files_test.cc135
-rw-r--r--fcp/client/cache/test_helpers.h43
-rw-r--r--fcp/client/client_runner.h238
-rw-r--r--fcp/client/client_runner_example_data.proto28
-rw-r--r--fcp/client/client_runner_main.cc135
-rw-r--r--fcp/client/diag_codes.proto335
-rw-r--r--fcp/client/engine/BUILD305
-rw-r--r--fcp/client/engine/caching_error_reporter.cc53
-rw-r--r--fcp/client/engine/caching_error_reporter.h47
-rw-r--r--fcp/client/engine/caching_error_reporter_test.cc49
-rw-r--r--fcp/client/engine/common.cc103
-rw-r--r--fcp/client/engine/common.h84
-rw-r--r--fcp/client/engine/data/BUILD24
-rw-r--r--fcp/client/engine/data/README.md11
-rw-r--r--fcp/client/engine/data/join_model.flatbufferbin0 -> 656 bytes
-rw-r--r--fcp/client/engine/data/length_model.flatbufferbin0 -> 592 bytes
-rw-r--r--fcp/client/engine/engine.proto55
-rw-r--r--fcp/client/engine/example_iterator_factory.h118
-rw-r--r--fcp/client/engine/example_query_plan_engine.cc247
-rw-r--r--fcp/client/engine/example_query_plan_engine.h52
-rw-r--r--fcp/client/engine/example_query_plan_engine_test.cc547
-rw-r--r--fcp/client/engine/plan_engine_helpers.cc285
-rw-r--r--fcp/client/engine/plan_engine_helpers.h190
-rw-r--r--fcp/client/engine/simple_plan_engine.cc184
-rw-r--r--fcp/client/engine/simple_plan_engine.h106
-rw-r--r--fcp/client/engine/tf_wrapper.cc190
-rw-r--r--fcp/client/engine/tf_wrapper.h106
-rw-r--r--fcp/client/engine/tf_wrapper_test.cc131
-rw-r--r--fcp/client/engine/tflite_plan_engine.cc155
-rw-r--r--fcp/client/engine/tflite_plan_engine.h80
-rw-r--r--fcp/client/engine/tflite_plan_engine_test.cc224
-rw-r--r--fcp/client/engine/tflite_wrapper.cc210
-rw-r--r--fcp/client/engine/tflite_wrapper.h121
-rw-r--r--fcp/client/engine/tflite_wrapper_test.cc140
-rw-r--r--fcp/client/event_publisher.h274
-rw-r--r--fcp/client/example_query_result.proto73
-rw-r--r--fcp/client/fake_event_publisher.h398
-rw-r--r--fcp/client/fake_log_manager.h44
-rw-r--r--fcp/client/fake_server.cc121
-rw-r--r--fcp/client/fake_server.h89
-rw-r--r--fcp/client/federated_protocol.h397
-rw-r--r--fcp/client/federated_protocol_util.cc116
-rw-r--r--fcp/client/federated_protocol_util.h66
-rw-r--r--fcp/client/federated_protocol_util_test.cc80
-rw-r--r--fcp/client/federated_select.cc306
-rw-r--r--fcp/client/federated_select.h162
-rw-r--r--fcp/client/federated_select_test.cc451
-rw-r--r--fcp/client/files.h46
-rw-r--r--fcp/client/fl_runner.cc1638
-rw-r--r--fcp/client/fl_runner.h112
-rw-r--r--fcp/client/fl_runner.proto74
-rw-r--r--fcp/client/flags.h216
-rw-r--r--fcp/client/grpc_bidi_channel.h105
-rw-r--r--fcp/client/grpc_bidi_channel_test.cc39
-rw-r--r--fcp/client/grpc_bidi_stream.cc139
-rw-r--r--fcp/client/grpc_bidi_stream.h160
-rw-r--r--fcp/client/grpc_bidi_stream_test.cc165
-rw-r--r--fcp/client/grpc_federated_protocol.cc1074
-rw-r--r--fcp/client/grpc_federated_protocol.h269
-rw-r--r--fcp/client/grpc_federated_protocol_test.cc1771
-rw-r--r--fcp/client/histogram_counters.proto178
-rw-r--r--fcp/client/http/BUILD276
-rw-r--r--fcp/client/http/README.md6
-rw-r--r--fcp/client/http/curl/BUILD65
-rw-r--r--fcp/client/http/curl/curl_api.cc100
-rw-r--r--fcp/client/http/curl/curl_api.h113
-rw-r--r--fcp/client/http/curl/curl_header_parser.cc108
-rw-r--r--fcp/client/http/curl/curl_header_parser.h62
-rw-r--r--fcp/client/http/curl/curl_header_parser_test.cc118
-rw-r--r--fcp/client/http/curl/curl_http_client.cc113
-rw-r--r--fcp/client/http/curl/curl_http_client.h60
-rw-r--r--fcp/client/http/curl/curl_http_client_test.cc564
-rw-r--r--fcp/client/http/curl/curl_http_request_handle.cc386
-rw-r--r--fcp/client/http/curl/curl_http_request_handle.h102
-rw-r--r--fcp/client/http/curl/curl_http_response.cc31
-rw-r--r--fcp/client/http/curl/curl_http_response.h41
-rw-r--r--fcp/client/http/http_client.h468
-rw-r--r--fcp/client/http/http_client_util.cc246
-rw-r--r--fcp/client/http/http_client_util.h135
-rw-r--r--fcp/client/http/http_client_util_test.cc323
-rw-r--r--fcp/client/http/http_federated_protocol.cc1428
-rw-r--r--fcp/client/http/http_federated_protocol.h306
-rw-r--r--fcp/client/http/http_federated_protocol_test.cc3062
-rw-r--r--fcp/client/http/http_resource_metadata.proto29
-rw-r--r--fcp/client/http/http_secagg_send_to_server_impl.cc452
-rw-r--r--fcp/client/http/http_secagg_send_to_server_impl.h181
-rw-r--r--fcp/client/http/http_secagg_send_to_server_impl_test.cc756
-rw-r--r--fcp/client/http/in_memory_request_response.cc607
-rw-r--r--fcp/client/http/in_memory_request_response.h251
-rw-r--r--fcp/client/http/in_memory_request_response_test.cc1576
-rw-r--r--fcp/client/http/java/BUILD62
-rw-r--r--fcp/client/http/java/java_http_client.cc531
-rw-r--r--fcp/client/http/java/java_http_client.h181
-rw-r--r--fcp/client/http/java/jni.proto62
-rw-r--r--fcp/client/http/protocol_request_helper.cc377
-rw-r--r--fcp/client/http/protocol_request_helper.h170
-rw-r--r--fcp/client/http/protocol_request_helper_test.cc762
-rw-r--r--fcp/client/http/testing/BUILD40
-rw-r--r--fcp/client/http/testing/http_test_server.cc105
-rw-r--r--fcp/client/http/testing/http_test_server.h41
-rw-r--r--fcp/client/http/testing/test_helpers.cc250
-rw-r--r--fcp/client/http/testing/test_helpers.h173
-rw-r--r--fcp/client/interruptible_runner.cc94
-rw-r--r--fcp/client/interruptible_runner.h94
-rw-r--r--fcp/client/interruptible_runner_test.cc258
-rw-r--r--fcp/client/lc_runner.cc362
-rw-r--r--fcp/client/lc_runner.h68
-rw-r--r--fcp/client/log_manager.h62
-rw-r--r--fcp/client/opstats/BUILD171
-rw-r--r--fcp/client/opstats/opstats_db.h54
-rw-r--r--fcp/client/opstats/opstats_example_store.cc254
-rw-r--r--fcp/client/opstats/opstats_example_store.h87
-rw-r--r--fcp/client/opstats/opstats_example_store_test.cc601
-rw-r--r--fcp/client/opstats/opstats_logger.h104
-rw-r--r--fcp/client/opstats/opstats_logger_impl.cc147
-rw-r--r--fcp/client/opstats/opstats_logger_impl.h112
-rw-r--r--fcp/client/opstats/opstats_logger_impl_test.cc574
-rw-r--r--fcp/client/opstats/opstats_utils.cc84
-rw-r--r--fcp/client/opstats/opstats_utils.h47
-rw-r--r--fcp/client/opstats/opstats_utils_test.cc154
-rw-r--r--fcp/client/opstats/pds_backed_opstats_db.cc297
-rw-r--r--fcp/client/opstats/pds_backed_opstats_db.h91
-rw-r--r--fcp/client/opstats/pds_backed_opstats_db_test.cc519
-rw-r--r--fcp/client/parsing_utils.h43
-rw-r--r--fcp/client/phase_logger.h222
-rw-r--r--fcp/client/phase_logger_impl.cc638
-rw-r--r--fcp/client/phase_logger_impl.h214
-rw-r--r--fcp/client/phase_logger_impl_test.cc916
-rw-r--r--fcp/client/secagg_event_publisher.h55
-rw-r--r--fcp/client/secagg_runner.cc224
-rw-r--r--fcp/client/secagg_runner.h120
-rw-r--r--fcp/client/selector_context.proto127
-rw-r--r--fcp/client/simple_task_environment.cc33
-rw-r--r--fcp/client/simple_task_environment.h102
-rw-r--r--fcp/client/simple_task_environment_test.cc84
-rw-r--r--fcp/client/stats.h73
-rw-r--r--fcp/client/test_helpers.cc175
-rw-r--r--fcp/client/test_helpers.h861
-rw-r--r--fcp/client/testing/BUILD51
-rw-r--r--fcp/client/testing/utils.h127
-rw-r--r--fcp/client/testing/utils_test.cc89
-rw-r--r--fcp/config.bzl33
-rw-r--r--fcp/demo/BUILD338
-rw-r--r--fcp/demo/README.md242
-rw-r--r--fcp/demo/__init__.py25
-rw-r--r--fcp/demo/aggregations.py554
-rw-r--r--fcp/demo/aggregations_test.py783
-rw-r--r--fcp/demo/checkpoint_tensor_reference.py66
-rw-r--r--fcp/demo/checkpoint_tensor_reference_test.py88
-rw-r--r--fcp/demo/eligibility_eval_tasks.py138
-rw-r--r--fcp/demo/eligibility_eval_tasks_test.py169
-rw-r--r--fcp/demo/federated_computation.py79
-rw-r--r--fcp/demo/federated_computation_test.py158
-rw-r--r--fcp/demo/federated_context.py314
-rw-r--r--fcp/demo/federated_context_test.py438
-rw-r--r--fcp/demo/federated_data_source.py141
-rw-r--r--fcp/demo/federated_data_source_test.py128
-rw-r--r--fcp/demo/federated_program_test.py172
-rw-r--r--fcp/demo/http_actions.py295
-rw-r--r--fcp/demo/http_actions_test.py216
-rw-r--r--fcp/demo/media.py135
-rw-r--r--fcp/demo/media_test.py188
-rw-r--r--fcp/demo/plan_utils.py203
-rw-r--r--fcp/demo/plan_utils_test.py350
-rw-r--r--fcp/demo/server.py164
-rw-r--r--fcp/demo/server_test.py284
-rw-r--r--fcp/demo/task_assignments.py230
-rw-r--r--fcp/demo/task_assignments_test.py453
-rw-r--r--fcp/demo/test_utils.py44
-rw-r--r--fcp/demo/test_utils_test.py52
-rw-r--r--fcp/dictionary/BUILD67
-rw-r--r--fcp/dictionary/dictionary.cc184
-rw-r--r--fcp/dictionary/dictionary.h72
-rw-r--r--fcp/dictionary/dictionary.proto61
-rw-r--r--fcp/dictionary/dictionary_test.cc114
-rw-r--r--fcp/java_src/main/java/com/google/fcp/client/BUILD27
-rw-r--r--fcp/java_src/main/java/com/google/fcp/client/CallFromNativeWrapper.java76
-rw-r--r--fcp/java_src/main/java/com/google/fcp/client/http/BUILD48
-rw-r--r--fcp/java_src/main/java/com/google/fcp/client/http/HttpClientForNative.java233
-rw-r--r--fcp/java_src/main/java/com/google/fcp/client/http/HttpClientForNativeImpl.java114
-rw-r--r--fcp/java_src/main/java/com/google/fcp/client/http/HttpRequestHandleImpl.java1052
-rw-r--r--fcp/java_src/test/java/com/google/fcp/client/http/BUILD33
-rw-r--r--fcp/java_src/test/java/com/google/fcp/client/http/HttpClientForNativeImplTest.java1723
-rw-r--r--fcp/jni/BUILD33
-rw-r--r--fcp/jni/jni_util.h178
-rw-r--r--fcp/patches/BUILD1
-rw-r--r--fcp/patches/googleapis_longrunning.patch16
-rw-r--r--fcp/patches/googletest.patch55
-rw-r--r--fcp/patches/tensorflow_googleapis_proto_rules.patch11
-rw-r--r--fcp/patches/tensorflow_llvm_url.patch23
-rw-r--r--fcp/patches/tensorflow_pybind11_osx.patch11
-rw-r--r--fcp/patches/tensorflow_serving.patch25
-rw-r--r--fcp/patches/tensorflow_tf_custom_op_py_library.patch10
-rw-r--r--fcp/patches/tensorflow_zlib.patch11
-rw-r--r--fcp/protocol/BUILD55
-rw-r--r--fcp/protocol/grpc_chunked_bidi_stream.h484
-rw-r--r--fcp/protocol/grpc_chunked_bidi_stream_test.cc330
-rw-r--r--fcp/protos/BUILD126
-rw-r--r--fcp/protos/federated_api.proto809
-rw-r--r--fcp/protos/federatedcompute/BUILD60
-rw-r--r--fcp/protos/federatedcompute/aggregations.proto144
-rw-r--r--fcp/protos/federatedcompute/common.proto313
-rw-r--r--fcp/protos/federatedcompute/eligibility_eval_tasks.proto208
-rw-r--r--fcp/protos/federatedcompute/secure_aggregations.proto343
-rw-r--r--fcp/protos/federatedcompute/task_assignments.proto282
-rw-r--r--fcp/protos/opstats.proto346
-rw-r--r--fcp/protos/plan.proto1380
-rw-r--r--fcp/protos/task_eligibility_context.proto51
-rw-r--r--fcp/secagg/client/BUILD113
-rw-r--r--fcp/secagg/client/other_client_state.h36
-rw-r--r--fcp/secagg/client/secagg_client.cc129
-rw-r--r--fcp/secagg/client/secagg_client.h171
-rw-r--r--fcp/secagg/client/secagg_client_aborted_state.cc50
-rw-r--r--fcp/secagg/client/secagg_client_aborted_state.h58
-rw-r--r--fcp/secagg/client/secagg_client_aborted_state_test.cc123
-rw-r--r--fcp/secagg/client/secagg_client_alive_base_state.cc55
-rw-r--r--fcp/secagg/client/secagg_client_alive_base_state.h60
-rw-r--r--fcp/secagg/client/secagg_client_completed_state.cc46
-rw-r--r--fcp/secagg/client/secagg_client_completed_state.h56
-rw-r--r--fcp/secagg/client/secagg_client_completed_state_test.cc115
-rw-r--r--fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.cc125
-rw-r--r--fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.h83
-rw-r--r--fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state_test.cc404
-rw-r--r--fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.cc112
-rw-r--r--fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.h81
-rw-r--r--fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state_test.cc290
-rw-r--r--fcp/secagg/client/secagg_client_r1_share_keys_base_state.cc202
-rw-r--r--fcp/secagg/client/secagg_client_r1_share_keys_base_state.h98
-rw-r--r--fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.cc154
-rw-r--r--fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.h89
-rw-r--r--fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state_test.cc1004
-rw-r--r--fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.cc133
-rw-r--r--fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.h87
-rw-r--r--fcp/secagg/client/secagg_client_r1_share_keys_input_set_state_test.cc901
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.cc218
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h94
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.cc155
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.h98
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state_test.cc735
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.cc137
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h94
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state_test.cc612
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.cc122
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.h92
-rw-r--r--fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state_test.cc482
-rw-r--r--fcp/secagg/client/secagg_client_r3_unmasking_state.cc168
-rw-r--r--fcp/secagg/client/secagg_client_r3_unmasking_state.h76
-rw-r--r--fcp/secagg/client/secagg_client_r3_unmasking_state_test.cc480
-rw-r--r--fcp/secagg/client/secagg_client_state.cc112
-rw-r--r--fcp/secagg/client/secagg_client_state.h108
-rw-r--r--fcp/secagg/client/secagg_client_test.cc262
-rw-r--r--fcp/secagg/client/send_to_server_interface.h40
-rw-r--r--fcp/secagg/client/state_transition_listener_interface.h71
-rw-r--r--fcp/secagg/server/BUILD360
-rw-r--r--fcp/secagg/server/aes/BUILD32
-rw-r--r--fcp/secagg/server/aes/aes_secagg_server_protocol_impl.cc223
-rw-r--r--fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h109
-rw-r--r--fcp/secagg/server/distribution_utilities.cc140
-rw-r--r--fcp/secagg/server/distribution_utilities.h62
-rw-r--r--fcp/secagg/server/distribution_utilities_test.cc166
-rw-r--r--fcp/secagg/server/experiments_interface.h38
-rw-r--r--fcp/secagg/server/experiments_names.h34
-rw-r--r--fcp/secagg/server/graph_parameter_finder.cc335
-rw-r--r--fcp/secagg/server/graph_parameter_finder.h53
-rw-r--r--fcp/secagg/server/graph_parameter_finder_test.cc572
-rw-r--r--fcp/secagg/server/secagg_scheduler.cc34
-rw-r--r--fcp/secagg/server/secagg_scheduler.h346
-rw-r--r--fcp/secagg/server/secagg_scheduler_test.cc287
-rw-r--r--fcp/secagg/server/secagg_server.cc369
-rw-r--r--fcp/secagg/server/secagg_server.h344
-rw-r--r--fcp/secagg/server/secagg_server_aborted_state.cc51
-rw-r--r--fcp/secagg/server/secagg_server_aborted_state.h59
-rw-r--r--fcp/secagg/server/secagg_server_aborted_state_test.cc309
-rw-r--r--fcp/secagg/server/secagg_server_completed_state.cc70
-rw-r--r--fcp/secagg/server/secagg_server_completed_state.h57
-rw-r--r--fcp/secagg/server/secagg_server_completed_state_test.cc261
-rw-r--r--fcp/secagg/server/secagg_server_enums.proto149
-rw-r--r--fcp/secagg/server/secagg_server_messages.proto44
-rw-r--r--fcp/secagg/server/secagg_server_metrics_listener.h102
-rw-r--r--fcp/secagg/server/secagg_server_prng_running_state.cc181
-rw-r--r--fcp/secagg/server/secagg_server_prng_running_state.h96
-rw-r--r--fcp/secagg/server/secagg_server_prng_running_state_test.cc997
-rw-r--r--fcp/secagg/server/secagg_server_protocol_impl.cc403
-rw-r--r--fcp/secagg/server/secagg_server_protocol_impl.h383
-rw-r--r--fcp/secagg/server/secagg_server_r0_advertise_keys_state.cc178
-rw-r--r--fcp/secagg/server/secagg_server_r0_advertise_keys_state.h67
-rw-r--r--fcp/secagg/server/secagg_server_r0_advertise_keys_state_test.cc795
-rw-r--r--fcp/secagg/server/secagg_server_r1_share_keys_state.cc163
-rw-r--r--fcp/secagg/server/secagg_server_r1_share_keys_state.h68
-rw-r--r--fcp/secagg/server/secagg_server_r1_share_keys_state_test.cc829
-rw-r--r--fcp/secagg/server/secagg_server_r2_masked_input_coll_state.cc211
-rw-r--r--fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h85
-rw-r--r--fcp/secagg/server/secagg_server_r2_masked_input_coll_state_test.cc931
-rw-r--r--fcp/secagg/server/secagg_server_r3_unmasking_state.cc167
-rw-r--r--fcp/secagg/server/secagg_server_r3_unmasking_state.h72
-rw-r--r--fcp/secagg/server/secagg_server_r3_unmasking_state_test.cc924
-rw-r--r--fcp/secagg/server/secagg_server_state.cc320
-rw-r--r--fcp/secagg/server/secagg_server_state.h314
-rw-r--r--fcp/secagg/server/secagg_server_test.cc404
-rw-r--r--fcp/secagg/server/secagg_trace_utility.cc173
-rw-r--r--fcp/secagg/server/secagg_trace_utility.h53
-rw-r--r--fcp/secagg/server/secret_sharing_complete_graph.h97
-rw-r--r--fcp/secagg/server/secret_sharing_complete_graph_test.cc106
-rw-r--r--fcp/secagg/server/secret_sharing_graph.h89
-rw-r--r--fcp/secagg/server/secret_sharing_graph_factory.h82
-rw-r--r--fcp/secagg/server/secret_sharing_harary_graph.cc94
-rw-r--r--fcp/secagg/server/secret_sharing_harary_graph.h135
-rw-r--r--fcp/secagg/server/secret_sharing_harary_graph_test.cc315
-rw-r--r--fcp/secagg/server/send_to_clients_interface.h44
-rw-r--r--fcp/secagg/server/ssl_bit_gen.cc43
-rw-r--r--fcp/secagg/server/ssl_bit_gen.h72
-rw-r--r--fcp/secagg/server/tracing_schema.fbs248
-rw-r--r--fcp/secagg/shared/BUILD255
-rw-r--r--fcp/secagg/shared/add_maps_bench.cc75
-rw-r--r--fcp/secagg/shared/aes_ctr_prng.cc89
-rw-r--r--fcp/secagg/shared/aes_ctr_prng.h106
-rw-r--r--fcp/secagg/shared/aes_ctr_prng_factory.cc33
-rw-r--r--fcp/secagg/shared/aes_ctr_prng_factory.h47
-rw-r--r--fcp/secagg/shared/aes_ctr_prng_test.cc222
-rw-r--r--fcp/secagg/shared/aes_gcm_encryption.cc92
-rw-r--r--fcp/secagg/shared/aes_gcm_encryption.h50
-rw-r--r--fcp/secagg/shared/aes_gcm_encryption_test.cc171
-rw-r--r--fcp/secagg/shared/aes_key.cc80
-rw-r--r--fcp/secagg/shared/aes_key.h52
-rw-r--r--fcp/secagg/shared/aes_key_test.cc95
-rw-r--r--fcp/secagg/shared/aes_prng_factory.h42
-rw-r--r--fcp/secagg/shared/async_abort.h76
-rw-r--r--fcp/secagg/shared/compute_session_id.cc53
-rw-r--r--fcp/secagg/shared/compute_session_id.h58
-rw-r--r--fcp/secagg/shared/compute_session_id_test.cc99
-rw-r--r--fcp/secagg/shared/crypto_rand_prng.cc41
-rw-r--r--fcp/secagg/shared/crypto_rand_prng.h43
-rw-r--r--fcp/secagg/shared/ecdh_key_agreement.cc153
-rw-r--r--fcp/secagg/shared/ecdh_key_agreement.h103
-rw-r--r--fcp/secagg/shared/ecdh_key_agreement_test.cc172
-rw-r--r--fcp/secagg/shared/ecdh_keys.h67
-rw-r--r--fcp/secagg/shared/input_vector_specification.cc37
-rw-r--r--fcp/secagg/shared/input_vector_specification.h57
-rw-r--r--fcp/secagg/shared/input_vector_specification_test.cc60
-rw-r--r--fcp/secagg/shared/key.h56
-rw-r--r--fcp/secagg/shared/map_of_masks.cc372
-rw-r--r--fcp/secagg/shared/map_of_masks.h84
-rw-r--r--fcp/secagg/shared/map_of_masks_bench.cc169
-rw-r--r--fcp/secagg/shared/map_of_masks_test.cc553
-rw-r--r--fcp/secagg/shared/math.h122
-rw-r--r--fcp/secagg/shared/math_test.cc205
-rw-r--r--fcp/secagg/shared/prng.h51
-rw-r--r--fcp/secagg/shared/secagg_messages.proto272
-rw-r--r--fcp/secagg/shared/secagg_vector.cc386
-rw-r--r--fcp/secagg/shared/secagg_vector.h311
-rw-r--r--fcp/secagg/shared/secagg_vector_bench.cc106
-rw-r--r--fcp/secagg/shared/secagg_vector_test.cc477
-rw-r--r--fcp/secagg/shared/shamir_secret_sharing.cc295
-rw-r--r--fcp/secagg/shared/shamir_secret_sharing.h139
-rw-r--r--fcp/secagg/shared/shamir_secret_sharing_test.cc190
-rw-r--r--fcp/secagg/testing/BUILD57
-rw-r--r--fcp/secagg/testing/ecdh_pregenerated_test_keys.cc168
-rw-r--r--fcp/secagg/testing/ecdh_pregenerated_test_keys.h62
-rw-r--r--fcp/secagg/testing/fake_prng.h48
-rw-r--r--fcp/secagg/testing/mock_send_to_server_interface.h37
-rw-r--r--fcp/secagg/testing/mock_state_transition_listener.h39
-rw-r--r--fcp/secagg/testing/server/BUILD45
-rw-r--r--fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h71
-rw-r--r--fcp/secagg/testing/server/mock_send_to_clients_interface.h42
-rw-r--r--fcp/secagg/testing/server/test_async_runner.h49
-rw-r--r--fcp/secagg/testing/server/test_secagg_experiments.h51
-rw-r--r--fcp/secagg/testing/test_matchers.cc83
-rw-r--r--fcp/secagg/testing/test_matchers.h54
-rw-r--r--fcp/tensorflow/BUILD974
-rw-r--r--fcp/tensorflow/append_slices.py74
-rw-r--r--fcp/tensorflow/append_slices_op.cc591
-rw-r--r--fcp/tensorflow/append_slices_test.py183
-rw-r--r--fcp/tensorflow/crc32.py40
-rw-r--r--fcp/tensorflow/crc32_op.cc63
-rw-r--r--fcp/tensorflow/crc32_test.py53
-rw-r--r--fcp/tensorflow/delete_file.py50
-rw-r--r--fcp/tensorflow/delete_file_op.cc127
-rw-r--r--fcp/tensorflow/delete_file_test.py100
-rw-r--r--fcp/tensorflow/dictionary_ops.cc252
-rw-r--r--fcp/tensorflow/dictionary_ops.py372
-rw-r--r--fcp/tensorflow/dictionary_ops_test.py110
-rw-r--r--fcp/tensorflow/example_selector_fuser.py50
-rw-r--r--fcp/tensorflow/example_selector_fuser_op.cc104
-rw-r--r--fcp/tensorflow/example_selector_fuser_test.py60
-rw-r--r--fcp/tensorflow/external_dataset.h179
-rw-r--r--fcp/tensorflow/external_dataset.py54
-rw-r--r--fcp/tensorflow/external_dataset_op.cc224
-rw-r--r--fcp/tensorflow/external_dataset_op_test.cc320
-rw-r--r--fcp/tensorflow/host_object.cc63
-rw-r--r--fcp/tensorflow/host_object.h162
-rw-r--r--fcp/tensorflow/host_object_test.cc80
-rw-r--r--fcp/tensorflow/make_external_dataset_test_graph.py58
-rw-r--r--fcp/tensorflow/make_serve_slices_test_graph.py68
-rw-r--r--fcp/tensorflow/make_slices_selector_example_selector.py34
-rw-r--r--fcp/tensorflow/make_slices_selector_example_selector_op.cc85
-rw-r--r--fcp/tensorflow/make_slices_selector_example_selector_test.py42
-rw-r--r--fcp/tensorflow/python/BUILD15
-rw-r--r--fcp/tensorflow/python/serve_slices_registry.cc126
-rw-r--r--fcp/tensorflow/python/serve_slices_registry_test.py76
-rw-r--r--fcp/tensorflow/serve_slices.py107
-rw-r--r--fcp/tensorflow/serve_slices_op.cc192
-rw-r--r--fcp/tensorflow/serve_slices_op_test.cc178
-rw-r--r--fcp/tensorflow/serve_slices_registry.h107
-rw-r--r--fcp/tensorflow/serve_slices_registry_test.cc82
-rw-r--r--fcp/tensorflow/status.cc45
-rw-r--r--fcp/tensorflow/status.h37
-rw-r--r--fcp/tensorflow/status_test.cc40
-rw-r--r--fcp/tensorflow/system_provided_tf/BUILD1
-rw-r--r--fcp/tensorflow/system_provided_tf/README.md20
-rw-r--r--fcp/tensorflow/system_provided_tf/system_provided_tf.bzl161
-rw-r--r--fcp/tensorflow/system_provided_tf/templates/BUILD.tpl36
-rw-r--r--fcp/tensorflow/system_provided_tf/templates/system_provided_tf.bzl.tpl125
-rw-r--r--fcp/tensorflow/task_eligibility_info_ops.cc103
-rw-r--r--fcp/tensorflow/task_eligibility_info_ops.py58
-rw-r--r--fcp/tensorflow/task_eligibility_info_ops_test.py96
-rw-r--r--fcp/tensorflow/tensor_crc32.cc33
-rw-r--r--fcp/tensorflow/tensor_crc32.h33
-rw-r--r--fcp/tensorflow/tensor_name.py33
-rw-r--r--fcp/tensorflow/tensor_name_op.cc66
-rw-r--r--fcp/tensorflow/tensor_name_test.py52
-rw-r--r--fcp/tensorflow/test_selector.proto43
-rw-r--r--fcp/tensorflow/testing/BUILD39
-rw-r--r--fcp/tensorflow/testing/tf_helper.cc29
-rw-r--r--fcp/tensorflow/testing/tf_helper.h42
-rw-r--r--fcp/tensorflow/tf_py_smoke_test.py38
-rw-r--r--fcp/tensorflow/tf_session.cc191
-rw-r--r--fcp/tensorflow/tf_session.h115
-rw-r--r--fcp/tensorflow/tf_session_test.cc296
-rw-r--r--fcp/tensorflow/tf_smoke_test.cc64
-rw-r--r--fcp/tensorflow/tracing_schema.fbs45
-rw-r--r--fcp/testdata/federation_client_only_plan.pbbin0 -> 194838 bytes
-rw-r--r--fcp/testdata/federation_proxy_train_examples.pbbin0 -> 15854 bytes
-rw-r--r--fcp/testdata/federation_test_checkpoint.client.ckpbin0 -> 31846 bytes
-rw-r--r--fcp/testdata/federation_test_select_checkpoints.pbbin0 -> 15854 bytes
-rw-r--r--fcp/testing/BUILD120
-rw-r--r--fcp/testing/parse_text_proto.h65
-rw-r--r--fcp/testing/result_matchers.h132
-rw-r--r--fcp/testing/result_matchers_test.cc66
-rw-r--r--fcp/testing/test_messages.proto25
-rw-r--r--fcp/testing/testdata/verify_baseline_test.baseline1
-rw-r--r--fcp/testing/testing.cc264
-rw-r--r--fcp/testing/testing.h240
-rw-r--r--fcp/testing/testing_test.cc101
-rw-r--r--fcp/testing/tracing_schema.fbs20
-rw-r--r--fcp/tracing/BUILD130
-rw-r--r--fcp/tracing/build_defs.bzl104
-rw-r--r--fcp/tracing/scoped_tracing_recorder.h46
-rw-r--r--fcp/tracing/test/BUILD99
-rw-r--r--fcp/tracing/test/test_api_message.proto41
-rw-r--r--fcp/tracing/test/testdata/Basic.baseline18
-rw-r--r--fcp/tracing/test/testdata/ChangeThreadLocal1.baseline4
-rw-r--r--fcp/tracing/test/testdata/ChangeThreadLocal2.baseline4
-rw-r--r--fcp/tracing/test/testdata/PerThread1.baseline9
-rw-r--r--fcp/tracing/test/testdata/PerThread2.baseline9
-rw-r--r--fcp/tracing/test/text_tracing_test.cc99
-rw-r--r--fcp/tracing/test/thread_local_tracing_recorder_test.cc191
-rw-r--r--fcp/tracing/test/tracing_context_utils_test.cc98
-rw-r--r--fcp/tracing/test/tracing_schema.fbs83
-rw-r--r--fcp/tracing/test/tracing_test.cc393
-rw-r--r--fcp/tracing/test_tracing_recorder.cc162
-rw-r--r--fcp/tracing/test_tracing_recorder.h471
-rw-r--r--fcp/tracing/test_tracing_recorder_impl.cc66
-rw-r--r--fcp/tracing/test_tracing_recorder_impl.h61
-rw-r--r--fcp/tracing/test_tracing_span_impl.cc46
-rw-r--r--fcp/tracing/test_tracing_span_impl.h49
-rw-r--r--fcp/tracing/text_tracing_recorder.h54
-rw-r--r--fcp/tracing/text_tracing_recorder_impl.cc100
-rw-r--r--fcp/tracing/text_tracing_recorder_impl.h91
-rw-r--r--fcp/tracing/text_tracing_span_impl.cc42
-rw-r--r--fcp/tracing/text_tracing_span_impl.h82
-rw-r--r--fcp/tracing/tools/BUILD56
-rw-r--r--fcp/tracing/tools/README.md16
-rwxr-xr-xfcp/tracing/tools/test_codegen_runner.sh32
-rw-r--r--fcp/tracing/tools/testdata/AllTypes.baseline79
-rw-r--r--fcp/tracing/tools/testdata/AllTypes.fbs16
-rw-r--r--fcp/tracing/tools/testdata/DeprecatedField.baseline109
-rw-r--r--fcp/tracing/tools/testdata/DeprecatedField.fbs11
-rw-r--r--fcp/tracing/tools/testdata/DuplicateTags.baseline17
-rw-r--r--fcp/tracing/tools/testdata/DuplicateTags.fbs11
-rw-r--r--fcp/tracing/tools/testdata/EmptyTable.baseline65
-rw-r--r--fcp/tracing/tools/testdata/EmptyTable.fbs3
-rw-r--r--fcp/tracing/tools/testdata/EnumType.baseline71
-rw-r--r--fcp/tracing/tools/testdata/EnumType.fbs8
-rw-r--r--fcp/tracing/tools/testdata/FieldsOfDifferentTypes.baseline111
-rw-r--r--fcp/tracing/tools/testdata/FieldsOfDifferentTypes.fbs11
-rw-r--r--fcp/tracing/tools/testdata/NoAttributes.baseline9
-rw-r--r--fcp/tracing/tools/testdata/NoAttributes.fbs3
-rw-r--r--fcp/tracing/tools/testdata/NoTag.baseline9
-rw-r--r--fcp/tracing/tools/testdata/NoTag.fbs3
-rw-r--r--fcp/tracing/tools/testdata/NonTableObjectsAreSkipped.baseline78
-rw-r--r--fcp/tracing/tools/testdata/NonTableObjectsAreSkipped.fbs15
-rw-r--r--fcp/tracing/tools/testdata/OrderWithIds.baseline69
-rw-r--r--fcp/tracing/tools/testdata/OrderWithIds.fbs7
-rw-r--r--fcp/tracing/tools/testdata/TableWithNamespace.baseline72
-rw-r--r--fcp/tracing/tools/testdata/TableWithNamespace.fbs9
-rw-r--r--fcp/tracing/tools/testdata/TagTooLong.baseline9
-rw-r--r--fcp/tracing/tools/testdata/TagTooLong.fbs3
-rw-r--r--fcp/tracing/tools/testdata/UnsupportedType.baseline18
-rw-r--r--fcp/tracing/tools/testdata/UnsupportedType.fbs12
-rw-r--r--fcp/tracing/tools/tracing_traits_generator.cc424
-rw-r--r--fcp/tracing/tools/tracing_traits_generator_test.cc125
-rw-r--r--fcp/tracing/tracing_context_utils.cc40
-rw-r--r--fcp/tracing/tracing_context_utils.h70
-rw-r--r--fcp/tracing/tracing_recorder.h62
-rw-r--r--fcp/tracing/tracing_recorder_impl.cc146
-rw-r--r--fcp/tracing/tracing_recorder_impl.h68
-rw-r--r--fcp/tracing/tracing_schema_common.fbs28
-rw-r--r--fcp/tracing/tracing_schema_common_generated.h16
-rw-r--r--fcp/tracing/tracing_severity.h24
-rw-r--r--fcp/tracing/tracing_span.h250
-rw-r--r--fcp/tracing/tracing_span_id.cc26
-rw-r--r--fcp/tracing/tracing_span_id.h63
-rw-r--r--fcp/tracing/tracing_span_impl.cc38
-rw-r--r--fcp/tracing/tracing_span_impl.h68
-rw-r--r--fcp/tracing/tracing_span_ref.cc30
-rw-r--r--fcp/tracing/tracing_span_ref.h55
-rw-r--r--fcp/tracing/tracing_tag.h70
-rw-r--r--fcp/tracing/tracing_tag_test.cc42
-rw-r--r--fcp/tracing/tracing_traits.cc63
-rw-r--r--fcp/tracing/tracing_traits.h91
-rw-r--r--requirements.txt28
-rw-r--r--third_party/BUILD0
-rw-r--r--third_party/curl.BUILD.bzl618
705 files changed, 124411 insertions, 0 deletions
diff --git a/CONTRIBUTING b/CONTRIBUTING
new file mode 100644
index 0000000..08025f3
--- /dev/null
+++ b/CONTRIBUTING
@@ -0,0 +1,7 @@
+# Contributing
+
+This is a copy of Google's internal code, intended to assist developers and
+researchers who are interested in the implementation details of Google's
+federated compute client. Changes to this code are made by Google first, and
+then mirrored to this repository. External contributions are currently not
+accepted.
diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md
new file mode 100644
index 0000000..b9d5370
--- /dev/null
+++ b/GETTING_STARTED.md
@@ -0,0 +1,81 @@
+# Instructions for getting the Federated Compute Platform code up and running on your own machine.
+
+## Download and install build dependencies
+
+### Basic tools
+
+There are some basic tools and packages you will need on your machine:
+
+* Git
+* A C++ compiler (e.g., Clang or GCC, but see note about GCC below)
+* Python 3.9 or greater, including the `venv` module
+
+For example, on Debian:
+
+```
+sudo apt install -y git gcc python3 python3-dev python3-venv
+```
+
+> ⚠️ The project maintainers internally test with Clang only, so support for
+> GCC-based builds is provided only on a best-effort basis and may at times be
+> broken.
+>
+> If using GCC then we recommend using a recent version (e.g., at least as
+> recent as what Debian stable uses, preferably newer than that).
+>
+> If using Clang then please see [Building with Clang](#building-with-clang) for
+> further Clang-specific instructions.
+
+### Install Bazelisk
+
+Bazelisk is used to fetch the correct Bazel binaries necessary to build and run
+Federated Compute code.
+
+Please read https://github.com/bazelbuild/bazelisk#installation.
+
+## Set up your Python environment
+
+Setting up a virtual Python environment will ensure that Python dependencies
+don't conflict or overwrite your existing Python installation. If you have
+multiple installed versions of Python, replace `python3` in the following
+instructions with the desired version (e.g., `python3.X`).
+
+```
+python3 -m venv venv
+source venv/bin/activate
+pip install --upgrade pip
+```
+
+Note: To exit the virtual environment, run `deactivate`.
+
+## Clone the Federated Compute repository and install Python requirements
+
+```
+git clone https://github.com/google/federated-compute.git
+cd federated-compute
+pip install -r requirements.txt
+```
+
+## Build and run the federated program test!
+
+> ⚠️ Many Federated Compute targets depend on TensorFlow, which can take several
+> hours to build for the first time. Consider running builds in `screen` or
+> `tmux` if you're worried about your terminal closing during this time.
+>
+> While not required, Bazel's
+> [remote build execution](https://bazel.build/remote/rbe) and
+> [remote caching](https://bazel.build/remote/caching) features can speed up
+> builds.
+
+```
+bazelisk test //fcp/demo:federated_program_test
+```
+
+### Building with Clang
+
+Use `--config=clang` to build with clang and libc++. On Debian, this requires
+installing several additional packages:
+
+```
+sudo apt install -y clang lld libc++-dev libc++abi-dev`
+```
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/METADATA b/METADATA
new file mode 100644
index 0000000..50996e9
--- /dev/null
+++ b/METADATA
@@ -0,0 +1,12 @@
+name: "federated-compute"
+description:
+ "This repository contains client code for executing federated computations."
+third_party {
+ url {
+ type: GIT
+ value: "https://github.com/google/federated-compute"
+ }
+ version: "30c3b60781f6d89cf5c401b2c392408ddcd04eab"
+ last_upgrade_date { year: 2023 month: 5 day: 4 }
+ license_type: NOTICE
+}
diff --git a/MODULE_LICENSE_APACHE2 b/MODULE_LICENSE_APACHE2
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/MODULE_LICENSE_APACHE2
diff --git a/OWNERS b/OWNERS
new file mode 100644
index 0000000..f73eb47
--- /dev/null
+++ b/OWNERS
@@ -0,0 +1,3 @@
+qiaoli@google.com
+tarading@google.com
+ymu@google.com
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..44e0432
--- /dev/null
+++ b/README.md
@@ -0,0 +1,64 @@
+# Federated Compute Platform
+
+This repository hosts code for executing federated programs and computations.
+
+## Definitions
+
+A *federated computation* is a set of processing steps that run on a server and
+set of clients, where each step is either
+
+* local processing of values on the client or server or
+* transport which moves values via
+ * broadcast (server-to-clients)
+ * select (client-server response)
+ * aggregate (clients-to-server). A federated computation invoked by a
+ central coordinator returns one or more aggregates.
+
+A *federated program* is a set of processing steps run by a central coordinator
+that include one or more invocations of federated computation, transformations
+of computation results, and releases of aggregate values to the engineer/analyst
+who invoked the program.
+
+To learn more about these concepts, check out:
+
+- [Federated learning comic book from Google AI](http://g.co/federated)
+- [Federated Learning: Collaborative Machine Learning without Centralized
+ Training
+ Data](https://ai.googleblog.com/2017/04/federated-learning-collaborative.html)
+- [Federated Analytics: Collaborative Data Science without Data Collection](https://ai.googleblog.com/2020/05/federated-analytics-collaborative-data.html)
+- [Towards Federated Learning at Scale: System Design](https://arxiv.org/abs/1902.01046)
+ (SysML 2019)
+- [Federated Program API in TFF](https://github.com/tensorflow/federated/blob/main/tensorflow_federated/python/program/README.md)
+
+## Infrastructure
+
+At Google, federated programs and computations are authored in
+[TensorFlow Federated](http://tensorflow.org/federated), compiled to deployable
+artifacts, and run in a distributed system consisting of a central coordinator,
+and a set of devices such as phones. The TFF repository contains infrastructure
+for authoring and simulating federated programs and computations.
+
+This repository hosts infrastructure for compiling and running federated
+programs and computations in the cross-device setting. We are actively working
+on open sourcing the core components of our production infrastructure, with a
+focus on privacy-sensitive code-paths such as the pipeline for compiling
+deployable artifacts from TFF computations, client-side processing, and
+server-side aggregation logic.
+
+As of 12/7/2022, parts of the repository - in particular, code in the `client/`
+directory, and the service & data format definitions in `proto/` - are used in
+production in Google's federated learning infrastructure. Other parts - notably,
+production server side infrastructure - have not yet been open sourced due to
+its dependencies on proprietary infrastructure, and we instead provide a
+reference / example server implementation in `demo/` for demonstration purposes.
+
+The best way to get started is to run the end-to-end demo
+`//fcp/demo:federated_program_test`, which will spin up example services,
+clients, and run a federated program; this test will cover the majority of the
+code in this repository.
+
+This is not an officially supported Google product.
+
+## Getting Started
+
+Please refer to the instructions in GETTING_STARTED.md.
diff --git a/WORKSPACE b/WORKSPACE
new file mode 100644
index 0000000..b837760
--- /dev/null
+++ b/WORKSPACE
@@ -0,0 +1,287 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This file uses repository rules to fetch external dependencies.
+#
+# We require that all uses of repository rules are effectively deterministic -
+# they may fail, but otherwise must always produce exactly the same output for
+# given parameters. For the http_archive rule, this means that you must specify
+# a sha256 parameter; consequently, make sure to use stable URLs, rather than
+# e.g. one that names a git branch.
+#
+# This is so that all(*) inputs to the build are fully determined (i.e. a
+# change to build inputs requires a change to our sources), which avoids
+# confusing outcomes from caching. If it is ever productive to clear your Bazel
+# cache, that's a bug.
+#
+# (*) A Bazel build depends on local compiler toolchains (and Bazel itself), so
+# it can be useful to pick a particular container image too (like some version
+# of http://l.gcr.io/bazel).
+#
+# The repository namespace
+# ------------------------
+#
+# Bazel's handling of @repository_names// is very broken. There is a single,
+# global namespace for repositories. Conflicts are silently ignored. This is
+# problematic for common dependencies. As much as possible, we use TensorFlow's
+# dependencies. Things become especially difficult if we try to use our own
+# version of grpc or protobuf. Current overrides:
+#
+# - @com_github_gflags_gflags: tf uses 2.2.1 and we need 2.2.2 apparently.
+# - @com_google_googletest: Need to patch in support for absl flags
+
+workspace(name = "com_google_fcp")
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+
+# Needed for user-defined configs
+http_archive(
+ name = "bazel_skylib",
+ sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz",
+ "https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz",
+ ],
+)
+
+load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
+
+bazel_skylib_workspace()
+
+# GoogleTest/GoogleMock framework. Used by most unit-tests.
+http_archive(
+ name = "com_google_googletest",
+ patches = ["//fcp/patches:googletest.patch"],
+ sha256 = "81964fe578e9bd7c94dfdb09c8e4d6e6759e19967e397dbea48d1c10e45d0df2",
+ strip_prefix = "googletest-release-1.12.1",
+ urls = ["https://github.com/google/googletest/archive/refs/tags/release-1.12.1.tar.gz"],
+)
+
+http_archive(
+ name = "com_github_gflags_gflags",
+ sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf",
+ strip_prefix = "gflags-2.2.2",
+ urls = ["https://github.com/gflags/gflags/archive/v2.2.2.tar.gz"],
+)
+
+http_archive(
+ name = "com_github_google_glog",
+ sha256 = "21bc744fb7f2fa701ee8db339ded7dce4f975d0d55837a97be7d46e8382dea5a",
+ strip_prefix = "glog-0.5.0",
+ urls = ["https://github.com/google/glog/archive/v0.5.0.zip"],
+)
+
+http_archive(
+ name = "com_github_grpc_grpc",
+ sha256 = "76900ab068da86378395a8e125b5cc43dfae671e09ff6462ddfef18676e2165a",
+ strip_prefix = "grpc-1.50.0",
+ urls = ["https://github.com/grpc/grpc/archive/refs/tags/v1.50.0.tar.gz"],
+)
+
+# TensorFlow 2.11.0 pins an old version of upb that's compatible with their old
+# version of gRPC, but not with the newer version we use. Pin the version that
+# would be added by gRPC 1.50.0.
+http_archive(
+ name = "upb",
+ sha256 = "017a7e8e4e842d01dba5dc8aa316323eee080cd1b75986a7d1f94d87220e6502",
+ strip_prefix = "upb-e4635f223e7d36dfbea3b722a4ca4807a7e882e2",
+ urls = [
+ "https://storage.googleapis.com/grpc-bazel-mirror/github.com/protocolbuffers/upb/archive/e4635f223e7d36dfbea3b722a4ca4807a7e882e2.tar.gz",
+ "https://github.com/protocolbuffers/upb/archive/e4635f223e7d36dfbea3b722a4ca4807a7e882e2.tar.gz",
+ ],
+)
+
+# The version provided by TensorFlow 2.11 doesn't support equality checks for
+# absl::Status.
+http_archive(
+ name = "pybind11_abseil",
+ sha256 = "6481888831cd548858c09371ea892329b36c8d4d961f559876c64e009d0bc630",
+ strip_prefix = "pybind11_abseil-3922b3861a2b27d4111e3ac971e6697ea030a36e",
+ url = "https://github.com/pybind/pybind11_abseil/archive/3922b3861a2b27d4111e3ac971e6697ea030a36e.tar.gz",
+)
+
+http_archive(
+ name = "pybind11_protobuf",
+ sha256 = "fe2b8bf12a65997b853709a5e719f7561b2e86a4cdbb9d8b051e654dd0fd8d11",
+ strip_prefix = "pybind11_protobuf-a50899c2eb604fc5f25deeb8901eff6231b8b3c0",
+ url = "https://github.com/pybind/pybind11_protobuf/archive/a50899c2eb604fc5f25deeb8901eff6231b8b3c0.tar.gz",
+)
+
+# Define the @io_grpc_grpc_java repository, which is used by the
+# @com_google_googleapis repository to define the Java protobuf targets such as
+# @com_google_googleapis//google/rpc:rpc_java_proto). The pattern we use here is
+# the same as @com_google_googleapis' WORKSPACE file uses.
+#
+# Note that the @com_google_googleapis repository is actually defined
+# transitively by the @org_tensorflow workspace rules.
+http_archive(
+ name = "com_google_api_gax_java",
+ sha256 = "7c172c20dc52c09f42b3077a5195dc5fbcb30e023831918593d1b81a1aea650e",
+ strip_prefix = "gax-java-2.20.0",
+ urls = ["https://github.com/googleapis/gax-java/archive/v2.20.0.zip"],
+)
+
+load("@com_google_api_gax_java//:repository_rules.bzl", "com_google_api_gax_java_properties")
+
+com_google_api_gax_java_properties(
+ name = "com_google_api_gax_java_properties",
+ file = "@com_google_api_gax_java//:dependencies.properties",
+)
+
+load("@com_google_api_gax_java//:repositories.bzl", "com_google_api_gax_java_repositories")
+
+com_google_api_gax_java_repositories()
+
+# Tensorflow v2.12.0
+http_archive(
+ name = "org_tensorflow",
+ patch_tool = "patch",
+ patches = [
+ # This patch enables googleapi Java and Python proto rules such as
+ # @com_google_googleapis//google/rpc:rpc_java_proto.
+ "//fcp/patches:tensorflow_googleapis_proto_rules.patch",
+ # This patch works around failures in GitHub infrastructure to
+ # download versions of LLVM pointed to by non-HEAD TensorFlow.
+ # TODO(team): Remove this patch when resolved.
+ "//fcp/patches:tensorflow_llvm_url.patch",
+ # TensorFlow's custom pybind11 BUILD file is missing the osx config
+ # setting expected by pybind11_bazel.
+ "//fcp/patches:tensorflow_pybind11_osx.patch",
+ # This patch removes tf_custom_op_py_library's dependency on the Bazel
+ # version of TensorFlow since for all of our Python code, we rely on a
+ # system-provided TensorFlow.
+ "//fcp/patches:tensorflow_tf_custom_op_py_library.patch",
+ # gRPC v1.48.0-pre1 and later include zconf.h in addition to zlib.h;
+ # TensorFlow's build rule for zlib only exports the latter.
+ "//fcp/patches:tensorflow_zlib.patch",
+ ],
+ sha256 = "c030cb1905bff1d2446615992aad8d8d85cbe90c4fb625cee458c63bf466bc8e",
+ strip_prefix = "tensorflow-2.12.0",
+ urls = [
+ "https://github.com/tensorflow/tensorflow/archive/v2.12.0.tar.gz",
+ ],
+)
+
+# The following is copied from TensorFlow's own WORKSPACE, see
+# https://github.com/tensorflow/tensorflow/blob/v2.8.0/WORKSPACE#L9
+
+load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
+
+tf_workspace3()
+
+load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2")
+
+tf_workspace2()
+
+load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1")
+
+tf_workspace1()
+
+load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0")
+
+tf_workspace0()
+
+load("//fcp/tensorflow/system_provided_tf:system_provided_tf.bzl", "system_provided_tf")
+
+system_provided_tf(
+ name = "system_provided_tf",
+)
+
+http_archive(
+ name = "com_google_benchmark",
+ sha256 = "23082937d1663a53b90cb5b61df4bcc312f6dee7018da78ba00dd6bd669dfef2",
+ strip_prefix = "benchmark-1.5.1",
+ urls = [
+ "https://github.com/google/benchmark/archive/v1.5.1.tar.gz",
+ ],
+)
+
+# Cpp ProtoDataStore
+http_archive(
+ name = "protodatastore_cpp",
+ sha256 = "d2627231fce0c9944100812b42d33203f7e03e78215d10123e78041b491005c3",
+ strip_prefix = "protodatastore-cpp-0cd8b124bc65bcac105bce4a706923218ae5d625",
+ url = "https://github.com/google/protodatastore-cpp/archive/0cd8b124bc65bcac105bce4a706923218ae5d625.zip",
+)
+
+# libcurl is used in the http client.
+http_archive(
+ name = "libcurl",
+ build_file = "//third_party:curl.BUILD.bzl",
+ sha256 = "e9d74b8586e0d2e6b45dc948bbe77525a1fa7f7c004ad5192f12e72c365e376e",
+ strip_prefix = "curl-curl-7_84_0",
+ url = "https://github.com/curl/curl/archive/refs/tags/curl-7_84_0.zip",
+)
+
+# We use only the http test server.
+http_archive(
+ name = "tensorflow_serving",
+ patch_tool = "patch",
+ patches = [
+ "//fcp/patches:tensorflow_serving.patch",
+ ],
+ sha256 = "6b428100be7ec4bb34fc7910b7f549b6556854016a3c4cbda5613b5c114797b3",
+ strip_prefix = "serving-2.9.0",
+ url = "https://github.com/tensorflow/serving/archive/refs/tags/2.9.0.zip",
+)
+
+load("@tensorflow_serving//tensorflow_serving:workspace.bzl", "tf_serving_workspace")
+
+tf_serving_workspace()
+
+# Java Maven-based repositories.
+http_archive(
+ name = "rules_jvm_external",
+ sha256 = "cd1a77b7b02e8e008439ca76fd34f5b07aecb8c752961f9640dea15e9e5ba1ca",
+ strip_prefix = "rules_jvm_external-4.2",
+ url = "https://github.com/bazelbuild/rules_jvm_external/archive/4.2.zip",
+)
+
+load("@rules_jvm_external//:repositories.bzl", "rules_jvm_external_deps")
+
+rules_jvm_external_deps()
+
+load("@rules_jvm_external//:setup.bzl", "rules_jvm_external_setup")
+
+rules_jvm_external_setup()
+
+load("@rules_jvm_external//:defs.bzl", "maven_install")
+
+maven_install(
+ name = "fcp_maven",
+ artifacts = [
+ "com.google.code.findbugs:jsr305:3.0.2",
+ "com.google.errorprone:error_prone_annotations:2.11.0",
+ "com.google.guava:guava:31.0.1-jre",
+ "com.google.truth:truth:1.1.3",
+ "junit:junit:4.13",
+ "org.mockito:mockito-core:4.3.1",
+ ],
+ repositories = [
+ "https://maven.google.com",
+ "https://repo1.maven.org/maven2",
+ ],
+)
+
+# The version of googleapis imported by TensorFlow doesn't provide
+# `py_proto_library` targets for //google/longrunning.
+http_archive(
+ name = "googleapis_for_longrunning",
+ patches = [
+ "//fcp/patches:googleapis_longrunning.patch",
+ ],
+ sha256 = "c1db0b022cdfc5b5ce5f05b0f00568e2d927c9890429ec9c35bda12f52d93065",
+ strip_prefix = "googleapis-2d8030c4102f97bc6be4ddab74c7cbfe88d8c016",
+ url = "https://github.com/googleapis/googleapis/archive/2d8030c4102f97bc6be4ddab74c7cbfe88d8c016.tar.gz",
+)
diff --git a/fcp/BUILD b/fcp/BUILD
new file mode 100644
index 0000000..99f975e
--- /dev/null
+++ b/fcp/BUILD
@@ -0,0 +1,21 @@
+# Description
+# Federated Computation Platform (FCP) is a computational framework,
+# which allows to orchestrate decentralized computations across
+# large set of dynamically available nodes (e.g. mobile devices).
+
+package(
+ default_visibility = [":internal"],
+)
+
+exports_files([
+ "LICENSE",
+])
+
+package_group(
+ name = "internal",
+ includes = [
+ ],
+ packages = [
+ "//fcp/...",
+ ],
+)
diff --git a/fcp/TEST_MAPPING b/fcp/TEST_MAPPING
new file mode 100644
index 0000000..3040c8d
--- /dev/null
+++ b/fcp/TEST_MAPPING
@@ -0,0 +1,7 @@
+{
+ "presubmit": [
+ {
+ "name": "fl_runner_test"
+ }
+ ]
+} \ No newline at end of file
diff --git a/fcp/aggregation/BUILD b/fcp/aggregation/BUILD
new file mode 100644
index 0000000..3fe68cc
--- /dev/null
+++ b/fcp/aggregation/BUILD
@@ -0,0 +1,27 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = [":internal"],
+)
+
+package_group(
+ name = "internal",
+ includes = [
+ ],
+ packages = [
+ "//fcp/aggregation/...",
+ "//fcp/demo/...",
+ ],
+)
diff --git a/fcp/aggregation/core/BUILD b/fcp/aggregation/core/BUILD
new file mode 100644
index 0000000..da85cc2
--- /dev/null
+++ b/fcp/aggregation/core/BUILD
@@ -0,0 +1,278 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Aggregation Core API
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp/aggregation:internal"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# TODO(team): Create a "core" library that bundles all core libraries together.
+
+TENSOR_SRCS = [
+ "tensor.cc",
+ "tensor_data.cc",
+ "tensor_shape.cc",
+ "input_tensor_list.cc",
+]
+
+TENSOR_HDRS = [
+ "agg_vector.h",
+ "agg_vector_iterator.h",
+ "datatype.h",
+ "tensor.h",
+ "tensor_data.h",
+ "tensor_shape.h",
+ "tensor_spec.h",
+ "mutable_vector_data.h",
+ "input_tensor_list.h",
+]
+
+proto_library(
+ name = "tensor_proto",
+ srcs = ["tensor.proto"],
+)
+
+cc_proto_library(
+ name = "tensor_cc_proto",
+ deps = [":tensor_proto"],
+)
+
+cc_library(
+ name = "tensor",
+ srcs = TENSOR_SRCS,
+ hdrs = TENSOR_HDRS,
+ copts = FCP_COPTS,
+ deps = [
+ ":tensor_cc_proto",
+ "//fcp/base",
+ "@com_google_absl//absl/strings",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+AGGREGATOR_SRCS = [
+ "tensor_aggregator.cc",
+ "tensor_aggregator_registry.cc",
+]
+
+AGGREGATOR_HDRS = [
+ "agg_vector_aggregator.h",
+ "one_dim_grouping_aggregator.h",
+ "aggregator.h",
+ "tensor_aggregator.h",
+ "tensor_aggregator_factory.h",
+ "tensor_aggregator_registry.h",
+]
+
+cc_library(
+ name = "aggregator",
+ srcs = AGGREGATOR_SRCS,
+ hdrs = AGGREGATOR_HDRS,
+ copts = FCP_COPTS,
+ deps = [
+ ":tensor",
+ "//fcp/base",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_test(
+ name = "tensor_test",
+ srcs = [
+ "tensor_data_test.cc",
+ "tensor_shape_test.cc",
+ "tensor_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":tensor",
+ ":tensor_cc_proto",
+ "//fcp/aggregation/testing",
+ "//fcp/aggregation/testing:test_data",
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "agg_vector_test",
+ srcs = ["agg_vector_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":tensor",
+ "//fcp/aggregation/testing:test_data",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "aggregator_test",
+ srcs = [
+ "agg_vector_aggregator_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":aggregator",
+ ":tensor",
+ ":tensor_cc_proto",
+ "//fcp/aggregation/testing",
+ "//fcp/aggregation/testing:test_data",
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "tensor_aggregator_registry_test",
+ srcs = [
+ "tensor_aggregator_registry_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":aggregator",
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "federated_sum",
+ srcs = [
+ "federated_sum.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":aggregator",
+ ":tensor",
+ "//fcp/base",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "composite_key_combiner",
+ srcs = ["composite_key_combiner.cc"],
+ hdrs = ["composite_key_combiner.h"],
+ deps = [
+ ":tensor",
+ ":tensor_cc_proto",
+ ":vector_string_data",
+ "//fcp/base",
+ ],
+)
+
+# Separate target from :tensor is required for vector_string_data as string_view is not yet
+# supported for nanolibc.
+cc_library(
+ name = "vector_string_data",
+ hdrs = ["vector_string_data.h"],
+ deps = [":tensor"],
+)
+
+cc_test(
+ name = "federated_sum_test",
+ srcs = ["federated_sum_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":aggregator",
+ ":federated_sum",
+ ":tensor",
+ "//fcp/aggregation/testing",
+ "//fcp/aggregation/testing:test_data",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "input_tensor_list_test",
+ srcs = ["input_tensor_list_test.cc"],
+ deps = [
+ ":tensor",
+ "//fcp/aggregation/testing:test_data",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "composite_key_combiner_test",
+ srcs = ["composite_key_combiner_test.cc"],
+ deps = [
+ ":composite_key_combiner",
+ ":tensor",
+ ":tensor_cc_proto",
+ "//fcp/aggregation/testing",
+ "//fcp/aggregation/testing:test_data",
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "mutable_vector_data_test",
+ srcs = ["mutable_vector_data_test.cc"],
+ deps = [
+ ":tensor",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "vector_string_data_test",
+ srcs = ["vector_string_data_test.cc"],
+ deps = [
+ ":tensor",
+ ":vector_string_data",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "one_dim_grouping_aggregator_test",
+ srcs = ["one_dim_grouping_aggregator_test.cc"],
+ deps = [
+ ":aggregator",
+ ":tensor",
+ "//fcp/aggregation/testing",
+ "//fcp/aggregation/testing:test_data",
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_binary(
+ name = "federated_sum_bench",
+ testonly = 1,
+ srcs = ["federated_sum_bench.cc"],
+ copts = FCP_COPTS,
+ linkstatic = 1,
+ deps = [
+ ":aggregator",
+ ":federated_sum",
+ ":tensor",
+ "@com_google_benchmark//:benchmark_main",
+ ],
+)
diff --git a/fcp/aggregation/core/agg_vector.h b/fcp/aggregation/core/agg_vector.h
new file mode 100644
index 0000000..60cb24a
--- /dev/null
+++ b/fcp/aggregation/core/agg_vector.h
@@ -0,0 +1,77 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_AGG_VECTOR_H_
+#define FCP_AGGREGATION_CORE_AGG_VECTOR_H_
+
+#include <cstddef>
+#include <memory>
+#include <utility>
+
+#include "fcp/aggregation/core/agg_vector_iterator.h"
+#include "fcp/aggregation/core/tensor_data.h"
+
+namespace fcp {
+namespace aggregation {
+
+// AggVector is flattened one-dimensional strongly typed view of tensor that
+// provides immutable access to the values.
+//
+// AggVector hides the actual data organization of the tensor. The only
+// way to access the tensor values is through the iterator that returns
+// {index, value} pairs where each index is the dense index corresponding to
+// the value.
+//
+// Example:
+//
+// template <typename T>
+// void Iterate(const AggVector<T>& agg_vector) {
+// for (const auto& [index, value] : agg_vector) {
+// // Aggregate the `value` at the given `index`.
+// }
+// }
+//
+template <typename T>
+class AggVector final {
+ public:
+ using value_type = typename AggVectorIterator<T>::value_type;
+ using const_iterator = AggVectorIterator<T>;
+
+ // Iterator begin() function.
+ const_iterator begin() const { return AggVectorIterator<T>(data_); }
+
+ // Iterator end() function.
+ const_iterator end() const { return AggVectorIterator<T>::end(); }
+
+ // Entire AggVector length.
+ size_t size() const { return size_; }
+
+ private:
+ // AggVector can be created only by Tensor::AsAggVector() method.
+ friend class Tensor;
+ explicit AggVector(const TensorData* data)
+ : size_(data->byte_size() / sizeof(T)), data_(data) {}
+
+ // The total length of the vector (in elements).
+ size_t size_;
+ // Tensor data, owned by the tensor object.
+ const TensorData* data_;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_AGG_VECTOR_H_
diff --git a/fcp/aggregation/core/agg_vector_aggregator.h b/fcp/aggregation/core/agg_vector_aggregator.h
new file mode 100644
index 0000000..992bb0f
--- /dev/null
+++ b/fcp/aggregation/core/agg_vector_aggregator.h
@@ -0,0 +1,150 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_
+#define FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "fcp/aggregation/core/agg_vector.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/input_tensor_list.h"
+#include "fcp/aggregation/core/mutable_vector_data.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_aggregator.h"
+#include "fcp/aggregation/core/tensor_data.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+// AggVectorAggregator class is a specialization of TensorAggregator which
+// operates on AggVector<T> instances rather than tensors.
+template <typename T>
+class AggVectorAggregator : public TensorAggregator {
+ public:
+ AggVectorAggregator(DataType dtype, TensorShape shape)
+ : AggVectorAggregator(dtype, shape,
+ new MutableVectorData<T>(shape.NumElements())) {}
+
+ // Provides mutable access to the aggregator data as a vector<T>
+ inline std::vector<T>& data() { return data_vector_; }
+
+ int GetNumInputs() const override { return num_inputs_; }
+
+ Status MergeWith(TensorAggregator&& other) override {
+ FCP_RETURN_IF_ERROR(CheckValid());
+ FCP_ASSIGN_OR_RETURN(AggVectorAggregator<T> * other_ptr, CastOther(other));
+ FCP_RETURN_IF_ERROR((*other_ptr).CheckValid());
+ int64_t other_num_inputs = other.GetNumInputs();
+ OutputTensorList output_tensors = std::move(*other_ptr).TakeOutputs();
+ FCP_CHECK(output_tensors.size() == 1)
+ << "AggVectorAggregator::MergeOutputTensors: AggVectorAggregator "
+ "should produce a single output tensor";
+ const Tensor& output = output_tensors[0];
+ if (output.shape() != result_tensor_.shape()) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "AggVectorAggregator::MergeOutputTensors: tensor shape "
+ "mismatch";
+ }
+ // Delegate the actual aggregation to the specific aggregation
+ // intrinsic implementation.
+ AggregateVector(output.AsAggVector<T>());
+ num_inputs_ += other_num_inputs;
+ return FCP_STATUS(OK);
+ }
+
+ protected:
+ // Implementation of the tensor aggregation.
+ Status AggregateTensors(InputTensorList tensors) override {
+ FCP_CHECK(tensors.size() == 1)
+ << "AggVectorAggregator should operate on a single input tensor";
+
+ const Tensor* tensor = tensors[0];
+ if (tensor->dtype() != internal::TypeTraits<T>::kDataType) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "AggVectorAggregator::AggregateTensors: dtype mismatch";
+ }
+ if (tensor->shape() != result_tensor_.shape()) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "AggVectorAggregator::AggregateTensors: tensor shape mismatch";
+ }
+ // Delegate the actual aggregation to the specific aggregation
+ // intrinsic implementation.
+ AggregateVector(tensor->AsAggVector<T>());
+ num_inputs_++;
+ return FCP_STATUS(OK);
+ }
+
+ Status CheckValid() const override { return result_tensor_.CheckValid(); }
+
+ OutputTensorList TakeOutputs() && override {
+ OutputTensorList outputs = std::vector<Tensor>();
+ outputs.push_back(std::move(result_tensor_));
+ return outputs;
+ }
+
+ // Delegates AggVector aggregation to a derived class.
+ virtual void AggregateVector(const AggVector<T>& agg_vector) = 0;
+
+ private:
+ AggVectorAggregator(DataType dtype, TensorShape shape,
+ MutableVectorData<T>* data)
+ : result_tensor_(
+ Tensor::Create(dtype, shape, std::unique_ptr<TensorData>(data))
+ .value()),
+ data_vector_(*data),
+ num_inputs_(0) {
+ FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype)
+ << "Incompatible dtype";
+ }
+
+ StatusOr<AggVectorAggregator<T>*> CastOther(TensorAggregator& other) {
+#ifndef FCP_NANOLIBC
+ AggVectorAggregator<T>* other_ptr =
+ dynamic_cast<AggVectorAggregator<T>*>(&other);
+ if (other_ptr == nullptr) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "AggVectorAggregator::MergeOutputTensors: Can only merge with"
+ << "another AggVectorAggregator operating on the same dtype "
+ << internal::TypeTraits<T>::kDataType;
+ }
+ return other_ptr;
+#else /* FCP_NANOLIBC */
+ // When compiling in nanolibc we do not have access to runtime type
+ // information or std::type_traits. Thus we cannot use dynamic cast and use
+ // static_cast instead.
+ // This means we are relying on the caller to always call the MergeWith
+ // method on two TensorAggregators of the same underlying type, or the
+ // program will have undefined behavior due to a static_cast to the wrong
+ // type.
+ return static_cast<AggVectorAggregator<T>*>(&other);
+#endif
+ }
+
+ Tensor result_tensor_;
+ std::vector<T>& data_vector_;
+ int num_inputs_;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_
diff --git a/fcp/aggregation/core/agg_vector_aggregator_test.cc b/fcp/aggregation/core/agg_vector_aggregator_test.cc
new file mode 100644
index 0000000..bf68ef5
--- /dev/null
+++ b/fcp/aggregation/core/agg_vector_aggregator_test.cc
@@ -0,0 +1,174 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/agg_vector_aggregator.h"
+
+#include <cstdint>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/input_tensor_list.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/testing/test_data.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+
+#ifndef FCP_NANOLIBC
+#include "fcp/aggregation/core/tensor.pb.h"
+#endif
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+using testing::Eq;
+using testing::IsFalse;
+using testing::IsTrue;
+
+// A simple Sum Aggregator
+template <typename T>
+class SumAggregator final : public AggVectorAggregator<T> {
+ public:
+ using AggVectorAggregator<T>::AggVectorAggregator;
+ using AggVectorAggregator<T>::data;
+
+ private:
+ void AggregateVector(const AggVector<T>& agg_vector) override {
+ for (auto [i, v] : agg_vector) {
+ data()[i] += v;
+ }
+ }
+};
+
+TEST(AggVectorAggregatorTest, ScalarAggregation_Succeeds) {
+ SumAggregator<int32_t> aggregator(DT_INT32, {});
+ Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
+ Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
+ Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
+ EXPECT_THAT(aggregator.Accumulate(t1), IsOk());
+ EXPECT_THAT(aggregator.Accumulate(t2), IsOk());
+ EXPECT_THAT(aggregator.Accumulate(t3), IsOk());
+ EXPECT_THAT(aggregator.CanReport(), IsTrue());
+
+ auto result = std::move(aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
+}
+
+TEST(AggVectorAggregatorTest, DenseAggregation_Succeeds) {
+ const TensorShape shape = {4};
+ SumAggregator<int32_t> aggregator(DT_INT32, shape);
+ Tensor t1 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({1, 3, 15, 27})).value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
+ EXPECT_THAT(aggregator.Accumulate(t1), IsOk());
+ EXPECT_THAT(aggregator.Accumulate(t2), IsOk());
+ EXPECT_THAT(aggregator.Accumulate(t3), IsOk());
+ EXPECT_THAT(aggregator.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor(shape, {14, 19, 23, 49}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(AggVectorAggregationTest, Merge_Succeeds) {
+ SumAggregator<int32_t> aggregator1(DT_INT32, {});
+ SumAggregator<int32_t> aggregator2(DT_INT32, {});
+ Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
+ Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
+ Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
+ EXPECT_THAT(aggregator1.Accumulate(t1), IsOk());
+ EXPECT_THAT(aggregator2.Accumulate(t2), IsOk());
+ EXPECT_THAT(aggregator2.Accumulate(t3), IsOk());
+
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
+ EXPECT_THAT(aggregator1.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator1).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
+}
+
+TEST(AggVectorAggregationTest, Aggregate_IncompatibleDataType) {
+ SumAggregator<int32_t> aggregator(DT_INT32, {});
+ Tensor t = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({0})).value();
+ EXPECT_THAT(aggregator.Accumulate(t), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(AggVectorAggregationTest, Aggregate_IncompatibleShape) {
+ SumAggregator<int32_t> aggregator(DT_INT32, {});
+ Tensor t = Tensor::Create(DT_INT32, {2, 1}, CreateTestData({0, 1})).value();
+ EXPECT_THAT(aggregator.Accumulate(t), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(AggVectorAggregationTest, Merge_IncompatibleDataType) {
+ SumAggregator<int32_t> aggregator1(DT_INT32, {});
+ SumAggregator<float> aggregator2(DT_FLOAT, {});
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(AggVectorAggregationTest, Merge_IncompatibleShape) {
+ SumAggregator<int32_t> aggregator1(DT_INT32, {3, 5});
+ SumAggregator<int32_t> aggregator2(DT_INT32, {5, 3});
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(AggVectorAggregationTest, FailsAfterBeingConsumed) {
+ SumAggregator<int32_t> aggregator(DT_INT32, {});
+ EXPECT_THAT(std::move(aggregator).Report(), IsOk());
+
+ // Now the aggregator instance has been consumed and should fail any
+ // further operations.
+ EXPECT_THAT(aggregator.CanReport(), IsFalse()); // NOLINT
+ EXPECT_THAT(std::move(aggregator).Report(),
+ IsCode(FAILED_PRECONDITION)); // NOLINT
+ EXPECT_THAT(aggregator.Accumulate( // NOLINT
+ Tensor::Create(DT_INT32, {}, CreateTestData({0})).value()),
+ IsCode(FAILED_PRECONDITION));
+ EXPECT_THAT(
+ aggregator.MergeWith(SumAggregator<int32_t>(DT_INT32, {})), // NOLINT
+ IsCode(FAILED_PRECONDITION));
+
+ // Passing this aggregator as an argument to another MergeWith must fail too.
+ SumAggregator<int32_t> aggregator2(DT_INT32, {});
+ EXPECT_THAT(aggregator2.MergeWith(std::move(aggregator)), // NOLINT
+ IsCode(FAILED_PRECONDITION));
+}
+
+TEST(AggVectorAggregatorTest, TypeCheckFailure) {
+ EXPECT_DEATH(new SumAggregator<float>(DT_INT32, {}), "Incompatible dtype");
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/agg_vector_iterator.h b/fcp/aggregation/core/agg_vector_iterator.h
new file mode 100644
index 0000000..960778f
--- /dev/null
+++ b/fcp/aggregation/core/agg_vector_iterator.h
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_AGG_VECTOR_ITERATOR_H_
+#define FCP_AGGREGATION_CORE_AGG_VECTOR_ITERATOR_H_
+
+#include "fcp/aggregation/core/tensor_data.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+// Iterator for AggVector which allows to iterate over sparse values
+// as a collection of {index, value} pairs.
+//
+// This allows a simple iteration loops like the following:
+// for (auto [index, value] : agg_vector) {
+// ... aggregate the value at the given dense index
+// }
+template <typename T>
+struct AggVectorIterator {
+ struct IndexValuePair {
+ size_t index;
+ T value;
+
+ friend bool operator==(const IndexValuePair& a, const IndexValuePair& b) {
+ return a.index == b.index && a.value == b.value;
+ }
+
+ friend bool operator!=(const IndexValuePair& a, const IndexValuePair& b) {
+ return a.index != b.index || a.value != b.value;
+ }
+ };
+
+ using value_type = IndexValuePair;
+ using pointer = value_type*;
+ using reference = value_type&;
+
+ explicit AggVectorIterator(const TensorData* data)
+ : AggVectorIterator(get_start_ptr(data), get_end_ptr(data), 0) {}
+
+ // Current dense index corresponding to the current value.
+ size_t index() const { return dense_index; }
+ // Current value.
+ T value() const { return *ptr; }
+ // The current interator {index, value} pair value. This is used by
+ // for loop iterators.
+ IndexValuePair operator*() const { return {dense_index, *ptr}; }
+
+ AggVectorIterator& operator++() {
+ FCP_CHECK(ptr != end_ptr);
+ if (++ptr == end_ptr) {
+ *this = end();
+ } else {
+ dense_index++;
+ }
+ return *this;
+ }
+
+ AggVectorIterator operator++(int) {
+ AggVectorIterator tmp = *this;
+ ++(*this);
+ return tmp;
+ }
+
+ friend bool operator==(const AggVectorIterator& a,
+ const AggVectorIterator& b) {
+ return a.ptr == b.ptr;
+ }
+
+ friend bool operator!=(const AggVectorIterator& a,
+ const AggVectorIterator& b) {
+ return a.ptr != b.ptr;
+ }
+
+ static AggVectorIterator end() {
+ return AggVectorIterator(nullptr, nullptr, 0);
+ }
+
+ private:
+ AggVectorIterator(const T* ptr, const T* end_ptr, size_t dense_index)
+ : ptr(ptr), end_ptr(end_ptr), dense_index(dense_index) {}
+
+ static const T* get_start_ptr(const TensorData* data) {
+ return static_cast<const T*>(data->data());
+ }
+
+ static const T* get_end_ptr(const TensorData* data) {
+ return get_start_ptr(data) + data->byte_size() / sizeof(T);
+ }
+
+ const T* ptr;
+ const T* end_ptr;
+ size_t dense_index;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_AGG_VECTOR_ITERATOR_H_
diff --git a/fcp/aggregation/core/agg_vector_test.cc b/fcp/aggregation/core/agg_vector_test.cc
new file mode 100644
index 0000000..6bdb0c6
--- /dev/null
+++ b/fcp/aggregation/core/agg_vector_test.cc
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/agg_vector.h"
+
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/testing/test_data.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+
+template <typename T>
+using Pair = typename AggVectorIterator<T>::IndexValuePair;
+
+TEST(AggVectorTest, Size) {
+ auto t1 = Tensor::Create(DT_INT32, {}, CreateTestData({0}));
+ EXPECT_EQ(t1->AsAggVector<int>().size(), 1);
+
+ auto t2 = Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({0, 1, 2}));
+ EXPECT_EQ(t2->AsAggVector<float>().size(), 3);
+}
+
+TEST(AggVectorTest, PostIncrementIterator_ScalarTensor) {
+ auto t = Tensor::Create(DT_INT32, {}, CreateTestData({5}));
+ EXPECT_THAT(t->AsAggVector<int>(), ElementsAre(Pair<int>{0, 5}));
+}
+
+TEST(AggVectorTest, PostIncrementIterator_DenseTensor) {
+ auto t = Tensor::Create(DT_INT32, {2}, CreateTestData({3, 14}));
+ EXPECT_THAT(t->AsAggVector<int>(),
+ ElementsAre(Pair<int>{0, 3}, Pair<int>{1, 14}));
+}
+
+TEST(AggVectorTest, PostIncrementIterator_ForLoopIterator) {
+ auto t = Tensor::Create(DT_FLOAT, {4}, CreateTestData<float>({2, 3, 4, 5}));
+ float sum = 0;
+ size_t expected_index = 0;
+ for (auto [index, value] : t->AsAggVector<float>()) {
+ EXPECT_THAT(index, Eq(expected_index++));
+ sum += value;
+ }
+ EXPECT_THAT(sum, Eq(14));
+}
+
+TEST(AggVectorTest, PreIncrementIterator) {
+ auto t = Tensor::Create(DT_FLOAT, {4}, CreateTestData<float>({2, 3, 4, 5}));
+ auto agg_vector = t->AsAggVector<float>();
+ float sum = 0;
+ size_t expected_index = 0;
+ for (auto it = agg_vector.begin(); it != agg_vector.end(); it++) {
+ EXPECT_THAT(it.index(), Eq(expected_index++));
+ sum += it.value();
+ }
+ EXPECT_THAT(sum, Eq(14));
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/aggregator.h b/fcp/aggregation/core/aggregator.h
new file mode 100644
index 0000000..3872285
--- /dev/null
+++ b/fcp/aggregation/core/aggregator.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_AGGREGATOR_H_
+#define FCP_AGGREGATION_CORE_AGGREGATOR_H_
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+// Abstract base for aggregators that compute an aggregate of input items of
+// type T into a final aggregate of type R using a multi-stage process in which
+// items are first partially aggregated at an intermediate layer, then the
+// partial aggregates are further combined, and finally projected into the
+// result. This multi-stage process consists of the following:
+// a) The aggregator is created with a zero value of an arbitrary intermediate
+// type U. Please note that the type U is never surfaced and considered an
+// implementation detail, so it doesn't need to be explicitlty parameterized.
+// b) The method Accumulate is used to accumulate T-typed client items into the
+// U-typed partial aggregate.
+// c) The method Merge is used to merge the intermediate U-typed aggregates of
+// the two aggregator instances producing a merged U-typed aggregate.
+// d) The method Report is used to project the top-level U-typed aggregate into
+// the final R-typed result.
+// The typename Self is used to specify the actual derived class.
+template <typename T, typename R, typename Self>
+class Aggregator {
+ public:
+ Aggregator() = default;
+ virtual ~Aggregator() = default;
+
+ // Aggregator derived classes are not copyable.
+ Aggregator(const Aggregator&) = delete;
+
+ // Accumulates an input into the intermediate aggregate.
+ // The method may fail if the input isn't compatible with the current
+ // Aggregator or if the Aggregator instance has already been 'consumed'.
+ virtual Status Accumulate(T input) = 0;
+
+ // Merges intermediate aggregates from the other Aggregator instance into the
+ // current Aggregator instance. Doing so 'consumes' the other Aggregator
+ // instance.
+ // The method may fail if the two Aggregator instances aren't compatible.
+ virtual Status MergeWith(Self&& other) = 0;
+
+ // Returns true if the current Aggregator instance can produce a report, for
+ // example if a sufficient number of inputs has been accumulated.
+ virtual bool CanReport() const = 0;
+
+ // Produces the final report, 'consuming' the current Aggregator instance.
+ // Once the current instance is consumed it can no longer perform any
+ // operations.
+ // This method fails when CanReport method returns false.
+ virtual StatusOr<R> Report() && = 0;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_AGGREGATOR_H_
diff --git a/fcp/aggregation/core/composite_key_combiner.cc b/fcp/aggregation/core/composite_key_combiner.cc
new file mode 100644
index 0000000..dc47cf1
--- /dev/null
+++ b/fcp/aggregation/core/composite_key_combiner.cc
@@ -0,0 +1,283 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/aggregation/core/composite_key_combiner.h"
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/input_tensor_list.h"
+#include "fcp/aggregation/core/mutable_vector_data.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor.pb.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/core/vector_string_data.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+namespace {
+
+template <typename T>
+bool CheckDataTypeSupported() {
+ return sizeof(T) <= sizeof(uint64_t);
+}
+
+template <>
+bool CheckDataTypeSupported<string_view>() {
+ // We will store the representation of a pointer to the string as an integer,
+ // so ensure the size of a pointer is less than or equal to the size of a
+ // 64-bit integer.
+ return sizeof(intptr_t) == sizeof(uint64_t);
+}
+
+// Copies the bytes pointed to by source_ptr to the destination pointed to by
+// dest_ptr and advances source_ptr to the next T.
+//
+// The number of bytes copied will be the size of the type T.
+//
+// It is the responsibility of the caller to ensure that source_ptr is only used
+// in subsequent code if it still points to a valid T after being incremented.
+template <typename T>
+void CopyToDest(const void*& source_ptr, uint64_t* dest_ptr,
+ std::unordered_set<std::string>& intern_pool) {
+ auto typed_source_ptr = static_cast<const T*>(source_ptr);
+ // Cast the pointer to the destination storage to a pointer to type T and set
+ // the value it points to to the source value. This allows us to copy the
+ // number of bytes in T to the destination storage even if T is smaller than
+ // uint64_t, without copying extra bytes.
+ // It would be problematic if T is larger than uint64_t, but the Create method
+ // validated that this was not the case.
+ T* typed_dest_ptr = reinterpret_cast<T*>(dest_ptr);
+ *typed_dest_ptr = *typed_source_ptr;
+ // Set source_ptr to point to the next T assuming that it points to
+ // an array of T.
+ source_ptr = static_cast<const void*>(++typed_source_ptr);
+}
+
+// Specialization of CopyToDest for DT_STRING data type that interns the
+// string_view pointed to by value_ptr. The address of the string in the
+// intern pool is then converted to a 64 bit integer and copied to the
+// destination pointed to by dest_ptr. Finally source_ptr is incremented to the
+// next string_view.
+//
+// It is the responsibility of the caller to ensure that source_ptr is only used
+// in subsequent code if it still points to a valid string_view after being
+// incremented.
+template <>
+void CopyToDest<string_view>(const void*& source_ptr, uint64_t* dest_ptr,
+ std::unordered_set<std::string>& intern_pool) {
+ auto string_view_ptr = static_cast<const string_view*>(source_ptr);
+ // Insert the string into the intern pool if it does not already exist. This
+ // makes a copy of the string so that the intern pool owns the storage.
+ auto it = intern_pool.emplace(*string_view_ptr).first;
+ // The iterator of an unordered set may be invalidated by inserting more
+ // elements, but the pointer to the underlying element is guaranteed to be
+ // stable. https://en.cppreference.com/w/cpp/container/unordered_set
+ // Thus, get the address of the string after dereferencing the iterator.
+ const std::string* interned_string_ptr = &*it;
+ // The stable address of the string can be interpreted as a 64-bit integer.
+ intptr_t ptr_int = reinterpret_cast<intptr_t>(interned_string_ptr);
+ // Set the destination storage to the integer representation of the string
+ // address.
+ *dest_ptr = static_cast<uint64_t>(ptr_int);
+ // Set the source_ptr to point to the next string_view assuming that it points
+ // to an array of string_view.
+ source_ptr = static_cast<const void*>(++string_view_ptr);
+}
+
+// Given a vector of uint64_t pointers, where the data pointed to can be safely
+// interpreted as type T, returns a Tensor of underlying data type
+// corresponding to T and the same length as the input vector. Each element of
+// the tensor is created by interpreting the data pointed to by the uint64_t
+// pointer at that index as type T.
+template <typename T>
+StatusOr<Tensor> GetTensorForType(
+ const std::vector<const uint64_t*>& key_iters) {
+ auto output_tensor_data = std::make_unique<MutableVectorData<T>>();
+ output_tensor_data->reserve(key_iters.size());
+ for (const uint64_t* key_it : key_iters) {
+ const T* ptr = reinterpret_cast<const T*>(key_it);
+ output_tensor_data->push_back(*ptr);
+ }
+ return Tensor::Create(internal::TypeTraits<T>::kDataType,
+ TensorShape{key_iters.size()},
+ std::move(output_tensor_data));
+}
+
+// Specialization of GetTensorForType for DT_STRING data type.
+// Given a vector of char pointers, where the data pointed to can be safely
+// interpreted as a pointer to a string, returns a tensor of type DT_STRING
+// and the same length as the input vector containing these strings.
+// The returned tensor will own all strings it refers to and is thus safe to
+// use after this class is destroyed.
+template <>
+StatusOr<Tensor> GetTensorForType<string_view>(
+ const std::vector<const uint64_t*>& key_iters) {
+ std::vector<std::string> strings_for_output;
+ for (auto key_it = key_iters.begin(); key_it != key_iters.end(); ++key_it) {
+ const intptr_t* ptr_to_string_address =
+ reinterpret_cast<const intptr_t*>(*key_it);
+ // The integer stored to represent a string is the address of the string
+ // stored in the intern_pool_. Thus this integer can be safely cast to a
+ // pointer and dereferenced to obtain the string.
+ const std::string* ptr =
+ reinterpret_cast<const std::string*>(*ptr_to_string_address);
+ strings_for_output.push_back(*ptr);
+ }
+ return Tensor::Create(
+ DT_STRING, TensorShape{key_iters.size()},
+ std::make_unique<VectorStringData>(std::move(strings_for_output)));
+}
+
+} // namespace
+
+CompositeKeyCombiner::CompositeKeyCombiner(std::vector<DataType> dtypes)
+ : dtypes_(dtypes) {
+ for (DataType dtype : dtypes) {
+ // Initialize to false to satisfy compiler that all cases in the DTYPE_CASES
+ // switch statement are covered, even though the cases that don't result in
+ // a value for data_type_supported will actually crash the program.
+ bool data_type_supported = false;
+ DTYPE_CASES(dtype, T, data_type_supported = CheckDataTypeSupported<T>());
+ FCP_CHECK(data_type_supported)
+ << "Unsupported data type for CompositeKeyCombiner: " << dtype;
+ }
+}
+
+// Returns a single tensor containing the ordinals of the composite keys
+// formed from the InputTensorList.
+StatusOr<Tensor> CompositeKeyCombiner::Accumulate(
+ const InputTensorList& tensors) {
+ FCP_ASSIGN_OR_RETURN(TensorShape shape, CheckValidAndGetShape(tensors));
+
+ // Determine the serialized size of the composite keys.
+ size_t composite_key_size = sizeof(uint64_t) * tensors.size();
+
+ std::vector<const void*> iterators;
+ iterators.reserve(tensors.size());
+ for (const Tensor* t : tensors) {
+ iterators.push_back(t->data().data());
+ }
+
+ // Iterate over all the TensorDataIterators at once to get the value for the
+ // composite key.
+ auto ordinals = std::make_unique<MutableVectorData<int64_t>>();
+ for (int i = 0; i < shape.NumElements(); ++i) {
+ // Create a string with the correct amount of memory to store an int64
+ // representation of the element in each input tensor at the current
+ // index.
+ std::string composite_key_data(composite_key_size, '\0');
+ uint64_t* key_ptr = reinterpret_cast<uint64_t*>(composite_key_data.data());
+
+ for (int j = 0; j < tensors.size(); ++j) {
+ // Copy the 64-bit representation of the element into the position in the
+ // composite key data corresponding to this tensor.
+ DTYPE_CASES(dtypes_[j], T,
+ CopyToDest<T>(iterators[j], key_ptr++, intern_pool_));
+ }
+ auto [it, inserted] = composite_keys_.insert(
+ {std::move(composite_key_data), composite_key_next_});
+ if (inserted) {
+ // This is the first time this CompositeKeyCombiner has encountered this
+ // particular composite key.
+ composite_key_next_++;
+ // Save the string representation of the key in order to recover the
+ // elements of the key when GetOutputKeys is called.
+ key_vec_.push_back(it->first);
+ }
+ // Insert the ordinal representing the composite key into the
+ // correct position in the output tensor.
+ ordinals->push_back(it->second);
+ }
+ return Tensor::Create(internal::TypeTraits<int64_t>::kDataType, shape,
+ std::move(ordinals));
+}
+
+StatusOr<std::vector<Tensor>> CompositeKeyCombiner::GetOutputKeys() const {
+ std::vector<Tensor> output_keys;
+ // Creating empty tensors is not allowed, so if there are no keys yet,
+ // which could happen if GetOutputKeys is called before Accumulate, return
+ // an empty vector.
+ if (key_vec_.empty()) return output_keys;
+ // Otherwise reserve space for a tensor for each data type.
+ output_keys.reserve(dtypes_.size());
+ std::vector<const uint64_t*> key_iters;
+ key_iters.reserve(key_vec_.size());
+ for (string_view s : key_vec_) {
+ key_iters.push_back(reinterpret_cast<const uint64_t*>(s.data()));
+ }
+ for (DataType dtype : dtypes_) {
+ StatusOr<Tensor> t;
+ DTYPE_CASES(dtype, T, t = GetTensorForType<T>(key_iters));
+ FCP_RETURN_IF_ERROR(t.status());
+ output_keys.push_back(std::move(t.value()));
+ for (auto key_it = key_iters.begin(); key_it != key_iters.end(); ++key_it) {
+ ++*key_it;
+ }
+ }
+ return output_keys;
+}
+
+StatusOr<TensorShape> CompositeKeyCombiner::CheckValidAndGetShape(
+ const InputTensorList& tensors) {
+ if (tensors.size() == 0) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "InputTensorList must contain at least one tensor.";
+ } else if (tensors.size() != dtypes_.size()) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "InputTensorList size " << tensors.size()
+ << "is not the same as the length of expected dtypes "
+ << dtypes_.size();
+ }
+ // All the tensors in the input list should have the same shape and have
+ // a dense encoding.
+ const TensorShape* shape = nullptr;
+ for (int i = 0; i < tensors.size(); ++i) {
+ const Tensor* t = tensors[i];
+ if (shape == nullptr) {
+ shape = &t->shape();
+ } else {
+ if (*shape != t->shape()) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "All tensors in the InputTensorList must have the expected "
+ "shape.";
+ }
+ }
+ if (!t->is_dense())
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "All tensors in the InputTensorList must be dense.";
+ // Ensure the data types of the input tensors match those provided to the
+ // constructor of this CompositeKeyCombiner.
+ DataType expected_dtype = dtypes_[i];
+ if (expected_dtype != t->dtype()) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "Tensor did not have expected dtype " << expected_dtype
+ << " and instead had dtype " << t->dtype();
+ }
+ }
+ return *shape;
+}
+
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/composite_key_combiner.h b/fcp/aggregation/core/composite_key_combiner.h
new file mode 100644
index 0000000..fa7b7ca
--- /dev/null
+++ b/fcp/aggregation/core/composite_key_combiner.h
@@ -0,0 +1,129 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_COMPOSITE_KEY_COMBINER_H_
+#define FCP_AGGREGATION_CORE_COMPOSITE_KEY_COMBINER_H_
+
+#include <cstdint>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/input_tensor_list.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor.pb.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+// Class operating on sets of tensors of the same shape to combine indices for
+// which the same combination of elements occurs, or in other words, indices
+// containing the same composite key.
+//
+// This class contains two methods: Accumulate and GetOutputKeys, which can each
+// be called multiple times.
+//
+// Accumulate takes in an InputTensorList of tensors of the same shape, and
+// returns a Tensor of the same shape containing ordinals to represent the
+// composite key that exists at each index. Composite keys are stored
+// across calls to Accumulate, so if the same composite key is ever encountered
+// in two different indices, whether in the same or a different call to
+// Accumulate, the same ordinal will be returned in both these indices.
+//
+// GetOutputKeys returns the composite keys that have been seen in all previous
+// calls to Accumulate, represented by a vector of Tensors. If the ordinal
+// returned by Accumulate for that composite key was i, the composite key will
+// be found at position i in the output vector.
+//
+// This class is not threadsafe.
+class CompositeKeyCombiner {
+ public:
+ ~CompositeKeyCombiner() = default;
+
+ // CompositeKeyCombiner is not copyable or moveable.
+ CompositeKeyCombiner(const CompositeKeyCombiner&) = delete;
+ CompositeKeyCombiner& operator=(const CompositeKeyCombiner&) = delete;
+ CompositeKeyCombiner(CompositeKeyCombiner&&) = delete;
+ CompositeKeyCombiner& operator=(CompositeKeyCombiner&&) = delete;
+
+ // Creates a CompositeKeyCombiner if inputs are valid or crashes otherwise.
+ explicit CompositeKeyCombiner(std::vector<DataType> dtypes);
+
+ // Returns a single tensor containing the ordinals of the composite keys
+ // formed from the tensors in the InputTensorList.
+ //
+ // The shape of each of the input tensors must match the shape provided to the
+ // constructor, and the dtypes of the input tensors must match the dtypes
+ // provided to the constructor.
+ //
+ // For each index in the input tensors, the combination of elements from each
+ // tensor at that index forms a "composite key." Across calls to Accumulate,
+ // each unique composite key will be represented by a unique ordinal.
+ //
+ // The returned tensor is of data type DT_INT64 and the same shape that was
+ // provided to the constructor.
+ StatusOr<Tensor> Accumulate(const InputTensorList& tensors);
+
+ // Obtains the vector of output keys ordered by their representative ordinal.
+ //
+ // The datatypes of the tensors in the output vector will match the data types
+ // provided to the constructor.
+ //
+ // For each unique combination of elements that was seen across all calls to
+ // Accumulate on this class so far, the vector of output tensors will include
+ // that combination of elements. The ordering of the elements within the
+ // output tensors will correspond to the ordinals returned by Accumulate. For
+ // example, if Accumulate returned the integer 5 in the output tensor at
+ // position 8 when it encountered this combination of elements in the input
+ // tensor list at position 8, then the elements in the composite key will
+ // appear at position 5 in the output tensors returned by this method.
+ StatusOr<std::vector<Tensor>> GetOutputKeys() const;
+
+ private:
+ // Checks that the provided InputTensorList can be accumulated into this
+ // CompositeKeyCombiner.
+ StatusOr<TensorShape> CheckValidAndGetShape(const InputTensorList& tensors);
+
+ // The data types of the tensors in valid inputs to Accumulate, in this exact
+ // order.
+ // TODO(team): Use inlined vector to store the DataTypes instead.
+ std::vector<DataType> dtypes_;
+ // String views of the composite keys in the order the keys will appear in the
+ // output tensors returned by GetOutputKeys.
+ std::vector<string_view> key_vec_;
+ // Set of unique strings encountered in tensors of type DT_STRING on calls to
+ // Accumulate.
+ // Used as an optimization to avoid storing the same string multiple
+ // times even if it appears in many composite keys.
+ // TODO(team): Intern directly into the output tensor instead to avoid
+ // copies when creating the output tensors.
+ std::unordered_set<std::string> intern_pool_;
+ // Mapping of string representations of the composite keys seen so far to
+ // their ordinal position in the output tensors returned by GetOutputKeys.
+ std::unordered_map<std::string, int64_t> composite_keys_;
+ // Number of unique composite keys encountered so far across all calls to
+ // Accumulate.
+ int64_t composite_key_next_ = 0;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_COMPOSITE_KEY_COMBINER_H_
diff --git a/fcp/aggregation/core/composite_key_combiner_test.cc b/fcp/aggregation/core/composite_key_combiner_test.cc
new file mode 100644
index 0000000..ccaf2c6
--- /dev/null
+++ b/fcp/aggregation/core/composite_key_combiner_test.cc
@@ -0,0 +1,257 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/aggregation/core/composite_key_combiner.h"
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/agg_vector.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/input_tensor_list.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor.pb.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/testing/test_data.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+using testing::Eq;
+using testing::IsEmpty;
+
+TEST(CompositeKeyCombinerTest, EmptyInput_Invalid) {
+ CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT});
+ StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({}));
+ ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CompositeKeyCombinerTest, InputWithWrongShapeTensor_Invalid) {
+ CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT, DT_INT32});
+ Tensor t1 =
+ Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData<int32_t>({1, 2, 3, 4}))
+ .value();
+ StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({&t1, &t2}));
+ ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CompositeKeyCombinerTest, InputWithTooFewTensors_Invalid) {
+ CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT, DT_INT32});
+ Tensor t1 =
+ Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({1, 2, 3})).value();
+ StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({&t1}));
+ ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CompositeKeyCombinerTest, InputWithTooManyTensors_Invalid) {
+ CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT, DT_INT32});
+ Tensor t1 =
+ Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({1, 2, 3})).value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({4, 5, 6})).value();
+ StatusOr<Tensor> result =
+ combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
+ ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CompositeKeyCombinerTest, InputWithWrongTypes_Invalid) {
+ CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT, DT_STRING});
+ Tensor t1 =
+ Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({1, 2, 3})).value();
+ StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({&t1, &t2}));
+ ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CompositeKeyCombinerTest, OutputBeforeAccumulate_Empty) {
+ CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT});
+ StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
+ ASSERT_OK(output);
+ EXPECT_THAT(output.value(), IsEmpty());
+}
+
+TEST(CompositeKeyCombinerTest, AccumulateAndOutput_SingleElement) {
+ CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT});
+ Tensor t1 =
+ Tensor::Create(DT_FLOAT, {1}, CreateTestData<float>({1.3})).value();
+ StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({&t1}));
+ ASSERT_OK(result);
+ EXPECT_THAT(result.value(), IsTensor<int64_t>({1}, {0}));
+
+ StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
+ ASSERT_OK(output);
+ EXPECT_THAT(output.value().size(), Eq(1));
+ EXPECT_THAT(output.value()[0], IsTensor<float>({1}, {1.3}));
+}
+
+TEST(CompositeKeyCombinerTest, AccumulateAndOutput_NumericTypes) {
+ CompositeKeyCombiner combiner(
+ std::vector<DataType>{DT_FLOAT, DT_INT32, DT_INT64});
+ Tensor t1 =
+ Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({1, 2, 3})).value();
+ Tensor t3 =
+ Tensor::Create(DT_INT64, {3}, CreateTestData<int64_t>({4, 5, 6})).value();
+ StatusOr<Tensor> result =
+ combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
+ ASSERT_OK(result);
+ EXPECT_THAT(result.value(), IsTensor<int64_t>({3}, {0, 1, 2}));
+
+ StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
+ ASSERT_OK(output);
+ EXPECT_THAT(output.value().size(), Eq(3));
+ EXPECT_THAT(output.value()[0], IsTensor<float>({3}, {1.1, 1.2, 1.3}));
+ EXPECT_THAT(output.value()[1], IsTensor<int32_t>({3}, {1, 2, 3}));
+ EXPECT_THAT(output.value()[2], IsTensor<int64_t>({3}, {4, 5, 6}));
+}
+
+TEST(CompositeKeyCombinerTest,
+ NumericTypes_SameKeysResultInSameOrdinalsAcrossAccumulateCalls) {
+ CompositeKeyCombiner combiner(
+ std::vector<DataType>{DT_FLOAT, DT_INT32, DT_INT64});
+ Tensor t1 =
+ Tensor::Create(DT_FLOAT, {4}, CreateTestData<float>({1.1, 1.2, 1.1, 1.2}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData<int32_t>({1, 2, 3, 2}))
+ .value();
+ Tensor t3 =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({4, 5, 6, 5}))
+ .value();
+ StatusOr<Tensor> result1 =
+ combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
+ ASSERT_OK(result1);
+ EXPECT_THAT(result1.value(), IsTensor<int64_t>({4}, {0, 1, 2, 1}));
+
+ // Across different calls to Accumulate, tensors can have different shape.
+ Tensor t4 = Tensor::Create(DT_FLOAT, {5},
+ CreateTestData<float>({1.2, 1.1, 1.1, 1.1, 1.2}))
+ .value();
+ Tensor t5 =
+ Tensor::Create(DT_INT32, {5}, CreateTestData<int32_t>({2, 3, 2, 3, 2}))
+ .value();
+ Tensor t6 =
+ Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({5, 6, 5, 6, 5}))
+ .value();
+ StatusOr<Tensor> result2 =
+ combiner.Accumulate(InputTensorList({&t4, &t5, &t6}));
+ ASSERT_OK(result2);
+ EXPECT_THAT(result2.value(), IsTensor<int64_t>({5}, {1, 2, 3, 2, 1}));
+
+ StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
+ ASSERT_OK(output);
+ EXPECT_THAT(output.value().size(), Eq(3));
+ EXPECT_THAT(output.value()[0], IsTensor<float>({4}, {1.1, 1.2, 1.1, 1.1}));
+ EXPECT_THAT(output.value()[1], IsTensor<int32_t>({4}, {1, 2, 3, 2}));
+ EXPECT_THAT(output.value()[2], IsTensor<int64_t>({4}, {4, 5, 6, 5}));
+}
+
+TEST(CompositeKeyCombinerTest, AccumulateAndOutput_StringTypes) {
+ CompositeKeyCombiner combiner(
+ std::vector<DataType>{DT_FLOAT, DT_STRING, DT_STRING});
+ Tensor t1 =
+ Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
+ .value();
+ Tensor t2 = Tensor::Create(DT_STRING, {3},
+ CreateTestData<string_view>({"abc", "de", ""}))
+ .value();
+ Tensor t3 =
+ Tensor::Create(DT_STRING, {3},
+ CreateTestData<string_view>({"fghi", "jklmn", "o"}))
+ .value();
+ StatusOr<Tensor> result =
+ combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
+ ASSERT_OK(result);
+ EXPECT_THAT(result.value(), IsTensor<int64_t>({3}, {0, 1, 2}));
+
+ StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
+ ASSERT_OK(output);
+ EXPECT_THAT(output.value().size(), Eq(3));
+ EXPECT_THAT(output.value()[0], IsTensor<float>({3}, {1.1, 1.2, 1.3}));
+ EXPECT_THAT(output.value()[1], IsTensor<string_view>({3}, {"abc", "de", ""}));
+ EXPECT_THAT(output.value()[2],
+ IsTensor<string_view>({3}, {"fghi", "jklmn", "o"}));
+}
+
+TEST(CompositeKeyCombinerTest,
+ StringTypes_SameCompositeKeysResultInSameOrdinalsAcrossAccumulateCalls) {
+ CompositeKeyCombiner combiner(
+ std::vector<DataType>{DT_FLOAT, DT_STRING, DT_STRING});
+ Tensor t1 =
+ Tensor::Create(DT_FLOAT, {4}, CreateTestData<float>({1.1, 1.2, 1.2, 1.3}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_STRING, {4},
+ CreateTestData<string_view>({"abc", "de", "de", ""}))
+ .value();
+ Tensor t3 = Tensor::Create(
+ DT_STRING, {4},
+ CreateTestData<string_view>({"fghi", "jklmn", "jklmn", "o"}))
+ .value();
+ StatusOr<Tensor> result1 =
+ combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
+ ASSERT_OK(result1);
+ EXPECT_THAT(result1.value(), IsTensor<int64_t>({4}, {0, 1, 1, 2}));
+
+ // Across different calls to Accumulate, tensors can have different shape.
+ Tensor t4 = Tensor::Create(DT_FLOAT, {5},
+ CreateTestData<float>({1.3, 1.4, 1.1, 1.2, 1.1}))
+ .value();
+ Tensor t5 = Tensor::Create(
+ DT_STRING, {5},
+ CreateTestData<string_view>({"", "abc", "abc", "de", "abc"}))
+ .value();
+ Tensor t6 =
+ Tensor::Create(
+ DT_STRING, {5},
+ CreateTestData<string_view>({"o", "pqrs", "fghi", "jklmn", "fghi"}))
+ .value();
+ StatusOr<Tensor> result2 =
+ combiner.Accumulate(InputTensorList({&t4, &t5, &t6}));
+ ASSERT_OK(result2);
+ EXPECT_THAT(result2.value(), IsTensor<int64_t>({5}, {2, 3, 0, 1, 0}));
+
+ StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
+ EXPECT_THAT(output.value().size(), Eq(3));
+ EXPECT_THAT(output.value()[0], IsTensor<float>({4}, {1.1, 1.2, 1.3, 1.4}));
+ EXPECT_THAT(output.value()[1],
+ IsTensor<string_view>({4}, {"abc", "de", "", "abc"}));
+ EXPECT_THAT(output.value()[2],
+ IsTensor<string_view>({4}, {"fghi", "jklmn", "o", "pqrs"}));
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/datatype.h b/fcp/aggregation/core/datatype.h
new file mode 100644
index 0000000..3c28289
--- /dev/null
+++ b/fcp/aggregation/core/datatype.h
@@ -0,0 +1,128 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_DATATYPE_H_
+#define FCP_AGGREGATION_CORE_DATATYPE_H_
+
+#include <cstdint>
+
+#include "fcp/base/monitoring.h"
+
+#ifndef FCP_NANOLIBC
+#include "absl/strings/string_view.h"
+#include "fcp/aggregation/core/tensor.pb.h"
+#endif
+
+namespace fcp {
+namespace aggregation {
+
+#ifndef FCP_NANOLIBC
+// Unless when building with Nanolibc, we can use absl::string_view directly.
+using string_view = absl::string_view;
+#else
+// TODO(team): Minimal implementation of string_view for bare-metal
+// environment.
+struct string_view {};
+#endif
+
+#ifdef FCP_NANOLIBC
+// TODO(team): Derive these values from tensor.proto built with Nanopb
+enum DataType {
+ // The constants below should be kept in sync with tensorflow::Datatype:
+ // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto
+ // While not strictly required, that has a number of benefits.
+ DT_INVALID = 0,
+ DT_FLOAT = 1,
+ DT_DOUBLE = 2,
+ DT_INT32 = 3,
+ DT_STRING = 7,
+ DT_INT64 = 9,
+
+ // TODO(team): Add other types.
+ // This should be a small subset of tensorflow::DataType types and include
+ // only simple numeric types and floating point types.
+ //
+ // When a tensor DT_ type is added here, it must also be added to the list of
+ // MATCH_TYPE_AND_DTYPE macros below and to the CASES macro.
+};
+#endif // FCP_NANOLIBC
+
+namespace internal {
+
+// This struct is used to map typename T to DataType and specify other traits
+// of typename T.
+template <typename T>
+struct TypeTraits {
+ constexpr static DataType kDataType = DT_INVALID;
+};
+
+#define MATCH_TYPE_AND_DTYPE(TYPE, DTYPE) \
+ template <> \
+ struct TypeTraits<TYPE> { \
+ constexpr static DataType kDataType = DTYPE; \
+ }
+
+// Mapping of native types to DT_ types.
+// TODO(team): Add other types.
+MATCH_TYPE_AND_DTYPE(float, DT_FLOAT);
+MATCH_TYPE_AND_DTYPE(double, DT_DOUBLE);
+MATCH_TYPE_AND_DTYPE(int32_t, DT_INT32);
+MATCH_TYPE_AND_DTYPE(int64_t, DT_INT64);
+MATCH_TYPE_AND_DTYPE(string_view, DT_STRING);
+
+// The macros DTYPE_CASE and DTYPE_CASES are used to translate Tensor DataType
+// to strongly typed calls of code parameterized with the template typename
+// TYPE_ARG.
+//
+// For example, let's say there is a function that takes an AggVector<T>:
+// template <typename T>
+// void DoSomething(AggVector<T> agg_vector) { ... }
+//
+// Given a Tensor, the following code can be used to make a DoSomething call:
+// DTYPE_CASES(tensor.dtype(), T, DoSomething(tensor.AsAggVector<T>()));
+//
+// The second parameter specifies the type argument to be used as the template
+// parameter in the statement in the third argument.
+
+#define SINGLE_ARG(...) __VA_ARGS__
+#define DTYPE_CASE(TYPE, TYPE_ARG, STMTS) \
+ case internal::TypeTraits<TYPE>::kDataType: { \
+ typedef TYPE TYPE_ARG; \
+ STMTS; \
+ break; \
+ }
+
+// TODO(team): Add other types.
+#define DTYPE_CASES(TYPE_ENUM, TYPE_ARG, STMTS) \
+ switch (TYPE_ENUM) { \
+ DTYPE_CASE(float, TYPE_ARG, SINGLE_ARG(STMTS)) \
+ DTYPE_CASE(double, TYPE_ARG, SINGLE_ARG(STMTS)) \
+ DTYPE_CASE(int32_t, TYPE_ARG, SINGLE_ARG(STMTS)) \
+ DTYPE_CASE(int64_t, TYPE_ARG, SINGLE_ARG(STMTS)) \
+ DTYPE_CASE(string_view, TYPE_ARG, SINGLE_ARG(STMTS)) \
+ case DT_INVALID: \
+ FCP_LOG(FATAL) << "Invalid type"; \
+ break; \
+ default: \
+ FCP_LOG(FATAL) << "Unknown type"; \
+ }
+
+} // namespace internal
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_DATATYPE_H_
diff --git a/fcp/aggregation/core/federated_sum.cc b/fcp/aggregation/core/federated_sum.cc
new file mode 100644
index 0000000..11dee3f
--- /dev/null
+++ b/fcp/aggregation/core/federated_sum.cc
@@ -0,0 +1,88 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/aggregation/core/agg_vector_aggregator.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
+#include "fcp/aggregation/core/tensor_aggregator_registry.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+// Implementation of a generic sum aggregator.
+template <typename T>
+class FederatedSum final : public AggVectorAggregator<T> {
+ public:
+ using AggVectorAggregator<T>::AggVectorAggregator;
+ using AggVectorAggregator<T>::data;
+
+ private:
+ void AggregateVector(const AggVector<T>& agg_vector) override {
+ for (auto v : agg_vector) {
+ data()[v.index] += v.value;
+ }
+ }
+};
+
+template <typename T>
+StatusOr<std::unique_ptr<TensorAggregator>> CreateFederatedSum(
+ DataType dtype, TensorShape shape) {
+ return std::unique_ptr<TensorAggregator>(new FederatedSum<T>(dtype, shape));
+}
+
+// Not supported for DT_STRING
+template <>
+StatusOr<std::unique_ptr<TensorAggregator>> CreateFederatedSum<string_view>(
+ DataType dtype, TensorShape shape) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "FederatedSum isn't supported for DT_STRING datatype.";
+}
+
+// Factory class for the FederatedSum.
+class FederatedSumFactory final : public TensorAggregatorFactory {
+ public:
+ FederatedSumFactory() = default;
+
+ // FederatedSumFactory isn't copyable or moveable.
+ FederatedSumFactory(const FederatedSumFactory&) = delete;
+ FederatedSumFactory& operator=(const FederatedSumFactory&) = delete;
+
+ StatusOr<std::unique_ptr<TensorAggregator>> Create(
+ DataType dtype, TensorShape shape) const override {
+ StatusOr<std::unique_ptr<TensorAggregator>> aggregator;
+ DTYPE_CASES(dtype, T,
+ aggregator = CreateFederatedSum<T>(dtype, std::move(shape)));
+ return aggregator;
+ }
+};
+
+// TODO(team): Revise the registration mechanism below.
+#ifdef FCP_BAREMETAL
+extern "C" void RegisterFederatedSum() {
+ RegisterAggregatorFactory("federated_sum", new FederatedSumFactory());
+}
+#else
+REGISTER_AGGREGATOR_FACTORY("federated_sum", FederatedSumFactory);
+#endif
+
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/federated_sum_bench.cc b/fcp/aggregation/core/federated_sum_bench.cc
new file mode 100644
index 0000000..ab02264
--- /dev/null
+++ b/fcp/aggregation/core/federated_sum_bench.cc
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/mutable_vector_data.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
+#include "fcp/aggregation/core/tensor_aggregator_registry.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+
+// Open-source version of benchmarking library
+#include "benchmark//benchmark.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+constexpr static int64_t kLength = 1000000;
+
+static void BM_FederatedSumAccumulate(benchmark::State& state) {
+ auto aggregator = (*GetAggregatorFactory("federated_sum"))
+ ->Create(DT_INT64, {kLength})
+ .value();
+ auto test_data = std::make_unique<MutableVectorData<int64_t>>(kLength);
+ std::vector<int64_t>& input = *test_data;
+ for (int64_t i = 0; i < kLength; ++i) {
+ input[i] = i % 123;
+ }
+ auto tensor = Tensor::Create(DT_INT64, {kLength}, std::move(test_data));
+ auto items_processed = 0;
+ for (auto s : state) {
+ benchmark::DoNotOptimize(aggregator->Accumulate(*tensor));
+ items_processed += kLength;
+ }
+ state.SetItemsProcessed(items_processed);
+}
+
+BENCHMARK(BM_FederatedSumAccumulate);
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/federated_sum_test.cc b/fcp/aggregation/core/federated_sum_test.cc
new file mode 100644
index 0000000..21b5358
--- /dev/null
+++ b/fcp/aggregation/core/federated_sum_test.cc
@@ -0,0 +1,106 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
+#include "fcp/aggregation/core/tensor_aggregator_registry.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/testing/test_data.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+using ::testing::Eq;
+using ::testing::IsTrue;
+
+TEST(FederatedSumTest, ScalarAggregation_Succeeds) {
+ auto aggregator =
+ (*GetAggregatorFactory("federated_sum"))->Create(DT_INT32, {}).value();
+ Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
+ Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
+ Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
+ EXPECT_THAT(aggregator->Accumulate(t1), IsOk());
+ EXPECT_THAT(aggregator->Accumulate(t2), IsOk());
+ EXPECT_THAT(aggregator->Accumulate(t3), IsOk());
+ EXPECT_THAT(aggregator->CanReport(), IsTrue());
+
+ auto result = std::move(*aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value().size(), Eq(1));
+ EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
+}
+
+TEST(FederatedSumTest, DenseAggregation_Succeeds) {
+ const TensorShape shape = {4};
+ auto aggregator =
+ (*GetAggregatorFactory("federated_sum"))->Create(DT_INT32, shape).value();
+ Tensor t1 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({1, 3, 15, 27})).value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
+ EXPECT_THAT(aggregator->Accumulate(t1), IsOk());
+ EXPECT_THAT(aggregator->Accumulate(t2), IsOk());
+ EXPECT_THAT(aggregator->Accumulate(t3), IsOk());
+ EXPECT_THAT(aggregator->CanReport(), IsTrue());
+ EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
+
+ auto result = std::move(*aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value().size(), Eq(1));
+ EXPECT_THAT(result.value()[0], IsTensor(shape, {14, 19, 23, 49}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(AggVectorAggregationTest, Merge_Succeeds) {
+ auto aggregator1 =
+ (*GetAggregatorFactory("federated_sum"))->Create(DT_INT32, {}).value();
+ auto aggregator2 =
+ (*GetAggregatorFactory("federated_sum"))->Create(DT_INT32, {}).value();
+ Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
+ Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
+ Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
+ EXPECT_THAT(aggregator1->Accumulate(t1), IsOk());
+ EXPECT_THAT(aggregator2->Accumulate(t2), IsOk());
+ EXPECT_THAT(aggregator2->Accumulate(t3), IsOk());
+
+ EXPECT_THAT(aggregator1->MergeWith(std::move(*aggregator2)), IsOk());
+ EXPECT_THAT(aggregator1->CanReport(), IsTrue());
+ EXPECT_THAT(aggregator1->GetNumInputs(), Eq(3));
+
+ auto result = std::move(*aggregator1).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/input_tensor_list.cc b/fcp/aggregation/core/input_tensor_list.cc
new file mode 100644
index 0000000..cd0782d
--- /dev/null
+++ b/fcp/aggregation/core/input_tensor_list.cc
@@ -0,0 +1,98 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/aggregation/core/input_tensor_list.h"
+
+#include <cstddef>
+#include <initializer_list>
+#include <utility>
+#include <vector>
+
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/base/new.h"
+
+namespace fcp {
+namespace aggregation {
+
+InputTensorList::InputTensorList(std::initializer_list<const Tensor*> list)
+ : InputTensorList(list.size()) {
+ size_t i = 0;
+ for (const Tensor* t : list) {
+ data_ptr_[i++] = t;
+ }
+}
+
+InputTensorList::InputTensorList(size_t size)
+ : size_(size), is_allocated_(size > kInlinedSize) {
+ if (is_allocated_) {
+ // Since the `allocated` union member is a class with user-defined
+ // constructors and destructors, to switch the active member, explicit
+ // placement new is needed. See
+ // https://en.cppreference.com/w/cpp/language/union.
+ new (&data_storage_.allocated) std::vector<const Tensor*>(size);
+ data_ptr_ = data_storage_.allocated.data();
+ } else {
+ // Use the new syntax to initialize elements to nullptr.
+ new (&data_storage_.inlined) int[size]();
+ data_ptr_ = data_storage_.inlined;
+ }
+}
+
+InputTensorList::InputTensorList(InputTensorList&& other)
+ : size_(other.size_), is_allocated_(other.is_allocated_) {
+ MoveData(std::move(other));
+}
+
+InputTensorList& InputTensorList::operator=(InputTensorList&& other) {
+ // Destroy any existing allocated storage.
+ if (is_allocated_) {
+ data_storage_.allocated.~vector();
+ }
+ size_ = other.size_;
+ is_allocated_ = other.is_allocated_;
+ MoveData(std::move(other));
+ return *this;
+}
+
+void InputTensorList::MoveData(InputTensorList&& other) {
+ if (is_allocated_) {
+ new (&data_storage_.allocated) std::vector<const Tensor*>;
+ data_storage_.allocated = std::move(other.data_storage_.allocated);
+ data_ptr_ = data_storage_.allocated.data();
+ other.data_storage_.allocated.~vector();
+ } else {
+ // If the storage is inlined copy the data; this is cheap since
+ // size_ < kInlinedSize.
+ for (size_t i = 0; i < size_; ++i) {
+ data_storage_.inlined[i] = other.data_storage_.inlined[i];
+ }
+ data_ptr_ = data_storage_.inlined;
+ }
+ new (&other.data_storage_.inlined) int[0]();
+ other.size_ = 0;
+ other.is_allocated_ = false;
+}
+
+InputTensorList::~InputTensorList() {
+ // Since the `allocated` union member is a class with user-defined
+ // constructors and destructors, explicit destruction is needed. See
+ // https://en.cppreference.com/w/cpp/language/union.
+ if (is_allocated_) {
+ data_storage_.allocated.~vector();
+ }
+}
+
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/input_tensor_list.h b/fcp/aggregation/core/input_tensor_list.h
new file mode 100644
index 0000000..d8c0ee9
--- /dev/null
+++ b/fcp/aggregation/core/input_tensor_list.h
@@ -0,0 +1,97 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_AGGREGATION_CORE_INPUT_TENSOR_LIST_H_
+#define FCP_AGGREGATION_CORE_INPUT_TENSOR_LIST_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <vector>
+
+#include "fcp/aggregation/core/tensor.h"
+
+namespace fcp {
+namespace aggregation {
+
+// Maximum size of InputTensorList for which inlined storage will be used.
+// Any InputTensorList with more elements than kInlinedSize will use allocated
+// storage.
+// TODO(team): Determine optimal size for this constant based on
+// microbenchmarks.
+constexpr int32_t kInlinedSize = 5;
+
+// InputTensorList holds pointers to some number of unowned tensors to be used
+// as input to a function.
+//
+// For efficiency, if there are fewer than kInlinedSize tensors, the memory to
+// hold the pointers is inlined rather than allocated.
+class InputTensorList final {
+ public:
+ typedef const Tensor* const* const_iterator;
+
+ // Creates an InputTensorList with the provided elements.
+ InputTensorList(std::initializer_list<const Tensor*> list);
+
+ // Creates an InputTensorList with a single input tensor.
+ InputTensorList(const Tensor& tensor) : InputTensorList({&tensor}) {}
+
+ // Creates an InputTensorList of a specific size. All elements will initially
+ // be set to nullptr.
+ explicit InputTensorList(size_t size);
+
+ // InputTensorList class isn't copyable.
+ InputTensorList(const InputTensorList&) = delete;
+
+ // Move constructor.
+ InputTensorList(InputTensorList&& other);
+
+ // Move assignment.
+ InputTensorList& operator=(InputTensorList&& other);
+
+ ~InputTensorList();
+
+ inline const_iterator begin() const { return data_ptr_; }
+
+ inline const_iterator end() const { return data_ptr_ + size_; }
+
+ inline size_t size() const { return size_; }
+
+ inline const Tensor* const& operator[](size_t i) const {
+ return data_ptr_[i];
+ }
+
+ inline const Tensor*& operator[](size_t i) { return data_ptr_[i]; }
+
+ private:
+ union DataStorage {
+ constexpr DataStorage() : inlined{} {};
+ ~DataStorage() {}
+ const Tensor* inlined[kInlinedSize];
+ std::vector<const Tensor*> allocated;
+ };
+
+ void MoveData(InputTensorList&& other);
+
+ size_t size_;
+ bool is_allocated_;
+ DataStorage data_storage_;
+ const Tensor** data_ptr_;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_INPUT_TENSOR_LIST_H_
diff --git a/fcp/aggregation/core/input_tensor_list_test.cc b/fcp/aggregation/core/input_tensor_list_test.cc
new file mode 100644
index 0000000..710462c
--- /dev/null
+++ b/fcp/aggregation/core/input_tensor_list_test.cc
@@ -0,0 +1,366 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/aggregation/core/input_tensor_list.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/testing/test_data.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+using testing::Eq;
+using testing::Not;
+
+class InputTensorListTest : public testing::Test {
+ protected:
+ InputTensorListTest()
+ : t1_(Tensor::Create(DT_FLOAT, {1}, CreateTestData<float>({1})).value()),
+ t2_(Tensor::Create(DT_INT32, {2}, CreateTestData<int32_t>({2, 3}))
+ .value()),
+ t3_(Tensor::Create(DT_INT64, {3}, CreateTestData<int64_t>({4, 5, 6}))
+ .value()),
+ t4_(Tensor::Create(DT_FLOAT, {4}, CreateTestData<float>({7, 8, 9, 10}))
+ .value()),
+ t5_(Tensor::Create(DT_INT32, {5},
+ CreateTestData<int32_t>({11, 12, 13, 14, 15}))
+ .value()),
+ t6_(Tensor::Create(DT_INT64, {6},
+ CreateTestData<int64_t>({16, 17, 18, 19, 20, 21}))
+ .value()) {}
+
+ InputTensorList CreateInlined() {
+ return InputTensorList({&t1_, &t2_, &t3_});
+ }
+
+ InputTensorList CreateAllocated() {
+ return InputTensorList({&t1_, &t2_, &t3_, &t4_, &t5_, &t6_});
+ }
+
+ Tensor t1_;
+ Tensor t2_;
+ Tensor t3_;
+ Tensor t4_;
+ Tensor t5_;
+ Tensor t6_;
+};
+
+TEST_F(InputTensorListTest, Inlined_Size) {
+ InputTensorList tensor_list = CreateInlined();
+ EXPECT_THAT(tensor_list.size(), Eq(3));
+}
+
+TEST_F(InputTensorListTest, Inlined_Iterate) {
+ InputTensorList tensor_list = CreateInlined();
+ auto iter = tensor_list.begin();
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{1}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{2}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{3}));
+ iter++;
+ EXPECT_THAT(iter, Eq(tensor_list.end()));
+}
+
+TEST_F(InputTensorListTest, Inlined_MoveConstructor_Iterate) {
+ InputTensorList moved_tensor_list = CreateInlined();
+ InputTensorList tensor_list(std::move(moved_tensor_list));
+ auto iter = tensor_list.begin();
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{1}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{2}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{3}));
+ iter++;
+ EXPECT_THAT(iter, Eq(tensor_list.end()));
+}
+
+TEST_F(InputTensorListTest, Inlined_MoveAssignment_Iterate) {
+ InputTensorList moved_tensor_list = CreateInlined();
+ // Initially, create the tensor list as an allocated tensor list before
+ // assigning it to an inlined InputTensorList via move assignment.
+ InputTensorList tensor_list = CreateAllocated();
+ tensor_list = std::move(moved_tensor_list);
+ auto iter = tensor_list.begin();
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{1}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{2}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{3}));
+ iter++;
+ EXPECT_THAT(iter, Eq(tensor_list.end()));
+
+ // Assigning back to the moved variable is valid.
+ moved_tensor_list = std::move(tensor_list);
+ iter = moved_tensor_list.begin();
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{1}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{2}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{3}));
+ iter++;
+ EXPECT_THAT(iter, Eq(moved_tensor_list.end()));
+}
+
+TEST_F(InputTensorListTest, Inlined_ForEachLoop) {
+ InputTensorList tensor_list = CreateInlined();
+ uint64_t expected_size = 1;
+ for (const Tensor* t : tensor_list) {
+ EXPECT_THAT(t->shape().NumElements(), Eq(expected_size));
+ expected_size++;
+ }
+}
+
+TEST_F(InputTensorListTest, Inlined_Iterate_MultiPassGuarantee) {
+ // Ensure the iterator meets the multi-pass guarantee requirements required
+ // by forward iterators.
+ // (https://en.cppreference.com/w/cpp/iterator/forward_iterator)
+ InputTensorList tensor_list = CreateInlined();
+ auto iterI = tensor_list.begin();
+ auto iterJ = tensor_list.begin();
+ EXPECT_THAT(iterI, Eq(iterJ));
+ EXPECT_THAT(*iterI, Eq(*iterJ));
+ const Tensor* elem = *iterI;
+ iterI++;
+ // iterJ points to the same element as before even though iterI was moved
+ // forward.
+ EXPECT_THAT(elem, Eq(*iterJ));
+ EXPECT_THAT(*iterI, Not(Eq(*iterJ)));
+ // After both iterators are incremented the same number of times they should
+ // again point to the same element.
+ iterI++;
+ iterJ++;
+ iterJ++;
+ EXPECT_THAT(*iterI, Eq(*iterJ));
+}
+
+TEST_F(InputTensorListTest, Inlined_Iterate_PostincrementAndPreincrement) {
+ InputTensorList tensor_list = CreateInlined();
+ auto iterI = tensor_list.begin();
+ // If postincrement works as expected, iterJ will be set to the value of iterI
+ // before it is incremented.
+ auto iterJ = iterI++;
+ EXPECT_THAT(iterJ, Eq(tensor_list.begin()));
+ // If preincrement works as expected, iterK should be set to the value of
+ // iterJ after it is incremented, which is now the same as iterI.
+ auto iterK = ++iterJ;
+ EXPECT_THAT(iterK, Eq(iterJ));
+ EXPECT_THAT(iterK, Eq(iterI));
+}
+
+TEST_F(InputTensorListTest, Inlined_Index) {
+ InputTensorList tensor_list = CreateInlined();
+ EXPECT_THAT(tensor_list[0]->shape(), Eq(TensorShape{1}));
+ EXPECT_THAT(tensor_list[1]->shape(), Eq(TensorShape{2}));
+ EXPECT_THAT(tensor_list[2]->shape(), Eq(TensorShape{3}));
+}
+
+TEST_F(InputTensorListTest, Inlined_SizeConstructorAndMutableIndex) {
+ InputTensorList tensor_list(3);
+ tensor_list[0] = &t1_;
+ tensor_list[1] = &t2_;
+ tensor_list[2] = &t3_;
+
+ EXPECT_THAT(tensor_list[0]->shape(), Eq(TensorShape{1}));
+ EXPECT_THAT(tensor_list[1]->shape(), Eq(TensorShape{2}));
+ EXPECT_THAT(tensor_list[2]->shape(), Eq(TensorShape{3}));
+}
+
+TEST_F(InputTensorListTest, Inlined_SizeConstructor_InitializesPointersToNull) {
+ InputTensorList tensor_list(3);
+
+ EXPECT_THAT(tensor_list[0], Eq(nullptr));
+ EXPECT_THAT(tensor_list[1], Eq(nullptr));
+ EXPECT_THAT(tensor_list[2], Eq(nullptr));
+}
+
+TEST_F(InputTensorListTest, Allocated_Size) {
+ InputTensorList tensor_list = CreateAllocated();
+ EXPECT_THAT(tensor_list.size(), Eq(6));
+}
+
+TEST_F(InputTensorListTest, Allocated_Iterate) {
+ InputTensorList tensor_list = CreateAllocated();
+ auto iter = tensor_list.begin();
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{1}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{2}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{3}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{4}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{5}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{6}));
+ iter++;
+ EXPECT_THAT(iter, Eq(tensor_list.end()));
+}
+
+TEST_F(InputTensorListTest, Allocated_MoveConstructor_Iterate) {
+ InputTensorList moved_tensor_list = CreateAllocated();
+ InputTensorList tensor_list(std::move(moved_tensor_list));
+ auto iter = tensor_list.begin();
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{1}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{2}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{3}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{4}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{5}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{6}));
+ iter++;
+ EXPECT_THAT(iter, Eq(tensor_list.end()));
+}
+
+TEST_F(InputTensorListTest, Allocated_MoveAssignment_Iterate) {
+ InputTensorList moved_tensor_list = CreateAllocated();
+ // Initially, create the tensor list as an inlined tensor list before
+ // assigning it to an inlined InputTensorList via move assignment.
+ InputTensorList tensor_list = CreateInlined();
+ tensor_list = std::move(moved_tensor_list);
+ auto iter = tensor_list.begin();
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{1}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{2}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{3}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{4}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{5}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{6}));
+ iter++;
+ EXPECT_THAT(iter, Eq(tensor_list.end()));
+
+ // Assigning back to the moved variable is valid.
+ moved_tensor_list = std::move(tensor_list);
+ iter = moved_tensor_list.begin();
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{1}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{2}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{3}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{4}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{5}));
+ iter++;
+ EXPECT_THAT((*iter)->shape(), Eq(TensorShape{6}));
+ iter++;
+ EXPECT_THAT(iter, Eq(moved_tensor_list.end()));
+}
+
+TEST_F(InputTensorListTest, Allocated_ForEachLoop) {
+ InputTensorList tensor_list = CreateAllocated();
+ uint64_t expected_size = 1;
+ for (const Tensor* t : tensor_list) {
+ EXPECT_THAT(t->shape().NumElements(), Eq(expected_size));
+ expected_size++;
+ }
+}
+
+TEST_F(InputTensorListTest, Allocated_Iterate_MultiPassGuarantee) {
+ // Ensure the iterator meets the multi-pass guarantee requirements required
+ // by forward iterators
+ // (https://en.cppreference.com/w/cpp/iterator/forward_iterator)
+ InputTensorList tensor_list = CreateAllocated();
+ auto iterI = tensor_list.begin();
+ auto iterJ = tensor_list.begin();
+ EXPECT_THAT(iterI, Eq(iterJ));
+ EXPECT_THAT(*iterI, Eq(*iterJ));
+ const Tensor* elem = *iterI;
+ iterI++;
+ // iterJ points to the same element as before even though iterI was moved
+ // forward.
+ EXPECT_THAT(elem, Eq(*iterJ));
+ EXPECT_THAT(*iterI, Not(Eq(*iterJ)));
+ // After both iterators are incremented the same number of times they should
+ // again point to the same element.
+ iterI++;
+ iterJ++;
+ iterJ++;
+ EXPECT_THAT(*iterI, Eq(*iterJ));
+}
+
+TEST_F(InputTensorListTest, Allocated_Iterate_PostincrementAndPreincrement) {
+ InputTensorList tensor_list = CreateAllocated();
+ auto iterI = tensor_list.begin();
+ // If postincrement works as expected, iterJ will be set to the value of iterI
+ // before it is incremented.
+ auto iterJ = iterI++;
+ EXPECT_THAT(iterJ, Eq(tensor_list.begin()));
+ // If preincrement works as expected, iterK should be set to the value of
+ // iterJ after it is incremented, which is now the same as iterI.
+ auto iterK = ++iterJ;
+ EXPECT_THAT(iterK, Eq(iterJ));
+ EXPECT_THAT(iterK, Eq(iterI));
+}
+
+TEST_F(InputTensorListTest, Allocated_Index) {
+ InputTensorList tensor_list = CreateAllocated();
+ EXPECT_THAT(tensor_list[0]->shape(), Eq(TensorShape{1}));
+ EXPECT_THAT(tensor_list[1]->shape(), Eq(TensorShape{2}));
+ EXPECT_THAT(tensor_list[2]->shape(), Eq(TensorShape{3}));
+ EXPECT_THAT(tensor_list[3]->shape(), Eq(TensorShape{4}));
+ EXPECT_THAT(tensor_list[4]->shape(), Eq(TensorShape{5}));
+ EXPECT_THAT(tensor_list[5]->shape(), Eq(TensorShape{6}));
+}
+
+TEST_F(InputTensorListTest, Allocated_SizeConstructorAndMutableIndex) {
+ InputTensorList tensor_list(6);
+ tensor_list[0] = &t1_;
+ tensor_list[1] = &t2_;
+ tensor_list[2] = &t3_;
+ tensor_list[3] = &t4_;
+ tensor_list[4] = &t5_;
+ tensor_list[5] = &t6_;
+
+ EXPECT_THAT(tensor_list[0]->shape(), Eq(TensorShape{1}));
+ EXPECT_THAT(tensor_list[1]->shape(), Eq(TensorShape{2}));
+ EXPECT_THAT(tensor_list[2]->shape(), Eq(TensorShape{3}));
+ EXPECT_THAT(tensor_list[3]->shape(), Eq(TensorShape{4}));
+ EXPECT_THAT(tensor_list[4]->shape(), Eq(TensorShape{5}));
+ EXPECT_THAT(tensor_list[5]->shape(), Eq(TensorShape{6}));
+}
+
+TEST_F(InputTensorListTest,
+ Allocated_SizeConstructor_InitializesPointersToNull) {
+ InputTensorList tensor_list(6);
+
+ EXPECT_THAT(tensor_list[0], Eq(nullptr));
+ EXPECT_THAT(tensor_list[1], Eq(nullptr));
+ EXPECT_THAT(tensor_list[2], Eq(nullptr));
+ EXPECT_THAT(tensor_list[3], Eq(nullptr));
+ EXPECT_THAT(tensor_list[4], Eq(nullptr));
+ EXPECT_THAT(tensor_list[5], Eq(nullptr));
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/mutable_vector_data.h b/fcp/aggregation/core/mutable_vector_data.h
new file mode 100644
index 0000000..c2f43a9
--- /dev/null
+++ b/fcp/aggregation/core/mutable_vector_data.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_MUTABLE_VECTOR_DATA_H_
+#define FCP_AGGREGATION_CORE_MUTABLE_VECTOR_DATA_H_
+
+#include <cstddef>
+#include <vector>
+
+#include "fcp/aggregation/core/tensor_data.h"
+
+namespace fcp {
+namespace aggregation {
+
+// MutableVectorData implements TensorData by wrapping std::vector and using it
+// as a backing storage. MutableVectorData can be mutated using std::vector
+// methods.
+template <typename T>
+class MutableVectorData : public std::vector<T>, public TensorData {
+ public:
+ // Derive constructors from the base vector class.
+ using std::vector<T>::vector;
+
+ ~MutableVectorData() override = default;
+
+ // Implementation of the base class methods.
+ size_t byte_size() const override { return this->size() * sizeof(T); }
+ const void* data() const override { return this->std::vector<T>::data(); }
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_MUTABLE_VECTOR_DATA_H_
diff --git a/fcp/aggregation/core/mutable_vector_data_test.cc b/fcp/aggregation/core/mutable_vector_data_test.cc
new file mode 100644
index 0000000..4b62074
--- /dev/null
+++ b/fcp/aggregation/core/mutable_vector_data_test.cc
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/aggregation/core/mutable_vector_data.h"
+
+#include <cstdint>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+TEST(MutableVectorDataTest, MutableVectorDataValid) {
+ MutableVectorData<int64_t> vector_data;
+ vector_data.push_back(1);
+ vector_data.push_back(2);
+ vector_data.push_back(3);
+ EXPECT_THAT(vector_data.CheckValid(sizeof(int64_t)), IsOk());
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/one_dim_grouping_aggregator.h b/fcp/aggregation/core/one_dim_grouping_aggregator.h
new file mode 100644
index 0000000..6241fbc
--- /dev/null
+++ b/fcp/aggregation/core/one_dim_grouping_aggregator.h
@@ -0,0 +1,209 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_AGGREGATION_CORE_ONE_DIM_GROUPING_AGGREGATOR_H_
+#define FCP_AGGREGATION_CORE_ONE_DIM_GROUPING_AGGREGATOR_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "fcp/aggregation/core/agg_vector.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/input_tensor_list.h"
+#include "fcp/aggregation/core/mutable_vector_data.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_aggregator.h"
+#include "fcp/aggregation/core/tensor_data.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+// GroupingAggregator class is a specialization of TensorAggregator which
+// takes in a tensor containing ordinals and a tensor containing values, and
+// accumulates the values into the output positions indicated by the
+// corresponding ordinals.
+//
+// Currently only 1D input tensors are supported.
+//
+// The specific means of accumulating values and producing default values are
+// left to the subclass.
+//
+// The implementation operates on AggVector<T> instances rather than tensors.
+template <typename T>
+class OneDimGroupingAggregator : public TensorAggregator {
+ public:
+ // TODO(team): Support accumulating tensors of multiple dimensions. In
+ // that case, the size of all dimensions but one (the dimension corresponding
+ // to the ordinal tensor) should be known in advance and thus this constructor
+ // should take in a shape with a single unknown dimension.
+ explicit OneDimGroupingAggregator(DataType dtype)
+ : data_vector_(std::make_unique<MutableVectorData<T>>()), num_inputs_(0) {
+ FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype)
+ << "Incompatible dtype";
+ }
+
+ Status MergeWith(TensorAggregator&& other) override {
+ FCP_RETURN_IF_ERROR(CheckValid());
+ OneDimGroupingAggregator<T>* other_ptr =
+ dynamic_cast<OneDimGroupingAggregator<T>*>(&other);
+ if (other_ptr == nullptr) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "GroupingAggregator::MergeOutputTensors: Can only merge with "
+ "another GroupingAggregator operating on the same dtype "
+ << internal::TypeTraits<T>::kDataType;
+ }
+ FCP_RETURN_IF_ERROR((*other_ptr).CheckValid());
+ int other_num_inputs = other.GetNumInputs();
+ OutputTensorList output_tensors = std::move(*other_ptr).TakeOutputs();
+
+ if (output_tensors.size() == 1) {
+ AggVector<T> other_data_vector = output_tensors[0].AsAggVector<T>();
+ if (other_data_vector.size() > data_vector_->size()) {
+ data_vector_->resize(other_data_vector.size(), GetDefaultValue());
+ }
+ AggregateVector(other_data_vector);
+ } else {
+ // An empty output is valid and merging it into the current
+ // GroupingAggregator is a no-op.
+ FCP_CHECK(output_tensors.empty())
+ << "GroupingAggregator::MergeOutputTensors: GroupingAggregator "
+ "should produce at most a single output tensor.";
+ }
+
+ num_inputs_ += other_num_inputs;
+ return FCP_STATUS(OK);
+ }
+
+ int GetNumInputs() const override { return num_inputs_; }
+
+ protected:
+ // Provides mutable access to the aggregator data as a vector<T>
+ inline std::vector<T>& data() { return *data_vector_; }
+
+ // Implementation of the tensor aggregation.
+ // Expects 2 tensors as input: a tensor containing ordinals and a tensor
+ // containing values.
+ //
+ // Accumulates the values into the positions in the output tensor which are
+ // indicated by the corresponding ordinals.
+ Status AggregateTensors(InputTensorList tensors) override {
+ FCP_CHECK(tensors.size() == 2)
+ << "GroupingAggregator should operate on 2 input tensors";
+
+ const Tensor* ordinals = tensors[0];
+ if (ordinals->dtype() != DT_INT64) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "GroupingAggregator::AggregateTensors: dtype mismatch for "
+ "tensor 0. Expected DT_INT64.";
+ }
+ const Tensor* tensor = tensors[1];
+ if (tensor->dtype() != internal::TypeTraits<T>::kDataType) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "GroupingAggregator::AggregateTensors: dtype mismatch for "
+ "tensor 1";
+ }
+ if (ordinals->shape() != tensor->shape()) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "GroupingAggregator::AggregateTensors: tensor shape mismatch. "
+ "Shape of both tensors must be the same.";
+ }
+ int num_dimensions = tensor->shape().dim_sizes().size();
+ if (num_dimensions > 1) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "GroupingAggregator::AggregateTensors: Only 1 dimensional "
+ "tensors supported. Input tensor has "
+ << num_dimensions << " dimensions.";
+ }
+ if (!ordinals->is_dense() || !tensor->is_dense()) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "GroupingAggregator::AggregateTensors: Only dense tensors are "
+ "supported.";
+ }
+ num_inputs_++;
+ AggVector<T> value_vector = tensor->AsAggVector<T>();
+ AggVector<int64_t> ordinals_vector = ordinals->AsAggVector<int64_t>();
+ size_t final_size = data_vector_->size();
+ for (auto o : ordinals_vector) {
+ if (o.value >= final_size) {
+ final_size = o.value + 1;
+ }
+ }
+ // Resize once outside the loop to avoid quadratic behavior.
+ data_vector_->resize(final_size, GetDefaultValue());
+ AggregateVectorByOrdinals(ordinals_vector, value_vector);
+ return FCP_STATUS(OK);
+ }
+
+ Status CheckValid() const override {
+ if (data_vector_ == nullptr) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "GroupingAggregator::CheckValid: Output has already been "
+ "consumed.";
+ }
+ return FCP_STATUS(OK);
+ }
+
+ OutputTensorList TakeOutputs() && override {
+ OutputTensorList outputs = std::vector<Tensor>();
+ if (!data_vector_->empty()) {
+ outputs.push_back(Tensor::Create(internal::TypeTraits<T>::kDataType,
+ TensorShape{data_vector_->size()},
+ std::move(data_vector_))
+ .value());
+ }
+ data_vector_ = nullptr;
+ return outputs;
+ }
+
+ // Delegates AggVector aggregation by ordinal to a derived class.
+ //
+ // The size of the vector returned by data() must be greater than the largest
+ // ordinal in this vector.
+ //
+ // To avoid making a virtual function call per value in the tensor, the whole
+ // vector is passed to the subclass for aggregation, which provides better
+ // performance but comes at the cost of duplicated code between subclasses for
+ // iterating over the vectors.
+ virtual void AggregateVectorByOrdinals(
+ const AggVector<int64_t>& ordinals_vector,
+ const AggVector<T>& value_vector) = 0;
+
+ // Delegates AggVector aggregation to a derived class.
+ //
+ // This vector must be the same size as the vector returned by data().
+ //
+ // To avoid making a virtual function call per value in the tensor, the whole
+ // vector is passed to the subclass for aggregation, which provides better
+ // performance but comes at the cost of duplicated code between subclasses for
+ // iterating over the vectors.
+ virtual void AggregateVector(const AggVector<T>& agg_vector) = 0;
+
+ // Delegates initialization of previously unseen ordinals to a derived class.
+ virtual T GetDefaultValue() = 0;
+
+ private:
+ std::unique_ptr<MutableVectorData<T>> data_vector_;
+ int num_inputs_;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_ONE_DIM_GROUPING_AGGREGATOR_H_
diff --git a/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc b/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc
new file mode 100644
index 0000000..d98ce7e
--- /dev/null
+++ b/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc
@@ -0,0 +1,582 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/aggregation/core/one_dim_grouping_aggregator.h"
+
+#include <climits>
+#include <cstdint>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/agg_vector.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/testing/test_data.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+using testing::Eq;
+using testing::IsFalse;
+using testing::IsTrue;
+
+// A simple Sum Aggregator
+template <typename T>
+class SumGroupingAggregator final : public OneDimGroupingAggregator<T> {
+ public:
+ using OneDimGroupingAggregator<T>::OneDimGroupingAggregator;
+ using OneDimGroupingAggregator<T>::data;
+
+ private:
+ void AggregateVectorByOrdinals(const AggVector<int64_t>& ordinals_vector,
+ const AggVector<T>& value_vector) override {
+ auto value_it = value_vector.begin();
+ for (auto o : ordinals_vector) {
+ int64_t output_index = o.value;
+ // If this function returned a failed Status at this point, the
+ // data_vector_ may have already been partially modified, leaving the
+ // GroupingAggregator in a bad state. Thus, check that the indices of the
+ // ordinals tensor and the data tensor match with FCP_CHECK instead.
+ //
+ // TODO(team): Revisit the constraint that the indices of the
+ // values must match the indices of the ordinals when sparse tensors are
+ // implemented. It may be possible for the value to be omitted for a given
+ // ordinal in which case the default value should be used.
+ FCP_CHECK(value_it.index() == o.index)
+ << "Indices in AggVector of ordinals and AggVector of values "
+ "are mismatched.";
+ // Delegate the actual aggregation to the specific aggregation
+ // intrinsic implementation.
+ AggregateValue(output_index, value_it++.value());
+ }
+ }
+
+ void AggregateVector(const AggVector<T>& value_vector) override {
+ for (auto it : value_vector) {
+ AggregateValue(it.index, it.value);
+ }
+ }
+
+ inline void AggregateValue(int64_t i, T value) { data()[i] += value; }
+
+ T GetDefaultValue() override { return static_cast<T>(0); }
+};
+
+// A simple Min Aggregator that works for int32_t
+class MinGroupingAggregator final : public OneDimGroupingAggregator<int32_t> {
+ public:
+ using OneDimGroupingAggregator<int32_t>::OneDimGroupingAggregator;
+ using OneDimGroupingAggregator<int32_t>::data;
+
+ private:
+ void AggregateVectorByOrdinals(
+ const AggVector<int64_t>& ordinals_vector,
+ const AggVector<int32_t>& value_vector) override {
+ auto value_it = value_vector.begin();
+ for (auto o : ordinals_vector) {
+ int64_t output_index = o.value;
+ // If this function returned a failed Status at this point, the
+ // data_vector_ may have already been partially modified, leaving the
+ // GroupingAggregator in a bad state. Thus, check that the indices of the
+ // ordinals tensor and the data tensor match with FCP_CHECK instead.
+ //
+ // TODO(team): Revisit the constraint that the indices of the
+ // values must match the indices of the ordinals when sparse tensors are
+ // implemented. It may be possible for the value to be omitted for a given
+ // ordinal in which case the default value should be used.
+ FCP_CHECK(value_it.index() == o.index)
+ << "Indices in AggVector of ordinals and AggVector of values "
+ "are mismatched.";
+ // Delegate the actual aggregation to the specific aggregation
+ // intrinsic implementation.
+ AggregateValue(output_index, value_it++.value());
+ }
+ }
+
+ void AggregateVector(const AggVector<int32_t>& value_vector) override {
+ for (auto it : value_vector) {
+ AggregateValue(it.index, it.value);
+ }
+ }
+
+ inline void AggregateValue(int64_t i, int32_t value) {
+ if (value < data()[i]) {
+ data()[i] = value;
+ }
+ }
+ int32_t GetDefaultValue() override { return INT_MAX; }
+};
+
+TEST(GroupingAggregatorTest, EmptyReport) {
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ auto result = std::move(aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result->size(), Eq(0));
+}
+
+TEST(GroupingAggregatorTest, ScalarAggregation_Succeeds) {
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ Tensor ordinal =
+ Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
+ Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
+ Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
+ Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
+ EXPECT_THAT(aggregator.Accumulate({&ordinal, &t1}), IsOk());
+ EXPECT_THAT(aggregator.Accumulate({&ordinal, &t2}), IsOk());
+ EXPECT_THAT(aggregator.Accumulate({&ordinal, &t3}), IsOk());
+ EXPECT_THAT(aggregator.CanReport(), IsTrue());
+
+ auto result = std::move(aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({1}, {6}));
+}
+
+TEST(GroupingAggregatorTest, DenseAggregation_Succeeds) {
+ const TensorShape shape = {4};
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ Tensor ordinals =
+ Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3}))
+ .value();
+ Tensor t1 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({1, 3, 15, 27})).value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
+ EXPECT_THAT(aggregator.Accumulate({&ordinals, &t1}), IsOk());
+ EXPECT_THAT(aggregator.Accumulate({&ordinals, &t2}), IsOk());
+ EXPECT_THAT(aggregator.Accumulate({&ordinals, &t3}), IsOk());
+ EXPECT_THAT(aggregator.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result->size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor(shape, {14, 19, 23, 49}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(GroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) {
+ const TensorShape shape = {4};
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ Tensor t1_ordinals =
+ Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({3, 3, 2, 0}))
+ .value();
+ Tensor t1 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({1, 3, 15, 27})).value();
+ EXPECT_THAT(aggregator.Accumulate({&t1_ordinals, &t1}), IsOk());
+ // Totals: [27, 0, 15, 4]
+ Tensor t2_ordinals =
+ Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({1, 0, 1, 4}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
+ EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk());
+ // Totals: [32, 11, 15, 4, 2]
+ Tensor t3_ordinals =
+ Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({2, 2, 5, 1}))
+ .value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
+ EXPECT_THAT(aggregator.Accumulate({&t3_ordinals, &t3}), IsOk());
+ // Totals: [32, 31, 29, 4, 2, 7]
+ EXPECT_THAT(aggregator.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({6}, {32, 31, 29, 4, 2, 7}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(GroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) {
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ Tensor t1_ordinals =
+ Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 0})).value();
+ Tensor t1 = Tensor::Create(DT_INT32, {2}, CreateTestData({17, 3})).value();
+ EXPECT_THAT(aggregator.Accumulate({&t1_ordinals, &t1}), IsOk());
+ // Totals: [3, 0, 17]
+ Tensor t2_ordinals =
+ Tensor::Create(DT_INT64, {6}, CreateTestData<int64_t>({1, 0, 1, 4, 3, 0}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, 5}))
+ .value();
+ EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk());
+ // Totals: [13, 23, 17, 4, 2]
+ Tensor t3_ordinals =
+ Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4}))
+ .value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, {5}, CreateTestData({3, 11, 7, 6, 3})).value();
+ EXPECT_THAT(aggregator.Accumulate({&t3_ordinals, &t3}), IsOk());
+ // Totals: [13, 30, 31, 4, 2]
+ EXPECT_THAT(aggregator.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({5}, {19, 30, 31, 4, 5}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(GroupingAggregatorTest,
+ DifferentShapesPerAccumulate_NonzeroDefaultValue_Succeeds) {
+ // Use a MinGroupingAggregator which has a non-zero default value so we can
+ // test that when the output grows, elements are set to the default value.
+ MinGroupingAggregator aggregator(DT_INT32);
+ Tensor t1_ordinals =
+ Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 0})).value();
+ Tensor t1 = Tensor::Create(DT_INT32, {2}, CreateTestData({17, 3})).value();
+ EXPECT_THAT(aggregator.Accumulate({&t1_ordinals, &t1}), IsOk());
+ // Totals: [3, INT_MAX, 17]
+ Tensor t2_ordinals =
+ Tensor::Create(DT_INT64, {6}, CreateTestData<int64_t>({0, 0, 0, 4, 4, 0}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, -50}))
+ .value();
+ EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk());
+ // Totals: [-50, INT_MAX, 17, INT_MAX, 2]
+ Tensor t3_ordinals =
+ Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4}))
+ .value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, {5}, CreateTestData({33, 11, 7, 6, 3})).value();
+ EXPECT_THAT(aggregator.Accumulate({&t3_ordinals, &t3}), IsOk());
+ // Totals: [-50, 7, 11, INT_MAX, 2]
+ EXPECT_THAT(aggregator.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({5}, {-50, 7, 11, INT_MAX, 2}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(GroupingAggregatorTest, Merge_Succeeds) {
+ SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
+ SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
+ Tensor ordinal =
+ Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
+ Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
+ Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
+ Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
+ EXPECT_THAT(aggregator1.Accumulate({&ordinal, &t1}), IsOk());
+ EXPECT_THAT(aggregator2.Accumulate({&ordinal, &t2}), IsOk());
+ EXPECT_THAT(aggregator2.Accumulate({&ordinal, &t3}), IsOk());
+
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
+ EXPECT_THAT(aggregator1.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator1).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ EXPECT_THAT(result.value()[0], IsTensor({1}, {6}));
+}
+
+TEST(GroupingAggregatorTest, Merge_BothEmpty_Succeeds) {
+ SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
+ SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
+
+ // Merge the two empty aggregators together.
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
+ EXPECT_THAT(aggregator1.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator1.GetNumInputs(), Eq(0));
+
+ auto result = std::move(aggregator1).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result->size(), Eq(0));
+}
+
+TEST(GroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) {
+ SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
+ SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
+
+ Tensor t1_ordinals =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({3, 3, 2, 0}))
+ .value();
+ Tensor t1 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData({1, 3, 15, 27})).value();
+ EXPECT_THAT(aggregator2.Accumulate({&t1_ordinals, &t1}), IsOk());
+ // aggregator2 totals: [27, 0, 15, 4]
+ Tensor t2_ordinals =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({1, 0, 1, 4}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData({10, 5, 1, 2})).value();
+ EXPECT_THAT(aggregator2.Accumulate({&t2_ordinals, &t2}), IsOk());
+ // aggregator2 totals: [32, 11, 15, 4, 2]
+
+ // Merge aggregator2 into aggregator1 which has not received any inputs.
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
+ EXPECT_THAT(aggregator1.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator1.GetNumInputs(), Eq(2));
+
+ auto result = std::move(aggregator1).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({5}, {32, 11, 15, 4, 2}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(GroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) {
+ SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
+ SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
+
+ Tensor t1_ordinals =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({3, 3, 2, 0}))
+ .value();
+ Tensor t1 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData({1, 3, 15, 27})).value();
+ EXPECT_THAT(aggregator1.Accumulate({&t1_ordinals, &t1}), IsOk());
+ // aggregator1 totals: [27, 0, 15, 4]
+ Tensor t2_ordinals =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({1, 0, 1, 4}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData({10, 5, 1, 2})).value();
+ EXPECT_THAT(aggregator1.Accumulate({&t2_ordinals, &t2}), IsOk());
+ // aggregator1 totals: [32, 11, 15, 4, 2]
+
+ // Merge with aggregator2 which has not received any inputs.
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
+ EXPECT_THAT(aggregator1.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator1.GetNumInputs(), Eq(2));
+
+ auto result = std::move(aggregator1).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({5}, {32, 11, 15, 4, 2}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(GroupingAggregatorTest, Merge_OtherOutputHasFewerElements_Succeeds) {
+ SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
+ SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
+
+ Tensor t1_ordinals =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({3, 3, 2, 0}))
+ .value();
+ Tensor t1 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData({1, 3, 15, 27})).value();
+ EXPECT_THAT(aggregator1.Accumulate({&t1_ordinals, &t1}), IsOk());
+ // aggregator1 totals: [27, 0, 15, 4]
+ Tensor t2_ordinals =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({1, 0, 1, 4}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData({10, 5, 1, 2})).value();
+ EXPECT_THAT(aggregator1.Accumulate({&t2_ordinals, &t2}), IsOk());
+ // aggregator1 totals: [32, 11, 15, 4, 2]
+
+ Tensor t3_ordinals =
+ Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 2})).value();
+ Tensor t3 = Tensor::Create(DT_INT32, {2}, CreateTestData({3, 11})).value();
+ EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk());
+ // aggregator2 totals: [0, 0, 14]
+
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
+ EXPECT_THAT(aggregator1.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator1).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({5}, {32, 11, 29, 4, 2}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(GroupingAggregatorTest, Merge_OtherOutputHasMoreElements_Succeeds) {
+ SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
+ SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
+
+ Tensor t1_ordinals =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({3, 3, 2, 0}))
+ .value();
+ Tensor t1 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData({1, 3, 15, 27})).value();
+ EXPECT_THAT(aggregator1.Accumulate({&t1_ordinals, &t1}), IsOk());
+ // aggregator1 totals: [27, 0, 15, 4]
+ Tensor t2_ordinals =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({1, 0, 1, 4}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData({10, 5, 1, 2})).value();
+ EXPECT_THAT(aggregator1.Accumulate({&t2_ordinals, &t2}), IsOk());
+ // aggregator1 totals: [32, 11, 15, 4, 2]
+
+ Tensor t3_ordinals =
+ Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({2, 2, 5, 1}))
+ .value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, {4}, CreateTestData({3, 11, 7, 20})).value();
+ EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk());
+ // aggregator2 totals: [0, 20, 14, 0, 0, 7]
+
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
+ EXPECT_THAT(aggregator1.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator1).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({6}, {32, 31, 29, 4, 2, 7}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(GroupingAggregatorTest,
+ Merge_OtherOutputHasMoreElements_NonzeroDefaultValue_Succeeds) {
+ // Use a MinGroupingAggregator which has a non-zero default value so we can
+ // test that when the output grows, elements are set to the default value.
+ MinGroupingAggregator aggregator1(DT_INT32);
+ MinGroupingAggregator aggregator2(DT_INT32);
+ Tensor t1_ordinals =
+ Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 0})).value();
+ Tensor t1 = Tensor::Create(DT_INT32, {2}, CreateTestData({-17, 3})).value();
+ EXPECT_THAT(aggregator1.Accumulate({&t1_ordinals, &t1}), IsOk());
+ // aggregator1 totals: [3, INT_MAX, -17]
+
+ Tensor t2_ordinals =
+ Tensor::Create(DT_INT64, {6}, CreateTestData<int64_t>({0, 0, 0, 4, 4, 0}))
+ .value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, -50}))
+ .value();
+ EXPECT_THAT(aggregator2.Accumulate({&t2_ordinals, &t2}), IsOk());
+ // aggregator2 totals: [-50, INT_MAX, INT_MAX, INT_MAX, 2]
+ Tensor t3_ordinals =
+ Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4}))
+ .value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, {5}, CreateTestData({33, 11, 7, 6, 3})).value();
+ EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk());
+ // aggregator2 totals: [-50, 7, 11, INT_MAX, 2]
+
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
+ EXPECT_THAT(aggregator1.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
+
+ auto result = std::move(aggregator1).Report();
+ EXPECT_THAT(result, IsOk());
+ EXPECT_THAT(result.value().size(), Eq(1));
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value()[0], IsTensor({5}, {-50, 7, -17, INT_MAX, 2}));
+ // Also ensure that the resulting tensor is dense.
+ EXPECT_TRUE(result.value()[0].is_dense());
+}
+
+TEST(GroupingAggregatorTest, Aggregate_OrdinalTensorHasIncompatibleDataType) {
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ Tensor ordinal =
+ Tensor::Create(DT_INT32, {}, CreateTestData<int32_t>({0})).value();
+ Tensor t = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({0})).value();
+ EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(GroupingAggregatorTest, Aggregate_IncompatibleDataType) {
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ Tensor ordinal =
+ Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
+ Tensor t = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({0})).value();
+ EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(GroupingAggregatorTest,
+ Aggregate_OrdinalAndValueTensorsHaveIncompatibleShapes) {
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ Tensor ordinal =
+ Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
+ Tensor t = Tensor::Create(DT_INT32, {2}, CreateTestData({0, 1})).value();
+ EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(GroupingAggregatorTest, Aggregate_MultidimensionalTensorsNotSupported) {
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ Tensor ordinal =
+ Tensor::Create(DT_INT64, {2, 2}, CreateTestData<int64_t>({0, 0, 0, 0}))
+ .value();
+ Tensor t =
+ Tensor::Create(DT_INT32, {2, 2}, CreateTestData({0, 1, 2, 3})).value();
+ EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(GroupingAggregatorTest, Merge_IncompatibleDataType) {
+ SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
+ SumGroupingAggregator<float> aggregator2(DT_FLOAT);
+ EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(GroupingAggregatorTest, FailsAfterBeingConsumed) {
+ Tensor ordinal =
+ Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
+ Tensor t = Tensor::Create(DT_INT32, {}, CreateTestData({0})).value();
+ SumGroupingAggregator<int32_t> aggregator(DT_INT32);
+ EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsOk());
+ EXPECT_THAT(std::move(aggregator).Report(), IsOk());
+
+ // Now the aggregator instance has been consumed and should fail any
+ // further operations.
+ EXPECT_THAT(aggregator.CanReport(), IsFalse()); // NOLINT
+ EXPECT_THAT(std::move(aggregator).Report(),
+ IsCode(FAILED_PRECONDITION)); // NOLINT
+ EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), // NOLINT
+ IsCode(FAILED_PRECONDITION));
+ EXPECT_THAT(
+ aggregator.MergeWith(SumGroupingAggregator<int32_t>(DT_INT32)), // NOLINT
+ IsCode(FAILED_PRECONDITION));
+
+ // Passing this aggregator as an argument to another MergeWith must fail too.
+ SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
+ EXPECT_THAT(aggregator2.MergeWith(std::move(aggregator)), // NOLINT
+ IsCode(FAILED_PRECONDITION));
+}
+
+TEST(GroupingAggregatorTest, TypeCheckFailure) {
+ EXPECT_DEATH(new SumGroupingAggregator<float>(DT_INT32),
+ "Incompatible dtype");
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/tensor.cc b/fcp/aggregation/core/tensor.cc
new file mode 100644
index 0000000..76f7bc6
--- /dev/null
+++ b/fcp/aggregation/core/tensor.cc
@@ -0,0 +1,257 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/tensor.h"
+
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/base/monitoring.h"
+
+#ifndef FCP_NANOLIBC
+#include "fcp/aggregation/core/tensor.pb.h"
+#include "google/protobuf/io/coded_stream.h"
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+#endif
+
+namespace fcp {
+namespace aggregation {
+
+Status Tensor::CheckValid() const {
+ if (dtype_ == DT_INVALID) {
+ return FCP_STATUS(FAILED_PRECONDITION) << "Invalid Tensor dtype.";
+ }
+
+ size_t value_size = 0;
+ DTYPE_CASES(dtype_, T, value_size = sizeof(T));
+
+ // Verify that the storage is consistent with the value size in terms of
+ // size and alignment.
+ FCP_RETURN_IF_ERROR(data_->CheckValid(value_size));
+
+ // Verify that the total size of the data is consistent with the value type
+ // and the shape.
+ // TODO(team): Implement sparse tensors.
+ if (data_->byte_size() != shape_.NumElements() * value_size) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "TensorData byte_size is inconsistent with the Tensor dtype and "
+ "shape.";
+ }
+
+ return FCP_STATUS(OK);
+}
+
+StatusOr<Tensor> Tensor::Create(DataType dtype, TensorShape shape,
+ std::unique_ptr<TensorData> data) {
+ Tensor tensor(dtype, std::move(shape), std::move(data));
+ FCP_RETURN_IF_ERROR(tensor.CheckValid());
+ return std::move(tensor);
+}
+
+#ifndef FCP_NANOLIBC
+
+// SerializedContentNumericData implements TensorData by wrapping the serialized
+// content string and using it directly as a backing storage. This relies on the
+// fact that the serialized content uses the same layout as in memory
+// representation if we assume that this code runs on a little-endian system.
+// TODO(team): Ensure little-endianness.
+class SerializedContentNumericData : public TensorData {
+ public:
+ explicit SerializedContentNumericData(std::string content)
+ : content_(std::move(content)) {}
+ ~SerializedContentNumericData() override = default;
+
+ // Implementation of TensorData methods.
+ size_t byte_size() const override { return content_.size(); }
+ const void* data() const override { return content_.data(); }
+
+ private:
+ std::string content_;
+};
+
+// Converts the tensor data to a serialized blob saved as the content field
+// in the TensorProto. The `num` argument is needed in case the number of
+// values can't be derived from the TensorData size.
+template <typename T>
+std::string EncodeContent(const TensorData* data, size_t num) {
+ // Default encoding of tensor data, valid only for numeric data types.
+ return std::string(reinterpret_cast<const char*>(data->data()),
+ data->byte_size());
+}
+
+// Specialization of EncodeContent for DT_STRING data type.
+template <>
+std::string EncodeContent<string_view>(const TensorData* data, size_t num) {
+ std::string content;
+ google::protobuf::io::StringOutputStream out(&content);
+ google::protobuf::io::CodedOutputStream coded_out(&out);
+ auto ptr = reinterpret_cast<const string_view*>(data->data());
+
+ // Write all string sizes as Varint64.
+ for (size_t i = 0; i < num; ++i) {
+ coded_out.WriteVarint64(ptr[i].size());
+ }
+
+ // Write all string contents.
+ for (size_t i = 0; i < num; ++i) {
+ coded_out.WriteRaw(ptr[i].data(), static_cast<int>(ptr[i].size()));
+ }
+
+ return content;
+}
+
+// Converts the serialized TensorData content stored in TensorProto to an
+// instance of TensorData. The `num` argument is needed in case the number of
+// values can't be derived from the content size.
+template <typename T>
+StatusOr<std::unique_ptr<TensorData>> DecodeContent(std::string content,
+ size_t num) {
+ // Default decoding of tensor data, valid only for numeric data types.
+ return std::make_unique<SerializedContentNumericData>(std::move(content));
+}
+
+// Wraps the serialized TensorData content stored and surfaces it as pointer
+// string_view values pointing back into the wrapped content. This class is
+// be created and initialized from within the DecodeContent<string_view>().
+class SerializedContentStringData : public TensorData {
+ public:
+ SerializedContentStringData() = default;
+ ~SerializedContentStringData() override = default;
+
+ // Implementation of TensorData methods.
+ size_t byte_size() const override {
+ return string_views_.size() * sizeof(string_view);
+ }
+ const void* data() const override { return string_views_.data(); }
+
+ // Initializes the string_view values to point to the strings embedded in the
+ // content.
+ Status Initialize(std::string content, size_t num) {
+ content_ = std::move(content);
+ google::protobuf::io::ArrayInputStream input(content_.data(),
+ static_cast<int>(content_.size()));
+ google::protobuf::io::CodedInputStream coded_input(&input);
+
+ // The pointer to the first string in the content is unknown at this point
+ // because there are multiple string sizes at the front, all encoded as
+ // VarInts. To avoid using the extra storage this code reuses the same
+ // string_views_ vector in the two passes. First it initializes the data
+ // pointers to start with the beginning of the content. Then in the second
+ // pass it shifts all data pointers to where strings actually begin in the
+ // content.
+ string_views_.resize(num);
+ size_t cumulative_size = 0;
+
+ // The first pass reads the string sizes;
+ for (size_t i = 0; i < num; ++i) {
+ size_t size;
+ if (!coded_input.ReadVarint64(&size)) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "Expected to read " << num
+ << " string values but the input tensor content doesn't contain "
+ "a size for the "
+ << i << "th string. The content size is " << content_.size()
+ << " bytes.";
+ }
+ string_views_[i] = string_view(content_.data() + cumulative_size, size);
+ cumulative_size += size;
+ }
+
+ // The current position in the input stream after reading all the string
+ // sizes. The input stream must be at the beginning of the first string now.
+ size_t offset = coded_input.CurrentPosition();
+
+ // Verify that the content is large enough.
+ if (content_.size() < offset + cumulative_size) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "Input tensor content has insufficient size to store " << num
+ << " string values. The content size is " << content_.size()
+ << " bytes, but " << offset + cumulative_size
+ << " bytes are required.";
+ }
+
+ // The second pass offsets string_view pointers so that the first one points
+ // to the first string embedded in the content, then all others are shifted
+ // by the same offset to point to subsequent strings.
+ for (size_t i = 0; i < num; ++i) {
+ string_views_[i] = string_view(string_views_[i].data() + offset,
+ string_views_[i].size());
+ }
+
+ return FCP_STATUS(OK);
+ }
+
+ private:
+ std::string content_;
+ std::vector<string_view> string_views_;
+};
+
+template <>
+StatusOr<std::unique_ptr<TensorData>> DecodeContent<string_view>(
+ std::string content, size_t num) {
+ auto tensor_data = std::make_unique<SerializedContentStringData>();
+ FCP_RETURN_IF_ERROR(tensor_data->Initialize(std::move(content), num));
+ return tensor_data;
+}
+
+StatusOr<Tensor> Tensor::FromProto(const TensorProto& tensor_proto) {
+ FCP_ASSIGN_OR_RETURN(TensorShape shape,
+ TensorShape::FromProto(tensor_proto.shape()));
+ // TODO(team): The num_values is valid only for dense tensors.
+ size_t num_values = shape.NumElements();
+ StatusOr<std::unique_ptr<TensorData>> data;
+ DTYPE_CASES(tensor_proto.dtype(), T,
+ data = DecodeContent<T>(tensor_proto.content(), num_values));
+ FCP_RETURN_IF_ERROR(data);
+ return Create(tensor_proto.dtype(), std::move(shape),
+ std::move(data).value());
+}
+
+StatusOr<Tensor> Tensor::FromProto(TensorProto&& tensor_proto) {
+ FCP_ASSIGN_OR_RETURN(TensorShape shape,
+ TensorShape::FromProto(tensor_proto.shape()));
+ // TODO(team): The num_values is valid only for dense tensors.
+ size_t num_values = shape.NumElements();
+ std::string content = std::move(*tensor_proto.mutable_content());
+ StatusOr<std::unique_ptr<TensorData>> data;
+ DTYPE_CASES(tensor_proto.dtype(), T,
+ data = DecodeContent<T>(std::move(content), num_values));
+ FCP_RETURN_IF_ERROR(data);
+ return Create(tensor_proto.dtype(), std::move(shape),
+ std::move(data).value());
+}
+
+TensorProto Tensor::ToProto() const {
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(dtype_);
+ *(tensor_proto.mutable_shape()) = shape_.ToProto();
+ // TODO(team): The num_values is valid only for dense tensors.
+ size_t num_values = shape_.NumElements();
+ std::string content;
+ DTYPE_CASES(dtype_, T, content = EncodeContent<T>(data_.get(), num_values));
+ *(tensor_proto.mutable_content()) = std::move(content);
+ return tensor_proto;
+}
+
+#endif // FCP_NANOLIBC
+
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/tensor.h b/fcp/aggregation/core/tensor.h
new file mode 100644
index 0000000..6e4ff2d
--- /dev/null
+++ b/fcp/aggregation/core/tensor.h
@@ -0,0 +1,126 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_TENSOR_H_
+#define FCP_AGGREGATION_CORE_TENSOR_H_
+
+#include <memory>
+#include <utility>
+
+#include "fcp/aggregation/core/agg_vector.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor_data.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/base/monitoring.h"
+
+#ifndef FCP_NANOLIBC
+#include "fcp/aggregation/core/tensor.pb.h"
+#endif
+
+namespace fcp {
+namespace aggregation {
+
+// Tensor class is a container that packages the tensor data with the tensor
+// metadata such as the value type and the shape.
+//
+// For the most part, the aggregation code won't be consuming tensors directly.
+// Instead the aggregation code will be working with AggVector instances that
+// represent the tensor data in a flattened way.
+class Tensor final {
+ public:
+ // Tensor class isn't copyable.
+ Tensor(const Tensor&) = delete;
+
+ // Move constructor.
+ Tensor(Tensor&& other)
+ : dtype_(other.dtype_),
+ shape_(std::move(other.shape_)),
+ data_(std::move(other.data_)) {
+ other.dtype_ = DT_INVALID;
+ }
+
+ // Move assignment.
+ Tensor& operator=(Tensor&& other) {
+ dtype_ = other.dtype_;
+ shape_ = std::move(other.shape_);
+ data_ = std::move(other.data_);
+ other.dtype_ = DT_INVALID;
+ return *this;
+ }
+
+ // Define a default constructor to allow for initalization of array
+ // to enable creation of a vector of Tensors.
+ // A tensor created with the default constructor is not valid and thus should
+ // not actually be used.
+ Tensor() : dtype_(DT_INVALID), shape_{}, data_(nullptr) {}
+
+ // Validates parameters and creates a Tensor instance.
+ static StatusOr<Tensor> Create(DataType dtype, TensorShape shape,
+ std::unique_ptr<TensorData> data);
+
+#ifndef FCP_NANOLIBC
+ // Creates a Tensor instance from a TensorProto.
+ static StatusOr<Tensor> FromProto(const TensorProto& tensor_proto);
+
+ // Creates a Tensor instance from a TensorProto, consuming the proto.
+ static StatusOr<Tensor> FromProto(TensorProto&& tensor_proto);
+
+ // Converts Tensor to TensorProto
+ TensorProto ToProto() const;
+#endif // FCP_NANOLIBC
+
+ // Validates the tensor.
+ Status CheckValid() const;
+
+ // Gets the tensor value type.
+ DataType dtype() const { return dtype_; }
+
+ // Gets the tensor shape.
+ const TensorShape& shape() const { return shape_; }
+
+ // Readonly access to the tensor data.
+ const TensorData& data() const { return *data_; }
+
+ // Returns true is the current tensor data is dense.
+ // TODO(team): Implement sparse tensors.
+ bool is_dense() const { return true; }
+
+ // Provides access to the tensor data via a strongly typed AggVector.
+ template <typename T>
+ AggVector<T> AsAggVector() const {
+ FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype_)
+ << "Incompatible tensor dtype()";
+ return AggVector<T>(data_.get());
+ }
+
+ // TODO(team): Add serialization functions.
+
+ private:
+ Tensor(DataType dtype, TensorShape shape, std::unique_ptr<TensorData> data)
+ : dtype_(dtype), shape_(std::move(shape)), data_(std::move(data)) {}
+
+ // Tensor data type.
+ DataType dtype_;
+ // Tensor shape.
+ TensorShape shape_;
+ // The underlying tensor data.
+ std::unique_ptr<TensorData> data_;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_TENSOR_H_
diff --git a/fcp/aggregation/core/tensor.proto b/fcp/aggregation/core/tensor.proto
new file mode 100644
index 0000000..7e5b93b
--- /dev/null
+++ b/fcp/aggregation/core/tensor.proto
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package fcp.aggregation;
+
+// Data types for individual tensor values.
+enum DataType {
+ // The constants below should be kept in sync with tensorflow::Datatype:
+ // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto
+ // While not strictly required, that has a number of benefits, including
+ // easier porting of tensors from tensorflow::Tensor to aggregation tensors.
+ DT_INVALID = 0;
+ DT_FLOAT = 1;
+ DT_DOUBLE = 2;
+ DT_INT32 = 3;
+ DT_STRING = 7;
+ DT_INT64 = 9;
+ // TODO(team): Add other types
+}
+
+// Tensor shape (e.g. dimensions)
+message TensorShapeProto {
+ // Sizes of each dimension in the tensor.
+ // Values must be >= -1, however values of -1 are reserved for "unknown".
+ //
+ // The order of entries in `dim_sizes` matters: It indicates the layout of the
+ // values in the tensor in-memory representation.
+ //
+ // The first entry in `dim_sizes` is the outermost dimension used to layout
+ // the values, the last entry is the innermost dimension. This matches the
+ // in-memory layout of row-major tensors.
+ //
+ // A scalar tensor has a shape with zero dimensions.
+ repeated int64 dim_sizes = 1;
+}
+
+// This message describes aggregation tensor name, type, and shape.
+message TensorSpecProto {
+ // Tensor name
+ string name = 1;
+
+ // Type of the tensor values.
+ DataType dtype = 2;
+
+ // Shape of the tensor.
+ TensorShapeProto shape = 3;
+}
+
+// Optional descriptor of the sparse index encoding, that is applicable only
+// to sparse tensors. If this message is empty (default) that means that
+// the tensor is dense.
+// The best way to think about SparsityEncoding as a way to describe mapping
+// of the indices in the tensor content to the indices in the dense tensor.
+message SparsityEncoding {
+ // TODO(team): Implement SparsityEncoding.
+}
+
+// Protocol buffer representation of a tensor.
+message TensorProto {
+ // Type of the tensor values.
+ DataType dtype = 1;
+
+ // Shape of the tensor.
+ TensorShapeProto shape = 2;
+
+ // Optional descriptor of sparse index encoding.
+ SparsityEncoding sparsity_encoding = 3;
+
+ // Serialized tensor values packed into a single blob.
+ // The exact format of the blob depends on dtype.
+ //
+ // For numeric data types, the following applies:
+ // For a dense tensor, the content matches in-memory representation of a
+ // C-style row-major multi-dimensional array of values.
+ // For a sparse tensor, the content matches in-memory representation of a
+ // one dimensional array of non-zero values, which order is described by
+ // the `sparsity_encoding`.
+ // The values must be encoded using little-endian byte layout.
+ bytes content = 4;
+}
diff --git a/fcp/aggregation/core/tensor_aggregator.cc b/fcp/aggregation/core/tensor_aggregator.cc
new file mode 100644
index 0000000..3dc1af7
--- /dev/null
+++ b/fcp/aggregation/core/tensor_aggregator.cc
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/tensor_aggregator.h"
+
+#include <utility>
+
+#include "fcp/aggregation/core/input_tensor_list.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+Status TensorAggregator::Accumulate(InputTensorList tensors) {
+ FCP_RETURN_IF_ERROR(CheckValid());
+
+ // Delegate aggregation to the derived class.
+ return AggregateTensors(std::move(tensors));
+}
+
+bool TensorAggregator::CanReport() const { return CheckValid().ok(); }
+
+StatusOr<OutputTensorList> TensorAggregator::Report() && {
+ FCP_RETURN_IF_ERROR(CheckValid());
+ if (!CanReport()) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "TensorAggregator::Report: the report goal isn't met";
+ }
+ return std::move(*this).TakeOutputs();
+}
+
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/tensor_aggregator.h b/fcp/aggregation/core/tensor_aggregator.h
new file mode 100644
index 0000000..d0590be
--- /dev/null
+++ b/fcp/aggregation/core/tensor_aggregator.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_H_
+#define FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_H_
+
+#include <vector>
+
+#include "fcp/aggregation/core/aggregator.h"
+#include "fcp/aggregation/core/input_tensor_list.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+using OutputTensorList = std::vector<Tensor>;
+
+// TensorAggregator is a base class for implementing Aggregation intrinsics
+// with Tensor being an input and output type for the aggregation.
+class TensorAggregator
+ : public Aggregator<InputTensorList, OutputTensorList, TensorAggregator> {
+ public:
+ ~TensorAggregator() override = default;
+
+ // Implementation of the base Aggregator class methods.
+ Status Accumulate(InputTensorList tensors) override;
+ bool CanReport() const override;
+ StatusOr<OutputTensorList> Report() && override;
+
+ // Returns the number of aggregated inputs.
+ virtual int GetNumInputs() const = 0;
+
+ protected:
+ // Construct TensorAggregator
+ explicit TensorAggregator() {}
+
+ // The actual implementation of the tensor aggregation to be provided by
+ // a derived class.
+ virtual Status AggregateTensors(InputTensorList tensors) = 0;
+
+ // Checks if the current TensorAggregator is valid e.g. the resulting output
+ // hasn't been consumed.
+ virtual Status CheckValid() const = 0;
+
+ // Consumes the output of this TensorAggregator.
+ virtual OutputTensorList TakeOutputs() && = 0;
+
+ private:
+ // Extracts the aggregated tensor and makes the current aggregator "consumed".
+ OutputTensorList TakeTensors() &&;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_H_
diff --git a/fcp/aggregation/core/tensor_aggregator_factory.h b/fcp/aggregation/core/tensor_aggregator_factory.h
new file mode 100644
index 0000000..6a3f5d7
--- /dev/null
+++ b/fcp/aggregation/core/tensor_aggregator_factory.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_FACTORY_H_
+#define FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_FACTORY_H_
+
+#include <memory>
+
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor_aggregator.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+// This class is the interface for the abstract factory that creates an instance
+// of a TensorAggregator derived class.
+class TensorAggregatorFactory {
+ public:
+ virtual ~TensorAggregatorFactory() = default;
+
+ // Creates an instance of a specific aggregator for the specified type of the
+ // aggregation instrinsic and the tensor specifications.
+ // TODO(team): Generalize this to allow multiple inputs and outputs,
+ // and an arbitrary number of arguments.
+ virtual StatusOr<std::unique_ptr<TensorAggregator>> Create(
+ DataType dtype, TensorShape shape) const = 0;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_FACTORY_H_
diff --git a/fcp/aggregation/core/tensor_aggregator_registry.cc b/fcp/aggregation/core/tensor_aggregator_registry.cc
new file mode 100644
index 0000000..36182bc
--- /dev/null
+++ b/fcp/aggregation/core/tensor_aggregator_registry.cc
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string>
+
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
+
+#ifdef FCP_BAREMETAL
+#include <unordered_map>
+#else
+#include "absl/container/flat_hash_map.h"
+#include "absl/synchronization/mutex.h"
+#endif
+
+namespace fcp {
+namespace aggregation {
+
+namespace internal {
+
+class Registry final {
+ public:
+ void RegisterAggregatorFactory(const std::string& intrinsic_uri,
+ const TensorAggregatorFactory* factory) {
+ FCP_CHECK(factory != nullptr);
+
+#ifndef FCP_BAREMETAL
+ absl::MutexLock lock(&mutex_);
+#endif
+ FCP_CHECK(map_.find(intrinsic_uri) == map_.end())
+ << "A factory for intrinsic_uri '" << intrinsic_uri
+ << "' is already registered.";
+ map_[intrinsic_uri] = factory;
+ FCP_LOG(INFO) << "TensorAggregatorFactory for intrinsic_uri '"
+ << intrinsic_uri << "' is registered.";
+ }
+
+ StatusOr<const TensorAggregatorFactory*> GetAggregatorFactory(
+ const std::string& intrinsic_uri) {
+#ifndef FCP_BAREMETAL
+ absl::MutexLock lock(&mutex_);
+#endif
+ auto it = map_.find(intrinsic_uri);
+ if (it == map_.end()) {
+ return FCP_STATUS(NOT_FOUND)
+ << "Unknown factory for intrinsic_uri '" << intrinsic_uri << "'.";
+ }
+ return it->second;
+ }
+
+ private:
+#ifdef FCP_BAREMETAL
+ std::unordered_map<std::string, const TensorAggregatorFactory*> map_;
+#else
+ // Synchronization of potentially concurrent registry calls is done only in
+ // the non-baremetal environment. In the baremetal environment, since there is
+ // no OS, a single thread execution environment is expected and the
+ // synchronization primitives aren't available.
+ absl::Mutex mutex_;
+ absl::flat_hash_map<std::string, const TensorAggregatorFactory*> map_
+ ABSL_GUARDED_BY(mutex_);
+#endif
+};
+
+#ifdef FCP_BAREMETAL
+// TODO(team): Revise the registration mechanism below.
+// In a baremetal build the static initialization mechanism isn't available
+// which means that all the aggregation intrinsics need to be explicitly
+// registered below.
+extern "C" void RegisterFederatedSum();
+
+void RegisterAll() { RegisterFederatedSum(); }
+#endif // FCP_BAREMETAL
+
+Registry* GetRegistry() {
+ static Registry* global_registry = new Registry();
+#ifdef FCP_BAREMETAL
+ // TODO(team): Revise the registration mechanism below.
+ static bool registration_done = false;
+ if (!registration_done) {
+ registration_done = true;
+ RegisterAll();
+ }
+#endif
+ return global_registry;
+}
+
+} // namespace internal
+
+// Registers a factory instance for the given intrinsic type.
+void RegisterAggregatorFactory(const std::string& intrinsic_uri,
+ const TensorAggregatorFactory* factory) {
+ internal::GetRegistry()->RegisterAggregatorFactory(intrinsic_uri, factory);
+}
+
+// Looks up a factory instance for the given intrinsic type.
+StatusOr<const TensorAggregatorFactory*> GetAggregatorFactory(
+ const std::string& intrinsic_uri) {
+ return internal::GetRegistry()->GetAggregatorFactory(intrinsic_uri);
+}
+
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/tensor_aggregator_registry.h b/fcp/aggregation/core/tensor_aggregator_registry.h
new file mode 100644
index 0000000..ddc6648
--- /dev/null
+++ b/fcp/aggregation/core/tensor_aggregator_registry.h
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_REGISTRY_H_
+#define FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_REGISTRY_H_
+
+#include <string>
+
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
+
+namespace fcp {
+namespace aggregation {
+
+// Registers a factory instance for the given intrinsic type.
+void RegisterAggregatorFactory(const std::string& intrinsic_uri,
+ const TensorAggregatorFactory* factory);
+
+// Looks up a factory instance for the given intrinsic type.
+StatusOr<const TensorAggregatorFactory*> GetAggregatorFactory(
+ const std::string& intrinsic_uri);
+
+namespace internal {
+
+template <typename FactoryType>
+struct Registrar {
+ explicit Registrar(const std::string& intrinsic_uri) {
+ RegisterAggregatorFactory(intrinsic_uri, new FactoryType());
+ }
+};
+
+} // namespace internal
+
+// This macro is used to register a factory type with the intrinsic uri.
+#define REGISTER_AGGREGATOR_FACTORY(intrinsic_uri, FactoryType) \
+ static auto unused = \
+ ::fcp::aggregation::internal::Registrar<FactoryType>(intrinsic_uri);
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_REGISTRY_H_
diff --git a/fcp/aggregation/core/tensor_aggregator_registry_test.cc b/fcp/aggregation/core/tensor_aggregator_registry_test.cc
new file mode 100644
index 0000000..e3f3d69
--- /dev/null
+++ b/fcp/aggregation/core/tensor_aggregator_registry_test.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/tensor_aggregator_registry.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+class MockFactory : public TensorAggregatorFactory {
+ MOCK_METHOD(StatusOr<std::unique_ptr<TensorAggregator>>, Create,
+ (DataType dtype, TensorShape shape), (const override));
+};
+
+REGISTER_AGGREGATOR_FACTORY("foobar", MockFactory);
+
+TEST(TensorAggregatorRegistryTest, FactoryRegistrationSuccessful) {
+ EXPECT_THAT(GetAggregatorFactory("foobar"), IsOk());
+ EXPECT_THAT(GetAggregatorFactory("xyz"), IsCode(NOT_FOUND));
+}
+
+TEST(TensorAggregatorRegistryTest, RepeatedRegistrationUnsuccessful) {
+ MockFactory factory2;
+ EXPECT_DEATH(RegisterAggregatorFactory("foobar", &factory2),
+ "already registered");
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/tensor_data.cc b/fcp/aggregation/core/tensor_data.cc
new file mode 100644
index 0000000..f98f6ef
--- /dev/null
+++ b/fcp/aggregation/core/tensor_data.cc
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/tensor_data.h"
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+Status TensorData::CheckValid(size_t value_size) const {
+ FCP_CHECK(value_size > 0);
+ if (byte_size() == 0) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "TensorData: non-empty size required";
+ }
+
+ if ((byte_size() % value_size) != 0) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "TensorData: byte_size() must be a multiple of value_size "
+ << value_size;
+ }
+
+ if ((reinterpret_cast<size_t>(data()) % value_size) != 0) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "TensorData: data() address is not aligned by value_size "
+ << value_size;
+ }
+
+ return FCP_STATUS(OK);
+}
+
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/tensor_data.h b/fcp/aggregation/core/tensor_data.h
new file mode 100644
index 0000000..c1f5ae2
--- /dev/null
+++ b/fcp/aggregation/core/tensor_data.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_TENSOR_DATA_H_
+#define FCP_AGGREGATION_CORE_TENSOR_DATA_H_
+
+#include <cstddef>
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace aggregation {
+
+// Abstract representation of tensor data storage.
+//
+// Tensor data is flattened one-dimensional array of tensor of tensor values
+// where each values takes sizeof(T) bytes.
+//
+// All tensor values are stored in a single blob regardless of whether the
+// tensor is dense or sparse.
+//
+// If the tensor is dense, then the values are flattened into
+// one-dimensional array the following way:
+// - First iterating over the last dimension
+// - Then incrementing the second from the last dimension and then iterating
+// over the last dimension
+// - Then gradually moving towards the first dimension.
+// For example, if we had a 3-dimensional {3 x 2 x 4} Tensor, the values
+// in TensorData would be ordered in the following way, showing 3-dimensional
+// indices of the tensor values:
+// (0,0,0), (0,0,1), (0,0,2), (0,0,3)
+// (0,1,0), (0,1,1), (0,1,2), (0,1,3)
+// (1,0,0), (1,0,1), (1,0,2), (1,0,3)
+// (1,1,0), (1,1,1), (1,1,2), (1,1,3)
+// (2,0,0), (2,0,1), (2,0,2), (2,0,3)
+// (2,1,0), (2,1,1), (2,1,2), (2,1,3)
+//
+// If the tensor is sparse, then the order of values in the array is arbitrary
+// and can be described by the tensor SparsityParameters which describes the
+// mapping from the value indices in tensor data to indices in the dense tensor
+// flattened the way described above.
+//
+// The tensor data can be backed by different implementations depending on
+// where the data comes from.
+class TensorData {
+ public:
+ virtual ~TensorData() = default;
+
+ // Tensor data pointer.
+ virtual const void* data() const = 0;
+
+ // The overall size of the tensor data in bytes.
+ virtual size_t byte_size() const = 0;
+
+ // Validates TensorData constraints given the specified value_size.
+ // The value_size is the size of the native data type (e.g. 4 bytes for int32
+ // or float, 8 bytes for int64). This is used to verify data alignment - that
+ // all offsets and sizes are multiples of value_size that pointers are memory
+ // aligned to the value_size.
+ // TODO(team): Consider separate sizes for the pointer alignment and
+ // the slices offsets/sizes. The latter may need to be more coarse.
+ Status CheckValid(size_t value_size) const;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_TENSOR_DATA_H_
diff --git a/fcp/aggregation/core/tensor_data_test.cc b/fcp/aggregation/core/tensor_data_test.cc
new file mode 100644
index 0000000..9743037
--- /dev/null
+++ b/fcp/aggregation/core/tensor_data_test.cc
@@ -0,0 +1,71 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/tensor_data.h"
+
+#include <cstdint>
+#include <memory>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+using testing::Return;
+
+class MockTensorData : public TensorData {
+ public:
+ MockTensorData(size_t data_pointer_offset, size_t size);
+
+ MOCK_METHOD(const void*, data, (), (const override));
+ MOCK_METHOD(size_t, byte_size, (), (const override));
+};
+
+MockTensorData::MockTensorData(size_t data_pointer_offset, size_t size) {
+ EXPECT_CALL(*this, byte_size()).WillRepeatedly(Return(size));
+ EXPECT_CALL(*this, data())
+ .WillRepeatedly(Return(reinterpret_cast<void*>(data_pointer_offset)));
+}
+
+TEST(TensorDataTest, CheckValid_ZeroByteSize) {
+ MockTensorData tensor_data(0, 0);
+ EXPECT_THAT(tensor_data.CheckValid(1), IsCode(FAILED_PRECONDITION));
+}
+
+TEST(TensorDataTest, CheckValid_ByteSizeNotAligned) {
+ MockTensorData tensor_data(0, 33);
+ EXPECT_THAT(tensor_data.CheckValid(4), IsCode(FAILED_PRECONDITION));
+}
+
+TEST(TensorDataTest, CheckValid_AddressNotAligned) {
+ MockTensorData tensor_data(3, 100);
+ EXPECT_THAT(tensor_data.CheckValid(4), IsCode(FAILED_PRECONDITION));
+}
+
+TEST(TensorDataTest, CheckValid_Success) {
+ MockTensorData tensor_data(0, 96);
+ EXPECT_THAT(tensor_data.CheckValid(1), IsOk());
+ EXPECT_THAT(tensor_data.CheckValid(2), IsOk());
+ EXPECT_THAT(tensor_data.CheckValid(4), IsOk());
+ EXPECT_THAT(tensor_data.CheckValid(8), IsOk());
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/tensor_shape.cc b/fcp/aggregation/core/tensor_shape.cc
new file mode 100644
index 0000000..07ad581
--- /dev/null
+++ b/fcp/aggregation/core/tensor_shape.cc
@@ -0,0 +1,65 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/tensor_shape.h"
+
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+
+#ifndef FCP_NANOLIBC
+#include "fcp/aggregation/core/tensor.pb.h"
+#endif
+
+namespace fcp {
+namespace aggregation {
+
+size_t TensorShape::NumElements() const {
+ size_t num_elements = 1;
+ for (auto dim_size : dim_sizes_) {
+ num_elements *= dim_size;
+ }
+ return num_elements;
+}
+
+#ifndef FCP_NANOLIBC
+
+StatusOr<TensorShape> TensorShape::FromProto(
+ const TensorShapeProto& shape_proto) {
+ TensorShape::DimSizesVector dim_sizes;
+ for (int64_t dim_size : shape_proto.dim_sizes()) {
+ if (dim_size < 0) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "Negative dimension size isn't supported when converting from "
+ << "shape_proto: " << shape_proto.ShortDebugString();
+ }
+ dim_sizes.push_back(dim_size);
+ }
+ return TensorShape(std::move(dim_sizes));
+}
+
+TensorShapeProto TensorShape::ToProto() const {
+ TensorShapeProto shape_proto;
+ for (auto dim_size : dim_sizes()) {
+ shape_proto.add_dim_sizes(static_cast<int64_t>(dim_size));
+ }
+ return shape_proto;
+}
+
+#endif // FCP_NANOLIBC
+
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/tensor_shape.h b/fcp/aggregation/core/tensor_shape.h
new file mode 100644
index 0000000..2023857
--- /dev/null
+++ b/fcp/aggregation/core/tensor_shape.h
@@ -0,0 +1,85 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_TENSOR_SHAPE_H_
+#define FCP_AGGREGATION_CORE_TENSOR_SHAPE_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <utility>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+
+#ifndef FCP_NANOLIBC
+#include "fcp/aggregation/core/tensor.pb.h"
+#endif
+
+namespace fcp {
+namespace aggregation {
+
+// Represents a tensor shape as a collection of
+// dimension sizes.
+class TensorShape final {
+ public:
+ using DimSizesVector = std::vector<size_t>;
+
+ template <typename ForwardIterator>
+ TensorShape(ForwardIterator first, ForwardIterator last)
+ : dim_sizes_(first, last) {}
+
+ TensorShape(std::initializer_list<size_t> dim_sizes)
+ : dim_sizes_(dim_sizes) {}
+
+#ifndef FCP_NANOLIBC
+ // Creates a TensorShape from a TensorShapeProto.
+ // Returns an error if any of the shape dimensions are unknown.
+ static StatusOr<TensorShape> FromProto(const TensorShapeProto& shape_proto);
+
+ // Returns a TensorShapeProto representation of the tensor shape.
+ TensorShapeProto ToProto() const;
+#endif
+
+ // Gets the dimensions and their sizes.
+ const DimSizesVector& dim_sizes() const { return dim_sizes_; }
+
+ // Gets the total number of elements (which is a multiplication of sizes of
+ // all dimensions).
+ // For a scalar tensor with zero dimensions this returns 1.
+ size_t NumElements() const;
+
+ friend bool operator==(const TensorShape& a, const TensorShape& b) {
+ return a.dim_sizes_ == b.dim_sizes_;
+ }
+
+ friend bool operator!=(const TensorShape& a, const TensorShape& b) {
+ return a.dim_sizes_ != b.dim_sizes_;
+ }
+
+ private:
+ explicit TensorShape(DimSizesVector&& dim_sizes)
+ : dim_sizes_(std::move(dim_sizes)) {}
+
+ // TODO(team): Consider optimizing the storage for better inlining
+ // of small number of dimensions.
+ DimSizesVector dim_sizes_;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_TENSOR_SHAPE_H_
diff --git a/fcp/aggregation/core/tensor_shape_test.cc b/fcp/aggregation/core/tensor_shape_test.cc
new file mode 100644
index 0000000..3691bd0
--- /dev/null
+++ b/fcp/aggregation/core/tensor_shape_test.cc
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/tensor_shape.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+using testing::ElementsAre;
+
+TEST(TensorShapeTest, CreateFromInitializerList) {
+ TensorShape shape({2, 3, 5});
+ EXPECT_THAT(shape.dim_sizes(), ElementsAre(2, 3, 5));
+}
+
+TEST(TensorShapeTest, CreateFromIterator) {
+ std::vector<size_t> dim_sizes = {4, 8, 3, 2};
+ TensorShape shape(dim_sizes.begin(), dim_sizes.end());
+ EXPECT_THAT(shape.dim_sizes(), ElementsAre(4, 8, 3, 2));
+}
+
+TEST(TensorShapeTest, NumElements) {
+ TensorShape shape({2, 3, 5});
+ EXPECT_EQ(shape.NumElements(), 30);
+}
+
+TEST(TensorShapeTest, ScalarShape) {
+ TensorShape shape({});
+ EXPECT_EQ(shape.dim_sizes().size(), 0);
+ EXPECT_EQ(shape.NumElements(), 1);
+}
+
+TEST(TensorShapeTest, EqualityOperators) {
+ TensorShape shape({3, 5});
+ EXPECT_EQ(shape, TensorShape({3, 5}));
+ EXPECT_NE(shape, TensorShape({}));
+ EXPECT_NE(shape, TensorShape({1}));
+ EXPECT_NE(shape, TensorShape({3, 4}));
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/tensor_spec.h b/fcp/aggregation/core/tensor_spec.h
new file mode 100644
index 0000000..0e63527
--- /dev/null
+++ b/fcp/aggregation/core/tensor_spec.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_CORE_TENSOR_SPEC_H_
+#define FCP_AGGREGATION_CORE_TENSOR_SPEC_H_
+
+#include <string>
+#include <utility>
+
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+
+namespace fcp {
+namespace aggregation {
+
+// A tuple representing tensor name, data type, and shape.
+class TensorSpec final {
+ public:
+ TensorSpec(std::string name, DataType dtype, TensorShape shape)
+ : name_(std::move(name)), dtype_(dtype), shape_(std::move(shape)) {}
+
+ const std::string& name() const { return name_; }
+ DataType dtype() const { return dtype_; }
+ const TensorShape& shape() const { return shape_; }
+
+ private:
+ const std::string name_;
+ const DataType dtype_;
+ const TensorShape shape_;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_TENSOR_SPEC_H_
diff --git a/fcp/aggregation/core/tensor_test.cc b/fcp/aggregation/core/tensor_test.cc
new file mode 100644
index 0000000..3a81b6e
--- /dev/null
+++ b/fcp/aggregation/core/tensor_test.cc
@@ -0,0 +1,200 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/core/tensor.h"
+
+#include <cstdint>
+#include <initializer_list>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.pb.h"
+#include "fcp/aggregation/testing/test_data.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+using testing::Eq;
+
+TEST(TensorTest, Create_Dense) {
+ auto t = Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1, 2, 3}));
+ EXPECT_THAT(t, IsOk());
+ EXPECT_THAT(t->dtype(), Eq(DT_FLOAT));
+ EXPECT_THAT(t->shape(), Eq(TensorShape{3}));
+ EXPECT_TRUE(t->is_dense());
+ EXPECT_THAT(t->AsAggVector<float>().size(), Eq(3));
+}
+
+TEST(TensorTest, Create_StringTensor) {
+ auto t = Tensor::Create(DT_STRING, {2},
+ CreateTestData<string_view>({"foo", "bar"}));
+ EXPECT_THAT(t, IsOk());
+ EXPECT_THAT(t->dtype(), Eq(DT_STRING));
+ EXPECT_THAT(t->shape(), Eq(TensorShape{2}));
+ EXPECT_TRUE(t->is_dense());
+ EXPECT_THAT(t->AsAggVector<string_view>().size(), Eq(2));
+}
+
+TEST(TensorTest, Create_DataValidationError) {
+ auto t = Tensor::Create(DT_FLOAT, {}, CreateTestData<char>({'a', 'b', 'c'}));
+ EXPECT_THAT(t, IsCode(FAILED_PRECONDITION));
+}
+
+TEST(TensorTest, Create_DataSizeError) {
+ auto t = Tensor::Create(DT_FLOAT, {1}, CreateTestData<float>({1, 2}));
+ EXPECT_THAT(t, IsCode(FAILED_PRECONDITION));
+}
+
+struct FooBar {};
+
+TEST(TensorTest, AsAggVector_TypeCheckFailure) {
+ auto t = Tensor::Create(DT_FLOAT, {1}, CreateTestData<float>({1}));
+ EXPECT_DEATH(t->AsAggVector<FooBar>(), "Incompatible tensor dtype()");
+ EXPECT_DEATH(t->AsAggVector<int>(), "Incompatible tensor dtype()");
+}
+
+template <typename T>
+std::string ToProtoContent(std::initializer_list<T> values) {
+ return std::string(reinterpret_cast<char*>(std::vector(values).data()),
+ values.size() * sizeof(T));
+}
+
+template <>
+std::string ToProtoContent(std::initializer_list<string_view> values) {
+ // The following is the simplified version of serializing the string values
+ // that works only for short strings that are shorter than 128 characters, in
+ // which case string lengths can be encoded with one byte each.
+ std::string content(values.size(), '\0');
+ size_t index = 0;
+ // Write sizes of strings first.
+ for (string_view value : values) {
+ FCP_CHECK(value.size() < 128);
+ content[index++] = static_cast<char>(value.size());
+ }
+ // Append data of all strings.
+ for (string_view value : values) {
+ content.append(value.data(), value.size());
+ }
+ return content;
+}
+
+TEST(TensorTest, ToProto_Numeric_Success) {
+ std::initializer_list<int32_t> values{1, 2, 3, 4};
+ auto t = Tensor::Create(DT_INT32, {2, 2}, CreateTestData(values));
+ TensorProto expected_proto;
+ expected_proto.set_dtype(DT_INT32);
+ expected_proto.mutable_shape()->add_dim_sizes(2);
+ expected_proto.mutable_shape()->add_dim_sizes(2);
+ expected_proto.set_content(ToProtoContent(values));
+ EXPECT_THAT(t->ToProto(), EqualsProto(expected_proto));
+}
+
+TEST(TensorTest, ToProto_String_Success) {
+ std::initializer_list<string_view> values{"abc", "de", "",
+ "fghi", "jklmn", "o"};
+ auto t = Tensor::Create(DT_STRING, {2, 3}, CreateTestData(values));
+ TensorProto expected_proto;
+ expected_proto.set_dtype(DT_STRING);
+ expected_proto.mutable_shape()->add_dim_sizes(2);
+ expected_proto.mutable_shape()->add_dim_sizes(3);
+ expected_proto.set_content(ToProtoContent(values));
+ EXPECT_THAT(t->ToProto(), EqualsProto(expected_proto));
+}
+
+TEST(TensorTest, FromProto_Numeric_Success) {
+ std::initializer_list<int32_t> values{5, 6, 7, 8, 9, 10};
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(DT_INT32);
+ tensor_proto.mutable_shape()->add_dim_sizes(2);
+ tensor_proto.mutable_shape()->add_dim_sizes(3);
+ tensor_proto.set_content(ToProtoContent(values));
+ auto t = Tensor::FromProto(tensor_proto);
+ EXPECT_THAT(t, IsOk());
+ EXPECT_THAT(*t, IsTensor({2, 3}, values));
+}
+
+TEST(TensorTest, FromProto_String_Success) {
+ std::initializer_list<string_view> values{"aaaaaaaa", "b", "cccc", "ddddddd"};
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(DT_STRING);
+ tensor_proto.mutable_shape()->add_dim_sizes(2);
+ tensor_proto.mutable_shape()->add_dim_sizes(2);
+ tensor_proto.set_content(ToProtoContent(values));
+ auto t = Tensor::FromProto(tensor_proto);
+ EXPECT_THAT(t, IsOk());
+ EXPECT_THAT(*t, IsTensor({2, 2}, values));
+}
+
+TEST(TensorTest, LargeStringValuesSerialization) {
+ std::string s1(123456, 'a');
+ std::string s2(7890, 'b');
+ std::string s3(1357924, 'c');
+ auto t1 =
+ Tensor::Create(DT_STRING, {3}, CreateTestData<string_view>({s1, s2, s3}));
+ auto proto = t1->ToProto();
+ auto t2 = Tensor::FromProto(proto);
+ EXPECT_THAT(*t2, IsTensor<string_view>({3}, {s1, s2, s3}));
+}
+
+TEST(TensorTest, FromProto_Mutable_Success) {
+ std::initializer_list<int32_t> values{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(DT_INT32);
+ tensor_proto.mutable_shape()->add_dim_sizes(10);
+ tensor_proto.set_content(ToProtoContent(values));
+ // Store the data pointer to make sure that the tensor retains the same data.
+ void* data_ptr = tensor_proto.mutable_content()->data();
+ auto t = Tensor::FromProto(std::move(tensor_proto));
+ EXPECT_THAT(t, IsOk());
+ EXPECT_THAT(*t, IsTensor({10}, values));
+ EXPECT_EQ(data_ptr, t->data().data());
+}
+
+TEST(TensorTest, FromProto_NegativeDimSize) {
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(DT_INT32);
+ tensor_proto.mutable_shape()->add_dim_sizes(-1);
+ tensor_proto.set_content(ToProtoContent<int32_t>({1}));
+ EXPECT_THAT(Tensor::FromProto(tensor_proto), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(TensorTest, FromProto_InvalidStringContent) {
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(DT_STRING);
+ tensor_proto.mutable_shape()->add_dim_sizes(1);
+ tensor_proto.set_content("");
+ EXPECT_THAT(Tensor::FromProto(tensor_proto), IsCode(INVALID_ARGUMENT));
+
+ std::string content(1, '\5');
+ tensor_proto.set_content(content);
+ EXPECT_THAT(Tensor::FromProto(tensor_proto), IsCode(INVALID_ARGUMENT));
+
+ content.append("abc");
+ tensor_proto.set_content(content);
+ EXPECT_THAT(Tensor::FromProto(tensor_proto), IsCode(INVALID_ARGUMENT));
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/core/vector_string_data.h b/fcp/aggregation/core/vector_string_data.h
new file mode 100644
index 0000000..8f83324
--- /dev/null
+++ b/fcp/aggregation/core/vector_string_data.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_AGGREGATION_CORE_VECTOR_STRING_DATA_H_
+#define FCP_AGGREGATION_CORE_VECTOR_STRING_DATA_H_
+
+#include <cstddef>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor_data.h"
+
+namespace fcp {
+namespace aggregation {
+
+class VectorStringData : public TensorData {
+ public:
+ explicit VectorStringData(std::vector<std::string>&& strings)
+ : strings_(std::move(strings)) {
+ string_views_.reserve(strings_.size());
+ for (const std::string& s : strings_) string_views_.emplace_back(s);
+ }
+ ~VectorStringData() override = default;
+
+ // Implementation of TensorData methods.
+ size_t byte_size() const override {
+ return string_views_.size() * sizeof(string_view);
+ }
+ const void* data() const override { return string_views_.data(); }
+
+ private:
+ std::vector<std::string> strings_;
+ std::vector<string_view> string_views_;
+};
+
+} // namespace aggregation
+} // namespace fcp
+
+#endif // FCP_AGGREGATION_CORE_VECTOR_STRING_DATA_H_
diff --git a/fcp/aggregation/core/vector_string_data_test.cc b/fcp/aggregation/core/vector_string_data_test.cc
new file mode 100644
index 0000000..702ad4b
--- /dev/null
+++ b/fcp/aggregation/core/vector_string_data_test.cc
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/aggregation/core/vector_string_data.h"
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace aggregation {
+namespace {
+
+TEST(VectorDataTest, VectorStringDataValid) {
+ VectorStringData vector_data(std::vector<std::string>(
+ {"string1", "another-string", "one_more_string"}));
+ EXPECT_THAT(vector_data.CheckValid(sizeof(string_view)), IsOk());
+}
+
+} // namespace
+} // namespace aggregation
+} // namespace fcp
diff --git a/fcp/aggregation/protocol/BUILD b/fcp/aggregation/protocol/BUILD
new file mode 100644
index 0000000..a5306f6
--- /dev/null
+++ b/fcp/aggregation/protocol/BUILD
@@ -0,0 +1,118 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Aggregation Protocol Package
+
+load("@org_tensorflow//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
+load("//fcp:config.bzl", "FCP_COPTS")
+load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library")
+
+package(
+ default_visibility = ["//fcp/aggregation:internal"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+proto_library(
+ name = "proto",
+ srcs = ["aggregation_protocol_messages.proto"],
+ deps = [
+ "//fcp/secagg/shared:proto",
+ ],
+)
+
+cc_proto_library(
+ name = "cc_proto",
+ deps = [
+ ":proto",
+ ],
+)
+
+py_proto_library(
+ name = "py_pb2",
+ deps = [
+ ":proto",
+ ],
+)
+
+# Using tf_proto_library to get dependencies to TF protos built correctly.
+tf_proto_library(
+ name = "configuration_proto",
+ srcs = ["configuration.proto"],
+ protodeps = [
+ "@org_tensorflow//tensorflow/core:protos_all",
+ ],
+)
+
+# Allowing to refer to the cc library generated by the rule above in usual way:
+alias(
+ name = "configuration_cc_proto",
+ actual = "configuration_proto_cc",
+)
+
+alias(
+ name = "configuration_py_pb2",
+ actual = "configuration_proto_py",
+)
+
+cc_library(
+ name = "aggregation_protocol",
+ hdrs = [
+ "aggregation_protocol.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":cc_proto",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:cord",
+ ],
+)
+
+cc_library(
+ name = "resource_resolver",
+ hdrs = [
+ "resource_resolver.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:cord",
+ ],
+)
+
+cc_library(
+ name = "checkpoint_builder",
+ hdrs = [
+ "checkpoint_builder.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/aggregation/core:tensor",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:cord",
+ ],
+)
+
+cc_library(
+ name = "checkpoint_parser",
+ hdrs = [
+ "checkpoint_parser.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/aggregation/core:tensor",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:cord",
+ ],
+)
diff --git a/fcp/aggregation/protocol/aggregation_protocol.h b/fcp/aggregation/protocol/aggregation_protocol.h
new file mode 100644
index 0000000..384cc7b
--- /dev/null
+++ b/fcp/aggregation/protocol/aggregation_protocol.h
@@ -0,0 +1,169 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_PROTOCOL_AGGREGATION_PROTOCOL_H_
+#define FCP_AGGREGATION_PROTOCOL_AGGREGATION_PROTOCOL_H_
+
+#include "absl/status/status.h"
+#include "absl/strings/cord.h"
+#include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
+
+namespace fcp::aggregation {
+
+// Describes a abstract aggregation protocol interface between a networking
+// layer (e.g. a service that handles receiving and sending messages with the
+// client devices) and an implementation of an aggregation algorithm.
+//
+// The design of the AggregationProtocol follows a Bridge Pattern
+// (https://en.wikipedia.org/wiki/Bridge_pattern) in that it is meant to
+// decouple an abstraction of the layers above and below the AggregationProtocol
+// from the implementation.
+//
+// In this interface the receiving and sending contributing inputs or
+// messages is abstracted from the actual mechanism for sending and receiving
+// data over the network and from the actual aggregation mechanism.
+//
+// Client identification: the real client identities are hidden from the
+// protocol implementations. Instead each client is identified by a client_id
+// number in a range [0, num_clients) where num_clients is the number of clients
+// the protocol started with or the extended number of clients, which is the
+// sum of the starting num_clients and num_clients passed to each subsequent
+// AddClients call.
+//
+// Thread safety: for any given client identified by a unique client_id, the
+// protocol methods are expected to be called sequentially. But there are no
+// assumptions about concurrent calls made for different clients. Specific
+// implementations of AggregationProtocol are expected to handle concurrent
+// calls. The caller side of the protocol isn't expected to queue messages.
+class AggregationProtocol {
+ public:
+ AggregationProtocol() = default;
+ virtual ~AggregationProtocol() = default;
+
+ // Instructs the protocol to start with the specified number of clients.
+ //
+ // Depending on the protocol implementation, the starting number of clients
+ // may be zero. This method is guaranteed to be the first method called on
+ // the protocol.
+ //
+ // AcceptClients callback is expected in response to this method.
+ virtual absl::Status Start(int64_t num_clients) = 0;
+
+ // Adds an additional batch of clients to the protocol.
+ //
+ // Depending on the protocol implementation, adding clients may not be allowed
+ // and this method might return an error Status.
+ //
+ // AcceptClients callback is expected in response to this method.
+ virtual absl::Status AddClients(int64_t num_clients) = 0;
+
+ // Handles a message from a given client.
+ //
+ // Depending on the specific protocol implementation there may be multiple
+ // messages exchanged with each clients.
+ //
+ // This method should return an error status only if there is an unrecoverable
+ // error which must result in aborting the protocol. Any client specific
+ // error, like an invalid message, should result in closing the protocol with
+ // that specific client only, but this method should still return OK status.
+ virtual absl::Status ReceiveClientMessage(int64_t client_id,
+ const ClientMessage& message) = 0;
+
+ // Notifies the protocol about a communication with a given client being
+ // closed, either normally or abnormally.
+ //
+ // The client_status indicates whether the client connection was closed
+ // normally.
+ //
+ // No further calls or callbacks specific to the given client are expected
+ // after this method.
+ virtual absl::Status CloseClient(int64_t client_id,
+ absl::Status client_status) = 0;
+
+ // Forces the protocol to complete.
+ //
+ // Once the protocol has completed successfully, the Complete callback will
+ // be invoked and provide the aggregation result. If the protocol cannot be
+ // completed in its current state, this method should return an error status.
+ // It is also possible for the completion to fail eventually due to finishing
+ // some asynchronous work, in which case the Abort callback will be invoked.
+ //
+ // No further protocol method calls except Abort and GetStatus are expected
+ // after this method.
+ virtual absl::Status Complete() = 0;
+
+ // Forces the protocol to Abort.
+ //
+ // No further protocol method calls except GetStatus are expected after this
+ // method.
+ virtual absl::Status Abort() = 0;
+
+ // Called periodically to receive the protocol status.
+ //
+ // This method can still be called after the protocol has been completed or
+ // aborted.
+ virtual StatusMessage GetStatus() = 0;
+
+ // Callback interface which methods are implemented by the protocol host.
+ class Callback {
+ public:
+ Callback() = default;
+ virtual ~Callback() = default;
+
+ // Called in response to either StartProtocol or AddClients methods being
+ // called and provides protocol parameters to be broadcasted to all newly
+ // joined clients.
+ virtual void OnAcceptClients(int64_t start_client_id, int64_t num_clients,
+ const AcceptanceMessage& message) = 0;
+
+ // Called by the protocol to deliver a message to a given client.
+ //
+ // Depending on the specific protocol implementation there may be multiple
+ // messages exchanged with each clients, but not all protocols need to
+ // send messages to clients.
+ virtual void OnSendServerMessage(int64_t client_id,
+ const ServerMessage& message) = 0;
+
+ // Called by the protocol to force communication with a client to be closed,
+ // for example due to a client specific error or due to the protocol getting
+ // into a state where no further input for that client is needed.
+ //
+ // No further calls or callbacks specific to the given client are expected
+ // after this method.
+ virtual void OnCloseClient(int64_t client_id,
+ absl::Status diagnostic_status) = 0;
+
+ // Indicates successful completion of the aggregation protocol, contains
+ // the result of the aggregation.
+ //
+ // The format of the result blob is unspecified and can be different for
+ // each specific aggregation protocol implementation. Completing the
+ // protocol should close communications with all remaining clients.
+ virtual void OnComplete(absl::Cord result) = 0;
+
+ // Called by the protocol to indicate that the protocol has been aborted
+ // for internal reasons (e.g. the number of remaining clients dropping
+ // too low).
+ //
+ // Aborting the protocol should close communications with all remaining
+ // clients.
+ virtual void OnAbort(absl::Status diagnostic_status) = 0;
+ };
+};
+
+} // namespace fcp::aggregation
+
+#endif // FCP_AGGREGATION_PROTOCOL_AGGREGATION_PROTOCOL_H_
diff --git a/fcp/aggregation/protocol/aggregation_protocol_messages.proto b/fcp/aggregation/protocol/aggregation_protocol_messages.proto
new file mode 100644
index 0000000..84dd4b0
--- /dev/null
+++ b/fcp/aggregation/protocol/aggregation_protocol_messages.proto
@@ -0,0 +1,119 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package fcp.aggregation;
+
+import "fcp/secagg/shared/secagg_messages.proto";
+
+// Client resource passed to the Aggregation Protocol either as an inlined
+// binary data or as a resource URI.
+message ClientResource {
+ oneof resource_kind {
+ bytes inline_bytes = 1
+ ;
+ string uri = 2;
+ }
+}
+
+// Polymorhic client-to-server message carrying protocol specific content.
+message ClientMessage {
+ oneof protocol_kind {
+ SimpleAggregation simple_aggregation = 1;
+ SecureAggregation secure_aggregation = 2;
+ }
+
+ message SimpleAggregation {
+ ClientResource input = 1;
+ }
+
+ message SecureAggregation {
+ fcp.secagg.ClientToServerWrapperMessage content = 1;
+ ClientResource masked_input = 2;
+ }
+}
+
+// Polymorhic server-to-client message carrying protocol specific content.
+message ServerMessage {
+ oneof protocol_kind {
+ SecureAggregation secure_aggregation = 1;
+ }
+
+ message SecureAggregation {
+ fcp.secagg.ServerToClientWrapperMessage content = 1;
+ }
+}
+
+// Polymorhic server-to-client message which content is included into
+// response to each client when that client joins the aggregation and shortly
+// before the client begins the aggregation protocol.
+message AcceptanceMessage {
+ oneof protocol_kind {
+ SecureAggregation secure_aggregation = 1;
+ }
+
+ message SecureAggregation {
+ // TODO(team): define this message.
+ }
+}
+
+// Status of the aggregation protocol.
+message StatusMessage {
+ // The below buckets are mutually exclusive and exhaustive, such that
+ // it should always be the case that:
+ // #clients = num_clients_completed
+ // + num_clients_failed
+ // + num_clients_pending
+ // + num_clients_aborted
+ //
+ // Number of clients that have successfully completed the aggregation
+ // protocol.
+ int64 num_clients_completed = 1;
+
+ // Number of clients that started the aggregation protocol but failed
+ // to complete e.g. dropped out in the middle of the protocol.
+ int64 num_clients_failed = 2;
+
+ // Number of clients that started the aggregation protocol but have not
+ // finished yet (either successfully or not).
+ int64 num_clients_pending = 3;
+
+ // Number of clients that started the aggregation protocol but were aborted by
+ // the server before they could complete e.g. if progress on the session was
+ // no longer needed.
+ int64 num_clients_aborted = 4;
+
+ // The below buckets provide a breakdown of the aggregated inputs that have
+ // been submitted by the completed clients.
+ // The below should always be true:
+ // num_clients_completed = num_inputs_aggregated_and_included
+ // + num_inputs_aggregated_and_pending
+ // + num_inputs_discarded.
+ //
+ // Number of inputs that were successfully aggregated and included in the
+ // final result of the protocol.
+ int64 num_inputs_aggregated_and_included = 5;
+
+ // Number of inputs that were received and are pending i.e. the inputs have
+ // not been included in the final result of the protocol yet.
+ int64 num_inputs_aggregated_and_pending = 6;
+
+ // Number of inputs that were received by the protocol but discarded for
+ // whatever reason, for example if the protocol has reached a state where it
+ // no longer needs client inputs to complete.
+ int64 num_inputs_discarded = 7;
+}
diff --git a/fcp/aggregation/protocol/checkpoint_builder.h b/fcp/aggregation/protocol/checkpoint_builder.h
new file mode 100644
index 0000000..558df13
--- /dev/null
+++ b/fcp/aggregation/protocol/checkpoint_builder.h
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_PROTOCOL_CHECKPOINT_BUILDER_H_
+#define FCP_AGGREGATION_PROTOCOL_CHECKPOINT_BUILDER_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "fcp/aggregation/core/tensor.h"
+
+namespace fcp::aggregation {
+
+// Describes an abstract interface for building and formatting a checkpoint
+// from a set of named tensors.
+class CheckpointBuilder {
+ public:
+ virtual ~CheckpointBuilder() = default;
+
+ // Adds a tensor to the checkpoint.
+ virtual absl::Status Add(const std::string& name, const Tensor& tensor) = 0;
+
+ // Builds and formats the checkpoint.
+ virtual absl::StatusOr<absl::Cord> Build() = 0;
+};
+
+// Describes an abstract factory for creating instances of CheckpointBuilder.
+class CheckpointBuilderFactory {
+ public:
+ virtual ~CheckpointBuilderFactory() = default;
+
+ // Creates an instance of CheckpointBuilder.
+ virtual std::unique_ptr<CheckpointBuilder> Create() const = 0;
+};
+
+} // namespace fcp::aggregation
+
+#endif // FCP_AGGREGATION_PROTOCOL_CHECKPOINT_BUILDER_H_
diff --git a/fcp/aggregation/protocol/checkpoint_parser.h b/fcp/aggregation/protocol/checkpoint_parser.h
new file mode 100644
index 0000000..ad7e0ed
--- /dev/null
+++ b/fcp/aggregation/protocol/checkpoint_parser.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_PROTOCOL_CHECKPOINT_PARSER_H_
+#define FCP_AGGREGATION_PROTOCOL_CHECKPOINT_PARSER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "fcp/aggregation/core/tensor.h"
+
+namespace fcp::aggregation {
+
+// Describes an abstract interface for parsing a checkpoint from a blob
+// and returning a set of named tensors.
+class CheckpointParser {
+ public:
+ virtual ~CheckpointParser() = default;
+
+ // Gets a tensor by name.
+ virtual absl::StatusOr<Tensor> GetTensor(const std::string& name) const = 0;
+};
+
+// Describes an abstract factory for creating instances of CheckpointParser.
+class CheckpointParserFactory {
+ public:
+ virtual ~CheckpointParserFactory() = default;
+
+ // Creates an instance of CheckpointParser with the provided serialized
+ // checkpoint content.
+ virtual absl::StatusOr<std::unique_ptr<CheckpointParser>> Create(
+ const absl::Cord& serialized_checkpoint) const = 0;
+};
+
+} // namespace fcp::aggregation
+
+#endif // FCP_AGGREGATION_PROTOCOL_CHECKPOINT_PARSER_H_
diff --git a/fcp/aggregation/protocol/configuration.proto b/fcp/aggregation/protocol/configuration.proto
new file mode 100644
index 0000000..e681219
--- /dev/null
+++ b/fcp/aggregation/protocol/configuration.proto
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package fcp.aggregation;
+
+import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/protobuf/struct.proto";
+
+option java_package = "fcp.aggregation";
+option java_multiple_files = true;
+
+// Configuration describing the kind of aggregation to perform.
+message Configuration {
+ // Represents a single aggregation operation, combining one or more input
+ // tensors from a collection of clients into one or more output tensors on the
+ // server.
+ // TODO(team): Remove usage of TensorFlow from this message.
+ message ServerAggregationConfig {
+ // The uri of the aggregation intrinsic (e.g. 'federated_sum').
+ string intrinsic_uri = 1;
+
+ // Describes an argument to the aggregation operation.
+ message IntrinsicArg {
+ oneof arg {
+ // Input tensor provided by each client.
+ tensorflow.TensorSpecProto input_tensor = 2;
+
+ // Constant parameter that is independent of client data (e.g. a modulus
+ // for a federated modular sum operation).
+ tensorflow.TensorProto parameter = 3;
+ }
+ }
+
+ // List of arguments for the aggregation operation. The arguments can be
+ // dependent on client data (in which case they must be retrieved from
+ // clients) or they can be independent of client data (in which case they
+ // can be configured server-side). For now we assume all client-independent
+ // arguments are constants. The arguments must be in the order expected by
+ // the server.
+ repeated IntrinsicArg intrinsic_args = 4;
+
+ // List of server-side outputs produced by the aggregation operation.
+ repeated tensorflow.TensorSpecProto output_tensors = 5;
+
+ // List of inner aggregation intrinsics. This can be used to delegate parts
+ // of the aggregation logic (e.g. a groupby intrinsic may want to delegate
+ // a sum operation to a sum intrinsic).
+ repeated ServerAggregationConfig inner_aggregations = 6;
+ }
+
+ // A list of client-to-server aggregations to perform.
+ repeated ServerAggregationConfig aggregation_configs = 1;
+}
diff --git a/fcp/aggregation/protocol/python/BUILD b/fcp/aggregation/protocol/python/BUILD
new file mode 100644
index 0000000..f65841c
--- /dev/null
+++ b/fcp/aggregation/protocol/python/BUILD
@@ -0,0 +1,22 @@
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
+
+package(
+ default_visibility = ["//fcp/aggregation:internal"],
+)
+
+pybind_extension(
+ name = "aggregation_protocol",
+ srcs = ["aggregation_protocol.cc"],
+ pytype_deps = [],
+ deps = [
+ "//fcp/aggregation/protocol:aggregation_protocol",
+ "//fcp/aggregation/protocol:cc_proto",
+ "//fcp/aggregation/protocol:configuration_cc_proto",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:cord",
+ "@pybind11_abseil//pybind11_abseil:absl_casters",
+ "@pybind11_abseil//pybind11_abseil:status_casters",
+ "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
+ ],
+)
diff --git a/fcp/aggregation/protocol/python/aggregation_protocol.cc b/fcp/aggregation/protocol/python/aggregation_protocol.cc
new file mode 100644
index 0000000..7901d71
--- /dev/null
+++ b/fcp/aggregation/protocol/python/aggregation_protocol.cc
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/protocol/aggregation_protocol.h"
+
+#include <pybind11/pybind11.h>
+
+#include <cstdint>
+
+#include "absl/status/status.h"
+#include "absl/strings/cord.h"
+#include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
+#include "fcp/aggregation/protocol/configuration.pb.h"
+#include "pybind11_abseil/absl_casters.h"
+#include "pybind11_abseil/status_casters.h"
+#include "pybind11_protobuf/native_proto_caster.h"
+
+namespace {
+
+namespace py = ::pybind11;
+
+using ::fcp::aggregation::AcceptanceMessage;
+using ::fcp::aggregation::AggregationProtocol;
+using ::fcp::aggregation::ServerMessage;
+
+// Allow AggregationProtocol::Callback to be subclassed in Python. See
+// https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtual-functions-in-python
+class PyAggregationProtocolCallback : public AggregationProtocol::Callback {
+ public:
+ void OnAcceptClients(int64_t start_client_id, int64_t num_clients,
+ const AcceptanceMessage& message) override {
+ PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnAcceptClients,
+ start_client_id, num_clients, message);
+ }
+
+ void OnSendServerMessage(int64_t client_id,
+ const ServerMessage& message) override {
+ PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback,
+ OnSendServerMessage, client_id, message);
+ }
+
+ void OnCloseClient(int64_t client_id,
+ absl::Status diagnostic_status) override {
+ PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnCloseClient,
+ client_id,
+ py::google::DoNotThrowStatus(diagnostic_status));
+ }
+
+ void OnComplete(absl::Cord result) override {
+ PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnComplete,
+ result);
+ }
+
+ void OnAbort(absl::Status diagnostic_status) override {
+ PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnAbort,
+ py::google::DoNotThrowStatus(diagnostic_status));
+ }
+};
+
+} // namespace
+
+PYBIND11_MODULE(aggregation_protocol, m) {
+ pybind11::google::ImportStatusModule();
+ pybind11_protobuf::ImportNativeProtoCasters();
+
+ auto py_aggregation_protocol =
+ py::class_<AggregationProtocol>(m, "AggregationProtocol")
+ .def("Start", &AggregationProtocol::Start)
+ .def("AddClients", &AggregationProtocol::AddClients)
+ .def("ReceiveClientMessage",
+ &AggregationProtocol::ReceiveClientMessage)
+ .def("CloseClient", &AggregationProtocol::CloseClient)
+ .def("Complete", &AggregationProtocol::Complete)
+ .def("Abort", &AggregationProtocol::Abort)
+ .def("GetStatus", &AggregationProtocol::GetStatus);
+
+ pybind11::class_<AggregationProtocol::Callback,
+ PyAggregationProtocolCallback>(py_aggregation_protocol,
+ "Callback")
+ .def(py::init<>())
+ .def("OnAcceptClients", &AggregationProtocol::Callback::OnAcceptClients)
+ .def("OnSendServerMessage",
+ &AggregationProtocol::Callback::OnSendServerMessage)
+ .def("OnCloseClient", &AggregationProtocol::Callback::OnCloseClient)
+ .def("OnComplete", &AggregationProtocol::Callback::OnComplete)
+ .def("OnAbort", &AggregationProtocol::Callback::OnAbort);
+}
diff --git a/fcp/aggregation/protocol/resource_resolver.h b/fcp/aggregation/protocol/resource_resolver.h
new file mode 100644
index 0000000..11ad08d
--- /dev/null
+++ b/fcp/aggregation/protocol/resource_resolver.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_PROTOCOL_RESOURCE_RESOLVER_H_
+#define FCP_AGGREGATION_PROTOCOL_RESOURCE_RESOLVER_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+
+namespace fcp::aggregation {
+
+// Describes an abstract interface for resolving a resource from a given client
+// and uri.
+class ResourceResolver {
+ public:
+ virtual ~ResourceResolver() = default;
+
+ // Retrieves a resource for the given `client_id` and `uri` combination.
+ // The resource can be accessed exactly once and must be deleted (best-effort)
+ // after it is returned.
+ virtual absl::StatusOr<absl::Cord> RetrieveResource(
+ int64_t client_id, const std::string& uri) = 0;
+};
+} // namespace fcp::aggregation
+
+#endif // FCP_AGGREGATION_PROTOCOL_RESOURCE_RESOLVER_H_
diff --git a/fcp/aggregation/protocol/simple_aggregation/BUILD b/fcp/aggregation/protocol/simple_aggregation/BUILD
new file mode 100644
index 0000000..863c346
--- /dev/null
+++ b/fcp/aggregation/protocol/simple_aggregation/BUILD
@@ -0,0 +1,83 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Aggregation Protocol Package
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp/aggregation:internal"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "simple_aggregation",
+ srcs = [
+ "simple_aggregation_protocol.cc",
+ ],
+ hdrs = [
+ "simple_aggregation_protocol.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/aggregation/core:aggregator",
+ "//fcp/aggregation/core:federated_sum",
+ "//fcp/aggregation/core:tensor",
+ "//fcp/aggregation/protocol:aggregation_protocol",
+ "//fcp/aggregation/protocol:cc_proto",
+ "//fcp/aggregation/protocol:checkpoint_builder",
+ "//fcp/aggregation/protocol:checkpoint_parser",
+ "//fcp/aggregation/protocol:configuration_cc_proto",
+ "//fcp/aggregation/protocol:resource_resolver",
+ "//fcp/aggregation/tensorflow:converters",
+ "//fcp/base",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
+ ],
+ alwayslink = 1,
+)
+
+cc_test(
+ name = "simple_aggregation_test",
+ srcs = ["simple_aggregation_protocol_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":simple_aggregation",
+ "//fcp/aggregation/core:aggregator",
+ "//fcp/aggregation/core:tensor",
+ "//fcp/aggregation/protocol:cc_proto",
+ "//fcp/aggregation/protocol:checkpoint_builder",
+ "//fcp/aggregation/protocol:checkpoint_parser",
+ "//fcp/aggregation/protocol:configuration_cc_proto",
+ "//fcp/aggregation/protocol/testing:mock_callback",
+ "//fcp/aggregation/testing",
+ "//fcp/aggregation/testing:test_data",
+ "//fcp/base",
+ "//fcp/base:scheduler",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.cc b/fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.cc
new file mode 100644
index 0000000..ab71526
--- /dev/null
+++ b/fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.cc
@@ -0,0 +1,558 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_aggregator.h"
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
+#include "fcp/aggregation/core/tensor_aggregator_registry.h"
+#include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
+#include "fcp/aggregation/protocol/checkpoint_builder.h"
+#include "fcp/aggregation/protocol/checkpoint_parser.h"
+#include "fcp/aggregation/tensorflow/converters.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp::aggregation {
+
+// Creates an INVALID_ARGUMENT error with the provided error message.
+absl::Status ServerAggregationConfigArgumentError(
+ const Configuration::ServerAggregationConfig& aggregation_config,
+ absl::string_view error_message) {
+ return absl::InvalidArgumentError(
+ absl::StrFormat("ServerAggregationConfig: %s\n:%s", error_message,
+ aggregation_config.DebugString()));
+}
+
+// Creates an aggregation intrinsic based on the intrinsic configuration.
+absl::StatusOr<SimpleAggregationProtocol::Intrinsic>
+SimpleAggregationProtocol::CreateIntrinsic(
+ const Configuration::ServerAggregationConfig& aggregation_config) {
+ // Resolve the intrinsic_uri to the registered TensorAggregatorFactory.
+ FCP_ASSIGN_OR_RETURN(
+ const TensorAggregatorFactory* factory,
+ GetAggregatorFactory(aggregation_config.intrinsic_uri()));
+
+ // Convert the input tensor specification.
+ FCP_ASSIGN_OR_RETURN(
+ TensorSpec input_spec,
+ tensorflow::ConvertTensorSpec(
+ aggregation_config.intrinsic_args(0).input_tensor()));
+
+ // Convert the output tensor specification.
+ FCP_ASSIGN_OR_RETURN(
+ TensorSpec output_spec,
+ tensorflow::ConvertTensorSpec(aggregation_config.output_tensors(0)));
+
+ // TODO(team): currently the input and output data type and shape are
+ // expected to be the same.
+ if (input_spec.dtype() != output_spec.dtype() ||
+ input_spec.shape() != output_spec.shape()) {
+ return ServerAggregationConfigArgumentError(
+ aggregation_config, "Input and output tensors have mismatched specs.");
+ }
+
+ // Use the factory to create the TensorAggregator instance.
+ FCP_ASSIGN_OR_RETURN(std::unique_ptr<TensorAggregator> aggregator,
+ factory->Create(input_spec.dtype(), input_spec.shape()));
+
+ return Intrinsic{std::move(input_spec), std::move(output_spec),
+ std::move(aggregator)};
+}
+
+absl::Status SimpleAggregationProtocol::ValidateConfig(
+ const Configuration& configuration) {
+ for (const Configuration::ServerAggregationConfig& aggregation_config :
+ configuration.aggregation_configs()) {
+ // TODO(team): Add support for other intrinsics after MVP launch.
+ if (!GetAggregatorFactory(aggregation_config.intrinsic_uri()).ok()) {
+ return ServerAggregationConfigArgumentError(
+ aggregation_config,
+ absl::StrFormat("%s is not a supported intrinsic_uri.",
+ aggregation_config.intrinsic_uri()));
+ }
+
+ // TODO(team): Support multiple intrinsic args.
+ if (aggregation_config.intrinsic_args_size() != 1) {
+ return ServerAggregationConfigArgumentError(
+ aggregation_config, "Exactly one intrinsic argument is expected.");
+ }
+
+ if (aggregation_config.output_tensors_size() != 1) {
+ return ServerAggregationConfigArgumentError(
+ aggregation_config, "Exactly one output tensor is expected.");
+ }
+
+ if (!aggregation_config.intrinsic_args(0).has_input_tensor()) {
+ return ServerAggregationConfigArgumentError(
+ aggregation_config, "Intrinsic arguments must be input tensors.");
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::StatusOr<std::unique_ptr<SimpleAggregationProtocol>>
+SimpleAggregationProtocol::Create(
+ const Configuration& configuration, AggregationProtocol::Callback* callback,
+ const CheckpointParserFactory* checkpoint_parser_factory,
+ const CheckpointBuilderFactory* checkpoint_builder_factory,
+ ResourceResolver* resource_resolver) {
+ FCP_CHECK(callback != nullptr);
+ FCP_CHECK(checkpoint_parser_factory != nullptr);
+ FCP_CHECK(checkpoint_builder_factory != nullptr);
+ FCP_CHECK(resource_resolver != nullptr);
+ FCP_RETURN_IF_ERROR(ValidateConfig(configuration));
+
+ std::vector<Intrinsic> intrinsics;
+ for (const Configuration::ServerAggregationConfig& aggregation_config :
+ configuration.aggregation_configs()) {
+ FCP_ASSIGN_OR_RETURN(Intrinsic intrinsic,
+ CreateIntrinsic(aggregation_config));
+ intrinsics.emplace_back(std::move(intrinsic));
+ }
+
+ return absl::WrapUnique(new SimpleAggregationProtocol(
+ std::move(intrinsics), callback, checkpoint_parser_factory,
+ checkpoint_builder_factory, resource_resolver));
+}
+
+SimpleAggregationProtocol::SimpleAggregationProtocol(
+ std::vector<Intrinsic> intrinsics, AggregationProtocol::Callback* callback,
+ const CheckpointParserFactory* checkpoint_parser_factory,
+ const CheckpointBuilderFactory* checkpoint_builder_factory,
+ ResourceResolver* resource_resolver)
+ : protocol_state_(PROTOCOL_CREATED),
+ intrinsics_(std::move(intrinsics)),
+ callback_(callback),
+ checkpoint_parser_factory_(checkpoint_parser_factory),
+ checkpoint_builder_factory_(checkpoint_builder_factory),
+ resource_resolver_(resource_resolver) {}
+
+absl::string_view SimpleAggregationProtocol::ProtocolStateDebugString(
+ ProtocolState state) {
+ switch (state) {
+ case PROTOCOL_CREATED:
+ return "PROTOCOL_CREATED";
+ case PROTOCOL_STARTED:
+ return "PROTOCOL_STARTED";
+ case PROTOCOL_COMPLETED:
+ return "PROTOCOL_COMPLETED";
+ case PROTOCOL_ABORTED:
+ return "PROTOCOL_ABORTED";
+ }
+}
+
+absl::string_view SimpleAggregationProtocol::ClientStateDebugString(
+ ClientState state) {
+ switch (state) {
+ case CLIENT_PENDING:
+ return "CLIENT_PENDING";
+ case CLIENT_RECEIVED_INPUT_AND_PENDING:
+ return "CLIENT_RECEIVED_INPUT_AND_PENDING";
+ case CLIENT_COMPLETED:
+ return "CLIENT_COMPLETED";
+ case CLIENT_FAILED:
+ return "CLIENT_FAILED";
+ case CLIENT_ABORTED:
+ return "CLIENT_ABORTED";
+ case CLIENT_DISCARDED:
+ return "CLIENT_DISCARDED";
+ }
+}
+
+absl::Status SimpleAggregationProtocol::CheckProtocolState(
+ ProtocolState state) const {
+ if (protocol_state_ != state) {
+ return absl::FailedPreconditionError(
+ absl::StrFormat("The current protocol state is %s, expected %s.",
+ ProtocolStateDebugString(protocol_state_),
+ ProtocolStateDebugString(state)));
+ }
+ return absl::OkStatus();
+}
+
+void SimpleAggregationProtocol::SetProtocolState(ProtocolState state) {
+ FCP_CHECK(
+ (protocol_state_ == PROTOCOL_CREATED && state == PROTOCOL_STARTED) ||
+ (protocol_state_ == PROTOCOL_STARTED &&
+ (state == PROTOCOL_COMPLETED || state == PROTOCOL_ABORTED)))
+ << "Invalid protocol state transition from "
+ << ProtocolStateDebugString(protocol_state_) << " to "
+ << ProtocolStateDebugString(state) << ".";
+ protocol_state_ = state;
+}
+
+absl::StatusOr<SimpleAggregationProtocol::ClientState>
+SimpleAggregationProtocol::GetClientState(int64_t client_id) const {
+ if (client_id < 0 || client_id >= client_states_.size()) {
+ return absl::InvalidArgumentError(
+ absl::StrFormat("client_id %ld is outside the valid range", client_id));
+ }
+ return client_states_[client_id];
+}
+
+void SimpleAggregationProtocol::SetClientState(int64_t client_id,
+ ClientState to_state) {
+ FCP_CHECK(client_id >= 0 && client_id < client_states_.size());
+ ClientState from_state = client_states_[client_id];
+ FCP_CHECK(from_state != to_state);
+ if (from_state == CLIENT_RECEIVED_INPUT_AND_PENDING) {
+ num_clients_received_and_pending_--;
+ } else if (from_state == CLIENT_COMPLETED) {
+ FCP_CHECK(to_state == CLIENT_DISCARDED)
+ << "Client state can't be changed from CLIENT_COMPLETED to "
+ << ClientStateDebugString(to_state);
+ num_clients_aggregated_--;
+ } else {
+ FCP_CHECK(from_state == CLIENT_PENDING)
+ << "Client state can't be changed from "
+ << ClientStateDebugString(from_state);
+ }
+ client_states_[client_id] = to_state;
+ switch (to_state) {
+ case CLIENT_PENDING:
+ FCP_LOG(FATAL) << "Client state can't be changed to CLIENT_PENDING";
+ break;
+ case CLIENT_RECEIVED_INPUT_AND_PENDING:
+ num_clients_received_and_pending_++;
+ break;
+ case CLIENT_COMPLETED:
+ num_clients_aggregated_++;
+ break;
+ case CLIENT_FAILED:
+ num_clients_failed_++;
+ break;
+ case CLIENT_ABORTED:
+ num_clients_aborted_++;
+ break;
+ case CLIENT_DISCARDED:
+ num_clients_discarded_++;
+ break;
+ }
+}
+
+absl::StatusOr<SimpleAggregationProtocol::TensorMap>
+SimpleAggregationProtocol::ParseCheckpoint(absl::Cord report) const {
+ FCP_ASSIGN_OR_RETURN(std::unique_ptr<CheckpointParser> parser,
+ checkpoint_parser_factory_->Create(report));
+ TensorMap tensor_map;
+ for (const auto& intrinsic : intrinsics_) {
+ // TODO(team): Support multiple input tensors.
+ FCP_ASSIGN_OR_RETURN(Tensor tensor,
+ parser->GetTensor(intrinsic.input.name()));
+ if (tensor.dtype() != intrinsic.input.dtype() ||
+ tensor.shape() != intrinsic.input.shape()) {
+ // TODO(team): Detailed diagnostics including the expected vs
+ // actual data types and shapes.
+ return absl::InvalidArgumentError("Input tensor spec mismatch.");
+ }
+ tensor_map.emplace(intrinsic.input.name(), std::move(tensor));
+ }
+
+ return tensor_map;
+}
+
+absl::Status SimpleAggregationProtocol::AggregateClientInput(
+ SimpleAggregationProtocol::TensorMap tensor_map) {
+ absl::MutexLock lock(&aggregation_mu_);
+ if (!aggregation_finished_) {
+ for (const auto& intrinsic : intrinsics_) {
+ // TODO(team): Support multiple input tensors.
+ const auto& it = tensor_map.find(intrinsic.input.name());
+ FCP_CHECK(it != tensor_map.end());
+ FCP_CHECK(intrinsic.aggregator != nullptr)
+ << "CreateReport() has already been called.";
+ FCP_RETURN_IF_ERROR(intrinsic.aggregator->Accumulate(it->second));
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::StatusOr<absl::Cord> SimpleAggregationProtocol::CreateReport() {
+ absl::MutexLock lock(&aggregation_mu_);
+ for (auto& intrinsic : intrinsics_) {
+ FCP_CHECK(intrinsic.aggregator != nullptr)
+ << "CreateReport() has already been called.";
+ if (!intrinsic.aggregator->CanReport()) {
+ return absl::FailedPreconditionError(
+ "The aggregation can't be completed due to failed preconditions.");
+ }
+ }
+
+ // Build the resulting checkpoint.
+ std::unique_ptr<CheckpointBuilder> checkpoint_builder =
+ checkpoint_builder_factory_->Create();
+ for (auto& intrinsic : intrinsics_) {
+ FCP_ASSIGN_OR_RETURN(OutputTensorList output_tensors,
+ std::move(*intrinsic.aggregator).Report());
+ // TODO(team): Support multiple output tensors per intrinsic.
+ FCP_CHECK(output_tensors.size() == 1);
+ const Tensor& tensor = output_tensors[0];
+ FCP_CHECK(tensor.dtype() == intrinsic.output.dtype());
+ FCP_CHECK(tensor.shape() == intrinsic.output.shape());
+ FCP_RETURN_IF_ERROR(
+ checkpoint_builder->Add(intrinsic.output.name(), tensor));
+ }
+ aggregation_finished_ = true;
+ return checkpoint_builder->Build();
+}
+
+absl::Status SimpleAggregationProtocol::Start(int64_t num_clients) {
+ if (num_clients < 0) {
+ return absl::InvalidArgumentError("Number of clients cannot be negative.");
+ }
+ {
+ absl::MutexLock lock(&state_mu_);
+ FCP_RETURN_IF_ERROR(CheckProtocolState(PROTOCOL_CREATED));
+ SetProtocolState(PROTOCOL_STARTED);
+ FCP_CHECK(client_states_.empty());
+ client_states_.resize(num_clients, CLIENT_PENDING);
+ }
+ if (num_clients > 0) {
+ AcceptanceMessage acceptance_message;
+ callback_->OnAcceptClients(0, num_clients, acceptance_message);
+ }
+ return absl::OkStatus();
+}
+
+absl::Status SimpleAggregationProtocol::AddClients(int64_t num_clients) {
+ int64_t start_index;
+ {
+ absl::MutexLock lock(&state_mu_);
+ FCP_RETURN_IF_ERROR(CheckProtocolState(PROTOCOL_STARTED));
+ if (num_clients <= 0) {
+ return absl::InvalidArgumentError("Non-zero number of clients required");
+ }
+ start_index = client_states_.size();
+ client_states_.resize(start_index + num_clients, CLIENT_PENDING);
+ }
+ AcceptanceMessage acceptance_message;
+ callback_->OnAcceptClients(start_index, num_clients, acceptance_message);
+ return absl::OkStatus();
+}
+
+absl::Status SimpleAggregationProtocol::ReceiveClientMessage(
+ int64_t client_id, const ClientMessage& message) {
+ if (!message.has_simple_aggregation() ||
+ !message.simple_aggregation().has_input()) {
+ return absl::InvalidArgumentError("Unexpected message");
+ }
+
+ if (!message.simple_aggregation().input().has_inline_bytes() &&
+ !message.simple_aggregation().input().has_uri()) {
+ return absl::InvalidArgumentError(
+ "Only inline_bytes or uri type of input is supported");
+ }
+
+ // Verify the state.
+ {
+ absl::MutexLock lock(&state_mu_);
+ if (protocol_state_ == PROTOCOL_CREATED) {
+ return absl::FailedPreconditionError("The protocol hasn't been started");
+ }
+ FCP_ASSIGN_OR_RETURN(auto client_state, GetClientState(client_id));
+ if (client_state != CLIENT_PENDING) {
+ // TODO(team): Decide whether the logging level should be INFO or
+ // WARNING, or perhaps it should depend on the client state (e.g. WARNING
+ // for COMPLETED and INFO for other states).
+ FCP_LOG(INFO) << "ReceiveClientMessage: client " << client_id
+ << " message ignored, the state is already "
+ << ClientStateDebugString(client_state);
+ return absl::OkStatus();
+ }
+ SetClientState(client_id, CLIENT_RECEIVED_INPUT_AND_PENDING);
+ }
+
+ absl::Status client_completion_status = absl::OkStatus();
+ ClientState client_completion_state = CLIENT_COMPLETED;
+
+ absl::Cord report;
+ if (message.simple_aggregation().input().has_inline_bytes()) {
+ // Parse the client input concurrently with other protocol calls.
+ report =
+ absl::Cord(message.simple_aggregation().input().inline_bytes());
+ } else {
+ absl::StatusOr<absl::Cord> report_or_status =
+ resource_resolver_->RetrieveResource(
+ client_id, message.simple_aggregation().input().uri());
+ if (!report_or_status.ok()) {
+ client_completion_status = report_or_status.status();
+ client_completion_state = CLIENT_FAILED;
+ FCP_LOG(WARNING) << "Report with resource uri "
+ << message.simple_aggregation().input().uri()
+ << " for client " << client_id << "is missing. "
+ << client_completion_status.ToString();
+ } else {
+ report = std::move(report_or_status.value());
+ }
+ }
+
+ if (client_completion_state != CLIENT_FAILED) {
+ absl::StatusOr<TensorMap> tensor_map_or_status =
+ ParseCheckpoint(std::move(report));
+ if (!tensor_map_or_status.ok()) {
+ client_completion_status = tensor_map_or_status.status();
+ client_completion_state = CLIENT_FAILED;
+ FCP_LOG(WARNING) << "Client " << client_id << " input can't be parsed: "
+ << client_completion_status.ToString();
+ } else {
+ // Aggregate the client input which would block on aggregation_mu_ if
+ // there are any concurrent AggregateClientInput calls.
+ client_completion_status =
+ AggregateClientInput(std::move(tensor_map_or_status).value());
+ if (!client_completion_status.ok()) {
+ client_completion_state = CLIENT_DISCARDED;
+ FCP_LOG(INFO) << "Client " << client_id << " input is discarded: "
+ << client_completion_status.ToString();
+ }
+ }
+ }
+
+ // Update the state post aggregation.
+ {
+ absl::MutexLock lock(&state_mu_);
+ // Change the client state only if the current state is still
+ // CLIENT_RECEIVED_INPUT_AND_PENDING, meaning that the client wasn't already
+ // closed by a concurrent Complete or Abort call.
+ if (client_states_[client_id] == CLIENT_RECEIVED_INPUT_AND_PENDING) {
+ SetClientState(client_id, client_completion_state);
+ callback_->OnCloseClient(client_id, client_completion_status);
+ }
+ }
+ return absl::OkStatus();
+}
+
+absl::Status SimpleAggregationProtocol::CloseClient(
+ int64_t client_id, absl::Status client_status) {
+ {
+ absl::MutexLock lock(&state_mu_);
+ if (protocol_state_ == PROTOCOL_CREATED) {
+ return absl::FailedPreconditionError("The protocol hasn't been started");
+ }
+ FCP_ASSIGN_OR_RETURN(auto client_state, GetClientState(client_id));
+ // Close the client only if the client is currently pending.
+ if (client_state == CLIENT_PENDING) {
+ FCP_LOG(INFO) << "Closing client " << client_id << " with the status "
+ << client_status.ToString();
+ SetClientState(client_id,
+ client_status.ok() ? CLIENT_DISCARDED : CLIENT_FAILED);
+ }
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status SimpleAggregationProtocol::Complete() {
+ absl::Cord result;
+ std::vector<int64_t> client_ids_to_close;
+ {
+ absl::MutexLock lock(&state_mu_);
+ FCP_RETURN_IF_ERROR(CheckProtocolState(PROTOCOL_STARTED));
+ FCP_ASSIGN_OR_RETURN(result, CreateReport());
+ SetProtocolState(PROTOCOL_COMPLETED);
+ for (int64_t client_id = 0; client_id < client_states_.size();
+ client_id++) {
+ switch (client_states_[client_id]) {
+ case CLIENT_PENDING:
+ SetClientState(client_id, CLIENT_ABORTED);
+ client_ids_to_close.push_back(client_id);
+ break;
+ case CLIENT_RECEIVED_INPUT_AND_PENDING:
+ SetClientState(client_id, CLIENT_DISCARDED);
+ client_ids_to_close.push_back(client_id);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ for (int64_t client_id : client_ids_to_close) {
+ callback_->OnCloseClient(
+ client_id, absl::AbortedError("The protocol has completed before the "
+ "client input has been aggregated."));
+ }
+ callback_->OnComplete(std::move(result));
+ return absl::OkStatus();
+}
+
+absl::Status SimpleAggregationProtocol::Abort() {
+ std::vector<int64_t> client_ids_to_close;
+ {
+ absl::MutexLock lock(&state_mu_);
+ FCP_RETURN_IF_ERROR(CheckProtocolState(PROTOCOL_STARTED));
+ aggregation_finished_ = true;
+ SetProtocolState(PROTOCOL_ABORTED);
+ for (int64_t client_id = 0; client_id < client_states_.size();
+ client_id++) {
+ switch (client_states_[client_id]) {
+ case CLIENT_PENDING:
+ SetClientState(client_id, CLIENT_ABORTED);
+ client_ids_to_close.push_back(client_id);
+ break;
+ case CLIENT_RECEIVED_INPUT_AND_PENDING:
+ SetClientState(client_id, CLIENT_DISCARDED);
+ client_ids_to_close.push_back(client_id);
+ break;
+ case CLIENT_COMPLETED:
+ SetClientState(client_id, CLIENT_DISCARDED);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+
+ for (int64_t client_id : client_ids_to_close) {
+ callback_->OnCloseClient(
+ client_id, absl::AbortedError("The protocol has aborted before the "
+ "client input has been aggregated."));
+ }
+ return absl::OkStatus();
+}
+
+StatusMessage SimpleAggregationProtocol::GetStatus() {
+ absl::MutexLock lock(&state_mu_);
+ int64_t num_clients_completed = num_clients_received_and_pending_ +
+ num_clients_aggregated_ +
+ num_clients_discarded_;
+ StatusMessage message;
+ message.set_num_clients_completed(num_clients_completed);
+ message.set_num_clients_failed(num_clients_failed_);
+ message.set_num_clients_pending(client_states_.size() -
+ num_clients_completed - num_clients_failed_ -
+ num_clients_aborted_);
+ message.set_num_inputs_aggregated_and_included(num_clients_aggregated_);
+ message.set_num_inputs_aggregated_and_pending(
+ num_clients_received_and_pending_);
+ message.set_num_clients_aborted(num_clients_aborted_);
+ message.set_num_inputs_discarded(num_clients_discarded_);
+ return message;
+}
+
+} // namespace fcp::aggregation
diff --git a/fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.h b/fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.h
new file mode 100644
index 0000000..83aa272
--- /dev/null
+++ b/fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.h
@@ -0,0 +1,217 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_PROTOCOL_SIMPLE_AGGREGATION_SIMPLE_AGGREGATION_PROTOCOL_H_
+#define FCP_AGGREGATION_PROTOCOL_SIMPLE_AGGREGATION_SIMPLE_AGGREGATION_PROTOCOL_H_
+
+#include <atomic>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/base/attributes.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/string_view.h"
+#include "fcp/aggregation/core/tensor_aggregator.h"
+#include "fcp/aggregation/core/tensor_spec.h"
+#include "fcp/aggregation/protocol/aggregation_protocol.h"
+#include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
+#include "fcp/aggregation/protocol/checkpoint_builder.h"
+#include "fcp/aggregation/protocol/checkpoint_parser.h"
+#include "fcp/aggregation/protocol/configuration.pb.h"
+#include "fcp/aggregation/protocol/resource_resolver.h"
+
+namespace fcp::aggregation {
+
+// Implementation of the simple aggregation protocol.
+//
+// This version of the protocol receives updates in the clear from clients in a
+// TF checkpoint and aggregates them in memory. The aggregated updates are
+// released only if the number of participants exceed configured threshold.
+class SimpleAggregationProtocol final : public AggregationProtocol {
+ public:
+ // Validates the Configuration that will subsequently be used to create an
+ // instance of this protocol.
+ // Returns INVALID_ARGUMENT if the configuration is invalid.
+ static absl::Status ValidateConfig(const Configuration& configuration);
+
+ // Factory method to create an instance of the Simple Aggregation Protocol.
+ //
+ // Does not take ownership of the callback, which must refer to a valid object
+ // that outlives the SimpleAggregationProtocol instance.
+ static absl::StatusOr<std::unique_ptr<SimpleAggregationProtocol>> Create(
+ const Configuration& configuration,
+ AggregationProtocol::Callback* callback,
+ const CheckpointParserFactory* checkpoint_parser_factory,
+ const CheckpointBuilderFactory* checkpoint_builder_factory,
+ ResourceResolver* resource_resolver);
+
+ // Implementation of the overridden Aggregation Protocol methods.
+ absl::Status Start(int64_t num_clients) override;
+ absl::Status AddClients(int64_t num_clients) override;
+ absl::Status ReceiveClientMessage(int64_t client_id,
+ const ClientMessage& message) override;
+ absl::Status CloseClient(int64_t client_id,
+ absl::Status client_status) override;
+ absl::Status Complete() override;
+ absl::Status Abort() override;
+ StatusMessage GetStatus() override;
+
+ ~SimpleAggregationProtocol() override = default;
+
+ // SimpleAggregationProtocol is neither copyable nor movable.
+ SimpleAggregationProtocol(const SimpleAggregationProtocol&) = delete;
+ SimpleAggregationProtocol& operator=(const SimpleAggregationProtocol&) =
+ delete;
+
+ private:
+ // The structure representing a single aggregation intrinsic.
+ // TODO(team): Implement mapping of multiple inputs and outputs to
+ // individual TensorAggregator instances.
+ struct Intrinsic {
+ TensorSpec input;
+ TensorSpec output;
+ std::unique_ptr<TensorAggregator> aggregator
+ ABSL_PT_GUARDED_BY(&SimpleAggregationProtocol::aggregation_mu_);
+ };
+
+ // Private constructor.
+ SimpleAggregationProtocol(
+ std::vector<Intrinsic> intrinsics,
+ AggregationProtocol::Callback* callback,
+ const CheckpointParserFactory* checkpoint_parser_factory,
+ const CheckpointBuilderFactory* checkpoint_builder_factory,
+ ResourceResolver* resource_resolver);
+
+ // Creates an aggregation intrinsic based on the intrinsic configuration.
+ static absl::StatusOr<Intrinsic> CreateIntrinsic(
+ const Configuration::ServerAggregationConfig& aggregation_config);
+
+ // Describes the overall protocol state.
+ enum ProtocolState {
+ // The initial state indicating that the protocol was created.
+ PROTOCOL_CREATED,
+ // The protocol `Start` method has been called.
+ PROTOCOL_STARTED,
+ // The protocol `Complete` method has finished successfully.
+ PROTOCOL_COMPLETED,
+ // The protocol `Abort` method has been called.
+ PROTOCOL_ABORTED
+ };
+
+ // Describes state of each client participating in the protocol.
+ enum ClientState : uint8_t {
+ // No input received from the client yet.
+ CLIENT_PENDING,
+ // Client input received but the aggregation still pending, which may
+ // be the case when there are multiple concurrent ReceiveClientMessage
+ // calls.
+ CLIENT_RECEIVED_INPUT_AND_PENDING,
+ // Client input has been successfully aggregated.
+ CLIENT_COMPLETED,
+ // Client failed either by being closed with an error or by submitting a
+ // malformed input.
+ CLIENT_FAILED,
+ // Client which has been aborted by the server before its input has been
+ // received.
+ CLIENT_ABORTED,
+ // Client input has been received but discarded, for example due to the
+ // protocol Abort method being called.
+ CLIENT_DISCARDED
+ };
+
+ // Returns string representation of the protocol state.
+ static absl::string_view ProtocolStateDebugString(ProtocolState state);
+
+ // Returns string representation of the client state.
+ static absl::string_view ClientStateDebugString(ClientState state);
+
+ // Returns an error if the current protocol state isn't the expected one.
+ absl::Status CheckProtocolState(ProtocolState state) const
+ ABSL_SHARED_LOCKS_REQUIRED(state_mu_);
+
+ // Changes the protocol state.
+ void SetProtocolState(ProtocolState state)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
+
+ // Gets the client state for the given client ID.
+ absl::StatusOr<ClientState> GetClientState(int64_t client_id) const
+ ABSL_SHARED_LOCKS_REQUIRED(state_mu_);
+
+ // Sets the client state for the given client ID.
+ void SetClientState(int64_t client_id, ClientState state)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
+
+ // Parses and validates the client report.
+ // This function involves a potentially expensive I/O and parsing and should
+ // run concurrently as much as possible. The ABSL_LOCKS_EXCLUDED attribution
+ // below is used to emphasize that.
+ using TensorMap = absl::flat_hash_map<std::string, Tensor>;
+ absl::StatusOr<TensorMap> ParseCheckpoint(absl::Cord report) const
+ ABSL_LOCKS_EXCLUDED(state_mu_, aggregation_mu_);
+
+ // Aggregates the input via the underlying aggregators.
+ absl::Status AggregateClientInput(TensorMap tensor_map)
+ ABSL_LOCKS_EXCLUDED(state_mu_, aggregation_mu_);
+
+ // Produces the report via the underlying aggregators.
+ absl::StatusOr<absl::Cord> CreateReport()
+ ABSL_LOCKS_EXCLUDED(aggregation_mu_);
+
+ // Protects the mutable state.
+ absl::Mutex state_mu_;
+ // Protects calls into the aggregators.
+ absl::Mutex aggregation_mu_;
+ // This indicates that the aggregation has finished either by completing
+ // the protocol or by aborting it. This can be triggered without locking on
+ // the aggregation_mu_ mutex first to allow aborting the protocol promptly and
+ // discarding all the pending aggregation calls.
+ std::atomic_bool aggregation_finished_ = false;
+
+ // The overall state of the protocol.
+ ProtocolState protocol_state_ ABSL_GUARDED_BY(state_mu_);
+
+ // Holds state of all clients. The length of the vector equals
+ // to the number of clients accepted into the protocol.
+ std::vector<ClientState> client_states_ ABSL_GUARDED_BY(state_mu_);
+
+ // Counters for various client states other than pending.
+ // Note that the number of pending clients can be found by subtracting the
+ // sum of the below counters from `client_states_.size()`.
+ uint64_t num_clients_received_and_pending_ ABSL_GUARDED_BY(state_mu_) = 0;
+ uint64_t num_clients_aggregated_ ABSL_GUARDED_BY(state_mu_) = 0;
+ uint64_t num_clients_failed_ ABSL_GUARDED_BY(state_mu_) = 0;
+ uint64_t num_clients_aborted_ ABSL_GUARDED_BY(state_mu_) = 0;
+ uint64_t num_clients_discarded_ ABSL_GUARDED_BY(state_mu_) = 0;
+
+ // Intrinsics are immutable and shouldn't be guarded by the either of mutexes.
+ // Please note that the access to the aggregators that intrinsics point to
+ // still needs to be strictly sequential. That is guarded separatedly by
+ // `aggregators_mu_`.
+ std::vector<Intrinsic> const intrinsics_;
+
+ AggregationProtocol::Callback* const callback_;
+ const CheckpointParserFactory* const checkpoint_parser_factory_;
+ const CheckpointBuilderFactory* const checkpoint_builder_factory_;
+ ResourceResolver* const resource_resolver_;
+};
+} // namespace fcp::aggregation
+
+#endif // FCP_AGGREGATION_PROTOCOL_SIMPLE_AGGREGATION_SIMPLE_AGGREGATION_PROTOCOL_H_
diff --git a/fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol_test.cc b/fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol_test.cc
new file mode 100644
index 0000000..f988d7b
--- /dev/null
+++ b/fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol_test.cc
@@ -0,0 +1,972 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.h"
+
+#include <atomic>
+#include <functional>
+#include <memory>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/synchronization/notification.h"
+#include "fcp/aggregation/core/agg_vector_aggregator.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
+#include "fcp/aggregation/core/tensor_aggregator_registry.h"
+#include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
+#include "fcp/aggregation/protocol/checkpoint_builder.h"
+#include "fcp/aggregation/protocol/checkpoint_parser.h"
+#include "fcp/aggregation/protocol/configuration.pb.h"
+#include "fcp/aggregation/protocol/testing/test_callback.h"
+#include "fcp/aggregation/testing/test_data.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::aggregation {
+namespace {
+
+using ::testing::_;
+using ::testing::ByMove;
+using ::testing::Eq;
+using ::testing::Invoke;
+using ::testing::Return;
+using ::testing::StrEq;
+
+// TODO(team): Consider moving mock classes into a separate test library.
+class MockCheckpointParser : public CheckpointParser {
+ public:
+ MOCK_METHOD(absl::StatusOr<Tensor>, GetTensor, (const std::string& name),
+ (const override));
+};
+
+class MockCheckpointParserFactory : public CheckpointParserFactory {
+ public:
+ MOCK_METHOD(absl::StatusOr<std::unique_ptr<CheckpointParser>>, Create,
+ (const absl::Cord& serialized_checkpoint), (const override));
+};
+
+class MockCheckpointBuilder : public CheckpointBuilder {
+ public:
+ MOCK_METHOD(absl::Status, Add,
+ (const std::string& name, const Tensor& tensor), (override));
+ MOCK_METHOD(absl::StatusOr<absl::Cord>, Build, (), (override));
+};
+
+class MockCheckpointBuilderFactory : public CheckpointBuilderFactory {
+ public:
+ MOCK_METHOD(std::unique_ptr<CheckpointBuilder>, Create, (), (const override));
+};
+
+class MockResourceResolver : public ResourceResolver {
+ public:
+ MOCK_METHOD(absl::StatusOr<absl::Cord>, RetrieveResource,
+ (int64_t client_id, const std::string& uri), (override));
+};
+
+class SimpleAggregationProtocolTest : public ::testing::Test {
+ protected:
+ // Returns default configuration.
+ Configuration default_configuration() const;
+
+ // Returns the default instance of checkpoint bilder;
+ MockCheckpointBuilder& ExpectCheckpointBuilder() {
+ MockCheckpointBuilder& checkpoint_builder = *wrapped_checkpoint_builder_;
+ EXPECT_CALL(checkpoint_builder_factory_, Create())
+ .WillOnce(Return(ByMove(std::move(wrapped_checkpoint_builder_))));
+ EXPECT_CALL(checkpoint_builder, Build()).WillOnce(Return(absl::Cord{}));
+ return checkpoint_builder;
+ }
+
+ // Creates an instance of SimpleAggregationProtocol with the specified config.
+ std::unique_ptr<SimpleAggregationProtocol> CreateProtocol(
+ Configuration config);
+
+ // Creates an instance of SimpleAggregationProtocol with the default config.
+ std::unique_ptr<SimpleAggregationProtocol> CreateProtocolWithDefaultConfig() {
+ return CreateProtocol(default_configuration());
+ }
+
+ MockAggregationProtocolCallback callback_;
+
+ MockCheckpointParserFactory checkpoint_parser_factory_;
+ MockCheckpointBuilderFactory checkpoint_builder_factory_;
+ MockResourceResolver resource_resolver_;
+
+ private:
+ std::unique_ptr<MockCheckpointBuilder> wrapped_checkpoint_builder_ =
+ std::make_unique<MockCheckpointBuilder>();
+};
+
+Configuration SimpleAggregationProtocolTest::default_configuration() const {
+ // One "federated_sum" intrinsic with a single scalar int32 tensor.
+ return PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ )pb");
+}
+
+std::unique_ptr<SimpleAggregationProtocol>
+SimpleAggregationProtocolTest::CreateProtocol(Configuration config) {
+ // Verify that the protocol can be created successfully.
+ absl::StatusOr<std::unique_ptr<SimpleAggregationProtocol>>
+ protocol_or_status = SimpleAggregationProtocol::Create(
+ config, &callback_, &checkpoint_parser_factory_,
+ &checkpoint_builder_factory_, &resource_resolver_);
+ EXPECT_THAT(protocol_or_status, IsOk());
+ return std::move(protocol_or_status).value();
+}
+
+ClientMessage MakeClientMessage() {
+ ClientMessage message;
+ message.mutable_simple_aggregation()->mutable_input()->set_inline_bytes("");
+ return message;
+}
+
+TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedNumberOfInputs) {
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ intrinsic_args {
+ input_tensor {
+ name: "bar"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ )pb");
+ EXPECT_THAT(SimpleAggregationProtocol::Create(
+ config_message, &callback_, &checkpoint_parser_factory_,
+ &checkpoint_builder_factory_, &resource_resolver_),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedNumberOfOutputs) {
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {}
+ }
+ output_tensors {
+ name: "bar_out"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ )pb");
+ EXPECT_THAT(SimpleAggregationProtocol::Create(
+ config_message, &callback_, &checkpoint_parser_factory_,
+ &checkpoint_builder_factory_, &resource_resolver_),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedInputType) {
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args { parameter {} }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ )pb");
+ EXPECT_THAT(SimpleAggregationProtocol::Create(
+ config_message, &callback_, &checkpoint_parser_factory_,
+ &checkpoint_builder_factory_, &resource_resolver_),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedIntrinsicUri) {
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "unsupported_xyz"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ )pb");
+ EXPECT_THAT(SimpleAggregationProtocol::Create(
+ config_message, &callback_, &checkpoint_parser_factory_,
+ &checkpoint_builder_factory_, &resource_resolver_),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedInputSpec) {
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape { dim { size: -1 } }
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ )pb");
+ EXPECT_THAT(SimpleAggregationProtocol::Create(
+ config_message, &callback_, &checkpoint_parser_factory_,
+ &checkpoint_builder_factory_, &resource_resolver_),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Create_UnmatchingInputAndOutputDataType) {
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_FLOAT
+ shape {}
+ }
+ }
+ )pb");
+ EXPECT_THAT(SimpleAggregationProtocol::Create(
+ config_message, &callback_, &checkpoint_parser_factory_,
+ &checkpoint_builder_factory_, &resource_resolver_),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Create_UnmatchingInputAndOutputShape) {
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape { dim { size: 1 } }
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape { dim { size: 2 } }
+ }
+ }
+ )pb");
+ EXPECT_THAT(SimpleAggregationProtocol::Create(
+ config_message, &callback_, &checkpoint_parser_factory_,
+ &checkpoint_builder_factory_, &resource_resolver_),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest, StartProtocol_Success) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_THAT(protocol->Start(3), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 3")));
+}
+
+// TODO(team): Add similar tests for other callbacks.
+TEST_F(SimpleAggregationProtocolTest,
+ StartProtocol_AcceptClientsProtocolReentrace) {
+ // This verifies that the protocol can be re-entered from the callback.
+ auto protocol = CreateProtocolWithDefaultConfig();
+
+ EXPECT_CALL(callback_, OnAcceptClients(0, 1, _)).WillOnce(Invoke([&]() {
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 1")));
+ }));
+
+ EXPECT_THAT(protocol->Start(1), IsOk());
+}
+
+TEST_F(SimpleAggregationProtocolTest, StartProtocol_MultipleCalls) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients).Times(1);
+ EXPECT_THAT(protocol->Start(1), IsOk());
+ // The second Start call must fail.
+ EXPECT_THAT(protocol->Start(1), IsCode(FAILED_PRECONDITION));
+}
+
+TEST_F(SimpleAggregationProtocolTest, StartProtocol_ZeroClients) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients).Times(0);
+ EXPECT_THAT(protocol->Start(0), IsOk());
+}
+
+TEST_F(SimpleAggregationProtocolTest, StartProtocol_NegativeClients) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients).Times(0);
+ EXPECT_THAT(protocol->Start(-1), IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest, AddClients_Success) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+
+ EXPECT_CALL(callback_, OnAcceptClients(0, 1, _));
+ EXPECT_THAT(protocol->Start(1), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 1")));
+
+ EXPECT_CALL(callback_, OnAcceptClients(1, 3, _));
+ EXPECT_THAT(protocol->AddClients(3), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 4")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, AddClients_ProtocolNotStarted) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ // Must fail because the protocol isn't started.
+ EXPECT_CALL(callback_, OnAcceptClients).Times(0);
+ EXPECT_THAT(protocol->AddClients(1), IsCode(FAILED_PRECONDITION));
+}
+
+TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_ProtocolNotStarted) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ // Must fail because the protocol isn't started.
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()),
+ IsCode(FAILED_PRECONDITION));
+}
+
+TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_InvalidMessage) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ ClientMessage message;
+ // Empty message without SimpleAggregation.
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, message),
+ IsCode(INVALID_ARGUMENT));
+ // Message with SimpleAggregation but without the input.
+ message.mutable_simple_aggregation();
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, message),
+ IsCode(INVALID_ARGUMENT));
+ // Message with empty input.
+ message.mutable_simple_aggregation()->mutable_input();
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, message),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_InvalidClientId) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients);
+ EXPECT_THAT(protocol->Start(1), IsOk());
+ // Must fail for the client_id -1 and 2.
+ EXPECT_THAT(protocol->ReceiveClientMessage(-1, MakeClientMessage()),
+ IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(protocol->ReceiveClientMessage(2, MakeClientMessage()),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(SimpleAggregationProtocolTest,
+ ReceiveClientMessage_DuplicateClientIdInputs) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients);
+ EXPECT_THAT(protocol->Start(2), IsOk());
+
+ auto parser = std::make_unique<MockCheckpointParser>();
+ EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
+ return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
+ }));
+
+ EXPECT_CALL(checkpoint_parser_factory_, Create(_))
+ .WillOnce(Return(ByMove(std::move(parser))));
+
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
+ // The second input for the same client must succeed to without changing the
+ // aggregated state.
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
+ EXPECT_THAT(protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
+ "num_clients_pending: 1 num_clients_completed: 1 "
+ "num_inputs_aggregated_and_included: 1")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_AfterClosingClient) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients);
+ EXPECT_THAT(protocol->Start(1), IsOk());
+
+ EXPECT_THAT(protocol->CloseClient(0, absl::OkStatus()), IsOk());
+ EXPECT_THAT(protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
+ "num_clients_completed: 1 num_inputs_discarded: 1")));
+ // This must succeed to without changing the aggregated state.
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
+ EXPECT_THAT(protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
+ "num_clients_completed: 1 num_inputs_discarded: 1")));
+}
+
+TEST_F(SimpleAggregationProtocolTest,
+ ReceiveClientMessage_FailureToParseInput) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients);
+ EXPECT_THAT(protocol->Start(1), IsOk());
+
+ EXPECT_CALL(checkpoint_parser_factory_, Create(_))
+ .WillOnce(
+ Return(ByMove(absl::InvalidArgumentError("Invalid checkpoint"))));
+
+ EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(INVALID_ARGUMENT)));
+
+ // Receiving the client input should still succeed.
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_MissingTensor) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients);
+ EXPECT_THAT(protocol->Start(1), IsOk());
+
+ auto parser = std::make_unique<MockCheckpointParser>();
+ EXPECT_CALL(*parser, GetTensor(StrEq("foo")))
+ .WillOnce(Return(ByMove(absl::NotFoundError("Missing tensor foo"))));
+
+ EXPECT_CALL(checkpoint_parser_factory_, Create(_))
+ .WillOnce(Return(ByMove(std::move(parser))));
+
+ EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(NOT_FOUND)));
+
+ // Receiving the client input should still succeed.
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_MismatchingTensor) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients);
+ EXPECT_THAT(protocol->Start(1), IsOk());
+
+ auto parser = std::make_unique<MockCheckpointParser>();
+ EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
+ return Tensor::Create(DT_FLOAT, {}, CreateTestData({2.f}));
+ }));
+
+ EXPECT_CALL(checkpoint_parser_factory_, Create(_))
+ .WillOnce(Return(ByMove(std::move(parser))));
+
+ EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(INVALID_ARGUMENT)));
+
+ // Receiving the client input should still succeed.
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_UriType_Success) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients);
+ EXPECT_THAT(protocol->Start(1), IsOk());
+ auto parser = std::make_unique<MockCheckpointParser>();
+ EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
+ return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
+ }));
+ EXPECT_CALL(checkpoint_parser_factory_, Create(_))
+ .WillOnce(Return(ByMove(std::move(parser))));
+
+ // Receive input for the client #0
+ EXPECT_CALL(resource_resolver_, RetrieveResource(0, StrEq("foo_uri")))
+ .WillOnce(Return(absl::Cord{}));
+ ClientMessage message;
+ message.mutable_simple_aggregation()->mutable_input()->set_uri("foo_uri");
+ EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(OK)));
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, message), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
+ "num_clients_completed: 1 num_inputs_aggregated_and_included: 1")));
+}
+
+TEST_F(SimpleAggregationProtocolTest,
+ ReceiveClientMessage_UriType_FailToParse) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients);
+ EXPECT_THAT(protocol->Start(1), IsOk());
+
+ // Receive invalid input for the client #0
+ EXPECT_CALL(resource_resolver_, RetrieveResource(0, _))
+ .WillOnce(Return(absl::InvalidArgumentError("Invalid uri")));
+ ClientMessage message;
+ message.mutable_simple_aggregation()->mutable_input()->set_uri("foo_uri");
+ EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(INVALID_ARGUMENT)));
+
+ // Receiving the client input should still succeed.
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, message), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Complete_NoInputsReceived) {
+ // Two intrinsics:
+ // 1) federated_sum "foo" that takes int32 {2,3} tensors.
+ // 2) federated_sum "bar" that takes scalar float tensors.
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape {
+ dim { size: 2 }
+ dim { size: 3 }
+ }
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {
+ dim { size: 2 }
+ dim { size: 3 }
+ }
+ }
+ }
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "bar"
+ dtype: DT_FLOAT
+ shape {}
+ }
+ }
+ output_tensors {
+ name: "bar_out"
+ dtype: DT_FLOAT
+ shape {}
+ }
+ }
+ )pb");
+ auto protocol = CreateProtocol(config_message);
+
+ EXPECT_CALL(callback_, OnAcceptClients(0, 1, _));
+ EXPECT_THAT(protocol->Start(1), IsOk());
+
+ // Verify that the checkpoint builder is created.
+ auto& checkpoint_builder = ExpectCheckpointBuilder();
+
+ // Verify that foo_out and bar_out tensors are added to the result checkpoint
+ EXPECT_CALL(checkpoint_builder,
+ Add(StrEq("foo_out"), IsTensor({2, 3}, {0, 0, 0, 0, 0, 0})))
+ .WillOnce(Return(absl::OkStatus()));
+ EXPECT_CALL(checkpoint_builder, Add(StrEq("bar_out"), IsTensor({}, {0.f})))
+ .WillOnce(Return(absl::OkStatus()));
+
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 1")));
+
+ // Verify that the pending client is closed.
+ EXPECT_CALL(callback_, OnCloseClient(0, IsCode(ABORTED)));
+ // Verify that the Complete callback method is called.
+ EXPECT_CALL(callback_, OnComplete);
+
+ EXPECT_THAT(protocol->Complete(), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_aborted: 1")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Complete_TwoInputsReceived) {
+ // Two intrinsics:
+ // 1) federated_sum "foo" that takes int32 {2,3} tensors.
+ // 2) federated_sum "bar" that takes scalar float tensors.
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape {
+ dim { size: 2 }
+ dim { size: 3 }
+ }
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {
+ dim { size: 2 }
+ dim { size: 3 }
+ }
+ }
+ }
+ aggregation_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args {
+ input_tensor {
+ name: "bar"
+ dtype: DT_FLOAT
+ shape {}
+ }
+ }
+ output_tensors {
+ name: "bar_out"
+ dtype: DT_FLOAT
+ shape {}
+ }
+ }
+ )pb");
+ auto protocol = CreateProtocol(config_message);
+ EXPECT_CALL(callback_, OnAcceptClients);
+ EXPECT_THAT(protocol->Start(2), IsOk());
+
+ // Expect two inputs.
+ auto parser1 = std::make_unique<MockCheckpointParser>();
+ EXPECT_CALL(*parser1, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
+ return Tensor::Create(DT_INT32, {2, 3},
+ CreateTestData({4, 3, 11, 7, 1, 6}));
+ }));
+ EXPECT_CALL(*parser1, GetTensor(StrEq("bar"))).WillOnce(Invoke([] {
+ return Tensor::Create(DT_FLOAT, {}, CreateTestData({1.f}));
+ }));
+
+ auto parser2 = std::make_unique<MockCheckpointParser>();
+ EXPECT_CALL(*parser2, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
+ return Tensor::Create(DT_INT32, {2, 3},
+ CreateTestData({1, 8, 2, 10, 13, 2}));
+ }));
+ EXPECT_CALL(*parser2, GetTensor(StrEq("bar"))).WillOnce(Invoke([] {
+ return Tensor::Create(DT_FLOAT, {}, CreateTestData({2.f}));
+ }));
+
+ EXPECT_CALL(checkpoint_parser_factory_, Create(_))
+ .WillOnce(Return(ByMove(std::move(parser1))))
+ .WillOnce(Return(ByMove(std::move(parser2))));
+
+ EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(OK)));
+ EXPECT_CALL(callback_, OnCloseClient(Eq(1), IsCode(OK)));
+
+ // Handle the inputs.
+ EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
+ EXPECT_THAT(protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
+ "num_clients_pending: 1 num_clients_completed: 1 "
+ "num_inputs_aggregated_and_included: 1")));
+
+ EXPECT_THAT(protocol->ReceiveClientMessage(1, MakeClientMessage()), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
+ "num_clients_completed: 2 num_inputs_aggregated_and_included: 2")));
+
+ // Complete the protocol.
+ // Verify that the checkpoint builder is created.
+ auto& checkpoint_builder = ExpectCheckpointBuilder();
+
+ // Verify that foo_out and bar_out tensors are added to the result checkpoint
+ EXPECT_CALL(checkpoint_builder,
+ Add(StrEq("foo_out"), IsTensor({2, 3}, {5, 11, 13, 17, 14, 8})))
+ .WillOnce(Return(absl::OkStatus()));
+ EXPECT_CALL(checkpoint_builder, Add(StrEq("bar_out"), IsTensor({}, {3.f})))
+ .WillOnce(Return(absl::OkStatus()));
+
+ // Verify that the OnComplete callback method is called.
+ EXPECT_CALL(callback_, OnComplete);
+
+ EXPECT_THAT(protocol->Complete(), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
+ "num_clients_completed: 2 num_inputs_aggregated_and_included: 2")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Complete_ProtocolNotStarted) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_THAT(protocol->Complete(), IsCode(FAILED_PRECONDITION));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Abort_NoInputsReceived) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients(0, 2, _));
+ EXPECT_THAT(protocol->Start(2), IsOk());
+
+ EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(ABORTED)));
+ EXPECT_CALL(callback_, OnCloseClient(Eq(1), IsCode(ABORTED)));
+ EXPECT_THAT(protocol->Abort(), IsOk());
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_aborted: 2")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Abort_OneInputReceived) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients(0, 2, _));
+ EXPECT_THAT(protocol->Start(2), IsOk());
+
+ auto parser = std::make_unique<MockCheckpointParser>();
+ EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
+ return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
+ }));
+
+ EXPECT_CALL(checkpoint_parser_factory_, Create(_))
+ .WillOnce(Return(ByMove(std::move(parser))));
+
+ // Receive input for the client #1
+ EXPECT_CALL(callback_, OnCloseClient(Eq(1), IsCode(OK)));
+ EXPECT_THAT(protocol->ReceiveClientMessage(1, MakeClientMessage()), IsOk());
+
+ // The client #0 should be aborted on Abort().
+ EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(ABORTED)));
+ EXPECT_THAT(protocol->Abort(), IsOk());
+ EXPECT_THAT(protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
+ "num_clients_aborted: 1 num_clients_completed:1 "
+ "num_inputs_discarded: 1")));
+}
+
+TEST_F(SimpleAggregationProtocolTest, Abort_ProtocolNotStarted) {
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_THAT(protocol->Abort(), IsCode(FAILED_PRECONDITION));
+}
+
+TEST_F(SimpleAggregationProtocolTest, ConcurrentAggregation_Success) {
+ const int64_t kNumClients = 10;
+ auto protocol = CreateProtocolWithDefaultConfig();
+ EXPECT_CALL(callback_, OnAcceptClients(0, kNumClients, _));
+ EXPECT_THAT(protocol->Start(kNumClients), IsOk());
+
+ // The following block will repeatedly create CheckpointParser instances
+ // which will be creating scalar int tensors with repeatedly incrementing
+ // values.
+ std::atomic<int> tensor_value = 0;
+ EXPECT_CALL(checkpoint_parser_factory_, Create(_)).WillRepeatedly(Invoke([&] {
+ auto parser = std::make_unique<MockCheckpointParser>();
+ EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([&] {
+ return Tensor::Create(DT_INT32, {}, CreateTestData({++tensor_value}));
+ }));
+ return parser;
+ }));
+
+ // Schedule receiving inputs on 4 concurrent threads.
+ auto scheduler = CreateThreadPoolScheduler(4);
+ for (int64_t i = 0; i < kNumClients; ++i) {
+ scheduler->Schedule([&, i]() {
+ EXPECT_THAT(protocol->ReceiveClientMessage(i, MakeClientMessage()),
+ IsOk());
+ });
+ }
+ scheduler->WaitUntilIdle();
+
+ // Complete the protocol.
+ // Verify that the checkpoint builder is created.
+ auto& checkpoint_builder = ExpectCheckpointBuilder();
+ // Verify that foo_out tensor is added to the result checkpoint
+ EXPECT_CALL(checkpoint_builder, Add(StrEq("foo_out"), IsTensor({}, {55})))
+ .WillOnce(Return(absl::OkStatus()));
+
+ // Verify that the OnComplete callback method is called.
+ EXPECT_CALL(callback_, OnComplete);
+ EXPECT_THAT(protocol->Complete(), IsOk());
+}
+
+// A trivial test aggregator that delegates aggregation to a function.
+class FunctionAggregator final : public AggVectorAggregator<int> {
+ public:
+ using Func = std::function<int(int, int)>;
+
+ FunctionAggregator(DataType dtype, TensorShape shape, Func agg_function)
+ : AggVectorAggregator<int>(dtype, shape), agg_function_(agg_function) {}
+
+ private:
+ void AggregateVector(const AggVector<int>& agg_vector) override {
+ for (auto [i, v] : agg_vector) {
+ data()[i] = agg_function_(data()[i], v);
+ }
+ }
+
+ const Func agg_function_;
+};
+
+// Factory for the FunctionAggregator.
+class FunctionAggregatorFactory final : public TensorAggregatorFactory {
+ public:
+ explicit FunctionAggregatorFactory(FunctionAggregator::Func agg_function)
+ : agg_function_(agg_function) {}
+
+ private:
+ absl::StatusOr<std::unique_ptr<TensorAggregator>> Create(
+ DataType dtype, TensorShape shape) const override {
+ if (dtype != DT_INT32) {
+ return absl::InvalidArgumentError("Unsupported dtype: expected DT_INT32");
+ }
+ return std::make_unique<FunctionAggregator>(dtype, shape, agg_function_);
+ }
+
+ const FunctionAggregator::Func agg_function_;
+};
+
+TEST_F(SimpleAggregationProtocolTest, ConcurrentAggregation_AbortWhileQueued) {
+ const int64_t kNumClients = 10;
+ const int64_t kNumClientBeforeBlocking = 3;
+
+ // Notifies the aggregation to unblock;
+ absl::Notification resume_aggregation_notification;
+ absl::Notification aggregation_blocked_notification;
+ std::atomic<int> agg_counter = 0;
+ FunctionAggregatorFactory agg_factory([&](int a, int b) {
+ if (++agg_counter > kNumClientBeforeBlocking &&
+ !aggregation_blocked_notification.HasBeenNotified()) {
+ aggregation_blocked_notification.Notify();
+ resume_aggregation_notification.WaitForNotification();
+ }
+ return a + b;
+ });
+ RegisterAggregatorFactory("foo1_aggregation", &agg_factory);
+
+ // The configuration below refers to the custom aggregation registered
+ // above.
+ auto protocol = CreateProtocol(PARSE_TEXT_PROTO(R"pb(
+ aggregation_configs {
+ intrinsic_uri: "foo1_aggregation"
+ intrinsic_args {
+ input_tensor {
+ name: "foo"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ )pb"));
+ EXPECT_CALL(callback_, OnAcceptClients(0, kNumClients, _));
+ EXPECT_THAT(protocol->Start(kNumClients), IsOk());
+
+ EXPECT_CALL(checkpoint_parser_factory_, Create(_)).WillRepeatedly(Invoke([&] {
+ auto parser = std::make_unique<MockCheckpointParser>();
+ EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([&] {
+ return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
+ }));
+ return parser;
+ }));
+
+ // Schedule receiving inputs on 10 concurrent threads.
+ auto scheduler = CreateThreadPoolScheduler(10);
+ for (int64_t i = 0; i < kNumClients; ++i) {
+ scheduler->Schedule([&, i]() {
+ EXPECT_THAT(protocol->ReceiveClientMessage(i, MakeClientMessage()),
+ IsOk());
+ });
+ }
+
+ aggregation_blocked_notification.WaitForNotification();
+
+ StatusMessage status_message;
+ do {
+ status_message = protocol->GetStatus();
+ } while (status_message.num_clients_pending() > 0);
+
+ // At this point one input must be blocked inside the aggregation waiting for
+ // the notification, 3 inputs should already be gone through the aggregation,
+ // and the remaining 6 inputs should be blocked waiting to enter the
+ // aggregation.
+
+ // TODO(team): Need to revise the status implementation because it
+ // treats received and pending (queued) inputs "as aggregated and pending".
+ EXPECT_THAT(protocol->GetStatus(),
+ EqualsProto<StatusMessage>(
+ PARSE_TEXT_PROTO("num_clients_completed: 10 "
+ "num_inputs_aggregated_and_pending: 7 "
+ "num_inputs_aggregated_and_included: 3")));
+
+ resume_aggregation_notification.Notify();
+
+ // Abort and let all blocked aggregations continue.
+ EXPECT_THAT(protocol->Abort(), IsOk());
+ scheduler->WaitUntilIdle();
+
+ // All 10 inputs should now be discarded.
+ EXPECT_THAT(
+ protocol->GetStatus(),
+ EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_completed: 10 "
+ "num_inputs_discarded: 10")));
+}
+
+} // namespace
+} // namespace fcp::aggregation
diff --git a/fcp/aggregation/protocol/testing/BUILD b/fcp/aggregation/protocol/testing/BUILD
new file mode 100644
index 0000000..b91b5b0
--- /dev/null
+++ b/fcp/aggregation/protocol/testing/BUILD
@@ -0,0 +1,19 @@
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp:internal"],
+)
+
+cc_library(
+ name = "mock_callback",
+ testonly = 1,
+ hdrs = [
+ "test_callback.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/aggregation/protocol:aggregation_protocol",
+ "//fcp/aggregation/protocol:cc_proto",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/aggregation/protocol/testing/test_callback.h b/fcp/aggregation/protocol/testing/test_callback.h
new file mode 100644
index 0000000..4106cf9
--- /dev/null
+++ b/fcp/aggregation/protocol/testing/test_callback.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_PROTOCOL_TESTING_TEST_CALLBACK_H_
+#define FCP_AGGREGATION_PROTOCOL_TESTING_TEST_CALLBACK_H_
+
+#include "gmock/gmock.h"
+#include "fcp/aggregation/protocol/aggregation_protocol.h"
+#include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
+
+namespace fcp::aggregation {
+
+class MockAggregationProtocolCallback : public AggregationProtocol::Callback {
+ public:
+ MOCK_METHOD(void, OnAcceptClients,
+ (int64_t start_client_id, int64_t num_clients,
+ const AcceptanceMessage& message),
+ (override));
+ MOCK_METHOD(void, OnSendServerMessage,
+ (int64_t client_id, const ServerMessage& message), (override));
+ MOCK_METHOD(void, OnCloseClient,
+ (int64_t client_id, absl::Status diagnostic_status), (override));
+ MOCK_METHOD(void, OnComplete, (absl::Cord result), (override));
+ MOCK_METHOD(void, OnAbort, (absl::Status diagnostic_status), (override));
+};
+
+} // namespace fcp::aggregation
+
+#endif // FCP_AGGREGATION_PROTOCOL_TESTING_TEST_CALLBACK_H_
diff --git a/fcp/aggregation/tensorflow/BUILD b/fcp/aggregation/tensorflow/BUILD
new file mode 100644
index 0000000..2bfb90a
--- /dev/null
+++ b/fcp/aggregation/tensorflow/BUILD
@@ -0,0 +1,191 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp/aggregation:internal"],
+)
+
+cc_library(
+ name = "converters",
+ srcs = ["converters.cc"],
+ hdrs = ["converters.h"],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/aggregation/core:tensor",
+ "//fcp/aggregation/core:tensor_cc_proto",
+ "//fcp/base",
+ "@com_google_absl//absl/strings",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "converters_test",
+ srcs = ["converters_test.cc"],
+ deps = [
+ ":converters",
+ "//fcp/aggregation/core:tensor",
+ "//fcp/aggregation/testing",
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "checkpoint_reader",
+ srcs = ["checkpoint_reader.cc"],
+ hdrs = ["checkpoint_reader.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":converters",
+ "//fcp/aggregation/core:tensor",
+ "//fcp/base",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@org_tensorflow//tensorflow/c:checkpoint_reader",
+ "@org_tensorflow//tensorflow/c:tf_status_headers",
+ "@org_tensorflow//tensorflow/c:tf_status_helper",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "checkpoint_reader_test",
+ srcs = ["checkpoint_reader_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":checkpoint_reader",
+ "//fcp/aggregation/testing",
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "checkpoint_writer",
+ srcs = ["checkpoint_writer.cc"],
+ hdrs = ["checkpoint_writer.h"],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/aggregation/core:tensor",
+ "//fcp/base",
+ "//fcp/tensorflow:status",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core/platform:tstring",
+ ],
+)
+
+cc_test(
+ name = "checkpoint_writer_test",
+ srcs = ["checkpoint_writer_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":checkpoint_reader",
+ ":checkpoint_writer",
+ "//fcp/aggregation/core:tensor",
+ "//fcp/aggregation/testing",
+ "//fcp/aggregation/testing:test_data",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "tensorflow_checkpoint_builder_factory",
+ srcs = ["tensorflow_checkpoint_builder_factory.cc"],
+ hdrs = ["tensorflow_checkpoint_builder_factory.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":checkpoint_writer",
+ "//fcp/aggregation/core:tensor",
+ "//fcp/aggregation/protocol:checkpoint_builder",
+ "//fcp/base",
+ "//fcp/tensorflow:status",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@org_tensorflow//tensorflow/core/platform:env",
+ "@org_tensorflow//tensorflow/core/platform:status",
+ ],
+)
+
+cc_test(
+ name = "tensorflow_checkpoint_builder_factory_test",
+ srcs = ["tensorflow_checkpoint_builder_factory_test.cc"],
+ deps = [
+ ":tensorflow_checkpoint_builder_factory",
+ "//fcp/aggregation/core:tensor",
+ "//fcp/aggregation/testing",
+ "//fcp/aggregation/testing:test_data",
+ "//fcp/testing",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "tensorflow_checkpoint_parser_factory",
+ srcs = ["tensorflow_checkpoint_parser_factory.cc"],
+ hdrs = ["tensorflow_checkpoint_parser_factory.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":checkpoint_reader",
+ "//fcp/aggregation/core:tensor",
+ "//fcp/aggregation/protocol:checkpoint_parser",
+ "//fcp/base",
+ "//fcp/tensorflow:status",
+ "@com_google_absl//absl/cleanup",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@org_tensorflow//tensorflow/core/platform:env",
+ ],
+)
+
+cc_test(
+ name = "tensorflow_checkpoint_parser_factory_test",
+ srcs = ["tensorflow_checkpoint_parser_factory_test.cc"],
+ deps = [
+ ":tensorflow_checkpoint_parser_factory",
+ "//fcp/aggregation/protocol:checkpoint_parser",
+ "//fcp/aggregation/testing",
+ "//fcp/base",
+ "//fcp/tensorflow:status",
+ "//fcp/testing",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/cc:cc_ops",
+ "@org_tensorflow//tensorflow/core:tensorflow",
+ ],
+)
diff --git a/fcp/aggregation/tensorflow/checkpoint_reader.cc b/fcp/aggregation/tensorflow/checkpoint_reader.cc
new file mode 100644
index 0000000..7ec0a72
--- /dev/null
+++ b/fcp/aggregation/tensorflow/checkpoint_reader.cc
@@ -0,0 +1,92 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/checkpoint_reader.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/tensorflow/converters.h"
+#include "fcp/base/monitoring.h"
+#include "tensorflow/c/checkpoint_reader.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+
+namespace fcp::aggregation::tensorflow {
+
+namespace tf = ::tensorflow;
+
+absl::StatusOr<std::unique_ptr<CheckpointReader>> CheckpointReader::Create(
+ const std::string& filename) {
+ tf::TF_StatusPtr tf_status(TF_NewStatus());
+ auto tf_checkpoint_reader =
+ std::make_unique<tf::checkpoint::CheckpointReader>(filename,
+ tf_status.get());
+ if (TF_GetCode(tf_status.get()) != TF_OK) {
+ return absl::InternalError(
+ absl::StrFormat("Couldn't read checkpoint: %s : %s", filename,
+ TF_Message(tf_status.get())));
+ }
+
+ // Populate the DataType map.
+ DataTypeMap data_type_map;
+ for (const auto& [name, tf_dtype] :
+ tf_checkpoint_reader->GetVariableToDataTypeMap()) {
+ FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(tf_dtype));
+ data_type_map.emplace(name, dtype);
+ }
+
+ // Populate the TensorShape map.
+ TensorShapeMap shape_map;
+ for (const auto& [name, tf_shape] :
+ tf_checkpoint_reader->GetVariableToShapeMap()) {
+ shape_map.emplace(name, ConvertShape(tf_shape));
+ }
+
+ return std::unique_ptr<CheckpointReader>(
+ new CheckpointReader(std::move(tf_checkpoint_reader),
+ std::move(data_type_map), std::move(shape_map)));
+}
+
+CheckpointReader::CheckpointReader(
+ std::unique_ptr<tf::checkpoint::CheckpointReader>
+ tensorflow_checkpoint_reader,
+ DataTypeMap data_type_map, TensorShapeMap shape_map)
+ : tf_checkpoint_reader_(std::move(tensorflow_checkpoint_reader)),
+ data_type_map_(std::move(data_type_map)),
+ shape_map_(std::move(shape_map)) {}
+
+StatusOr<Tensor> CheckpointReader::GetTensor(const std::string& name) const {
+ std::unique_ptr<tf::Tensor> tensor;
+ const tf::TF_StatusPtr read_status(TF_NewStatus());
+ tf_checkpoint_reader_->GetTensor(name, &tensor, read_status.get());
+ if (TF_GetCode(read_status.get()) != TF_OK) {
+ return absl::NotFoundError(
+ absl::StrFormat("Checkpoint doesn't have tensor %s", name));
+ }
+ return ConvertTensor(std::move(tensor));
+}
+
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/tensorflow/checkpoint_reader.h b/fcp/aggregation/tensorflow/checkpoint_reader.h
new file mode 100644
index 0000000..fcb6a1d
--- /dev/null
+++ b/fcp/aggregation/tensorflow/checkpoint_reader.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_READER_H_
+#define FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_READER_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "tensorflow/c/checkpoint_reader.h"
+
+namespace fcp::aggregation::tensorflow {
+
+// This class wraps Tensorflow checkpoint reader and provides a similar
+// functionality but returns Aggregation Core tensors instead.
+// This class is designed to read only dense tensors that consist of a
+// single slice.
+class CheckpointReader final {
+ public:
+ // CheckpointReader is neither copyable nor moveable
+ CheckpointReader(const CheckpointReader&) = delete;
+ CheckpointReader& operator=(const CheckpointReader&) = delete;
+
+ using DataTypeMap = absl::flat_hash_map<std::string, DataType>;
+ using TensorShapeMap = absl::flat_hash_map<std::string, TensorShape>;
+
+ static absl::StatusOr<std::unique_ptr<CheckpointReader>> Create(
+ const std::string& filename);
+
+ const DataTypeMap& GetDataTypeMap() const { return data_type_map_; }
+ const TensorShapeMap& GetTensorShapeMap() const { return shape_map_; }
+
+ absl::StatusOr<Tensor> GetTensor(const std::string& name) const;
+
+ private:
+ CheckpointReader(std::unique_ptr<::tensorflow::checkpoint::CheckpointReader>
+ tensorflow_checkpoint_reader,
+ DataTypeMap data_type_map, TensorShapeMap shape_map);
+
+ std::unique_ptr<::tensorflow::checkpoint::CheckpointReader>
+ tf_checkpoint_reader_;
+ DataTypeMap data_type_map_;
+ TensorShapeMap shape_map_;
+};
+
+} // namespace fcp::aggregation::tensorflow
+
+#endif // FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_READER_H_
diff --git a/fcp/aggregation/tensorflow/checkpoint_reader_test.cc b/fcp/aggregation/tensorflow/checkpoint_reader_test.cc
new file mode 100644
index 0000000..40121eb
--- /dev/null
+++ b/fcp/aggregation/tensorflow/checkpoint_reader_test.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/checkpoint_reader.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <memory>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::aggregation::tensorflow {
+namespace {
+
+using ::testing::Key;
+using ::testing::UnorderedElementsAre;
+
+TEST(CheckpointReaderTest, ReadTensors) {
+ // Write a test TF checkpoint with 3 tensors
+ auto temp_filename = TemporaryTestFile(".ckpt");
+ auto tensor_a =
+ CreateTfTensor<float>(tf::DT_FLOAT, {4}, {1.0, 2.0, 3.0, 4.0});
+ auto tensor_b =
+ CreateTfTensor<int32_t>(tf::DT_INT32, {2, 3}, {11, 12, 13, 14, 15, 16});
+ auto tensor_c = CreateStringTfTensor({}, {"foobar"});
+ EXPECT_TRUE(CreateTfCheckpoint(temp_filename, {"a", "b", "c"},
+ {tensor_a, tensor_b, tensor_c})
+ .ok());
+
+ // Read the checkpoint using the Aggregation Core checkpoint reader.
+ auto checkpoint_reader_or_status = CheckpointReader::Create(temp_filename);
+ EXPECT_OK(checkpoint_reader_or_status.status());
+
+ auto checkpoint_reader = std::move(checkpoint_reader_or_status).value();
+ EXPECT_THAT(checkpoint_reader->GetDataTypeMap(),
+ UnorderedElementsAre(Key("a"), Key("b"), Key("c")));
+ EXPECT_THAT(checkpoint_reader->GetTensorShapeMap(),
+ UnorderedElementsAre(Key("a"), Key("b"), Key("c")));
+
+ // Read and verify the tensors.
+ EXPECT_THAT(*checkpoint_reader->GetTensor("a"),
+ IsTensor<float>({4}, {1.0, 2.0, 3.0, 4.0}));
+ EXPECT_THAT(*checkpoint_reader->GetTensor("b"),
+ IsTensor<int32_t>({2, 3}, {11, 12, 13, 14, 15, 16}));
+ EXPECT_THAT(*checkpoint_reader->GetTensor("c"),
+ IsTensor<string_view>({}, {"foobar"}));
+}
+
+TEST(CheckpointReaderTest, InvalidFileName) {
+ auto checkpoint_reader_or_status = CheckpointReader::Create("foo/bar");
+ EXPECT_THAT(checkpoint_reader_or_status, IsCode(INTERNAL));
+}
+
+TEST(CheckpointReaderTest, MalformedFile) {
+ auto temp_filename = TemporaryTestFile(".ckpt");
+ WriteStringToFile(temp_filename, "foobar").IgnoreError();
+ auto checkpoint_reader_or_status = CheckpointReader::Create(temp_filename);
+ EXPECT_THAT(checkpoint_reader_or_status, IsCode(INTERNAL));
+}
+
+} // namespace
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/tensorflow/checkpoint_writer.cc b/fcp/aggregation/tensorflow/checkpoint_writer.cc
new file mode 100644
index 0000000..dda7a5f
--- /dev/null
+++ b/fcp/aggregation/tensorflow/checkpoint_writer.cc
@@ -0,0 +1,93 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/checkpoint_writer.h"
+
+#include <string>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/tensorflow/status.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/platform/tstring.h"
+
+namespace fcp::aggregation::tensorflow {
+
+namespace tf = ::tensorflow;
+
+tf::TensorShape ConvertShape(const TensorShape& shape) {
+ tf::TensorShape tf_shape;
+ for (auto dim : shape.dim_sizes()) {
+ tf_shape.AddDim(dim);
+ }
+ FCP_CHECK(tf_shape.IsValid());
+ return tf_shape;
+}
+
+template <typename T>
+tf::Status AddTensorSlice(tf::checkpoint::TensorSliceWriter* writer,
+ const std::string& name, const tf::TensorShape& shape,
+ const tf::TensorSlice& slice, const Tensor& tensor) {
+ return writer->Add<T>(name, shape, slice,
+ static_cast<const T*>(tensor.data().data()));
+}
+
+template <>
+tf::Status AddTensorSlice<string_view>(
+ tf::checkpoint::TensorSliceWriter* writer, const std::string& name,
+ const tf::TensorShape& shape, const tf::TensorSlice& slice,
+ const Tensor& tensor) {
+ std::vector<tf::tstring> values(tensor.shape().NumElements());
+ const auto* string_views =
+ static_cast<const string_view*>(tensor.data().data());
+ for (size_t i = 0; i < values.size(); ++i) {
+ values[i].assign_as_view(string_views[i].data(), string_views[i].size());
+ }
+ return writer->Add(name, shape, slice, values.data());
+}
+
+CheckpointWriter::CheckpointWriter(const std::string& filename)
+ : tensorflow_writer_(filename,
+ tf::checkpoint::CreateTableTensorSliceBuilder) {}
+
+CheckpointWriter::CheckpointWriter(
+ const std::string& filename,
+ tf::checkpoint::TensorSliceWriter::CreateBuilderFunction create_builder_fn)
+ : tensorflow_writer_(filename, create_builder_fn) {}
+
+absl::Status CheckpointWriter::Add(const std::string& tensor_name,
+ const Tensor& tensor) {
+ tf::TensorShape tf_shape = ConvertShape(tensor.shape());
+ tf::TensorSlice tf_slice(tf_shape.dims());
+ FCP_CHECK(tensor.is_dense())
+ << "Only dense tensors with one slice are supported";
+ tf::Status tf_status;
+ DTYPE_CASES(tensor.dtype(), T,
+ tf_status = AddTensorSlice<T>(&tensorflow_writer_, tensor_name,
+ tf_shape, tf_slice, tensor));
+ return ConvertFromTensorFlowStatus(tf_status);
+}
+
+absl::Status CheckpointWriter::Finish() {
+ return ConvertFromTensorFlowStatus(tensorflow_writer_.Finish());
+}
+
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/tensorflow/checkpoint_writer.h b/fcp/aggregation/tensorflow/checkpoint_writer.h
new file mode 100644
index 0000000..cf8c6e5
--- /dev/null
+++ b/fcp/aggregation/tensorflow/checkpoint_writer.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_WRITER_H_
+#define FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_WRITER_H_
+
+#include <string>
+
+#include "absl/status/status.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+namespace fcp::aggregation::tensorflow {
+
+// This class wraps TensorSliceWriter and provides a similar
+// functionality but accepts Aggregation Core tensors instead.
+// This class is designed to write only dense tensors that consist of a
+// single slice.
+class CheckpointWriter final {
+ public:
+ // CheckpointReader is neither copyable nor moveable
+ CheckpointWriter(const CheckpointWriter&) = delete;
+ CheckpointWriter& operator=(const CheckpointWriter&) = delete;
+
+ // Constructs CheckpointWriter for the given filename.
+ explicit CheckpointWriter(const std::string& filename);
+
+ // Constructs CheckpointWriter for the given filename and
+ // CreateBuilderFunction.
+ explicit CheckpointWriter(
+ const std::string& filename,
+ ::tensorflow::checkpoint::TensorSliceWriter::CreateBuilderFunction
+ create_builder_fn);
+
+ // Adds a tensor to the checkpoint.
+ absl::Status Add(const std::string& tensor_name, const Tensor& tensor);
+
+ // Writes the checkpoint to the file.
+ absl::Status Finish();
+
+ private:
+ ::tensorflow::checkpoint::TensorSliceWriter tensorflow_writer_;
+};
+
+} // namespace fcp::aggregation::tensorflow
+
+#endif // FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_WRITER_H_
diff --git a/fcp/aggregation/tensorflow/checkpoint_writer_test.cc b/fcp/aggregation/tensorflow/checkpoint_writer_test.cc
new file mode 100644
index 0000000..700f62e
--- /dev/null
+++ b/fcp/aggregation/tensorflow/checkpoint_writer_test.cc
@@ -0,0 +1,82 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/checkpoint_writer.h"
+
+#include <cstdint>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/tensorflow/checkpoint_reader.h"
+#include "fcp/aggregation/testing/test_data.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::aggregation::tensorflow {
+namespace {
+
+using ::testing::Key;
+using ::testing::UnorderedElementsAre;
+
+TEST(CheckpointWriterTest, WriteTensors) {
+ // Write the checkpoint using Aggregation Core checkpoint writer.
+ auto temp_filename = TemporaryTestFile(".ckpt");
+
+ auto t1 = Tensor::Create(DT_FLOAT, TensorShape({4}),
+ CreateTestData<float>({1.0, 2.0, 3.0, 4.0}))
+ .value();
+ auto t2 = Tensor::Create(DT_INT32, TensorShape({2, 3}),
+ CreateTestData<int32_t>({11, 12, 13, 14, 15, 16}))
+ .value();
+ auto t3 =
+ Tensor::Create(
+ DT_STRING, TensorShape({3}),
+ CreateTestData<string_view>({"foo", "bar", "bazzzzzzzzzzzzzzzzzzz"}))
+ .value();
+
+ CheckpointWriter checkpoint_writer(temp_filename);
+ EXPECT_OK(checkpoint_writer.Add("a", t1));
+ EXPECT_OK(checkpoint_writer.Add("b", t2));
+ EXPECT_OK(checkpoint_writer.Add("c", t3));
+ EXPECT_OK(checkpoint_writer.Finish());
+
+ // Read the checkpoint using the Aggregation Core checkpoint reader.
+ auto checkpoint_reader_or_status = CheckpointReader::Create(temp_filename);
+ EXPECT_OK(checkpoint_reader_or_status.status());
+
+ auto checkpoint_reader = std::move(checkpoint_reader_or_status).value();
+ EXPECT_THAT(checkpoint_reader->GetDataTypeMap(),
+ UnorderedElementsAre(Key("a"), Key("b"), Key("c")));
+ EXPECT_THAT(checkpoint_reader->GetTensorShapeMap(),
+ UnorderedElementsAre(Key("a"), Key("b"), Key("c")));
+
+ // Read and verify the tensors.
+ EXPECT_THAT(*checkpoint_reader->GetTensor("a"),
+ IsTensor<float>({4}, {1.0, 2.0, 3.0, 4.0}));
+ EXPECT_THAT(*checkpoint_reader->GetTensor("b"),
+ IsTensor<int32_t>({2, 3}, {11, 12, 13, 14, 15, 16}));
+ EXPECT_THAT(
+ *checkpoint_reader->GetTensor("c"),
+ IsTensor<string_view>({3}, {"foo", "bar", "bazzzzzzzzzzzzzzzzzzz"}));
+}
+
+} // namespace
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/tensorflow/converters.cc b/fcp/aggregation/tensorflow/converters.cc
new file mode 100644
index 0000000..97a70e1
--- /dev/null
+++ b/fcp/aggregation/tensorflow/converters.cc
@@ -0,0 +1,137 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/converters.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.pb.h"
+#include "fcp/base/monitoring.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace fcp::aggregation::tensorflow {
+
+namespace tf = ::tensorflow;
+
+StatusOr<DataType> ConvertDataType(tf::DataType dtype) {
+ switch (dtype) {
+ case tf::DT_FLOAT:
+ return DT_FLOAT;
+ case tf::DT_DOUBLE:
+ return DT_DOUBLE;
+ case tf::DT_INT32:
+ return DT_INT32;
+ case tf::DT_INT64:
+ return DT_INT64;
+ case tf::DT_STRING:
+ return DT_STRING;
+ default:
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "Unsupported tf::DataType: " << dtype;
+ }
+}
+
+TensorShape ConvertShape(const tf::TensorShape& shape) {
+ FCP_CHECK(shape.IsFullyDefined());
+ std::vector<size_t> dim_sizes;
+ for (auto dim_size : shape.dim_sizes()) {
+ FCP_CHECK(dim_size >= 0);
+ dim_sizes.push_back(dim_size);
+ }
+ return TensorShape(dim_sizes.begin(), dim_sizes.end());
+}
+
+StatusOr<TensorSpec> ConvertTensorSpec(
+ const ::tensorflow::TensorSpecProto& spec) {
+ FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(spec.dtype()));
+ tf::TensorShape tf_shape;
+ if (!tf::TensorShape::BuildTensorShape(spec.shape(), &tf_shape).ok()) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "Unsupported tf::TensorShape: " << spec.shape().DebugString();
+ }
+ return TensorSpec(spec.name(), dtype, ConvertShape(tf_shape));
+}
+
+// A primitive TensorData implementation that wraps the original
+// tf::Tensor data.
+// NumericTensorDataAdapter gets the ownership of the wrapped tensor, which
+// keeps the underlying data alive.
+class NumericTensorDataAdapter : public TensorData {
+ public:
+ explicit NumericTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)
+ : tensor_(std::move(tensor)) {}
+
+ // The source tf::Tensor has the data as one continuous blob.
+ size_t byte_size() const override { return tensor_->tensor_data().size(); }
+ const void* data() const override { return tensor_->tensor_data().data(); }
+
+ private:
+ std::unique_ptr<tf::Tensor> tensor_;
+};
+
+// Similar to NumericTensorDataAdapter but performs additional conversion
+// of the original tensor tstring values to string_view while keeping the
+// the tstring values owned by the original tensor.
+class StringTensorDataAdapter : public TensorData {
+ public:
+ explicit StringTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)
+ : tensor_(std::move(tensor)), string_views_(tensor_->NumElements()) {
+ auto string_values = tensor_->flat<tf::tstring>();
+ for (size_t i = 0; i < string_values.size(); ++i) {
+ string_views_[i] = string_values(i);
+ }
+ }
+
+ size_t byte_size() const override {
+ return string_views_.size() * sizeof(string_view);
+ }
+ const void* data() const override { return string_views_.data(); }
+
+ private:
+ std::unique_ptr<tf::Tensor> tensor_;
+ std::vector<string_view> string_views_;
+};
+
+// Conversion of tensor data for numeric data types, which can be
+// done by simply wrapping the original tensorflow tensor data.
+template <typename t>
+std::unique_ptr<TensorData> ConvertTensorData(
+ std::unique_ptr<tf::Tensor> tensor) {
+ return std::make_unique<NumericTensorDataAdapter>(std::move(tensor));
+}
+
+// Specialization of ConvertTensorData for the DT_STRING data type.
+template <>
+std::unique_ptr<TensorData> ConvertTensorData<string_view>(
+ std::unique_ptr<tf::Tensor> tensor) {
+ return std::make_unique<StringTensorDataAdapter>(std::move(tensor));
+}
+
+StatusOr<Tensor> ConvertTensor(std::unique_ptr<tf::Tensor> tensor) {
+ FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(tensor->dtype()));
+ TensorShape shape = ConvertShape(tensor->shape());
+ std::unique_ptr<TensorData> data;
+ DTYPE_CASES(dtype, T, data = ConvertTensorData<T>(std::move(tensor)));
+ return Tensor::Create(dtype, std::move(shape), std::move(data));
+}
+
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/tensorflow/converters.h b/fcp/aggregation/tensorflow/converters.h
new file mode 100644
index 0000000..3a75b57
--- /dev/null
+++ b/fcp/aggregation/tensorflow/converters.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_TENSORFLOW_CONVERTERS_H_
+#define FCP_AGGREGATION_TENSORFLOW_CONVERTERS_H_
+
+#include <memory>
+
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/core/tensor_spec.h"
+#include "fcp/base/monitoring.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace fcp::aggregation::tensorflow {
+
+// Converts Tensorflow DataType to Aggregation DataType.
+// Returns an error status if the input data type isn't supported by
+// the Aggregation Core.
+StatusOr<DataType> ConvertDataType(::tensorflow::DataType dtype);
+
+// Converts Tensorflow TensorShape to Aggregation TensorShape.
+// Note that the Tensorflow shape is expected to be valid (it seems impossible
+// to create an invalid shape).
+TensorShape ConvertShape(const ::tensorflow::TensorShape& shape);
+
+// Converts Tensorflow TensorSpecProto to Aggregation TensorSpec.
+// Returns an error status if supplied TensorSpecProto data type or shape isn't
+// supported by the Aggregation Core.
+StatusOr<TensorSpec> ConvertTensorSpec(
+ const ::tensorflow::TensorSpecProto& spec);
+
+// Converts Tensorflow Tensor to Aggregation Tensor.
+// Returns an error status if supplied Tensor data type or shape isn't
+// supported by the Aggregation Core.
+// Note that this function consumes the Tensorflow tensor.
+StatusOr<Tensor> ConvertTensor(std::unique_ptr<::tensorflow::Tensor> tensor);
+
+} // namespace fcp::aggregation::tensorflow
+
+#endif // FCP_AGGREGATION_TENSORFLOW_CONVERTERS_H_
diff --git a/fcp/aggregation/tensorflow/converters_test.cc b/fcp/aggregation/tensorflow/converters_test.cc
new file mode 100644
index 0000000..6c6ebac
--- /dev/null
+++ b/fcp/aggregation/tensorflow/converters_test.cc
@@ -0,0 +1,152 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/converters.h"
+
+#include <initializer_list>
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/core/tensor_spec.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace fcp::aggregation::tensorflow {
+namespace {
+
+namespace tf = ::tensorflow;
+
+tf::TensorShape CreateTfShape(std::initializer_list<int64_t> dim_sizes) {
+ tf::TensorShape shape;
+ EXPECT_TRUE(tf::TensorShape::BuildTensorShape(dim_sizes, &shape).ok());
+ return shape;
+}
+
+tf::TensorSpecProto CreateTfTensorSpec(
+ const std::string& name, tf::DataType dtype,
+ std::initializer_list<int64_t> dim_sizes) {
+ tf::TensorSpecProto spec;
+ spec.set_name(name);
+ spec.set_dtype(dtype);
+ for (auto dim_size : dim_sizes) {
+ spec.mutable_shape()->add_dim()->set_size(dim_size);
+ }
+ return spec;
+}
+
+TEST(ConvertersTest, ConvertDataType_Success) {
+ EXPECT_EQ(*ConvertDataType(tf::DT_FLOAT), DT_FLOAT);
+ EXPECT_EQ(*ConvertDataType(tf::DT_DOUBLE), DT_DOUBLE);
+ EXPECT_EQ(*ConvertDataType(tf::DT_INT32), DT_INT32);
+ EXPECT_EQ(*ConvertDataType(tf::DT_INT64), DT_INT64);
+ EXPECT_EQ(*ConvertDataType(tf::DT_STRING), DT_STRING);
+}
+
+TEST(ConvertersTest, ConvertDataType_Unsupported) {
+ EXPECT_THAT(ConvertDataType(tf::DT_VARIANT), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(ConvertersTest, ConvertShape_Success) {
+ EXPECT_EQ(ConvertShape(CreateTfShape({})), TensorShape({}));
+ EXPECT_EQ(ConvertShape(CreateTfShape({1})), TensorShape({1}));
+ EXPECT_EQ(ConvertShape(CreateTfShape({2, 3})), TensorShape({2, 3}));
+}
+
+TEST(ConvertersTest, ConvertTensorSpec_Success) {
+ auto tensor_spec =
+ ConvertTensorSpec(CreateTfTensorSpec("foo", tf::DT_FLOAT, {1, 2, 3}));
+ ASSERT_THAT(tensor_spec, IsOk());
+ EXPECT_EQ(tensor_spec->name(), "foo");
+ EXPECT_EQ(tensor_spec->dtype(), DT_FLOAT);
+ EXPECT_EQ(tensor_spec->shape(), TensorShape({1, 2, 3}));
+}
+
+TEST(ConvertersTest, ConvertTensorSpec_UnsupportedDataType) {
+ EXPECT_THAT(
+ ConvertTensorSpec(CreateTfTensorSpec("foo", tf::DT_VARIANT, {1, 2, 3})),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(ConvertersTest, ConvertTensorSpec_UnsupportedShape) {
+ EXPECT_THAT(
+ ConvertTensorSpec(CreateTfTensorSpec("foo", tf::DT_FLOAT, {1, -1})),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(ConvertersTest, ConvertTensor_Numeric) {
+ tf::TensorProto tensor_proto = PARSE_TEXT_PROTO(R"pb(
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 2 }
+ dim { size: 3 }
+ }
+ float_val: 1
+ float_val: 2
+ float_val: 3
+ float_val: 4
+ float_val: 5
+ float_val: 6
+ )pb");
+ auto tensor = std::make_unique<tf::Tensor>();
+ ASSERT_TRUE(tensor->FromProto(tensor_proto));
+ EXPECT_THAT(*ConvertTensor(std::move(tensor)),
+ IsTensor<float>({2, 3}, {1, 2, 3, 4, 5, 6}));
+}
+
+TEST(ConvertersTest, ConvertTensor_String) {
+ tf::TensorProto tensor_proto = PARSE_TEXT_PROTO(R"pb(
+ dtype: DT_STRING
+ tensor_shape { dim { size: 3 } }
+ string_val: "abcd"
+ string_val: "foobar"
+ string_val: "zzzzzzzzzzzzzz"
+ )pb");
+ auto tensor = std::make_unique<tf::Tensor>();
+ ASSERT_TRUE(tensor->FromProto(tensor_proto));
+ EXPECT_THAT(*ConvertTensor(std::move(tensor)),
+ IsTensor<string_view>({3}, {"abcd", "foobar", "zzzzzzzzzzzzzz"}));
+}
+
+TEST(ConvertersTest, ConvertTensor_ScalarString) {
+ tf::TensorProto tensor_proto = PARSE_TEXT_PROTO(R"pb(
+ dtype: DT_STRING
+ tensor_shape {}
+ string_val: "0123456789"
+ )pb");
+ auto tensor = std::make_unique<tf::Tensor>();
+ ASSERT_TRUE(tensor->FromProto(tensor_proto));
+ EXPECT_THAT(*ConvertTensor(std::move(tensor)),
+ IsTensor<string_view>({}, {"0123456789"}));
+}
+
+TEST(ConvertersTest, ConvertTensor_UnsupportedDataType) {
+ auto tensor = std::make_unique<tf::Tensor>(tf::DT_VARIANT, CreateTfShape({}));
+ EXPECT_THAT(ConvertTensor(std::move(tensor)), IsCode(INVALID_ARGUMENT));
+}
+
+} // namespace
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/tensorflow/python/BUILD b/fcp/aggregation/tensorflow/python/BUILD
new file mode 100644
index 0000000..f00299f
--- /dev/null
+++ b/fcp/aggregation/tensorflow/python/BUILD
@@ -0,0 +1,36 @@
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
+load("@rules_python//python:defs.bzl", "py_test")
+
+package(
+ default_visibility = ["//fcp/aggregation:internal"],
+)
+
+pybind_extension(
+ name = "aggregation_protocols",
+ srcs = ["aggregation_protocols.cc"],
+ pytype_deps = [
+ "//fcp/aggregation/protocol/python:aggregation_protocol",
+ ],
+ deps = [
+ "//fcp/aggregation/protocol:aggregation_protocol",
+ "//fcp/aggregation/protocol:configuration_cc_proto",
+ "//fcp/aggregation/protocol:resource_resolver",
+ "//fcp/aggregation/protocol/simple_aggregation",
+ "//fcp/aggregation/tensorflow:tensorflow_checkpoint_builder_factory",
+ "//fcp/aggregation/tensorflow:tensorflow_checkpoint_parser_factory",
+ "@com_google_absl//absl/status:statusor",
+ "@pybind11_abseil//pybind11_abseil:status_casters",
+ "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
+ ],
+)
+
+py_test(
+ name = "aggregation_protocols_test",
+ srcs = ["aggregation_protocols_test.py"],
+ data = ["@pybind11_abseil//pybind11_abseil:status.so"],
+ deps = [
+ ":aggregation_protocols",
+ "//fcp/aggregation/protocol:configuration_py_pb2",
+ "//fcp/aggregation/protocol:py_pb2",
+ ],
+)
diff --git a/fcp/aggregation/tensorflow/python/aggregation_protocols.cc b/fcp/aggregation/tensorflow/python/aggregation_protocols.cc
new file mode 100644
index 0000000..199244e
--- /dev/null
+++ b/fcp/aggregation/tensorflow/python/aggregation_protocols.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <pybind11/pybind11.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/status/statusor.h"
+#include "fcp/aggregation/protocol/aggregation_protocol.h"
+#include "fcp/aggregation/protocol/configuration.pb.h"
+#include "fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.h"
+#include "fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h"
+#include "fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.h"
+#include "pybind11_abseil/status_casters.h"
+#include "pybind11_protobuf/native_proto_caster.h"
+
+namespace py = ::pybind11;
+
+using ::fcp::aggregation::AggregationProtocol;
+using ::fcp::aggregation::Configuration;
+using ::fcp::aggregation::ResourceResolver;
+using ::fcp::aggregation::tensorflow::TensorflowCheckpointBuilderFactory;
+using ::fcp::aggregation::tensorflow::TensorflowCheckpointParserFactory;
+
+PYBIND11_MODULE(aggregation_protocols, m) {
+ class DefaultResourceResolver : public ResourceResolver {
+ absl::StatusOr<absl::Cord> RetrieveResource(
+ int64_t client_id, const std::string& uri) override {
+ return absl::UnimplementedError("RetrieveResource() is not supported.");
+ }
+ };
+
+ pybind11::google::ImportStatusModule();
+ pybind11_protobuf::ImportNativeProtoCasters();
+
+ static const TensorflowCheckpointBuilderFactory* const
+ kCheckpointBuilderFactory = new TensorflowCheckpointBuilderFactory();
+ static const TensorflowCheckpointParserFactory* const
+ kCheckpointParserFactory = new TensorflowCheckpointParserFactory();
+ static ResourceResolver* kResourceResolver = new DefaultResourceResolver();
+
+ m.def(
+ "create_simple_aggregation_protocol",
+ [](const Configuration& configuration,
+ AggregationProtocol::Callback* callback)
+ -> absl::StatusOr<std::unique_ptr<AggregationProtocol>> {
+ return fcp::aggregation::SimpleAggregationProtocol::Create(
+ configuration, callback, kCheckpointParserFactory,
+ kCheckpointBuilderFactory, kResourceResolver);
+ },
+ // Ensure the Callback object outlives the AggregationProtocol.
+ py::keep_alive<0, 2>());
+}
diff --git a/fcp/aggregation/tensorflow/python/aggregation_protocols_test.py b/fcp/aggregation/tensorflow/python/aggregation_protocols_test.py
new file mode 100644
index 0000000..e32a6f5
--- /dev/null
+++ b/fcp/aggregation/tensorflow/python/aggregation_protocols_test.py
@@ -0,0 +1,119 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for aggregation_protocols."""
+
+import tempfile
+from typing import Any
+from unittest import mock
+
+from absl.testing import absltest
+import tensorflow as tf
+
+from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2
+from fcp.aggregation.protocol import configuration_pb2
+from fcp.aggregation.protocol.python import aggregation_protocol
+from fcp.aggregation.tensorflow.python import aggregation_protocols
+from pybind11_abseil import status
+
+
+def create_client_input(tensors: dict[str, Any]) -> apm_pb2.ClientMessage:
+ with tempfile.NamedTemporaryFile() as tmpfile:
+ tf.raw_ops.Save(
+ filename=tmpfile.name,
+ tensor_names=list(tensors.keys()),
+ data=list(tensors.values()))
+ with open(tmpfile.name, 'rb') as f:
+ return apm_pb2.ClientMessage(
+ simple_aggregation=apm_pb2.ClientMessage.SimpleAggregation(
+ input=apm_pb2.ClientResource(inline_bytes=f.read())))
+
+
+class CallbackProxy(aggregation_protocol.AggregationProtocol.Callback):
+ """A pass-through Callback that delegates to another Callback.
+
+ This works around the issue that mock.Mock objects aren't recognized as
+ Callback subclasses by pybind11.
+ """
+
+ def __init__(self,
+ callback: aggregation_protocol.AggregationProtocol.Callback):
+ super().__init__()
+ self._callback = callback
+
+ def OnAcceptClients(self, start_client_id: int, num_clients: int,
+ message: apm_pb2.AcceptanceMessage):
+ self._callback.OnAcceptClients(start_client_id, num_clients, message)
+
+ def OnSendServerMessage(self, client_id: int, message: apm_pb2.ServerMessage):
+ self._callback.OnSendServerMessage(client_id, message)
+
+ def OnCloseClient(self, client_id: int, diagnostic_status: status.Status):
+ self._callback.OnCloseClient(client_id, diagnostic_status)
+
+ def OnComplete(self, result: bytes):
+ self._callback.OnComplete(result)
+
+ def OnAbort(self, diagnostic_status: status.Status):
+ self._callback.OnAbort(diagnostic_status)
+
+
+class AggregationProtocolsTest(absltest.TestCase):
+
+ def test_simple_aggregation_protocol(self):
+ input_tensor = tf.TensorSpec((), tf.int32, 'in')
+ output_tensor = tf.TensorSpec((), tf.int32, 'out')
+ config = configuration_pb2.Configuration(aggregation_configs=[
+ configuration_pb2.Configuration.ServerAggregationConfig(
+ intrinsic_uri='federated_sum',
+ intrinsic_args=[
+ configuration_pb2.Configuration.ServerAggregationConfig.
+ IntrinsicArg(input_tensor=input_tensor.experimental_as_proto()),
+ ],
+ output_tensors=[output_tensor.experimental_as_proto()],
+ ),
+ ])
+ callback = mock.create_autospec(
+ aggregation_protocol.AggregationProtocol.Callback, instance=True)
+
+ agg_protocol = aggregation_protocols.create_simple_aggregation_protocol(
+ config, CallbackProxy(callback))
+ self.assertIsNotNone(agg_protocol)
+
+ agg_protocol.Start(2)
+ callback.OnAcceptClients.assert_called_once_with(mock.ANY, 2, mock.ANY)
+ start_client_id = callback.OnAcceptClients.call_args.args[0]
+
+ agg_protocol.ReceiveClientMessage(
+ start_client_id, create_client_input({input_tensor.name: 3}))
+ agg_protocol.ReceiveClientMessage(
+ start_client_id + 1, create_client_input({input_tensor.name: 5}))
+ callback.OnCloseClient.assert_has_calls([
+ mock.call(start_client_id, status.Status.OkStatus()),
+ mock.call(start_client_id + 1, status.Status.OkStatus()),
+ ])
+
+ agg_protocol.Complete()
+ callback.OnComplete.assert_called_once()
+ with tempfile.NamedTemporaryFile('wb') as tmpfile:
+ tmpfile.write(callback.OnComplete.call_args.args[0])
+ tmpfile.flush()
+ self.assertEqual(
+ tf.raw_ops.Restore(
+ file_pattern=tmpfile.name,
+ tensor_name=output_tensor.name,
+ dt=output_tensor.dtype), 8)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.cc b/fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.cc
new file mode 100644
index 0000000..e62b1e2
--- /dev/null
+++ b/fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.cc
@@ -0,0 +1,103 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h"
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/random/random.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/protocol/checkpoint_builder.h"
+#include "fcp/aggregation/tensorflow/checkpoint_writer.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/tensorflow/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/status.h"
+
+namespace fcp::aggregation::tensorflow {
+namespace {
+
+using ::tensorflow::Env;
+
+// A CheckpointBuilder implementation that builds TensorFlow checkpoints using a
+// CheckpointWriter.
+class TensorflowCheckpointBuilder : public CheckpointBuilder {
+ public:
+ explicit TensorflowCheckpointBuilder(std::string filename)
+ : filename_(std::move(filename)) {}
+
+ ~TensorflowCheckpointBuilder() override {
+ Env::Default()->DeleteFile(filename_).IgnoreError();
+ }
+
+ absl::Status Add(const std::string& name, const Tensor& tensor) override {
+ return writer_.Add(name, tensor);
+ }
+
+ absl::StatusOr<absl::Cord> Build() override {
+ FCP_RETURN_IF_ERROR(writer_.Finish());
+
+ // Read the checkpoints contents from the file.
+ std::unique_ptr<::tensorflow::RandomAccessFile> file;
+ FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(
+ Env::Default()->NewRandomAccessFile(filename_, &file)));
+
+ absl::Cord output;
+ for (;;) {
+ char scratch[4096];
+ absl::string_view read_result;
+ ::tensorflow::Status status =
+ file->Read(output.size(), sizeof(scratch), &read_result, scratch);
+ output.Append(read_result);
+ if (status.code() == ::tensorflow::error::OUT_OF_RANGE) {
+ return output;
+ } else if (!status.ok()) {
+ return ConvertFromTensorFlowStatus(status);
+ }
+ }
+ }
+
+ private:
+ std::string filename_;
+ CheckpointWriter writer_{filename_};
+};
+
+} // namespace
+
+std::unique_ptr<CheckpointBuilder> TensorflowCheckpointBuilderFactory::Create()
+ const {
+ // Create a (likely) unique filename in Tensorflow's RamFileSystem. This
+ // results in a second in-memory copy of the data but avoids disk I/O.
+ std::string filename =
+ absl::StrCat("ram://",
+ absl::Hex(absl::Uniform(
+ absl::BitGen(), 0, std::numeric_limits<int64_t>::max())),
+ ".ckpt");
+
+ return std::make_unique<TensorflowCheckpointBuilder>(std::move(filename));
+}
+
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h b/fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h
new file mode 100644
index 0000000..59d8776
--- /dev/null
+++ b/fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_TENSORFLOW_TENSORFLOW_CHECKPOINT_BUILDER_FACTORY_H_
+#define FCP_AGGREGATION_TENSORFLOW_TENSORFLOW_CHECKPOINT_BUILDER_FACTORY_H_
+
+#include <memory>
+
+#include "fcp/aggregation/protocol/checkpoint_builder.h"
+
+namespace fcp::aggregation::tensorflow {
+
+// A CheckpointBuilderFactory implementation that writes TensorFlow checkpoints.
+class TensorflowCheckpointBuilderFactory
+ : public fcp::aggregation::CheckpointBuilderFactory {
+ public:
+ std::unique_ptr<fcp::aggregation::CheckpointBuilder> Create() const override;
+};
+
+} // namespace fcp::aggregation::tensorflow
+
+#endif // FCP_AGGREGATION_TENSORFLOW_TENSORFLOW_CHECKPOINT_BUILDER_FACTORY_H_
diff --git a/fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory_test.cc b/fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory_test.cc
new file mode 100644
index 0000000..7ccced9
--- /dev/null
+++ b/fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory_test.cc
@@ -0,0 +1,116 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h"
+
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/aggregation/core/mutable_vector_data.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/testing/test_data.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::aggregation::tensorflow {
+namespace {
+
+using ::testing::AllOf;
+using ::testing::Each;
+using ::testing::Pair;
+using ::testing::SizeIs;
+using ::testing::StartsWith;
+using ::testing::UnorderedElementsAre;
+
+TEST(TensorflowCheckpointBuilderFactoryTest, BuildCheckpoint) {
+ TensorflowCheckpointBuilderFactory factory;
+ std::unique_ptr<CheckpointBuilder> builder = factory.Create();
+
+ absl::StatusOr<Tensor> t1 = Tensor::Create(
+ DT_FLOAT, TensorShape({4}), CreateTestData<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_OK(t1.status());
+ absl::StatusOr<Tensor> t2 = Tensor::Create(DT_FLOAT, TensorShape({2}),
+ CreateTestData<float>({5.0, 6.0}));
+ ASSERT_OK(t2.status());
+
+ EXPECT_OK(builder->Add("t1", *t1));
+ EXPECT_OK(builder->Add("t2", *t2));
+ absl::StatusOr<absl::Cord> checkpoint = builder->Build();
+ ASSERT_OK(checkpoint.status());
+ auto summary = SummarizeCheckpoint(*checkpoint);
+ ASSERT_OK(summary.status());
+ EXPECT_THAT(*summary,
+ UnorderedElementsAre(Pair("t1", "1 2 3 4"), Pair("t2", "5 6")));
+}
+
+// Check that multiple checkpoints can be built simultanously.
+TEST(TensorflowCheckpointBuilderFactoryTest, SimultaneousWrites) {
+ TensorflowCheckpointBuilderFactory factory;
+
+ absl::StatusOr<Tensor> t1 = Tensor::Create(
+ DT_FLOAT, TensorShape({4}), CreateTestData<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_OK(t1.status());
+ absl::StatusOr<Tensor> t2 = Tensor::Create(DT_FLOAT, TensorShape({2}),
+ CreateTestData<float>({5.0, 6.0}));
+ ASSERT_OK(t2.status());
+
+ std::unique_ptr<CheckpointBuilder> builder1 = factory.Create();
+ std::unique_ptr<CheckpointBuilder> builder2 = factory.Create();
+ EXPECT_OK(builder1->Add("t1", *t1));
+ EXPECT_OK(builder2->Add("t2", *t2));
+ absl::StatusOr<absl::Cord> checkpoint1 = builder1->Build();
+ ASSERT_OK(checkpoint1.status());
+ absl::StatusOr<absl::Cord> checkpoint2 = builder2->Build();
+ ASSERT_OK(checkpoint2.status());
+ auto summary1 = SummarizeCheckpoint(*checkpoint1);
+ ASSERT_OK(summary1.status());
+ EXPECT_THAT(*summary1, UnorderedElementsAre(Pair("t1", "1 2 3 4")));
+ auto summary2 = SummarizeCheckpoint(*checkpoint2);
+ ASSERT_OK(summary2.status());
+ EXPECT_THAT(*summary2, UnorderedElementsAre(Pair("t2", "5 6")));
+}
+
+TEST(TensorflowCheckpointBuilderFactoryTest, LargeCheckpoint) {
+ TensorflowCheckpointBuilderFactory factory;
+ std::unique_ptr<CheckpointBuilder> builder = factory.Create();
+
+ // Add 10 tensors that each require at least 8kB to exercise reading and
+ // writing in multiple chunks.
+ static constexpr int kTensorSize = 1024;
+ absl::StatusOr<Tensor> t =
+ Tensor::Create(DT_INT64, TensorShape({kTensorSize}),
+ std::make_unique<MutableVectorData<int64_t>>(kTensorSize));
+ ASSERT_OK(t.status());
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_OK(builder->Add(absl::StrCat("t", i), *t));
+ }
+ absl::StatusOr<absl::Cord> checkpoint = builder->Build();
+ ASSERT_OK(checkpoint.status());
+ auto summary = SummarizeCheckpoint(*checkpoint);
+ ASSERT_OK(summary.status());
+ EXPECT_THAT(*summary,
+ AllOf(SizeIs(10), Each(Pair(StartsWith("t"),
+ StartsWith("0 0 0 0 0 0 0 0 0")))));
+}
+
+} // namespace
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.cc b/fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.cc
new file mode 100644
index 0000000..9214a93
--- /dev/null
+++ b/fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.cc
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.h"
+
+#include <stdint.h>
+
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/cleanup/cleanup.h"
+#include "absl/random/random.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/protocol/checkpoint_parser.h"
+#include "fcp/aggregation/tensorflow/checkpoint_reader.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/tensorflow/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/file_system.h"
+
+namespace fcp::aggregation::tensorflow {
+namespace {
+
+using ::tensorflow::Env;
+
+// A CheckpointParser implementation that reads TensorFlow checkpoints using a
+// CheckpointReader.
+class TensorflowCheckpointParser : public CheckpointParser {
+ public:
+ TensorflowCheckpointParser(std::string filename,
+ std::unique_ptr<CheckpointReader> reader)
+ : filename_(std::move(filename)), reader_(std::move(reader)) {}
+
+ ~TensorflowCheckpointParser() override {
+ Env::Default()->DeleteFile(filename_).IgnoreError();
+ }
+
+ absl::StatusOr<Tensor> GetTensor(const std::string& name) const override {
+ return reader_->GetTensor(name);
+ }
+
+ private:
+ std::string filename_;
+ std::unique_ptr<CheckpointReader> reader_;
+};
+
+} // namespace
+
+absl::StatusOr<std::unique_ptr<CheckpointParser>>
+TensorflowCheckpointParserFactory::Create(
+ const absl::Cord& serialized_checkpoint) const {
+ // Create a (likely) unique filename in Tensorflow's RamFileSystem. This
+ // results in a second in-memory copy of the data but avoids disk I/O.
+ std::string filename =
+ absl::StrCat("ram://",
+ absl::Hex(absl::Uniform(
+ absl::BitGen(), 0, std::numeric_limits<int64_t>::max())),
+ ".ckpt");
+
+ // Write the checkpoint to the temporary file.
+ std::unique_ptr<::tensorflow::WritableFile> file;
+ FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(
+ Env::Default()->NewWritableFile(filename, &file)));
+ absl::Cleanup cleanup = [&] {
+ Env::Default()->DeleteFile(filename).IgnoreError();
+ };
+ for (absl::string_view chunk : serialized_checkpoint.Chunks()) {
+ FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(file->Append(chunk)));
+ }
+ FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(file->Close()));
+
+ // Return a TensorflowCheckpointParser that will read from the file.
+ FCP_ASSIGN_OR_RETURN(std::unique_ptr<CheckpointReader> reader,
+ CheckpointReader::Create(filename));
+ std::move(cleanup).Cancel();
+ return std::make_unique<TensorflowCheckpointParser>(std::move(filename),
+ std::move(reader));
+}
+
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.h b/fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.h
new file mode 100644
index 0000000..b065fba
--- /dev/null
+++ b/fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_TENSORFLOW_TENSORFLOW_CHECKPOINT_PARSER_FACTORY_H_
+#define FCP_AGGREGATION_TENSORFLOW_TENSORFLOW_CHECKPOINT_PARSER_FACTORY_H_
+
+#include <memory>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "fcp/aggregation/protocol/checkpoint_parser.h"
+
+namespace fcp::aggregation::tensorflow {
+
+// A CheckpointParserFactory implementation that reads TensorFlow checkpoints.
+class TensorflowCheckpointParserFactory
+ : public fcp::aggregation::CheckpointParserFactory {
+ public:
+ absl::StatusOr<std::unique_ptr<fcp::aggregation::CheckpointParser>> Create(
+ const absl::Cord& serialized_checkpoint) const override;
+};
+
+} // namespace fcp::aggregation::tensorflow
+
+#endif // FCP_AGGREGATION_TENSORFLOW_TENSORFLOW_CHECKPOINT_PARSER_FACTORY_H_
diff --git a/fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory_test.cc b/fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory_test.cc
new file mode 100644
index 0000000..c7e4d17
--- /dev/null
+++ b/fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory_test.cc
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "fcp/aggregation/protocol/checkpoint_parser.h"
+#include "fcp/aggregation/testing/testing.h"
+#include "fcp/base/platform.h"
+#include "fcp/tensorflow/status.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::aggregation::tensorflow {
+namespace {
+
+TEST(TensorflowCheckpointParserFactoryTest, ReadCheckpoint) {
+ std::string filename = TemporaryTestFile(".ckpt");
+ ASSERT_OK(ConvertFromTensorFlowStatus(CreateTfCheckpoint(
+ filename, {"t1", "t2"}, {{1.0f, 2.0f, 3.0f, 4.0f}, {5.0f, 6.0f}})));
+ absl::StatusOr<absl::Cord> checkpoint = ReadFileToCord(filename);
+ ASSERT_OK(checkpoint.status());
+
+ TensorflowCheckpointParserFactory factory;
+ absl::StatusOr<std::unique_ptr<CheckpointParser>> parser =
+ factory.Create(*checkpoint);
+ ASSERT_OK(parser.status());
+
+ auto t1 = (*parser)->GetTensor("t1");
+ ASSERT_OK(t1.status());
+ EXPECT_THAT(*t1, IsTensor<float>({4}, {1.0, 2.0, 3.0, 4.0}));
+ auto t2 = (*parser)->GetTensor("t2");
+ ASSERT_OK(t2.status());
+ EXPECT_THAT(*t2, IsTensor<float>({2}, {5.0, 6.0}));
+ EXPECT_FALSE((*parser)->GetTensor("t3").ok());
+}
+
+TEST(TensorflowCheckpointParserFactoryTest, InvalidCheckpoint) {
+ TensorflowCheckpointParserFactory factory;
+ EXPECT_FALSE(factory.Create(absl::Cord("invalid")).ok());
+}
+
+} // namespace
+} // namespace fcp::aggregation::tensorflow
diff --git a/fcp/aggregation/testing/BUILD b/fcp/aggregation/testing/BUILD
new file mode 100644
index 0000000..c5854af
--- /dev/null
+++ b/fcp/aggregation/testing/BUILD
@@ -0,0 +1,61 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp/aggregation:internal"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "testing",
+ testonly = True,
+ srcs = [
+ "testing.cc",
+ ],
+ hdrs = [
+ "testing.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/aggregation/core:tensor",
+ "//fcp/base",
+ "//fcp/tensorflow:status",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/c:checkpoint_reader",
+ "@org_tensorflow//tensorflow/c:tf_status_headers",
+ "@org_tensorflow//tensorflow/c:tf_status_helper",
+ "@org_tensorflow//tensorflow/cc:cc_ops",
+ "@org_tensorflow//tensorflow/cc:ops",
+ "@org_tensorflow//tensorflow/cc:scope",
+ "@org_tensorflow//tensorflow/core:core_cpu",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ "@org_tensorflow//tensorflow/core:tensorflow",
+ ],
+)
+
+cc_library(
+ name = "test_data",
+ testonly = True,
+ hdrs = [
+ "test_data.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/aggregation/core:tensor",
+ ],
+)
diff --git a/fcp/aggregation/testing/test_data.h b/fcp/aggregation/testing/test_data.h
new file mode 100644
index 0000000..d559c63
--- /dev/null
+++ b/fcp/aggregation/testing/test_data.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_TESTING_TEST_DATA_H_
+#define FCP_AGGREGATION_TESTING_TEST_DATA_H_
+
+#include <initializer_list>
+#include <memory>
+
+#include "fcp/aggregation/core/mutable_vector_data.h"
+
+namespace fcp::aggregation {
+
+// Creates test tensor data based on a vector<T>.
+template <typename T>
+std::unique_ptr<MutableVectorData<T>> CreateTestData(
+ std::initializer_list<T> values) {
+ return std::make_unique<MutableVectorData<T>>(values);
+}
+
+} // namespace fcp::aggregation
+
+#endif // FCP_AGGREGATION_TESTING_TEST_DATA_H_
diff --git a/fcp/aggregation/testing/testing.cc b/fcp/aggregation/testing/testing.cc
new file mode 100644
index 0000000..d8624e6
--- /dev/null
+++ b/fcp/aggregation/testing/testing.cc
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/aggregation/testing/testing.h"
+
+#include <memory>
+#include <ostream>
+#include <string>
+#include <utility>
+
+#include "fcp/base/platform.h"
+#include "fcp/tensorflow/status.h"
+#include "fcp/testing/testing.h"
+#include "tensorflow/c/checkpoint_reader.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/io_ops.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/public/session.h"
+
+namespace fcp::aggregation {
+
+using ::tensorflow::StatusFromTF_Status;
+using ::tensorflow::TF_StatusPtr;
+using ::tensorflow::checkpoint::CheckpointReader;
+
+std::ostream& operator<<(std::ostream& os, const Tensor& tensor) {
+ DTYPE_CASES(tensor.dtype(), T,
+ DescribeTensor<T>(&os, tensor.dtype(), tensor.shape(),
+ TensorValuesToVector<T>(tensor)));
+ return os;
+}
+
+tf::Tensor CreateStringTfTensor(std::initializer_list<int64_t> dim_sizes,
+ std::initializer_list<string_view> values) {
+ tf::TensorShape shape;
+ EXPECT_TRUE(tf::TensorShape::BuildTensorShape(dim_sizes, &shape).ok());
+ tf::Tensor tensor(tf::DT_STRING, shape);
+ auto* tensor_data_ptr = reinterpret_cast<tf::tstring*>(tensor.data());
+ for (auto value : values) {
+ *tensor_data_ptr++ = value;
+ }
+ return tensor;
+}
+
+tf::Status CreateTfCheckpoint(tf::Input filename, tf::Input tensor_names,
+ tf::InputList tensors) {
+ tf::Scope scope = tf::Scope::NewRootScope();
+
+ tf::ops::Save save(scope, std::move(filename), std::move(tensor_names),
+ std::move(tensors));
+
+ tf::GraphDef graph;
+ if (auto s = scope.ToGraphDef(&graph); !s.ok()) return s;
+
+ auto session = absl::WrapUnique(tf::NewSession(tf::SessionOptions()));
+ if (auto s = session->Create(graph); !s.ok()) return s;
+ return session->Run({}, {}, {save.operation.node()->name()}, nullptr);
+}
+
+absl::StatusOr<absl::flat_hash_map<std::string, std::string>>
+SummarizeCheckpoint(const absl::Cord& checkpoint) {
+ std::string filename = TemporaryTestFile(".ckpt");
+ FCP_RETURN_IF_ERROR(WriteCordToFile(filename, checkpoint));
+
+ TF_StatusPtr tf_status(TF_NewStatus());
+ auto reader = std::make_unique<CheckpointReader>(filename, tf_status.get());
+ FCP_RETURN_IF_ERROR(
+ ConvertFromTensorFlowStatus(StatusFromTF_Status(tf_status.get())));
+
+ absl::flat_hash_map<std::string, std::string> tensors;
+ for (const auto& [name, shape] : reader->GetVariableToShapeMap()) {
+ std::unique_ptr<::tensorflow::Tensor> tensor;
+ reader->GetTensor(name, &tensor, tf_status.get());
+ FCP_RETURN_IF_ERROR(
+ ConvertFromTensorFlowStatus(StatusFromTF_Status(tf_status.get())));
+ tensors[name] = tensor->SummarizeValue(/*max_entries=*/10);
+ }
+ return tensors;
+}
+} // namespace fcp::aggregation
diff --git a/fcp/aggregation/testing/testing.h b/fcp/aggregation/testing/testing.h
new file mode 100644
index 0000000..4b397d2
--- /dev/null
+++ b/fcp/aggregation/testing/testing.h
@@ -0,0 +1,169 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_AGGREGATION_TESTING_TESTING_H_
+#define FCP_AGGREGATION_TESTING_TESTING_H_
+
+#include <initializer_list>
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/aggregation/core/datatype.h"
+#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor_shape.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+
+namespace fcp::aggregation {
+
+namespace tf = ::tensorflow;
+
+template <typename T>
+tf::Tensor CreateTfTensor(tf::DataType data_type,
+ std::initializer_list<int64_t> dim_sizes,
+ std::initializer_list<T> values) {
+ tf::TensorShape shape;
+ EXPECT_TRUE(tf::TensorShape::BuildTensorShape(dim_sizes, &shape).ok());
+ tf::Tensor tensor(data_type, shape);
+ T* tensor_data_ptr = reinterpret_cast<T*>(tensor.data());
+ for (auto value : values) {
+ *tensor_data_ptr++ = value;
+ }
+ return tensor;
+}
+
+tf::Tensor CreateStringTfTensor(std::initializer_list<int64_t> dim_sizes,
+ std::initializer_list<string_view> values);
+
+// Wrapper around tf::ops::Save that sets up and runs the op.
+tf::Status CreateTfCheckpoint(tf::Input filename, tf::Input tensor_names,
+ tf::InputList tensors);
+
+// Returns a summary of the checkpoint as a map of tensor names and values.
+absl::StatusOr<absl::flat_hash_map<std::string, std::string>>
+SummarizeCheckpoint(const absl::Cord& checkpoint);
+
+// Converts a potentially sparse tensor to a flat vector of tensor values.
+template <typename T>
+std::vector<T> TensorValuesToVector(const Tensor& arg) {
+ std::vector<T> vec(arg.shape().NumElements());
+ AggVector<T> agg_vector = arg.AsAggVector<T>();
+ for (auto [i, v] : agg_vector) {
+ vec[i] = v;
+ }
+ return vec;
+}
+
+// Writes description of a tensor to the ostream.
+template <typename T>
+void DescribeTensor(::std::ostream* os, DataType dtype, TensorShape shape,
+ std::vector<T> values) {
+ // Max number of tensor values to be printed.
+ constexpr int kMaxValues = 100;
+ // TODO(team): Print dtype name istead of number.
+ *os << "{dtype: " << dtype;
+ *os << ", shape: {";
+ bool insert_comma = false;
+ for (auto dim_size : shape.dim_sizes()) {
+ if (insert_comma) {
+ *os << ", ";
+ }
+ *os << dim_size;
+ insert_comma = true;
+ }
+ *os << "}, values: {";
+ int num_values = 0;
+ insert_comma = false;
+ for (auto v : values) {
+ if (++num_values > kMaxValues) {
+ *os << "...";
+ break;
+ }
+ if (insert_comma) {
+ *os << ", ";
+ }
+ *os << v;
+ insert_comma = true;
+ }
+ *os << "}}";
+}
+
+// Writes description of a tensor to the ostream.
+std::ostream& operator<<(std::ostream& os, const Tensor& tensor);
+
+// TensorMatcher implementation.
+template <typename T>
+class TensorMatcherImpl : public ::testing::MatcherInterface<const Tensor&> {
+ public:
+ TensorMatcherImpl(DataType expected_dtype, TensorShape expected_shape,
+ std::vector<T> expected_values)
+ : expected_dtype_(expected_dtype),
+ expected_shape_(expected_shape),
+ expected_values_(expected_values) {}
+
+ void DescribeTo(std::ostream* os) const override {
+ DescribeTensor<T>(os, expected_dtype_, expected_shape_, expected_values_);
+ }
+
+ bool MatchAndExplain(
+ const Tensor& arg,
+ ::testing::MatchResultListener* listener) const override {
+ return arg.dtype() == expected_dtype_ && arg.shape() == expected_shape_ &&
+ TensorValuesToVector<T>(arg) == expected_values_;
+ }
+
+ private:
+ DataType expected_dtype_;
+ TensorShape expected_shape_;
+ std::vector<T> expected_values_;
+};
+
+// TensorMatcher can be used to compare a tensor against an expected
+// value type, shape, and the list of values.
+template <typename T>
+class TensorMatcher {
+ public:
+ explicit TensorMatcher(DataType expected_dtype, TensorShape expected_shape,
+ std::initializer_list<T> expected_values)
+ : expected_dtype_(expected_dtype),
+ expected_shape_(expected_shape),
+ expected_values_(expected_values.begin(), expected_values.end()) {}
+ // Intentionally allowed to be implicit.
+ operator ::testing::Matcher<const Tensor&>() const { // NOLINT
+ return ::testing::MakeMatcher(new TensorMatcherImpl<T>(
+ expected_dtype_, expected_shape_, expected_values_));
+ }
+
+ private:
+ DataType expected_dtype_;
+ TensorShape expected_shape_;
+ std::vector<T> expected_values_;
+};
+
+template <typename T>
+TensorMatcher<T> IsTensor(TensorShape expected_shape,
+ std::initializer_list<T> expected_values) {
+ return TensorMatcher<T>(internal::TypeTraits<T>::kDataType, expected_shape,
+ expected_values);
+}
+
+} // namespace fcp::aggregation
+
+#endif // FCP_AGGREGATION_TESTING_TESTING_H_
diff --git a/fcp/artifact_building/BUILD b/fcp/artifact_building/BUILD
new file mode 100644
index 0000000..a005503
--- /dev/null
+++ b/fcp/artifact_building/BUILD
@@ -0,0 +1,261 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+default_visibility = ["//fcp:internal"]
+
+py_library(
+ name = "artifact_constants",
+ srcs = ["artifact_constants.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+)
+
+py_library(
+ name = "checkpoint_type",
+ srcs = ["checkpoint_type.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+)
+
+py_library(
+ name = "checkpoint_utils",
+ srcs = ["checkpoint_utils.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+ deps = [
+ ":artifact_constants",
+ ":tensor_utils",
+ ":type_checks",
+ ":variable_helpers",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_test(
+ name = "checkpoint_utils_test",
+ srcs = [
+ "checkpoint_utils_test.py",
+ ],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":checkpoint_utils",
+ "//fcp/protos:plan_py_pb2",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "data_spec",
+ srcs = ["data_spec.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+ deps = [
+ ":type_checks",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_test(
+ name = "data_spec_test",
+ srcs = [
+ "data_spec_test.py",
+ ],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":data_spec",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_library(
+ name = "federated_compute_plan_builder",
+ srcs = ["federated_compute_plan_builder.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+ deps = [
+ ":artifact_constants",
+ ":checkpoint_type",
+ ":checkpoint_utils",
+ ":data_spec",
+ ":graph_helpers",
+ ":proto_helpers",
+ ":tensor_utils",
+ ":type_checks",
+ ":variable_helpers",
+ "//fcp/protos:plan_py_pb2",
+ "//fcp/tensorflow:append_slices_py",
+ "//fcp/tensorflow:delete_file_py",
+ ],
+)
+
+py_library(
+ name = "graph_helpers",
+ srcs = ["graph_helpers.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+ deps = [
+ ":data_spec",
+ ":tensor_utils",
+ ":type_checks",
+ "//fcp/tensorflow:external_dataset_py",
+ ],
+)
+
+py_test(
+ name = "graph_helpers_test",
+ size = "small",
+ srcs = ["graph_helpers_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":data_spec",
+ ":graph_helpers",
+ ":variable_helpers",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_library(
+ name = "plan_utils",
+ srcs = ["plan_utils.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+ deps = [
+ ":tensor_utils",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_test(
+ name = "plan_utils_test",
+ srcs = [
+ "plan_utils_test.py",
+ ],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":checkpoint_utils",
+ ":plan_utils",
+ ":test_utils",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_library(
+ name = "proto_helpers",
+ srcs = ["proto_helpers.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+ deps = [
+ ":tensor_utils",
+ ":type_checks",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_test(
+ name = "proto_helpers_test",
+ size = "small",
+ srcs = ["proto_helpers_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":proto_helpers",
+ ":variable_helpers",
+ ],
+)
+
+py_library(
+ name = "tensor_utils",
+ srcs = ["tensor_utils.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+)
+
+py_test(
+ name = "tensor_utils_test",
+ srcs = [
+ "tensor_utils_test.py",
+ ],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":tensor_utils",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "test_utils",
+ srcs = ["test_utils.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+ deps = ["//fcp/protos:plan_py_pb2"],
+)
+
+py_test(
+ name = "test_utils_test",
+ srcs = [
+ "test_utils_test.py",
+ ],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":checkpoint_utils",
+ ":test_utils",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_library(
+ name = "type_checks",
+ srcs = ["type_checks.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+)
+
+py_test(
+ name = "type_checks_test",
+ srcs = [
+ "type_checks_test.py",
+ ],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [":type_checks"],
+)
+
+py_library(
+ name = "variable_helpers",
+ srcs = ["variable_helpers.py"],
+ srcs_version = "PY3",
+ visibility = default_visibility + [
+ ],
+ deps = [
+ ":tensor_utils",
+ ":type_checks",
+ ],
+)
+
+py_test(
+ name = "variable_helpers_test",
+ size = "small",
+ srcs = ["variable_helpers_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":artifact_constants",
+ ":variable_helpers",
+ ],
+)
diff --git a/fcp/artifact_building/artifact_constants.py b/fcp/artifact_building/artifact_constants.py
new file mode 100644
index 0000000..6729844
--- /dev/null
+++ b/fcp/artifact_building/artifact_constants.py
@@ -0,0 +1,34 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Constants used throughout artifact building."""
+
+# These constants are required for legacy execution and harmless for federated
+# programs. They may be removed in the future.
+SERVER_STATE_VAR_PREFIX = 'server'
+SERVER_METRICS_VAR_PREFIX = 'metrics'
+
+# The name given to variables part of the client 'update' name space.
+UPDATE = 'update'
+
+# Indices into DistributeAggregateForm.client_to_server_aggregation parameter.
+INTERMEDIATE_STATE_INDEX = 0
+CLIENT_CHECKPOINT_INDEX = 1
+# This map is used in the construction of the aggregation portion of the Plan
+# proto to ensure that the names of the input tensor names for the aggregation
+# logic match the names of the corresponding output tensors generated by prior
+# stages of the computation.
+AGGREGATION_INTRINSIC_ARG_SELECTION_INDEX_TO_NAME_DICT = {
+ INTERMEDIATE_STATE_INDEX: 'intermediate_state',
+ CLIENT_CHECKPOINT_INDEX: 'update',
+}
diff --git a/fcp/artifact_building/checkpoint_type.py b/fcp/artifact_building/checkpoint_type.py
new file mode 100644
index 0000000..97cfd92
--- /dev/null
+++ b/fcp/artifact_building/checkpoint_type.py
@@ -0,0 +1,29 @@
+"""A module holding enum determining checkpoint format type.
+
+This FCP module should be kept lightweight as it is used for enum flag
+definition purposes.
+"""
+
+import enum
+
+
+@enum.unique
+class CheckpointFormatType(enum.Enum):
+ """Option adjusting checkpoint format between client and server.
+
+ Values:
+ TF1_SAVE_SLICES: The default value. Uses a standard TFv1 format.
+ APPEND_SLICES_MERGE_WRITE: Experimental value allowing to stream data to
+ checkpoint. The conversion from stream of checkpoint slices to TFv1 format
+ happens right after all the chunks are written. So the transport format is
+ identical to the TF1_SAVE_SLICES option.
+ APPEND_SLICES_MERGE_READ: Experimental value allowing to stream data to
+ checkpoint. The conversion from stream of checkpoint slices to TFv1 format
+ happens before the data is read. So the transport format is the
+ appended slices format. This setting has the smallest write memory
+ overhead.
+ """
+
+ TF1_SAVE_SLICES = 'tf1_save_slices'
+ APPEND_SLICES_MERGE_WRITE = 'append_slices_merge_write'
+ APPEND_SLICES_MERGE_READ = 'append_slices_merge_read'
diff --git a/fcp/artifact_building/checkpoint_utils.py b/fcp/artifact_building/checkpoint_utils.py
new file mode 100644
index 0000000..935fed5
--- /dev/null
+++ b/fcp/artifact_building/checkpoint_utils.py
@@ -0,0 +1,520 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Helper methods for working with demo server checkpoints."""
+
+import collections
+from collections.abc import Callable, Iterable, Mapping
+from typing import Any, Optional, Union
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import artifact_constants
+from fcp.artifact_building import tensor_utils
+from fcp.artifact_building import type_checks
+from fcp.artifact_building import variable_helpers
+from fcp.protos import plan_pb2
+
+SAVE_SERVER_SAVEPOINT_NAME = 'save_server_savepoint'
+
+
+def create_server_checkpoint_vars_and_savepoint(
+ *,
+ server_state_type: tff.StructType,
+ server_metrics_type: Optional[tff.StructType] = None,
+ write_metrics_to_checkpoint: bool = True,
+ additional_checkpoint_metadata_var_fn: Optional[
+ Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]]
+ ] = None,
+) -> tuple[
+ list[tf.Variable],
+ list[tf.Variable],
+ list[tf.Variable],
+ plan_pb2.CheckpointOp,
+]:
+ """Creates tf.Variables for a server checkpoint and the associated savepoint.
+
+ The variables and the associated saver are constructed in the default graph.
+
+ For now, only `server_state_type` is required. If metrics are to be saved in
+ the server checkpoint, `server_metrics_type` and `server_result_type` must
+ be provided. `server_state_type` refers to the server state portion of the
+ checkpoint and is used in the `Restore` op of the savepoint. The
+ `server_metrics_type` refers to the metrics saved in the checkpoint, and is
+ not used in the `Restore` op of the savepoint. `server_result_type` refers to
+ the complete round result structure stored in the checkpoint for a round.
+
+ Args:
+ server_state_type: A `tff.Type` with the type signature of the state. This
+ is used to construct the server state variable names stored in the
+ checkpoint and is used to create the metadata variables for the checkpoint
+ if `server_result_type` is not provided.
+ server_metrics_type: Optional. A `tff.Type` with the type signature of the
+ metrics. If provided, this is used to construct the metric variable names
+ that are stored in the checkpoint.
+ write_metrics_to_checkpoint: If False, revert to legacy behavior where
+ metrics and other non-state values were handled by post-processing
+ separate from the outputted checkpoint.
+ additional_checkpoint_metadata_var_fn: An optional method that takes in the
+ server_state_type, server_metrics_type, and write_metrics_to_checkpoint to
+ produce additional metadata variables.
+
+ Returns:
+ A tuple `(state_vars, metric_vars, metadata_vars, savepoint)`:
+ - `state_vars` is a Python `list` of variables that hold the state.
+ - `metric_vars` is a Python `list` of variables that hold the metrics.
+ - `metadata_vars` is a Python `list` of variables that hold optional
+ metadata.
+ - `savepoint` is the associated savepoint, i.e., an instance of
+ `plan_pb2.CheckpointOp` with a saver configured for saving the
+ `state_vars`, `metadata_vars`, and, if write_metrics_to_checkpoint is
+ True, `metric_vars`, and restoring the `state_vars` and
+ `metadata_vars`.
+ """
+ has_metrics = False
+ metric_vars = []
+ save_tensor_name = None
+ type_checks.check_type(server_state_type, tff.Type, name='server_state_type')
+ state_vars = variable_helpers.create_vars_for_tff_type(
+ server_state_type, artifact_constants.SERVER_STATE_VAR_PREFIX
+ )
+ var_names = list(map(tensor_utils.bare_name, state_vars))
+ metadata_vars = []
+ if server_metrics_type is not None:
+ type_checks.check_type(
+ server_metrics_type, tff.Type, name='server_metrics_type'
+ )
+ metric_vars = variable_helpers.create_vars_for_tff_type(
+ server_metrics_type, artifact_constants.SERVER_METRICS_VAR_PREFIX
+ )
+ if additional_checkpoint_metadata_var_fn:
+ metadata_vars = additional_checkpoint_metadata_var_fn(
+ state_vars, metric_vars, write_metrics_to_checkpoint
+ )
+
+ has_metrics = bool(tff.structure.flatten(server_metrics_type))
+ if has_metrics and write_metrics_to_checkpoint:
+ var_names.extend(list(map(tensor_utils.bare_name, metric_vars)))
+
+ temp_saver_for_all_vars = create_deterministic_saver(
+ var_list=state_vars + metadata_vars + metric_vars,
+ name=SAVE_SERVER_SAVEPOINT_NAME,
+ )
+ temp_saver_def = temp_saver_for_all_vars.as_saver_def()
+ save_tensor_name = temp_saver_def.save_tensor_name
+ else:
+ if additional_checkpoint_metadata_var_fn:
+ metadata_vars = additional_checkpoint_metadata_var_fn(
+ state_vars, None, write_metrics_to_checkpoint
+ )
+
+ saver = create_deterministic_saver(
+ var_list=state_vars + metadata_vars,
+ name='{}_savepoint'.format(artifact_constants.SERVER_STATE_VAR_PREFIX),
+ )
+ savepoint = plan_pb2.CheckpointOp()
+ savepoint.saver_def.CopyFrom(saver.as_saver_def())
+
+ if save_tensor_name is not None:
+ # Replace the save_tensor_name to the one in
+ # temp_saver_for_all_vars so that we are additionally saving metrics vars
+ # in the checkpoint that don't need to be restored as part of the input
+ # computation state.
+ # Once we create the server GraphDef, we will edit the GraphDef directly
+ # to ensure the input filename links to the filename tensor from the
+ # `savepoint`.
+ savepoint.saver_def.save_tensor_name = save_tensor_name
+ return state_vars, metric_vars, metadata_vars, savepoint
+
+
+def create_state_vars_and_savepoint(
+ type_spec: variable_helpers.AllowedTffTypes, name: str
+) -> tuple[list[tf.Variable], plan_pb2.CheckpointOp]:
+ """Creates state variables and their savepoint as a `plan_pb2.CheckpointOp`.
+
+ The variables and the associated saver are constructed in the default graph.
+
+ Args:
+ type_spec: An instance of `tff.Type` with the type signature of the state.
+ name: The string to use as a basis for naming the vars and the saver. The
+ vars will be under `${name}_state`, and saver under `${name}_savepoint`.
+
+ Returns:
+ A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables
+ that hold the state, and `savepoint` is the associated savepoint, i.e.,
+ an instance of `plan_pb2.CheckpointOp` with a saver configured for saving
+ and restoring the `vars`.
+
+ Raises:
+ ValueError: If the name is empty.
+ """
+ state_vars, saver = create_state_vars_and_saver(type_spec, name)
+ savepoint = plan_pb2.CheckpointOp()
+ savepoint.saver_def.CopyFrom(saver.as_saver_def())
+ return state_vars, savepoint
+
+
+def create_state_vars_and_saver(
+ type_spec: variable_helpers.AllowedTffTypes, name: str
+) -> tuple[list[tf.Variable], tf.compat.v1.train.Saver]:
+ """Creates state variables and the associated saver.
+
+ The variables and the associated saver are constructed in the default graph.
+
+ Args:
+ type_spec: An instance of `tff.Type` with the type signature of the state.
+ name: The string to use as a basis for naming the vars and the saver. The
+ vars will be under `${name}_state`, and saver under `${name}_savepoint`.
+
+ Returns:
+ A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables
+ that hold the state, and `savepoint` is the associated
+ `tf.compat.v1.train.Saver`.
+
+ Raises:
+ ValueError: If the name is empty.
+ """
+ type_checks.check_type(type_spec, tff.Type, name='type_spec')
+ type_checks.check_type(name, str, name='name')
+ if not name:
+ raise ValueError('Name cannot be empty.')
+ state_vars = variable_helpers.create_vars_for_tff_type(type_spec, name)
+ saver = create_deterministic_saver(
+ state_vars, name='{}_savepoint'.format(name)
+ )
+ return state_vars, saver
+
+
+def restore_tensors_from_savepoint(
+ tensor_specs: Iterable[tf.TensorSpec], filepath_tensor: tf.Tensor
+) -> list[tf.Tensor]:
+ """Restores tensors from a checkpoint designated by a tensor filepath.
+
+ Args:
+ tensor_specs: A `list` of `tf.TensorSpec`s with the names and dtypes of the
+ tensors to restore.
+ filepath_tensor: A placeholder tensor that contains file names with a given
+ pattern.
+
+ Returns:
+ A list of restored tensors.
+ """
+ return [
+ tensor_utils.restore(
+ filepath_tensor, tensor_utils.bare_name(spec.name), spec.dtype
+ )
+ for spec in tensor_specs
+ ]
+
+
+def create_deterministic_saver(
+ var_list: Union[Iterable[tf.Variable], Mapping[str, tf.Variable]],
+ *args,
+ **kwargs,
+) -> tf.compat.v1.train.Saver:
+ """Creates a `tf.compat.v1.Saver` that is deterministic.
+
+ This method sorts the `var_list` to ensure a deterministic ordering which
+ in turn ensures a deterministic checkpoint.
+
+ Uses `tf.compat.v1.train.SaverDef.V1` version for writing checkpoints.
+
+ Args:
+ var_list: An `Iterable` or `str` keyed `Mapping` of `tf.Variables`. In the
+ case of a `dict`, the keys become the names of the checkpoint variables
+ (rather than reading the names off the `tf.Variable` values).
+ *args: Positional arguments forwarded to the `tf.compat.v1.train.Saver`
+ constructor.
+ **kwargs: Keyword arguments forwarded to the `tf.compat.v1.train.Saver`
+ constructor.
+
+ Returns:
+ A `tf.compat.v1.train.Saver` instance.
+ """
+ if isinstance(var_list, collections.abc.Mapping):
+ determinisic_names = collections.OrderedDict(sorted(var_list.items()))
+ elif isinstance(var_list, collections.abc.Iterable):
+ determinisic_names = sorted(var_list, key=lambda v: v.name)
+ else:
+ raise ValueError(
+ 'Do not know how to make a deterministic saver for '
+ '`var_list` of type [{t}]. Must be a Mapping or Sequence'.format(
+ t=type(var_list)
+ )
+ )
+ return tf.compat.v1.train.Saver(
+ determinisic_names,
+ write_version=tf.compat.v1.train.SaverDef.V1,
+ *args,
+ **kwargs,
+ )
+
+
+def tff_type_to_dtype_list(
+ tff_type: variable_helpers.AllowedTffTypes,
+) -> list[tf.DType]:
+ """Creates a flat list of `tf.DType`s for tensors in a `tff.Type`.
+
+ Args:
+ tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a
+ `tff.TensorType` object.
+
+ Returns:
+ A flat list of `tf.DType`s.
+ """
+ type_checks.check_type(
+ tff_type, (tff.TensorType, tff.FederatedType, tff.StructType)
+ )
+ if isinstance(tff_type, tff.TensorType):
+ return [tff_type.dtype]
+ elif isinstance(tff_type, tff.FederatedType):
+ return tff_type_to_dtype_list(tff_type.member)
+ else: # tff.StructType
+ elem_list = []
+ for elem_type in tff_type:
+ elem_list.extend(tff_type_to_dtype_list(elem_type))
+ return elem_list
+
+
+def tff_type_to_tensor_spec_list(
+ tff_type: variable_helpers.AllowedTffTypes,
+) -> list[tf.TensorSpec]:
+ """Creates a flat list of tensor specs for tensors in a `tff.Type`.
+
+ Args:
+ tff_type: Either a `tff.StructType`, `tff.FederatedType` or a
+ `tff.TensorType` object.
+
+ Returns:
+ A flat list of `tf.TensorSpec`s.
+ """
+ type_checks.check_type(
+ tff_type, (tff.TensorType, tff.FederatedType, tff.StructType)
+ )
+ if isinstance(tff_type, tff.TensorType):
+ return [tf.TensorSpec(tff_type.shape, dtype=tff_type.dtype)]
+ elif isinstance(tff_type, tff.FederatedType):
+ return tff_type_to_tensor_spec_list(tff_type.member)
+ else: # tff.StructType
+ elem_list = []
+ for elem_type in tff_type:
+ elem_list.extend(tff_type_to_tensor_spec_list(elem_type))
+ return elem_list
+
+
+def pack_tff_value(
+ tff_type: variable_helpers.AllowedTffTypes, value_list: Any
+) -> Any:
+ """Packs a list of values into a shape specified by a `tff.Type`.
+
+ Args:
+ tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a
+ `tff.TensorType` object.
+ value_list: A flat list of `tf.Tensor` or `CheckpointTensorReference`.
+
+ Returns:
+ A Python container with a structure consistent with a `tff.Type`.
+
+ Raises:
+ ValueError: If the number of leaves in `tff_type` does not match the length
+ of `value_list`, or `tff_type` is of a disallowed type.
+ """
+ type_checks.check_type(
+ tff_type, (tff.TensorType, tff.FederatedType, tff.StructType)
+ )
+
+ # We must "unwrap" any FederatedTypes because the
+ # `tff.structure.pack_sequence_as` call below will fail to recurse into them.
+ # Instead, we remove all the FederatedTypes, because we're only trying to
+ # build up a Python tree structure that matches the struct/tensor types from a
+ # list of values.
+ def remove_federated_types(
+ type_spec: tff.Type,
+ ) -> Union[tff.StructType, tff.TensorType]:
+ """Removes `FederatedType` from a type tree, returning a new tree."""
+ if type_spec.is_tensor():
+ return type_spec
+ elif type_spec.is_federated():
+ return type_spec.member
+ elif type_spec.is_struct():
+ return tff.StructType(
+ (elem_name, remove_federated_types(elem_type))
+ for elem_name, elem_type in tff.structure.iter_elements(type_spec)
+ )
+ else:
+ raise ValueError(
+ 'Must be either tff.TensorType, tff.FederatedType, or tff.StructType.'
+ f' Got a {type(type_spec)}'
+ )
+
+ try:
+ tff_type = remove_federated_types(tff_type)
+ except ValueError as e:
+ raise ValueError(
+ '`tff_type` is not packable, see earlier error. '
+ f'Attempted to pack type: {tff_type}'
+ ) from e
+
+ ordered_dtypes = tff_type_to_dtype_list(tff_type)
+ if len(ordered_dtypes) != len(value_list):
+ raise ValueError(
+ 'The number of leaves in `tff_type` must equals the length'
+ ' of `value_list`. Found `tff_type` with'
+ f' {len(ordered_dtypes)} leaves and `value_list` of length'
+ f' {len(value_list)}.'
+ )
+
+ if tff_type.is_tensor():
+ return value_list[0]
+ elif tff_type.is_struct():
+ return tff.structure.pack_sequence_as(tff_type, value_list)
+ else:
+ raise ValueError(
+ '`tff_type` must be either tff.TensorType or '
+ 'tff.StructType, reaching here is an internal coding '
+ 'error, please file a bug.'
+ )
+
+
+def variable_names_from_structure(
+ tff_structure: Union[tff.structure.Struct, tf.Tensor], name: str = 'v'
+) -> list[str]:
+ """Creates a flattened list of variable names for the given structure.
+
+ If the `tff_structure` is a `tf.Tensor`, the name is the `name` parameter if
+ specified, otheriwse a default name: `v`. If `tff_structure` is a
+ `tff.structure.Struct` then '/' is used between inner and outer fields
+ together with the tuple name or index of the element in the tuple.
+
+ Some examples:
+ 1. If the `tff_structure` is `<'a'=tf.constant(1.0), 'b'=tf.constant(0.0)>`
+ and name is not specified, the returned variable name list is
+ ['v/a', 'v/b'].
+ 2. If the `tff_structure` is `<None=tf.constant(1.0), None=tf.constant(0.0)>`
+ and `name` is `update`, the returned variable name list is
+ ['update/0', 'update/1'].
+ 3. If the `tff_structure` is
+ `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(0.0)>>` and `name` is
+ `update`, the returned variable name list is ['update/a/b', 'update/a/c'].
+ 4. If the `tff_structure` is
+ `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(1.0), tf.constant(0.0)>>` and
+ `name` is `update`, the returned variable name list is ['update/a/b',
+ 'update/a/c', 'update/a/2'].
+
+ Args:
+ tff_structure: Either a `tff.structure.Struct` or a `tf.Tensor` object.
+ name: The preferred name to use at the top-most level (if not None, must be
+ a string). If `tff_structure` is a `tff.structure.Struct`, the names of
+ the inner fields will be scoped under `name`, e.g. `some_name/field_name`.
+
+ Returns:
+ A flat Python `list` of `str` names.
+
+ Raises:
+ TypeError: If either argument is of the wrong type.
+ """
+ type_checks.check_type(
+ tff_structure, (tff.structure.Struct, tf.Tensor), name='structure_type'
+ )
+ type_checks.check_type(name, str, name='name')
+ if isinstance(tff_structure, tf.Tensor):
+ return [name]
+ elif isinstance(tff_structure, tff.structure.Struct):
+ result = []
+ fields = tff.structure.iter_elements(tff_structure)
+ for index, (field_name, field_type) in enumerate(fields):
+ # Default the name of the element to its index so that we don't wind up
+ # with multiple child fields listed under `/v/`
+ field_name = field_name or str(index)
+ result.extend(
+ variable_names_from_structure(
+ field_type, name=name + '/' + field_name
+ )
+ )
+ return result
+ else:
+ raise TypeError(
+ 'Cannot create variable names from [{t}] type. Short-hand: {s}'.format(
+ t=type(tff_structure), s=tff_structure
+ )
+ )
+
+
+def is_structure_of_allowed_types(
+ structure: Union[
+ tff.structure.Struct,
+ tf.Tensor,
+ np.ndarray,
+ np.number,
+ int,
+ float,
+ str,
+ bytes,
+ ]
+) -> bool:
+ """Checks if each node in `structure` is an allowed type for serialization."""
+ flattened_structure = tff.structure.flatten(structure)
+ for item in flattened_structure:
+ if not (
+ tf.is_tensor(item)
+ or isinstance(item, (np.ndarray, np.number, int, float, str, bytes))
+ ):
+ return False
+ return True
+
+
+def save_tff_structure_to_checkpoint(
+ tff_structure: Union[tff.structure.Struct, tf.Tensor],
+ ordered_var_names: list[str],
+ output_checkpoint_path: str,
+) -> None:
+ """Saves a TFF structure to a checkpoint file.
+
+ The input `tff_structure` is a either `tff.structure.Struct` or a single
+ `tf.Tensor`. This function saves `tff_structure` to a checkpoint file using
+ variable names supplied via the `ordered_var_names` argument.
+
+ Args:
+ tff_structure: A `tff.structure.Struct` of values or a single value. Each
+ leaf in the structure must be a value serializable to a TensorFlow
+ checkpoint.
+ ordered_var_names: The list of variable names for the values that appear in
+ `tff_structure` after calling `tff.structure.flatten()`.
+ output_checkpoint_path: A string specifying the path to the output
+ checkpoint file.
+
+ Raises:
+ TypeError: If not all leaves in `tff_structure` are of allowed types.
+ ValueError: If the number of `tf.Tensor`s in `tff_structure` does not match
+ the size of `ordered_var_names`.
+ """
+ if not is_structure_of_allowed_types(tff_structure):
+ raise TypeError(
+ 'Not all leaves in `tff_structure` are `tf.Tensor`s, '
+ '`np.ndarray`s, `np.number`s, or Python scalars. Got: '
+ f'{tff.structure.map_structure(type, tff_structure)!r})'
+ )
+
+ tensors = tff.structure.flatten(tff_structure)
+ if len(tensors) != len(ordered_var_names):
+ raise ValueError(
+ 'The length of `ordered_var_names` does not match the '
+ 'number of tensors in `tff_structure`:'
+ f'{len(ordered_var_names)} != {len(tensors)}'
+ )
+
+ tensor_utils.save(
+ output_checkpoint_path, tensor_names=ordered_var_names, tensors=tensors
+ )
diff --git a/fcp/artifact_building/checkpoint_utils_test.py b/fcp/artifact_building/checkpoint_utils_test.py
new file mode 100644
index 0000000..8450bd9
--- /dev/null
+++ b/fcp/artifact_building/checkpoint_utils_test.py
@@ -0,0 +1,364 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for checkpoint_utils."""
+
+import collections
+import os
+import typing
+from typing import Any
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from google.protobuf import any_pb2
+from fcp.artifact_building import checkpoint_utils
+from fcp.protos import plan_pb2
+
+
+class CheckpointUtilsTest(tf.test.TestCase, parameterized.TestCase):
+
+ def _assert_variable_functionality(
+ self, test_vars: list[tf.Variable], test_value_to_save: Any = 10
+ ):
+ self.assertIsInstance(test_vars, list)
+ initializer = tf.compat.v1.global_variables_initializer()
+ for test_variable in test_vars:
+ with self.test_session() as session:
+ session.run(initializer)
+ session.run(test_variable.assign(test_value_to_save))
+ self.assertEqual(session.run(test_variable), test_value_to_save)
+
+ def test_create_server_checkpoint_vars_and_savepoint_succeeds_state_vars(
+ self,
+ ):
+ with tf.Graph().as_default():
+ state_vars, _, _, savepoint = (
+ checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
+ server_state_type=tff.to_type([('foo1', tf.int32)]),
+ server_metrics_type=tff.to_type([('bar2', tf.int32)]),
+ write_metrics_to_checkpoint=True,
+ )
+ )
+ self.assertIsInstance(savepoint, plan_pb2.CheckpointOp)
+ self._assert_variable_functionality(state_vars)
+
+ def test_create_server_checkpoint_vars_and_savepoint_succeeds_metadata_vars(
+ self,
+ ):
+ def additional_checkpoint_metadata_var_fn(
+ state_vars, metrics_vars, write_metrics_to_checkpoint
+ ):
+ del state_vars, metrics_vars, write_metrics_to_checkpoint
+ return [tf.Variable(initial_value=b'dog', name='metadata')]
+
+ with tf.Graph().as_default():
+ _, _, metadata_vars, savepoint = (
+ checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
+ server_state_type=tff.to_type([('foo3', tf.int32)]),
+ server_metrics_type=tff.to_type([('bar1', tf.int32)]),
+ additional_checkpoint_metadata_var_fn=(
+ additional_checkpoint_metadata_var_fn
+ ),
+ write_metrics_to_checkpoint=True,
+ )
+ )
+ self.assertIsInstance(savepoint, plan_pb2.CheckpointOp)
+ self._assert_variable_functionality(
+ metadata_vars, test_value_to_save=b'cat'
+ )
+
+ def test_create_server_checkpoint_vars_and_savepoint_succeeds_metrics_vars(
+ self,
+ ):
+ with tf.Graph().as_default():
+ _, metrics_vars, _, savepoint = (
+ checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
+ server_state_type=tff.to_type([('foo2', tf.int32)]),
+ server_metrics_type=tff.to_type([('bar3', tf.int32)]),
+ write_metrics_to_checkpoint=True,
+ )
+ )
+ self.assertIsInstance(savepoint, plan_pb2.CheckpointOp)
+ self._assert_variable_functionality(metrics_vars)
+
+ def test_tff_type_to_dtype_list_as_expected(self):
+ tff_type = tff.FederatedType(
+ tff.StructType([('foo', tf.int32), ('bar', tf.string)]), tff.SERVER
+ )
+ expected_dtype_list = [tf.int32, tf.string]
+ self.assertEqual(
+ checkpoint_utils.tff_type_to_dtype_list(tff_type), expected_dtype_list
+ )
+
+ def test_tff_type_to_dtype_list_type_error(self):
+ list_type = [tf.int32, tf.string]
+ with self.assertRaisesRegex(TypeError, 'to be an instance of type'):
+ checkpoint_utils.tff_type_to_dtype_list(list_type)
+
+ def test_tff_type_to_tensor_spec_list_as_expected(self):
+ tff_type = tff.FederatedType(
+ tff.StructType(
+ [('foo', tf.int32), ('bar', tff.TensorType(tf.string, shape=[1]))]
+ ),
+ tff.SERVER,
+ )
+ expected_tensor_spec_list = [
+ tf.TensorSpec([], tf.int32),
+ tf.TensorSpec([1], tf.string),
+ ]
+ self.assertEqual(
+ checkpoint_utils.tff_type_to_tensor_spec_list(tff_type),
+ expected_tensor_spec_list,
+ )
+
+ def test_tff_type_to_tensor_spec_list_type_error(self):
+ list_type = [tf.int32, tf.string]
+ with self.assertRaisesRegex(TypeError, 'to be an instance of type'):
+ checkpoint_utils.tff_type_to_tensor_spec_list(list_type)
+
+ def test_pack_tff_value_with_tensors_as_expected(self):
+ tff_type = tff.StructType([('foo', tf.int32), ('bar', tf.string)])
+ value_list = [
+ tf.constant(1, dtype=tf.int32),
+ tf.constant('bla', dtype=tf.string),
+ ]
+ expected_packed_structure = tff.structure.Struct([
+ ('foo', tf.constant(1, dtype=tf.int32)),
+ ('bar', tf.constant('bla', dtype=tf.string)),
+ ])
+ self.assertEqual(
+ checkpoint_utils.pack_tff_value(tff_type, value_list),
+ expected_packed_structure,
+ )
+
+ def test_pack_tff_value_with_federated_server_tensors_as_expected(self):
+ # This test must create a type that has `StructType`s nested under the
+ # `FederatedType` to cover testing that tff.structure.pack_sequence_as
+ # package correctly descends through the entire type tree.
+ tff_type = tff.to_type(
+ collections.OrderedDict(
+ foo=tff.FederatedType(tf.int32, tff.SERVER),
+ # Some arbitrarily deep nesting to ensure full traversals.
+ bar=tff.FederatedType([(), ([tf.int32], tf.int32)], tff.SERVER),
+ )
+ )
+ value_list = [tf.constant(1), tf.constant(2), tf.constant(3)]
+ expected_packed_structure = tff.structure.from_container(
+ collections.OrderedDict(
+ foo=tf.constant(1), bar=[(), ([tf.constant(2)], tf.constant(3))]
+ ),
+ recursive=True,
+ )
+ self.assertEqual(
+ checkpoint_utils.pack_tff_value(tff_type, value_list),
+ expected_packed_structure,
+ )
+
+ def test_pack_tff_value_with_unmatched_input_sizes(self):
+ tff_type = tff.StructType([('foo', tf.int32), ('bar', tf.string)])
+ value_list = [tf.constant(1, dtype=tf.int32)]
+ with self.assertRaises(ValueError):
+ checkpoint_utils.pack_tff_value(tff_type, value_list)
+
+ def test_pack_tff_value_with_tff_type_error(self):
+ @tff.federated_computation
+ def fed_comp():
+ return tff.federated_value(0, tff.SERVER)
+
+ tff_function_type = fed_comp.type_signature
+ value_list = [tf.constant(1, dtype=tf.int32)]
+ with self.assertRaisesRegex(TypeError, 'to be an instance of type'):
+ checkpoint_utils.pack_tff_value(tff_function_type, value_list)
+
+ def test_variable_names_from_structure_with_tensor_and_no_name(self):
+ names = checkpoint_utils.variable_names_from_structure(tf.constant(1.0))
+ self.assertEqual(names, ['v'])
+
+ def test_variable_names_from_structure_with_tensor(self):
+ names = checkpoint_utils.variable_names_from_structure(
+ tf.constant(1.0), 'test_name'
+ )
+ self.assertEqual(names, ['test_name'])
+
+ def test_variable_names_from_structure_with_named_tuple_type_and_no_name(
+ self,
+ ):
+ names = checkpoint_utils.variable_names_from_structure(
+ tff.structure.Struct([
+ ('a', tf.constant(1.0)),
+ (
+ 'b',
+ tff.structure.Struct(
+ [('c', tf.constant(True)), ('d', tf.constant(0.0))]
+ ),
+ ),
+ ])
+ )
+ self.assertEqual(names, ['v/a', 'v/b/c', 'v/b/d'])
+
+ def test_variable_names_from_structure_with_named_struct(self):
+ names = checkpoint_utils.variable_names_from_structure(
+ tff.structure.Struct([
+ ('a', tf.constant(1.0)),
+ (
+ 'b',
+ tff.structure.Struct(
+ [('c', tf.constant(True)), ('d', tf.constant(0.0))]
+ ),
+ ),
+ ]),
+ 'test_name',
+ )
+ self.assertEqual(names, ['test_name/a', 'test_name/b/c', 'test_name/b/d'])
+
+ def test_variable_names_from_structure_with_named_tuple_type_no_name_field(
+ self,
+ ):
+ names = checkpoint_utils.variable_names_from_structure(
+ tff.structure.Struct([
+ (None, tf.constant(1.0)),
+ (
+ 'b',
+ tff.structure.Struct(
+ [(None, tf.constant(False)), ('d', tf.constant(0.0))]
+ ),
+ ),
+ ]),
+ 'test_name',
+ )
+ self.assertEqual(names, ['test_name/0', 'test_name/b/0', 'test_name/b/d'])
+
+ def test_save_tf_tensor_to_checkpoint_as_expected(self):
+ temp_dir = self.create_tempdir()
+ output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt')
+
+ tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]])
+
+ checkpoint_utils.save_tff_structure_to_checkpoint(
+ tensor, ['v'], output_checkpoint_path=output_checkpoint_path
+ )
+
+ reader = tf.compat.v1.train.NewCheckpointReader(output_checkpoint_path)
+ var_to_shape_map = reader.get_variable_to_shape_map()
+ self.assertLen(var_to_shape_map, 1)
+ self.assertIn('v', var_to_shape_map)
+ np.testing.assert_almost_equal(
+ [[1.0, 2.0], [3.0, 4.0]], reader.get_tensor('v')
+ )
+
+ def test_save_tff_struct_to_checkpoint_as_expected(self):
+ temp_dir = self.create_tempdir()
+ output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt')
+
+ struct = tff.structure.Struct([
+ ('foo', tf.constant(1, dtype=tf.int32)),
+ ('bar', tf.constant('bla', dtype=tf.string)),
+ ])
+
+ checkpoint_utils.save_tff_structure_to_checkpoint(
+ struct,
+ ordered_var_names=['v/foo', 'v/bar'],
+ output_checkpoint_path=output_checkpoint_path,
+ )
+
+ reader = tf.compat.v1.train.NewCheckpointReader(output_checkpoint_path)
+ var_to_shape_map = reader.get_variable_to_shape_map()
+ self.assertLen(var_to_shape_map, 2)
+ self.assertIn('v/foo', var_to_shape_map)
+ self.assertIn('v/bar', var_to_shape_map)
+ self.assertEqual(1, reader.get_tensor('v/foo'))
+ self.assertEqual(b'bla', reader.get_tensor('v/bar'))
+
+ def test_save_tff_struct_to_checkpoint_fails_if_wrong_num_var_names(self):
+ temp_dir = self.create_tempdir()
+ output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt')
+
+ struct = tff.structure.Struct([
+ ('foo', tf.constant(1, dtype=tf.int32)),
+ ('bar', tf.constant('bla', dtype=tf.string)),
+ ])
+
+ with self.assertRaisesRegex(ValueError, 'does not match the number'):
+ checkpoint_utils.save_tff_structure_to_checkpoint(
+ struct,
+ ordered_var_names=['v/foo'],
+ output_checkpoint_path=output_checkpoint_path,
+ )
+
+ @parameterized.named_parameters(
+ ('tf.tensor', tf.constant(1.0)),
+ ('ndarray', np.asarray([1.0, 2.0, 3.0])),
+ ('npnumber', np.float64(1.0)),
+ ('int', 1),
+ ('float', 1.0),
+ ('str', 'test'),
+ ('bytes', b'test'),
+ )
+ def test_is_allowed(self, structure):
+ self.assertTrue(checkpoint_utils.is_structure_of_allowed_types(structure))
+
+ @parameterized.named_parameters(
+ ('function', lambda x: x),
+ ('any_proto', any_pb2.Any()),
+ )
+ def test_is_not_allowed(self, structure):
+ self.assertFalse(checkpoint_utils.is_structure_of_allowed_types(structure))
+
+
+class CreateDeterministicSaverTest(tf.test.TestCase):
+
+ def test_failure_unknown_type(self):
+ with self.assertRaisesRegex(ValueError, 'Do not know how to make'):
+ # Using a cast in case the test is being run with static type checking.
+ checkpoint_utils.create_deterministic_saver(
+ typing.cast(list[tf.Variable], 0)
+ )
+
+ def test_creates_saver_for_list(self):
+ with tf.Graph().as_default() as g:
+ saver = checkpoint_utils.create_deterministic_saver([
+ tf.Variable(initial_value=1.0, name='z'),
+ tf.Variable(initial_value=2.0, name='x'),
+ tf.Variable(initial_value=3.0, name='y'),
+ ])
+ self.assertIsInstance(saver, tf.compat.v1.train.Saver)
+ test_filepath = self.create_tempfile().full_path
+ with tf.compat.v1.Session(graph=g) as sess:
+ sess.run(tf.compat.v1.global_variables_initializer())
+ saver.save(sess, save_path=test_filepath)
+ variable_specs = tf.train.list_variables(test_filepath)
+ self.assertEqual([('x', []), ('y', []), ('z', [])], variable_specs)
+
+ def test_creates_saver_for_dict(self):
+ with tf.Graph().as_default() as g:
+ saver = checkpoint_utils.create_deterministic_saver({
+ 'foo': tf.Variable(initial_value=1.0, name='z'),
+ 'baz': tf.Variable(initial_value=2.0, name='x'),
+ 'bar': tf.Variable(initial_value=3.0, name='y'),
+ })
+ self.assertIsInstance(saver, tf.compat.v1.train.Saver)
+ test_filepath = self.create_tempfile().full_path
+ with tf.compat.v1.Session(graph=g) as sess:
+ sess.run(tf.compat.v1.global_variables_initializer())
+ saver.save(sess, save_path=test_filepath)
+ variable_specs = tf.train.list_variables(test_filepath)
+ self.assertEqual([('bar', []), ('baz', []), ('foo', [])], variable_specs)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/artifact_building/data_spec.py b/fcp/artifact_building/data_spec.py
new file mode 100644
index 0000000..619ca5b
--- /dev/null
+++ b/fcp/artifact_building/data_spec.py
@@ -0,0 +1,150 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A class to specify on-device dataset inputs."""
+
+from collections.abc import Callable
+from typing import Any, Optional, Union
+
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import type_checks
+from fcp.protos import plan_pb2
+
+
+class DataSpec:
+ """A specification of a single dataset input."""
+
+ __slots__ = (
+ '_example_selector_proto',
+ '_preprocessing_fn',
+ '_preprocessing_comp',
+ '_fingerprint',
+ )
+
+ def __init__(
+ self,
+ example_selector_proto: plan_pb2.ExampleSelector,
+ preprocessing_fn: Optional[
+ Callable[[tf.data.Dataset], tf.data.Dataset]
+ ] = None,
+ ):
+ """Constructs a specification of a dataset input.
+
+ Args:
+ example_selector_proto: An instance of `plan_pb2.ExampleSelector` proto.
+ preprocessing_fn: A callable that accepts as an argument the raw input
+ `tf.data.Dataset` with `string`-serialized items, performs any desired
+ preprocessing such as deserialization, filtering, batching, and
+ formatting, and returns the transformed `tf.data.Dataset` as a result.
+ If preprocessing_fn is set to None, it is expected that any client data
+ preprocessing has already been incorporated into the `tff.Computation`
+ that this `DataSpec` is associated with.
+
+ Raises:
+ TypeError: If the types of the arguments are invalid.
+ """
+ type_checks.check_type(
+ example_selector_proto,
+ plan_pb2.ExampleSelector,
+ name='example_selector_proto',
+ )
+ if preprocessing_fn is not None:
+ type_checks.check_callable(preprocessing_fn, name='preprocessing_fn')
+ self._example_selector_proto = example_selector_proto
+ self._preprocessing_fn = preprocessing_fn
+ # Set once self.preprocessing_comp is accessed, as we can't call
+ # tff.computation in __init__.
+ self._preprocessing_comp = None
+
+ @property
+ def example_selector_proto(self) -> plan_pb2.ExampleSelector:
+ return self._example_selector_proto
+
+ @property
+ def preprocessing_fn(
+ self,
+ ) -> Optional[Callable[[tf.data.Dataset], tf.data.Dataset]]:
+ return self._preprocessing_fn
+
+ @property
+ def preprocessing_comp(self) -> tff.Computation:
+ """Returns the preprocessing computation for the input dataset."""
+ if self._preprocessing_comp is None:
+ if self.preprocessing_fn is None:
+ raise ValueError(
+ "DataSpec's preprocessing_fn is None so a "
+ 'preprocessing tff.Computation cannot be generated.'
+ )
+ self._preprocessing_comp = tff.tf_computation(
+ self.preprocessing_fn, tff.SequenceType(tf.string)
+ )
+ return self._preprocessing_comp
+
+ @property
+ def type_signature(self) -> tff.Type:
+ """Returns the type signature of the result of the preprocessing_comp.
+
+ Effectively the type or 'spec' of the parsed example from the example store
+ pointed at by `example_selector_proto`.
+ """
+ return self.preprocessing_comp.type_signature.result
+
+
+def is_data_spec_or_structure(x: Any) -> bool:
+ """Returns True iff `x` is either a `DataSpec` or a nested structure of it."""
+ if x is None:
+ return False
+ if isinstance(x, DataSpec):
+ return True
+ try:
+ x = tff.structure.from_container(x)
+ return all(
+ is_data_spec_or_structure(y) for _, y in tff.structure.to_elements(x)
+ )
+ except TypeError:
+ return False
+
+
+def check_data_spec_or_structure(x: Any, name: str):
+ """Raises error iff `x` is not a `DataSpec` or a nested structure of it."""
+ if not is_data_spec_or_structure(x):
+ raise TypeError(
+ f'Expected `{name}` to be a `DataSpec` or a nested '
+ f'structure of it, found {str(x)}.'
+ )
+
+
+NestedDataSpec = Union[DataSpec, dict[str, 'NestedDataSpec']]
+
+
+def generate_example_selector_bytes_list(ds: NestedDataSpec):
+ """Returns an ordered list of the bytes of each DataSpec's example selector.
+
+ The order aligns with the order of a struct given by
+ tff.structure.to_elements().
+
+ Args:
+ ds: A `NestedDataSpec`.
+ """
+ if isinstance(ds, DataSpec):
+ return [ds.example_selector_proto.SerializeToString()]
+ else:
+ ds = tff.structure.from_container(ds)
+ assert isinstance(ds, tff.structure.Struct)
+ data_spec_elements = tff.structure.to_elements(ds)
+ selector_bytes_list = []
+ for _, element in data_spec_elements:
+ selector_bytes_list.extend(generate_example_selector_bytes_list(element))
+ return selector_bytes_list
diff --git a/fcp/artifact_building/data_spec_test.py b/fcp/artifact_building/data_spec_test.py
new file mode 100644
index 0000000..c8d3538
--- /dev/null
+++ b/fcp/artifact_building/data_spec_test.py
@@ -0,0 +1,71 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for data_spec.py."""
+
+import collections
+
+from absl.testing import absltest
+
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import data_spec
+from fcp.protos import plan_pb2
+
+_TEST_EXAMPLE_SELECTOR = plan_pb2.ExampleSelector(
+ collection_uri='app://fake_uri'
+)
+
+
+class DataSpecTest(absltest.TestCase):
+
+ def test_construction_with_valid_arguments(self):
+ preprocessing_fn = lambda ds: ds.batch(10)
+ ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn)
+ self.assertIs(ds.example_selector_proto, _TEST_EXAMPLE_SELECTOR)
+ self.assertIs(ds.preprocessing_fn, preprocessing_fn)
+
+ def test_is_data_spec_or_structure(self):
+ preprocessing_fn = lambda ds: ds.batch(10)
+ ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn)
+ self.assertTrue(data_spec.is_data_spec_or_structure(ds))
+ self.assertTrue(data_spec.is_data_spec_or_structure([ds, ds]))
+ self.assertTrue(data_spec.is_data_spec_or_structure({'a': ds}))
+ self.assertFalse(data_spec.is_data_spec_or_structure(10))
+ self.assertFalse(data_spec.is_data_spec_or_structure({'a': 10}))
+
+ def test_type_signature(self):
+ def parsing_fn(serialized_example):
+ parsing_dict = {
+ 'key': tf.io.FixedLenFeature(shape=[1], dtype=tf.int64),
+ }
+ parsed_example = tf.io.parse_example(serialized_example, parsing_dict)
+ return collections.OrderedDict([('key', parsed_example['key'])])
+
+ preprocessing_fn = lambda ds: ds.map(parsing_fn)
+ ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn)
+
+ expected_type = tff.SequenceType(
+ tff.types.to_type(
+ collections.OrderedDict(
+ [('key', tf.TensorSpec(shape=(1,), dtype=tf.int64))]
+ )
+ )
+ )
+ self.assertEqual(ds.type_signature, expected_type)
+
+
+if __name__ == '__main__':
+ tf.compat.v1.enable_v2_behavior()
+ absltest.main()
diff --git a/fcp/artifact_building/federated_compute_plan_builder.py b/fcp/artifact_building/federated_compute_plan_builder.py
new file mode 100644
index 0000000..9ed4778
--- /dev/null
+++ b/fcp/artifact_building/federated_compute_plan_builder.py
@@ -0,0 +1,1802 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A library responsible for building Federated Compute plans.
+
+This library builds TFF-backed plans, using the `MapReduceForm` object
+output by the TFF compiler pipeline.
+"""
+
+import collections
+from collections.abc import Callable, Iterable, Mapping, Sequence
+import enum
+from typing import Optional, TypeVar, Union
+
+import attr
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import artifact_constants
+from fcp.artifact_building import checkpoint_type
+from fcp.artifact_building import checkpoint_utils
+from fcp.artifact_building import data_spec
+from fcp.artifact_building import graph_helpers
+from fcp.artifact_building import proto_helpers
+from fcp.artifact_building import tensor_utils
+from fcp.artifact_building import type_checks
+from fcp.artifact_building import variable_helpers
+from fcp.protos import plan_pb2
+from fcp.tensorflow import append_slices
+from fcp.tensorflow import delete_file
+
+SECURE_SUM_BITWIDTH_URI = 'federated_secure_sum_bitwidth'
+SECURE_SUM_URI = 'federated_secure_sum'
+SECURE_MODULAR_SUM_URI = 'federated_secure_modular_sum'
+
+
+class SecureAggregationTensorShapeError(Exception):
+ """Error raised when secagg tensors do not have fully defined shape."""
+
+
+@enum.unique
+class ClientPlanType(enum.Enum):
+ """Option adjusting client plan type during plan building.
+
+ Values:
+ TENSORFLOW: The default value. Uses a TF client graph for client
+ computation.
+ EXAMPLE_QUERY: Uses an example query containing client computation logic in
+ the provided example selector(s).
+ """
+
+ TENSORFLOW = 'tensorflow'
+ EXAMPLE_QUERY = 'example_query'
+
+
+# A type representing a potentially nested struct of integers.
+IntStruct = Union[
+ int,
+ Mapping[str, Union['IntStruct', int]],
+ Sequence[Union['IntStruct', int]],
+]
+
+
+def _compute_secagg_parameters(
+ mrf: tff.backends.mapreduce.MapReduceForm,
+) -> tuple[IntStruct, IntStruct, IntStruct]:
+ """Executes the TensorFlow logic that computes the SecAgg parameters.
+
+ This function makes use of `mrf.secure_sum_bitwidth`,
+ `mrf.secure_sum_max_input`, and `mrf.secure_modular_sum_modulus` to derive
+ the parameters needed for the SecAgg protocol variants.
+
+ Args:
+ mrf: An instance of `tff.backends.mapreduce.MapReduceForm`.
+
+ Returns:
+ A 3-tuple of `bitwidth`, `max_input` and `moduli` structures of parameters
+ for the associated SecAgg variants.
+ """
+ type_checks.check_type(mrf, tff.backends.mapreduce.MapReduceForm, name='mrf')
+ secagg_parameters = []
+ with tf.Graph().as_default() as g:
+ for name, computation in [
+ ('bitwidth', mrf.secure_sum_bitwidth),
+ ('max_input', mrf.secure_sum_max_input),
+ ('modulus', mrf.secure_modular_sum_modulus),
+ ]:
+ secagg_parameters.append(
+ graph_helpers.import_tensorflow(name, computation)
+ )
+ with tf.compat.v1.Session(graph=g) as sess:
+ flat_output = sess.run(fetches=tf.nest.flatten(secagg_parameters))
+ return tf.nest.pack_sequence_as(secagg_parameters, flat_output)
+
+
+# A side-channel through which one tensor is securely aggregated.
+@attr.s
+class SecAggSideChannel:
+ # The name of the tensor being aggregated in the client and server graphs.
+ tensor_name: str = attr.ib()
+ # A proto describing how the side-channel is to be aggregated.
+ side_channel_proto: plan_pb2.SideChannel = attr.ib()
+ # A placeholder tensor into which the sidechannel aggregation is filled.
+ placeholder: tf.Tensor = attr.ib()
+ # The variable to feed into the server graph.
+ update_var: tf.Variable = attr.ib()
+
+
+SecAggParam = TypeVar('SecAggParam')
+
+
+def _create_secagg_sidechannels(
+ intrinsic_name: str,
+ update_type: variable_helpers.AllowedTffTypes,
+ get_modulus_scheme: Callable[
+ [SecAggParam], plan_pb2.SideChannel.SecureAggregand
+ ],
+ params: list[SecAggParam],
+) -> list[SecAggSideChannel]:
+ """Returns `SecAggSideChannel`s for tensors aggregated with `intrinsic_name`.
+
+ This method also creates variables for the securely-aggregated tensors within
+ the current default graph using `create_vars_for_tff_type`.
+
+ Args:
+ intrinsic_name: The name of the intrinsic (e.g.
+ `federated_secure_sum_bitwidth`) with which the tensors in `update_type`
+ are being aggregated.
+ update_type: The TFF type representing a structure of all tensors being
+ aggregated with `intrinsic_name`.
+ get_modulus_scheme: A function which will get the modulus scheme being used.
+ This typically requires some additional per-tensor parameters which must
+ be supplied using `params`.
+ params: A list of arguments to pass to `set_modulus_scheme`. There must be
+ exactly one element in this list per tensor in `update_type`.
+
+ Returns:
+ A list of `SecAggSideChannel`s describing how to aggregate each tensor.
+ """
+ # For secure aggregation, we don't use a saver (but still store metadata in a
+ # CheckpointOp). Instead we create sidechannel tensors that get fed into the
+ # server graph.
+ update_vars = variable_helpers.create_vars_for_tff_type(
+ update_type, f'{intrinsic_name}_update'
+ )
+
+ # For tensors aggregated by secagg, we make sure the tensor names are aligned
+ # in both client and sever graph by getting the names from the same method.
+ tensor_names = variable_helpers.get_shared_secagg_tensor_names(
+ intrinsic_name, update_type
+ )
+ assert len(update_vars) == len(params) == len(tensor_names), (
+ 'The length of update_vars, params and tensor_names for'
+ f' {{intrinsic_name}} should be all equal, but found: {len(update_vars)},'
+ f' {len(params)}, and {len(tensor_names)}.'
+ )
+
+ results = []
+ for param, update_var, tensor_name in zip(params, update_vars, tensor_names):
+ secure_aggregand = get_modulus_scheme(param)
+ secure_aggregand.dimension.extend(
+ plan_pb2.SideChannel.SecureAggregand.Dimension(size=d.value)
+ for d in update_var.shape.dims
+ )
+ secure_aggregand.dtype = update_var.dtype.base_dtype.as_datatype_enum
+ placeholder = tf.compat.v1.placeholder(
+ update_var.dtype, update_var.get_shape()
+ )
+ side_channel_proto = plan_pb2.SideChannel(
+ secure_aggregand=secure_aggregand, restore_name=placeholder.name
+ )
+ results.append(
+ SecAggSideChannel(
+ tensor_name=tensor_name,
+ side_channel_proto=side_channel_proto,
+ placeholder=placeholder,
+ update_var=update_var,
+ )
+ )
+ return results
+
+
+def _read_secagg_update_from_sidechannel_into_vars(
+ *, # Require parameters to be named.
+ secagg_intermediate_update_vars: list[tf.Variable],
+ secure_sum_bitwidth_update_type: (variable_helpers.AllowedTffTypes),
+ bitwidths: list[int],
+ secure_sum_update_type: (variable_helpers.AllowedTffTypes),
+ max_inputs: list[int],
+ secure_modular_sum_update_type: (variable_helpers.AllowedTffTypes),
+ moduli: list[int],
+) -> plan_pb2.CheckpointOp:
+ """Creates the `read_secagg_update` op.
+
+ `read_secagg_update` is a `plan_pb2.CheckpointOp` and used to restore the
+ secagg tensors in server graph.
+
+ Args:
+ secagg_intermediate_update_vars: A list of variables to assign the
+ secagg_update_data in the `after_restore_op`.
+ secure_sum_bitwidth_update_type: The type of the tensors aggregated using
+ `bitwidth`-based secure sum.
+ bitwidths: The `bitwidth`s for the tensors that will be aggregated using
+ `bitwidth`-based secure summation.
+ secure_sum_update_type: The type of the tensors aggregated using
+ `max_input`-based secure sum.
+ max_inputs: The max_input`s for the tensors that will be aggregated using
+ `max_input`-based secure summation.
+ secure_modular_sum_update_type: The type of the tensors aggregated using
+ modular secure summation.
+ moduli: The `modulus`s for the tensors that will be aggregated using modular
+ secure summation.
+
+ Returns:
+ A `plan_pb2.CheckpointOp` which performs the `read_secagg_update`.
+ """
+ side_channels: list[SecAggSideChannel] = []
+
+ def _aggregand_for_bitwidth(bitwidth):
+ return plan_pb2.SideChannel.SecureAggregand(
+ quantized_input_bitwidth=bitwidth
+ )
+
+ side_channels += _create_secagg_sidechannels(
+ SECURE_SUM_BITWIDTH_URI,
+ secure_sum_bitwidth_update_type,
+ _aggregand_for_bitwidth,
+ bitwidths,
+ )
+
+ def _aggregand_for_max_input(max_input):
+ # Note the +1-- `max_input` is inclusive, so `base_modulus == max_input`
+ # would overflow maximum-valued inputs to zero.
+ base_modulus = max_input + 1
+ modulus_times_shard_size = (
+ plan_pb2.SideChannel.SecureAggregand.ModulusTimesShardSize(
+ base_modulus=base_modulus
+ )
+ )
+ return plan_pb2.SideChannel.SecureAggregand(
+ modulus_times_shard_size=modulus_times_shard_size
+ )
+
+ side_channels += _create_secagg_sidechannels(
+ SECURE_SUM_URI,
+ secure_sum_update_type,
+ _aggregand_for_max_input,
+ max_inputs,
+ )
+
+ def _aggregand_for_modulus(modulus):
+ fixed_modulus = plan_pb2.SideChannel.SecureAggregand.FixedModulus(
+ modulus=modulus
+ )
+ return plan_pb2.SideChannel.SecureAggregand(fixed_modulus=fixed_modulus)
+
+ side_channels += _create_secagg_sidechannels(
+ SECURE_MODULAR_SUM_URI,
+ secure_modular_sum_update_type,
+ _aggregand_for_modulus,
+ moduli,
+ )
+
+ # Operations assigning from sidechannel placeholders to update variables.
+ assign_placeholders_to_updates = []
+ # Operations assigning from update variables to the result variables.
+ assign_updates_to_intermediate = []
+ read_secagg_update = plan_pb2.CheckpointOp()
+ for intermediate_update_var, side_channel in zip(
+ secagg_intermediate_update_vars, side_channels
+ ):
+ assign_placeholders_to_updates.append(
+ side_channel.update_var.assign(side_channel.placeholder)
+ )
+ assign_updates_to_intermediate.append(
+ intermediate_update_var.assign(side_channel.update_var)
+ )
+ read_secagg_update.side_channel_tensors[side_channel.tensor_name].CopyFrom(
+ side_channel.side_channel_proto
+ )
+
+ read_secagg_update.before_restore_op = tf.group(
+ *(assign_placeholders_to_updates)
+ ).name
+ read_secagg_update.after_restore_op = tf.group(
+ *(assign_updates_to_intermediate)
+ ).name
+
+ return read_secagg_update
+
+
+def _merge_secagg_vars(
+ secure_sum_bitwidth_update_type: tff.Type,
+ secure_sum_update_type: tff.Type,
+ secure_modular_sum_update_type: tff.Type,
+ flattened_moduli: list[int],
+ variables: list[tf.Variable],
+ tensors: list[tf.Variable],
+) -> list[tf.Operation]:
+ """Generates a set of ops to `merge` secagg `tensors` into `variables`."""
+ if len(variables) != len(tensors):
+ raise ValueError(
+ 'Expected an equal number of variables and tensors, but found '
+ f'{len(variables)} variables and {len(tensors)} tensors.'
+ )
+ num_simple_add_vars = len(
+ tff.structure.flatten(
+ tff.to_type([
+ secure_sum_bitwidth_update_type,
+ secure_sum_update_type,
+ ])
+ )
+ )
+ num_modular_add_vars = len(
+ tff.structure.flatten(secure_modular_sum_update_type)
+ )
+ # There must be one variable and tensor for each tensor in the secure update
+ # types.
+ num_vars_from_types = num_simple_add_vars + num_modular_add_vars
+ if num_vars_from_types != len(variables):
+ raise ValueError(
+ 'Expected one variable for each leaf element of the secagg update, but '
+ f'found {len(variables)} variables and {num_vars_from_types} leaf '
+ 'elements in the following types:\n'
+ f'secure_sum_bitwidth_update_type: {secure_sum_bitwidth_update_type}\n'
+ f'secure_sum_update_type: {secure_sum_update_type}\n'
+ f'secure_modular_sum_update_type: {secure_modular_sum_update_type}\n'
+ )
+ if num_modular_add_vars != len(flattened_moduli):
+ raise ValueError(
+ 'Expected one modulus for each leaf element of the secure modular sum '
+ f'update type. Found {len(flattened_moduli)} moduli and '
+ f'{num_modular_add_vars} leaf elements in the secure modular sum '
+ f'update type:\n{secure_modular_sum_update_type}'
+ )
+ # Add `tensors` to `vars`, using simple addition for the first
+ # `num_secagg_simple_add_vars` variables and modular addition for the rest
+ # (those coming from `secure_modular_sum`).
+ ops = []
+ simple_add_vars = variables[:num_simple_add_vars]
+ simple_add_tensors = tensors[:num_simple_add_vars]
+ for variable, tensor in zip(simple_add_vars, simple_add_tensors):
+ ops.append(variable.assign_add(tensor))
+ modular_add_vars = variables[num_simple_add_vars:]
+ modular_add_tensors = tensors[num_simple_add_vars:]
+ for modulus, (variable, tensor) in zip(
+ flattened_moduli, zip(modular_add_vars, modular_add_tensors)
+ ):
+ new_sum = tf.math.add(variable.read_value(), tensor)
+ modular_sum = tf.math.floormod(new_sum, modulus)
+ ops.append(variable.assign(tf.reshape(modular_sum, tf.shape(variable))))
+ return ops
+
+
+def _build_server_graphs_from_distribute_aggregate_form(
+ daf: tff.backends.mapreduce.DistributeAggregateForm,
+ is_broadcast_empty: bool,
+ grappler_config: tf.compat.v1.ConfigProto,
+) -> tuple[
+ tf.compat.v1.GraphDef, tf.compat.v1.GraphDef, plan_pb2.ServerPhaseV2
+]:
+ """Generates the server plan components based on DistributeAggregateForm.
+
+ Derives the pre-broadcast, aggregation, and post-aggregation logical
+ components in the ServerPhaseV2 message that will be executed on the server.
+ The pre-broadcast and post-aggregation components are to be executed with a
+ single TF sess.run call using the corresponding GraphDef. The aggregation
+ component is to be executed natively (i.e. not using TensorFlow) according to
+ the aggregation messages contained in the ServerPhaseV2 message.
+
+ Args:
+ daf: An instance of `tff.backends.mapreduce.DistributeAggregateForm`.
+ is_broadcast_empty: A boolean indicating whether the broadcasted value from
+ the server is expected to be empty based on the DistributeAggregateForm,
+ in which case the server should broadcast a placeholder tf.int32 tensor as
+ empty checkpoints are not supported.
+ grappler_config: The config specifying Grappler optimizations for TFF-
+ generated graphs.
+
+ Returns:
+ A tuple containing the server_prepare GraphDef, the server_result GraphDef,
+ and the ServerPhaseV2 message.
+ """
+ # Generate the TensorFlow graph needed to execute the server_prepare step,
+ # including reading input checkpoints and writing output checkpoints.
+ server_prepare_input_tensors = []
+ server_prepare_target_nodes = []
+ with tf.Graph().as_default() as server_prepare_graph:
+ # Create the placeholders for the input and output filenames needed by
+ # the server_prepare step.
+ server_prepare_server_state_input_filepath_placeholder = (
+ tf.compat.v1.placeholder(
+ name='server_state_input_filepath', shape=(), dtype=tf.string
+ )
+ )
+ server_prepare_output_filepath_placeholder = tf.compat.v1.placeholder(
+ name='server_prepare_output_filepath', shape=(), dtype=tf.string
+ )
+ server_prepare_intermediate_state_output_filepath_placeholder = (
+ tf.compat.v1.placeholder(
+ name='server_intermediate_state_output_filepath',
+ shape=(),
+ dtype=tf.string,
+ )
+ )
+ server_prepare_input_tensors.extend([
+ server_prepare_server_state_input_filepath_placeholder,
+ server_prepare_output_filepath_placeholder,
+ server_prepare_intermediate_state_output_filepath_placeholder,
+ ])
+
+ # Restore the server state.
+ server_state_type = daf.server_prepare.type_signature.parameter
+ server_state_vars = variable_helpers.create_vars_for_tff_type(
+ server_state_type, name='server'
+ )
+ server_state_tensor_specs = tf.nest.map_structure(
+ variable_helpers.tensorspec_from_var, server_state_vars
+ )
+ server_state = checkpoint_utils.restore_tensors_from_savepoint(
+ server_state_tensor_specs,
+ server_prepare_server_state_input_filepath_placeholder,
+ )
+
+ # TODO(team): Add support for federated select slice generation.
+
+ # Perform the server_prepare step.
+ prepared_values, intermediate_state_values = (
+ graph_helpers.import_tensorflow(
+ 'server_prepare',
+ tff.framework.ConcreteComputation.from_building_block(
+ tff.backends.mapreduce.consolidate_and_extract_local_processing(
+ daf.server_prepare.to_building_block(), grappler_config
+ )
+ ),
+ server_state,
+ split_outputs=True,
+ )
+ )
+
+ # Create checkpoints storing the broadcast values and intermediate server
+ # state. If there is no broadcast value, create a checkpoint containing a
+ # placeholder tf.int32 constant since empty broadcasts are not supported.
+ # If there is no intermediate server state, don't create an intermediate
+ # server state checkpoint.
+ save_tensor_names = variable_helpers.variable_names_from_type(
+ daf.server_prepare.type_signature.result[0], name='client'
+ )
+ save_tensors = prepared_values
+ if is_broadcast_empty:
+ save_tensor_names = variable_helpers.variable_names_from_type(
+ tff.StructType([tf.int32]), name='client'
+ )
+ save_tensors = [tf.constant(0, tf.int32)]
+ prepared_values_save_op = tensor_utils.save(
+ filename=server_prepare_output_filepath_placeholder,
+ tensor_names=save_tensor_names,
+ tensors=save_tensors,
+ name='save_prepared_values_tensors',
+ )
+ server_prepare_target_nodes.append(prepared_values_save_op.name)
+
+ intermediate_state_empty = (
+ isinstance(daf.server_prepare.type_signature.result[1], tff.StructType)
+ and not daf.server_prepare.type_signature.result[1]
+ )
+ if not intermediate_state_empty:
+ intermediate_state_values_save_op = tensor_utils.save(
+ filename=server_prepare_intermediate_state_output_filepath_placeholder,
+ tensor_names=variable_helpers.variable_names_from_type(
+ daf.server_prepare.type_signature.result[1], 'intermediate_state'
+ ),
+ tensors=intermediate_state_values,
+ name='save_intermediate_state_values_tensors',
+ )
+ server_prepare_target_nodes.append(intermediate_state_values_save_op.name)
+
+ # Build aggregations.
+ # The client_to_server_aggregation computation is guaranteed to conform to
+ # a specific structure. It is a lambda computation whose result block contains
+ # locals that are exclusively aggregation-type intrinsics.
+ aggregations_bb = daf.client_to_server_aggregation.to_building_block()
+ aggregations_bb.check_lambda()
+ aggregations_bb.result.check_block() # pytype: disable=attribute-error
+
+ # Get lists of the TensorSpecProtos for the inputs and outputs of all
+ # intrinsic calls. These lists are formatted such that the ith entry
+ # represents the TensorSpecProtos for the ith intrinsic in the aggregation
+ # computation. Since intrinsics may have one or more args, the ith entry in
+ # the input TensorSpecProto list is itself a list, where the jth entry
+ # represents the TensorSpecProtos corresponding to the jth argument of the
+ # ith intrinsic.
+ grouped_input_tensor_specs = variable_helpers.get_grouped_input_tensor_specs_for_aggregations(
+ aggregations_bb,
+ artifact_constants.AGGREGATION_INTRINSIC_ARG_SELECTION_INDEX_TO_NAME_DICT,
+ )
+ grouped_output_tensor_specs = (
+ variable_helpers.get_grouped_output_tensor_specs_for_aggregations(
+ aggregations_bb
+ )
+ )
+ assert len(grouped_input_tensor_specs) == len(grouped_output_tensor_specs)
+
+ intrinsic_uris = [
+ local_value.function.intrinsic_def().uri
+ for _, local_value in aggregations_bb.result.locals # pytype: disable=attribute-error
+ ]
+ assert len(intrinsic_uris) == len(grouped_output_tensor_specs)
+
+ # Each intrinsic input arg can be a struct or even a nested struct, which
+ # requires the intrinsic to be applied independently to each element (e.g. a
+ # tff.federated_sum call applied to a struct will result in a federated_sum
+ # aggregation message for each element of the struct). Note that elements of
+ # structs can themselves be multi-dimensional tensors. When an intrinsic call
+ # has multiple args with mismatching structure (e.g. a federated_weighted_mean
+ # intrinsic applied to a 2D struct value arg and scalar weight arg), some args
+ # will need to be "scaled up" via repetition to match the args with the
+ # "largest" structure.
+ aggregations = []
+ for intrinsic_index, (input_tensor_specs, output_tensor_specs) in enumerate(
+ zip(grouped_input_tensor_specs, grouped_output_tensor_specs)
+ ):
+ # Generate the aggregation messages for this intrinsic call.
+ max_input_struct_length = max([len(x) for x in input_tensor_specs])
+ max_struct_length = max(max_input_struct_length, len(output_tensor_specs))
+ for i in range(max_struct_length):
+ intrinsic_args = []
+ for j, _ in enumerate(input_tensor_specs):
+ # Scale up any "smaller" structure args by reusing their last element.
+ tensor_spec = input_tensor_specs[j][
+ min(i, len(input_tensor_specs[j]) - 1)
+ ]
+ if tensor_spec.name.startswith('update'):
+ intrinsic_args.append(
+ plan_pb2.ServerAggregationConfig.IntrinsicArg(
+ input_tensor=tensor_spec.experimental_as_proto()
+ )
+ )
+ else:
+ intrinsic_args.append(
+ plan_pb2.ServerAggregationConfig.IntrinsicArg(
+ state_tensor=tensor_spec.experimental_as_proto()
+ )
+ )
+ aggregations.append(
+ plan_pb2.ServerAggregationConfig(
+ intrinsic_uri=intrinsic_uris[intrinsic_index],
+ intrinsic_args=intrinsic_args,
+ # Scale up the output structure by reusing the last element if
+ # needed.
+ output_tensors=[
+ output_tensor_specs[
+ min(i, len(output_tensor_specs) - 1)
+ ].experimental_as_proto()
+ ],
+ )
+ )
+
+ # Generate the TensorFlow graph needed to execute the server_result step,
+ # including reading input checkpoints, writing output checkpoints, and
+ # generating output tensors.
+ server_result_input_tensors = []
+ server_result_output_tensors = []
+ server_result_target_nodes = []
+ with tf.Graph().as_default() as server_result_graph:
+ # Create the placeholders for the input and output filenames needed by
+ # the server_result step.
+ server_result_intermediate_state_input_filepath_placeholder = (
+ tf.compat.v1.placeholder(
+ name='server_intermediate_state_input_filepath',
+ shape=(),
+ dtype=tf.string,
+ )
+ )
+ server_result_aggregate_result_input_filepath_placeholder = (
+ tf.compat.v1.placeholder(
+ name='aggregate_result_input_filepath', shape=(), dtype=tf.string
+ )
+ )
+ server_result_server_state_output_filepath_placeholder = (
+ tf.compat.v1.placeholder(
+ name='server_state_output_filepath', shape=(), dtype=tf.string
+ )
+ )
+ server_result_input_tensors.extend([
+ server_result_intermediate_state_input_filepath_placeholder,
+ server_result_aggregate_result_input_filepath_placeholder,
+ server_result_server_state_output_filepath_placeholder,
+ ])
+
+ # Restore the intermediate server state.
+ intermediate_state = []
+ if not intermediate_state_empty:
+ intermediate_state_type = daf.server_result.type_signature.parameter[0]
+ intermediate_state_vars = variable_helpers.create_vars_for_tff_type(
+ intermediate_state_type, 'intermediate_state'
+ )
+ intermediate_state_tensor_specs = tf.nest.map_structure(
+ variable_helpers.tensorspec_from_var, intermediate_state_vars
+ )
+ intermediate_state = checkpoint_utils.restore_tensors_from_savepoint(
+ intermediate_state_tensor_specs,
+ server_result_intermediate_state_input_filepath_placeholder,
+ )
+
+ # Restore the aggregation results.
+ aggregate_result_type = tff.StructType(
+ [daf.server_result.type_signature.parameter[1]]
+ )
+ aggregate_result_vars = variable_helpers.create_vars_for_tff_type(
+ aggregate_result_type, 'intermediate_update'
+ )
+ aggregate_result_tensor_specs = tf.nest.map_structure(
+ variable_helpers.tensorspec_from_var, aggregate_result_vars
+ )
+ aggregate_result = checkpoint_utils.restore_tensors_from_savepoint(
+ aggregate_result_tensor_specs,
+ server_result_aggregate_result_input_filepath_placeholder,
+ )
+
+ # Perform the server_result step.
+ server_state_values, server_output_values = graph_helpers.import_tensorflow(
+ 'server_result',
+ tff.framework.ConcreteComputation.from_building_block(
+ tff.backends.mapreduce.consolidate_and_extract_local_processing(
+ daf.server_result.to_building_block(), grappler_config
+ )
+ ),
+ (intermediate_state, aggregate_result),
+ split_outputs=True,
+ )
+
+ # Create checkpoints storing the updated server state.
+ server_state_save_op = tensor_utils.save(
+ filename=server_result_server_state_output_filepath_placeholder,
+ tensor_names=variable_helpers.variable_names_from_type(
+ daf.server_result.type_signature.result[0], 'server'
+ ),
+ tensors=server_state_values,
+ name='save_server_state_tensors',
+ )
+ server_result_target_nodes.append(server_state_save_op.name)
+
+ # Generate the output TensorSpecProtos for the server metrics if some exist.
+ server_output_empty = (
+ isinstance(daf.server_result.type_signature.result[1], tff.StructType)
+ and not daf.server_result.type_signature.result[1]
+ )
+ if not server_output_empty:
+ metric_names = variable_helpers.variable_names_from_type(
+ daf.server_result.type_signature.result[1], 'server'
+ )
+ metric_tensors = [
+ tf.identity(tensor, name)
+ for tensor, name in zip(server_output_values, metric_names)
+ ]
+ for metric in metric_tensors:
+ server_result_output_tensors.append(
+ proto_helpers.make_tensor_spec_from_tensor(
+ metric
+ ).experimental_as_proto()
+ )
+
+ # Create the TensorflowSpec messages for the pre-broadcast (server_prepare)
+ # and post-aggregation (server_result) steps.
+ tensorflow_spec_prepare = plan_pb2.TensorflowSpec(
+ input_tensor_specs=[
+ proto_helpers.make_tensor_spec_from_tensor(t).experimental_as_proto()
+ for t in server_prepare_input_tensors
+ ],
+ target_node_names=server_prepare_target_nodes,
+ )
+ tensorflow_spec_result = plan_pb2.TensorflowSpec(
+ input_tensor_specs=[
+ proto_helpers.make_tensor_spec_from_tensor(t).experimental_as_proto()
+ for t in server_result_input_tensors
+ ],
+ output_tensor_specs=server_result_output_tensors,
+ target_node_names=server_result_target_nodes,
+ )
+
+ # Create the IORouter messages for the pre-broadcast (server_prepare) and
+ # post-aggregation (server_result) steps.
+ server_prepare_io_router = plan_pb2.ServerPrepareIORouter(
+ prepare_server_state_input_filepath_tensor_name=server_prepare_server_state_input_filepath_placeholder.name,
+ prepare_output_filepath_tensor_name=server_prepare_output_filepath_placeholder.name,
+ prepare_intermediate_state_output_filepath_tensor_name=server_prepare_intermediate_state_output_filepath_placeholder.name,
+ )
+ server_result_io_router = plan_pb2.ServerResultIORouter(
+ result_intermediate_state_input_filepath_tensor_name=server_result_intermediate_state_input_filepath_placeholder.name,
+ result_aggregate_result_input_filepath_tensor_name=server_result_aggregate_result_input_filepath_placeholder.name,
+ result_server_state_output_filepath_tensor_name=server_result_server_state_output_filepath_placeholder.name,
+ )
+
+ server_phase_v2 = plan_pb2.ServerPhaseV2(
+ tensorflow_spec_prepare=tensorflow_spec_prepare,
+ prepare_router=server_prepare_io_router,
+ aggregations=aggregations,
+ tensorflow_spec_result=tensorflow_spec_result,
+ result_router=server_result_io_router,
+ )
+
+ return (
+ server_prepare_graph.as_graph_def(),
+ server_result_graph.as_graph_def(),
+ server_phase_v2,
+ )
+
+
+def _build_server_graph(
+ mrf: tff.backends.mapreduce.MapReduceForm,
+ broadcast_tff_type: variable_helpers.AllowedTffTypes,
+ is_broadcast_empty: bool,
+ flattened_bitwidths: list[int],
+ flattened_max_inputs: list[int],
+ flattened_moduli: list[int],
+ write_metrics_to_checkpoint: bool = True,
+ additional_checkpoint_metadata_var_fn: Optional[
+ Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]]
+ ] = None,
+ experimental_client_update_format: checkpoint_type.CheckpointFormatType = checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES,
+) -> tuple[
+ tf.compat.v1.GraphDef,
+ plan_pb2.CheckpointOp,
+ plan_pb2.ServerPhase,
+ list[tf.TensorSpec],
+]:
+ """Builds the `tf.Graph` that will run on the server.
+
+ Args:
+ mrf: A `MapReduceForm` object containing the different computations to
+ combine into a single server graph.
+ broadcast_tff_type: A `tff.Type` object that specifies the tensors in the
+ model that are broadcasted and aggregated.
+ is_broadcast_empty: boolean indicating whether the broadcasted value from
+ the server was initially empty.
+ flattened_bitwidths: The `bitwidth`s for the tensors that will be aggregated
+ using `bitwidth`-based secure summation.
+ flattened_max_inputs: The max_input`s for the tensors that will be
+ aggregated using `max_input`-based secure summation.
+ flattened_moduli: The `modulus`s for the tensors that will be aggregated
+ using modular secure summation.
+ write_metrics_to_checkpoint: If False, revert to legacy behavior where
+ metrics values were handled by post-processing separate from the outputted
+ checkpoint. Regardless, they will additionally continue to be written to
+ recordio and accumulator checkpoints as defined by the Plan proto.
+ additional_checkpoint_metadata_var_fn: An optional method that takes in a
+ server state type, a server metrics type, and a boolean determining
+ whether to revert to legacy metrics behavior to produce additional
+ metadata variables.
+ experimental_client_update_format: Determines how the client update will be
+ interpreted. The value has to match experimental_checkpoint_write argument
+ of the _build_client_graph_with_tensorflow_spec call.
+
+ Returns:
+ A `tuple` containing the following (in order):
+ - A server `tf.GraphDef`,
+ - A server checkpoint,
+ - A server phase proto message, and
+ - A list of `tf.TensorSpec`s for the broadcasted values.
+ """
+ (
+ simpleagg_update_type,
+ secure_sum_bitwidth_update_type,
+ secure_sum_update_type,
+ secure_modular_sum_update_type,
+ ) = mrf.work.type_signature.result
+ with tf.Graph().as_default() as server_graph:
+ # Creates all server-side variables and savepoints for both the coordinator
+ # and the intermediate aggregators.
+ # server_state_type will be a SERVER-placed federated type.
+ server_state_type, server_metrics_type = mrf.type_signature.result
+ assert server_state_type.is_federated(), server_state_type
+ assert server_state_type.placement == tff.SERVER, server_state_type
+ # server_metrics_type can be a tff.FederatedType or a structure containing
+ # tff.FederatedTypes.
+ if isinstance(server_metrics_type, tff.FederatedType):
+ # We need to check for server metrics without the placement so
+ # tff.structure.flatten works correctly.
+ has_server_metrics = bool(
+ tff.structure.flatten(server_metrics_type.member)
+ )
+ else:
+ has_server_metrics = bool(tff.structure.flatten(server_metrics_type))
+ if isinstance(server_metrics_type, tff.TensorType) or (
+ isinstance(server_metrics_type, tff.FederatedType)
+ and isinstance(server_metrics_type.member, tff.TensorType)
+ ):
+ # Single tensor; must be wrapped inside of a NamedTuple for proper
+ # variable initialization.
+ server_metrics_type = tff.StructType([server_metrics_type])
+ (
+ server_state_vars,
+ server_metrics_vars,
+ metadata_vars,
+ server_savepoint,
+ ) = checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
+ server_state_type=server_state_type,
+ server_metrics_type=server_metrics_type,
+ write_metrics_to_checkpoint=write_metrics_to_checkpoint,
+ additional_checkpoint_metadata_var_fn=(
+ additional_checkpoint_metadata_var_fn
+ ),
+ )
+
+ # TODO(team): Switch to `tf.save()` in lieu of savers to avoid the
+ # need to create client variables on the server.
+ client_vars_on_server, write_client = (
+ checkpoint_utils.create_state_vars_and_savepoint(
+ broadcast_tff_type, 'client'
+ )
+ )
+
+ secure_sum_update_types = [
+ secure_sum_bitwidth_update_type,
+ secure_sum_update_type,
+ secure_modular_sum_update_type,
+ ]
+ combined_intermediate_update_type = tff.StructType(
+ [mrf.zero.type_signature.result] + secure_sum_update_types
+ )
+
+ combined_intermediate_update_vars, write_intermediate_update = (
+ checkpoint_utils.create_state_vars_and_savepoint(
+ combined_intermediate_update_type, 'intermediate_update'
+ )
+ )
+ num_simpleagg_vars = len(combined_intermediate_update_vars) - len(
+ tff.structure.flatten(tff.to_type(secure_sum_update_types))
+ )
+ intermediate_update_vars = combined_intermediate_update_vars[
+ :num_simpleagg_vars
+ ]
+ secagg_intermediate_update_vars = combined_intermediate_update_vars[
+ num_simpleagg_vars:
+ ]
+
+ read_secagg_update = _read_secagg_update_from_sidechannel_into_vars(
+ secagg_intermediate_update_vars=secagg_intermediate_update_vars,
+ secure_sum_bitwidth_update_type=secure_sum_bitwidth_update_type,
+ bitwidths=flattened_bitwidths,
+ secure_sum_update_type=secure_sum_update_type,
+ max_inputs=flattened_max_inputs,
+ secure_modular_sum_update_type=secure_modular_sum_update_type,
+ moduli=flattened_moduli,
+ )
+
+ combined_aggregated_update_vars, write_accumulators = (
+ checkpoint_utils.create_state_vars_and_savepoint(
+ combined_intermediate_update_type, 'aggregated_update'
+ )
+ )
+ aggregated_update_vars = combined_aggregated_update_vars[
+ :num_simpleagg_vars
+ ]
+ secagg_aggregated_update_vars = combined_aggregated_update_vars[
+ num_simpleagg_vars:
+ ]
+
+ # Throws in the initializer for all state variables, to be executed prior
+ # to restoring the savepoint. Run this variable initializer prior to
+ # restoring from the savepoint to allow the vars to be overwritten by the
+ # savepoint in this case, and so they do not get re-executed after being
+ # overwritten. Also include the metrics vars here in case the execution
+ # environment wants to read those in.
+ server_vars_initializer = tf.compat.v1.variables_initializer(
+ server_state_vars + metadata_vars + server_metrics_vars,
+ 'initialize_server_state_and_non_state_vars',
+ )
+ server_savepoint.before_restore_op = server_vars_initializer.name
+
+ # In graph mode, TensorFlow does not allow creating a
+ # `tf.compat.v1.train.Saver` when there are no variables. As a result,
+ # calling `create_state_vars_and_savepoint` below will fail when there are
+ # no SimpleAgg variables (e.g., all results are aggregated via SecAgg). In
+ # this case, there are no client checkpoints, and hence, no need to populate
+ # the `read_update` field.
+ if num_simpleagg_vars > 0:
+ # Run the initializer for update vars prior to restoring the client update
+ update_vars, read_update = (
+ checkpoint_utils.create_state_vars_and_savepoint(
+ simpleagg_update_type, artifact_constants.UPDATE
+ )
+ )
+ update_vars_initializer = tf.compat.v1.variables_initializer(
+ update_vars, 'initialize_update_vars'
+ )
+ if (
+ experimental_client_update_format
+ == checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_READ
+ ):
+ graph = tf.compat.v1.get_default_graph()
+ checkpoint_pl = graph.get_tensor_by_name(
+ read_update.saver_def.filename_tensor_name
+ )
+ merge_checkpoint_slices = append_slices.merge_appended_slices(
+ checkpoint_pl, 'merge_checkpoint_slices'
+ )
+ init_merge = tf.group(update_vars_initializer, merge_checkpoint_slices)
+ read_update.before_restore_op = init_merge.name
+ else:
+ read_update.before_restore_op = update_vars_initializer.name
+ else:
+ # Create a empty list for `update_vars` when there are no SimpleAgg
+ # variables, to be compatible with the `accumulated_values` defined below.
+ update_vars = []
+
+ # Copy the intermediate aggregator's update saver for use on coordinator.
+ read_intermediate_update = plan_pb2.CheckpointOp()
+ read_intermediate_update.CopyFrom(write_intermediate_update)
+
+ # Condition all the remaining logic on the variable initializers, since
+ # intermediate aggregators are supposed to be stateless (no savepoint, and
+ # therefore no `before_restore_op`, either).
+ with tf.control_dependencies(
+ [
+ tf.compat.v1.variables_initializer(
+ (intermediate_update_vars + aggregated_update_vars),
+ 'initialize_accumulator_vars',
+ )
+ ]
+ ):
+ # Embeds the `zero` logic and hooks it up to `after_restore_op` of
+ # server's checkpointed state (shared between the coordinator and the
+ # intermediate aggregators). The zeros get assigned to
+ # `intermediate_update_vars` and to the `aggregated_update_vars` at the
+ # very beginning, right after restoring from `server_savepoint`.
+ zero_values = graph_helpers.import_tensorflow('zero', mrf.zero)
+ assign_zero_ops = tf.nest.map_structure(
+ lambda variable, value: variable.assign(value),
+ intermediate_update_vars,
+ zero_values,
+ ) + tf.nest.map_structure(
+ lambda variable, value: variable.assign(value),
+ aggregated_update_vars,
+ zero_values,
+ )
+
+ # Embeds the `prepare` logic, and hooks it up to `before_save_op` of
+ # client state (to be checkpointed and sent to the clients at the
+ # beginning of the round by the central coordinator).
+ with tf.control_dependencies(
+ [
+ tf.compat.v1.variables_initializer(
+ client_vars_on_server, 'initialize_client_vars_on_server'
+ )
+ ]
+ ):
+ # Configure the session token for `write_client` so that the `prepare`
+ # operation may be fed the callback ID for the `SaveSlices` op
+ # (necessary for plans containing `federated_select`).
+ write_client_session_token = tf.compat.v1.placeholder_with_default(
+ input='', shape=(), name='write_client_session_token'
+ )
+ prepared_values = graph_helpers.import_tensorflow(
+ 'prepare',
+ mrf.prepare,
+ server_state_vars,
+ session_token_tensor=write_client_session_token,
+ )
+ if is_broadcast_empty:
+ # If the broadcast was empty, don't assigning the sample incoming
+ # tf.int32 to anything.
+ client_state_assign_ops = [tf.no_op()]
+ else:
+ client_state_assign_ops = tf.nest.map_structure(
+ lambda variable, tensor: variable.assign(tensor),
+ client_vars_on_server,
+ prepared_values,
+ )
+ write_client.before_save_op = tf.group(*client_state_assign_ops).name
+ write_client.session_token_tensor_name = write_client_session_token.name
+
+ # Embeds the `accumulate` logic, and hooks up the assignment of a client
+ # update to the intermediate update to `aggregate_into_accumulators_op`.
+ accumulated_values = graph_helpers.import_tensorflow(
+ 'accumulate', mrf.accumulate, (intermediate_update_vars, update_vars)
+ )
+ intermediate_update_assign_ops = tf.nest.map_structure(
+ lambda variable, tensor: variable.assign(tensor),
+ intermediate_update_vars,
+ accumulated_values,
+ )
+ aggregate_into_accumulators_op = tf.group(
+ *intermediate_update_assign_ops
+ ).name
+
+ secagg_aggregated_update_init = tf.compat.v1.variables_initializer(
+ secagg_aggregated_update_vars
+ )
+
+ # Reset the accumulators in `phase_init_op`, after variable initializers
+ # and after restoring from the savepoint.
+ phase_init_op = tf.group(
+ *(assign_zero_ops + [secagg_aggregated_update_init])
+ ).name
+
+ # Embeds the `merge` logic, and hooks up the assignment of an intermediate
+ # update to the top-level aggregate update at the coordinator to
+ # `intermediate_aggregate_into_accumulators_op`.
+ merged_values = graph_helpers.import_tensorflow(
+ 'merge', mrf.merge, (aggregated_update_vars, intermediate_update_vars)
+ )
+ aggregated_update_assign_ops = tf.nest.map_structure(
+ lambda variable, tensor: variable.assign(tensor),
+ aggregated_update_vars,
+ merged_values,
+ )
+
+ secagg_aggregated_update_ops = _merge_secagg_vars(
+ secure_sum_bitwidth_update_type,
+ secure_sum_update_type,
+ secure_modular_sum_update_type,
+ flattened_moduli,
+ secagg_aggregated_update_vars,
+ secagg_intermediate_update_vars,
+ )
+
+ intermediate_aggregate_into_accumulators_op = tf.group(
+ *(aggregated_update_assign_ops + secagg_aggregated_update_ops)
+ ).name
+
+ # Embeds the `report` and `update` logic, and hooks up the assignments of
+ # the results of the final update to the server state and metric vars, to
+ # be triggered by `apply_aggregrated_updates_op`.
+ simpleagg_reported_values = graph_helpers.import_tensorflow(
+ 'report', mrf.report, aggregated_update_vars
+ )
+
+ # NOTE: In MapReduceForm, the `update` method takes in the simpleagg vars
+ # and SecAgg vars as a tuple of two separate lists. However, here, as
+ # above, we concatenate the simpleagg values and the secure values into a
+ # single list. This mismatch is not a problem because the variables are all
+ # flattened either way when traveling in and out of the tensorflow graph.
+ combined_update_vars = (
+ simpleagg_reported_values + secagg_aggregated_update_vars
+ )
+ new_server_state_values, server_metrics_values = (
+ graph_helpers.import_tensorflow(
+ artifact_constants.UPDATE,
+ mrf.update,
+ (server_state_vars, combined_update_vars),
+ split_outputs=True,
+ )
+ )
+
+ assign_server_state_ops = tf.nest.map_structure(
+ lambda variable, tensor: variable.assign(tensor),
+ server_state_vars,
+ new_server_state_values,
+ )
+ assign_non_state_ops = tf.nest.map_structure(
+ lambda variable, value: variable.assign(value),
+ server_metrics_vars,
+ server_metrics_values,
+ )
+ all_assign_ops = assign_server_state_ops + assign_non_state_ops
+ apply_aggregrated_updates_op = tf.group(*all_assign_ops).name
+
+ # Constructs the metadata for server metrics to be included in the plan.
+ server_metrics = [
+ proto_helpers.make_metric(v, artifact_constants.SERVER_STATE_VAR_PREFIX)
+ for v in server_metrics_vars
+ ]
+
+ server_phase_kwargs = collections.OrderedDict(
+ phase_init_op=phase_init_op,
+ write_client_init=write_client,
+ read_aggregated_update=read_secagg_update,
+ write_intermediate_update=write_intermediate_update,
+ read_intermediate_update=read_intermediate_update,
+ intermediate_aggregate_into_accumulators_op=(
+ intermediate_aggregate_into_accumulators_op
+ ),
+ write_accumulators=write_accumulators,
+ apply_aggregrated_updates_op=apply_aggregrated_updates_op,
+ metrics=server_metrics,
+ )
+
+ if num_simpleagg_vars > 0:
+ # The `read_update` loads SimpleAgg updates from client checkpoints. The
+ # `aggregate_into_accumulators_op` aggregates SimpleAgg data after loading
+ # the client updates. No need to populate the two fields if there are no
+ # SimpleAgg variables (e.g., if all results are aggregated via SecAgg).
+ server_phase_kwargs['read_update'] = read_update
+ server_phase_kwargs['aggregate_into_accumulators_op'] = (
+ aggregate_into_accumulators_op
+ )
+
+ server_phase = plan_pb2.ServerPhase(**server_phase_kwargs)
+
+ broadcasted_tensor_specs = tf.nest.map_structure(
+ variable_helpers.tensorspec_from_var, client_vars_on_server
+ )
+ server_graph_def = server_graph.as_graph_def()
+
+ if write_metrics_to_checkpoint:
+ server_graph_def = _redirect_save_saver_to_restore_saver_placeholder(
+ server_graph_def
+ )
+
+ return (
+ server_graph_def,
+ server_savepoint,
+ server_phase,
+ broadcasted_tensor_specs,
+ )
+
+
+def _redirect_save_saver_to_restore_saver_placeholder(
+ graph_def: tf.compat.v1.GraphDef,
+) -> tf.compat.v1.GraphDef:
+ """Updates save Saver's savepoint to point to restore Saver's placeholder.
+
+ NOTE: mutates the GraphDef passed in and returns the mutated GraphDef.
+
+ When we created the server_savepoint Saver when we are outputting all of
+ the metrics to the output checkpoint as well, we set different nodes for
+ saving and restoring so that we could save state + metrics and restore
+ just state. However, the only way to do so was to make two Savers and
+ splice them together. This meant that the save and restore operations
+ depend on two different placeholders for the checkpoint filename. To
+ avoid server changes that pass the same checkpoint name in twice to both
+ placeholders, we make a few changes to the server GraphDef so that the
+ saving op connects back to the placeholder for the restore operation.
+ Once this is done, the original save placeholder node will still exist in
+ the graph, but it won't be used by any part of the graph that connects to
+ an operation we care about.
+
+ Args:
+ graph_def: A `tf.compat.v1.GraphDef` to mutate.
+
+ Returns:
+ The mutated `tf.compat.v1.GraphDef` that was passed in as graph_def.
+ """
+ old_const_node = f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/Const'
+ new_const_node = (
+ f'{artifact_constants.SERVER_STATE_VAR_PREFIX}_savepoint/Const'
+ )
+ nodes_to_change = [
+ f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/save',
+ f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/control_dependency',
+ f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/RestoreV2',
+ ]
+ num_changed_nodes = 0
+ for node in graph_def.node:
+ if node.name in nodes_to_change:
+ input_index = 0
+ for input_index, input_node in enumerate(node.input):
+ if input_node == old_const_node:
+ node.input[input_index] = new_const_node
+ break
+ assert input_index != len(
+ node.input
+ ), 'Missed input arg in saver GraphDef rewriting.'
+ num_changed_nodes = num_changed_nodes + 1
+ if num_changed_nodes == len(nodes_to_change):
+ # Once we've changed all of the callsites, we stop.
+ return graph_def
+ return graph_def
+
+
+def _build_client_graph_with_tensorflow_spec(
+ client_work_comp: tff.Computation,
+ dataspec,
+ broadcasted_tensor_specs: Iterable[tf.TensorSpec],
+ is_broadcast_empty: bool,
+ *,
+ experimental_checkpoint_write: checkpoint_type.CheckpointFormatType = checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES,
+) -> tuple[tf.compat.v1.GraphDef, plan_pb2.ClientPhase]:
+ """Builds the client graph and ClientPhase with TensorflowSpec populated.
+
+ This function builds a client phase with tensorflow specs proto.
+
+ Args:
+ client_work_comp: A `tff.Computation` that represents the TensorFlow logic
+ run on-device.
+ dataspec: Either an instance of `data_spec.DataSpec` or a nested structure
+ of these that matches the structure of the first element of the input to
+ `client_work_comp`.
+ broadcasted_tensor_specs: A list of `tf.TensorSpec` containing the name and
+ dtype of the variables arriving via the broadcast checkpoint.
+ is_broadcast_empty: A boolean indicating whether the MapReduce form
+ initially called for an empty broadcast. In this case the
+ broadcasted_tensor_specs will contain a single tf.int32, but it will be
+ ignored.
+ experimental_checkpoint_write: Determines the format of the final client
+ update checkpoint. The value affects required operations and might have
+ performance implications.
+
+ Returns:
+ A `tuple` of the client TensorFlow GraphDef and the client phase protocol
+ message.
+
+ Raises:
+ SecureAggregationTensorShapeError: If SecAgg tensors do not have all
+ dimensions of their shape fully defined.
+ ValueError: If any of the arguments are found to be in an unexpected form.
+ """
+ if (
+ not isinstance(client_work_comp.type_signature.parameter, tff.StructType)
+ or len(client_work_comp.type_signature.parameter) < 1
+ ):
+ raise ValueError(
+ 'client_work_comp.type_signature.parameter should be a '
+ '`tff.StructType` with length >= 1, but found: {p}.'.format(
+ p=client_work_comp.type_signature.parameter
+ )
+ )
+
+ if (
+ not isinstance(client_work_comp.type_signature.result, tff.StructType)
+ or len(client_work_comp.type_signature.result) != 4
+ ):
+ raise ValueError(
+ 'client_work_comp.type_signature.result should be a '
+ '`tff.StructType` with length == 4, but found: {r}.'.format(
+ r=client_work_comp.type_signature.result
+ )
+ )
+
+ (
+ simpleagg_update_type,
+ secure_sum_bitwidth_update_type,
+ secure_sum_update_type,
+ secure_modular_sum_update_type,
+ ) = client_work_comp.type_signature.result
+
+ # A list of tensors that will be passed into TensorFlow, corresponding to
+ # `plan_pb2.ClientPhase.tensorflow_spec.input_tensor_specs`. Note that the
+ # dataset token is excluded from this list. In general, this list should
+ # include the filepath placeholder tensors for the input checkpoint file and
+ # output checkpoint file.
+ input_tensors = []
+
+ # A list of tensor specs that should be fetched from TensorFlow, corresponding
+ # to `plan_pb2.ClientPhase.tensorflow_spec.output_tensor_specs`. In general,
+ # this list should include the tensors that are not in the output checkpoint
+ # file, such as secure aggregation tensors.
+ output_tensor_specs = []
+
+ # A list of node names in the client graph that should be executed but no
+ # output returned, corresponding to
+ # `plan_pb2.ClientPhase.tensorflow_spec.target_node_names`. In general, this
+ # list should include the op that creates the output checkpoint file.
+ target_nodes = []
+ with tf.Graph().as_default() as client_graph:
+ input_filepath_placeholder = None
+ if not is_broadcast_empty:
+ input_filepath_placeholder = tf.compat.v1.placeholder(
+ name='input_filepath', shape=(), dtype=tf.string
+ )
+ weights_from_server = checkpoint_utils.restore_tensors_from_savepoint(
+ broadcasted_tensor_specs, input_filepath_placeholder
+ )
+ input_tensors.append(input_filepath_placeholder)
+ else:
+ weights_from_server = []
+
+ # Add the custom Dataset ops to the graph.
+ token_placeholder, data_values, example_selector_placeholders = (
+ graph_helpers.embed_data_logic(
+ client_work_comp.type_signature.parameter[0], dataspec
+ )
+ )
+
+ # Embed the graph coming from TFF into the client work graph.
+ combined_update_tensors = graph_helpers.import_tensorflow(
+ 'work',
+ client_work_comp,
+ (data_values, weights_from_server),
+ split_outputs=False,
+ session_token_tensor=token_placeholder,
+ ) # pytype: disable=wrong-arg-types
+
+ num_simpleagg_tensors = len(tff.structure.flatten(simpleagg_update_type))
+ simpleagg_tensors = combined_update_tensors[:num_simpleagg_tensors]
+ secagg_tensors = combined_update_tensors[num_simpleagg_tensors:]
+
+ # For tensors aggregated by secagg, we make sure the tensor names are
+ # aligned in both client and sever graph by getting the names from the same
+ # method.
+ secagg_tensor_names = []
+ secagg_tensor_types = []
+ for uri, update_type in [
+ (SECURE_SUM_BITWIDTH_URI, secure_sum_bitwidth_update_type),
+ (SECURE_SUM_URI, secure_sum_update_type),
+ (SECURE_MODULAR_SUM_URI, secure_modular_sum_update_type),
+ ]:
+ secagg_tensor_names += variable_helpers.get_shared_secagg_tensor_names(
+ uri, update_type
+ )
+ secagg_tensor_types += tff.structure.flatten(update_type)
+
+ secagg_tensors = [
+ tf.identity(tensor, name=tensor_utils.bare_name(name))
+ for tensor, name in zip(secagg_tensors, secagg_tensor_names)
+ ]
+ for t, type_spec in zip(secagg_tensors, secagg_tensor_types):
+ secagg_tensor_spec = proto_helpers.make_tensor_spec_from_tensor(
+ t, shape_hint=type_spec.shape
+ )
+ output_tensor_specs.append(secagg_tensor_spec.experimental_as_proto())
+
+ # Verify that SecAgg Tensors have all dimension fully defined.
+ for tensor_spec in output_tensor_specs:
+ if not tf.TensorShape(tensor_spec.shape).is_fully_defined():
+ raise SecureAggregationTensorShapeError(
+ '`TensorflowSpec.output_tensor_specs` has unknown dimension.'
+ )
+
+ output_filepath_placeholder = None
+ if simpleagg_tensors:
+ output_filepath_placeholder = tf.compat.v1.placeholder(
+ dtype=tf.string, shape=(), name='output_filepath'
+ )
+ simpleagg_variable_names = variable_helpers.variable_names_from_type(
+ simpleagg_update_type, name=artifact_constants.UPDATE
+ )
+ if experimental_checkpoint_write in [
+ checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_WRITE,
+ checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_READ,
+ ]:
+ delete_op = delete_file.delete_file(output_filepath_placeholder)
+ with tf.control_dependencies([delete_op]):
+ append_ops = []
+ for tensor_name, tensor in zip(
+ simpleagg_variable_names, simpleagg_tensors
+ ):
+ append_ops.append(
+ tensor_utils.save(
+ filename=output_filepath_placeholder,
+ tensor_names=[tensor_name],
+ tensors=[tensor],
+ save_op=append_slices.append_slices,
+ )
+ )
+ if (
+ experimental_checkpoint_write
+ == checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_WRITE
+ ):
+ with tf.control_dependencies(append_ops):
+ save_op = append_slices.merge_appended_slices(
+ filename=output_filepath_placeholder
+ )
+ else:
+ # APPEND_SLICES_MERGE_READ
+ save_op = tf.group(*append_ops)
+
+ elif (
+ experimental_checkpoint_write
+ == checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES
+ ):
+ save_op = tensor_utils.save(
+ filename=output_filepath_placeholder,
+ tensor_names=simpleagg_variable_names,
+ tensors=simpleagg_tensors,
+ name='save_client_update_tensors',
+ )
+ else:
+ raise NotImplementedError(
+ f'Unsupported CheckpointFormatType {experimental_checkpoint_write}.'
+ )
+ input_tensors.append(output_filepath_placeholder)
+ target_nodes.append(save_op.name)
+
+ tensorflow_spec = plan_pb2.TensorflowSpec()
+ if token_placeholder is not None:
+ tensorflow_spec.dataset_token_tensor_name = token_placeholder.name
+ if input_tensors:
+ tensorflow_spec.input_tensor_specs.extend(
+ tf.TensorSpec.from_tensor(t, name=t.name).experimental_as_proto()
+ for t in input_tensors
+ )
+ if output_tensor_specs:
+ tensorflow_spec.output_tensor_specs.extend(output_tensor_specs)
+ if target_nodes:
+ tensorflow_spec.target_node_names.extend(target_nodes)
+ if example_selector_placeholders:
+ for placeholder in example_selector_placeholders:
+ # Generating the default TensorProto will create a TensorProto with an
+ # DT_INVALID DType. This identifies that there is a placeholder that is
+ # needed. In order to have the Plan proto be completely runnable, the
+ # value will need to be filled in with a real TensorProto that matches
+ # the shape/type of the expected input.
+ tensorflow_spec.constant_inputs[placeholder.name].dtype = 0
+
+ io_router = plan_pb2.FederatedComputeIORouter()
+ if input_filepath_placeholder is not None:
+ io_router.input_filepath_tensor_name = input_filepath_placeholder.name
+ if output_filepath_placeholder is not None:
+ io_router.output_filepath_tensor_name = output_filepath_placeholder.name
+ for secagg_tensor in secagg_tensors:
+ io_router.aggregations[secagg_tensor.name].CopyFrom(
+ plan_pb2.AggregationConfig(
+ secure_aggregation=plan_pb2.SecureAggregationConfig()
+ )
+ )
+
+ return client_graph.as_graph_def(), plan_pb2.ClientPhase(
+ tensorflow_spec=tensorflow_spec, federated_compute=io_router
+ )
+
+
+def _build_client_phase_with_example_query_spec(
+ client_work_comp: tff.Computation,
+ example_query_spec: plan_pb2.ExampleQuerySpec,
+) -> plan_pb2.ClientPhase:
+ """Builds the ClientPhase with `ExampleQuerySpec` populated.
+
+ Args:
+ client_work_comp: A `tff.Computation` that represents the TensorFlow logic
+ run on-device.
+ example_query_spec: Field containing output vector information for client
+ example query. The output vector names listed in the spec are expected to
+ be consistent with the output names we would produce in the
+ `MapReduceForm` client work computation, if we were to build a TF-based
+ plan from that `MapReduceForm`.
+
+ Returns:
+ A client phase part of the federated protocol.
+ """
+ expected_vector_names = set(
+ variable_helpers.variable_names_from_type(
+ client_work_comp.type_signature.result[0], artifact_constants.UPDATE
+ )
+ )
+ used_names = set()
+ io_router = plan_pb2.FederatedExampleQueryIORouter()
+ for example_query in example_query_spec.example_queries:
+ vector_names = set(example_query.output_vector_specs.keys())
+ if not all([name in expected_vector_names for name in vector_names]):
+ raise ValueError(
+ 'Found unexpected vector names in supplied `example_query_spec`. '
+ f'Expected names: {expected_vector_names}. '
+ f'Found unexpected names: {vector_names-expected_vector_names}.'
+ )
+
+ if any([name in used_names for name in vector_names]):
+ raise ValueError(
+ 'Duplicate vector names found in supplied `example_query_spec`. '
+ f'Duplicates: {vector_names.intersection(used_names)}'
+ )
+
+ used_names.update(vector_names)
+
+ for vector_name in vector_names:
+ io_router.aggregations[vector_name].CopyFrom(
+ plan_pb2.AggregationConfig(
+ tf_v1_checkpoint_aggregation=plan_pb2.TFV1CheckpointAggregation()
+ )
+ )
+
+ if used_names != expected_vector_names:
+ raise ValueError(
+ 'Not all expected vector names were in supplied `example_query_spec`.'
+ f' Expected names: {expected_vector_names}. Names not present in'
+ f' `example_query_spec`: {expected_vector_names-vector_names}'
+ )
+ return plan_pb2.ClientPhase(
+ example_query_spec=example_query_spec, federated_example_query=io_router
+ )
+
+
+def build_plan(
+ mrf: tff.backends.mapreduce.MapReduceForm,
+ daf: Optional[tff.backends.mapreduce.DistributeAggregateForm] = None,
+ dataspec: Optional[data_spec.NestedDataSpec] = None,
+ example_query_spec: Optional[plan_pb2.ExampleQuerySpec] = None,
+ grappler_config: Optional[tf.compat.v1.ConfigProto] = None,
+ additional_checkpoint_metadata_var_fn: Optional[
+ Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]]
+ ] = None,
+ experimental_client_checkpoint_write: checkpoint_type.CheckpointFormatType = checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES,
+ generate_server_phase_v2: bool = False,
+ write_metrics_to_checkpoint: bool = True,
+) -> plan_pb2.Plan:
+ """Constructs an instance of `plan_pb2.Plan` given a `MapReduceForm` instance.
+
+ Plans generated by this method are executable, but a number of features have
+ yet to be implemented.
+
+ These include:
+
+ - Setting metrics' `stat_name` field based on externally-supplied metadata,
+ such as that from the model stampers. Currently, these names are based on
+ the names of TensorFlow variables, which in turn are based on the TFF
+ type signatures.
+
+ - Populating the client `example_selector` field. Currently not set.
+
+ - Populating client-side `savepoint`. Currently not set.
+
+ - Populating the plan's `tensorflow_config_proto`. Currently not set.
+
+ - Setting a field in the plan that represets a token to drive the custom op
+ that iplements the client-side dataset. There is no such field in the plan
+ at the time of this writing.
+
+ - Populating plan fields related to secure aggregation and side channels,
+ such as the `read_aggregated_update` checkpoint op.
+
+ Args:
+ mrf: An instance of `tff.backends.mapreduce.MapReduceForm`.
+ daf: An instance of `tff.backends.mapreduce.DistributeAggregateForm`.
+ dataspec: If provided, either an instance of `data_spec.DataSpec` or a
+ nested structure of these that matches the structure of the first element
+ of the input to client-side processing computation `mrf.work`. If not
+ provided and `example_query_spec` is also not provided, then placeholders
+ are added to the client graph via `embed_data_logic()` and the example
+ selectors will need to be passed to the client via the `constant_inputs`
+ part of the `TensorflowSpec`. The constant_inputs field needs to be
+ populated outside of `build_plan()`. Can only provide one of `dataspec` or
+ `example_query_spec`.
+ example_query_spec: An instance of `plan_pb2.ExampleQuerySpec`. If provided
+ it is assumed a light weight client plan should be constructed. No client
+ graph will be included in the produced plan object. Instead the generated
+ plan will have an `ExampleQuerySpec` and `FederatedExampleQueryIORouter`.
+ Can only supply one of `dataspec` or `example_query_spec`.
+ grappler_config: The config specifying Grappler optimizations for TFF-
+ generated graphs. Should be provided if daf is provided.
+ additional_checkpoint_metadata_var_fn: An optional method that takes in a
+ server state type, a server metrics type, and a boolean determining
+ whether to revert to legacy metrics behavior to produce additional
+ metadata variables.
+ experimental_client_checkpoint_write: Determines the style of writing of the
+ client checkpoint (client->server communication). The value affects the
+ operation used and might have impact on overall task performance.
+ generate_server_phase_v2: Iff `True`, will produce a ServerPhaseV2 message
+ in the plan in addition to a ServerPhase message.
+ write_metrics_to_checkpoint: If False, revert to legacy behavior where
+ metrics values were handled by post-processing separate from the outputted
+ checkpoint. Regardless, they will additionally continue to be written to
+ recordio and accumulator checkpoints as defined by the Plan proto.
+
+ Returns:
+ An instance of `plan_pb2.Plan` corresponding to MapReduce form `mrf`.
+
+ Raises:
+ TypeError: If the arguments are of the wrong types.
+ ValueError: If any of the arguments are found to be in an unexpected form.
+ """
+ type_checks.check_type(mrf, tff.backends.mapreduce.MapReduceForm, name='mrf')
+ client_plan_type = (
+ ClientPlanType.TENSORFLOW
+ if example_query_spec is None
+ else ClientPlanType.EXAMPLE_QUERY
+ )
+
+ if example_query_spec is not None:
+ if dataspec is not None:
+ raise ValueError(
+ '`example_query_spec` or `dataspec` cannot both be specified.'
+ )
+
+ with tff.framework.get_context_stack().install(
+ tff.test.create_runtime_error_context()
+ ):
+ is_broadcast_empty = (
+ isinstance(mrf.prepare.type_signature.result, tff.StructType)
+ and not mrf.prepare.type_signature.result
+ )
+ if is_broadcast_empty:
+ # This MapReduceForm does not send any server state to clients, however we
+ # need something to satisfy current restrictions from the FCP server.
+ # Use a placeholder scalar int.
+ broadcast_tff_type = tff.TensorType(tf.int32)
+ else:
+ broadcast_tff_type = mrf.prepare.type_signature.result
+
+ # Execute the bitwidths TFF computation using the default TFF executor.
+ bitwidths, max_inputs, moduli = _compute_secagg_parameters(mrf)
+ # Note: the variables below are flat lists, even though
+ # `secure_sum_bitwidth_update_type`
+ # could potentially represent a large group of nested tensors. In order
+ # for each var to line up with the appropriate bitwidth, we must also
+ # flatten the list of bitwidths.
+ flattened_bitwidths = tff.structure.flatten(bitwidths)
+ flattened_max_inputs = tff.structure.flatten(max_inputs)
+ flattened_moduli = tff.structure.flatten(moduli)
+
+ (
+ server_graph_def,
+ server_savepoint,
+ server_phase,
+ broadcasted_tensor_specs,
+ ) = _build_server_graph(
+ mrf,
+ broadcast_tff_type,
+ is_broadcast_empty,
+ flattened_bitwidths,
+ flattened_max_inputs,
+ flattened_moduli,
+ write_metrics_to_checkpoint,
+ additional_checkpoint_metadata_var_fn,
+ experimental_client_update_format=experimental_client_checkpoint_write,
+ )
+
+ if client_plan_type == ClientPlanType.TENSORFLOW:
+ client_graph_def, client_phase = _build_client_graph_with_tensorflow_spec(
+ mrf.work,
+ dataspec,
+ broadcasted_tensor_specs,
+ is_broadcast_empty,
+ experimental_checkpoint_write=experimental_client_checkpoint_write,
+ )
+ elif client_plan_type == ClientPlanType.EXAMPLE_QUERY:
+ client_phase = _build_client_phase_with_example_query_spec(
+ mrf.work, example_query_spec
+ )
+ else:
+ raise ValueError(
+ f'Unexpected value for `client_plan_type`: {client_plan_type}'
+ )
+
+ combined_phases = plan_pb2.Plan.Phase(
+ server_phase=server_phase, client_phase=client_phase
+ )
+
+ if generate_server_phase_v2:
+ assert daf
+ assert grappler_config
+ (server_graph_def_prepare, server_graph_def_result, server_phase_v2) = (
+ _build_server_graphs_from_distribute_aggregate_form(
+ daf, is_broadcast_empty, grappler_config
+ )
+ )
+ combined_phases.server_phase_v2.CopyFrom(server_phase_v2)
+
+ plan = plan_pb2.Plan(
+ version=1, server_savepoint=server_savepoint, phase=[combined_phases]
+ )
+
+ plan.server_graph_bytes.Pack(server_graph_def)
+ if client_plan_type == ClientPlanType.TENSORFLOW:
+ plan.client_graph_bytes.Pack(client_graph_def)
+
+ if generate_server_phase_v2:
+ plan.server_graph_prepare_bytes.Pack(server_graph_def_prepare)
+ plan.server_graph_result_bytes.Pack(server_graph_def_result)
+ return plan
+
+
+def build_cross_round_aggregation_execution(
+ mrf: tff.backends.mapreduce.MapReduceForm,
+) -> bytes:
+ """Constructs an instance of `plan_pb2.CrossRoundAggregationExecution`.
+
+ Args:
+ mrf: An instance of `tff.backends.mapreduce.MapReduceForm`.
+
+ Returns:
+ A serialized instance of `plan_pb2.CrossRoundAggregationExecution` for given
+ `mrf`.
+ """
+ type_checks.check_type(mrf, tff.backends.mapreduce.MapReduceForm, name='mrf')
+
+ server_metrics_type = mrf.update.type_signature.result[1]
+ (
+ simpleagg_update_type,
+ secure_sum_bitwidth_update_type,
+ secure_sum_update_type,
+ secure_modular_sum_update_type,
+ ) = mrf.work.type_signature.result
+ # We don't ever work directly on `simpleagg_update_type` because client
+ # updates are transformed by `accumulate` and `merge` before ever being passed
+ # into cross-round aggregation.
+ del simpleagg_update_type
+ simpleagg_merge_type = mrf.merge.type_signature.result
+ flattened_moduli = tff.structure.flatten(mrf.secure_modular_sum_modulus())
+
+ if not server_metrics_type:
+ # No metrics to aggregrate; will initialize to no-op.
+ server_metrics_type = tff.StructType([])
+ elif isinstance(server_metrics_type, tff.TensorType):
+ # Single tensor metric; must be wrapped inside of a NamedTuple for proper
+ # variable initialiazation.
+ server_metrics_type = tff.StructType([server_metrics_type])
+ combined_aggregated_update_type = tff.StructType([
+ simpleagg_merge_type,
+ secure_sum_bitwidth_update_type,
+ secure_sum_update_type,
+ secure_modular_sum_update_type,
+ ])
+
+ with tf.Graph().as_default() as cross_round_aggregation_graph:
+ server_state_vars = variable_helpers.create_vars_for_tff_type(
+ mrf.update.type_signature.parameter[0],
+ artifact_constants.SERVER_STATE_VAR_PREFIX,
+ )
+
+ combined_aggregated_update_vars, read_aggregated_update = (
+ checkpoint_utils.create_state_vars_and_savepoint(
+ combined_aggregated_update_type, 'aggregated_update'
+ )
+ )
+
+ num_simpleagg_vars = len(tff.structure.flatten(simpleagg_merge_type))
+
+ aggregated_update_vars = combined_aggregated_update_vars[
+ :num_simpleagg_vars
+ ]
+ secagg_aggregated_update_vars = combined_aggregated_update_vars[
+ num_simpleagg_vars:
+ ]
+
+ # Add a new output for metrics_loader `merge` and `report`.
+ combined_final_accumulator_vars, read_write_final_accumulators = (
+ checkpoint_utils.create_state_vars_and_savepoint(
+ combined_aggregated_update_type, 'final_accumulators'
+ )
+ )
+
+ final_accumulator_vars = combined_final_accumulator_vars[
+ :num_simpleagg_vars
+ ]
+ secagg_final_accumulator_vars = combined_final_accumulator_vars[
+ num_simpleagg_vars:
+ ]
+
+ var_init_op = tf.compat.v1.initializers.variables(
+ server_state_vars
+ + combined_aggregated_update_vars
+ + combined_final_accumulator_vars
+ )
+
+ # Embeds the MapReduce form `merge` logic.
+ merged_values = graph_helpers.import_tensorflow(
+ 'merge', mrf.merge, (final_accumulator_vars, aggregated_update_vars)
+ )
+ final_accumulator_assign_ops = tf.nest.map_structure(
+ lambda variable, tensor: variable.assign(tensor),
+ final_accumulator_vars,
+ merged_values,
+ )
+
+ # SecAgg tensors' aggregation is not provided in the imported TensorFlow,
+ # but is instead fixed based on the operator (e.g. `assign_add` for
+ # variables passed into `secure_sum`).
+ secagg_final_accumulator_ops = _merge_secagg_vars(
+ secure_sum_bitwidth_update_type,
+ secure_sum_update_type,
+ secure_modular_sum_update_type,
+ flattened_moduli,
+ secagg_final_accumulator_vars,
+ secagg_aggregated_update_vars,
+ )
+ final_accumulator_op = tf.group(
+ *(final_accumulator_assign_ops + secagg_final_accumulator_ops)
+ ).name
+
+ # Embeds the `report` and `update` logic, and hooks up the assignments of
+ # the results of the final update to the server state and metric vars, to
+ # be triggered by `apply_aggregrated_updates_op`.
+ simpleagg_reported_values = graph_helpers.import_tensorflow(
+ 'report', mrf.report, final_accumulator_vars
+ )
+ combined_final_vars = (
+ simpleagg_reported_values + secagg_final_accumulator_vars
+ )
+ (_, server_metric_values) = graph_helpers.import_tensorflow(
+ artifact_constants.UPDATE,
+ mrf.update,
+ (server_state_vars, combined_final_vars),
+ split_outputs=True,
+ )
+
+ server_metrics_names = variable_helpers.variable_names_from_type(
+ server_metrics_type, name=artifact_constants.SERVER_STATE_VAR_PREFIX
+ )
+
+ flattened_metrics_types = tff.structure.flatten(server_metrics_type)
+ measurements = [
+ proto_helpers.make_measurement(v, s, a)
+ for v, s, a in zip(
+ server_metric_values, server_metrics_names, flattened_metrics_types
+ )
+ ]
+
+ cross_round_aggregation_execution = plan_pb2.CrossRoundAggregationExecution(
+ init_op=var_init_op.name,
+ read_aggregated_update=read_aggregated_update,
+ merge_op=final_accumulator_op,
+ read_write_final_accumulators=read_write_final_accumulators,
+ measurements=measurements,
+ )
+
+ cross_round_aggregation_execution.cross_round_aggregation_graph_bytes.Pack(
+ cross_round_aggregation_graph.as_graph_def()
+ )
+
+ return cross_round_aggregation_execution.SerializeToString()
diff --git a/fcp/artifact_building/graph_helpers.py b/fcp/artifact_building/graph_helpers.py
new file mode 100644
index 0000000..6bd2804
--- /dev/null
+++ b/fcp/artifact_building/graph_helpers.py
@@ -0,0 +1,659 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for manipulating TensorFlow graph logic."""
+
+from typing import Optional, Union
+
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import data_spec
+from fcp.artifact_building import tensor_utils
+from fcp.artifact_building import type_checks
+from fcp.tensorflow import external_dataset
+from tensorflow_federated.proto.v0 import computation_pb2
+
+TfValue = Union[tf.Variable, tf.Tensor]
+DatasetTensor = tf.Tensor
+Argument = Union[TfValue, list[TfValue], DatasetTensor]
+Args = Optional[Union[Argument, tuple[Argument, ...]]]
+
+Result = Argument
+MaybeSplitOutputs = Union[Result, tuple[Result, ...]]
+
+
+EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX = 'example_selector'
+
+
+def generate_example_selector_placeholders(
+ type_spec: tff.Type,
+ name_prefix: str,
+):
+ """Generates list of tff.compat.v1.placeholders for each leaf in a type spec.
+
+ The order of the placeholders aligns with the order given by
+ tff.structure.to_elements().
+
+ Placeholders will be named by concatenating the name_prefix arg with the list
+ of indexes at each level of the struct to get to the placeholder's leaf in the
+ tff.Type.
+
+ Args:
+ type_spec: A type spec to infer the list of placeholders from. This is
+ expected to be a tff.SequenceType or a tff.StructType, and if it is a
+ tff.StructType, it is expected to be a tree of tff.StructTypes with
+ tff.SequenceTypes at the leaves. This is expected to reflect the TFF type
+ signature of the input client data.
+ name_prefix: The name prefix that should be used when naming each
+ placeholder.
+
+ Returns:
+ A list of tf.compat.v2.placeholders.
+ """
+ type_spec = tff.to_type(type_spec)
+ type_checks.check_type(
+ type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
+ )
+ if type_spec.is_sequence():
+ # Each client input is a sequence of serialized `tf.Example`s, which is why
+ # the leaves of these TFF type signatures are sequences. Each input sequence
+ # of `tf.Example`s requires a single `ExampleSelector` that determines that
+ # stream of `tf.Example`s is selected from the data store, which is why we
+ # only have a single placeholder for the `ExampleSelector`.
+ return [tf.compat.v1.placeholder(tf.string, shape=[], name=name_prefix)]
+ else:
+ type_spec.check_struct()
+ type_spec_elements = tff.structure.to_elements(type_spec)
+ placeholders = []
+ for element_index, (_, element_type) in enumerate(type_spec_elements):
+ placeholders.extend(
+ generate_example_selector_placeholders(
+ element_type, f'{name_prefix}_{element_index}'
+ )
+ )
+ return placeholders
+
+
+def embed_data_logic(
+ client_data_type: tff.Type,
+ dataspec: Optional[data_spec.NestedDataSpec] = None,
+) -> tuple[tf.Tensor, list[MaybeSplitOutputs], list[tf.Tensor]]:
+ """Embeds the data logic into the current TensorFlow graph.
+
+ Adds dataset ops to the current graph, using the custom `ExternalDataset`
+ which returns a placeholder token. The initialization op and data values are
+ also returned.
+
+ Args:
+ client_data_type: The TFF type signature of the input client data.
+ dataspec: If provided, either an instance of `data_spec.DataSpec` or a
+ nested structure of these that matches the structure of the first element
+ of the input to the client work part of the computation.
+
+ Returns:
+ A `tuple` containing the following (in order):
+ token_placeholder: A dataset token placeholder tensor
+ data_values: A list of dataset output values
+ example_selector_placeholders: A possibly empty list of placeholders used
+ for passing in example selector information into the client graph. This
+ list will be empty iff dataspec is supplied.
+
+ Raises:
+ ValueError: If the number of dataset output from one data source is not 1.
+ ValueError: If a node exists in the graph already that contains a node with
+ the same name as the example selector placeholders.
+ """
+ data_values = []
+ # Embeds the token placeholder for the custom ExternalDataset op.
+ token_placeholder = tf.compat.v1.placeholder(
+ tf.string, shape=[], name='data_token'
+ )
+
+ example_selector_placeholders = []
+ if dataspec is None:
+ example_selector_placeholders = generate_example_selector_placeholders(
+ client_data_type, EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX
+ )
+ # If the first placeholder does not have the expected prefix, then it is due
+ # to other variables in the graph, likely created from the input
+ # tff.Computation, having the special name. This check ensures that no other
+ # variables use this special example selector placeholder name and we can
+ # easily extract example selector placeholders in the generated artifact.
+ if example_selector_placeholders and (
+ not (
+ example_selector_placeholders[0].name.startswith(
+ f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}:'
+ )
+ or example_selector_placeholders[0].name.startswith(
+ f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}_0'
+ )
+ )
+ ):
+ raise ValueError(
+ 'Graph already contains a placeholder with name '
+ f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}. Please '
+ 'avoid the use of this special name.'
+ )
+ data_sources = make_data_sources_without_dataspec(client_data_type)
+ assert len(example_selector_placeholders) == len(data_sources)
+ else:
+ data_sources = make_data_sources_with_dataspec(client_data_type, dataspec)
+
+ # Embeds data source computations into the current graph.
+ for index, data_comp in enumerate(data_sources):
+ data_comp_import_args = [token_placeholder]
+ if example_selector_placeholders:
+ data_comp_import_args.append(example_selector_placeholders[index])
+ ds_values = import_tensorflow(
+ 'data_{}'.format(index), data_comp, data_comp_import_args
+ ) # pytype: disable=wrong-arg-types
+ if len(ds_values) != 1:
+ raise ValueError(
+ 'Expected one dataset output from a data source, found {}.'.format(
+ str(len(ds_values))
+ )
+ )
+ data_values.extend(ds_values)
+
+ return token_placeholder, data_values, example_selector_placeholders
+
+
+def import_tensorflow(
+ name: str,
+ comp: tff.framework.ConcreteComputation,
+ args: Args = None,
+ split_outputs: bool = False,
+ session_token_tensor: Optional[tf.Tensor] = None,
+) -> MaybeSplitOutputs:
+ """Imports a tensorflow computation into the current graph.
+
+ Args:
+ name: The string name to use as the graph import prefix.
+ comp: An instance of `tff.framework.ConcreteComputation` with just the
+ `tensorflow` section.
+ args: Either a single argument, a tuple of arguments, or None. An argument
+ must be either: - a Python `list` containing either tensors or variables,
+ or - a single variant tensor representing a dataset input.
+ split_outputs: Whether to unpack the result tuple into a Python tuple. If
+ `True`, `import_tensorflow` will return a tuple with multiple result
+ objects, corresponding to the return elements in the type signature of
+ `comp`. Notice that the return type signature of `comp` must be a tuple in
+ this case. If `False`, `import_tensorflow` will return the entire result
+ in a flattened form as a single Python result object. Each Python result
+ object, similar to the argumens in `args`, will be either a Python `list`
+ of variant tensors or a singleton Python list containing only the dataset
+ variant tensor.
+ session_token_tensor: A tensor in the current graph containing the "session
+ token" of the TensorFlow being imported. This is useful for passing a
+ session-global identifier into the graph for use with ops like
+ `ServeSlices` and `ExternalDataset` that take in a token which references
+ session-global state.
+
+ Returns:
+ One of:
+ - A single result (Python `list` of variable value or variant tensors) if
+ `split_outputs` is `False`.
+ - A Python `tuple` of such results, if `split_outputs` is `True`.
+
+ Raises:
+ TypeError: If the arguments are of the wrong types.
+ """
+ type_checks.check_type(name, str, name='name')
+ type_checks.check_type(comp, tff.framework.ConcreteComputation, name='comp')
+ type_checks.check_type(split_outputs, bool, name='split_outputs')
+
+ comp_proto = tff.framework.ConcreteComputation.get_proto(comp)
+ type_checks.check_type(
+ comp_proto, computation_pb2.Computation, name='comp_proto'
+ )
+
+ which_comp = comp_proto.WhichOneof('computation')
+ if which_comp != 'tensorflow':
+ raise TypeError(
+ 'Expected a TensorFlow computation, found {}.'.format(which_comp)
+ )
+ if args is None:
+ input_map = None
+ elif isinstance(args, tuple):
+ which_binding = comp_proto.tensorflow.parameter.WhichOneof('binding')
+ if which_binding != 'struct':
+ raise TypeError(
+ 'Expected a struct binding with a struct of args, found {}.'.format(
+ which_binding
+ )
+ )
+ input_map = {}
+ for index, arg in enumerate(args):
+ input_map.update(
+ create_tensor_map(
+ comp_proto.tensorflow.parameter.struct.element[index], arg
+ )
+ )
+ else:
+ input_map = create_tensor_map(comp_proto.tensorflow.parameter, args)
+ if input_map is not None:
+ # Add remappings for all potential control dependencies in the graph as
+ # well. Since `tf.graph_util.import_graph_def` input map works on the tensor
+ # (not graph node) level, we must handle this case also.
+ def control_dep_name(name: str) -> str:
+ if name.startswith('^'):
+ return name
+ node_name = name.split(':', maxsplit=1)[0]
+ return f'^{node_name}'
+
+ input_map.update(
+ {
+ control_dep_name(k): control_dep_name(v.name)
+ for k, v in input_map.items()
+ if not k.startswith('^')
+ }
+ )
+ input_map = {} if input_map is None else input_map
+ if (
+ session_token_tensor is not None
+ and comp_proto.tensorflow.session_token_tensor_name
+ ):
+ input_map[comp_proto.tensorflow.session_token_tensor_name] = (
+ session_token_tensor
+ )
+ if split_outputs:
+ return_elements = []
+ subset_sizes = []
+ which_binding = comp_proto.tensorflow.result.WhichOneof('binding')
+ if which_binding != 'struct':
+ raise TypeError(
+ 'If `split_outputs` is `True`, the result of the computation we are '
+ 'importing must be a `struct`; found {}.'.format(which_binding)
+ )
+ for binding in comp_proto.tensorflow.result.struct.element:
+ tensor_names = _list_tensor_names_in_binding(binding)
+ return_elements.extend(tensor_names)
+ subset_sizes.append(len(tensor_names))
+ else:
+ return_elements = _list_tensor_names_in_binding(
+ comp_proto.tensorflow.result
+ )
+ subset_sizes = [len(return_elements)]
+
+ graph_def = tensor_utils.import_graph_def_from_any(
+ comp_proto.tensorflow.graph_def
+ )
+
+ # We will be importing multiple GraphDefs into the server or client graphs.
+ # These individual graphs may have identifical `shared_name` attributes on
+ # variable ops, which causes the runtime to reference the same resource, which
+ # is highly undesired. We must uniquify the names before importing.
+ def uniquify_shared_names(
+ graph_def: tf.compat.v1.GraphDef, suffix: bytes
+ ) -> tf.compat.v1.GraphDef:
+ for x in graph_def.node:
+ shared_name = x.attr.get('shared_name')
+ if shared_name is not None:
+ if not shared_name.s:
+ # Encountered an empty string shared name, avoid creating a shared
+ # name that starts with an underscore (not allowed by TF).
+ shared_name.s = b'None'
+ shared_name.s += b'_' + suffix
+ return graph_def
+
+ uniquified_graph_def = uniquify_shared_names(
+ graph_def, suffix=name.encode('utf-8')
+ )
+ if comp_proto.tensorflow.initialize_op:
+ uniquified_graph_def = add_control_deps_for_init_op(
+ uniquified_graph_def, comp_proto.tensorflow.initialize_op
+ )
+ import_result = tf.graph_util.import_graph_def(
+ uniquified_graph_def,
+ input_map=input_map,
+ return_elements=return_elements,
+ name=name,
+ )
+
+ if split_outputs:
+ subsets = []
+ offset = 0
+ for subset_size in subset_sizes:
+ next_offset = offset + subset_size
+ subsets.append(import_result[offset:next_offset])
+ offset = next_offset
+ results = tuple(subsets)
+ else:
+ results = import_result[: subset_sizes[0]]
+ return results
+
+
+def _get_deps_for_graph_node(
+ graph_def: tf.compat.v1.GraphDef, node_name: str
+) -> set[str]:
+ """Returns the set of node names that a node named `node_name` depends on.
+
+ Note that this function does not work for nodes in the function library.
+
+ Args:
+ graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`.
+ node_name: The node name, a string.
+
+ Returns:
+ An instance of `set()` containing string names of the nodes `node_name`
+ depends on in graph_def.
+
+ Raises:
+ TypeError: If either argument is of the wrong type.
+ """
+ type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def')
+ type_checks.check_type(node_name, str, name='node_name')
+ input_map = {}
+ for node in graph_def.node:
+ input_map[node.name] = set(tensor_utils.bare_name(x) for x in node.input)
+ dependencies = set()
+ initial_singleton = set([node_name])
+ nodes_to_process = initial_singleton
+ while nodes_to_process:
+ dependencies.update(nodes_to_process)
+ nodes_to_process = set.union(
+ *[input_map[name] for name in nodes_to_process]
+ ).difference(dependencies)
+ return dependencies.difference(initial_singleton)
+
+
+def add_control_deps_for_init_op(
+ graph_def: tf.compat.v1.GraphDef, init_op: str
+) -> tf.compat.v1.GraphDef:
+ """Adds control deps on `init_op` to nodes in GraphDef.
+
+ Note that control deps are not added to any of the ancestors of `init_op`
+ (which would result in a control dep cycle) and control deps are not added to
+ any nodes in the function library of a GraphDef.
+
+ Args:
+ graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`.
+ init_op: The init op name, a string.
+
+ Returns:
+ The updated graph, an instance of `tf.compat.v1.GraphDef`.
+
+ Raises:
+ TypeError: If either argument is of the wrong type.
+ """
+ type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def')
+ type_checks.check_type(init_op, str, name='init_op')
+ init_op_str = tensor_utils.bare_name(init_op)
+ init_op_control_dep = '^{}'.format(init_op_str)
+ deps = _get_deps_for_graph_node(graph_def, init_op_str).union(
+ set([init_op_str])
+ )
+ new_graph_def = tf.compat.v1.GraphDef()
+ new_graph_def.CopyFrom(graph_def)
+ for new_node in new_graph_def.node:
+ if new_node.name not in deps:
+ node_inputs = new_node.input
+ if init_op_control_dep not in node_inputs:
+ new_node.input.append(init_op_control_dep)
+ return new_graph_def
+
+
+def create_tensor_map(
+ binding: computation_pb2.TensorFlow.Binding,
+ arg: list[Union[tf.Tensor, tf.Variable]],
+) -> dict[str, tf.Tensor]:
+ """Creates a `dict` mapping tensor names in the binding to tensors in `arg`.
+
+ Args:
+ binding: An instance of `computation_pb2.TensorFlow.Binding`.
+ arg: Either a singleton Python `list` with variant tensor in case of a
+ sequence binding, or a Python `list` of tensors or resource variables
+ otherwise for a tuple binding.
+
+ Returns:
+ An instance of Python `dict` with the map as specified above.
+
+ Raises:
+ TypeError: If the argument types are incorrect.
+ ValueError: If the arguments are malformed (e.g., multiple variant tensors).
+ """
+ type_checks.check_type(
+ binding, computation_pb2.TensorFlow.Binding, name='binding'
+ )
+ type_checks.check_type(arg, list, name='arg')
+ tensor_names_in_binding = _list_tensor_names_in_binding(binding)
+ which_binding = binding.WhichOneof('binding')
+ if which_binding == 'sequence':
+ if (len(tensor_names_in_binding) != 1) or (len(arg) != 1):
+ raise ValueError('Multiple variant tensors found.')
+ variant_tensor_name = tensor_names_in_binding[0]
+ arg = arg[0]
+ if not tf.is_tensor(arg):
+ raise TypeError('Expected a tensor, found {!r}.'.format(type(arg)))
+ if arg.dtype != tf.variant:
+ raise TypeError('Expected `tf.variant`, found {!r}.'.format(arg.dtype))
+ return {variant_tensor_name: arg}
+ else:
+ return {
+ k: v.read_value() if hasattr(v, 'read_value') else v
+ for k, v in zip(tensor_names_in_binding, arg)
+ }
+
+
+def _validate_data_comp(data_comp: tff.Computation, type_spec: tff.Type):
+ type_checks.check_type(data_comp.type_signature, tff.FunctionType)
+ if not type_spec.is_assignable_from(data_comp.type_signature.result):
+ type_mismatch_string = tff.types.type_mismatch_error_message(
+ type_spec,
+ data_comp.type_signature.result,
+ tff.types.TypeRelation.ASSIGNABLE,
+ )
+ raise TypeError(
+ 'The data source constructed with the supplied dataspec returns data '
+ 'which does not match type of request. Details of the mismatch:\n'
+ + type_mismatch_string
+ )
+
+
+def make_data_sources_with_dataspec(
+ type_spec: tff.Type, ds: data_spec.NestedDataSpec
+) -> list[tff.Computation]:
+ """Creates a list of computations that feed data into the graph using specified example selectors.
+
+ The computations use the custom ExternalDataset op to feed in example data.
+ The computations will expect one input:
+ -- A token specifying where the data store is on the device.
+ Example selectors that describes what data to take from the on-device data
+ store will be hard-coded into the computations.
+
+ Args:
+ type_spec: The TFF type signature of the output, which must be either a
+ sequence, or a named tuple of sequences.
+ ds: Either a single `data_spec.DataSpec`, or a nested structure of these,
+ made up of Python containers, that exactly matches the structure of the
+ `type_spec`.
+
+ Returns:
+ A list of `tff.Computation`s, each of which accepts a single `string`-typed
+ tensor as input (the token for the ExternalDataset op) and returns a
+ sequence as output (with the result that matches the corresponding part of
+ `type_spec`). The computations appear on the list in a depth-first order
+ (matching exactly the convention used in the
+ `_list_tensor_names_in_binding()` method below).
+
+ Raises:
+ TypeError: If the arguments are of the wrong types.
+ """
+ assert ds
+ type_spec = tff.to_type(type_spec)
+ type_checks.check_type(
+ type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
+ )
+ if type_spec.is_sequence():
+ type_checks.check_type(ds, data_spec.DataSpec)
+ assert isinstance(ds, data_spec.DataSpec)
+ assert ds.example_selector_proto is not None
+ sel_bytes = ds.example_selector_proto.SerializeToString()
+
+ @tff.tf_computation(tf.string)
+ def data_comp(token):
+ """The data source computation.
+
+ Args:
+ token: The token placeholder tensor (`tf.string`).
+
+ Returns:
+ An instance of `tf.data.Dataset`.
+ """
+ if ds.preprocessing_fn is not None:
+ processed_ds = ds.preprocessing_fn(
+ external_dataset.ExternalDataset(token=token, selector=sel_bytes)
+ )
+ else:
+ processed_ds = external_dataset.ExternalDataset(
+ token=token, selector=sel_bytes
+ )
+
+ if 'Dataset' not in type(processed_ds).__name__:
+ raise TypeError(
+ 'The preprocessing function returned an unrecognized non-dataset '
+ 'type {!r}.'.format(type(processed_ds))
+ )
+ return processed_ds
+
+ _validate_data_comp(data_comp, type_spec)
+ return [data_comp]
+ else:
+ type_spec.check_struct()
+ if isinstance(ds, data_spec.DataSpec):
+ raise TypeError(
+ 'Expected nested structure of `DataSpec`s conforming to '
+ f'the structure of the type {type_spec}. '
+ 'Found single `DataSpec` instead.'
+ )
+ ds = tff.structure.from_container(ds)
+ assert isinstance(ds, tff.structure.Struct)
+ type_spec_elements = tff.structure.to_elements(type_spec)
+ data_spec_elements = tff.structure.to_elements(ds)
+ type_spec_element_names = [str(k) for k, _ in type_spec_elements]
+ data_spec_element_names = [str(k) for k, _ in data_spec_elements]
+ if type_spec_element_names != data_spec_element_names:
+ raise TypeError(
+ 'Type vs. data spec elements names mismatch: {} vs. {}.'.format(
+ str(type_spec_element_names), str(data_spec_element_names)
+ )
+ )
+ elements = []
+ for element_index, (_, element_type) in enumerate(type_spec_elements):
+ elements.extend(
+ make_data_sources_with_dataspec(element_type, ds[element_index])
+ )
+ return elements
+
+
+def make_data_sources_without_dataspec(type_spec) -> list[tff.Computation]:
+ """Creates a list of computations that feed data into the graph.
+
+ The computations use the custom ExternalDataset op to feed in example data.
+ The computations will expect two inputs:
+ -- A token specifying where the data store is on the device.
+ -- An example selector that describes what data to take from the on-device
+ data store.
+
+ Args:
+ type_spec: The TFF type signature of the output, which must be either a
+ sequence, or a named tuple of sequences.
+
+ Returns:
+ A list of `tff.Computation`s, each of which accepts a single `string`-typed
+ tensor as input (the token for the ExternalDataset op) and returns a
+ sequence as output (with the result that matches the corresponding part of
+ `type_spec`). The computations appear on the list in a depth-first order
+ (matching exactly the convention used in the
+ `_list_tensor_names_in_binding()` method below).
+
+ Raises:
+ TypeError: If the arguments are of the wrong types.
+ """
+ type_spec = tff.to_type(type_spec)
+ type_checks.check_type(
+ type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
+ )
+ if type_spec.is_sequence():
+
+ @tff.tf_computation(tf.string, tf.string)
+ def data_comp(token, example_selector):
+ """The data source computation.
+
+ Args:
+ token: The token placeholder tensor (`tf.string`).
+ example_selector: The example selector placeholder tensor (`tf.string`).
+
+ Returns:
+ An instance of `tf.data.Dataset`.
+ """
+ processed_ds = external_dataset.ExternalDataset(
+ token=token, selector=example_selector
+ )
+
+ if 'Dataset' not in type(processed_ds).__name__:
+ raise TypeError(
+ 'The preprocessing function returned an unrecognized non-dataset '
+ 'type {!r}.'.format(type(processed_ds))
+ )
+ return processed_ds
+
+ _validate_data_comp(data_comp, type_spec)
+ return [data_comp]
+ else: # type_spec is a struct.
+ type_spec.check_struct()
+ type_spec_elements = tff.structure.to_elements(type_spec)
+ elements = []
+ for _, element_type in type_spec_elements:
+ elements.extend(make_data_sources_without_dataspec(element_type))
+ return elements
+
+
+def _list_tensor_names_in_binding(
+ binding: computation_pb2.TensorFlow.Binding,
+) -> list[str]:
+ """Returns a flat Python list of tensor names that appear in the `binding`.
+
+ Args:
+ binding: An instance of `computation_pb2.TensorFlow.Binding` in which any
+ sequence bindings must contain variant tensors.
+
+ Returns:
+ A list of `str` instances with tensor names that appear in `binding` in the
+ order in which they appear in the depth-first traversal of the potentially
+ nested binding structure.
+
+ Raises:
+ TypeError: If the arguments are of the wrong types.
+ """
+ type_checks.check_type(binding, computation_pb2.TensorFlow.Binding)
+ which_binding = binding.WhichOneof('binding')
+ if which_binding == 'tensor':
+ return [str(binding.tensor.tensor_name)]
+ elif which_binding == 'struct':
+ result = []
+ for element in binding.struct.element:
+ result.extend(_list_tensor_names_in_binding(element))
+ return result
+ elif which_binding == 'sequence':
+ which_sequence = binding.sequence.WhichOneof('binding')
+ if which_sequence != 'variant_tensor_name':
+ raise TypeError(
+ 'Expected a variant tensor in sequence binding, found {}.'.format(
+ which_sequence
+ )
+ )
+ return [binding.sequence.variant_tensor_name]
+ else:
+ raise TypeError('Unexpected type of binding {}.'.format(which_binding))
diff --git a/fcp/artifact_building/graph_helpers_test.py b/fcp/artifact_building/graph_helpers_test.py
new file mode 100644
index 0000000..819ac2f
--- /dev/null
+++ b/fcp/artifact_building/graph_helpers_test.py
@@ -0,0 +1,409 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for graph_helpers.py."""
+
+import collections
+
+from absl.testing import absltest
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import data_spec
+from fcp.artifact_building import graph_helpers
+from fcp.artifact_building import variable_helpers
+from fcp.protos import plan_pb2
+from tensorflow_federated.proto.v0 import computation_pb2
+
+TRAIN_URI = 'boo'
+TEST_URI = 'foo'
+NUM_PIXELS = 784
+FAKE_INPUT_DIRECTORY_TENSOR = tf.constant('/path/to/input_dir')
+
+
+class EmbedDataLogicTest(absltest.TestCase):
+
+ def assertTensorSpec(self, tensor, name, shape, dtype):
+ self.assertIsInstance(tensor, tf.Tensor)
+ self.assertEqual(tensor.name, name)
+ self.assertEqual(tensor.shape, shape)
+ self.assertEqual(tensor.dtype, dtype)
+
+ def test_one_dataset_of_integers_w_dataspec(self):
+ with tf.Graph().as_default():
+ token_placeholder, data_values, placeholders = (
+ graph_helpers.embed_data_logic(
+ tff.SequenceType((tf.string)),
+ data_spec.DataSpec(
+ plan_pb2.ExampleSelector(collection_uri='app://fake_uri')
+ ),
+ )
+ )
+
+ self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
+ self.assertLen(data_values, 1)
+ self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
+ self.assertEmpty(placeholders)
+
+ def test_two_datasets_of_integers_w_dataspec(self):
+ with tf.Graph().as_default():
+ token_placeholder, data_values, placeholders = (
+ graph_helpers.embed_data_logic(
+ collections.OrderedDict(
+ A=tff.SequenceType((tf.string)),
+ B=tff.SequenceType((tf.string)),
+ ),
+ collections.OrderedDict(
+ A=data_spec.DataSpec(
+ plan_pb2.ExampleSelector(collection_uri='app://foo')
+ ),
+ B=data_spec.DataSpec(
+ plan_pb2.ExampleSelector(collection_uri='app://bar')
+ ),
+ ),
+ )
+ )
+
+ self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
+
+ self.assertLen(data_values, 2)
+ self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
+ self.assertTensorSpec(data_values[1], 'data_1/Identity:0', [], tf.variant)
+ self.assertEmpty(placeholders)
+
+ def test_nested_dataspec(self):
+ with tf.Graph().as_default():
+ token_placeholder, data_values, placeholders = (
+ graph_helpers.embed_data_logic(
+ collections.OrderedDict(
+ A=collections.OrderedDict(B=tff.SequenceType((tf.string)))
+ ),
+ collections.OrderedDict(
+ A=collections.OrderedDict(
+ B=data_spec.DataSpec(
+ plan_pb2.ExampleSelector(collection_uri='app://foo')
+ )
+ )
+ ),
+ )
+ )
+
+ self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
+ self.assertLen(data_values, 1)
+ self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
+ self.assertEmpty(placeholders)
+
+ def test_one_dataset_of_integers_without_dataspec(self):
+ with tf.Graph().as_default():
+ token_placeholder, data_values, placeholders = (
+ graph_helpers.embed_data_logic(tff.SequenceType((tf.string)))
+ )
+
+ self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
+ self.assertLen(data_values, 1)
+ self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
+ self.assertLen(placeholders, 1)
+ self.assertEqual(placeholders[0].name, 'example_selector:0')
+
+ def test_two_datasets_of_integers_without_dataspec(self):
+ with tf.Graph().as_default():
+ token_placeholder, data_values, placeholders = (
+ graph_helpers.embed_data_logic(
+ collections.OrderedDict(
+ A=tff.SequenceType((tf.string)),
+ B=tff.SequenceType((tf.string)),
+ )
+ )
+ )
+
+ self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
+
+ self.assertLen(data_values, 2)
+ self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
+ self.assertTensorSpec(data_values[1], 'data_1/Identity:0', [], tf.variant)
+ self.assertLen(placeholders, 2)
+ self.assertEqual(placeholders[0].name, 'example_selector_0:0')
+ self.assertEqual(placeholders[1].name, 'example_selector_1:0')
+
+ def test_nested_input_without_dataspec(self):
+ with tf.Graph().as_default():
+ token_placeholder, data_values, placeholders = (
+ graph_helpers.embed_data_logic(
+ collections.OrderedDict(
+ A=collections.OrderedDict(B=tff.SequenceType((tf.string)))
+ )
+ )
+ )
+
+ self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
+ self.assertLen(data_values, 1)
+ self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
+ self.assertLen(placeholders, 1)
+ self.assertEqual(placeholders[0].name, 'example_selector_0_0:0')
+
+
+class GraphHelperTest(absltest.TestCase):
+
+ def test_import_tensorflow(self):
+ # NOTE: Minimal test for now, since this is exercised by other components,
+ # just a single example with a combo of all flavors of params and results.
+ @tff.tf_computation(tff.SequenceType(tf.int64), tf.int64)
+ def work(ds, x):
+ return x + 1, ds.map(lambda a: a + x)
+
+ with tf.Graph().as_default():
+ ds = tf.data.experimental.to_variant(tf.data.Dataset.range(3))
+ v = tf.constant(10, dtype=tf.int64)
+ y, ds2_variant = graph_helpers.import_tensorflow(
+ 'work', work, ([ds], [v]), split_outputs=True
+ )
+ ds2 = tf.data.experimental.from_variant(
+ ds2_variant[0], tf.TensorSpec([], tf.int64)
+ )
+ z = ds2.reduce(np.int64(0), lambda x, y: x + y)
+ with tf.compat.v1.Session() as sess:
+ self.assertEqual(sess.run(y[0]), 11)
+ self.assertEqual(sess.run(z), 33)
+
+ def test_import_tensorflow_with_session_token(self):
+ @tff.tf_computation
+ def return_value():
+ return tff.framework.get_session_token()
+
+ with tf.Graph().as_default():
+ x = tf.compat.v1.placeholder(dtype=tf.string)
+ output = graph_helpers.import_tensorflow(
+ 'return_value', comp=return_value, session_token_tensor=x
+ )
+ with tf.compat.v1.Session() as sess:
+ self.assertEqual(sess.run(output[0], feed_dict={x: 'value'}), b'value')
+
+ def test_import_tensorflow_with_control_dep_remap(self):
+ # Assert that importing graphdef remaps both regular and control dep inputs.
+ @tff.tf_computation(tf.int64, tf.int64)
+ def work(x, y):
+ # Insert a control dependency to ensure it is remapped during import.
+ with tf.compat.v1.control_dependencies([y]):
+ return tf.identity(x)
+
+ with tf.Graph().as_default():
+ x = tf.compat.v1.placeholder(dtype=tf.int64)
+ y = tf.compat.v1.placeholder(dtype=tf.int64)
+ output = graph_helpers.import_tensorflow(
+ 'control_dep_graph', comp=work, args=[x, y]
+ )
+ with tf.compat.v1.Session() as sess:
+ self.assertEqual(sess.run(output, feed_dict={x: 10, y: 20})[0], 10)
+
+ def test_add_control_deps_for_init_op(self):
+ # Creates a graph (double edges are regular dependencies, single edges are
+ # control dependencies) like this:
+ #
+ # ghi
+ # |
+ # def
+ # ||
+ # def:0 foo
+ # || // ||
+ # abc bar ||
+ # \ // \\ ||
+ # bak baz
+ #
+ graph_def = tf.compat.v1.GraphDef(
+ node=[
+ tf.compat.v1.NodeDef(name='foo', input=[]),
+ tf.compat.v1.NodeDef(name='bar', input=['foo']),
+ tf.compat.v1.NodeDef(name='baz', input=['foo', 'bar']),
+ tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']),
+ tf.compat.v1.NodeDef(name='abc', input=['def:0']),
+ tf.compat.v1.NodeDef(name='def', input=['^ghi']),
+ tf.compat.v1.NodeDef(name='ghi', input=[]),
+ ]
+ )
+ new_graph_def = graph_helpers.add_control_deps_for_init_op(graph_def, 'abc')
+ self.assertEqual(
+ ','.join(
+ '{}({})'.format(node.name, ','.join(node.input))
+ for node in new_graph_def.node
+ ),
+ (
+ 'foo(^abc),bar(foo,^abc),baz(foo,bar,^abc),'
+ 'bak(bar,^abc),abc(def:0),def(^ghi),ghi()'
+ ),
+ )
+
+ def test_create_tensor_map_with_sequence_binding_and_variant(self):
+ with tf.Graph().as_default():
+ variant_tensor = tf.data.experimental.to_variant(tf.data.Dataset.range(3))
+ input_map = graph_helpers.create_tensor_map(
+ computation_pb2.TensorFlow.Binding(
+ sequence=computation_pb2.TensorFlow.SequenceBinding(
+ variant_tensor_name='foo'
+ )
+ ),
+ [variant_tensor],
+ )
+ self.assertLen(input_map, 1)
+ self.assertCountEqual(list(input_map.keys()), ['foo'])
+ self.assertIs(input_map['foo'], variant_tensor)
+
+ def test_create_tensor_map_with_sequence_binding_and_multiple_variants(self):
+ with tf.Graph().as_default():
+ variant_tensor = tf.data.experimental.to_variant(tf.data.Dataset.range(3))
+ with self.assertRaises(ValueError):
+ graph_helpers.create_tensor_map(
+ computation_pb2.TensorFlow.Binding(
+ sequence=computation_pb2.TensorFlow.SequenceBinding(
+ variant_tensor_name='foo'
+ )
+ ),
+ [variant_tensor, variant_tensor],
+ )
+
+ def test_create_tensor_map_with_sequence_binding_and_non_variant(self):
+ with tf.Graph().as_default():
+ non_variant_tensor = tf.constant(1)
+ with self.assertRaises(TypeError):
+ graph_helpers.create_tensor_map(
+ computation_pb2.TensorFlow.Binding(
+ sequence=computation_pb2.TensorFlow.SequenceBinding(
+ variant_tensor_name='foo'
+ )
+ ),
+ [non_variant_tensor],
+ )
+
+ def test_create_tensor_map_with_non_sequence_binding_and_vars(self):
+ with tf.Graph().as_default():
+ vars_list = variable_helpers.create_vars_for_tff_type(
+ tff.to_type([('a', tf.int32), ('b', tf.int32)])
+ )
+ init_op = tf.compat.v1.global_variables_initializer()
+ assign_op = tf.group(
+ *(v.assign(tf.constant(k + 1)) for k, v in enumerate(vars_list))
+ )
+ input_map = graph_helpers.create_tensor_map(
+ computation_pb2.TensorFlow.Binding(
+ struct=computation_pb2.TensorFlow.StructBinding(
+ element=[
+ computation_pb2.TensorFlow.Binding(
+ tensor=computation_pb2.TensorFlow.TensorBinding(
+ tensor_name='foo'
+ )
+ ),
+ computation_pb2.TensorFlow.Binding(
+ tensor=computation_pb2.TensorFlow.TensorBinding(
+ tensor_name='bar'
+ )
+ ),
+ ]
+ )
+ ),
+ vars_list,
+ )
+ with tf.compat.v1.Session() as sess:
+ sess.run(init_op)
+ sess.run(assign_op)
+ self.assertDictEqual(sess.run(input_map), {'foo': 1, 'bar': 2})
+
+ def test_get_deps_for_graph_node(self):
+ # Creates a graph (double edges are regular dependencies, single edges are
+ # control dependencies) like this:
+ # foo
+ # // \\
+ # foo:0 foo:1
+ # || //
+ # abc bar //
+ # // \ // \\ //
+ # abc:0 bak baz
+ # ||
+ # def
+ # |
+ # ghi
+ #
+ graph_def = tf.compat.v1.GraphDef(
+ node=[
+ tf.compat.v1.NodeDef(name='foo', input=[]),
+ tf.compat.v1.NodeDef(name='bar', input=['foo:0']),
+ tf.compat.v1.NodeDef(name='baz', input=['foo:1', 'bar']),
+ tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']),
+ tf.compat.v1.NodeDef(name='abc', input=[]),
+ tf.compat.v1.NodeDef(name='def', input=['abc:0']),
+ tf.compat.v1.NodeDef(name='ghi', input=['^def']),
+ ]
+ )
+
+ def _get_deps(x):
+ return ','.join(
+ sorted(list(graph_helpers._get_deps_for_graph_node(graph_def, x)))
+ )
+
+ self.assertEqual(_get_deps('foo'), '')
+ self.assertEqual(_get_deps('bar'), 'foo')
+ self.assertEqual(_get_deps('baz'), 'bar,foo')
+ self.assertEqual(_get_deps('bak'), 'abc,bar,foo')
+ self.assertEqual(_get_deps('abc'), '')
+ self.assertEqual(_get_deps('def'), 'abc')
+ self.assertEqual(_get_deps('ghi'), 'abc,def')
+
+ def test_list_tensor_names_in_binding(self):
+ binding = computation_pb2.TensorFlow.Binding(
+ struct=computation_pb2.TensorFlow.StructBinding(
+ element=[
+ computation_pb2.TensorFlow.Binding(
+ tensor=computation_pb2.TensorFlow.TensorBinding(
+ tensor_name='a'
+ )
+ ),
+ computation_pb2.TensorFlow.Binding(
+ struct=computation_pb2.TensorFlow.StructBinding(
+ element=[
+ computation_pb2.TensorFlow.Binding(
+ tensor=computation_pb2.TensorFlow.TensorBinding(
+ tensor_name='b'
+ )
+ ),
+ computation_pb2.TensorFlow.Binding(
+ tensor=computation_pb2.TensorFlow.TensorBinding(
+ tensor_name='c'
+ )
+ ),
+ ]
+ )
+ ),
+ computation_pb2.TensorFlow.Binding(
+ tensor=computation_pb2.TensorFlow.TensorBinding(
+ tensor_name='d'
+ )
+ ),
+ computation_pb2.TensorFlow.Binding(
+ sequence=computation_pb2.TensorFlow.SequenceBinding(
+ variant_tensor_name='e'
+ )
+ ),
+ ]
+ )
+ )
+ self.assertEqual(
+ graph_helpers._list_tensor_names_in_binding(binding),
+ ['a', 'b', 'c', 'd', 'e'],
+ )
+
+
+if __name__ == '__main__':
+ with tff.framework.get_context_stack().install(
+ tff.test.create_runtime_error_context()
+ ):
+ absltest.main()
diff --git a/fcp/artifact_building/plan_utils.py b/fcp/artifact_building/plan_utils.py
new file mode 100644
index 0000000..419d952
--- /dev/null
+++ b/fcp/artifact_building/plan_utils.py
@@ -0,0 +1,161 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities related to plan protos."""
+
+from typing import TypeVar
+
+import tensorflow as tf
+from fcp.artifact_building import tensor_utils
+from fcp.protos import plan_pb2
+
+
+_PlanT = TypeVar('_PlanT', plan_pb2.Plan, plan_pb2.ClientOnlyPlan)
+
+
+# TODO(team): Remove in favor of save_from_checkpoint_op.
+def write_checkpoint(sess, checkpoint_op, checkpoint_filename):
+ """Writes from a CheckpointOp, without executing before/after restore ops."""
+ if not isinstance(checkpoint_op, plan_pb2.CheckpointOp):
+ raise ValueError('A CheckpointOp is required.')
+ if (
+ checkpoint_op
+ and checkpoint_op.saver_def
+ and checkpoint_op.saver_def.save_tensor_name
+ ):
+ sess.run(
+ checkpoint_op.saver_def.save_tensor_name,
+ {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename},
+ )
+
+
+# TODO(team): Remove in favor of restore_from_checkpoint_op.
+def read_checkpoint(sess, checkpoint_op, checkpoint_filename):
+ """Reads from a CheckpointOp, without executing before/after restore ops."""
+ if not isinstance(checkpoint_op, plan_pb2.CheckpointOp):
+ raise ValueError('A CheckpointOp is required.')
+ if (
+ checkpoint_op
+ and checkpoint_op.saver_def
+ and checkpoint_op.saver_def.restore_op_name
+ ):
+ sess.run(
+ checkpoint_op.saver_def.restore_op_name,
+ {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename},
+ )
+
+
+def convert_graphdef_to_flatbuffer(
+ graph: tf.compat.v1.GraphDef,
+ spec: plan_pb2.TensorflowSpec,
+ guarantee_all_funcs_one_use: bool = False,
+):
+ """Converts a tf.Graph to a serialized TFLite model FlatBuffer."""
+
+ def create_input(input_tensor):
+ return (input_tensor.name, [item.size for item in input_tensor.shape.dim])
+
+ inputs = [(spec.dataset_token_tensor_name, [])]
+ for input_tensor in spec.input_tensor_specs:
+ inputs.append(create_input(input_tensor))
+ converter = tf.compat.v1.lite.TFLiteConverter(
+ graph,
+ input_tensors=None,
+ output_tensors=None,
+ input_arrays_with_shape=inputs,
+ output_arrays=[item.name for item in spec.output_tensor_specs],
+ )
+
+ # pylint: disable=protected-access
+ # Sets the control output node names. This is used when converting a tf.Graph
+ # with no output tensors.
+ converter._control_output_arrays = spec.target_node_names
+ # Set this flag to true so that flatbuffer size can be reduced.
+ converter._experimental_unfold_large_splat_constant = True
+ # Exclude conversion metadata generation to reduce conversion time.
+ converter.exclude_conversion_metadata = True
+ converter.target_spec.supported_ops = [
+ tf.lite.OpsSet.TFLITE_BUILTINS,
+ tf.lite.OpsSet.SELECT_TF_OPS,
+ ]
+ converter._experimental_allow_all_select_tf_ops = True
+ converter._experimental_guarantee_all_funcs_one_use = (
+ guarantee_all_funcs_one_use
+ )
+ # Instructs the TF Lite converter to not eliminate Assert ops, since the
+ # client code needs this op to verify result correctness.
+ converter._experimental_preserve_assert_op = True
+ # pylint: enable=protected-access
+ converter.experimental_enable_resource_variables = True
+ return converter.convert()
+
+
+def generate_and_add_flat_buffer_to_plan(
+ plan: _PlanT, forgive_tflite_conversion_failure=True
+) -> _PlanT:
+ """Generates and adds a TFLite model to the specified Plan.
+
+ Note: This method mutates the plan argument.
+
+ Args:
+ plan: An input plan_pb2.Plan object.
+ forgive_tflite_conversion_failure: If True, if TFLite conversion fails no
+ exception will be raised and the Plan will be returned unmutated.
+
+ Returns:
+ The input Plan mutated to include a TFLite model when TFLite conversion
+ succeeds, or the Plan without any mutation if TFLite conversion does not
+ succeed.
+
+ Raises:
+ RuntimeError: if TFLite conversion fails and
+ forgive_tflite_conversion_failure is set to False.
+ """
+
+ def convert(graph_def, tensorflow_spec, guarantee_all_funcs_one_use=False):
+ stateful_partitioned_call_err = (
+ "'tf.StatefulPartitionedCall' op is"
+ + ' neither a custom op nor a flex op'
+ )
+ # Pack the TFLite flatbuffer into a BytesValue proto.
+ try:
+ return convert_graphdef_to_flatbuffer(
+ graph_def, tensorflow_spec, guarantee_all_funcs_one_use
+ )
+ except Exception as e: # pylint: disable=broad-except
+ # Try to handle conversion errors and run converter again.
+ if (
+ stateful_partitioned_call_err in str(e)
+ and not guarantee_all_funcs_one_use
+ ):
+ return convert(graph_def, tensorflow_spec, True)
+ elif forgive_tflite_conversion_failure:
+ return b''
+ else:
+ raise RuntimeError(
+ f'Failure during TFLite conversion of the client graph: {str(e)}'
+ ) from e
+
+ if isinstance(plan, plan_pb2.Plan):
+ client_graph_def = tensor_utils.import_graph_def_from_any(
+ plan.client_graph_bytes
+ )
+ plan.client_tflite_graph_bytes = convert(
+ client_graph_def, plan.phase[0].client_phase.tensorflow_spec
+ )
+ elif isinstance(plan, plan_pb2.ClientOnlyPlan):
+ client_graph_def = tf.compat.v1.GraphDef.FromString(plan.graph)
+ plan.tflite_graph = convert(client_graph_def, plan.phase.tensorflow_spec)
+ else:
+ raise NotImplementedError(f'Unsupported _PlanT {type(plan)}')
+ return plan
diff --git a/fcp/artifact_building/plan_utils_test.py b/fcp/artifact_building/plan_utils_test.py
new file mode 100644
index 0000000..cd9eb8e
--- /dev/null
+++ b/fcp/artifact_building/plan_utils_test.py
@@ -0,0 +1,252 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test class for plan_utils."""
+
+import os
+
+import tensorflow as tf
+
+from fcp.artifact_building import checkpoint_utils
+from fcp.artifact_building import plan_utils
+from fcp.artifact_building import test_utils
+from fcp.protos import plan_pb2
+
+
+class PlanUtilsTest(tf.test.TestCase):
+
+ def test_write_checkpoint(self):
+ checkpoint_op = plan_pb2.CheckpointOp()
+ graph = tf.Graph()
+ with graph.as_default():
+ v = tf.compat.v1.get_variable('v', initializer=tf.constant(1))
+ saver = checkpoint_utils.create_deterministic_saver([v])
+ test_utils.set_checkpoint_op(checkpoint_op, saver)
+ init_op = v.assign(tf.constant(2))
+ change_op = v.assign(tf.constant(3))
+
+ with tf.compat.v1.Session(graph=graph) as sess:
+ sess.run(init_op)
+ temp_file = self.create_tempfile().full_path
+ plan_utils.write_checkpoint(sess, checkpoint_op, temp_file)
+ # Change the variable in this session.
+ sess.run(change_op)
+
+ with tf.compat.v1.Session(graph=graph) as sess:
+ saver.restore(sess, temp_file)
+ # Should not see update to 3.
+ self.assertEqual(2, sess.run(v))
+
+ def test_write_checkpoint_not_checkpoint_op(self):
+ with self.assertRaises(ValueError):
+ plan_utils.write_checkpoint(None, 'not_checkpoint_op', None)
+
+ def test_write_checkpoint_skips_when_no_saver_def(self):
+ checkpoint_op = plan_pb2.CheckpointOp()
+ with tf.compat.v1.Session() as sess:
+ temp_file = self.create_tempfile().full_path
+ # Close deletes the file, we just want a good name.
+ os.remove(temp_file)
+ plan_utils.write_checkpoint(sess, checkpoint_op, temp_file)
+ self.assertFalse(os.path.isfile(temp_file))
+
+ def test_read_checkpoint(self):
+ checkpoint_op = plan_pb2.CheckpointOp()
+ graph = tf.Graph()
+ with graph.as_default():
+ v = tf.compat.v1.get_variable('v', initializer=tf.constant(1))
+ saver = checkpoint_utils.create_deterministic_saver([v])
+ test_utils.set_checkpoint_op(checkpoint_op, saver)
+ init_op = v.assign(tf.constant(2))
+ change_op = v.assign(tf.constant(3))
+
+ with tf.compat.v1.Session(graph=graph) as sess:
+ sess.run(init_op)
+ temp_file = self.create_tempfile().full_path
+ saver.save(sess, temp_file)
+ sess.run(change_op)
+
+ plan_utils.read_checkpoint(sess, checkpoint_op, temp_file)
+ # Should not see update to 3.
+ self.assertEqual(2, sess.run(v))
+
+ def test_generate_and_add_tflite_model_to_plan(self):
+ # Create a graph for y = x ^ 2.
+ graph = tf.Graph()
+ with graph.as_default():
+ x = tf.compat.v1.placeholder(tf.int32, shape=[], name='x')
+ _ = tf.math.pow(x, 2, name='y')
+ input_tensor_spec = tf.TensorSpec(
+ shape=tf.TensorShape([]), dtype=tf.int32, name='x:0'
+ ).experimental_as_proto()
+ output_tensor_spec = tf.TensorSpec(
+ shape=tf.TensorShape([]), dtype=tf.int32, name='y:0'
+ ).experimental_as_proto()
+
+ tensorflow_spec = plan_pb2.TensorflowSpec()
+ tensorflow_spec.input_tensor_specs.append(input_tensor_spec)
+ tensorflow_spec.output_tensor_specs.append(output_tensor_spec)
+
+ flatbuffer = plan_utils.convert_graphdef_to_flatbuffer(
+ graph.as_graph_def(), tensorflow_spec
+ )
+
+ interpreter = tf.lite.Interpreter(model_content=flatbuffer)
+ interpreter.allocate_tensors()
+ input_data = tf.constant(3, shape=[])
+ # Model has single output.
+ model_output = interpreter.get_output_details()[0]
+ # Model has single input.
+ model_input = interpreter.get_input_details()[0]
+ interpreter.set_tensor(model_input['index'], input_data)
+ interpreter.invoke()
+ self.assertEqual(interpreter.get_tensor(model_output['index']), 9)
+
+
+class TfLiteTest(tf.test.TestCase):
+ """Tests common methods related to TFLite support."""
+
+ def test_caught_exception_in_tflite_conversion_failure_for_plan(self):
+ plan = plan_pb2.Plan()
+ plan.client_graph_bytes.Pack(tf.compat.v1.GraphDef())
+ plan.phase.add()
+ with self.assertRaisesRegex(
+ RuntimeError, 'Failure during TFLite conversion'
+ ):
+ plan_utils.generate_and_add_flat_buffer_to_plan(
+ plan, forgive_tflite_conversion_failure=False
+ )
+
+ def test_forgive_tflite_conversion_failure_for_plan(self):
+ plan = plan_pb2.Plan()
+ plan.client_graph_bytes.Pack(tf.compat.v1.GraphDef())
+ plan.phase.add()
+ plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan(
+ plan, forgive_tflite_conversion_failure=True
+ )
+ self.assertIsInstance(plan_after_conversion, plan_pb2.Plan)
+ self.assertEmpty(plan_after_conversion.client_tflite_graph_bytes)
+
+ def test_caught_exception_in_tflite_conversion_failure_for_client_only_plan(
+ self,
+ ):
+ client_only_plan = plan_pb2.ClientOnlyPlan()
+ client_only_plan.graph = tf.compat.v1.GraphDef().SerializeToString()
+ with self.assertRaisesRegex(
+ RuntimeError, 'Failure during TFLite conversion'
+ ):
+ plan_utils.generate_and_add_flat_buffer_to_plan(
+ client_only_plan, forgive_tflite_conversion_failure=False
+ )
+
+ def test_forgive_tflite_conversion_failure_for_client_only_plan(self):
+ client_only_plan = plan_pb2.ClientOnlyPlan()
+ client_only_plan.graph = tf.compat.v1.GraphDef().SerializeToString()
+ plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan(
+ client_only_plan, forgive_tflite_conversion_failure=True
+ )
+ self.assertIsInstance(plan_after_conversion, plan_pb2.ClientOnlyPlan)
+ self.assertEmpty(plan_after_conversion.tflite_graph)
+
+ def _create_test_graph_with_associated_tensor_specs(self):
+ # Create a graph for y = x ^ 2.
+ graph = tf.Graph()
+ with graph.as_default():
+ x = tf.compat.v1.placeholder(tf.int32, shape=[], name='x')
+ _ = tf.math.pow(x, 2, name='y')
+ input_tensor_spec = tf.TensorSpec(
+ shape=tf.TensorShape([]), dtype=tf.int32, name='x:0'
+ ).experimental_as_proto()
+ output_tensor_spec = tf.TensorSpec(
+ shape=tf.TensorShape([]), dtype=tf.int32, name='y:0'
+ ).experimental_as_proto()
+ return graph, input_tensor_spec, output_tensor_spec
+
+ def _assert_tflite_flatbuffer_is_equivalent_to_test_graph(self, tflite_graph):
+ # Check that the generated TFLite model also is y = x ^ 2.
+ self.assertNotEmpty(tflite_graph)
+ interpreter = tf.lite.Interpreter(model_content=tflite_graph)
+ interpreter.allocate_tensors()
+ input_data = tf.constant(3, shape=[])
+ # Model has single output.
+ model_output = interpreter.get_output_details()[0]
+ # Model has single input.
+ model_input = interpreter.get_input_details()[0]
+ interpreter.set_tensor(model_input['index'], input_data)
+ interpreter.invoke()
+ self.assertEqual(interpreter.get_tensor(model_output['index']), 9)
+
+ def test_add_equivalent_tflite_model_to_plan(self):
+ """Tests that the generated tflite model is identical to the tf.Graph."""
+
+ graph, input_tensor_spec, output_tensor_spec = (
+ self._create_test_graph_with_associated_tensor_specs()
+ )
+
+ # Create a fairly empty Plan with just the graph and the
+ # TensorSpecProtos populated (since that is all that is needed for
+ # conversion.)
+ plan_proto = plan_pb2.Plan()
+ plan_proto.client_graph_bytes.Pack(graph.as_graph_def())
+ plan_proto.phase.add()
+ plan_proto.phase[0].client_phase.tensorflow_spec.input_tensor_specs.append(
+ input_tensor_spec
+ )
+ plan_proto.phase[0].client_phase.tensorflow_spec.output_tensor_specs.append(
+ output_tensor_spec
+ )
+
+ # Generate the TFLite model.
+ plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan(
+ plan_proto
+ )
+
+ self.assertIsInstance(plan_after_conversion, plan_pb2.Plan)
+ self.assertEqual(plan_after_conversion, plan_proto)
+ self._assert_tflite_flatbuffer_is_equivalent_to_test_graph(
+ plan_after_conversion.client_tflite_graph_bytes
+ )
+
+ def test_add_equivalent_tflite_model_to_client_only_plan(self):
+ """Tests that the generated tflite model is identical to the tf.Graph."""
+
+ graph, input_tensor_spec, output_tensor_spec = (
+ self._create_test_graph_with_associated_tensor_specs()
+ )
+
+ # Create a fairly empty ClientOnlyPlan with just the graph and the
+ # TensorSpecProtos populated (since that is all that is needed for
+ # conversion.)
+ client_only_plan_proto = plan_pb2.ClientOnlyPlan()
+ client_only_plan_proto.graph = graph.as_graph_def().SerializeToString()
+ client_only_plan_proto.phase.tensorflow_spec.input_tensor_specs.append(
+ input_tensor_spec
+ )
+ client_only_plan_proto.phase.tensorflow_spec.output_tensor_specs.append(
+ output_tensor_spec
+ )
+
+ # Generate the TFLite model.
+ plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan(
+ client_only_plan_proto
+ )
+
+ self.assertIsInstance(plan_after_conversion, plan_pb2.ClientOnlyPlan)
+ self.assertEqual(plan_after_conversion, client_only_plan_proto)
+ self._assert_tflite_flatbuffer_is_equivalent_to_test_graph(
+ plan_after_conversion.tflite_graph
+ )
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/fcp/artifact_building/proto_helpers.py b/fcp/artifact_building/proto_helpers.py
new file mode 100644
index 0000000..5418191
--- /dev/null
+++ b/fcp/artifact_building/proto_helpers.py
@@ -0,0 +1,129 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Helper methods for proto creation logic."""
+
+from typing import Optional
+
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import tensor_utils
+from fcp.artifact_building import type_checks
+from fcp.protos import plan_pb2
+
+
+def make_tensor_spec_from_tensor(
+ t: tf.Tensor, shape_hint: Optional[tf.TensorShape] = None
+) -> tf.TensorSpec:
+ """Creates a `TensorSpec` from Tensor w/ optional shape hint.
+
+ Args:
+ t: A `tf.Tensor` instance to be used to create a `TensorSpec`.
+ shape_hint: A `tf.TensorShape` that provides a fully defined shape in the
+ case that `t` is partially defined. If `t` has a fully defined shape,
+ `shape_hint` is ignored. `shape_hint` must be compatible with the
+ partially defined shape of `t`.
+
+ Returns:
+ A `tf.TensorSpec` instance corresponding to the input `tf.Tensor`.
+
+ Raises:
+ NotImplementedError: If the input `tf.Tensor` type is not supported.
+ TypeError: if `shape_hint` is not `None` and is incompatible with the
+ runtime shape of `t`.
+ """
+ if not tf.is_tensor(t):
+ raise NotImplementedError(
+ 'Cannot handle type {t}: {v}'.format(t=type(t), v=t)
+ )
+ derived_shape = tf.TensorShape(t.shape)
+ if not derived_shape.is_fully_defined() and shape_hint is not None:
+ if derived_shape.is_compatible_with(shape_hint):
+ shape = shape_hint
+ else:
+ raise TypeError(
+ 'shape_hint is not compatible with tensor ('
+ f'{shape_hint} vs {derived_shape})'
+ )
+ else:
+ shape = derived_shape
+ return tf.TensorSpec(shape, t.dtype, name=t.name)
+
+
+def make_measurement(
+ t: tf.Tensor, name: str, tff_type: tff.types.TensorType
+) -> plan_pb2.Measurement:
+ """Creates a `plan_pb.Measurement` descriptor for a tensor.
+
+ Args:
+ t: A tensor to create the measurement for.
+ name: The name of the measurement (e.g. 'server/loss').
+ tff_type: The `tff.Type` of the measurement.
+
+ Returns:
+ An instance of `plan_pb.Measurement`.
+
+ Raises:
+ ValueError: If the `dtype`s or `shape`s of the provided tensor and TFF type
+ do not match.
+ """
+ type_checks.check_type(tff_type, tff.types.TensorType)
+ if tff_type.dtype != t.dtype:
+ raise ValueError(
+ f'`tff_type.dtype`: {tff_type.dtype} does not match '
+ f"provided tensor's dtype: {t.dtype}."
+ )
+ if tff_type.shape.is_fully_defined() and t.shape.is_fully_defined():
+ if tff_type.shape.as_list() != t.shape.as_list():
+ raise ValueError(
+ f'`tff_type.shape`: {tff_type.shape} does not match '
+ f"provided tensor's shape: {t.shape}."
+ )
+ return plan_pb2.Measurement(
+ read_op_name=t.name,
+ name=name,
+ tff_type=tff.types.serialize_type(tff_type).SerializeToString(),
+ )
+
+
+def make_metric(v: tf.Variable, stat_name_prefix: str) -> plan_pb2.Metric:
+ """Creates a `plan_pb.Metric` descriptor for a resource variable.
+
+ The stat name is formed by stripping the leading `..../` prefix and any
+ colon-based suffix.
+
+ Args:
+ v: A variable to create the metric descriptor for.
+ stat_name_prefix: The prefix (string) to use in formulating a stat name,
+ excluding the trailing slash `/` (added automatically).
+
+ Returns:
+ An instance of `plan_pb.Metric` for `v`.
+
+ Raises:
+ TypeError: If the arguments are of the wrong types.
+ ValueError: If the arguments are malformed (e.g., no leading name prefix).
+ """
+ type_checks.check_type(stat_name_prefix, str, name='stat_name_prefix')
+ if not hasattr(v, 'read_value'):
+ raise TypeError('Expected a resource variable, found {!r}.'.format(type(v)))
+ bare_name = tensor_utils.bare_name(v.name)
+ if '/' not in bare_name:
+ raise ValueError(
+ 'Expected a prefix in the name, found none in {}.'.format(bare_name)
+ )
+ stat_name = '{}/{}'.format(
+ stat_name_prefix, bare_name[(bare_name.find('/') + 1) :]
+ )
+ return plan_pb2.Metric(variable_name=v.read_value().name, stat_name=stat_name)
diff --git a/fcp/artifact_building/proto_helpers_test.py b/fcp/artifact_building/proto_helpers_test.py
new file mode 100644
index 0000000..2c74aeb
--- /dev/null
+++ b/fcp/artifact_building/proto_helpers_test.py
@@ -0,0 +1,184 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for proto_helpers.py."""
+
+import collections
+
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import proto_helpers
+from fcp.artifact_building import variable_helpers
+
+
+class MakeMetricTest(tf.test.TestCase):
+
+ def test_make_metric(self):
+ with tf.Graph().as_default():
+ v = variable_helpers.create_vars_for_tff_type(
+ tff.to_type(collections.OrderedDict([("bar", tf.int32)])), name="foo"
+ )
+ self.assertProtoEquals(
+ "variable_name: 'Identity:0' stat_name: 'client/bar'",
+ proto_helpers.make_metric(v[0], "client"),
+ )
+
+
+class MakeTensorSpecTest(tf.test.TestCase):
+
+ def test_fully_defined_shape(self):
+ with tf.Graph().as_default():
+ test_tensor = tf.constant([[1], [2]]) # Shape [1, 2]
+ with self.subTest("no_hint"):
+ tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor)
+ self.assertProtoEquals(
+ (
+ "name: 'Const:0' "
+ "shape { "
+ " dim { size: 2 } "
+ " dim { size: 1 } "
+ "} "
+ "dtype: DT_INT32"
+ ),
+ tensor_spec.experimental_as_proto(),
+ )
+ with self.subTest("ignored_hint"):
+ # Supplied shape hint is incompatible, but ignored because tensor is
+ # fully defined.
+ tensor_spec = proto_helpers.make_tensor_spec_from_tensor(
+ test_tensor, shape_hint=tf.TensorShape([1, 4])
+ )
+ self.assertProtoEquals(
+ (
+ "name: 'Const:0' "
+ "shape { "
+ " dim { size: 2 } "
+ " dim { size: 1 } "
+ "} "
+ "dtype: DT_INT32"
+ ),
+ tensor_spec.experimental_as_proto(),
+ )
+
+ def test_undefined_shape(self):
+ with tf.Graph().as_default():
+ # Create a undefined shape tensor via a placeholder and an op that doesn't
+ # alter shape.
+ test_tensor = tf.clip_by_value(
+ tf.compat.v1.placeholder(dtype=tf.int32), 0, 1
+ )
+ with self.subTest("no_hint"):
+ tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor)
+ self.assertProtoEquals(
+ (
+ "name: 'clip_by_value:0' "
+ "shape { "
+ " unknown_rank: true "
+ "} "
+ "dtype: DT_INT32"
+ ),
+ tensor_spec.experimental_as_proto(),
+ )
+ with self.subTest("hint"):
+ tensor_spec = proto_helpers.make_tensor_spec_from_tensor(
+ test_tensor, shape_hint=tf.TensorShape([1, 4])
+ )
+ self.assertProtoEquals(
+ (
+ "name: 'clip_by_value:0' "
+ "shape { "
+ " dim { size: 1 } "
+ " dim { size: 4 } "
+ "} "
+ "dtype: DT_INT32"
+ ),
+ tensor_spec.experimental_as_proto(),
+ )
+
+ def test_partially_defined_shape(self):
+ with tf.Graph().as_default():
+ # Create a partially defined shape tensor via a placeholder and a reshape
+ # to specify some dimensions.
+ test_tensor = tf.reshape(
+ tf.compat.v1.placeholder(dtype=tf.int32), [2, -1]
+ )
+ with self.subTest("no_hint"):
+ tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor)
+ self.assertProtoEquals(
+ (
+ "name: 'Reshape:0' "
+ "shape { "
+ " dim { size: 2 } "
+ " dim { size: -1 } "
+ "} "
+ "dtype: DT_INT32"
+ ),
+ tensor_spec.experimental_as_proto(),
+ )
+ with self.subTest("hint"):
+ tensor_spec = proto_helpers.make_tensor_spec_from_tensor(
+ test_tensor, shape_hint=tf.TensorShape([2, 4])
+ )
+ self.assertProtoEquals(
+ (
+ "name: 'Reshape:0' "
+ "shape { "
+ " dim { size: 2 } "
+ " dim { size: 4} "
+ "} "
+ "dtype: DT_INT32"
+ ),
+ tensor_spec.experimental_as_proto(),
+ )
+ with self.subTest("invalid_hint"):
+ with self.assertRaises(TypeError):
+ _ = proto_helpers.make_tensor_spec_from_tensor(
+ test_tensor, shape_hint=tf.TensorShape([1, 4])
+ )
+
+
+class MakeMeasurementTest(tf.test.TestCase):
+
+ def test_succeeds(self):
+ with tf.Graph().as_default():
+ tensor = tf.constant(1)
+ tff_type = tff.types.TensorType(tensor.dtype, tensor.shape)
+ m = proto_helpers.make_measurement(
+ t=tensor, name="test", tff_type=tff_type
+ )
+
+ self.assertEqual(m.name, "test")
+ self.assertProtoEquals(
+ m.tff_type, tff.types.serialize_type(tff_type).SerializeToString()
+ )
+
+ def test_fails_for_non_matching_dtype(self):
+ with tf.Graph().as_default():
+ tensor = tf.constant(1.0)
+ tff_type = tff.types.TensorType(tf.int32, tensor.shape)
+
+ with self.assertRaisesRegex(ValueError, ".* does not match.*"):
+ proto_helpers.make_measurement(t=tensor, name="test", tff_type=tff_type)
+
+ def test_fails_for_non_matching_shape(self):
+ with tf.Graph().as_default():
+ tensor = tf.constant(1.0)
+ tff_type = tff.types.TensorType(tensor.dtype, shape=[5])
+
+ with self.assertRaisesRegex(ValueError, ".* does not match.*"):
+ proto_helpers.make_measurement(t=tensor, name="test", tff_type=tff_type)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/fcp/artifact_building/tensor_utils.py b/fcp/artifact_building/tensor_utils.py
new file mode 100644
index 0000000..4547492
--- /dev/null
+++ b/fcp/artifact_building/tensor_utils.py
@@ -0,0 +1,153 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities specific to the manipulation of tensors and operators."""
+
+from typing import Any, Callable, Optional, Union
+
+import tensorflow as tf
+
+
+######################################################################
+# Helper functions for names and naming.
+#
+def bare_name(v) -> str:
+ """Strips off the part after the colon in a tensor name."""
+ name = name_or_str(v)
+ if name[0] == '^':
+ name = name[1:]
+ # User specified names are everything up to the first colon. User supplied
+ # names cannot contain colons, TensorFlow will raise an error on invalid name.
+ colon = name.find(':')
+ if colon >= 0:
+ return name[:colon]
+ else:
+ return name
+
+
+def name_or_str(v) -> str:
+ """Returns the name of v, or if v has no name attr, str(op)."""
+ if hasattr(v, 'name'):
+ name = v.name
+ assert isinstance(name, str)
+ return name
+ return str(v)
+
+
+######################################################################
+# Helper function for graphs.
+#
+
+
+def import_graph_def_from_any(an) -> tf.compat.v1.GraphDef:
+ """Parses a tf.compat.v1.GraphDef from an Any message.
+
+ Args:
+ an: An 'Any' message, which contains a serialized tf.compat.v1.GraphDef. The
+ type_url field of the Any message must identify a supported type;
+ currently, the only supported type is 'type.googleapis.com/GraphDef'.
+
+ Returns:
+ A tf.compat.v1.GraphDef object.
+ """
+ assert an
+ # The only kind of supported graph is a TensorFlow GraphDef.
+ assert an.Is(tf.compat.v1.GraphDef.DESCRIPTOR)
+ g = tf.compat.v1.GraphDef()
+ an.Unpack(g)
+ return g
+
+
+######################################################################
+# Helper functions for savers or saverdefs.
+#
+
+
+def save(
+ filename: Union[tf.Tensor, str],
+ tensor_names: list[str],
+ tensors: list[tf.Tensor],
+ tensor_slices: Optional[list[str]] = None,
+ name: str = 'save',
+ save_op: Callable[..., Any] = tf.raw_ops.SaveSlices,
+) -> tf.Operation:
+ """Saves a list of tensors to file.
+
+ This function always passes a value for the `tensor_slices` argument in order
+ to use the `SaveSlices` op (instead of a `Save` op).
+
+ Args:
+ filename: A string or a scalar tensor of dtype string that specifies the
+ path to file.
+ tensor_names: A list of strings.
+ tensors: A list of tensors to be saved.
+ tensor_slices: An optional list of strings, that specifies the shape and
+ slices of a larger virtual tensor that each tensor is a part of. If not
+ specified, each tensor is saved as a full slice.
+ name: An optional name for the op.
+ save_op: A callable that creates the op(s) to use for performing the tensor
+ save. Defaults to `tf.raw_ops.SaveSlices`.
+
+ Returns:
+ A `SaveSlices` op in graph mode or None in eager mode.
+ """
+ tensor_slices = tensor_slices if tensor_slices else ([''] * len(tensors))
+ return save_op(
+ filename=filename,
+ tensor_names=tensor_names,
+ shapes_and_slices=tensor_slices,
+ data=tensors,
+ name=name,
+ )
+
+
+def restore(
+ filename: Union[tf.Tensor, str],
+ tensor_name: str,
+ tensor_type: tf.DType,
+ tensor_shape: Optional[tf.TensorShape] = None,
+ name: str = 'restore',
+) -> tf.Tensor:
+ """Restores a tensor from the file.
+
+ It is a wrapper of `tf.raw_ops.RestoreV2`. When used in graph mode, it adds a
+ `RestoreV2` op to the graph.
+
+ Args:
+ filename: A string or a scalar tensor of dtype string that specifies the
+ path to file.
+ tensor_name: The name of the tensor to restore.
+ tensor_type: The type of the tensor to restore.
+ tensor_shape: Optional. The shape of the tensor to restore.
+ name: An optional name for the op.
+
+ Returns:
+ A tensor of dtype `tensor_type`.
+ """
+ shape_str = ''
+ slice_str = ''
+ if tensor_shape is not None and tensor_shape.rank > 0:
+ shape_str = ' '.join('%d' % d for d in tensor_shape) + ' '
+ # Ideally we want to pass an empty string to slice, but this is not allowed
+ # because the size of the slice string list (after the string is split by
+ # separator ':') needs to match the rank of the tensor (see b/197779415 for
+ # more information).
+ slice_str = ':-' * tensor_shape.rank
+ restored_tensors = tf.raw_ops.RestoreV2(
+ prefix=filename,
+ tensor_names=[tensor_name],
+ shape_and_slices=[shape_str + slice_str],
+ dtypes=[tensor_type],
+ name=name,
+ )
+ return restored_tensors[0]
diff --git a/fcp/artifact_building/tensor_utils_test.py b/fcp/artifact_building/tensor_utils_test.py
new file mode 100644
index 0000000..9b4a630
--- /dev/null
+++ b/fcp/artifact_building/tensor_utils_test.py
@@ -0,0 +1,157 @@
+"""Tests for tensor_utils."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+import tensorflow as tf
+
+from google.protobuf import any_pb2
+from fcp.artifact_building import tensor_utils
+
+
+class TensorUtilsTest(parameterized.TestCase, tf.test.TestCase):
+
+ def test_bare_name(self):
+ self.assertEqual(tensor_utils.bare_name('foo'), 'foo')
+ self.assertEqual(tensor_utils.bare_name('foo:0'), 'foo')
+ self.assertEqual(tensor_utils.bare_name('foo:1'), 'foo')
+ self.assertEqual(tensor_utils.bare_name('^foo:1'), 'foo')
+ self.assertEqual(tensor_utils.bare_name('^foo:output:2'), 'foo')
+ with tf.Graph().as_default() as g:
+ v = tf.Variable(0.0, name='foo')
+ self.assertEqual(tensor_utils.bare_name(v), 'foo')
+
+ @tf.function
+ def foo(x):
+ return tf.add(x, v.read_value(), 'add_op')
+
+ foo(tf.constant(1.0))
+
+ # Exchange the input tensor names (the outputs of other nodes) in the graph
+ # to ensure we can recover the original user-specified bare names.
+ graph_def = g.as_graph_def()
+ # Test that the graph def contains
+ graph_def_str = str(graph_def)
+ self.assertIn('add_op:z:0', graph_def_str)
+ self.assertIn('Read/ReadVariableOp:value:0', graph_def_str)
+ # Ensure that we can locate
+ required_names = ['add_op', 'Read/ReadVariableOp']
+ for node in graph_def.library.function[0].node_def:
+ for i in node.input:
+ if tensor_utils.bare_name(i) in required_names:
+ required_names.remove(tensor_utils.bare_name(i))
+ self.assertEmpty(required_names)
+
+ def test_bare_name_with_scope(self):
+ self.assertEqual(tensor_utils.bare_name('bar/foo:1'), 'bar/foo')
+
+ with tf.Graph().as_default():
+ with tf.compat.v1.variable_scope('bar'):
+ v = tf.Variable(0.0, name='foo')
+ self.assertEqual(tensor_utils.bare_name(v), 'bar/foo')
+
+ def test_name_or_str_with_named_variable(self):
+ with tf.Graph().as_default():
+ v = tf.Variable(0.0, name='foo')
+ self.assertEqual('foo:0', tensor_utils.name_or_str(v))
+
+ def test_name_or_str_with_unnamed_variable(self):
+ with tf.Graph().as_default():
+ v = tf.Variable(0.0)
+ self.assertEqual('Variable:0', tensor_utils.name_or_str(v))
+
+ def test_import_graph_def_from_any(self):
+ with tf.Graph().as_default() as g:
+ tf.constant(0.0)
+ graph_def = g.as_graph_def()
+ graph_def_any = any_pb2.Any()
+ graph_def_any.Pack(graph_def)
+ # Graph object doesn't have equality, so we check that the graph defs match.
+ self.assertEqual(
+ tensor_utils.import_graph_def_from_any(graph_def_any), g.as_graph_def()
+ )
+
+ def test_save_and_restore_in_eager_mode(self):
+ filename = tf.constant(self.create_tempfile().full_path)
+ tensor_name = 'a'
+ tensor = tf.constant(1.0)
+ tensor_utils.save(filename, [tensor_name], [tensor])
+ restored_tensor = tensor_utils.restore(filename, tensor_name, tensor.dtype)
+ self.assertAllEqual(tensor, restored_tensor)
+
+ @parameterized.named_parameters(
+ ('scalar_tensor', tf.constant(1.0)),
+ ('non_scalar_tensor', tf.constant([1.0, 2.0])),
+ )
+ def test_save_and_restore_with_shape_info_in_eager_mode(self, tensor):
+ filename = tf.constant(self.create_tempfile().full_path)
+ tensor_name = 'a'
+ tensor_utils.save(filename, [tensor_name], [tensor])
+ restored_tensor = tensor_utils.restore(
+ filename, tensor_name, tensor.dtype, tensor.shape
+ )
+ self.assertAllEqual(tensor, restored_tensor)
+
+ def _assert_op_in_graph(self, expected_op, graph):
+ graph_def = graph.as_graph_def()
+ node_ops = [node.op for node in graph_def.node]
+ self.assertIn(expected_op, node_ops)
+
+ def _get_shape_and_slices_value(self, graph):
+ graph_def = graph.as_graph_def()
+ node_name_to_value_dict = {node.name: node for node in graph_def.node}
+ self.assertIn('restore/shape_and_slices', node_name_to_value_dict)
+ return (
+ node_name_to_value_dict['restore/shape_and_slices']
+ .attr['value']
+ .tensor.string_val[0]
+ )
+
+ def test_save_and_restore_in_graph_mode(self):
+ temp_file = self.create_tempfile().full_path
+ graph = tf.Graph()
+ with graph.as_default():
+ filename = tf.constant(temp_file)
+ tensor_name = 'a'
+ tensor = tf.constant(1.0)
+ save_op = tensor_utils.save(filename, [tensor_name], [tensor])
+ restored = tensor_utils.restore(filename, tensor_name, tensor.dtype)
+ with tf.compat.v1.Session(graph=graph) as sess:
+ sess.run(save_op)
+ expected_tensor, restored_tensor = sess.run([tensor, restored])
+ self.assertAllEqual(expected_tensor, restored_tensor)
+ self._assert_op_in_graph(expected_op='SaveSlices', graph=graph)
+ self._assert_op_in_graph(expected_op='RestoreV2', graph=graph)
+ self.assertEqual(b'', self._get_shape_and_slices_value(graph))
+
+ @parameterized.named_parameters(
+ ('scalar_tensor', lambda: tf.constant(1.0), b''),
+ ('non_scalar_tensor', lambda: tf.constant([1.0, 2.0]), b'2 :-'),
+ )
+ def test_save_and_restore_with_shape_info_in_graph_mode(
+ self, tensor_builder, expected_shape_and_slices_value
+ ):
+ temp_file = self.create_tempfile().full_path
+ graph = tf.Graph()
+ with graph.as_default():
+ filename = tf.constant(temp_file)
+ tensor_name = 'a'
+ tensor = tensor_builder()
+ save_op = tensor_utils.save(filename, [tensor_name], [tensor])
+ restored = tensor_utils.restore(
+ filename, tensor_name, tensor.dtype, tensor.shape
+ )
+ with tf.compat.v1.Session(graph=graph) as sess:
+ sess.run(save_op)
+ expected_tensor, restored_tensor = sess.run([tensor, restored])
+ self.assertAllEqual(expected_tensor, restored_tensor)
+ self._assert_op_in_graph(expected_op='SaveSlices', graph=graph)
+ self._assert_op_in_graph(expected_op='RestoreV2', graph=graph)
+ self.assertEqual(
+ expected_shape_and_slices_value,
+ self._get_shape_and_slices_value(graph),
+ )
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/artifact_building/test_utils.py b/fcp/artifact_building/test_utils.py
new file mode 100644
index 0000000..ea600dd
--- /dev/null
+++ b/fcp/artifact_building/test_utils.py
@@ -0,0 +1,40 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities used in tests."""
+
+import tensorflow as tf
+
+from fcp.protos import plan_pb2
+
+
+def set_checkpoint_op(
+ checkpoint_op_proto: plan_pb2.CheckpointOp,
+ saver: tf.compat.v1.train.SaverDef,
+):
+ """Sets the saver_def from saver onto checkpoint_op_proto and fixes a name."""
+ if not saver:
+ return
+ saver_def_proto = checkpoint_op_proto.saver_def
+
+ saver_def_proto.CopyFrom(saver.as_saver_def())
+ # They are calling an Op a Tensor and it works in python and
+ # breaks in C++. However, for use in the python Saver class, we
+ # need the tensor because we need sess.run() to return the
+ # tensor's value. So, we only strip the ":0" in the case of
+ # plan execution, where we use the write_checkpoint and
+ # read_checkpoint methods below instead of the Saver.
+ saver_def_proto.save_tensor_name = saver_def_proto.save_tensor_name.replace(
+ ':0', ''
+ )
+ assert saver_def_proto.save_tensor_name.rfind(':') == -1
diff --git a/fcp/artifact_building/test_utils_test.py b/fcp/artifact_building/test_utils_test.py
new file mode 100644
index 0000000..70bee87
--- /dev/null
+++ b/fcp/artifact_building/test_utils_test.py
@@ -0,0 +1,44 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for test_utils.py."""
+
+from absl.testing import absltest
+
+import tensorflow as tf
+
+from fcp.artifact_building import checkpoint_utils
+from fcp.artifact_building import test_utils
+from fcp.protos import plan_pb2
+
+
+class TestUtilsTest(absltest.TestCase):
+
+ def test_set_savepoint(self):
+ checkpoint_op = plan_pb2.CheckpointOp()
+ graph = tf.Graph()
+ with graph.as_default():
+ v = tf.compat.v1.get_variable('v', initializer=tf.constant(1))
+ saver = checkpoint_utils.create_deterministic_saver([v])
+ test_utils.set_checkpoint_op(checkpoint_op, saver)
+ self.assertTrue(checkpoint_op.HasField('saver_def'))
+ self.assertNotIn(':', checkpoint_op.saver_def.save_tensor_name)
+
+ def test_set_savepoint_no_saver(self):
+ checkpoint_op = plan_pb2.CheckpointOp()
+ test_utils.set_checkpoint_op(checkpoint_op, None)
+ self.assertEqual(plan_pb2.CheckpointOp(), checkpoint_op)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/artifact_building/type_checks.py b/fcp/artifact_building/type_checks.py
new file mode 100644
index 0000000..022594f
--- /dev/null
+++ b/fcp/artifact_building/type_checks.py
@@ -0,0 +1,103 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Helper methods for doing runtime type checks."""
+
+from typing import Any, Optional, Tuple, Type, Union
+
+import tensorflow as tf
+
+
+def _format_name_for_error(name: Optional[Any]) -> str:
+ """Formats an optional object name for `check_*` error messages.
+
+ Args:
+ name: Optional name of the object being checked. If unspecified, will use a
+ placeholder object name instead.
+
+ Returns:
+ A formatted name for the object suitable for including in error messages.
+ """
+ return f'`{name}`' if name else 'argument'
+
+
+def check_type(
+ obj: Any,
+ t: Union[Type[Any], Tuple[Type[Any], ...]],
+ name: Optional[str] = None,
+) -> None:
+ """Checks if an object is an instance of a type.
+
+ Args:
+ obj: The object to check.
+ t: The type to test whether `obj` is an instance or not.
+ name: Optional name of the object being checked. Will be included in the
+ error message if specified.
+
+ Raises:
+ TypeError: If `obj` is not an instance of `t`.
+ """
+ if not isinstance(obj, t):
+ msg_name = _format_name_for_error(name)
+ raise TypeError(
+ f'Expected {msg_name} to be an instance of type {t!r}, but '
+ f'found an instance of type {type(obj)!r}.'
+ )
+
+
+def check_callable(obj: Any, name: Optional[str] = None) -> None:
+ """Checks if an object is a Python callable.
+
+ Args:
+ obj: The object to check.
+ name: Optional name of the object being checked. Will be included in the
+ error message if specified.
+
+ Raises:
+ TypeError: If `obj` is not a Python callable.
+ """
+ if not callable(obj):
+ msg_name = _format_name_for_error(name)
+ raise TypeError(
+ f'Expected {msg_name} to be callable, but found an '
+ f'instance of {type(obj)!r}.'
+ )
+
+
+def check_dataset(
+ obj: Union[
+ tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset
+ ],
+ name: Optional[str] = None,
+) -> None:
+ """Checks that the runtime type of the input is a Tensorflow Dataset.
+
+ Tensorflow has many classes which conform to the Dataset API. This method
+ checks each of the known Dataset types.
+
+ Args:
+ obj: The input object to check.
+ name: Optional name of the object being checked. Will be included in the
+ error message if specified.
+ """
+ dataset_types = (
+ tf.data.Dataset,
+ tf.compat.v1.data.Dataset,
+ tf.compat.v2.data.Dataset,
+ )
+ if not isinstance(obj, dataset_types):
+ msg_name = _format_name_for_error(name)
+ raise TypeError(
+ f'Expected {msg_name} to be a Dataset; but found an '
+ f'instance of {type(obj).__name__}.'
+ )
diff --git a/fcp/artifact_building/type_checks_test.py b/fcp/artifact_building/type_checks_test.py
new file mode 100644
index 0000000..579e60e
--- /dev/null
+++ b/fcp/artifact_building/type_checks_test.py
@@ -0,0 +1,109 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for type_checks.py."""
+
+from absl.testing import absltest
+
+import tensorflow as tf
+
+from fcp.artifact_building import type_checks
+
+
+class TestObject:
+ pass
+
+
+class TestObject2:
+ pass
+
+
+class TypeChecksTest(absltest.TestCase):
+
+ def test_check_callable_succeeds(self):
+ type_checks.check_callable(lambda: None)
+
+ def foo():
+ pass
+
+ type_checks.check_callable(foo)
+
+ class Bar:
+
+ def __call__(self):
+ pass
+
+ type_checks.check_callable(Bar())
+
+ def test_check_callable_failure(self):
+ with self.assertRaisesRegex(TypeError, 'Expected argument to be callable'):
+ type_checks.check_callable(None)
+ with self.assertRaisesRegex(TypeError, 'Expected argument to be callable'):
+ type_checks.check_callable(0)
+ with self.assertRaisesRegex(TypeError, 'Expected argument to be callable'):
+ type_checks.check_callable([])
+
+ def test_check_callable_failure_message_with_name(self):
+ with self.assertRaisesRegex(TypeError, r'\bfoo\b'):
+ type_checks.check_callable(3, name='foo')
+
+ def test_check_type_succeeds(self):
+ with self.subTest('int'):
+ type_checks.check_type(3, int)
+ type_checks.check_type(3, (int, float))
+ type_checks.check_type(3, (int, float, TestObject))
+
+ with self.subTest('custom_class'):
+ test_obj = TestObject()
+ type_checks.check_type(test_obj, object) # Also true for parent classes.
+ type_checks.check_type(test_obj, TestObject)
+ type_checks.check_type(test_obj, (object, TestObject))
+ type_checks.check_type(test_obj, (int, TestObject))
+
+ def test_check_type_fails(self):
+ with self.subTest('int'):
+ with self.assertRaises(TypeError):
+ type_checks.check_type(3, float)
+ with self.assertRaises(TypeError):
+ type_checks.check_type(3, (float, TestObject))
+
+ with self.subTest('custom_class'):
+ test_obj = TestObject()
+ with self.assertRaises(TypeError):
+ type_checks.check_type(test_obj, TestObject2)
+ with self.assertRaises(TypeError):
+ type_checks.check_type(test_obj, int)
+
+ def test_check_type_failure_message_with_name(self):
+ with self.assertRaisesRegex(TypeError, r'\bfoo\b'):
+ type_checks.check_type(3, float, name='foo')
+
+ def test_check_dataset(self):
+ # Should not raise
+ type_checks.check_dataset(tf.data.Dataset.from_tensors([42]))
+ type_checks.check_dataset(tf.compat.v1.data.Dataset.from_tensors([42]))
+ type_checks.check_dataset(tf.compat.v2.data.Dataset.from_tensors([42]))
+
+ with self.assertRaisesWithLiteralMatch(
+ TypeError,
+ 'Expected argument to be a Dataset; but found an instance of int.',
+ ):
+ type_checks.check_dataset(1234)
+
+ def test_check_dataset_failure_message_with_name(self):
+ with self.assertRaisesRegex(TypeError, r'\bfoo\b'):
+ type_checks.check_dataset(3, name='foo')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/artifact_building/variable_helpers.py b/fcp/artifact_building/variable_helpers.py
new file mode 100644
index 0000000..e4e5d04
--- /dev/null
+++ b/fcp/artifact_building/variable_helpers.py
@@ -0,0 +1,460 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Helper methods for TensorFlow variables."""
+
+from typing import Optional, Union
+
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import tensor_utils
+from fcp.artifact_building import type_checks
+
+# TFF types allowed for variables created at input/output serialization
+# boundaries.
+AllowedTffTypes = Union[tff.TensorType, tff.StructType, tff.FederatedType]
+
+
+# The prefix for the name of the sidechannel for a securely-summed variable.
+#
+# This transformed name is used as the name of the Op which *reads* from the
+# variable, rather than identifies the variable itself. Names with this prefix
+# are used as the keys in the `side_channel_tensors` map entries corresponding
+# with the variable of the unprefixed name.
+SIDECHANNEL_NAME_PREFIX = 'sidechannel_'
+
+# `variable_names_from_type` returns the `name` argument of `tf.Variable()`.
+# However when the variable is created, the name of its tensor is actually
+# `<name>:0`. This macro is created to match this behavior.
+_TF_TENSOR_NAME_SUFFIX = ':0'
+
+
+def _create_var_for_tff_tensor(
+ tff_type: tff.TensorType, name: str, **kwargs
+) -> tf.Variable:
+ """Creates a TensorFlow variable to hold a value of the `tff.TensorType`."""
+ type_checks.check_type(tff_type, tff.TensorType)
+ type_checks.check_type(name, str)
+ # `tff_type` can have shapes that contain `None` or `0`:
+ # * `None` shape cannot be used in `tf.zeros` to create the initial value
+ # of a `tf.Variable`. Hence, we replace it with a `0` in `tf.zeros`.
+ # * The dimension that has `0` shape may change its shape at run time. To
+ # support this, we use `None` for that dimension when creating the
+ # `tf.Variable`.
+ initial_value_shape = []
+ variable_shape = []
+ for shape in tff_type.shape.as_list():
+ if shape is None or shape == 0:
+ initial_value_shape.append(0)
+ variable_shape.append(None)
+ else:
+ initial_value_shape.append(shape)
+ variable_shape.append(shape)
+ return tf.Variable(
+ initial_value=tf.zeros(shape=initial_value_shape, dtype=tff_type.dtype),
+ name=name,
+ dtype=tff_type.dtype,
+ shape=variable_shape,
+ **kwargs,
+ )
+
+
+# Build the TensorSpec for the values we will send to the client so that the
+# client graph will know how to read the incoming values.
+def tensorspec_from_var(var: tf.Variable) -> tf.TensorSpec:
+ """Builds `tf.TensorSpec` from `tf.Variables`.
+
+ Args:
+ var: An instance of `tf.Variable`.
+
+ Returns:
+ An instance of `tf.TensorSpec` corresponding to the input `tf.Variable`.
+ """
+ return tf.TensorSpec(
+ shape=var.shape, dtype=var.dtype, name=tensor_utils.bare_name(var.name)
+ )
+
+
+def create_vars_for_tff_type(
+ tff_type: AllowedTffTypes, name: Optional[str] = None, **kwargs
+) -> list[tf.Variable]:
+ """Creates TensorFlow variables to hold a value of the given `tff_type`.
+
+ The variables are created in the default graph and scope. The variables are
+ automatically given `tf.zeros` initializers.
+
+ Args:
+ tff_type: Either a `tff.StructType`, SERVER-placed `tff.FederatedType` or a
+ `tff.TensorType` object.
+ name: The preferred name to use at the top-most level (if not None, must be
+ a string). If `tff_type` is a `tff.StructType`, the names of the inner
+ fields will be scoped under `name`, e.g. `some_name/field_name`.
+ **kwargs: Optional arguments, if any, to pass to the `tf.Variable()` calls.
+
+ Returns:
+ A flat Python `list` of TensorFlow variable instances.
+
+ Raises:
+ TypeError: If the argument is of the wrong type or has the wrong placement.
+ """
+ type_checks.check_type(
+ tff_type,
+ (tff.TensorType, tff.StructType, tff.FederatedType),
+ name='tff_type',
+ )
+ if name is not None:
+ type_checks.check_type(name, str)
+ else:
+ name = 'v'
+ if isinstance(tff_type, tff.TensorType):
+ return [_create_var_for_tff_tensor(tff_type, name, **kwargs)]
+ elif isinstance(tff_type, tff.FederatedType):
+ if tff_type.placement != tff.SERVER:
+ raise TypeError(
+ 'Can only create vars for unplaced types or types placed '
+ 'on the SERVER.'
+ )
+ return create_vars_for_tff_type(tff_type.member, name, **kwargs)
+ else: # tff.StructType
+ result = []
+ with tf.compat.v1.variable_scope(name):
+ fields = tff.structure.to_elements(tff_type)
+ for index, (field_name, field_type) in enumerate(fields):
+ # Default the name of the element to its index so that we don't wind up
+ # with multiple child fields listed under `/v/`
+ if field_name is None:
+ field_name = str(index)
+ result.extend(
+ create_vars_for_tff_type(field_type, name=field_name, **kwargs)
+ )
+ return result
+
+
+def variable_names_from_type(
+ tff_type: AllowedTffTypes, name: str = 'v'
+) -> list[str]:
+ """Creates a flattened list of variables names for the given `tff_type`.
+
+ If `tff_type` is a `tff.TensorType`, the name is the `name` parameter if
+ specified, otherwise a default name: `v`. If `tff_type` is a
+ `tff.StructType` then '/' is used between inner and outer fields together
+ with the tuple name or index of the element in the tuple.
+
+ Some examples:
+ 1. If the tff_type is `<'a'=tf.int32, 'b'=tf.int32>` and `name` is not
+ specified, the returned variable name list is ['v/a', 'v/b'].
+ 2. If the tff_type is `<tf.int32, tf.int32>` and `name` is `update`, the
+ returned variable name list is ['update/0', 'update/1'].
+ 3. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32>>` and `name` is
+ `update`, the returned variable name list is ['update/a/b', 'update/a/c'].
+ 4. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32, tf.int32>>` and
+ `name` is `update`, the returned variable name list is ['update/a/b',
+ 'update/a/c', 'update/a/2'].
+
+ Args:
+ tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a
+ `tff.TensorType` object.
+ name: The preferred name to use at the top-most level (if not None, must be
+ a string). If `tff_type` is a `tff.StructType`, the names of the inner
+ fields will be scoped under `name`, e.g. `some_name/field_name`.
+
+ Returns:
+ A flat Python `list` of `str` names.
+
+ Raises:
+ TypeError: If the argument is of the wrong type.
+ """
+ type_checks.check_type(
+ tff_type,
+ (tff.TensorType, tff.FederatedType, tff.StructType),
+ name='tff_type',
+ )
+ type_checks.check_type(name, str, name='name')
+ if isinstance(tff_type, tff.TensorType):
+ return [name]
+ elif isinstance(tff_type, tff.FederatedType):
+ return variable_names_from_type(tff_type.member, name)
+ elif isinstance(tff_type, tff.StructType):
+ result = []
+ fields = tff.structure.iter_elements(tff_type)
+ for index, (field_name, field_type) in enumerate(fields):
+ # Default the name of the element to its index so that we don't wind up
+ # with multiple child fields listed under `/v/`
+ field_name = field_name or str(index)
+ result.extend(
+ variable_names_from_type(field_type, name=name + '/' + field_name)
+ )
+ return result
+ else:
+ raise TypeError(
+ 'Cannot create variable names from [{t}] TFF type. '
+ 'Short-hand: {s}'.format(t=type(tff_type), s=tff_type)
+ )
+
+
+def get_shared_secagg_tensor_names(
+ intrinsic_name: str, tff_type: AllowedTffTypes
+) -> list[str]:
+ """Creates the shared name of secagg tensors in client and server graph.
+
+ This is the canonical function for ensuring the secagg tensor names in the
+ client and server graph are the same. The server uses secagg tensor
+ names as the keys to retrieve values from secagg server which are originally
+ from client graph, so if the secagg tensor names in the client and server
+ graph are not the same, the server could not find secagg tensors. This
+ function is created to ensure this implicit dependency.
+
+ Args:
+ intrinsic_name: The name of the secure aggregation intrinsic being used.
+ tff_type: Either a `tff.StructType`, `tff.FederatedType` or a
+ `tff.TensorType` object.
+
+ Returns:
+ A list of variable names created from the input TFF type.
+ """
+ tensor_names = variable_names_from_type(
+ tff_type, f'secagg_{intrinsic_name}_update'
+ )
+ return [
+ SIDECHANNEL_NAME_PREFIX + name + _TF_TENSOR_NAME_SUFFIX
+ for name in tensor_names
+ ]
+
+
+def get_flattened_tensor_specs(
+ tff_type: AllowedTffTypes, name: str
+) -> list[tf.TensorSpec]:
+ """Generates TensorSpecs for a flattened version of the given `tff_type`.
+
+ This function uses the same naming logic as `variable_names_from_type`. Please
+ see that function's docstring.
+
+ Args:
+ tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a
+ `tff.TensorType` object.
+ name: The preferred name to use at the top-most level (if not None, must be
+ a string). If `tff_type` is a `tff.StructType`, the names of the inner
+ fields will be scoped under `name`, e.g. `some_name/field_name`.
+
+ Returns:
+ A flat Python `list` of `TensorSpec`s.
+
+ Raises:
+ TypeError: If the argument is of the wrong type.
+ """
+ type_checks.check_type(
+ tff_type,
+ (tff.TensorType, tff.FederatedType, tff.StructType),
+ name='tff_type',
+ )
+ type_checks.check_type(name, str, name='name')
+ if isinstance(tff_type, tff.TensorType):
+ return [tf.TensorSpec(tff_type.shape, tff_type.dtype, name=name)]
+ elif isinstance(tff_type, tff.FederatedType):
+ return get_flattened_tensor_specs(tff_type.member, name)
+ elif isinstance(tff_type, tff.StructType):
+ result = []
+ fields = tff.structure.iter_elements(tff_type)
+ for index, (field_name, field_type) in enumerate(fields):
+ # Default the name of the element to its index so that we don't wind up
+ # with multiple child fields listed under `/v/`
+ field_name = field_name or str(index)
+ result.extend(
+ get_flattened_tensor_specs(field_type, name=name + '/' + field_name)
+ )
+ return result
+ else:
+ raise TypeError(
+ 'Cannot create TensorSpecs from [{t}] TFF type. Short-hand: {s}'.format(
+ t=type(tff_type), s=tff_type
+ )
+ )
+
+
+def get_grouped_input_tensor_specs_for_aggregations(
+ aggregation_comp: tff.framework.ComputationBuildingBlock,
+ names: dict[int, str],
+) -> list[list[list[tf.TensorSpec]]]:
+ """Gets the input TensorSpecs for an aggregation computation.
+
+ This function can be used to generate the TensorSpecs that are assigned to
+ ServerAggregationConfig.IntrinsicArg messages to represent the aggregation
+ intrinsic calls in DistributeAggregateForm.client_to_server_aggregation.
+
+ It derives the tensor name(s) for each intrinsic input argument by following
+ naming logic similar to `variable_names_from_type`. DistributeAggregateForm
+ does guarantee that each intrinsic input argument will be a
+ `building_block.Selection` or a (potentially nested) struct of
+ `building_block.Selection`s. The first element of the path is used to
+ determine the top-level name, which must match the top-level name that was
+ used to construct the tensor that will be getting consumed by this argument.
+
+ Args:
+ aggregation_comp: The aggregation computation.
+ names: A dictionary describing how to map the first element of the path to a
+ top-level name.
+
+ Returns:
+ A `list` where the ith entry represents the input tensor specs for the
+ ith intrinsic in the aggregation computation. The ith entry is itself a list
+ where the jth entry represents the input tensor specs for the jth argument
+ of the ith intrinsic in the aggregation computation.
+
+ Raises:
+ TypeError: If the argument is of the wrong type.
+ ValueError: If the argument contains an unexpected
+ `building_block.Selection` index.
+ """
+
+ def _get_selection_path(
+ selection: tff.framework.ComputationBuildingBlock,
+ ) -> list[int]:
+ """Gets the list of selection indices for a building_blocks.Selection."""
+
+ path = []
+ while selection.is_selection():
+ path.append(selection.index) # pytype: disable=attribute-error
+ selection = selection.source # pytype: disable=attribute-error
+ # In ASTs like x[0][1], we'll see the last (outermost) selection first.
+ path.reverse()
+ return path
+
+ def _get_input_tensor_specs_for_aggregation_arg(
+ value: tff.framework.ComputationBuildingBlock, names: dict[int, str]
+ ) -> list[tf.TensorSpec]:
+ """Gets the input TensorSpecs for a single intrinsic argument."""
+
+ # An intrinsic arg may be a `building_block.Selection` or a (potentially
+ # nested) struct of `building_block.Selection`s. Start by creating a
+ # flattened list of the `building_block.Selection`s.
+ inner_values = []
+ if value.is_struct():
+ inner_values = tff.structure.flatten(value)
+ else:
+ inner_values = [value]
+
+ # For each `building_block.Selection`, reconstruct the tensor name that
+ # will be used to supply that value. The first index of the selection path
+ # indicates whether the tensor will be coming from the intermediate state
+ # checkpoint (0) or from the client checkpoint (1), since TFF condenses
+ # daf.client_to_server_aggregation(temp_server_state, client_update)
+ # into a 1-arg function. Since the tensors within the checkpoints
+ # corresponding to temp_server_state and work_at_clients will be named using
+ # variable_names_from_type, which uses a simple filepath-like naming pattern
+ # to refer to the tensors within a struct, we can reconstruct the relevant
+ # tensor name by concatenating together the remaining indices of each
+ # selection path.
+ tensor_specs = []
+ for inner_value in inner_values:
+ inner_value.check_selection()
+ path = _get_selection_path(inner_value)
+ arg_index = path[0]
+ if arg_index in names:
+ prefix = names[arg_index]
+ else:
+ raise ValueError('Unexpected arg index for aggregation selection')
+ prefix += '/' + '/'.join([str(x) for x in path[1:]])
+ tensor_specs.extend(
+ get_flattened_tensor_specs(inner_value.type_signature, name=prefix)
+ )
+
+ return tensor_specs
+
+ grouped_input_tensor_specs = []
+
+ for _, local_value in aggregation_comp.result.locals: # pytype: disable=attribute-error
+ local_value.check_call()
+ local_value.function.check_intrinsic()
+ assert local_value.function.intrinsic_def().aggregation_kind
+
+ # Collect the input TensorFlowSpecs for each argument for this intrinsic.
+ input_tensor_specs_for_intrinsic = []
+ if (
+ local_value.function.intrinsic_def().type_signature.parameter.is_struct()
+ ):
+ for element in local_value.argument.children():
+ input_tensor_specs_for_intrinsic.append(
+ _get_input_tensor_specs_for_aggregation_arg(element, names)
+ )
+ else:
+ input_tensor_specs_for_intrinsic.append(
+ _get_input_tensor_specs_for_aggregation_arg(
+ local_value.argument, names
+ )
+ )
+
+ grouped_input_tensor_specs.append(input_tensor_specs_for_intrinsic)
+
+ return grouped_input_tensor_specs
+
+
+def get_grouped_output_tensor_specs_for_aggregations(
+ aggregation_comp: tff.framework.ComputationBuildingBlock,
+) -> list[list[tf.TensorSpec]]:
+ """Gets the output TensorSpecs for an aggregation computation.
+
+ This function can be used to generate the TensorSpecs that are assigned
+ to the output_tensors value in ServerAggregationConfig messages to represent
+ the aggregation intrinsic calls in
+ DistributeAggregateForm.client_to_server_aggregation.
+
+ It derives the tensor name(s) for each intrinsic output argument by following
+ naming logic similar to `variable_names_from_type`. It must produce tensor
+ names that match the tensor names that are expected by the post-aggregation
+ computation.
+
+ Args:
+ aggregation_comp: The aggregation computation.
+
+ Returns:
+ A list where the ith entry represents the output tensor specs for the ith
+ intrinsic in the aggregation computation.
+
+ Raises:
+ TypeError: If the argument is of the wrong type.
+ """
+ # TensorflowSpecs for all the intrinsic results. These TensorflowSpecs must
+ # have names that mirror the result of calling variable_names_from_type on
+ # the output type of DistributeAggregateForm.client_to_server_aggregation
+ # (which is the same as the type of the aggregation result input arg in
+ # DistributeAggregateForm.server_result).
+ output_tensor_specs = get_flattened_tensor_specs(
+ tff.StructType([aggregation_comp.type_signature.result]),
+ name='intermediate_update',
+ )
+ output_tensor_spec_index = 0
+
+ grouped_output_tensor_specs = []
+
+ for _, local_value in aggregation_comp.result.locals: # pytype: disable=attribute-error
+ local_value.check_call()
+ local_value.function.check_intrinsic()
+ local_value.type_signature.check_federated()
+ assert local_value.function.intrinsic_def().aggregation_kind
+
+ tensor_specs = []
+ # If the output is a struct, select the appropriate number of
+ # TensorflowSpecs.
+ if local_value.type_signature.member.is_struct():
+ num_specs = len(tff.structure.flatten(local_value.type_signature.member))
+ tensor_specs = output_tensor_specs[
+ output_tensor_spec_index : output_tensor_spec_index + num_specs
+ ]
+ output_tensor_spec_index += num_specs
+ else:
+ tensor_specs.append(output_tensor_specs[output_tensor_spec_index])
+ output_tensor_spec_index += 1
+ grouped_output_tensor_specs.append(tensor_specs)
+
+ return grouped_output_tensor_specs
diff --git a/fcp/artifact_building/variable_helpers_test.py b/fcp/artifact_building/variable_helpers_test.py
new file mode 100644
index 0000000..1857448
--- /dev/null
+++ b/fcp/artifact_building/variable_helpers_test.py
@@ -0,0 +1,328 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for variable_helpers.py."""
+
+from absl.testing import absltest
+
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import artifact_constants
+from fcp.artifact_building import variable_helpers
+
+
+@tff.federated_computation(
+ tff.type_at_server(tf.int32), tff.type_at_clients(tf.float32)
+)
+def sample_comp(x, y):
+ a = tff.federated_broadcast(x)
+ output1 = tff.federated_secure_sum_bitwidth(a, 5)
+ output2 = tff.federated_mean([y, y], y)
+ return output1, output2
+
+
+class VariableHelpersTest(absltest.TestCase):
+
+ def test_create_vars_for_tff_type(self):
+ with tf.Graph().as_default():
+ vl = variable_helpers.create_vars_for_tff_type(
+ tff.to_type(
+ [('a', tf.int32), ('b', [('c', tf.bool), ('d', tf.float32)])]
+ ),
+ 'x',
+ )
+ self.assertLen(vl, 3)
+ for v in vl:
+ self.assertTrue(type(v).__name__.endswith('Variable'))
+ self.assertEqual(v.shape.ndims, 0)
+ self.assertEqual([v.dtype for v in vl], [tf.int32, tf.bool, tf.float32])
+ self.assertEqual([v.name for v in vl], ['x/a:0', 'x/b/c:0', 'x/b/d:0'])
+
+ def test_create_vars_for_tff_type_with_none_and_zero_shape(self):
+ with tf.Graph().as_default():
+ vl = variable_helpers.create_vars_for_tff_type(
+ tff.TensorType(dtype=tf.int32, shape=[5, None, 0])
+ )
+ self.assertLen(vl, 1)
+ test_variable = vl[0]
+ self.assertEqual(test_variable.initial_value.shape.as_list(), [5, 0, 0])
+ self.assertEqual(test_variable.shape.as_list(), [5, None, None])
+
+ def test_create_vars_for_tff_federated_type(self):
+ tff_type = tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER)
+ with tf.Graph().as_default():
+ vl = variable_helpers.create_vars_for_tff_type(tff_type)
+
+ self.assertLen(vl, 1)
+ v = vl[0]
+ self.assertTrue(type(v).__name__.endswith('Variable'))
+ self.assertEqual(v.shape.ndims, 0)
+ self.assertEqual(v.dtype, tf.int32)
+ self.assertEqual(v.name, 'v:0')
+
+ def test_create_vars_for_struct_of_tff_federated_types(self):
+ tff_type = tff.StructType([
+ (
+ 'num_examples_secagg',
+ tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER),
+ ),
+ (
+ 'num_examples_simpleagg',
+ tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER),
+ ),
+ ])
+ with tf.Graph().as_default():
+ vl = variable_helpers.create_vars_for_tff_type(tff_type)
+
+ self.assertLen(vl, 2)
+ for v in vl:
+ self.assertTrue(type(v).__name__.endswith('Variable'))
+ self.assertEqual(v.shape.ndims, 0)
+ self.assertEqual([v.dtype for v in vl], [tf.int32, tf.int32])
+ self.assertEqual(
+ [v.name for v in vl],
+ ['v/num_examples_secagg:0', 'v/num_examples_simpleagg:0'],
+ )
+
+ def test_create_vars_fails_for_client_placed_type(self):
+ tff_type = tff.FederatedType(tff.TensorType(tf.int32), tff.CLIENTS)
+ with self.assertRaisesRegex(TypeError, 'Can only create vars'):
+ with tf.Graph().as_default():
+ _ = variable_helpers.create_vars_for_tff_type(tff_type)
+
+ def test_create_vars_fails_for_struct_with_client_placed_type(self):
+ tff_type = tff.StructType([
+ (
+ 'num_examples_secagg',
+ tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER),
+ ),
+ (
+ 'num_examples_simpleagg',
+ tff.FederatedType(tff.TensorType(tf.int32), tff.CLIENTS),
+ ),
+ ])
+ with self.assertRaisesRegex(TypeError, 'Can only create vars'):
+ with tf.Graph().as_default():
+ _ = variable_helpers.create_vars_for_tff_type(tff_type)
+
+ def test_variable_names_from_type_with_tensor_type_and_no_name(self):
+ names = variable_helpers.variable_names_from_type(
+ tff.TensorType(dtype=tf.int32)
+ )
+ self.assertEqual(names, ['v'])
+
+ def test_variable_names_from_type_with_tensor_type(self):
+ names = variable_helpers.variable_names_from_type(
+ tff.TensorType(dtype=tf.int32), 'test_name'
+ )
+ self.assertEqual(names, ['test_name'])
+
+ def test_variable_names_from_type_with_federated_type(self):
+ names = variable_helpers.variable_names_from_type(
+ tff.FederatedType(tff.TensorType(dtype=tf.int32), tff.SERVER),
+ 'test_name',
+ )
+ self.assertEqual(names, ['test_name'])
+
+ def test_variable_names_from_type_with_named_tuple_type_and_no_name(self):
+ names = variable_helpers.variable_names_from_type(
+ tff.to_type(
+ [('a', tf.int32), ('b', [('c', tf.bool), ('d', tf.float32)])]
+ )
+ )
+ self.assertEqual(names, ['v/a', 'v/b/c', 'v/b/d'])
+
+ def test_variable_names_from_type_with_named_tuple_type(self):
+ names = variable_helpers.variable_names_from_type(
+ tff.to_type(
+ [('a', tf.int32), ('b', [('c', tf.bool), ('d', tf.float32)])]
+ ),
+ 'test_name',
+ )
+ self.assertEqual(names, ['test_name/a', 'test_name/b/c', 'test_name/b/d'])
+
+ def test_variable_names_from_type_with_named_tuple_type_no_name_field(self):
+ names = variable_helpers.variable_names_from_type(
+ tff.to_type([(tf.int32), ('b', [(tf.bool), ('d', tf.float32)])]),
+ 'test_name',
+ )
+ self.assertEqual(names, ['test_name/0', 'test_name/b/0', 'test_name/b/d'])
+
+ def test_get_flattened_tensor_specs_with_tensor_type(self):
+ specs = variable_helpers.get_flattened_tensor_specs(
+ tff.TensorType(dtype=tf.int32, shape=tf.TensorShape([3, 5])),
+ 'test_name',
+ )
+ self.assertEqual(
+ specs,
+ [
+ tf.TensorSpec(
+ name='test_name',
+ shape=tf.TensorShape([3, 5]),
+ dtype=tf.int32,
+ )
+ ],
+ )
+
+ def test_get_flattened_tensor_specs_with_federated_type(self):
+ specs = variable_helpers.get_flattened_tensor_specs(
+ tff.FederatedType(
+ tff.TensorType(dtype=tf.int32, shape=tf.TensorShape([3, 5])),
+ tff.SERVER,
+ ),
+ 'test_name',
+ )
+ self.assertEqual(
+ specs,
+ [
+ tf.TensorSpec(
+ name='test_name',
+ shape=tf.TensorShape([3, 5]),
+ dtype=tf.int32,
+ )
+ ],
+ )
+
+ def test_get_flattened_tensor_specs_with_tuple_type(self):
+ specs = variable_helpers.get_flattened_tensor_specs(
+ tff.StructType([
+ (
+ 'a',
+ tff.TensorType(dtype=tf.int32, shape=tf.TensorShape([3, 5])),
+ ),
+ (
+ 'b',
+ tff.StructType([
+ (tff.TensorType(dtype=tf.bool, shape=tf.TensorShape([4]))),
+ (
+ 'd',
+ tff.TensorType(
+ dtype=tf.float32,
+ shape=tf.TensorShape([1, 3, 5]),
+ ),
+ ),
+ ]),
+ ),
+ ]),
+ 'test_name',
+ )
+ self.assertEqual(
+ specs,
+ [
+ tf.TensorSpec(
+ name='test_name/a',
+ shape=tf.TensorShape([3, 5]),
+ dtype=tf.int32,
+ ),
+ tf.TensorSpec(
+ name='test_name/b/0',
+ shape=tf.TensorShape([4]),
+ dtype=tf.bool,
+ ),
+ tf.TensorSpec(
+ name='test_name/b/d',
+ shape=tf.TensorShape([1, 3, 5]),
+ dtype=tf.float32,
+ ),
+ ],
+ )
+
+ def test_get_grouped_input_tensor_specs_for_aggregations(self):
+ daf = tff.backends.mapreduce.get_distribute_aggregate_form_for_computation(
+ sample_comp
+ )
+ grouped_input_tensor_specs = variable_helpers.get_grouped_input_tensor_specs_for_aggregations(
+ daf.client_to_server_aggregation.to_building_block(),
+ artifact_constants.AGGREGATION_INTRINSIC_ARG_SELECTION_INDEX_TO_NAME_DICT,
+ )
+ self.assertEqual(
+ grouped_input_tensor_specs,
+ [
+ [ # federated_weighted_mean intrinsic args
+ [ # federated_weighted_mean value arg
+ tf.TensorSpec(
+ name='update/0/0',
+ shape=tf.TensorShape([]),
+ dtype=tf.float32,
+ ),
+ tf.TensorSpec(
+ name='update/0/1',
+ shape=tf.TensorShape([]),
+ dtype=tf.float32,
+ ),
+ ],
+ [ # federated_weighted_mean weight arg
+ tf.TensorSpec(
+ name='update/1',
+ shape=tf.TensorShape([]),
+ dtype=tf.float32,
+ )
+ ],
+ ],
+ [ # federated_secure_sum_bitwidth intrinsic args
+ [ # federated_secure_sum_bitwidth value arg
+ tf.TensorSpec(
+ name='update/2',
+ shape=tf.TensorShape([]),
+ dtype=tf.int32,
+ )
+ ],
+ [ # federated_secure_sum_bitwidth bitwidth arg
+ tf.TensorSpec(
+ name='intermediate_state/0',
+ shape=tf.TensorShape([]),
+ dtype=tf.int32,
+ )
+ ],
+ ],
+ ],
+ )
+
+ def test_get_grouped_output_tensor_specs_for_aggregations(self):
+ daf = tff.backends.mapreduce.get_distribute_aggregate_form_for_computation(
+ sample_comp
+ )
+ grouped_output_tensor_specs = (
+ variable_helpers.get_grouped_output_tensor_specs_for_aggregations(
+ daf.client_to_server_aggregation.to_building_block()
+ )
+ )
+ self.assertEqual(
+ grouped_output_tensor_specs,
+ [
+ [ # federated_weighted_mean intrinsic output
+ tf.TensorSpec(
+ name='intermediate_update/0/0/0',
+ shape=tf.TensorShape([]),
+ dtype=tf.float32,
+ ),
+ tf.TensorSpec(
+ name='intermediate_update/0/0/1',
+ shape=tf.TensorShape([]),
+ dtype=tf.float32,
+ ),
+ ],
+ [ # federated_secure_sum_bitwidth intrinsic output
+ tf.TensorSpec(
+ name='intermediate_update/0/1',
+ shape=tf.TensorShape([]),
+ dtype=tf.int32,
+ )
+ ],
+ ],
+ )
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/base/BUILD b/fcp/base/BUILD
new file mode 100644
index 0000000..6f2cc37
--- /dev/null
+++ b/fcp/base/BUILD
@@ -0,0 +1,556 @@
+# Description:
+# Base component, containing common functionality used by other FCP components.
+
+load("//fcp:config.bzl", "FCP_BAREMETAL_COPTS", "FCP_COPTS")
+load("//fcp/tracing:build_defs.bzl", "tracing_schema_cc_library")
+
+package(
+ default_visibility = ["//fcp:internal"],
+)
+
+tracing_schema_cc_library(
+ name = "tracing_schema",
+ srcs = ["tracing_schema.fbs"],
+)
+
+# Used to detect when we're building for Android, using select().
+config_setting(
+ name = "android",
+ values = {"crosstool_top": "//external:android/crosstool"},
+)
+
+cc_library(
+ name = "base",
+ srcs = [
+ "base_name.cc",
+ "monitoring.cc",
+ "platform.cc",
+ ],
+ hdrs = [
+ "base_name.h",
+ "monitoring.h",
+ "move_to_lambda.h",
+ "new.h",
+ "platform.h",
+ ],
+ copts = FCP_COPTS,
+ linkopts = select({
+ ":android": [
+ # For accessing Android's native logging APIs.
+ "-llog",
+ ],
+ "//conditions:default": [],
+ }),
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/base:log_severity",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+# TODO(team): Consider using configuration profiles to merge "base_baremetal" with "base"
+cc_library(
+ name = "baremetal_base",
+ srcs = [
+ "base_name.cc",
+ "monitoring.cc",
+ "string_stream.cc",
+ ],
+ hdrs = [
+ "base_name.h",
+ "monitoring.h",
+ "new.h",
+ "string_stream.h",
+ ],
+ copts = FCP_BAREMETAL_COPTS,
+ features = ["-use_header_modules"],
+ linkstatic = True,
+)
+
+cc_library(
+ name = "bounds",
+ srcs = [
+ ],
+ hdrs = [
+ "bounds.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [":base"],
+)
+
+cc_test(
+ name = "bounds_test",
+ srcs = [
+ "bounds_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":bounds",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "meta",
+ srcs = [
+ ],
+ hdrs = [
+ "meta.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [":base"],
+)
+
+cc_test(
+ name = "meta_test",
+ srcs = [
+ "meta_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":meta",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "reentrancy_guard",
+ hdrs = [
+ "reentrancy_guard.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [":base"],
+)
+
+cc_test(
+ name = "reentrancy_guard_test",
+ srcs = [
+ "reentrancy_guard_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ ":reentrancy_guard",
+ ":scheduler",
+ "//fcp/testing",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "random_token",
+ srcs = [
+ "random_token.cc",
+ ],
+ hdrs = [
+ "random_token.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ "@boringssl//:crypto",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_test(
+ name = "random_token_test",
+ srcs = [
+ "random_token_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":random_token",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/hash:hash_testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "future",
+ srcs = [
+ "future.cc",
+ ],
+ hdrs = [
+ "future.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ ":meta",
+ ":scheduler",
+ ":unique_value",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_test(
+ name = "future_test",
+ srcs = ["future_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ ":future",
+ ":meta",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "monitoring_test",
+ size = "small",
+ srcs = [
+ "monitoring_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ "@com_google_absl//absl/base:log_severity",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "baremetal_monitoring_test",
+ size = "small",
+ srcs = [
+ "monitoring_test.cc",
+ ],
+ copts = FCP_COPTS,
+ local_defines = ["FCP_BAREMETAL"],
+ deps = [
+ ":baremetal_base",
+ "@com_google_absl//absl/base:log_severity",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "baremetal_string_stream_test",
+ size = "small",
+ srcs = [
+ "string_stream_test.cc",
+ ],
+ copts = FCP_COPTS,
+ local_defines = ["FCP_BAREMETAL"],
+ deps = [
+ ":baremetal_base",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "platform_test",
+ size = "small",
+ srcs = [
+ "platform_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ "//fcp/testing",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "scheduler",
+ srcs = [
+ "scheduler.cc",
+ ],
+ hdrs = [
+ "scheduler.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_test(
+ name = "scheduler_test",
+ size = "small",
+ srcs = [
+ "scheduler_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ ":scheduler",
+ "//fcp/testing",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "move_to_lambda_test",
+ size = "small",
+ srcs = [
+ "move_to_lambda_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ ":unique_value",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "unique_value",
+ srcs = [
+ ],
+ hdrs = [
+ "unique_value.h",
+ ],
+ copts = FCP_COPTS,
+)
+
+cc_test(
+ name = "unique_value_test",
+ srcs = [
+ "unique_value_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ ":unique_value",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "error",
+ hdrs = [
+ "error.h",
+ ],
+ copts = FCP_COPTS,
+)
+
+cc_library(
+ name = "result",
+ srcs = ["result.cc"],
+ hdrs = [
+ "result.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":error",
+ ":meta",
+ ":source_location",
+ ":tracing_schema",
+ "//fcp/tracing",
+ ],
+)
+
+cc_library(
+ name = "status_converters",
+ srcs = ["status_converters.cc"],
+ hdrs = ["status_converters.h"],
+ deps = [
+ ":base",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/status",
+ ],
+)
+
+cc_test(
+ name = "result_test",
+ srcs = [
+ "result_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":result",
+ ":tracing_schema",
+ ":unique_value",
+ "//fcp/testing",
+ "//fcp/testing:result_matchers",
+ "//fcp/tracing:test_tracing_recorder",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "source_location",
+ srcs = [
+ ],
+ hdrs = [
+ "source_location.h",
+ ],
+ copts = FCP_COPTS,
+)
+
+cc_test(
+ name = "source_location_test",
+ srcs = [
+ "source_location_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":source_location",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "match",
+ srcs = [
+ ],
+ hdrs = [
+ "match.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [":meta"],
+)
+
+cc_test(
+ name = "match_test",
+ srcs = [
+ "match_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":match",
+ ":result",
+ "//fcp/testing",
+ "//fcp/testing:result_matchers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "process_unique_id",
+ srcs = ["process_unique_id.cc"],
+ hdrs = ["process_unique_id.h"],
+)
+
+cc_test(
+ name = "process_unique_id_test",
+ srcs = ["process_unique_id_test.cc"],
+ deps = [
+ ":process_unique_id",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "clock",
+ srcs = [
+ "clock.cc",
+ ],
+ hdrs = [
+ "clock.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "simulated_clock",
+ testonly = 1,
+ srcs = [
+ "simulated_clock.cc",
+ ],
+ hdrs = [
+ "simulated_clock.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":clock",
+ ],
+)
+
+cc_test(
+ name = "simulated_clock_test",
+ srcs = [
+ "simulated_clock_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":simulated_clock",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "realtime_clock_test",
+ srcs = [
+ "realtime_clock_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":clock",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "time_util",
+ srcs = ["time_util.cc"],
+ hdrs = ["time_util.h"],
+ deps = [
+ "@com_google_absl//absl/time",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "time_util_test",
+ srcs = ["time_util_test.cc"],
+ deps = [
+ ":time_util",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "wall_clock_stopwatch",
+ srcs = ["wall_clock_stopwatch.cc"],
+ hdrs = ["wall_clock_stopwatch.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":base",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_test(
+ name = "wall_clock_stopwatch_test",
+ srcs = ["wall_clock_stopwatch_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":clock",
+ ":scheduler",
+ ":wall_clock_stopwatch",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/base/base_name.cc b/fcp/base/base_name.cc
new file mode 100644
index 0000000..acc3d9a
--- /dev/null
+++ b/fcp/base/base_name.cc
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/base_name.h"
+
+#include <string.h>
+
+#include <cstdint>
+#include <string>
+
+namespace fcp {
+
+#ifdef _WIN32
+constexpr char kPathSeparator = '\\';
+#else
+constexpr char kPathSeparator = '/';
+#endif
+
+std::string BaseName(const std::string& path) {
+ // Note: the code below needs to be compatible with baremetal build with
+ // nanolibc. Therefore it is implemented via the standard "C" library strrchr.
+ const char* separator_ptr = strrchr(path.c_str(), kPathSeparator);
+ if (separator_ptr == nullptr) return path;
+
+ return path.substr((separator_ptr - path.c_str()) + 1);
+}
+
+} // namespace fcp
diff --git a/fcp/base/base_name.h b/fcp/base/base_name.h
new file mode 100644
index 0000000..c24d422
--- /dev/null
+++ b/fcp/base/base_name.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_BASE_NAME_H_
+#define FCP_BASE_BASE_NAME_H_
+
+#include <string>
+
+#ifndef FCP_BAREMETAL
+#include "absl/strings/string_view.h"
+#endif
+
+namespace fcp {
+
+/**
+ * Returns the file base name of a path.
+ */
+std::string BaseName(const std::string& path);
+
+#ifndef FCP_BAREMETAL
+inline std::string BaseName(const char* path) {
+ return BaseName(std::string(path));
+}
+
+inline std::string BaseName(absl::string_view path) {
+ return BaseName(std::string(path));
+}
+#endif
+
+} // namespace fcp
+
+#endif // FCP_BASE_BASE_NAME_H_
diff --git a/fcp/base/bounds.h b/fcp/base/bounds.h
new file mode 100644
index 0000000..6074369
--- /dev/null
+++ b/fcp/base/bounds.h
@@ -0,0 +1,180 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * This file defines safe operations related to bounds checking and size
+ * calculations - robust against overflow, etc.
+ */
+
+#ifndef FCP_BASE_BOUNDS_H_
+#define FCP_BASE_BOUNDS_H_
+
+#include <limits>
+#include <type_traits>
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+
+/**
+ * Attempts to cast 'from' to 'To'. Returns true (and sets *to) if the cast is
+ * lossless; otherwise, returns false without modifying *to.
+ *
+ * Examples:
+ * TryCastInteger<uint8_t, uint32_t>(1024, &to) => false
+ * static_cast<uint8_t>(1024) would yield 0
+ * TryCastInteger<uint8_t, uint32_t>(123, &to) => true
+ * TryCastInteger<uint32_t, int32_t>(123, &to) => true
+ * TryCastInteger<uint32_t, int32_t>(-1, &to) => false
+ */
+template <typename To, typename From>
+bool TryCastInteger(From from, To* to) {
+ static_assert(std::is_integral<To>::value && std::is_integral<From>::value,
+ "Both types must be integral");
+
+ if (std::is_signed<From>::value == std::is_signed<To>::value) {
+ // Same sign: Easy!
+ if (from < std::numeric_limits<To>::min() ||
+ from > std::numeric_limits<To>::max()) {
+ return false;
+ }
+ } else if (std::is_signed<From>::value && !std::is_signed<To>::value) {
+ // Signed => Unsigned: Widening conversion would sign-extend 'from' first;
+ // i.e. -1 would look larger than To's min(). Negative values are definitely
+ // out of range anyway. Positive values are effectively zero-extended,
+ // which is fine.
+ if (from < 0 || from > std::numeric_limits<To>::max()) {
+ return false;
+ }
+ } else {
+ // Unsigned => Signed: We don't want to mention min(), since widening
+ // conversion of min() would have the same problem as in the prior case.
+ if (from > std::numeric_limits<To>::max()) {
+ return false;
+ }
+ }
+
+ *to = static_cast<To>(from);
+ return true;
+}
+
+/**
+ * Casts from 'from' to 'To'. Check-fails if the cast is not lossless.
+ * See also: TryCastInteger
+ */
+template <typename To, typename From>
+To CastIntegerChecked(From from) {
+ To to;
+ FCP_CHECK(TryCastInteger(from, &to));
+ return to;
+}
+
+/** Multiplies without the possibility of overflow. */
+inline uint64_t SafeMultiply(uint32_t a, uint32_t b) {
+ return static_cast<uint64_t>(a) * static_cast<uint64_t>(b);
+}
+
+/**
+ * Represents an embedded address space as a pair of a starting address and a
+ * size. This is a correspondence of the addresses [0, size) <=> [start, start +
+ * size), for the embedded address space and this one, respectively.
+ *
+ * A ForeignPointer represents an address in an embedded address space (left).
+ * Given a ForeignPointer and a supposed size, one can use
+ * MapOutOfAddressSpace() to get a pointer in this address space (right),
+ * subject to bounds checking.
+ *
+ * We require that start + size does not overflow; this is convenient for bounds
+ * checks. Since start + size is the open part of the interval (one past the
+ * end), that happens to mean that ~0 cannot be in bounds (irrelevant in
+ * practice).
+ */
+struct AddressSpace {
+ void* start;
+ uint64_t size;
+
+ /**
+ * Returns a representation of the ambient, 'native' address space - it just
+ * starts at address zero and has maximum size. Note that the highest address
+ * (~0) is thus out of bounds.
+ */
+ static constexpr AddressSpace Current() {
+ return AddressSpace{nullptr, ~static_cast<uint64_t>(0)};
+ }
+
+ /**
+ * Returns an AddressSpace spanning mapping [0, size) <=> [start, start +
+ * size).
+ */
+ static AddressSpace Embedded(void* start, uint64_t size) {
+ uint64_t end;
+ FCP_CHECK(
+ !__builtin_add_overflow(reinterpret_cast<uint64_t>(start), size, &end));
+ return AddressSpace{start, size};
+ }
+};
+
+/**
+ * An address in some AddressSpace. It can be translated to a pointer in this
+ * address space with MapOutOfAddressSpace().
+ */
+struct ForeignPointer {
+ uint64_t value;
+};
+
+/**
+ * Translates a ForeignPointer out of an embedded AddressSpace, yielding a void*
+ * in this address space. The pointer is understood to refer to (size * count)
+ * bytes of memory (i.e. an array), as useful pointers tend to do.
+ *
+ * If that span does _not_ fully reside within the provided AddressSpace,
+ * returns nullptr. Otherwise, returns space.start + ptr.value.
+ *
+ * This function is intended to behave safely for arbitrary values of 'ptr',
+ * 'size', and 'count', perhaps provided by untrusted code. 'size' and 'count'
+ * are provided separately for this reason (to save the caller from worrying
+ * about multiplication overflow).
+ */
+inline void* MapOutOfAddressSpace(AddressSpace const& space, ForeignPointer ptr,
+ uint32_t size, uint32_t count) {
+ // Because the element size and count are each 32 bits, we can't overflow a 64
+ // bit total_size.
+ uint64_t total_size = SafeMultiply(size, count);
+
+ // The span in the embedded space is [ptr, ptr + total_size).
+ uint64_t ptr_end;
+ if (__builtin_add_overflow(ptr.value, total_size, &ptr_end)) {
+ return nullptr;
+ }
+
+ // The embedded address space ranges from [0, space.size). We know that ptr >=
+ // 0 and that ptr <= ptr_end, so it is sufficient to check that ptr_end <=
+ // space.size. Note that this allows ptr == space.size, iff total_size == 0.
+ if (ptr_end > space.size) {
+ return nullptr;
+ }
+
+ // AddressSpace requires that start + size does not overflow.
+ // - Since ptr_end <= space.size, space.start + ptr_end does not overflow.
+ // - Since ptr <= ptr_end, space.start + ptr does not overflow.
+ // Therefore, we can return the offset span space.start + [ptr, ptr_end).
+ return reinterpret_cast<void*>(reinterpret_cast<uint64_t>(space.start) +
+ ptr.value);
+}
+
+} // namespace fcp
+
+#endif // FCP_BASE_BOUNDS_H_
diff --git a/fcp/base/bounds_test.cc b/fcp/base/bounds_test.cc
new file mode 100644
index 0000000..639e1d7
--- /dev/null
+++ b/fcp/base/bounds_test.cc
@@ -0,0 +1,316 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/bounds.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace fcp {
+
+using ::testing::Eq;
+
+//
+// Now follows a bunch of support for the TryCastInteger test. That
+// implementation is very easy to get wrong, due to the sneaky conversion rules
+// imposed on us by C++ (nice example: (int32_t)-1 > (uint32_t)1).
+//
+// In fact, it's fairly easy for the same class of issues to manifest in the
+// test itself (or, even just mistaking the type of an integer literal).
+//
+// So, we approach this test like so:
+//
+// - Prevent unintended conversions in the test itself, and be explicit about
+// integer types. Recall that this tends to compile:
+// uint8_t u = -1; // Oh dear.
+//
+// But meanwhile, the rules for integer template arguments (maybe constant
+// expressions generally?) are much stricter.
+//
+// We introduce a macro to raise an integer literal into a template
+// argument. Subsequent attempts to cast it are checked at compile time with
+// the stricter rules (Clang, by default, complains under
+// -Wc++11-narrowing).
+//
+// LITERAL(-1)::As<uint8_t> // Error!
+// LITERAL(-1)::As<int8_t> // Good.
+//
+// - Be thorough: just test all the interesting combinations!
+//
+// We start by defining a set (AllIntTypes) of the common integer types to
+// test with: the signed and unsigned variants with sizes of 1, 2, 4, and 8
+// bytes.
+//
+// Then, we choose some interesting LITERALs to test: for example, -1, and
+// the min / max values for each type in AllIntTypes. For each LITERAL, we
+// write down the _expected_ set of types that can fully represent it. As a
+// shorthand, we work in terms of the _minimum_ sizes that work (a larger
+// type of the same signed-ness always works). Some examples:
+//
+// // Any signed type (i.e. >= 8 bits) can represent -1
+// VerifyIntCasts<LITERAL(-1), Signed<k8>>();
+// // The max value of int64_t also fits in uint64_t
+// VerifyIntCasts<MAX_LITERAL(int64_t), SignedOrUnsigned<k64>>();
+// // Nothing else can hold a uint64_t's max value
+// VerifyIntCasts<MAX_LITERAL(uint64_t), Unsigned<k64>>();
+//
+// (a LITERAL is safely cast to each type in its expected set at compile
+// time)
+//
+// For each literal, we test runtime casts like follows:
+//
+// for (Type From : Expected) {
+// From e = As<From>(); // Checked at compile time
+// for (Type To : AllIntTypes) {
+// bool succeeded = TryCastInteger<To, From>(...);
+// if (To in Expected) {
+// EXPECT_TRUE(succeeded);
+// ...
+// } else {
+// EXPECT_FALSE(succeeded);
+// }
+// }
+// }
+//
+
+enum IntSize { kNone, k8 = 1, k16 = 2, k32 = 4, k64 = 8 };
+
+/**
+ * The set of the integer types we care about, with a filter applied.
+ * Apply() instantiates a given template, for each integer type passing the
+ * filter. We use the filter to model the set of types that can represent a
+ * literal (each instantiation tries to compile ::As).
+ */
+template <typename FilterType>
+class IntTypeSet {
+ public:
+ template <typename F>
+ static void Apply(F f) {
+ ApplySingle<uint8_t>(f);
+ ApplySingle<uint16_t>(f);
+ ApplySingle<uint32_t>(f);
+ ApplySingle<uint64_t>(f);
+
+ ApplySingle<int8_t>(f);
+ ApplySingle<int16_t>(f);
+ ApplySingle<int32_t>(f);
+ ApplySingle<int64_t>(f);
+ }
+
+ template <typename T>
+ static constexpr bool Matches() {
+ return FilterType::template Matches<T>();
+ }
+
+ private:
+ template <bool B>
+ using BoolTag = std::integral_constant<bool, B>;
+
+ template <typename T, typename F>
+ static void ApplySingle(F f) {
+ ApplySingleImpl<T>(f, BoolTag<Matches<T>()>{});
+ }
+
+ template <typename T, typename F>
+ static void ApplySingleImpl(F f, BoolTag<true>) {
+ f.template Apply<T>();
+ }
+
+ template <typename T, typename F>
+ static void ApplySingleImpl(F f, BoolTag<false>) {}
+};
+
+struct NoFilter {
+ template <typename T>
+ static constexpr bool Matches() {
+ return true;
+ }
+};
+
+/**
+ * The filter type we use per literal. It's sufficient to give a minimum size,
+ * separately per signed / unsigned.
+ */
+template <IntSize MinSignedSize, IntSize MinUnsignedSize>
+struct IntSizeFilter {
+ template <typename T>
+ static constexpr bool Matches() {
+ return SizeRequiredForType<T>() != IntSize::kNone &&
+ sizeof(T) >= SizeRequiredForType<T>();
+ }
+
+ template <typename T>
+ static constexpr IntSize SizeRequiredForType() {
+ return std::is_signed<T>() ? MinSignedSize : MinUnsignedSize;
+ }
+};
+
+using AllIntTypes = IntTypeSet<NoFilter>;
+
+template <IntSize MinSignedSize>
+using Signed = IntTypeSet<IntSizeFilter<MinSignedSize, kNone>>;
+template <IntSize MinUnsignedSize>
+using Unsigned = IntTypeSet<IntSizeFilter<kNone, MinUnsignedSize>>;
+template <IntSize MinSignedSize, IntSize MinUnsignedSize = MinSignedSize>
+using SignedOrUnsigned =
+ IntTypeSet<IntSizeFilter<MinSignedSize, MinUnsignedSize>>;
+
+template <typename T, T Value_>
+struct Literal {
+ template <typename R>
+ using As = Literal<R, Value_>;
+
+ static constexpr T Value() { return Value_; }
+};
+
+/**
+ * This is the per-literal test as described at the top of the file -
+ * but uglier.
+ */
+template <typename LiteralType, typename SetType>
+struct VerifyCastFromEachInSetToAll {
+ // Outer loop body: called for each type in the literal's 'expected' set.
+ template <typename FromType>
+ void Apply() {
+ AllIntTypes::Apply(ForAll<FromType>{});
+ }
+
+ template <typename FromType>
+ struct ForAll {
+ static constexpr FromType From() {
+ return LiteralType::template As<FromType>::Value();
+ }
+
+ // Inner loop body: called for all integer types.
+ template <typename ToType>
+ void Apply() {
+ ToType actual;
+ bool succeeded = TryCastInteger(From(), &actual);
+ if (SetType::template Matches<ToType>()) {
+ EXPECT_TRUE(succeeded);
+ if (succeeded) {
+ EXPECT_THAT(actual, Eq(static_cast<ToType>(From())));
+ EXPECT_THAT(static_cast<FromType>(actual), Eq(From()));
+ }
+ } else {
+ EXPECT_FALSE(succeeded);
+ }
+ }
+ };
+};
+
+template <typename LiteralType, typename SetType>
+void VerifyIntCasts() {
+ SetType::Apply(VerifyCastFromEachInSetToAll<LiteralType, SetType>{});
+}
+
+#define LITERAL(i) Literal<decltype(i), i>
+#define MAX_LITERAL(t) Literal<t, std::numeric_limits<t>::max()>
+#define MIN_LITERAL(t) Literal<t, std::numeric_limits<t>::min()>
+
+TEST(BoundsTest, TryCastInteger) {
+ VerifyIntCasts<LITERAL(-1), Signed<k8>>();
+ VerifyIntCasts<LITERAL(0), SignedOrUnsigned<k8>>();
+
+ VerifyIntCasts<MAX_LITERAL(int8_t), SignedOrUnsigned<k8>>();
+ VerifyIntCasts<MAX_LITERAL(int16_t), SignedOrUnsigned<k16>>();
+ VerifyIntCasts<MAX_LITERAL(int32_t), SignedOrUnsigned<k32>>();
+ VerifyIntCasts<MAX_LITERAL(int64_t), SignedOrUnsigned<k64>>();
+
+ VerifyIntCasts<MAX_LITERAL(uint8_t), SignedOrUnsigned<k16, k8>>();
+ VerifyIntCasts<MAX_LITERAL(uint16_t), SignedOrUnsigned<k32, k16>>();
+ VerifyIntCasts<MAX_LITERAL(uint32_t), SignedOrUnsigned<k64, k32>>();
+ VerifyIntCasts<MAX_LITERAL(uint64_t), Unsigned<k64>>();
+
+ VerifyIntCasts<MIN_LITERAL(int8_t), Signed<k8>>();
+ VerifyIntCasts<MIN_LITERAL(int16_t), Signed<k16>>();
+ VerifyIntCasts<MIN_LITERAL(int32_t), Signed<k32>>();
+ VerifyIntCasts<MIN_LITERAL(int64_t), Signed<k64>>();
+
+ VerifyIntCasts<MIN_LITERAL(uint8_t), SignedOrUnsigned<k8>>();
+ VerifyIntCasts<MIN_LITERAL(uint16_t), SignedOrUnsigned<k8>>();
+ VerifyIntCasts<MIN_LITERAL(uint32_t), SignedOrUnsigned<k8>>();
+ VerifyIntCasts<MIN_LITERAL(uint64_t), SignedOrUnsigned<k8>>();
+}
+
+//
+// End of the TryCastInteger test
+//
+
+AddressSpace MakeFakeAddressSpace(uint64_t start, uint64_t size) {
+ return AddressSpace::Embedded(reinterpret_cast<void*>(start), size);
+}
+
+MATCHER_P(IsAddress, addr, "") {
+ return reinterpret_cast<uint64_t>(arg) == addr;
+}
+
+TEST(BoundsTest, MapOutOfAddressSpace_Success) {
+ AddressSpace space = MakeFakeAddressSpace(128, 128);
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{0}, 8, 1),
+ IsAddress(128));
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{0}, 8, 128 / 8),
+ IsAddress(128));
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{127}, 1, 1),
+ IsAddress(128 + 127));
+}
+
+TEST(BoundsTest, MapOutOfAddressSpace_OutOfBounds) {
+ AddressSpace space = MakeFakeAddressSpace(128, 128);
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{128}, 8, 1),
+ Eq(nullptr));
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{8}, 8, 128 / 8),
+ Eq(nullptr));
+}
+
+TEST(BoundsTest, MapOutOfAddressSpace_ZeroSizeEdge) {
+ AddressSpace space = MakeFakeAddressSpace(128, 128);
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{128}, 1, 0),
+ IsAddress(128 + 128));
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{128}, 1, 1),
+ Eq(nullptr));
+}
+
+TEST(BoundsTest, MapOutOfAddressSpace_HighAddress) {
+ constexpr uint64_t kMax = std::numeric_limits<uint64_t>::max();
+ // Note that kMax is *not* a valid address; AddressSpace requires that 'one
+ // past the end' is <= kMax.
+ AddressSpace space = MakeFakeAddressSpace(kMax - 128, 128);
+
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{0}, 8, 128 / 8),
+ IsAddress(kMax - 128));
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{127}, 1, 1),
+ IsAddress(kMax - 1));
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{128}, 1, 0),
+ IsAddress(kMax));
+
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{128}, 1, 1),
+ Eq(nullptr));
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{0}, 1, 129),
+ Eq(nullptr));
+}
+
+TEST(BoundsTest, MapOutOfAddressSpace_Overflow) {
+ constexpr uint64_t kMax = std::numeric_limits<uint64_t>::max();
+ AddressSpace space = MakeFakeAddressSpace(0, kMax);
+
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{kMax - 1}, 1, 1),
+ IsAddress(kMax - 1));
+ EXPECT_THAT(MapOutOfAddressSpace(space, ForeignPointer{kMax - 1}, 1, 2),
+ Eq(nullptr));
+}
+
+} // namespace fcp
diff --git a/fcp/base/clock.cc b/fcp/base/clock.cc
new file mode 100644
index 0000000..c5f83cc
--- /dev/null
+++ b/fcp/base/clock.cc
@@ -0,0 +1,225 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/clock.h"
+
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <thread> // NOLINT(build/c++11)
+
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+
+// Implements global realtime clock that uses timers to schedule wake-up of
+// waiters.
+class RealTimeClock : public Clock {
+ public:
+ RealTimeClock();
+ ~RealTimeClock() override {
+ FCP_LOG(FATAL) << "RealTimeClock should never be destroyed";
+ }
+
+ // Returns the current time.
+ absl::Time Now() override { return absl::Now(); }
+ absl::Time NowLocked() override { return absl::Now(); }
+
+ // Sleeps for the specified duration.
+ void Sleep(absl::Duration d) override { absl::SleepFor(d); }
+
+ // Schedules wakeup at the specified wakeup_time.
+ void ScheduleWakeup(absl::Time wakeup_time) override;
+
+ private:
+ // The worker loop that performs the sleep and dispatches wake-up calls.
+ void WorkerLoop();
+
+ // The currently scheduled wake-up time. There is at most one wake-up
+ // time per process.
+ absl::Time next_wakeup_ ABSL_GUARDED_BY(&wakeup_mu_) = absl::InfiniteFuture();
+ // Mutex that protects next_wakeup and used with the wake-up CondVar.
+ absl::Mutex wakeup_mu_;
+ // CondVar used to sleep until the next wake-up deadline.
+ absl::CondVar wakeup_condvar_;
+ // Worker thread that runs the worker loop function. Since this class
+ // is singleton, there is only one thread per process, and that thread is
+ // never terminated.
+ std::unique_ptr<std::thread> worker_thread_;
+};
+
+Clock* Clock::RealClock() {
+ static Clock* realtime_clock = new RealTimeClock;
+ return realtime_clock;
+}
+
+void Clock::WakeupWithDeadline(absl::Time deadline,
+ const std::shared_ptr<Clock::Waiter>& waiter) {
+ // Insert the new waiter into the map ordered by its deadline.
+ // Waiters with matching deadlines are inserted into the same bucket and
+ // their order within the bucket is preserved.
+ {
+ absl::MutexLock lock(mutex());
+ WaiterMap::iterator it;
+ if ((it = pending_waiters_.find(deadline)) != pending_waiters_.end()) {
+ it->second.push_back(waiter);
+ } else {
+ pending_waiters_.insert(std::make_pair(deadline, WaiterList{waiter}));
+ }
+ }
+
+ // Inserting a new waiter may trigger an immediate wake-up if the deadline
+ // is due. Otherwise a new wake-up is scheduled at the end on the dispatch.
+ DispatchWakeups();
+}
+
+// DispatchWakeup performs the following actions in the loop:
+// - Check for reentrancy to avoid more than one concurrent dispatch loop
+// - Take out all waiters that are due
+// - Make WakeUp calls on all of those waiters. This step is done outside
+// of lock because it may potentially take longer time and new waiters may
+// potentially be inserted during that step.
+// - If there are any waiters that are still due at that point (because the
+// the previous step took too long and new waiters have expired or because
+// there were any new waiters inserted during the previous steps), loop
+// back and repeat the previous steps.
+// - Otherwise finish the dispatch by scheduling a new wakeup for the bucket
+// that expires the soonest.
+void Clock::DispatchWakeups() {
+ do {
+ if (CheckReentrancy()) {
+ // Avoid reentrancy. An ongoing DispatchWakeups() call will take care
+ // of dispatching any new due wakeups if necessary.
+ // If there is a race condition, only one of dispatch calls will go
+ // through and all other will just increment the dispatch_level and
+ // return.
+ return;
+ }
+
+ // Collect waiters that are due.
+ WaiterList wakeup_calls = GetExpiredWaiters();
+
+ // Dispatch WakeUp calls to those waiters.
+ for (const auto& waiter : wakeup_calls) {
+ waiter->WakeUp();
+ }
+ // One more dispatch loop may be needed if there were any reentrant calls
+ // or if WakeUp() calls took so long that more waiters have become due.
+ } while (!FinishDispatchAndScheduleNextWakeup());
+}
+
+// Called at the beginning of dispatch loop.
+// Increments dispatch_level_ and returns true if there is already
+// another dispatch call in progress.
+bool Clock::CheckReentrancy() {
+ absl::MutexLock lock(mutex());
+ return ++dispatch_level_ > 1;
+}
+
+// Iterate through waiter buckets ordered by deadline time and take out all
+// waiters that are due.
+Clock::WaiterList Clock::GetExpiredWaiters() {
+ absl::MutexLock lock(mutex());
+ absl::Time now = NowLocked();
+ std::vector<std::shared_ptr<Waiter>> wakeup_calls;
+ WaiterMap::iterator iter;
+
+ while ((iter = pending_waiters_.begin()) != pending_waiters_.end() &&
+ iter->first <= now) {
+ std::move(iter->second.begin(), iter->second.end(),
+ std::back_inserter(wakeup_calls));
+ pending_waiters_.erase(iter);
+ }
+ return wakeup_calls;
+}
+
+// Called at the end of dispatch loop to check post-dispatch conditions,
+// reset re-entracy level, and schedule a next dispatch if needed.
+// Returns true if the dispatch loop has ended.
+// Returns false if more the dispatch loop needs to be repeated.
+bool Clock::FinishDispatchAndScheduleNextWakeup() {
+ absl::MutexLock lock(mutex());
+ int current_dispatch_level = dispatch_level_;
+ dispatch_level_ = 0;
+
+ if (!pending_waiters_.empty()) {
+ if (current_dispatch_level > 1) {
+ // There was another dispatch call while this one was in progress.
+ // One more dispatch loop is needed.
+ return false;
+ }
+
+ absl::Time next_wakeup = pending_waiters_.begin()->first;
+ if (next_wakeup <= NowLocked()) {
+ // One more dispatch loop is needed because a new waiter has become due
+ // while the wake-ups were dispatched.
+ return false;
+ }
+
+ // Schedule DispatchWakeups() to be called at a future next_wakeup time.
+ ScheduleWakeup(next_wakeup);
+ }
+
+ return true;
+}
+
+RealTimeClock::RealTimeClock() {
+ worker_thread_ =
+ std::make_unique<std::thread>([this] { this->WorkerLoop(); });
+}
+
+void RealTimeClock::WorkerLoop() {
+ for (;;) {
+ bool dispatch = false;
+
+ {
+ absl::MutexLock lock(&wakeup_mu_);
+ wakeup_condvar_.WaitWithDeadline(&wakeup_mu_, next_wakeup_);
+ if (Now() >= next_wakeup_) {
+ dispatch = true;
+ next_wakeup_ = absl::InfiniteFuture();
+ }
+ }
+
+ if (dispatch) {
+ DispatchWakeups();
+ }
+ }
+}
+
+// RealTimeClock implementation of ScheduleWakeup.
+void RealTimeClock::ScheduleWakeup(absl::Time wakeup_time) {
+ absl::MutexLock lock(&wakeup_mu_);
+
+ // Optimization: round wakeup_time up to whole milliseconds.
+ wakeup_time = absl::FromUDate(ceil(absl::ToUDate(wakeup_time)));
+
+ // ScheduleWakeup may be called repeatedly with the same time if a new timer
+ // is created in the future after already existing timer. In that case
+ // this function continues relying on already scheduled wake-up time.
+ // A new ScheduleWakeup call will be made from within DispatchWakeups() once
+ // the current timer expires.
+ if (wakeup_time == next_wakeup_) {
+ return;
+ }
+
+ next_wakeup_ = wakeup_time;
+ wakeup_condvar_.Signal();
+}
+
+} // namespace fcp
diff --git a/fcp/base/clock.h b/fcp/base/clock.h
new file mode 100644
index 0000000..66742a9
--- /dev/null
+++ b/fcp/base/clock.h
@@ -0,0 +1,107 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_CLOCK_H_
+#define FCP_BASE_CLOCK_H_
+
+#include <map>
+#include <memory>
+#include <optional>
+#include <vector>
+
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+
+namespace fcp {
+
+/*
+ * Clock is an abstract class representing a Clock, which is an object that can
+ * tell you the current time and schedule a wakeable event in a future.
+ */
+class Clock {
+ public:
+ // Returns a pointer to the global realtime clock. The caller does not
+ // own the returned pointer and should not delete it. The returned clock
+ // is thread-safe.
+ static Clock* RealClock();
+
+ virtual ~Clock() = default;
+
+ // Returns current time.
+ virtual absl::Time Now() = 0;
+
+ // Sleeps for the specified duration.
+ virtual void Sleep(absl::Duration d) = 0;
+
+ // An abstract interface for a waiter class that is passed to
+ // WakeupWithDeadline and is responsible for handling a timer wake-up.
+ // Waiter interface doesn't support a cancellation mechanism which means
+ //
+ // Note: it is up to Waiter implementation how to handle a cancellation. Clock
+ // itself doesn't manage cancellation and will call WakeUp() on all all alarms
+ // once their deadline time is past due.
+ class Waiter {
+ public:
+ virtual ~Waiter() = default;
+ // A callback method that is called when the corresponding deadline is
+ // reached. This method may be called on an arbitrary thread.
+ virtual void WakeUp() = 0;
+ };
+
+ // Schedule the waiter to be waked up at the specified deadline.
+ void WakeupWithDeadline(absl::Time deadline,
+ const std::shared_ptr<Waiter>& waiter);
+
+ protected:
+ // Accessors shared for derived clases.
+ absl::Mutex* mutex() { return &mu_; }
+
+ // Internal version of now which is called under mutex.
+ virtual absl::Time NowLocked()
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex()) = 0;
+
+ // Overloaded by derived class to implement the actual scheduling.
+ virtual void ScheduleWakeup(absl::Time wakeup_time)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex()) = 0;
+
+ // Called to dispatch wakeup to all due waiters.
+ void DispatchWakeups();
+
+ private:
+ using WaiterList = std::vector<std::shared_ptr<Waiter>>;
+ using WaiterMap = std::map<absl::Time, WaiterList>;
+
+ bool CheckReentrancy();
+ WaiterList GetExpiredWaiters();
+ bool FinishDispatchAndScheduleNextWakeup();
+
+ // Mutex that protects the internal state.
+ absl::Mutex mu_;
+ // Pending (unexpired) waiters ordered by deadline - soonest to latest.
+ // Waiters with exactly the same deadline are stored in the same bucket and
+ // the order at which they were added is preserved.
+ WaiterMap pending_waiters_ ABSL_GUARDED_BY(mutex());
+ // This value =0 when no DispatchWakeups() is running;
+ // =1 when DispatchWakeups() is running
+ // >1 when at least one additional DispatchWakeups() call happened
+ // while DispatchWakeups() was running, for example from
+ // a timer elapsing and triggering a wake-up.
+ int dispatch_level_ ABSL_GUARDED_BY(mutex()) = 0;
+};
+
+} // namespace fcp
+
+#endif // FCP_BASE_CLOCK_H_
diff --git a/fcp/base/error.h b/fcp/base/error.h
new file mode 100644
index 0000000..99784da
--- /dev/null
+++ b/fcp/base/error.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_ERROR_H_
+#define FCP_BASE_ERROR_H_
+
+namespace fcp {
+
+// An Error indicates that something went wrong. Errors carry data, but are
+// opaque to most code; error data is intended to support diagnostics, but not
+// program behavior.
+//
+// Code should be written such that values are turned into opaque Errors only
+// after they are judged to be problematic (something went wrong). If one finds
+// the need to inspect the "reason" for an Error, some earlier (non-error) value
+// should be used instead.
+//
+// Errors are typically returned inside of a Result.
+class Error {
+ // TODO(team): Make it possible to get a stack trace from a tracing span
+ // and save this in the error.
+ public:
+ class ConstructorAccess {
+ template <class FlatBufferTable, class... Arg>
+ friend Error TraceError(Arg&&... args);
+
+ private:
+ constexpr ConstructorAccess() = default;
+ constexpr ConstructorAccess(const ConstructorAccess&) = default;
+ };
+
+ // Error is copyable but not trivially constructible.
+ // Use global TraceError() function to construct an Error.
+ explicit constexpr Error(ConstructorAccess) {}
+ Error(const Error&) = default;
+};
+
+} // namespace fcp
+
+#endif // FCP_BASE_ERROR_H_
diff --git a/fcp/base/future.cc b/fcp/base/future.cc
new file mode 100644
index 0000000..5355981
--- /dev/null
+++ b/fcp/base/future.cc
@@ -0,0 +1,15 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/base/future.h"
diff --git a/fcp/base/future.h b/fcp/base/future.h
new file mode 100644
index 0000000..f464f49
--- /dev/null
+++ b/fcp/base/future.h
@@ -0,0 +1,290 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * This file provides a pair of types Future<T> (a value to wait for) and
+ * Promise<T> (allows providing the value for an associated future).
+ *
+ * These serve the same purpose as std::future and std::promise, but with a few
+ * differences:
+ * - They do not represent exceptions (i.e. std::promise::set_exception).
+ * Consider representing failure conditions with StatusOr or std::variant
+ * - They do not throw future-related exceptions (e.g. std::future::get throws
+ * if the promise was 'abandoned'; this one indicates that with a value).
+ * - There is no integration with std::async etc.
+ * - They use absl::Duration / absl::Time for waiting with a timeout.
+ * - They are created as a pair (vs. std::promise::get_future(), which throws
+ * an exception if called twice).
+ * - Setting (promise) and taking (future) require rvalues (you might need to
+ * use std::move). This is to indicate that these are 'consuming' operations
+ * (to humans and static analysis tools).
+ */
+
+#ifndef FCP_BASE_FUTURE_H_
+#define FCP_BASE_FUTURE_H_
+
+#include <memory>
+#include <optional>
+#include <tuple>
+#include <variant>
+
+#include "absl/base/macros.h"
+#include "absl/synchronization/notification.h"
+#include "fcp/base/meta.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/move_to_lambda.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/base/unique_value.h"
+
+namespace fcp {
+
+// Since fcp::Promise is already defined by the reactive streams library
+// (fcp/reactive/), we'll define fcp::thread::{Promise, Future}.
+namespace thread {
+
+// Forward declarations; see doc comments below
+template <typename T>
+class Future;
+template <typename T>
+class Promise;
+
+template <typename T>
+struct FuturePair {
+ Promise<T> promise;
+ Future<T> future;
+};
+
+namespace future_internal {
+// We want Promise and Future to be created only as a pair, with MakeFuture.
+// This type is given permission to construct them.
+struct Maker {
+ template <typename T>
+ static FuturePair<T> Make();
+};
+
+// Common state of a Promise / Future pair. Destructed when *both* the promise
+// and future are gone.
+//
+// States: NotSet, Set, Taken
+// Transitions:
+// NotSet -> Set: When a value is provided (std::nullopt indicates an
+// abandoned promise). *Before* ready_ is signalled.
+// Set -> Taken: When a future takes a value. *After* ready_ is signalled.
+template <typename T>
+class FutureState {
+ public:
+ bool Wait(absl::Duration timeout) const;
+ std::optional<T> Take();
+ void Set(std::optional<T> val);
+
+ private:
+ enum class State { kNotSet, kSet, kTaken };
+
+ absl::Notification ready_;
+ State state_ = State::kNotSet;
+ std::optional<T> value_;
+};
+
+// A Future and Promise share a single FutureState. That is, FutureState
+// is ref-counted, with two initial refs (no additional refs can be created,
+// since Future and Promise are move-only). So, we define FutureStateRef as a
+// move-only std::shared_ptr.
+template <typename T>
+using FutureStateRef = UniqueValue<std::shared_ptr<FutureState<T>>>;
+} // namespace future_internal
+
+/**
+ * Allows waiting for and retrieving a value (provided eventually by a paired
+ * Promise).
+ *
+ * If the paired Promise is 'abandoned' (destructed without having a value set),
+ * then the Future's value is std::nullopt.
+ */
+template <typename T>
+class Future {
+ public:
+ Future(Future&&) = default;
+ Future& operator=(Future&&) = default;
+
+ /**
+ * Retrieves the future value, waiting until it is available.
+ * Taking from a future *consumes* it, and so requires an rvalue. To take
+ * from a Future<T> f:
+ * std::move(f).Take()
+ *
+ * If the paired promise is 'abandoned' (destructed before a real value is
+ * provided), the value is std::nullopt.
+ */
+ ABSL_MUST_USE_RESULT
+ std::optional<T> Take() && {
+ future_internal::FutureStateRef<T> state = std::move(state_);
+ FCP_CHECK(state.has_value());
+ return (*state)->Take();
+ }
+
+ /**
+ * Waits for the value to become available, with a timeout. Unlike Take(),
+ * this does *not* consume the value.
+ *
+ * Returns a bool indicating if the value is available (if so, Take() will
+ * return immediately).
+ */
+ ABSL_MUST_USE_RESULT
+ bool Wait(absl::Duration timeout) const {
+ FCP_CHECK(state_.has_value());
+ return (*state_)->Wait(timeout);
+ }
+
+ private:
+ friend struct future_internal::Maker;
+
+ explicit Future(future_internal::FutureStateRef<T> state)
+ : state_(std::move(state)) {}
+
+ future_internal::FutureStateRef<T> state_;
+};
+
+/**
+ * Allows providing a value to satisfy a paired Future.
+ *
+ * If this Promise is 'abandoned' (destructed without having a value set),
+ * then the Future gets the value std::nullopt.
+ */
+template <typename T>
+class Promise {
+ public:
+ Promise(Promise&&) = default;
+ Promise& operator=(Promise&&) = default;
+
+ ~Promise() {
+ if (state_.has_value()) {
+ // Abandoned
+ (*state_)->Set(std::nullopt);
+ }
+ }
+
+ /**
+ * Provides a value to the paired Future. Setting a promise *consumes* it,
+ * and so requires an rvalue. To set a Promise<T> p:
+ * std::move(p).Set(...)
+ */
+ void Set(T value) && {
+ future_internal::FutureStateRef<T> state = std::move(state_);
+ FCP_CHECK(state.has_value());
+ (*state)->Set(std::move(value));
+ }
+
+ private:
+ friend struct future_internal::Maker;
+
+ explicit Promise(future_internal::FutureStateRef<T> state)
+ : state_(std::move(state)) {}
+
+ future_internal::FutureStateRef<T> state_;
+};
+
+/** Creates a paired Future and Promise. */
+template <typename T>
+FuturePair<T> MakeFuture() {
+ return future_internal::Maker::Make<T>();
+}
+
+/**
+ * Schedules a task which calls a function computing a value. Returns a future
+ * to wait for and access the value once it is computed.
+ */
+template <typename T>
+Future<T> ScheduleFuture(Scheduler* scheduler, std::function<T()> func) {
+ thread::FuturePair<T> p = thread::MakeFuture<T>();
+ MoveToLambdaWrapper<thread::Promise<T>> promise_capture =
+ MoveToLambda(std::move(p.promise));
+ // Lambda is stateful (since the promise is consumed). This is okay, since
+ // it should only be called once.
+ scheduler->Schedule([promise_capture, func]() mutable {
+ std::move(*promise_capture).Set(func());
+ });
+
+ return std::move(p.future);
+}
+
+namespace future_internal {
+
+template <typename T>
+FuturePair<T> Maker::Make() {
+ std::shared_ptr<FutureState<T>> state = std::make_shared<FutureState<T>>();
+
+ auto promise_ref = FutureStateRef<T>(state);
+ // Note that we use std::move this time, to avoid ref-count churn.
+ auto future_ref = FutureStateRef<T>(std::move(state));
+ return {Promise<T>(std::move(promise_ref)), Future<T>(std::move(future_ref))};
+}
+
+template <typename T>
+bool FutureState<T>::Wait(absl::Duration timeout) const {
+ return ready_.WaitForNotificationWithTimeout(timeout);
+}
+
+template <typename T>
+void FutureState<T>::Set(std::optional<T> val) {
+ FCP_CHECK(!ready_.HasBeenNotified())
+ << "Attempted to set a FutureState which has already been notified";
+ // Not notified => value_ has *not* been set, and the Promise has exclusive
+ // access (no atomics or locks needed below).
+ switch (state_) {
+ case State::kNotSet:
+ state_ = State::kSet;
+ value_ = std::move(val);
+ // This has release semantics; stores to state_ and value_ will be visible
+ // to whomever sees that the notification.
+ ready_.Notify();
+ return;
+ case State::kSet:
+ FCP_CHECK(false) << "FutureState has been notified, so state_ should be "
+ "kTaken or kSet";
+ abort(); // Compiler thinks FCP_CHECK(false) might return
+ case State::kTaken:
+ FCP_CHECK(false) << "FutureState has already been taken from";
+ abort(); // Compiler thinks FCP_CHECK(false) might return
+ }
+}
+
+template <typename T>
+std::optional<T> FutureState<T>::Take() {
+ ready_.WaitForNotification();
+ // Notified => value_ has been set, and exclusive access has been transferred
+ // from the Promise to the Future (no atomics or locks needed below).
+ switch (state_) {
+ case State::kSet:
+ state_ = State::kTaken;
+ // value_.has_value() will still be set, but we won't read it again
+ // in the kTaken state.
+ return std::move(value_);
+ case State::kNotSet:
+ FCP_CHECK(false) << "FutureState has been notified, so state_ should be "
+ "kTaken or kSet";
+ abort(); // Compiler thinks FCP_CHECK(false) might return
+ case State::kTaken:
+ FCP_CHECK(false) << "FutureState has already been taken from";
+ abort(); // Compiler thinks FCP_CHECK(false) might return
+ }
+}
+
+} // namespace future_internal
+
+} // namespace thread
+} // namespace fcp
+
+#endif // FCP_BASE_FUTURE_H_
diff --git a/fcp/base/future_test.cc b/fcp/base/future_test.cc
new file mode 100644
index 0000000..e4a5fde
--- /dev/null
+++ b/fcp/base/future_test.cc
@@ -0,0 +1,201 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/future.h"
+
+#include <functional>
+#include <memory>
+#include <thread> // NOLINT(build/c++11)
+#include <type_traits>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/barrier.h"
+#include "absl/time/time.h"
+#include "fcp/base/meta.h"
+#include "fcp/base/move_to_lambda.h"
+
+namespace fcp {
+namespace thread {
+
+using ::testing::Eq;
+
+// Future::Wait and Future::Take sometimes block. We'd like to test thread
+// interleavings where these calls block before being woken up. In the absence
+// of instrumentation of the underlying synchronization primitives, we just use
+// this arbitrary delay before unblocking operations (the other thread
+// "probably" has time to block). Note that the other ordering (Promise::Set
+// before Future::Take etc.) is easy to guarantee.
+constexpr absl::Duration kArbitraryDelay = absl::Milliseconds(50);
+
+void Delay() { absl::SleepFor(kArbitraryDelay); }
+
+// Freely copyable test value. We put in a 'valid' value and hope to find it
+// again.
+enum class V { kInvalid, kValid };
+
+// Move-only test value (we use this for Promise and Future, to make sure they
+// are compatible with move-only types - typically the harder case). This
+// corresponds to V:
+// Value present (not moved) <=> kValid
+// Value moved <=> kInvalid
+// The TakeV / SetV wrappers below actually do those conversions, since the test
+// assertions (e.g. Eq matcher) are difficult to use with move-only types.
+using UV = UniqueValue<Unit>;
+
+static_assert(!std::is_copy_constructible<UV>::value,
+ "Expected to be move-only");
+
+std::optional<V> TakeV(Future<UV> future) {
+ std::optional<UV> maybe_uv = std::move(future).Take();
+ if (maybe_uv.has_value()) {
+ UniqueValue<Unit> uv = *std::move(maybe_uv);
+ return uv.has_value() ? V::kValid : V::kInvalid;
+ } else {
+ return std::nullopt;
+ }
+}
+
+void SetV(Promise<UV> promise) { std::move(promise).Set(UV(Unit{})); }
+
+absl::Barrier MakeBarrier() { return absl::Barrier(2); }
+
+void RunThreads(std::vector<std::function<void()>> fns) {
+ std::vector<std::thread> threads;
+ for (auto& fn : fns) {
+ threads.push_back(std::thread(std::move(fn)));
+ }
+
+ for (auto& thread : threads) {
+ thread.join();
+ }
+}
+
+void RunThreadsWithFuture(std::function<void(Promise<UV>)> promise_fn,
+ std::function<void(Future<UV>)> future_fn) {
+ FuturePair<UV> pair = MakeFuture<UV>();
+
+ MoveToLambdaWrapper<Promise<UV>> promise_capture =
+ MoveToLambda(std::move(pair.promise));
+ auto promise_fn_wrapped = [promise_capture, promise_fn]() mutable {
+ promise_fn(std::move(*promise_capture));
+ };
+
+ MoveToLambdaWrapper<Future<UV>> future_capture =
+ MoveToLambda(std::move(pair.future));
+ auto future_fn_wrapped = [future_capture, future_fn]() mutable {
+ future_fn(std::move(*future_capture));
+ };
+
+ RunThreads({std::move(promise_fn_wrapped), std::move(future_fn_wrapped)});
+}
+
+TEST(FutureTest, WaitTimeouts) {
+ absl::Barrier waited = MakeBarrier();
+ absl::Barrier set = MakeBarrier();
+
+ auto promise_fn = [&](Promise<UV> promise) {
+ waited.Block();
+ SetV(std::move(promise));
+ set.Block();
+ };
+
+ auto future_fn = [&](Future<UV> future) {
+ // Before set: Timeout should elapse
+ EXPECT_FALSE(future.Wait(absl::Milliseconds(1)))
+ << "Future shouldn't be ready yet";
+ waited.Block();
+ set.Block();
+ // After set: Zero timeout should be sufficient
+ EXPECT_TRUE(future.Wait(absl::ZeroDuration()))
+ << "Future should be ready without waiting";
+ };
+
+ RunThreadsWithFuture(std::move(promise_fn), std::move(future_fn));
+}
+
+TEST(FutureTest, TakeAfterSet) {
+ absl::Barrier set = MakeBarrier();
+
+ auto promise_fn = [&](Promise<UV> promise) {
+ SetV(std::move(promise));
+ set.Block();
+ };
+
+ auto future_fn = [&](Future<UV> future) {
+ set.Block();
+ EXPECT_THAT(TakeV(std::move(future)), Eq(V::kValid));
+ };
+
+ RunThreadsWithFuture(std::move(promise_fn), std::move(future_fn));
+}
+
+TEST(FutureTest, TakeProbablyBeforeSet) {
+ auto promise_fn = [](Promise<UV> promise) {
+ Delay();
+ SetV(std::move(promise));
+ };
+
+ auto future_fn = [](Future<UV> future) {
+ EXPECT_THAT(TakeV(std::move(future)), Eq(V::kValid));
+ };
+
+ RunThreadsWithFuture(std::move(promise_fn), std::move(future_fn));
+}
+
+TEST(FutureTest, AbandonWhileProbablyTaking) {
+ auto promise_fn = [](Promise<UV> promise) {
+ Delay();
+ { Promise<UV> dies = std::move(promise); }
+ };
+
+ auto future_fn = [](Future<UV> future) {
+ EXPECT_THAT(std::move(future).Take(), Eq(std::nullopt));
+ };
+
+ RunThreadsWithFuture(std::move(promise_fn), std::move(future_fn));
+}
+
+TEST(FutureTest, SetWhileProbablyWaiting) {
+ auto promise_fn = [](Promise<UV> promise) {
+ Delay();
+ SetV(std::move(promise));
+ };
+
+ auto future_fn = [](Future<UV> future) {
+ EXPECT_TRUE(future.Wait(absl::InfiniteDuration()));
+ };
+
+ RunThreadsWithFuture(std::move(promise_fn), std::move(future_fn));
+}
+
+TEST(FutureTest, AbandonWhileProbablyWaiting) {
+ auto promise_fn = [](Promise<UV> promise) {
+ Delay();
+ { Promise<UV> dies = std::move(promise); }
+ };
+
+ auto future_fn = [](Future<UV> future) {
+ EXPECT_TRUE(future.Wait(absl::InfiniteDuration()));
+ };
+
+ RunThreadsWithFuture(std::move(promise_fn), std::move(future_fn));
+}
+
+} // namespace thread
+} // namespace fcp
diff --git a/fcp/base/golden_file.bzl b/fcp/base/golden_file.bzl
new file mode 100644
index 0000000..13edd8b
--- /dev/null
+++ b/fcp/base/golden_file.bzl
@@ -0,0 +1,65 @@
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Provides the golden_file_test() rule."""
+
+GOLDEN_FILE_CHECKER_TEMPLATE = """
+#!/bin/bash
+set -o nounset
+set -o errexit
+
+readonly EXPECTED={EXPECTED}
+readonly ACTUAL={ACTUAL}
+
+diff -u --label EXPECTED "$EXPECTED" --label ACTUAL "$ACTUAL"
+readonly r=$?
+
+if (( $r != 0 )); then
+ echo "***" >&2
+ echo "*** FAIL: Contents of $ACTUAL do not match $EXPECTED" >&2
+ echo "***" >&2
+fi
+
+exit $r
+"""
+
+def _golden_file_test_impl(ctx):
+ replacements = {
+ "{EXPECTED}": repr(ctx.file.expected.short_path),
+ "{ACTUAL}": repr(ctx.file.actual.short_path),
+ }
+
+ contents = GOLDEN_FILE_CHECKER_TEMPLATE
+ for k, v in replacements.items():
+ contents = contents.replace(k, v)
+
+ ctx.actions.write(
+ ctx.outputs.sh,
+ contents,
+ is_executable = True,
+ )
+
+ runfiles = ctx.runfiles(files = [ctx.file.expected, ctx.file.actual])
+ return [DefaultInfo(executable = ctx.outputs.sh, runfiles = runfiles)]
+
+golden_file_test = rule(
+ implementation = _golden_file_test_impl,
+ outputs = {"sh": "%{name}.sh"},
+ test = True,
+ attrs = {
+ "actual": attr.label(allow_single_file = True),
+ "expected": attr.label(allow_single_file = True),
+ },
+)
+"""Checks that two files are equal; fails with a text diff otherwise."""
diff --git a/fcp/base/match.h b/fcp/base/match.h
new file mode 100644
index 0000000..253b59b
--- /dev/null
+++ b/fcp/base/match.h
@@ -0,0 +1,292 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// 'Match' expressions for {std, absl}::variant.
+//
+// {std, absl}::variant is an algebraic sum type. However, the standard library
+// does not provide a convenient way to destructure or match on them - unlike in
+// Haskell, Rust, etc.
+//
+// This file provides a way to match on :variant in a way akin to a switch
+// statement.
+//
+// Example:
+//
+// using V = std::variant<X, Y, Z>;
+// V v = ...;
+// ...
+// int i = Match(v,
+// [](X const& x) { return 1; },
+// [](Y const& y) { return 2; },
+// [](Z const& z) { return 3; });
+//
+// It is a compile-time error if the match is not exhaustive. A 'Default' case
+// can be provided:
+//
+// int i = Match(v,
+// [](X const& x) { return 1; },
+// // Called with the otherwise-unhandled alternative (see decltype(alt)).
+// [](Default, auto const& alt) { ...; });
+//
+// int i = Match(v,
+// [](X const& x) { return 1; },
+// // Called with the variant itself.
+// [](Default, V const& v) { ...; });
+//
+// If constructing the matcher lambdas is non-trivial, it might be worthwhile to
+// create a re-usable matcher object. See 'MakeMatcher'.
+
+#ifndef FCP_BASE_MATCH_H_
+#define FCP_BASE_MATCH_H_
+
+#include <optional>
+#include <type_traits>
+#include <variant>
+
+#include "fcp/base/meta.h"
+
+namespace fcp {
+
+// Marker type for default match cases.
+struct Default {};
+
+namespace match_internal {
+
+template <typename... CaseFns>
+struct MatchCasesCallable : public CaseFns... {
+ // Each CaseFn provides operator(). We want to pick one by overload
+ // resolution.
+ using CaseFns::operator()...;
+};
+
+template <typename ToType, typename... CaseFns>
+class MatchCases {
+ public:
+ explicit constexpr MatchCases(MatchCasesCallable<CaseFns...> c)
+ : callable_(std::move(c)) {}
+
+ // False by default
+ template <typename Enable, typename... T>
+ struct IsCaseHandledImpl : public std::false_type {};
+
+ // True when m.MatchCases(args...) is well-formed, for a
+ // MatchCases<CaseFns...> m and T arg.
+ template <typename... T>
+ struct IsCaseHandledImpl<
+ std::void_t<decltype(std::declval<MatchCasesCallable<CaseFns...>>()(
+ std::declval<T>()...))>,
+ T...> : public std::true_type {};
+
+ template <typename... T>
+ static constexpr bool IsCaseHandled() {
+ return IsCaseHandledImpl<void, T...>::value;
+ }
+
+ template <typename ToType_ = ToType, typename... Args>
+ constexpr auto operator()(Args&&... args) const {
+ if constexpr (std::is_void_v<ToType_>) {
+ return callable_(std::forward<Args>(args)...);
+ } else {
+ return ToType_(callable_(std::forward<Args>(args)...));
+ }
+ }
+
+ private:
+ MatchCasesCallable<CaseFns...> callable_;
+};
+
+template <typename ToType, typename... CaseFns>
+constexpr MatchCases<ToType, CaseFns...> MakeMatchCases(CaseFns... case_fns) {
+ return MatchCases<ToType, CaseFns...>(
+ MatchCasesCallable<CaseFns...>{case_fns...});
+}
+
+template <typename CasesType, typename VariantType, typename ArgType>
+constexpr auto ApplyCase(CasesType const& cases, VariantType&& v,
+ ArgType&& arg) {
+ if constexpr (CasesType::template IsCaseHandled<ArgType>()) {
+ return cases(std::forward<ArgType>(arg));
+ } else if constexpr (CasesType::template IsCaseHandled<Default, ArgType>()) {
+ return cases(Default{}, std::forward<ArgType>(arg));
+ } else if constexpr (CasesType::template IsCaseHandled<Default,
+ VariantType>()) {
+ return cases(Default{}, std::forward<VariantType>(v));
+ } else if constexpr (CasesType::template IsCaseHandled<Default>()) {
+ return cases(Default{});
+ } else {
+ static_assert(
+ FailIfReached<ArgType>(),
+ "Provide a case for all variant alternatives, or a 'Default' case");
+ }
+}
+
+template <typename Traits, typename CasesType>
+class VariantMatcherImpl {
+ public:
+ using ValueType = typename Traits::ValueType;
+
+ explicit constexpr VariantMatcherImpl(CasesType cases)
+ : cases_(std::move(cases)) {}
+
+ constexpr auto Match(ValueType* v) const { return MatchInternal(v); }
+
+ constexpr auto Match(ValueType const& v) const { return MatchInternal(v); }
+
+ constexpr auto Match(ValueType&& v) const {
+ return MatchInternal(std::move(v));
+ }
+
+ private:
+ template <typename FromType>
+ constexpr auto MatchInternal(FromType&& v) const {
+ return Traits::Visit(std::forward<FromType>(v), [this, &v](auto&& alt) {
+ return ApplyCase(cases_, std::forward<FromType>(v),
+ std::forward<decltype(alt)>(alt));
+ });
+ }
+
+ CasesType cases_;
+};
+
+template <typename T, typename Enable = void>
+struct MatchTraits {
+ static_assert(FailIfReached<T>(),
+ "Only variant-like (e.g. std::variant<...> types can be "
+ "matched. See MatchTraits.");
+};
+
+template <typename... AltTypes>
+struct MatchTraits<std::variant<AltTypes...>> {
+ using ValueType = std::variant<AltTypes...>;
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType const& v, VisitFn&& fn) {
+ return absl::visit(std::forward<VisitFn>(fn), v);
+ }
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType&& v, VisitFn&& fn) {
+ return absl::visit(std::forward<VisitFn>(fn), std::move(v));
+ }
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType* v, VisitFn&& fn) {
+ return absl::visit([fn = std::forward<VisitFn>(fn)](
+ auto& alt) mutable { return fn(&alt); },
+ *v);
+ }
+};
+
+template <typename T>
+struct MatchTraits<std::optional<T>> {
+ using ValueType = std::optional<T>;
+
+ static constexpr auto Wrap(std::optional<T>* o)
+ -> std::variant<T*, std::nullopt_t> {
+ if (o->has_value()) {
+ return &**o;
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ static constexpr auto Wrap(std::optional<T> const& o)
+ -> std::variant<std::reference_wrapper<T const>, std::nullopt_t> {
+ if (o.has_value()) {
+ return std::ref(*o);
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ static constexpr auto Wrap(std::optional<T>&& o)
+ -> std::variant<T, std::nullopt_t> {
+ if (o.has_value()) {
+ return *std::move(o);
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ template <typename V, typename VisitFn>
+ static constexpr auto Visit(V&& v, VisitFn&& fn) {
+ return absl::visit(std::forward<VisitFn>(fn), Wrap(std::forward<V>(v)));
+ }
+};
+
+template <typename T>
+struct MatchTraits<T, std::void_t<typename T::VariantType>> {
+ using ValueType = T;
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType const& v, VisitFn&& fn) {
+ return MatchTraits<typename T::VariantType>::Visit(
+ v.variant(), std::forward<VisitFn>(fn));
+ }
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType&& v, VisitFn&& fn) {
+ return MatchTraits<typename T::VariantType>::Visit(
+ std::move(v).variant(), std::forward<VisitFn>(fn));
+ }
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType* v, VisitFn&& fn) {
+ return MatchTraits<typename T::VariantType>::Visit(
+ &v->variant(), std::forward<VisitFn>(fn));
+ }
+};
+
+template <typename VariantType, typename CasesType>
+constexpr auto CreateMatcherImpl(CasesType cases) {
+ return VariantMatcherImpl<MatchTraits<VariantType>, CasesType>(
+ std::move(cases));
+}
+
+} // namespace match_internal
+
+// See file remarks.
+template <typename From, typename To = void, typename... CaseFnTypes>
+constexpr auto MakeMatcher(CaseFnTypes... fns) {
+ return match_internal::CreateMatcherImpl<From>(
+ match_internal::MakeMatchCases<To>(fns...));
+}
+
+// See file remarks.
+//
+// Note that the order of template arguments differs from MakeMatcher; it is
+// expected that 'From' is always deduced (but it can be useful to specify 'To'
+// explicitly).
+template <typename To = void, typename From, typename... CaseFnTypes>
+constexpr auto Match(From&& v, CaseFnTypes... fns) {
+ // 'From' is intended to be deduced. For MakeMatcher, we want V (not e.g. V
+ // const&).
+ auto m = MakeMatcher<std::decay_t<From>, To>(fns...);
+ // The full type is still relevant for forwarding.
+ return m.Match(std::forward<From>(v));
+}
+
+template <typename To = void, typename From, typename... CaseFnTypes>
+constexpr auto Match(From* v, CaseFnTypes... fns) {
+ // 'From' is intended to be deduced. For MakeMatcher, we want V (not e.g. V
+ // const*).
+ auto m = MakeMatcher<std::decay_t<From>, To>(fns...);
+ return m.Match(v);
+}
+
+} // namespace fcp
+
+#endif // FCP_BASE_MATCH_H_
diff --git a/fcp/base/match_test.cc b/fcp/base/match_test.cc
new file mode 100644
index 0000000..156fc17
--- /dev/null
+++ b/fcp/base/match_test.cc
@@ -0,0 +1,260 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/match.h"
+
+#include <memory>
+#include <optional>
+#include <variant>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/result.h"
+#include "fcp/testing/result_matchers.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+
+namespace {
+
+using ::testing::Eq;
+using ::testing::Optional;
+
+struct X {
+ int x;
+};
+
+struct Y {
+ int y;
+};
+
+struct Z {
+ int z;
+};
+
+using V = std::variant<X, Y, Z>;
+
+using VMoveOnly =
+ std::variant<std::unique_ptr<X>, std::unique_ptr<Y>, std::unique_ptr<Z>>;
+
+TEST(MatchTest, AllDefault) {
+ constexpr auto matcher =
+ MakeMatcher<V>([](Default, V const&) { return 1; });
+
+ static_assert(matcher.Match(X{}) == 1);
+ static_assert(matcher.Match(Z{}) == 1);
+ static_assert(matcher.Match(Y{}) == 1);
+}
+
+TEST(MatchTest, SingleDefault) {
+ constexpr auto matcher = MakeMatcher<V>(
+ [](X const& x) { return 10 + x.x; }, //
+ [](Z const& z) { return 20 + z.z; },
+ [](Default, V const& v) { return 30 + absl::get<Y>(v).y; });
+ static_assert(matcher.Match(X{1}) == 11);
+ static_assert(matcher.Match(Z{2}) == 22);
+ static_assert(matcher.Match(Y{3}) == 33);
+}
+
+TEST(MatchTest, SingleDefault_Pointer) {
+ constexpr auto matcher =
+ MakeMatcher<V>([](X* x) { return 10 + x->x; }, //
+ [](Z* z) { return 20 + z->z; },
+ [](Default, V* v) { return 30 + absl::get<Y>(*v).y; });
+
+ V x = X{1};
+ V z = Z{2};
+ V y = Y{3};
+
+ EXPECT_THAT(matcher.Match(&x), Eq(11));
+ EXPECT_THAT(matcher.Match(&z), Eq(22));
+ EXPECT_THAT(matcher.Match(&y), Eq(33));
+}
+
+TEST(MatchTest, Exhaustive) {
+ constexpr auto matcher = MakeMatcher<V>(
+ [](X const& x) { return 10 + x.x; }, [](Z const& z) { return 20 + z.z; },
+ [](Y const& y) { return 30 + y.y; });
+ static_assert(matcher.Match(X{1}) == 11);
+ static_assert(matcher.Match(Z{2}) == 22);
+ static_assert(matcher.Match(Y{3}) == 33);
+}
+
+TEST(MatchTest, Exhaustive_Pointer) {
+ constexpr auto matcher = MakeMatcher<V>([](X* x) { return 10 + x->x; },
+ [](Z* z) { return 20 + z->z; },
+ [](Y* y) { return 30 + y->y; });
+
+ V x = X{1};
+ V z = Z{2};
+ V y = Y{3};
+
+ EXPECT_THAT(matcher.Match(&x), Eq(11));
+ EXPECT_THAT(matcher.Match(&z), Eq(22));
+ EXPECT_THAT(matcher.Match(&y), Eq(33));
+}
+
+TEST(MatchTest, Exhaustive_MatchInsteadOfMatcher) {
+ constexpr auto do_match = [](V const& v) {
+ return Match(
+ v, //
+ [](X const& x) { return 10 + x.x; },
+ [](Z const& z) { return 20 + z.z; },
+ [](Y const& y) { return 30 + y.y; });
+ };
+
+ static_assert(do_match(X{1}) == 11);
+ static_assert(do_match(Z{2}) == 22);
+ static_assert(do_match(Y{3}) == 33);
+}
+
+TEST(MatchTest, CoerceViaExplicitReturnType) {
+ constexpr auto do_match = [](V const& v) {
+ return Match<std::optional<int>>(
+ v, //
+ [](X const& x) { return 10 + x.x; },
+ [](Z const& z) { return 20 + z.z; },
+ [](Y const& y) { return std::nullopt; });
+ };
+
+ static_assert(*do_match(X{1}) == 11);
+ static_assert(*do_match(Z{2}) == 22);
+ static_assert(!do_match(Y{3}).has_value());
+}
+
+TEST(MatchTest, MoveOnly_Borrow_Exhaustive) {
+ constexpr auto matcher = MakeMatcher<VMoveOnly>(
+ [](std::unique_ptr<X> const& x) { return 10 + x->x; },
+ [](std::unique_ptr<Z> const& z) { return 20 + z->z; },
+ [](std::unique_ptr<Y> const& y) { return 30 + y->y; });
+
+ VMoveOnly v_x = std::make_unique<X>(X{1});
+ VMoveOnly v_z = std::make_unique<Z>(Z{2});
+ VMoveOnly v_y = std::make_unique<Y>(Y{3});
+
+ EXPECT_THAT(matcher.Match(v_x), Eq(11));
+ EXPECT_THAT(matcher.Match(v_z), Eq(22));
+ EXPECT_THAT(matcher.Match(v_y), Eq(33));
+}
+
+TEST(MatchTest, MoveOnly_Consume_Exhaustive) {
+ constexpr auto matcher = MakeMatcher<VMoveOnly>(
+ [](std::unique_ptr<X> x) { return 10 + x->x; },
+ [](std::unique_ptr<Z> z) { return 20 + z->z; },
+ [](std::unique_ptr<Y> y) { return 30 + y->y; });
+
+ VMoveOnly v_x = std::make_unique<X>(X{1});
+ VMoveOnly v_z = std::make_unique<Z>(Z{2});
+ VMoveOnly v_y = std::make_unique<Y>(Y{3});
+
+ EXPECT_THAT(matcher.Match(std::move(v_x)), Eq(11));
+ EXPECT_THAT(matcher.Match(std::move(v_z)), Eq(22));
+ EXPECT_THAT(matcher.Match(std::move(v_y)), Eq(33));
+}
+
+// std::optional is handled with a special MatchTraits implementation.
+// The corresponding std::variant has to be synthesized on the fly, so that
+// implementation is trickier than usual.
+
+TEST(MatchTest, Optional_Ref) {
+ using O = std::optional<std::unique_ptr<X>>;
+ constexpr auto matcher =
+ MakeMatcher<O>([](std::unique_ptr<X> const& x) { return x->x; },
+ [](std::nullopt_t) { return 0; });
+
+ O const engaged = std::make_unique<X>(X{123});
+ O const empty = std::nullopt;
+
+ EXPECT_THAT(matcher.Match(engaged), Eq(123));
+ EXPECT_THAT(matcher.Match(empty), Eq(0));
+}
+
+TEST(MatchTest, Optional_Pointer) {
+ using O = std::optional<std::unique_ptr<X>>;
+ constexpr auto matcher = MakeMatcher<O>(
+ [](std::unique_ptr<X>* x) {
+ int v = (*x)->x;
+ x->reset();
+ return v;
+ },
+ [](std::nullopt_t) { return 0; });
+
+ O engaged = std::make_unique<X>(X{123});
+ O empty = std::nullopt;
+
+ EXPECT_THAT(matcher.Match(&engaged), Eq(123));
+ EXPECT_THAT(engaged, Optional(Eq(nullptr)));
+ EXPECT_THAT(matcher.Match(&empty), Eq(0));
+}
+
+TEST(MatchTest, Optional_Consume) {
+ using O = std::optional<std::unique_ptr<X>>;
+ constexpr auto matcher =
+ MakeMatcher<O>([](std::unique_ptr<X> x) { return x->x; },
+ [](std::nullopt_t) { return 0; });
+
+ EXPECT_THAT(matcher.Match(O{std::make_unique<X>(X{123})}), Eq(123));
+ EXPECT_THAT(matcher.Match(O{std::nullopt}), Eq(0));
+}
+
+// Result<T> uses the extensibility mechanism provided by MatchTrait
+// (VariantType alias and a variant() accessor). These tests demonstrate that
+// MatchTraits is extensible (in addition to testing the particular
+// implementation for Result).
+
+TEST(MatchTest, Result_Ref) {
+ using R = Result<std::unique_ptr<X>>;
+ constexpr auto matcher =
+ MakeMatcher<R>([](std::unique_ptr<X> const& x) { return x->x; },
+ [](Error) { return 0; });
+
+ R const val = std::make_unique<X>(X{123});
+ R const err = TraceTestError();
+
+ EXPECT_THAT(matcher.Match(val), Eq(123));
+ EXPECT_THAT(matcher.Match(err), Eq(0));
+}
+
+TEST(MatchTest, Result_Pointer) {
+ using R = Result<std::unique_ptr<X>>;
+ constexpr auto matcher = MakeMatcher<R>(
+ [](std::unique_ptr<X>* x) {
+ int v = (*x)->x;
+ x->reset();
+ return v;
+ },
+ [](Error*) { return 0; });
+
+ R val = std::make_unique<X>(X{123});
+ R err = TraceTestError();
+
+ EXPECT_THAT(matcher.Match(&val), Eq(123));
+ EXPECT_THAT(val, HasValue(Eq(nullptr)));
+ EXPECT_THAT(matcher.Match(&err), Eq(0));
+}
+
+TEST(MatchTest, Result_Consume) {
+ using R = Result<std::unique_ptr<X>>;
+ constexpr auto matcher = MakeMatcher<R>(
+ [](std::unique_ptr<X> x) { return x->x; }, [](Error) { return 0; });
+
+ EXPECT_THAT(matcher.Match(R(std::make_unique<X>(X{123}))), Eq(123));
+ EXPECT_THAT(matcher.Match(R(TraceTestError())), Eq(0));
+}
+
+} // namespace
+
+} // namespace fcp
diff --git a/fcp/base/meta.h b/fcp/base/meta.h
new file mode 100644
index 0000000..91d53cc
--- /dev/null
+++ b/fcp/base/meta.h
@@ -0,0 +1,444 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * This file provides general utilities for metaprogramming.
+ *
+ * - LIFT_MEMBER_TO_TYPE: Generates distinct types in correspondence with
+ * member-pointers (for both fields and functions). For example,
+ * LIFT_MEMBER_TO_TYPE(S, X) != LIFT_MEMBER_TO_TYPE(R, X), even if the
+ * declarations of S::X and R::X are identical.
+ *
+ * - Unit: An empty struct (i.e. has a single canonical element). It is useful
+ * in contexts where a non-void return type is necessary but undesired: for
+ * example, in a constexpr function called only for static_asserts.
+ *
+ * - Pack<T...>: Helps passing around the 'parameter packs' arising from
+ * variadic templates (they are not first-class). In particular, these allow
+ * writing a function such that F<A, B>() and F<Pack<A, B>>() are equivalent
+ * (Pack<A, B> _is_ first-class).
+ *
+ * - MemberPointerTraits: Allows removing the 'container' part from
+ * member-pointer types, e.g.
+ * 'R T::*' => 'R'
+ * 'R (T::*)(A, B)' => 'R(A, B)'
+ *
+ * - FunctionTraits: Allows destructuring function types, e.g. 'bool(int,
+ * int)' into ResultType = bool, ArgPackType = Pack<int, int>.
+ *
+ * - FailIfReached: Allows writing static_asserts for templates that should
+ * never be instantiated. This is a workaround for the fact that
+ * 'static_assert(false, "")' can trigger regardless of where it's located.
+ *
+ * - Identity<T>: An alias useful with higher order templates.
+ *
+ * - CastContainerElements: Allows 'casting' homogenous containers to
+ * heterogenous tuples, e.g. vector<X> -> tuple<A, B> - useful when
+ * when the type-list was erased earlier.
+ *
+ * - LiftVoidReturn: Wraps a callable object, so that returned 'void' becomes
+ * 'Unit' (if applicable). This avoids spread of special cases when handling
+ * callables and function-types generically (e.g. 'auto r = f()' is valid
+ * for f() returning anything _except_ void).
+ *
+ * - MAKE_LINK and LinkedType<T>: Given types T, U and MAKE_LINK(T, U),
+ * LinkedType<T> == U. This can often be handled with template
+ * specialization, but (like AbslHashValue) we use ADL so that T (and
+ * MAKE_LINK next to it) can appear in any namespace.
+ *
+ * - IsTypeOneOf<T, Us...>: A function to determines if the type T is in the
+ * list of types Us.
+ *
+ * - IsSubsetOf<Pack<Ts...>, Pack<Us...>>: Determins if a pack of types Ts
+ * is a subset of a pack of types Us.
+ */
+
+#ifndef FCP_BASE_META_H_
+#define FCP_BASE_META_H_
+
+#include <tuple>
+#include <type_traits>
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+
+/**
+ * An empty struct - i.e. there is a single canonical element.
+ *
+ * It is useful in contexts where a non-void return type is necessary but
+ * undesired: for example, in a constexpr function called only for
+ * static_asserts.
+ *
+ * Unit defines equality (they're always equal). True() always returns true,
+ * which is convenient for allowing a unit-returning function call in a
+ * static_assert.
+ *
+ * Unit::Ignore(...) sinks any arguments to a Unit. This is useful in C++11's
+ * restricted constexpr as well as for parameter-pack expansions.
+ */
+struct Unit {
+ constexpr bool operator==(Unit other) const { return true; }
+ constexpr bool operator!=(Unit other) const { return !(*this == other); }
+ constexpr bool True() const { return true; }
+
+ /** Ignores all arguments (of any type), returning Unit */
+ template <typename... ArgTypes>
+ static constexpr Unit Ignore(ArgTypes... args) {
+ return {};
+ }
+};
+
+/**
+ * Pack<T...> facilitates passing around a parameter-pack T...
+ *
+ * Types are more or less first-class, in that you can place one somewhere (e.g.
+ * as a struct member) and use it later. This is not the case for
+ * parameter-packs: one can only expand T... within some template<typename...
+ * T>.
+ *
+ * Pack<> is a work-around for that:
+ *
+ * - To store a parameter-pack T... in hand: Instead store Pack<T...>, e.g.
+ * 'using P = Pack<T...>'
+ *
+ * - To revitalize the parameter pack later: Define a target function like
+ * template<typename... T> F(Pack<T...>)
+ * and call it as
+ * F(P{})
+ * (noting P from the prior example). The T... in scope of F arises from
+ * template argument deduction.
+ */
+template <typename... T>
+struct Pack {
+ /** Returns the related index-sequence type.
+ *
+ * Example:
+ *
+ * template <typename... T, size_t... Idx>
+ * void Impl(Pack<T...>, absl::index_sequence<Idx...>) {
+ * auto zipped[] = {
+ * F<T>(Idx)... // T... and Idx... are zipped together.
+ * };
+ * }
+ *
+ * template <typename... T>
+ * void Foo(Pack<T...> pack) {
+ * Impl(pack, pack.MakeIndexSequence());
+ * }
+ */
+ static constexpr absl::index_sequence_for<T...> MakeIndexSequence() {
+ return {};
+ }
+};
+
+/**
+ * Workaround for static_assert(false) tripping even for un-instantiated
+ * templates.
+ */
+template <typename T>
+constexpr bool FailIfReached() {
+ return !std::is_same<T, T>::value;
+}
+
+namespace meta_internal {
+
+template <typename T, T M>
+struct MemberTag {
+ static_assert(std::is_member_pointer<T>::value,
+ "Expected a member-pointer type");
+};
+
+template <typename CastOp, typename... T>
+using CastResultType = std::tuple<typename CastOp::template TargetType<T>...>;
+
+template <typename... T, size_t... Idx, typename Container, typename CastOp>
+CastResultType<CastOp, T...> CastContainerElementsImpl(
+ Container const& container, CastOp const& cast, Pack<T...>,
+ absl::index_sequence<Idx...>) {
+ FCP_CHECK(sizeof...(T) == container.size());
+ return CastResultType<CastOp, T...>{cast.template Cast<T>(container[Idx])...};
+}
+
+template <typename F>
+class VoidLifter {
+ private:
+ template <typename T>
+ struct Tag {};
+
+ template <typename... A>
+ Unit DoCall(Tag<void>, A&&... args) {
+ f_(std::forward<A>(args)...);
+ return {};
+ }
+
+ template <typename R, typename... A>
+ R DoCall(Tag<R>, A&&... args) {
+ return f_(std::forward<A>(args)...);
+ }
+
+ public:
+ explicit VoidLifter(F f) : f_(std::move(f)) {}
+
+ template <typename... A>
+ auto operator()(A&&... args) -> decltype(
+ DoCall(Tag<decltype(std::declval<F>()(std::forward<A>(args)...))>{},
+ std::forward<A>(args)...)) {
+ return DoCall(Tag<decltype(std::declval<F>()(std::forward<A>(args)...))>{},
+ std::forward<A>(args)...);
+ }
+
+ private:
+ F f_;
+};
+
+template <typename U, typename Dummy = void>
+struct FailIfLinkMissing {
+ using Type = U;
+};
+
+template <typename Dummy>
+struct FailIfLinkMissing<void, Dummy> {
+ static_assert(FailIfReached<Dummy>(),
+ "Expected a type linked from T, via MAKE_LINK(T, U). Note that "
+ "MAKE_LINK must appear in the same namespace as T.");
+};
+
+template <typename T>
+struct LinkedTypeToken {
+ using Type = T;
+};
+
+/**
+ * Default case for LookupTypeLink. MAKE_LINK creates overloads which are more
+ * specific (argument type matches without needing a template).
+ */
+template <typename T>
+inline LinkedTypeToken<void> TypeLink_(LinkedTypeToken<T>) {
+ return {};
+}
+
+/**
+ * Resolves MAKE_LINK at the level of values (i.e. the link target is
+ * represented in the return type). May be called qualified, i.e.
+ * fcp::meta_internal::LookupTypeLink.
+ *
+ * This depends on ADL. TypeLink_ is an unqualified name, so those next to T are
+ * overload candidates. As such, it's fine to call meta_internal::LookupTypeLink
+ * but *not* meta_internal::TypeLink_ (hence this indirection).
+ */
+template <typename T>
+constexpr auto LookupTypeLink(LinkedTypeToken<T> t) -> decltype(TypeLink_(t)) {
+ return {};
+}
+
+template <template <typename> class M, typename Z>
+struct UnwrapTemplateImpl {
+ static constexpr bool kValid = false;
+
+ struct Type {
+ static_assert(FailIfReached<Z>(), "Z must be M<T> for some type T");
+ };
+};
+
+template <template <typename> class M, typename T>
+struct UnwrapTemplateImpl<M, M<T>> {
+ static constexpr bool kValid = true;
+ using Type = T;
+};
+
+template <template <typename> class M, typename Z>
+using UnwrapTemplate = meta_internal::UnwrapTemplateImpl<M, std::decay_t<Z>>;
+
+} // namespace meta_internal
+
+/**
+ * Generates distinct types in correspondence with member-pointers (for both
+ * fields and functions).
+ *
+ * For example, LIFT_MEMBER_TO_TYPE(S, X) != LIFT_MEMBER_TO_TYPE(R, X), even if
+ * the declarations of S::X and R::X are identical.
+ *
+ * The lifted type is always an empty struct, so it can be instantiated with {}
+ * (for use in overload resolution) at no cost.
+ */
+#define LIFT_MEMBER_TO_TYPE(type, member) \
+ LIFT_MEMBER_POINTER_TO_TYPE(&type::member)
+
+/**
+ * Same as LIFT_MEMBER_TO_TYPE, but invoked as e.g.
+ * LIFT_MEMBER_POINTER_TO_TYPE(&S::X)
+ */
+#define LIFT_MEMBER_POINTER_TO_TYPE(ptr) \
+ ::fcp::meta_internal::MemberTag<decltype(ptr), ptr>
+
+/**
+ * Allows removing the 'container' part from member-pointer types, e.g.
+ * 'R T::*' => 'R' 'R (T::*)(A, B)' => 'R(A, B)'
+ */
+template <typename T>
+struct MemberPointerTraits {
+ static_assert(
+ FailIfReached<T>(),
+ "Expected a member pointer (both fields and functions are accepted)");
+};
+
+template <typename T, typename R>
+struct MemberPointerTraits<R T::*> {
+ using TargetType = R;
+};
+
+template <typename T>
+struct FunctionTraits {
+ static_assert(FailIfReached<T>(), "Expected a function type");
+};
+
+template <typename R, typename... A>
+struct FunctionTraits<R(A...)> {
+ using ResultType = R;
+ using ArgPackType = Pack<A...>;
+};
+
+/** Type-level identity function; useful for higher order templates */
+template <typename T>
+using Identity = T;
+
+/** See other overload; this one takes a Pack<T...> instead of explicit T... */
+template <typename... T, typename Container, typename CastOp>
+auto CastContainerElements(Pack<T...> pack, Container const& container,
+ CastOp const& cast)
+ -> decltype(meta_internal::CastContainerElementsImpl(
+ container, cast, pack, pack.MakeIndexSequence())) {
+ return meta_internal::CastContainerElementsImpl(container, cast, pack,
+ pack.MakeIndexSequence());
+}
+
+/**
+ * Allows 'casting' homogenous containers to heterogenous tuples, e.g.
+ * vector<X> -> tuple<A, B> - useful when when the type-list was erased
+ * earlier.
+ *
+ * 'CastOp' determines how to cast each element. It should be a type like the
+ * following:
+ *
+ * struct FooCast {
+ * template<typename T>
+ * using TargetType = Y<T>;
+ *
+ * template <typename T>
+ * TargetType<T> Cast(X const& val) const {
+ * ...
+ * }
+ * };
+ *
+ * Supposing vector<X> vx, CastContainerElements<A, B>(vx, FooCast{}) would
+ * yield a tuple<Y<A>, Y<B>> with values {Cast<A>(vx[0]), Cast<B>(vx[1])}.
+ *
+ * This function supports the 'Pack' wrapper. For example, the previous example
+ * could also be written as CastContainerElements(Pack<X, Y>{}, vx, FooCast{}).
+ */
+template <typename... T, typename Container, typename CastOp>
+auto CastContainerElements(Container const& container, CastOp const& cast)
+ -> decltype(CastContainerElements(Pack<T...>{}, container, cast)) {
+ return CastContainerElements(Pack<T...>{}, container, cast);
+}
+
+/**
+ * Wraps a callable object, so that returned 'void' becomes 'Unit' (if
+ * applicable). This avoids spread of special cases when handling callables and
+ * function-types generically (e.g. 'auto r = f()' is valid for f() returning
+ * anything _except_ void).
+ */
+template <typename F>
+meta_internal::VoidLifter<F> LiftVoidReturn(F f) {
+ return meta_internal::VoidLifter<F>(std::move(f));
+}
+
+/** See LinkedType<T> */
+#define MAKE_LINK(a, b) \
+ inline ::fcp::meta_internal::LinkedTypeToken<b> TypeLink_( \
+ ::fcp::meta_internal::LinkedTypeToken<a>) { \
+ return {}; \
+ }
+
+/**
+ * See LinkedType<T>. This form returns void instead of failing when a link is
+ * missing
+ */
+template <typename T>
+using LinkedTypeOrVoid = typename decltype(meta_internal::LookupTypeLink(
+ std::declval<meta_internal::LinkedTypeToken<T>>()))::Type;
+
+/**
+ * Indicates if some MAKE_LINK(T, ...) is visible.
+ */
+template <typename T>
+constexpr bool HasLinkedType() {
+ return !std::is_same<LinkedTypeOrVoid<T>, void>::value;
+}
+
+/**
+ * Given types T, U and MAKE_LINK(T, U), LinkedType<T> == U.
+ *
+ * This can often be handled with template specialization, but (like
+ * AbslHashValue) we use ADL to avoid restrictions on the namespaces in which
+ * specializations can appear.
+ *
+ * The type T can appear in any namespace, but MAKE_LINK(T, U) must appear in
+ * the same namespace (ideally, place it right after the declaration of T).
+ * LinkedType<T> then works in any namespace.
+ *
+ * It is an error to use this alias for a T without a MAKE_LINK. See
+ * HasLinkedType() and LinkedTypeOrVoid.
+ */
+template <typename T>
+using LinkedType =
+ typename meta_internal::FailIfLinkMissing<LinkedTypeOrVoid<T>>::Type;
+
+
+/*
+ * Given type T and typelist Us... determines if T is one of the types in Us.
+ */
+template <typename T, typename... Us>
+struct IsTypeOneOfT : std::disjunction<std::is_same<T, Us>...> {};
+
+template <typename T, typename... Us>
+constexpr bool IsTypeOneOf() {
+ return IsTypeOneOfT<T, Us...>::value;
+}
+
+/*
+ * Given two typelists Ts... and Us... determines if Ts is a subset of Us.
+ */
+template <typename Ts, typename Us>
+struct IsSubsetOf : std::false_type {};
+
+template <typename... Ts, typename... Us>
+struct IsSubsetOf<Pack<Ts...>, Pack<Us...>>
+ : std::conjunction<IsTypeOneOfT<Ts, Us...>...> {};
+
+template <template <typename> class M, typename Z>
+using UnapplyTemplate =
+ typename meta_internal::UnwrapTemplateImpl<M, std::decay_t<Z>>::Type;
+
+template <template <typename> class M, typename Z>
+constexpr bool IsAppliedTemplate() {
+ return meta_internal::UnwrapTemplateImpl<M, std::decay_t<Z>>::kValid;
+}
+
+} // namespace fcp
+
+#endif // FCP_BASE_META_H_
diff --git a/fcp/base/meta_test.cc b/fcp/base/meta_test.cc
new file mode 100644
index 0000000..911cc8e
--- /dev/null
+++ b/fcp/base/meta_test.cc
@@ -0,0 +1,397 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/meta.h"
+
+#include <functional>
+#include <memory>
+#include <optional>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace fcp {
+
+using ::testing::Eq;
+using ::testing::Not;
+
+struct R {
+ virtual ~R() = default;
+ virtual bool Virt1(int i) = 0;
+ virtual bool Virt2(int i, int j) = 0;
+ virtual void Virt3() = 0;
+ int NonVirt1() { return 1; }
+ int NonVirt2() { return 2; }
+ char field;
+};
+
+struct S {
+ virtual ~S() = default;
+ virtual bool Virt1(int i) = 0;
+ virtual bool Virt2(int i, int j) = 0;
+ virtual void Virt3() = 0;
+ int NonVirt1() { return 1; }
+ int NonVirt2() { return 2; }
+ char field;
+};
+
+//
+// Compile-time tests for MemberPointerTraits
+//
+#define STATIC_ASSERT_TARGET_TYPE(member, type) \
+ static_assert( \
+ std::is_same<MemberPointerTraits<decltype(&S::member)>::TargetType, \
+ type>::value, \
+ "Incorrect target type from MemberPointerTraits");
+
+// For some reason the linter otherwise thinks e.g. 'bool(int)' is an old-style
+// cast.
+template<typename T>
+struct Id { using Type = T; };
+
+STATIC_ASSERT_TARGET_TYPE(Virt1, Id<bool(int)>::Type);
+STATIC_ASSERT_TARGET_TYPE(Virt2, Id<bool(int, int)>::Type);
+STATIC_ASSERT_TARGET_TYPE(Virt3, Id<void()>::Type);
+STATIC_ASSERT_TARGET_TYPE(NonVirt1, Id<int()>::Type);
+STATIC_ASSERT_TARGET_TYPE(NonVirt2, Id<int()>::Type);
+STATIC_ASSERT_TARGET_TYPE(field, Id<char>::Type);
+//
+// End compile-time tests for MemberPointerTraits
+//
+
+template<typename T>
+struct TypeIdHolder {
+ static constexpr char kTarget = '\0';
+};
+template<typename T>
+constexpr char TypeIdHolder<T>::kTarget;
+
+// This gives us a unique runtime value per unique type - with which its much
+// easier to verify uniqueness (below).
+template<typename T>
+static constexpr char const* TypeId() {
+ return &TypeIdHolder<T>::kTarget;
+}
+
+//
+// LIFT_MEMBER_TO_TYPE
+//
+
+TEST(MetaTest, MemberTagUniqueness) {
+ std::vector<char const*> type_ids = {
+ TypeId<LIFT_MEMBER_TO_TYPE(S, Virt1)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(S, Virt2)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(S, Virt3)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(S, NonVirt1)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(S, NonVirt2)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(S, field)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(R, Virt1)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(R, Virt2)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(R, Virt3)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(R, NonVirt1)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(R, NonVirt2)>(),
+ TypeId<LIFT_MEMBER_TO_TYPE(R, field)>(),
+ };
+
+ for (int i = 0; i < type_ids.size(); i++) {
+ for (int j = 0; j < type_ids.size(); j++) {
+ if (i == j) {
+ continue;
+ }
+ EXPECT_THAT(type_ids[i], Not(Eq(type_ids[j])))
+ << "Member tags must be unique";
+ }
+ }
+}
+
+int PokeMemberCase(LIFT_MEMBER_TO_TYPE(S, Virt1)) {
+ return 1;
+}
+
+int PokeMemberCase(LIFT_MEMBER_TO_TYPE(S, Virt2)) {
+ return 2;
+}
+
+template<typename R, R S::* M>
+int PokeMember() {
+ return PokeMemberCase(LIFT_MEMBER_POINTER_TO_TYPE(M){});
+}
+
+TEST(MetaTest, MemberTagDispatch) {
+ EXPECT_THAT((PokeMember<bool(int), &S::Virt1>()), Eq(1));
+ EXPECT_THAT((PokeMember<bool(int, int), &S::Virt2>()), Eq(2));
+}
+
+//
+// CastContainerElements
+//
+
+struct X {
+ static constexpr int tag() { return 1; }
+};
+
+struct Y {
+ static constexpr int tag() { return 2; }
+};
+
+template<typename T>
+struct TypedVal {
+ int value;
+
+ bool operator==(TypedVal<T> const& other) const {
+ return other.value == value;
+ }
+ bool operator!=(TypedVal<T> const& other) const { return !(*this == other); }
+};
+
+struct UntypedVal {
+ int tag;
+ int value;
+};
+
+struct CastValByTag {
+ template <typename T>
+ using TargetType = std::optional<TypedVal<T>>;
+
+ template <typename T>
+ TargetType<T> Cast(UntypedVal const& val) const {
+ if (val.tag == T::tag()) {
+ return TypedVal<T>{val.value};
+ } else {
+ return std::nullopt;
+ }
+ }
+};
+
+TEST(MetaTest, CastContainerElements_AllSuccess) {
+ std::vector<UntypedVal> v{
+ {X::tag(), 123},
+ {Y::tag(), 456},
+ {X::tag(), 789}
+ };
+
+ auto actual = CastContainerElements<X, Y, X>(v, CastValByTag{});
+ auto expected = std::make_tuple(absl::make_optional(TypedVal<X>{123}),
+ absl::make_optional(TypedVal<Y>{456}),
+ absl::make_optional(TypedVal<X>{789}));
+
+ EXPECT_THAT(actual, Eq(expected));
+}
+
+TEST(MetaTest, CastContainerElements_AllSuccess_Pack) {
+ std::vector<UntypedVal> v{
+ {X::tag(), 123},
+ {Y::tag(), 456},
+ {X::tag(), 789}
+ };
+
+ // This uses the Pack<> overload instead.
+ auto actual = CastContainerElements(Pack<X, Y, X>{}, v, CastValByTag{});
+ auto expected = std::make_tuple(absl::make_optional(TypedVal<X>{123}),
+ absl::make_optional(TypedVal<Y>{456}),
+ absl::make_optional(TypedVal<X>{789}));
+
+ EXPECT_THAT(actual, Eq(expected));
+}
+
+TEST(MetaTest, CastContainerElements_OneFails) {
+ std::vector<UntypedVal> v{
+ {X::tag(), 123},
+ {X::tag(), 456},
+ {X::tag(), 789}
+ };
+
+ // Second element does not have the tag for Y.
+ auto actual = CastContainerElements<X, Y, X>(v, CastValByTag{});
+ auto expected = std::make_tuple(absl::make_optional(TypedVal<X>{123}),
+ std::optional<TypedVal<Y>>(std::nullopt),
+ absl::make_optional(TypedVal<X>{789}));
+
+ EXPECT_THAT(actual, Eq(expected));
+}
+
+//
+// MAKE_LINK and LinkedType<>
+//
+
+namespace links {
+
+namespace a {
+
+struct A1 {};
+struct A2 {};
+struct A3 {};
+
+MAKE_LINK(A1, A2);
+
+} // namespace a
+
+namespace b {
+
+struct B1 {};
+
+MAKE_LINK(B1, a::A3);
+
+} // namespace b
+
+} // namespace links
+
+static_assert(std::is_same<LinkedType<links::a::A1>, links::a::A2>::value,
+ "A1 -> A2");
+static_assert(HasLinkedType<links::a::A1>(), "A1 -> A2");
+static_assert(std::is_same<LinkedTypeOrVoid<links::a::A2>, void>::value,
+ "A2 -/>");
+static_assert(!HasLinkedType<links::a::A2>(), "A2 -/>");
+static_assert(std::is_same<LinkedTypeOrVoid<links::a::A3>, void>::value,
+ "A3 -/>");
+static_assert(!HasLinkedType<links::a::A3>(), "A3 -/>");
+static_assert(std::is_same<LinkedType<links::b::B1>, links::a::A3>::value,
+ "b::B1 -> a::A3");
+static_assert(HasLinkedType<links::b::B1>(), "b::B1 -> a::A3");
+
+//
+// Pack<>
+//
+
+template<typename A1, typename A2, size_t I1, size_t I2>
+constexpr Unit CheckUnpack() {
+ static_assert(std::is_same<A1, X>::value, "A1 == X");
+ static_assert(std::is_same<A2, Y>::value, "A2 == Y");
+ static_assert(I1 == 0, "I1 == 0");
+ static_assert(I2 == 1, "I2 == 0");
+ return {};
+}
+
+template<typename... A, size_t... I>
+constexpr Unit UsePack(Pack<A...>, absl::index_sequence<I...>) {
+ return CheckUnpack<A..., I...>();
+}
+
+template<typename... A>
+constexpr Unit MakeAndUsePack() {
+ return UsePack(Pack<A...>{}, Pack<A...>::MakeIndexSequence());
+}
+
+static_assert(MakeAndUsePack<X, Y>().True(), "Pack<>");
+
+//
+// LiftVoidReturn
+//
+
+TEST(MetaTest, LiftVoidReturn_Void) {
+ int counter = 0;
+ std::function<void()> f = [&counter]() { counter++; };
+
+ f();
+ EXPECT_THAT(counter, Eq(1));
+ auto f_wrapped = LiftVoidReturn(f);
+ EXPECT_THAT(f_wrapped(), Eq(Unit{}));
+ EXPECT_THAT(counter, Eq(2));
+}
+
+TEST(MetaTest, LiftVoidReturn_Void_Args) {
+ int counter = 0;
+ std::function<void(int)> f = [&counter](int i) { counter += i; };
+
+ f(10);
+ EXPECT_THAT(counter, Eq(10));
+ auto f_wrapped = LiftVoidReturn(f);
+ EXPECT_THAT(f_wrapped(32), Eq(Unit{}));
+ EXPECT_THAT(counter, Eq(42));
+}
+
+TEST(MetaTest, LiftVoidReturn_NonVoid) {
+ int counter = 0;
+ std::function<int(int)> f = [&counter](int i) {
+ counter += i;
+ return counter;
+ };
+
+ EXPECT_THAT(f(10), Eq(10));
+ EXPECT_THAT(counter, Eq(10));
+ auto f_wrapped = LiftVoidReturn(f);
+ EXPECT_THAT(f_wrapped(32), Eq(42));
+ EXPECT_THAT(counter, Eq(42));
+}
+
+TEST(MetaTest, LiftVoidReturn_Mutable) {
+ int r = -1;
+ auto f = [&r, counter = 0]() mutable {
+ counter++;
+ r = counter;
+ };
+
+ f();
+ EXPECT_THAT(r, Eq(1));
+ auto f_wrapped = LiftVoidReturn(f);
+ EXPECT_THAT(f_wrapped(), Eq(Unit{}));
+ EXPECT_THAT(r, Eq(2));
+}
+
+TEST(MetaTest, LiftVoidReturn_MutableAndMoveOnly) {
+ int r = -1;
+ auto f = [&r, counter = std::make_unique<int>(0)]() mutable {
+ (*counter)++;
+ r = *counter;
+ };
+
+ f();
+ EXPECT_THAT(r, Eq(1));
+ auto f_wrapped = LiftVoidReturn(std::move(f));
+ EXPECT_THAT(f_wrapped(), Eq(Unit{}));
+ EXPECT_THAT(r, Eq(2));
+}
+
+//
+// FunctionTraits
+//
+
+#define STATIC_ASSERT_FUNCTION_TRAITS(fn, r, ...) \
+ static_assert(std::is_same<FunctionTraits<fn>::ResultType, r>::value, \
+ "Incorrect result type from FunctionTraits"); \
+ static_assert( \
+ std::is_same<FunctionTraits<fn>::ArgPackType, Pack<__VA_ARGS__>>::value, \
+ "Incorrect arg pack from FunctionTraits")
+
+STATIC_ASSERT_FUNCTION_TRAITS(void(), void);
+STATIC_ASSERT_FUNCTION_TRAITS(void(int, char), void, int, char);
+STATIC_ASSERT_FUNCTION_TRAITS(Identity<bool(char const*, int)>, bool,
+ char const*, int);
+
+TEST(MetaTest, IsTypeOneOf) {
+ static_assert(IsTypeOneOf<int, int>());
+ static_assert(IsTypeOneOf<int, int, double>());
+ static_assert(IsTypeOneOf<int, double, int>());
+ static_assert(!IsTypeOneOf<int, bool>());
+ static_assert(!IsTypeOneOf<int, double, char>());
+}
+
+TEST(MetaTest, IsSubsetOf) {
+ using T1 = Pack<int, double>;
+ using T2 = Pack<double, int>;
+ using T3 = Pack<int, double, char>;
+
+ static_assert(IsSubsetOf<T1, T1>::value);
+ static_assert(IsSubsetOf<T1, T2>::value);
+ static_assert(IsSubsetOf<T2, T1>::value);
+ static_assert(IsSubsetOf<T2, T3>::value);
+ static_assert(!IsSubsetOf<T3, T2>::value);
+}
+
+} // namespace fcp
diff --git a/fcp/base/monitoring.cc b/fcp/base/monitoring.cc
new file mode 100644
index 0000000..7af8daf
--- /dev/null
+++ b/fcp/base/monitoring.cc
@@ -0,0 +1,131 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/monitoring.h"
+
+#include <stdlib.h> /* for abort() */
+
+#include <string>
+
+#ifndef FCP_BAREMETAL
+#include <stdarg.h>
+#include <stdio.h>
+
+#ifdef __ANDROID__
+#include <android/log.h>
+#endif
+
+#include <cstring>
+
+#include "absl/strings/str_format.h"
+#endif // FCP_BAREMETAL
+
+#include "fcp/base/base_name.h"
+
+namespace fcp {
+
+namespace internal {
+
+namespace {
+#ifdef __ANDROID__
+constexpr char kAndroidLogTag[] = "fcp";
+
+int AndroidLogLevel(LogSeverity severity) {
+ switch (severity) {
+ case LogSeverity::kFatal:
+ return ANDROID_LOG_FATAL;
+ case LogSeverity::kError:
+ return ANDROID_LOG_ERROR;
+ case LogSeverity::kWarning:
+ return ANDROID_LOG_WARN;
+ default:
+ return ANDROID_LOG_INFO;
+ }
+}
+#endif
+
+} // namespace
+
+// Provides safe static initialization of the default global logger instance.
+// TODO(team): Improve the logger registration mechanism.
+Logger*& GetGlobalLogger() {
+ static Logger* global_logger = new Logger();
+ return global_logger;
+}
+
+Logger* logger() { return GetGlobalLogger(); }
+void set_logger(Logger* logger) { GetGlobalLogger() = logger; }
+
+void Logger::Log(const char* file, int line, LogSeverity severity,
+ const char* message) {
+#ifndef FCP_BAREMETAL
+ auto base_file_name = BaseName(file);
+#ifdef __ANDROID__
+ bool log_to_logcat = true;
+#ifdef NDEBUG
+ // We don't log INFO logs on Android if this is a production build, since
+ // they're too verbose. We can't just log them at ANDROID_LOG_VERBOSE either,
+ // since then they'd still show up in the logcat unless we first check
+ // __android_log_is_loggable, but that function isn't available until Android
+ // API level 30. So to keep things simple we only log warnings or above,
+ // unless this is a debug build.
+ log_to_logcat = severity != LogSeverity::kInfo;
+#endif // NDEBUG
+ if (log_to_logcat) {
+ int level = AndroidLogLevel(severity);
+ __android_log_print(level, kAndroidLogTag, "%c %s:%d %s\n",
+ absl::LogSeverityName(severity)[0],
+ base_file_name.c_str(), line, message);
+ }
+#endif // __ANDROID__
+ // Note that on Android we print both to logcat *and* stderr. This allows
+ // tests to use ASSERT_DEATH to test for fatal error messages, among other
+ // uses.
+ absl::FPrintF(stderr, "%c %s:%d %s\n", absl::LogSeverityName(severity)[0],
+ base_file_name, line, message);
+#endif // FCP_BAREMETAL
+}
+
+StatusBuilder::StatusBuilder(StatusCode code, const char* file, int line)
+ : file_(file), line_(line), code_(code), message_() {}
+
+StatusBuilder::StatusBuilder(StatusBuilder const& other)
+ : file_(other.file_),
+ line_(other.line_),
+ code_(other.code_),
+ message_(other.message_.str()) {}
+
+StatusBuilder::operator Status() {
+ auto message_str = message_.str();
+ if (code_ != OK) {
+ StringStream status_message;
+ status_message << "(at " << BaseName(file_) << ":" << line_ << message_str;
+ message_str = status_message.str();
+ if (log_severity_ != kNoLog) {
+ StringStream log_message;
+ log_message << "[" << code_ << "] " << message_str;
+ logger()->Log(file_, line_, log_severity_, log_message.str().c_str());
+ if (log_severity_ == LogSeverity::kFatal) {
+ abort();
+ }
+ }
+ }
+ return Status(code_, message_str);
+}
+
+} // namespace internal
+
+} // namespace fcp
diff --git a/fcp/base/monitoring.h b/fcp/base/monitoring.h
new file mode 100644
index 0000000..988ff91
--- /dev/null
+++ b/fcp/base/monitoring.h
@@ -0,0 +1,587 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_BASE_MONITORING_H_
+#define FCP_BASE_MONITORING_H_
+
+#include <string>
+#include <utility>
+
+#include "fcp/base/new.h"
+
+#ifdef FCP_BAREMETAL
+#include "fcp/base/string_stream.h"
+#else
+#include <cstdlib>
+#include <iostream>
+#include <ostream>
+#include <sstream>
+
+#include "absl/base/attributes.h"
+#include "absl/base/log_severity.h"
+#include "absl/base/optimization.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#endif // FCP_BAREMETAL
+
+namespace fcp {
+
+// General Definitions
+// ===================
+
+/**
+ * Indicates to run the functionality in this file in debugging mode. This
+ * creates more exhaustive diagnostics in some cases, as documented case by
+ * case.
+ *
+ * We may want to make this a flag if this turns out to be often needed.
+ */
+constexpr bool fcp_debug = false;
+
+// Logging and Assertions
+// ======================
+
+/**
+ * Defines a subset of Google style logging. Use FCP_LOG(INFO),
+ * FCP_LOG(WARNING), FCP_LOG(ERROR) or FCP_LOG(FATAL) to stream log messages.
+ * Example:
+ *
+ * FCP_LOG(INFO) << "some info log";
+ */
+// TODO(team): adapt to absl logging once available
+#define FCP_LOG(severity) _FCP_LOG_##severity
+
+#define FCP_LOG_IF(severity, condition) \
+ !(condition) ? (void)0 \
+ : ::fcp::internal::LogMessageVoidify() & _FCP_LOG_##severity
+
+/**
+ * Log Severity
+ */
+#ifdef FCP_BAREMETAL
+enum class LogSeverity : int {
+ kInfo = 0,
+ kWarning = 1,
+ kError = 2,
+ kFatal = 3,
+};
+#else
+using LogSeverity = absl::LogSeverity;
+#endif // FCP_BAREMETAL
+
+#define _FCP_LOG_INFO \
+ ::fcp::internal::LogMessage(__FILE__, __LINE__, ::fcp::LogSeverity::kInfo)
+#define _FCP_LOG_WARNING \
+ ::fcp::internal::LogMessage(__FILE__, __LINE__, ::fcp::LogSeverity::kWarning)
+#define _FCP_LOG_ERROR \
+ ::fcp::internal::LogMessage(__FILE__, __LINE__, ::fcp::LogSeverity::kError)
+#define _FCP_LOG_FATAL \
+ ::fcp::internal::LogMessage(__FILE__, __LINE__, ::fcp::LogSeverity::kFatal)
+
+#ifdef FCP_BAREMETAL
+#define FCP_PREDICT_FALSE(x) (x)
+#define FCP_PREDICT_TRUE(x) (x)
+#else
+#define FCP_PREDICT_FALSE(x) ABSL_PREDICT_FALSE(x)
+#define FCP_PREDICT_TRUE(x) ABSL_PREDICT_TRUE(x)
+#endif
+
+/**
+ * Check that the condition holds, otherwise die. Any additional messages can
+ * be streamed into the invocation. Example:
+ *
+ * FCP_CHECK(condition) << "stuff went wrong";
+ */
+#define FCP_CHECK(condition) \
+ FCP_LOG_IF(FATAL, FCP_PREDICT_FALSE(!(condition))) \
+ << ("Check failed: " #condition ". ")
+
+/**
+ * Check that the expression generating a status code is OK, otherwise die.
+ * Any additional messages can be streamed into the invocation.
+ */
+#define FCP_CHECK_STATUS(status) \
+ for (auto __check_status = (status); \
+ __check_status.code() != ::fcp::StatusCode::kOk;) \
+ FCP_LOG_IF(FATAL, __check_status.code() != ::fcp::StatusCode::kOk) \
+ << "status not OK: " << __check_status
+
+// Logging Implementation Details
+// ==============================
+
+namespace internal {
+
+/**
+ * An object which implements a logger sink. The default sink sends log
+ * messages to stderr.
+ */
+class Logger {
+ public:
+ virtual ~Logger() = default;
+
+ /**
+ * Basic log function.
+ *
+ * @param file The name of the associated file.
+ * @param line The line in this file.
+ * @param severity Severity of the log message.
+ * @param message The message to log.
+ */
+ virtual void Log(const char* file, int line, LogSeverity severity,
+ const char* message);
+};
+
+/**
+ * Gets/sets the active logger object.
+ */
+Logger* logger();
+void set_logger(Logger* logger);
+
+#ifndef FCP_BAREMETAL
+using StringStream = std::ostringstream;
+#endif // FCP_BAREMETAL
+
+/**
+ * Object allowing to construct a log message by streaming into it. This is
+ * used by the macro LOG(severity).
+ */
+class LogMessage {
+ public:
+ LogMessage(const char* file, int line, LogSeverity severity)
+ : file_(file), line_(line), severity_(severity) {}
+
+ template <typename T>
+ LogMessage& operator<<(const T& x) {
+ message_ << x;
+ return *this;
+ }
+
+#ifndef FCP_BAREMETAL
+ LogMessage& operator<<(std::ostream& (*pf)(std::ostream&)) {
+ message_ << pf;
+ return *this;
+ }
+#endif
+
+ ~LogMessage() {
+ logger()->Log(file_, line_, severity_, message_.str().c_str());
+ if (severity_ == LogSeverity::kFatal) {
+ abort();
+ }
+ }
+
+ private:
+ const char* file_;
+ int line_;
+ LogSeverity severity_;
+ StringStream message_;
+};
+
+/**
+ * This class is used to cast a LogMessage instance to void within a ternary
+ * expression in the expansion of FCP_LOG_IF. The cast is necessary so
+ * that the types of the expressions on either side of the : match, and the &
+ * operator is used because its precedence is lower than << but higher than
+ * ?:.
+ */
+class LogMessageVoidify {
+ public:
+ void operator&(LogMessage&) {} // NOLINT
+};
+
+} // namespace internal
+
+// Status and StatusOr
+// ===================
+
+/**
+ * Constructor for a status. A status message can be streamed into it. This
+ * captures the current file and line position and includes it into the status
+ * message if the status code is not OK.
+ *
+ * Use as in:
+ *
+ * FCP_STATUS(OK); // signal success
+ * FCP_STATUS(code) << message; // signal failure
+ *
+ * FCP_STATUS can be used in places which either expect a Status or a
+ * StatusOr<T>.
+ *
+ * You can configure the constructed status to also emit a log entry if the
+ * status is not OK by using LogInfo, LogWarning, LogError, or LogFatal as
+ * below:
+ *
+ * FCP_STATUS(code).LogInfo() << message;
+ *
+ * If the constant rx_debug is true, by default, all FCP_STATUS invocations
+ * will be logged on INFO level.
+ */
+#define FCP_STATUS(code) \
+ ::fcp::internal::MakeStatusBuilder(code, __FILE__, __LINE__)
+
+#ifdef FCP_BAREMETAL
+#define FCP_MUST_USE_RESULT __attribute__((warn_unused_result))
+#else
+#define FCP_MUST_USE_RESULT ABSL_MUST_USE_RESULT
+#endif
+
+#ifdef FCP_BAREMETAL
+
+// The bare-metal implementation doesn't depend on Abseil library and
+// provides its own implementations of StatusCode, Status and StatusOr that are
+// source code compatible with absl::StatusCode, absl::Status,
+// and absl::StatusOr.
+
+// See absl::StatusCode for details.
+enum StatusCode : int {
+ kOk = 0,
+ kCancelled = 1,
+ kUnknown = 2,
+ kInvalidArgument = 3,
+ kDeadlineExceeded = 4,
+ kNotFound = 5,
+ kAlreadyExists = 6,
+ kPermissionDenied = 7,
+ kResourceExhausted = 8,
+ kFailedPrecondition = 9,
+ kAborted = 10,
+ kOutOfRange = 11,
+ kUnimplemented = 12,
+ kInternal = 13,
+ kUnavailable = 14,
+ kDataLoss = 15,
+ kUnauthenticated = 16,
+};
+
+class FCP_MUST_USE_RESULT Status final {
+ public:
+ Status() : code_(StatusCode::kOk) {}
+ Status(StatusCode code, const std::string& message)
+ : code_(code), message_(message) {}
+
+ // Status is copyable and moveable.
+ Status(const Status&) = default;
+ Status& operator=(const Status&) = default;
+ Status(Status&&) = default;
+ Status& operator=(Status&&) = default;
+
+ // Tests whether this status is OK.
+ bool ok() const { return code_ == StatusCode::kOk; }
+
+ // Gets this status code.
+ StatusCode code() const { return code_; }
+
+ // Gets this status message.
+ const std::string& message() const { return message_; }
+
+ private:
+ StatusCode code_;
+ std::string message_;
+};
+
+template <typename T>
+class FCP_MUST_USE_RESULT StatusOr final {
+ public:
+ // Default constructor initializes StatusOr with kUnknown code.
+ explicit StatusOr() : StatusOr(StatusCode::kUnknown) {}
+
+ // Constructs a StatusOr from a failed status. The passed status must not be
+ // OK. This constructor is expected to be implicitly called.
+ StatusOr(Status status) // NOLINT
+ : status_(std::move(status)) {
+ FCP_CHECK(!this->status().ok());
+ }
+
+ // Constructs a StatusOr from a status code.
+ explicit StatusOr(StatusCode code) : StatusOr(code, "") {}
+
+ // Constructs a StatusOr from a status code and a message.
+ StatusOr(StatusCode code, const std::string& message)
+ : status_(Status(code, message)) {
+ FCP_CHECK(!this->status().ok());
+ }
+
+ // Construct a StatusOr from a value.
+ StatusOr(const T& value) // NOLINT
+ : value_(value) {}
+
+ // Construct a StatusOr from an R-value.
+ StatusOr(T&& value) // NOLINT
+ : value_(std::move(value)) {}
+
+ // StatusOr is copyable and moveable.
+ StatusOr(const StatusOr& other) : status_(other.status_) {
+ if (ok()) {
+ AssignValue(other.value_);
+ }
+ }
+
+ StatusOr(StatusOr&& other) : status_(std::move(other.status_)) {
+ if (ok()) {
+ AssignValue(std::move(other.value_));
+ }
+ }
+
+ StatusOr& operator=(const StatusOr& other) {
+ if (this != &other) {
+ ClearValue();
+ if (other.ok()) {
+ AssignValue(other.value_);
+ }
+ status_ = other.status_;
+ }
+ return *this;
+ }
+
+ StatusOr& operator=(StatusOr&& other) {
+ if (this != &other) {
+ ClearValue();
+ if (other.ok()) {
+ AssignValue(std::move(other.value_));
+ }
+ status_ = std::move(other.status_);
+ }
+ return *this;
+ }
+
+ ~StatusOr() { ClearValue(); }
+
+ // Tests whether this StatusOr is OK and has a value.
+ bool ok() const { return status_.ok(); }
+
+ // Returns the status.
+ const Status& status() const { return status_; }
+
+ // Returns the value if the StatusOr is OK.
+ const T& value() const& {
+ CheckOk();
+ return value_;
+ }
+ T& value() & {
+ CheckOk();
+ return value_;
+ }
+ T&& value() && {
+ CheckOk();
+ return std::move(value_);
+ }
+
+ // Operator *
+ const T& operator*() const& { return value(); }
+ T& operator*() & { return value(); }
+ T&& operator*() && { return std::move(value()); }
+
+ // Operator ->
+ const T* operator->() const { return &value(); }
+ T* operator->() { return &value(); }
+
+ // Used to explicitly ignore a StatusOr (avoiding unused-result warnings).
+ void Ignore() const {}
+
+ private:
+ void CheckOk() const { FCP_CHECK(ok()) << "StatusOr has no value"; }
+
+ // This is used to assign the value in place without invoking the assignment
+ // operator. Using the assignment operator wouldn't work in case the value_
+ // wasn't previously initialized. For example the value_ object might try
+ // to clear its previous value.
+ template <typename Arg>
+ void AssignValue(Arg&& arg) {
+ new (&unused_) T(std::forward<Arg>(arg));
+ }
+
+ // Destroy the current value if it was initialized.
+ void ClearValue() {
+ if (ok()) value_.~T();
+ }
+
+ Status status_;
+
+ // Using the union allows to avoid initializing the value_ field when
+ // StatusOr is constructed with Status.
+ struct Unused {};
+ union {
+ Unused unused_;
+ T value_;
+ };
+};
+
+#else
+
+// By default absl::Status and absl::StatusOr classes are used.
+using Status = absl::Status;
+using StatusCode = absl::StatusCode;
+template <typename T>
+using StatusOr = absl::StatusOr<T>;
+
+#endif // FCP_BAREMETAL
+
+constexpr auto OK = StatusCode::kOk;
+constexpr auto CANCELLED = StatusCode::kCancelled;
+constexpr auto UNKNOWN = StatusCode::kUnknown;
+constexpr auto INVALID_ARGUMENT = StatusCode::kInvalidArgument;
+constexpr auto DEADLINE_EXCEEDED = StatusCode::kDeadlineExceeded;
+constexpr auto NOT_FOUND = StatusCode::kNotFound;
+constexpr auto ALREADY_EXISTS = StatusCode::kAlreadyExists;
+constexpr auto PERMISSION_DENIED = StatusCode::kPermissionDenied;
+constexpr auto RESOURCE_EXHAUSTED = StatusCode::kResourceExhausted;
+constexpr auto FAILED_PRECONDITION = StatusCode::kFailedPrecondition;
+constexpr auto ABORTED = StatusCode::kAborted;
+constexpr auto OUT_OF_RANGE = StatusCode::kOutOfRange;
+constexpr auto UNIMPLEMENTED = StatusCode::kUnimplemented;
+constexpr auto INTERNAL = StatusCode::kInternal;
+constexpr auto UNAVAILABLE = StatusCode::kUnavailable;
+constexpr auto DATA_LOSS = StatusCode::kDataLoss;
+constexpr auto UNAUTHENTICATED = StatusCode::kUnauthenticated;
+
+namespace internal {
+/** Functions to assist with FCP_RETURN_IF_ERROR() */
+inline const Status AsStatus(const Status& status) { return status; }
+template <typename T>
+inline const Status AsStatus(const StatusOr<T>& status_or) {
+ return status_or.status();
+}
+} // namespace internal
+
+/**
+ * Macro which allows to check for a Status (or StatusOr) and return from the
+ * current method if not OK. Example:
+ *
+ * Status DoSomething() {
+ * FCP_RETURN_IF_ERROR(Step1());
+ * FCP_RETURN_IF_ERROR(Step2ReturningStatusOr().status());
+ * return FCP_STATUS(OK);
+ * }
+ */
+#define FCP_RETURN_IF_ERROR(expr) \
+ do { \
+ ::fcp::Status __status = ::fcp::internal::AsStatus(expr); \
+ if (__status.code() != ::fcp::StatusCode::kOk) { \
+ return (__status); \
+ } \
+ } while (false)
+
+/**
+ * Macro which allows to check for a StatusOr and return it's status if not OK,
+ * otherwise assign the value in the StatusOr to variable or declaration. Usage:
+ *
+ * StatusOr<bool> DoSomething() {
+ * FCP_ASSIGN_OR_RETURN(auto value, TryComputeSomething());
+ * if (!value) {
+ * FCP_ASSIGN_OR_RETURN(value, TryComputeSomethingElse());
+ * }
+ * return value;
+ * }
+ */
+#define FCP_ASSIGN_OR_RETURN(lhs, expr) \
+ _FCP_ASSIGN_OR_RETURN_1( \
+ _FCP_ASSIGN_OR_RETURN_CONCAT(statusor_for_aor, __LINE__), lhs, expr)
+
+#define _FCP_ASSIGN_OR_RETURN_1(statusor, lhs, expr) \
+ auto statusor = (expr); \
+ if (!statusor.ok()) { \
+ return statusor.status(); \
+ } \
+ lhs = std::move(statusor).value()
+
+// See https://goo.gl/x3iba2 for the reason of this construction.
+#define _FCP_ASSIGN_OR_RETURN_CONCAT(x, y) \
+ _FCP_ASSIGN_OR_RETURN_CONCAT_INNER(x, y)
+#define _FCP_ASSIGN_OR_RETURN_CONCAT_INNER(x, y) x##y
+
+// Status Implementation Details
+// =============================
+
+namespace internal {
+
+/**
+ * Helper class which allows to construct a status with message by streaming
+ * into it. Implicitly converts to Status and StatusOr so can be used as a drop
+ * in replacement when those types are expected.
+ */
+class FCP_MUST_USE_RESULT StatusBuilder {
+ public:
+ /** Construct a StatusBuilder from status code. */
+ StatusBuilder(StatusCode code, const char* file, int line);
+
+ /**
+ * Copy constructor for status builder. Most of the time not needed because of
+ * copy ellision. */
+ StatusBuilder(StatusBuilder const& other);
+
+ /** Return true if the constructed status will be OK. */
+ inline bool ok() const { return code_ == OK; }
+
+ /** Returns the code of the constructed status. */
+ inline StatusCode code() const { return code_; }
+
+ /** Stream into status message of this builder. */
+ template <typename T>
+ StatusBuilder& operator<<(T x) {
+ message_ << x;
+ return *this;
+ }
+
+ /** Mark this builder to emit a log message when the result is constructed. */
+ inline StatusBuilder& LogInfo() {
+ log_severity_ = LogSeverity::kInfo;
+ return *this;
+ }
+
+ /** Mark this builder to emit a log message when the result is constructed. */
+ inline StatusBuilder& LogWarning() {
+ log_severity_ = LogSeverity::kWarning;
+ return *this;
+ }
+
+ /** Mark this builder to emit a log message when the result is constructed. */
+ inline StatusBuilder& LogError() {
+ log_severity_ = LogSeverity::kError;
+ return *this;
+ }
+
+ /** Mark this builder to emit a log message when the result is constructed. */
+ inline StatusBuilder& LogFatal() {
+ log_severity_ = LogSeverity::kFatal;
+ return *this;
+ }
+
+ /** Implicit conversion to Status. */
+ operator Status(); // NOLINT
+
+ /** Implicit conversion to StatusOr. */
+ template <typename T>
+ inline operator StatusOr<T>() { // NOLINT
+ return StatusOr<T>(static_cast<Status>(*this));
+ }
+
+ private:
+ static constexpr LogSeverity kNoLog = static_cast<LogSeverity>(-1);
+ const char* const file_;
+ const int line_;
+ const StatusCode code_;
+ StringStream message_;
+ LogSeverity log_severity_ = fcp_debug ? LogSeverity::kInfo : kNoLog;
+};
+
+inline StatusBuilder MakeStatusBuilder(StatusCode code, const char* file,
+ int line) {
+ return StatusBuilder(code, file, line);
+}
+
+} // namespace internal
+} // namespace fcp
+
+#endif // FCP_BASE_MONITORING_H_
diff --git a/fcp/base/monitoring_test.cc b/fcp/base/monitoring_test.cc
new file mode 100644
index 0000000..41b8e41
--- /dev/null
+++ b/fcp/base/monitoring_test.cc
@@ -0,0 +1,269 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/monitoring.h"
+
+#include <stdio.h>
+
+#include <array>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/base/log_severity.h"
+#include "absl/strings/str_format.h"
+#include "fcp/base/base_name.h"
+
+namespace fcp {
+namespace {
+
+using ::testing::Eq;
+using ::testing::MatchesRegex;
+using ::testing::Not;
+
+MATCHER(IsOk, "") { return arg.ok(); }
+
+MATCHER_P(IsOkAndHolds, m, "") {
+ return testing::ExplainMatchResult(IsOk(), arg, result_listener) &&
+ testing::ExplainMatchResult(m, arg.value(), result_listener);
+}
+
+class BaremetalLogger final : public internal::Logger {
+ public:
+ void Log(const char* file, int line, LogSeverity severity,
+ const char* message) override {
+ absl::FPrintF(stderr, "%c %s:%d %s\n",
+ absl::LogSeverityName(static_cast<absl::LogSeverity>(
+ static_cast<int>(severity)))[0],
+ BaseName(file), line, message);
+ }
+};
+
+class MonitoringTest : public ::testing::TestWithParam<bool> {
+ public:
+ void SetUp() override {
+ if (replace_logger_) {
+ prev_logger_ = internal::logger();
+ internal::set_logger(&logger_);
+ }
+ }
+ void TearDown() override {
+ if (replace_logger_) {
+ internal::set_logger(prev_logger_);
+ }
+ }
+
+ private:
+ const bool replace_logger_ = GetParam();
+ internal::Logger* prev_logger_ = nullptr;
+ BaremetalLogger logger_;
+};
+
+#ifdef FCP_BAREMETAL
+INSTANTIATE_TEST_SUITE_P(Baremetal, MonitoringTest, testing::Values(true));
+#else
+INSTANTIATE_TEST_SUITE_P(Base, MonitoringTest, testing::Values(false));
+#endif
+
+TEST_P(MonitoringTest, LogInfo) {
+ testing::internal::CaptureStderr();
+ FCP_LOG(INFO) << "info log of something happening";
+ std::string output = testing::internal::GetCapturedStderr();
+ ASSERT_THAT(output, MatchesRegex("I.*info log of something happening\n"));
+}
+
+TEST_P(MonitoringTest, LogWarning) {
+ testing::internal::CaptureStderr();
+ FCP_LOG(WARNING) << "warning log of something happening";
+ std::string output = testing::internal::GetCapturedStderr();
+ ASSERT_THAT(output, MatchesRegex("W.*warning log of something happening\n"));
+}
+
+TEST_P(MonitoringTest, LogError) {
+ testing::internal::CaptureStderr();
+ FCP_LOG(ERROR) << "error log of something happening";
+ std::string output = testing::internal::GetCapturedStderr();
+ ASSERT_THAT(output, MatchesRegex("E.*error log of something happening\n"));
+}
+
+TEST_P(MonitoringTest, LogFatal) {
+ ASSERT_DEATH({ FCP_LOG(FATAL) << "fatal log"; }, "fatal log");
+}
+
+TEST_P(MonitoringTest, StatusBuilderLogInfo) {
+ testing::internal::CaptureStderr();
+ Status status = (FCP_STATUS(ABORTED) << "something happened").LogInfo();
+ std::string output = testing::internal::GetCapturedStderr();
+ ASSERT_THAT(output, MatchesRegex("I.*something happened\n"));
+}
+
+TEST_P(MonitoringTest, StatusBuilderLogWarning) {
+ testing::internal::CaptureStderr();
+ Status status = (FCP_STATUS(ABORTED) << "something happened").LogWarning();
+ std::string output = testing::internal::GetCapturedStderr();
+ ASSERT_THAT(output, MatchesRegex("W.*something happened\n"));
+}
+
+TEST_P(MonitoringTest, StatusBuilderLogError) {
+ testing::internal::CaptureStderr();
+ Status status = (FCP_STATUS(ABORTED) << "something happened").LogError();
+ std::string output = testing::internal::GetCapturedStderr();
+ ASSERT_THAT(output, MatchesRegex("E.*something happened\n"));
+}
+
+TEST_P(MonitoringTest, StatusBuilderLogFatal) {
+ ASSERT_DEATH(
+ {
+ Status status =
+ (FCP_STATUS(ABORTED) << "something happened").LogFatal();
+ },
+ "something happened");
+}
+
+TEST_P(MonitoringTest, LogIfTrue) {
+ testing::internal::CaptureStderr();
+ FCP_LOG_IF(INFO, true) << "some log";
+ std::string output = testing::internal::GetCapturedStderr();
+ ASSERT_THAT(output, MatchesRegex("I.*some log\n"));
+}
+
+TEST_P(MonitoringTest, LogIfFalse) {
+ testing::internal::CaptureStderr();
+ FCP_LOG_IF(INFO, false) << "some log";
+ std::string output = testing::internal::GetCapturedStderr();
+ ASSERT_EQ(output, "");
+}
+
+TEST_P(MonitoringTest, CheckSucceeds) { FCP_CHECK(1 < 2); }
+
+TEST_P(MonitoringTest, CheckFails) {
+ ASSERT_DEATH({ FCP_CHECK(1 < 0); }, "Check failed: 1 < 0.");
+}
+
+TEST_P(MonitoringTest, StatusOr) {
+ StatusOr<int> default_constructed_status;
+ ASSERT_FALSE(default_constructed_status.ok());
+ ASSERT_EQ(default_constructed_status.status().code(), UNKNOWN);
+
+ StatusOr<int> fail_status = FCP_STATUS(ABORTED) << "operation aborted";
+ ASSERT_FALSE(fail_status.ok());
+ ASSERT_EQ(fail_status.status().code(), ABORTED);
+ // TODO(team): Port StatusIs matcher to avoid casting message(),
+ // which is string_view, to std::string.
+ ASSERT_THAT(fail_status.status().message(),
+ MatchesRegex(".*operation aborted"));
+}
+
+TEST_P(MonitoringTest, StatusOrCopyAssignment) {
+ StatusOr<int> fail_status = FCP_STATUS(ABORTED) << "operation aborted";
+ StatusOr<int> copy_of_fail_status(fail_status);
+ ASSERT_FALSE(copy_of_fail_status.ok());
+ ASSERT_EQ(copy_of_fail_status.status().code(), ABORTED);
+ ASSERT_THAT(copy_of_fail_status.status().message(),
+ MatchesRegex(".*operation aborted"));
+
+ StatusOr<int> ok_status = 42;
+ StatusOr<int> copy_of_ok_status(ok_status);
+ ASSERT_THAT(copy_of_ok_status, IsOkAndHolds(Eq(42)));
+ ASSERT_EQ(copy_of_ok_status.value(), 42);
+}
+
+TEST_P(MonitoringTest, StatusOrMoveAssignment) {
+ StatusOr<std::unique_ptr<std::string>> fail_status = FCP_STATUS(ABORTED)
+ << "operation aborted";
+ StatusOr<std::unique_ptr<std::string>> moved_fail_status(
+ std::move(fail_status));
+ ASSERT_FALSE(moved_fail_status.ok());
+ ASSERT_EQ(moved_fail_status.status().code(), ABORTED);
+ ASSERT_THAT(moved_fail_status.status().message(),
+ MatchesRegex(".*operation aborted"));
+
+ auto value = std::make_unique<std::string>("foobar");
+ StatusOr<std::unique_ptr<std::string>> ok_status = std::move(value);
+ StatusOr<std::unique_ptr<std::string>> moved_ok_status(std::move(ok_status));
+ ASSERT_TRUE(moved_ok_status.ok());
+ ASSERT_EQ(*moved_ok_status.value(), "foobar");
+}
+
+TEST_P(MonitoringTest, StatusOrCopying) {
+ StatusOr<int> fail_status = FCP_STATUS(ABORTED) << "operation aborted";
+ StatusOr<int> copy_of_status = fail_status;
+ ASSERT_FALSE(copy_of_status.ok());
+ ASSERT_EQ(copy_of_status.status().code(), ABORTED);
+ ASSERT_THAT(copy_of_status.status().message(),
+ MatchesRegex(".*operation aborted"));
+
+ StatusOr<int> ok_status = 42;
+ copy_of_status = ok_status;
+ ASSERT_THAT(copy_of_status, IsOkAndHolds(Eq(42)));
+ ASSERT_EQ(copy_of_status.value(), 42);
+}
+
+TEST_P(MonitoringTest, StatusOrMoving) {
+ StatusOr<std::unique_ptr<std::string>> fail_status = FCP_STATUS(ABORTED)
+ << "operation aborted";
+ StatusOr<std::unique_ptr<std::string>> moved_status = std::move(fail_status);
+ ASSERT_FALSE(moved_status.ok());
+ ASSERT_EQ(moved_status.status().code(), ABORTED);
+ ASSERT_THAT(moved_status.status().message(),
+ MatchesRegex(".*operation aborted"));
+
+ auto value = std::make_unique<std::string>("foobar");
+ StatusOr<std::unique_ptr<std::string>> ok_status = std::move(value);
+ moved_status = std::move(ok_status);
+ ASSERT_TRUE(moved_status.ok());
+ ASSERT_EQ(*moved_status.value(), "foobar");
+}
+
+TEST_P(MonitoringTest, StatusBuilder) {
+ ASSERT_FALSE(FCP_STATUS(ABORTED).ok());
+ ASSERT_EQ(FCP_STATUS(ABORTED).code(), ABORTED);
+}
+
+TEST_P(MonitoringTest, FcpReturnIfError) {
+ ASSERT_THAT(
+ []() -> StatusOr<int> {
+ Status fail_status = FCP_STATUS(ABORTED);
+ FCP_RETURN_IF_ERROR(fail_status);
+ return 0;
+ }(),
+ Not(IsOk()));
+ ASSERT_THAT(
+ []() -> StatusOr<int> {
+ FCP_RETURN_IF_ERROR(Status());
+ return 0;
+ }(),
+ IsOkAndHolds(0));
+
+ ASSERT_THAT(
+ []() -> StatusOr<int> {
+ StatusOr<int> fail_statusor = FCP_STATUS(ABORTED);
+ FCP_RETURN_IF_ERROR(fail_statusor);
+ return 0;
+ }(),
+ Not(IsOk()));
+ ASSERT_THAT(
+ []() -> StatusOr<int> {
+ FCP_RETURN_IF_ERROR(StatusOr<int>(0));
+ return 0;
+ }(),
+ IsOkAndHolds(0));
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/base/move_to_lambda.h b/fcp/base/move_to_lambda.h
new file mode 100644
index 0000000..a314a78
--- /dev/null
+++ b/fcp/base/move_to_lambda.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_MOVE_TO_LAMBDA_H_
+#define FCP_BASE_MOVE_TO_LAMBDA_H_
+
+#include <utility>
+
+namespace fcp {
+
+/**
+ * Copyable wrapper for a move-only value. See MoveToLambda.
+ * The value is accessible with the * and -> operators.
+ *
+ * You must be careful to avoid accidnetal copies of this type. Copies are
+ * destructive (by design), so accidental copies might lead to using a
+ */
+template <typename T>
+class MoveToLambdaWrapper {
+ public:
+ explicit MoveToLambdaWrapper(T t) : value_(std::move(t)) {}
+
+ // The copy and move constructors are intentionally non-const.
+
+ MoveToLambdaWrapper(MoveToLambdaWrapper const& other)
+ : value_(std::move(other.value_)) {}
+
+ MoveToLambdaWrapper& operator=(MoveToLambdaWrapper const& other) {
+ value_ = std::move(other.value_);
+ return *this;
+ }
+
+ // We respect const-ness of the wrapper when dereferencing, so that 'mutable'
+ // is required on the lambda depending on usage of the value; changes
+ // to a captured value persist across calls to the lambda, which is rarely
+ // desired.
+
+ T const& operator*() const & {
+ return value_;
+ }
+
+ T const* operator->() const & {
+ return &value_;
+ }
+
+ T& operator*() & {
+ return value_;
+ }
+
+ T* operator->() & {
+ return &value_;
+ }
+
+ private:
+ mutable T value_;
+};
+
+/**
+ * Allows capturing a value into a lambda 'by move', before C++14. This is
+ * implemented by a copyable wrapper, which actually moves its value.
+ *
+ * auto moving = MoveToLambda(value);
+ * DoSometing([moving]{ V const& v = *moving; ... });
+ */
+template <typename T>
+MoveToLambdaWrapper<std::remove_reference_t<T>> MoveToLambda(T&& value) {
+ static_assert(
+ std::is_rvalue_reference<T&&>::value,
+ "Expected an rvalue: If the value is copied anyway (to this function), "
+ "you might as well put it in the lambda-capture list directly.");
+ return MoveToLambdaWrapper<std::remove_reference_t<T>>(
+ std::forward<T>(value));
+}
+
+} // namespace fcp
+
+#endif // FCP_BASE_MOVE_TO_LAMBDA_H_
diff --git a/fcp/base/move_to_lambda_test.cc b/fcp/base/move_to_lambda_test.cc
new file mode 100644
index 0000000..efe5e93
--- /dev/null
+++ b/fcp/base/move_to_lambda_test.cc
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/move_to_lambda.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/unique_value.h"
+
+namespace fcp {
+
+using ::testing::Eq;
+
+TEST(MoveToLambda, Basic) {
+ auto capture = MoveToLambda(UniqueValue<int>{123});
+ auto lambda = [capture]() {
+ EXPECT_TRUE(capture->has_value()) << "Should have moved the original";
+ return **capture;
+ };
+
+ int returned = lambda();
+ EXPECT_FALSE(capture->has_value()) << "Should have moved the original";
+ EXPECT_THAT(returned, Eq(123));
+
+ int returned_again = lambda();
+ EXPECT_THAT(returned_again, Eq(123)) << "Usage shouldn't be destructive";
+}
+
+TEST(MoveToLambda, Mutable) {
+ auto capture = MoveToLambda(UniqueValue<int>{0});
+ auto counter = [capture]() mutable {
+ EXPECT_TRUE(capture->has_value()) << "Should have moved the original";
+ return (**capture)++;
+ };
+
+ EXPECT_FALSE(capture->has_value()) << "Should have moved the original";
+
+ EXPECT_THAT(counter(), Eq(0));
+ EXPECT_THAT(counter(), Eq(1));
+ EXPECT_THAT(counter(), Eq(2));
+}
+
+} // namespace fcp
diff --git a/fcp/base/new.h b/fcp/base/new.h
new file mode 100644
index 0000000..5220a4d
--- /dev/null
+++ b/fcp/base/new.h
@@ -0,0 +1,13 @@
+#ifndef FCP_BASE_NEW_H_
+#define FCP_BASE_NEW_H_
+
+#ifdef FCP_NANOLIBC
+// Definitions of placement operator new are needed because nanolibc doesn't
+// currently have the <new> header.
+inline void* operator new(size_t, void* p) noexcept { return p; }
+inline void* operator new[](size_t, void* p) noexcept { return p; }
+#else
+#include <new>
+#endif // FCP_NANOLIBC
+
+#endif // FCP_BASE_NEW_H_
diff --git a/fcp/base/platform.cc b/fcp/base/platform.cc
new file mode 100644
index 0000000..0f43c72
--- /dev/null
+++ b/fcp/base/platform.cc
@@ -0,0 +1,131 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/platform.h"
+
+#include <stdlib.h>
+#include <sys/stat.h>
+
+#include <cstdio>
+#include <fstream>
+#include <sstream>
+#include <string>
+
+#include "absl/status/status.h"
+
+#ifdef _WIN32
+#include <direct.h>
+#endif
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+
+namespace {
+#ifdef _WIN32
+constexpr char kPathSeparator[] = "\\";
+#else
+constexpr char kPathSeparator[] = "/";
+#endif
+} // namespace
+
+std::string ConcatPath(absl::string_view path1, absl::string_view path2) {
+ if (path1.empty()) {
+ return std::string(path2);
+ }
+ return absl::StrCat(path1, kPathSeparator, path2);
+}
+
+absl::string_view StripTrailingPathSeparator(absl::string_view path) {
+ return absl::StripSuffix(path, kPathSeparator);
+}
+
+namespace internal {
+
+template <typename T>
+absl::StatusOr<T> ReadFile(absl::string_view file_name) {
+ auto file_name_str = std::string(file_name);
+ std::ifstream is(file_name_str);
+ if (!is) {
+ return absl::InternalError(
+ absl::StrCat("cannot read file ", file_name_str));
+ }
+ std::ostringstream buffer;
+ buffer << is.rdbuf();
+ if (!is) {
+ return absl::InternalError(
+ absl::StrCat("error reading file ", file_name_str));
+ }
+ return static_cast<T>(buffer.str());
+}
+
+} // namespace internal
+
+absl::StatusOr<std::string> ReadFileToString(absl::string_view file_name) {
+ return internal::ReadFile<std::string>(file_name);
+}
+
+absl::StatusOr<absl::Cord> ReadFileToCord(absl::string_view file_name) {
+ return internal::ReadFile<absl::Cord>(file_name);
+}
+
+absl::Status WriteStringToFile(absl::string_view file_name,
+ absl::string_view content) {
+ auto file_name_str = std::string(file_name);
+ std::ofstream os(file_name_str);
+ if (!os) {
+ return absl::InternalError(
+ absl::StrCat("cannot create file ", file_name_str));
+ }
+ os << content;
+ if (!os) {
+ return absl::InternalError(
+ absl::StrCat("error writing to file ", file_name_str));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status WriteCordToFile(absl::string_view file_name,
+ const absl::Cord& content) {
+ auto file_name_str = std::string(file_name);
+ std::ofstream os(file_name_str);
+ if (!os) {
+ return absl::InternalError(
+ absl::StrCat("cannot create file ", file_name_str));
+ }
+ for (absl::string_view chunk : content.Chunks()) {
+ os << chunk;
+ if (!os) {
+ return absl::InternalError(
+ absl::StrCat("error writing to file ", file_name_str));
+ }
+ }
+ return absl::OkStatus();
+}
+
+bool FileExists(absl::string_view file_name) {
+ struct stat info;
+ return stat(std::string(file_name).c_str(), &info) == 0;
+}
+
+std::string GetDataPath(absl::string_view relative_path) {
+ return std::string(relative_path);
+}
+
+} // namespace fcp
diff --git a/fcp/base/platform.h b/fcp/base/platform.h
new file mode 100644
index 0000000..bdf7dc8
--- /dev/null
+++ b/fcp/base/platform.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_PLATFORM_H_
+#define FCP_BASE_PLATFORM_H_
+
+#include <string>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/string_view.h"
+
+// This file defines platform dependent utilities.
+
+namespace fcp {
+
+/**
+ * Concatenates two file path components using platform specific separator.
+ */
+std::string ConcatPath(absl::string_view path1, absl::string_view path2);
+
+/**
+ * Strips a single platform specific path separator from the end of a path if
+ * it is present, returns the original path otherwise.
+ */
+absl::string_view StripTrailingPathSeparator(absl::string_view path);
+
+/**
+ * Reads file content into string.
+ */
+absl::StatusOr<std::string> ReadFileToString(absl::string_view file_name);
+
+/**
+ * Reads file content into absl::Cord.
+ */
+absl::StatusOr<absl::Cord> ReadFileToCord(absl::string_view file_name);
+
+/**
+ * Writes string content into file.
+ */
+absl::Status WriteStringToFile(absl::string_view file_name,
+ absl::string_view content);
+
+/**
+ * Writes cord content into file.
+ */
+absl::Status WriteCordToFile(absl::string_view file_name,
+ const absl::Cord& content);
+
+/**
+ * Returns true if the file exists.
+ */
+bool FileExists(absl::string_view file_name);
+
+/**
+ * Get absolute path given `relative_path`
+ */
+std::string GetDataPath(absl::string_view relative_path);
+
+} // namespace fcp
+
+#endif // FCP_BASE_PLATFORM_H_
diff --git a/fcp/base/platform_test.cc b/fcp/base/platform_test.cc
new file mode 100644
index 0000000..51e8e9a
--- /dev/null
+++ b/fcp/base/platform_test.cc
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/platform.h"
+
+#include "gtest/gtest.h"
+#include "absl/strings/cord.h"
+#include "fcp/base/base_name.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+
+namespace {
+
+TEST(PlatformTest, ConcatPath) {
+ auto combined = ConcatPath("first", "second");
+#if _WIN32
+ ASSERT_EQ(combined, "first\\second");
+#else
+ ASSERT_EQ(combined, "first/second");
+#endif
+}
+
+TEST(PlatformTest, StripTrailingPathSeparator) {
+#if _WIN32
+ ASSERT_EQ(StripTrailingPathSeparator("path\\"), "path");
+ ASSERT_EQ(StripTrailingPathSeparator("dir/path"), "dir/path");
+#else
+ ASSERT_EQ(StripTrailingPathSeparator("path/"), "path");
+ ASSERT_EQ(StripTrailingPathSeparator("dir/path"), "dir/path");
+#endif
+}
+
+TEST(PlatformTest, ReadWriteString) {
+ auto file = TemporaryTestFile(".dat");
+ ASSERT_EQ(WriteStringToFile(file, "Ein Text").code(), OK);
+ auto status_or_string = ReadFileToString(file);
+ ASSERT_TRUE(status_or_string.ok()) << status_or_string.status();
+ ASSERT_EQ(status_or_string.value(), "Ein Text");
+}
+
+TEST(PlatformTest, ReadWriteCord) {
+ auto file = TemporaryTestFile(".dat");
+ // Make cord with two chunks.
+ absl::Cord content("Ein");
+ content.Append(" Text");
+ ASSERT_EQ(WriteCordToFile(file, content).code(), OK);
+ auto status_or_cord = ReadFileToCord(file);
+ ASSERT_TRUE(status_or_cord.ok()) << status_or_cord.status();
+ ASSERT_EQ(status_or_cord.value(), "Ein Text");
+}
+
+TEST(PlatformTest, ReadStringFails) {
+ ASSERT_FALSE(ReadFileToString("foobarbaz").ok());
+}
+
+TEST(PlatformTest, ReadCordFails) {
+ ASSERT_FALSE(ReadFileToCord("foobarbaz").ok());
+}
+
+TEST(PlatformTest, BaseName) {
+ ASSERT_EQ(BaseName(ConcatPath("foo", "bar.x")), "bar.x");
+}
+
+TEST(PlatformTest, FileExists) {
+ auto file = TemporaryTestFile(".dat");
+ ASSERT_EQ(WriteStringToFile(file, "Ein Text").code(), OK);
+ ASSERT_TRUE(FileExists(file));
+}
+
+TEST(PlatformTest, FileExistsNot) { ASSERT_FALSE(FileExists("foobarbaz")); }
+
+} // namespace
+
+} // namespace fcp
diff --git a/fcp/base/process_unique_id.cc b/fcp/base/process_unique_id.cc
new file mode 100644
index 0000000..b939bb1
--- /dev/null
+++ b/fcp/base/process_unique_id.cc
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/process_unique_id.h"
+
+namespace fcp {
+
+namespace {
+// This will be zero-initialized.
+std::atomic<uint64_t> next_id;
+} // namespace
+
+ProcessUniqueId ProcessUniqueId::Next() {
+ return ProcessUniqueId(next_id.fetch_add(1, std::memory_order_relaxed));
+}
+
+} // namespace fcp
diff --git a/fcp/base/process_unique_id.h b/fcp/base/process_unique_id.h
new file mode 100644
index 0000000..93ff880
--- /dev/null
+++ b/fcp/base/process_unique_id.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_PROCESS_UNIQUE_ID_H_
+#define FCP_BASE_PROCESS_UNIQUE_ID_H_
+
+#include <atomic>
+
+namespace fcp {
+
+// Threadsafe class for creating IDs that are unique per-process.
+class ProcessUniqueId {
+ public:
+ // Threadsafe method for getting a new unique ID, intended to be cheap.
+ static ProcessUniqueId Next();
+ uint64_t value() { return value_; }
+
+ private:
+ explicit constexpr ProcessUniqueId(uint64_t value) : value_(value) {}
+ // Value of the unique ID.
+ uint64_t value_;
+};
+
+} // namespace fcp
+
+#endif // FCP_BASE_PROCESS_UNIQUE_ID_H_
diff --git a/fcp/base/process_unique_id_test.cc b/fcp/base/process_unique_id_test.cc
new file mode 100644
index 0000000..91710cd
--- /dev/null
+++ b/fcp/base/process_unique_id_test.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/process_unique_id.h"
+
+#include <thread> // NOLINT(build/c++11)
+
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_set.h"
+
+namespace fcp {
+
+TEST(UniqueIdGeneratorTest, CallingNextYieldsDifferentIds) {
+ uint64_t id1 = ProcessUniqueId::Next().value();
+ uint64_t id2 = ProcessUniqueId::Next().value();
+ EXPECT_NE(id1, id2);
+}
+
+TEST(UniqueIdGeneratorTest, MultipleThreads) {
+ const int n = 15;
+ std::thread threads[n];
+ uint64_t ids[n];
+ for (int i = 0; i < n; i++) {
+ threads[i] =
+ std::thread([&ids, i]() { ids[i] = ProcessUniqueId::Next().value(); });
+ }
+ for (int i = 0; i < n; i++) {
+ threads[i].join();
+ }
+ absl::flat_hash_set<uint64_t> id_set;
+ for (int i = 0; i < n; i++) {
+ bool inserted = id_set.insert(ids[i]).second;
+ EXPECT_TRUE(inserted);
+ }
+}
+
+} // namespace fcp
diff --git a/fcp/base/random_token.cc b/fcp/base/random_token.cc
new file mode 100644
index 0000000..8564a75
--- /dev/null
+++ b/fcp/base/random_token.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/random_token.h"
+
+#include <string.h>
+
+#include <string>
+
+#include "absl/strings/escaping.h"
+#include "fcp/base/monitoring.h"
+#include "openssl/rand.h"
+
+namespace fcp {
+
+RandomToken RandomToken::Generate() {
+ uint64_t words[2];
+ static_assert(sizeof(words) == kRandomTokenSizeInBytes,
+ "Should match the token size");
+ int r = RAND_bytes(reinterpret_cast<unsigned char*>(words),
+ kRandomTokenSizeInBytes);
+ FCP_CHECK(r == 1);
+ return RandomToken(words[0], words[1]);
+}
+
+RandomToken RandomToken::FromBytes(absl::Span<char const> bytes) {
+ FCP_CHECK(bytes.size() == kRandomTokenSizeInBytes);
+
+ uint64_t words[2];
+ static_assert(sizeof(words) == kRandomTokenSizeInBytes,
+ "Should match the token size");
+ memcpy(reinterpret_cast<char*>(words), bytes.data(), kRandomTokenSizeInBytes);
+ return RandomToken(words[0], words[1]);
+}
+
+std::array<char, kRandomTokenSizeInBytes> RandomToken::ToBytes() const {
+ std::array<char, kRandomTokenSizeInBytes> bytes;
+ memcpy(bytes.data(), reinterpret_cast<char const*>(words_),
+ kRandomTokenSizeInBytes);
+ return bytes;
+}
+
+std::string RandomToken::ToString() const {
+ return std::string(reinterpret_cast<char const*>(words_),
+ kRandomTokenSizeInBytes);
+}
+
+std::string RandomToken::ToPrintableString() const {
+ return absl::BytesToHexString(absl::string_view(
+ reinterpret_cast<char const*>(words_), kRandomTokenSizeInBytes));
+}
+
+} // namespace fcp
diff --git a/fcp/base/random_token.h b/fcp/base/random_token.h
new file mode 100644
index 0000000..1bb0911
--- /dev/null
+++ b/fcp/base/random_token.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_RANDOM_TOKEN_H_
+#define FCP_BASE_RANDOM_TOKEN_H_
+
+#include <stdint.h>
+
+#include <array>
+#include <string>
+#include <utility>
+
+#include "absl/types/span.h"
+
+namespace fcp {
+
+enum { kRandomTokenSizeInBytes = 16 };
+
+/**
+ * A RandomToken is a unique and "unguessable" value, thus usable as a
+ * 'password-capability' (even in an adversarial / network context). Each is
+ * comprised of 128 random bits, sourced from a CSRNG.
+ *
+ * The current implementation uses BoringSSL's RAND_bytes. Unless someone calls
+ * RAND_enable_fork_unsafe_buffering, it should be well-behaved even under
+ * fork(); i.e. tokens should not collide with those generated by new child
+ * processes.
+ */
+class RandomToken {
+ public:
+ /**
+ * Generates a new token. This sources bits from a CSRNG. The returned token
+ * can be assumed to have the desired properties (unique and unguessable) -
+ * unlike one that was deserialized from an untrusted source.
+ */
+ static RandomToken Generate();
+
+ /**
+ * Deserializes a token, serialized with ToBytes() or ToString().
+ * Note that tokens from untrusted sources might not have been generated
+ * correctly, so should not be assumed unique and unguessable.
+ *
+ * Precondition: bytes.size() == kRandomTokenSizeInBytes
+ */
+ static RandomToken FromBytes(absl::Span<char const> bytes);
+
+ /**
+ * Serializes a token, to something usable by FromBytes().
+ */
+ std::array<char, kRandomTokenSizeInBytes> ToBytes() const;
+
+ /**
+ * Serializes a token, to an std::string usable by FromBytes().
+ *
+ * Postcondition: returned_string.size() == kRandomTokenSizeInBytes.
+ */
+ std::string ToString() const;
+
+ /**
+ * Returns a hex-string representation (suitable for log output etc.)
+ */
+ std::string ToPrintableString() const;
+
+ constexpr bool operator==(RandomToken other) const {
+ return words_[0] == other.words_[0] && words_[1] == other.words_[1];
+ }
+
+ constexpr bool operator!=(RandomToken other) const {
+ return !(*this == other);
+ }
+
+ template <typename H>
+ friend H AbslHashValue(H h, RandomToken t) {
+ return H::combine(std::move(h), t.words_[0], t.words_[1]);
+ }
+
+ private:
+ explicit constexpr RandomToken(uint64_t a, uint64_t b) : words_{a, b} {}
+
+ // It would have been nice to write char bytes[16], with alignas(16).
+ // Surprisingly, even current compilers were found prone to unrolling
+ // byte-by-byte comparison loops etc. This representation yields very compact
+ // code.
+ uint64_t words_[2];
+};
+
+static_assert(sizeof(RandomToken) == kRandomTokenSizeInBytes,
+ "Incorrect RandomToken size");
+
+} // namespace fcp
+
+#endif // FCP_BASE_RANDOM_TOKEN_H_
diff --git a/fcp/base/random_token_test.cc b/fcp/base/random_token_test.cc
new file mode 100644
index 0000000..843e53e
--- /dev/null
+++ b/fcp/base/random_token_test.cc
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/random_token.h"
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/hash/hash_testing.h"
+
+namespace fcp {
+
+using ::testing::Eq;
+
+TEST(RandomTokenTest, Equality) {
+ RandomToken a1 = RandomToken::Generate();
+ RandomToken a2 = a1;
+ RandomToken b = RandomToken::Generate();
+
+ EXPECT_TRUE(a1 == a2);
+ EXPECT_FALSE(a1 != a2);
+
+ EXPECT_TRUE(b != a1);
+ EXPECT_FALSE(b == a1);
+}
+
+TEST(RandomTokenTest, Hashing) {
+ std::vector<RandomToken> distinct;
+ for (int i = 0; i < 128; i++) {
+ distinct.push_back(RandomToken::Generate());
+ }
+
+ EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly(distinct));
+}
+
+TEST(RandomTokenTest, Collisions) {
+ // If this test ever fails, then we've tragically over-estimated the quality
+ // of our random source.
+ absl::flat_hash_set<RandomToken> tokens;
+ for (int i = 0; i < 1024; i++) {
+ RandomToken t = RandomToken::Generate();
+ bool inserted = tokens.insert(t).second;
+ EXPECT_TRUE(inserted);
+ }
+}
+
+TEST(RandomTokenTest, Serialization) {
+ RandomToken original = RandomToken::Generate();
+ auto bytes = original.ToBytes();
+ RandomToken deserialized = RandomToken::FromBytes(bytes);
+ EXPECT_THAT(deserialized, Eq(original));
+}
+
+TEST(RandomTokenTest, SerializationToString) {
+ RandomToken original = RandomToken::Generate();
+ std::string str = original.ToString();
+ RandomToken deserialized = RandomToken::FromBytes(str);
+ EXPECT_THAT(deserialized, Eq(original));
+}
+
+TEST(RandomTokenTest, ToPrintableString) {
+ constexpr char const* kHex = "000102030405060708090a0b0c0d0e0f";
+ std::array<char, kRandomTokenSizeInBytes> kBytes{
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+ EXPECT_THAT(RandomToken::FromBytes(kBytes).ToPrintableString(), Eq(kHex));
+}
+
+} // namespace fcp
diff --git a/fcp/base/realtime_clock_test.cc b/fcp/base/realtime_clock_test.cc
new file mode 100644
index 0000000..d38f6f3
--- /dev/null
+++ b/fcp/base/realtime_clock_test.cc
@@ -0,0 +1,88 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "fcp/base/clock.h"
+
+namespace fcp {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Test;
+
+class RealtimeClockTest : public Test {
+ public:
+ RealtimeClockTest() : start_(Clock::RealClock()->Now()) {}
+
+ void OnWakeUp(int waiter_id) {
+ absl::MutexLock lock(&mu_);
+ waiter_ids_.push_back(waiter_id);
+ waiter_intervals_.push_back(Clock::RealClock()->Now() - start_);
+ }
+
+ protected:
+ absl::Time start_;
+ absl::Mutex mu_;
+ std::vector<int> waiter_ids_ ABSL_GUARDED_BY(&mu_);
+ std::vector<absl::Duration> waiter_intervals_ ABSL_GUARDED_BY(&mu_);
+};
+
+// Simple callback waiter that records current waiter ID with the test.
+class TestWaiter : public Clock::Waiter {
+ public:
+ explicit TestWaiter(int id, RealtimeClockTest* test) : id_(id), test_(test) {}
+
+ void WakeUp() override { test_->OnWakeUp(id_); }
+
+ private:
+ int id_;
+ RealtimeClockTest* test_;
+};
+
+TEST_F(RealtimeClockTest, MultipleTimerWakeUp) {
+ // Add 4 timers at various deadlines, the last one in the past.
+ Clock::RealClock()->WakeupWithDeadline(start_ + absl::Milliseconds(200),
+ std::make_shared<TestWaiter>(1, this));
+ Clock::RealClock()->WakeupWithDeadline(start_ + absl::Milliseconds(100),
+ std::make_shared<TestWaiter>(2, this));
+ Clock::RealClock()->WakeupWithDeadline(start_ + absl::Milliseconds(101),
+ std::make_shared<TestWaiter>(3, this));
+ Clock::RealClock()->WakeupWithDeadline(start_ - absl::Milliseconds(1),
+ std::make_shared<TestWaiter>(4, this));
+
+ // End the test when all 3 timers have been triggered.
+ auto test_done = [this]() {
+ mu_.AssertHeld();
+ return waiter_ids_.size() == 4;
+ };
+
+ absl::MutexLock lock(&mu_);
+ mu_.Await(absl::Condition(&test_done));
+
+ // Verify the results
+ EXPECT_THAT(waiter_ids_, ElementsAre(4, 2, 3, 1));
+ EXPECT_GE(waiter_intervals_[0], absl::ZeroDuration());
+ EXPECT_GE(waiter_intervals_[1], absl::Milliseconds(100));
+ EXPECT_GE(waiter_intervals_[2], absl::Milliseconds(101));
+ EXPECT_GE(waiter_intervals_[3], absl::Milliseconds(200));
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/base/reentrancy_guard.h b/fcp/base/reentrancy_guard.h
new file mode 100644
index 0000000..376110a
--- /dev/null
+++ b/fcp/base/reentrancy_guard.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_REENTRANCY_GUARD_H_
+#define FCP_BASE_REENTRANCY_GUARD_H_
+
+#include <atomic>
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+
+/**
+ * ReentrancyGuard class is used to enforce strictly sequential calling pattern.
+ * Usage pattern:
+ *
+ * Status Method(...) {
+ * ReentrancyGuard guard;
+ * FCP_RETURN_IF_ERROR(guard.Check(&in_use_));
+ *
+ * // The rest of the method body...
+ * }
+ *
+ * in_use_ above is std::atomic<bool> value stored in the object which methods
+ * are enforced.
+ */
+class ReentrancyGuard {
+ public:
+ Status Check(std::atomic<bool>* in_use) {
+ FCP_CHECK(in_use != nullptr);
+ bool expected_value = false;
+ if (!in_use->compare_exchange_strong(expected_value, true)) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Concurrent method calls detected";
+ }
+
+ in_use_ = in_use;
+ return FCP_STATUS(OK);
+ }
+
+ ~ReentrancyGuard() {
+ if (in_use_ != nullptr) {
+ in_use_->store(false);
+ }
+ }
+
+ private:
+ // Pointer to atomic boolean value which is owned by the object which methods
+ // are guarded against reentrancy. This value is set to true when inside
+ // a method call; otherwise false.
+ // Note: std::atomic is used here rather than Mutex is emphasise non-blocking
+ // nature of the implementation. The purpose in_use_ is only to check against
+ // reentrancy rather than synchronization.
+ std::atomic<bool>* in_use_ = nullptr;
+};
+
+} // namespace fcp
+
+#endif // FCP_BASE_REENTRANCY_GUARD_H_
diff --git a/fcp/base/reentrancy_guard_test.cc b/fcp/base/reentrancy_guard_test.cc
new file mode 100644
index 0000000..c2d3595
--- /dev/null
+++ b/fcp/base/reentrancy_guard_test.cc
@@ -0,0 +1,86 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/base/reentrancy_guard.h"
+
+#include <atomic>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace {
+
+class ReentrancyGuardTest : public testing::Test {
+ protected:
+ Status SimpleMethod() {
+ ReentrancyGuard guard;
+ return guard.Check(&in_use_);
+ }
+
+ Status ReentrantMethod() {
+ ReentrancyGuard guard;
+ FCP_RETURN_IF_ERROR((guard.Check(&in_use_)));
+ return ReentrantMethod();
+ }
+
+ Status LongRunningMethod(absl::Notification* method_entered,
+ absl::Notification* resume) {
+ ReentrancyGuard guard;
+ FCP_RETURN_IF_ERROR((guard.Check(&in_use_)));
+ method_entered->Notify();
+ resume->WaitForNotification();
+ return FCP_STATUS(OK);
+ }
+
+ private:
+ std::atomic<bool> in_use_ = false;
+};
+
+TEST_F(ReentrancyGuardTest, SequentialCallsSucceed) {
+ ASSERT_THAT(SimpleMethod(), IsOk());
+ ASSERT_THAT(SimpleMethod(), IsOk());
+}
+
+TEST_F(ReentrancyGuardTest, ReentrantCallsFail) {
+ ASSERT_THAT(ReentrantMethod(), IsCode(FAILED_PRECONDITION));
+}
+
+TEST_F(ReentrancyGuardTest, ConcurrentCallsFail) {
+ absl::Notification long_running_method_entered;
+ absl::Notification resume;
+
+ auto pool = fcp::CreateThreadPoolScheduler(1);
+ pool->Schedule([&] {
+ ASSERT_THAT(LongRunningMethod(&long_running_method_entered, &resume),
+ IsOk());
+ });
+
+ // This signals that LongRunningMethod() has been entered and waits there
+ // to be resumed.
+ long_running_method_entered.WaitForNotification();
+
+ // Make a concurrent call, which is expected to fail.
+ ASSERT_THAT(SimpleMethod(), IsCode(FAILED_PRECONDITION));
+
+ // Resume LongRunningMethod() and wait for the thread to finish.
+ resume.Notify();
+ pool->WaitUntilIdle();
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/base/result.cc b/fcp/base/result.cc
new file mode 100644
index 0000000..a37ce9b
--- /dev/null
+++ b/fcp/base/result.cc
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/result.h"
+
+#include "fcp/base/tracing_schema.h"
+#include "fcp/tracing/tracing_span.h"
+
+namespace fcp {
+Error ExpectBase::TraceExpectError(const char* expectation) const {
+ return TraceError<ResultExpectError>(expectation, loc_.file_name(),
+ loc_.line());
+}
+
+static TracingStatusCode ConvertToTracingStatus(fcp::StatusCode code) {
+ switch (code) {
+ case fcp::StatusCode::kOk:
+ return TracingStatusCode_Ok;
+ case fcp::StatusCode::kCancelled:
+ return TracingStatusCode_Cancelled;
+ case fcp::StatusCode::kInvalidArgument:
+ return TracingStatusCode_InvalidArgument;
+ case fcp::StatusCode::kDeadlineExceeded:
+ return TracingStatusCode_DeadlineExceeded;
+ case fcp::StatusCode::kNotFound:
+ return TracingStatusCode_NotFound;
+ case fcp::StatusCode::kAlreadyExists:
+ return TracingStatusCode_AlreadyExists;
+ case fcp::StatusCode::kPermissionDenied:
+ return TracingStatusCode_PermissionDenied;
+ case fcp::StatusCode::kResourceExhausted:
+ return TracingStatusCode_ResourceExhausted;
+ case fcp::StatusCode::kFailedPrecondition:
+ return TracingStatusCode_FailedPrecondition;
+ case fcp::StatusCode::kAborted:
+ return TracingStatusCode_Aborted;
+ case fcp::StatusCode::kOutOfRange:
+ return TracingStatusCode_OutOfRange;
+ case fcp::StatusCode::kUnimplemented:
+ return TracingStatusCode_Unimplemented;
+ case fcp::StatusCode::kInternal:
+ return TracingStatusCode_Internal;
+ case fcp::StatusCode::kUnavailable:
+ return TracingStatusCode_Unavailable;
+ case fcp::StatusCode::kDataLoss:
+ return TracingStatusCode_DataLoss;
+ case fcp::StatusCode::kUnauthenticated:
+ return TracingStatusCode_Unauthenticated;
+ default:
+ return TracingStatusCode_Unknown;
+ }
+}
+
+Error ExpectBase::TraceUnexpectedStatus(fcp::StatusCode expected_code,
+ const fcp::Status& actual) const {
+ return TraceError<ResultExpectStatusError>(
+ ConvertToTracingStatus(expected_code),
+ ConvertToTracingStatus(actual.code()), actual.message(), loc_.file_name(),
+ loc_.line());
+}
+} // namespace fcp
diff --git a/fcp/base/result.h b/fcp/base/result.h
new file mode 100644
index 0000000..f97d6a5
--- /dev/null
+++ b/fcp/base/result.h
@@ -0,0 +1,401 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_RESULT_H_
+#define FCP_BASE_RESULT_H_
+
+#include <optional>
+#include <type_traits>
+#include <variant>
+
+#include "fcp/base/error.h"
+#include "fcp/base/meta.h"
+#include "fcp/base/source_location.h"
+
+namespace fcp {
+namespace result_internal {
+
+template <typename R>
+struct ResultTraits;
+
+} // namespace result_internal
+
+// A Result is either a value (T) or an opaque Error. There are two main ways to
+// use one.
+//
+// Transform: Given a Result<T> r, and a callable f from T to Result<U>,
+// r.Then(f) returns Result<U>. Note that errors are passed through (without
+// calling f).
+//
+// Similarly, given a Result<T> r, and a callable f from T to U, r.Map(f)
+// returns Result<U>. The difference is that Then can introduce new errors (it
+// returns Result<U>) whereas Map only transforms values to other values.
+//
+// Result<int> some_int = ...
+// Result<bool> b = some_int.Then([](int i) -> Result<int> {
+// if (i < 0) {
+// return TraceError<...>(...);
+// } else {
+// return i;
+// }
+// }).Map([](int i) -> bool) {
+// return i % 2 == 0;
+// });
+//
+// Propagate: The FCP_TRY macro unwraps results to their values. If a result
+// contains an error, it is returned (from the function where FCP_TRY appears).
+//
+// Result<int> GetInt();
+//
+// Result<bool> F() {
+// int i = FCP_TRY(GetInt());
+// if (i < 0) {
+// }
+//
+// return i % 2 == 0;
+// }
+//
+// Result<T> provides implicit conversions from T and Error. As above, in
+// functions returning Result<T>, it is useful and encourage to return a T or
+// Error directly.
+template <typename T>
+class ABSL_MUST_USE_RESULT Result {
+ public:
+ using ValueType = T;
+
+ // These make Result<> usable as an argument to Match() (see match.h).
+ using VariantType = std::variant<Error, T>;
+ constexpr VariantType& variant() & { return val_; }
+ constexpr VariantType const& variant() const& { return val_; }
+ constexpr VariantType&& variant() && { return std::move(val_); }
+
+ // Implicit conversion from T
+ constexpr Result(T t) : val_(std::move(t)) {} // NOLINT
+
+ // Implicit conversion from Error
+ Result(Error e) : val_(e) {} // NOLINT
+
+ constexpr bool is_error() const {
+ return std::holds_alternative<Error>(val_);
+ }
+
+ // Returns a *reference* to the contained value.
+ // Requires (CHECK): !is_error()
+ constexpr T const& GetValueOrDie() const& {
+ FCP_CHECK(std::holds_alternative<T>(val_));
+ return absl::get<T>(val_);
+ }
+
+ // Returns the contained value (by move).
+ // This applies for results which are rvalues.
+ //
+ // Example:
+ // Result<X> r = f();
+ // X v = std::move(r).GetValueOrDie();
+ //
+ // Example:
+ // X v = f().GetValueOrDie();
+ //
+ // Requires (CHECK): !is_error()
+ constexpr T GetValueOrDie() && {
+ FCP_CHECK(std::holds_alternative<T>(val_));
+ return absl::get<T>(std::move(val_));
+ }
+
+ // Returns the contained error.
+ // Requires (CHECK): is_error()
+ Error GetErrorOrDie() const {
+ FCP_CHECK(std::holds_alternative<Error>(val_));
+ return absl::get<Error>(val_);
+ }
+
+ // Transforms this Result into another (with value type U).
+ //
+ // If this Result holds an Error, it is passed through.
+ // If this Result holds a value, then the callable 'fn' is applied to it.
+ // The callable 'fn' is expected to return Result<U>.
+ //
+ // Example:
+ //
+ // Result<int> some_int = ...
+ // Result<bool> b = some_int.Then([](int i) -> Result<bool> {
+ // if (i < 0) {
+ // return TraceError<...>(...);
+ // } else {
+ // return i % 2 == 0;
+ // }
+ // });
+ template <typename Fn>
+ constexpr auto Then(Fn fn) const& {
+ return ThenInternal<false>(*this, std::move(fn));
+ }
+
+ template <typename Fn>
+ constexpr auto Then(Fn fn) && {
+ return ThenInternal<true>(std::move(*this), std::move(fn));
+ }
+
+ // Maps values of type T to a values of type U.
+ //
+ // If this Result holds an Error, it is passed through.
+ // If this Result holds a value, then the callable 'fn' is applied to it.
+ //
+ // Example:
+ //
+ // Result<int> some_int = ...
+ // Result<bool> b = some_int.Map([](int i) {
+ // return i % 2 == 0;
+ // });
+ template <typename Fn>
+ constexpr auto Map(Fn fn) const& {
+ using U = std::invoke_result_t<Fn, T const&>;
+ return ThenInternal<false>(
+ *this, [fn = std::move(fn)](T const& t) { return Result<U>(fn(t)); });
+ }
+
+ template <typename Fn>
+ constexpr auto Map(Fn fn) && {
+ using U = std::invoke_result_t<Fn, T&&>;
+ return ThenInternal<true>(std::move(*this), [fn = std::move(fn)](T&& t) {
+ return Result<U>(fn(std::move(t)));
+ });
+ }
+
+ private:
+ template <bool Move, typename Fn>
+ static constexpr auto ThenInternal(
+ std::conditional_t<Move, Result<T>&&, Result<T> const&> r, Fn fn) {
+ using RefType = std::conditional_t<Move, T&&, T const&>;
+ using RetType = std::invoke_result_t<Fn, RefType>;
+ static_assert(
+ result_internal::ResultTraits<RetType>::is_result(),
+ "The callable provided to 'Then' must return Result<U> for "
+ "some type U. When always returning a value, use Map instead.");
+
+ if (r.is_error()) {
+ return RetType(r.GetErrorOrDie());
+ } else {
+ return fn(absl::get<T>(std::move(r).variant()));
+ }
+ }
+
+ std::variant<Error, T> val_;
+};
+
+// This is a deduction guide so that one can write Result(t) for a value t,
+// without explicitly specifying the value type. This one is implicitly
+// declared anyway; we make it explicit to suppress -Wctad-maybe-unsupported.
+template <typename T>
+Result(T) -> Result<T>;
+
+// ResultFrom<T> -> Result<T>
+// ResultFrom<Result<T>> -> Result<T>
+//
+// Note that ResultFrom<Error> is ill-formed (like Result<Error>).
+template <typename T>
+using ResultFrom = decltype(Result(std::declval<T>()));
+
+// ResultOf applied to the result of calling Fn with Args...
+template <typename Fn, typename... Args>
+using ResultOf = ResultFrom<std::invoke_result_t<Fn, Args...>>;
+
+namespace result_internal {
+
+template <typename R>
+struct ResultTraits {
+ using ValueType = void;
+};
+
+template <typename T>
+struct ResultTraits<Result<T>> {
+ static constexpr bool is_result() { return true; }
+ using ValueType = T;
+};
+
+// This is used in FCP_TRY, to require that the parameter to FCP_TRY has type
+// Result<T> for some T.
+template <typename T>
+constexpr bool ResultIsError(Result<T> const& r) {
+ return r.is_error();
+}
+
+} // namespace result_internal
+
+class ExpectBase {
+ public:
+ constexpr explicit ExpectBase(SourceLocation loc) : loc_(loc) {}
+
+ protected:
+ Error TraceExpectError(const char* expectation) const;
+ Error TraceUnexpectedStatus(fcp::StatusCode expected_code,
+ const fcp::Status& actual) const;
+
+ private:
+ SourceLocation loc_;
+};
+
+// Returns Result<T> if the current result has std::variant that holds a
+// value of type T; otherwise returns an error Result.
+template <typename T>
+struct ExpectIs : public ExpectBase {
+ using ExpectBase::ExpectBase;
+ constexpr explicit ExpectIs(SourceLocation loc = SourceLocation::current())
+ : ExpectBase(loc) {}
+
+ template <typename... Us>
+ constexpr Result<T> operator()(std::variant<Us...> v) const {
+ if (std::holds_alternative<T>(v)) {
+ return absl::get<T>(std::move(v));
+ } else {
+ return TraceExpectError("ExpectIs");
+ }
+ }
+};
+
+// Returns Result<std::variant<Us...>> if the current result has
+// std::variant that holds a value of one of the types from Us... typelist;
+// otherwise returns an error Result. This operation is valid only when the
+// set of expected types Us... is a subset of the set of types Ts... in the
+// current result.
+template <typename... Ts>
+struct ExpectOneOf : public ExpectBase {
+ using ExpectBase::ExpectBase;
+ constexpr explicit ExpectOneOf(SourceLocation loc = SourceLocation::current())
+ : ExpectBase(loc) {}
+
+ template <typename... Us>
+ constexpr Result<std::variant<Ts...>> operator()(
+ std::variant<Us...> v) const {
+ static_assert(IsSubsetOf<Pack<Ts...>, Pack<Us...>>::value);
+
+ // TODO(team): This should be expressible with Match
+ return absl::visit(
+ [this](auto arg) -> Result<std::variant<Ts...>> {
+ if constexpr (IsTypeOneOf<std::decay_t<decltype(arg)>, Ts...>()) {
+ return std::variant<Ts...>(std::move(arg));
+ } else {
+ return TraceExpectError("ExpectOneOf<>");
+ }
+ },
+ std::move(v));
+ }
+};
+
+// Returns Result<Unit> if the current result has boolean 'true' value;
+// otherwise returns an error Result.
+struct ExpectTrue : public ExpectBase {
+ using ExpectBase::ExpectBase;
+ constexpr explicit ExpectTrue(SourceLocation loc = SourceLocation::current())
+ : ExpectBase(loc) {}
+
+ template <typename... Us>
+ constexpr Result<Unit> operator()(bool b) const {
+ if (b) {
+ return Unit{};
+ } else {
+ return TraceExpectError("ExpectTrue");
+ }
+ }
+};
+
+// Returns Result<Unit> if the current result has boolean 'false' value;
+// otherwise returns an error Result.
+struct ExpectFalse : public ExpectBase {
+ using ExpectBase::ExpectBase;
+ constexpr explicit ExpectFalse(SourceLocation loc = SourceLocation::current())
+ : ExpectBase(loc) {}
+
+ template <typename... Us>
+ constexpr Result<Unit> operator()(bool b) const {
+ if (!b) {
+ return Unit{};
+ } else {
+ return TraceExpectError("ExpectTrue");
+ }
+ }
+};
+
+// Returns Result<T> if the current result has std::optional<T> has a value;
+// otherwise returns an error Result.
+struct ExpectHasValue : public ExpectBase {
+ using ExpectBase::ExpectBase;
+ constexpr explicit ExpectHasValue(
+ SourceLocation loc = SourceLocation::current())
+ : ExpectBase(loc) {}
+
+ template <typename T>
+ constexpr Result<T> operator()(std::optional<T> v) const {
+ if (v.has_value()) {
+ return *std::move(v);
+ } else {
+ return TraceExpectError("ExpectHasValue");
+ }
+ }
+};
+
+// Returns Result<Unit> if the current result has an empty std::optional;
+// otherwise returns an error Result.
+struct ExpectIsEmpty : public ExpectBase {
+ using ExpectBase::ExpectBase;
+ constexpr explicit ExpectIsEmpty(
+ SourceLocation loc = SourceLocation::current())
+ : ExpectBase(loc) {}
+
+ template <typename T>
+ constexpr Result<Unit> operator()(std::optional<T> v) const {
+ if (!v.has_value()) {
+ return Unit{};
+ } else {
+ return TraceExpectError("ExpectIsEmpty");
+ }
+ }
+};
+
+struct ExpectOk : public ExpectBase {
+ using ExpectBase::ExpectBase;
+ constexpr explicit ExpectOk(SourceLocation loc = SourceLocation::current())
+ : ExpectBase(loc) {}
+
+ template <typename T>
+ constexpr Result<T> operator()(StatusOr<T> s) const {
+ if (s.ok()) {
+ return std::move(s).value();
+ } else {
+ return TraceUnexpectedStatus(fcp::OK, s.status());
+ }
+ }
+
+ Result<Unit> operator()(const Status& s) const {
+ if (s.code() == fcp::OK) {
+ return Unit{};
+ } else {
+ return TraceUnexpectedStatus(fcp::OK, s);
+ }
+ }
+};
+
+} // namespace fcp
+
+#define FCP_TRY(...) \
+ ({ \
+ auto try_tmp_value_ = (__VA_ARGS__); \
+ if (::fcp::result_internal::ResultIsError(try_tmp_value_)) { \
+ return try_tmp_value_.GetErrorOrDie(); \
+ } \
+ std::move(try_tmp_value_).GetValueOrDie(); \
+ })
+
+#endif // FCP_BASE_RESULT_H_
diff --git a/fcp/base/result_test.cc b/fcp/base/result_test.cc
new file mode 100644
index 0000000..d05696a
--- /dev/null
+++ b/fcp/base/result_test.cc
@@ -0,0 +1,317 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/result.h"
+
+#include <memory>
+#include <utility>
+#include <variant>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/tracing_schema.h"
+#include "fcp/base/unique_value.h"
+#include "fcp/testing/result_matchers.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+
+using ::testing::Eq;
+using ::testing::VariantWith;
+
+template <typename T>
+constexpr bool HasValue(Result<T> r, T v) {
+ return !r.is_error() && r.GetValueOrDie() == v;
+}
+
+template <typename T>
+constexpr bool IsError(Result<T> r) {
+ return r.is_error();
+}
+
+TEST(ResultTest, Constructor) {
+ ASSERT_THAT(Result<int>(TraceTestError()), IsError());
+ ASSERT_THAT(Result(123), HasValue(123));
+}
+
+TEST(ResultTest, CombinatorsToValue) {
+ Result<bool> r = Result(123)
+ .Then([](int i) -> Result<bool> { return i != 123; })
+ .Map([](bool b) -> bool { return !b; });
+ ASSERT_THAT(r, HasValue(true));
+}
+
+TEST(ResultTest, CombinatorsToValue_MoveOnly) {
+ Result<bool> r =
+ Result(UniqueValue(123))
+ .Then([](UniqueValue<int> i) -> Result<UniqueValue<bool>> {
+ return UniqueValue(std::move(i).Take() != 123);
+ })
+ .Map([](UniqueValue<bool> b) -> UniqueValue<bool> {
+ return UniqueValue(!std::move(b).Take());
+ })
+ .Map([](UniqueValue<bool> b) -> bool { return std::move(b).Take(); });
+ ASSERT_THAT(r, HasValue(true));
+}
+
+TEST(ResultTest, MapToValue_MoveOnly_Const) {
+ Result<UniqueValue<int>> r1 = Result(UniqueValue(21));
+ Result<int> r2 = r1.Map([](UniqueValue<int> const& v) { return (*v) * 2; });
+ Result<bool> r1_still_valid =
+ r1.Map([](UniqueValue<int> const& v) { return v.has_value(); });
+
+ ASSERT_THAT(r2, HasValue(42));
+ ASSERT_THAT(r1_still_valid, HasValue(true));
+}
+
+TEST(ResultTest, ThenToValue_MoveOnly_Const) {
+ Result<UniqueValue<int>> r1 = Result(UniqueValue(21));
+ Result<int> r2 =
+ r1.Then([](UniqueValue<int> const& v) { return Result((*v) * 2); });
+ Result<bool> r1_still_valid =
+ r1.Then([](UniqueValue<int> const& v) { return Result(v.has_value()); });
+
+ ASSERT_THAT(r2, HasValue(42));
+ ASSERT_THAT(r1_still_valid, HasValue(true));
+}
+
+void ExpectUnreachable() { FAIL(); }
+
+TEST(ResultTest, CombinatorsToError) {
+ Result<Unit> r = Result(123)
+ .Then([](int i) -> Result<int> {
+ if (i > 0) {
+ return TraceTestError();
+ } else {
+ return i;
+ }
+ })
+ .Map([](int i) -> Unit {
+ ExpectUnreachable();
+ return Unit{};
+ });
+ ASSERT_THAT(r, IsError());
+}
+
+TEST(ResultTest, ResultFrom) {
+ static_assert(std::is_same_v<ResultFrom<Result<int>>, Result<int>>);
+ static_assert(std::is_same_v<ResultFrom<int>, Result<int>>);
+}
+
+template <typename Expect, typename Fn, typename... Args>
+constexpr Unit ExpectResultOf(Fn fn, Args... args) {
+ using R = ResultOf<Fn, Args...>;
+ static_assert(std::is_same_v<R, Expect>);
+ return {};
+}
+
+namespace result_of_example {
+
+Result<Unit> Result0() { return Unit{}; }
+Result<Unit> Result1(int) { return Unit{}; }
+Unit Value1(int i) { return Unit{}; }
+constexpr auto Generic = [](auto t) { return Result(t); };
+
+} // namespace result_of_example
+
+TEST(ResultTest, ResultOf) {
+ static_assert(
+ ExpectResultOf<Result<Unit>>(result_of_example::Result0).True());
+ static_assert(
+ ExpectResultOf<Result<Unit>>(result_of_example::Result1, 123).True());
+ static_assert(
+ ExpectResultOf<Result<Unit>>(result_of_example::Value1, 123).True());
+ static_assert(
+ ExpectResultOf<Result<bool>>(result_of_example::Generic, true).True());
+}
+
+Result<bool> Example_OneTryExpression(Result<int> r) {
+ int i = FCP_TRY(r);
+ if (i < 0) {
+ return TraceTestError();
+ }
+
+ return i % 2 == 0;
+}
+
+TEST(ResultTest, TryExpressionWithError) {
+ EXPECT_THAT(Example_OneTryExpression(TraceTestError()), IsError());
+}
+
+TEST(ResultTest, TryExpressionWithValue) {
+ EXPECT_THAT(Example_OneTryExpression(-1), IsError());
+ EXPECT_THAT(Example_OneTryExpression(1), HasValue(false));
+ EXPECT_THAT(Example_OneTryExpression(2), HasValue(true));
+}
+
+Result<bool> Example_OneTryExpression_UnparenthesizedCommas(
+ Result<std::variant<int, bool, Unit>> r) {
+ std::variant<int, bool> v = FCP_TRY(r.Then(ExpectOneOf<int, bool>()));
+ if (std::holds_alternative<int>(v)) {
+ return absl::get<int>(v) > 0;
+ } else {
+ return absl::get<bool>(v);
+ }
+}
+
+TEST(ResultTest, TryExpressionWithValue_UnparenthesizedCommas) {
+ using V = std::variant<int, bool, Unit>;
+ EXPECT_THAT(Example_OneTryExpression_UnparenthesizedCommas(V(-1)),
+ HasValue(false));
+ EXPECT_THAT(Example_OneTryExpression_UnparenthesizedCommas(V(1)),
+ HasValue(true));
+ EXPECT_THAT(Example_OneTryExpression_UnparenthesizedCommas(V(false)),
+ HasValue(false));
+ EXPECT_THAT(Example_OneTryExpression_UnparenthesizedCommas(V(true)),
+ HasValue(true));
+ EXPECT_THAT(Example_OneTryExpression_UnparenthesizedCommas(V(Unit{})),
+ IsError());
+}
+
+Result<int> Example_SumWithTryExpressions(Result<int> a, Result<int> b) {
+ return FCP_TRY(a) + FCP_TRY(b);
+}
+
+TEST(ResultTest, TwoTryExpressionsWithError) {
+ EXPECT_THAT(Example_SumWithTryExpressions(TraceTestError(), 1), IsError());
+ EXPECT_THAT(Example_SumWithTryExpressions(41, TraceTestError()), IsError());
+ EXPECT_THAT(Example_SumWithTryExpressions(TraceTestError(), TraceTestError()),
+ IsError());
+}
+
+TEST(ResultTest, TwoTryExpressionsWithValues) {
+ EXPECT_THAT(Example_SumWithTryExpressions(1, 41), HasValue(42));
+}
+
+Result<int> Example_TryExpression_MoveOnly(Result<std::unique_ptr<int>> r) {
+ std::unique_ptr<int> p = FCP_TRY(std::move(r));
+ return *p;
+}
+
+TEST(ResultTest, TryExpressionWithError_MoveOnly) {
+ EXPECT_THAT(Example_TryExpression_MoveOnly(TraceTestError()), IsError());
+}
+
+TEST(ResultTest, TryExpressionWithValue_MoveOnly) {
+ EXPECT_THAT(Example_TryExpression_MoveOnly(std::make_unique<int>(123)),
+ HasValue(123));
+}
+
+Result<bool> Example_TryStatement(Result<Unit> r) {
+ FCP_TRY(r);
+ return true;
+}
+
+TEST(ResultTest, TryStatementWithError) {
+ EXPECT_THAT(Example_TryStatement(TraceTestError()), IsError());
+}
+
+TEST(ResultTest, TryStatementWithValue) {
+ EXPECT_THAT(Example_TryStatement(Unit{}), HasValue(true));
+}
+
+TEST(ResultTest, ExpectTrue) {
+ EXPECT_THAT(Result(true).Then(ExpectTrue()), HasValue(Unit{}));
+ EXPECT_THAT(Result(false).Then(ExpectTrue()), IsError());
+ EXPECT_THAT(Result<bool>(TraceTestError()).Then(ExpectTrue()), IsError());
+}
+
+TEST(ResultTest, ExpectFalse) {
+ EXPECT_THAT(Result(false).Then(ExpectFalse()), HasValue(Unit{}));
+ EXPECT_THAT(Result(true).Then(ExpectFalse()), IsError());
+ EXPECT_THAT(Result<bool>(TraceTestError()).Then(ExpectFalse()), IsError());
+}
+
+TEST(ResultTest, ExpectHasValue) {
+ using V = std::optional<int>;
+ EXPECT_THAT(Result<V>(123).Then(ExpectHasValue()), HasValue(123));
+ EXPECT_THAT(Result<V>(V{}).Then(ExpectHasValue()), IsError());
+ EXPECT_THAT(Result<V>(TraceTestError()).Then(ExpectHasValue()), IsError());
+}
+
+TEST(ResultTest, ExpectIsEmpty) {
+ using V = std::optional<int>;
+ EXPECT_THAT(Result<V>(123).Then(ExpectIsEmpty()), IsError());
+ EXPECT_THAT(Result<V>(V{}).Then(ExpectIsEmpty()), HasValue(Unit{}));
+ EXPECT_THAT(Result<V>(TraceTestError()).Then(ExpectIsEmpty()), IsError());
+}
+
+TEST(ResultTest, ExpectIs) {
+ using V = std::variant<int, char>;
+ EXPECT_THAT(Result<V>(123).Then(ExpectIs<int>()), HasValue(123));
+ EXPECT_THAT(Result<V>('a').Then(ExpectIs<char>()), HasValue('a'));
+ EXPECT_THAT(Result<V>('b').Then(ExpectIs<int>()), IsError());
+ EXPECT_THAT(Result<V>(TraceTestError()).Then(ExpectIs<int>()), IsError());
+ EXPECT_THAT(Result<V>(TraceTestError()).Then(ExpectIs<char>()), IsError());
+}
+
+TEST(ResultTest, ExpectOneOf) {
+ using V = std::variant<int, char, bool>;
+ EXPECT_THAT(Result<V>(123).Then(ExpectOneOf<int>()),
+ HasValue(VariantWith<int>(123)));
+ EXPECT_THAT(Result<V>(123).Then(ExpectOneOf<bool>()), IsError());
+ EXPECT_THAT((Result<V>(123).Then(ExpectOneOf<int, bool>())),
+ HasValue(VariantWith<int>(123)));
+ EXPECT_THAT((Result<V>(123).Then(ExpectOneOf<char, bool>())), IsError());
+ EXPECT_THAT((Result<V>(TraceTestError()).Then(ExpectOneOf<int, bool>())),
+ IsError());
+}
+
+TEST(ResultTest, ExpectOk) {
+ TestTracingRecorder recorder;
+ EXPECT_THAT(Result<Status>(FCP_STATUS(OK)).Then(ExpectOk()),
+ HasValue(Unit{}));
+}
+
+TEST(ResultTest, ExpectOkReturnsError) {
+ TestTracingRecorder recorder;
+ recorder.ExpectError<ResultExpectStatusError>();
+ EXPECT_THAT(Result<Status>(FCP_STATUS(INVALID_ARGUMENT)).Then(ExpectOk()),
+ IsError());
+}
+
+TEST(ResultTest, ExpectOkStatusOr) {
+ TestTracingRecorder recorder;
+ EXPECT_THAT(Result<StatusOr<Unit>>(StatusOr<Unit>(Unit{})).Then(ExpectOk()),
+ HasValue(Unit{}));
+}
+
+TEST(ResultTest, ExpectOkStatusOrReturnsError) {
+ TestTracingRecorder recorder;
+ recorder.ExpectError<ResultExpectStatusError>();
+ EXPECT_THAT(
+ Result<StatusOr<Unit>>(FCP_STATUS(INVALID_ARGUMENT)).Then(ExpectOk()),
+ IsError());
+ EXPECT_THAT(
+ recorder.FindAllEvents<ResultExpectStatusError>(),
+ ElementsAre(IsEvent<ResultExpectStatusError>(
+ Eq(TracingStatusCode_Ok), Eq(TracingStatusCode_InvalidArgument))));
+}
+
+TEST(ResultTest, TraceFailedPrecondition) {
+ TestTracingRecorder recorder;
+ recorder.ExpectError<ResultExpectStatusError>();
+ EXPECT_THAT(
+ Result<StatusOr<Unit>>(FCP_STATUS(FAILED_PRECONDITION)).Then(ExpectOk()),
+ IsError());
+ EXPECT_THAT(
+ recorder.FindAllEvents<ResultExpectStatusError>(),
+ ElementsAre(IsEvent<ResultExpectStatusError>(
+ Eq(TracingStatusCode_Ok), Eq(TracingStatusCode_FailedPrecondition))));
+}
+
+} // namespace fcp
diff --git a/fcp/base/scheduler.cc b/fcp/base/scheduler.cc
new file mode 100644
index 0000000..e815d77
--- /dev/null
+++ b/fcp/base/scheduler.cc
@@ -0,0 +1,205 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/scheduler.h"
+
+#include <array>
+#include <functional>
+#include <memory>
+#include <queue>
+#include <thread> // NOLINT(build/c++11)
+#include <utility>
+#include <vector>
+
+#include "absl/synchronization/blocking_counter.h"
+#include "absl/synchronization/mutex.h"
+
+namespace fcp {
+
+namespace {
+
+// A helper class to track information about lifetime of an object.
+// Uses a shared pointer (SharedMarker) to a boolean memory fragment
+// which remembers if the object has been destroyed. Capturing the
+// marker in a lambda gives us a clean way to CHECK fail if the
+// object is accessed post destruction.
+class LifetimeTracker {
+ public:
+ using SharedMarker = std::shared_ptr<bool>;
+ LifetimeTracker() : marker_(std::make_shared<bool>(true)) {}
+ virtual ~LifetimeTracker() { *marker_ = false; }
+ SharedMarker& marker() { return marker_; }
+
+ private:
+ SharedMarker marker_;
+};
+
+// Implementation of workers.
+class WorkerImpl : public Worker, public LifetimeTracker {
+ public:
+ explicit WorkerImpl(Scheduler* scheduler) : scheduler_(scheduler) {}
+
+ ~WorkerImpl() override = default;
+
+ void Schedule(std::function<void()> task) override {
+ absl::MutexLock lock(&busy_);
+ steps_.emplace_back(std::move(task));
+ MaybeRunNext();
+ }
+
+ private:
+ void MaybeRunNext() ABSL_EXCLUSIVE_LOCKS_REQUIRED(busy_) {
+ if (running_ || steps_.empty()) {
+ // Already running, and next task will be executed when finished, or
+ // nothing to run.
+ return;
+ }
+ auto task = std::move(steps_.front());
+ steps_.pop_front();
+ running_ = true;
+ auto wrapped_task = MoveToLambda(std::move(task));
+ auto marker = this->marker();
+ scheduler_->Schedule([this, marker, wrapped_task] {
+ // Call the Task which is stored in wrapped_task.value.
+ (*wrapped_task)();
+
+ // Run the next task.
+ FCP_CHECK(*marker) << "Worker destroyed before all tasks finished";
+ {
+ // Try run next task if any.
+ absl::MutexLock lock(&this->busy_);
+ this->running_ = false;
+ this->MaybeRunNext();
+ }
+ });
+ }
+
+ Scheduler* scheduler_;
+ absl::Mutex busy_;
+ bool running_ ABSL_GUARDED_BY(busy_) = false;
+ std::deque<std::function<void()>> steps_ ABSL_GUARDED_BY(busy_);
+};
+
+// Implementation of thread pools.
+class ThreadPoolScheduler : public Scheduler {
+ public:
+ explicit ThreadPoolScheduler(std::size_t thread_count)
+ : idle_condition_(absl::Condition(IdleCondition, this)),
+ active_count_(thread_count) {
+ FCP_CHECK(thread_count > 0) << "invalid thread_count";
+
+ // Create threads.
+ for (int i = 0; i < thread_count; ++i) {
+ threads_.emplace_back(std::thread([this] { this->PerThreadActivity(); }));
+ }
+ }
+
+ ~ThreadPoolScheduler() override {
+ {
+ absl::MutexLock lock(&busy_);
+ FCP_CHECK(IdleCondition(this))
+ << "Thread pool must be idle at destruction time";
+
+ threads_should_join_ = true;
+ work_available_cond_var_.SignalAll();
+ }
+
+ for (auto& thread : threads_) {
+ FCP_CHECK(thread.joinable()) << "Attempted to destroy a threadpool from "
+ "one of its running threads";
+ thread.join();
+ }
+ }
+
+ void Schedule(std::function<void()> task) override {
+ absl::MutexLock lock(&busy_);
+ todo_.push(std::move(task));
+ // Wake up a *single* thread to handle this task.
+ work_available_cond_var_.Signal();
+ }
+
+ void WaitUntilIdle() override {
+ busy_.LockWhen(idle_condition_);
+ busy_.Unlock();
+ }
+
+ static bool IdleCondition(ThreadPoolScheduler* pool)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(pool->busy_) {
+ return pool->todo_.empty() && pool->active_count_ == 0;
+ }
+
+ void PerThreadActivity() {
+ for (;;) {
+ std::function<void()> task;
+ {
+ absl::MutexLock lock(&busy_);
+ --active_count_;
+ while (todo_.empty()) {
+ if (threads_should_join_) {
+ return;
+ }
+
+ work_available_cond_var_.Wait(&busy_);
+ }
+
+ // Destructor invariant
+ FCP_CHECK(!threads_should_join_);
+ task = std::move(todo_.front());
+ todo_.pop();
+ ++active_count_;
+ }
+
+ task();
+ }
+ }
+
+ // A vector of threads allocated for execution.
+ std::vector<std::thread> threads_;
+
+ // A CondVar used to signal availability of tasks.
+ //
+ // We would prefer to use the more declarative absl::Condition instead,
+ // however, this one only allows to wake up all threads if a new task is
+ // available -- but we want to wake up only one in this case.
+ absl::CondVar work_available_cond_var_;
+
+ // See IdleCondition
+ absl::Condition idle_condition_;
+
+ // A mutex protecting mutable state in this class.
+ absl::Mutex busy_;
+
+ // Set when worker threads should join instead of waiting for work.
+ bool threads_should_join_ ABSL_GUARDED_BY(busy_) = false;
+
+ // Queue of tasks with work to do.
+ std::queue<std::function<void()>> todo_ ABSL_GUARDED_BY(busy_);
+
+ // The number of threads currently doing work in this pool.
+ std::size_t active_count_ ABSL_GUARDED_BY(busy_);
+};
+
+} // namespace
+
+std::unique_ptr<Worker> Scheduler::CreateWorker() {
+ return std::make_unique<WorkerImpl>(this);
+}
+
+std::unique_ptr<Scheduler> CreateThreadPoolScheduler(std::size_t thread_count) {
+ return std::make_unique<ThreadPoolScheduler>(thread_count);
+}
+
+} // namespace fcp
diff --git a/fcp/base/scheduler.h b/fcp/base/scheduler.h
new file mode 100644
index 0000000..2be149e
--- /dev/null
+++ b/fcp/base/scheduler.h
@@ -0,0 +1,110 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_SCHEDULER_H_
+#define FCP_BASE_SCHEDULER_H_
+
+/**
+ * Overview
+ * ========
+ *
+ * A simple implementation of a scheduler (thread pool). Allows to schedule
+ * tasks and futures.
+ */
+
+#include <functional>
+#include <memory>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/base/move_to_lambda.h"
+
+namespace fcp {
+
+/**
+ * A Worker allows to schedule tasks which are executed sequentially.
+ * Workers are created from a Scheduler.
+ *
+ * Lifetime and destruction:
+ *
+ * - The scheduler from which a worker is created must not be destructed
+ * before the worker.
+ *
+ * - The worker must not be destructed before all its tasks are finished.
+ */
+class Worker {
+ public:
+ virtual ~Worker() = default;
+ Worker() = default;
+
+ Worker(Worker const&) = delete;
+ Worker& operator=(Worker const&) = delete;
+
+ /**
+ * Schedules a task on this worker. Tasks are executed strictly sequentially
+ * in the order they are scheduled.
+ */
+ virtual void Schedule(std::function<void()> task) = 0;
+};
+
+/**
+ * A Scheduler which allows to schedule 'tasks'.
+ *
+ * Lifetime and destruction:
+ *
+ * - A Scheduler *must* be idle (no active or pending work) at destruction
+ * time. See WaitUntilIdle.
+ *
+ * - Implies: A Scheduler *must not* be destructed by one of its own tasks
+ *
+ * - Implies: Task closures may safely hold raw pointers to their thread pool.
+ * They should *not* have ownership (via a smart-pointer or similar).
+ */
+class Scheduler {
+ public:
+ virtual ~Scheduler() = default;
+ Scheduler() = default;
+
+ Scheduler(Scheduler const&) = delete;
+ Scheduler& operator=(Scheduler const&) = delete;
+
+ /**
+ * Creates a new Worker based on this scheduler.
+ */
+ virtual std::unique_ptr<Worker> CreateWorker();
+
+ /**
+ * Schedules a task that will execute on the scheduler.
+ */
+ virtual void Schedule(std::function<void()> task) = 0;
+
+ /**
+ * Waits until there are no tasks running or pending.
+ *
+ * In this state, the thread pool will not restart working until some
+ * external entity is scheduling new tasks, as work caused by tasks spawning
+ * other tasks has ceased.
+ */
+ virtual void WaitUntilIdle() = 0;
+};
+
+/**
+ * Creates a scheduler using a fixed-size pool of threads to run tasks.
+ */
+std::unique_ptr<Scheduler> CreateThreadPoolScheduler(std::size_t thread_count);
+
+} // namespace fcp
+
+#endif // FCP_BASE_SCHEDULER_H_
diff --git a/fcp/base/scheduler_test.cc b/fcp/base/scheduler_test.cc
new file mode 100644
index 0000000..d9587e6
--- /dev/null
+++ b/fcp/base/scheduler_test.cc
@@ -0,0 +1,128 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/scheduler.h"
+
+#include <atomic>
+#include <cstdlib> // for std::rand
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/blocking_counter.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace base {
+namespace {
+
+// NOTE: many of tests below use reference captures in lambdas for locals.
+// This is sound because the test methods do not return before the thread
+// pool has become idle (pool->WaitUntilIdle()).
+
+// Tests whether scheduled tasks are successfully executed.
+TEST(ThreadPool, TasksAreExecuted) {
+ auto pool = CreateThreadPoolScheduler(2);
+
+ bool b1 = false;
+ bool b2 = false;
+ pool->Schedule([&b1]() { b1 = true; });
+ pool->Schedule([&b2]() { b2 = true; });
+
+ pool->WaitUntilIdle();
+
+ EXPECT_TRUE(b1);
+ EXPECT_TRUE(b2);
+}
+
+// Tests whether the pool uses actually multiple threads to execute tasks.
+// The test goal is achieved by blocking in one task until another task
+// unblocks, which can only work if multiple threads are used.
+TEST(ThreadPool, ThreadsAreUtilized) {
+ auto pool = CreateThreadPoolScheduler(2);
+
+ absl::BlockingCounter counter(1);
+ bool b1 = false;
+ bool b2 = false;
+
+ pool->Schedule([&b1, &counter] {
+ counter.Wait();
+ b1 = true;
+ });
+ pool->Schedule([&b2, &counter] {
+ counter.DecrementCount();
+ b2 = true;
+ });
+
+ pool->WaitUntilIdle();
+
+ EXPECT_TRUE(b1);
+ EXPECT_TRUE(b2);
+}
+
+TEST(ThreadPool, StressTest) {
+ // A simple stress test where we spawn many threads and let them after
+ // a random wait time increment a counter.
+ static constexpr int kThreads = 32;
+ static constexpr int kIterations = 16;
+ auto pool = CreateThreadPoolScheduler(kThreads);
+ std::atomic<int64_t> atomic_counter{0};
+
+ for (auto i = 0; i < kThreads; ++i) {
+ auto task = [&atomic_counter] {
+ for (auto j = 0; j < kIterations; ++j) {
+ absl::SleepFor(absl::Microseconds(std::rand() % 500));
+ atomic_counter.fetch_add(1);
+ }
+ };
+ pool->Schedule(task);
+ }
+
+ pool->WaitUntilIdle();
+ ASSERT_EQ(atomic_counter, kThreads * kIterations);
+}
+
+TEST(Worker, TasksAreExecutedSequentially) {
+ auto pool = CreateThreadPoolScheduler(3);
+ auto worker = pool->CreateWorker();
+ absl::Mutex mutex{};
+ std::vector<int> recorded{};
+ for (int i = 0; i < 128; i++) {
+ worker->Schedule([&mutex, &recorded, i] {
+ // Expect that no one is holding the mutex (tests for non-overlap).
+ if (mutex.TryLock()) {
+ // Add i to the recorded values (tests for execution in order).
+ recorded.push_back(i);
+ // Idle wait to be sure we don't execute faster than we schedule
+ absl::SleepFor(absl::Milliseconds(50));
+ mutex.Unlock();
+ } else {
+ FAIL() << "mutex was unexpectedly hold";
+ }
+ });
+ }
+ pool->WaitUntilIdle();
+
+ // Verify recorded values.
+ for (int i = 0; i < 128; i++) {
+ ASSERT_EQ(recorded[i], i);
+ }
+}
+
+} // namespace
+
+} // namespace base
+} // namespace fcp
diff --git a/fcp/base/simulated_clock.cc b/fcp/base/simulated_clock.cc
new file mode 100644
index 0000000..e1c3097
--- /dev/null
+++ b/fcp/base/simulated_clock.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/simulated_clock.h"
+
+namespace fcp {
+
+absl::Time SimulatedClock::Now() {
+ absl::MutexLock lock(mutex());
+ return NowLocked();
+}
+
+absl::Time SimulatedClock::NowLocked() {
+ mutex()->AssertHeld();
+ return now_;
+}
+
+void SimulatedClock::Sleep(absl::Duration d) {
+ absl::Time current;
+ {
+ absl::MutexLock lock(mutex());
+ current = now_;
+ }
+ absl::Time deadline = current + d;
+ while (true) {
+ {
+ absl::MutexLock lock(mutex());
+ current = now_;
+ }
+ if (current >= deadline) {
+ return;
+ }
+ }
+}
+
+void SimulatedClock::SetTime(absl::Time t) {
+ {
+ absl::MutexLock lock(mutex());
+ now_ = t;
+ }
+ DispatchWakeups();
+}
+
+void SimulatedClock::AdvanceTime(absl::Duration d) {
+ {
+ absl::MutexLock lock(mutex());
+ now_ += d;
+ }
+ DispatchWakeups();
+}
+
+} // namespace fcp
+
diff --git a/fcp/base/simulated_clock.h b/fcp/base/simulated_clock.h
new file mode 100644
index 0000000..625ad74
--- /dev/null
+++ b/fcp/base/simulated_clock.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_SIMULATED_CLOCK_H_
+#define FCP_BASE_SIMULATED_CLOCK_H_
+
+#include "fcp/base/clock.h"
+
+namespace fcp {
+
+/*
+ * A simulated clock is a concrete Clock implementation that does not "tick"
+ * on its own. Time is advanced by explicit calls to the AdvanceTime() or
+ * SetTime() functions.
+ */
+class SimulatedClock : public Clock {
+ public:
+ // Construct SimulatedClock with a specific initial time.
+ explicit SimulatedClock(absl::Time t) : now_(t) {}
+
+ // Construct SimulatedClock with default initial time (1970-01-01 00:00:00)
+ SimulatedClock() : SimulatedClock(absl::UnixEpoch()) {}
+
+ // Returns the simulated time.
+ absl::Time Now() override;
+
+ // Sleeps until the specified duration has elapsed according to this clock.
+ void Sleep(absl::Duration d) override;
+
+ // Sets the simulated time. Wakes up any waiters whose deadlines have now
+ // expired.
+ void SetTime(absl::Time t);
+
+ // Advances the simulated time. Wakes up any waiters whose deadlines have now
+ // expired.
+ void AdvanceTime(absl::Duration d);
+
+ private:
+ // Returns the simulated time (called internally from the base class).
+ absl::Time NowLocked() override;
+
+ // No specific scheduling is needed for SimulatedClock.
+ void ScheduleWakeup(absl::Time wakeup_time) override {}
+
+ absl::Time now_ ABSL_GUARDED_BY(mutex());
+};
+
+} // namespace fcp
+
+#endif // FCP_BASE_SIMULATED_CLOCK_H_
diff --git a/fcp/base/simulated_clock_test.cc b/fcp/base/simulated_clock_test.cc
new file mode 100644
index 0000000..30e4ad4
--- /dev/null
+++ b/fcp/base/simulated_clock_test.cc
@@ -0,0 +1,193 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/simulated_clock.h"
+
+#include <functional>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/time/civil_time.h"
+#include "absl/time/time.h"
+
+namespace fcp {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+
+// Simple callback waiter that runs the function on Wakeup.
+class CallbackWaiter : public Clock::Waiter {
+ public:
+ explicit CallbackWaiter(std::function<void()> callback)
+ : callback_(std::move(callback)) {}
+
+ void WakeUp() override { callback_(); }
+
+ private:
+ std::function<void()> callback_;
+};
+
+// Simple test waiter that adds its ID to the provided vector when WakeUp is
+// called. This is used to verify that waiters are woken up in the right order.
+class TestWaiter : public CallbackWaiter {
+ public:
+ explicit TestWaiter(int id, std::vector<int>* output)
+ : CallbackWaiter([=]() { output->push_back(id); }) {}
+};
+
+absl::Time GetTestInitialTime() {
+ return absl::FromCivil(absl::CivilDay(2020, 1, 1), absl::LocalTimeZone());
+}
+
+TEST(SimulatedClockTest, GetAndUpdateNow) {
+ absl::Time t = absl::UnixEpoch();
+ SimulatedClock clock;
+ EXPECT_THAT(clock.Now(), t);
+
+ absl::Time t2 = GetTestInitialTime();
+
+ SimulatedClock clock2(t2);
+ EXPECT_THAT(clock2.Now(), t2);
+
+ absl::Time t3 = t2 + absl::Seconds(42);
+ clock2.AdvanceTime(absl::Seconds(42));
+ EXPECT_THAT(clock2.Now(), t3);
+
+ absl::Time t4 = t3 + absl::Seconds(18);
+ clock2.SetTime(t4);
+ EXPECT_THAT(clock2.Now(), Eq(t4));
+}
+
+// Verifies that waiters with future deadlines are not triggered unless the
+// time is advanced.
+TEST(SimulatedClockTest, FutureDeadline) {
+ std::vector<int> output;
+ absl::Time t = GetTestInitialTime();
+ SimulatedClock clock(t);
+
+ clock.WakeupWithDeadline(t + absl::Seconds(1),
+ std::make_shared<TestWaiter>(1, &output));
+ EXPECT_THAT(output, ElementsAre());
+
+ clock.AdvanceTime(absl::Seconds(1));
+ EXPECT_THAT(output, ElementsAre(1));
+
+ // Advancing time again doesn't trigger the same waiter again.
+ clock.AdvanceTime(absl::Seconds(1));
+ EXPECT_THAT(output, ElementsAre(1));
+}
+
+// Verifies that the order of waiters with maching deadlines is preserved
+// when their wake-up is triggered.
+TEST(SimulatedClockTest, MatchingDeadlines) {
+ std::vector<int> output;
+ absl::Time t = GetTestInitialTime();
+ SimulatedClock clock(t);
+
+ absl::Time t1 = t + absl::Seconds(1);
+ absl::Time t2 = t + absl::Seconds(2);
+ clock.WakeupWithDeadline(t1, std::make_shared<TestWaiter>(1, &output));
+ clock.WakeupWithDeadline(t2, std::make_shared<TestWaiter>(2, &output));
+ clock.WakeupWithDeadline(t1, std::make_shared<TestWaiter>(3, &output));
+ clock.WakeupWithDeadline(t2, std::make_shared<TestWaiter>(4, &output));
+ clock.WakeupWithDeadline(t1, std::make_shared<TestWaiter>(5, &output));
+
+ // Trigger all waiters.
+ clock.AdvanceTime(absl::Seconds(2));
+ EXPECT_THAT(output, ElementsAre(1, 3, 5, 2, 4));
+}
+
+// Verifies that waiters with current or past deadlines are triggered promptly.
+TEST(SimulatedClockTest, PastAndCurrentDeadlines) {
+ std::vector<int> output;
+ absl::Time t =
+ absl::FromCivil(absl::CivilDay(2020, 1, 1), absl::LocalTimeZone());
+ SimulatedClock clock(t);
+
+ clock.WakeupWithDeadline(t, std::make_shared<TestWaiter>(1, &output));
+ clock.WakeupWithDeadline(t - absl::Seconds(1),
+ std::make_shared<TestWaiter>(2, &output));
+ EXPECT_THAT(output, ElementsAre(1, 2));
+}
+
+// Verifies that only expired waiters are triggered.
+TEST(SimulatedClockTest, MultipleWaiters) {
+ std::vector<int> output;
+ absl::Time t = GetTestInitialTime();
+ SimulatedClock clock(t);
+
+ clock.WakeupWithDeadline(t + absl::Seconds(30),
+ std::make_shared<TestWaiter>(1, &output));
+ clock.WakeupWithDeadline(t + absl::Seconds(20),
+ std::make_shared<TestWaiter>(2, &output));
+ clock.WakeupWithDeadline(t + absl::Seconds(10),
+ std::make_shared<TestWaiter>(3, &output));
+ // Advance by 15 seconds
+ clock.AdvanceTime(absl::Seconds(15));
+ // Advance by another 5 seconds
+ clock.AdvanceTime(absl::Seconds(5));
+ // Only waiters 3 and 2 should be triggered.
+ EXPECT_THAT(output, ElementsAre(3, 2));
+}
+
+// Verifies that a new timer can be scheduled when anoter timer is triggered.
+TEST(SimulatedClockTest, RecursiveWakeup) {
+ std::vector<int> output;
+ absl::Time t = GetTestInitialTime();
+ SimulatedClock clock(t);
+
+ clock.WakeupWithDeadline(t + absl::Seconds(20),
+ std::make_shared<TestWaiter>(1, &output));
+ clock.WakeupWithDeadline(
+ t + absl::Seconds(20), std::make_shared<CallbackWaiter>([&]() {
+ output.push_back(2);
+ clock.WakeupWithDeadline(t + absl::Seconds(15),
+ std::make_shared<TestWaiter>(3, &output));
+ }));
+ clock.AdvanceTime(absl::Seconds(20));
+ // Both waiters are triggered because the #3 one is already expired when
+ // inserted recursively by waiter #2.
+ EXPECT_THAT(output, ElementsAre(1, 2, 3));
+}
+
+// Verifies that a long taking Wakeup notification results in triggering
+// other waiters that expire later.
+TEST(SimulatedClockTest, LongRunningWakeup) {
+ std::vector<int> output;
+ absl::Time t = GetTestInitialTime();
+ SimulatedClock clock(t);
+
+ clock.WakeupWithDeadline(t + absl::Seconds(10),
+ std::make_shared<TestWaiter>(1, &output));
+ clock.WakeupWithDeadline(
+ t + absl::Seconds(20), std::make_shared<CallbackWaiter>([&]() {
+ output.push_back(2);
+ clock.AdvanceTime(absl::Seconds(10));
+ }));
+ clock.WakeupWithDeadline(t + absl::Seconds(30),
+ std::make_shared<TestWaiter>(3, &output));
+ // Advance time by 20 second, which will advance time by another 10 seconds
+ // when waking up waiter #2.
+ clock.AdvanceTime(absl::Seconds(20));
+ EXPECT_THAT(output, ElementsAre(1, 2, 3));
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/base/source_location.h b/fcp/base/source_location.h
new file mode 100644
index 0000000..984307f
--- /dev/null
+++ b/fcp/base/source_location.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_BASE_SOURCE_LOCATION_H_
+#define FCP_BASE_SOURCE_LOCATION_H_
+
+namespace fcp {
+
+#if (__clang_major__ >= 9) || (__GNUC__ >= 7)
+// Using non-standard builtin extensions of gcc and clang to capture call-site
+// location
+#define FCP_HAS_SOURCE_LOCATION
+#define FCP_BUILTIN_LINE() __builtin_LINE()
+#define FCP_BUILTIN_FILE() __builtin_FILE()
+#else
+// If compiler feature unavailable replace with stub values
+#define FCP_BUILTIN_LINE() (0)
+#define FCP_BUILTIN_FILE() ("<unknown_source>")
+
+#endif
+
+class SourceLocationImpl {
+ public:
+ static constexpr SourceLocationImpl current(
+ // Builtins _must_ be referenced from default arguments, so they get
+ // evaluated at the callsite.
+ int line = FCP_BUILTIN_LINE(),
+ const char* file_name = FCP_BUILTIN_FILE()) {
+ return SourceLocationImpl(line, file_name);
+ }
+ constexpr int line() const { return line_; }
+ constexpr const char* file_name() const { return file_name_; }
+
+ private:
+ constexpr SourceLocationImpl(int line, const char* file_name)
+ : line_(line), file_name_(file_name) {}
+ int line_;
+ const char* file_name_;
+};
+
+using SourceLocation = SourceLocationImpl;
+
+} // namespace fcp
+
+#endif // FCP_BASE_SOURCE_LOCATION_H_
diff --git a/fcp/base/source_location_test.cc b/fcp/base/source_location_test.cc
new file mode 100644
index 0000000..743248d
--- /dev/null
+++ b/fcp/base/source_location_test.cc
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/source_location.h"
+
+#include "gtest/gtest.h"
+
+namespace fcp {
+
+SourceLocation Foo() { return SourceLocation::current(); }
+
+SourceLocation Bar() { return SourceLocation::current(); }
+
+TEST(SourceLocation, Test) {
+ EXPECT_EQ(Foo().line(), Foo().line());
+ EXPECT_EQ(Bar().line(), Bar().line());
+ EXPECT_EQ(Bar().file_name(), Bar().file_name());
+#ifdef FCP_HAS_SOURCE_LOCATION
+ EXPECT_NE(Foo().line(), Bar().line());
+#endif
+}
+
+} // namespace fcp
diff --git a/fcp/base/status_converters.cc b/fcp/base/status_converters.cc
new file mode 100644
index 0000000..25630c1
--- /dev/null
+++ b/fcp/base/status_converters.cc
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/base/status_converters.h"
+
+#include "grpcpp/support/status.h"
+
+namespace fcp {
+namespace base {
+
+#define MAP_FROM_GRPC_STATUS(grpc_name, absl_name) \
+ case grpc::StatusCode::grpc_name: \
+ return StatusCode::absl_name;
+
+#define MAP_TO_GRPC_STATUS(absl_name, grpc_name) \
+ case StatusCode::absl_name: \
+ return grpc::StatusCode::grpc_name;
+
+StatusCode FromGrpcStatusCode(grpc::StatusCode code) {
+ switch (code) {
+ MAP_FROM_GRPC_STATUS(OK, kOk)
+ MAP_FROM_GRPC_STATUS(CANCELLED, kCancelled)
+ MAP_FROM_GRPC_STATUS(UNKNOWN, kUnknown)
+ MAP_FROM_GRPC_STATUS(INVALID_ARGUMENT, kInvalidArgument)
+ MAP_FROM_GRPC_STATUS(DEADLINE_EXCEEDED, kDeadlineExceeded)
+ MAP_FROM_GRPC_STATUS(NOT_FOUND, kNotFound)
+ MAP_FROM_GRPC_STATUS(ALREADY_EXISTS, kAlreadyExists)
+ MAP_FROM_GRPC_STATUS(PERMISSION_DENIED, kPermissionDenied)
+ MAP_FROM_GRPC_STATUS(UNAUTHENTICATED, kUnauthenticated)
+ MAP_FROM_GRPC_STATUS(RESOURCE_EXHAUSTED, kResourceExhausted)
+ MAP_FROM_GRPC_STATUS(FAILED_PRECONDITION, kFailedPrecondition)
+ MAP_FROM_GRPC_STATUS(ABORTED, kAborted)
+ MAP_FROM_GRPC_STATUS(OUT_OF_RANGE, kOutOfRange)
+ MAP_FROM_GRPC_STATUS(UNIMPLEMENTED, kUnimplemented)
+ MAP_FROM_GRPC_STATUS(INTERNAL, kInternal)
+ MAP_FROM_GRPC_STATUS(UNAVAILABLE, kUnavailable)
+ MAP_FROM_GRPC_STATUS(DATA_LOSS, kDataLoss)
+ default:
+ return StatusCode::kUnknown;
+ }
+}
+
+Status FromGrpcStatus(grpc::Status status) {
+ return Status(FromGrpcStatusCode(status.error_code()),
+ status.error_message());
+}
+
+grpc::StatusCode ToGrpcStatusCode(StatusCode code) {
+ switch (code) {
+ MAP_TO_GRPC_STATUS(kOk, OK)
+ MAP_TO_GRPC_STATUS(kCancelled, CANCELLED)
+ MAP_TO_GRPC_STATUS(kUnknown, UNKNOWN)
+ MAP_TO_GRPC_STATUS(kInvalidArgument, INVALID_ARGUMENT)
+ MAP_TO_GRPC_STATUS(kDeadlineExceeded, DEADLINE_EXCEEDED)
+ MAP_TO_GRPC_STATUS(kNotFound, NOT_FOUND)
+ MAP_TO_GRPC_STATUS(kAlreadyExists, ALREADY_EXISTS)
+ MAP_TO_GRPC_STATUS(kPermissionDenied, PERMISSION_DENIED)
+ MAP_TO_GRPC_STATUS(kUnauthenticated, UNAUTHENTICATED)
+ MAP_TO_GRPC_STATUS(kResourceExhausted, RESOURCE_EXHAUSTED)
+ MAP_TO_GRPC_STATUS(kFailedPrecondition, FAILED_PRECONDITION)
+ MAP_TO_GRPC_STATUS(kAborted, ABORTED)
+ MAP_TO_GRPC_STATUS(kOutOfRange, OUT_OF_RANGE)
+ MAP_TO_GRPC_STATUS(kUnimplemented, UNIMPLEMENTED)
+ MAP_TO_GRPC_STATUS(kInternal, INTERNAL)
+ MAP_TO_GRPC_STATUS(kUnavailable, UNAVAILABLE)
+ MAP_TO_GRPC_STATUS(kDataLoss, DATA_LOSS)
+ default:
+ return grpc::StatusCode::UNKNOWN;
+ }
+}
+
+grpc::Status ToGrpcStatus(Status status) {
+ grpc::StatusCode code = ToGrpcStatusCode(status.code());
+ if (code != grpc::StatusCode::OK) {
+ return grpc::Status(code, std::string(status.message()));
+ }
+
+ return grpc::Status::OK;
+}
+
+} // namespace base
+} // namespace fcp
diff --git a/fcp/base/status_converters.h b/fcp/base/status_converters.h
new file mode 100644
index 0000000..bf77743
--- /dev/null
+++ b/fcp/base/status_converters.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_STATUS_CONVERTERS_H_
+#define FCP_BASE_STATUS_CONVERTERS_H_
+
+#include "grpcpp/support/status.h"
+#include "absl/status/status.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace base {
+
+/**
+ * Converts a grpc::StatusCode to an StatusCode.
+ */
+StatusCode FromGrpcStatusCode(grpc::StatusCode code);
+
+/**
+ * Converts grpc::Status to an fcp::Status.
+ */
+Status FromGrpcStatus(grpc::Status status);
+
+/**
+ * Converts an StatusCode to a grpc::StatusCode.
+ */
+grpc::StatusCode ToGrpcStatusCode(StatusCode code);
+
+/**
+ * Converts fcp::Status to grpc::Status.
+ */
+grpc::Status ToGrpcStatus(Status status);
+
+} // namespace base
+} // namespace fcp
+
+#endif // FCP_BASE_STATUS_CONVERTERS_H_
diff --git a/fcp/base/string_stream.cc b/fcp/base/string_stream.cc
new file mode 100644
index 0000000..60aad6b
--- /dev/null
+++ b/fcp/base/string_stream.cc
@@ -0,0 +1,54 @@
+#include "fcp/base/string_stream.h"
+
+#include <stdio.h>
+
+#include <string>
+
+namespace fcp {
+namespace internal {
+
+StringStream& StringStream::operator<<(bool x) {
+ return *this << (x ? "true" : "false");
+}
+
+template <typename T>
+StringStream& AppendWithFormat(StringStream& string_buffer, const char* format,
+ T value) {
+ char buf[64];
+ // The buffer is large enough for any possible value.
+ sprintf(buf, format, value); // NOLINT
+ return string_buffer << buf;
+}
+
+StringStream& StringStream::operator<<(int32_t x) {
+ return AppendWithFormat(*this, "%d", x);
+}
+
+StringStream& StringStream::operator<<(uint32_t x) {
+ return AppendWithFormat(*this, "%u", x);
+}
+
+StringStream& StringStream::operator<<(int64_t x) {
+ return AppendWithFormat(*this, "%ld", x);
+}
+
+StringStream& StringStream::operator<<(uint64_t x) {
+ return AppendWithFormat(*this, "%lu", x);
+}
+
+StringStream& StringStream::operator<<(double x) {
+ return AppendWithFormat(*this, "%f", x);
+}
+
+StringStream& StringStream::operator<<(const char* x) {
+ str_.append(x);
+ return *this;
+}
+
+StringStream& StringStream::operator<<(const std::string& x) {
+ str_.append(x);
+ return *this;
+}
+
+} // namespace internal
+} // namespace fcp
diff --git a/fcp/base/string_stream.h b/fcp/base/string_stream.h
new file mode 100644
index 0000000..d5bc911
--- /dev/null
+++ b/fcp/base/string_stream.h
@@ -0,0 +1,41 @@
+#ifndef FCP_BASE_STRING_STREAM_H_
+#define FCP_BASE_STRING_STREAM_H_
+
+#include <string>
+
+#ifndef FCP_BAREMETAL
+static_assert(false,
+ "StringStream should be used only when building FCP in bare "
+ "metal configuration.");
+#endif
+
+namespace fcp {
+namespace internal {
+
+// This is a class used for building diagnostic messages with
+// FCP_LOG and FCP_STATUS macros.
+// The class is designed to be a simplified replacement for std::ostringstream.
+class StringStream final {
+ public:
+ StringStream() = default;
+ explicit StringStream(const std::string& str) : str_(str) {}
+
+ StringStream& operator<<(bool x);
+ StringStream& operator<<(int32_t x);
+ StringStream& operator<<(uint32_t x);
+ StringStream& operator<<(int64_t x);
+ StringStream& operator<<(uint64_t x);
+ StringStream& operator<<(double x);
+ StringStream& operator<<(const char* x);
+ StringStream& operator<<(const std::string& x);
+
+ std::string str() const { return str_; }
+
+ private:
+ std::string str_;
+};
+
+} // namespace internal
+} // namespace fcp
+
+#endif // FCP_BASE_STRING_STREAM_H_
diff --git a/fcp/base/string_stream_test.cc b/fcp/base/string_stream_test.cc
new file mode 100644
index 0000000..1136841
--- /dev/null
+++ b/fcp/base/string_stream_test.cc
@@ -0,0 +1,20 @@
+#include "fcp/base/string_stream.h"
+
+#include <string>
+
+#include "gtest/gtest.h"
+
+namespace fcp {
+namespace internal {
+namespace {
+
+TEST(StringTest, Basic) {
+ StringStream s;
+ s << "A" << 1 << std::string("b") << 2U << "c" << -3L << "d" << 3.5f << "e"
+ << -3.14 << "f" << 9UL << true << ":" << false;
+ EXPECT_EQ(s.str(), "A1b2c-3d3.500000e-3.140000f9true:false");
+}
+
+} // namespace
+} // namespace internal
+} // namespace fcp
diff --git a/fcp/base/time_util.cc b/fcp/base/time_util.cc
new file mode 100644
index 0000000..14e9be2
--- /dev/null
+++ b/fcp/base/time_util.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/base/time_util.h"
+
+#include <limits>
+
+#include "absl/time/time.h"
+
+namespace fcp {
+
+google::protobuf::Timestamp TimeUtil::ConvertAbslToProtoTimestamp(
+ absl::Time t) {
+ google::protobuf::Timestamp proto_timestamp;
+ const int64_t s = absl::ToUnixSeconds(t);
+ proto_timestamp.set_seconds(s);
+ // The nanos field can only range from 0 to 1e9 - 1 so conversion to int32 is
+ // fine.
+ proto_timestamp.set_nanos((t - absl::FromUnixSeconds(s)) /
+ absl::Nanoseconds(1));
+ return proto_timestamp;
+}
+
+absl::Time TimeUtil::ConvertProtoToAbslTime(google::protobuf::Timestamp proto) {
+ return absl::FromUnixSeconds(proto.seconds()) +
+ absl::Nanoseconds(proto.nanos());
+}
+
+google::protobuf::Duration TimeUtil::ConvertAbslToProtoDuration(
+ absl::Duration absl_duration) {
+ google::protobuf::Duration proto_duration;
+ if (absl_duration == absl::InfiniteDuration()) {
+ proto_duration.set_seconds(std::numeric_limits<int64_t>::max());
+ proto_duration.set_nanos(static_cast<int32_t>(999999999));
+ } else if (absl_duration == -absl::InfiniteDuration()) {
+ proto_duration.set_seconds(std::numeric_limits<int64_t>::min());
+ proto_duration.set_nanos(static_cast<int32_t>(-999999999));
+ } else {
+ // s and n may both be negative, per the Duration proto spec.
+ const int64_t s =
+ absl::IDivDuration(absl_duration, absl::Seconds(1), &absl_duration);
+ const int64_t n =
+ absl::IDivDuration(absl_duration, absl::Nanoseconds(1), &absl_duration);
+ proto_duration.set_seconds(s);
+ proto_duration.set_nanos(static_cast<int32_t>(n));
+ }
+ return proto_duration;
+}
+
+absl::Duration TimeUtil::ConvertProtoToAbslDuration(
+ google::protobuf::Duration proto) {
+ return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos());
+}
+
+} // namespace fcp
diff --git a/fcp/base/time_util.h b/fcp/base/time_util.h
new file mode 100644
index 0000000..1c6d703
--- /dev/null
+++ b/fcp/base/time_util.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_BASE_TIME_UTIL_H_
+#define FCP_BASE_TIME_UTIL_H_
+
+#include "google/protobuf/duration.pb.h"
+#include "google/protobuf/timestamp.pb.h"
+#include "absl/time/time.h"
+
+namespace fcp {
+
+class TimeUtil {
+ public:
+ // Converts an absl::Time to a google::protobuf::Timestamp.
+ // Note that we assume the timestamps we deal with here are representable by
+ // both formats. If the resulted google::protobuf::Timestamp is invalid, it
+ // will lead to undefined behavior.
+ static google::protobuf::Timestamp ConvertAbslToProtoTimestamp(absl::Time t);
+
+ // Converts a google::protobuf::Timestamp to an absl::Time.
+ // Note that we assume the timestamps we deal with here are representable by
+ // both formats. If the resulted absl::Time is invalid, it will lead to
+ // undefined behavior.
+ static absl::Time ConvertProtoToAbslTime(google::protobuf::Timestamp proto);
+
+ // Converts an absl::Duration to a google::protobuf::Duration.
+ // Note that we assume the durations we deal with here are representable by
+ // both formats. If the resulted google::protobuf::Duration is invalid, it
+ // will lead to undefined behavior.
+ static google::protobuf::Duration ConvertAbslToProtoDuration(
+ absl::Duration absl_duration);
+
+ // Converts a google::protobuf::Duration to an absl::Duration.
+ // Note that we assume the timestamps we deal with here are representable by
+ // both formats. If the resulted google::protobuf::Duration is invalid, it
+ // will lead to undefined behavior.
+ static absl::Duration ConvertProtoToAbslDuration(
+ google::protobuf::Duration proto);
+};
+
+} // namespace fcp
+
+#endif // FCP_BASE_TIME_UTIL_H_
diff --git a/fcp/base/time_util_test.cc b/fcp/base/time_util_test.cc
new file mode 100644
index 0000000..728f478
--- /dev/null
+++ b/fcp/base/time_util_test.cc
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/base/time_util.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace {
+
+TEST(ConvertAbslToProtoTimestampTest, ConvertSuccessfully) {
+ absl::Time time = absl::FromUnixSeconds(1000) + absl::Nanoseconds(3);
+ google::protobuf::Timestamp expected_timestamp;
+ expected_timestamp.set_seconds(1000L);
+ expected_timestamp.set_nanos(3);
+ EXPECT_THAT(TimeUtil::ConvertAbslToProtoTimestamp(time),
+ EqualsProto(expected_timestamp));
+}
+
+TEST(ConvertProtoToAbslTimeTest, ConvertSuccessfully) {
+ google::protobuf::Timestamp timestamp;
+ timestamp.set_seconds(1000L);
+ timestamp.set_nanos(3);
+ absl::Time expected_time = absl::FromUnixSeconds(1000) + absl::Nanoseconds(3);
+ EXPECT_EQ(TimeUtil::ConvertProtoToAbslTime(timestamp), expected_time);
+}
+
+TEST(ConvertAbslToProtoDurationTest, ConvertSuccessfully) {
+ absl::Duration duration = absl::Seconds(1000) + absl::Nanoseconds(3);
+ google::protobuf::Duration expected_duration;
+ expected_duration.set_seconds(1000L);
+ expected_duration.set_nanos(3);
+ EXPECT_THAT(TimeUtil::ConvertAbslToProtoDuration(duration),
+ EqualsProto(expected_duration));
+}
+
+TEST(ConvertProtoToAbslDurationTest, ConvertSuccessfully) {
+ google::protobuf::Duration duration;
+ duration.set_seconds(1000L);
+ duration.set_nanos(3);
+ absl::Duration expected_duration = absl::Seconds(1000) + absl::Nanoseconds(3);
+ EXPECT_EQ(TimeUtil::ConvertProtoToAbslDuration(duration), expected_duration);
+}
+
+} // anonymous namespace
+} // namespace fcp
diff --git a/fcp/base/tracing_schema.fbs b/fcp/base/tracing_schema.fbs
new file mode 100644
index 0000000..a311681
--- /dev/null
+++ b/fcp/base/tracing_schema.fbs
@@ -0,0 +1,55 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+include "fcp/tracing/tracing_schema_common.fbs";
+
+// Keep in sync with absl status in third_party/absl/status/status.h
+enum TracingStatusCode : short {
+ Ok,
+ Cancelled,
+ Unknown,
+ InvalidArgument,
+ DeadlineExceeded,
+ NotFound,
+ AlreadyExists,
+ PermissionDenied,
+ ResourceExhausted,
+ FailedPrecondition,
+ Aborted,
+ OutOfRange,
+ Unimplemented,
+ Internal,
+ Unavailable,
+ DataLoss,
+ Unauthenticated,
+}
+
+table ResultExpectError (tag: "!EXP", error) {
+ expectation: string;
+ file_name: string;
+ line: int32;
+}
+
+table ResultExpectStatusError (tag: "STAT", error) {
+ // TODO(team): Pull out status fields into status struct.
+ expected_code: TracingStatusCode;
+ actual_code: TracingStatusCode;
+ message: string;
+ file_name: string;
+ line: int32;
+}
+
+table ProtoParseFailure (tag: "PRPR", error) {
+ type: string;
+}
diff --git a/fcp/base/tracing_schema.h b/fcp/base/tracing_schema.h
new file mode 100644
index 0000000..d936f1e
--- /dev/null
+++ b/fcp/base/tracing_schema.h
@@ -0,0 +1,134 @@
+// Autogenerated by tracing_traits_generator, do not edit
+
+#ifndef FCP_BASE_TRACING_SCHEMA_H
+#define FCP_BASE_TRACING_SCHEMA_H
+
+#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#endif
+#include "fcp/base/tracing_schema_generated.h"
+#include "absl/strings/string_view.h"
+#include "fcp/tracing/tracing_severity.h"
+#include "fcp/tracing/tracing_traits.h"
+#include "flatbuffers/minireflect.h"
+#include "flatbuffers/idl.h"
+#include "fcp/base/platform.h"
+
+namespace fcp {
+
+template<> class TracingTraits<ProtoParseFailure>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("PRPR");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kError;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "ProtoParseFailure"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kError;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), ProtoParseFailureTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "fcp/base/tracing_schema.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("fcp/tracing/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("fcp/tracing/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("ProtoParseFailure");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<ProtoParseFailure> Create(absl::string_view type, flatbuffers::FlatBufferBuilder* fbb) {
+ auto type__ = fbb->CreateString(type.data(), type.size());
+ return CreateProtoParseFailure(*fbb, type__);
+ }
+ using TupleType = std::tuple<std::string>;
+ static TupleType MakeTuple(const ProtoParseFailure* table) {
+ return std::make_tuple(table->type()->str());
+ }
+};
+static internal::TracingTraitsRegistrar<ProtoParseFailure> registrar_ProtoParseFailure;
+template<> class TracingTraits<ResultExpectError>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("!EXP");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kError;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "ResultExpectError"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kError;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), ResultExpectErrorTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "fcp/base/tracing_schema.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("fcp/tracing/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("fcp/tracing/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("ResultExpectError");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<ResultExpectError> Create(absl::string_view expectation, absl::string_view file_name, std::int32_t line, flatbuffers::FlatBufferBuilder* fbb) {
+ auto expectation__ = fbb->CreateString(expectation.data(), expectation.size());
+ auto file_name__ = fbb->CreateString(file_name.data(), file_name.size());
+ return CreateResultExpectError(*fbb, expectation__, file_name__, line);
+ }
+ using TupleType = std::tuple<std::string, std::string, std::int32_t>;
+ static TupleType MakeTuple(const ResultExpectError* table) {
+ return std::make_tuple(table->expectation()->str(), table->file_name()->str(), table->line());
+ }
+};
+static internal::TracingTraitsRegistrar<ResultExpectError> registrar_ResultExpectError;
+template<> class TracingTraits<ResultExpectStatusError>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("STAT");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kError;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "ResultExpectStatusError"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kError;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), ResultExpectStatusErrorTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "fcp/base/tracing_schema.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("fcp/tracing/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("fcp/tracing/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("ResultExpectStatusError");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<ResultExpectStatusError> Create(TracingStatusCode expected_code, TracingStatusCode actual_code, absl::string_view message, absl::string_view file_name, std::int32_t line, flatbuffers::FlatBufferBuilder* fbb) {
+ auto message__ = fbb->CreateString(message.data(), message.size());
+ auto file_name__ = fbb->CreateString(file_name.data(), file_name.size());
+ return CreateResultExpectStatusError(*fbb, expected_code, actual_code, message__, file_name__, line);
+ }
+ using TupleType = std::tuple<TracingStatusCode, TracingStatusCode, std::string, std::string, std::int32_t>;
+ static TupleType MakeTuple(const ResultExpectStatusError* table) {
+ return std::make_tuple(table->expected_code(), table->actual_code(), table->message()->str(), table->file_name()->str(), table->line());
+ }
+};
+static internal::TracingTraitsRegistrar<ResultExpectStatusError> registrar_ResultExpectStatusError;
+} // namespace fcp
+
+#endif // FCP_BASE_TRACING_SCHEMA_H \ No newline at end of file
diff --git a/fcp/base/tracing_schema_generated.h b/fcp/base/tracing_schema_generated.h
new file mode 100644
index 0000000..f5cdc71
--- /dev/null
+++ b/fcp/base/tracing_schema_generated.h
@@ -0,0 +1,586 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#define FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+// Ensure the included flatbuffers.h is the same version as when this file was
+// generated, otherwise it may not be compatible.
+static_assert(FLATBUFFERS_VERSION_MAJOR == 2 &&
+ FLATBUFFERS_VERSION_MINOR == 0 &&
+ FLATBUFFERS_VERSION_REVISION == 7,
+ "Non-compatible flatbuffers version included");
+
+#include "fcp/tracing/tracing_schema_common_generated.h"
+
+struct ResultExpectError;
+struct ResultExpectErrorBuilder;
+struct ResultExpectErrorT;
+
+struct ResultExpectStatusError;
+struct ResultExpectStatusErrorBuilder;
+struct ResultExpectStatusErrorT;
+
+struct ProtoParseFailure;
+struct ProtoParseFailureBuilder;
+struct ProtoParseFailureT;
+
+inline const flatbuffers::TypeTable *ResultExpectErrorTypeTable();
+
+inline const flatbuffers::TypeTable *ResultExpectStatusErrorTypeTable();
+
+inline const flatbuffers::TypeTable *ProtoParseFailureTypeTable();
+
+enum TracingStatusCode : int16_t {
+ TracingStatusCode_Ok = 0,
+ TracingStatusCode_Cancelled = 1,
+ TracingStatusCode_Unknown = 2,
+ TracingStatusCode_InvalidArgument = 3,
+ TracingStatusCode_DeadlineExceeded = 4,
+ TracingStatusCode_NotFound = 5,
+ TracingStatusCode_AlreadyExists = 6,
+ TracingStatusCode_PermissionDenied = 7,
+ TracingStatusCode_ResourceExhausted = 8,
+ TracingStatusCode_FailedPrecondition = 9,
+ TracingStatusCode_Aborted = 10,
+ TracingStatusCode_OutOfRange = 11,
+ TracingStatusCode_Unimplemented = 12,
+ TracingStatusCode_Internal = 13,
+ TracingStatusCode_Unavailable = 14,
+ TracingStatusCode_DataLoss = 15,
+ TracingStatusCode_Unauthenticated = 16,
+ TracingStatusCode_MIN = TracingStatusCode_Ok,
+ TracingStatusCode_MAX = TracingStatusCode_Unauthenticated
+};
+
+inline const TracingStatusCode (&EnumValuesTracingStatusCode())[17] {
+ static const TracingStatusCode values[] = {
+ TracingStatusCode_Ok,
+ TracingStatusCode_Cancelled,
+ TracingStatusCode_Unknown,
+ TracingStatusCode_InvalidArgument,
+ TracingStatusCode_DeadlineExceeded,
+ TracingStatusCode_NotFound,
+ TracingStatusCode_AlreadyExists,
+ TracingStatusCode_PermissionDenied,
+ TracingStatusCode_ResourceExhausted,
+ TracingStatusCode_FailedPrecondition,
+ TracingStatusCode_Aborted,
+ TracingStatusCode_OutOfRange,
+ TracingStatusCode_Unimplemented,
+ TracingStatusCode_Internal,
+ TracingStatusCode_Unavailable,
+ TracingStatusCode_DataLoss,
+ TracingStatusCode_Unauthenticated
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesTracingStatusCode() {
+ static const char * const names[18] = {
+ "Ok",
+ "Cancelled",
+ "Unknown",
+ "InvalidArgument",
+ "DeadlineExceeded",
+ "NotFound",
+ "AlreadyExists",
+ "PermissionDenied",
+ "ResourceExhausted",
+ "FailedPrecondition",
+ "Aborted",
+ "OutOfRange",
+ "Unimplemented",
+ "Internal",
+ "Unavailable",
+ "DataLoss",
+ "Unauthenticated",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameTracingStatusCode(TracingStatusCode e) {
+ if (flatbuffers::IsOutRange(e, TracingStatusCode_Ok, TracingStatusCode_Unauthenticated)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesTracingStatusCode()[index];
+}
+
+struct ResultExpectErrorT : public flatbuffers::NativeTable {
+ typedef ResultExpectError TableType;
+ std::string expectation{};
+ std::string file_name{};
+ int32_t line = 0;
+};
+
+struct ResultExpectError FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ResultExpectErrorT NativeTableType;
+ typedef ResultExpectErrorBuilder Builder;
+ static const flatbuffers::TypeTable *MiniReflectTypeTable() {
+ return ResultExpectErrorTypeTable();
+ }
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_EXPECTATION = 4,
+ VT_FILE_NAME = 6,
+ VT_LINE = 8
+ };
+ const flatbuffers::String *expectation() const {
+ return GetPointer<const flatbuffers::String *>(VT_EXPECTATION);
+ }
+ const flatbuffers::String *file_name() const {
+ return GetPointer<const flatbuffers::String *>(VT_FILE_NAME);
+ }
+ int32_t line() const {
+ return GetField<int32_t>(VT_LINE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_EXPECTATION) &&
+ verifier.VerifyString(expectation()) &&
+ VerifyOffset(verifier, VT_FILE_NAME) &&
+ verifier.VerifyString(file_name()) &&
+ VerifyField<int32_t>(verifier, VT_LINE, 4) &&
+ verifier.EndTable();
+ }
+ ResultExpectErrorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ResultExpectErrorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ResultExpectError> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResultExpectErrorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ResultExpectErrorBuilder {
+ typedef ResultExpectError Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_expectation(flatbuffers::Offset<flatbuffers::String> expectation) {
+ fbb_.AddOffset(ResultExpectError::VT_EXPECTATION, expectation);
+ }
+ void add_file_name(flatbuffers::Offset<flatbuffers::String> file_name) {
+ fbb_.AddOffset(ResultExpectError::VT_FILE_NAME, file_name);
+ }
+ void add_line(int32_t line) {
+ fbb_.AddElement<int32_t>(ResultExpectError::VT_LINE, line, 0);
+ }
+ explicit ResultExpectErrorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ flatbuffers::Offset<ResultExpectError> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ResultExpectError>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ResultExpectError> CreateResultExpectError(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> expectation = 0,
+ flatbuffers::Offset<flatbuffers::String> file_name = 0,
+ int32_t line = 0) {
+ ResultExpectErrorBuilder builder_(_fbb);
+ builder_.add_line(line);
+ builder_.add_file_name(file_name);
+ builder_.add_expectation(expectation);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ResultExpectError> CreateResultExpectErrorDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *expectation = nullptr,
+ const char *file_name = nullptr,
+ int32_t line = 0) {
+ auto expectation__ = expectation ? _fbb.CreateString(expectation) : 0;
+ auto file_name__ = file_name ? _fbb.CreateString(file_name) : 0;
+ return CreateResultExpectError(
+ _fbb,
+ expectation__,
+ file_name__,
+ line);
+}
+
+flatbuffers::Offset<ResultExpectError> CreateResultExpectError(flatbuffers::FlatBufferBuilder &_fbb, const ResultExpectErrorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct ResultExpectStatusErrorT : public flatbuffers::NativeTable {
+ typedef ResultExpectStatusError TableType;
+ TracingStatusCode expected_code = TracingStatusCode_Ok;
+ TracingStatusCode actual_code = TracingStatusCode_Ok;
+ std::string message{};
+ std::string file_name{};
+ int32_t line = 0;
+};
+
+struct ResultExpectStatusError FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ResultExpectStatusErrorT NativeTableType;
+ typedef ResultExpectStatusErrorBuilder Builder;
+ static const flatbuffers::TypeTable *MiniReflectTypeTable() {
+ return ResultExpectStatusErrorTypeTable();
+ }
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_EXPECTED_CODE = 4,
+ VT_ACTUAL_CODE = 6,
+ VT_MESSAGE = 8,
+ VT_FILE_NAME = 10,
+ VT_LINE = 12
+ };
+ TracingStatusCode expected_code() const {
+ return static_cast<TracingStatusCode>(GetField<int16_t>(VT_EXPECTED_CODE, 0));
+ }
+ TracingStatusCode actual_code() const {
+ return static_cast<TracingStatusCode>(GetField<int16_t>(VT_ACTUAL_CODE, 0));
+ }
+ const flatbuffers::String *message() const {
+ return GetPointer<const flatbuffers::String *>(VT_MESSAGE);
+ }
+ const flatbuffers::String *file_name() const {
+ return GetPointer<const flatbuffers::String *>(VT_FILE_NAME);
+ }
+ int32_t line() const {
+ return GetField<int32_t>(VT_LINE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_EXPECTED_CODE, 2) &&
+ VerifyField<int16_t>(verifier, VT_ACTUAL_CODE, 2) &&
+ VerifyOffset(verifier, VT_MESSAGE) &&
+ verifier.VerifyString(message()) &&
+ VerifyOffset(verifier, VT_FILE_NAME) &&
+ verifier.VerifyString(file_name()) &&
+ VerifyField<int32_t>(verifier, VT_LINE, 4) &&
+ verifier.EndTable();
+ }
+ ResultExpectStatusErrorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ResultExpectStatusErrorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ResultExpectStatusError> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResultExpectStatusErrorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ResultExpectStatusErrorBuilder {
+ typedef ResultExpectStatusError Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_expected_code(TracingStatusCode expected_code) {
+ fbb_.AddElement<int16_t>(ResultExpectStatusError::VT_EXPECTED_CODE, static_cast<int16_t>(expected_code), 0);
+ }
+ void add_actual_code(TracingStatusCode actual_code) {
+ fbb_.AddElement<int16_t>(ResultExpectStatusError::VT_ACTUAL_CODE, static_cast<int16_t>(actual_code), 0);
+ }
+ void add_message(flatbuffers::Offset<flatbuffers::String> message) {
+ fbb_.AddOffset(ResultExpectStatusError::VT_MESSAGE, message);
+ }
+ void add_file_name(flatbuffers::Offset<flatbuffers::String> file_name) {
+ fbb_.AddOffset(ResultExpectStatusError::VT_FILE_NAME, file_name);
+ }
+ void add_line(int32_t line) {
+ fbb_.AddElement<int32_t>(ResultExpectStatusError::VT_LINE, line, 0);
+ }
+ explicit ResultExpectStatusErrorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ flatbuffers::Offset<ResultExpectStatusError> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ResultExpectStatusError>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ResultExpectStatusError> CreateResultExpectStatusError(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ TracingStatusCode expected_code = TracingStatusCode_Ok,
+ TracingStatusCode actual_code = TracingStatusCode_Ok,
+ flatbuffers::Offset<flatbuffers::String> message = 0,
+ flatbuffers::Offset<flatbuffers::String> file_name = 0,
+ int32_t line = 0) {
+ ResultExpectStatusErrorBuilder builder_(_fbb);
+ builder_.add_line(line);
+ builder_.add_file_name(file_name);
+ builder_.add_message(message);
+ builder_.add_actual_code(actual_code);
+ builder_.add_expected_code(expected_code);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ResultExpectStatusError> CreateResultExpectStatusErrorDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ TracingStatusCode expected_code = TracingStatusCode_Ok,
+ TracingStatusCode actual_code = TracingStatusCode_Ok,
+ const char *message = nullptr,
+ const char *file_name = nullptr,
+ int32_t line = 0) {
+ auto message__ = message ? _fbb.CreateString(message) : 0;
+ auto file_name__ = file_name ? _fbb.CreateString(file_name) : 0;
+ return CreateResultExpectStatusError(
+ _fbb,
+ expected_code,
+ actual_code,
+ message__,
+ file_name__,
+ line);
+}
+
+flatbuffers::Offset<ResultExpectStatusError> CreateResultExpectStatusError(flatbuffers::FlatBufferBuilder &_fbb, const ResultExpectStatusErrorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct ProtoParseFailureT : public flatbuffers::NativeTable {
+ typedef ProtoParseFailure TableType;
+ std::string type{};
+};
+
+struct ProtoParseFailure FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ProtoParseFailureT NativeTableType;
+ typedef ProtoParseFailureBuilder Builder;
+ static const flatbuffers::TypeTable *MiniReflectTypeTable() {
+ return ProtoParseFailureTypeTable();
+ }
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_TYPE = 4
+ };
+ const flatbuffers::String *type() const {
+ return GetPointer<const flatbuffers::String *>(VT_TYPE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_TYPE) &&
+ verifier.VerifyString(type()) &&
+ verifier.EndTable();
+ }
+ ProtoParseFailureT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ProtoParseFailureT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ProtoParseFailure> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ProtoParseFailureT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ProtoParseFailureBuilder {
+ typedef ProtoParseFailure Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_type(flatbuffers::Offset<flatbuffers::String> type) {
+ fbb_.AddOffset(ProtoParseFailure::VT_TYPE, type);
+ }
+ explicit ProtoParseFailureBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ flatbuffers::Offset<ProtoParseFailure> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ProtoParseFailure>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ProtoParseFailure> CreateProtoParseFailure(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> type = 0) {
+ ProtoParseFailureBuilder builder_(_fbb);
+ builder_.add_type(type);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ProtoParseFailure> CreateProtoParseFailureDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *type = nullptr) {
+ auto type__ = type ? _fbb.CreateString(type) : 0;
+ return CreateProtoParseFailure(
+ _fbb,
+ type__);
+}
+
+flatbuffers::Offset<ProtoParseFailure> CreateProtoParseFailure(flatbuffers::FlatBufferBuilder &_fbb, const ProtoParseFailureT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+inline ResultExpectErrorT *ResultExpectError::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = std::unique_ptr<ResultExpectErrorT>(new ResultExpectErrorT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void ResultExpectError::UnPackTo(ResultExpectErrorT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = expectation(); if (_e) _o->expectation = _e->str(); }
+ { auto _e = file_name(); if (_e) _o->file_name = _e->str(); }
+ { auto _e = line(); _o->line = _e; }
+}
+
+inline flatbuffers::Offset<ResultExpectError> ResultExpectError::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResultExpectErrorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateResultExpectError(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ResultExpectError> CreateResultExpectError(flatbuffers::FlatBufferBuilder &_fbb, const ResultExpectErrorT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ResultExpectErrorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _expectation = _o->expectation.empty() ? 0 : _fbb.CreateString(_o->expectation);
+ auto _file_name = _o->file_name.empty() ? 0 : _fbb.CreateString(_o->file_name);
+ auto _line = _o->line;
+ return CreateResultExpectError(
+ _fbb,
+ _expectation,
+ _file_name,
+ _line);
+}
+
+inline ResultExpectStatusErrorT *ResultExpectStatusError::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = std::unique_ptr<ResultExpectStatusErrorT>(new ResultExpectStatusErrorT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void ResultExpectStatusError::UnPackTo(ResultExpectStatusErrorT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = expected_code(); _o->expected_code = _e; }
+ { auto _e = actual_code(); _o->actual_code = _e; }
+ { auto _e = message(); if (_e) _o->message = _e->str(); }
+ { auto _e = file_name(); if (_e) _o->file_name = _e->str(); }
+ { auto _e = line(); _o->line = _e; }
+}
+
+inline flatbuffers::Offset<ResultExpectStatusError> ResultExpectStatusError::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResultExpectStatusErrorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateResultExpectStatusError(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ResultExpectStatusError> CreateResultExpectStatusError(flatbuffers::FlatBufferBuilder &_fbb, const ResultExpectStatusErrorT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ResultExpectStatusErrorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _expected_code = _o->expected_code;
+ auto _actual_code = _o->actual_code;
+ auto _message = _o->message.empty() ? 0 : _fbb.CreateString(_o->message);
+ auto _file_name = _o->file_name.empty() ? 0 : _fbb.CreateString(_o->file_name);
+ auto _line = _o->line;
+ return CreateResultExpectStatusError(
+ _fbb,
+ _expected_code,
+ _actual_code,
+ _message,
+ _file_name,
+ _line);
+}
+
+inline ProtoParseFailureT *ProtoParseFailure::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = std::unique_ptr<ProtoParseFailureT>(new ProtoParseFailureT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void ProtoParseFailure::UnPackTo(ProtoParseFailureT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = type(); if (_e) _o->type = _e->str(); }
+}
+
+inline flatbuffers::Offset<ProtoParseFailure> ProtoParseFailure::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ProtoParseFailureT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateProtoParseFailure(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ProtoParseFailure> CreateProtoParseFailure(flatbuffers::FlatBufferBuilder &_fbb, const ProtoParseFailureT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ProtoParseFailureT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _type = _o->type.empty() ? 0 : _fbb.CreateString(_o->type);
+ return CreateProtoParseFailure(
+ _fbb,
+ _type);
+}
+
+inline const flatbuffers::TypeTable *TracingStatusCodeTypeTable() {
+ static const flatbuffers::TypeCode type_codes[] = {
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 }
+ };
+ static const flatbuffers::TypeFunction type_refs[] = {
+ TracingStatusCodeTypeTable
+ };
+ static const char * const names[] = {
+ "Ok",
+ "Cancelled",
+ "Unknown",
+ "InvalidArgument",
+ "DeadlineExceeded",
+ "NotFound",
+ "AlreadyExists",
+ "PermissionDenied",
+ "ResourceExhausted",
+ "FailedPrecondition",
+ "Aborted",
+ "OutOfRange",
+ "Unimplemented",
+ "Internal",
+ "Unavailable",
+ "DataLoss",
+ "Unauthenticated"
+ };
+ static const flatbuffers::TypeTable tt = {
+ flatbuffers::ST_ENUM, 17, type_codes, type_refs, nullptr, nullptr, names
+ };
+ return &tt;
+}
+
+inline const flatbuffers::TypeTable *ResultExpectErrorTypeTable() {
+ static const flatbuffers::TypeCode type_codes[] = {
+ { flatbuffers::ET_STRING, 0, -1 },
+ { flatbuffers::ET_STRING, 0, -1 },
+ { flatbuffers::ET_INT, 0, -1 }
+ };
+ static const char * const names[] = {
+ "expectation",
+ "file_name",
+ "line"
+ };
+ static const flatbuffers::TypeTable tt = {
+ flatbuffers::ST_TABLE, 3, type_codes, nullptr, nullptr, nullptr, names
+ };
+ return &tt;
+}
+
+inline const flatbuffers::TypeTable *ResultExpectStatusErrorTypeTable() {
+ static const flatbuffers::TypeCode type_codes[] = {
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_SHORT, 0, 0 },
+ { flatbuffers::ET_STRING, 0, -1 },
+ { flatbuffers::ET_STRING, 0, -1 },
+ { flatbuffers::ET_INT, 0, -1 }
+ };
+ static const flatbuffers::TypeFunction type_refs[] = {
+ TracingStatusCodeTypeTable
+ };
+ static const char * const names[] = {
+ "expected_code",
+ "actual_code",
+ "message",
+ "file_name",
+ "line"
+ };
+ static const flatbuffers::TypeTable tt = {
+ flatbuffers::ST_TABLE, 5, type_codes, type_refs, nullptr, nullptr, names
+ };
+ return &tt;
+}
+
+inline const flatbuffers::TypeTable *ProtoParseFailureTypeTable() {
+ static const flatbuffers::TypeCode type_codes[] = {
+ { flatbuffers::ET_STRING, 0, -1 }
+ };
+ static const char * const names[] = {
+ "type"
+ };
+ static const flatbuffers::TypeTable tt = {
+ flatbuffers::ST_TABLE, 1, type_codes, nullptr, nullptr, nullptr, names
+ };
+ return &tt;
+}
+
+#endif // FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_ \ No newline at end of file
diff --git a/fcp/base/unique_value.h b/fcp/base/unique_value.h
new file mode 100644
index 0000000..e8f7290
--- /dev/null
+++ b/fcp/base/unique_value.h
@@ -0,0 +1,123 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_UNIQUE_VALUE_H_
+#define FCP_BASE_UNIQUE_VALUE_H_
+
+#include <optional>
+#include <utility>
+
+namespace fcp {
+
+/**
+ * UniqueValue<T> provides move-only semantics for some value of type T.
+ *
+ * Its semantics are much like std::unique_ptr, but without requiring an
+ * allocation and pointer indirection (recall a moved-from std::unique_ptr is
+ * reset to nullptr).
+ *
+ * Instead, UniqueValue is represented just like std::optional - but
+ * has_value() == false once moved-from. std::optional does *not* reset when
+ * moved from (even if the wrapped type is move-only); that's consistent, but
+ * not especially desirable.
+ *
+ * Since UniqueValue is always move-only, including a UniqueValue member is
+ * sufficient for a containing aggregate to be move-only.
+ */
+template <typename T>
+class UniqueValue {
+ public:
+ constexpr explicit UniqueValue(std::nullopt_t) : value_() {}
+ constexpr explicit UniqueValue(T val) : value_(std::move(val)) {}
+
+ UniqueValue(UniqueValue const&) = delete;
+ UniqueValue(UniqueValue&& other) : value_(std::move(other.value_)) {
+ other.value_.reset();
+ }
+
+ UniqueValue& operator=(UniqueValue other) {
+ value_.swap(other.value_);
+ return *this;
+ }
+
+ /**
+ * Indicates if this instance holds a value (i.e. has not been moved away).
+ *
+ * It is an error to dereference this UniqueValue if !has_value().
+ */
+ constexpr bool has_value() const {
+ return value_.has_value();
+ }
+
+ constexpr T Take() && {
+ T v = *std::move(value_);
+ value_.reset();
+ return v;
+ }
+
+ constexpr T const& operator*() const & {
+ return *value_;
+ }
+
+ T& operator*() & {
+ return *value_;
+ }
+
+ T const* operator->() const {
+ return &*value_;
+ }
+
+ T* operator->() {
+ return &*value_;
+ }
+
+ /**
+ * Replaces current value with a newly constructed one given constructor
+ * arguments for T (like std::optional::emplace).
+ */
+ template <class... _Args>
+ T& Emplace(_Args&&... __args) {
+ value_.emplace(std::forward<_Args>(__args)...);
+ return *value_;
+ }
+
+ /**
+ * Removes (destructs) a value. No-op if absent;
+ */
+ void Reset() { value_.reset(); }
+
+ private:
+ std::optional<T> value_;
+};
+
+// Deduction guide allowing one to write UniqueValue(x) without an explicit
+// template argument. This one would be implicitly generated; it's here to
+// suppress -Wctad-maybe-unsupported.
+template <typename T>
+UniqueValue(T val) -> UniqueValue<T>;
+
+/**
+ * Makes a UniqueValue<T> given constructor arguments for T
+ * (like std::make_unique).
+ */
+template <typename T, typename... Args>
+constexpr UniqueValue<T> MakeUniqueValue(Args&&... args) {
+ return UniqueValue<T>(T(std::forward<Args>(args)...));
+}
+
+} // namespace fcp
+
+#endif // FCP_BASE_UNIQUE_VALUE_H_
diff --git a/fcp/base/unique_value_test.cc b/fcp/base/unique_value_test.cc
new file mode 100644
index 0000000..f621e8e
--- /dev/null
+++ b/fcp/base/unique_value_test.cc
@@ -0,0 +1,167 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/unique_value.h"
+
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+
+using ::testing::Eq;
+
+struct ValueBox {
+ bool destructed = false;
+ int value = 0;
+};
+
+class TracedValue {
+ public:
+ explicit TracedValue(int value) : local_value_(0), box_(nullptr) {
+ UpdateValue(value);
+ }
+
+ void AttachToBox(ValueBox* box) {
+ FCP_CHECK(box_ == nullptr);
+ box_ = box;
+ UpdateValue(local_value_);
+ }
+
+ TracedValue(TracedValue const& other) : local_value_(0), box_(nullptr) {
+ UpdateValue(other.value());
+ }
+
+ TracedValue& operator=(TracedValue const& other) {
+ UpdateValue(other.value());
+ return *this;
+ }
+
+ ~TracedValue() {
+ if (box_) {
+ box_->destructed = true;
+ }
+ }
+
+ int value() const { return local_value_; }
+
+ private:
+ void UpdateValue(int value) {
+ local_value_ = value;
+ if (box_) {
+ box_->destructed = false;
+ box_->value = value;
+ }
+ }
+
+ int local_value_;
+ ValueBox* box_;
+};
+
+TEST(UniqueValueTest, MoveToInnerScope) {
+ ValueBox box_a{};
+ ValueBox box_b{};
+
+ {
+ UniqueValue<TracedValue> a = MakeUniqueValue<TracedValue>(123);
+ a->AttachToBox(&box_a);
+ EXPECT_THAT(box_a.destructed, Eq(false));
+ EXPECT_THAT(box_a.value, Eq(123));
+
+ {
+ UniqueValue<TracedValue> b = MakeUniqueValue<TracedValue>(456);
+ b->AttachToBox(&box_b);
+ EXPECT_THAT(box_b.destructed, Eq(false));
+ EXPECT_THAT(box_b.value, Eq(456));
+
+ b = std::move(a);
+
+ EXPECT_THAT(box_a.destructed, Eq(true));
+ EXPECT_THAT(box_b.destructed, Eq(false));
+ EXPECT_THAT(box_b.value, Eq(123));
+ }
+
+ EXPECT_THAT(box_a.destructed, Eq(true));
+ EXPECT_THAT(box_b.destructed, Eq(true));
+ }
+}
+
+TEST(UniqueValueTest, MoveToOuterScope) {
+ ValueBox box_a{};
+ ValueBox box_b{};
+
+ {
+ UniqueValue<TracedValue> a = MakeUniqueValue<TracedValue>(123);
+ a->AttachToBox(&box_a);
+ EXPECT_THAT(box_a.destructed, Eq(false));
+ EXPECT_THAT(box_a.value, Eq(123));
+
+ {
+ UniqueValue<TracedValue> b = MakeUniqueValue<TracedValue>(456);
+ b->AttachToBox(&box_b);
+ EXPECT_THAT(box_b.destructed, Eq(false));
+ EXPECT_THAT(box_b.value, Eq(456));
+
+ a = std::move(b);
+
+ EXPECT_THAT(box_a.destructed, Eq(false));
+ EXPECT_THAT(box_a.value, Eq(456));
+ EXPECT_THAT(box_b.destructed, Eq(true));
+ }
+
+ EXPECT_THAT(box_a.destructed, Eq(false));
+ EXPECT_THAT(box_a.value, Eq(456));
+ EXPECT_THAT(box_b.destructed, Eq(true));
+ }
+
+ EXPECT_THAT(box_a.destructed, Eq(true));
+ EXPECT_THAT(box_b.destructed, Eq(true));
+}
+
+TEST(UniqueValueTest, Emplace) {
+ ValueBox box_a{};
+ ValueBox box_b{};
+ {
+ UniqueValue<TracedValue> v{std::nullopt};
+ v.Emplace(123);
+ v->AttachToBox(&box_a);
+ EXPECT_THAT(box_a.destructed, Eq(false));
+ EXPECT_THAT(box_a.value, Eq(123));
+ v.Emplace(321);
+ v->AttachToBox(&box_b);
+ EXPECT_THAT(box_a.destructed, Eq(true));
+ EXPECT_THAT(box_b.destructed, Eq(false));
+ EXPECT_THAT(box_b.value, Eq(321));
+ }
+}
+
+TEST(UniqueValueTest, Reset) {
+ ValueBox box_a{};
+ UniqueValue<TracedValue> v = MakeUniqueValue<TracedValue>(123);
+ v.Emplace(123);
+ v->AttachToBox(&box_a);
+ EXPECT_THAT(box_a.destructed, Eq(false));
+ EXPECT_THAT(box_a.value, Eq(123));
+ v.Reset();
+ EXPECT_THAT(box_a.destructed, Eq(true));
+ v.Reset();
+ EXPECT_THAT(box_a.destructed, Eq(true));
+}
+
+} // namespace fcp
diff --git a/fcp/base/wall_clock_stopwatch.cc b/fcp/base/wall_clock_stopwatch.cc
new file mode 100644
index 0000000..b328554
--- /dev/null
+++ b/fcp/base/wall_clock_stopwatch.cc
@@ -0,0 +1,105 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/wall_clock_stopwatch.h"
+
+#include <memory>
+
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+
+namespace internal {
+class RealWallClockStopwatch : public WallClockStopwatch {
+ public:
+ RealWallClockStopwatch() = default;
+
+ Handle Start() override ABSL_LOCKS_EXCLUDED(mutex_) {
+ return WallClockStopwatch::Handle(this);
+ }
+ absl::Duration GetTotalDuration() const override ABSL_LOCKS_EXCLUDED(mutex_) {
+ absl::MutexLock lock(&mutex_);
+ FCP_CHECK(started_count_ >= 0);
+ if (latest_start_time_ == absl::InfiniteFuture()) {
+ return previous_durations_;
+ }
+ return previous_durations_ + (absl::Now() - latest_start_time_);
+ }
+
+ private:
+ void StartInternal() override ABSL_LOCKS_EXCLUDED(mutex_) {
+ absl::MutexLock lock(&mutex_);
+ FCP_CHECK(started_count_ >= 0);
+ started_count_++;
+ if (started_count_ == 1) {
+ latest_start_time_ = absl::Now();
+ }
+ }
+ void StopInternal() override ABSL_LOCKS_EXCLUDED(mutex_) {
+ absl::MutexLock lock(&mutex_);
+ FCP_CHECK(started_count_ >= 1);
+ started_count_--;
+ if (started_count_ == 0) {
+ previous_durations_ += absl::Now() - latest_start_time_;
+ latest_start_time_ = absl::InfiniteFuture();
+ }
+ }
+
+ mutable absl::Mutex mutex_;
+ int started_count_ ABSL_GUARDED_BY(mutex_) = 0;
+ absl::Time latest_start_time_ ABSL_GUARDED_BY(mutex_) =
+ absl::InfiniteFuture();
+ absl::Duration previous_durations_ ABSL_GUARDED_BY(mutex_) =
+ absl::ZeroDuration();
+};
+
+// A noop stopwatch that does nothing (e.g. for use in tests or to
+// flag-off the measurement of something).
+class NoopWallClockStopwatch : public WallClockStopwatch {
+ public:
+ NoopWallClockStopwatch() = default;
+
+ Handle Start() override { return Handle(nullptr); }
+ absl::Duration GetTotalDuration() const override {
+ return absl::ZeroDuration();
+ }
+};
+} // namespace internal
+
+WallClockStopwatch::Handle::Handle(WallClockStopwatch* stopwatch)
+ : stopwatch_(stopwatch) {
+ if (stopwatch_ != nullptr) {
+ stopwatch_->StartInternal();
+ }
+}
+
+WallClockStopwatch::Handle::~Handle() {
+ if (stopwatch_ != nullptr) {
+ stopwatch_->StopInternal();
+ }
+}
+
+std::unique_ptr<WallClockStopwatch> WallClockStopwatch::Create() {
+ return std::make_unique<internal::RealWallClockStopwatch>();
+}
+
+std::unique_ptr<WallClockStopwatch> WallClockStopwatch::CreateNoop() {
+ return std::make_unique<internal::NoopWallClockStopwatch>();
+}
+
+} // namespace fcp
diff --git a/fcp/base/wall_clock_stopwatch.h b/fcp/base/wall_clock_stopwatch.h
new file mode 100644
index 0000000..a1007e5
--- /dev/null
+++ b/fcp/base/wall_clock_stopwatch.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_BASE_WALL_CLOCK_STOPWATCH_H_
+#define FCP_BASE_WALL_CLOCK_STOPWATCH_H_
+
+#include <memory>
+
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+
+namespace fcp {
+
+namespace internal {
+class RealWallClockStopwatch;
+class NoopWallClockStopwatch;
+} // namespace internal
+
+// A utility for measuring wall clock time across multiple threads.
+//
+// This class is non-reentrant: `Start()` should only be called once per thread
+// (but once `Stop()` has been called, `Start()` may be called again).
+class WallClockStopwatch {
+ public:
+ static std::unique_ptr<WallClockStopwatch> Create();
+ static std::unique_ptr<WallClockStopwatch> CreateNoop();
+ // Disable copy and move semantics.
+ WallClockStopwatch(const WallClockStopwatch&) = delete;
+ WallClockStopwatch& operator=(const WallClockStopwatch&) = delete;
+
+ // A handle that stops the stopwatch once destroyed.
+ class Handle {
+ public:
+ // Disable copy and move semantics.
+ Handle(const Handle&) = delete;
+ Handle& operator=(const Handle&) = delete;
+ ~Handle();
+
+ private:
+ // If `stopwatch` is a nullptr then the Handle that does nothing (for use in
+ // testing or flagging-off the measurement with a real stopwatch).
+ explicit Handle(WallClockStopwatch* stopwatch);
+
+ WallClockStopwatch* const stopwatch_;
+ friend internal::RealWallClockStopwatch;
+ friend internal::NoopWallClockStopwatch;
+ };
+
+ // Start the stopwatch from this thread. If it wasn't running yet from any
+ // other thread, then time will start being accumulated from this point on.
+ // If it was already running from another thread then this call will have no
+ // immediate effect.
+ //
+ // Once the returned Handle is destroyed, the stopwatch is stopped from this
+ // thread. If it isn't running from any other thread, then time will stop
+ // being accumulated from that point on. If it still running from another
+ // thread then Handle destruction will have no immediate effect.
+ virtual Handle Start() = 0;
+
+ // Get the total duration of wall clock time that the stopwatch has run for,
+ // up until this moment (i.e. including any still-ongoing measurement).
+ virtual absl::Duration GetTotalDuration() const = 0;
+
+ virtual ~WallClockStopwatch() = default;
+
+ private:
+ WallClockStopwatch() = default;
+ virtual void StartInternal() {}
+ virtual void StopInternal() {}
+ friend internal::RealWallClockStopwatch;
+ friend internal::NoopWallClockStopwatch;
+};
+
+} // namespace fcp
+
+#endif // FCP_BASE_WALL_CLOCK_STOPWATCH_H_
diff --git a/fcp/base/wall_clock_stopwatch_test.cc b/fcp/base/wall_clock_stopwatch_test.cc
new file mode 100644
index 0000000..f92c73b
--- /dev/null
+++ b/fcp/base/wall_clock_stopwatch_test.cc
@@ -0,0 +1,251 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/base/wall_clock_stopwatch.h"
+
+#include <memory>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "fcp/base/clock.h"
+#include "fcp/base/scheduler.h"
+
+namespace fcp {
+
+using ::testing::AllOf;
+using ::testing::Eq;
+using ::testing::Ge;
+using ::testing::Lt;
+
+TEST(WallClockStopwatchTest, NoopHandle) {
+ // These noop handles should not crash (or do anything).
+ auto stopwatch = WallClockStopwatch::CreateNoop();
+ {
+ auto started_stopwatch1 = stopwatch->Start();
+ auto started_stopwatch2 = stopwatch->Start();
+ }
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Eq(absl::ZeroDuration()));
+}
+
+TEST(WallClockStopwatchTest, ShouldBeInitializedToZero) {
+ auto stopwatch = WallClockStopwatch::Create();
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Eq(absl::ZeroDuration()));
+}
+
+TEST(WallClockStopwatchTest, SingleThreadSingleStart) {
+ auto stopwatch = WallClockStopwatch::Create();
+
+ {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ }
+
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+}
+
+TEST(WallClockStopwatchTest, SingleThreadMultipleSequentialStartStop) {
+ auto stopwatch = WallClockStopwatch::Create();
+
+ {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ }
+
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+
+ absl::SleepFor(absl::Milliseconds(100));
+ // The SleepFor should not be reflect in the measurement, since the stopwatch
+ // was stopped.
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+
+ {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ }
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(200)));
+}
+
+TEST(WallClockStopwatchTest, ShouldReflectOngoingMeasurement) {
+ auto stopwatch = WallClockStopwatch::Create();
+
+ {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+ absl::SleepFor(absl::Milliseconds(100));
+ }
+
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(200)));
+}
+
+TEST(WallClockStopwatchTest, SingleThreadMultipleConcurrentStart) {
+ auto stopwatch = WallClockStopwatch::Create();
+
+ {
+ auto started_stopwatch1 = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+ {
+ auto started_stopwatch2 = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(200)));
+ {
+ auto started_stopwatch3 = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ }
+ }
+ }
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(300)));
+}
+
+/** Tests that the stopwatch truly measures wall clock time, and not the
+ * cumulative (but concurrent) time spent in each separate thread. */
+TEST(WallClockStopwatchTest, ThreeThreadsThreeTasks) {
+ auto stopwatch = WallClockStopwatch::Create();
+ std::unique_ptr<Scheduler> scheduler =
+ CreateThreadPoolScheduler(/*thread_count=*/3);
+
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+ });
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+ });
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+ });
+ scheduler->WaitUntilIdle();
+ // The stopwatch should only have measured ~100ms of wall clock time, since
+ // the three threads will have run concurrently (we use a margin of 50 extra
+ // ms since these can be quite slow when run with ASAN/TSAN).
+ EXPECT_THAT(stopwatch->GetTotalDuration(),
+ AllOf(Ge(absl::Milliseconds(100)), Lt(absl::Milliseconds(150))));
+}
+
+/** Tests that the stopwatch truly measures wall clock time, but this time in a
+ * scenario where there are only 2 threads so the third measurement *will*
+ * happen sequentially. */
+TEST(WallClockStopwatchTest, TwoThreadsThreeTasks) {
+ auto stopwatch = WallClockStopwatch::Create();
+ std::unique_ptr<Scheduler> scheduler =
+ CreateThreadPoolScheduler(/*thread_count=*/2);
+
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+ });
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+ });
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(200)));
+ });
+ scheduler->WaitUntilIdle();
+ // The stopwatch should have measured ~200ms of wall clock time, since the
+ // two threads will have run concurrently but there were three tasks, so the
+ // third task will have run sequentially.
+ EXPECT_THAT(stopwatch->GetTotalDuration(),
+ AllOf(Ge(absl::Milliseconds(200)), Lt(absl::Milliseconds(250))));
+}
+
+/** Tests that the stopwatch handles stop/starts across different threads
+ * correctly, including partially overlapping measurements. */
+TEST(WallClockStopwatchTest, TwoThreadsMultipleOverlappingStartStop) {
+ auto stopwatch = WallClockStopwatch::Create();
+ std::unique_ptr<Scheduler> scheduler =
+ CreateThreadPoolScheduler(/*thread_count=*/2);
+
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(100));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+ });
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(50));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(50)));
+ });
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(50));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(100)));
+ });
+ scheduler->WaitUntilIdle();
+
+ // The stopwatch should have measured ~100ms of wall clock time until now,
+ // since the two threads will have run concurrently and there were three
+ // tasks, which should all have been able to run concurrently within that
+ // time.
+ EXPECT_THAT(stopwatch->GetTotalDuration(),
+ AllOf(Ge(absl::Milliseconds(100)), Lt(absl::Milliseconds(150))));
+
+ absl::SleepFor(absl::Milliseconds(100));
+ // The SleepFor should not be reflected in the measurement since all
+ // stopwatches were stopped.
+ EXPECT_THAT(stopwatch->GetTotalDuration(),
+ AllOf(Ge(absl::Milliseconds(100)), Lt(absl::Milliseconds(150))));
+
+ {
+ auto outer_started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(50));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(150)));
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(200));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(350)));
+ });
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(50));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(200)));
+ });
+ scheduler->Schedule([&stopwatch]() {
+ auto started_stopwatch = stopwatch->Start();
+ absl::SleepFor(absl::Milliseconds(350));
+ EXPECT_THAT(stopwatch->GetTotalDuration(), Ge(absl::Milliseconds(500)));
+ });
+ scheduler->WaitUntilIdle();
+
+ // The stopwatch should have measured ~550ms of wall clock time until now:
+ // the previous ~100ms measurement + 50ms + 50ms + 350ms (the shortest
+ // critical path for the above three tasks).
+ //
+ // Note that the outer stopwatch is still active so the measurement is still
+ // ongoing.
+ EXPECT_THAT(
+ stopwatch->GetTotalDuration(),
+ AllOf(Ge(absl::Milliseconds(550)), Lt(absl::Milliseconds(600))));
+ absl::SleepFor(absl::Milliseconds(100));
+ }
+
+ // The final SleepFor should now also be reflected.
+ EXPECT_THAT(stopwatch->GetTotalDuration(),
+ AllOf(Ge(absl::Milliseconds(650)), Lt(absl::Milliseconds(700))));
+}
+
+} // namespace fcp
diff --git a/fcp/client/BUILD b/fcp/client/BUILD
new file mode 100644
index 0000000..f95c632
--- /dev/null
+++ b/fcp/client/BUILD
@@ -0,0 +1,773 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+load("@org_tensorflow//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
+load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library")
+load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "grpc_bidi_channel",
+ hdrs = ["grpc_bidi_channel.h"],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/base",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "grpc_bidi_channel_test",
+ srcs = ["grpc_bidi_channel_test.cc"],
+ copts = FCP_COPTS,
+ tags = ["local"], # The certificate path is not accessible from a sandbox.
+ deps = [
+ ":grpc_bidi_stream",
+ "//fcp/protos:cc_grpc",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "grpc_bidi_stream",
+ srcs = ["grpc_bidi_stream.cc"],
+ hdrs = ["grpc_bidi_stream.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":grpc_bidi_channel",
+ "//fcp/base",
+ "//fcp/base:status_converters",
+ "//fcp/protocol:grpc_chunked_bidi_stream",
+ "//fcp/protos:cc_grpc",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_test(
+ name = "grpc_bidi_stream_test",
+ srcs = ["grpc_bidi_stream_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":fake_server",
+ ":grpc_bidi_stream",
+ ":test_helpers",
+ "//fcp/base",
+ "//fcp/base:scheduler",
+ "//fcp/testing",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "fake_server",
+ testonly = 1,
+ srcs = ["fake_server.cc"],
+ hdrs = ["fake_server.h"],
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":grpc_bidi_stream",
+ "//fcp/base",
+ "//fcp/base:status_converters",
+ "//fcp/protocol:grpc_chunked_bidi_stream",
+ "//fcp/protos:cc_grpc",
+ "//fcp/protos:federated_api_cc_proto",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Interfaces used by the engine & protocol. These require platform specific
+# implementations.
+cc_library(
+ name = "interfaces",
+ srcs = [],
+ hdrs = [
+ "event_publisher.h",
+ "files.h",
+ "flags.h",
+ "log_manager.h",
+ "secagg_event_publisher.h",
+ "stats.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":diag_codes_cc_proto",
+ ":histogram_counters_cc_proto",
+ "//fcp/client/engine:engine_cc_proto",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "federated_protocol",
+ hdrs = ["federated_protocol.h"],
+ deps = [
+ ":interfaces",
+ "//fcp/client/engine:engine_cc_proto",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/protos/federatedcompute:federated_compute_cc_proto",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ ],
+)
+
+cc_library(
+ name = "federated_protocol_util",
+ srcs = ["federated_protocol_util.cc"],
+ hdrs = ["federated_protocol_util.h"],
+ deps = [
+ ":diag_codes_cc_proto",
+ ":interfaces",
+ "//fcp/base",
+ "//fcp/base:time_util",
+ "//fcp/protos:federated_api_cc_proto",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/time",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "federated_protocol_util_test",
+ srcs = ["federated_protocol_util_test.cc"],
+ deps = [
+ ":diag_codes_cc_proto",
+ ":federated_protocol_util",
+ ":test_helpers",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "grpc_federated_protocol",
+ srcs = ["grpc_federated_protocol.cc"],
+ hdrs = [
+ "grpc_federated_protocol.h",
+ ],
+ deps = [
+ ":diag_codes_cc_proto",
+ ":federated_protocol",
+ ":federated_protocol_util",
+ ":fl_runner_cc_proto",
+ ":grpc_bidi_stream",
+ ":interfaces",
+ ":interruptible_runner",
+ ":secagg_runner",
+ ":selector_context_cc_proto",
+ "//fcp/base",
+ "//fcp/base:time_util",
+ "//fcp/base:wall_clock_stopwatch",
+ "//fcp/client/cache:resource_cache",
+ "//fcp/client/engine:engine_cc_proto",
+ "//fcp/client/http:http_client",
+ "//fcp/client/http:in_memory_request_response",
+ "//fcp/client/opstats:opstats_logger",
+ "//fcp/protocol:grpc_chunked_bidi_stream",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/secagg/client",
+ "//fcp/secagg/client:state_transition_listener",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/shared:cc_proto",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:span",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "grpc_federated_protocol_test",
+ srcs = ["grpc_federated_protocol_test.cc"],
+ deps = [
+ ":diag_codes_cc_proto",
+ ":grpc_bidi_stream",
+ ":grpc_federated_protocol",
+ ":interfaces",
+ ":interruptible_runner",
+ ":test_helpers",
+ "//fcp/base",
+ "//fcp/client/cache:test_helpers",
+ "//fcp/client/engine:engine_cc_proto",
+ "//fcp/client/http:http_client",
+ "//fcp/client/http/testing:test_helpers",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/secagg/client",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/testing:client_mocks",
+ "//fcp/secagg/testing:common_mocks",
+ "//fcp/testing",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "simple_task_environment",
+ srcs = ["simple_task_environment.cc"],
+ hdrs = ["simple_task_environment.h"],
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":selector_context_cc_proto",
+ "//fcp/base",
+ "//fcp/client/http:http_client",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_test(
+ name = "simple_task_environment_test",
+ srcs = ["simple_task_environment_test.cc"],
+ deps = [
+ ":simple_task_environment",
+ ":test_helpers",
+ "//fcp/testing",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+bool_flag(
+ name = "support_tfmobile",
+ build_setting_default = True,
+)
+
+bool_flag(
+ name = "support_grpc",
+ build_setting_default = True,
+)
+
+config_setting(
+ name = "client_support_tfmobile",
+ flag_values = {":support_tfmobile": "True"},
+)
+
+config_setting(
+ name = "client_support_grpc",
+ flag_values = {":support_grpc": "True"},
+)
+
+TF_OPTIONAL_DEPS = select({
+ ":client_support_tfmobile": [
+ "//fcp/client/engine:plan_engine",
+ ],
+ "//conditions:default": [],
+})
+
+TF_OPTIONAL_DEFINES = select({
+ ":client_support_tfmobile": [
+ "FCP_CLIENT_SUPPORT_TFMOBILE",
+ ],
+ "//conditions:default": [],
+})
+
+GRPC_OPTIONAL_DEFINES = select(
+ {
+ ":client_support_grpc": ["FCP_CLIENT_SUPPORT_GRPC"],
+ "//conditions:default": [],
+ },
+)
+
+GRPC_OPTIONAL_DEPS = select(
+ {
+ ":client_support_grpc": [":grpc_federated_protocol"],
+ "//conditions:default": [],
+ },
+)
+
+cc_library(
+ name = "fl_runner",
+ srcs = ["fl_runner.cc"],
+ hdrs = [
+ "fl_runner.h",
+ ],
+ copts = FCP_COPTS,
+ defines = TF_OPTIONAL_DEFINES + GRPC_OPTIONAL_DEFINES,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":federated_protocol",
+ ":federated_protocol_util",
+ ":federated_select",
+ ":fl_runner_cc_proto",
+ ":histogram_counters_cc_proto",
+ ":interfaces",
+ ":interruptible_runner",
+ ":parsing_utils",
+ ":phase_logger",
+ ":phase_logger_impl",
+ ":secagg_runner",
+ ":selector_context_cc_proto",
+ ":simple_task_environment",
+ "//fcp/base",
+ "//fcp/base:clock",
+ "//fcp/client/cache:file_backed_resource_cache",
+ "//fcp/client/cache:resource_cache",
+ "//fcp/client/engine:common",
+ "//fcp/client/engine:engine_cc_proto",
+ "//fcp/client/engine:example_iterator_factory",
+ "//fcp/client/engine:example_query_plan_engine",
+ "//fcp/client/engine:plan_engine_helpers",
+ "//fcp/client/engine:tflite_plan_engine",
+ "//fcp/client/http:http_client",
+ "//fcp/client/http:http_federated_protocol",
+ "//fcp/client/opstats:opstats_example_store",
+ "//fcp/client/opstats:opstats_logger",
+ "//fcp/client/opstats:opstats_utils",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:opstats_cc_proto",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/protos/federatedcompute:federated_compute_cc_proto",
+ "@boringssl//:crypto",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_absl//absl/time",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ] + TF_OPTIONAL_DEPS + GRPC_OPTIONAL_DEPS,
+)
+
+cc_library(
+ name = "lc_runner",
+ srcs = ["lc_runner.cc"],
+ hdrs = [
+ "lc_runner.h",
+ ],
+ copts = FCP_COPTS,
+ defines = TF_OPTIONAL_DEFINES,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":interfaces",
+ ":phase_logger",
+ ":phase_logger_impl",
+ ":selector_context_cc_proto",
+ ":simple_task_environment",
+ "//fcp/base",
+ "//fcp/client/engine:example_iterator_factory",
+ "//fcp/client/engine:plan_engine_helpers",
+ "//fcp/client/engine:tflite_plan_engine",
+ "//fcp/client/opstats:opstats_example_store",
+ "//fcp/client/opstats:opstats_logger",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/time",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ] + TF_OPTIONAL_DEPS,
+)
+
+cc_library(
+ name = "phase_logger",
+ hdrs = ["phase_logger.h"],
+ deps = [
+ ":interfaces",
+ "//fcp/client/opstats:opstats_logger",
+ "//fcp/protos:federated_api_cc_proto",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "phase_logger_impl",
+ srcs = ["phase_logger_impl.cc"],
+ hdrs = ["phase_logger_impl.h"],
+ deps = [
+ ":interfaces",
+ ":phase_logger",
+ "//fcp/base",
+ "//fcp/client/opstats:opstats_logger",
+ "//fcp/protos:federated_api_cc_proto",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_test(
+ name = "phase_logger_impl_test",
+ srcs = ["phase_logger_impl_test.cc"],
+ deps = [
+ ":phase_logger_impl",
+ ":test_helpers",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "secagg_runner",
+ srcs = ["secagg_runner.cc"],
+ hdrs = ["secagg_runner.h"],
+ deps = [
+ ":federated_protocol",
+ ":interfaces",
+ ":interruptible_runner",
+ "//fcp/secagg/client",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/shared:cc_proto",
+ ],
+)
+
+cc_library(
+ name = "federated_select",
+ srcs = ["federated_select.cc"],
+ hdrs = ["federated_select.h"],
+ deps = [
+ ":diag_codes_cc_proto",
+ ":interfaces",
+ ":interruptible_runner",
+ ":simple_task_environment",
+ "//fcp/base",
+ "//fcp/base:wall_clock_stopwatch",
+ "//fcp/client/engine:example_iterator_factory",
+ "//fcp/client/http:http_client",
+ "//fcp/client/http:http_client_util",
+ "//fcp/client/http:in_memory_request_response",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_absl//absl/time",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "parsing_utils",
+ hdrs = ["parsing_utils.h"],
+ deps = ["@com_google_absl//absl/strings:cord"],
+)
+
+cc_test(
+ name = "federated_select_test",
+ srcs = ["federated_select_test.cc"],
+ deps = [
+ ":client_runner",
+ ":diag_codes_cc_proto",
+ ":federated_select",
+ ":interfaces",
+ ":interruptible_runner",
+ ":test_helpers",
+ "//fcp/base",
+ "//fcp/client/engine:example_iterator_factory",
+ "//fcp/client/http:http_client",
+ "//fcp/client/http:in_memory_request_response",
+ "//fcp/client/http/testing:test_helpers",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "interruptible_runner",
+ srcs = ["interruptible_runner.cc"],
+ hdrs = ["interruptible_runner.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":diag_codes_cc_proto",
+ ":interfaces",
+ "//fcp/base",
+ "//fcp/base:future",
+ "//fcp/base:scheduler",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "fake_log_manager",
+ hdrs = ["fake_log_manager.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":diag_codes_cc_proto",
+ ":interfaces",
+ "//fcp/base",
+ ],
+)
+
+cc_library(
+ name = "fake_event_publisher",
+ hdrs = ["fake_event_publisher.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":interfaces",
+ "//fcp/base",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_test(
+ name = "interruptible_runner_test",
+ srcs = ["interruptible_runner_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":diag_codes_cc_proto",
+ ":interruptible_runner",
+ ":test_helpers",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+# Misc. classes to support embedding a fake client in unit or integration tests.
+cc_library(
+ name = "client_runner",
+ testonly = True,
+ hdrs = ["client_runner.h"],
+ deps = [
+ ":client_runner_example_data_cc_proto",
+ ":diag_codes_cc_proto",
+ ":fake_event_publisher",
+ ":histogram_counters_cc_proto",
+ ":interfaces",
+ ":simple_task_environment",
+ "//fcp/base",
+ "//fcp/client/http/curl:curl_http_client",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+# A command line executable running most of the client side code to
+# use for debugging and illustrating an example integration.
+cc_library(
+ name = "client_runner_main_lib",
+ testonly = True,
+ srcs = ["client_runner_main.cc"],
+ deps = [
+ ":client_runner",
+ ":client_runner_example_data_cc_proto",
+ ":fake_event_publisher",
+ ":fl_runner",
+ "//fcp/base",
+ "//fcp/tensorflow:external_dataset_op_lib",
+ "//fcp/tensorflow:task_eligibility_info_ops_lib",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/flags:usage",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "client_runner_main",
+ testonly = True,
+ deps = [
+ ":client_runner_main_lib",
+ "@org_tensorflow//tensorflow/core:tensorflow_opensource",
+ ],
+)
+
+cc_library(
+ name = "test_helpers",
+ testonly = True,
+ srcs = ["test_helpers.cc"],
+ hdrs = ["test_helpers.h"],
+ deps = [
+ ":federated_protocol",
+ ":federated_select",
+ ":interfaces",
+ ":phase_logger",
+ ":secagg_runner",
+ ":simple_task_environment",
+ "//fcp/base",
+ "//fcp/client/engine:example_iterator_factory",
+ "//fcp/client/http:http_client",
+ "//fcp/client/opstats:opstats_db",
+ "//fcp/client/opstats:opstats_logger",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+# Protocol buffers for logging. Those get serialized.
+proto_library(
+ name = "histogram_counters_proto",
+ srcs = ["histogram_counters.proto"],
+)
+
+java_proto_library(
+ name = "histogram_counters_java_proto",
+ deps = [":histogram_counters_proto"],
+)
+
+cc_proto_library(
+ name = "histogram_counters_cc_proto",
+ deps = [":histogram_counters_proto"],
+)
+
+# Protocol buffers for FL Runner. These do not get serialized to disk.
+tf_proto_library(
+ name = "fl_runner_proto",
+ srcs = ["fl_runner.proto"],
+ protodeps = [
+ "//fcp/client/engine:engine_proto",
+ "@org_tensorflow//tensorflow/core:protos_all",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+java_proto_library(
+ name = "fl_runner_java_proto",
+ visibility = ["//visibility:public"],
+ deps = [":fl_runner_proto"],
+)
+
+# Allowing to refer to the cc library generated by the rule above in usual way:
+alias(
+ name = "fl_runner_cc_proto",
+ actual = "fl_runner_proto_cc",
+ visibility = ["//visibility:public"],
+)
+
+# Protocol buffers for logging. Those get serialized.
+proto_library(
+ name = "diag_codes_proto",
+ srcs = ["diag_codes.proto"],
+)
+
+java_proto_library(
+ name = "diag_codes_java_proto",
+ deps = [":diag_codes_proto"],
+)
+
+cc_proto_library(
+ name = "diag_codes_cc_proto",
+ deps = [":diag_codes_proto"],
+)
+
+# Protocol buffers for providing example data to client_runner_main.
+proto_library(
+ name = "client_runner_example_data_proto",
+ testonly = True,
+ srcs = ["client_runner_example_data.proto"],
+)
+
+cc_proto_library(
+ name = "client_runner_example_data_cc_proto",
+ testonly = True,
+ deps = [":client_runner_example_data_proto"],
+)
+
+py_proto_library(
+ name = "client_runner_example_data_py_pb2",
+ testonly = True,
+ deps = [":client_runner_example_data_proto"],
+)
+
+# --------------------------------------------------------------------
+# selector_context.proto
+
+proto_library(
+ name = "selector_context_proto",
+ srcs = ["selector_context.proto"],
+ deps = [
+ "@com_google_protobuf//:timestamp_proto",
+ ],
+)
+
+cc_proto_library(
+ name = "selector_context_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":selector_context_proto"],
+)
+
+java_proto_library(
+ name = "selector_context_java_proto",
+ visibility = ["//visibility:public"],
+ deps = [":selector_context_proto"],
+)
+
+# --------------------------------------------------------------------
+# example_query_result.proto
+
+proto_library(
+ name = "example_query_result_proto",
+ srcs = ["example_query_result.proto"],
+)
+
+java_proto_library(
+ name = "example_query_result_java_proto",
+ deps = [":example_query_result_proto"],
+)
+
+cc_proto_library(
+ name = "example_query_result_cc_proto",
+ deps = [":example_query_result_proto"],
+)
diff --git a/fcp/client/README.md b/fcp/client/README.md
new file mode 100644
index 0000000..e6309ab
--- /dev/null
+++ b/fcp/client/README.md
@@ -0,0 +1,17 @@
+# Federated Computations Client
+
+This directory contains the portable client implementation of Google's platform
+for federated and local computations. A final build of the client will consist
+of
+
+1. The portable core functionality provided by this directory - `:fl_runner`
+ and`:lc_runner` for running federated and local computations. This code
+ contains the network stack and model / query interpreter.
+1. Platform-dependent implementations of the `:interfaces` target. This allows
+ to inject dependencies for e.g. telemetry, attestation, flag-guarding,
+ access to example stores etc.
+
+The stand-alone binary `:client_runner_main` provides a bare bones example of a
+federated computation client. Most practical implementations will wrap the calls
+to the client in a scheduler that respects device constraints and the returned
+retry window.
diff --git a/fcp/client/cache/BUILD b/fcp/client/cache/BUILD
new file mode 100644
index 0000000..d96b16f
--- /dev/null
+++ b/fcp/client/cache/BUILD
@@ -0,0 +1,124 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "temp_files",
+ srcs = ["temp_files.cc"],
+ hdrs = ["temp_files.h"],
+ deps = [
+ "//fcp/base",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:interfaces",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ ],
+)
+
+cc_test(
+ name = "temp_files_test",
+ srcs = ["temp_files_test.cc"],
+ deps = [
+ ":temp_files",
+ "//fcp/client:test_helpers",
+ "//fcp/testing",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+proto_library(
+ name = "cache_manifest_proto",
+ srcs = ["cache_manifest.proto"],
+ deps = [
+ "@com_google_protobuf//:any_proto",
+ "@com_google_protobuf//:timestamp_proto",
+ ],
+)
+
+cc_proto_library(
+ name = "cache_manifest_cc_proto",
+ deps = [":cache_manifest_proto"],
+)
+
+cc_library(
+ name = "resource_cache",
+ hdrs = ["resource_cache.h"],
+ deps = [
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "file_backed_resource_cache",
+ srcs = [
+ "file_backed_resource_cache.cc",
+ ],
+ hdrs = [
+ "file_backed_resource_cache.h",
+ ],
+ deps = [
+ ":cache_manifest_cc_proto",
+ ":resource_cache",
+ "//fcp/base",
+ "//fcp/base:clock",
+ "//fcp/base:time_util",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:interfaces",
+ "@com_google_absl//absl/cleanup",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_absl//absl/time",
+ "@com_google_protobuf//:protobuf",
+ "@protodatastore_cpp//protostore:file-storage",
+ "@protodatastore_cpp//protostore:proto-data-store",
+ ],
+)
+
+cc_test(
+ name = "file_backed_resource_cache_test",
+ srcs = ["file_backed_resource_cache_test.cc"],
+ deps = [
+ ":file_backed_resource_cache",
+ "//fcp/base",
+ "//fcp/base:simulated_clock",
+ "//fcp/client:selector_context_cc_proto",
+ "//fcp/client:test_helpers",
+ "//fcp/testing",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "test_helpers",
+ testonly = 1,
+ hdrs = ["test_helpers.h"],
+ deps = [
+ ":resource_cache",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/client/cache/cache_manifest.proto b/fcp/client/cache/cache_manifest.proto
new file mode 100644
index 0000000..8a4ff29
--- /dev/null
+++ b/fcp/client/cache/cache_manifest.proto
@@ -0,0 +1,41 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.client.cache;
+
+import "google/protobuf/any.proto";
+import "google/protobuf/timestamp.proto";
+
+option java_package = "com.google.intelligence.fcp.client";
+option java_multiple_files = true;
+
+// Maps cache IDs to CachedResource protos containing metadata and
+// the file name of the cached resource.
+message CacheManifest {
+ // A map of `cache_id` to `CachedResource`.
+ map<string, CachedResource> cache = 1;
+}
+
+message CachedResource {
+ // Name of the file holding the cached resource.
+ string file_name = 1;
+ // Serialized metadata proto. This proto should be small.
+ google.protobuf.Any metadata = 2;
+ // Timestamp of when the cached resource should be deleted.
+ google.protobuf.Timestamp expiry_time = 3;
+ // Timestamp of when the cached resource was last accessed.
+ google.protobuf.Timestamp last_accessed_time = 4;
+}
diff --git a/fcp/client/cache/file_backed_resource_cache.cc b/fcp/client/cache/file_backed_resource_cache.cc
new file mode 100644
index 0000000..a6e7034
--- /dev/null
+++ b/fcp/client/cache/file_backed_resource_cache.cc
@@ -0,0 +1,500 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/cache/file_backed_resource_cache.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <filesystem>
+#include <fstream>
+#include <functional>
+#include <map>
+#include <memory>
+#include <optional>
+#include <set>
+#include <string>
+#include <system_error> // NOLINT
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "google/protobuf/timestamp.pb.h"
+#include "absl/cleanup/cleanup.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/base/time_util.h"
+#include "fcp/client/cache/cache_manifest.pb.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "protostore/file-storage.h"
+#include "protostore/proto-data-store.h"
+
+namespace fcp {
+namespace client {
+namespace cache {
+
+constexpr char kCacheManifestFileName[] = "cache_manifest.pb";
+constexpr char kParentDir[] = "fcp";
+// Cached files will be saved in <cache directory>/fcp/cache.
+constexpr char kCacheDir[] = "cache";
+
+absl::StatusOr<CacheManifest> FileBackedResourceCache::ReadInternal() {
+ absl::StatusOr<const CacheManifest*> data = pds_->Read();
+ if (data.ok()) {
+ return *data.value();
+ }
+ log_manager_.LogDiag(ProdDiagCode::RESOURCE_CACHE_MANIFEST_READ_FAILED);
+ // Ignore the status from DeleteManifest() even if it's an error, and bubble
+ // up the status from pds. We call DeleteManifest() here instead of
+ // Initialize(), as Initialize() calls ReadInternal(), potentially causing
+ // infinite recursion. This means that any resources that were tracked by the
+ // deleted manifest will not be cleaned up until the next time Initialize() is
+ // called.
+ auto ignored_status = DeleteManifest();
+ if (!ignored_status.ok()) {
+ FCP_LOG(INFO) << "Failed to delete manifest: " << ignored_status.ToString();
+ }
+ return absl::InternalError(
+ absl::StrCat("Failed to read from database, with error message: ",
+ data.status().message()));
+}
+
+absl::Status FileBackedResourceCache::WriteInternal(
+ std::unique_ptr<CacheManifest> manifest) {
+ absl::Status status = pds_->Write(std::move(manifest));
+ if (!status.ok()) {
+ log_manager_.LogDiag(ProdDiagCode::RESOURCE_CACHE_MANIFEST_WRITE_FAILED);
+ // Ignore the status returned by DeleteManifest even if it's an error and
+ // instead return the status from pds. We call DeleteManifest() here instead
+ // of Initialize(), as Initialize() calls WriteInternal(), potentially
+ // causing infinite recursion. This means that any resources that were
+ // tracked by the deleted manifest will not be cleaned up until the next
+ // time Initialize() is called.
+ auto ignored_status = DeleteManifest();
+ if (!ignored_status.ok()) {
+ FCP_LOG(INFO) << "Failed to delete manifest: "
+ << ignored_status.ToString();
+ }
+ }
+ return status;
+}
+
+absl::StatusOr<std::unique_ptr<FileBackedResourceCache>>
+FileBackedResourceCache::Create(absl::string_view base_dir,
+ absl::string_view cache_dir,
+ LogManager* log_manager, fcp::Clock* clock,
+ int64_t max_cache_size_bytes) {
+ // Create <cache root>/fcp.
+ // Unfortunately NDK's flavor of std::filesystem::path does not support using
+ // absl::string_view.
+ std::filesystem::path cache_root_path((std::string(cache_dir)));
+ if (!cache_root_path.is_absolute()) {
+ log_manager->LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_CACHE_ROOT_PATH_NOT_ABSOLUTE);
+ return absl::InvalidArgumentError(
+ absl::StrCat("The provided path: ", cache_dir,
+ " is invalid. The path must be absolute"));
+ }
+ std::filesystem::path cache_dir_path =
+ cache_root_path / kParentDir / kCacheDir;
+ std::error_code error;
+ std::filesystem::create_directories(cache_dir_path, error);
+ if (error.value() != 0) {
+ log_manager->LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_FAILED_TO_CREATE_CACHE_DIR);
+ return absl::InternalError(absl::StrCat(
+ "Failed to create FileBackedResourceCache cache directory ",
+ cache_dir_path.string()));
+ }
+ // Create <files root>/fcp/cache_manifest.pb.s
+ std::filesystem::path manifest_path((std::string(base_dir)));
+ if (!manifest_path.is_absolute()) {
+ log_manager->LogDiag(ProdDiagCode::RESOURCE_CACHE_INVALID_MANIFEST_PATH);
+ return absl::InvalidArgumentError(
+ absl::StrCat("The provided path: ", manifest_path.string(),
+ " is invalid. The path must start with \"/\""));
+ }
+ manifest_path /= kParentDir;
+ std::filesystem::create_directories(manifest_path, error);
+ if (error.value() != 0) {
+ log_manager->LogDiag(RESOURCE_CACHE_FAILED_TO_CREATE_MANIFEST_DIR);
+ return absl::InternalError(
+ absl::StrCat("Failed to create directory ", manifest_path.string()));
+ }
+ manifest_path /= kCacheManifestFileName;
+
+ auto file_storage = std::make_unique<protostore::FileStorage>();
+ auto pds = std::make_unique<protostore::ProtoDataStore<CacheManifest>>(
+ *file_storage, manifest_path.string());
+ std::unique_ptr<FileBackedResourceCache> resource_cache =
+ absl::WrapUnique(new FileBackedResourceCache(
+ std::move(pds), std::move(file_storage), cache_dir_path,
+ manifest_path, log_manager, clock, max_cache_size_bytes));
+ {
+ absl::MutexLock lock(&resource_cache->mutex_);
+ FCP_RETURN_IF_ERROR(resource_cache->Initialize());
+ }
+
+ return resource_cache;
+}
+
+absl::Status FileBackedResourceCache::Put(absl::string_view cache_id,
+ const absl::Cord& resource,
+ const google::protobuf::Any& metadata,
+ absl::Duration max_age) {
+ absl::MutexLock lock(&mutex_);
+
+ if (resource.size() > max_cache_size_bytes_ / 2) {
+ return absl::ResourceExhaustedError(absl::StrCat(cache_id, " too large"));
+ }
+
+ FCP_ASSIGN_OR_RETURN(CacheManifest manifest, ReadInternal());
+ FCP_RETURN_IF_ERROR(CleanUp(resource.size(), manifest));
+
+ std::string cache_id_str(cache_id);
+ std::filesystem::path cached_file_path = cache_dir_path_ / cache_id_str;
+ absl::Time now = clock_.Now();
+ absl::Time expiry = now + max_age;
+ CachedResource cached_resource;
+ cached_resource.set_file_name(cache_id_str);
+ *cached_resource.mutable_metadata() = metadata;
+ *cached_resource.mutable_expiry_time() =
+ TimeUtil::ConvertAbslToProtoTimestamp(expiry);
+ *cached_resource.mutable_last_accessed_time() =
+ TimeUtil::ConvertAbslToProtoTimestamp(now);
+
+ // Write the manifest back to disk before we write the file.
+ manifest.mutable_cache()->insert({cache_id_str, cached_resource});
+ FCP_RETURN_IF_ERROR(
+ WriteInternal(std::make_unique<CacheManifest>(std::move(manifest))));
+
+ // Write file if it doesn't exist.
+ std::error_code exists_error;
+ bool cached_file_exists =
+ std::filesystem::exists(cached_file_path, exists_error);
+ if (exists_error.value() != 0) {
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_PUT_FAILED_TO_CHECK_IF_FILE_EXISTS);
+ return absl::InternalError(absl::StrCat(
+ "Failed to check if cached resource already exists with error code: ",
+ exists_error.value()));
+ }
+ if (!cached_file_exists) {
+ auto status = WriteCordToFile(cached_file_path.string(), resource);
+ if (!status.ok()) {
+ log_manager_.LogDiag(ProdDiagCode::RESOURCE_CACHE_RESOURCE_WRITE_FAILED);
+ return status;
+ }
+ }
+
+ return absl::OkStatus();
+}
+
+absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+FileBackedResourceCache::Get(absl::string_view cache_id,
+ std::optional<absl::Duration> max_age) {
+ // By default, set up a "CACHE_MISS" diag code to be logged when this method
+ // exits.
+ DebugDiagCode diag_code = DebugDiagCode::RESOURCE_CACHE_MISS;
+ absl::Cleanup diag_code_logger = [this, &diag_code] {
+ log_manager_.LogDiag(diag_code);
+ };
+ absl::MutexLock lock(&mutex_);
+ FCP_ASSIGN_OR_RETURN(CacheManifest manifest, ReadInternal());
+
+ std::string cache_id_str(cache_id);
+ if (!manifest.cache().contains(cache_id_str)) {
+ return absl::NotFoundError(absl::StrCat(cache_id, " not found"));
+ }
+ CachedResource cached_resource = manifest.cache().at(cache_id_str);
+ std::filesystem::path cached_file_path = cache_dir_path_ / cache_id_str;
+ google::protobuf::Any metadata = cached_resource.metadata();
+ absl::Time now = clock_.Now();
+ *cached_resource.mutable_last_accessed_time() =
+ TimeUtil::ConvertAbslToProtoTimestamp(now);
+ if (max_age.has_value()) {
+ absl::Time expiry = now + max_age.value();
+ *cached_resource.mutable_expiry_time() =
+ TimeUtil::ConvertAbslToProtoTimestamp(expiry);
+ }
+
+ absl::StatusOr<absl::Cord> contents =
+ ReadFileToCord(cached_file_path.string());
+ if (!contents.ok()) {
+ log_manager_.LogDiag(ProdDiagCode::RESOURCE_CACHE_RESOURCE_READ_FAILED);
+ manifest.mutable_cache()->erase(cache_id_str);
+ std::error_code error;
+ std::filesystem::remove(cached_file_path, error);
+ if (error.value() != 0) {
+ return absl::InternalError(error.message());
+ }
+ // Treat as not found, the resource should be fetched again.
+ return absl::NotFoundError(absl::StrCat(cache_id, " not found"));
+ }
+
+ manifest.mutable_cache()->erase(cache_id_str);
+ manifest.mutable_cache()->insert({cache_id_str, cached_resource});
+
+ absl::Status status =
+ WriteInternal(std::make_unique<CacheManifest>(std::move(manifest)));
+ if (!status.ok()) return status;
+
+ // We've reached the end, this is a hit! The absl::Cleanup above has a
+ // reference to diag_code, so we update it to CACHE_HIT here.
+ diag_code = DebugDiagCode::RESOURCE_CACHE_HIT;
+ return FileBackedResourceCache::ResourceAndMetadata{*contents, metadata};
+}
+
+absl::Status FileBackedResourceCache::Initialize() {
+ absl::string_view errorInInitializePrefix = "Error in initialize: ";
+ std::string pds_path = manifest_path_.string();
+ if (!std::filesystem::exists(pds_path)) {
+ std::ofstream ofs(pds_path);
+ }
+ absl::StatusOr<int64_t> file_size = storage_->GetFileSize(pds_path);
+ if (!file_size.ok()) {
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_INIT_FAILED_TO_GET_MANIFEST_SIZE);
+ return absl::InternalError(absl::StrCat(
+ errorInInitializePrefix, "Failed to get file size of cache manifest: ",
+ file_size.status().message()));
+ }
+ // Initialize db if it's not initialized.
+ if (*file_size == 0) {
+ auto status = WriteInternal(std::make_unique<CacheManifest>());
+ if (!status.ok()) {
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_INIT_FAILED_TO_INITIALIZE_MANIFEST);
+ return absl::InternalError(absl::StrCat(
+ errorInInitializePrefix,
+ "Failed to initialize cache manifest for the first time: ",
+ status.message()));
+ }
+ }
+ // Then run CleanUp. Even if our manifest was empty we still might have
+ // stranded cache files to delete, i.e. in the case that the manifest was
+ // deleted but the cache dir was not deleted.
+ absl::StatusOr<CacheManifest> manifest = ReadInternal();
+ if (!manifest.ok()) {
+ return absl::InternalError(
+ absl::StrCat(errorInInitializePrefix,
+ "Failed to read manifest: ", manifest.status().message()));
+ }
+ auto cleanup_status = CleanUp(std::nullopt, *manifest);
+ if (!cleanup_status.ok()) {
+ log_manager_.LogDiag(ProdDiagCode::RESOURCE_CACHE_INIT_FAILED_CLEANUP);
+ return absl::InternalError(absl::StrCat(
+ errorInInitializePrefix,
+ "Failed to clean up resource cache: ", cleanup_status.message()));
+ }
+ auto write_status = WriteInternal(std::make_unique<CacheManifest>(*manifest));
+ if (!write_status.ok()) {
+ return absl::InternalError(absl::StrCat(
+ errorInInitializePrefix,
+ "Failed to write cleaned up resource cache: ", write_status.message()));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status FileBackedResourceCache::CleanUp(
+ std::optional<int64_t> reserved_space_bytes, CacheManifest& manifest) {
+ // Expire any cached resources past their expiry.
+ // Clean up any files that are not tracked in the manifest.
+ // Clean up any manifest entries that point to nonexistent files.
+
+ // In order to delete files we don't track in the CacheManifest (or that
+ // became untracked due to a crash), fill cache_dir_files with every file in
+ // the cache dir. We'll then remove any file not actively tracked in the cache
+ // manifest.
+ std::set<std::filesystem::path> cache_dir_files;
+
+ // We don't have any subdirectories in the cache, so we can use a directory
+ // iterator.
+ std::error_code directory_error;
+ auto cache_dir_iterator =
+ std::filesystem::directory_iterator(cache_dir_path_, directory_error);
+ if (directory_error.value() != 0) {
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_CLEANUP_FAILED_TO_ITERATE_OVER_CACHE_DIR);
+ return absl::InternalError(absl::StrCat(
+ "Error iterating over cache dir. Error code: ", directory_error.value(),
+ " message: ", directory_error.message()));
+ }
+ for (auto& file : cache_dir_iterator) {
+ cache_dir_files.insert(cache_dir_path_ / file);
+ }
+
+ int64_t max_allowed_size_bytes = max_cache_size_bytes_;
+ max_allowed_size_bytes -= reserved_space_bytes.value_or(0);
+
+ std::set<std::string> cache_ids_to_delete;
+ absl::Time now = clock_.Now();
+ for (const auto& [id, resource] : manifest.cache()) {
+ absl::Time expiry =
+ TimeUtil::ConvertProtoToAbslTime(resource.expiry_time());
+ std::filesystem::path resource_file =
+ cache_dir_path_ / resource.file_name();
+ // It's possible that this manifest entry points at a file in the cache dir
+ // that doesn't exist, i.e. due to a failed write. In this case, the entry
+ // should be deleted as well. cache_dir_files should contain a scan of the
+ // entire cache dir, so the file pointed at by this manifest entry should be
+ // there.
+ bool cached_resource_exists =
+ cache_dir_files.find(resource_file) != cache_dir_files.end();
+ if (expiry < now || !cached_resource_exists) {
+ cache_ids_to_delete.insert(id);
+ } else {
+ cache_dir_files.erase(resource_file);
+ }
+ }
+
+ // Then delete CacheManifest entries.
+ for (const auto& cache_id : cache_ids_to_delete) {
+ manifest.mutable_cache()->erase(cache_id);
+ }
+
+ // Then delete files.
+ absl::Status filesystem_status = absl::OkStatus();
+ for (const auto& file : cache_dir_files) {
+ std::error_code remove_error;
+ std::filesystem::remove(file, remove_error);
+ // We intentionally loop through all files and attempt to remove as many as
+ // we can, then return the first error we saw.
+ if (remove_error.value() != 0 && filesystem_status.ok()) {
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_CLEANUP_FAILED_TO_DELETE_CACHED_FILE);
+ filesystem_status = absl::InternalError(absl::StrCat(
+ "Failed to delete file. Error code: ", remove_error.value(),
+ ", message: ", remove_error.message()));
+ }
+ }
+
+ FCP_RETURN_IF_ERROR(filesystem_status);
+
+ // If we still exceed the allowed size of the cache, delete entries until
+ // we're under the allowed size, sorted by least recently used.
+
+ // Build up a list of (cache_id, least recently used timestamp) and compute
+ // the total size of the cache.
+ std::vector<std::pair<std::string, absl::Time>> cache_id_lru;
+ cache_id_lru.reserve(manifest.cache().size());
+ uintmax_t cache_dir_size = 0;
+
+ for (const auto& [id, resource] : manifest.cache()) {
+ cache_id_lru.emplace_back(std::make_pair(
+ id, TimeUtil::ConvertProtoToAbslTime(resource.last_accessed_time())));
+ std::filesystem::path resource_file =
+ cache_dir_path_ / resource.file_name();
+ // We calculate the sum of tracked files instead of taking the file_size()
+ // of the cache directory, because the latter generally does not reflect the
+ // total size of the sum of all the files inside a directory.
+ std::error_code ignored_exists_error;
+ if (!std::filesystem::exists(resource_file, ignored_exists_error)) {
+ // We log that the manifest entry pointed at a file in the cache that
+ // doesn't exist, but otherwise continue. The next time the cache is
+ // initialized, the manifest entry will be cleaned up.
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_CLEANUP_FAILED_TO_GET_FILE_SIZE);
+ continue;
+ }
+ std::error_code file_size_error;
+ std::uintmax_t size =
+ std::filesystem::file_size(resource_file, file_size_error);
+ // Loop through as many as we can and if there's an error, return the first
+ // error we saw.
+ if (file_size_error.value() != 0) {
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_CLEANUP_FAILED_TO_GET_FILE_SIZE);
+ if (filesystem_status.ok()) {
+ filesystem_status = absl::InternalError(absl::StrCat(
+ "Error getting file size. Error code: ", file_size_error.value(),
+ ", message: ", file_size_error.message()));
+ }
+ // If the file exists, but we failed to get the file size for some reason,
+ // try to delete it then continue.
+ std::error_code ignored_remove_error;
+ std::filesystem::remove(resource_file, ignored_remove_error);
+ } else {
+ cache_dir_size += size;
+ }
+ }
+
+ FCP_RETURN_IF_ERROR(filesystem_status);
+
+ // Then, if the cache is bigger than the allowed size, delete entries ordered
+ // by least recently used until we're below the threshold.
+ if (cache_dir_size > max_allowed_size_bytes) {
+ std::sort(cache_id_lru.begin(), cache_id_lru.end(),
+ [](std::pair<std::string, absl::Time> first,
+ std::pair<std::string, absl::Time> second) -> bool {
+ // Sort by least recently used timestamp.
+ return first.second < second.second;
+ });
+ for (auto const& [cache_id, timestamp] : cache_id_lru) {
+ std::string id_to_remove = cache_id;
+ std::filesystem::path file_to_remove =
+ cache_dir_path_ / manifest.cache().at(id_to_remove).file_name();
+ manifest.mutable_cache()->erase(id_to_remove);
+ std::error_code remove_error;
+ uintmax_t file_size =
+ std::filesystem::file_size(file_to_remove, remove_error);
+ if (remove_error.value() != 0 && filesystem_status.ok()) {
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_CLEANUP_FAILED_TO_GET_FILE_SIZE);
+ filesystem_status = absl::InternalError(absl::StrCat(
+ "Error getting file size. Error code: ", remove_error.value(),
+ ", message: ", remove_error.message()));
+ }
+ std::filesystem::remove(file_to_remove, remove_error);
+ if (remove_error.value() != 0 && filesystem_status.ok()) {
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_CLEANUP_FAILED_TO_GET_FILE_SIZE);
+ filesystem_status = absl::InternalError(absl::StrCat(
+ "Failed to delete file. Error code: ", remove_error.value(),
+ ", message: ", remove_error.message()));
+ }
+ cache_dir_size -= file_size;
+ if (cache_dir_size < max_allowed_size_bytes) break;
+ }
+ }
+
+ FCP_RETURN_IF_ERROR(filesystem_status);
+
+ return absl::OkStatus();
+}
+
+absl::Status FileBackedResourceCache::DeleteManifest() {
+ if (std::filesystem::exists(manifest_path_)) {
+ std::error_code error;
+ std::filesystem::remove(manifest_path_, error);
+ if (error.value() != 0) {
+ log_manager_.LogDiag(
+ ProdDiagCode::RESOURCE_CACHE_FAILED_TO_DELETE_MANIFEST);
+ return absl::InternalError(
+ absl::StrCat("Failed to delete manifest! error code: ", error.value(),
+ ", message: ", error.message()));
+ }
+ }
+ return absl::OkStatus();
+}
+
+} // namespace cache
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/cache/file_backed_resource_cache.h b/fcp/client/cache/file_backed_resource_cache.h
new file mode 100644
index 0000000..20816a6
--- /dev/null
+++ b/fcp/client/cache/file_backed_resource_cache.h
@@ -0,0 +1,147 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_CACHE_FILE_BACKED_RESOURCE_CACHE_H_
+#define FCP_CLIENT_CACHE_FILE_BACKED_RESOURCE_CACHE_H_
+
+#include <filesystem>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "fcp/base/clock.h"
+#include "fcp/client/cache/cache_manifest.pb.h"
+#include "fcp/client/cache/resource_cache.h"
+#include "fcp/client/log_manager.h"
+#include "protostore/file-storage.h"
+#include "protostore/proto-data-store.h"
+
+namespace fcp {
+namespace client {
+namespace cache {
+
+/**
+ * A FileBackedResourceCache is a ResourceCache implementation where each
+ * resource payload is stored as an individual file in a directory, along with a
+ * ProtoDataStore manifest that tracks each entry.
+ *
+ * FileBackedResourceCache is thread safe.
+ */
+class FileBackedResourceCache : public ResourceCache {
+ public:
+ // The CacheManifest will be created in
+ // <base directory>/fcp/cache_manifest.pb.
+
+ // Factory method to create FileBackedResourceCache. The provided cache dir is
+ // the absolute path for storing cached files, and the provided base dir is
+ // the absolute path for longer term storage. FileBackedResourceCache will
+ // attempt to create subdirectories and files, so the directory must grant
+ // read/write access.
+ //
+ // A FileBackedResourceCache will not store any resources larger than
+ // `max_cache_size_bytes` / 2.
+ //
+ // Deletes any stored resources past expiry.
+ static absl::StatusOr<std::unique_ptr<FileBackedResourceCache>> Create(
+ absl::string_view base_dir, absl::string_view cache_dir,
+ LogManager* log_manager, fcp::Clock* clock, int64_t max_cache_size_bytes);
+
+ // Implementation of `ResourceCache::Put`.
+ //
+ // If storing `resource` pushes the size of the cache directory over
+ // `max_cache_size_bytes`, entries with the oldest last_accessed_time will be
+ // deleted until the directory is under `max_cache_size_bytes` Returns Ok on
+ // success. On error, returns:
+ // - INTERNAL - unexpected error.
+ // - INVALID_ARGUMENT - if max_age is in the past.
+ // - RESOURCE_EXHAUSTED - if resource bytes is bigger than
+ // `max_cache_size_bytes` / 2.
+ absl::Status Put(absl::string_view cache_id, const absl::Cord& resource,
+ const google::protobuf::Any& metadata,
+ absl::Duration max_age) override ABSL_LOCKS_EXCLUDED(mutex_);
+
+ // Implementation of `ResourceCache::Get`.
+ absl::StatusOr<ResourceAndMetadata> Get(absl::string_view cache_id,
+ std::optional<absl::Duration> max_age)
+ override ABSL_LOCKS_EXCLUDED(mutex_);
+
+ ~FileBackedResourceCache() override = default;
+
+ // FileBackedResourceCache is neither copyable nor movable.
+ FileBackedResourceCache(const FileBackedResourceCache&) = delete;
+ FileBackedResourceCache& operator=(const FileBackedResourceCache&) = delete;
+
+ private:
+ FileBackedResourceCache(
+ std::unique_ptr<protostore::ProtoDataStore<CacheManifest>> pds,
+ std::unique_ptr<protostore::FileStorage> storage,
+ std::filesystem::path cache_dir_path, std::filesystem::path manifest_path,
+ LogManager* log_manager, Clock* clock, const int64_t max_cache_size_bytes)
+ : storage_(std::move(storage)),
+ pds_(std::move(pds)),
+ cache_dir_path_(cache_dir_path),
+ manifest_path_(manifest_path),
+ log_manager_(*log_manager),
+ clock_(*clock),
+ max_cache_size_bytes_(max_cache_size_bytes) {}
+
+ absl::StatusOr<CacheManifest> ReadInternal()
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ absl::Status WriteInternal(std::unique_ptr<CacheManifest> manifest)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Initializes the CacheManifest ProtoDataStore db if necessesary, then runs
+ // CleanUp().
+ absl::Status Initialize() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Deletes the cache manifest.
+ absl::Status DeleteManifest() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // TTLs any cached resources stored past their expiry, then deletes any
+ // stranded files without matching manifest entries, and any entries without
+ // matching resource files. If `reserved_space_bytes` is set, cleans up
+ // resources sorted by least recently used until the cache size is less than
+ // `max_cache_size_bytes_ - reserved_space_bytes`.
+ // This modifies the passed `manifest`.
+ absl::Status CleanUp(std::optional<int64_t> reserved_space_bytes,
+ CacheManifest& manifest)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Unused, but must be kept alive for longer than pds_.
+ std::unique_ptr<protostore::FileStorage> storage_;
+ std::unique_ptr<protostore::ProtoDataStore<CacheManifest>> pds_
+ ABSL_GUARDED_BY(mutex_);
+ const std::filesystem::path cache_dir_path_;
+ const std::filesystem::path manifest_path_;
+ LogManager& log_manager_;
+ Clock& clock_;
+ const int64_t max_cache_size_bytes_;
+ absl::Mutex mutex_;
+};
+
+// Used by the class and in tests only.
+namespace internal {} // namespace internal
+
+} // namespace cache
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_CACHE_FILE_BACKED_RESOURCE_CACHE_H_
diff --git a/fcp/client/cache/file_backed_resource_cache_test.cc b/fcp/client/cache/file_backed_resource_cache_test.cc
new file mode 100644
index 0000000..aa3d651
--- /dev/null
+++ b/fcp/client/cache/file_backed_resource_cache_test.cc
@@ -0,0 +1,569 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/cache/file_backed_resource_cache.h"
+
+#include <cstddef>
+#include <filesystem>
+#include <fstream>
+#include <functional>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/simulated_clock.h"
+#include "fcp/client/selector_context.pb.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace cache {
+namespace {
+
+constexpr char kKey1[] = "1";
+absl::Cord Resource1() { return absl::Cord("stream RENAISSANCE by Beyoncé"); }
+constexpr char kKey2[] = "2";
+absl::Cord Resource2() { return absl::Cord("stream PURE/HONEY by Beyoncé"); }
+constexpr char kKey3[] = "3";
+absl::Cord Resource3() {
+ return absl::Cord("A third resource?? In this economy");
+}
+SelectorContext SampleStoredMetadata() {
+ SelectorContext sample_stored_metadata;
+ sample_stored_metadata.mutable_computation_properties()->set_session_name(
+ "test");
+ return sample_stored_metadata;
+}
+google::protobuf::Any Metadata() {
+ google::protobuf::Any metadata;
+ metadata.PackFrom(SampleStoredMetadata());
+ return metadata;
+}
+absl::Duration kMaxAge = absl::Hours(1);
+int64_t kMaxCacheSizeBytes = 10000000;
+
+int NumFilesInDir(std::filesystem::path dir) {
+ int num_files_in_dir = 0;
+ for ([[maybe_unused]] auto& de : std::filesystem::directory_iterator(dir)) {
+ num_files_in_dir++;
+ }
+ return num_files_in_dir;
+}
+
+class FileBackedResourceCacheTest : public testing::Test {
+ protected:
+ void SetUp() override {
+ root_cache_dir_ = testing::TempDir();
+ std::filesystem::path root_cache_dir(root_cache_dir_);
+ cache_dir_ = root_cache_dir / "fcp" / "cache";
+ root_files_dir_ = testing::TempDir();
+ std::filesystem::path root_files_dir(root_files_dir_);
+ manifest_path_ = root_files_dir / "fcp" / "cache_manifest.pb";
+ }
+
+ void TearDown() override {
+ std::filesystem::remove_all(root_cache_dir_);
+ std::filesystem::remove_all(root_files_dir_);
+ }
+
+ testing::StrictMock<MockLogManager> log_manager_;
+ SimulatedClock clock_;
+ std::string root_cache_dir_;
+ std::string root_files_dir_;
+ std::filesystem::path cache_dir_;
+ std::filesystem::path manifest_path_;
+};
+
+TEST_F(FileBackedResourceCacheTest, FailToCreateParentDirectoryInBaseDir) {
+ EXPECT_CALL(
+ log_manager_,
+ LogDiag(ProdDiagCode::RESOURCE_CACHE_FAILED_TO_CREATE_MANIFEST_DIR));
+ ASSERT_THAT(
+ FileBackedResourceCache::Create("/proc/0", root_cache_dir_, &log_manager_,
+ &clock_, kMaxCacheSizeBytes),
+ IsCode(INTERNAL));
+}
+
+TEST_F(FileBackedResourceCacheTest, FailToCreateParentDirectoryInCacheDir) {
+ EXPECT_CALL(log_manager_,
+ LogDiag(ProdDiagCode::RESOURCE_CACHE_FAILED_TO_CREATE_CACHE_DIR));
+ ASSERT_THAT(
+ FileBackedResourceCache::Create(root_files_dir_, "/proc/0", &log_manager_,
+ &clock_, kMaxCacheSizeBytes),
+ IsCode(INTERNAL));
+}
+
+TEST_F(FileBackedResourceCacheTest, InvalidBaseDirRelativePath) {
+ EXPECT_CALL(log_manager_,
+ LogDiag(ProdDiagCode::RESOURCE_CACHE_INVALID_MANIFEST_PATH));
+ ASSERT_THAT(FileBackedResourceCache::Create("relative/base", root_cache_dir_,
+ &log_manager_, &clock_,
+ kMaxCacheSizeBytes),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(FileBackedResourceCacheTest, InvalidCacheDirRelativePath) {
+ EXPECT_CALL(
+ log_manager_,
+ LogDiag(ProdDiagCode::RESOURCE_CACHE_CACHE_ROOT_PATH_NOT_ABSOLUTE));
+ ASSERT_THAT(FileBackedResourceCache::Create(root_files_dir_, "relative/cache",
+ &log_manager_, &clock_,
+ kMaxCacheSizeBytes),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(FileBackedResourceCacheTest, SuccessfulInitialization) {
+ ASSERT_OK(FileBackedResourceCache::Create(root_files_dir_, root_cache_dir_,
+ &log_manager_, &clock_,
+ kMaxCacheSizeBytes));
+}
+
+TEST_F(FileBackedResourceCacheTest, CacheFile) {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata> cached_resource =
+ (*resource_cache)->Get(kKey1, std::nullopt);
+ ASSERT_OK(cached_resource);
+ ASSERT_EQ(Resource1(), (*cached_resource).resource);
+ ASSERT_EQ(Metadata().GetTypeName(),
+ (*cached_resource).metadata.GetTypeName());
+ SelectorContext stored_metadata;
+ (*cached_resource).metadata.UnpackTo(&stored_metadata);
+ ASSERT_THAT(SampleStoredMetadata(), EqualsProto(stored_metadata));
+}
+
+TEST_F(FileBackedResourceCacheTest, CacheFileCloseReinitializeFileStillCached) {
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+ }
+
+ // Advance the clock a little bit
+ clock_.AdvanceTime(absl::Minutes(1));
+
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource = (*resource_cache)->Get(kKey1, std::nullopt);
+ ASSERT_OK(cached_resource);
+ ASSERT_EQ(Resource1(), (*cached_resource).resource);
+ }
+}
+
+TEST_F(FileBackedResourceCacheTest, CacheTooBigFileReturnsResourceExhausted) {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ (int64_t)(Resource1().size() / 2));
+ ASSERT_OK(resource_cache);
+ ASSERT_THAT(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)),
+ IsCode(RESOURCE_EXHAUSTED));
+}
+
+TEST_F(FileBackedResourceCacheTest,
+ UnreadableManifestReturnsInternalButIsThenReadable) {
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+ }
+
+ // There should be the one file we cached.
+ ASSERT_EQ(NumFilesInDir(cache_dir_), 1);
+
+ // Write some garbage to the manifest.
+ {
+ std::ofstream ofs(manifest_path_, std::ofstream::trunc);
+ ofs << "garbage garbage garbage";
+ }
+
+ {
+ EXPECT_CALL(log_manager_,
+ LogDiag(ProdDiagCode::RESOURCE_CACHE_MANIFEST_READ_FAILED));
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_THAT(resource_cache, IsCode(INTERNAL));
+ }
+
+ // Failing to read the manifest should have deleted it.
+ ASSERT_EQ(std::filesystem::exists(manifest_path_), false);
+ // But there will still be files in the cache dir. These files will be cleaned
+ // up the next time the cache is initialized.
+ ASSERT_EQ(NumFilesInDir(cache_dir_), 1);
+
+ // We should be able to create a new FileBackedResourceCache successfully
+ // since the garbage manifest was deleted.
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ // Initializing the cache should have deleted the untracked files in the
+ // cache dir.
+ ASSERT_EQ(NumFilesInDir(cache_dir_), 0);
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ ASSERT_THAT((*resource_cache)->Get(kKey1, std::nullopt), IsCode(NOT_FOUND));
+ }
+}
+
+TEST_F(FileBackedResourceCacheTest,
+ UnreadableManifestReturnsInternalButIsThenWritable) {
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+ }
+
+ // There should be the one file we cached.
+ ASSERT_EQ(NumFilesInDir(cache_dir_), 1);
+
+ // Write some garbage to the manifest.
+ {
+ std::ofstream ofs(manifest_path_, std::ofstream::trunc);
+ ofs << "garbage garbage garbage";
+ }
+
+ {
+ EXPECT_CALL(log_manager_,
+ LogDiag(ProdDiagCode::RESOURCE_CACHE_MANIFEST_READ_FAILED));
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_THAT(resource_cache, IsCode(INTERNAL));
+ }
+
+ // Failing to read the manifest should have deleted it.
+ ASSERT_EQ(std::filesystem::exists(manifest_path_), false);
+ // But there will still be files in the cache dir. These files will be cleaned
+ // up the next time the cache is initialized.
+ ASSERT_EQ(NumFilesInDir(cache_dir_), 1);
+
+ // We should be able to create a new FileBackedResourceCache successfully
+ // since it was reset.
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ // Initializing the cache should have deleted the untracked files in the
+ // cache dir.
+ ASSERT_EQ(NumFilesInDir(cache_dir_), 0);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+ ASSERT_EQ(NumFilesInDir(cache_dir_), 1);
+ }
+}
+
+TEST_F(FileBackedResourceCacheTest, PutTwoFilesThenGetThem) {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK((*resource_cache)->Put(kKey1, Resource1(), Metadata(), kMaxAge));
+ ASSERT_OK((*resource_cache)->Put(kKey2, Resource2(), Metadata(), kMaxAge));
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource1 = (*resource_cache)->Get(kKey1, std::nullopt);
+ ASSERT_OK(cached_resource1);
+ ASSERT_EQ(Resource1(), (*cached_resource1).resource);
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource2 = (*resource_cache)->Get(kKey2, std::nullopt);
+ ASSERT_OK(cached_resource2);
+ ASSERT_EQ(Resource2(), (*cached_resource2).resource);
+}
+
+TEST_F(FileBackedResourceCacheTest, CacheFileThenExpire) {
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK((*resource_cache)->Put(kKey1, Resource1(), Metadata(), kMaxAge));
+ }
+
+ // Advance the clock a little bit beyond max_age
+ clock_.AdvanceTime(kMaxAge + absl::Minutes(1));
+
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource = (*resource_cache)->Get(kKey1, std::nullopt);
+ ASSERT_THAT(cached_resource, IsCode(NOT_FOUND));
+ }
+}
+
+TEST_F(FileBackedResourceCacheTest, PutTwoFilesThenOneExpires) {
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK((*resource_cache)->Put(kKey1, Resource1(), Metadata(), kMaxAge));
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey2, Resource2(), Metadata(), kMaxAge * 2));
+ }
+
+ // Advance the clock a little bit beyond the first resource's expiry.
+
+ clock_.AdvanceTime(kMaxAge + absl::Minutes(1));
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource1 = (*resource_cache)->Get(kKey1, std::nullopt);
+ ASSERT_THAT(cached_resource1, IsCode(NOT_FOUND));
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource2 = (*resource_cache)->Get(kKey2, std::nullopt);
+ ASSERT_OK(cached_resource2);
+ ASSERT_EQ(Resource2(), (*cached_resource2).resource);
+ }
+}
+
+TEST_F(FileBackedResourceCacheTest, CacheFileThenUpdateExpiry) {
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK((*resource_cache)->Put(kKey1, Resource1(), Metadata(), kMaxAge));
+ }
+
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ // Pass a new max_age when we Get the resource, updating its expiry time.
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource = (*resource_cache)->Get(kKey1, 6 * kMaxAge);
+ ASSERT_OK(cached_resource);
+ ASSERT_EQ(Resource1(), (*cached_resource).resource);
+ }
+
+ // Advance the clock. Even though we've now passed the original expiry, the
+ // resource should still be cached because we updated the expiry with the
+ // Get().
+ clock_.AdvanceTime(kMaxAge + absl::Minutes(5));
+
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ // Pass a new max_age when we Get the resource, updating its expiry time.
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource = (*resource_cache)->Get(kKey1, 6 * kMaxAge);
+ ASSERT_OK(cached_resource);
+ ASSERT_EQ(Resource1(), (*cached_resource).resource);
+ }
+}
+
+TEST_F(FileBackedResourceCacheTest, CacheExceedsMaxCacheSize) {
+ // Room for resource2 and resource3 but not quite enough for resource1 as
+ // well.
+ int64_t local_max_cache_size_bytes =
+ Resource2().size() + Resource3().size() + (Resource1().size() / 2);
+
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ local_max_cache_size_bytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+ clock_.AdvanceTime(absl::Minutes(1));
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey2, Resource2(), Metadata(), absl::Hours(1)));
+ clock_.AdvanceTime(absl::Minutes(1));
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey3, Resource3(), Metadata(), absl::Hours(1)));
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ ASSERT_OK((*resource_cache)->Get(kKey3, std::nullopt));
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ ASSERT_OK((*resource_cache)->Get(kKey2, std::nullopt));
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ ASSERT_THAT((*resource_cache)->Get(kKey1, std::nullopt), IsCode(NOT_FOUND));
+}
+
+TEST_F(FileBackedResourceCacheTest,
+ CacheExceedsMaxCacheSizeLeastRecentlyUsedDeleted) {
+ int64_t local_max_cache_size_bytes =
+ Resource1().size() + (Resource2().size() / 2) + Resource3().size();
+
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ local_max_cache_size_bytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+ clock_.AdvanceTime(absl::Minutes(1));
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey2, Resource2(), Metadata(), absl::Hours(1)));
+ clock_.AdvanceTime(absl::Minutes(1));
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ // Get resource1 so we update it's least recently used time before we put in
+ // resource3. This should cause resource2 to get deleted instead of resource1
+ // when we add resource3.
+ ASSERT_OK((*resource_cache)->Get(kKey1, std::nullopt));
+ clock_.AdvanceTime(absl::Minutes(1));
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey3, Resource3(), Metadata(), absl::Hours(1)));
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ ASSERT_OK((*resource_cache)->Get(kKey3, std::nullopt));
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ ASSERT_THAT((*resource_cache)->Get(kKey2, std::nullopt), IsCode(NOT_FOUND));
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ ASSERT_OK((*resource_cache)->Get(kKey1, std::nullopt));
+}
+
+TEST_F(FileBackedResourceCacheTest, FileInCacheDirButNotInManifest) {
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+ }
+
+ // Delete the manifest!
+ std::filesystem::remove(manifest_path_);
+
+ // There should be the one file we cached.
+ ASSERT_EQ(NumFilesInDir(cache_dir_), 1);
+
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource = (*resource_cache)->Get(kKey1, std::nullopt);
+ ASSERT_THAT(cached_resource, IsCode(NOT_FOUND));
+ // The cache dir should also be empty, because we reinitialized the cache
+ // and there was an untracked file in it.
+ ASSERT_EQ(NumFilesInDir(cache_dir_), 0);
+ }
+}
+
+// Covers the case where a user manually deletes the app's cache dir.
+TEST_F(FileBackedResourceCacheTest, FileInManifestButRootCacheDirDeleted) {
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+ }
+
+ // Delete the entire cache dir from the root.
+ std::filesystem::remove_all(root_cache_dir_);
+
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+
+ // Now we should gracefully fail even though the file is in the manifest but
+ // not on disk.
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource = (*resource_cache)->Get(kKey1, std::nullopt);
+ ASSERT_THAT(cached_resource, IsCode(NOT_FOUND));
+ }
+}
+
+TEST_F(FileBackedResourceCacheTest, FileInManifestButNotInCacheDir) {
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK(
+ (*resource_cache)->Put(kKey1, Resource1(), Metadata(), absl::Hours(1)));
+ }
+
+ // Delete the file we just cached.
+ std::filesystem::remove(cache_dir_ / kKey1);
+
+ {
+ auto resource_cache = FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+
+ // Now we should gracefully fail even though the file is in the manifest but
+ // not on disk.
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ absl::StatusOr<FileBackedResourceCache::ResourceAndMetadata>
+ cached_resource = (*resource_cache)->Get(kKey1, std::nullopt);
+ ASSERT_THAT(cached_resource, IsCode(NOT_FOUND));
+ }
+}
+
+} // namespace
+} // namespace cache
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/cache/resource_cache.h b/fcp/client/cache/resource_cache.h
new file mode 100644
index 0000000..2b4d392
--- /dev/null
+++ b/fcp/client/cache/resource_cache.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_CACHE_RESOURCE_CACHE_H_
+#define FCP_CLIENT_CACHE_RESOURCE_CACHE_H_
+
+#include <optional>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+
+namespace fcp {
+namespace client {
+namespace cache {
+
+/**
+ * A ResourceCache is an interface for a cache that stores resources (entries)
+ * for a certain duration. A resource consists of an absl::Cord payload and
+ * accompanying metadata, keyed by a string ID.
+ */
+class ResourceCache {
+ public:
+ struct ResourceAndMetadata {
+ absl::Cord resource;
+ google::protobuf::Any metadata;
+ };
+
+ virtual ~ResourceCache() = default;
+
+ // Stores resource`under key cache_id. Will be deleted after now + max_age. If
+ // cache_id already exists, the corresponding resource and metadata will be
+ // overwritten. Metadata will be returned along with the resource when Get()
+ // is called. A Resource is *not* guaranteed to be cached until now + max_age,
+ // implementations may choose to delete resources earlier if needed. Returns
+ // Ok on success On error, returns
+ // - INTERNAL - unexpected error.
+ // - INVALID_ARGUMENT - if max_age is in the past.
+ // - RESOURCE_EXHAUSTED - if the resource is too big too be cached.
+ virtual absl::Status Put(absl::string_view cache_id,
+ const absl::Cord& resource,
+ const google::protobuf::Any& metadata,
+ absl::Duration max_age) = 0;
+
+ // Returns resource along with caller-provided stored metadata.
+ // If max_age is set, the stored max_age for this cache_id will be updated.
+ // Returns Ok on success
+ // On error, returns
+ // - INTERNAL - unexpected error.
+ // - NOT_FOUND - if cache_id not in ResourceCache
+ virtual absl::StatusOr<ResourceAndMetadata> Get(
+ absl::string_view cache_id, std::optional<absl::Duration> max_age) = 0;
+};
+
+} // namespace cache
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_CACHE_RESOURCE_CACHE_H_
diff --git a/fcp/client/cache/temp_files.cc b/fcp/client/cache/temp_files.cc
new file mode 100644
index 0000000..1354109
--- /dev/null
+++ b/fcp/client/cache/temp_files.cc
@@ -0,0 +1,129 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/cache/temp_files.h"
+
+#include <sys/file.h>
+#include <unistd.h>
+
+#include <cstdlib>
+#include <filesystem>
+#include <fstream>
+#include <memory>
+#include <string>
+#include <system_error> // NOLINT
+
+#include "fcp/base/monitoring.h"
+#include "fcp/client/diag_codes.pb.h"
+
+namespace fcp {
+namespace client {
+namespace cache {
+namespace {
+
+absl::Status DeleteFilesInDirectory(const std::filesystem::path& directory) {
+ if (!std::filesystem::exists(directory)) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Directory does not exist: ", directory.string()));
+ }
+ absl::Status status = absl::OkStatus();
+ // Note this only iterates through the top level directory and will not
+ // traverse subdirectories.
+ for (auto& de : std::filesystem::directory_iterator(directory)) {
+ std::error_code error;
+ // Save the first error, but attempt to delete the other files.
+ if (!std::filesystem::remove(de.path(), error)) {
+ if (status.ok()) {
+ status = absl::InternalError(absl::StrCat(
+ "Failed to delete file with error code: ", error.value()));
+ }
+ }
+ }
+ return status;
+}
+
+} // namespace
+
+absl::StatusOr<std::unique_ptr<TempFiles>> TempFiles::Create(
+ const std::string& cache_dir, LogManager* log_manager) {
+ std::filesystem::path root_path(cache_dir);
+ if (!root_path.is_absolute()) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("The provided path: ", cache_dir,
+ "is invalid. The path must start with \"/\""));
+ }
+
+ // Create fcp parent dir in the passed root dir.
+ std::filesystem::path fcp_base_dir = root_path / kParentDir;
+ std::error_code error;
+ std::filesystem::create_directories(fcp_base_dir, error);
+ if (error.value() != 0) {
+ return absl::InternalError(absl::StrCat(
+ "Failed to create TempFiles base directory ",
+ fcp_base_dir.generic_string(), " with error code ", error.value()));
+ }
+
+ // Create directory in parent dir for temporary files.
+ std::filesystem::path temp_files_dir = fcp_base_dir / kTempFilesDir;
+ std::filesystem::create_directories(temp_files_dir, error);
+ if (error.value() != 0) {
+ return absl::InternalError(
+ absl::StrCat("Failed to create TempFiles temp file directory ",
+ temp_files_dir.generic_string()));
+ }
+
+ // We clean up the temp files dir on creation in case we failed to clean it up
+ // during a previous run (i.e. due to the training process getting killed
+ // etc.) and to make sure we don't end up in the pathological case where we
+ // are always crashing partway through training and stranding temp files
+ // because the TempFiles dtor never runs.
+ auto cleanup_status = DeleteFilesInDirectory(temp_files_dir);
+ if (!cleanup_status.ok()) {
+ log_manager->LogDiag(ProdDiagCode::TEMP_FILES_NATIVE_FAILED_TO_DELETE);
+ return cleanup_status;
+ }
+ return absl::WrapUnique(new TempFiles(temp_files_dir, log_manager));
+}
+
+absl::StatusOr<std::string> TempFiles::CreateTempFile(
+ const std::string& prefix, const std::string& suffix) {
+ std::filesystem::path candidate_path;
+ int fd;
+ do {
+ candidate_path = temp_files_dir_ /
+ absl::StrCat(prefix, std::to_string(std::rand()), suffix);
+ } while ((fd = open(candidate_path.c_str(), O_CREAT | O_EXCL | O_RDWR,
+ S_IRWXU)) == -1 &&
+ errno == EEXIST);
+ close(fd);
+ std::ofstream tmp_file(candidate_path);
+ if (!tmp_file) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("could not create file ", candidate_path.string()));
+ }
+
+ return candidate_path.string();
+}
+
+TempFiles::~TempFiles() {
+ if (!DeleteFilesInDirectory(temp_files_dir_).ok()) {
+ log_manager_.LogDiag(ProdDiagCode::TEMP_FILES_NATIVE_FAILED_TO_DELETE);
+ }
+}
+
+} // namespace cache
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/cache/temp_files.h b/fcp/client/cache/temp_files.h
new file mode 100644
index 0000000..6f2e77a
--- /dev/null
+++ b/fcp/client/cache/temp_files.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_CACHE_TEMP_FILES_H_
+#define FCP_CLIENT_CACHE_TEMP_FILES_H_
+
+#include <filesystem>
+#include <memory>
+#include <string>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/client/files.h"
+#include "fcp/client/log_manager.h"
+
+namespace fcp {
+namespace client {
+namespace cache {
+
+// Manages temporary files created by the federated compute runtime. Unlike
+// other Files implementations, TempFiles will clean up created temporary files
+// eagerly as part its construction and deletion.
+class TempFiles : public Files {
+ public:
+ static constexpr char kParentDir[] = "fcp";
+ // The subdirectory temporary files will be created in. Files in this
+ // directory are deleted at the end of a federated computation.
+ static constexpr char kTempFilesDir[] = "tmp";
+
+ // Factory method to create TempFiles. The provided cache dir is the
+ // absolute path for storing cached files. TempFiles will attempt to
+ // create subdirectories and files, so the directory must grant read/write
+ // access.
+ static absl::StatusOr<std::unique_ptr<TempFiles>> Create(
+ const std::string& cache_dir, LogManager* log_manager);
+
+ // Creates a temporary file. TempFiles will delete these files at the end
+ // of a federated computation run, or upon the next creation of a TempFiles
+ // instance.
+ // On success, returns a file path.
+ // On error, returns
+ // - INTERNAL - unexpected error.
+ // - INVALID_ARGUMENT - on "expected" errors such as I/O issues.
+ absl::StatusOr<std::string> CreateTempFile(
+ const std::string& prefix, const std::string& suffix) override;
+
+ // Any temporary Files created with TempFiles will be deleted.
+ ~TempFiles() override;
+
+ // TempFiles is neither copyable nor movable.
+ TempFiles(const TempFiles&) = delete;
+ TempFiles& operator=(const TempFiles&) = delete;
+
+ private:
+ TempFiles(std::filesystem::path temp_files_dir, LogManager* log_manager)
+ : temp_files_dir_(temp_files_dir), log_manager_(*log_manager) {}
+
+ const std::filesystem::path temp_files_dir_;
+ LogManager& log_manager_;
+};
+
+} // namespace cache
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_CACHE_TEMP_FILES_H_
diff --git a/fcp/client/cache/temp_files_test.cc b/fcp/client/cache/temp_files_test.cc
new file mode 100644
index 0000000..b1c7555
--- /dev/null
+++ b/fcp/client/cache/temp_files_test.cc
@@ -0,0 +1,135 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/cache/temp_files.h"
+
+#include <filesystem>
+#include <fstream>
+#include <functional>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/statusor.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace cache {
+namespace {
+
+int CountFilesInDir(const std::filesystem::path& dir) {
+ int num_files = 0;
+ for ([[maybe_unused]] auto const& unused :
+ std::filesystem::directory_iterator{dir}) {
+ num_files++;
+ }
+ return num_files;
+}
+
+class TempFilesTest : public testing::Test {
+ protected:
+ void SetUp() override {
+ root_dir_ = testing::TempDir();
+ std::filesystem::path root_dir(root_dir_);
+ temp_file_dir_ =
+ root_dir / TempFiles::kParentDir / TempFiles::kTempFilesDir;
+ }
+ void TearDown() override {
+ std::filesystem::remove_all(std::filesystem::path(root_dir_));
+ }
+
+ testing::StrictMock<MockLogManager> log_manager_;
+ std::string root_dir_;
+ std::filesystem::path temp_file_dir_;
+};
+
+TEST_F(TempFilesTest, FailToCreateParentDirectory) {
+ ASSERT_THAT(TempFiles::Create("/proc/0", &log_manager_), IsCode(INTERNAL));
+}
+
+TEST_F(TempFilesTest, InvalidRelativePath) {
+ ASSERT_THAT(TempFiles::Create("relative/cache", &log_manager_),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(TempFilesTest, SuccessfulInitialization) {
+ ASSERT_OK(TempFiles::Create(root_dir_, &log_manager_));
+}
+
+TEST_F(TempFilesTest, CreateTempFile) {
+ auto temp_files = TempFiles::Create(root_dir_, &log_manager_);
+ ASSERT_OK(temp_files);
+ auto temp_file = (*temp_files)->CreateTempFile("stefan", ".cool");
+ ASSERT_OK(temp_file);
+}
+
+TEST_F(TempFilesTest, CreateSomeTempFilesThenDeleteInDtor) {
+ auto temp_files = TempFiles::Create(root_dir_, &log_manager_);
+ ASSERT_OK(temp_files);
+ int num_temp_files = 4;
+ for (int i = 0; i < num_temp_files; i++) {
+ ASSERT_OK((*temp_files)->CreateTempFile("stefan", ".cool"));
+ }
+ ASSERT_EQ(num_temp_files, CountFilesInDir(temp_file_dir_));
+
+ temp_files->reset(); // deleting temp_files should empty the directory.
+
+ ASSERT_EQ(0, CountFilesInDir(temp_file_dir_));
+}
+
+TEST_F(TempFilesTest, CreatingTempFilesDeletesExistingFiles) {
+ std::filesystem::path root_dir(root_dir_);
+
+ ASSERT_TRUE(std::filesystem::create_directories(temp_file_dir_));
+
+ int num_existing_temp_files = 10;
+ for (int i = 0; i < num_existing_temp_files; i++) {
+ std::filesystem::path temp_file_path =
+ temp_file_dir_ / absl::StrCat("temp", i);
+ std::ofstream{temp_file_path};
+ }
+ ASSERT_EQ(num_existing_temp_files, CountFilesInDir(temp_file_dir_));
+
+ auto temp_files = TempFiles::Create(root_dir_, &log_manager_);
+ ASSERT_OK(temp_files);
+ ASSERT_EQ(0, CountFilesInDir(temp_file_dir_));
+}
+
+TEST_F(TempFilesTest, FailToDeleteTempFilesLogs) {
+ // Create a temp file in the temp dir
+ auto temp_files = TempFiles::Create(root_dir_, &log_manager_);
+ ASSERT_OK(temp_files);
+ ASSERT_OK((*temp_files)->CreateTempFile("stefan", ".cool"));
+ ASSERT_OK((*temp_files)->CreateTempFile("stefan", ".cool"));
+ ASSERT_OK((*temp_files)->CreateTempFile("stefan", ".cool"));
+ ASSERT_EQ(3, CountFilesInDir(temp_file_dir_));
+
+ // Delete the temp file dir and root dir, which should cause the dtor to fail
+ // because we deleted the directories out from underneath it.
+ std::filesystem::remove_all(std::filesystem::path(root_dir_));
+
+ EXPECT_CALL(log_manager_,
+ LogDiag(ProdDiagCode::TEMP_FILES_NATIVE_FAILED_TO_DELETE));
+ temp_files->reset();
+}
+
+} // namespace
+} // namespace cache
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/cache/test_helpers.h b/fcp/client/cache/test_helpers.h
new file mode 100644
index 0000000..45acdd8
--- /dev/null
+++ b/fcp/client/cache/test_helpers.h
@@ -0,0 +1,43 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_CLIENT_CACHE_TEST_HELPERS_H_
+#define FCP_CLIENT_CACHE_TEST_HELPERS_H_
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/client/cache/resource_cache.h"
+
+namespace fcp {
+namespace client {
+namespace cache {
+
+// A mock `ResourceCache` implementation that can be used in tests.
+class MockResourceCache : public ResourceCache {
+ public:
+ MOCK_METHOD(absl::Status, Put,
+ (absl::string_view cache_id, const absl::Cord& resource,
+ const google::protobuf::Any& metadata, absl::Duration max_age),
+ (override));
+ MOCK_METHOD(absl::StatusOr<ResourceAndMetadata>, Get,
+ (absl::string_view cache_id,
+ std::optional<absl::Duration> max_age),
+ (override));
+};
+
+} // namespace cache
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_CACHE_TEST_HELPERS_H_q
diff --git a/fcp/client/client_runner.h b/fcp/client/client_runner.h
new file mode 100644
index 0000000..1b28ebd
--- /dev/null
+++ b/fcp/client/client_runner.h
@@ -0,0 +1,238 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_CLIENT_RUNNER_H_
+#define FCP_CLIENT_CLIENT_RUNNER_H_
+
+#include <cxxabi.h>
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#include <array>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+#include <filesystem>
+#include <fstream>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <typeinfo>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/client_runner_example_data.pb.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/fake_event_publisher.h"
+#include "fcp/client/files.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/histogram_counters.pb.h"
+#include "fcp/client/http/curl/curl_api.h"
+#include "fcp/client/http/curl/curl_http_client.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/protos/plan.pb.h"
+#include "google/protobuf/any.pb.h"
+#include "gtest/gtest.h"
+
+namespace fcp::client {
+
+// A stub implementation of the SimpleTaskEnvironment interface that logs calls
+// to stderr and returns canned example data.
+class FederatedTaskEnvDepsImpl : public SimpleTaskEnvironment {
+ public:
+ // Constructs a SimpleTaskEnvironment that will return an example iterator
+ // with `num_empty_examples` empty examples.
+ explicit FederatedTaskEnvDepsImpl(int num_empty_examples,
+ std::string test_cert_path = "")
+ : examples_(num_empty_examples),
+ test_cert_path_(std::move(test_cert_path)) {}
+
+ // Constructs a SimpleTaskEnvironment that will return an example iterator
+ // with examples determined by the collection URI.
+ explicit FederatedTaskEnvDepsImpl(ClientRunnerExampleData example_data,
+ std::string test_cert_path = "")
+ : examples_(std::move(example_data)),
+ test_cert_path_(std::move(test_cert_path)) {}
+
+ std::string GetBaseDir() override {
+ return std::filesystem::path(testing::TempDir());
+ }
+
+ std::string GetCacheDir() override {
+ return std::filesystem::path(testing::TempDir());
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ SelectorContext unused;
+ return CreateExampleIterator(example_selector, unused);
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector,
+ const SelectorContext& selector_context) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME
+ // << ":\n\turi: " << example_selector.collection_uri()
+ // << "\n\ttype: " << example_selector.criteria().type_url();
+ if (auto* num_empty_examples = std::get_if<int>(&examples_)) {
+ return std::make_unique<FakeExampleIterator>(*num_empty_examples);
+ } else if (auto* store = std::get_if<ClientRunnerExampleData>(&examples_)) {
+ const auto& examples_map = store->examples_by_collection_uri();
+ if (auto it = examples_map.find(example_selector.collection_uri());
+ it != examples_map.end()) {
+ return std::make_unique<FakeExampleIterator>(&it->second);
+ }
+ return absl::InvalidArgumentError("no examples for collection_uri");
+ }
+ return absl::InternalError("unsupported examples variant type");
+ }
+
+ std::unique_ptr<fcp::client::http::HttpClient> CreateHttpClient() override {
+ return std::make_unique<fcp::client::http::curl::CurlHttpClient>(
+ &curl_api_, test_cert_path_);
+ }
+
+ private:
+ class FakeExampleIterator : public ExampleIterator {
+ public:
+ explicit FakeExampleIterator(int num_examples)
+ : example_list_(nullptr), num_examples_(num_examples) {}
+ explicit FakeExampleIterator(
+ const ClientRunnerExampleData::ExampleList* examples)
+ : example_list_(examples), num_examples_(examples->examples_size()) {}
+ absl::StatusOr<std::string> Next() override {
+ if (num_examples_served_ >= num_examples_) {
+ return absl::OutOfRangeError("");
+ }
+ std::string example =
+ example_list_ ? example_list_->examples(num_examples_served_) : "";
+ num_examples_served_++;
+ return example;
+ }
+ void Close() override {}
+
+ private:
+ const ClientRunnerExampleData::ExampleList* const example_list_;
+ const int num_examples_;
+ int num_examples_served_ = 0;
+ };
+
+ bool TrainingConditionsSatisfied() override {
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ return true;
+ }
+
+ const std::variant<int, ClientRunnerExampleData> examples_;
+ const std::string test_cert_path_;
+ fcp::client::http::curl::CurlApi curl_api_;
+};
+
+// An implementation of the Files interface that attempts to create a temporary
+// file with the given prefix and suffix in a directory suitable for temporary
+// files.
+// NB this is a proof-of-concept implementation that does not use existing infra
+// such as mkstemps() or std::tmpfile due to the requirements of the existing
+// Files API: include prefix, suffix strings in filename; return file path
+// instead of file descriptor.
+class FilesImpl : public Files {
+ public:
+ FilesImpl() { std::srand(static_cast<int32_t>(std::time(nullptr))); }
+
+ absl::StatusOr<std::string> CreateTempFile(
+ const std::string& prefix, const std::string& suffix) override {
+ const auto tmp_dir = std::filesystem::path(testing::TempDir());
+ std::filesystem::path candidate_path;
+ int fd;
+ do {
+ candidate_path =
+ tmp_dir / absl::StrCat(prefix, std::to_string(std::rand()), suffix);
+ } while ((fd = open(candidate_path.c_str(), O_CREAT | O_EXCL | O_RDWR,
+ S_IRWXU)) == -1 &&
+ errno == EEXIST);
+ close(fd);
+ std::ofstream tmp_file(candidate_path);
+ if (!tmp_file) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("could not create file ", candidate_path.string()));
+ }
+ // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << candidate_path;
+ return candidate_path.string();
+ }
+};
+
+// A stub implementation of the LogManager interface that logs invocations to
+// stderr.
+class LogManagerImpl : public LogManager {
+ public:
+ void LogDiag(ProdDiagCode diag_code) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << ProdDiagCode_Name(diag_code);
+ }
+ void LogDiag(DebugDiagCode diag_code) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << DebugDiagCode_Name(diag_code);
+ }
+ void LogToLongHistogram(HistogramCounters histogram_counter, int, int,
+ engine::DataSourceType data_source_type,
+ int64_t value) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME
+ // << ": " << HistogramCounters_Name(histogram_counter) << " <- " <<
+ // value;
+ }
+
+ void SetModelIdentifier(const std::string& model_identifier) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << ":\n\t" << model_identifier;
+ }
+};
+
+class FlagsImpl : public Flags {
+ public:
+ void set_use_http_federated_compute_protocol(bool value) {
+ use_http_federated_compute_protocol_ = value;
+ }
+ void set_use_tflite_training(bool value) { use_tflite_training_ = value; }
+
+ int64_t condition_polling_period_millis() const override { return 1000; }
+ int64_t tf_execution_teardown_grace_period_millis() const override {
+ return 1000;
+ }
+ int64_t tf_execution_teardown_extended_period_millis() const override {
+ return 2000;
+ }
+ int64_t grpc_channel_deadline_seconds() const override { return 0; }
+ bool log_tensorflow_error_messages() const override { return true; }
+ bool use_http_federated_compute_protocol() const override {
+ return use_http_federated_compute_protocol_;
+ }
+ bool use_tflite_training() const override { return use_tflite_training_; }
+
+ private:
+ bool use_http_federated_compute_protocol_ = false;
+ bool use_tflite_training_ = false;
+};
+
+} // namespace fcp::client
+
+#endif // FCP_CLIENT_CLIENT_RUNNER_H_
diff --git a/fcp/client/client_runner_example_data.proto b/fcp/client/client_runner_example_data.proto
new file mode 100644
index 0000000..d75d360
--- /dev/null
+++ b/fcp/client/client_runner_example_data.proto
@@ -0,0 +1,28 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.client;
+
+// A collection of example data for use with client_runner_main.
+message ClientRunnerExampleData {
+ message ExampleList {
+ // The serialized examples that will be returned by the ExampleIterator.
+ repeated bytes examples = 1;
+ }
+
+ // The list of examples that will be returned for each collection uri.
+ map<string, ExampleList> examples_by_collection_uri = 1;
+}
diff --git a/fcp/client/client_runner_main.cc b/fcp/client/client_runner_main.cc
new file mode 100644
index 0000000..d46b172
--- /dev/null
+++ b/fcp/client/client_runner_main.cc
@@ -0,0 +1,135 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <optional>
+#include <string>
+#include <utility>
+
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "absl/flags/usage.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_split.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/client_runner.h"
+#include "fcp/client/client_runner_example_data.pb.h"
+#include "fcp/client/fake_event_publisher.h"
+#include "fcp/client/fl_runner.h"
+
+ABSL_FLAG(std::string, server, "",
+ "Federated Server URI (supports https+test:// and https:// URIs");
+ABSL_FLAG(std::string, api_key, "", "API Key");
+ABSL_FLAG(std::string, test_cert, "",
+ "Path to test CA certificate PEM file; used for https+test:// URIs");
+ABSL_FLAG(std::string, session, "", "Session name");
+ABSL_FLAG(std::string, population, "", "Population name");
+ABSL_FLAG(std::string, retry_token, "", "Retry token");
+ABSL_FLAG(std::string, client_version, "", "Client version");
+ABSL_FLAG(std::string, attestation_string, "", "Attestation string");
+ABSL_FLAG(std::string, example_data_path, "",
+ "Path to a serialized ClientRunnerExampleData proto with client "
+ "example data. Falls back to --num_empty_examples if unset.");
+ABSL_FLAG(int, num_empty_examples, 0,
+ "Number of (empty) examples each created iterator serves. Ignored if "
+ "--example_store_path is set.");
+ABSL_FLAG(int, num_rounds, 1, "Number of rounds to train");
+ABSL_FLAG(int, sleep_after_round_secs, 3,
+ "Number of seconds to sleep after each round.");
+ABSL_FLAG(bool, use_http_federated_compute_protocol, false,
+ "Whether to enable the HTTP FederatedCompute protocol instead "
+ "of the gRPC FederatedTrainingApi protocol.");
+ABSL_FLAG(bool, use_tflite_training, false, "Whether use TFLite for training.");
+
+static constexpr char kUsageString[] =
+ "Stand-alone Federated Client Executable.\n\n"
+ "Connects to the specified server, tries to retrieve a plan, run the\n"
+ "plan (feeding the specified number of empty examples), and report the\n"
+ "results of the computation back to the server.";
+
+static absl::StatusOr<fcp::client::ClientRunnerExampleData> LoadExampleData(
+ const std::string& examples_path) {
+ std::ifstream examples_file(examples_path);
+ fcp::client::ClientRunnerExampleData data;
+ if (!data.ParseFromIstream(&examples_file) || !examples_file.eof()) {
+ return absl::InvalidArgumentError(
+ "Failed to parse ClientRunnerExampleData");
+ }
+ return data;
+}
+
+int main(int argc, char** argv) {
+ absl::SetProgramUsageMessage(kUsageString);
+ absl::ParseCommandLine(argc, argv);
+
+ int num_rounds = absl::GetFlag(FLAGS_num_rounds);
+ std::string server = absl::GetFlag(FLAGS_server);
+ std::string session = absl::GetFlag(FLAGS_session);
+ std::string population = absl::GetFlag(FLAGS_population);
+ std::string client_version = absl::GetFlag(FLAGS_client_version);
+ std::string test_cert = absl::GetFlag(FLAGS_test_cert);
+ FCP_LOG(INFO) << "Running for " << num_rounds << " rounds:";
+ FCP_LOG(INFO) << " - server: " << server;
+ FCP_LOG(INFO) << " - session: " << session;
+ FCP_LOG(INFO) << " - population: " << population;
+ FCP_LOG(INFO) << " - client_version: " << client_version;
+
+ std::optional<fcp::client::ClientRunnerExampleData> example_data;
+ if (std::string path = absl::GetFlag(FLAGS_example_data_path);
+ !path.empty()) {
+ auto statusor = LoadExampleData(path);
+ if (!statusor.ok()) {
+ FCP_LOG(ERROR) << "Failed to load example data: " << statusor.status();
+ return 1;
+ }
+ example_data = *std::move(statusor);
+ }
+
+ bool success = false;
+ for (auto i = 0; i < num_rounds || num_rounds < 0; ++i) {
+ fcp::client::FederatedTaskEnvDepsImpl federated_task_env_deps_impl =
+ example_data
+ ? fcp::client::FederatedTaskEnvDepsImpl(*example_data, test_cert)
+ : fcp::client::FederatedTaskEnvDepsImpl(
+ absl::GetFlag(FLAGS_num_empty_examples), test_cert);
+ fcp::client::FakeEventPublisher event_publisher(/*quiet=*/false);
+ fcp::client::FilesImpl files_impl;
+ fcp::client::LogManagerImpl log_manager_impl;
+ fcp::client::FlagsImpl flags;
+ flags.set_use_http_federated_compute_protocol(
+ absl::GetFlag(FLAGS_use_http_federated_compute_protocol));
+ flags.set_use_tflite_training(absl::GetFlag(FLAGS_use_tflite_training));
+
+ auto fl_runner_result = RunFederatedComputation(
+ &federated_task_env_deps_impl, &event_publisher, &files_impl,
+ &log_manager_impl, &flags, server, absl::GetFlag(FLAGS_api_key),
+ test_cert, session, population, absl::GetFlag(FLAGS_retry_token),
+ client_version, absl::GetFlag(FLAGS_attestation_string));
+ if (fl_runner_result.ok()) {
+ FCP_LOG(INFO) << "Run finished successfully; result: "
+ << fl_runner_result.value().DebugString();
+ success = true;
+ } else {
+ FCP_LOG(ERROR) << "Error during run: " << fl_runner_result.status();
+ }
+ int sleep_secs = absl::GetFlag(FLAGS_sleep_after_round_secs);
+ FCP_LOG(INFO) << "Sleeping for " << sleep_secs << " secs";
+ absl::SleepFor(absl::Seconds(sleep_secs));
+ }
+ return success ? 0 : 1;
+}
diff --git a/fcp/client/diag_codes.proto b/fcp/client/diag_codes.proto
new file mode 100644
index 0000000..1d71af8
--- /dev/null
+++ b/fcp/client/diag_codes.proto
@@ -0,0 +1,335 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.client;
+
+option java_package = "com.google.intelligence.fcp.client";
+option java_multiple_files = true;
+
+/**
+ * Enumerations of diagnostic codes for debugging, testing, and logging.
+ *
+ * Diag codes serve two purposes:
+ * - testing and development. The ability to log, observe and assert on
+ * diag code traces allows for integration testing of code that runs
+ * asynchronously in different processes or apps. Both DebugDiagCodes and
+ * ProdDiagCodes are used to this end.
+ * - monitoring of a deployment. Sequences of diag codes are
+ * * easy to analyze
+ * * have limited expressive power by design (mere integers) to discourage
+ * logging sensitive information
+ * * are easier to support across platforms as compared to unstructured logs,
+ * for both policy and technical reasons.
+ *
+ * Note that only the ordinal of the diag code enum will be logged to clearcut.
+ * As a result, the diag codes for debug and production should be mutually
+ * exclusive.
+ */
+enum DebugDiagCode {
+ DEBUG_DIAG_CODE_UNDEFINED = 0;
+
+ // Codes reserved for test-only training diag codes.
+ // =================================================
+
+ /** Logged right before ClientExecution.getLoopOp() is executed */
+ TRAINING_BEFORE_LOOP_OP = 1000;
+
+ /** Logged right after ClientExecution.getLoopOp() is executed */
+ TRAINING_AFTER_LOOP_OP = 1001;
+
+ /** Logged if opstats is enabled */
+ TRAINING_OPSTATS_ENABLED = 1002;
+
+ // Codes reserved for test-only resource cache diag codes.
+ //================================================================
+
+ // Logged when a resource is requested that is in the cache.
+ RESOURCE_CACHE_HIT = 1200;
+
+ // Logged when a resource is requested that isn't in the cache.
+ RESOURCE_CACHE_MISS = 1201;
+}
+
+/**
+ * Diagnosis codes that are meant to be logged in production. These usually are
+ * pretty severe errors, public API being called, or infrequent jobs (like
+ * training or old example removal) being run.
+ *
+ * The logging of ProdDiagCode is controlled by a runtime dynamic flag. Logging
+ * can be skipped in accordance to the flag.
+ */
+enum ProdDiagCode {
+ PROD_DIAG_CODE_UNDEFINED = 0;
+
+ // Codes reserved for background training
+ // ======================================
+
+ /**
+ * Successfully interrupted TensorFlow execution happening on a separate
+ * thread.
+ */
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION = 51;
+
+ /**
+ * TensorFlow session was interrupted but timed out waiting for execution to
+ * complete.
+ */
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT = 50;
+
+ /**
+ * TensorFlow session was interrupted and finished execution after the grace
+ * period.
+ */
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED = 49;
+
+ /**
+ * TensorFlow session was interrupted but timed out waiting for execution to
+ * complete in the extended period.
+ */
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT = 48;
+
+ /** Sent when the provided ClientOnlyPlan cannot be parsed. */
+ BACKGROUND_TRAINING_FAILED_CANNOT_PARSE_PLAN = 40;
+
+ /** Sent when the provided ClientOnlyPlan does not pass a sanity check. */
+ BACKGROUND_TRAINING_FAILED_PLAN_FAILS_SANITY_CHECK = 39;
+
+ /** Successfully interrupted GRPC on a separate thread. */
+ BACKGROUND_TRAINING_INTERRUPT_GRPC = 34;
+
+ /** GRPC was interrupted but timed out waiting for execution to complete. */
+ BACKGROUND_TRAINING_INTERRUPT_GRPC_TIMED_OUT = 33;
+
+ /** GRPC was interrupted and finished after the grace period. */
+ BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_COMPLETED = 32;
+
+ /**
+ * GRPC was interrupted but timed out waiting for execution to complete in the
+ * extended period.
+ */
+ BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_TIMED_OUT = 31;
+
+ /** Successfully interrupted HTTP on a separate thread. */
+ BACKGROUND_TRAINING_INTERRUPT_HTTP = 24;
+
+ /**
+ * HTTP was interrupted but timed out waiting for execution to complete.
+ */
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT = 23;
+
+ /** HTTP was interrupted and finished after the grace period. */
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED = 22;
+
+ /**
+ * HTTP was interrupted but timed out waiting for execution to complete in the
+ * extended period.
+ */
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT = 21;
+
+ /**
+ * Sent when TFLite was used.
+ */
+ BACKGROUND_TRAINING_TFLITE_ENGINE_USED = 20;
+
+ /**
+ * Sent when TFLite model flatbuffer is not empty.
+ */
+ BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED = 18;
+
+ /**
+ * A generic SecAgg client error.
+ */
+ SECAGG_CLIENT_NATIVE_ERROR_GENERIC = 1369;
+
+ /**
+ * The server requested an unsupported version.
+ */
+ SECAGG_CLIENT_ERROR_UNSUPPORTED_VERSION = 1368;
+
+ /**
+ * Sent when a plan that ingests data via Dataset is attempted to be run in
+ * an environment where Dataset support is not available.
+ */
+ DATASET_NOT_SUPPORTED = 1493;
+
+ /** Logged when a CheckinRequestAck message was expected, but not received. */
+ BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_EXPECTED_BUT_NOT_RECVD = 29;
+
+ /** Logged when a CheckinRequestAck message is received. */
+ BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED = 28;
+
+ /**
+ * Sent when the provided eligibility eval ClientOnlyPlan cannot be parsed.
+ */
+ BACKGROUND_TRAINING_ELIGIBILITY_EVAL_FAILED_CANNOT_PARSE_PLAN = 26;
+
+ /**
+ * Codes reserved for logs related to opstats
+ * ==========================================
+ */
+
+ // Logged when the provided path for creating database is invalid.
+ OPSTATS_INVALID_FILE_PATH = 1699;
+
+ // Logged when failed to create parent directories for the database file.
+ OPSTATS_PARENT_DIR_CREATION_FAILED = 1698;
+
+ // Logged when failed to read from OpStats DB.
+ OPSTATS_READ_FAILED = 1697;
+
+ // Logged when failed to reset OpStats DB.
+ OPSTATS_RESET_FAILED = 1696;
+
+ // Logged when failed to write to OpStats DB.
+ OPSTATS_WRITE_FAILED = 1695;
+
+ // Logged when the OpStats example store is requested, but the collection uri
+ // is wrong.
+ OPSTATS_INCORRECT_COLLECTION_URI = 1694;
+
+ // Logged when the provided selection criteria for the OpStats example store
+ // is invalid.
+ OPSTATS_INVALID_SELECTION_CRITERIA = 1693;
+
+ // Logged when the OpStats example store is requested, but not enabled.
+ OPSTATS_EXAMPLE_STORE_REQUESTED_NOT_ENABLED = 1692;
+
+ // Logged when extracting the task name from the checkin response fails.
+ OPSTATS_TASK_NAME_EXTRACTION_FAILED = 1691;
+
+ // Logged when we start to construct an opstats message for a run after having
+ // successfully created an underlying db.
+ OPSTATS_DB_COMMIT_EXPECTED = 1690;
+
+ // Logged when we try to commit an opstats message to the db.
+ OPSTATS_DB_COMMIT_ATTEMPTED = 1689;
+
+ // Logged when there's already another instance of OpStatsDb which uses the
+ // same underlying file.
+ OPSTATS_MULTIPLE_DB_INSTANCE_DETECTED = 1688;
+
+ // Logged when failed to open a file descriptor for the underlying database
+ // file.
+ OPSTATS_FAILED_TO_OPEN_FILE = 1687;
+
+ /**
+ * Codes reserved for logs related to HTTP
+ * =======================================
+ */
+ /* Logged when a client using the GRPC protocol downloads a regular
+ * (non-eligibility eval) task's resource (plan or initial checkpoint) using
+ * HTTP. */
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP = 1799;
+ /* Logged when the attempt to fetch HTTP resources (as per
+ * `HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP`) failed. */
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED = 1798;
+ /* Logged when the attempt to fetch HTTP resources (as per
+ * `HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP`) succeeded. */
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED = 1797;
+ /* Logged when a cancellation request or an abort request failed. */
+ HTTP_CANCELLATION_OR_ABORT_REQUEST_FAILED = 1790;
+ /* Logged when a ReportEligibilityEvalTaskResult request failed. */
+ HTTP_REPORT_ELIGIBILITY_EVAL_RESULT_REQUEST_FAILED = 1789;
+ /* Logged when a ReportTaskResult request failed. */
+ HTTP_REPORT_TASK_RESULT_REQUEST_FAILED = 1788;
+ /* Logged when HTTP federated protocol is used. */
+ HTTP_FEDERATED_PROTOCOL_USED = 1787;
+
+ /**
+ * Codes reserved for logs related to temp files
+ * =============================================
+ */
+ TEMP_FILES_NATIVE_FAILED_TO_DELETE = 1090;
+
+ /**
+ * Codes reserved for logs related to Federated Select
+ * =======================================
+ */
+ /* Logged when a task tries to use Federated Select to fetch one or more
+ * slices but the feature is disabled. */
+ FEDSELECT_SLICE_HTTP_FETCH_REQUESTED_BUT_DISABLED = 1899;
+ /* Logged when a regular (non-eligibility eval) task uses Federated Select to
+ * fetch one or more slices via HTTP. */
+ FEDSELECT_SLICE_HTTP_FETCH_REQUESTED = 1898;
+ /* Logged when the attempt to fetch one or more slices via HTTP (as per
+ * `FEDSELECT_SLICE_HTTP_FETCH_REQUESTED`) failed. */
+ FEDSELECT_SLICE_HTTP_FETCH_FAILED = 1897;
+ /* Logged when the attempt to fetch one or more slices via HTTP (as per
+ * `FEDSELECT_SLICE_HTTP_FETCH_REQUESTED`) succeeded. */
+ FEDSELECT_SLICE_HTTP_FETCH_SUCCEEDED = 1896;
+
+ /**
+ * Codes for logs related to the resource cache
+ * ========================================
+ */
+
+ /* Logged when a FileBackedResourceCache fails to read the CacheManifest
+ * proto db. */
+ RESOURCE_CACHE_MANIFEST_READ_FAILED = 1999;
+ /* Logged when a FileBackedResourceCache fails to write the CacheManifest to
+ * storage. */
+ RESOURCE_CACHE_MANIFEST_WRITE_FAILED = 1998;
+ /* Logged when a FileBackedResourceCache fails to read the cached resource to
+ * storage. */
+ RESOURCE_CACHE_RESOURCE_READ_FAILED = 1997;
+ /* Logged when a FileBackedResourceCache fails to write the cached resource to
+ * storage. */
+ RESOURCE_CACHE_RESOURCE_WRITE_FAILED = 1996;
+ /* Logged when a FileBackedResourceCache is initialized without an absolute
+ * root path. */
+ RESOURCE_CACHE_CACHE_ROOT_PATH_NOT_ABSOLUTE = 1995;
+ /* Logged when a FileBackedResourceCache fails to create the cache dir on
+ * initialization. */
+ RESOURCE_CACHE_FAILED_TO_CREATE_CACHE_DIR = 1994;
+ /* Logged when a FileBackedResourceCache is initialized with an invalid cache
+ * manifest path. */
+ RESOURCE_CACHE_INVALID_MANIFEST_PATH = 1993;
+ /* Logged when a FileBackedResourceCache fails to create the parent directory
+ * of the cache manifest. */
+ RESOURCE_CACHE_FAILED_TO_CREATE_MANIFEST_DIR = 1992;
+ /* Logged when a FileBackedResourceCache fails to reset the cache manifest. */
+ RESOURCE_CACHE_FAILED_TO_RESET_MANIFEST = 1991;
+ /* Logged when a FileBackedResourceCache fails to get the size of the cache
+ * manifest. */
+ RESOURCE_CACHE_INIT_FAILED_TO_GET_MANIFEST_SIZE = 1990;
+ /* Logged when a FileBackedResourceCache fails to iterate over the cache
+ * directory during cleanup.
+ */
+ RESOURCE_CACHE_CLEANUP_FAILED_TO_ITERATE_OVER_CACHE_DIR = 1989;
+ /* Logged when a FileBackedResourceCache fails to delete a cached file during
+ * cleanup. */
+ RESOURCE_CACHE_CLEANUP_FAILED_TO_DELETE_CACHED_FILE = 1988;
+ /* Logged when a FileBackedResourceCache fails to get the file size of a
+ * cached file. */
+ RESOURCE_CACHE_CLEANUP_FAILED_TO_GET_FILE_SIZE = 1987;
+ /* Logged when a FileBackedResourceCache fails to initialize the cache
+ * manifest when it doesn't already exist. */
+ RESOURCE_CACHE_INIT_FAILED_TO_INITIALIZE_MANIFEST = 1986;
+ /* Logged when a FileBackedResourceCache fails to delete an existing cache
+ * manifest due to an error. */
+ RESOURCE_CACHE_FAILED_TO_DELETE_MANIFEST = 1985;
+ /* Logged when a FileBackedResourceCache fails in some way during cleanup in
+ * initialization. */
+ RESOURCE_CACHE_INIT_FAILED_CLEANUP = 1984;
+ /* Logged when a FileBackedResourceCache fails to check if a cached file
+ * exists during cleanup. */
+ RESOURCE_CACHE_CLEANUP_FAILED_TO_CHECK_IF_FILE_EXISTS = 1983;
+ /* Logged when a FileBackedResourceCache fails to check if a cached file
+ * exists during Put(). */
+ RESOURCE_CACHE_PUT_FAILED_TO_CHECK_IF_FILE_EXISTS = 1982;
+
+ reserved 25;
+}
diff --git a/fcp/client/engine/BUILD b/fcp/client/engine/BUILD
new file mode 100644
index 0000000..7792f85
--- /dev/null
+++ b/fcp/client/engine/BUILD
@@ -0,0 +1,305 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@org_tensorflow//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
+load("//fcp:config.bzl", "FCP_COPTS")
+
+default_visibility = [
+ "//fcp:internal",
+]
+
+package(
+ default_visibility = default_visibility,
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "plan_engine",
+ srcs = [
+ "simple_plan_engine.cc",
+ ],
+ hdrs = [
+ "simple_plan_engine.h",
+ ],
+ copts = FCP_COPTS,
+ visibility = default_visibility,
+ deps = [
+ ":common",
+ ":example_iterator_factory",
+ ":plan_engine_helpers",
+ ":tf_wrapper",
+ "//fcp/base",
+ "//fcp/client:histogram_counters_cc_proto",
+ "//fcp/client:interfaces",
+ "//fcp/client:interruptible_runner",
+ "//fcp/client:simple_task_environment",
+ "//fcp/client/opstats:opstats_logger",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "example_query_plan_engine",
+ srcs = [
+ "example_query_plan_engine.cc",
+ ],
+ hdrs = [
+ "example_query_plan_engine.h",
+ ],
+ copts = FCP_COPTS,
+ visibility = default_visibility,
+ deps = [
+ ":common",
+ ":example_iterator_factory",
+ ":plan_engine_helpers",
+ "//fcp/base",
+ "//fcp/client:example_query_result_cc_proto",
+ "//fcp/client:simple_task_environment",
+ "//fcp/client/opstats:opstats_logger",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/tensorflow:status",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core/platform:tstring",
+ ],
+)
+
+cc_test(
+ name = "example_query_plan_engine_test",
+ srcs = ["example_query_plan_engine_test.cc"],
+ deps = [
+ ":common",
+ ":example_query_plan_engine",
+ "//fcp/client:client_runner",
+ "//fcp/client:example_query_result_cc_proto",
+ "//fcp/client:test_helpers",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/tensorflow:external_dataset_op_lib",
+ "//fcp/testing",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/c:checkpoint_reader",
+ "@org_tensorflow//tensorflow/c:tf_status_headers",
+ "@org_tensorflow//tensorflow/c:tf_status_helper",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+# A plan-engine independent wrapper around TF that supports cancellation.
+cc_library(
+ name = "tf_wrapper",
+ srcs = ["tf_wrapper.cc"],
+ hdrs = ["tf_wrapper.h"],
+ copts = FCP_COPTS,
+ visibility = ["//visibility:private"],
+ deps = [
+ ":plan_engine_helpers",
+ "//fcp/base",
+ "//fcp/base:future",
+ "//fcp/base:scheduler",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:interfaces",
+ "//fcp/client:interruptible_runner",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/time",
+ "@com_google_protobuf//:protobuf",
+ "@org_tensorflow//tensorflow/core:core_cpu",
+ ],
+)
+
+cc_test(
+ name = "tf_wrapper_test",
+ srcs = ["tf_wrapper_test.cc"],
+ deps = [
+ ":tf_wrapper",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "plan_engine_helpers",
+ srcs = ["plan_engine_helpers.cc"],
+ hdrs = ["plan_engine_helpers.h"],
+ copts = FCP_COPTS,
+ visibility = default_visibility,
+ deps = [
+ ":common",
+ ":example_iterator_factory",
+ "//fcp/base",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:interfaces",
+ "//fcp/client:simple_task_environment",
+ "//fcp/client/opstats:opstats_logger",
+ "//fcp/client/opstats:opstats_logger_impl",
+ "//fcp/client/opstats:pds_backed_opstats_db",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/tensorflow:external_dataset",
+ "//fcp/tensorflow:host_object",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@org_tensorflow//tensorflow/core:framework",
+ ],
+)
+
+cc_library(
+ name = "tflite_wrapper",
+ srcs = ["tflite_wrapper.cc"],
+ hdrs = ["tflite_wrapper.h"],
+ deps = [
+ ":caching_error_reporter",
+ "//fcp/base",
+ "//fcp/client:interfaces",
+ "//fcp/client:interruptible_runner",
+ "//fcp/client:simple_task_environment",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_protobuf//:protobuf",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/lite:framework_stable",
+ "@org_tensorflow//tensorflow/lite:string_util",
+ "@org_tensorflow//tensorflow/lite/delegates/flex:delegate_only_runtime",
+ "@org_tensorflow//tensorflow/lite/delegates/flex:util",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ ],
+)
+
+cc_test(
+ name = "tflite_wrapper_test",
+ srcs = ["tflite_wrapper_test.cc"],
+ data = [
+ "//fcp/client/engine/data:join_model.flatbuffer",
+ "//fcp/client/engine/data:length_model.flatbuffer",
+ ],
+ deps = [
+ ":tflite_wrapper",
+ "//fcp/client:interruptible_runner",
+ "//fcp/client:test_helpers",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/core/kernels:string_join_op",
+ "@org_tensorflow//tensorflow/core/ops:string_ops_op_lib",
+ ],
+)
+
+cc_library(
+ name = "tflite_plan_engine",
+ srcs = [
+ "tflite_plan_engine.cc",
+ ],
+ hdrs = [
+ "tflite_plan_engine.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":common",
+ ":example_iterator_factory",
+ ":plan_engine_helpers",
+ ":tflite_wrapper",
+ "//fcp/client:interfaces",
+ "//fcp/client:interruptible_runner",
+ "//fcp/client:simple_task_environment",
+ "//fcp/client/opstats:opstats_logger",
+ "//fcp/protos:plan_cc_proto",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "caching_error_reporter",
+ srcs = ["caching_error_reporter.cc"],
+ hdrs = ["caching_error_reporter.h"],
+ deps = [
+ "@com_google_absl//absl/synchronization",
+ "@org_tensorflow//tensorflow/lite/core/api:error_reporter",
+ ],
+)
+
+cc_test(
+ name = "caching_error_reporter_test",
+ srcs = ["caching_error_reporter_test.cc"],
+ deps = [
+ ":caching_error_reporter",
+ "//fcp/testing",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/lite/core/api:error_reporter",
+ ],
+)
+
+cc_library(
+ name = "common",
+ srcs = ["common.cc"],
+ hdrs = ["common.h"],
+ deps = [
+ ":engine_cc_proto",
+ "//fcp/base",
+ "//fcp/client:interfaces",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "example_iterator_factory",
+ hdrs = ["example_iterator_factory.h"],
+ copts = FCP_COPTS,
+ visibility = default_visibility,
+ deps = [
+ "//fcp/client:simple_task_environment",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/status:statusor",
+ ],
+)
+
+# Runtime protos. Those do not typically get serialized.
+tf_proto_library(
+ name = "engine_proto",
+ srcs = ["engine.proto"],
+ visibility = ["//visibility:public"],
+)
+
+java_proto_library(
+ name = "engine_java_proto",
+ deps = [":engine_proto"],
+)
+
+# Allowing to refer to the cc library generated by the rule above in usual way:
+alias(
+ name = "engine_cc_proto",
+ actual = "engine_proto_cc",
+ visibility = default_visibility,
+)
diff --git a/fcp/client/engine/caching_error_reporter.cc b/fcp/client/engine/caching_error_reporter.cc
new file mode 100644
index 0000000..9b3784a
--- /dev/null
+++ b/fcp/client/engine/caching_error_reporter.cc
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/caching_error_reporter.h"
+
+#include <stdio.h>
+
+#include <string>
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+int CachingErrorReporter::Report(const char* format, va_list args) {
+ absl::MutexLock lock(&mutex_);
+ char error_msg[kBufferSize];
+ int num_characters = vsnprintf(error_msg, kBufferSize, format, args);
+ if (num_characters >= 0) {
+ error_messages_.push_back(std::string(error_msg));
+ } else {
+ // If num_characters is below zero, we can't trust the created string, so we
+ // push an "Unknown error" to the stored error messages.
+ // We don't want to crash here, because the TFLite execution will be
+ // terminated soon.
+ error_messages_.push_back("Unknown error.");
+ }
+ return num_characters;
+}
+
+std::string CachingErrorReporter::GetFirstErrorMessage() {
+ absl::MutexLock lock(&mutex_);
+ if (error_messages_.empty()) {
+ return "";
+ } else {
+ return error_messages_[0];
+ }
+}
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/caching_error_reporter.h b/fcp/client/engine/caching_error_reporter.h
new file mode 100644
index 0000000..e443e1b
--- /dev/null
+++ b/fcp/client/engine/caching_error_reporter.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_ENGINE_CACHING_ERROR_REPORTER_H_
+#define FCP_CLIENT_ENGINE_CACHING_ERROR_REPORTER_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/lite/core/api/error_reporter.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+// This implementation of ErrorReporter stores all the error messages.
+class CachingErrorReporter : public tflite::ErrorReporter {
+ public:
+ int Report(const char* format, va_list args) override
+ ABSL_LOCKS_EXCLUDED(mutex_);
+ std::string GetFirstErrorMessage() ABSL_LOCKS_EXCLUDED(mutex_);
+
+ private:
+ absl::Mutex mutex_;
+ static constexpr int kBufferSize = 1024;
+ // There could be more than one error messages so we store all of them.
+ std::vector<std::string> error_messages_ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_ENGINE_CACHING_ERROR_REPORTER_H_
diff --git a/fcp/client/engine/caching_error_reporter_test.cc b/fcp/client/engine/caching_error_reporter_test.cc
new file mode 100644
index 0000000..896cfc9
--- /dev/null
+++ b/fcp/client/engine/caching_error_reporter_test.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/caching_error_reporter.h"
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/testing/testing.h"
+#include "tensorflow/lite/core/api/error_reporter.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+namespace {
+
+using ::testing::IsEmpty;
+
+TEST(CachingErrorReporterTest, CachingMultiple) {
+ CachingErrorReporter reporter;
+ std::string first_error = "Op a is not found.";
+ TF_LITE_REPORT_ERROR(&reporter, "%s%d", first_error.c_str(), 1);
+ std::string second_error = "Op b is not found.";
+ TF_LITE_REPORT_ERROR(&reporter, "%s%d", second_error.c_str(), 2);
+ EXPECT_THAT(reporter.GetFirstErrorMessage(), absl::StrCat(first_error, "1"));
+}
+
+TEST(CachingErrorReporterTest, Empty) {
+ CachingErrorReporter reporter;
+ EXPECT_THAT(reporter.GetFirstErrorMessage(), IsEmpty());
+}
+
+} // anonymous namespace
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/common.cc b/fcp/client/engine/common.cc
new file mode 100644
index 0000000..e64984b
--- /dev/null
+++ b/fcp/client/engine/common.cc
@@ -0,0 +1,103 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/common.h"
+
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+using ::google::internal::federated::plan::TensorflowSpec;
+
+PlanResult::PlanResult(PlanOutcome outcome, absl::Status status)
+ : outcome(outcome), original_status(std::move(status)) {
+ if (outcome == PlanOutcome::kSuccess) {
+ FCP_CHECK(original_status.ok());
+ }
+}
+
+absl::Status ValidateTensorflowSpec(
+ const TensorflowSpec& tensorflow_spec,
+ const absl::flat_hash_set<std::string>& expected_input_tensor_names_set,
+ const std::vector<std::string>& output_names) {
+ // Check that all inputs have corresponding TensorSpecProtos.
+ if (expected_input_tensor_names_set.size() !=
+ tensorflow_spec.input_tensor_specs_size()) {
+ return absl::InvalidArgumentError(
+ "Unexpected number of input_tensor_specs");
+ }
+
+ for (const tensorflow::TensorSpecProto& it :
+ tensorflow_spec.input_tensor_specs()) {
+ if (!expected_input_tensor_names_set.contains(it.name())) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "Missing expected TensorSpecProto for input ", it.name()));
+ }
+ }
+ // Check that all outputs have corresponding TensorSpecProtos.
+ absl::flat_hash_set<std::string> expected_output_tensor_names_set(
+ output_names.begin(), output_names.end());
+ if (expected_output_tensor_names_set.size() !=
+ tensorflow_spec.output_tensor_specs_size()) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Unexpected number of output_tensor_specs: ",
+ expected_output_tensor_names_set.size(), " vs. ",
+ tensorflow_spec.output_tensor_specs_size()));
+ }
+ for (const tensorflow::TensorSpecProto& it :
+ tensorflow_spec.output_tensor_specs()) {
+ if (!expected_output_tensor_names_set.count(it.name())) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "Missing expected TensorSpecProto for output ", it.name()));
+ }
+ }
+
+ return absl::OkStatus();
+}
+
+PhaseOutcome ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome) {
+ switch (plan_outcome) {
+ case PlanOutcome::kSuccess:
+ return PhaseOutcome::COMPLETED;
+ case PlanOutcome::kInterrupted:
+ return PhaseOutcome::INTERRUPTED;
+ case PlanOutcome::kTensorflowError:
+ case PlanOutcome::kInvalidArgument:
+ case PlanOutcome::kExampleIteratorError:
+ return PhaseOutcome::ERROR;
+ }
+}
+
+absl::Status ConvertPlanOutcomeToStatus(PlanOutcome outcome) {
+ switch (outcome) {
+ case PlanOutcome::kSuccess:
+ return absl::OkStatus();
+ case PlanOutcome::kTensorflowError:
+ case PlanOutcome::kInvalidArgument:
+ case PlanOutcome::kExampleIteratorError:
+ return absl::InternalError("");
+ case PlanOutcome::kInterrupted:
+ return absl::CancelledError("");
+ }
+}
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/common.h b/fcp/client/engine/common.h
new file mode 100644
index 0000000..42b4bde
--- /dev/null
+++ b/fcp/client/engine/common.h
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_ENGINE_COMMON_H_
+#define FCP_CLIENT_ENGINE_COMMON_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/plan.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+enum class PlanOutcome {
+ kSuccess,
+ // A TensorFlow error occurred.
+ kTensorflowError,
+ // Computation was interrupted.
+ kInterrupted,
+ // The input parameters are invalid.
+ kInvalidArgument,
+ // An example iterator error occurred.
+ kExampleIteratorError,
+};
+
+// The result of a call to `SimplePlanEngine::RunPlan` or
+// `TfLitePlanEngine::RunPlan`.
+struct PlanResult {
+ explicit PlanResult(PlanOutcome outcome, absl::Status status);
+
+ // The outcome of the plan execution.
+ PlanOutcome outcome;
+ // Only set if `outcome` is `kSuccess`, otherwise this is empty.
+ std::vector<tensorflow::Tensor> output_tensors;
+ // Only set if `outcome` is `kSuccess`, otherwise this is empty.
+ std::vector<std::string> output_names;
+ // When the outcome is `kSuccess`, the status is ok. Otherwise, this status
+ // contain the original error status which leads to the PlanOutcome.
+ absl::Status original_status;
+ ::fcp::client::ExampleStats example_stats;
+
+ PlanResult(PlanResult&&) = default;
+ PlanResult& operator=(PlanResult&&) = default;
+
+ // Disallow copy and assign.
+ PlanResult(const PlanResult&) = delete;
+ PlanResult& operator=(const PlanResult&) = delete;
+};
+
+// Validates that the input tensors match what's inside the TensorflowSpec.
+absl::Status ValidateTensorflowSpec(
+ const google::internal::federated::plan::TensorflowSpec& tensorflow_spec,
+ const absl::flat_hash_set<std::string>& expected_input_tensor_names_set,
+ const std::vector<std::string>& output_names);
+
+PhaseOutcome ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome);
+
+absl::Status ConvertPlanOutcomeToStatus(engine::PlanOutcome outcome);
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_ENGINE_COMMON_H_
diff --git a/fcp/client/engine/data/BUILD b/fcp/client/engine/data/BUILD
new file mode 100644
index 0000000..f635c18
--- /dev/null
+++ b/fcp/client/engine/data/BUILD
@@ -0,0 +1,24 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "join_model.flatbuffer",
+ "length_model.flatbuffer",
+])
diff --git a/fcp/client/engine/data/README.md b/fcp/client/engine/data/README.md
new file mode 100644
index 0000000..4be24f3
--- /dev/null
+++ b/fcp/client/engine/data/README.md
@@ -0,0 +1,11 @@
+# Simple TfLite models to use in tests.
+
+This directory contains a couple of simple TfLite models for testing purpose.
+
+1. "join_model.flatbuffer": This model takes two input string tensor and
+ returns the concatenated string in an output string tensor.
+2. "length_model.flatbuffer": This model takes one input string tensor and
+ returns the length of the string in an int32 tensor.
+
+These models are generated by using the TfLite converter to convert a TF
+function.
diff --git a/fcp/client/engine/data/join_model.flatbuffer b/fcp/client/engine/data/join_model.flatbuffer
new file mode 100644
index 0000000..66fc6ef
--- /dev/null
+++ b/fcp/client/engine/data/join_model.flatbuffer
Binary files differ
diff --git a/fcp/client/engine/data/length_model.flatbuffer b/fcp/client/engine/data/length_model.flatbuffer
new file mode 100644
index 0000000..a1800e6
--- /dev/null
+++ b/fcp/client/engine/data/length_model.flatbuffer
Binary files differ
diff --git a/fcp/client/engine/engine.proto b/fcp/client/engine/engine.proto
new file mode 100644
index 0000000..fc2e2bb
--- /dev/null
+++ b/fcp/client/engine/engine.proto
@@ -0,0 +1,55 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.client.engine;
+
+option java_package = "com.google.intelligence.fcp.client.engine";
+option java_multiple_files = true;
+option java_outer_classname = "TrainingProto";
+
+// A constraint on the task when to run again.
+message TaskRetry {
+ // An opaque context stored between task activations.
+ string retry_token = 1;
+
+ // The suggested minimal duration after which the client should retry again,
+ // in milliseconds.
+ int64 delay_min = 2;
+
+ // The suggested maximal duration before which the client should retry again
+ // (if conditions allow), in milliseconds.
+ int64 delay_max = 3;
+}
+
+enum PhaseOutcome {
+ PHASE_OUTCOME_UNDEFINED = 0;
+ COMPLETED = 1;
+ INTERRUPTED = 2;
+ ERROR = 3;
+}
+
+enum DataSourceType {
+ // Default value for this enum.
+ TRAINING_DATA_SOURCE_UNDEFINED = 0;
+
+ // Feed based execution, examples were batched outside of TensorFlow and fed
+ // into the training session.
+ FEED = 1;
+
+ // Dataset based execution, TensorFlow was given an ExternalDatasetProvider
+ // and used it internally to create iterators and pull examples.
+ DATASET = 2;
+}
diff --git a/fcp/client/engine/example_iterator_factory.h b/fcp/client/engine/example_iterator_factory.h
new file mode 100644
index 0000000..d1d438f
--- /dev/null
+++ b/fcp/client/engine/example_iterator_factory.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_ENGINE_EXAMPLE_ITERATOR_FACTORY_H_
+#define FCP_CLIENT_ENGINE_EXAMPLE_ITERATOR_FACTORY_H_
+
+#include <functional>
+#include <memory>
+
+#include "absl/status/statusor.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+// An interface for engine-internal usage, describing a handle with which
+// ExampleIterator instances can be created, based on a given ExampleSelector.
+// Each handle indicates which type of ExampleSelector it is actually able to
+// handle.
+class ExampleIteratorFactory {
+ public:
+ // Whether this factory can create iterators that serve the given
+ // `ExampleSelector`-described query.
+ virtual bool CanHandle(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) = 0;
+
+ // Creates an iterator for the given `ExampleSelector`-described query.
+ virtual absl::StatusOr<std::unique_ptr<ExampleIterator>>
+ CreateExampleIterator(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) = 0;
+
+ // Whether stats should be generated and logged into the OpStats database for
+ // iterators created by this factory.
+ virtual bool ShouldCollectStats() = 0;
+
+ virtual ~ExampleIteratorFactory() {}
+};
+
+// A utility ExampleIteratorFactory implementation that can use simple
+// std::function objects and wrap them.
+class FunctionalExampleIteratorFactory : public ExampleIteratorFactory {
+ public:
+ // Creates an `ExampleIteratorFactory` that can handle all queries and for
+ // which stats are collected, and delegates the creation of the
+ // `ExampleIterator` to an std::function.
+ explicit FunctionalExampleIteratorFactory(
+ std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>(
+ const google::internal::federated::plan::ExampleSelector&
+
+ )>
+ create_iterator_func)
+ : can_handle_func_(
+ [](const google::internal::federated::plan::ExampleSelector&) {
+ return true;
+ }),
+ create_iterator_func_(create_iterator_func),
+ should_collect_stats_(true) {}
+
+ // Creates an `ExampleIteratorFactory` that delegates to an std::function to
+ // determine if a given query can be handled, and delegates the creation of
+ // the `ExampleIterator` to an std::function as well.
+ FunctionalExampleIteratorFactory(
+ std::function<
+ bool(const google::internal::federated::plan::ExampleSelector&)>
+ can_handle_func,
+ std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>(
+ const google::internal::federated::plan::ExampleSelector&
+
+ )>
+ create_iterator_func,
+ bool should_collect_stats)
+ : can_handle_func_(can_handle_func),
+ create_iterator_func_(create_iterator_func),
+ should_collect_stats_(should_collect_stats) {}
+
+ bool CanHandle(const google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ return can_handle_func_(example_selector);
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ return create_iterator_func_(example_selector);
+ }
+
+ bool ShouldCollectStats() override { return should_collect_stats_; }
+
+ private:
+ std::function<bool(const google::internal::federated::plan::ExampleSelector&)>
+ can_handle_func_;
+ std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>(
+ const google::internal::federated::plan::ExampleSelector&)>
+ create_iterator_func_;
+ bool should_collect_stats_;
+};
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_ENGINE_EXAMPLE_ITERATOR_FACTORY_H_
diff --git a/fcp/client/engine/example_query_plan_engine.cc b/fcp/client/engine/example_query_plan_engine.cc
new file mode 100644
index 0000000..0e3de0d
--- /dev/null
+++ b/fcp/client/engine/example_query_plan_engine.cc
@@ -0,0 +1,247 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/engine/example_query_plan_engine.h"
+
+#include <atomic>
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/engine/common.h"
+#include "fcp/client/engine/plan_engine_helpers.h"
+#include "fcp/client/example_query_result.pb.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/tensorflow/status.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/platform/tstring.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+namespace tf = ::tensorflow;
+
+using ::fcp::client::ExampleQueryResult;
+using ::fcp::client::engine::PlanResult;
+using ::fcp::client::opstats::OpStatsLogger;
+using ::google::internal::federated::plan::ExampleQuerySpec;
+using ::google::internal::federated::plan::ExampleSelector;
+
+namespace {
+
+// Writes an one-dimensional tensor using the slice writer.
+template <typename T>
+absl::Status WriteSlice(tf::checkpoint::TensorSliceWriter& slice_writer,
+ const std::string& name, const int64_t size,
+ const T* data) {
+ tf::TensorShape shape;
+ shape.AddDim(size);
+ tf::TensorSlice slice(shape.dims());
+ tf::Status tf_status = slice_writer.Add(name, shape, slice, data);
+ return ConvertFromTensorFlowStatus(tf_status);
+}
+
+// Returns a map of (vector name) -> tuple(output name, vector spec).
+absl::flat_hash_map<std::string,
+ std::tuple<std::string, ExampleQuerySpec::OutputVectorSpec>>
+GetOutputVectorSpecs(const ExampleQuerySpec::ExampleQuery& example_query) {
+ absl::flat_hash_map<
+ std::string, std::tuple<std::string, ExampleQuerySpec::OutputVectorSpec>>
+ map;
+ for (auto const& [output_name, output_vector_spec] :
+ example_query.output_vector_specs()) {
+ map[output_vector_spec.vector_name()] =
+ std::make_tuple(output_name, output_vector_spec);
+ }
+ return map;
+}
+
+absl::Status CheckOutputVectorDataType(
+ const ExampleQuerySpec::OutputVectorSpec& output_vector_spec,
+ const ExampleQuerySpec::OutputVectorSpec::DataType& expected_data_type) {
+ if (output_vector_spec.data_type() != expected_data_type) {
+ return absl::FailedPreconditionError(
+ "Unexpected data type in the example query");
+ }
+ return absl::OkStatus();
+}
+
+// Writes example query results into a checkpoint. Example query results order
+// must be the same as example_query_spec.example_queries.
+absl::Status WriteCheckpoint(
+ const std::string& output_checkpoint_filename,
+ const std::vector<ExampleQueryResult>& example_query_results,
+ const ExampleQuerySpec& example_query_spec) {
+ tf::checkpoint::TensorSliceWriter slice_writer(
+ output_checkpoint_filename,
+ tf::checkpoint::CreateTableTensorSliceBuilder);
+ for (int i = 0; i < example_query_results.size(); ++i) {
+ const ExampleQueryResult& example_query_result = example_query_results[i];
+ const ExampleQuerySpec::ExampleQuery& example_query =
+ example_query_spec.example_queries()[i];
+ for (auto const& [vector_name, vector_tuple] :
+ GetOutputVectorSpecs(example_query)) {
+ std::string output_name = std::get<0>(vector_tuple);
+ ExampleQuerySpec::OutputVectorSpec output_vector_spec =
+ std::get<1>(vector_tuple);
+ auto it = example_query_result.vector_data().vectors().find(vector_name);
+ if (it == example_query_result.vector_data().vectors().end()) {
+ return absl::DataLossError(
+ "Expected value not found in the example query result");
+ }
+ const ExampleQueryResult::VectorData::Values values = it->second;
+ absl::Status status;
+ if (values.has_int32_values()) {
+ FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
+ output_vector_spec, ExampleQuerySpec::OutputVectorSpec::INT32));
+ int64_t size = values.int32_values().value_size();
+ auto data =
+ static_cast<const int32_t*>(values.int32_values().value().data());
+ FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
+ } else if (values.has_int64_values()) {
+ FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
+ output_vector_spec, ExampleQuerySpec::OutputVectorSpec::INT64));
+ int64_t size = values.int64_values().value_size();
+ auto data =
+ static_cast<const int64_t*>(values.int64_values().value().data());
+ FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
+ } else if (values.has_string_values()) {
+ FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
+ output_vector_spec, ExampleQuerySpec::OutputVectorSpec::STRING));
+ int64_t size = values.string_values().value_size();
+ std::vector<tf::tstring> tf_string_vector;
+ for (const auto& value : values.string_values().value()) {
+ tf_string_vector.emplace_back(value);
+ }
+ FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size,
+ tf_string_vector.data()));
+ } else if (values.has_bool_values()) {
+ FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
+ output_vector_spec, ExampleQuerySpec::OutputVectorSpec::BOOL));
+ int64_t size = values.bool_values().value_size();
+ auto data =
+ static_cast<const bool*>(values.bool_values().value().data());
+ FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
+ } else if (values.has_float_values()) {
+ FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
+ output_vector_spec, ExampleQuerySpec::OutputVectorSpec::FLOAT));
+ int64_t size = values.float_values().value_size();
+ auto data =
+ static_cast<const float*>(values.float_values().value().data());
+ FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
+ } else if (values.has_double_values()) {
+ FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
+ output_vector_spec, ExampleQuerySpec::OutputVectorSpec::DOUBLE));
+ int64_t size = values.double_values().value_size();
+ auto data =
+ static_cast<const double*>(values.double_values().value().data());
+ FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
+ } else if (values.has_bytes_values()) {
+ FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
+ output_vector_spec, ExampleQuerySpec::OutputVectorSpec::BYTES));
+ int64_t size = values.bytes_values().value_size();
+ std::vector<tf::tstring> tf_string_vector;
+ for (const auto& value : values.string_values().value()) {
+ tf_string_vector.emplace_back(value);
+ }
+ FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size,
+ tf_string_vector.data()));
+ } else {
+ return absl::DataLossError(
+ "Unexpected data type in the example query result");
+ }
+ }
+ }
+ return ConvertFromTensorFlowStatus(slice_writer.Finish());
+}
+
+} // anonymous namespace
+
+ExampleQueryPlanEngine::ExampleQueryPlanEngine(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ OpStatsLogger* opstats_logger)
+ : example_iterator_factories_(example_iterator_factories),
+ opstats_logger_(opstats_logger) {}
+
+PlanResult ExampleQueryPlanEngine::RunPlan(
+ const ExampleQuerySpec& example_query_spec,
+ const std::string& output_checkpoint_filename) {
+ // TODO(team): Add the same logging as in simple_plan_engine.
+ std::vector<ExampleQueryResult> example_query_results;
+
+ for (const auto& example_query : example_query_spec.example_queries()) {
+ ExampleSelector selector = example_query.example_selector();
+ ExampleIteratorFactory* example_iterator_factory =
+ FindExampleIteratorFactory(selector, example_iterator_factories_);
+ if (example_iterator_factory == nullptr) {
+ return PlanResult(PlanOutcome::kExampleIteratorError,
+ absl::InternalError(
+ "Could not find suitable ExampleIteratorFactory"));
+ }
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> example_iterator =
+ example_iterator_factory->CreateExampleIterator(selector);
+ if (!example_iterator.ok()) {
+ return PlanResult(PlanOutcome::kExampleIteratorError,
+ example_iterator.status());
+ }
+
+ std::atomic<int> total_example_count = 0;
+ std::atomic<int64_t> total_example_size_bytes = 0;
+ ExampleIteratorStatus example_iterator_status;
+
+ auto dataset_iterator = std::make_unique<DatasetIterator>(
+ std::move(*example_iterator), opstats_logger_, &total_example_count,
+ &total_example_size_bytes, &example_iterator_status,
+ selector.collection_uri(),
+ /*collect_stats=*/example_iterator_factory->ShouldCollectStats());
+
+ absl::StatusOr<std::string> example_query_result_str =
+ dataset_iterator->GetNext();
+ if (!example_query_result_str.ok()) {
+ return PlanResult(PlanOutcome::kExampleIteratorError,
+ example_query_result_str.status());
+ }
+
+ ExampleQueryResult example_query_result;
+ if (!example_query_result.ParseFromString(*example_query_result_str)) {
+ return PlanResult(
+ PlanOutcome::kExampleIteratorError,
+ absl::DataLossError("Unexpected example query result format"));
+ }
+ example_query_results.push_back(std::move(example_query_result));
+ }
+ absl::Status status = WriteCheckpoint(
+ output_checkpoint_filename, example_query_results, example_query_spec);
+ if (!status.ok()) {
+ return PlanResult(PlanOutcome::kExampleIteratorError, status);
+ }
+ return PlanResult(PlanOutcome::kSuccess, absl::OkStatus());
+}
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/example_query_plan_engine.h b/fcp/client/engine/example_query_plan_engine.h
new file mode 100644
index 0000000..2b11342
--- /dev/null
+++ b/fcp/client/engine/example_query_plan_engine.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_ENGINE_EXAMPLE_QUERY_PLAN_ENGINE_H_
+#define FCP_CLIENT_ENGINE_EXAMPLE_QUERY_PLAN_ENGINE_H_
+
+#include <string>
+#include <vector>
+
+#include "fcp/client/engine/common.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/opstats/opstats_logger.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+// A class used to "run" (interpret) an ExampleQuerySpec-based plan. Each
+// instance should generally only be used once to run a plan.
+class ExampleQueryPlanEngine {
+ public:
+ ExampleQueryPlanEngine(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger);
+
+ // Runs a plan and writes an output into a checkpoint at the given path.
+ ::fcp::client::engine::PlanResult RunPlan(
+ const google::internal::federated::plan::ExampleQuerySpec&
+ example_query_spec,
+ const std::string& output_checkpoint_filename);
+
+ private:
+ std::vector<ExampleIteratorFactory*> example_iterator_factories_;
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger_;
+};
+} // namespace engine
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_ENGINE_EXAMPLE_QUERY_PLAN_ENGINE_H_
diff --git a/fcp/client/engine/example_query_plan_engine_test.cc b/fcp/client/engine/example_query_plan_engine_test.cc
new file mode 100644
index 0000000..dbad82d
--- /dev/null
+++ b/fcp/client/engine/example_query_plan_engine_test.cc
@@ -0,0 +1,547 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/example_query_plan_engine.h"
+
+#include <fcntl.h>
+
+#include <cstdint>
+#include <filesystem>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "fcp/client/client_runner.h"
+#include "fcp/client/engine/common.h"
+#include "fcp/client/example_query_result.pb.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/testing/testing.h"
+#include "tensorflow/c/checkpoint_reader.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+namespace {
+
+namespace tf = ::tensorflow;
+
+using ::fcp::client::ExampleQueryResult;
+using ::google::internal::federated::plan::AggregationConfig;
+using ::google::internal::federated::plan::ClientOnlyPlan;
+using ::google::internal::federated::plan::Dataset;
+using ::google::internal::federated::plan::ExampleQuerySpec;
+using ::google::internal::federated::plan::ExampleSelector;
+using ::testing::StrictMock;
+
+const char* const kCollectionUri = "app:/test_collection";
+const char* const kOutputStringVectorName = "vector1";
+const char* const kOutputIntVectorName = "vector2";
+const char* const kOutputStringTensorName = "tensor1";
+const char* const kOutputIntTensorName = "tensor2";
+
+class InvalidExampleIteratorFactory : public ExampleIteratorFactory {
+ public:
+ InvalidExampleIteratorFactory() = default;
+
+ bool CanHandle(const google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ return false;
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const ExampleSelector& example_selector) override {
+ absl::Status error(absl::StatusCode::kInternal, "");
+ return error;
+ }
+
+ bool ShouldCollectStats() override { return false; }
+};
+
+class NoIteratorExampleIteratorFactory : public ExampleIteratorFactory {
+ public:
+ NoIteratorExampleIteratorFactory() = default;
+
+ bool CanHandle(const google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ return true;
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const ExampleSelector& example_selector) override {
+ absl::Status error(absl::StatusCode::kInternal, "");
+ return error;
+ }
+
+ bool ShouldCollectStats() override { return false; }
+};
+
+class TwoExampleIteratorsFactory : public ExampleIteratorFactory {
+ public:
+ explicit TwoExampleIteratorsFactory(
+ std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>(
+ const google::internal::federated::plan::ExampleSelector&
+
+ )>
+ create_first_iterator_func,
+ std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>(
+ const google::internal::federated::plan::ExampleSelector&
+
+ )>
+ create_second_iterator_func,
+ const std::string& first_collection_uri,
+ const std::string& second_collection_uri)
+ : create_first_iterator_func_(create_first_iterator_func),
+ create_second_iterator_func_(create_second_iterator_func),
+ first_collection_uri_(first_collection_uri),
+ second_collection_uri_(second_collection_uri) {}
+
+ bool CanHandle(const google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ return true;
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ if (example_selector.collection_uri() == first_collection_uri_) {
+ return create_first_iterator_func_(example_selector);
+ } else if (example_selector.collection_uri() == second_collection_uri_) {
+ return create_second_iterator_func_(example_selector);
+ }
+ return absl::InvalidArgumentError("Unknown collection URI");
+ }
+
+ bool ShouldCollectStats() override { return false; }
+
+ private:
+ std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>(
+ const google::internal::federated::plan::ExampleSelector&)>
+ create_first_iterator_func_;
+ std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>(
+ const google::internal::federated::plan::ExampleSelector&)>
+ create_second_iterator_func_;
+ std::string first_collection_uri_;
+ std::string second_collection_uri_;
+};
+
+absl::StatusOr<absl::flat_hash_map<std::string, tf::Tensor>> ReadTensors(
+ std::string checkpoint_path) {
+ absl::flat_hash_map<std::string, tf::Tensor> tensors;
+ tf::TF_StatusPtr tf_status(TF_NewStatus());
+ tf::checkpoint::CheckpointReader tf_checkpoint_reader(checkpoint_path,
+ tf_status.get());
+ if (TF_GetCode(tf_status.get()) != TF_OK) {
+ return absl::NotFoundError("Couldn't read an input checkpoint");
+ }
+ for (const auto& [name, tf_dtype] :
+ tf_checkpoint_reader.GetVariableToDataTypeMap()) {
+ std::unique_ptr<tf::Tensor> tensor;
+ tf_checkpoint_reader.GetTensor(name, &tensor, tf_status.get());
+ if (TF_GetCode(tf_status.get()) != TF_OK) {
+ return absl::NotFoundError(
+ absl::StrFormat("Checkpoint doesn't have tensor %s", name));
+ }
+ tensors[name] = *tensor;
+ }
+
+ return tensors;
+}
+
+class ExampleQueryPlanEngineTest : public testing::Test {
+ protected:
+ void Initialize() {
+ std::filesystem::path root_dir(testing::TempDir());
+ std::filesystem::path output_path = root_dir / std::string("output.ckpt");
+ output_checkpoint_filename_ = output_path.string();
+
+ ExampleQuerySpec::OutputVectorSpec string_vector_spec;
+ string_vector_spec.set_vector_name(kOutputStringVectorName);
+ string_vector_spec.set_data_type(
+ ExampleQuerySpec::OutputVectorSpec::STRING);
+ ExampleQuerySpec::OutputVectorSpec int_vector_spec;
+ int_vector_spec.set_vector_name(kOutputIntVectorName);
+ int_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::INT64);
+
+ ExampleQuerySpec::ExampleQuery example_query;
+ example_query.mutable_example_selector()->set_collection_uri(
+ kCollectionUri);
+ (*example_query.mutable_output_vector_specs())[kOutputStringTensorName] =
+ string_vector_spec;
+ (*example_query.mutable_output_vector_specs())[kOutputIntTensorName] =
+ int_vector_spec;
+ client_only_plan_.mutable_phase()
+ ->mutable_example_query_spec()
+ ->mutable_example_queries()
+ ->Add(std::move(example_query));
+
+ AggregationConfig aggregation_config;
+ aggregation_config.mutable_tf_v1_checkpoint_aggregation();
+ (*client_only_plan_.mutable_phase()
+ ->mutable_federated_example_query()
+ ->mutable_aggregations())[kOutputStringTensorName] =
+ aggregation_config;
+ (*client_only_plan_.mutable_phase()
+ ->mutable_federated_example_query()
+ ->mutable_aggregations())[kOutputIntTensorName] = aggregation_config;
+
+ ExampleQueryResult::VectorData::Values int_values;
+ int_values.mutable_int64_values()->add_value(42);
+ int_values.mutable_int64_values()->add_value(24);
+ (*example_query_result_.mutable_vector_data()
+ ->mutable_vectors())[kOutputIntVectorName] = int_values;
+ ExampleQueryResult::VectorData::Values string_values;
+ string_values.mutable_string_values()->add_value("value1");
+ string_values.mutable_string_values()->add_value("value2");
+ (*example_query_result_.mutable_vector_data()
+ ->mutable_vectors())[kOutputStringVectorName] = string_values;
+ std::string example = example_query_result_.SerializeAsString();
+
+ Dataset::ClientDataset client_dataset;
+ client_dataset.set_client_id("client_id");
+ client_dataset.add_example(example);
+ dataset_.mutable_client_data()->Add(std::move(client_dataset));
+
+ num_examples_ = 1;
+ example_bytes_ = example.size();
+
+ example_iterator_factory_ =
+ std::make_unique<FunctionalExampleIteratorFactory>(
+ [&dataset = dataset_](
+ const google::internal::federated::plan::ExampleSelector&
+ selector) {
+ return std::make_unique<SimpleExampleIterator>(dataset);
+ });
+ }
+
+ fcp::client::FilesImpl files_impl_;
+ StrictMock<MockOpStatsLogger> mock_opstats_logger_;
+ std::unique_ptr<ExampleIteratorFactory> example_iterator_factory_;
+
+ ExampleQueryResult example_query_result_;
+ ClientOnlyPlan client_only_plan_;
+ Dataset dataset_;
+ std::string output_checkpoint_filename_;
+
+ int num_examples_ = 0;
+ int64_t example_bytes_ = 0;
+};
+
+TEST_F(ExampleQueryPlanEngineTest, PlanSucceeds) {
+ Initialize();
+
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_));
+
+ ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()},
+ &mock_opstats_logger_);
+ engine::PlanResult result =
+ plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(),
+ output_checkpoint_filename_);
+
+ EXPECT_THAT(result.outcome, PlanOutcome::kSuccess);
+
+ auto tensors = ReadTensors(output_checkpoint_filename_);
+ ASSERT_OK(tensors);
+ tf::Tensor int_tensor = tensors.value()[kOutputIntTensorName];
+ ASSERT_EQ(int_tensor.shape(), tf::TensorShape({2}));
+ ASSERT_EQ(int_tensor.dtype(), tf::DT_INT64);
+ auto int_data = static_cast<int64_t*>(int_tensor.data());
+ std::vector<int64_t> expected_int_data({42, 24});
+ for (int i = 0; i < 2; ++i) {
+ ASSERT_EQ(int_data[i], expected_int_data[i]);
+ }
+
+ tf::Tensor string_tensor = tensors.value()[kOutputStringTensorName];
+ ASSERT_EQ(string_tensor.shape(), tf::TensorShape({2}));
+ ASSERT_EQ(string_tensor.dtype(), tf::DT_STRING);
+ auto string_data = static_cast<tf::tstring*>(string_tensor.data());
+ std::vector<std::string> expected_string_data({"value1", "value2"});
+ for (int i = 0; i < 2; ++i) {
+ ASSERT_EQ(static_cast<std::string>(string_data[i]),
+ expected_string_data[i]);
+ }
+}
+
+TEST_F(ExampleQueryPlanEngineTest, MultipleQueries) {
+ Initialize();
+
+ ExampleQuerySpec::OutputVectorSpec float_vector_spec;
+ float_vector_spec.set_vector_name("float_vector");
+ float_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::FLOAT);
+ ExampleQuerySpec::OutputVectorSpec string_vector_spec;
+ // Same vector name as in the other ExampleQuery, but with a different output
+ // one to make sure these vectors are distinguished in
+ // example_query_plan_engine.
+ string_vector_spec.set_vector_name(kOutputStringVectorName);
+ string_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::STRING);
+
+ ExampleQuerySpec::ExampleQuery second_example_query;
+ second_example_query.mutable_example_selector()->set_collection_uri(
+ "app:/second_collection");
+ (*second_example_query.mutable_output_vector_specs())["float_tensor"] =
+ float_vector_spec;
+ (*second_example_query
+ .mutable_output_vector_specs())["another_string_tensor"] =
+ string_vector_spec;
+ client_only_plan_.mutable_phase()
+ ->mutable_example_query_spec()
+ ->mutable_example_queries()
+ ->Add(std::move(second_example_query));
+
+ AggregationConfig aggregation_config;
+ aggregation_config.mutable_tf_v1_checkpoint_aggregation();
+ (*client_only_plan_.mutable_phase()
+ ->mutable_federated_example_query()
+ ->mutable_aggregations())["float_tensor"] = aggregation_config;
+
+ ExampleQueryResult second_example_query_result;
+ ExampleQueryResult::VectorData::Values float_values;
+ float_values.mutable_float_values()->add_value(0.24f);
+ float_values.mutable_float_values()->add_value(0.42f);
+ float_values.mutable_float_values()->add_value(0.33f);
+ ExampleQueryResult::VectorData::Values string_values;
+ string_values.mutable_string_values()->add_value("another_string_value");
+ (*second_example_query_result.mutable_vector_data()
+ ->mutable_vectors())["float_vector"] = float_values;
+ (*second_example_query_result.mutable_vector_data()
+ ->mutable_vectors())[kOutputStringVectorName] = string_values;
+ std::string example = second_example_query_result.SerializeAsString();
+
+ Dataset::ClientDataset dataset;
+ dataset.set_client_id("second_client_id");
+ dataset.add_example(example);
+ Dataset second_dataset;
+ second_dataset.mutable_client_data()->Add(std::move(dataset));
+
+ example_iterator_factory_ = std::make_unique<TwoExampleIteratorsFactory>(
+ [&dataset = dataset_](
+ const google::internal::federated::plan::ExampleSelector& selector) {
+ return std::make_unique<SimpleExampleIterator>(dataset);
+ },
+ [&dataset = second_dataset](
+ const google::internal::federated::plan::ExampleSelector& selector) {
+ return std::make_unique<SimpleExampleIterator>(dataset);
+ },
+ kCollectionUri, "app:/second_collection");
+
+ ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()},
+ &mock_opstats_logger_);
+ engine::PlanResult result =
+ plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(),
+ output_checkpoint_filename_);
+
+ EXPECT_THAT(result.outcome, PlanOutcome::kSuccess);
+
+ auto tensors = ReadTensors(output_checkpoint_filename_);
+ ASSERT_OK(tensors);
+ tf::Tensor int_tensor = tensors.value()[kOutputIntTensorName];
+ ASSERT_EQ(int_tensor.shape(), tf::TensorShape({2}));
+ ASSERT_EQ(int_tensor.dtype(), tf::DT_INT64);
+ auto int_data = static_cast<int64_t*>(int_tensor.data());
+ std::vector<int64_t> expected_int_data({42, 24});
+ for (int i = 0; i < 2; ++i) {
+ ASSERT_EQ(int_data[i], expected_int_data[i]);
+ }
+
+ tf::Tensor string_tensor = tensors.value()[kOutputStringTensorName];
+ ASSERT_EQ(string_tensor.shape(), tf::TensorShape({2}));
+ ASSERT_EQ(string_tensor.dtype(), tf::DT_STRING);
+ auto string_data = static_cast<tf::tstring*>(string_tensor.data());
+ std::vector<std::string> expected_string_data({"value1", "value2"});
+ for (int i = 0; i < 2; ++i) {
+ ASSERT_EQ(static_cast<std::string>(string_data[i]),
+ expected_string_data[i]);
+ }
+
+ tf::Tensor float_tensor = tensors.value()["float_tensor"];
+ ASSERT_EQ(float_tensor.shape(), tf::TensorShape({3}));
+ ASSERT_EQ(float_tensor.dtype(), tf::DT_FLOAT);
+ auto float_data = static_cast<float*>(float_tensor.data());
+ std::vector<float> expected_float_data({0.24f, 0.42f, 0.33f});
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_EQ(float_data[i], expected_float_data[i]);
+ }
+
+ tf::Tensor second_query_string_tensor =
+ tensors.value()["another_string_tensor"];
+ ASSERT_EQ(second_query_string_tensor.shape(), tf::TensorShape({1}));
+ ASSERT_EQ(second_query_string_tensor.dtype(), tf::DT_STRING);
+ auto second_query_string_data =
+ static_cast<tf::tstring*>(second_query_string_tensor.data());
+ ASSERT_EQ(static_cast<std::string>(*second_query_string_data),
+ "another_string_value");
+}
+
+TEST_F(ExampleQueryPlanEngineTest, OutputVectorSpecMissingInResult) {
+ Initialize();
+
+ ExampleQuerySpec::OutputVectorSpec new_vector_spec;
+ new_vector_spec.set_vector_name("new_vector");
+ new_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::DOUBLE);
+
+ ExampleQuerySpec::ExampleQuery example_query =
+ client_only_plan_.phase().example_query_spec().example_queries().at(0);
+ (*example_query.mutable_output_vector_specs())["new_tensor"] =
+ new_vector_spec;
+ client_only_plan_.mutable_phase()
+ ->mutable_example_query_spec()
+ ->clear_example_queries();
+ client_only_plan_.mutable_phase()
+ ->mutable_example_query_spec()
+ ->mutable_example_queries()
+ ->Add(std::move(example_query));
+
+ ExampleQueryResult example_query_result;
+ ExampleQueryResult::VectorData::Values bool_values;
+ bool_values.mutable_bool_values()->add_value(true);
+ (*example_query_result_.mutable_vector_data()
+ ->mutable_vectors())["new_vector"] = bool_values;
+ std::string example = example_query_result_.SerializeAsString();
+
+ Dataset::ClientDataset client_dataset;
+ client_dataset.set_client_id("client_id");
+ client_dataset.add_example(example);
+ dataset_.clear_client_data();
+ dataset_.mutable_client_data()->Add(std::move(client_dataset));
+
+ num_examples_ = 1;
+ example_bytes_ = example.size();
+
+ example_iterator_factory_ =
+ std::make_unique<FunctionalExampleIteratorFactory>(
+ [&dataset = dataset_](
+ const google::internal::federated::plan::ExampleSelector&
+ selector) {
+ return std::make_unique<SimpleExampleIterator>(dataset);
+ });
+
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_));
+
+ ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()},
+ &mock_opstats_logger_);
+ engine::PlanResult result =
+ plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(),
+ output_checkpoint_filename_);
+
+ EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError);
+}
+
+TEST_F(ExampleQueryPlanEngineTest, OutputVectorSpecTypeMismatch) {
+ Initialize();
+
+ ExampleQuerySpec::OutputVectorSpec new_vector_spec;
+ new_vector_spec.set_vector_name("new_vector");
+ new_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::DOUBLE);
+
+ ExampleQuerySpec::ExampleQuery example_query =
+ client_only_plan_.phase().example_query_spec().example_queries().at(0);
+ (*example_query.mutable_output_vector_specs())["new_tensor"] =
+ new_vector_spec;
+ client_only_plan_.mutable_phase()
+ ->mutable_example_query_spec()
+ ->clear_example_queries();
+ client_only_plan_.mutable_phase()
+ ->mutable_example_query_spec()
+ ->mutable_example_queries()
+ ->Add(std::move(example_query));
+
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_));
+
+ ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()},
+ &mock_opstats_logger_);
+ engine::PlanResult result =
+ plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(),
+ output_checkpoint_filename_);
+
+ EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError);
+}
+
+TEST_F(ExampleQueryPlanEngineTest, FactoryNotFound) {
+ Initialize();
+ auto invalid_example_factory =
+ std::make_unique<InvalidExampleIteratorFactory>();
+
+ ExampleQueryPlanEngine plan_engine({invalid_example_factory.get()},
+ &mock_opstats_logger_);
+ engine::PlanResult result =
+ plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(),
+ output_checkpoint_filename_);
+
+ EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError);
+}
+
+TEST_F(ExampleQueryPlanEngineTest, NoIteratorCreated) {
+ Initialize();
+ auto invalid_example_factory =
+ std::make_unique<NoIteratorExampleIteratorFactory>();
+
+ ExampleQueryPlanEngine plan_engine({invalid_example_factory.get()},
+ &mock_opstats_logger_);
+ engine::PlanResult result =
+ plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(),
+ output_checkpoint_filename_);
+
+ EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError);
+}
+
+TEST_F(ExampleQueryPlanEngineTest, InvalidExampleQueryResultFormat) {
+ Initialize();
+ std::string invalid_example = "invalid_example";
+ Dataset::ClientDataset client_dataset;
+ client_dataset.add_example(invalid_example);
+ dataset_.clear_client_data();
+ dataset_.mutable_client_data()->Add(std::move(client_dataset));
+ example_iterator_factory_ =
+ std::make_unique<FunctionalExampleIteratorFactory>(
+ [&dataset = dataset_](
+ const google::internal::federated::plan::ExampleSelector&
+ selector) {
+ return std::make_unique<SimpleExampleIterator>(dataset);
+ });
+ EXPECT_CALL(mock_opstats_logger_,
+ UpdateDatasetStats(kCollectionUri, 1, invalid_example.size()));
+
+ ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()},
+ &mock_opstats_logger_);
+ engine::PlanResult result =
+ plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(),
+ output_checkpoint_filename_);
+
+ EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError);
+}
+
+} // anonymous namespace
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/plan_engine_helpers.cc b/fcp/client/engine/plan_engine_helpers.cc
new file mode 100644
index 0000000..d5a322e
--- /dev/null
+++ b/fcp/client/engine/plan_engine_helpers.cc
@@ -0,0 +1,285 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/plan_engine_helpers.h"
+
+#include <functional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/status/statusor.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/opstats/opstats_logger_impl.h"
+// #include "fcp/client/opstats/pds_backed_opstats_db.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/tensorflow/external_dataset.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+namespace {
+
+using ::fcp::client::opstats::OpStatsLogger;
+using ::fcp::client::opstats::OpStatsLoggerImpl;
+// using ::fcp::client::opstats::PdsBackedOpStatsDb;
+using ::google::internal::federated::plan::ExampleSelector;
+
+/** An iterator that forwards the failing status from the external dataset to
+ * TensorFlow. */
+class FailingDatasetIterator : public ExternalDatasetIterator {
+ public:
+ explicit FailingDatasetIterator(absl::Status status) : status_(status) {}
+
+ absl::StatusOr<std::string> GetNext() final { return status_; }
+
+ private:
+ const absl::Status status_;
+};
+
+class TrainingDatasetProvider
+ : public ExternalDatasetProvider::UsingProtoSelector<ExampleSelector> {
+ public:
+ TrainingDatasetProvider(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ OpStatsLogger* opstats_logger, std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ ExampleIteratorStatus* example_iterator_status)
+ : example_iterator_factories_(example_iterator_factories),
+ opstats_logger_(opstats_logger),
+ total_example_count_(total_example_count),
+ total_example_size_bytes_(total_example_size_bytes),
+ example_iterator_status_(example_iterator_status) {}
+
+ absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
+ ExampleSelector selector) final {
+ return ExternalDataset::FromFunction(
+ [example_iterator_factories = example_iterator_factories_,
+ opstats_logger = opstats_logger_, selector,
+ total_example_count = total_example_count_,
+ total_example_size_bytes = total_example_size_bytes_,
+ example_iterator_status = example_iterator_status_]()
+ -> std::unique_ptr<ExternalDatasetIterator> {
+ ExampleIteratorFactory* example_iterator_factory =
+ FindExampleIteratorFactory(selector, example_iterator_factories);
+ // The DatasetOp requires a valid iterator at this stage so return an
+ // empty iterator if there was an error.
+ if (example_iterator_factory == nullptr) {
+ absl::Status error(
+ absl::StatusCode::kInternal,
+ "Could not find suitable ExampleIteratorFactory");
+ example_iterator_status->SetStatus(error);
+ return std::make_unique<FailingDatasetIterator>(error);
+ }
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> example_iterator =
+ example_iterator_factory->CreateExampleIterator(selector);
+ if (!example_iterator.ok()) {
+ example_iterator_status->SetStatus(example_iterator.status());
+ return std::make_unique<FailingDatasetIterator>(
+ example_iterator.status());
+ }
+ return std::make_unique<DatasetIterator>(
+ std::move(*example_iterator), opstats_logger, total_example_count,
+ total_example_size_bytes, example_iterator_status,
+ selector.collection_uri(),
+ /*collect_stats=*/example_iterator_factory->ShouldCollectStats());
+ });
+ }
+
+ private:
+ std::vector<ExampleIteratorFactory*> example_iterator_factories_;
+ OpStatsLogger* opstats_logger_;
+ std::atomic<int>* total_example_count_;
+ std::atomic<int64_t>* total_example_size_bytes_;
+ ExampleIteratorStatus* example_iterator_status_;
+};
+
+} // namespace
+
+DatasetIterator::DatasetIterator(
+ std::unique_ptr<ExampleIterator> example_iterator,
+ opstats::OpStatsLogger* opstats_logger,
+ std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ ExampleIteratorStatus* example_iterator_status,
+ const std::string& collection_uri, bool collect_stats)
+ : example_iterator_(std::move(example_iterator)),
+ opstats_logger_(opstats_logger),
+ iterator_start_time_(absl::Now()),
+ total_example_count_(total_example_count),
+ total_example_size_bytes_(total_example_size_bytes),
+ example_iterator_status_(example_iterator_status),
+ example_count_(0),
+ example_size_bytes_(0),
+ collection_uri_(collection_uri),
+ iterator_finished_(false),
+ collect_stats_(collect_stats) {}
+
+DatasetIterator::~DatasetIterator() {
+ if (collect_stats_) {
+ opstats_logger_->UpdateDatasetStats(collection_uri_, example_count_,
+ example_size_bytes_);
+ }
+}
+
+// Returns the next entry from the dataset.
+absl::StatusOr<std::string> DatasetIterator::GetNext() {
+ absl::MutexLock locked(&iterator_lock_);
+ if (iterator_finished_) {
+ // If we've reached the end of the iterator, always return OUT_OF_RANGE.
+ return absl::OutOfRangeError("End of iterator reached");
+ }
+ absl::StatusOr<std::string> example = example_iterator_->Next();
+ absl::StatusCode error_code = example.status().code();
+ example_iterator_status_->SetStatus(example.status());
+ if (error_code == absl::StatusCode::kOutOfRange) {
+ example_iterator_->Close();
+ iterator_finished_ = true;
+ }
+ // If we're not forwarding an OUT_OF_RANGE to the caller, record example
+ // stats for metrics logging.
+ if (collect_stats_ && example.ok()) {
+ // TODO(team): Consider reducing logic duplication in
+ // cross-dataset and single-dataset example stat variables.
+ *total_example_count_ += 1;
+ *total_example_size_bytes_ += example->size();
+ example_count_ += 1;
+ example_size_bytes_ += example->size();
+ }
+ return example;
+}
+
+void ExampleIteratorStatus::SetStatus(absl::Status status) {
+ absl::MutexLock lock(&mu_);
+ // We ignores normal status such as ok and outOfRange to avoid running into a
+ // race condition when an error happened, then an outofRange or ok status
+ // returned in a different thread which overrides the error status.
+ if (status.code() != absl::StatusCode::kOk &&
+ status.code() != absl::StatusCode::kOutOfRange) {
+ status_ = status;
+ }
+}
+
+absl::Status ExampleIteratorStatus::GetStatus() {
+ absl::MutexLock lock(&mu_);
+ return status_;
+}
+
+HostObjectRegistration AddDatasetTokenToInputs(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ OpStatsLogger* opstats_logger,
+ std::vector<std::pair<std::string, tensorflow::Tensor>>* inputs,
+ const std::string& dataset_token_tensor_name,
+ std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ ExampleIteratorStatus* example_iterator_status) {
+ // Register the TrainingDatasetProvider with the global
+ // ExternalDatasetProviderRegistry.
+ auto host_registration = fcp::ExternalDatasetProviderRegistry::Register(
+ std::make_shared<TrainingDatasetProvider>(
+ example_iterator_factories, opstats_logger, total_example_count,
+ total_example_size_bytes, example_iterator_status));
+ // Pack the token returned from registering the provider into a string
+ // tensor. TensorFlow will use that token via the ExternalDatasetOp to create
+ // datasets and iterators.
+ tensorflow::Tensor token_scalar(std::string{});
+ token_scalar.scalar<tensorflow::tstring>()() =
+ host_registration.token().ToString();
+ std::pair<std::string, tensorflow::Tensor> token_pair(
+ dataset_token_tensor_name, token_scalar);
+ inputs->emplace_back(token_pair);
+ return host_registration;
+}
+
+HostObjectRegistration AddDatasetTokenToInputsForTfLite(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ OpStatsLogger* opstats_logger,
+ absl::flat_hash_map<std::string, std::string>* inputs,
+ const std::string& dataset_token_tensor_name,
+ std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ ExampleIteratorStatus* example_iterator_status) {
+ // Registers the TrainingDatasetProvider with the global
+ // ExternalDatasetProviderRegistry.
+ auto host_registration = fcp::ExternalDatasetProviderRegistry::Register(
+ std::make_shared<TrainingDatasetProvider>(
+ example_iterator_factories, opstats_logger, total_example_count,
+ total_example_size_bytes, example_iterator_status));
+ // Adds the token returned from registering the provider to the map of inputs.
+ // TfLite will use that token via the ExternalDatasetOp to create
+ // datasets and iterators.
+ (*inputs)[dataset_token_tensor_name] = host_registration.token().ToString();
+ return host_registration;
+}
+
+std::unique_ptr<::fcp::client::opstats::OpStatsLogger> CreateOpStatsLogger(
+ const std::string& base_dir, const Flags* flags, LogManager* log_manager,
+ const std::string& session_name, const std::string& population_name) {
+ // if (flags->enable_opstats()) {
+ // auto db_or = PdsBackedOpStatsDb::Create(
+ // base_dir, flags->opstats_ttl_days() * absl::Hours(24), *log_manager,
+ // flags->opstats_db_size_limit_bytes());
+ // if (db_or.ok()) {
+ // return std::make_unique<OpStatsLoggerImpl>(
+ // std::move(db_or).value(), log_manager, flags, session_name,
+ // population_name);
+ // } else {
+ // if (flags->log_opstats_initialization_errors()) {
+ // return std::make_unique<OpStatsLogger>(
+ // /*opstats_enabled=*/flags->enable_opstats(),
+ // /*init_status=*/db_or.status());
+ // }
+ // }
+ // }
+ return std::make_unique<OpStatsLogger>(
+ /*opstats_enabled=*/flags->enable_opstats());
+}
+
+PlanResult CreateComputationErrorPlanResult(
+ absl::Status example_iterator_status,
+ absl::Status computation_error_status) {
+ switch (example_iterator_status.code()) {
+ case absl::StatusCode::kOk:
+ case absl::StatusCode::kOutOfRange:
+ // Either example iterators are working fine or we don't know the status
+ // of the example iterators. In this case, we'll use the error status
+ // returned from TensorFlow.
+ return PlanResult(PlanOutcome::kTensorflowError,
+ computation_error_status);
+ case absl::StatusCode::kCancelled:
+ // Example iterator got interrupted.
+ return PlanResult(PlanOutcome::kInterrupted, example_iterator_status);
+ default:
+ // All other Example iterator errors.
+ return PlanResult(PlanOutcome::kExampleIteratorError,
+ example_iterator_status);
+ }
+}
+
+ExampleIteratorFactory* FindExampleIteratorFactory(
+ const ExampleSelector& selector,
+ std::vector<ExampleIteratorFactory*> example_iterator_factories) {
+ for (ExampleIteratorFactory* factory : example_iterator_factories) {
+ if (factory->CanHandle(selector)) {
+ return factory;
+ }
+ }
+ return nullptr;
+}
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/plan_engine_helpers.h b/fcp/client/engine/plan_engine_helpers.h
new file mode 100644
index 0000000..823a074
--- /dev/null
+++ b/fcp/client/engine/plan_engine_helpers.h
@@ -0,0 +1,190 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_ENGINE_PLAN_ENGINE_HELPERS_H_
+#define FCP_CLIENT_ENGINE_PLAN_ENGINE_HELPERS_H_
+
+#include <atomic>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/engine/common.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/tensorflow/external_dataset.h"
+#include "fcp/tensorflow/host_object.h"
+#include "tensorflow/core/framework/tensor.h"
+
+// On Error Handling
+// Calls in the engine are assumed to either
+// 1. be successful (Status::OK)
+// 2. fail with an "expected" error -> handle gracefully - log error, tell the
+// environment (via finish), return
+// 3. encounter "unexpected" errors; when originating inside the engine or in
+// native code in the environment, or from java, crash.
+// While this type of tristate error handling is easy in Java (success, checked,
+// unchecked exceptions), it isn't in C++, hence we adopt the following
+// convention for control flow/error handling inside the engine:
+// - all functions in the plan engine downstream of runPhase() that can fail
+// must return a Status with one of the following codes: INTERNAL_ERROR,
+// CANCELLED, INVALID_ARGUMENT, OK. Only on OK will normal execution continue,
+// otherwise return up to the top level (runPhase). Once at the top level,
+// those error codes will be handled as follows:
+// a) CANCELLED -> report INTERRUPTED to env
+// b) INTERNAL_ERROR/INVALID_ARGUMENT -> report ERROR to env
+// c) OK -> report COMPLETED to env
+// For all status codes, the TaskRetry returned from the env is returned.
+// - utility functions outside of the engine will also use Status/StatusOr, but
+// may use other error codes (e.g. the TensorFlowWrapper or ExampleIterator
+// use OUT_OF_RANGE).
+// Return error handling is beautiful, I use this macro:
+// #1: FCP_ENGINE_RETURN_IF_ERROR(...): Return if the Status code is not OK,
+// else continue.
+
+namespace fcp {
+namespace client {
+namespace engine {
+namespace internal {
+inline absl::Status AsStatus(absl::Status status) { return status; }
+} // namespace internal
+
+// Macro to return the provided Status (or Status contained in StatusOr) if a
+// call to ok() fails.
+#define FCP_ENGINE_RETURN_IF_ERROR(status_or_statusor_expr) \
+ do { \
+ const absl::Status __status = \
+ ::fcp::client::engine::internal::AsStatus(status_or_statusor_expr); \
+ if (ABSL_PREDICT_FALSE(__status.code() != absl::StatusCode::kOk)) { \
+ return __status; \
+ } \
+ } while (0)
+
+// Tracks whether any example iterator encountered an error during the
+// computation (a single computation may use multiple iterators), either during
+// creation of the iterator or during one of the iterations.
+// This class is thread-safe.
+class ExampleIteratorStatus {
+ public:
+ void SetStatus(absl::Status status) ABSL_LOCKS_EXCLUDED(mu_);
+ absl::Status GetStatus() ABSL_LOCKS_EXCLUDED(mu_);
+
+ private:
+ absl::Status status_ ABSL_GUARDED_BY(mu_) = absl::OkStatus();
+ mutable absl::Mutex mu_;
+};
+
+// A class to iterate over a given example iterator.
+class DatasetIterator : public ExternalDatasetIterator {
+ public:
+ DatasetIterator(std::unique_ptr<ExampleIterator> example_iterator,
+ opstats::OpStatsLogger* opstats_logger,
+ std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ ExampleIteratorStatus* example_iterator_status,
+ const std::string& collection_uri, bool collect_stats);
+ ~DatasetIterator() override;
+
+ // Returns the next entry from the dataset.
+ absl::StatusOr<std::string> GetNext() final;
+
+ private:
+ std::unique_ptr<ExampleIterator> example_iterator_
+ ABSL_GUARDED_BY(iterator_lock_);
+ opstats::OpStatsLogger* opstats_logger_;
+ absl::Time iterator_start_time_;
+ // Example stats across all datasets.
+ std::atomic<int>* total_example_count_;
+ std::atomic<int64_t>* total_example_size_bytes_;
+ ExampleIteratorStatus* example_iterator_status_;
+ // Example stats only for this dataset.
+ std::atomic<int> example_count_;
+ std::atomic<int64_t> example_size_bytes_;
+ const std::string collection_uri_;
+ bool iterator_finished_ ABSL_GUARDED_BY(iterator_lock_);
+ const bool collect_stats_;
+ absl::Mutex iterator_lock_;
+};
+
+// Sets up a ExternalDatasetProvider that is registered with the global
+// HostObjectRegistry. Adds a tensor representing the HostObjectRegistration
+// token to the input tensors with the provided dataset_token_tensor_name key.
+//
+// For each example query issued by the plan at runtime, the given
+// `example_iterator_factories` parameter will be iterated and the first
+// iterator factory that can handle the given query will be used to create the
+// example iterator to handle that query.
+HostObjectRegistration AddDatasetTokenToInputs(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger,
+ std::vector<std::pair<std::string, tensorflow::Tensor>>* inputs,
+ const std::string& dataset_token_tensor_name,
+ std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ ExampleIteratorStatus* example_iterator_status);
+
+// Sets up an ExternalDatasetProvider that is registered with the global
+// HostObjectRegistry. Adds a string representing the HostObjectRegistration
+// token to the map of input tensor name and values with the provided
+// dataset_token_tensor_name key.
+//
+// For each example query issued by the plan at runtime, the given
+// `example_iterator_factories` parameter will be iterated and the first
+// iterator factory that can handle the given query will be used to create the
+// example iterator to handle that query.
+HostObjectRegistration AddDatasetTokenToInputsForTfLite(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger,
+ absl::flat_hash_map<std::string, std::string>* inputs,
+ const std::string& dataset_token_tensor_name,
+ std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ ExampleIteratorStatus* example_iterator_status);
+
+// If opstats is enabled, this method attempts to create an opstats logger
+// backed by a database within base_dir and prepares to record information for a
+// training run with the provided session and population names. If there is an
+// error initializing the db or opstats is disabled, creates a no-op logger.
+std::unique_ptr<::fcp::client::opstats::OpStatsLogger> CreateOpStatsLogger(
+ const std::string& base_dir, const Flags* flags, LogManager* log_manager,
+ const std::string& session_name, const std::string& population_name);
+
+// Utility for creating a PlanResult when an `INVALID_ARGUMENT` TensorFlow error
+// was encountered, disambiguating between generic TF errors and TF errors that
+// were likely root-caused by an earlier example iterator error.
+PlanResult CreateComputationErrorPlanResult(
+ absl::Status example_iterator_status,
+ absl::Status computation_error_status);
+
+// Finds a suitable example iterator factory out of provided factories based on
+// the provided selector.
+ExampleIteratorFactory* FindExampleIteratorFactory(
+ const google::internal::federated::plan::ExampleSelector& selector,
+ std::vector<ExampleIteratorFactory*> example_iterator_factories);
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_ENGINE_PLAN_ENGINE_HELPERS_H_
diff --git a/fcp/client/engine/simple_plan_engine.cc b/fcp/client/engine/simple_plan_engine.cc
new file mode 100644
index 0000000..1926872
--- /dev/null
+++ b/fcp/client/engine/simple_plan_engine.cc
@@ -0,0 +1,184 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/simple_plan_engine.h"
+
+#include <functional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/engine/plan_engine_helpers.h"
+#include "fcp/client/simple_task_environment.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+using ::fcp::client::opstats::OpStatsLogger;
+using ::google::internal::federated::plan::TensorflowSpec;
+
+SimplePlanEngine::SimplePlanEngine(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ std::function<bool()> should_abort, LogManager* log_manager,
+ OpStatsLogger* opstats_logger,
+ const InterruptibleRunner::TimingConfig* timing_config,
+ const bool support_constant_tf_inputs)
+ : example_iterator_factories_(example_iterator_factories),
+ should_abort_(should_abort),
+ log_manager_(log_manager),
+ opstats_logger_(opstats_logger),
+ timing_config_(timing_config),
+ support_constant_tf_inputs_(support_constant_tf_inputs) {}
+
+PlanResult SimplePlanEngine::RunPlan(
+ const TensorflowSpec& tensorflow_spec, const std::string& graph,
+ const ::google::protobuf::Any& config_proto,
+ std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
+ inputs,
+ const std::vector<std::string>& output_names) {
+ // Check that all inputs have corresponding TensorSpecProtos.
+ absl::flat_hash_set<std::string> expected_input_tensor_names_set;
+ for (const std::pair<std::string, tensorflow::Tensor>& input : *inputs) {
+ expected_input_tensor_names_set.insert(input.first);
+ }
+ absl::Status validity_checks = ValidateTensorflowSpec(
+ tensorflow_spec, expected_input_tensor_names_set, output_names);
+ if (!validity_checks.ok()) {
+ FCP_LOG(ERROR) << validity_checks.message();
+ return PlanResult(PlanOutcome::kInvalidArgument,
+ std::move(validity_checks));
+ }
+
+ absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> tf_wrapper_or =
+ TensorFlowWrapper::Create(graph, config_proto, should_abort_,
+ *timing_config_, log_manager_);
+ if (!tf_wrapper_or.ok()) {
+ return PlanResult(PlanOutcome::kTensorflowError, tf_wrapper_or.status());
+ }
+
+ std::unique_ptr<TensorFlowWrapper> tf_wrapper =
+ std::move(tf_wrapper_or.value());
+ std::atomic<int> total_example_count = 0;
+ std::atomic<int64_t> total_example_size_bytes = 0;
+ ExampleIteratorStatus example_iterator_status;
+ auto tf_result =
+ RunPlanInternal(tf_wrapper.get(), tensorflow_spec, std::move(inputs),
+ output_names, &total_example_count,
+ &total_example_size_bytes, &example_iterator_status);
+ FCP_CHECK(tf_wrapper->CloseAndRelease().ok());
+
+ switch (tf_result.status().code()) {
+ case absl::StatusCode::kOk: {
+ PlanResult plan_result(PlanOutcome::kSuccess, absl::OkStatus());
+ plan_result.output_names = output_names;
+ plan_result.output_tensors = std::move(tf_result).value();
+ plan_result.example_stats = {
+ .example_count = total_example_count,
+ .example_size_bytes = total_example_size_bytes};
+ return plan_result;
+ }
+ case absl::StatusCode::kCancelled:
+ return PlanResult(PlanOutcome::kInterrupted, tf_result.status());
+ case absl::StatusCode::kInvalidArgument:
+ return CreateComputationErrorPlanResult(
+ example_iterator_status.GetStatus(), tf_result.status());
+ default:
+ FCP_LOG(FATAL) << "unexpected status code: " << tf_result.status().code();
+ }
+ // Unreachable, but clang doesn't get it.
+ return PlanResult(PlanOutcome::kTensorflowError, absl::InternalError(""));
+}
+
+absl::StatusOr<std::vector<tensorflow::Tensor>>
+SimplePlanEngine::RunPlanInternal(
+ TensorFlowWrapper* tf_wrapper,
+ const google::internal::federated::plan::TensorflowSpec& tensorflow_spec,
+ std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
+ inputs,
+ const std::vector<std::string>& output_names,
+ std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ ExampleIteratorStatus* example_iterator_status) {
+ // Populate input tensor vector
+ // AddDatasetTokenToInputs first registers a DatasetProvider with the global
+ // ExternalDatasetProviderRegistry and then returns a HostObjectRegistration
+ // object. Hold onto the HostObjectRegistration object since it de-registers
+ // upon destruction.
+ HostObjectRegistration host_registration = AddDatasetTokenToInputs(
+ example_iterator_factories_, opstats_logger_, inputs.get(),
+ tensorflow_spec.dataset_token_tensor_name(), total_example_count,
+ total_example_size_bytes, example_iterator_status);
+
+ std::vector<std::string> target_names;
+ for (const std::string& target_node_name :
+ tensorflow_spec.target_node_names()) {
+ target_names.push_back(target_node_name);
+ }
+ if (support_constant_tf_inputs_ &&
+ !tensorflow_spec.constant_inputs().empty()) {
+ // If the server-side constant inputs message is provided, copy over these
+ // values to the set of input tensors.
+ for (const auto& [name, tensor_proto] : tensorflow_spec.constant_inputs()) {
+ tensorflow::Tensor input_tensor;
+ if (!input_tensor.FromProto(tensor_proto)) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("unable to convert constant_input to tensor: %s",
+ tensor_proto.DebugString()));
+ }
+ inputs->push_back({name, std::move(input_tensor)});
+ }
+ }
+
+ FCP_ASSIGN_OR_RETURN(
+ auto result,
+ RunTensorFlowInternal(tf_wrapper, *inputs, output_names, target_names));
+ return result;
+}
+
+absl::StatusOr<std::vector<tensorflow::Tensor>>
+SimplePlanEngine::RunTensorFlowInternal(
+ TensorFlowWrapper* tf_wrapper,
+ const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
+ const std::vector<std::string>& output_tensor_names,
+ const std::vector<std::string>& target_node_names) {
+ std::vector<tensorflow::Tensor> outputs;
+ absl::Status status =
+ tf_wrapper->Run(inputs, output_tensor_names, target_node_names, &outputs);
+ switch (status.code()) {
+ case absl::StatusCode::kCancelled:
+ case absl::StatusCode::kInvalidArgument:
+ return status;
+ case absl::StatusCode::kOutOfRange:
+ case absl::StatusCode::kOk:
+ break;
+ default:
+ FCP_CHECK_STATUS(status);
+ }
+ return outputs;
+}
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/simple_plan_engine.h b/fcp/client/engine/simple_plan_engine.h
new file mode 100644
index 0000000..bf418a3
--- /dev/null
+++ b/fcp/client/engine/simple_plan_engine.h
@@ -0,0 +1,106 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_ENGINE_SIMPLE_PLAN_ENGINE_H_
+#define FCP_CLIENT_ENGINE_SIMPLE_PLAN_ENGINE_H_
+
+#include <functional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/engine/common.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/engine/plan_engine_helpers.h"
+#include "fcp/client/engine/tf_wrapper.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/histogram_counters.pb.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/protos/plan.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+// A class used to "run" (interpret) a TensorflowSpec-based plan. Each instance
+// should generally only be used once to run a plan.
+class SimplePlanEngine {
+ public:
+ // For each example query issued by the plan at runtime, the given
+ // `example_iterator_factories` parameter will be iterated and the first
+ // iterator factory that can handle the given query will be used to create the
+ // example iterator for that query.
+ SimplePlanEngine(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ std::function<bool()> should_abort, LogManager* log_manager,
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger,
+ const InterruptibleRunner::TimingConfig* timing_config,
+ bool support_constant_tf_inputs);
+
+ PlanResult RunPlan(
+ const google::internal::federated::plan::TensorflowSpec& tensorflow_spec,
+ const std::string& graph, const ::google::protobuf::Any& config_proto,
+ std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
+ inputs,
+ const std::vector<std::string>& output_names);
+
+ private:
+ // Runs the plan. Returns one of three error codes:
+ // OK, INVALID_ARGUMENT, CANCELLED.
+ absl::StatusOr<std::vector<tensorflow::Tensor>> RunPlanInternal(
+ TensorFlowWrapper* tf_wrapper,
+ const google::internal::federated::plan::TensorflowSpec& tensorflow_spec,
+ std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
+ inputs,
+ const std::vector<std::string>& output_names,
+ std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ ExampleIteratorStatus* example_iterator_status);
+
+ // Invokes TensorFlowWrapper, and takes care of logging TensorFlow errors and
+ // external interruptions via event_publisher.
+ // If the TF call fails because it got aborted externally, returns CANCELLED.
+ // If the TF call fails with an INVALID argument, indicating a TF error,
+ // publishes an event, then returns INVALID_ARGUMENT
+ // If the TF call reports an OUT_OF_RANGE error ("internal" abortion) or the
+ // TF call is successful, returns OK.
+ absl::StatusOr<std::vector<tensorflow::Tensor>> RunTensorFlowInternal(
+ TensorFlowWrapper* tf_wrapper,
+ const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
+ const std::vector<std::string>& output_tensor_names,
+ const std::vector<std::string>& target_node_names);
+
+ std::vector<ExampleIteratorFactory*> example_iterator_factories_;
+ std::function<bool()> should_abort_;
+ LogManager* log_manager_;
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger_;
+ const InterruptibleRunner::TimingConfig* timing_config_;
+ const bool support_constant_tf_inputs_;
+};
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_ENGINE_SIMPLE_PLAN_ENGINE_H_
diff --git a/fcp/client/engine/tf_wrapper.cc b/fcp/client/engine/tf_wrapper.cc
new file mode 100644
index 0000000..6d0d4e9
--- /dev/null
+++ b/fcp/client/engine/tf_wrapper.cc
@@ -0,0 +1,190 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/tf_wrapper.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/engine/plan_engine_helpers.h"
+#include "fcp/client/interruptible_runner.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+using ::google::protobuf::Any;
+
+// If `external_config_proto` contains a non-empty config proto, use that.
+// Otherwise initializes a config proto from a set of defaults.
+absl::StatusOr<tensorflow::ConfigProto>
+TensorFlowWrapper::InitializeConfigProto(const Any& external_config_proto) {
+ // Previously, we specified a hardcoded set of options in the ConfigProto by
+ // default. However, if a non-empty ConfigProto is now provided as a
+ // parameter, then we should use it as-is, without overriding any of the
+ // options (otherwise we prevent the caller from having control over the
+ // parameters we set by default).
+ if (external_config_proto.ByteSizeLong() > 0) {
+ // Unpack the external_config_proto parameter if one is provided. In this
+ // case it must be a packed ConfigProto (anything else is an error).
+ // Accordingly, UnpackTo will return false if parsing fails or if the Any is
+ // not of a compatible type.
+ tensorflow::ConfigProto unpacked_config_proto;
+ if (!external_config_proto.UnpackTo(&unpacked_config_proto)) {
+ return absl::InvalidArgumentError("Could not parse ConfigProto.");
+ }
+ if (unpacked_config_proto.ByteSizeLong() > 0) {
+ // The caller-provided, unpacked ConfigProto was not empty, so we use it
+ // in the SessionOptions and we do not specify our default config options
+ // anymore.
+ return unpacked_config_proto;
+ }
+ // We purposely fall through to the next block if the unpacked_config_proto
+ // was empty.
+ }
+
+ // Only if the provided ConfigProto was empty (or if none was provided) do we
+ // still set hardcoded options (this is our "old" behavior, equivalent to what
+ // we did before we supported caller-specified ConfigProtos).
+ //
+ // WARNING: If the need for tuning configuration options further arises again
+ // in the future, we ideally shouldn't update any of the hardcoded ConfigProto
+ // values here anymore. Instead, we should expect our callers to specify any
+ // ConfigProto values they want to use. We only maintain this block of code
+ // for compatibility with callers that don't provide any ConfigProto at all
+ // (yet).
+ //
+ tensorflow::ConfigProto config_proto;
+ config_proto.mutable_graph_options()->set_place_pruned_graph(true);
+ auto mutable_experimental = config_proto.mutable_experimental();
+ mutable_experimental->set_optimize_for_static_graph(true);
+ mutable_experimental->set_disable_output_partition_graphs(true);
+ return config_proto;
+}
+
+absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> TensorFlowWrapper::Create(
+ const std::string& graph, const Any& config_proto,
+ std::function<bool()> should_abort,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ LogManager* log_manager) {
+ // Create a tensorflow::Session.
+ tensorflow::Session* session_ptr;
+ std::unique_ptr<tensorflow::Session> session;
+ tensorflow::SessionOptions session_options;
+ FCP_ASSIGN_OR_RETURN(session_options.config,
+ InitializeConfigProto(config_proto));
+
+ tensorflow::Status status =
+ tensorflow::NewSession(session_options, &session_ptr);
+ if (!status.ok()) {
+ return ToFcpStatus(status, "Error in tensorflow::NewSession()");
+ }
+ session = absl::WrapUnique(session_ptr);
+
+ // Parse GraphDef.
+ tensorflow::GraphDef graph_def;
+ bool parse_result = graph_def.ParseFromString(graph);
+ if (parse_result == false) {
+ return absl::InvalidArgumentError("Could not parse GraphDef.");
+ }
+ // Load graph.
+ status = session->Create(std::move(graph_def));
+ if (!status.ok()) {
+ return ToFcpStatus(status, "Error in Session::Create()");
+ }
+
+ // Create an InterruptibleRunner to execute TF calls in a background thread,
+ // allowing us to abort them if need be.
+ auto interruptible_runner = std::make_unique<InterruptibleRunner>(
+ log_manager, should_abort, timing_config,
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
+ .interrupt_timeout = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT});
+ auto wrapper = absl::WrapUnique(new TensorFlowWrapper(
+ std::move(session), std::move(interruptible_runner), log_manager));
+ return wrapper;
+}
+
+TensorFlowWrapper::~TensorFlowWrapper() { FCP_CHECK(CloseAndRelease().ok()); }
+
+absl::Status TensorFlowWrapper::ToFcpStatus(tensorflow::Status s,
+ const std::string& message_prefix) {
+ if (s.ok()) {
+ return absl::OkStatus();
+ } else if (s.code() == tensorflow::error::OUT_OF_RANGE) {
+ return absl::OutOfRangeError("");
+ } else {
+ return absl::InvalidArgumentError(
+ absl::StrCat(message_prefix, ": ", s.ToString()));
+ }
+}
+
+absl::Status TensorFlowWrapper::Run(
+ const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
+ const std::vector<std::string>& output_tensor_names,
+ const std::vector<std::string>& target_node_names,
+ std::vector<tensorflow::Tensor>* outputs) {
+ FCP_CHECK(!session_closed_) << "Run() called after session close!";
+
+ auto tensorflow_runnable = [&inputs, &output_tensor_names, &target_node_names,
+ &outputs, this]() -> absl::Status {
+ tensorflow::Status status = this->session_->Run(inputs, output_tensor_names,
+ target_node_names, outputs);
+ if (!status.ok()) {
+ return ToFcpStatus(status, "Error in Session::Run()");
+ }
+ return absl::OkStatus();
+ };
+ auto abort_tensorflow = [this]() {
+ absl::MutexLock _(&session_lock_);
+ // Errors from Close() are expected when interrupting ongoing calls. We
+ // don't call CloseAndRelease() here because that would free the TensorFlow
+ // session while other TensorFlow worker threads may still be using it.
+ session_->Close().IgnoreError();
+ session_closed_ = true;
+ };
+ return interruptible_runner_->Run(tensorflow_runnable, abort_tensorflow);
+}
+
+absl::Status TensorFlowWrapper::CloseAndRelease() {
+ absl::MutexLock _(&session_lock_);
+ // If the TensorFlow session hasn't been closed yet, close it.
+ if (!session_closed_) {
+ FCP_ENGINE_RETURN_IF_ERROR(
+ ToFcpStatus(session_->Close(), "Could not close TF session"));
+ session_closed_ = true;
+ }
+ // If the TensorflowSession hasn't been released yet, release it.
+ if (session_) {
+ session_.reset();
+ }
+ return absl::OkStatus();
+}
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/tf_wrapper.h b/fcp/client/engine/tf_wrapper.h
new file mode 100644
index 0000000..7b97dbb
--- /dev/null
+++ b/fcp/client/engine/tf_wrapper.h
@@ -0,0 +1,106 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_ENGINE_TF_WRAPPER_H_
+#define FCP_CLIENT_ENGINE_TF_WRAPPER_H_
+
+#include <functional>
+#include <string>
+#include <utility>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "fcp/base/future.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "tensorflow/core/public/session.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+// A class to call into TensorFlow.
+// All functions in this interface indicate errors as follows:
+// - CANCELLED: interrupted execution
+// - INVALID_ARGUMENT: TensorFlow error. The TensorFlow error code and message
+// are included in the Status message.
+// - OUT_OF_RANGE: internal abortion, i.e. TensorFlow reporting the model
+// aborted execution.
+// This class supports aborting ongoing calls, by polling the provided
+// should_abort function.
+class TensorFlowWrapper {
+ public:
+ static absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> Create(
+ const std::string& graph, const ::google::protobuf::Any& config_proto,
+ std::function<bool()> should_abort,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ LogManager* log_manager);
+
+ // Utility method for creating a ConfigProto from an optionally
+ // externally provided value, or from hardcoded defaults. This is a separate
+ // method to aid with testing.
+ static absl::StatusOr<::tensorflow::ConfigProto> InitializeConfigProto(
+ const ::google::protobuf::Any& external_config_proto);
+
+ ~TensorFlowWrapper();
+
+ // Wrapper around TensorFlow's Session::Run method with full support for
+ // feeds, fetches and target node names.
+ // Returns OK, OUT_OF_RANGE, INVALID_ARGUMENT, or CANCELLED.
+ absl::Status Run(
+ const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
+ const std::vector<std::string>& output_tensor_names,
+ const std::vector<std::string>& target_node_names,
+ std::vector<tensorflow::Tensor>* outputs);
+
+ // Closes and releases the TensorFlow session. After this is called, no
+ // further calls on this TensorFlowWrapper should be made. Subsequent calls to
+ // CloseAndRelease() will have no effect.
+ absl::Status CloseAndRelease();
+
+ private:
+ TensorFlowWrapper(std::unique_ptr<tensorflow::Session> session,
+ std::unique_ptr<InterruptibleRunner> interruptible_runner,
+ LogManager* log_manager)
+ : session_(std::move(session)),
+ interruptible_runner_(std::move(interruptible_runner)),
+ session_closed_(false) {}
+
+ // Converts a TensorFlow status to an absl::Status.
+ //
+ // Rule:
+ // TensorFlow OK status -> absl OK status
+ // TensorFlow OUT_OF_RANGE -> absl OUT_OF_RANGE status (this is TF indicating
+ // that the plan decided to abort, e.g. because of convergence)
+ // Other TensorFlow status -> absl INVALID_ARGUMENT status with error
+ // message being message_prefix + TensorFlow status code + error message.
+ static absl::Status ToFcpStatus(tensorflow::Status s,
+ const std::string& message_prefix);
+
+ std::unique_ptr<tensorflow::Session> session_;
+ std::unique_ptr<InterruptibleRunner> interruptible_runner_;
+ absl::Mutex session_lock_;
+ bool session_closed_;
+};
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_ENGINE_TF_WRAPPER_H_
diff --git a/fcp/client/engine/tf_wrapper_test.cc b/fcp/client/engine/tf_wrapper_test.cc
new file mode 100644
index 0000000..0bc07dd
--- /dev/null
+++ b/fcp/client/engine/tf_wrapper_test.cc
@@ -0,0 +1,131 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/tf_wrapper.h"
+
+#include "google/protobuf/any.pb.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+namespace {
+
+using ::google::protobuf::Any;
+using ::tensorflow::ConfigProto;
+
+TEST(TfWrapperInitializeConfigProtoTest, InvalidConfigProtoWrongTypeUrl) {
+ // Create an Any with a valid value but invalid type URL.
+ ConfigProto config_proto;
+ config_proto.mutable_graph_options()->set_timeline_step(123);
+ Any packed_config_proto;
+ packed_config_proto.PackFrom(config_proto);
+ packed_config_proto.set_type_url("invalid");
+
+ absl::StatusOr<ConfigProto> result =
+ TensorFlowWrapper::InitializeConfigProto(packed_config_proto);
+
+ EXPECT_THAT(result, IsCode(INVALID_ARGUMENT));
+}
+
+TEST(TfWrapperInitializeConfigProtoTest, InvalidConfigProtoEmptyTypeUrl) {
+ // Create an Any with a valid value but empty type URL.
+ ConfigProto config_proto;
+ config_proto.mutable_graph_options()->set_timeline_step(123);
+ Any packed_config_proto;
+ packed_config_proto.PackFrom(config_proto);
+ packed_config_proto.clear_type_url();
+
+ absl::StatusOr<ConfigProto> result =
+ TensorFlowWrapper::InitializeConfigProto(packed_config_proto);
+
+ EXPECT_THAT(result, IsCode(INVALID_ARGUMENT));
+}
+
+TEST(TfWrapperInitializeConfigProtoTest, InvalidConfigProtoValue) {
+ // Set the correct type URL, but an unparseable value.
+ Any packed_config_proto;
+ packed_config_proto.PackFrom(ConfigProto());
+ packed_config_proto.set_value("nonparseable");
+
+ absl::StatusOr<ConfigProto> result =
+ TensorFlowWrapper::InitializeConfigProto(packed_config_proto);
+
+ EXPECT_THAT(result, IsCode(INVALID_ARGUMENT));
+}
+
+TEST(TfWrapperInitializeConfigProtoTest, ValidNonEmptyConfigProtoValue) {
+ // Create an Any containing a valid, non-empty ConfigProto.
+ ConfigProto config_proto;
+ config_proto.mutable_graph_options()->set_timeline_step(123);
+ Any packed_config_proto;
+ packed_config_proto.PackFrom(config_proto);
+
+ absl::StatusOr<ConfigProto> result =
+ TensorFlowWrapper::InitializeConfigProto(packed_config_proto);
+
+ // A non-empty ConfigProto was provided, so it should be used as-is.
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(config_proto));
+}
+
+TEST(TfWrapperInitializeConfigProtoTest, ValidEmptyConfigProtoValue) {
+ // Create an Any containing an empty ConfigProto.
+ Any packed_config_proto;
+ packed_config_proto.PackFrom(ConfigProto());
+
+ absl::StatusOr<ConfigProto> result =
+ TensorFlowWrapper::InitializeConfigProto(packed_config_proto);
+
+ // No external ConfigProto was provided, so the hardcoded defaults should be
+ // used.
+ ConfigProto expected_config_proto;
+ expected_config_proto.mutable_graph_options()->set_place_pruned_graph(true);
+ expected_config_proto.mutable_experimental()
+ ->set_disable_output_partition_graphs(true);
+ expected_config_proto.mutable_experimental()->set_optimize_for_static_graph(
+ true);
+
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_config_proto));
+}
+
+TEST(TfWrapperInitializeConfigProtoTest, ValidEmptyPackedConfigProtoValue) {
+ // Create an empty Any.
+ Any packed_config_proto;
+
+ absl::StatusOr<ConfigProto> result =
+ TensorFlowWrapper::InitializeConfigProto(packed_config_proto);
+
+ // No external ConfigProto was provided, so the hardcoded defaults should be
+ // used.
+ ConfigProto expected_config_proto;
+ expected_config_proto.mutable_graph_options()->set_place_pruned_graph(true);
+ expected_config_proto.mutable_experimental()
+ ->set_disable_output_partition_graphs(true);
+ expected_config_proto.mutable_experimental()->set_optimize_for_static_graph(
+ true);
+
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_config_proto));
+}
+
+} // namespace
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/tflite_plan_engine.cc b/fcp/client/engine/tflite_plan_engine.cc
new file mode 100644
index 0000000..3e27039
--- /dev/null
+++ b/fcp/client/engine/tflite_plan_engine.cc
@@ -0,0 +1,155 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/tflite_plan_engine.h"
+
+#include <functional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "fcp/client/engine/plan_engine_helpers.h"
+#include "fcp/client/engine/tflite_wrapper.h"
+#include "fcp/protos/plan.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+using ::google::internal::federated::plan::TensorflowSpec;
+
+namespace {
+
+PlanResult CreatePlanResultFromOutput(
+ absl::StatusOr<OutputTensors> output, std::atomic<int>* total_example_count,
+ std::atomic<int64_t>* total_example_size_bytes,
+ absl::Status example_iterator_status) {
+ switch (output.status().code()) {
+ case absl::StatusCode::kOk: {
+ PlanResult plan_result(PlanOutcome::kSuccess, absl::OkStatus());
+ plan_result.output_names = std::move(output->output_tensor_names);
+ plan_result.output_tensors = std::move(output->output_tensors);
+ plan_result.example_stats = {
+ .example_count = *total_example_count,
+ .example_size_bytes = *total_example_size_bytes};
+ return plan_result;
+ }
+ case absl::StatusCode::kCancelled:
+ return PlanResult(PlanOutcome::kInterrupted, std::move(output.status()));
+ case absl::StatusCode::kInvalidArgument:
+ return CreateComputationErrorPlanResult(example_iterator_status,
+ output.status());
+ default:
+ FCP_LOG(FATAL) << "unexpected status code: " << output.status().code();
+ }
+ // Unreachable code.
+ return PlanResult(PlanOutcome::kTensorflowError, absl::InternalError(""));
+}
+
+TfLiteInterpreterOptions CreateOptions(const Flags& flags) {
+ return TfLiteInterpreterOptions{
+ .ensure_dynamic_tensors_are_released =
+ flags.ensure_dynamic_tensors_are_released(),
+ .large_tensor_threshold_for_dynamic_allocation =
+ flags.large_tensor_threshold_for_dynamic_allocation(),
+ .disable_delegate_clustering =
+ flags.disable_tflite_delegate_clustering()};
+}
+} // namespace
+
+PlanResult TfLitePlanEngine::RunPlan(
+ const TensorflowSpec& tensorflow_spec, const std::string& model,
+ std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
+ const std::vector<std::string>& output_names) {
+ FCP_LOG(INFO) << "***** start running plan";
+ log_manager_->LogDiag(ProdDiagCode::BACKGROUND_TRAINING_TFLITE_ENGINE_USED);
+ // Check that all inputs have corresponding TensorSpecProtos.
+ absl::flat_hash_set<std::string> expected_input_tensor_names_set;
+ for (auto it = inputs->begin(); it != inputs->end(); it++) {
+ expected_input_tensor_names_set.insert(it->first);
+ }
+ absl::Status validity_checks = ValidateTensorflowSpec(
+ tensorflow_spec, expected_input_tensor_names_set, output_names);
+ if (!validity_checks.ok()) {
+ FCP_LOG(ERROR) << validity_checks.message();
+ return PlanResult(PlanOutcome::kInvalidArgument,
+ std::move(validity_checks));
+ }
+ std::atomic<int> total_example_count = 0;
+ std::atomic<int64_t> total_example_size_bytes = 0;
+ ExampleIteratorStatus example_iterator_status;
+ HostObjectRegistration host_registration = AddDatasetTokenToInputsForTfLite(
+ example_iterator_factories_, opstats_logger_, inputs.get(),
+ tensorflow_spec.dataset_token_tensor_name(), &total_example_count,
+ &total_example_size_bytes, &example_iterator_status);
+ // If the constant inputs are provided and the flag is enabled, add these to
+ // the map of TFLite inputs.
+ if (!tensorflow_spec.constant_inputs().empty()) {
+ FCP_LOG(INFO) << "***** constant inputs is not empty";
+ if (!flags_.support_constant_tf_inputs()) {
+ return PlanResult(
+ PlanOutcome::kInvalidArgument,
+ absl::InternalError(
+ "Cannot run constant_inputs when experiment is disabled."));
+ } else {
+ for (const auto& [name, tensor_proto] :
+ tensorflow_spec.constant_inputs()) {
+ tensorflow::Tensor input_tensor;
+ if (!input_tensor.FromProto(tensor_proto)) {
+ FCP_LOG(ERROR) << "unable to convert constant_input to tensor: "
+ << tensor_proto.DebugString();
+ return PlanResult(PlanOutcome::kInvalidArgument,
+ absl::InternalError(
+ "Unable to convert constant_input to tensor"));
+ }
+ // Convert Tensor to TFLite represenation and add this as a string to
+ // inputs.
+ if (input_tensor.dtype() == tensorflow::DT_STRING) {
+ tensorflow::tstring str_data =
+ input_tensor.scalar<tensorflow::tstring>()();
+ inputs->insert({name, std::string(str_data.data(), str_data.size())});
+ } else {
+ FCP_LOG(ERROR) << "Constant input tensor is not a string tensor. "
+ "Currently only string tensors are supported.";
+ return PlanResult(
+ PlanOutcome::kInvalidArgument,
+ absl::InternalError("Only string tensors are supported"));
+ }
+ }
+ }
+ }
+ absl::StatusOr<std::unique_ptr<TfLiteWrapper>> tflite_wrapper =
+ TfLiteWrapper::Create(model, should_abort_, *timing_config_, log_manager_,
+ std::move(inputs), output_names,
+ CreateOptions(flags_),
+ flags_.num_threads_for_tflite());
+ FCP_LOG(INFO) << "***** create tflite wrapper";
+
+ if (!tflite_wrapper.ok()) {
+ return PlanResult(PlanOutcome::kTensorflowError, tflite_wrapper.status());
+ }
+ // Start running the plan.
+ absl::StatusOr<OutputTensors> output = (*tflite_wrapper)->Run();
+ PlanResult plan_result = CreatePlanResultFromOutput(
+ std::move(output), &total_example_count, &total_example_size_bytes,
+ example_iterator_status.GetStatus());
+ return plan_result;
+}
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/tflite_plan_engine.h b/fcp/client/engine/tflite_plan_engine.h
new file mode 100644
index 0000000..5dc09a5
--- /dev/null
+++ b/fcp/client/engine/tflite_plan_engine.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_ENGINE_TFLITE_PLAN_ENGINE_H_
+#define FCP_CLIENT_ENGINE_TFLITE_PLAN_ENGINE_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "fcp/client/engine/common.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/simple_task_environment.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+// A class used to "run" (interpret) a TensorflowSpec-based plan with TfLite.
+// Each instance should generally only be used once to run a plan.
+class TfLitePlanEngine {
+ public:
+ // For each example query issued by the plan at runtime, the given
+ // `example_iterator_factories` parameter will be iterated and the first
+ // iterator factory that can handle the given query will be used to create the
+ // example iterator for that query.
+ TfLitePlanEngine(
+ std::vector<ExampleIteratorFactory*> example_iterator_factories,
+ std::function<bool()> should_abort, LogManager* log_manager,
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger, const Flags* flags,
+ const InterruptibleRunner::TimingConfig* timing_config)
+ : example_iterator_factories_(example_iterator_factories),
+ should_abort_(should_abort),
+ log_manager_(log_manager),
+ opstats_logger_(opstats_logger),
+ flags_(*flags),
+ timing_config_(timing_config) {}
+
+ // Runs the plan, and takes care of logging TfLite errors and external
+ // interruptions via event_publisher. If the TfLite call fails because it got
+ // aborted externally, returns CANCELLED. If the TfLite call fails because of
+ // other reasons, publishes an event, then returns INVALID_ARGUMENT. If the
+ // TfLite call is successful, returns OK, and the output tensors.
+ PlanResult RunPlan(
+ const google::internal::federated::plan::TensorflowSpec& tensorflow_spec,
+ const std::string& model,
+ std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
+ const std::vector<std::string>& output_names);
+
+ private:
+ std::vector<ExampleIteratorFactory*> example_iterator_factories_;
+ std::function<bool()> should_abort_;
+ LogManager* log_manager_;
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger_;
+ const Flags& flags_;
+ const InterruptibleRunner::TimingConfig* timing_config_;
+};
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_ENGINE_TFLITE_PLAN_ENGINE_H_
diff --git a/fcp/client/engine/tflite_plan_engine_test.cc b/fcp/client/engine/tflite_plan_engine_test.cc
new file mode 100644
index 0000000..8e5319e
--- /dev/null
+++ b/fcp/client/engine/tflite_plan_engine_test.cc
@@ -0,0 +1,224 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// fcp:google3-internal-file
+#include "fcp/client/engine/tflite_plan_engine.h"
+
+#include <algorithm>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/status/statusor.h"
+#include "fcp/client/client_runner.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/opstats/opstats_example_store.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+namespace {
+using ::fcp::client::opstats::OpStatsSequence;
+using ::google::internal::federated::plan::ClientOnlyPlan;
+using ::google::internal::federated::plan::Dataset;
+using ::google::internal::federated::plan::FederatedComputeEligibilityIORouter;
+using ::google::internal::federated::plan::FederatedComputeIORouter;
+using ::google::internal::federated::plan::LocalComputeIORouter;
+using ::testing::Gt;
+using ::testing::InSequence;
+using ::testing::Invoke;
+using ::testing::IsEmpty;
+using ::testing::NiceMock;
+using ::testing::Return;
+using ::testing::StrictMock;
+
+// We turn formatting off to prevent line breaks, which ensures that these paths
+// are more easily code searchable.
+// clang-format off
+constexpr absl::string_view kArtifactPrefix =
+ "intelligence/brella/testing/tasks/mnist/simpleagg_mnist_training_tflite_task_artifacts"; // NOLINT
+constexpr absl::string_view kEligibilityPlanArtifactPrefix =
+ "intelligence/brella/testing/tasks/eligibility_eval/eligibility_eval_tflite_task_artifacts"; // NOLINT
+constexpr absl::string_view kSecaggArtifactPrefix =
+ "intelligence/brella/testing/tasks/secagg_only_tflite_task_artifacts";
+constexpr absl::string_view kLcArtifactPrefix =
+ "intelligence/brella/testing/local_computation/mnist_tflite_personalization_artifacts"; // NOLINT
+constexpr absl::string_view kLcInitialCheckpoint =
+ "intelligence/brella/testing/local_computation/initial.ckpt";
+constexpr absl::string_view kConstantInputsArtifactPrefix =
+ "intelligence/brella/testing/tasks/mnist/simpleagg_constant_tflite_inputs_task_artifacts"; // NOLINT
+// clang-format on
+
+const char* const kCollectionUri = "app:/test_collection";
+const char* const kEligibilityEvalCollectionUri =
+ "app:/test_eligibility_eval_collection";
+const char* const kLcTrainCollectionUri = "app:/p13n_train_collection";
+const char* const kLcTestCollectionUri = "app:/p13n_test_collection";
+
+// Parameterized with whether per_phase_logs should be used.
+class TfLitePlanEngineTest : public testing::Test {
+ protected:
+ void SetUp() override {
+ EXPECT_CALL(mock_opstats_logger_, IsOpStatsEnabled())
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(mock_opstats_logger_, GetOpStatsDb())
+ .WillRepeatedly(Return(&mock_opstats_db_));
+ EXPECT_CALL(mock_opstats_db_, Read())
+ .WillRepeatedly(Return(OpStatsSequence::default_instance()));
+ EXPECT_CALL(mock_flags_, ensure_dynamic_tensors_are_released())
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(mock_flags_, large_tensor_threshold_for_dynamic_allocation())
+ .WillRepeatedly(Return(1000));
+ EXPECT_CALL(mock_flags_, num_threads_for_tflite())
+ .WillRepeatedly(Return(4));
+ EXPECT_CALL(mock_flags_, disable_tflite_delegate_clustering())
+ .WillRepeatedly(Return(false));
+ EXPECT_CALL(mock_flags_, support_constant_tf_inputs())
+ .WillRepeatedly(Return(false));
+ }
+
+ void InitializeFlTask(absl::string_view prefix) {
+ LoadArtifacts();
+
+ example_iterator_factory_ =
+ std::make_unique<FunctionalExampleIteratorFactory>(
+ [&dataset = dataset_](
+ const google::internal::federated::plan::ExampleSelector&
+ selector) {
+ return std::make_unique<::fcp::client::SimpleExampleIterator>(
+ dataset);
+ });
+
+ // Compute dataset stats.
+ for (const Dataset::ClientDataset& client_dataset :
+ dataset_.client_data()) {
+ num_examples_ += client_dataset.example_size();
+ for (const std::string& example : client_dataset.example()) {
+ example_bytes_ += example.size();
+ }
+ }
+ // The single session FL plan specifies both input and output filepaths in
+ // its FederatedComputeIORouter.
+ FederatedComputeIORouter io_router =
+ client_only_plan_.phase().federated_compute();
+ if (!io_router.input_filepath_tensor_name().empty()) {
+ (*inputs_)[io_router.input_filepath_tensor_name()] =
+ checkpoint_input_filename_;
+ }
+ checkpoint_output_filename_ =
+ files_impl_.CreateTempFile("output", ".ckp").value();
+ ASSERT_EQ(std::filesystem::file_size(checkpoint_output_filename_), 0);
+ if (!io_router.output_filepath_tensor_name().empty()) {
+ (*inputs_)[io_router.output_filepath_tensor_name()] =
+ checkpoint_output_filename_;
+ }
+
+ for (const auto& tensor_spec :
+ client_only_plan_.phase().tensorflow_spec().output_tensor_specs()) {
+ output_names_.push_back(tensor_spec.name());
+ }
+ }
+
+ void LoadArtifacts() {
+ absl::StatusOr<::fcp::client::ComputationArtifacts> artifacts =
+ ::fcp::client::LoadFlArtifacts();
+ EXPECT_TRUE(artifacts.ok());
+ client_only_plan_ = std::move(artifacts->plan);
+ dataset_ = std::move(artifacts->dataset);
+ checkpoint_input_filename_ = artifacts->checkpoint_filepath;
+ }
+
+ void ComputeDatasetStats(const std::string& collection_uri) {
+ for (const Dataset::ClientDataset& client_dataset :
+ dataset_.client_data()) {
+ for (const Dataset::ClientDataset::SelectedExample& selected_example :
+ client_dataset.selected_example()) {
+ if (selected_example.selector().collection_uri() != collection_uri) {
+ continue;
+ }
+ num_examples_ += selected_example.example_size();
+ for (const auto& example : selected_example.example()) {
+ example_bytes_ += example.size();
+ }
+ }
+ }
+ }
+
+ fcp::client::FilesImpl files_impl_;
+ StrictMock<MockLogManager> mock_log_manager_;
+ StrictMock<MockOpStatsLogger> mock_opstats_logger_;
+ StrictMock<MockOpStatsDb> mock_opstats_db_;
+ StrictMock<MockFlags> mock_flags_;
+ std::unique_ptr<ExampleIteratorFactory> example_iterator_factory_;
+ // Never abort, by default.
+ std::function<bool()> should_abort_ = []() { return false; };
+
+ ClientOnlyPlan client_only_plan_;
+ Dataset dataset_;
+ std::string checkpoint_input_filename_;
+ std::string checkpoint_output_filename_;
+
+ int num_examples_ = 0;
+ int example_bytes_ = 0;
+ std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs_ =
+ std::make_unique<absl::flat_hash_map<std::string, std::string>>();
+ std::vector<std::string> output_names_;
+
+ fcp::client::InterruptibleRunner::TimingConfig timing_config_ = {
+ // Use 10 ms to make the polling faster, otherwise the Abort test might
+ // fail because the plan finishes before interruption.
+ .polling_period = absl::Milliseconds(10),
+ .graceful_shutdown_period = absl::Milliseconds(1000),
+ .extended_shutdown_period = absl::Milliseconds(2000),
+ };
+};
+
+TEST_F(TfLitePlanEngineTest, SimpleAggPlanSucceeds) {
+ InitializeFlTask(kArtifactPrefix);
+
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_TFLITE_ENGINE_USED));
+
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_));
+
+ TfLitePlanEngine plan_engine({example_iterator_factory_.get()}, should_abort_,
+ &mock_log_manager_, &mock_opstats_logger_,
+ &mock_flags_, &timing_config_);
+ engine::PlanResult result = plan_engine.RunPlan(
+ client_only_plan_.phase().tensorflow_spec(),
+ client_only_plan_.tflite_graph(), std::move(inputs_), output_names_);
+ FCP_LOG(INFO) << "**** plan result " << result.original_status;
+
+ EXPECT_THAT(result.outcome, PlanOutcome::kSuccess);
+ EXPECT_THAT(result.output_tensors.size(), 0);
+ EXPECT_THAT(result.output_names.size(), 0);
+ EXPECT_EQ(result.example_stats.example_count, num_examples_);
+ EXPECT_EQ(result.example_stats.example_size_bytes, example_bytes_);
+ EXPECT_GT(std::filesystem::file_size(checkpoint_output_filename_), 0);
+}
+} // namespace
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/tflite_wrapper.cc b/fcp/client/engine/tflite_wrapper.cc
new file mode 100644
index 0000000..633251b
--- /dev/null
+++ b/fcp/client/engine/tflite_wrapper.cc
@@ -0,0 +1,210 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/tflite_wrapper.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "fcp/base/monitoring.h"
+#include "tensorflow/core/public/version.h"
+#include "tensorflow/lite/delegates/flex/util.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/interpreter_builder.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/model_builder.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+using ::tflite::ops::builtin::BuiltinOpResolver;
+
+namespace {
+
+absl::Status AssignStringInput(int index, const std::string& value,
+ tflite::Interpreter* interpreter) {
+ TfLiteTensor* tensor = interpreter->tensor(index);
+ if (tensor->type != kTfLiteString) {
+ return absl::InvalidArgumentError("Input tensor is not a string tensor.");
+ }
+
+ tflite::DynamicBuffer buf;
+ buf.AddString(value.data(), value.length());
+ buf.WriteToTensor(tensor, nullptr);
+ return absl::OkStatus();
+}
+
+} // anonymous namespace
+
+absl::StatusOr<std::unique_ptr<TfLiteWrapper>> TfLiteWrapper::Create(
+ const std::string& model, std::function<bool()> should_abort,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ LogManager* log_manager,
+ std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
+ std::vector<std::string> output_names,
+ const TfLiteInterpreterOptions& interpreter_options, int32_t num_threads) {
+ std::unique_ptr<tflite::FlatBufferModel> flat_buffer_model =
+ tflite::FlatBufferModel::BuildFromBuffer(model.c_str(), model.size());
+ if (flat_buffer_model == nullptr) {
+ return absl::InvalidArgumentError("Failed to build FlatBufferModel.");
+ }
+ // The training delegate needs to be created before the interpreter.
+ auto delegate = tflite::FlexDelegate::Create();
+ auto error_reporter = std::make_unique<CachingErrorReporter>();
+ auto interpreter = std::make_unique<tflite::Interpreter>();
+
+ if (tflite::InterpreterBuilder(
+ flat_buffer_model->GetModel(), BuiltinOpResolver(),
+ error_reporter.get())(&interpreter) != kTfLiteOk) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Failed to initiate interpreter: ",
+ error_reporter->GetFirstErrorMessage()));
+ }
+ interpreter->SetNumThreads(num_threads);
+ if (interpreter->ModifyGraphWithDelegate(delegate.get()) != kTfLiteOk) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Failed to modify graph with TrainingFlexDelegate: ",
+ error_reporter->GetFirstErrorMessage()));
+ }
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Failed to allocate tensors: ",
+ error_reporter->GetFirstErrorMessage()));
+ }
+ interpreter->SetCancellationFunction(delegate->data_,
+ tflite::FlexDelegate::HasCancelled);
+ for (const auto& input : interpreter->inputs()) {
+ std::string key = interpreter->GetInputName(input);
+ if (inputs->find(key) == inputs->end()) {
+ return absl::InvalidArgumentError("Unexpected input tensor.");
+ }
+ FCP_RETURN_IF_ERROR(
+ AssignStringInput(input, inputs->at(key), interpreter.get()));
+ }
+ // Create an InterruptibleRunner to execute TF calls in a background thread,
+ // allowing us to abort them if need be.
+ auto runner = std::make_unique<InterruptibleRunner>(
+ log_manager, should_abort, timing_config,
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
+ .interrupt_timeout = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT});
+ return absl::WrapUnique(
+ new TfLiteWrapper(std::move(flat_buffer_model), std::move(error_reporter),
+ std::move(delegate), std::move(interpreter),
+ std::move(runner), std::move(output_names)));
+}
+
+absl::StatusOr<OutputTensors> TfLiteWrapper::Run() {
+ auto* interpreter_raw_pointer = interpreter_.get();
+ auto tflite_runnable = [interpreter_raw_pointer, this]() {
+ return ConvertTfLiteStatus(interpreter_raw_pointer->Invoke());
+ };
+ auto* delegate_raw_pointer =
+ static_cast<tflite::FlexDelegate*>(delegate_->data_);
+ auto abort_tflite = [delegate_raw_pointer]() {
+ delegate_raw_pointer->Cancel();
+ };
+ FCP_RETURN_IF_ERROR(
+ interruptible_runner_->Run(tflite_runnable, abort_tflite));
+ // handles output tensors
+ return ConstructOutputs();
+}
+
+absl::Status TfLiteWrapper::ConvertTfLiteStatus(TfLiteStatus status) {
+ switch (status) {
+ case kTfLiteOk:
+ return absl::OkStatus();
+ case kTfLiteError: {
+ // TfLite doesn't differentiate the error type when the training is
+ // cancelled or an error happened during training. It also doesn't
+ // distinguish different error types thrown by Tensorflow. Therefore, we
+ // need to check whether the training was cancelled, and record the error
+ // message from the ErrorReporter.
+ if (tflite::FlexDelegate::HasCancelled(delegate_->data_)) {
+ return absl::CancelledError("Training is cancelled.");
+ }
+ std::string error = error_reporter_->GetFirstErrorMessage();
+ if (error.empty()) {
+ return absl::InvalidArgumentError("Empty error messages returned.");
+ }
+ // Use the first error we encountered.
+ return absl::InvalidArgumentError(error);
+ }
+ case kTfLiteDelegateError:
+ return absl::InvalidArgumentError("TfLite delegate error.");
+ case kTfLiteApplicationError:
+ return absl::InvalidArgumentError(
+ "An error in applying a delegate due to incompatibility between "
+ "runtime and delegate");
+ case kTfLiteDelegateDataNotFound:
+ return absl::InvalidArgumentError(
+ "Serialized delegate data not being found");
+ case kTfLiteDelegateDataWriteError:
+ return absl::InvalidArgumentError(
+ "Data-writing issues in delegate serialization");
+ case kTfLiteDelegateDataReadError:
+ return absl::InvalidArgumentError(
+ "Data-reading issues in delegate serialization.");
+ case kTfLiteUnresolvedOps:
+ return absl::InvalidArgumentError(
+ "The TF Lite model has ops that cannot be resolved at runtime.");
+ default:
+ return absl::InternalError("Unexpected TfLiteStatus.");
+ }
+}
+
+absl::StatusOr<OutputTensors> TfLiteWrapper::ConstructOutputs() {
+ if (interpreter_->outputs().size() != output_names_.size()) {
+ return absl::InvalidArgumentError(
+ absl::StrFormat("The number of output tensors is wrong. Expected: %d, "
+ "Returned by TFLite interpreter: %d",
+ output_names_.size(), interpreter_->outputs().size()));
+ }
+ OutputTensors output_tensors;
+ // The order of the output tensors should match the order of output tensor
+ // names.
+ for (int output_tensor_index : interpreter_->outputs()) {
+ auto tensor = tflite::flex::CreateTfTensorFromTfLiteTensor(
+ interpreter_->tensor(output_tensor_index));
+ if (!tensor.ok()) {
+#if TF_GRAPH_DEF_VERSION < 1467
+ return absl::InvalidArgumentError(tensor.status().error_message());
+#else
+ return absl::InvalidArgumentError(tensor.status().message());
+#endif
+ }
+ output_tensors.output_tensors.push_back(*tensor);
+ }
+ output_tensors.output_tensor_names = output_names_;
+ return output_tensors;
+}
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/engine/tflite_wrapper.h b/fcp/client/engine/tflite_wrapper.h
new file mode 100644
index 0000000..e38d0ed
--- /dev/null
+++ b/fcp/client/engine/tflite_wrapper.h
@@ -0,0 +1,121 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_
+#define FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_
+
+#include <functional>
+#include <string>
+#include <utility>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/client/engine/caching_error_reporter.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/simple_task_environment.h"
+#include "tensorflow/lite/delegates/flex/delegate.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/model_builder.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+
+struct OutputTensors {
+ std::vector<std::string> output_tensor_names;
+ std::vector<tensorflow::Tensor> output_tensors;
+};
+
+// Options for TFLite interpreter.
+struct TfLiteInterpreterOptions {
+ // When true, TFLite uses dynamic tensor allocation and release tensors that
+ // are no longer needed.
+ bool ensure_dynamic_tensors_are_released = false;
+ // When the threshold is zero, dynamic allocation is not enabled for any
+ // tensor.
+ int32_t large_tensor_threshold_for_dynamic_allocation = 0;
+ // Whether to disable the graph-reordering optimization that clusters delegate
+ // ops together.
+ bool disable_delegate_clustering = false;
+};
+
+// A class to call into TFLite.
+// All functions in this interface indicate errors as follows:
+// - CANCELLED: interrupted execution
+// - INVALID_ARGUMENT:
+// 1. Invalid model.
+// 2. Initialization failure for TFLite required classes such as Interpreter,
+// Delegate etc.
+// 3. Missing required inputs.
+// 4. TensorFlow error. The TensorFlow error messages are included in the
+// Status message.
+// This class supports aborting ongoing calls, by polling the provided
+// should_abort function.
+// Parameters:
+// 1. model: The serialized TFLite model.
+// 2. should_abort: A function which will be polled periodically to determine
+// if the computation should be aborted.
+// 3. timing_config: The TimingConfig for an InterruptibleRunner.
+// 4. log_manager: A LogManager.
+// 5. inputs: A hashmap which has input tensor name as key, tensor data as
+// value.
+// 6. output_names: The names of the output tensors. The order for these
+// tensor names must be deterministic.
+class TfLiteWrapper {
+ public:
+ static absl::StatusOr<std::unique_ptr<TfLiteWrapper>> Create(
+ const std::string& model, std::function<bool()> should_abort,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ LogManager* log_manager,
+ std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
+ std::vector<std::string> output_names,
+ const TfLiteInterpreterOptions& interpreter_options, int32_t num_threads);
+
+ // Wrapper around TfLite's Interpreter::Invoke method.
+ // If the run succeeds, a vector of output tensors (empty if there's no
+ // output tensors), or CANCELLED if the training run was cancelled or
+ // INVALID_ARGUMENT for the rest of errors.
+ absl::StatusOr<OutputTensors> Run();
+
+ private:
+ TfLiteWrapper(std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<CachingErrorReporter> error_reporter,
+ tflite::TfLiteDelegateUniquePtr delegate,
+ std::unique_ptr<tflite::Interpreter> interpreter,
+ std::unique_ptr<InterruptibleRunner> interruptible_runner,
+ std::vector<std::string> output_names)
+ : model_(std::move(model)),
+ error_reporter_(std::move(error_reporter)),
+ delegate_(std::move(delegate)),
+ interpreter_(std::move(interpreter)),
+ interruptible_runner_(std::move(interruptible_runner)),
+ output_names_(std::move(output_names)) {}
+ absl::Status ConvertTfLiteStatus(TfLiteStatus status);
+ absl::StatusOr<OutputTensors> ConstructOutputs();
+
+ std::unique_ptr<tflite::FlatBufferModel> model_;
+ std::unique_ptr<CachingErrorReporter> error_reporter_;
+ tflite::TfLiteDelegateUniquePtr delegate_;
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+ std::unique_ptr<InterruptibleRunner> interruptible_runner_;
+ const std::vector<std::string> output_names_;
+};
+
+} // namespace engine
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_
diff --git a/fcp/client/engine/tflite_wrapper_test.cc b/fcp/client/engine/tflite_wrapper_test.cc
new file mode 100644
index 0000000..ae9d0c0
--- /dev/null
+++ b/fcp/client/engine/tflite_wrapper_test.cc
@@ -0,0 +1,140 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/engine/tflite_wrapper.h"
+
+#include <fstream>
+#include <string>
+#include <utility>
+
+#include "gtest/gtest.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace engine {
+namespace {
+
+const absl::string_view kAssetsPath = "fcp/client/engine/data/";
+const absl::string_view kJoinModelFile = "join_model.flatbuffer";
+
+const int32_t kNumThreads = 4;
+
+class TfLiteWrapperTest : public testing::Test {
+ protected:
+ absl::StatusOr<std::string> ReadFileAsString(const std::string& path) {
+ std::ifstream input_istream(path);
+ if (!input_istream) {
+ return absl::InternalError("Failed to create input stream.");
+ }
+ std::stringstream output_stream;
+ output_stream << input_istream.rdbuf();
+ return output_stream.str();
+ }
+
+ MockLogManager mock_log_manager_;
+ InterruptibleRunner::TimingConfig default_timing_config_ =
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::Milliseconds(1000),
+ .graceful_shutdown_period = absl::Milliseconds(1000),
+ .extended_shutdown_period = absl::Milliseconds(2000),
+ };
+ std::vector<std::string> output_names_ = {"Identity"};
+ TfLiteInterpreterOptions options_ = {
+ .ensure_dynamic_tensors_are_released = true,
+ .large_tensor_threshold_for_dynamic_allocation = 1000};
+};
+
+TEST_F(TfLiteWrapperTest, InvalidModel) {
+ EXPECT_THAT(
+ TfLiteWrapper::Create(
+ "INVALID_FLATBUFFER", []() { return false; }, default_timing_config_,
+ &mock_log_manager_,
+ std::make_unique<absl::flat_hash_map<std::string, std::string>>(),
+ output_names_, options_, kNumThreads),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(TfLiteWrapperTest, InputNotSet) {
+ auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
+ ASSERT_OK(plan);
+ // The plan that we use here join two strings. It requires two string tensors
+ // as input. We didn't pass the required tensor, therefore, we expect an
+ // internal error to be thrown.
+ EXPECT_THAT(
+ TfLiteWrapper::Create(
+ *plan, []() { return false; }, default_timing_config_,
+ &mock_log_manager_,
+ std::make_unique<absl::flat_hash_map<std::string, std::string>>(),
+ output_names_, options_, kNumThreads),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(TfLiteWrapperTest, WrongNumberOfOutputs) {
+ auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
+ ASSERT_OK(plan);
+ // The plan that we use here join two strings. It requires two string tensors
+ // as input. We didn't pass the required tensor, therefore, we expect an
+ // internal error to be thrown.
+ EXPECT_THAT(
+ TfLiteWrapper::Create(
+ *plan, []() { return false; }, default_timing_config_,
+ &mock_log_manager_,
+ std::make_unique<absl::flat_hash_map<std::string, std::string>>(),
+ {"Identity", "EXTRA"}, options_, kNumThreads),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(TfLiteWrapperTest, Aborted) {
+ auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
+ ASSERT_OK(plan);
+ auto inputs =
+ std::make_unique<absl::flat_hash_map<std::string, std::string>>();
+ (*inputs)["x"] = "abc";
+ (*inputs)["y"] = "def";
+ // The should_abort function is set to always return true, therefore we expect
+ // to see a CANCELLED status when we run the plan.
+ auto wrapper = TfLiteWrapper::Create(
+ *plan, []() { return true; }, default_timing_config_, &mock_log_manager_,
+ std::move(inputs), output_names_, options_, kNumThreads);
+ ASSERT_OK(wrapper);
+ EXPECT_THAT((*wrapper)->Run(), IsCode(CANCELLED));
+}
+
+TEST_F(TfLiteWrapperTest, Success) {
+ auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
+ ASSERT_OK(plan);
+ auto inputs =
+ std::make_unique<absl::flat_hash_map<std::string, std::string>>();
+ (*inputs)["x"] = "abc";
+ (*inputs)["y"] = "def";
+ auto wrapper = TfLiteWrapper::Create(
+ *plan, []() { return false; }, default_timing_config_, &mock_log_manager_,
+ std::move(inputs), output_names_, options_, kNumThreads);
+ EXPECT_THAT(wrapper, IsCode(OK));
+ auto outputs = (*wrapper)->Run();
+ ASSERT_OK(outputs);
+ EXPECT_EQ(outputs->output_tensor_names.size(), 1);
+ EXPECT_EQ(
+ *static_cast<tensorflow::tstring*>(outputs->output_tensors.at(0).data()),
+ "abcdef");
+}
+
+} // anonymous namespace
+} // namespace engine
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/event_publisher.h b/fcp/client/event_publisher.h
new file mode 100644
index 0000000..c10e85b
--- /dev/null
+++ b/fcp/client/event_publisher.h
@@ -0,0 +1,274 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_EVENT_PUBLISHER_H_
+#define FCP_CLIENT_EVENT_PUBLISHER_H_
+
+#include <cstdint>
+#include <string>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "fcp/client/stats.h"
+
+namespace fcp {
+namespace client {
+
+class SecAggEventPublisher;
+
+// An interface for publishing events that occur during training. This is a
+// separate interface from LogManager because the reported events will typically
+// be both reported to a cloud monitoring backend and to the Federated server as
+// part of publishing results.
+// All methods in here either succeed with OK, or fail with INVALID_ARGUMENT.
+class EventPublisher {
+ public:
+ virtual ~EventPublisher() = default;
+
+ // Publishes that the device is about to issue an eligibility eval check in
+ // with the server.
+ virtual void PublishEligibilityEvalCheckin() = 0;
+
+ // Publishes that the device has finished its eligibility eval checkin with
+ // the server, and received the URIs to download the eligibility eval plan
+ // with, but hasn't actually downloaded them yet, along with information
+ // how much data was transferred up to this point and how long that took.
+ virtual void PublishEligibilityEvalPlanUriReceived(
+ const NetworkStats& network_stats, absl::Duration phase_duration) = 0;
+
+ // Publishes that the device has finished its eligibility eval checkin with
+ // the server, and received an eligibility eval plan, along with information
+ // how much data was transferred and how long that took.
+ virtual void PublishEligibilityEvalPlanReceived(
+ const NetworkStats& network_stats, absl::Duration phase_duration) = 0;
+
+ // Publishes that the server did not return an eligibility eval task to the
+ // client, along with information how much data was transferred and how long
+ // that took.
+ virtual void PublishEligibilityEvalNotConfigured(
+ const NetworkStats& network_stats, absl::Duration phase_duration) = 0;
+
+ // Publishes that the server rejected the device's eligibility eval checkin,
+ // along with information how much data was downloaded and how long that took.
+ virtual void PublishEligibilityEvalRejected(
+ const NetworkStats& network_stats, absl::Duration phase_duration) = 0;
+
+ // Publishes that the device is about to check in with the server.
+ virtual void PublishCheckin() = 0;
+
+ // Publishes that the device has finished checking in with the server, along
+ // with information how much data was downloaded and how long that took.
+ virtual void PublishCheckinFinished(const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+
+ // Publishes that the server rejected the device.
+ virtual void PublishRejected() = 0;
+
+ // Publishes that the device is about to report the results of a federated
+ // computation to the server.
+ virtual void PublishReportStarted(int64_t report_size_bytes) = 0;
+
+ // Publishes that the device has successfully reported its results to the
+ // server and received instructions on when to reconnect.
+ virtual void PublishReportFinished(const NetworkStats& network_stats,
+ absl::Duration report_duration) = 0;
+
+ // Publishes that plan execution has started.
+ virtual void PublishPlanExecutionStarted() = 0;
+
+ // Publishes a TensorFlow error that happened in the given ClientExecution.
+ virtual void PublishTensorFlowError(int example_count,
+ absl::string_view error_message) = 0;
+
+ // Publishes an I/O error (e.g. disk, network) that happened in the given
+ // ClientExecution.
+ virtual void PublishIoError(absl::string_view error_message) = 0;
+
+ // Publishes an ExampleSelector error from the given ClientExecution.
+ virtual void PublishExampleSelectorError(int example_count,
+ absl::string_view error_message) = 0;
+
+ // Publishes an interruption event for the given client execution.
+ virtual void PublishInterruption(const ExampleStats& example_stats,
+ absl::Time start_time) = 0;
+
+ // Publishes an event that plan execution is complete.
+ virtual void PublishPlanCompleted(const ExampleStats& example_stats,
+ absl::Time start_time) = 0;
+ // Publishes that the task didn't start.
+ virtual void PublishTaskNotStarted(absl::string_view error_message) = 0;
+
+ // Publishes that the federated compute runtime failed to initialize a
+ // noncritical component, but execution continued.
+ virtual void PublishNonfatalInitializationError(
+ absl::string_view error_message) = 0;
+ // Publishes that the federated compute runtime failed to initialize a
+ // component, and execution was halted.
+ virtual void PublishFatalInitializationError(
+ absl::string_view error_message) = 0;
+
+ // Publish that an IO error was encountered during eligibility eval check-in.
+ virtual void PublishEligibilityEvalCheckinIoError(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the eligibility eval check-in is interrupted by the client.
+ virtual void PublishEligibilityEvalCheckinClientInterrupted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the eligibility eval check-in is aborted by the server.
+ virtual void PublishEligibilityEvalCheckinServerAborted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the eligibility eval check-in returned an invalid payload.
+ virtual void PublishEligibilityEvalCheckinErrorInvalidPayload(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish an eligibility eval task starts computation.
+ virtual void PublishEligibilityEvalComputationStarted() = 0;
+ // Publish that the eligibility eval task is invalid.
+ virtual void PublishEligibilityEvalComputationInvalidArgument(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish an example iterator error occurred during eligibility eval task.
+ virtual void PublishEligibilityEvalComputationExampleIteratorError(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that a tensorflow error occurred during eligibility eval task.
+ virtual void PublishEligibilityEvalComputationTensorflowError(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the client has interrupted the eligibility eval task.
+ virtual void PublishEligibilityEvalComputationInterrupted(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish an eligibility eval task finished.
+ virtual void PublishEligibilityEvalComputationCompleted(
+ const ExampleStats& example_stats, absl::Duration phase_duration) = 0;
+ // Publish an IO error occurred during regular check-in.
+ virtual void PublishCheckinIoError(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the client interrupted the regular check-in.
+ virtual void PublishCheckinClientInterrupted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the server aborted the regular check-in.
+ virtual void PublishCheckinServerAborted(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that an invalid payload was downloaded from the regular check-in.
+ virtual void PublishCheckinInvalidPayload(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publishes that the server rejected the device, also logs network stats and
+ // duration.
+ virtual void PublishRejected(const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+
+ // Publishes that the device has finished checking in with the server and
+ // received URIs to download the plan and checkpoint with, but hasn't yet
+ // downloaded those, along with information how much data was transferred up
+ // to this point and how long that took.
+ virtual void PublishCheckinPlanUriReceived(const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publishes that the device has finished checking in with the server, along
+ // with information how much data was transferred and how long that took.
+ virtual void PublishCheckinFinishedV2(const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publishes that plan execution has started.
+ virtual void PublishComputationStarted() = 0;
+ // Publish that the task is invalid.
+ virtual void PublishComputationInvalidArgument(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ const NetworkStats& network_stats, absl::Duration phase_duration) = 0;
+ // Publish that an IO error occurred during computation.
+ virtual void PublishComputationIOError(absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that an example iterator error occurred during computation.
+ virtual void PublishComputationExampleIteratorError(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ const NetworkStats& network_stats, absl::Duration phase_duration) = 0;
+ // Publish that an tensorflow error occurred during computation.
+ virtual void PublishComputationTensorflowError(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ const NetworkStats& network_stats, absl::Duration phase_duration) = 0;
+ // Publish that the task computation is interrupted.
+ virtual void PublishComputationInterrupted(absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publishes an event that plan execution is complete.
+ virtual void PublishComputationCompleted(const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the client starts to upload result.
+ virtual void PublishResultUploadStarted() = 0;
+ // Publish that an IO error occurred during result upload.
+ virtual void PublishResultUploadIOError(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the client has interrupted the result upload.
+ virtual void PublishResultUploadClientInterrupted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish hat the server has aborted the result upload.
+ virtual void PublishResultUploadServerAborted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the result upload is completed.
+ virtual void PublishResultUploadCompleted(const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the task computation has failed, and the client starts to
+ // upload the failure to the server.
+ virtual void PublishFailureUploadStarted() = 0;
+ // Publish that an IO error occurred during failure upload.
+ virtual void PublishFailureUploadIOError(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the client has interrupted the failure upload.
+ virtual void PublishFailureUploadClientInterrupted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the server has aborted the failure upload.
+ virtual void PublishFailureUploadServerAborted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+ // Publish that the failure upload completed.
+ virtual void PublishFailureUploadCompleted(const NetworkStats& network_stats,
+ absl::Duration phase_duration) = 0;
+
+ // After calling this function, all subsequently published events will be
+ // annotated with the specified model_identifier. This value is typically
+ // provided by the federated server and used on events resulting from
+ // PublishEligibilityEvalCheckinFinished(), PublishCheckinFinished() and
+ // later.
+ //
+ // Note that this method may be called multiple times with different values,
+ // if over the course of a training session multiple models are executed.
+ virtual void SetModelIdentifier(const std::string& model_identifier) = 0;
+
+ // Returns a pointer to a publisher which records secure aggregation protocol
+ // events. The returned value must not be nullptr.
+ virtual SecAggEventPublisher* secagg_event_publisher() = 0;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_EVENT_PUBLISHER_H_
diff --git a/fcp/client/example_query_result.proto b/fcp/client/example_query_result.proto
new file mode 100644
index 0000000..f4c8b31
--- /dev/null
+++ b/fcp/client/example_query_result.proto
@@ -0,0 +1,73 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.client;
+
+option java_package = "com.google.intelligence.fcp.client";
+option java_multiple_files = true;
+
+// Describes the result of an example query, as a series of vectors. Example
+// iterators invoked using `ExampleQuerySpec` are expected to return a single
+// result that is a serialized proto of this type.
+message ExampleQueryResult {
+ message VectorData {
+ message Int32Values {
+ repeated int32 value = 1;
+ }
+
+ message Int64Values {
+ repeated int64 value = 1;
+ }
+
+ message BoolValues {
+ repeated bool value = 1;
+ }
+
+ message FloatValues {
+ repeated float value = 1;
+ }
+
+ message DoubleValues {
+ repeated double value = 1;
+ }
+
+ message StringValues {
+ repeated string value = 1;
+ }
+
+ message BytesValues {
+ repeated bytes value = 1;
+ }
+
+ message Values {
+ oneof values {
+ Int32Values int32_values = 1;
+ Int64Values int64_values = 2;
+ BoolValues bool_values = 3;
+ FloatValues float_values = 4;
+ DoubleValues double_values = 5;
+ StringValues string_values = 6;
+ BytesValues bytes_values = 7;
+ }
+ }
+
+ // Maps a name of the result vector to its values.
+ map<string, Values> vectors = 1;
+ }
+
+ // Vector data fetched from the example store.
+ VectorData vector_data = 1;
+}
diff --git a/fcp/client/fake_event_publisher.h b/fcp/client/fake_event_publisher.h
new file mode 100644
index 0000000..2b6b722
--- /dev/null
+++ b/fcp/client/fake_event_publisher.h
@@ -0,0 +1,398 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_FAKE_EVENT_PUBLISHER_H_
+#define FCP_CLIENT_FAKE_EVENT_PUBLISHER_H_
+
+#include "absl/strings/str_split.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/secagg_event_publisher.h"
+#include "fcp/client/stats.h"
+
+namespace fcp {
+namespace client {
+
+class SecAggEventPublisher;
+
+// Macro to print log messages prefixed by ClassName::FunctionName, stripping
+// namespaces before ClassName, if any.
+#define FCP_CLIENT_LOG_FUNCTION_NAME \
+ FCP_LOG(INFO) << "::" << __func__; \
+ // std::string _demangle_buf(1024, '\0'); \
+ // size_t _demangle_buf_len = _demangle_buf.length(); \
+ // abi::__cxa_demangle(typeid(*this).name(), _demangle_buf.data(), \
+ // &_demangle_buf_len, nullptr); \
+ // FCP_LOG(INFO) << static_cast<std::vector<std::string>>( \
+ // absl::StrSplit(_demangle_buf, "::")) \
+ // .back() \
+ // .c_str() \
+ // << "::" << __func__
+
+// An implementation of the SecAggEventPublisher interface that logs calls to
+// stderr.
+class SecAggLoggingEventPublisher : public SecAggEventPublisher {
+ public:
+ explicit SecAggLoggingEventPublisher(bool quiet) : quiet_(quiet) {}
+
+ void PublishStateTransition(::fcp::secagg::ClientState state,
+ size_t last_sent_message_size,
+ size_t last_received_message_size) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+ void PublishError() override { FCP_CLIENT_LOG_FUNCTION_NAME; }
+ void PublishAbort(bool client_initiated,
+ const std::string& error_message) override {
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+ void set_execution_session_id(int64_t execution_session_id) override {
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ private:
+ const bool quiet_;
+};
+
+// An implementation of the EventPublisher interface that logs calls to stderr.
+class FakeEventPublisher : public EventPublisher {
+ public:
+ // Logs all events to stderr.
+ FakeEventPublisher() : FakeEventPublisher(/*quiet=*/false) {}
+ // Logs only error and "client rejected" events to stderr.
+ explicit FakeEventPublisher(bool quiet)
+ : quiet_(quiet), secagg_event_publisher_(quiet) {}
+
+ void PublishEligibilityEvalCheckin() override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+ void PublishEligibilityEvalPlanUriReceived(const NetworkStats&,
+ absl::Duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishEligibilityEvalPlanReceived(const NetworkStats&,
+ absl::Duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishEligibilityEvalNotConfigured(const NetworkStats&,
+ absl::Duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishEligibilityEvalRejected(const NetworkStats&,
+ absl::Duration) override {
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishCheckin() override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishCheckinFinished(const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishRejected() override { FCP_CLIENT_LOG_FUNCTION_NAME; }
+
+ void PublishReportStarted(int64_t report_size_bytes) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishReportFinished(const NetworkStats& network_stats,
+ absl::Duration report_duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishPlanExecutionStarted() override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishTensorFlowError(int example_count,
+ absl::string_view error_message) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishIoError(absl::string_view error_message) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishExampleSelectorError(int example_count,
+ absl::string_view error_message) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishInterruption(const ExampleStats& example_stats,
+ absl::Time start_time) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishPlanCompleted(const ExampleStats& example_stats,
+ absl::Time start_time) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishTaskNotStarted(absl::string_view error_message) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishNonfatalInitializationError(
+ absl::string_view error_message) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishFatalInitializationError(
+ absl::string_view error_message) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishEligibilityEvalCheckinIoError(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishEligibilityEvalCheckinClientInterrupted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishEligibilityEvalCheckinServerAborted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishEligibilityEvalCheckinErrorInvalidPayload(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishEligibilityEvalComputationStarted() override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishEligibilityEvalComputationInvalidArgument(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishEligibilityEvalComputationExampleIteratorError(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishEligibilityEvalComputationTensorflowError(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishEligibilityEvalComputationInterrupted(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishEligibilityEvalComputationCompleted(
+ const ExampleStats& example_stats,
+ absl::Duration phase_duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishCheckinIoError(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishCheckinClientInterrupted(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishCheckinServerAborted(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishCheckinInvalidPayload(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishRejected(const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishCheckinPlanUriReceived(const NetworkStats& network_stats,
+ absl::Duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+ void PublishCheckinFinishedV2(const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishComputationStarted() override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishComputationInvalidArgument(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishComputationIOError(absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishComputationExampleIteratorError(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishComputationTensorflowError(
+ absl::string_view error_message, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishComputationInterrupted(absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishComputationCompleted(const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishResultUploadStarted() override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishResultUploadIOError(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishResultUploadClientInterrupted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishResultUploadServerAborted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishResultUploadCompleted(const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishFailureUploadStarted() override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void PublishFailureUploadIOError(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishFailureUploadClientInterrupted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishFailureUploadServerAborted(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ // FCP_CLIENT_LOG_FUNCTION_NAME << error_message;
+ }
+
+ void PublishFailureUploadCompleted(const NetworkStats& network_stats,
+ absl::Duration phase_duration) override {
+ if (quiet_) return;
+ FCP_CLIENT_LOG_FUNCTION_NAME;
+ }
+
+ void SetModelIdentifier(const std::string& model_identifier) override {
+ if (quiet_) return;
+ // FCP_CLIENT_LOG_FUNCTION_NAME << ":\n\t" << model_identifier;
+ }
+
+ SecAggEventPublisher* secagg_event_publisher() override {
+ return &secagg_event_publisher_;
+ }
+
+ private:
+ const bool quiet_;
+ SecAggLoggingEventPublisher secagg_event_publisher_;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_FAKE_EVENT_PUBLISHER_H_
diff --git a/fcp/client/fake_log_manager.h b/fcp/client/fake_log_manager.h
new file mode 100644
index 0000000..ee91a79
--- /dev/null
+++ b/fcp/client/fake_log_manager.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_FAKE_LOG_MANAGER_H_
+#define FCP_CLIENT_FAKE_LOG_MANAGER_H_
+
+#include "fcp/base/monitoring.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/log_manager.h"
+
+namespace fcp {
+namespace client {
+
+class FakeLogManager : public LogManager {
+ public:
+ void LogDiag(ProdDiagCode diagCode) override { FCP_LOG(ERROR) << diagCode; }
+
+ void LogDiag(DebugDiagCode diagCode) override { FCP_LOG(ERROR) << diagCode; }
+
+ void LogToLongHistogram(HistogramCounters histogram_counter,
+ int execution_index, int epoch_index,
+ engine::DataSourceType data_source_type,
+ int64_t value) override {}
+
+ void SetModelIdentifier(const std::string& model_identifier) override {}
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_FAKE_LOG_MANAGER_H_
diff --git a/fcp/client/fake_server.cc b/fcp/client/fake_server.cc
new file mode 100644
index 0000000..b933ed6
--- /dev/null
+++ b/fcp/client/fake_server.cc
@@ -0,0 +1,121 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/client/fake_server.h"
+
+#include <string>
+#include <utility>
+
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/status_converters.h"
+#include "fcp/client/grpc_bidi_stream.h"
+#include "fcp/protocol/grpc_chunked_bidi_stream.h"
+#include "fcp/protos/federated_api.pb.h"
+
+namespace fcp {
+namespace client {
+namespace test {
+
+using fcp::base::ToGrpcStatus;
+using fcp::client::GrpcChunkedBidiStream;
+using google::internal::federatedml::v2::ClientStreamMessage;
+using google::internal::federatedml::v2::RetryWindow;
+using google::internal::federatedml::v2::ServerStreamMessage;
+
+static RetryWindow GetRetryWindow(const std::string& token, int64_t min,
+ int64_t max) {
+ RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(min);
+ retry_window.mutable_delay_max()->set_seconds(max);
+ *retry_window.mutable_retry_token() = token;
+ return retry_window;
+}
+
+grpc::Status FakeServer::Session(
+ grpc::ServerContext* context,
+ grpc::ServerReaderWriter<ServerStreamMessage, ClientStreamMessage>*
+ stream) {
+ GrpcChunkedBidiStream<ServerStreamMessage, ClientStreamMessage>
+ chunked_bidi_stream(
+ stream, stream,
+ {chunk_size_for_upload_, max_pending_chunks_, compression_level_});
+ ClientStreamMessage request;
+ ServerStreamMessage response;
+ FCP_LOG(INFO) << "Server session started";
+ absl::Status status;
+ while ((status = chunked_bidi_stream.Receive(&request)).ok()) {
+ FCP_LOG(INFO) << "Request is: " << request.DebugString();
+ for (const auto& [key, value] : context->client_metadata()) {
+ client_metadata_.insert(
+ std::make_pair(std::string(key.data(), key.size()),
+ std::string(value.data(), value.size())));
+ }
+ if (request.eligibility_eval_checkin_request()
+ .protocol_options_request()
+ .should_ack_checkin() ||
+ request.checkin_request()
+ .protocol_options_request()
+ .should_ack_checkin()) {
+ ServerStreamMessage checkin_request_ack_msg;
+ auto checkin_request_ack =
+ checkin_request_ack_msg.mutable_checkin_request_ack();
+ *checkin_request_ack->mutable_retry_window_if_accepted() =
+ GetRetryWindow("A", 111L, 222L);
+ *checkin_request_ack->mutable_retry_window_if_rejected() =
+ GetRetryWindow("R", 333L, 444L);
+ if (!chunked_bidi_stream.Send(&checkin_request_ack_msg).ok()) {
+ FCP_LOG(INFO) << "Server returning status " << status;
+ return ToGrpcStatus(status);
+ }
+ }
+ if (request.has_eligibility_eval_checkin_request() ||
+ request.has_checkin_request()) {
+ auto protocol_options_response =
+ request.has_eligibility_eval_checkin_request()
+ ? response.mutable_eligibility_eval_checkin_response()
+ ->mutable_protocol_options_response()
+ : response.mutable_checkin_response()
+ ->mutable_protocol_options_response();
+ protocol_options_response->set_compression_level(compression_level_);
+ protocol_options_response->set_chunk_size_for_upload(
+ chunk_size_for_upload_);
+ protocol_options_response->set_max_pending_chunks(max_pending_chunks_);
+ }
+ if (!(status = Handle(request, &response, &chunked_bidi_stream)).ok()) {
+ FCP_LOG(INFO) << "Server returning status " << status;
+ return ToGrpcStatus(status);
+ }
+ }
+ session_done_.Notify();
+ FCP_LOG(INFO) << "Server returning status " << status;
+ return ToGrpcStatus(status);
+}
+
+std::multimap<std::string, std::string> FakeServer::GetClientMetadata() const {
+ return client_metadata_;
+}
+
+void FakeServer::WaitForSessionDone() { session_done_.WaitForNotification(); }
+
+absl::Status FakeServer::Handle(
+ const ClientStreamMessage& request, ServerStreamMessage* first_reply,
+ GrpcChunkedBidiStream<ServerStreamMessage, ClientStreamMessage>* stream) {
+ return stream->Send(first_reply);
+}
+
+} // namespace test
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/fake_server.h b/fcp/client/fake_server.h
new file mode 100644
index 0000000..3342745
--- /dev/null
+++ b/fcp/client/fake_server.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_FAKE_SERVER_H_
+#define FCP_CLIENT_FAKE_SERVER_H_
+
+#include <cstddef>
+#include <string>
+#include <tuple>
+
+#include "grpcpp/impl/codegen/status.h"
+#include "absl/status/status.h"
+#include "absl/synchronization/notification.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/grpc_bidi_stream.h"
+#include "fcp/protocol/grpc_chunked_bidi_stream.h"
+#include "fcp/protos/federated_api.grpc.pb.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "grpcpp/impl/codegen/server_context.h"
+
+namespace fcp {
+namespace client {
+namespace test {
+
+class FakeServer
+ : public google::internal::federatedml::v2::FederatedTrainingApi::Service {
+ public:
+ FakeServer()
+ : chunk_size_for_upload_(8192),
+ max_pending_chunks_(2),
+ compression_level_(google::internal::federatedml::v2::CompressionLevel::
+ ZLIB_BEST_COMPRESSION) {}
+ FakeServer(
+ int32_t chunk_size_for_upload, int32_t max_pending_chunks,
+ google::internal::federatedml::v2::CompressionLevel compression_level)
+ : chunk_size_for_upload_(chunk_size_for_upload),
+ max_pending_chunks_(max_pending_chunks),
+ compression_level_(compression_level) {}
+
+ // FakeServer is neither copyable nor movable.
+ FakeServer(const FakeServer&) = delete;
+ FakeServer& operator=(const FakeServer&) = delete;
+
+ grpc::Status Session(
+ grpc::ServerContext* context,
+ grpc::ServerReaderWriter<
+ google::internal::federatedml::v2::ServerStreamMessage,
+ google::internal::federatedml::v2::ClientStreamMessage>* stream)
+ override;
+ void WaitForSessionDone();
+
+ virtual absl::Status Handle(
+ const google::internal::federatedml::v2::ClientStreamMessage& request,
+ google::internal::federatedml::v2::ServerStreamMessage* first_reply,
+ ::fcp::client::GrpcChunkedBidiStream<
+ google::internal::federatedml::v2::ServerStreamMessage,
+ google::internal::federatedml::v2::ClientStreamMessage>* stream);
+
+ // Returns the client metadata from the most recent session call.
+ std::multimap<std::string, std::string> GetClientMetadata() const;
+
+ protected:
+ int32_t chunk_size_for_upload_;
+ int32_t max_pending_chunks_;
+ google::internal::federatedml::v2::CompressionLevel compression_level_;
+ absl::Notification session_done_;
+
+ private:
+ std::multimap<std::string, std::string> client_metadata_;
+};
+
+} // namespace test
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_FAKE_SERVER_H_
diff --git a/fcp/client/federated_protocol.h b/fcp/client/federated_protocol.h
new file mode 100644
index 0000000..39c63f3
--- /dev/null
+++ b/fcp/client/federated_protocol.h
@@ -0,0 +1,397 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_FEDERATED_PROTOCOL_H_
+#define FCP_CLIENT_FEDERATED_PROTOCOL_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp {
+namespace client {
+
+// Data type used to encode results of a computation - a TensorFlow
+// checkpoint, or SecAgg quantized tensors.
+// For non-SecAgg use (simple federated aggregation, or local computation),
+// this map should only contain one entry - a TFCheckpoint - and the string
+// should be ignored by downstream code.
+// For SecAgg use, there should be
+// * at most one TFCheckpoint - again, the key should be ignored - and
+// * N QuantizedTensors, whose string keys must map to the tensor names
+// provided in the server's CheckinResponse's SideChannelExecutionInfo.
+using TFCheckpoint = std::string;
+struct QuantizedTensor {
+ std::vector<uint64_t> values;
+ int32_t bitwidth = 0;
+ std::vector<int64_t> dimensions;
+
+ QuantizedTensor() = default;
+ // Disallow copy and assign.
+ QuantizedTensor(const QuantizedTensor&) = delete;
+ QuantizedTensor& operator=(const QuantizedTensor&) = delete;
+ // Enable move semantics.
+ QuantizedTensor(QuantizedTensor&&) = default;
+ QuantizedTensor& operator=(QuantizedTensor&&) = default;
+};
+// This is equivalent to using ComputationResults =
+// std::map<std::string, std::variant<TFCheckpoint, QuantizedTensor>>;
+// except copy construction and assignment are explicitly prohibited and move
+// semantics is enforced.
+class ComputationResults
+ : public absl::node_hash_map<std::string,
+ std::variant<TFCheckpoint, QuantizedTensor>> {
+ public:
+ using Base = absl::node_hash_map<std::string,
+ std::variant<TFCheckpoint, QuantizedTensor>>;
+ using Base::Base;
+ using Base::operator=;
+ ComputationResults(const ComputationResults&) = delete;
+ ComputationResults& operator=(const ComputationResults&) = delete;
+ ComputationResults(ComputationResults&&) = default;
+ ComputationResults& operator=(ComputationResults&&) = default;
+};
+
+// An interface that represents a single Federated Compute protocol session.
+//
+// An instance of this class represents a single session of client-server
+// interaction. Instances are generally stateful, and therefore cannot be
+// reused (each session should use a dedicated instance).
+//
+// The protocol consists of 3 phases, which must occur in the following order:
+// 1. A call to `EligibilityEvalCheckin()`.
+// 2. A call to `Checkin(...)`, only if the client wasn't rejected by the server
+// in the previous phase.
+// 3. A call to `ReportCompleted(...)` or `ReportNotCompleted(...)`, only if the
+// client wasn't rejected in the previous phase.
+class FederatedProtocol {
+ public:
+ virtual ~FederatedProtocol() = default;
+
+ // The unparsed plan and checkpoint payload which make up a computation. The
+ // data can be provided as either an std::string or an absl::Cord.
+ struct PlanAndCheckpointPayloads {
+ std::variant<std::string, absl::Cord> plan;
+ std::variant<std::string, absl::Cord> checkpoint;
+ };
+
+ // An eligibility task, consisting of task payloads and an execution ID.
+ struct EligibilityEvalTask {
+ PlanAndCheckpointPayloads payloads;
+ std::string execution_id;
+ std::optional<
+ google::internal::federatedcompute::v1::PopulationEligibilitySpec>
+ population_eligibility_spec;
+ };
+ // A rejection of a client by the server.
+ struct Rejection {};
+ // Indicates that the server does not have an eligibility eval task configured
+ // for the population.
+ struct EligibilityEvalDisabled {};
+ // EligibilityEvalCheckin() returns either
+ // 1. an `EligibilityEvalTask` struct holding the payloads for an eligibility
+ // eval task, if the population is configured with such a task. In this
+ // case the caller should execute the task and pass the resulting
+ // `TaskEligibilityInfo` value to the `Checkin(...)` method.
+ // 2. an `EligibilityEvalDisabled` struct if the population doesn't have an
+ // eligibility eval task configured. In this case the caller should
+ // continue the protocol by calling the `Checkin(...)` method without
+ // providing a `TaskEligibilityInfo` value.
+ // 3. a `Rejection` if the server rejected this device. In this case the
+ // caller
+ // should end its protocol interaction.
+ using EligibilityEvalCheckinResult =
+ std::variant<EligibilityEvalTask, EligibilityEvalDisabled, Rejection>;
+
+ // Checks in with a federated server to receive the population's eligibility
+ // eval task. This method is optional and may be called 0 or 1 times. If it is
+ // called, then it must be called before any call to `Checkin(...)`.
+ //
+ // If an eligibility eval task is configured, then the
+ // `payload_uris_received_callback` function will be called with a partially
+ // populated `EligibilityEvalTask` containing all of the task's info except
+ // for the actual payloads (which are yet to be fetched at that point).
+ //
+ // Returns:
+ // - On success, an EligibilityEvalCheckinResult.
+ // - On error:
+ // - ABORTED when one of the I/O operations got aborted by the server.
+ // - CANCELLED when one of the I/O operations was interrupted by the client
+ // (possibly due to a positive result from the should_abort callback).
+ // - UNAVAILABLE when server cannot be reached or URI is invalid.
+ // - NOT_FOUND if the server responds with NOT_FOUND, e.g. because the
+ // specified population name is incorrect.
+ // - UNIMPLEMENTED if an unexpected server response is received.
+ // - INTERNAL if the server-provided ClientOnlyPlan cannot be parsed. (See
+ // note in federated_protocol.cc for the reasoning for this.)
+ // - INTERNAL for other unexpected client-side errors.
+ // - any server-provided error code.
+ virtual absl::StatusOr<EligibilityEvalCheckinResult> EligibilityEvalCheckin(
+ std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback) = 0;
+
+ // Report an eligibility eval task error to the federated server.
+ // Must only be called once and after a successful call to
+ // EligibilityEvalCheckin() which returns an eligibility eval task. This
+ // method is only used to report an error happened during the computation of
+ // the eligibility eval task. If the eligibility eval computation succeeds,
+ // the success will be reported during task assignment.
+ // @param status the outcome of the eligibility eval computation.
+ virtual void ReportEligibilityEvalError(absl::Status error_status) = 0;
+
+ // SecAgg metadata, e.g. see SecureAggregationProtocolExecutionInfo in
+ // federated_api.proto.
+ struct SecAggInfo {
+ int32_t expected_number_of_clients;
+ int32_t minimum_clients_in_server_visible_aggregate;
+ };
+
+ // A task assignment, consisting of task payloads, a URI template to download
+ // federated select task slices with (if the plan uses federated select), a
+ // session identifier, and SecAgg-related metadata.
+ struct TaskAssignment {
+ PlanAndCheckpointPayloads payloads;
+ std::string federated_select_uri_template;
+ std::string aggregation_session_id;
+ std::optional<SecAggInfo> sec_agg_info;
+ };
+ // Checkin() returns either
+ // 1. a `TaskAssignment` struct if the client was assigned a task to run, or
+ // 2. a `Rejection` struct if the server rejected this device.
+ using CheckinResult = std::variant<TaskAssignment, Rejection>;
+
+ // Checks in with a federated server. Must only be called once. If the
+ // `EligibilityEvalCheckin()` method was previously called, then this method
+ // must only be called if the result of that call was not a `Rejection`.
+ //
+ // If the caller previously called `EligibilityEvalCheckin()` and:
+ // - received a payload, then the `TaskEligibilityInfo` value computed by that
+ // payload must be provided via the `task_eligibility_info` parameter.
+ // - received an `EligibilityEvalDisabled` result, then the
+ // `task_eligibility_info` parameter should be left empty.
+ //
+ // If the caller did not previously call `EligibilityEvalCheckin()`, then the
+ // `task_eligibility_info` parameter should be left empty.
+ //
+ // If the client is assigned a task by the server, then the
+ // `payload_uris_received_callback` function will be called with a partially
+ // populated `TaskAssignment` containing all of the task's info except for the
+ // actual payloads (which are yet to be fetched at that point)
+ //
+ // Returns:
+ // - On success, a `CheckinResult`.
+ // - On error:
+ // - ABORTED when one of the I/O operations got aborted by the server.
+ // - CANCELLED when one of the I/O operations was interrupted by the client
+ // (possibly due to a positive result from the should_abort callback).
+ // - UNAVAILABLE when server cannot be reached or URI is invalid.
+ // - NOT_FOUND if the server responds with NOT_FOUND, e.g. because the
+ // specified population name is incorrect.
+ // - UNIMPLEMENTED if an unexpected server response is received.
+ // - INTERNAL if the server-provided ClientOnlyPlan cannot be parsed. (See
+ // note in federated_protocol.cc for the reasoning for this.)
+ // - INTERNAL for other unexpected client-side errors.
+ // - any server-provided error code.
+ // TODO(team): Replace this reference to protocol-specific
+ // TaskEligibilityInfo proto with a protocol-agnostic struct.
+ virtual absl::StatusOr<CheckinResult> Checkin(
+ const std::optional<
+ google::internal::federatedml::v2::TaskEligibilityInfo>&
+ task_eligibility_info,
+ std::function<void(const TaskAssignment&)>
+ payload_uris_received_callback) = 0;
+
+ // A list of absl::StatusOr<TaskAssignment> returned by
+ // PerformMultipleTaskAssignments. Individual absl::StatusOr<TaskAssignment>
+ // may be an error status due to failed to fetch the plan resources.
+ struct MultipleTaskAssignments {
+ std::vector<absl::StatusOr<TaskAssignment>> task_assignments;
+ };
+
+ // Checks in with a federated server to get multiple task assignments.
+ //
+ // Must only be called once after the following conditions are met:
+ //
+ // - the caller previously called `EligibilityEvalCheckin()` and,
+ // - received a payload, and the returned EligibilityEvalTask's
+ // `PopulationEligibilitySpec` contained at least one task with
+ // TASK_ASSIGNMENT_MODE_MULTIPLE, for which the device is eligible.
+ //
+ //
+ // Returns:
+ // - On success, a `MultipleTaskAssignments`.
+ // - On error:
+ // - ABORTED when one of the I/O operations got aborted by the server.
+ // - CANCELLED when one of the I/O operations was interrupted by the client
+ // (possibly due to a positive result from the should_abort callback).
+ // - UNAVAILABLE when server cannot be reached or URI is invalid.
+ // - NOT_FOUND if the server responds with NOT_FOUND, e.g. because the
+ // specified population name is incorrect.
+ // - UNIMPLEMENTED if an unexpected server response is received.
+ // - INTERNAL for other unexpected client-side errors.
+ // - any server-provided error code.
+ virtual absl::StatusOr<MultipleTaskAssignments>
+ PerformMultipleTaskAssignments(
+ const std::vector<std::string>& task_names) = 0;
+
+ // Reports the result of a federated computation to the server. Must only be
+ // called once and after a successful call to Checkin().
+ // @param checkpoint A checkpoint proto.
+ // @param stats all stats reported during the computation.
+ // @param plan_duration the duration for executing the plan in the plan
+ // engine. Does not include time spent on downloading the plan.
+ // Returns:
+ // - On success, OK.
+ // - On error (e.g. an interruption, network error, or other unexpected
+ // error):
+ // - ABORTED when one of the I/O operations got aborted by the server.
+ // - CANCELLED when one of the I/O operations was interrupted by the client
+ // (possibly due to a positive result from the should_abort callback).
+ // - UNIMPLEMENTED if the server responded with an unexpected response
+ // message.
+ // - INTERNAL for other unexpected client-side errors.
+ // - any server-provided error code.
+ virtual absl::Status ReportCompleted(
+ ComputationResults results, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) = 0;
+
+ // Reports the unsuccessful result of a federated computation to the server.
+ // Must only be called once and after a successful call to Checkin().
+ // @param phase_outcome the outcome of the federated computation.
+ // @param plan_duration the duration for executing the plan in the plan
+ // engine. Does not include time spent on downloading the plan.
+ // Returns:
+ // - On success, OK.
+ // - On error:
+ // - ABORTED when one of the I/O operations got aborted by the server.
+ // - CANCELLED when one of the I/O operations was interrupted by the client
+ // (possibly due to a positive result from the should_abort callback).
+ // - UNIMPLEMENTED if the server responded with an unexpected response
+ // message, or if the results to report require SecAgg support.
+ // - INTERNAL for other unexpected client-side errors.
+ // - any server-provided error code.
+ virtual absl::Status ReportNotCompleted(
+ engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) = 0;
+
+ // Returns the RetryWindow the caller should use when rescheduling, based on
+ // the current protocol phase. The value returned by this method may change
+ // after every interaction with the protocol, so callers should call this
+ // right before ending their interactions with the FederatedProtocol object to
+ // ensure they use the most recent value.
+ // TODO(team): Replace this reference to protocol-specific
+ // RetryWindow proto with a protocol-agnostic struct (or just a single
+ // absl::Duration).
+ virtual google::internal::federatedml::v2::RetryWindow
+ GetLatestRetryWindow() = 0;
+
+ // Returns the best estimate of the total bytes downloaded and uploaded over
+ // the network, plus the best estimate of the duration of wall clock time
+ // spent waiting for network requests to finish (but, for example, excluding
+ // any idle time spent waiting between issuing polling requests).
+ //
+ // Note that this estimate may still include time spent simply waiting for a
+ // server response, even if no data was being sent or received during that
+ // time. E.g. in the case of the legacy gRPC protocol where the single checkin
+ // request blocks until a task is assigned to the client.
+ //
+ // If possible, this estimate should also include time spent
+ // compressing/decompressing payloads before writing them to or after reading
+ // them from the network.
+ virtual NetworkStats GetNetworkStats() = 0;
+
+ protected:
+ // A list of states representing the sequence of calls we expect to receive
+ // via this interface, as well as their possible outcomes. Implementations of
+ // this class are likely to share these coarse-grained states, and use them to
+ // determine which values to return from `GetLatestRetryWindow()`.
+ enum class ObjectState {
+ // The initial object state.
+ kInitialized,
+ // EligibilityEvalCheckin() was called but it failed with a 'transient'
+ // error (e.g. an UNAVAILABLE network error, although the set of transient
+ // errors is flag-defined).
+ kEligibilityEvalCheckinFailed,
+ // EligibilityEvalCheckin() was called but it failed with a 'permanent'
+ // error (e.g. a NOT_FOUND network error, although the set of permanent
+ // errors is flag-defined).
+ kEligibilityEvalCheckinFailedPermanentError,
+ // EligibilityEvalCheckin() was called, and the server rejected the client.
+ kEligibilityEvalCheckinRejected,
+ // EligibilityEvalCheckin() was called, and the server did not return an
+ // eligibility eval payload.
+ kEligibilityEvalDisabled,
+ // EligibilityEvalCheckin() was called, and the server did return an
+ // eligibility eval payload, which must then be run to produce a
+ // TaskEligibilityInfo value.
+ kEligibilityEvalEnabled,
+ // Checkin(...) was called but it failed with a 'transient' error.
+ kCheckinFailed,
+ // Checkin(...) was called but it failed with a 'permanent' error.
+ kCheckinFailedPermanentError,
+ // Checkin(...) was called, and the server rejected the client.
+ kCheckinRejected,
+ // Checkin(...) was called, and the server accepted the client and returned
+ // a payload, which must then be run to produce a report.
+ kCheckinAccepted,
+ // PerformMultipleTaskAssignments(...) was called but it failed with a
+ // 'transient' error, without receiving a single task assignment. If some
+ // task assignments were successfully received, but some others failed (e.g.
+ // because their resources failed to be downloaded), then this state won't
+ // be used.
+ kMultipleTaskAssignmentsFailed,
+ // PerformMultipleTaskAssignments(...) was called but it failed with a
+ // 'permanent' error.
+ kMultipleTaskAssignmentsFailedPermanentError,
+ // PerformMultipleTaskAssignments(...) was called but an empty list of tasks
+ // is returned by the server.
+ kMultipleTaskAssignmentsNoAvailableTask,
+ // PerformMultipleTaskAssignments(...) was called, and the server accepted
+ // the client and returned one or more payload, which must then be run to
+ // produce a report.
+ kMultipleTaskAssignmentsAccepted,
+ // Report(...) was called.
+ kReportCalled,
+ // Report(...) was called and it resulted in a 'permanent' error.
+ //
+ // Note: there is no kReportFailed (corresponding to 'transient' errors,
+ // like the other phases have), because by the time the report phase is
+ // reached, a set of RetryWindows is guaranteed to have been received from
+ // the server.
+ kReportFailedPermanentError,
+ // Report(...) was called for multiple tasks, and only a subset of the tasks
+ // succeed.
+ kReportMultipleTaskPartialError,
+ };
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_FEDERATED_PROTOCOL_H_
diff --git a/fcp/client/federated_protocol_util.cc b/fcp/client/federated_protocol_util.cc
new file mode 100644
index 0000000..80345f7
--- /dev/null
+++ b/fcp/client/federated_protocol_util.cc
@@ -0,0 +1,116 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/federated_protocol_util.h"
+
+#include <algorithm>
+#include <string>
+
+#include "google/protobuf/duration.pb.h"
+#include "absl/random/random.h"
+#include "absl/status/statusor.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/time_util.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/protos/federated_api.pb.h"
+
+namespace fcp {
+namespace client {
+
+namespace {
+
+// Takes the given minimum and maximum delays, and uniformly randomly
+// chooses a delay in that range.
+absl::Duration PickRetryDelayFromRange(absl::Duration min_delay,
+ absl::Duration max_delay,
+ absl::BitGen& bit_gen) {
+ // Sanitize inputs (ensure min_delay is >= 0, and max_delay is >= min_delay).
+ min_delay = std::max(absl::ZeroDuration(), min_delay);
+ max_delay = std::max(max_delay, min_delay);
+
+ // Pick a value.
+ absl::Duration window_width = max_delay - min_delay;
+ double random = absl::Uniform(bit_gen, 0, 1.0);
+ return min_delay + (window_width * random);
+}
+
+} // namespace
+
+absl::Time PickRetryTimeFromRange(const ::google::protobuf::Duration& min_delay,
+ const ::google::protobuf::Duration& max_delay,
+ absl::BitGen& bit_gen) {
+ return absl::Now() +
+ PickRetryDelayFromRange(absl::Seconds(min_delay.seconds()) +
+ absl::Nanoseconds(min_delay.nanos()),
+ absl::Seconds(max_delay.seconds()) +
+ absl::Nanoseconds(max_delay.nanos()),
+ bit_gen);
+}
+
+::google::internal::federatedml::v2::RetryWindow
+GenerateRetryWindowFromTargetDelay(absl::Duration target_delay,
+ double jitter_percent,
+ absl::BitGen& bit_gen) {
+ // Sanitize the jitter_percent input, ensuring it's within [0.0 and 1.0]
+ jitter_percent = std::min(1.0, std::max(0.0, jitter_percent));
+ // Pick a retry delay from the target range.
+ absl::Duration retry_delay =
+ PickRetryDelayFromRange(target_delay * (1.0 - jitter_percent),
+ target_delay * (1.0 + jitter_percent), bit_gen);
+ ::google::internal::federatedml::v2::RetryWindow result;
+ *result.mutable_delay_min() = *result.mutable_delay_max() =
+ TimeUtil::ConvertAbslToProtoDuration(retry_delay);
+ return result;
+}
+
+::google::internal::federatedml::v2::RetryWindow
+GenerateRetryWindowFromRetryTime(absl::Time retry_time) {
+ // Convert the target retry time back to a duration, based on the current
+ // time. I.e. if at 09:50AM the retry window was received and the chosen
+ // target retry time was 11:00AM, and if it is now 09:55AM, then the
+ // calculated duration will be 1 hour and 5 minutes.
+ absl::Duration retry_delay = retry_time - absl::Now();
+ // If the target retry time has already passed, then use a zero-length
+ // duration.
+ retry_delay = std::max(absl::ZeroDuration(), retry_delay);
+
+ // Generate a RetryWindow with delay_min and delay_max both set to the same
+ // value.
+ ::google::internal::federatedml::v2::RetryWindow retry_window;
+ *retry_window.mutable_delay_min() = *retry_window.mutable_delay_max() =
+ TimeUtil::ConvertAbslToProtoDuration(retry_delay);
+ return retry_window;
+}
+
+std::string ExtractTaskNameFromAggregationSessionId(
+ const std::string& session_id, const std::string& population_name,
+ LogManager& log_manager) {
+ auto population_start = session_id.find(population_name + "/");
+ auto task_end = session_id.find('#');
+ if (population_start != 0 || task_end == std::string::npos ||
+ task_end <= population_name.length() + 1) {
+ log_manager.LogDiag(ProdDiagCode::OPSTATS_TASK_NAME_EXTRACTION_FAILED);
+ return session_id;
+ } else {
+ return session_id.substr(population_name.length() + 1,
+ task_end - population_name.length() - 1);
+ }
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/federated_protocol_util.h b/fcp/client/federated_protocol_util.h
new file mode 100644
index 0000000..6b4b8b0
--- /dev/null
+++ b/fcp/client/federated_protocol_util.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_FEDERATED_PROTOCOL_UTIL_H_
+#define FCP_CLIENT_FEDERATED_PROTOCOL_UTIL_H_
+
+#include <string>
+
+#include "google/protobuf/duration.pb.h"
+#include "absl/random/random.h"
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/protos/federated_api.pb.h"
+
+namespace fcp {
+namespace client {
+
+// Utility methods likely shared by FederatedProtocol implementations.
+
+// Picks an absolute retry time by picking a retry delay from the range
+// specified by the RetryWindow, and then adding it to the current timestamp.
+absl::Time PickRetryTimeFromRange(const ::google::protobuf::Duration& min_delay,
+ const ::google::protobuf::Duration& max_delay,
+ absl::BitGen& bit_gen);
+
+// Picks a retry delay and encodes it as a zero-width RetryWindow (where
+// delay_min and delay_max are set to the same value), from a given target delay
+// and a configured amount of jitter.
+::google::internal::federatedml::v2::RetryWindow
+GenerateRetryWindowFromTargetDelay(absl::Duration target_delay,
+ double jitter_percent,
+ absl::BitGen& bit_gen);
+
+// Converts the given absl::Time to a zero-width RetryWindow (where
+// delay_min and delay_max are set to the same value), by converting the target
+// retry time to a delay relative to the current timestamp.
+::google::internal::federatedml::v2::RetryWindow
+GenerateRetryWindowFromRetryTime(absl::Time retry_time);
+
+// Extracts a task name from an aggregation session ID (in the HTTP protocol) or
+// a phase ID (in the gRPC protocol), both of which are expected to adhere to
+// the following format: "population_name/task_name#round_id.shard_id".
+//
+// Returns the `session_id` string unmodified if it does not match that format.
+// A diag code will be logged to the `LogManager` in this case.
+std::string ExtractTaskNameFromAggregationSessionId(
+ const std::string& session_id, const std::string& population_name,
+ fcp::client::LogManager& log_manager);
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_FEDERATED_PROTOCOL_UTIL_H_
diff --git a/fcp/client/federated_protocol_util_test.cc b/fcp/client/federated_protocol_util_test.cc
new file mode 100644
index 0000000..e41e5f4
--- /dev/null
+++ b/fcp/client/federated_protocol_util_test.cc
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/federated_protocol_util.h"
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::client {
+namespace {
+
+using ::testing::StrictMock;
+
+TEST(ExtractTaskNameFromAggregationSessionIdTest, ExtractSuccessfully) {
+ StrictMock<MockLogManager> mock_log_manager_;
+ EXPECT_EQ(ExtractTaskNameFromAggregationSessionId(
+ "population_name/task_name#foo.bar", "population_name",
+ mock_log_manager_),
+ "task_name");
+ EXPECT_EQ(
+ ExtractTaskNameFromAggregationSessionId(
+ "population_name/task_name#", "population_name", mock_log_manager_),
+ "task_name");
+ EXPECT_EQ(ExtractTaskNameFromAggregationSessionId(
+ "population_name/task_name#foobar", "population_name",
+ mock_log_manager_),
+ "task_name");
+ EXPECT_EQ(ExtractTaskNameFromAggregationSessionId(
+ "population/name/task_name#foo.bar", "population/name",
+ mock_log_manager_),
+ "task_name");
+ EXPECT_EQ(ExtractTaskNameFromAggregationSessionId(
+ "population/name/task/name#foo.bar", "population/name",
+ mock_log_manager_),
+ "task/name");
+ EXPECT_EQ(ExtractTaskNameFromAggregationSessionId(
+ "population_name/task/name#foo.bar", "population_name",
+ mock_log_manager_),
+ "task/name");
+}
+
+TEST(ExtractTaskNameFromAggregationSessionIdTest, ExtractUnsuccessfully) {
+ {
+ StrictMock<MockLogManager> mock_log_manager_;
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::OPSTATS_TASK_NAME_EXTRACTION_FAILED));
+ EXPECT_EQ(ExtractTaskNameFromAggregationSessionId("foo", "population_name",
+ mock_log_manager_),
+ "foo");
+ }
+ {
+ StrictMock<MockLogManager> mock_log_manager_;
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::OPSTATS_TASK_NAME_EXTRACTION_FAILED));
+ EXPECT_EQ(
+ ExtractTaskNameFromAggregationSessionId(
+ "population_name2/foo#bar", "population_name", mock_log_manager_),
+ "population_name2/foo#bar");
+ }
+}
+
+} // namespace
+} // namespace fcp::client
diff --git a/fcp/client/federated_select.cc b/fcp/client/federated_select.cc
new file mode 100644
index 0000000..fd5165b
--- /dev/null
+++ b/fcp/client/federated_select.cc
@@ -0,0 +1,306 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/federated_select.h"
+
+#include <deque>
+#include <filesystem>
+#include <fstream>
+#include <functional>
+#include <ios>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/str_replace.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/wall_clock_stopwatch.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/files.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/http_client_util.h"
+#include "fcp/client/http/in_memory_request_response.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp {
+namespace client {
+
+using fcp::client::http::HttpClient;
+using fcp::client::http::InMemoryHttpResponse;
+using fcp::client::http::UriOrInlineData;
+using ::google::internal::federated::plan::SlicesSelector;
+
+namespace {
+
+// A Federated Select `ExampleIteratorFactory` that fails all queries.
+class DisabledFederatedSelectExampleIteratorFactory
+ : public FederatedSelectExampleIteratorFactory {
+ public:
+ explicit DisabledFederatedSelectExampleIteratorFactory(
+ LogManager* log_manager)
+ : log_manager_(*log_manager) {}
+
+ // Will fetch the slice data via HTTP and return an error if any of the
+ // slice fetch requests failed.
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const ::google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ log_manager_.LogDiag(
+ ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_REQUESTED_BUT_DISABLED);
+ return absl::InvalidArgumentError("Federated Select is disabled.");
+ }
+
+ private:
+ LogManager& log_manager_;
+};
+
+absl::StatusOr<std::deque<absl::Cord>> FetchSlicesViaHttp(
+ const SlicesSelector& slices_selector, absl::string_view uri_template,
+ HttpClient& http_client, InterruptibleRunner& interruptible_runner,
+ int64_t* bytes_received_acc, int64_t* bytes_sent_acc) {
+ std::vector<UriOrInlineData> resources;
+ for (int32_t slice_key : slices_selector.keys()) {
+ std::string slice_uri = absl::StrReplaceAll(
+ // Note that `served_at_id` is documented to not require URL-escaping,
+ // so we don't apply any here.
+ uri_template, {{"{served_at_id}", slices_selector.served_at_id()},
+ {"{key_base10}", absl::StrCat(slice_key)}});
+
+ resources.push_back(
+ UriOrInlineData::CreateUri(slice_uri, "", absl::ZeroDuration()));
+ }
+
+ // Perform the requests.
+ absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
+ slice_fetch_result = http::FetchResourcesInMemory(
+ http_client, interruptible_runner, std::move(resources),
+ bytes_received_acc, bytes_sent_acc,
+ // TODO(team): Enable caching for federated select slices.
+ /*resource_cache=*/nullptr);
+
+ // Check whether issuing the requests failed as a whole (generally indicating
+ // a programming error).
+ if (!slice_fetch_result.ok()) {
+ return absl::InternalError(absl::StrCat(
+ "Failed to perform HTTP requests (URI template: ", uri_template,
+ "): ", absl::StatusCodeToString(slice_fetch_result.status().code())));
+ }
+
+ std::deque<absl::Cord> slices;
+ for (const absl::StatusOr<InMemoryHttpResponse>& http_response :
+ *slice_fetch_result) {
+ if (!http_response.ok()) {
+ return absl::UnavailableError(absl::StrCat(
+ "Slice fetch request failed (URI template: ", uri_template,
+ "): ", absl::StatusCodeToString(http_response.status().code())));
+ }
+ slices.push_back(http_response->body);
+ }
+ return slices;
+}
+
+// A Federated Select `ExampleIteratorFactory` that, upon creation of an
+// iterator, fetches the slice data via HTTP, buffers it in-memory, and then
+// exposes it to the plan via an `InMemoryFederatedSelectExampleIterator`.
+class HttpFederatedSelectExampleIteratorFactory
+ : public FederatedSelectExampleIteratorFactory {
+ public:
+ HttpFederatedSelectExampleIteratorFactory(
+ LogManager* log_manager, Files* files, HttpClient* http_client,
+ InterruptibleRunner* interruptible_runner, absl::string_view uri_template,
+ std::atomic<int64_t>& bytes_sent_acc,
+ std::atomic<int64_t>& bytes_received_acc,
+ WallClockStopwatch* network_stopwatch)
+ : log_manager_(*log_manager),
+ files_(*files),
+ http_client_(*http_client),
+ interruptible_runner_(*interruptible_runner),
+ uri_template_(uri_template),
+ bytes_sent_acc_(bytes_sent_acc),
+ bytes_received_acc_(bytes_received_acc),
+ network_stopwatch_(*network_stopwatch) {}
+
+ // Will fetch the slice data via HTTP and return an error if any of the slice
+ // fetch requests failed.
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const ::google::internal::federated::plan::ExampleSelector&
+ example_selector) override;
+
+ private:
+ LogManager& log_manager_;
+ Files& files_;
+ HttpClient& http_client_;
+ InterruptibleRunner& interruptible_runner_;
+ std::string uri_template_;
+ std::atomic<int64_t>& bytes_sent_acc_;
+ std::atomic<int64_t>& bytes_received_acc_;
+ WallClockStopwatch& network_stopwatch_;
+};
+
+absl::StatusOr<std::unique_ptr<ExampleIterator>>
+HttpFederatedSelectExampleIteratorFactory::CreateExampleIterator(
+ const ::google::internal::federated::plan::ExampleSelector&
+ example_selector) {
+ SlicesSelector slices_selector;
+ if (!example_selector.criteria().UnpackTo(&slices_selector)) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Unexpected/unparseable selection criteria: ",
+ example_selector.criteria().GetTypeName()));
+ }
+
+ log_manager_.LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_REQUESTED);
+
+ // Create the temporary scratch file to store the checkpoint data in.
+ // Deletion of the file is done in the
+ // InMemoryFederatedSelectExampleIterator::Close() method or its destructor.
+ absl::StatusOr<std::string> scratch_filename =
+ files_.CreateTempFile("slice", ".ckp");
+
+ if (!scratch_filename.ok()) {
+ return absl::InternalError(absl::StrCat(
+ "Failed to create scratch file for slice data (URI template: ",
+ uri_template_,
+ "): ", absl::StatusCodeToString(scratch_filename.status().code()), ": ",
+ scratch_filename.status().message()));
+ }
+
+ // Fetch the slices.
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ absl::StatusOr<std::deque<absl::Cord>> slices;
+ {
+ auto started_stopwatch = network_stopwatch_.Start();
+ slices = FetchSlicesViaHttp(slices_selector, uri_template_, http_client_,
+ interruptible_runner_,
+ /*bytes_received_acc=*/&bytes_received,
+ /*bytes_sent_acc=*/&bytes_sent);
+ }
+ bytes_sent_acc_ += bytes_sent;
+ bytes_received_acc_ += bytes_received;
+ if (!slices.ok()) {
+ log_manager_.LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_FAILED);
+ return absl::Status(slices.status().code(),
+ absl::StrCat("Failed to fetch slice data: ",
+ slices.status().message()));
+ }
+ log_manager_.LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_SUCCEEDED);
+
+ return std::make_unique<InMemoryFederatedSelectExampleIterator>(
+ *scratch_filename, std::move(*slices));
+}
+} // namespace
+
+DisabledFederatedSelectManager::DisabledFederatedSelectManager(
+ LogManager* log_manager)
+ : log_manager_(*log_manager) {}
+
+std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory>
+DisabledFederatedSelectManager::CreateExampleIteratorFactoryForUriTemplate(
+ absl::string_view uri_template) {
+ return std::make_unique<DisabledFederatedSelectExampleIteratorFactory>(
+ &log_manager_);
+}
+
+HttpFederatedSelectManager::HttpFederatedSelectManager(
+ LogManager* log_manager, Files* files,
+ fcp::client::http::HttpClient* http_client,
+ std::function<bool()> should_abort,
+ const InterruptibleRunner::TimingConfig& timing_config)
+ : log_manager_(*log_manager),
+ files_(*files),
+ http_client_(*http_client),
+ interruptible_runner_(std::make_unique<InterruptibleRunner>(
+ log_manager, should_abort, timing_config,
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
+ .interrupt_timeout =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT})) {}
+
+std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory>
+HttpFederatedSelectManager::CreateExampleIteratorFactoryForUriTemplate(
+ absl::string_view uri_template) {
+ // If the server didn't populate the URI template then we can't support any
+ // slice fetch requests.
+ if (uri_template.empty()) {
+ return std::make_unique<DisabledFederatedSelectExampleIteratorFactory>(
+ &log_manager_);
+ }
+ return std::make_unique<HttpFederatedSelectExampleIteratorFactory>(
+ &log_manager_, &files_, &http_client_, interruptible_runner_.get(),
+ uri_template,
+ /*bytes_sent_acc=*/bytes_sent_, /*bytes_received_acc=*/bytes_received_,
+ network_stopwatch_.get());
+}
+
+absl::StatusOr<std::string> InMemoryFederatedSelectExampleIterator::Next() {
+ absl::MutexLock lock(&mutex_);
+
+ if (slices_.empty()) {
+ // Eagerly delete the scratch file, since we won't need it anymore.
+ std::filesystem::remove(scratch_filename_);
+ return absl::OutOfRangeError("end of iterator reached");
+ }
+
+ absl::Cord& slice_data = slices_.front();
+
+ // Write the checkpoint data to the file (truncating any data previously
+ // written to the file).
+ std::fstream checkpoint_stream(scratch_filename_,
+ std::ios_base::out | std::ios_base::trunc);
+ if (checkpoint_stream.fail()) {
+ return absl::InternalError("Failed to write slice to file");
+ }
+ for (absl::string_view chunk : slice_data.Chunks()) {
+ if (!(checkpoint_stream << chunk).good()) {
+ return absl::InternalError("Failed to write slice to file");
+ }
+ }
+ checkpoint_stream.close();
+
+ // Remove the slice from the deque, releasing its data from memory.
+ slices_.pop_front();
+
+ return scratch_filename_;
+}
+
+void InMemoryFederatedSelectExampleIterator::Close() { CleanupInternal(); }
+
+InMemoryFederatedSelectExampleIterator::
+ ~InMemoryFederatedSelectExampleIterator() {
+ // Remove the scratch file, even if Close() wasn't called first.
+ CleanupInternal();
+}
+
+void InMemoryFederatedSelectExampleIterator::CleanupInternal() {
+ absl::MutexLock lock(&mutex_);
+ // Remove the scratch filename, if it hadn't been removed yet.
+ slices_.clear();
+ std::filesystem::remove(scratch_filename_);
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/federated_select.h b/fcp/client/federated_select.h
new file mode 100644
index 0000000..232aa4d
--- /dev/null
+++ b/fcp/client/federated_select.h
@@ -0,0 +1,162 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_FEDERATED_SELECT_H_
+#define FCP_CLIENT_FEDERATED_SELECT_H_
+
+#include <deque>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/time/time.h"
+#include "fcp/base/wall_clock_stopwatch.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/files.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp {
+namespace client {
+
+// The example query collection URI via which slice fetch requests will arrive.
+inline static constexpr char kFederatedSelectCollectionUri[] =
+ "internal:/federated_select";
+
+// An interface via which a Federated Select `ExampleIteratorFactory` can be
+// created. Each factory is expected to fetch slice data using the given
+// `uri_template`, and to then serve the slice data by writing it to a file and
+// by then returning that filename as a tf.Example to the plan.
+class FederatedSelectManager {
+ public:
+ virtual std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory>
+ CreateExampleIteratorFactoryForUriTemplate(
+ absl::string_view uri_template) = 0;
+
+ // The best estimate of the over-the-wire bytes downloaded and uploadeded over
+ // the network, and the total duration of wall clock time spent waiting on
+ // network requests.
+
+ // Note that if two different slice fetches are in flight from different
+ // threads, this should measure just the wall clock time spent completing both
+ // sets of fetches (i.e. it should not double-count the wall clock time by
+ // summing each per-thread duration individually).
+ //
+ // If possible, this estimate should also include time spent decompressing
+ // payloads after reading them from the network.
+ virtual NetworkStats GetNetworkStats() = 0;
+
+ virtual ~FederatedSelectManager() {}
+};
+
+// An base class for `ExampleIteratorFactory` implementations that can handle
+// Federated Select example queries.
+class FederatedSelectExampleIteratorFactory
+ : public ::fcp::client::engine::ExampleIteratorFactory {
+ public:
+ bool CanHandle(const ::google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ return example_selector.collection_uri() == kFederatedSelectCollectionUri;
+ }
+
+ bool ShouldCollectStats() override {
+ // Federated Select example queries should not be recorded in the OpStats
+ // DB, since the fact that Federated Select uses the example iterator
+ // interface is an internal implementation detail.
+ return false;
+ }
+};
+
+class DisabledFederatedSelectManager : public FederatedSelectManager {
+ public:
+ explicit DisabledFederatedSelectManager(LogManager* log_manager);
+
+ std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory>
+ CreateExampleIteratorFactoryForUriTemplate(
+ absl::string_view uri_template) override;
+
+ NetworkStats GetNetworkStats() override { return NetworkStats(); }
+
+ private:
+ LogManager& log_manager_;
+};
+
+// A FederatedSelectManager implementation that actually issues HTTP requests to
+// fetch slice data (i.e. the "real" implementation).
+class HttpFederatedSelectManager : public FederatedSelectManager {
+ public:
+ HttpFederatedSelectManager(
+ LogManager* log_manager, Files* files,
+ fcp::client::http::HttpClient* http_client,
+ std::function<bool()> should_abort,
+ const InterruptibleRunner::TimingConfig& timing_config);
+
+ std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory>
+ CreateExampleIteratorFactoryForUriTemplate(
+ absl::string_view uri_template) override;
+
+ NetworkStats GetNetworkStats() override {
+ return {.bytes_downloaded = bytes_received_.load(),
+ .bytes_uploaded = bytes_sent_.load(),
+ .network_duration = network_stopwatch_->GetTotalDuration()};
+ }
+
+ private:
+ LogManager& log_manager_;
+ Files& files_;
+ std::atomic<int64_t> bytes_sent_ = 0;
+ std::atomic<int64_t> bytes_received_ = 0;
+ std::unique_ptr<WallClockStopwatch> network_stopwatch_ =
+ WallClockStopwatch::Create();
+ fcp::client::http::HttpClient& http_client_;
+ std::unique_ptr<InterruptibleRunner> interruptible_runner_;
+};
+
+// A Federated Select ExampleIterator that simply returns slice data that is
+// already in-memory.
+class InMemoryFederatedSelectExampleIterator : public ExampleIterator {
+ public:
+ // Each time another slice is requested by a call to Next(), the slice data at
+ // the front of the `slices` deque will be written to the `scratch_filename`
+ // and the filename will be returned as the example data. The scratch file
+ // will be deleted at the end of the iterator, or when the iterator is closed.
+ InMemoryFederatedSelectExampleIterator(std::string scratch_filename,
+ std::deque<absl::Cord> slices)
+ : scratch_filename_(scratch_filename), slices_(std::move(slices)) {}
+ absl::StatusOr<std::string> Next() override;
+ void Close() override;
+
+ ~InMemoryFederatedSelectExampleIterator() override;
+
+ private:
+ void CleanupInternal() ABSL_LOCKS_EXCLUDED(mutex_);
+
+ std::string scratch_filename_;
+
+ absl::Mutex mutex_;
+ std::deque<absl::Cord> slices_ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_FEDERATED_SELECT_H_
diff --git a/fcp/client/federated_select_test.cc b/fcp/client/federated_select_test.cc
new file mode 100644
index 0000000..c994a0f
--- /dev/null
+++ b/fcp/client/federated_select_test.cc
@@ -0,0 +1,451 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/federated_select.h"
+
+#include <fstream>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "google/protobuf/text_format.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "absl/synchronization/blocking_counter.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/client_runner.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/in_memory_request_response.h"
+#include "fcp/client/http/testing/test_helpers.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/stats.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::client {
+namespace {
+
+using ::fcp::IsCode;
+using ::fcp::client::ExampleIterator;
+using ::fcp::client::engine::ExampleIteratorFactory;
+using ::fcp::client::http::FakeHttpResponse;
+using ::fcp::client::http::HeaderList;
+using ::fcp::client::http::HttpRequest;
+using ::fcp::client::http::HttpRequestHandle;
+using ::fcp::client::http::MockHttpClient;
+using ::fcp::client::http::SimpleHttpRequestMatcher;
+using ::fcp::client::http::internal::CompressWithGzip;
+using ::google::internal::federated::plan::ExampleSelector;
+using ::google::internal::federated::plan::SlicesSelector;
+using ::testing::_;
+using ::testing::Gt;
+using ::testing::HasSubstr;
+using ::testing::InSequence;
+using ::testing::MockFunction;
+using ::testing::NiceMock;
+using ::testing::Not;
+using ::testing::Return;
+using ::testing::StrictMock;
+
+ExampleSelector CreateExampleSelector(const std::string& served_at_id,
+ std::vector<int32_t> keys) {
+ ExampleSelector example_selector;
+ *example_selector.mutable_collection_uri() = "internal:/federated_select";
+ SlicesSelector slices_selector;
+ *slices_selector.mutable_served_at_id() = served_at_id;
+ slices_selector.mutable_keys()->Add(keys.begin(), keys.end());
+ example_selector.mutable_criteria()->PackFrom(slices_selector);
+ return example_selector;
+}
+
+bool FileExists(const std::string& path) {
+ std::ifstream istream(path);
+ return istream.good();
+}
+
+std::string ReadFile(const std::string& path) {
+ std::ifstream istream(path);
+ FCP_CHECK(istream);
+ std::stringstream stringstream;
+ stringstream << istream.rdbuf();
+ return stringstream.str();
+}
+
+class HttpFederatedSelectManagerTest : public ::testing::Test {
+ protected:
+ HttpFederatedSelectManagerTest()
+ : fedselect_manager_(
+ &mock_log_manager_, &files_impl_, &mock_http_client_,
+ mock_should_abort_.AsStdFunction(),
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()}) {}
+
+ void SetUp() override {
+ EXPECT_CALL(mock_flags_, enable_federated_select())
+ .WillRepeatedly(Return(true));
+ }
+
+ void TearDown() override {
+ // Regardless of the outcome of the test (or the protocol interaction being
+ // tested), network usage must always be reflected in the network stats
+ // methods.
+ HttpRequestHandle::SentReceivedBytes sent_received_bytes =
+ mock_http_client_.TotalSentReceivedBytes();
+ NetworkStats network_stats = fedselect_manager_.GetNetworkStats();
+ EXPECT_EQ(network_stats.bytes_downloaded,
+ sent_received_bytes.received_bytes);
+ EXPECT_EQ(network_stats.bytes_uploaded, sent_received_bytes.sent_bytes);
+ // If any network traffic occurred, we expect to see some time reflected in
+ // the duration.
+ if (network_stats.bytes_uploaded > 0) {
+ EXPECT_THAT(network_stats.network_duration, Gt(absl::ZeroDuration()));
+ }
+ }
+
+ NiceMock<MockLogManager> mock_log_manager_;
+ MockFlags mock_flags_;
+ fcp::client::FilesImpl files_impl_;
+ StrictMock<MockHttpClient> mock_http_client_;
+ NiceMock<MockFunction<bool()>> mock_should_abort_;
+
+ HttpFederatedSelectManager fedselect_manager_;
+};
+
+TEST_F(HttpFederatedSelectManagerTest,
+ IteratorFactoryShouldHandleValidSelector) {
+ // Should be handled by the factory.
+ ExampleSelector selector =
+ CreateExampleSelector(/*served_at_id=*/"foo", /*keys=*/{1});
+
+ std::unique_ptr<ExampleIteratorFactory> iterator_factory =
+ fedselect_manager_.CreateExampleIteratorFactoryForUriTemplate(
+ "https://foo.bar");
+
+ EXPECT_TRUE(iterator_factory->CanHandle(selector));
+}
+
+TEST_F(HttpFederatedSelectManagerTest,
+ IteratorFactoryShouldNotHandleUnrelatedSelector) {
+ // Should not be handled by the factory.
+ ExampleSelector selector;
+ *selector.mutable_collection_uri() = "internal:/foo";
+
+ std::unique_ptr<ExampleIteratorFactory> iterator_factory =
+ fedselect_manager_.CreateExampleIteratorFactoryForUriTemplate(
+ "https://foo.bar");
+
+ EXPECT_FALSE(iterator_factory->CanHandle(selector));
+ EXPECT_THAT(iterator_factory->CreateExampleIterator(selector),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(HttpFederatedSelectManagerTest, IteratorFactoryShouldNotCollectStats) {
+ std::unique_ptr<ExampleIteratorFactory> iterator_factory =
+ fedselect_manager_.CreateExampleIteratorFactoryForUriTemplate(
+ "https://foo.bar");
+
+ EXPECT_FALSE(iterator_factory->ShouldCollectStats());
+}
+
+TEST_F(HttpFederatedSelectManagerTest,
+ EmptyUriTemplateShouldFailAllExampleQueries) {
+ ExampleSelector selector =
+ CreateExampleSelector(/*served_at_id=*/"foo", /*keys=*/{1});
+
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_REQUESTED_BUT_DISABLED));
+
+ // Create an iterator factory using an empty base URI (indicating that the
+ // server didn't provide us with a federated select URI template, i.e. the
+ // feature is disabled on the server or the plan doesn't use the Federated
+ // Select feature).
+ std::unique_ptr<ExampleIteratorFactory> iterator_factory =
+ fedselect_manager_.CreateExampleIteratorFactoryForUriTemplate(
+ /*uri_template=*/"");
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator =
+ iterator_factory->CreateExampleIterator(selector);
+
+ // The iterator creation should have failed, since the URI template was empty.
+ EXPECT_THAT(iterator, IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(iterator.status().message(), HasSubstr("disabled"));
+ // Even though creating an iterator should fail since the feature is disabled,
+ // the iterator factory should still handle all internal:/federated_select
+ // example queries (rather than let them bubble up to the default
+ // environment-provided example iterator factory, which won't know how to
+ // handle them anyway).
+ EXPECT_TRUE(iterator_factory->CanHandle(selector));
+}
+
+/** Tests the "happy" path. Two separate federated select slice fetching queries
+ * are received by a single iterator factory, each serving different slice data
+ * to the client. All slice fetches are successful and should result in the
+ * correct data being returned, in the right order.*/
+TEST_F(HttpFederatedSelectManagerTest,
+ SuccessfullyFetchMultipleSlicesAcrossMultipleIterators) {
+ const std::string uri_template =
+ "https://foo.bar/{served_at_id}/baz/{key_base10}/bazz";
+
+ // Create an iterator factory with a valid (non-empty) URI template.
+ std::unique_ptr<ExampleIteratorFactory> iterator_factory =
+ fedselect_manager_.CreateExampleIteratorFactoryForUriTemplate(
+ uri_template);
+
+ // Once the first iterator is created we expect the following slice fetch
+ // requests to be issued immediately.
+ const std::string expected_key1_data = "key1_data";
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '/' in "id/X" is *not* URI-escaped.
+ "https://foo.bar/id/X/baz/1/bazz", HttpRequest::Method::kGet, _, "")))
+ .WillOnce(
+ Return(FakeHttpResponse(200, HeaderList(), expected_key1_data)));
+
+ const std::string expected_key2_data = "key2_data";
+ EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://foo.bar/id/X/baz/2/bazz",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(
+ Return(FakeHttpResponse(200, HeaderList(), expected_key2_data)));
+
+ {
+ InSequence in_sequence;
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_REQUESTED));
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_SUCCEEDED));
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator1 =
+ iterator_factory->CreateExampleIterator(CreateExampleSelector(
+ // Note that we use a served_at_id value with a '/' in it. It should
+ // *not* get URI-escaped, as per the FederatedSelectUriInfo docs.
+ /*served_at_id=*/"id/X",
+ // Also note that we request slices 2 and 1, in that exact order. I.e.
+ // the first slice data we receive should be slice 2, and the second
+ // slice data we receive should be slice 1.
+ /*keys=*/{2, 1}));
+
+ // The iterator creation should have succeeded.
+ ASSERT_OK(iterator1);
+
+ // Reading the data for each of the slices should now succeed.
+ absl::StatusOr<std::string> first_slice = (*iterator1)->Next();
+ ASSERT_OK(first_slice);
+ ASSERT_TRUE(FileExists(*first_slice));
+ EXPECT_THAT(ReadFile(*first_slice), expected_key2_data);
+
+ absl::StatusOr<std::string> second_slice = (*iterator1)->Next();
+ ASSERT_OK(second_slice);
+ ASSERT_TRUE(FileExists(*second_slice));
+ EXPECT_THAT(ReadFile(*second_slice), expected_key1_data);
+
+ // We should now have reached the end of the first iterator.
+ EXPECT_THAT((*iterator1)->Next(), IsCode(OUT_OF_RANGE));
+
+ // Closing the iterator should not fail/crash.
+ (*iterator1)->Close();
+ // The slice files we saw earlier (possibly all the same file) should now be
+ // deleted.
+ ASSERT_FALSE(FileExists(*first_slice));
+ ASSERT_FALSE(FileExists(*second_slice));
+
+ const std::string expected_key99_data = "key99_data";
+ EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://foo.bar/id/Y/baz/99/bazz",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(
+ Return(FakeHttpResponse(200, HeaderList(), expected_key99_data)));
+
+ {
+ InSequence in_sequence;
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_REQUESTED));
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_SUCCEEDED));
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator2 =
+ iterator_factory->CreateExampleIterator(
+ CreateExampleSelector(/*served_at_id=*/"id/Y", /*keys=*/{99}));
+
+ // The iterator creation should have succeeded.
+ ASSERT_OK(iterator2);
+
+ // Reading the data for the slices should now succeed.
+ absl::StatusOr<std::string> third_slice = (*iterator2)->Next();
+ ASSERT_OK(third_slice);
+ ASSERT_TRUE(FileExists(*third_slice));
+ EXPECT_THAT(ReadFile(*third_slice), expected_key99_data);
+
+ // We purposely do not close the 2nd iterator, nor iterate it all the way to
+ // the end until we receive OUT_OF_RANGE, but instead simply destroy it. This
+ // should have the same effect as closing it, and cause the file to be
+ // deleted.
+ *iterator2 = nullptr;
+ ASSERT_FALSE(FileExists(*third_slice));
+}
+
+/** Tests the case where the fetched resources are compressed using the
+ * "Content-Type: ...+gzip" approach. The data should be decompressed before
+ * being returned.
+ */
+TEST_F(HttpFederatedSelectManagerTest, SuccessfullyFetchCompressedSlice) {
+ const std::string uri_template =
+ "https://foo.bar/{served_at_id}/{key_base10}";
+
+ // Create an iterator factory with a valid (non-empty) URI template.
+ std::unique_ptr<ExampleIteratorFactory> iterator_factory =
+ fedselect_manager_.CreateExampleIteratorFactoryForUriTemplate(
+ uri_template);
+
+ // Once the first iterator is created we expect the following slice fetch
+ // requests to be issued immediately.
+ const std::string expected_key1_data = "key1_data";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://foo.bar/id-X/1", HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList{{"Content-Type", "application/octet-stream+gzip"}},
+ *CompressWithGzip(expected_key1_data))));
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator =
+ iterator_factory->CreateExampleIterator(CreateExampleSelector(
+ /*served_at_id=*/"id-X", /*keys=*/{1}));
+
+ // The iterator creation should have succeeded.
+ ASSERT_OK(iterator);
+
+ // Reading the data for the slice should now succeed and return the expected
+ // (uncompressed) data.
+ absl::StatusOr<std::string> slice = (*iterator)->Next();
+ ASSERT_OK(slice);
+ ASSERT_TRUE(FileExists(*slice));
+ EXPECT_THAT(ReadFile(*slice), expected_key1_data);
+}
+
+/** Tests the case where the URI template contains the substitution strings more
+ * than once. The client should replace *all* of them, not just the first one.
+ */
+TEST_F(HttpFederatedSelectManagerTest,
+ SuccessfullyFetchFromUriTemplateWithMultipleTemplateEntries) {
+ const std::string uri_template =
+ "https://{served_at_id}.foo.bar/{key_base10}{served_at_id}/baz/"
+ "{key_base10}/bazz";
+
+ // Create an iterator factory with a valid (non-empty) URI template.
+ std::unique_ptr<ExampleIteratorFactory> iterator_factory =
+ fedselect_manager_.CreateExampleIteratorFactoryForUriTemplate(
+ uri_template);
+
+ // Once the first iterator is created we expect the following slice fetch
+ // requests to be issued immediately.
+ const std::string expected_key1_data = "key1_data";
+ EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://id-X.foo.bar/1id-X/baz/1/bazz",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(
+ Return(FakeHttpResponse(200, HeaderList(), expected_key1_data)));
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator =
+ iterator_factory->CreateExampleIterator(CreateExampleSelector(
+ /*served_at_id=*/"id-X", /*keys=*/{1}));
+
+ // The iterator creation should have succeeded.
+ ASSERT_OK(iterator);
+
+ // Reading the data should now succeed.
+ absl::StatusOr<std::string> slice = (*iterator)->Next();
+ ASSERT_OK(slice);
+ ASSERT_TRUE(FileExists(*slice));
+ EXPECT_THAT(ReadFile(*slice), expected_key1_data);
+
+ // We should now have reached the end of the first iterator.
+ EXPECT_THAT((*iterator)->Next(), IsCode(OUT_OF_RANGE));
+
+ // Closing the iterator should not fail/crash.
+ (*iterator)->Close();
+ // The slice files we saw earlier (possibly all the same file) should now be
+ // deleted.
+ ASSERT_FALSE(FileExists(*slice));
+}
+
+/** Tests the case where the URI template contains the substitution strings more
+ * than once. The client should replace *all* of them, not just the first one.
+ */
+TEST_F(HttpFederatedSelectManagerTest, ErrorDuringFetch) {
+ const std::string uri_template =
+ "https://foo.bar/{served_at_id}/{key_base10}";
+
+ // Create an iterator factory with a valid (non-empty) URI template.
+ std::unique_ptr<ExampleIteratorFactory> iterator_factory =
+ fedselect_manager_.CreateExampleIteratorFactoryForUriTemplate(
+ uri_template);
+
+ // Once the first iterator is created we expect the following slice fetch
+ // requests to be issued immediately. We'll make the 2nd slice's HTTP request
+ // return an error.
+ EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://foo.bar/id-X/998",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://foo.bar/id-X/999",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
+
+ {
+ InSequence in_sequence;
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_REQUESTED));
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::FEDSELECT_SLICE_HTTP_FETCH_FAILED));
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator =
+ iterator_factory->CreateExampleIterator(CreateExampleSelector(
+ /*served_at_id=*/"id-X", /*keys=*/{998, 999}));
+
+ // The iterator creation should fail, since if we can't fetch all of the
+ // slices successfully there is no way the plan can continue executing.
+ // The error code should be UNAVAILABLE and the error message should include
+ // the original NOT_FOUND error code that HTTP's 404 maps to, as well as the
+ // URI template to aid debugging.
+ EXPECT_THAT(iterator, IsCode(UNAVAILABLE));
+ EXPECT_THAT(iterator.status().message(), HasSubstr("fetch request failed"));
+ EXPECT_THAT(iterator.status().message(), HasSubstr(uri_template));
+ EXPECT_THAT(iterator.status().message(), HasSubstr("NOT_FOUND"));
+ // The error message should not contain the exact original HTTP code (since we
+ // expect the HTTP layer's error *messages* to not be included in the message
+ // returned to the plan).
+ EXPECT_THAT(iterator.status().message(), Not(HasSubstr("404")));
+ // The error message should not contain the slice IDs either.
+ EXPECT_THAT(iterator.status().message(), Not(HasSubstr("998")));
+ EXPECT_THAT(iterator.status().message(), Not(HasSubstr("999")));
+}
+
+} // anonymous namespace
+} // namespace fcp::client
diff --git a/fcp/client/files.h b/fcp/client/files.h
new file mode 100644
index 0000000..77c8dd0
--- /dev/null
+++ b/fcp/client/files.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_FILES_H_
+#define FCP_CLIENT_FILES_H_
+
+#include <string>
+
+#include "absl/status/statusor.h"
+
+namespace fcp {
+namespace client {
+
+// An interface used by the plan engine for platform-dependent file system
+// access.
+class Files {
+ public:
+ virtual ~Files() = default;
+
+ // Creates a temporary file. The runtime environment (e.g. operating system)
+ // is expected to clean up these files if necessary, i.e. the engine is not
+ // responsible for their deletion (but may chose to do so).
+ // On success, returns a file path.
+ // On error, returns
+ // - INTERNAL - unexpected error.
+ // - INVALID_ARGUMENT - on "expected" errors such as I/O issues.
+ virtual absl::StatusOr<std::string> CreateTempFile(
+ const std::string& prefix, const std::string& suffix) = 0;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_FILES_H_
diff --git a/fcp/client/fl_runner.cc b/fcp/client/fl_runner.cc
new file mode 100644
index 0000000..3eaf36f
--- /dev/null
+++ b/fcp/client/fl_runner.cc
@@ -0,0 +1,1638 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/fl_runner.h"
+
+#include <fcntl.h>
+
+#include <fstream>
+#include <functional>
+#include <map>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/time/time.h"
+#include "fcp/base/clock.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+// #include "fcp/client/cache/file_backed_resource_cache.h"
+#include "fcp/client/cache/resource_cache.h"
+#include "fcp/client/engine/common.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/engine/example_query_plan_engine.h"
+#include "fcp/client/engine/plan_engine_helpers.h"
+#include "fcp/client/opstats/opstats_utils.h"
+#include "fcp/client/parsing_utils.h"
+#ifdef FCP_CLIENT_SUPPORT_TFMOBILE
+#include "fcp/client/engine/simple_plan_engine.h"
+#endif
+#include "fcp/client/engine/tflite_plan_engine.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/federated_protocol_util.h"
+#include "fcp/client/files.h"
+#include "fcp/client/fl_runner.pb.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/http/http_federated_protocol.h"
+#ifdef FCP_CLIENT_SUPPORT_GRPC
+#include "fcp/client/grpc_federated_protocol.h"
+#endif
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_example_store.h"
+#include "fcp/client/phase_logger_impl.h"
+#include "fcp/client/secagg_runner.h"
+#include "fcp/client/selector_context.pb.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
+#include "fcp/protos/opstats.pb.h"
+#include "fcp/protos/plan.pb.h"
+#include "openssl/digest.h"
+#include "openssl/evp.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+namespace fcp {
+namespace client {
+using ::fcp::client::opstats::OpStatsLogger;
+using ::google::internal::federated::plan::AggregationConfig;
+using ::google::internal::federated::plan::ClientOnlyPlan;
+using ::google::internal::federated::plan::FederatedComputeEligibilityIORouter;
+using ::google::internal::federated::plan::FederatedComputeIORouter;
+using ::google::internal::federated::plan::TensorflowSpec;
+using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec;
+using ::google::internal::federatedml::v2::RetryWindow;
+using ::google::internal::federatedml::v2::TaskEligibilityInfo;
+using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
+namespace {
+template <typename T>
+void AddValuesToQuantized(QuantizedTensor* quantized,
+ const tensorflow::Tensor& tensor) {
+ auto flat_tensor = tensor.flat<T>();
+ quantized->values.reserve(quantized->values.size() + flat_tensor.size());
+ for (int i = 0; i < flat_tensor.size(); i++) {
+ quantized->values.push_back(flat_tensor(i));
+ }
+}
+std::string ComputeSHA256FromStringOrCord(
+ std::variant<std::string, absl::Cord> data) {
+ std::unique_ptr<EVP_MD_CTX, void (*)(EVP_MD_CTX*)> mdctx(EVP_MD_CTX_create(),
+ EVP_MD_CTX_destroy);
+ FCP_CHECK(EVP_DigestInit_ex(mdctx.get(), EVP_sha256(), nullptr));
+ std::string plan_str;
+ if (std::holds_alternative<std::string>(data)) {
+ plan_str = std::get<std::string>(data);
+ } else {
+ plan_str = std::string(std::get<absl::Cord>(data));
+ }
+ FCP_CHECK(EVP_DigestUpdate(mdctx.get(), plan_str.c_str(), sizeof(int)));
+ const int hash_len = 32; // 32 bytes for SHA-256.
+ uint8_t computation_id_bytes[hash_len];
+ FCP_CHECK(EVP_DigestFinal_ex(mdctx.get(), computation_id_bytes, nullptr));
+ return std::string(reinterpret_cast<char const*>(computation_id_bytes),
+ hash_len);
+}
+struct PlanResultAndCheckpointFile {
+ explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result)
+ : plan_result(std::move(plan_result)) {}
+ engine::PlanResult plan_result;
+ std::string checkpoint_file;
+ PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default;
+ PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) =
+ default;
+ // Disallow copy and assign.
+ PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete;
+ PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) =
+ delete;
+};
+// Creates computation results. The method checks for SecAgg tensors only if
+// `tensorflow_spec != nullptr`.
+absl::StatusOr<ComputationResults> CreateComputationResults(
+ const TensorflowSpec* tensorflow_spec,
+ const PlanResultAndCheckpointFile& plan_result_and_checkpoint_file) {
+ const auto& [plan_result, checkpoint_file] = plan_result_and_checkpoint_file;
+ if (plan_result.outcome != engine::PlanOutcome::kSuccess) {
+ return absl::InvalidArgumentError("Computation failed.");
+ }
+ ComputationResults computation_results;
+ if (tensorflow_spec != nullptr) {
+ for (int i = 0; i < plan_result.output_names.size(); i++) {
+ QuantizedTensor quantized;
+ const auto& output_tensor = plan_result.output_tensors[i];
+ switch (output_tensor.dtype()) {
+ case tensorflow::DT_INT8:
+ AddValuesToQuantized<int8_t>(&quantized, output_tensor);
+ quantized.bitwidth = 7;
+ break;
+ case tensorflow::DT_UINT8:
+ AddValuesToQuantized<uint8_t>(&quantized, output_tensor);
+ quantized.bitwidth = 8;
+ break;
+ case tensorflow::DT_INT16:
+ AddValuesToQuantized<int16_t>(&quantized, output_tensor);
+ quantized.bitwidth = 15;
+ break;
+ case tensorflow::DT_UINT16:
+ AddValuesToQuantized<uint16_t>(&quantized, output_tensor);
+ quantized.bitwidth = 16;
+ break;
+ case tensorflow::DT_INT32:
+ AddValuesToQuantized<int32_t>(&quantized, output_tensor);
+ quantized.bitwidth = 31;
+ break;
+ case tensorflow::DT_INT64:
+ AddValuesToQuantized<tensorflow::int64>(&quantized, output_tensor);
+ quantized.bitwidth = 62;
+ break;
+ default:
+ return absl::InvalidArgumentError(
+ absl::StrCat("Tensor of type",
+ tensorflow::DataType_Name(output_tensor.dtype()),
+ "could not be converted to quantized value"));
+ }
+ computation_results[plan_result.output_names[i]] = std::move(quantized);
+ }
+ // Add dimensions to QuantizedTensors.
+ for (const tensorflow::TensorSpecProto& tensor_spec :
+ tensorflow_spec->output_tensor_specs()) {
+ if (computation_results.find(tensor_spec.name()) !=
+ computation_results.end()) {
+ for (const tensorflow::TensorShapeProto_Dim& dim :
+ tensor_spec.shape().dim()) {
+ std::get<QuantizedTensor>(computation_results[tensor_spec.name()])
+ .dimensions.push_back(dim.size());
+ }
+ }
+ }
+ }
+ // Name of the TF checkpoint inside the aggregand map in the Checkpoint
+ // protobuf. This field name is ignored by the server.
+ if (!checkpoint_file.empty()) {
+ FCP_ASSIGN_OR_RETURN(std::string tf_checkpoint,
+ fcp::ReadFileToString(checkpoint_file));
+ computation_results[std::string(kTensorflowCheckpointAggregand)] =
+ std::move(tf_checkpoint);
+ }
+ return computation_results;
+}
+#ifdef FCP_CLIENT_SUPPORT_TFMOBILE
+std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
+ConstructInputsForEligibilityEvalPlan(
+ const FederatedComputeEligibilityIORouter& io_router,
+ const std::string& checkpoint_input_filename) {
+ auto inputs = std::make_unique<
+ std::vector<std::pair<std::string, tensorflow::Tensor>>>();
+ if (!io_router.input_filepath_tensor_name().empty()) {
+ tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {});
+ input_filepath.scalar<tensorflow::tstring>()() = checkpoint_input_filename;
+ inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath});
+ }
+ return inputs;
+}
+#endif
+std::unique_ptr<TfLiteInputs> ConstructTfLiteInputsForEligibilityEvalPlan(
+ const FederatedComputeEligibilityIORouter& io_router,
+ const std::string& checkpoint_input_filename) {
+ auto inputs = std::make_unique<TfLiteInputs>();
+ if (!io_router.input_filepath_tensor_name().empty()) {
+ (*inputs)[io_router.input_filepath_tensor_name()] =
+ checkpoint_input_filename;
+ }
+ return inputs;
+}
+// Returns the cumulative network stats (those incurred up until this point in
+// time).
+//
+// The `FederatedSelectManager` object may be null, if it is know that there
+// has been no network usage from it yet.
+NetworkStats GetCumulativeNetworkStats(
+ FederatedProtocol* federated_protocol,
+ FederatedSelectManager* fedselect_manager) {
+ NetworkStats result = federated_protocol->GetNetworkStats();
+ if (fedselect_manager != nullptr) {
+ result = result + fedselect_manager->GetNetworkStats();
+ }
+ return result;
+}
+// Returns the newly incurred network stats since the previous snapshot of
+// stats (the `reference_point` argument).
+NetworkStats GetNetworkStatsSince(FederatedProtocol* federated_protocol,
+ FederatedSelectManager* fedselect_manager,
+ const NetworkStats& reference_point) {
+ return GetCumulativeNetworkStats(federated_protocol, fedselect_manager) -
+ reference_point;
+}
+// Updates the fields of `FLRunnerResult` that should always be updated after
+// each interaction with the `FederatedProtocol` or `FederatedSelectManager`
+// objects.
+//
+// The `FederatedSelectManager` object may be null, if it is know that there
+// has been no network usage from it yet.
+void UpdateRetryWindowAndNetworkStats(FederatedProtocol& federated_protocol,
+ FederatedSelectManager* fedselect_manager,
+ PhaseLogger& phase_logger,
+ FLRunnerResult& fl_runner_result) {
+ // Update the result's retry window to the most recent one.
+ auto retry_window = federated_protocol.GetLatestRetryWindow();
+ RetryInfo retry_info;
+ *retry_info.mutable_retry_token() = retry_window.retry_token();
+ *retry_info.mutable_minimum_delay() = retry_window.delay_min();
+ *fl_runner_result.mutable_retry_info() = retry_info;
+ phase_logger.UpdateRetryWindowAndNetworkStats(
+ retry_window,
+ GetCumulativeNetworkStats(&federated_protocol, fedselect_manager));
+}
+// Creates an ExampleIteratorFactory that routes queries to the
+// SimpleTaskEnvironment::CreateExampleIterator() method.
+std::unique_ptr<engine::ExampleIteratorFactory>
+CreateSimpleTaskEnvironmentIteratorFactory(
+ SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
+ return std::make_unique<engine::FunctionalExampleIteratorFactory>(
+ /*can_handle_func=*/
+ [](const google::internal::federated::plan::ExampleSelector&) {
+ // The SimpleTaskEnvironment-based ExampleIteratorFactory should
+ // be the catch-all factory that is able to handle all queries
+ // that no other ExampleIteratorFactory is able to handle.
+ return true;
+ },
+ /*create_iterator_func=*/
+ [task_env, selector_context](
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) {
+ return task_env->CreateExampleIterator(example_selector,
+ selector_context);
+ },
+ /*should_collect_stats=*/true);
+}
+engine::PlanResult RunEligibilityEvalPlanWithTensorflowSpec(
+ std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
+ std::function<bool()> should_abort, LogManager* log_manager,
+ OpStatsLogger* opstats_logger, const Flags* flags,
+ const ClientOnlyPlan& client_plan,
+ const std::string& checkpoint_input_filename,
+ const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
+ const absl::Time run_plan_start_time, const absl::Time reference_time) {
+ // Check that this is a TensorflowSpec-based plan for federated eligibility
+ // computation.
+ if (!client_plan.phase().has_tensorflow_spec() ||
+ !client_plan.phase().has_federated_compute_eligibility()) {
+ return engine::PlanResult(
+ engine::PlanOutcome::kInvalidArgument,
+ absl::InvalidArgumentError("Invalid eligibility eval plan"));
+ }
+ const FederatedComputeEligibilityIORouter& io_router =
+ client_plan.phase().federated_compute_eligibility();
+ std::vector<std::string> output_names = {
+ io_router.task_eligibility_info_tensor_name()};
+ if (!client_plan.tflite_graph().empty()) {
+ log_manager->LogDiag(
+ ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED);
+ }
+ if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
+ std::unique_ptr<TfLiteInputs> tflite_inputs =
+ ConstructTfLiteInputsForEligibilityEvalPlan(io_router,
+ checkpoint_input_filename);
+ engine::TfLitePlanEngine plan_engine(example_iterator_factories,
+ should_abort, log_manager,
+ opstats_logger, flags, &timing_config);
+ return plan_engine.RunPlan(client_plan.phase().tensorflow_spec(),
+ client_plan.tflite_graph(),
+ std::move(tflite_inputs), output_names);
+ }
+#ifdef FCP_CLIENT_SUPPORT_TFMOBILE
+ // Construct input tensors and output tensor names based on the values in
+ // the FederatedComputeEligibilityIORouter message.
+ auto inputs = ConstructInputsForEligibilityEvalPlan(
+ io_router, checkpoint_input_filename);
+ // Run plan and get a set of output tensors back.
+ engine::SimplePlanEngine plan_engine(
+ example_iterator_factories, should_abort, log_manager, opstats_logger,
+ &timing_config, flags->support_constant_tf_inputs());
+ return plan_engine.RunPlan(
+ client_plan.phase().tensorflow_spec(), client_plan.graph(),
+ client_plan.tensorflow_config_proto(), std::move(inputs), output_names);
+#else
+ return engine::PlanResult(
+ engine::PlanOutcome::kTensorflowError,
+ absl::InternalError("No eligibility eval plan engine enabled"));
+#endif
+}
+// Validates the output tensors that resulted from executing the plan, and
+// then parses the output into a TaskEligibilityInfo proto. Returns an error
+// if validation or parsing failed.
+absl::StatusOr<TaskEligibilityInfo> ParseEligibilityEvalPlanOutput(
+ const std::vector<tensorflow::Tensor>& output_tensors) {
+ auto output_size = output_tensors.size();
+ if (output_size != 1) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Unexpected number of output tensors: ", output_size));
+ }
+ auto output_elements = output_tensors[0].NumElements();
+ if (output_elements != 1) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "Unexpected number of output tensor elements: ", output_elements));
+ }
+ tensorflow::DataType output_type = output_tensors[0].dtype();
+ if (output_type != tensorflow::DT_STRING) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Unexpected output tensor type: ", output_type));
+ }
+ // Extract the serialized TaskEligibilityInfo proto from the tensor and
+ // parse it.
+ // First, convert the output Tensor into a Scalar (= a TensorMap with 1
+ // element), then use its operator() to access the actual data.
+ const tensorflow::tstring& serialized_output =
+ output_tensors[0].scalar<const tensorflow::tstring>()();
+ TaskEligibilityInfo parsed_output;
+ if (!parsed_output.ParseFromString(serialized_output)) {
+ return absl::InvalidArgumentError("Could not parse output proto");
+ }
+ return parsed_output;
+}
+#ifdef FCP_CLIENT_SUPPORT_TFMOBILE
+std::unique_ptr<std::vector<std::pair<std::string, tensorflow::Tensor>>>
+ConstructInputsForTensorflowSpecPlan(
+ const FederatedComputeIORouter& io_router,
+ const std::string& checkpoint_input_filename,
+ const std::string& checkpoint_output_filename) {
+ auto inputs = std::make_unique<
+ std::vector<std::pair<std::string, tensorflow::Tensor>>>();
+ if (!io_router.input_filepath_tensor_name().empty()) {
+ tensorflow::Tensor input_filepath(tensorflow::DT_STRING, {});
+ input_filepath.scalar<tensorflow::tstring>()() = checkpoint_input_filename;
+ inputs->push_back({io_router.input_filepath_tensor_name(), input_filepath});
+ }
+ if (!io_router.output_filepath_tensor_name().empty()) {
+ tensorflow::Tensor output_filepath(tensorflow::DT_STRING, {});
+ output_filepath.scalar<tensorflow::tstring>()() =
+ checkpoint_output_filename;
+ inputs->push_back(
+ {io_router.output_filepath_tensor_name(), output_filepath});
+ }
+ return inputs;
+}
+#endif
+std::unique_ptr<TfLiteInputs> ConstructTFLiteInputsForTensorflowSpecPlan(
+ const FederatedComputeIORouter& io_router,
+ const std::string& checkpoint_input_filename,
+ const std::string& checkpoint_output_filename) {
+ auto inputs = std::make_unique<TfLiteInputs>();
+ if (!io_router.input_filepath_tensor_name().empty()) {
+ (*inputs)[io_router.input_filepath_tensor_name()] =
+ checkpoint_input_filename;
+ }
+ if (!io_router.output_filepath_tensor_name().empty()) {
+ (*inputs)[io_router.output_filepath_tensor_name()] =
+ checkpoint_output_filename;
+ }
+ return inputs;
+}
+absl::StatusOr<std::vector<std::string>> ConstructOutputsWithDeterministicOrder(
+ const TensorflowSpec& tensorflow_spec,
+ const FederatedComputeIORouter& io_router) {
+ std::vector<std::string> output_names;
+ // The order of output tensor names should match the order in
+ // TensorflowSpec.
+ for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) {
+ std::string tensor_name = output_tensor_spec.name();
+ if (!io_router.aggregations().contains(tensor_name) ||
+ !io_router.aggregations().at(tensor_name).has_secure_aggregation()) {
+ return absl::InvalidArgumentError(
+ "Output tensor is missing in AggregationConfig, or has unsupported "
+ "aggregation type.");
+ }
+ output_names.push_back(tensor_name);
+ }
+ return output_names;
+}
+PlanResultAndCheckpointFile RunPlanWithTensorflowSpec(
+ std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
+ std::function<bool()> should_abort, LogManager* log_manager,
+ OpStatsLogger* opstats_logger, const Flags* flags,
+ const ClientOnlyPlan& client_plan,
+ const std::string& checkpoint_input_filename,
+ const std::string& checkpoint_output_filename,
+ const fcp::client::InterruptibleRunner::TimingConfig& timing_config) {
+ if (!client_plan.phase().has_tensorflow_spec()) {
+ return PlanResultAndCheckpointFile(engine::PlanResult(
+ engine::PlanOutcome::kInvalidArgument,
+ absl::InvalidArgumentError("Plan must include TensorflowSpec.")));
+ }
+ if (!client_plan.phase().has_federated_compute()) {
+ return PlanResultAndCheckpointFile(engine::PlanResult(
+ engine::PlanOutcome::kInvalidArgument,
+ absl::InvalidArgumentError("Invalid TensorflowSpec-based plan")));
+ }
+ // Get the output tensor names.
+ absl::StatusOr<std::vector<std::string>> output_names;
+ output_names = ConstructOutputsWithDeterministicOrder(
+ client_plan.phase().tensorflow_spec(),
+ client_plan.phase().federated_compute());
+ if (!output_names.ok()) {
+ return PlanResultAndCheckpointFile(engine::PlanResult(
+ engine::PlanOutcome::kInvalidArgument, output_names.status()));
+ }
+ // Run plan and get a set of output tensors back.
+ if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
+ std::unique_ptr<TfLiteInputs> tflite_inputs =
+ ConstructTFLiteInputsForTensorflowSpecPlan(
+ client_plan.phase().federated_compute(), checkpoint_input_filename,
+ checkpoint_output_filename);
+ engine::TfLitePlanEngine plan_engine(example_iterator_factories,
+ should_abort, log_manager,
+ opstats_logger, flags, &timing_config);
+ engine::PlanResult plan_result = plan_engine.RunPlan(
+ client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
+ std::move(tflite_inputs), *output_names);
+ PlanResultAndCheckpointFile result(std::move(plan_result));
+ result.checkpoint_file = checkpoint_output_filename;
+ return result;
+ }
+#ifdef FCP_CLIENT_SUPPORT_TFMOBILE
+ // Construct input tensors based on the values in the
+ // FederatedComputeIORouter message and create a temporary file for the
+ // output checkpoint if needed.
+ auto inputs = ConstructInputsForTensorflowSpecPlan(
+ client_plan.phase().federated_compute(), checkpoint_input_filename,
+ checkpoint_output_filename);
+ engine::SimplePlanEngine plan_engine(
+ example_iterator_factories, should_abort, log_manager, opstats_logger,
+ &timing_config, flags->support_constant_tf_inputs());
+ engine::PlanResult plan_result = plan_engine.RunPlan(
+ client_plan.phase().tensorflow_spec(), client_plan.graph(),
+ client_plan.tensorflow_config_proto(), std::move(inputs), *output_names);
+ PlanResultAndCheckpointFile result(std::move(plan_result));
+ result.checkpoint_file = checkpoint_output_filename;
+ return result;
+#else
+ return PlanResultAndCheckpointFile(
+ engine::PlanResult(engine::PlanOutcome::kTensorflowError,
+ absl::InternalError("No plan engine enabled")));
+#endif
+}
+PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec(
+ std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
+ OpStatsLogger* opstats_logger, const Flags* flags,
+ const ClientOnlyPlan& client_plan,
+ const std::string& checkpoint_output_filename) {
+ if (!client_plan.phase().has_example_query_spec()) {
+ return PlanResultAndCheckpointFile(engine::PlanResult(
+ engine::PlanOutcome::kInvalidArgument,
+ absl::InvalidArgumentError("Plan must include ExampleQuerySpec")));
+ }
+ if (!flags->enable_example_query_plan_engine()) {
+ // Example query plan received while the flag is off.
+ return PlanResultAndCheckpointFile(engine::PlanResult(
+ engine::PlanOutcome::kInvalidArgument,
+ absl::InvalidArgumentError(
+ "Example query plan received while the flag is off")));
+ }
+ if (!client_plan.phase().has_federated_example_query()) {
+ return PlanResultAndCheckpointFile(engine::PlanResult(
+ engine::PlanOutcome::kInvalidArgument,
+ absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan")));
+ }
+ for (const auto& example_query :
+ client_plan.phase().example_query_spec().example_queries()) {
+ for (auto const& [vector_name, spec] :
+ example_query.output_vector_specs()) {
+ const auto& aggregations =
+ client_plan.phase().federated_example_query().aggregations();
+ if ((aggregations.find(vector_name) == aggregations.end()) ||
+ !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) {
+ return PlanResultAndCheckpointFile(engine::PlanResult(
+ engine::PlanOutcome::kInvalidArgument,
+ absl::InvalidArgumentError("Output vector is missing in "
+ "AggregationConfig, or has unsupported "
+ "aggregation type.")));
+ }
+ }
+ }
+ engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories,
+ opstats_logger);
+ engine::PlanResult plan_result = plan_engine.RunPlan(
+ client_plan.phase().example_query_spec(), checkpoint_output_filename);
+ PlanResultAndCheckpointFile result(std::move(plan_result));
+ result.checkpoint_file = checkpoint_output_filename;
+ return result;
+}
+void LogEligibilityEvalComputationOutcome(
+ PhaseLogger& phase_logger, engine::PlanResult plan_result,
+ const absl::Status& eligibility_info_parsing_status,
+ absl::Time run_plan_start_time, absl::Time reference_time) {
+ switch (plan_result.outcome) {
+ case engine::PlanOutcome::kSuccess: {
+ if (eligibility_info_parsing_status.ok()) {
+ phase_logger.LogEligibilityEvalComputationCompleted(
+ plan_result.example_stats, run_plan_start_time, reference_time);
+ } else {
+ phase_logger.LogEligibilityEvalComputationTensorflowError(
+ eligibility_info_parsing_status, plan_result.example_stats,
+ run_plan_start_time, reference_time);
+ FCP_LOG(ERROR) << eligibility_info_parsing_status.message();
+ }
+ break;
+ }
+ case engine::PlanOutcome::kInterrupted:
+ phase_logger.LogEligibilityEvalComputationInterrupted(
+ plan_result.original_status, plan_result.example_stats,
+ run_plan_start_time, reference_time);
+ break;
+ case engine::PlanOutcome::kInvalidArgument:
+ phase_logger.LogEligibilityEvalComputationInvalidArgument(
+ plan_result.original_status, plan_result.example_stats,
+ run_plan_start_time);
+ break;
+ case engine::PlanOutcome::kTensorflowError:
+ phase_logger.LogEligibilityEvalComputationTensorflowError(
+ plan_result.original_status, plan_result.example_stats,
+ run_plan_start_time, reference_time);
+ break;
+ case engine::PlanOutcome::kExampleIteratorError:
+ phase_logger.LogEligibilityEvalComputationExampleIteratorError(
+ plan_result.original_status, plan_result.example_stats,
+ run_plan_start_time);
+ break;
+ }
+}
+void LogComputationOutcome(const engine::PlanResult& plan_result,
+ absl::Status computation_results_parsing_status,
+ PhaseLogger& phase_logger,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time,
+ absl::Time reference_time) {
+ switch (plan_result.outcome) {
+ case engine::PlanOutcome::kSuccess: {
+ if (computation_results_parsing_status.ok()) {
+ phase_logger.LogComputationCompleted(plan_result.example_stats,
+ network_stats, run_plan_start_time,
+ reference_time);
+ } else {
+ phase_logger.LogComputationTensorflowError(
+ computation_results_parsing_status, plan_result.example_stats,
+ network_stats, run_plan_start_time, reference_time);
+ }
+ break;
+ }
+ case engine::PlanOutcome::kInterrupted:
+ phase_logger.LogComputationInterrupted(
+ plan_result.original_status, plan_result.example_stats, network_stats,
+ run_plan_start_time, reference_time);
+ break;
+ case engine::PlanOutcome::kInvalidArgument:
+ phase_logger.LogComputationInvalidArgument(
+ plan_result.original_status, plan_result.example_stats, network_stats,
+ run_plan_start_time);
+ break;
+ case engine::PlanOutcome::kTensorflowError:
+ phase_logger.LogComputationTensorflowError(
+ plan_result.original_status, plan_result.example_stats, network_stats,
+ run_plan_start_time, reference_time);
+ break;
+ case engine::PlanOutcome::kExampleIteratorError:
+ phase_logger.LogComputationExampleIteratorError(
+ plan_result.original_status, plan_result.example_stats, network_stats,
+ run_plan_start_time);
+ break;
+ }
+}
+void LogResultUploadStatus(PhaseLogger& phase_logger, absl::Status result,
+ const NetworkStats& network_stats,
+ absl::Time time_before_result_upload,
+ absl::Time reference_time) {
+ if (result.ok()) {
+ phase_logger.LogResultUploadCompleted(
+ network_stats, time_before_result_upload, reference_time);
+ } else {
+ auto message =
+ absl::StrCat("Error reporting results: code: ", result.code(),
+ ", message: ", result.message());
+ FCP_LOG(INFO) << message;
+ if (result.code() == absl::StatusCode::kAborted) {
+ phase_logger.LogResultUploadServerAborted(
+ result, network_stats, time_before_result_upload, reference_time);
+ } else if (result.code() == absl::StatusCode::kCancelled) {
+ phase_logger.LogResultUploadClientInterrupted(
+ result, network_stats, time_before_result_upload, reference_time);
+ } else {
+ phase_logger.LogResultUploadIOError(
+ result, network_stats, time_before_result_upload, reference_time);
+ }
+ }
+}
+void LogFailureUploadStatus(PhaseLogger& phase_logger, absl::Status result,
+ const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time) {
+ if (result.ok()) {
+ phase_logger.LogFailureUploadCompleted(
+ network_stats, time_before_failure_upload, reference_time);
+ } else {
+ auto message = absl::StrCat("Error reporting computation failure: code: ",
+ result.code(), ", message: ", result.message());
+ FCP_LOG(INFO) << message;
+ if (result.code() == absl::StatusCode::kAborted) {
+ phase_logger.LogFailureUploadServerAborted(
+ result, network_stats, time_before_failure_upload, reference_time);
+ } else if (result.code() == absl::StatusCode::kCancelled) {
+ phase_logger.LogFailureUploadClientInterrupted(
+ result, network_stats, time_before_failure_upload, reference_time);
+ } else {
+ phase_logger.LogFailureUploadIOError(
+ result, network_stats, time_before_failure_upload, reference_time);
+ }
+ }
+}
+absl::Status ReportPlanResult(
+ FederatedProtocol* federated_protocol, PhaseLogger& phase_logger,
+ absl::StatusOr<ComputationResults> computation_results,
+ absl::Time run_plan_start_time, absl::Time reference_time) {
+ const absl::Time before_report_time = absl::Now();
+ // Note that the FederatedSelectManager shouldn't be active anymore during
+ // the reporting of results, so we don't bother passing it to
+ // GetNetworkStatsSince.
+ //
+ // We must return only stats that cover the report phase for the log events
+ // below.
+ const NetworkStats before_report_stats =
+ GetCumulativeNetworkStats(federated_protocol,
+ /*fedselect_manager=*/nullptr);
+ absl::Status result = absl::InternalError("");
+ if (computation_results.ok()) {
+ FCP_RETURN_IF_ERROR(phase_logger.LogResultUploadStarted());
+ result = federated_protocol->ReportCompleted(
+ std::move(*computation_results),
+ /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt);
+ LogResultUploadStatus(phase_logger, result,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ before_report_stats),
+ before_report_time, reference_time);
+ } else {
+ FCP_RETURN_IF_ERROR(phase_logger.LogFailureUploadStarted());
+ result = federated_protocol->ReportNotCompleted(
+ engine::PhaseOutcome::ERROR,
+ /*plan_duration=*/absl::Now() - run_plan_start_time, std::nullopt);
+ LogFailureUploadStatus(phase_logger, result,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ before_report_stats),
+ before_report_time, reference_time);
+ }
+ return result;
+}
+// Writes the given data to the stream, and returns true if successful and
+// false if not.
+bool WriteStringOrCordToFstream(
+ std::fstream& stream, const std::variant<std::string, absl::Cord>& data) {
+ if (stream.fail()) {
+ return false;
+ }
+ if (std::holds_alternative<std::string>(data)) {
+ return (stream << std::get<std::string>(data)).good();
+ }
+ for (absl::string_view chunk : std::get<absl::Cord>(data).Chunks()) {
+ if (!(stream << chunk).good()) {
+ return false;
+ }
+ }
+ return true;
+}
+// Writes the given checkpoint data to a newly created temporary file.
+// Returns the filename if successful, or an error if the file could not be
+// created, or if writing to the file failed.
+absl::StatusOr<std::string> CreateInputCheckpointFile(
+ Files* files, const std::variant<std::string, absl::Cord>& checkpoint) {
+ // Create the temporary checkpoint file.
+ // Deletion of the file is left to the caller / the Files implementation.
+ FCP_ASSIGN_OR_RETURN(absl::StatusOr<std::string> filename,
+ files->CreateTempFile("init", ".ckp"));
+ // Write the checkpoint data to the file.
+ std::fstream checkpoint_stream(*filename, std::ios_base::out);
+ if (!WriteStringOrCordToFstream(checkpoint_stream, checkpoint)) {
+ return absl::InvalidArgumentError("Failed to write to file");
+ }
+ checkpoint_stream.close();
+ return filename;
+}
+absl::StatusOr<std::optional<TaskEligibilityInfo>> RunEligibilityEvalPlan(
+ const FederatedProtocol::EligibilityEvalTask& eligibility_eval_task,
+ std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
+ std::function<bool()> should_abort, PhaseLogger& phase_logger, Files* files,
+ LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
+ FederatedProtocol* federated_protocol,
+ const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
+ const absl::Time reference_time, const absl::Time time_before_checkin,
+ const absl::Time time_before_plan_download,
+ const NetworkStats& network_stats) {
+ ClientOnlyPlan plan;
+ if (!ParseFromStringOrCord(plan, eligibility_eval_task.payloads.plan)) {
+ auto message = "Failed to parse received eligibility eval plan";
+ phase_logger.LogEligibilityEvalCheckinInvalidPayloadError(
+ message, network_stats, time_before_plan_download);
+ FCP_LOG(ERROR) << message;
+ return absl::InternalError(message);
+ }
+ absl::StatusOr<std::string> checkpoint_input_filename =
+ CreateInputCheckpointFile(files,
+ eligibility_eval_task.payloads.checkpoint);
+ if (!checkpoint_input_filename.ok()) {
+ auto status = checkpoint_input_filename.status();
+ auto message = absl::StrCat(
+ "Failed to create eligibility eval checkpoint input file: code: ",
+ status.code(), ", message: ", status.message());
+ phase_logger.LogEligibilityEvalCheckinIOError(status, network_stats,
+ time_before_plan_download);
+ FCP_LOG(ERROR) << message;
+ return absl::InternalError("");
+ }
+ phase_logger.LogEligibilityEvalCheckinCompleted(network_stats,
+ /*time_before_checkin=*/
+ time_before_checkin,
+ /*time_before_plan_download=*/
+ time_before_plan_download);
+ absl::Time run_plan_start_time = absl::Now();
+ phase_logger.LogEligibilityEvalComputationStarted();
+ engine::PlanResult plan_result = RunEligibilityEvalPlanWithTensorflowSpec(
+ example_iterator_factories, should_abort, log_manager, opstats_logger,
+ flags, plan, *checkpoint_input_filename, timing_config,
+ run_plan_start_time, reference_time);
+ absl::StatusOr<TaskEligibilityInfo> task_eligibility_info;
+ if (plan_result.outcome == engine::PlanOutcome::kSuccess) {
+ task_eligibility_info =
+ ParseEligibilityEvalPlanOutput(plan_result.output_tensors);
+ }
+ LogEligibilityEvalComputationOutcome(phase_logger, std::move(plan_result),
+ task_eligibility_info.status(),
+ run_plan_start_time, reference_time);
+ return task_eligibility_info;
+}
+struct EligibilityEvalResult {
+ std::optional<TaskEligibilityInfo> task_eligibility_info;
+ std::vector<std::string> task_names_for_multiple_task_assignments;
+};
+// Create an EligibilityEvalResult from a TaskEligibilityInfo and
+// PopulationEligibilitySpec. If both population_spec and
+// task_eligibility_info are present, the returned EligibilityEvalResult will
+// contain a TaskEligibilityInfo which only contains the tasks for single task
+// assignment, and a vector of task names for multiple task assignment.
+EligibilityEvalResult CreateEligibilityEvalResult(
+ const std::optional<TaskEligibilityInfo>& task_eligibility_info,
+ const std::optional<PopulationEligibilitySpec>& population_spec) {
+ EligibilityEvalResult result;
+ if (population_spec.has_value() && task_eligibility_info.has_value()) {
+ absl::flat_hash_set<std::string> task_names_for_multiple_task_assignments;
+ for (const auto& task_info : population_spec.value().task_info()) {
+ if (task_info.task_assignment_mode() ==
+ PopulationEligibilitySpec::TaskInfo::TASK_ASSIGNMENT_MODE_MULTIPLE) {
+ task_names_for_multiple_task_assignments.insert(task_info.task_name());
+ }
+ }
+ TaskEligibilityInfo single_task_assignment_eligibility_info;
+ single_task_assignment_eligibility_info.set_version(
+ task_eligibility_info.value().version());
+ for (const auto& task_weight :
+ task_eligibility_info.value().task_weights()) {
+ if (task_names_for_multiple_task_assignments.contains(
+ task_weight.task_name())) {
+ result.task_names_for_multiple_task_assignments.push_back(
+ task_weight.task_name());
+ } else {
+ *single_task_assignment_eligibility_info.mutable_task_weights()->Add() =
+ task_weight;
+ }
+ }
+ result.task_eligibility_info = single_task_assignment_eligibility_info;
+ } else {
+ result.task_eligibility_info = task_eligibility_info;
+ }
+ return result;
+}
+// Issues an eligibility eval checkin request and executes the eligibility
+// eval task if the server returns one.
+//
+// This function modifies the FLRunnerResult with values received over the
+// course of the eligibility eval protocol interaction.
+//
+// Returns:
+// - the TaskEligibilityInfo produced by the eligibility eval task, if the
+// server provided an eligibility eval task to run.
+// - an std::nullopt if the server indicated that there is no eligibility eval
+// task configured for the population.
+// - an INTERNAL error if the server rejects the client or another error
+// occurs
+// that should abort the training run. The error will already have been
+// logged appropriately.
+absl::StatusOr<EligibilityEvalResult> IssueEligibilityEvalCheckinAndRunPlan(
+ std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
+ std::function<bool()> should_abort, PhaseLogger& phase_logger, Files* files,
+ LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
+ FederatedProtocol* federated_protocol,
+ const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
+ const absl::Time reference_time, FLRunnerResult& fl_runner_result) {
+ const absl::Time time_before_checkin = absl::Now();
+ const NetworkStats network_stats_before_checkin =
+ GetCumulativeNetworkStats(federated_protocol,
+ /*fedselect_manager=*/nullptr);
+ // These fields will, after a successful checkin that resulted in an EET
+ // being received, contain the time at which the EET plan/checkpoint URIs
+ // were received (but not yet downloaded), as well as the cumulative network
+ // stats at that point, allowing us to separately calculate how long it took
+ // to then download the actual payloads.
+ absl::Time time_before_plan_download = time_before_checkin;
+ NetworkStats network_stats_before_plan_download =
+ network_stats_before_checkin;
+ // Log that we are about to check in with the server.
+ phase_logger.LogEligibilityEvalCheckinStarted();
+ // Issue the eligibility eval checkin request (providing a callback that
+ // will be called when an EET is assigned to the task but before its
+ // plan/checkpoint URIs have actually been downloaded).
+ bool plan_uris_received_callback_called = false;
+ std::function<void(const FederatedProtocol::EligibilityEvalTask&)>
+ plan_uris_received_callback =
+ [&time_before_plan_download, &network_stats_before_plan_download,
+ &time_before_checkin, &network_stats_before_checkin,
+ &federated_protocol, &phase_logger,
+ &plan_uris_received_callback_called](
+ const FederatedProtocol::EligibilityEvalTask& task) {
+ // When the plan URIs have been received, we already know the name
+ // of the task we have been assigned, so let's tell the
+ // PhaseLogger.
+ phase_logger.SetModelIdentifier(task.execution_id);
+ // We also should log a corresponding log event.
+ phase_logger.LogEligibilityEvalCheckinPlanUriReceived(
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_checkin),
+ time_before_checkin);
+ // And we must take a snapshot of the current time & network
+ // stats, so we can distinguish between the duration/network stats
+ // incurred for the checkin request vs. the actual downloading of
+ // the plan/checkpoint resources.
+ time_before_plan_download = absl::Now();
+ network_stats_before_plan_download =
+ GetCumulativeNetworkStats(federated_protocol,
+ /*fedselect_manager=*/nullptr);
+ plan_uris_received_callback_called = true;
+ };
+ absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
+ eligibility_checkin_result = federated_protocol->EligibilityEvalCheckin(
+ plan_uris_received_callback);
+ UpdateRetryWindowAndNetworkStats(*federated_protocol,
+ /*fedselect_manager=*/nullptr, phase_logger,
+ fl_runner_result);
+ // It's a bit unfortunate that we have to inspect the checkin_result and
+ // extract the model identifier here rather than further down the function,
+ // but this ensures that the histograms below will have the right model
+ // identifier attached (and we want to also emit the histograms even if we
+ // have failed/rejected checkin outcomes).
+ if (eligibility_checkin_result.ok() &&
+ std::holds_alternative<FederatedProtocol::EligibilityEvalTask>(
+ *eligibility_checkin_result)) {
+ // Make sure that if we received an EligibilityEvalTask, then the callback
+ // should have already been called by this point by the protocol (ensuring
+ // that SetModelIdentifier has been called etc.).
+ FCP_CHECK(plan_uris_received_callback_called);
+ }
+ if (!eligibility_checkin_result.ok()) {
+ auto status = eligibility_checkin_result.status();
+ auto message = absl::StrCat("Error during eligibility eval checkin: code: ",
+ status.code(), ", message: ", status.message());
+ if (status.code() == absl::StatusCode::kAborted) {
+ phase_logger.LogEligibilityEvalCheckinServerAborted(
+ status,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download),
+ time_before_plan_download);
+ } else if (status.code() == absl::StatusCode::kCancelled) {
+ phase_logger.LogEligibilityEvalCheckinClientInterrupted(
+ status,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download),
+ time_before_plan_download);
+ } else if (!status.ok()) {
+ phase_logger.LogEligibilityEvalCheckinIOError(
+ status,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download),
+ time_before_plan_download);
+ }
+ FCP_LOG(INFO) << message;
+ return absl::InternalError("");
+ }
+ EligibilityEvalResult result;
+ if (std::holds_alternative<FederatedProtocol::Rejection>(
+ *eligibility_checkin_result)) {
+ phase_logger.LogEligibilityEvalCheckinTurnedAway(
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_checkin),
+ time_before_checkin);
+ // If the server explicitly rejected our request, then we must abort and
+ // we must not proceed to the "checkin" phase below.
+ FCP_LOG(INFO) << "Device rejected by server during eligibility eval "
+ "checkin; aborting";
+ return absl::InternalError("");
+ } else if (std::holds_alternative<FederatedProtocol::EligibilityEvalDisabled>(
+ *eligibility_checkin_result)) {
+ phase_logger.LogEligibilityEvalNotConfigured(
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_checkin),
+ time_before_checkin);
+ // If the server indicates that no eligibility eval task is configured for
+ // the population then there is nothing more to do. We simply proceed to
+ // the "checkin" phase below without providing it a TaskEligibilityInfo
+ // proto.
+ result.task_eligibility_info = std::nullopt;
+ return result;
+ }
+ auto eligibility_eval_task =
+ absl::get<FederatedProtocol::EligibilityEvalTask>(
+ *eligibility_checkin_result);
+ // Parse and run the eligibility eval task if the server returned one.
+ // Now we have a EligibilityEvalTask, if an error happens, we will report to
+ // the server via the ReportEligibilityEvalError.
+ absl::StatusOr<std::optional<TaskEligibilityInfo>> task_eligibility_info =
+ RunEligibilityEvalPlan(
+ eligibility_eval_task, example_iterator_factories, should_abort,
+ phase_logger, files, log_manager, opstats_logger, flags,
+ federated_protocol, timing_config, reference_time,
+ /*time_before_checkin=*/time_before_checkin,
+ /*time_before_plan_download=*/time_before_plan_download,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download));
+ if (!task_eligibility_info.ok()) {
+ // Note that none of the PhaseLogger methods will reflect the very little
+ // amount of network usage the will be incurred by this protocol request.
+ // We consider this to be OK to keep things simple, and because this
+ // should use such a limited amount of network bandwidth. Do note that the
+ // network usage *will* be correctly reported in the OpStats database.
+ federated_protocol->ReportEligibilityEvalError(
+ absl::Status(task_eligibility_info.status().code(),
+ "Failed to compute eligibility info"));
+ UpdateRetryWindowAndNetworkStats(*federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ phase_logger, fl_runner_result);
+ return task_eligibility_info.status();
+ }
+ return CreateEligibilityEvalResult(
+ *task_eligibility_info,
+ eligibility_eval_task.population_eligibility_spec);
+}
+struct CheckinResult {
+ std::string task_name;
+ ClientOnlyPlan plan;
+ int32_t minimum_clients_in_server_visible_aggregate;
+ std::string checkpoint_input_filename;
+ std::string computation_id;
+ std::string federated_select_uri_template;
+};
+absl::StatusOr<CheckinResult> IssueCheckin(
+ PhaseLogger& phase_logger, LogManager* log_manager, Files* files,
+ FederatedProtocol* federated_protocol,
+ std::optional<TaskEligibilityInfo> task_eligibility_info,
+ absl::Time reference_time, const std::string& population_name,
+ FLRunnerResult& fl_runner_result, const Flags* flags) {
+ absl::Time time_before_checkin = absl::Now();
+ // We must return only stats that cover the check in phase for the log
+ // events below.
+ const NetworkStats network_stats_before_checkin =
+ GetCumulativeNetworkStats(federated_protocol,
+ /*fedselect_manager=*/nullptr);
+ // These fields will, after a successful checkin that resulted in a task
+ // being assigned, contain the time at which the task plan/checkpoint URIs
+ // were received (but not yet downloaded), as well as the cumulative network
+ // stats at that point, allowing us to separately calculate how long it took
+ // to then download the actual payloads.
+ absl::Time time_before_plan_download = time_before_checkin;
+ NetworkStats network_stats_before_plan_download =
+ network_stats_before_checkin;
+ // Clear the model identifier before check-in, to ensure that the any prior
+ // eligibility eval task name isn't used any longer.
+ phase_logger.SetModelIdentifier("");
+ phase_logger.LogCheckinStarted();
+ std::string task_name;
+ // Issue the checkin request (providing a callback that will be called when
+ // an EET is assigned to the task but before its plan/checkpoint URIs have
+ // actually been downloaded).
+ bool plan_uris_received_callback_called = false;
+ std::function<void(const FederatedProtocol::TaskAssignment&)>
+ plan_uris_received_callback =
+ [&time_before_plan_download, &network_stats_before_plan_download,
+ &time_before_checkin, &network_stats_before_checkin, &task_name,
+ &federated_protocol, &population_name, &log_manager, &phase_logger,
+ &plan_uris_received_callback_called](
+ const FederatedProtocol::TaskAssignment& task_assignment) {
+ // When the plan URIs have been received, we already know the name
+ // of the task we have been assigned, so let's tell the
+ // PhaseLogger.
+ auto model_identifier = task_assignment.aggregation_session_id;
+ phase_logger.SetModelIdentifier(model_identifier);
+ // We also should log a corresponding log event.
+ task_name = ExtractTaskNameFromAggregationSessionId(
+ model_identifier, population_name, *log_manager);
+ phase_logger.LogCheckinPlanUriReceived(
+ task_name,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_checkin),
+ time_before_checkin);
+ // And we must take a snapshot of the current time & network
+ // stats, so we can distinguish between the duration/network stats
+ // incurred for the checkin request vs. the actual downloading of
+ // the plan/checkpoint resources.
+ time_before_plan_download = absl::Now();
+ network_stats_before_plan_download = GetCumulativeNetworkStats(
+ federated_protocol, /*fedselect_manager=*/nullptr);
+ plan_uris_received_callback_called = true;
+ };
+ absl::StatusOr<FederatedProtocol::CheckinResult> checkin_result =
+ federated_protocol->Checkin(task_eligibility_info,
+ plan_uris_received_callback);
+ UpdateRetryWindowAndNetworkStats(*federated_protocol,
+ /*fedselect_manager=*/nullptr, phase_logger,
+ fl_runner_result);
+ // It's a bit unfortunate that we have to inspect the checkin_result and
+ // extract the model identifier here rather than further down the function,
+ // but this ensures that the histograms below will have the right model
+ // identifier attached (and we want to also emit the histograms even if we
+ // have failed/rejected checkin outcomes).
+ if (checkin_result.ok() &&
+ std::holds_alternative<FederatedProtocol::TaskAssignment>(
+ *checkin_result)) {
+ // Make sure that if we received a TaskAssignment, then the callback
+ // should have already been called by this point by the protocol (ensuring
+ // that SetModelIdentifier has been called etc.).
+ FCP_CHECK(plan_uris_received_callback_called);
+ }
+ if (!checkin_result.ok()) {
+ auto status = checkin_result.status();
+ auto message = absl::StrCat("Error during checkin: code: ", status.code(),
+ ", message: ", status.message());
+ if (status.code() == absl::StatusCode::kAborted) {
+ phase_logger.LogCheckinServerAborted(
+ status,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download),
+ time_before_plan_download, reference_time);
+ } else if (status.code() == absl::StatusCode::kCancelled) {
+ phase_logger.LogCheckinClientInterrupted(
+ status,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download),
+ time_before_plan_download, reference_time);
+ } else if (!status.ok()) {
+ phase_logger.LogCheckinIOError(
+ status,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download),
+ time_before_plan_download, reference_time);
+ }
+ FCP_LOG(INFO) << message;
+ return status;
+ }
+ // Server rejected us? Return the fl_runner_results as-is.
+ if (std::holds_alternative<FederatedProtocol::Rejection>(*checkin_result)) {
+ phase_logger.LogCheckinTurnedAway(
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_checkin),
+ time_before_checkin, reference_time);
+ FCP_LOG(INFO) << "Device rejected by server during checkin; aborting";
+ return absl::InternalError("Device rejected by server.");
+ }
+ auto task_assignment =
+ absl::get<FederatedProtocol::TaskAssignment>(*checkin_result);
+ ClientOnlyPlan plan;
+ auto plan_bytes = task_assignment.payloads.plan;
+ if (!ParseFromStringOrCord(plan, plan_bytes)) {
+ auto message = "Failed to parse received plan";
+ phase_logger.LogCheckinInvalidPayload(
+ message,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download),
+ time_before_plan_download, reference_time);
+ FCP_LOG(ERROR) << message;
+ return absl::InternalError("");
+ }
+ std::string computation_id;
+ if (flags->enable_computation_id()) {
+ computation_id = ComputeSHA256FromStringOrCord(plan_bytes);
+ }
+ int32_t minimum_clients_in_server_visible_aggregate = 0;
+ if (task_assignment.sec_agg_info.has_value()) {
+ auto minimum_number_of_participants =
+ plan.phase().minimum_number_of_participants();
+ if (task_assignment.sec_agg_info->expected_number_of_clients <
+ minimum_number_of_participants) {
+ return absl::InternalError(
+ "expectedNumberOfClients was less than Plan's "
+ "minimumNumberOfParticipants.");
+ }
+ minimum_clients_in_server_visible_aggregate =
+ task_assignment.sec_agg_info
+ ->minimum_clients_in_server_visible_aggregate;
+ }
+ absl::StatusOr<std::string> checkpoint_input_filename = "";
+ // Example query plan does not have an input checkpoint.
+ if (!plan.phase().has_example_query_spec()) {
+ checkpoint_input_filename =
+ CreateInputCheckpointFile(files, task_assignment.payloads.checkpoint);
+ if (!checkpoint_input_filename.ok()) {
+ auto status = checkpoint_input_filename.status();
+ auto message = absl::StrCat(
+ "Failed to create checkpoint input file: code: ", status.code(),
+ ", message: ", status.message());
+ phase_logger.LogCheckinIOError(
+ status,
+ GetNetworkStatsSince(federated_protocol,
+ /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download),
+ time_before_plan_download, reference_time);
+ FCP_LOG(ERROR) << message;
+ return status;
+ }
+ }
+ phase_logger.LogCheckinCompleted(
+ task_name,
+ GetNetworkStatsSince(federated_protocol, /*fedselect_manager=*/nullptr,
+ network_stats_before_plan_download),
+ /*time_before_checkin=*/time_before_checkin,
+ /*time_before_plan_download=*/time_before_plan_download, reference_time);
+ return CheckinResult{
+ .task_name = std::move(task_name),
+ .plan = std::move(plan),
+ .minimum_clients_in_server_visible_aggregate =
+ minimum_clients_in_server_visible_aggregate,
+ .checkpoint_input_filename = std::move(*checkpoint_input_filename),
+ .computation_id = std::move(computation_id),
+ .federated_select_uri_template =
+ task_assignment.federated_select_uri_template};
+}
+} // namespace
+absl::StatusOr<FLRunnerResult> RunFederatedComputation(
+ SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
+ Files* files, LogManager* log_manager, const Flags* flags,
+ const std::string& federated_service_uri, const std::string& api_key,
+ const std::string& test_cert_path, const std::string& session_name,
+ const std::string& population_name, const std::string& retry_token,
+ const std::string& client_version,
+ const std::string& attestation_measurement) {
+ auto opstats_logger =
+ engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
+ session_name, population_name);
+ absl::Time reference_time = absl::Now();
+ FLRunnerResult fl_runner_result;
+ fcp::client::InterruptibleRunner::TimingConfig timing_config = {
+ .polling_period =
+ absl::Milliseconds(flags->condition_polling_period_millis()),
+ .graceful_shutdown_period = absl::Milliseconds(
+ flags->tf_execution_teardown_grace_period_millis()),
+ .extended_shutdown_period = absl::Milliseconds(
+ flags->tf_execution_teardown_extended_period_millis()),
+ };
+ auto should_abort_protocol_callback = [&env_deps, &timing_config]() -> bool {
+ // Return the Status if failed, or the negated value if successful.
+ return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
+ };
+ PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
+ log_manager, flags);
+ // If there was an error initializing OpStats, opstats_logger will be a
+ // no-op implementation and execution will be allowed to continue.
+ if (!opstats_logger->GetInitStatus().ok()) {
+ // This will only happen if OpStats is enabled and there was an error in
+ // initialization.
+ phase_logger.LogNonfatalInitializationError(
+ opstats_logger->GetInitStatus());
+ }
+ Clock* clock = Clock::RealClock();
+ std::unique_ptr<cache::ResourceCache> resource_cache;
+ // if (flags->max_resource_cache_size_bytes() > 0) {
+ // // Anything that goes wrong in FileBackedResourceCache::Create is a
+ // // programmer error.
+ // absl::StatusOr<std::unique_ptr<cache::ResourceCache>>
+ // resource_cache_internal = cache::FileBackedResourceCache::Create(
+ // env_deps->GetBaseDir(), env_deps->GetCacheDir(), log_manager,
+ // clock, flags->max_resource_cache_size_bytes());
+ // if (!resource_cache_internal.ok()) {
+ // auto resource_init_failed_status = absl::Status(
+ // resource_cache_internal.status().code(),
+ // absl::StrCat("Failed to initialize FileBackedResourceCache: ",
+ // resource_cache_internal.status().ToString()));
+ // if (flags->resource_cache_initialization_error_is_fatal()) {
+ // phase_logger.LogFatalInitializationError(resource_init_failed_status);
+ // return resource_init_failed_status;
+ // }
+ // // We log an error but otherwise proceed as if the cache was disabled.
+ // phase_logger.LogNonfatalInitializationError(resource_init_failed_status);
+ // } else {
+ // resource_cache = std::move(*resource_cache_internal);
+ // }
+ // }
+ std::unique_ptr<::fcp::client::http::HttpClient> http_client =
+ flags->enable_grpc_with_http_resource_support() ||
+ flags->use_http_federated_compute_protocol()
+ ? env_deps->CreateHttpClient()
+ : nullptr;
+ std::unique_ptr<FederatedProtocol> federated_protocol;
+ if (flags->use_http_federated_compute_protocol()) {
+ log_manager->LogDiag(ProdDiagCode::HTTP_FEDERATED_PROTOCOL_USED);
+ // Verify the entry point uri starts with "https://" or
+ // "http://localhost". Note "http://localhost" is allowed for testing
+ // purpose.
+ if (!(absl::StartsWith(federated_service_uri, "https://") ||
+ absl::StartsWith(federated_service_uri, "http://localhost"))) {
+ return absl::InvalidArgumentError("The entry point uri is invalid.");
+ }
+ federated_protocol = std::make_unique<http::HttpFederatedProtocol>(
+ clock, log_manager, flags, http_client.get(),
+ std::make_unique<SecAggRunnerFactoryImpl>(),
+ event_publisher->secagg_event_publisher(), federated_service_uri,
+ api_key, population_name, retry_token, client_version,
+ attestation_measurement, should_abort_protocol_callback, absl::BitGen(),
+ timing_config, resource_cache.get());
+ } else {
+#ifdef FCP_CLIENT_SUPPORT_GRPC
+ // Check in with the server to either retrieve a plan + initial
+ // checkpoint, or get rejected with a RetryWindow.
+ auto grpc_channel_deadline = flags->grpc_channel_deadline_seconds();
+ if (grpc_channel_deadline <= 0) {
+ grpc_channel_deadline = 600;
+ FCP_LOG(INFO) << "Using default channel deadline of "
+ << grpc_channel_deadline << " seconds.";
+ }
+ federated_protocol = std::make_unique<GrpcFederatedProtocol>(
+ event_publisher, log_manager,
+ std::make_unique<SecAggRunnerFactoryImpl>(), flags, http_client.get(),
+ federated_service_uri, api_key, test_cert_path, population_name,
+ retry_token, client_version, attestation_measurement,
+ should_abort_protocol_callback, timing_config, grpc_channel_deadline,
+ resource_cache.get());
+#else
+ return absl::InternalError("No FederatedProtocol enabled.");
+#endif
+ }
+ std::unique_ptr<FederatedSelectManager> federated_select_manager;
+ if (http_client != nullptr && flags->enable_federated_select()) {
+ federated_select_manager = std::make_unique<HttpFederatedSelectManager>(
+ log_manager, files, http_client.get(), should_abort_protocol_callback,
+ timing_config);
+ } else {
+ federated_select_manager =
+ std::make_unique<DisabledFederatedSelectManager>(log_manager);
+ }
+ return RunFederatedComputation(env_deps, phase_logger, event_publisher, files,
+ log_manager, opstats_logger.get(), flags,
+ federated_protocol.get(),
+ federated_select_manager.get(), timing_config,
+ reference_time, session_name, population_name);
+}
+absl::StatusOr<FLRunnerResult> RunFederatedComputation(
+ SimpleTaskEnvironment* env_deps, PhaseLogger& phase_logger,
+ EventPublisher* event_publisher, Files* files, LogManager* log_manager,
+ OpStatsLogger* opstats_logger, const Flags* flags,
+ FederatedProtocol* federated_protocol,
+ FederatedSelectManager* fedselect_manager,
+ const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
+ const absl::Time reference_time, const std::string& session_name,
+ const std::string& population_name) {
+ SelectorContext federated_selector_context;
+ federated_selector_context.mutable_computation_properties()->set_session_name(
+ session_name);
+ FederatedComputation federated_computation;
+ federated_computation.set_population_name(population_name);
+ *federated_selector_context.mutable_computation_properties()
+ ->mutable_federated() = federated_computation;
+ SelectorContext eligibility_selector_context;
+ eligibility_selector_context.mutable_computation_properties()
+ ->set_session_name(session_name);
+ EligibilityEvalComputation eligibility_eval_computation;
+ eligibility_eval_computation.set_population_name(population_name);
+ *eligibility_selector_context.mutable_computation_properties()
+ ->mutable_eligibility_eval() = eligibility_eval_computation;
+ // Construct a default FLRunnerResult that reflects an unsuccessful training
+ // attempt and which uses RetryWindow corresponding to transient errors (if
+ // the flag is on).
+ // This is what will be returned if we have to bail early, before we've
+ // received a RetryWindow from the server.
+ FLRunnerResult fl_runner_result;
+ fl_runner_result.set_contribution_result(FLRunnerResult::FAIL);
+ // Before we even check whether we should abort right away, update the retry
+ // window. That way we will use the most appropriate retry window we have
+ // available (an implementation detail of FederatedProtocol, but generally a
+ // 'transient error' retry window based on the provided flag values) in case
+ // we do need to abort.
+ UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
+ phase_logger, fl_runner_result);
+ // Check if the device conditions allow for checking in with the server
+ // and running a federated computation. If not, bail early with the
+ // transient error retry window.
+ std::function<bool()> should_abort = [env_deps, &timing_config]() {
+ return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
+ };
+ if (should_abort()) {
+ std::string message =
+ "Device conditions not satisfied, aborting federated computation";
+ FCP_LOG(INFO) << message;
+ phase_logger.LogTaskNotStarted(message);
+ return fl_runner_result;
+ }
+ // Eligibility eval plans can use example iterators from the
+ // SimpleTaskEnvironment and those reading the OpStats DB.
+ opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
+ opstats_logger, log_manager,
+ flags->opstats_last_successful_contribution_criteria());
+ std::unique_ptr<engine::ExampleIteratorFactory>
+ env_eligibility_example_iterator_factory =
+ CreateSimpleTaskEnvironmentIteratorFactory(
+ env_deps, eligibility_selector_context);
+ std::vector<engine::ExampleIteratorFactory*>
+ eligibility_example_iterator_factories{
+ &opstats_example_iterator_factory,
+ env_eligibility_example_iterator_factory.get()};
+ // Note that this method will update fl_runner_result's fields with values
+ // received over the course of the eligibility eval protocol interaction.
+ absl::StatusOr<EligibilityEvalResult> eligibility_eval_result =
+ IssueEligibilityEvalCheckinAndRunPlan(
+ eligibility_example_iterator_factories, should_abort, phase_logger,
+ files, log_manager, opstats_logger, flags, federated_protocol,
+ timing_config, reference_time, fl_runner_result);
+ if (!eligibility_eval_result.ok()) {
+ return fl_runner_result;
+ }
+ auto checkin_result =
+ IssueCheckin(phase_logger, log_manager, files, federated_protocol,
+ std::move(eligibility_eval_result->task_eligibility_info),
+ reference_time, population_name, fl_runner_result, flags);
+ if (!checkin_result.ok()) {
+ return fl_runner_result;
+ }
+ SelectorContext federated_selector_context_with_task_name =
+ federated_selector_context;
+ federated_selector_context_with_task_name.mutable_computation_properties()
+ ->mutable_federated()
+ ->set_task_name(checkin_result->task_name);
+ if (flags->enable_computation_id()) {
+ federated_selector_context_with_task_name.mutable_computation_properties()
+ ->mutable_federated()
+ ->set_computation_id(checkin_result->computation_id);
+ }
+ if (checkin_result->plan.phase().has_example_query_spec()) {
+ federated_selector_context_with_task_name.mutable_computation_properties()
+ ->set_example_iterator_output_format(
+ ::fcp::client::QueryTimeComputationProperties::
+ EXAMPLE_QUERY_RESULT);
+ }
+ // Include the last successful contribution timestamp in the
+ // SelectorContext.
+ const auto& opstats_db = opstats_logger->GetOpStatsDb();
+ if (opstats_db != nullptr) {
+ absl::StatusOr<opstats::OpStatsSequence> data = opstats_db->Read();
+ if (data.ok()) {
+ std::optional<google::protobuf::Timestamp>
+ last_successful_contribution_time =
+ opstats::GetLastSuccessfulContributionTime(
+ *data, checkin_result->task_name);
+ if (last_successful_contribution_time.has_value()) {
+ *(federated_selector_context_with_task_name
+ .mutable_computation_properties()
+ ->mutable_federated()
+ ->mutable_historical_context()
+ ->mutable_last_successful_contribution_time()) =
+ *last_successful_contribution_time;
+ }
+ }
+ }
+ if (checkin_result->plan.phase().has_example_query_spec()) {
+ // Example query plan only supports simple agg for now.
+ *(federated_selector_context_with_task_name
+ .mutable_computation_properties()
+ ->mutable_federated()
+ ->mutable_simple_aggregation()) = SimpleAggregation();
+ } else {
+ const auto& federated_compute_io_router =
+ checkin_result->plan.phase().federated_compute();
+ const bool has_simpleagg_tensors =
+ !federated_compute_io_router.output_filepath_tensor_name().empty();
+ bool all_aggregations_are_secagg = true;
+ for (const auto& aggregation : federated_compute_io_router.aggregations()) {
+ all_aggregations_are_secagg &=
+ aggregation.second.protocol_config_case() ==
+ AggregationConfig::kSecureAggregation;
+ }
+ if (!has_simpleagg_tensors && all_aggregations_are_secagg) {
+ federated_selector_context_with_task_name
+ .mutable_computation_properties()
+ ->mutable_federated()
+ ->mutable_secure_aggregation()
+ ->set_minimum_clients_in_server_visible_aggregate(
+ checkin_result->minimum_clients_in_server_visible_aggregate);
+ } else {
+ // Has an output checkpoint, so some tensors must be simply aggregated.
+ *(federated_selector_context_with_task_name
+ .mutable_computation_properties()
+ ->mutable_federated()
+ ->mutable_simple_aggregation()) = SimpleAggregation();
+ }
+ }
+ RetryWindow report_retry_window;
+ phase_logger.LogComputationStarted();
+ absl::Time run_plan_start_time = absl::Now();
+ NetworkStats run_plan_start_network_stats =
+ GetCumulativeNetworkStats(federated_protocol, fedselect_manager);
+ absl::StatusOr<std::string> checkpoint_output_filename =
+ files->CreateTempFile("output", ".ckp");
+ if (!checkpoint_output_filename.ok()) {
+ auto status = checkpoint_output_filename.status();
+ auto message = absl::StrCat(
+ "Could not create temporary output checkpoint file: code: ",
+ status.code(), ", message: ", status.message());
+ phase_logger.LogComputationIOError(
+ status, ExampleStats(),
+ GetNetworkStatsSince(federated_protocol, fedselect_manager,
+ run_plan_start_network_stats),
+ run_plan_start_time);
+ return fl_runner_result;
+ }
+ // Regular plans can use example iterators from the SimpleTaskEnvironment,
+ // those reading the OpStats DB, or those serving Federated Select slices.
+ std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
+ CreateSimpleTaskEnvironmentIteratorFactory(
+ env_deps, federated_selector_context_with_task_name);
+ std::unique_ptr<::fcp::client::engine::ExampleIteratorFactory>
+ fedselect_example_iterator_factory =
+ fedselect_manager->CreateExampleIteratorFactoryForUriTemplate(
+ checkin_result->federated_select_uri_template);
+ std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
+ fedselect_example_iterator_factory.get(),
+ &opstats_example_iterator_factory, env_example_iterator_factory.get()};
+ PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
+ checkin_result->plan.phase().has_example_query_spec()
+ ? RunPlanWithExampleQuerySpec(
+ example_iterator_factories, opstats_logger, flags,
+ checkin_result->plan, *checkpoint_output_filename)
+ : RunPlanWithTensorflowSpec(
+ example_iterator_factories, should_abort, log_manager,
+ opstats_logger, flags, checkin_result->plan,
+ checkin_result->checkpoint_input_filename,
+ *checkpoint_output_filename, timing_config);
+ // Update the FLRunnerResult fields to account for any network usage during
+ // the execution of the plan (e.g. due to Federated Select slices having
+ // been fetched).
+ UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
+ phase_logger, fl_runner_result);
+ auto outcome = plan_result_and_checkpoint_file.plan_result.outcome;
+ absl::StatusOr<ComputationResults> computation_results;
+ if (outcome == engine::PlanOutcome::kSuccess) {
+ computation_results = CreateComputationResults(
+ checkin_result->plan.phase().has_example_query_spec()
+ ? nullptr
+ : &checkin_result->plan.phase().tensorflow_spec(),
+ plan_result_and_checkpoint_file);
+ }
+ LogComputationOutcome(
+ plan_result_and_checkpoint_file.plan_result, computation_results.status(),
+ phase_logger,
+ GetNetworkStatsSince(federated_protocol, fedselect_manager,
+ run_plan_start_network_stats),
+ run_plan_start_time, reference_time);
+ absl::Status report_result = ReportPlanResult(
+ federated_protocol, phase_logger, std::move(computation_results),
+ run_plan_start_time, reference_time);
+ if (outcome == engine::PlanOutcome::kSuccess && report_result.ok()) {
+ // Only if training succeeded *and* reporting succeeded do we consider
+ // the device to have contributed successfully.
+ fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS);
+ }
+ // Update the FLRunnerResult fields one more time to account for the
+ // "Report" protocol interaction.
+ UpdateRetryWindowAndNetworkStats(*federated_protocol, fedselect_manager,
+ phase_logger, fl_runner_result);
+ return fl_runner_result;
+}
+FLRunnerTensorflowSpecResult RunPlanWithTensorflowSpecForTesting(
+ SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
+ Files* files, LogManager* log_manager, const Flags* flags,
+ const ClientOnlyPlan& client_plan,
+ const std::string& checkpoint_input_filename,
+ const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
+ const absl::Time run_plan_start_time, const absl::Time reference_time) {
+ FLRunnerTensorflowSpecResult result;
+ result.set_outcome(engine::PhaseOutcome::ERROR);
+ engine::PlanResult plan_result(engine::PlanOutcome::kTensorflowError,
+ absl::UnknownError(""));
+ std::function<bool()> should_abort = [env_deps, &timing_config]() {
+ return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period);
+ };
+ auto opstats_logger =
+ engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager,
+ /*session_name=*/"", /*population_name=*/"");
+ PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
+ log_manager, flags);
+ // Regular plans can use example iterators from the SimpleTaskEnvironment,
+ // those reading the OpStats DB, or those serving Federated Select slices.
+ // However, we don't provide a Federated Select-specific example iterator
+ // factory. That way, the Federated Select slice queries will be forwarded
+ // to SimpleTaskEnvironment, which can handle them by providing
+ // test-specific slices if they want to.
+ //
+ // Eligibility eval plans can only use iterators from the
+ // SimpleTaskEnvironment and those reading the OpStats DB.
+ opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
+ opstats_logger.get(), log_manager,
+ flags->opstats_last_successful_contribution_criteria());
+ std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
+ CreateSimpleTaskEnvironmentIteratorFactory(env_deps, SelectorContext());
+ std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
+ &opstats_example_iterator_factory, env_example_iterator_factory.get()};
+ phase_logger.LogComputationStarted();
+ if (client_plan.phase().has_federated_compute()) {
+ absl::StatusOr<std::string> checkpoint_output_filename =
+ files->CreateTempFile("output", ".ckp");
+ if (!checkpoint_output_filename.ok()) {
+ phase_logger.LogComputationIOError(
+ checkpoint_output_filename.status(), ExampleStats(),
+ // Empty network stats, since no network protocol is actually used
+ // in this method.
+ NetworkStats(), run_plan_start_time);
+ return result;
+ }
+ // Regular TensorflowSpec-based plans.
+ PlanResultAndCheckpointFile plan_result_and_checkpoint_file =
+ RunPlanWithTensorflowSpec(example_iterator_factories, should_abort,
+ log_manager, opstats_logger.get(), flags,
+ client_plan, checkpoint_input_filename,
+ *checkpoint_output_filename, timing_config);
+ result.set_checkpoint_output_filename(
+ plan_result_and_checkpoint_file.checkpoint_file);
+ plan_result = std::move(plan_result_and_checkpoint_file.plan_result);
+ } else if (client_plan.phase().has_federated_compute_eligibility()) {
+ // Eligibility eval plans.
+ plan_result = RunEligibilityEvalPlanWithTensorflowSpec(
+ example_iterator_factories, should_abort, log_manager,
+ opstats_logger.get(), flags, client_plan, checkpoint_input_filename,
+ timing_config, run_plan_start_time, reference_time);
+ } else {
+ // This branch shouldn't be taken, unless we add an additional type of
+ // TensorflowSpec-based plan in the future. We return a readable error so
+ // that when such new plan types *are* added, they result in clear
+ // compatibility test failures when such plans are erroneously targeted at
+ // old releases that don't support them yet.
+ event_publisher->PublishIoError("Unsupported TensorflowSpec-based plan");
+ return result;
+ }
+ // Copy output tensors into the result proto.
+ result.set_outcome(
+ engine::ConvertPlanOutcomeToPhaseOutcome(plan_result.outcome));
+ if (plan_result.outcome == engine::PlanOutcome::kSuccess) {
+ for (int i = 0; i < plan_result.output_names.size(); i++) {
+ tensorflow::TensorProto output_tensor_proto;
+ plan_result.output_tensors[i].AsProtoField(&output_tensor_proto);
+ (*result.mutable_output_tensors())[plan_result.output_names[i]] =
+ std::move(output_tensor_proto);
+ }
+ phase_logger.LogComputationCompleted(
+ plan_result.example_stats,
+ // Empty network stats, since no network protocol is actually used in
+ // this method.
+ NetworkStats(), run_plan_start_time, reference_time);
+ } else {
+ phase_logger.LogComputationTensorflowError(
+ plan_result.original_status, plan_result.example_stats, NetworkStats(),
+ run_plan_start_time, reference_time);
+ }
+ return result;
+}
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/fl_runner.h b/fcp/client/fl_runner.h
new file mode 100644
index 0000000..0304c73
--- /dev/null
+++ b/fcp/client/fl_runner.h
@@ -0,0 +1,112 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_FL_RUNNER_H_
+#define FCP_CLIENT_FL_RUNNER_H_
+
+#include <string>
+
+#include "absl/status/statusor.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/federated_select.h"
+#include "fcp/client/files.h"
+#include "fcp/client/fl_runner.pb.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/phase_logger.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/protos/plan.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace fcp {
+namespace client {
+
+inline constexpr absl::string_view kTensorflowCheckpointAggregand =
+ "tensorflow_checkpoint";
+
+// Prod entry point for running a federated computation. Concurrent calls, with
+// the same SimpleTaskEnvironment::GetBaseDir(), are not supported.
+//
+// This is a long running blocking call that - for a successful run -
+// encompasses connecting to a server, downloading and running a computation,
+// uploading results, and storing logs about the run in an operational stats DB.
+// During that call, the function will call back (from both the calling and from
+// newly created threads) into the dependencies injected here for to query for
+// examples, check whether it should abort, publish events / logs for telemetry,
+// create files, and query feature flags.
+//
+// Arguments:
+// - federated_service_uri, api_key: used to connect to the Federated server.
+// - test_cert_path: a file path to a CA certificate to be used in tests. Should
+// be empty for production use; when used in tests, the URI must use the
+// https+test:// scheme.
+// - session_name: A client-side identifier of the type of work this computation
+// performs; used to annotate log entries in the operational stats DB.
+// - population_name: a string provided to the Federated server to identify
+// what population this device is checking in for.
+// - client_version: A platform-specific identifier that is used by the server
+// to serve versioned computations - that is, versions of a computation that
+// have been tested and found to be compatible with this device's version -
+// or reject the device.
+// - attestation_measurement: An opaque string from a "measurement" that can be
+// used
+// by the server to attest the device integrity.
+//
+// Returns:
+// On success, the returned FLRunnerResult contains information on when the
+// function should be called again for this session.
+absl::StatusOr<FLRunnerResult> RunFederatedComputation(
+ SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
+ Files* files, LogManager* log_manager, const Flags* flags,
+ const std::string& federated_service_uri, const std::string& api_key,
+ const std::string& test_cert_path, const std::string& session_name,
+ const std::string& population_name, const std::string& retry_token,
+ const std::string& client_version,
+ const std::string& attestation_measurement);
+
+// This is exposed for use in tests that require a mocked FederatedProtocol and
+// OpStatsLogger. Otherwise, this is used internally by the other
+// RunFederatedComputation method once the FederatedProtocol and OpStatsLogger
+// objects have been created.
+absl::StatusOr<FLRunnerResult> RunFederatedComputation(
+ SimpleTaskEnvironment* env_deps, PhaseLogger& phase_logger,
+ EventPublisher* event_publisher, Files* files, LogManager* log_manager,
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger, const Flags* flags,
+ FederatedProtocol* federated_protocol,
+ FederatedSelectManager* fedselect_manager,
+ const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
+ const absl::Time reference_time, const std::string& session_name,
+ const std::string& population_name);
+
+// This is exposed for use in compatibility tests only. Prod code should call
+// RunFederatedComputation.
+FLRunnerTensorflowSpecResult RunPlanWithTensorflowSpecForTesting(
+ SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
+ Files* files, LogManager* log_manager, const Flags* flags,
+ const google::internal::federated::plan::ClientOnlyPlan& client_plan,
+ const std::string& checkpoint_input_filename,
+ const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
+ const absl::Time run_plan_start_time, const absl::Time reference_time);
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_FL_RUNNER_H_
diff --git a/fcp/client/fl_runner.proto b/fcp/client/fl_runner.proto
new file mode 100644
index 0000000..849630f
--- /dev/null
+++ b/fcp/client/fl_runner.proto
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+syntax = "proto3";
+
+package fcp.client;
+
+import "google/protobuf/duration.proto";
+import "fcp/client/engine/engine.proto";
+import "tensorflow/core/framework/tensor.proto";
+
+option java_package = "com.google.intelligence.fcp.client";
+option java_multiple_files = true;
+
+/**
+ * This protocol buffer is used to report results and statistics of a Federated
+ * Computation - including checking in with the server, running a plan, and
+ * reporting back results - to the caller. It is a protocol buffer to support
+ * sending it across language boundaries.
+ */
+message FLRunnerResult {
+ reserved 1;
+ // A RetryInfo returned to the caller for consideration in scheduling future
+ // runs of this task.
+ RetryInfo retry_info = 4;
+ // An enum that summarizes whether the client has contributed to an FL/FA
+ // round.
+ enum ContributionResult {
+ UNSPECIFIED = 0;
+ SUCCESS = 1;
+ // Any outcome that is not a success.
+ FAIL = 2;
+ }
+
+ ContributionResult contribution_result = 5;
+ reserved 2, 3;
+}
+
+// A suggestion to the client when to retry the connection to the service next
+// time
+message RetryInfo {
+ // Optional. If set, should be provided back to the next
+ // RunFederatedComputation invocation.
+ string retry_token = 1;
+
+ // The suggested delay duration after which the client should
+ // retry. Clients should ideally not retry any earlier than this.
+ google.protobuf.Duration minimum_delay = 2;
+}
+
+/**
+ * This protocol buffer is used to pass TensorflowSpec-based plan outputs across
+ * the JNI boundary so they can be accessed in compatibility tests.
+ */
+message FLRunnerTensorflowSpecResult {
+ // The outcome of running the plan.
+ engine.PhaseOutcome outcome = 1;
+ // The location of the output checkpoint file, if one was created.
+ string checkpoint_output_filename = 2;
+ // A map of output tensor names and values, if any.
+ map<string, tensorflow.TensorProto> output_tensors = 3;
+}
diff --git a/fcp/client/flags.h b/fcp/client/flags.h
new file mode 100644
index 0000000..136a49b
--- /dev/null
+++ b/fcp/client/flags.h
@@ -0,0 +1,216 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_FLAGS_H_
+#define FCP_CLIENT_FLAGS_H_
+
+#include <cstdint>
+#include <string>
+
+#include "absl/status/status.h"
+
+namespace fcp {
+namespace client {
+
+// A class for changing runtime behavior with "flags" - typically, server
+// provided values.
+class Flags {
+ public:
+ virtual ~Flags() = default;
+
+ // The period of time in milliseconds between device condition checks. This is
+ // used during potentially long blocking calls such as TensorFlow or network
+ // I/O, as well as for throttling regular condition checks during plan
+ // execution (e.g. before fetching a new example).
+ virtual int64_t condition_polling_period_millis() const = 0;
+
+ // The period of time in milliseconds allowed for TensorFlow execution to
+ // finish after it's been interrupted.
+ virtual int64_t tf_execution_teardown_grace_period_millis() const = 0;
+
+ // The period of time in milliseconds allowed for TensorFlow execution to
+ // finish after the grace period. This allows us to decide if we want long
+ // running native execution to be forcibly resolved or continue indefinitely.
+ virtual int64_t tf_execution_teardown_extended_period_millis() const = 0;
+
+ // The deadline in seconds for the gRPC channel used for communication
+ // between the client and server.
+ virtual int64_t grpc_channel_deadline_seconds() const = 0;
+
+ // Whether to log the error message strings from TensorFlow exceptions.
+ virtual bool log_tensorflow_error_messages() const = 0;
+
+ // Whether to enable recording to and querying from the Operational Statistics
+ // db.
+ virtual bool enable_opstats() const { return true; }
+
+ // The number of days for data to live in the OpStatsDb without update.
+ virtual int64_t opstats_ttl_days() const { return 30; }
+
+ // The maximum size of the data stored by OpStatsDb.
+ virtual int64_t opstats_db_size_limit_bytes() const {
+ return 1 * 1024 * 1024;
+ }
+
+ // The retry delay to use when encountering a transient error during a
+ // training run before having received a RetryWindow from the server.
+ virtual int64_t federated_training_transient_errors_retry_delay_secs() const {
+ // 15 minutes
+ return 15 * 60;
+ }
+
+ // The amount of jitter to apply when using the
+ // `federated_training_transient_errors_retry_delay_secs` flag. Must be a
+ // value between 0 and 1. E.g. a value of 0.2 means that retry delays will
+ // fall within [0.8 * target period, 1.2 * target period).
+ virtual float federated_training_transient_errors_retry_delay_jitter_percent()
+ const {
+ return 0.2;
+ }
+
+ // The retry delay to use when encountering a permanent error during a
+ // training run (regardless of whether the client already received a
+ // RetryWindow from the server).
+ virtual int64_t federated_training_permanent_errors_retry_delay_secs() const {
+ // 4 hours
+ return 4 * 60 * 60;
+ }
+
+ // The amount of jitter to apply when using the
+ // `federated_training_permanent_errors_retry_delay_secs` flag. Must be a
+ // value between 0 and 1. E.g. a value of 0.2 means that retry delays will
+ // fall within [0.8 * target period, 1.2 * target period).
+ virtual float federated_training_permanent_errors_retry_delay_jitter_percent()
+ const {
+ return 0.2;
+ }
+
+ // The list of error codes that should be considered 'permanent'.
+ virtual std::vector<int32_t> federated_training_permanent_error_codes()
+ const {
+ return {
+ // The server returns NOT_FOUND if the client checks in with an unknown
+ // population name. While this can be resolved without any client
+ // changes by creating the population server-side, it is nevertheless
+ // wise to treat this as a 'permanent' error for which a longer
+ // RetryPeriod is used, because such temporary mismatches in
+ // client/server configuration are fairly common and otherwise cause
+ // clients to check in unnecessarily frequently.
+ static_cast<int32_t>(absl::StatusCode::kNotFound),
+ // INVALID_ARGUMENT generally indicates a client-side issue (e.g. a bug
+ // in the client's protocol implementation), which is unlikely to be
+ // resolved by merely retrying the request.
+ static_cast<int32_t>(absl::StatusCode::kInvalidArgument),
+ // UNIMPLEMENTED similarly could indicate a client-side issue, or a
+ // temporary server issue (e.g. a bug/missing feature implementation in
+ // the server). Either way, it is also unlikely to be resolved by merely
+ // retrying the request soon.
+ static_cast<int32_t>(absl::StatusCode::kUnimplemented)};
+ }
+
+ // Whether use TFLite for training.
+ virtual bool use_tflite_training() const { return false; }
+
+ // Whether to enable support for downloading plan/initial checkpoint resources
+ // via HTTP, while still using gRPC for the main protocol.
+ virtual bool enable_grpc_with_http_resource_support() const { return false; }
+
+ // Whether to enable support for downloading eligibility eval plan/initial
+ // checkpoint resources via HTTP, while still using gRPC for the main
+ // protocol.
+ virtual bool enable_grpc_with_eligibility_eval_http_resource_support() const {
+ return false;
+ }
+
+ // When true, TFLite interpreter will use dynamic memory allocation, and
+ // release the memory for tensors that are no longer needed.
+ virtual bool ensure_dynamic_tensors_are_released() const { return true; }
+
+ // When the value is above zero, any tensor size (bytes) above the threshold
+ // will be considered as a large tensor, and dynamic allocation is applied on
+ // them.
+ virtual int32_t large_tensor_threshold_for_dynamic_allocation() const {
+ return 1000;
+ }
+
+ // When true, the TFLite runtime graph-reordering optimization that clusters
+ // delegate nodes together is disabled.
+ virtual bool disable_tflite_delegate_clustering() const { return false; }
+
+ // When true, http request body won't be compressed.
+ virtual bool disable_http_request_body_compression() const { return false; }
+
+ // When true, HTTP Federated Compute protocol is used.
+ virtual bool use_http_federated_compute_protocol() const { return false; }
+
+ // When true, the client computes the task identity to pass in
+ // SelectorContext.
+ virtual bool enable_computation_id() const { return false; }
+
+ // The waiting period for issuing cancellation requests before checking
+ // whether the client should be interrupted.
+ virtual int32_t waiting_period_sec_for_cancellation() const { return 10; }
+
+ // If true, the client supports the Federated Select feature. If not
+ // then any Federated Select-specific example query will fail with an error
+ virtual bool enable_federated_select() const { return false; }
+
+ // The max size in bytes of resources that the ResourceCache is allowed to
+ // store. If greater than 0, the client will attempt to cache resources sent
+ // by uri via the hybrid grpc-with-http-resources and the full http stack. If
+ // this value is reduced from some previous greater value, the cache dir will
+ // be reduced appropriately the next time it is initialized at the start of
+ // the next run.
+ virtual int64_t max_resource_cache_size_bytes() const { return 0; }
+
+ // If true, an error during the initialization of the resource cache will
+ // publish a fatal initialization error instead of a nonfatal initialization
+ // error and halt execution.
+ virtual bool resource_cache_initialization_error_is_fatal() const {
+ return false;
+ }
+
+ // The number of threads that TFLite interpreter will use.
+ virtual int32_t num_threads_for_tflite() const { return 1; }
+
+ // If true, Opstats initialization errors will be logged via
+ // PhaseLogger.LogNonfatalInitializationError(). Execution will still be
+ // allowed to continue with a no-op implementation like before.
+ virtual bool log_opstats_initialization_errors() const { return false; }
+
+ // If true, enables the last_successful_contribution option in the opstats
+ // selection criteria which returns an opstats entry for the last successful
+ // contribution for the currently executing task.
+ virtual bool opstats_last_successful_contribution_criteria() const {
+ return false;
+ }
+
+ // If true, enables support for the `TensorflowSpec.constant_inputs` field. If
+ // false, then the field will be ignored.
+ virtual bool support_constant_tf_inputs() const { return false; }
+
+ // If true, enables an Example Query plan engine to be invoked for
+ // non-TensorFlow tasks.
+ virtual bool enable_example_query_plan_engine() const { return false; }
+
+ // If true, the HTTP federated protocol supports multiple task assignments.
+ virtual bool http_protocol_supports_multiple_task_assignments() const {
+ return false;
+ }
+};
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_FLAGS_H_
diff --git a/fcp/client/grpc_bidi_channel.h b/fcp/client/grpc_bidi_channel.h
new file mode 100644
index 0000000..864a22f
--- /dev/null
+++ b/fcp/client/grpc_bidi_channel.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_GRPC_BIDI_CHANNEL_H_
+#define FCP_CLIENT_GRPC_BIDI_CHANNEL_H_
+
+#include <dirent.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "absl/base/attributes.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/strip.h"
+#include "fcp/base/monitoring.h"
+#include "grpcpp/channel.h"
+#include "grpcpp/create_channel.h"
+#include "grpcpp/security/credentials.h"
+
+namespace fcp {
+namespace client {
+
+class GrpcBidiChannel {
+ public:
+ static constexpr int kChannelMaxMessageSize = 20 * 1000 * 1000;
+ static constexpr int kKeepAliveTimeSeconds = 60;
+ static constexpr const absl::string_view kSecurePrefix = "https://";
+ static constexpr const absl::string_view kSecureTestPrefix = "https+test://";
+
+ /**
+ * Create a channel to the remote endpoint.
+ * @param target URI of the remote endpoint.
+ * @param cert_path If the URI is https:// or https+test://, an optional
+ * path to a certificate root file or directory.
+ * @return A shared pointer to the channel interface.
+ */
+ static std::shared_ptr<grpc::ChannelInterface> Create(
+ const std::string& target, std::string cert_path) {
+ bool secure = false;
+ bool test = false;
+ // This double check avoids a dependency on re2:
+ if (absl::StartsWith(target, kSecureTestPrefix)) {
+ secure = true;
+ test = true;
+ } else if (absl::StartsWith(target, kSecurePrefix)) {
+ secure = true;
+ }
+
+ grpc::ChannelArguments channel_arguments;
+ channel_arguments.SetMaxReceiveMessageSize(kChannelMaxMessageSize);
+ channel_arguments.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS,
+ kKeepAliveTimeSeconds * 1000);
+
+ if (!secure)
+ return grpc::CreateCustomChannel(
+ target, grpc::InsecureChannelCredentials(), channel_arguments);
+
+ std::shared_ptr<grpc::ChannelCredentials> channel_creds;
+ grpc::SslCredentialsOptions ssl_opts{};
+
+ if (!cert_path.empty()) {
+ std::ifstream cert_file(cert_path);
+ FCP_LOG_IF(ERROR, cert_file.fail())
+ << "Open for: " << cert_path << " failed: " << strerror(errno);
+ std::stringstream string_stream;
+ if (cert_file) {
+ string_stream << cert_file.rdbuf();
+ }
+ FCP_LOG_IF(WARNING, string_stream.str().empty())
+ << "Cert: " << cert_path << " is empty.";
+ ssl_opts = {string_stream.str(), "", ""};
+ }
+
+ channel_creds = grpc::SslCredentials(ssl_opts);
+ auto target_uri = absl::StrCat(
+ "dns:///",
+ absl::StripPrefix(target, test ? kSecureTestPrefix : kSecurePrefix));
+ FCP_LOG(INFO) << "Creating channel to: " << target_uri;
+ return CreateCustomChannel(target_uri, channel_creds, channel_arguments);
+ }
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_GRPC_BIDI_CHANNEL_H_
diff --git a/fcp/client/grpc_bidi_channel_test.cc b/fcp/client/grpc_bidi_channel_test.cc
new file mode 100644
index 0000000..876f4fe
--- /dev/null
+++ b/fcp/client/grpc_bidi_channel_test.cc
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "fcp/client/grpc_bidi_stream.h"
+#include "fcp/protos/federated_api.grpc.pb.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace test {
+namespace {
+
+using google::internal::federatedml::v2::ClientStreamMessage;
+using google::internal::federatedml::v2::ServerStreamMessage;
+using ::testing::Not;
+
+} // namespace
+} // namespace test
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/grpc_bidi_stream.cc b/fcp/client/grpc_bidi_stream.cc
new file mode 100644
index 0000000..89a2529
--- /dev/null
+++ b/fcp/client/grpc_bidi_stream.cc
@@ -0,0 +1,139 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/grpc_bidi_stream.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/status/status.h"
+#include "fcp/base/status_converters.h"
+#include "fcp/client/grpc_bidi_channel.h"
+#include "grpcpp/support/time.h"
+
+namespace fcp {
+namespace client {
+
+using fcp::base::FromGrpcStatus;
+using google::internal::federatedml::v2::ClientStreamMessage;
+using google::internal::federatedml::v2::FederatedTrainingApi;
+using google::internal::federatedml::v2::ServerStreamMessage;
+using grpc::ChannelInterface;
+
+GrpcBidiStream::GrpcBidiStream(const std::string& target,
+ const std::string& api_key,
+ const std::string& population_name,
+ int64_t grpc_channel_deadline_seconds,
+ std::string cert_path)
+ : GrpcBidiStream(GrpcBidiChannel::Create(target, std::move(cert_path)),
+ api_key, population_name, grpc_channel_deadline_seconds) {}
+
+GrpcBidiStream::GrpcBidiStream(
+ const std::shared_ptr<grpc::ChannelInterface>& channel,
+ const std::string& api_key, const std::string& population_name,
+ int64_t grpc_channel_deadline_seconds)
+ : mu_(), stub_(FederatedTrainingApi::NewStub(channel)) {
+ FCP_LOG(INFO) << "Connecting to stub: " << stub_.get();
+ gpr_timespec deadline = gpr_time_add(
+ gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_seconds(grpc_channel_deadline_seconds, GPR_TIMESPAN));
+ client_context_.set_deadline(deadline);
+ client_context_.AddMetadata(kApiKeyHeader, api_key);
+ client_context_.AddMetadata(kPopulationNameHeader, population_name);
+ client_reader_writer_ = stub_->Session(&client_context_);
+ GrpcChunkedBidiStream<ClientStreamMessage,
+ ServerStreamMessage>::GrpcChunkedBidiStreamOptions
+ options;
+ chunked_bidi_stream_ = std::make_unique<
+ GrpcChunkedBidiStream<ClientStreamMessage, ServerStreamMessage>>(
+ client_reader_writer_.get(), client_reader_writer_.get(), options);
+ if (!channel) Close();
+}
+
+absl::Status GrpcBidiStream::Send(ClientStreamMessage* message) {
+ absl::Status status;
+ {
+ absl::MutexLock _(&mu_);
+ if (client_reader_writer_ == nullptr) {
+ return absl::CancelledError(
+ "Send failed because GrpcBidiStream was closed.");
+ }
+ status = chunked_bidi_stream_->Send(message);
+ if (status.code() == absl::StatusCode::kAborted) {
+ FCP_LOG(INFO) << "Send aborted: " << status.code();
+ auto finish_status = FromGrpcStatus(client_reader_writer_->Finish());
+ // If the connection aborts early or harshly enough, there will be no
+ // error status from Finish().
+ if (!finish_status.ok()) status = finish_status;
+ }
+ }
+ if (!status.ok()) {
+ FCP_LOG(INFO) << "Closing; error on send: " << status.message();
+ Close();
+ }
+ return status;
+}
+
+absl::Status GrpcBidiStream::Receive(ServerStreamMessage* message) {
+ absl::Status status;
+ {
+ absl::MutexLock _(&mu_);
+ if (client_reader_writer_ == nullptr) {
+ return absl::CancelledError(
+ "Receive failed because GrpcBidiStream was closed.");
+ }
+ status = chunked_bidi_stream_->Receive(message);
+ if (status.code() == absl::StatusCode::kAborted) {
+ FCP_LOG(INFO) << "Receive aborted: " << status.code();
+ auto finish_status = FromGrpcStatus(client_reader_writer_->Finish());
+ // If the connection aborts early or harshly enough, there will be no
+ // error status from Finish().
+ if (!finish_status.ok()) status = finish_status;
+ }
+ }
+ if (!status.ok()) {
+ FCP_LOG(INFO) << "Closing; error on receive: " << status.message();
+ Close();
+ }
+ return status;
+}
+
+void GrpcBidiStream::Close() {
+ if (!mu_.TryLock()) {
+ client_context_.TryCancel();
+ mu_.Lock();
+ }
+ chunked_bidi_stream_->Close();
+ if (client_reader_writer_) client_reader_writer_->WritesDone();
+ client_reader_writer_.reset();
+ FCP_LOG(INFO) << "Closing stub: " << stub_.get();
+ stub_.reset();
+ mu_.Unlock();
+}
+
+int64_t GrpcBidiStream::ChunkingLayerBytesReceived() {
+ absl::MutexLock _(&mu_);
+ return chunked_bidi_stream_->ChunkingLayerBytesReceived();
+}
+
+int64_t GrpcBidiStream::ChunkingLayerBytesSent() {
+ absl::MutexLock _(&mu_);
+ return chunked_bidi_stream_->ChunkingLayerBytesSent();
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/grpc_bidi_stream.h b/fcp/client/grpc_bidi_stream.h
new file mode 100644
index 0000000..b21be4c
--- /dev/null
+++ b/fcp/client/grpc_bidi_stream.h
@@ -0,0 +1,160 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_GRPC_BIDI_STREAM_H_
+#define FCP_CLIENT_GRPC_BIDI_STREAM_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/base/attributes.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/status/status.h"
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/protocol/grpc_chunked_bidi_stream.h"
+#include "fcp/protos/federated_api.grpc.pb.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_context.h"
+
+namespace fcp {
+namespace client {
+
+/**
+ * Interface to support dependency injection and hence testing
+ */
+class GrpcBidiStreamInterface {
+ public:
+ virtual ~GrpcBidiStreamInterface() = default;
+
+ virtual ABSL_MUST_USE_RESULT absl::Status Send(
+ google::internal::federatedml::v2::ClientStreamMessage* message) = 0;
+
+ virtual ABSL_MUST_USE_RESULT absl::Status Receive(
+ google::internal::federatedml::v2::ServerStreamMessage* message) = 0;
+
+ virtual void Close() = 0;
+
+ virtual int64_t ChunkingLayerBytesSent() = 0;
+
+ virtual int64_t ChunkingLayerBytesReceived() = 0;
+};
+
+/**
+ * A class which encapsulates a chunking gRPC endpoint for the federated
+ * learning API.
+ *
+ * This class is thread-safe, but note that calls to Send() and Receive() are
+ * serialized *and* blocking.
+ */
+class GrpcBidiStream : public GrpcBidiStreamInterface {
+ public:
+ /**
+ * Create a chunking gRPC endpoint for the federated learning API.
+ * @param target The URI of the target endpoint.
+ * @param api_key The API key of the target endpoint.
+ * @param population_name The population this connection is associated with.
+ * This param will not be empty if the include_population_in_header flag is
+ * False.
+ * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC
+ * channel.
+ * @param cert_path Test-only path to a CA certificate root, to be used in
+ * combination with an "https+test://" URI scheme.
+ */
+ GrpcBidiStream(const std::string& target, const std::string& api_key,
+ const std::string& population_name,
+ int64_t grpc_channel_deadline_seconds,
+ std::string cert_path = "");
+
+ /**
+ * @param channel A preexisting channel to the target endpoint.
+ * @param api_key The API of the target endpoint.
+ * @param population_name The population this connection is associated with.
+ * This param will not be empty if the include_population_in_header flag is
+ * False.
+ * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC
+ * channel.
+ */
+ GrpcBidiStream(const std::shared_ptr<grpc::ChannelInterface>& channel,
+ const std::string& api_key, const std::string& population_name,
+ int64_t grpc_channel_deadline_seconds);
+ ~GrpcBidiStream() override = default;
+
+ // GrpcBidiStream is neither copyable nor movable.
+ GrpcBidiStream(const GrpcBidiStream&) = delete;
+ GrpcBidiStream& operator=(const GrpcBidiStream&) = delete;
+
+ /**
+ * Send a ClientStreamMessage to the remote endpoint.
+ * @param message The message to send.
+ * @return absl::Status, which will have code OK if the message was sent
+ * successfully.
+ */
+ ABSL_MUST_USE_RESULT absl::Status Send(
+ google::internal::federatedml::v2::ClientStreamMessage* message) override
+ ABSL_LOCKS_EXCLUDED(mu_);
+
+ /**
+ * Receive a ServerStreamMessage from the remote endpoint. Blocking.
+ * @param message The message to receive.
+ * @return absl::Status. This may be a translation of the status returned by
+ * the server, or a status generated during execution of the chunking
+ * protocol.
+ */
+ ABSL_MUST_USE_RESULT absl::Status Receive(
+ google::internal::federatedml::v2::ServerStreamMessage* message) override
+ ABSL_LOCKS_EXCLUDED(mu_);
+
+ /**
+ * Close this stream.
+ * Releases any blocked readers. Thread safe.
+ */
+ void Close() override ABSL_LOCKS_EXCLUDED(mu_);
+
+ /**
+ * Returns the number of bytes sent from the chunking layer.
+ * Flow control means this value may not increment until Receive() is called.
+ */
+ int64_t ChunkingLayerBytesSent() override;
+
+ /**
+ * Returns the number of bytes received by the chunking layer.
+ */
+ int64_t ChunkingLayerBytesReceived() override;
+
+ // Note: Must be lowercase:
+ static constexpr char kApiKeyHeader[] = "x-goog-api-key";
+ static constexpr char kPopulationNameHeader[] = "x-goog-population";
+
+ private:
+ absl::Mutex mu_;
+ std::unique_ptr<google::internal::federatedml::v2::FederatedTrainingApi::Stub>
+ stub_;
+ grpc::ClientContext client_context_;
+ std::unique_ptr<grpc::ClientReaderWriter<
+ google::internal::federatedml::v2::ClientStreamMessage,
+ google::internal::federatedml::v2::ServerStreamMessage>>
+ client_reader_writer_ ABSL_GUARDED_BY(mu_);
+ std::unique_ptr<GrpcChunkedBidiStream<
+ google::internal::federatedml::v2::ClientStreamMessage,
+ google::internal::federatedml::v2::ServerStreamMessage>>
+ chunked_bidi_stream_ ABSL_GUARDED_BY(mu_);
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_GRPC_BIDI_STREAM_H_
diff --git a/fcp/client/grpc_bidi_stream_test.cc b/fcp/client/grpc_bidi_stream_test.cc
new file mode 100644
index 0000000..0a23506
--- /dev/null
+++ b/fcp/client/grpc_bidi_stream_test.cc
@@ -0,0 +1,165 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/grpc_bidi_stream.h"
+
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/client/fake_server.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+#include "grpcpp/server_builder.h"
+
+namespace fcp {
+namespace client {
+namespace test {
+namespace {
+
+using google::internal::federatedml::v2::ClientStreamMessage;
+using google::internal::federatedml::v2::ServerStreamMessage;
+using ::testing::Contains;
+using ::testing::Not;
+using ::testing::Pair;
+
+class GrpcBidiStreamTest : public testing::Test {
+ protected:
+ void SetUp() override { BuildAndStartServer(); }
+
+ void TearDown() override {
+ server_->Shutdown();
+ server_->Wait();
+ }
+
+ void CreateClient(const std::string& population_name = "") {
+ client_stream_ = std::make_unique<GrpcBidiStream>(
+ absl::StrCat("dns:///localhost", ":", port_), "none", population_name,
+ /* grpc_channel_deadline_seconds=*/600);
+ FCP_LOG(INFO) << "Client created." << std::endl;
+ }
+
+ std::unique_ptr<GrpcBidiStream> client_stream_;
+ FakeServer server_impl_;
+
+ private:
+ void BuildAndStartServer() {
+ grpc::ServerBuilder builder;
+ builder.AddListeningPort("dns:///localhost:0",
+ grpc::InsecureServerCredentials(), &port_);
+ builder.RegisterService(&server_impl_);
+ server_ = builder.BuildAndStart();
+ }
+ // Variables that must be in scope for the lifetime of a test but are not
+ // used by test code.
+ int port_ = 0;
+ std::unique_ptr<grpc::Server> server_;
+};
+
+TEST_F(GrpcBidiStreamTest, ClientContainsPopulationMetadata) {
+ CreateClient("population_name");
+ ClientStreamMessage request;
+ request.mutable_checkin_request();
+ EXPECT_THAT(client_stream_->Send(&request), IsOk());
+ ServerStreamMessage reply;
+ EXPECT_THAT(client_stream_->Receive(&reply), IsOk());
+ EXPECT_TRUE(reply.has_checkin_response()) << reply.DebugString();
+ EXPECT_THAT(server_impl_.GetClientMetadata(),
+ Contains(Pair(GrpcBidiStream::kApiKeyHeader, "none")));
+ EXPECT_THAT(
+ server_impl_.GetClientMetadata(),
+ Contains(Pair(GrpcBidiStream::kPopulationNameHeader, "population_name")));
+ client_stream_->Close();
+ server_impl_.WaitForSessionDone();
+}
+
+TEST_F(GrpcBidiStreamTest, CancellationDuringBlockingOp) {
+ CreateClient();
+ auto pool = CreateThreadPoolScheduler(1);
+ pool->Schedule([this]() {
+ sleep(1);
+ client_stream_->Close();
+ });
+ ServerStreamMessage reply;
+ auto start = absl::Now();
+ // Will block indefinitely, as the default FakeServer requires a request
+ // before sending a response.
+ EXPECT_THAT(client_stream_->Receive(&reply),
+ IsCode(absl::StatusCode::kCancelled));
+ EXPECT_GE(absl::Now() - start, absl::Seconds(1));
+
+ server_impl_.WaitForSessionDone();
+
+ // Idempotency check:
+ client_stream_->Close();
+ EXPECT_THAT(client_stream_->Receive(&reply), Not(IsOk()));
+ pool->WaitUntilIdle();
+}
+
+TEST_F(GrpcBidiStreamTest, CancellationBeforeSend) {
+ CreateClient();
+ absl::Status status;
+ client_stream_->Close();
+ server_impl_.WaitForSessionDone();
+ ClientStreamMessage request;
+ request.mutable_checkin_request();
+ EXPECT_THAT(client_stream_->Send(&request),
+ IsCode(absl::StatusCode::kCancelled));
+}
+
+TEST_F(GrpcBidiStreamTest, CancellationBeforeReceive) {
+ CreateClient();
+ ClientStreamMessage request;
+ request.mutable_checkin_request();
+ EXPECT_THAT(client_stream_->Send(&request), IsOk());
+ client_stream_->Close();
+ server_impl_.WaitForSessionDone();
+ ServerStreamMessage reply;
+ EXPECT_THAT(client_stream_->Receive(&reply),
+ IsCode(absl::StatusCode::kCancelled));
+ // Idempotency check:
+ EXPECT_THAT(client_stream_->Receive(&reply),
+ IsCode(absl::StatusCode::kCancelled));
+}
+
+TEST_F(GrpcBidiStreamTest, CancellationWithoutBlockingOp) {
+ CreateClient();
+ ClientStreamMessage request;
+ request.mutable_checkin_request();
+ EXPECT_THAT(client_stream_->Send(&request), IsOk());
+ ServerStreamMessage reply;
+ EXPECT_THAT(client_stream_->Receive(&reply), IsOk());
+ EXPECT_TRUE(reply.has_checkin_response()) << reply.DebugString();
+ EXPECT_THAT(server_impl_.GetClientMetadata(),
+ Contains(Pair(GrpcBidiStream::kApiKeyHeader, "none")));
+ EXPECT_THAT(server_impl_.GetClientMetadata(),
+ Contains(Pair(GrpcBidiStream::kPopulationNameHeader, "")));
+
+ client_stream_->Close();
+ server_impl_.WaitForSessionDone();
+}
+
+} // namespace
+} // namespace test
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/grpc_federated_protocol.cc b/fcp/client/grpc_federated_protocol.cc
new file mode 100644
index 0000000..c1147f4
--- /dev/null
+++ b/fcp/client/grpc_federated_protocol.cc
@@ -0,0 +1,1074 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/grpc_federated_protocol.h"
+
+#include <algorithm>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+
+#include "google/protobuf/duration.pb.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/time_util.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/federated_protocol_util.h"
+#include "fcp/client/fl_runner.pb.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/grpc_bidi_stream.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/in_memory_request_response.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/secagg_event_publisher.h"
+#include "fcp/client/secagg_runner.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/secagg/client/secagg_client.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/crypto_rand_prng.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/math.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace client {
+
+using ::fcp::client::http::UriOrInlineData;
+using ::fcp::secagg::ClientToServerWrapperMessage;
+using ::google::internal::federatedml::v2::CheckinRequest;
+using ::google::internal::federatedml::v2::CheckinRequestAck;
+using ::google::internal::federatedml::v2::CheckinResponse;
+using ::google::internal::federatedml::v2::ClientExecutionStats;
+using ::google::internal::federatedml::v2::ClientStreamMessage;
+using ::google::internal::federatedml::v2::EligibilityEvalCheckinRequest;
+using ::google::internal::federatedml::v2::EligibilityEvalCheckinResponse;
+using ::google::internal::federatedml::v2::EligibilityEvalPayload;
+using ::google::internal::federatedml::v2::HttpCompressionFormat;
+using ::google::internal::federatedml::v2::ProtocolOptionsRequest;
+using ::google::internal::federatedml::v2::RetryWindow;
+using ::google::internal::federatedml::v2::ServerStreamMessage;
+using ::google::internal::federatedml::v2::SideChannelExecutionInfo;
+using ::google::internal::federatedml::v2::TaskEligibilityInfo;
+
+// A note on error handling:
+//
+// The implementation here makes a distinction between what we call 'transient'
+// and 'permanent' errors. While the exact categorization of transient vs.
+// permanent errors is defined by a flag, the intent is that transient errors
+// are those types of errors that may occur in the regular course of business,
+// e.g. due to an interrupted network connection, a load balancer temporarily
+// rejecting our request etc. Generally, these are expected to be resolvable by
+// merely retrying the request at a slightly later time. Permanent errors are
+// intended to be those that are not expected to be resolvable as quickly or by
+// merely retrying the request. E.g. if a client checks in to the server with a
+// population name that doesn't exist, then the server may return NOT_FOUND, and
+// until the server-side configuration is changed, it will continue returning
+// such an error. Hence, such errors can warrant a longer retry period (to waste
+// less of both the client's and server's resources).
+//
+// The errors also differ in how they interact with the server-specified retry
+// windows that are returned via the CheckinRequestAck message.
+// - If a permanent error occurs, then we will always return a retry window
+// based on the target 'permanent errors retry period' flag, regardless of
+// whether we received a CheckinRequestAck from the server at an earlier time.
+// - If a transient error occurs, then we will only return a retry window
+// based on the target 'transient errors retry period' flag if the server
+// didn't already return a CheckinRequestAck. If it did return such an ack,
+// then one of the retry windows in that message will be used instead.
+//
+// Finally, note that for simplicity's sake we generally check whether a
+// permanent error was received at the level of this class's public method,
+// rather than deeper down in each of our helper methods that actually call
+// directly into the gRPC stack. This keeps our state-managing code simpler, but
+// does mean that if any of our helper methods like SendCheckinRequest produce a
+// permanent error code locally (i.e. without it being sent by the server), it
+// will be treated as if the server sent it and the permanent error retry period
+// will be used. We consider this a reasonable tradeoff.
+
+GrpcFederatedProtocol::GrpcFederatedProtocol(
+ EventPublisher* event_publisher, LogManager* log_manager,
+ std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
+ const Flags* flags, ::fcp::client::http::HttpClient* http_client,
+ const std::string& federated_service_uri, const std::string& api_key,
+ const std::string& test_cert_path, absl::string_view population_name,
+ absl::string_view retry_token, absl::string_view client_version,
+ absl::string_view attestation_measurement,
+ std::function<bool()> should_abort,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ const int64_t grpc_channel_deadline_seconds,
+ cache::ResourceCache* resource_cache)
+ : GrpcFederatedProtocol(
+ event_publisher, log_manager, std::move(secagg_runner_factory), flags,
+ http_client,
+ std::make_unique<GrpcBidiStream>(
+ federated_service_uri, api_key, std::string(population_name),
+ grpc_channel_deadline_seconds, test_cert_path),
+ population_name, retry_token, client_version, attestation_measurement,
+ should_abort, absl::BitGen(), timing_config, resource_cache) {}
+
+GrpcFederatedProtocol::GrpcFederatedProtocol(
+ EventPublisher* event_publisher, LogManager* log_manager,
+ std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
+ const Flags* flags, ::fcp::client::http::HttpClient* http_client,
+ std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream,
+ absl::string_view population_name, absl::string_view retry_token,
+ absl::string_view client_version, absl::string_view attestation_measurement,
+ std::function<bool()> should_abort, absl::BitGen bit_gen,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ cache::ResourceCache* resource_cache)
+ : object_state_(ObjectState::kInitialized),
+ event_publisher_(event_publisher),
+ log_manager_(log_manager),
+ secagg_runner_factory_(std::move(secagg_runner_factory)),
+ flags_(flags),
+ http_client_(http_client),
+ grpc_bidi_stream_(std::move(grpc_bidi_stream)),
+ population_name_(population_name),
+ retry_token_(retry_token),
+ client_version_(client_version),
+ attestation_measurement_(attestation_measurement),
+ bit_gen_(std::move(bit_gen)),
+ resource_cache_(resource_cache) {
+ interruptible_runner_ = std::make_unique<InterruptibleRunner>(
+ log_manager, should_abort, timing_config,
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC,
+ .interrupt_timeout =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_TIMED_OUT});
+ // Note that we could cast the provided error codes to absl::StatusCode
+ // values here. However, that means we'd have to handle the case when
+ // invalid integers that don't map to a StatusCode enum are provided in the
+ // flag here. Instead, we cast absl::StatusCodes to int32_t each time we
+ // compare them with the flag-provided list of codes, which means we never
+ // have to worry about invalid flag values (besides the fact that invalid
+ // values will be silently ignored, which could make it harder to realize when
+ // flag is misconfigured).
+ const std::vector<int32_t>& error_codes =
+ flags->federated_training_permanent_error_codes();
+ federated_training_permanent_error_codes_ =
+ absl::flat_hash_set<int32_t>(error_codes.begin(), error_codes.end());
+}
+
+GrpcFederatedProtocol::~GrpcFederatedProtocol() { grpc_bidi_stream_->Close(); }
+
+absl::Status GrpcFederatedProtocol::Send(
+ google::internal::federatedml::v2::ClientStreamMessage*
+ client_stream_message) {
+ // Note that this stopwatch measurement may not fully measure the time it
+ // takes to send all of the data, as it may return before all data was written
+ // to the network socket. It's the best estimate we can provide though.
+ auto started_stopwatch = network_stopwatch_->Start();
+ FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
+ [this, &client_stream_message]() {
+ return this->grpc_bidi_stream_->Send(client_stream_message);
+ },
+ [this]() { this->grpc_bidi_stream_->Close(); }));
+ return absl::OkStatus();
+}
+
+absl::Status GrpcFederatedProtocol::Receive(
+ google::internal::federatedml::v2::ServerStreamMessage*
+ server_stream_message) {
+ // Note that this stopwatch measurement will generally include time spent
+ // waiting for the server to return a response (i.e. idle time rather than the
+ // true time it takes to send/receive data on the network). It's the best
+ // estimate we can provide though.
+ auto started_stopwatch = network_stopwatch_->Start();
+ FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
+ [this, &server_stream_message]() {
+ return grpc_bidi_stream_->Receive(server_stream_message);
+ },
+ [this]() { this->grpc_bidi_stream_->Close(); }));
+ return absl::OkStatus();
+}
+
+ProtocolOptionsRequest GrpcFederatedProtocol::CreateProtocolOptionsRequest(
+ bool should_ack_checkin) const {
+ ProtocolOptionsRequest request;
+ request.set_should_ack_checkin(should_ack_checkin);
+ request.set_supports_http_download(http_client_ != nullptr);
+ request.set_supports_eligibility_eval_http_download(
+ http_client_ != nullptr &&
+ flags_->enable_grpc_with_eligibility_eval_http_resource_support());
+
+ // Note that we set this field for both eligibility eval checkin requests
+ // and regular checkin requests. Even though eligibility eval tasks do not
+ // have any aggregation phase, we still advertise the client's support for
+ // Secure Aggregation during the eligibility eval checkin phase. We do
+ // this because it doesn't hurt anything, and because letting the server
+ // know whether client supports SecAgg sooner rather than later in the
+ // protocol seems to provide maximum flexibility if the server ever were
+ // to use that information at this stage of the protocol in the future.
+ request.mutable_side_channels()
+ ->mutable_secure_aggregation()
+ ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
+ request.mutable_supported_http_compression_formats()->Add(
+ HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
+ return request;
+}
+
+absl::Status GrpcFederatedProtocol::SendEligibilityEvalCheckinRequest() {
+ ClientStreamMessage client_stream_message;
+ EligibilityEvalCheckinRequest* eligibility_checkin_request =
+ client_stream_message.mutable_eligibility_eval_checkin_request();
+ eligibility_checkin_request->set_population_name(population_name_);
+ eligibility_checkin_request->set_retry_token(retry_token_);
+ eligibility_checkin_request->set_client_version(client_version_);
+ eligibility_checkin_request->set_attestation_measurement(
+ attestation_measurement_);
+ *eligibility_checkin_request->mutable_protocol_options_request() =
+ CreateProtocolOptionsRequest(
+ /* should_ack_checkin=*/true);
+
+ return Send(&client_stream_message);
+}
+
+absl::Status GrpcFederatedProtocol::SendCheckinRequest(
+ const std::optional<TaskEligibilityInfo>& task_eligibility_info) {
+ ClientStreamMessage client_stream_message;
+ CheckinRequest* checkin_request =
+ client_stream_message.mutable_checkin_request();
+ checkin_request->set_population_name(population_name_);
+ checkin_request->set_retry_token(retry_token_);
+ checkin_request->set_client_version(client_version_);
+ checkin_request->set_attestation_measurement(attestation_measurement_);
+ *checkin_request->mutable_protocol_options_request() =
+ CreateProtocolOptionsRequest(/* should_ack_checkin=*/false);
+
+ if (task_eligibility_info.has_value()) {
+ *checkin_request->mutable_task_eligibility_info() = *task_eligibility_info;
+ }
+
+ return Send(&client_stream_message);
+}
+
+absl::Status GrpcFederatedProtocol::ReceiveCheckinRequestAck() {
+ // Wait for a CheckinRequestAck.
+ ServerStreamMessage server_stream_message;
+ absl::Status receive_status = Receive(&server_stream_message);
+ if (receive_status.code() == absl::StatusCode::kNotFound) {
+ FCP_LOG(INFO) << "Server responded NOT_FOUND to checkin request, "
+ "population name '"
+ << population_name_ << "' is likely incorrect.";
+ }
+ FCP_RETURN_IF_ERROR(receive_status);
+
+ if (!server_stream_message.has_checkin_request_ack()) {
+ log_manager_->LogDiag(
+ ProdDiagCode::
+ BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_EXPECTED_BUT_NOT_RECVD);
+ return absl::UnimplementedError(
+ "Requested but did not receive CheckinRequestAck");
+ }
+ log_manager_->LogDiag(
+ ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED);
+ // Process the received CheckinRequestAck message.
+ const CheckinRequestAck& checkin_request_ack =
+ server_stream_message.checkin_request_ack();
+ if (!checkin_request_ack.has_retry_window_if_accepted() ||
+ !checkin_request_ack.has_retry_window_if_rejected()) {
+ return absl::UnimplementedError(
+ "Received CheckinRequestAck message with missing retry windows");
+ }
+ // Upon receiving the server's RetryWindows we immediately choose a concrete
+ // target timestamp to retry at. This ensures that a) clients of this class
+ // don't have to implement the logic to select a timestamp from a min/max
+ // range themselves, b) we tell clients of this class to come back at exactly
+ // a point in time the server intended us to come at (i.e. "now +
+ // server_specified_retry_period", and not a point in time that is partly
+ // determined by how long the remaining protocol interactions (e.g. training
+ // and results upload) will take (i.e. "now +
+ // duration_of_remaining_protocol_interactions +
+ // server_specified_retry_period").
+ checkin_request_ack_info_ = CheckinRequestAckInfo{
+ .retry_info_if_rejected =
+ RetryTimeAndToken{
+ PickRetryTimeFromRange(
+ checkin_request_ack.retry_window_if_rejected().delay_min(),
+ checkin_request_ack.retry_window_if_rejected().delay_max(),
+ bit_gen_),
+ checkin_request_ack.retry_window_if_rejected().retry_token()},
+ .retry_info_if_accepted = RetryTimeAndToken{
+ PickRetryTimeFromRange(
+ checkin_request_ack.retry_window_if_accepted().delay_min(),
+ checkin_request_ack.retry_window_if_accepted().delay_max(),
+ bit_gen_),
+ checkin_request_ack.retry_window_if_accepted().retry_token()}};
+ return absl::OkStatus();
+}
+
+absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
+GrpcFederatedProtocol::ReceiveEligibilityEvalCheckinResponse(
+ absl::Time start_time, std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback) {
+ ServerStreamMessage server_stream_message;
+ FCP_RETURN_IF_ERROR(Receive(&server_stream_message));
+
+ if (!server_stream_message.has_eligibility_eval_checkin_response()) {
+ return absl::UnimplementedError(
+ absl::StrCat("Bad response to EligibilityEvalCheckinRequest; Expected "
+ "EligibilityEvalCheckinResponse but got ",
+ server_stream_message.kind_case(), "."));
+ }
+
+ const EligibilityEvalCheckinResponse& eligibility_checkin_response =
+ server_stream_message.eligibility_eval_checkin_response();
+ switch (eligibility_checkin_response.checkin_result_case()) {
+ case EligibilityEvalCheckinResponse::kEligibilityEvalPayload: {
+ const EligibilityEvalPayload& eligibility_eval_payload =
+ eligibility_checkin_response.eligibility_eval_payload();
+ object_state_ = ObjectState::kEligibilityEvalEnabled;
+ EligibilityEvalTask result{.execution_id =
+ eligibility_eval_payload.execution_id()};
+
+ payload_uris_received_callback(result);
+
+ PlanAndCheckpointPayloads payloads;
+ if (http_client_ == nullptr ||
+ !flags_->enable_grpc_with_eligibility_eval_http_resource_support()) {
+ result.payloads = {
+ .plan = eligibility_eval_payload.plan(),
+ .checkpoint = eligibility_eval_payload.init_checkpoint()};
+ } else {
+ // Fetch the task resources, returning any errors that may be
+ // encountered in the process.
+ FCP_ASSIGN_OR_RETURN(
+ result.payloads,
+ FetchTaskResources(
+ {.plan =
+ {
+ .has_uri =
+ eligibility_eval_payload.has_plan_resource(),
+ .uri = eligibility_eval_payload.plan_resource().uri(),
+ .data = eligibility_eval_payload.plan(),
+ .client_cache_id =
+ eligibility_eval_payload.plan_resource()
+ .client_cache_id(),
+ .max_age = TimeUtil::ConvertProtoToAbslDuration(
+ eligibility_eval_payload.plan_resource()
+ .max_age()),
+ },
+ .checkpoint = {
+ .has_uri = eligibility_eval_payload
+ .has_init_checkpoint_resource(),
+ .uri = eligibility_eval_payload.init_checkpoint_resource()
+ .uri(),
+ .data = eligibility_eval_payload.init_checkpoint(),
+ .client_cache_id =
+ eligibility_eval_payload.init_checkpoint_resource()
+ .client_cache_id(),
+ .max_age = TimeUtil::ConvertProtoToAbslDuration(
+ eligibility_eval_payload.init_checkpoint_resource()
+ .max_age()),
+ }}));
+ }
+ return std::move(result);
+ }
+ case EligibilityEvalCheckinResponse::kNoEligibilityEvalConfigured: {
+ // Nothing to do...
+ object_state_ = ObjectState::kEligibilityEvalDisabled;
+ return EligibilityEvalDisabled{};
+ }
+ case EligibilityEvalCheckinResponse::kRejectionInfo: {
+ object_state_ = ObjectState::kEligibilityEvalCheckinRejected;
+ return Rejection{};
+ }
+ default:
+ return absl::UnimplementedError(
+ "Unrecognized EligibilityEvalCheckinResponse");
+ }
+}
+
+absl::StatusOr<FederatedProtocol::CheckinResult>
+GrpcFederatedProtocol::ReceiveCheckinResponse(
+ absl::Time start_time,
+ std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
+ ServerStreamMessage server_stream_message;
+ absl::Status receive_status = Receive(&server_stream_message);
+ FCP_RETURN_IF_ERROR(receive_status);
+
+ if (!server_stream_message.has_checkin_response()) {
+ return absl::UnimplementedError(absl::StrCat(
+ "Bad response to CheckinRequest; Expected CheckinResponse but got ",
+ server_stream_message.kind_case(), "."));
+ }
+
+ const CheckinResponse& checkin_response =
+ server_stream_message.checkin_response();
+
+ execution_phase_id_ =
+ checkin_response.has_acceptance_info()
+ ? checkin_response.acceptance_info().execution_phase_id()
+ : "";
+ switch (checkin_response.checkin_result_case()) {
+ case CheckinResponse::kAcceptanceInfo: {
+ const auto& acceptance_info = checkin_response.acceptance_info();
+
+ for (const auto& [k, v] : acceptance_info.side_channels())
+ side_channels_[k] = v;
+ side_channel_protocol_execution_info_ =
+ acceptance_info.side_channel_protocol_execution_info();
+ side_channel_protocol_options_response_ =
+ checkin_response.protocol_options_response().side_channels();
+
+ std::optional<SecAggInfo> sec_agg_info = std::nullopt;
+ if (side_channel_protocol_execution_info_.has_secure_aggregation()) {
+ sec_agg_info = SecAggInfo{
+ .expected_number_of_clients =
+ side_channel_protocol_execution_info_.secure_aggregation()
+ .expected_number_of_clients(),
+ .minimum_clients_in_server_visible_aggregate =
+ side_channel_protocol_execution_info_.secure_aggregation()
+ .minimum_clients_in_server_visible_aggregate()};
+ }
+
+ TaskAssignment result{
+ .federated_select_uri_template =
+ acceptance_info.federated_select_uri_info().uri_template(),
+ .aggregation_session_id = acceptance_info.execution_phase_id(),
+ .sec_agg_info = sec_agg_info};
+
+ payload_uris_received_callback(result);
+
+ PlanAndCheckpointPayloads payloads;
+ if (http_client_ == nullptr) {
+ result.payloads = {.plan = acceptance_info.plan(),
+ .checkpoint = acceptance_info.init_checkpoint()};
+ } else {
+ // Fetch the task resources, returning any errors that may be
+ // encountered in the process.
+ FCP_ASSIGN_OR_RETURN(
+ result.payloads,
+ FetchTaskResources(
+ {.plan =
+ {
+ .has_uri = acceptance_info.has_plan_resource(),
+ .uri = acceptance_info.plan_resource().uri(),
+ .data = acceptance_info.plan(),
+ .client_cache_id =
+ acceptance_info.plan_resource().client_cache_id(),
+ .max_age = TimeUtil::ConvertProtoToAbslDuration(
+ acceptance_info.plan_resource().max_age()),
+ },
+ .checkpoint = {
+ .has_uri = acceptance_info.has_init_checkpoint_resource(),
+ .uri = acceptance_info.init_checkpoint_resource().uri(),
+ .data = acceptance_info.init_checkpoint(),
+ .client_cache_id =
+ acceptance_info.init_checkpoint_resource()
+ .client_cache_id(),
+ .max_age = TimeUtil::ConvertProtoToAbslDuration(
+ acceptance_info.init_checkpoint_resource().max_age()),
+ }}));
+ }
+
+ object_state_ = ObjectState::kCheckinAccepted;
+ return result;
+ }
+ case CheckinResponse::kRejectionInfo: {
+ object_state_ = ObjectState::kCheckinRejected;
+ return Rejection{};
+ }
+ default:
+ return absl::UnimplementedError("Unrecognized CheckinResponse");
+ }
+}
+
+absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
+GrpcFederatedProtocol::EligibilityEvalCheckin(
+ std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback) {
+ FCP_CHECK(object_state_ == ObjectState::kInitialized)
+ << "Invalid call sequence";
+ object_state_ = ObjectState::kEligibilityEvalCheckinFailed;
+
+ absl::Time start_time = absl::Now();
+
+ // Send an EligibilityEvalCheckinRequest.
+ absl::Status request_status = SendEligibilityEvalCheckinRequest();
+ // See note about how we handle 'permanent' errors at the top of this file.
+ UpdateObjectStateIfPermanentError(
+ request_status, ObjectState::kEligibilityEvalCheckinFailedPermanentError);
+ FCP_RETURN_IF_ERROR(request_status);
+
+ // Receive a CheckinRequestAck.
+ absl::Status ack_status = ReceiveCheckinRequestAck();
+ UpdateObjectStateIfPermanentError(
+ ack_status, ObjectState::kEligibilityEvalCheckinFailedPermanentError);
+ FCP_RETURN_IF_ERROR(ack_status);
+
+ // Receive + handle an EligibilityEvalCheckinResponse message, and update the
+ // object state based on the received response.
+ auto response = ReceiveEligibilityEvalCheckinResponse(
+ start_time, payload_uris_received_callback);
+ UpdateObjectStateIfPermanentError(
+ response.status(),
+ ObjectState::kEligibilityEvalCheckinFailedPermanentError);
+ return response;
+}
+
+// This is not supported in gRPC federated protocol, we'll do nothing.
+void GrpcFederatedProtocol::ReportEligibilityEvalError(
+ absl::Status error_status) {}
+
+absl::StatusOr<FederatedProtocol::CheckinResult> GrpcFederatedProtocol::Checkin(
+ const std::optional<TaskEligibilityInfo>& task_eligibility_info,
+ std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
+ // Checkin(...) must follow an earlier call to EligibilityEvalCheckin() that
+ // resulted in a CheckinResultPayload or an EligibilityEvalDisabled result.
+ FCP_CHECK(object_state_ == ObjectState::kEligibilityEvalDisabled ||
+ object_state_ == ObjectState::kEligibilityEvalEnabled)
+ << "Checkin(...) called despite failed/rejected earlier "
+ "EligibilityEvalCheckin";
+ if (object_state_ == ObjectState::kEligibilityEvalEnabled) {
+ FCP_CHECK(task_eligibility_info.has_value())
+ << "Missing TaskEligibilityInfo despite receiving prior "
+ "EligibilityEvalCheckin payload";
+ } else {
+ FCP_CHECK(!task_eligibility_info.has_value())
+ << "Received TaskEligibilityInfo despite not receiving a prior "
+ "EligibilityEvalCheckin payload";
+ }
+
+ object_state_ = ObjectState::kCheckinFailed;
+
+ absl::Time start_time = absl::Now();
+ // Send a CheckinRequest.
+ absl::Status request_status = SendCheckinRequest(task_eligibility_info);
+ // See note about how we handle 'permanent' errors at the top of this file.
+ UpdateObjectStateIfPermanentError(request_status,
+ ObjectState::kCheckinFailedPermanentError);
+ FCP_RETURN_IF_ERROR(request_status);
+
+ // Receive + handle a CheckinResponse message, and update the object state
+ // based on the received response.
+ auto response =
+ ReceiveCheckinResponse(start_time, payload_uris_received_callback);
+ UpdateObjectStateIfPermanentError(response.status(),
+ ObjectState::kCheckinFailedPermanentError);
+ return response;
+}
+
+absl::StatusOr<FederatedProtocol::MultipleTaskAssignments>
+GrpcFederatedProtocol::PerformMultipleTaskAssignments(
+ const std::vector<std::string>& task_names) {
+ return absl::UnimplementedError(
+ "PerformMultipleTaskAssignments is not supported by "
+ "GrpcFederatedProtocol.");
+}
+
+absl::Status GrpcFederatedProtocol::ReportCompleted(
+ ComputationResults results, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) {
+ FCP_LOG(INFO) << "Reporting outcome: " << static_cast<int>(engine::COMPLETED);
+ FCP_CHECK(object_state_ == ObjectState::kCheckinAccepted)
+ << "Invalid call sequence";
+ object_state_ = ObjectState::kReportCalled;
+ auto response = Report(std::move(results), engine::COMPLETED, plan_duration);
+ // See note about how we handle 'permanent' errors at the top of this file.
+ UpdateObjectStateIfPermanentError(response,
+ ObjectState::kReportFailedPermanentError);
+ return response;
+}
+
+absl::Status GrpcFederatedProtocol::ReportNotCompleted(
+ engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_Id) {
+ FCP_LOG(WARNING) << "Reporting outcome: " << static_cast<int>(phase_outcome);
+ FCP_CHECK(object_state_ == ObjectState::kCheckinAccepted)
+ << "Invalid call sequence";
+ object_state_ = ObjectState::kReportCalled;
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", "");
+ auto response = Report(std::move(results), phase_outcome, plan_duration);
+ // See note about how we handle 'permanent' errors at the top of this file.
+ UpdateObjectStateIfPermanentError(response,
+ ObjectState::kReportFailedPermanentError);
+ return response;
+}
+
+class GrpcSecAggSendToServerImpl : public SecAggSendToServerBase {
+ public:
+ GrpcSecAggSendToServerImpl(
+ GrpcBidiStreamInterface* grpc_bidi_stream,
+ const std::function<absl::Status(ClientToServerWrapperMessage*)>&
+ report_func)
+ : grpc_bidi_stream_(grpc_bidi_stream), report_func_(report_func) {}
+ ~GrpcSecAggSendToServerImpl() override = default;
+
+ void Send(ClientToServerWrapperMessage* message) override {
+ // The commit message (MaskedInputRequest) must be piggy-backed onto the
+ // ReportRequest message, the logic for which is encapsulated in
+ // report_func_ so that it may be held in common between both accumulation
+ // methods.
+ if (message->message_content_case() ==
+ ClientToServerWrapperMessage::MessageContentCase::
+ kMaskedInputResponse) {
+ auto status = report_func_(message);
+ if (!status.ok())
+ FCP_LOG(ERROR) << "Could not send ReportRequest: " << status;
+ return;
+ }
+ ClientStreamMessage client_stream_message;
+ client_stream_message.mutable_secure_aggregation_client_message()->Swap(
+ message);
+ auto bytes_to_upload = client_stream_message.ByteSizeLong();
+ auto status = grpc_bidi_stream_->Send(&client_stream_message);
+ if (status.ok()) {
+ last_sent_message_size_ = bytes_to_upload;
+ }
+ }
+
+ private:
+ GrpcBidiStreamInterface* grpc_bidi_stream_;
+ // SecAgg's output must be wrapped in a ReportRequest; because the report
+ // logic is mostly generic, this lambda allows it to be shared between
+ // aggregation types.
+ const std::function<absl::Status(ClientToServerWrapperMessage*)>&
+ report_func_;
+};
+
+class GrpcSecAggProtocolDelegate : public SecAggProtocolDelegate {
+ public:
+ GrpcSecAggProtocolDelegate(
+ absl::flat_hash_map<std::string, SideChannelExecutionInfo> side_channels,
+ GrpcBidiStreamInterface* grpc_bidi_stream)
+ : side_channels_(std::move(side_channels)),
+ grpc_bidi_stream_(grpc_bidi_stream) {}
+
+ absl::StatusOr<uint64_t> GetModulus(const std::string& key) override {
+ auto execution_info = side_channels_.find(key);
+ if (execution_info == side_channels_.end())
+ return absl::InternalError(
+ absl::StrCat("Execution not found for aggregand: ", key));
+ uint64_t modulus;
+ auto secure_aggregand = execution_info->second.secure_aggregand();
+ // TODO(team): Delete output_bitwidth support once
+ // modulus is fully rolled out.
+ if (secure_aggregand.modulus() > 0) {
+ modulus = secure_aggregand.modulus();
+ } else {
+ // Note: we ignore vector.get_bitwidth() here, because (1)
+ // it is only an upper bound on the *input* bitwidth,
+ // based on the Tensorflow dtype, but (2) we have exact
+ // *output* bitwidth information from the execution_info,
+ // and that is what SecAgg needs.
+ modulus = 1ULL << secure_aggregand.output_bitwidth();
+ }
+ return modulus;
+ }
+
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> ReceiveServerMessage()
+ override {
+ ServerStreamMessage server_stream_message;
+ absl::Status receive_status =
+ grpc_bidi_stream_->Receive(&server_stream_message);
+ if (!receive_status.ok()) {
+ return absl::Status(receive_status.code(),
+ absl::StrCat("Error during SecAgg receive: ",
+ receive_status.message()));
+ }
+ last_received_message_size_ = server_stream_message.ByteSizeLong();
+ if (!server_stream_message.has_secure_aggregation_server_message()) {
+ return absl::InternalError(
+ absl::StrCat("Bad response to SecAgg protocol; Expected "
+ "ServerToClientWrapperMessage but got ",
+ server_stream_message.kind_case(), "."));
+ }
+ return server_stream_message.secure_aggregation_server_message();
+ }
+
+ void Abort() override { grpc_bidi_stream_->Close(); }
+ size_t last_received_message_size() override {
+ return last_received_message_size_;
+ };
+
+ private:
+ absl::flat_hash_map<std::string, SideChannelExecutionInfo> side_channels_;
+ GrpcBidiStreamInterface* grpc_bidi_stream_;
+ size_t last_received_message_size_;
+};
+
+absl::Status GrpcFederatedProtocol::ReportInternal(
+ std::string tf_checkpoint, engine::PhaseOutcome phase_outcome,
+ absl::Duration plan_duration,
+ ClientToServerWrapperMessage* secagg_commit_message) {
+ ClientStreamMessage client_stream_message;
+ auto report_request = client_stream_message.mutable_report_request();
+ report_request->set_population_name(population_name_);
+ report_request->set_execution_phase_id(execution_phase_id_);
+ auto report = report_request->mutable_report();
+
+ // 1. Include TF checkpoint and/or SecAgg commit message.
+ report->set_update_checkpoint(std::move(tf_checkpoint));
+ if (secagg_commit_message) {
+ client_stream_message.mutable_secure_aggregation_client_message()->Swap(
+ secagg_commit_message);
+ }
+
+ // 2. Include outcome of computation.
+ report->set_status_code(phase_outcome == engine::COMPLETED
+ ? google::rpc::OK
+ : google::rpc::INTERNAL);
+
+ // 3. Include client execution statistics, if any.
+ ClientExecutionStats client_execution_stats;
+ client_execution_stats.mutable_duration()->set_seconds(
+ absl::IDivDuration(plan_duration, absl::Seconds(1), &plan_duration));
+ client_execution_stats.mutable_duration()->set_nanos(static_cast<int32_t>(
+ absl::IDivDuration(plan_duration, absl::Nanoseconds(1), &plan_duration)));
+ report->add_serialized_train_event()->PackFrom(client_execution_stats);
+
+ // 4. Send ReportRequest.
+
+ // Note that we do not use the GrpcFederatedProtocol::Send(...) helper method
+ // here, since we are already running within a call to
+ // InterruptibleRunner::Run.
+ const auto status = this->grpc_bidi_stream_->Send(&client_stream_message);
+ if (!status.ok()) {
+ return absl::Status(
+ status.code(),
+ absl::StrCat("Error sending ReportRequest: ", status.message()));
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status GrpcFederatedProtocol::Report(ComputationResults results,
+ engine::PhaseOutcome phase_outcome,
+ absl::Duration plan_duration) {
+ std::string tf_checkpoint;
+ bool has_checkpoint;
+ for (auto& [k, v] : results) {
+ if (std::holds_alternative<TFCheckpoint>(v)) {
+ tf_checkpoint = std::get<TFCheckpoint>(std::move(v));
+ has_checkpoint = true;
+ break;
+ }
+ }
+
+ // This lambda allows for convenient reporting from within SecAgg's
+ // SendToServerInterface::Send().
+ std::function<absl::Status(ClientToServerWrapperMessage*)> report_lambda =
+ [&](ClientToServerWrapperMessage* secagg_commit_message) -> absl::Status {
+ return ReportInternal(std::move(tf_checkpoint), phase_outcome,
+ plan_duration, secagg_commit_message);
+ };
+
+ // Run the Secure Aggregation protocol, if necessary.
+ if (side_channel_protocol_execution_info_.has_secure_aggregation()) {
+ auto secure_aggregation_protocol_execution_info =
+ side_channel_protocol_execution_info_.secure_aggregation();
+ auto expected_number_of_clients =
+ secure_aggregation_protocol_execution_info.expected_number_of_clients();
+
+ FCP_LOG(INFO) << "Reporting via Secure Aggregation";
+ if (phase_outcome != engine::COMPLETED)
+ return absl::InternalError(
+ "Aborting the SecAgg protocol (no update was produced).");
+
+ if (side_channel_protocol_options_response_.secure_aggregation()
+ .client_variant() != secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1) {
+ log_manager_->LogDiag(
+ ProdDiagCode::SECAGG_CLIENT_ERROR_UNSUPPORTED_VERSION);
+ return absl::InternalError(absl::StrCat(
+ "Unsupported SecAgg client variant: ",
+ side_channel_protocol_options_response_.secure_aggregation()
+ .client_variant()));
+ }
+
+ auto send_to_server_impl = std::make_unique<GrpcSecAggSendToServerImpl>(
+ grpc_bidi_stream_.get(), report_lambda);
+ auto secagg_event_publisher = event_publisher_->secagg_event_publisher();
+ FCP_CHECK(secagg_event_publisher)
+ << "An implementation of "
+ << "SecAggEventPublisher must be provided.";
+ auto delegate = std::make_unique<GrpcSecAggProtocolDelegate>(
+ side_channels_, grpc_bidi_stream_.get());
+ std::unique_ptr<SecAggRunner> secagg_runner =
+ secagg_runner_factory_->CreateSecAggRunner(
+ std::move(send_to_server_impl), std::move(delegate),
+ secagg_event_publisher, log_manager_, interruptible_runner_.get(),
+ expected_number_of_clients,
+ secure_aggregation_protocol_execution_info
+ .minimum_surviving_clients_for_reconstruction());
+
+ FCP_RETURN_IF_ERROR(secagg_runner->Run(std::move(results)));
+ } else {
+ // Report without secure aggregation.
+ FCP_LOG(INFO) << "Reporting via Simple Aggregation";
+ if (results.size() != 1 || !has_checkpoint) {
+ return absl::InternalError(
+ "Simple Aggregation aggregands have unexpected format.");
+ }
+ FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
+ [&report_lambda]() { return report_lambda(nullptr); },
+ [this]() {
+ // What about event_publisher_ and log_manager_?
+ this->grpc_bidi_stream_->Close();
+ }));
+ }
+
+ FCP_LOG(INFO) << "Finished reporting.";
+
+ // Receive ReportResponse.
+ ServerStreamMessage server_stream_message;
+ absl::Status receive_status = Receive(&server_stream_message);
+ if (receive_status.code() == absl::StatusCode::kAborted) {
+ FCP_LOG(INFO) << "Server responded ABORTED.";
+ } else if (receive_status.code() == absl::StatusCode::kCancelled) {
+ FCP_LOG(INFO) << "Upload was cancelled by the client.";
+ }
+ if (!receive_status.ok()) {
+ return absl::Status(
+ receive_status.code(),
+ absl::StrCat("Error after ReportRequest: ", receive_status.message()));
+ }
+ if (!server_stream_message.has_report_response()) {
+ return absl::UnimplementedError(absl::StrCat(
+ "Bad response to ReportRequest; Expected REPORT_RESPONSE but got ",
+ server_stream_message.kind_case(), "."));
+ }
+ return absl::OkStatus();
+}
+
+RetryWindow GrpcFederatedProtocol::GetLatestRetryWindow() {
+ // We explicitly enumerate all possible states here rather than using
+ // "default", to ensure that when new states are added later on, the author
+ // is forced to update this method and consider which is the correct
+ // RetryWindow to return.
+ switch (object_state_) {
+ case ObjectState::kCheckinAccepted:
+ case ObjectState::kReportCalled:
+ // If a client makes it past the 'checkin acceptance' stage, we use the
+ // 'accepted' RetryWindow unconditionally (unless a permanent error is
+ // encountered). This includes cases where the checkin is accepted, but
+ // the report request results in a (transient) error.
+ FCP_CHECK(checkin_request_ack_info_.has_value());
+ return GenerateRetryWindowFromRetryTimeAndToken(
+ checkin_request_ack_info_->retry_info_if_accepted);
+ case ObjectState::kEligibilityEvalCheckinRejected:
+ case ObjectState::kEligibilityEvalDisabled:
+ case ObjectState::kEligibilityEvalEnabled:
+ case ObjectState::kCheckinRejected:
+ FCP_CHECK(checkin_request_ack_info_.has_value());
+ return GenerateRetryWindowFromRetryTimeAndToken(
+ checkin_request_ack_info_->retry_info_if_rejected);
+ case ObjectState::kInitialized:
+ case ObjectState::kEligibilityEvalCheckinFailed:
+ case ObjectState::kCheckinFailed:
+ // If the flag is true, then we use the previously chosen absolute retry
+ // time instead (if available).
+ if (checkin_request_ack_info_.has_value()) {
+ // If we already received a server-provided retry window, then use it.
+ return GenerateRetryWindowFromRetryTimeAndToken(
+ checkin_request_ack_info_->retry_info_if_rejected);
+ }
+ // Otherwise, we generate a retry window using the flag-provided transient
+ // error retry period.
+ return GenerateRetryWindowFromTargetDelay(
+ absl::Seconds(
+ flags_->federated_training_transient_errors_retry_delay_secs()),
+ // NOLINTBEGIN(whitespace/line_length)
+ flags_
+ ->federated_training_transient_errors_retry_delay_jitter_percent(),
+ // NOLINTEND
+ bit_gen_);
+ case ObjectState::kEligibilityEvalCheckinFailedPermanentError:
+ case ObjectState::kCheckinFailedPermanentError:
+ case ObjectState::kReportFailedPermanentError:
+ // If we encountered a permanent error during the eligibility eval or
+ // regular checkins, then we use the Flags-configured 'permanent error'
+ // retry period. Note that we do so regardless of whether the server had,
+ // by the time the permanent error was received, already returned a
+ // CheckinRequestAck containing a set of retry windows. See note on error
+ // handling at the top of this file.
+ return GenerateRetryWindowFromTargetDelay(
+ absl::Seconds(
+ flags_->federated_training_permanent_errors_retry_delay_secs()),
+ // NOLINTBEGIN(whitespace/line_length)
+ flags_
+ ->federated_training_permanent_errors_retry_delay_jitter_percent(),
+ // NOLINTEND
+ bit_gen_);
+ case ObjectState::kMultipleTaskAssignmentsAccepted:
+ case ObjectState::kMultipleTaskAssignmentsFailed:
+ case ObjectState::kMultipleTaskAssignmentsFailedPermanentError:
+ case ObjectState::kMultipleTaskAssignmentsNoAvailableTask:
+ case ObjectState::kReportMultipleTaskPartialError:
+ FCP_LOG(FATAL) << "Multi-task assignments is not supported by gRPC.";
+ RetryWindow retry_window;
+ return retry_window;
+ }
+}
+
+// Converts the given RetryTimeAndToken to a zero-width RetryWindow (where
+// delay_min and delay_max are set to the same value), by converting the target
+// retry time to a delay relative to the current timestamp.
+RetryWindow GrpcFederatedProtocol::GenerateRetryWindowFromRetryTimeAndToken(
+ const GrpcFederatedProtocol::RetryTimeAndToken& retry_info) {
+ // Generate a RetryWindow with delay_min and delay_max both set to the same
+ // value.
+ RetryWindow retry_window =
+ GenerateRetryWindowFromRetryTime(retry_info.retry_time);
+ retry_window.set_retry_token(retry_info.retry_token);
+ return retry_window;
+}
+
+void GrpcFederatedProtocol::UpdateObjectStateIfPermanentError(
+ absl::Status status,
+ GrpcFederatedProtocol::ObjectState permanent_error_object_state) {
+ if (federated_training_permanent_error_codes_.contains(
+ static_cast<int32_t>(status.code()))) {
+ object_state_ = permanent_error_object_state;
+ }
+}
+
+absl::StatusOr<FederatedProtocol::PlanAndCheckpointPayloads>
+GrpcFederatedProtocol::FetchTaskResources(
+ GrpcFederatedProtocol::TaskResources task_resources) {
+ FCP_ASSIGN_OR_RETURN(UriOrInlineData plan_uri_or_data,
+ ConvertResourceToUriOrInlineData(task_resources.plan));
+ FCP_ASSIGN_OR_RETURN(
+ UriOrInlineData checkpoint_uri_or_data,
+ ConvertResourceToUriOrInlineData(task_resources.checkpoint));
+
+ // Log a diag code if either resource is about to be downloaded via HTTP.
+ if (!plan_uri_or_data.uri().uri.empty() ||
+ !checkpoint_uri_or_data.uri().uri.empty()) {
+ log_manager_->LogDiag(
+ ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP);
+ }
+
+ // Fetch the plan and init checkpoint resources if they need to be fetched
+ // (using the inline data instead if available).
+ absl::StatusOr<
+ std::vector<absl::StatusOr<::fcp::client::http::InMemoryHttpResponse>>>
+ resource_responses;
+ {
+ auto started_stopwatch = network_stopwatch_->Start();
+ resource_responses = ::fcp::client::http::FetchResourcesInMemory(
+ *http_client_, *interruptible_runner_,
+ {plan_uri_or_data, checkpoint_uri_or_data}, &http_bytes_downloaded_,
+ &http_bytes_uploaded_, resource_cache_);
+ }
+ if (!resource_responses.ok()) {
+ log_manager_->LogDiag(
+ ProdDiagCode::
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED);
+ return resource_responses.status();
+ }
+ auto& plan_data_response = (*resource_responses)[0];
+ auto& checkpoint_data_response = (*resource_responses)[1];
+
+ if (!plan_data_response.ok() || !checkpoint_data_response.ok()) {
+ log_manager_->LogDiag(
+ ProdDiagCode::
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED);
+ }
+ // Note: we forward any error during the fetching of the plan/checkpoint
+ // resources resources to the caller, which means that these error codes
+ // will be checked against the set of 'permanent' error codes, just like the
+ // errors in response to the protocol request are.
+ if (!plan_data_response.ok()) {
+ return absl::Status(plan_data_response.status().code(),
+ absl::StrCat("plan fetch failed: ",
+ plan_data_response.status().ToString()));
+ }
+ if (!checkpoint_data_response.ok()) {
+ return absl::Status(
+ checkpoint_data_response.status().code(),
+ absl::StrCat("checkpoint fetch failed: ",
+ checkpoint_data_response.status().ToString()));
+ }
+ if (!plan_uri_or_data.uri().uri.empty() ||
+ !checkpoint_uri_or_data.uri().uri.empty()) {
+ // We only want to log this diag code when we actually did fetch something
+ // via HTTP.
+ log_manager_->LogDiag(
+ ProdDiagCode::
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED);
+ }
+
+ return PlanAndCheckpointPayloads{plan_data_response->body,
+ checkpoint_data_response->body};
+}
+
+// Convert a Resource proto into a UriOrInlineData object. Returns an
+// `INVALID_ARGUMENT` error if the given `Resource` has the `uri` field set to
+// an empty value, or an `UNIMPLEMENTED` error if the `Resource` has an unknown
+// field set.
+absl::StatusOr<UriOrInlineData>
+GrpcFederatedProtocol::ConvertResourceToUriOrInlineData(
+ const GrpcFederatedProtocol::TaskResource& resource) {
+ // We need to support 3 states:
+ // - Inline data is available.
+ // - No inline data nor is there a URI. This should be treated as there being
+ // an 'empty' inline data.
+ // - No inline data is available but a URI is available.
+ if (!resource.has_uri) {
+ // If the URI field wasn't set, then we'll just use the inline data field
+ // (which will either be set or be empty).
+ //
+ // Note: this copies the data into the new absl::Cord. However, this Cord is
+ // then passed around all the way to fl_runner.cc without copying its data,
+ // so this is ultimately approx. as efficient as the non-HTTP resource code
+ // path where we also make a copy of the protobuf string into a new string
+ // which is then returned.
+ return UriOrInlineData::CreateInlineData(
+ absl::Cord(resource.data),
+ UriOrInlineData::InlineData::CompressionFormat::kUncompressed);
+ }
+ if (resource.uri.empty()) {
+ return absl::InvalidArgumentError(
+ "Resource uri must be non-empty when set");
+ }
+ return UriOrInlineData::CreateUri(resource.uri, resource.client_cache_id,
+ resource.max_age);
+}
+
+NetworkStats GrpcFederatedProtocol::GetNetworkStats() {
+ // Note: the `HttpClient` bandwidth stats are similar to the gRPC protocol's
+ // "chunking layer" stats, in that they reflect as closely as possible the
+ // amount of data sent on the wire.
+ return {.bytes_downloaded = grpc_bidi_stream_->ChunkingLayerBytesReceived() +
+ http_bytes_downloaded_,
+ .bytes_uploaded = grpc_bidi_stream_->ChunkingLayerBytesSent() +
+ http_bytes_uploaded_,
+ .network_duration = network_stopwatch_->GetTotalDuration()};
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/grpc_federated_protocol.h b/fcp/client/grpc_federated_protocol.h
new file mode 100644
index 0000000..c9517f4
--- /dev/null
+++ b/fcp/client/grpc_federated_protocol.h
@@ -0,0 +1,269 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_
+#define FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/node_hash_map.h"
+#include "absl/random/random.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/wall_clock_stopwatch.h"
+#include "fcp/client/cache/resource_cache.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/fl_runner.pb.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/grpc_bidi_stream.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/in_memory_request_response.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/secagg_runner.h"
+#include "fcp/client/selector_context.pb.h"
+#include "fcp/client/stats.h"
+#include "fcp/protocol/grpc_chunked_bidi_stream.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/secagg/client/secagg_client.h"
+
+namespace fcp {
+namespace client {
+
+// Implements a single session of the gRPC-based Federated Learning protocol.
+class GrpcFederatedProtocol : public ::fcp::client::FederatedProtocol {
+ public:
+ GrpcFederatedProtocol(
+ EventPublisher* event_publisher, LogManager* log_manager,
+ std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
+ const Flags* flags, ::fcp::client::http::HttpClient* http_client,
+ const std::string& federated_service_uri, const std::string& api_key,
+ const std::string& test_cert_path, absl::string_view population_name,
+ absl::string_view retry_token, absl::string_view client_version,
+ absl::string_view attestation_measurement,
+ std::function<bool()> should_abort,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ const int64_t grpc_channel_deadline_seconds,
+ cache::ResourceCache* resource_cache);
+
+ // Test c'tor.
+ GrpcFederatedProtocol(
+ EventPublisher* event_publisher, LogManager* log_manager,
+ std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
+ const Flags* flags, ::fcp::client::http::HttpClient* http_client,
+ std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream,
+ absl::string_view population_name, absl::string_view retry_token,
+ absl::string_view client_version,
+ absl::string_view attestation_measurement,
+ std::function<bool()> should_abort, absl::BitGen bit_gen,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ cache::ResourceCache* resource_cache);
+
+ ~GrpcFederatedProtocol() override;
+
+ absl::StatusOr<::fcp::client::FederatedProtocol::EligibilityEvalCheckinResult>
+ EligibilityEvalCheckin(std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback) override;
+
+ void ReportEligibilityEvalError(absl::Status error_status) override;
+
+ absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult> Checkin(
+ const std::optional<
+ google::internal::federatedml::v2::TaskEligibilityInfo>&
+ task_eligibility_info,
+ std::function<void(const TaskAssignment&)> payload_uris_received_callback)
+ override;
+
+ absl::StatusOr<::fcp::client::FederatedProtocol::MultipleTaskAssignments>
+ PerformMultipleTaskAssignments(
+ const std::vector<std::string>& task_names) override;
+
+ absl::Status ReportCompleted(
+ ComputationResults results, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) override;
+
+ absl::Status ReportNotCompleted(
+ engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) override;
+
+ google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow()
+ override;
+
+ NetworkStats GetNetworkStats() override;
+
+ private:
+ // Internal implementation of reporting for use by ReportCompleted() and
+ // ReportNotCompleted().
+ absl::Status Report(ComputationResults results,
+ engine::PhaseOutcome phase_outcome,
+ absl::Duration plan_duration);
+ absl::Status ReportInternal(
+ std::string tf_checkpoint, engine::PhaseOutcome phase_outcome,
+ absl::Duration plan_duration,
+ fcp::secagg::ClientToServerWrapperMessage* secagg_commit_message);
+
+ // Helper function to send a ClientStreamMessage. If sending did not succeed,
+ // closes the underlying grpc stream. If sending does succeed then it updates
+ // `bytes_uploaded_`.
+ absl::Status Send(google::internal::federatedml::v2::ClientStreamMessage*
+ client_stream_message);
+
+ // Helper function to receive a ServerStreamMessage. If receiving did not
+ // succeed, closes the underlying grpc stream. If receiving does succeed then
+ // it updates `bytes_downloaded_`.
+ absl::Status Receive(google::internal::federatedml::v2::ServerStreamMessage*
+ server_stream_message);
+
+ // Helper function to compose a ProtocolOptionsRequest for eligibility eval or
+ // regular checkin requests.
+ google::internal::federatedml::v2::ProtocolOptionsRequest
+ CreateProtocolOptionsRequest(bool should_ack_checkin) const;
+
+ // Helper function to compose and send an EligibilityEvalCheckinRequest to the
+ // server.
+ absl::Status SendEligibilityEvalCheckinRequest();
+
+ // Helper function to compose and send a CheckinRequest to the server.
+ absl::Status SendCheckinRequest(
+ const std::optional<
+ google::internal::federatedml::v2::TaskEligibilityInfo>&
+ task_eligibility_info);
+
+ // Helper to receive + process a CheckinRequestAck message.
+ absl::Status ReceiveCheckinRequestAck();
+
+ // Helper to receive + process an EligibilityEvalCheckinResponse message.
+ absl::StatusOr<EligibilityEvalCheckinResult>
+ ReceiveEligibilityEvalCheckinResponse(
+ absl::Time start_time, std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback);
+
+ // Helper to receive + process a CheckinResponse message.
+ absl::StatusOr<CheckinResult> ReceiveCheckinResponse(
+ absl::Time start_time, std::function<void(const TaskAssignment&)>
+ payload_uris_received_callback);
+
+ // Utility class for holding an absolute retry time and a corresponding retry
+ // token.
+ struct RetryTimeAndToken {
+ absl::Time retry_time;
+ std::string retry_token;
+ };
+ // Helper to generate a RetryWindow from a given time and token.
+ google::internal::federatedml::v2::RetryWindow
+ GenerateRetryWindowFromRetryTimeAndToken(const RetryTimeAndToken& retry_info);
+
+ // Helper that moves to the given object state if the given status represents
+ // a permanent error.
+ void UpdateObjectStateIfPermanentError(
+ absl::Status status, ObjectState permanent_error_object_state);
+
+ // Utility struct to represent resource data coming from the gRPC protocol.
+ // A resource is either represented by a URI from which the data should be
+ // fetched (in which case `has_uri` is true and `uri` should not be empty), or
+ // is available as inline data (in which case `has_uri` is false and `data`
+ // may or may not be empty).
+ struct TaskResource {
+ bool has_uri;
+ const std::string& uri;
+ const std::string& data;
+ // The following fields will be set if the client should attempt to cache
+ // the resource.
+ const std::string& client_cache_id;
+ const absl::Duration max_age;
+ };
+ // Represents the common set of resources a task may have.
+ struct TaskResources {
+ TaskResource plan;
+ TaskResource checkpoint;
+ };
+
+ // Helper function for fetching the checkpoint/plan resources for an
+ // eligibility eval task or regular task. This function will return an error
+ // if either `TaskResource` represents an invalid state (e.g. if `has_uri &&
+ // uri.empty()`).
+ absl::StatusOr<PlanAndCheckpointPayloads> FetchTaskResources(
+ TaskResources task_resources);
+ // Validates the given `TaskResource` and converts it to a `UriOrInlineData`
+ // object for use with the `FetchResourcesInMemory` utility method.
+ absl::StatusOr<::fcp::client::http::UriOrInlineData>
+ ConvertResourceToUriOrInlineData(const TaskResource& resource);
+
+ ObjectState object_state_;
+ EventPublisher* const event_publisher_;
+ LogManager* const log_manager_;
+ std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory_;
+ const Flags* const flags_;
+ ::fcp::client::http::HttpClient* const http_client_;
+ std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream_;
+ std::unique_ptr<InterruptibleRunner> interruptible_runner_;
+ const std::string population_name_;
+ const std::string retry_token_;
+ const std::string client_version_;
+ const std::string attestation_measurement_;
+ std::function<absl::StatusOr<bool>()> should_abort_;
+ absl::BitGen bit_gen_;
+ // The set of canonical error codes that should be treated as 'permanent'
+ // errors.
+ absl::flat_hash_set<int32_t> federated_training_permanent_error_codes_;
+ int64_t http_bytes_downloaded_ = 0;
+ int64_t http_bytes_uploaded_ = 0;
+ std::unique_ptr<WallClockStopwatch> network_stopwatch_ =
+ WallClockStopwatch::Create();
+ // Represents 2 absolute retry timestamps and their corresponding retry
+ // tokens, to use when the device is rejected or accepted. The retry
+ // timestamps will have been generated based on the retry windows specified in
+ // the server's CheckinRequestAck message and the time at which that message
+ // was received.
+ struct CheckinRequestAckInfo {
+ RetryTimeAndToken retry_info_if_rejected;
+ RetryTimeAndToken retry_info_if_accepted;
+ };
+ // Represents the information received via the CheckinRequestAck message.
+ // This field will have an absent value until that message has been received.
+ std::optional<CheckinRequestAckInfo> checkin_request_ack_info_;
+ // The identifier of the task that was received in a CheckinResponse. Note
+ // that this does not refer to the identifier of the eligbility eval task that
+ // may have been received in an EligibilityEvalCheckinResponse.
+ std::string execution_phase_id_;
+ absl::flat_hash_map<
+ std::string, google::internal::federatedml::v2::SideChannelExecutionInfo>
+ side_channels_;
+ google::internal::federatedml::v2::SideChannelProtocolExecutionInfo
+ side_channel_protocol_execution_info_;
+ google::internal::federatedml::v2::SideChannelProtocolOptionsResponse
+ side_channel_protocol_options_response_;
+ // `nullptr` if the feature is disabled.
+ cache::ResourceCache* resource_cache_;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_
diff --git a/fcp/client/grpc_federated_protocol_test.cc b/fcp/client/grpc_federated_protocol_test.cc
new file mode 100644
index 0000000..9eb702e
--- /dev/null
+++ b/fcp/client/grpc_federated_protocol_test.cc
@@ -0,0 +1,1771 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/grpc_federated_protocol.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+
+#include "google/protobuf/text_format.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/random/random.h"
+#include "absl/status/status.h"
+#include "absl/synchronization/blocking_counter.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/cache/test_helpers.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/grpc_bidi_stream.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/testing/test_helpers.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/stats.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/secagg/client/secagg_client.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::client {
+namespace {
+
+using ::fcp::EqualsProto;
+using ::fcp::IsCode;
+using ::fcp::client::http::FakeHttpResponse;
+using ::fcp::client::http::HttpRequest;
+using ::fcp::client::http::MockHttpClient;
+using ::fcp::client::http::SimpleHttpRequestMatcher;
+using ::google::internal::federatedml::v2::AcceptanceInfo;
+using ::google::internal::federatedml::v2::CheckinRequest;
+using ::google::internal::federatedml::v2::ClientStreamMessage;
+using ::google::internal::federatedml::v2::EligibilityEvalCheckinRequest;
+using ::google::internal::federatedml::v2::EligibilityEvalPayload;
+using ::google::internal::federatedml::v2::HttpCompressionFormat;
+using ::google::internal::federatedml::v2::ReportResponse;
+using ::google::internal::federatedml::v2::RetryWindow;
+using ::google::internal::federatedml::v2::ServerStreamMessage;
+using ::google::internal::federatedml::v2::TaskEligibilityInfo;
+using ::google::internal::federatedml::v2::TaskWeight;
+using ::testing::_;
+using ::testing::AllOf;
+using ::testing::DoAll;
+using ::testing::DoubleEq;
+using ::testing::DoubleNear;
+using ::testing::Eq;
+using ::testing::Field;
+using ::testing::FieldsAre;
+using ::testing::Ge;
+using ::testing::Gt;
+using ::testing::HasSubstr;
+using ::testing::InSequence;
+using ::testing::IsEmpty;
+using ::testing::Lt;
+using ::testing::MockFunction;
+using ::testing::NiceMock;
+using ::testing::Not;
+using ::testing::Optional;
+using ::testing::Pair;
+using ::testing::Pointee;
+using ::testing::Return;
+using ::testing::SetArgPointee;
+using ::testing::StrictMock;
+using ::testing::UnorderedElementsAre;
+using ::testing::VariantWith;
+
+constexpr char kPopulationName[] = "TEST/POPULATION";
+constexpr char kFederatedSelectUriTemplate[] = "https://federated.select";
+constexpr char kExecutionPhaseId[] = "TEST/POPULATION/TEST_TASK#1234.ab35";
+constexpr char kPlan[] = "CLIENT_ONLY_PLAN";
+constexpr char kInitCheckpoint[] = "INIT_CHECKPOINT";
+constexpr char kRetryToken[] = "OLD_RETRY_TOKEN";
+constexpr char kClientVersion[] = "CLIENT_VERSION";
+constexpr char kAttestationMeasurement[] = "ATTESTATION_MEASUREMENT";
+constexpr int kSecAggExpectedNumberOfClients = 10;
+constexpr int kSecAggMinSurvivingClientsForReconstruction = 8;
+constexpr int kSecAggMinClientsInServerVisibleAggregate = 4;
+
+class MockGrpcBidiStream : public GrpcBidiStreamInterface {
+ public:
+ MOCK_METHOD(absl::Status, Send, (ClientStreamMessage*), (override));
+ MOCK_METHOD(absl::Status, Receive, (ServerStreamMessage*), (override));
+ MOCK_METHOD(void, Close, (), (override));
+ MOCK_METHOD(int64_t, ChunkingLayerBytesSent, (), (override));
+ MOCK_METHOD(int64_t, ChunkingLayerBytesReceived, (), (override));
+};
+
+constexpr int kTransientErrorsRetryPeriodSecs = 10;
+constexpr double kTransientErrorsRetryDelayJitterPercent = 0.1;
+constexpr double kExpectedTransientErrorsRetryPeriodSecsMin = 9.0;
+constexpr double kExpectedTransientErrorsRetryPeriodSecsMax = 11.0;
+constexpr int kPermanentErrorsRetryPeriodSecs = 100;
+constexpr double kPermanentErrorsRetryDelayJitterPercent = 0.2;
+constexpr double kExpectedPermanentErrorsRetryPeriodSecsMin = 80.0;
+constexpr double kExpectedPermanentErrorsRetryPeriodSecsMax = 120.0;
+
+void ExpectTransientErrorRetryWindow(const RetryWindow& retry_window) {
+ // The calculated retry delay must lie within the expected transient errors
+ // retry delay range.
+ EXPECT_THAT(retry_window.delay_min().seconds() +
+ retry_window.delay_min().nanos() / 1000000000,
+ AllOf(Ge(kExpectedTransientErrorsRetryPeriodSecsMin),
+ Lt(kExpectedTransientErrorsRetryPeriodSecsMax)));
+ EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
+}
+
+void ExpectPermanentErrorRetryWindow(const RetryWindow& retry_window) {
+ // The calculated retry delay must lie within the expected permanent errors
+ // retry delay range.
+ EXPECT_THAT(retry_window.delay_min().seconds() +
+ retry_window.delay_min().nanos() / 1000000000,
+ AllOf(Ge(kExpectedPermanentErrorsRetryPeriodSecsMin),
+ Lt(kExpectedPermanentErrorsRetryPeriodSecsMax)));
+ EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
+}
+
+google::internal::federatedml::v2::RetryWindow GetAcceptedRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ // Must not overlap with kTransientErrorsRetryPeriodSecs or
+ // kPermanentErrorsRetryPeriodSecs.
+ retry_window.mutable_delay_min()->set_seconds(200L);
+ retry_window.mutable_delay_max()->set_seconds(299L);
+ *retry_window.mutable_retry_token() = "RETRY_TOKEN_ACCEPTED";
+ return retry_window;
+}
+
+google::internal::federatedml::v2::RetryWindow GetRejectedRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ // Must not overlap with kTransientErrorsRetryPeriodSecs or
+ // kPermanentErrorsRetryPeriodSecs.
+ retry_window.mutable_delay_min()->set_seconds(300);
+ retry_window.mutable_delay_max()->set_seconds(399L);
+ *retry_window.mutable_retry_token() = "RETRY_TOKEN_REJECTED";
+ return retry_window;
+}
+
+void ExpectAcceptedRetryWindow(const RetryWindow& retry_window) {
+ // The calculated retry delay must lie within the expected permanent errors
+ // retry delay range.
+ EXPECT_THAT(retry_window.delay_min().seconds() +
+ retry_window.delay_min().nanos() / 1000000000,
+ AllOf(Ge(200), Lt(299L)));
+ EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
+}
+
+void ExpectRejectedRetryWindow(const RetryWindow& retry_window) {
+ // The calculated retry delay must lie within the expected permanent errors
+ // retry delay range.
+ EXPECT_THAT(retry_window.delay_min().seconds() +
+ retry_window.delay_min().nanos() / 1000000000,
+ AllOf(Ge(300), Lt(399)));
+ EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
+}
+
+ServerStreamMessage GetFakeCheckinRequestAck(
+ const RetryWindow& accepted_retry_window = GetAcceptedRetryWindow(),
+ const RetryWindow& rejected_retry_window = GetRejectedRetryWindow()) {
+ ServerStreamMessage checkin_request_ack_message;
+ *checkin_request_ack_message.mutable_checkin_request_ack()
+ ->mutable_retry_window_if_accepted() = accepted_retry_window;
+ *checkin_request_ack_message.mutable_checkin_request_ack()
+ ->mutable_retry_window_if_rejected() = rejected_retry_window;
+ return checkin_request_ack_message;
+}
+
+ServerStreamMessage GetFakeEnabledEligibilityCheckinResponse(
+ const std::string& plan, const std::string& init_checkpoint,
+ const std::string& execution_id) {
+ ServerStreamMessage checkin_response_message;
+ EligibilityEvalPayload* eval_payload =
+ checkin_response_message.mutable_eligibility_eval_checkin_response()
+ ->mutable_eligibility_eval_payload();
+ eval_payload->set_plan(plan);
+ eval_payload->set_init_checkpoint(init_checkpoint);
+ eval_payload->set_execution_id(execution_id);
+ return checkin_response_message;
+}
+
+ServerStreamMessage GetFakeDisabledEligibilityCheckinResponse() {
+ ServerStreamMessage checkin_response_message;
+ checkin_response_message.mutable_eligibility_eval_checkin_response()
+ ->mutable_no_eligibility_eval_configured();
+ return checkin_response_message;
+}
+
+ServerStreamMessage GetFakeRejectedEligibilityCheckinResponse() {
+ ServerStreamMessage rejection_response_message;
+ rejection_response_message.mutable_eligibility_eval_checkin_response()
+ ->mutable_rejection_info();
+ return rejection_response_message;
+}
+
+TaskEligibilityInfo GetFakeTaskEligibilityInfo() {
+ TaskEligibilityInfo eligibility_info;
+ TaskWeight* task_weight = eligibility_info.mutable_task_weights()->Add();
+ task_weight->set_task_name("foo");
+ task_weight->set_weight(567.8);
+ return eligibility_info;
+}
+
+ServerStreamMessage GetFakeRejectedCheckinResponse() {
+ ServerStreamMessage rejection_response_message;
+ rejection_response_message.mutable_checkin_response()
+ ->mutable_rejection_info();
+ return rejection_response_message;
+}
+
+ServerStreamMessage GetFakeAcceptedCheckinResponse(
+ const std::string& plan, const std::string& init_checkpoint,
+ const std::string& federated_select_uri_template,
+ const std::string& phase_id, bool use_secure_aggregation) {
+ ServerStreamMessage checkin_response_message;
+ AcceptanceInfo* acceptance_info =
+ checkin_response_message.mutable_checkin_response()
+ ->mutable_acceptance_info();
+ acceptance_info->set_plan(plan);
+ acceptance_info->set_execution_phase_id(phase_id);
+ acceptance_info->set_init_checkpoint(init_checkpoint);
+ acceptance_info->mutable_federated_select_uri_info()->set_uri_template(
+ federated_select_uri_template);
+ if (use_secure_aggregation) {
+ auto sec_agg =
+ acceptance_info->mutable_side_channel_protocol_execution_info()
+ ->mutable_secure_aggregation();
+ sec_agg->set_expected_number_of_clients(kSecAggExpectedNumberOfClients);
+ sec_agg->set_minimum_surviving_clients_for_reconstruction(
+ kSecAggMinSurvivingClientsForReconstruction);
+ sec_agg->set_minimum_clients_in_server_visible_aggregate(
+ kSecAggMinClientsInServerVisibleAggregate);
+ checkin_response_message.mutable_checkin_response()
+ ->mutable_protocol_options_response()
+ ->mutable_side_channels()
+ ->mutable_secure_aggregation()
+ ->set_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
+ }
+ return checkin_response_message;
+}
+
+ServerStreamMessage GetFakeReportResponse() {
+ ServerStreamMessage report_response_message;
+ *report_response_message.mutable_report_response() = ReportResponse();
+ return report_response_message;
+}
+
+ClientStreamMessage GetExpectedEligibilityEvalCheckinRequest(
+ bool enable_http_resource_support = false) {
+ ClientStreamMessage expected_message;
+ EligibilityEvalCheckinRequest* checkin_request =
+ expected_message.mutable_eligibility_eval_checkin_request();
+ checkin_request->set_population_name(kPopulationName);
+ checkin_request->set_client_version(kClientVersion);
+ checkin_request->set_retry_token(kRetryToken);
+ checkin_request->set_attestation_measurement(kAttestationMeasurement);
+ checkin_request->mutable_protocol_options_request()
+ ->mutable_side_channels()
+ ->mutable_secure_aggregation()
+ ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
+ checkin_request->mutable_protocol_options_request()->set_should_ack_checkin(
+ true);
+ checkin_request->mutable_protocol_options_request()
+ ->add_supported_http_compression_formats(
+ HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
+
+ if (enable_http_resource_support) {
+ checkin_request->mutable_protocol_options_request()
+ ->set_supports_http_download(true);
+ checkin_request->mutable_protocol_options_request()
+ ->set_supports_eligibility_eval_http_download(true);
+ }
+
+ return expected_message;
+}
+
+// This returns the CheckinRequest gRPC proto we expect each Checkin(...) call
+// to result in.
+ClientStreamMessage GetExpectedCheckinRequest(
+ const std::optional<TaskEligibilityInfo>& task_eligibility_info =
+ std::nullopt,
+ bool enable_http_resource_support = false) {
+ ClientStreamMessage expected_message;
+ CheckinRequest* checkin_request = expected_message.mutable_checkin_request();
+ checkin_request->set_population_name(kPopulationName);
+ checkin_request->set_client_version(kClientVersion);
+ checkin_request->set_retry_token(kRetryToken);
+ checkin_request->set_attestation_measurement(kAttestationMeasurement);
+ checkin_request->mutable_protocol_options_request()
+ ->mutable_side_channels()
+ ->mutable_secure_aggregation()
+ ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
+ checkin_request->mutable_protocol_options_request()->set_should_ack_checkin(
+ false);
+ checkin_request->mutable_protocol_options_request()
+ ->add_supported_http_compression_formats(
+ HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
+
+ if (enable_http_resource_support) {
+ checkin_request->mutable_protocol_options_request()
+ ->set_supports_http_download(true);
+ checkin_request->mutable_protocol_options_request()
+ ->set_supports_eligibility_eval_http_download(true);
+ }
+
+ if (task_eligibility_info.has_value()) {
+ *checkin_request->mutable_task_eligibility_info() = *task_eligibility_info;
+ }
+ return expected_message;
+}
+
+class GrpcFederatedProtocolTest
+ // The first parameter indicates whether support for HTTP task resources
+ // should be enabled.
+ : public testing::TestWithParam<bool> {
+ public:
+ GrpcFederatedProtocolTest() {
+ // The gRPC stream should always be closed at the end of all tests.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Close());
+ }
+
+ protected:
+ void SetUp() override {
+ enable_http_resource_support_ = GetParam();
+ EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(mock_flags_,
+ federated_training_transient_errors_retry_delay_secs)
+ .WillRepeatedly(Return(kTransientErrorsRetryPeriodSecs));
+ EXPECT_CALL(mock_flags_,
+ federated_training_transient_errors_retry_delay_jitter_percent)
+ .WillRepeatedly(Return(kTransientErrorsRetryDelayJitterPercent));
+ EXPECT_CALL(mock_flags_,
+ federated_training_permanent_errors_retry_delay_secs)
+ .WillRepeatedly(Return(kPermanentErrorsRetryPeriodSecs));
+ EXPECT_CALL(mock_flags_,
+ federated_training_permanent_errors_retry_delay_jitter_percent)
+ .WillRepeatedly(Return(kPermanentErrorsRetryDelayJitterPercent));
+ EXPECT_CALL(mock_flags_, federated_training_permanent_error_codes)
+ .WillRepeatedly(Return(std::vector<int32_t>{
+ static_cast<int32_t>(absl::StatusCode::kNotFound),
+ static_cast<int32_t>(absl::StatusCode::kInvalidArgument),
+ static_cast<int32_t>(absl::StatusCode::kUnimplemented)}));
+ EXPECT_CALL(mock_flags_,
+ enable_grpc_with_eligibility_eval_http_resource_support)
+ .WillRepeatedly(Return(enable_http_resource_support_));
+
+ // We only initialize federated_protocol_ in this SetUp method, rather than
+ // in the test's constructor, to ensure that we can set mock flag values
+ // before the GrpcFederatedProtocol constructor is called. Using
+ // std::unique_ptr conveniently allows us to assign the field a new value
+ // after construction (which we could not do if the field's type was
+ // GrpcFederatedProtocol, since it doesn't have copy or move constructors).
+ federated_protocol_ = std::make_unique<GrpcFederatedProtocol>(
+ &mock_event_publisher_, &mock_log_manager_,
+ absl::WrapUnique(mock_secagg_runner_factory_), &mock_flags_,
+ /*http_client=*/
+ enable_http_resource_support_ ? &mock_http_client_ : nullptr,
+ // We want to inject mocks stored in unique_ptrs to the
+ // class-under-test, hence we transfer ownership via WrapUnique. To
+ // write expectations for the mock, we retain the raw pointer to it,
+ // which will be valid until GrpcFederatedProtocol's d'tor is called.
+ absl::WrapUnique(mock_grpc_bidi_stream_), kPopulationName, kRetryToken,
+ kClientVersion, kAttestationMeasurement,
+ mock_should_abort_.AsStdFunction(), absl::BitGen(),
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ &mock_resource_cache_);
+ }
+
+ void TearDown() override {
+ fcp::client::http::HttpRequestHandle::SentReceivedBytes
+ sent_received_bytes = mock_http_client_.TotalSentReceivedBytes();
+
+ NetworkStats network_stats = federated_protocol_->GetNetworkStats();
+ EXPECT_THAT(network_stats.bytes_downloaded,
+ Ge(mock_grpc_bidi_stream_->ChunkingLayerBytesReceived() +
+ sent_received_bytes.received_bytes));
+ EXPECT_THAT(network_stats.bytes_uploaded,
+ Ge(mock_grpc_bidi_stream_->ChunkingLayerBytesSent() +
+ sent_received_bytes.sent_bytes));
+ // If any network traffic occurred, we expect to see some time reflected in
+ // the duration (if the flag is on).
+ if (network_stats.bytes_uploaded > 0) {
+ EXPECT_THAT(network_stats.network_duration, Gt(absl::ZeroDuration()));
+ }
+ }
+
+ // This function runs a successful
+ // EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction()) that
+ // results in an eligibility eval payload being returned by the server. This
+ // is a utility function used by Checkin*() tests that depend on a prior,
+ // successful execution of
+ // EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction()). It
+ // returns a absl::Status, which the caller should verify is OK using
+ // ASSERT_OK.
+ absl::Status RunSuccessfulEligibilityEvalCheckin(
+ bool eligibility_eval_enabled = true,
+ const RetryWindow& accepted_retry_window = GetAcceptedRetryWindow(),
+ const RetryWindow& rejected_retry_window = GetRejectedRetryWindow()) {
+ EXPECT_CALL(
+ *mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(GetExpectedEligibilityEvalCheckinRequest(
+ enable_http_resource_support_)))))
+ .WillOnce(Return(absl::OkStatus()));
+
+ const std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck(
+ accepted_retry_window, rejected_retry_window)),
+ Return(absl::OkStatus())))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(
+ eligibility_eval_enabled
+ ? GetFakeEnabledEligibilityCheckinResponse(
+ kPlan, kInitCheckpoint, expected_execution_id)
+ : GetFakeDisabledEligibilityCheckinResponse()),
+ Return(absl::OkStatus())));
+
+ return federated_protocol_
+ ->EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction())
+ .status();
+ }
+
+ // This function runs a successful Checkin() that results in acceptance by the
+ // server. This is a utility function used by Report*() tests that depend on a
+ // prior, successful execution of Checkin().
+ // It returns a absl::Status, which the caller should verify is OK using
+ // ASSERT_OK.
+ absl::StatusOr<FederatedProtocol::CheckinResult> RunSuccessfulCheckin(
+ bool use_secure_aggregation,
+ const std::optional<TaskEligibilityInfo>& task_eligibility_info =
+ GetFakeTaskEligibilityInfo()) {
+ EXPECT_CALL(*mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
+ task_eligibility_info, enable_http_resource_support_)))))
+ .WillOnce(Return(absl::OkStatus()));
+
+ {
+ InSequence seq;
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(GetFakeAcceptedCheckinResponse(
+ kPlan, kInitCheckpoint, kFederatedSelectUriTemplate,
+ kExecutionPhaseId, use_secure_aggregation)),
+ Return(absl::OkStatus())))
+ .RetiresOnSaturation();
+ }
+
+ return federated_protocol_->Checkin(
+ task_eligibility_info, mock_task_received_callback_.AsStdFunction());
+ }
+
+ // See note in the constructor for why these are pointers.
+ StrictMock<MockGrpcBidiStream>* mock_grpc_bidi_stream_ =
+ new StrictMock<MockGrpcBidiStream>();
+
+ StrictMock<MockEventPublisher> mock_event_publisher_;
+ NiceMock<MockLogManager> mock_log_manager_;
+ StrictMock<MockSecAggRunnerFactory>* mock_secagg_runner_factory_ =
+ new StrictMock<MockSecAggRunnerFactory>();
+ StrictMock<MockSecAggRunner>* mock_secagg_runner_;
+ NiceMock<MockFlags> mock_flags_;
+ StrictMock<MockHttpClient> mock_http_client_;
+ NiceMock<MockFunction<bool()>> mock_should_abort_;
+ StrictMock<cache::MockResourceCache> mock_resource_cache_;
+ NiceMock<MockFunction<void(
+ const ::fcp::client::FederatedProtocol::EligibilityEvalTask&)>>
+ mock_eet_received_callback_;
+ NiceMock<MockFunction<void(
+ const ::fcp::client::FederatedProtocol::TaskAssignment&)>>
+ mock_task_received_callback_;
+
+ // The class under test.
+ std::unique_ptr<GrpcFederatedProtocol> federated_protocol_;
+ bool enable_http_resource_support_;
+};
+
+std::string GenerateTestName(
+ const testing::TestParamInfo<GrpcFederatedProtocolTest::ParamType>& info) {
+ std::string name = info.param ? "Http_resource_support_enabled"
+ : "Http_resource_support_disabled";
+ return name;
+}
+
+INSTANTIATE_TEST_SUITE_P(NewVsOldBehavior, GrpcFederatedProtocolTest,
+ testing::Bool(), GenerateTestName);
+
+using GrpcFederatedProtocolDeathTest = GrpcFederatedProtocolTest;
+INSTANTIATE_TEST_SUITE_P(NewVsOldBehavior, GrpcFederatedProtocolDeathTest,
+ testing::Bool(), GenerateTestName);
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestTransientErrorRetryWindowDifferentAcrossDifferentInstances) {
+ const RetryWindow& retry_window1 =
+ federated_protocol_->GetLatestRetryWindow();
+ ExpectTransientErrorRetryWindow(retry_window1);
+ federated_protocol_.reset(nullptr);
+
+ mock_grpc_bidi_stream_ = new StrictMock<MockGrpcBidiStream>();
+ EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Close());
+ mock_secagg_runner_factory_ = new StrictMock<MockSecAggRunnerFactory>();
+ // Create a new GrpcFederatedProtocol instance. It should not produce the same
+ // retry window value as the one we just got. This is a simple correctness
+ // check to ensure that the value is at least randomly generated (and that we
+ // don't accidentally use the random number generator incorrectly).
+ federated_protocol_ = std::make_unique<GrpcFederatedProtocol>(
+ &mock_event_publisher_, &mock_log_manager_,
+ absl::WrapUnique(mock_secagg_runner_factory_), &mock_flags_,
+ /*http_client=*/nullptr, absl::WrapUnique(mock_grpc_bidi_stream_),
+ kPopulationName, kRetryToken, kClientVersion, kAttestationMeasurement,
+ mock_should_abort_.AsStdFunction(), absl::BitGen(),
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ &mock_resource_cache_);
+
+ const RetryWindow& retry_window2 =
+ federated_protocol_->GetLatestRetryWindow();
+ ExpectTransientErrorRetryWindow(retry_window2);
+
+ EXPECT_THAT(retry_window1, Not(EqualsProto(retry_window2)));
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestEligibilityEvalCheckinSendFailsTransientError) {
+ // Make the gRPC stream return an UNAVAILABLE error when the
+ // EligibilityEvalCheckin(...) code tries to send its first message. This
+ // should result in the error being returned as the result.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::UnavailableError("foo")));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
+ EXPECT_THAT(eligibility_checkin_result.status().message(), "foo");
+ // No RetryWindows were received from the server, so we expect to get a
+ // RetryWindow generated based on the transient errors retry delay flag.
+ ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestEligibilityEvalCheckinSendFailsPermanentError) {
+ // Make the gRPC stream return an NOT_FOUND error when the
+ // EligibilityEvalCheckin(...) code tries to send its first message. This
+ // should result in the error being returned as the result.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::NotFoundError("foo")));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
+ EXPECT_THAT(eligibility_checkin_result.status().message(), "foo");
+ // No RetryWindows were received from the server, so we expect to get a
+ // RetryWindow generated based on the *permanent* errors retry delay flag,
+ // since NOT_FOUND is marked as a permanent error in the flags.
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests the case where the blocking Send() call in EligibilityEvalCheckin is
+// interrupted.
+TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinSendInterrupted) {
+ absl::BlockingCounter counter_should_abort(1);
+
+ // Make Send() block until the counter is decremented.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce([&counter_should_abort](ClientStreamMessage* ignored) {
+ counter_should_abort.Wait();
+ return absl::OkStatus();
+ });
+ // Make should_abort return false for the first two calls, and then make it
+ // decrement the counter and return true, triggering an abort sequence and
+ // unblocking the Send() call we caused to block above.
+ EXPECT_CALL(mock_should_abort_, Call())
+ .WillOnce(Return(false))
+ .WillOnce(Return(false))
+ .WillRepeatedly([&counter_should_abort] {
+ counter_should_abort.DecrementCount();
+ return true;
+ });
+ // In addition to the Close() call we expect in the test fixture above, expect
+ // an additional one (the one that induced the abort).
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Close()).Times(1).RetiresOnSaturation();
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(CANCELLED));
+ // No RetryWindows were received from the server, so we expect to get a
+ // RetryWindow generated based on the transient errors retry delay flag.
+ ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// If a CheckinRequestAck is requested in the ProtocolOptionsRequest but not
+// received, UNIMPLEMENTED should be returned.
+TEST_P(GrpcFederatedProtocolTest,
+ TestEligibilityEvalCheckinMissingCheckinRequestAck) {
+ // We immediately return an EligibilityEvalCheckinResponse, rather than
+ // returning a CheckinRequestAck first.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(GetFakeRejectedEligibilityCheckinResponse()),
+ Return(absl::OkStatus())));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::
+ BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_EXPECTED_BUT_NOT_RECVD)); // NOLINT
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNIMPLEMENTED));
+ // No RetryWindows were received from the server, so we expect to get a
+ // RetryWindow generated based on the *permanent* errors retry delay flag,
+ // since UNIMPLEMENTED is marked as a permanent error in the flags.
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestEligibilityEvalCheckinWaitForCheckinRequestAckFails) {
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+
+ // Make the very first Receive() call fail (i.e. the one expecting the
+ // CheckinRequestAck).
+ std::string expected_message = "foo";
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(Return(absl::AbortedError(expected_message)));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(ABORTED));
+ EXPECT_THAT(eligibility_checkin_result.status().message(), expected_message);
+ // No RetryWindows were received from the server, so we expect to get a
+ // RetryWindow generated based on the transient errors retry delay flag.
+ ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestEligibilityEvalCheckinWaitForCheckinResponseFails) {
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+
+ // Failed checkins that have received an ack already should return the
+ // rejected retry window.
+ std::string expected_message = "foo";
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
+ Return(absl::OkStatus())))
+ // Make the second Receive() call fail (i.e. the one expecting the
+ // EligibilityEvalCheckinResponse).
+ .WillOnce(Return(absl::AbortedError(expected_message)));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(ABORTED));
+ EXPECT_THAT(eligibility_checkin_result.status().message(), expected_message);
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinRejection) {
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
+ Return(absl::OkStatus())))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(GetFakeRejectedEligibilityCheckinResponse()),
+ Return(absl::OkStatus())));
+
+ // The 'eet received' callback should not be invoked since no EET was given to
+ // the client.
+ EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_OK(eligibility_checkin_result);
+ EXPECT_THAT(*eligibility_checkin_result,
+ VariantWith<FederatedProtocol::Rejection>(_));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinDisabled) {
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
+ Return(absl::OkStatus())))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(GetFakeDisabledEligibilityCheckinResponse()),
+ Return(absl::OkStatus())));
+
+ // The 'eet received' callback should not be invoked since no EET was given to
+ // the client.
+ EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_OK(eligibility_checkin_result);
+ EXPECT_THAT(*eligibility_checkin_result,
+ VariantWith<FederatedProtocol::EligibilityEvalDisabled>(_));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinEnabled) {
+ // Note that in this particular test we check that the eligibility eval
+ // checkin request is as expected (in all prior tests we just use the '_'
+ // matcher, because the request isn't really relevant to the test).
+ EXPECT_CALL(*mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(GetExpectedEligibilityEvalCheckinRequest(
+ enable_http_resource_support_)))))
+ .WillOnce(Return(absl::OkStatus()));
+
+ // The EligibilityEvalCheckin(...) method should return the rejected
+ // RetryWindow, since after merely completing an eligibility eval checkin the
+ // client hasn't actually been accepted to a specific task yet.
+ std::string expected_plan = kPlan;
+ std::string expected_checkpoint = kInitCheckpoint;
+ std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
+ Return(absl::OkStatus())))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(GetFakeEnabledEligibilityCheckinResponse(
+ expected_plan, expected_checkpoint, expected_execution_id)),
+ Return(absl::OkStatus())));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
+
+ // The 'EET received' callback should be called, even if the task resource
+ // data was available inline.
+ EXPECT_CALL(mock_eet_received_callback_,
+ Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
+ Eq(std::nullopt))));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_OK(eligibility_checkin_result);
+ // If HTTP support is enabled then the checkpoint data gets returned in the
+ // shape of an absl::Cord (rather than an std::string), regardless of
+ // whether it was actually downloaded via HTTP.
+ if (enable_http_resource_support_) {
+ EXPECT_THAT(*eligibility_checkin_result,
+ VariantWith<FederatedProtocol::EligibilityEvalTask>(
+ FieldsAre(FieldsAre(absl::Cord(expected_plan),
+ absl::Cord(expected_checkpoint)),
+ expected_execution_id, Eq(std::nullopt))));
+ } else {
+ EXPECT_THAT(*eligibility_checkin_result,
+ VariantWith<FederatedProtocol::EligibilityEvalTask>(
+ FieldsAre(FieldsAre(expected_plan, expected_checkpoint),
+ expected_execution_id, Eq(std::nullopt))));
+ }
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestEligiblityEvalCheckinEnabledWithHttpResourcesDownloaded) {
+ if (!enable_http_resource_support_) {
+ GTEST_SKIP() << "This test only applies if the HTTP task resources feature "
+ "is enabled";
+ return;
+ }
+
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+
+ std::string expected_plan = kPlan;
+ std::string plan_uri = "https://fake.uri/plan";
+ std::string expected_checkpoint = kInitCheckpoint;
+ std::string checkpoint_uri = "https://fake.uri/checkpoint";
+ std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
+ ServerStreamMessage fake_response = GetFakeEnabledEligibilityCheckinResponse(
+ /*plan=*/"", /*init_checkpoint=*/"", expected_execution_id);
+ EligibilityEvalPayload* eligibility_eval_payload =
+ fake_response.mutable_eligibility_eval_checkin_response()
+ ->mutable_eligibility_eval_payload();
+ eligibility_eval_payload->mutable_plan_resource()->set_uri(plan_uri);
+ eligibility_eval_payload->mutable_init_checkpoint_resource()->set_uri(
+ checkpoint_uri);
+
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
+ Return(absl::OkStatus())))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(fake_response), Return(absl::OkStatus())));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
+
+ {
+ InSequence seq;
+ // The 'EET received' callback should be called *before* the actual task
+ // resources are fetched.
+ EXPECT_CALL(mock_eet_received_callback_,
+ Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
+ Eq(std::nullopt))));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, {}, expected_plan)));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ checkpoint_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, {}, expected_checkpoint)));
+ }
+
+ {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED));
+ }
+
+ // Issue the Eligibility Eval checkin.
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_OK(eligibility_checkin_result);
+ EXPECT_THAT(
+ *eligibility_checkin_result,
+ VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
+ FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
+ expected_execution_id, Eq(std::nullopt))));
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestEligiblityEvalCheckinEnabledWithHttpResourcesPlanDataFetchFailed) {
+ if (!enable_http_resource_support_) {
+ GTEST_SKIP() << "This test only applies if the HTTP task resources feature "
+ "is enabled";
+ return;
+ }
+
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+
+ std::string expected_plan = kPlan;
+ std::string plan_uri = "https://fake.uri/plan";
+ std::string expected_checkpoint = kInitCheckpoint;
+ std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
+ ServerStreamMessage fake_response = GetFakeEnabledEligibilityCheckinResponse(
+ /*plan=*/"", expected_checkpoint, expected_execution_id);
+ fake_response.mutable_eligibility_eval_checkin_response()
+ ->mutable_eligibility_eval_payload()
+ ->mutable_plan_resource()
+ ->set_uri(plan_uri);
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
+ Return(absl::OkStatus())))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(fake_response), Return(absl::OkStatus())));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(404, {}, "")));
+
+ {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
+ }
+
+ // Issue the eligibility eval checkin.
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
+ EXPECT_THAT(eligibility_checkin_result.status().message(),
+ HasSubstr("plan fetch failed"));
+ EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("404"));
+ // The EligibilityEvalCheckin call is expected to return the permanent error
+ // retry window, since 404 maps to a permanent error.
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestEligiblityEvalCheckinEnabledWithHttpResourcesCheckpointFetchFailed) {
+ if (!enable_http_resource_support_) {
+ GTEST_SKIP() << "This test only applies if the HTTP task resources feature "
+ "is enabled";
+ return;
+ }
+
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+
+ std::string expected_plan = kPlan;
+ std::string expected_checkpoint = kInitCheckpoint;
+ std::string checkpoint_uri = "https://fake.uri/checkpoint";
+ std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
+ ServerStreamMessage fake_response = GetFakeEnabledEligibilityCheckinResponse(
+ expected_plan, /*init_checkpoint=*/"", expected_execution_id);
+ fake_response.mutable_eligibility_eval_checkin_response()
+ ->mutable_eligibility_eval_payload()
+ ->mutable_init_checkpoint_resource()
+ ->set_uri(checkpoint_uri);
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
+ Return(absl::OkStatus())))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(fake_response), Return(absl::OkStatus())));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ checkpoint_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(503, {}, "")));
+
+ {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
+ }
+
+ // Issue the eligibility eval checkin.
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
+ EXPECT_THAT(eligibility_checkin_result.status().message(),
+ HasSubstr("checkpoint fetch failed"));
+ EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("503"));
+ // The EligibilityEvalCheckin call is expected to return the rejected error
+ // retry window, since 503 maps to a transient error.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests that the protocol correctly sanitizes any invalid values it may have
+// received from the server.
+TEST_P(GrpcFederatedProtocolTest,
+ TestNegativeMinMaxRetryDelayValueSanitization) {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(-1);
+ retry_window.mutable_delay_max()->set_seconds(-2);
+
+ // The above retry window's negative min/max values should be clamped to 0.
+ google::internal::federatedml::v2::RetryWindow expected_retry_window;
+ expected_retry_window.mutable_delay_min()->set_seconds(0);
+ expected_retry_window.mutable_delay_max()->set_seconds(0);
+
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin(
+ /* eligibility_eval_enabled=*/true, retry_window, retry_window));
+ const RetryWindow& actual_retry_window =
+ federated_protocol_->GetLatestRetryWindow();
+ // The above retry window's invalid max value should be clamped to the min
+ // value (minus some errors introduced by the inaccuracy of double
+ // multiplication).
+ EXPECT_THAT(actual_retry_window.delay_min().seconds() +
+ actual_retry_window.delay_min().nanos() / 1000000000.0,
+ DoubleEq(0));
+ EXPECT_THAT(actual_retry_window.delay_max().seconds() +
+ actual_retry_window.delay_max().nanos() / 1000000000.0,
+ DoubleEq(0));
+}
+
+// Tests that the protocol correctly sanitizes any invalid values it may have
+// received from the server.
+TEST_P(GrpcFederatedProtocolTest, TestInvalidMaxRetryDelayValueSanitization) {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(1234);
+ retry_window.mutable_delay_max()->set_seconds(1233); // less than delay_min
+
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin(
+ /* eligibility_eval_enabled=*/true, retry_window, retry_window));
+ const RetryWindow& actual_retry_window =
+ federated_protocol_->GetLatestRetryWindow();
+ // The above retry window's invalid max value should be clamped to the min
+ // value (minus some errors introduced by the inaccuracy of double
+ // multiplication). Note that DoubleEq enforces too precise of bounds, so we
+ // use DoubleNear instead.
+ EXPECT_THAT(actual_retry_window.delay_min().seconds() +
+ actual_retry_window.delay_min().nanos() / 1000000000.0,
+ DoubleNear(1234.0, 0.02));
+ EXPECT_THAT(actual_retry_window.delay_max().seconds() +
+ actual_retry_window.delay_max().nanos() / 1000000000.0,
+ DoubleNear(1234.0, 0.02));
+}
+
+TEST_P(GrpcFederatedProtocolDeathTest, TestCheckinMissingTaskEligibilityInfo) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+
+ // A Checkin(...) request with a missing TaskEligibilityInfo should now fail,
+ // as the protocol requires us to provide one based on the plan includes in
+ // the eligibility eval checkin response payload.
+ ASSERT_DEATH(
+ {
+ auto unused = federated_protocol_->Checkin(
+ std::nullopt, mock_task_received_callback_.AsStdFunction());
+ },
+ _);
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestCheckinSendFailsTransientError) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+
+ // Make the gRPC stream return an UNAVAILABLE error when the Checkin(...) code
+ // tries to send its first message. This should result in the error being
+ // returned as the result.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::UnavailableError("foo")));
+
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+ EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
+ EXPECT_THAT(checkin_result.status().message(), "foo");
+ // RetryWindows were already received from the server during the eligibility
+ // eval checkin, so we expect to get a 'rejected' retry window.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestCheckinSendFailsPermanentError) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+
+ // Make the gRPC stream return an NOT_FOUND error when the Checkin(...) code
+ // tries to send its first message. This should result in the error being
+ // returned as the result.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::NotFoundError("foo")));
+
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+ EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
+ EXPECT_THAT(checkin_result.status().message(), "foo");
+ // Even though RetryWindows were already received from the server during the
+ // eligibility eval checkin, we expect a RetryWindow generated based on the
+ // *permanent* errors retry delay flag, since NOT_FOUND is marked as a
+ // permanent error in the flags, and permanent errors should always result in
+ // permanent error windows (regardless of whether a CheckinRequestAck was
+ // already received).
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests the case where the blocking Send() call in Checkin is interrupted.
+TEST_P(GrpcFederatedProtocolTest, TestCheckinSendInterrupted) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+
+ absl::BlockingCounter counter_should_abort(1);
+
+ // Make Send() block until the counter is decremented.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce([&counter_should_abort](ClientStreamMessage* ignored) {
+ counter_should_abort.Wait();
+ return absl::OkStatus();
+ });
+ // Make should_abort return false for the first two calls, and then make it
+ // decrement the counter and return true, triggering an abort sequence and
+ // unblocking the Send() call we caused to block above.
+ EXPECT_CALL(mock_should_abort_, Call())
+ .WillOnce(Return(false))
+ .WillOnce(Return(false))
+ .WillRepeatedly([&counter_should_abort] {
+ counter_should_abort.DecrementCount();
+ return true;
+ });
+ // In addition to the Close() call we expect in the test fixture above, expect
+ // an additional one (the one that induced the abort).
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Close()).Times(1).RetiresOnSaturation();
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC));
+
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+ EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
+ // RetryWindows were already received from the server during the eligibility
+ // eval checkin, so we expect to get a 'rejected' retry window.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestCheckinRejectionWithTaskEligibilityInfo) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+
+ // Expect a checkin request for the next call to Checkin(...).
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeRejectedCheckinResponse()),
+ Return(absl::OkStatus())));
+
+ // The 'task received' callback should not be invoked since no task was given
+ // to the client.
+ EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+
+ ASSERT_OK(checkin_result.status());
+ EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests whether we can issue a Checkin() request correctly without passing a
+// TaskEligibilityInfo, in the case that the eligibility eval checkin didn't
+// return any eligibility eval task to run.
+TEST_P(GrpcFederatedProtocolTest,
+ TestCheckinRejectionWithoutTaskEligibilityInfo) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(
+ RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/false));
+
+ // Expect a checkin request for the next call to Checkin(...).
+ EXPECT_CALL(*mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
+ /*task_eligibility_info=*/std::nullopt,
+ enable_http_resource_support_)))))
+ .WillOnce(Return(absl::OkStatus()));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeRejectedCheckinResponse()),
+ Return(absl::OkStatus())));
+
+ // The 'task received' callback should not be invoked since no task was given
+ // to the client.
+ EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
+
+ // Issue the regular checkin, without a TaskEligibilityInfo (since we didn't
+ // receive an eligibility eval task to run during eligibility eval checkin).
+ auto checkin_result = federated_protocol_->Checkin(
+ std::nullopt, mock_task_received_callback_.AsStdFunction());
+
+ ASSERT_OK(checkin_result.status());
+ EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestCheckinAccept) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+
+ // Once the eligibility eval checkin has succeeded, let's fake some network
+ // stats data so that we can verify that it is logged correctly.
+ int64_t chunking_layer_bytes_downloaded = 555;
+ int64_t chunking_layer_bytes_uploaded = 666;
+ EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
+ .WillRepeatedly(Return(chunking_layer_bytes_downloaded));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
+ .WillRepeatedly(Return(chunking_layer_bytes_uploaded));
+
+ // Note that in this particular test we check that the CheckinRequest is as
+ // expected (in all prior tests we just use the '_' matcher, because the
+ // request isn't really relevant to the test).
+ TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
+ EXPECT_CALL(*mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
+ expected_eligibility_info, enable_http_resource_support_)))))
+ .WillOnce(Return(absl::OkStatus()));
+
+ std::string expected_plan = kPlan;
+ std::string expected_checkpoint = kInitCheckpoint;
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(GetFakeAcceptedCheckinResponse(
+ expected_plan, expected_checkpoint,
+ kFederatedSelectUriTemplate, kExecutionPhaseId,
+ /* use_secure_aggregation=*/true)),
+ Return(absl::OkStatus())));
+
+ // The 'task received' callback should be called even when the resources were
+ // available inline.
+ EXPECT_CALL(
+ mock_task_received_callback_,
+ Call(FieldsAre(
+ FieldsAre("", ""), kFederatedSelectUriTemplate, kExecutionPhaseId,
+ Optional(AllOf(
+ Field(&FederatedProtocol::SecAggInfo::expected_number_of_clients,
+ kSecAggExpectedNumberOfClients),
+ Field(&FederatedProtocol::SecAggInfo::
+ minimum_clients_in_server_visible_aggregate,
+ kSecAggMinClientsInServerVisibleAggregate))))));
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
+
+ ASSERT_OK(checkin_result.status());
+ // If HTTP support is enabled then the checkpoint data gets returned in the
+ // shape of an absl::Cord (rather than an std::string), regardless of whether
+ // it was actually downloaded via HTTP.
+ if (enable_http_resource_support_) {
+ EXPECT_THAT(
+ *checkin_result,
+ VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
+ FieldsAre(absl::Cord(expected_plan),
+ absl::Cord(expected_checkpoint)),
+ kFederatedSelectUriTemplate, kExecutionPhaseId,
+ Optional(AllOf(
+ Field(
+ &FederatedProtocol::SecAggInfo::expected_number_of_clients,
+ kSecAggExpectedNumberOfClients),
+ Field(&FederatedProtocol::SecAggInfo::
+ minimum_clients_in_server_visible_aggregate,
+ kSecAggMinClientsInServerVisibleAggregate))))));
+ } else {
+ EXPECT_THAT(
+ *checkin_result,
+ VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
+ FieldsAre(expected_plan, expected_checkpoint),
+ kFederatedSelectUriTemplate, kExecutionPhaseId,
+ Optional(AllOf(
+ Field(
+ &FederatedProtocol::SecAggInfo::expected_number_of_clients,
+ kSecAggExpectedNumberOfClients),
+ Field(&FederatedProtocol::SecAggInfo::
+ minimum_clients_in_server_visible_aggregate,
+ kSecAggMinClientsInServerVisibleAggregate))))));
+ }
+ // The Checkin call is expected to return the accepted retry window from the
+ // CheckinRequestAck response to the first eligibility eval request.
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestCheckinAcceptWithHttpResourcesDownloaded) {
+ if (!enable_http_resource_support_) {
+ GTEST_SKIP() << "This test only applies the HTTP task resources feature "
+ "is enabled";
+ return;
+ }
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+
+ // Once the eligibility eval checkin has succeeded, let's fake some network
+ // stats data so that we can verify that it is logged correctly.
+ int64_t chunking_layer_bytes_downloaded = 555;
+ int64_t chunking_layer_bytes_uploaded = 666;
+ EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
+ .WillRepeatedly(Return(chunking_layer_bytes_downloaded));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
+ .WillRepeatedly(Return(chunking_layer_bytes_uploaded));
+
+ // Note that in this particular test we check that the CheckinRequest is as
+ // expected (in all prior tests we just use the '_' matcher, because the
+ // request isn't really relevant to the test).
+ TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
+ EXPECT_CALL(
+ *mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
+ expected_eligibility_info, /*enable_http_resource_support=*/true)))))
+ .WillOnce(Return(absl::OkStatus()));
+
+ std::string expected_plan = kPlan;
+ std::string plan_uri = "https://fake.uri/plan";
+ std::string expected_checkpoint = kInitCheckpoint;
+ std::string checkpoint_uri = "https://fake.uri/checkpoint";
+ ServerStreamMessage fake_checkin_response = GetFakeAcceptedCheckinResponse(
+ /*plan=*/"", /*init_checkpoint=*/"", kFederatedSelectUriTemplate,
+ kExecutionPhaseId,
+ /* use_secure_aggregation=*/true);
+ AcceptanceInfo* acceptance_info =
+ fake_checkin_response.mutable_checkin_response()
+ ->mutable_acceptance_info();
+ acceptance_info->mutable_plan_resource()->set_uri(plan_uri);
+ acceptance_info->mutable_init_checkpoint_resource()->set_uri(checkpoint_uri);
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(fake_checkin_response),
+ Return(absl::OkStatus())));
+
+ {
+ InSequence seq;
+ // The 'task received' callback should be called *before* the actual task
+ // resources are fetched.
+ EXPECT_CALL(
+ mock_task_received_callback_,
+ Call(FieldsAre(
+ FieldsAre("", ""), kFederatedSelectUriTemplate, kExecutionPhaseId,
+ Optional(AllOf(
+ Field(
+ &FederatedProtocol::SecAggInfo::expected_number_of_clients,
+ kSecAggExpectedNumberOfClients),
+ Field(&FederatedProtocol::SecAggInfo::
+ minimum_clients_in_server_visible_aggregate,
+ kSecAggMinClientsInServerVisibleAggregate))))));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, {}, expected_plan)));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ checkpoint_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, {}, expected_checkpoint)));
+ }
+
+ {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED));
+ }
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
+
+ ASSERT_OK(checkin_result.status());
+ EXPECT_THAT(
+ *checkin_result,
+ VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
+ FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
+ kFederatedSelectUriTemplate, kExecutionPhaseId,
+ Optional(AllOf(
+ Field(&FederatedProtocol::SecAggInfo::expected_number_of_clients,
+ kSecAggExpectedNumberOfClients),
+ Field(&FederatedProtocol::SecAggInfo::
+ minimum_clients_in_server_visible_aggregate,
+ kSecAggMinClientsInServerVisibleAggregate))))));
+ // The Checkin call is expected to return the accepted retry window from the
+ // CheckinRequestAck response to the first eligibility eval request.
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestCheckinAcceptWithHttpResourcePlanDataFetchFailed) {
+ if (!enable_http_resource_support_) {
+ GTEST_SKIP() << "This test only applies the HTTP task resources feature "
+ "is enabled";
+ return;
+ }
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+
+ // Note that in this particular test we check that the CheckinRequest is as
+ // expected (in all prior tests we just use the '_' matcher, because the
+ // request isn't really relevant to the test).
+ TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
+ EXPECT_CALL(
+ *mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
+ expected_eligibility_info, /*enable_http_resource_support=*/true)))))
+ .WillOnce(Return(absl::OkStatus()));
+
+ std::string expected_plan = kPlan;
+ std::string plan_uri = "https://fake.uri/plan";
+ std::string expected_checkpoint = kInitCheckpoint;
+ ServerStreamMessage fake_checkin_response = GetFakeAcceptedCheckinResponse(
+ /*plan=*/"", expected_checkpoint, kFederatedSelectUriTemplate,
+ kExecutionPhaseId,
+ /* use_secure_aggregation=*/true);
+ AcceptanceInfo* acceptance_info =
+ fake_checkin_response.mutable_checkin_response()
+ ->mutable_acceptance_info();
+ acceptance_info->mutable_plan_resource()->set_uri(plan_uri);
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(fake_checkin_response),
+ Return(absl::OkStatus())));
+
+ // Mock a failed plan fetch.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(404, {}, "")));
+
+ {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
+ }
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
+ EXPECT_THAT(checkin_result.status().message(),
+ HasSubstr("plan fetch failed"));
+ EXPECT_THAT(checkin_result.status().message(), HasSubstr("404"));
+ // The Checkin call is expected to return the permanent error retry window,
+ // since 404 maps to a permanent error.
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest,
+ TestCheckinAcceptWithHttpResourceCheckpointDataFetchFailed) {
+ if (!enable_http_resource_support_) {
+ GTEST_SKIP() << "This test only applies the HTTP task resources feature "
+ "is enabled";
+ return;
+ }
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+
+ // Note that in this particular test we check that the CheckinRequest is as
+ // expected (in all prior tests we just use the '_' matcher, because the
+ // request isn't really relevant to the test).
+ TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
+ EXPECT_CALL(
+ *mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
+ expected_eligibility_info, /*enable_http_resource_support=*/true)))))
+ .WillOnce(Return(absl::OkStatus()));
+
+ std::string expected_plan = kPlan;
+ std::string expected_checkpoint = kInitCheckpoint;
+ std::string checkpoint_uri = "https://fake.uri/checkpoint";
+ ServerStreamMessage fake_checkin_response = GetFakeAcceptedCheckinResponse(
+ expected_plan, /*init_checkpoint=*/"", kFederatedSelectUriTemplate,
+ kExecutionPhaseId,
+ /* use_secure_aggregation=*/true);
+ AcceptanceInfo* acceptance_info =
+ fake_checkin_response.mutable_checkin_response()
+ ->mutable_acceptance_info();
+ acceptance_info->mutable_init_checkpoint_resource()->set_uri(checkpoint_uri);
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(DoAll(SetArgPointee<0>(fake_checkin_response),
+ Return(absl::OkStatus())));
+
+ // Mock a failed checkpoint fetch.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ checkpoint_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(503, {}, "")));
+
+ {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::
+ HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
+ }
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
+ EXPECT_THAT(checkin_result.status().message(),
+ HasSubstr("checkpoint fetch failed"));
+ EXPECT_THAT(checkin_result.status().message(), HasSubstr("503"));
+ // The Checkin call is expected to return the rejected retry window from the
+ // response to the first eligibility eval request.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestCheckinAcceptNonSecAgg) {
+ // Issue an eligibility eval checkin first, followed by a successful checkin
+ // returning a task that doesn't use SecAgg.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ auto checkin_result = RunSuccessfulCheckin(/*use_secure_aggregation=*/false);
+ ASSERT_OK(checkin_result.status());
+ // If HTTP support is enabled then the checkpoint data gets returned in the
+ // shape of an absl::Cord (rather than an std::string), regardless of whether
+ // it was actually downloaded via HTTP.
+ if (enable_http_resource_support_) {
+ EXPECT_THAT(*checkin_result,
+ VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
+ FieldsAre(absl::Cord(kPlan), absl::Cord(kInitCheckpoint)),
+ kFederatedSelectUriTemplate, kExecutionPhaseId,
+ // There should be no SecAggInfo in the result.
+ Eq(std::nullopt))));
+ } else {
+ EXPECT_THAT(*checkin_result,
+ VariantWith<FederatedProtocol::TaskAssignment>(
+ FieldsAre(FieldsAre(kPlan, kInitCheckpoint),
+ kFederatedSelectUriTemplate, kExecutionPhaseId,
+ // There should be no SecAggInfo in the result.
+ Eq(std::nullopt))));
+ }
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestReportWithSecAgg) {
+ // Issue an eligibility eval checkin first, followed by a successful checkin.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/true));
+ // Create a SecAgg like Checkpoint - a combination of a TF checkpoint, and
+ // one or more SecAgg quantized aggregands.
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", "");
+ results.emplace("some_tensor", QuantizedTensor());
+
+ mock_secagg_runner_ = new StrictMock<MockSecAggRunner>();
+ EXPECT_CALL(*mock_secagg_runner_factory_,
+ CreateSecAggRunner(_, _, _, _, _, kSecAggExpectedNumberOfClients,
+ kSecAggMinSurvivingClientsForReconstruction))
+ .WillOnce(Return(ByMove(absl::WrapUnique(mock_secagg_runner_))));
+ EXPECT_CALL(
+ *mock_secagg_runner_,
+ Run(UnorderedElementsAre(
+ Pair("tensorflow_checkpoint", VariantWith<TFCheckpoint>(IsEmpty())),
+ Pair("some_tensor", VariantWith<QuantizedTensor>(
+ FieldsAre(IsEmpty(), 0, IsEmpty()))))))
+ .WillOnce(Return(absl::OkStatus()));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(GetFakeReportResponse()),
+ Return(absl::OkStatus())));
+ EXPECT_OK(federated_protocol_->ReportCompleted(
+ std::move(results), absl::ZeroDuration(), std::nullopt));
+}
+
+TEST_P(GrpcFederatedProtocolTest, TestReportWithSecAggWithoutTFCheckpoint) {
+ // Issue an eligibility eval checkin first, followed by a successful checkin.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/true));
+
+ ComputationResults results;
+ results.emplace("some_tensor", QuantizedTensor());
+
+ mock_secagg_runner_ = new StrictMock<MockSecAggRunner>();
+ EXPECT_CALL(*mock_secagg_runner_factory_,
+ CreateSecAggRunner(_, _, _, _, _, kSecAggExpectedNumberOfClients,
+ kSecAggMinSurvivingClientsForReconstruction))
+ .WillOnce(Return(ByMove(absl::WrapUnique(mock_secagg_runner_))));
+ EXPECT_CALL(*mock_secagg_runner_,
+ Run(UnorderedElementsAre(
+ Pair("some_tensor", VariantWith<QuantizedTensor>(FieldsAre(
+ IsEmpty(), 0, IsEmpty()))))))
+ .WillOnce(Return(absl::OkStatus()));
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(GetFakeReportResponse()),
+ Return(absl::OkStatus())));
+ EXPECT_OK(federated_protocol_->ReportCompleted(
+ std::move(results), absl::ZeroDuration(), std::nullopt));
+}
+
+// This function tests the Report(...) method's Send code path, ensuring the
+// right events are logged / and the right data is transmitted to the server.
+TEST_P(GrpcFederatedProtocolTest, TestReportSendFails) {
+ // Issue an eligibility eval checkin first, followed by a successful checkin.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
+
+ // 1. Create input for the Report function.
+ std::string checkpoint_str;
+ const size_t kTFCheckpointSize = 32;
+ checkpoint_str.resize(kTFCheckpointSize, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+
+ absl::Duration plan_duration = absl::Milliseconds(1337);
+
+ // 2. The expected message sent to the server by the ReportCompleted()
+ // function, as text proto.
+ ClientStreamMessage expected_client_stream_message;
+ ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
+ absl::StrCat(
+ "report_request {", " population_name: \"", kPopulationName, "\"",
+ " execution_phase_id: \"", kExecutionPhaseId, "\"", " report {",
+ " update_checkpoint: \"XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\"",
+ " serialized_train_event {", "[type.googleapis.com/",
+ "google.internal.federatedml.v2.ClientExecutionStats] {",
+ " duration { seconds: 1 nanos: 337000000 }", " }",
+ " }", " }", "}"),
+ &expected_client_stream_message));
+
+ // 3. Set up mocks.
+ EXPECT_CALL(*mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(expected_client_stream_message))))
+ .WillOnce(Return(absl::AbortedError("foo")));
+
+ // 4. Test that ReportCompleted() sends the expected message.
+ auto report_result = federated_protocol_->ReportCompleted(
+ std::move(results), plan_duration, std::nullopt);
+ EXPECT_THAT(report_result, IsCode(ABORTED));
+ EXPECT_THAT(report_result.message(), HasSubstr("foo"));
+
+ // If we made it to the Report protocol phase, then the client must've been
+ // accepted during the Checkin phase first, and so we should receive the
+ // "accepted" RetryWindow.
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// This function tests the happy path of ReportCompleted() - results get
+// reported, server replies with a RetryWindow.
+TEST_P(GrpcFederatedProtocolTest, TestPublishReportSuccess) {
+ // Issue an eligibility eval checkin first, followed by a successful checkin.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
+
+ // 1. Create input for the Report function.
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", "");
+
+ // 2. Set up mocks.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+ ServerStreamMessage response_message;
+ response_message.mutable_report_response();
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(response_message), Return(absl::OkStatus())));
+
+ // 3. Test that ReportCompleted() sends the expected message.
+ auto report_result = federated_protocol_->ReportCompleted(
+ std::move(results), absl::ZeroDuration(), std::nullopt);
+ EXPECT_OK(report_result);
+
+ // If we made it to the Report protocol phase, then the client must've been
+ // accepted during the Checkin phase first, and so we should receive the
+ // "accepted" RetryWindow.
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// This function tests the Send code path when PhaseOutcome indicates an
+// error. / In that case, no checkpoint, and only the duration stat, should be
+// uploaded.
+TEST_P(GrpcFederatedProtocolTest, TestPublishReportNotCompleteSendFails) {
+ // Issue an eligibility eval checkin first, followed by a successful checkin.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
+
+ // 1. Create input for the Report function.
+ absl::Duration plan_duration = absl::Milliseconds(1337);
+
+ // 2. The expected message sent to the server by the ReportNotCompleted()
+ // function, as text proto.
+ ClientStreamMessage expected_client_stream_message;
+ ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
+ absl::StrCat("report_request {", " population_name: \"", kPopulationName,
+ "\"", " execution_phase_id: \"", kExecutionPhaseId, "\"",
+ " report {", " serialized_train_event {",
+ "[type.googleapis.com/",
+ "google.internal.federatedml.v2.ClientExecutionStats] {",
+ " duration { seconds: 1 nanos: 337000000 }",
+ " }", " }", " status_code: INTERNAL", " }", "}"),
+ &expected_client_stream_message));
+
+ // 3. Set up mocks.
+ EXPECT_CALL(*mock_grpc_bidi_stream_,
+ Send(Pointee(EqualsProto(expected_client_stream_message))))
+ .WillOnce(Return(absl::AbortedError("foo")));
+
+ // 4. Test that ReportNotCompleted() sends the expected message.
+ auto report_result = federated_protocol_->ReportNotCompleted(
+ engine::PhaseOutcome::ERROR, plan_duration, std::nullopt);
+ EXPECT_THAT(report_result, IsCode(ABORTED));
+ EXPECT_THAT(report_result.message(), HasSubstr("foo"));
+
+ // If we made it to the Report protocol phase, then the client must've been
+ // accepted during the Checkin phase first, and so we should receive the
+ // "accepted" RetryWindow.
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// This function tests the happy path of ReportCompleted() - results get
+// reported, server replies with a RetryWindow.
+TEST_P(GrpcFederatedProtocolTest, TestPublishReportSuccessCommitsToOpstats) {
+ // Issue an eligibility eval checkin first, followed by a successful checkin.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
+
+ // 1. Create input for the Report function.
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", "");
+
+ // 2. Set up mocks.
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
+ .WillOnce(Return(absl::OkStatus()));
+ ServerStreamMessage response_message;
+ response_message.mutable_report_response();
+ EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
+ .WillOnce(
+ DoAll(SetArgPointee<0>(response_message), Return(absl::OkStatus())));
+
+ // 3. Test that ReportCompleted() sends the expected message.
+ auto report_result = federated_protocol_->ReportCompleted(
+ std::move(results), absl::ZeroDuration(), std::nullopt);
+ EXPECT_OK(report_result);
+
+ // If we made it to the Report protocol phase, then the client must've been
+ // accepted during the Checkin phase first, and so we should receive the
+ // "accepted" RetryWindow.
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+} // anonymous namespace
+} // namespace fcp::client
diff --git a/fcp/client/histogram_counters.proto b/fcp/client/histogram_counters.proto
new file mode 100644
index 0000000..6c26753
--- /dev/null
+++ b/fcp/client/histogram_counters.proto
@@ -0,0 +1,178 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+syntax = "proto3";
+
+package fcp.client;
+
+option java_package = "com.google.intelligence.fcp.client";
+option java_multiple_files = true;
+
+/**
+ * Enumerations of timer and counter identifiers.
+ *
+ * For monitoring, certain timers and counters are logged as integer histograms.
+ * This allows for computing aggregate histograms on the cloud and determine
+ * distributions of latencies for blocks of code, resource usage etc.
+ */
+enum HistogramCounters {
+ HISTOGRAM_COUNTER_UNDEFINED = 0;
+
+ /**
+ * How long it takes to run a plan on device, excluding downloading the plan
+ * and reporting results.
+ */
+ TRAINING_RUN_PHASE_LATENCY = 1;
+
+ /**
+ * The end time of running training for a whole plan, excluding downloading
+ * the plan and reporting results, relative to the start of the training
+ * session.
+ */
+ TRAINING_RUN_PHASE_END_TIME = 2;
+
+ /** How long running a "restore state op" takes. */
+ TRAINING_RESTORE_STATE_LATENCY = 3;
+
+ /**
+ * How long it takes to run training for a whole client execution (which may
+ * involve running multiple epochs). This includes connecting and fetching
+ * example from the example store, as well as training over them.
+ */
+ TRAINING_RUN_CLIENT_EXECUTION_LATENCY = 4;
+
+ /** How long running an "init op" takes. */
+ TRAINING_INIT_OP_LATENCY = 5;
+
+ /** How long running a "before op" takes. */
+ TRAINING_BEFORE_OP_LATENCY = 6;
+
+ /** How long running an "after op" takes. */
+ TRAINING_AFTER_OP_LATENCY = 7;
+
+ /**
+ * How long it takes to run training for a whole epoch. This includes
+ * connecting and fetching example from the example store, as well as training
+ * over them.
+ */
+ TRAINING_RUN_EPOCH_LATENCY = 8;
+
+ /**
+ * How long it takes to gather enough examples for a mini batch.
+ * This counter may be an average across minibatches and epochs.
+ */
+ TRAINING_GATHER_MINI_BATCH_LATENCY = 9;
+
+ /**
+ * How long it takes to run training on a mini batch.
+ * This counter may be an average across minibatches and epochs.
+ */
+ TRAINING_RUN_MINI_BATCH_LATENCY = 10;
+
+ /**
+ * How long it takes the TensorFlow session to terminate after it's been
+ * interrupted.
+ */
+ TRAINING_INTERRUPT_TERMINATION_LATENCY = 11;
+
+ /** How long it takes to commit the opstats message to the database. */
+ TRAINING_OPSTATS_COMMIT_LATENCY = 12;
+
+ /** The number of examples encountered during overall training, across all
+ * client executions. */
+ TRAINING_OVERALL_EXAMPLE_COUNT = 100001;
+
+ /**
+ * The sum of the size (in bytes) of all the examples encountered during
+ * overall training, across all client executions.
+ */
+ TRAINING_OVERALL_EXAMPLE_SIZE = 100002;
+
+ /**
+ * The number of examples encountered in a client execution, across all
+ * epochs.
+ */
+ TRAINING_CLIENT_EXECUTION_EXAMPLE_COUNT = 100003;
+
+ /**
+ * The sum of the size (in bytes) of all the examples encountered in a client
+ * execution, across all epoch.
+ */
+ TRAINING_CLIENT_EXECUTION_EXAMPLE_SIZE = 100004;
+
+ /**
+ * The number of examples encountered in an epoch.
+ * This counter may be an average from multiple epochs.
+ */
+ TRAINING_EPOCH_EXAMPLE_COUNT = 100005;
+
+ /**
+ * The sum of the size (in bytes) of all the examples encountered in an
+ * epoch. This counter may be an average from multiple epochs
+ */
+ TRAINING_EPOCH_EXAMPLE_SIZE = 100006;
+
+ /**
+ * The number of examples in a mini batch.
+ * This counter may be an average from multiple minibatches.
+ */
+ TRAINING_MINI_BATCH_EXAMPLE_COUNT = 100007;
+
+ /**
+ * The sum of the size (in bytes) of all the examples in a mini batch.
+ * This counter may be an average from multiple minibatches.
+ */
+ TRAINING_MINI_BATCH_EXAMPLE_SIZE = 100008;
+
+ /**
+ * The size (in bytes) of the OpStatsDb file.
+ */
+ OPSTATS_DB_SIZE_BYTES = 100009;
+
+ /**
+ * The number of entries in OpStatsDb.
+ */
+ OPSTATS_DB_NUM_ENTRIES = 100010;
+
+ /**
+ * The number of entries pruned from OpStatsDb due to exceeding max size.
+ */
+ OPSTATS_NUM_PRUNED_ENTRIES = 100011;
+
+ /**
+ * The tenure (in hours) of the oldest entry which has been pruned from the
+ * OpStatsDb due to exceeding max size.
+ */
+ OPSTATS_OLDEST_PRUNED_ENTRY_TENURE_HOURS = 100012;
+
+ /** How long checking in/downloading a plan takes (for FL plans only). */
+ TRAINING_FL_CHECKIN_LATENCY = 200001;
+
+ /**
+ * The end time of reporting results to the server, relative to the start
+ * of the training session.
+ */
+ TRAINING_FL_REPORT_RESULTS_END_TIME = 200002;
+
+ /** How long reporting results to the server takes. */
+ TRAINING_FL_REPORT_RESULTS_LATENCY = 200003;
+
+ /** The end time of checking in/downloading a plan from the server, relative
+ to the start of the training session. */
+ TRAINING_FL_CHECKIN_END_TIME = 200004;
+
+ /** How long reporting results to the server takes. */
+ TRAINING_FL_ELIGIBILITY_EVAL_CHECKIN_LATENCY = 200005;
+}
diff --git a/fcp/client/http/BUILD b/fcp/client/http/BUILD
new file mode 100644
index 0000000..ce4fe7e
--- /dev/null
+++ b/fcp/client/http/BUILD
@@ -0,0 +1,276 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "http_client",
+ hdrs = ["http_client.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "http_client_util",
+ srcs = ["http_client_util.cc"],
+ hdrs = ["http_client_util.h"],
+ deps = [
+ ":http_client",
+ "//fcp/base",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googleapis//google/rpc:status_cc_proto",
+ ],
+)
+
+cc_test(
+ name = "http_client_util_test",
+ srcs = ["http_client_util_test.cc"],
+ deps = [
+ ":http_client_util",
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_googleapis//google/rpc:status_cc_proto",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+proto_library(
+ name = "http_resource_metadata_proto",
+ srcs = ["http_resource_metadata.proto"],
+)
+
+cc_proto_library(
+ name = "http_resource_metadata_cc_proto",
+ deps = [":http_resource_metadata_proto"],
+)
+
+cc_library(
+ name = "in_memory_request_response",
+ srcs = ["in_memory_request_response.cc"],
+ hdrs = ["in_memory_request_response.h"],
+ deps = [
+ ":http_client",
+ ":http_client_util",
+ ":http_resource_metadata_cc_proto",
+ "//fcp/base",
+ "//fcp/client:interruptible_runner",
+ "//fcp/client/cache:resource_cache",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "in_memory_request_response_test",
+ srcs = ["in_memory_request_response_test.cc"],
+ deps = [
+ ":http_client",
+ ":http_client_util",
+ ":http_resource_metadata_cc_proto",
+ ":in_memory_request_response",
+ "//fcp/base",
+ "//fcp/base:simulated_clock",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:interruptible_runner",
+ "//fcp/client:test_helpers",
+ "//fcp/client/cache:file_backed_resource_cache",
+ "//fcp/client/cache:test_helpers",
+ "//fcp/client/http/testing:test_helpers",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "http_federated_protocol",
+ srcs = ["http_federated_protocol.cc"],
+ hdrs = [
+ "http_federated_protocol.h",
+ ],
+ deps = [
+ ":http_client",
+ ":http_client_util",
+ ":http_secagg_send_to_server_impl",
+ ":in_memory_request_response",
+ ":protocol_request_helper",
+ "//fcp/base",
+ "//fcp/base:clock",
+ "//fcp/base:time_util",
+ "//fcp/base:wall_clock_stopwatch",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:federated_protocol",
+ "//fcp/client:federated_protocol_util",
+ "//fcp/client:fl_runner_cc_proto",
+ "//fcp/client:interfaces",
+ "//fcp/client:interruptible_runner",
+ "//fcp/client:parsing_utils",
+ "//fcp/client:secagg_runner",
+ "//fcp/client:selector_context_cc_proto",
+ "//fcp/client/cache:resource_cache",
+ "//fcp/client/engine:engine_cc_proto",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/protos/federatedcompute:federated_compute_cc_proto",
+ "//fcp/secagg/client",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_absl//absl/time",
+ "@com_google_googleapis//google/longrunning:longrunning_cc_proto",
+ "@com_google_googleapis//google/rpc:code_cc_proto",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "http_federated_protocol_test",
+ srcs = ["http_federated_protocol_test.cc"],
+ deps = [
+ ":http_client",
+ ":http_client_util",
+ ":http_federated_protocol",
+ ":in_memory_request_response",
+ "//fcp/base",
+ "//fcp/base:clock",
+ "//fcp/base:time_util",
+ "//fcp/base:wall_clock_stopwatch",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:federated_protocol",
+ "//fcp/client:federated_protocol_util",
+ "//fcp/client:interfaces",
+ "//fcp/client:interruptible_runner",
+ "//fcp/client:test_helpers",
+ "//fcp/client/cache:test_helpers",
+ "//fcp/client/engine:engine_cc_proto",
+ "//fcp/client/http/testing:test_helpers",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/protos/federatedcompute:federated_compute_cc_proto",
+ "//fcp/secagg/shared:cc_proto",
+ "//fcp/testing",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_googleapis//google/longrunning:longrunning_cc_proto",
+ "@com_google_googleapis//google/rpc:code_cc_proto",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "http_secagg_send_to_server_impl",
+ srcs = ["http_secagg_send_to_server_impl.cc"],
+ hdrs = ["http_secagg_send_to_server_impl.h"],
+ visibility = ["//visibility:private"],
+ deps = [
+ ":http_client_util",
+ ":protocol_request_helper",
+ "//fcp/client:interfaces",
+ "//fcp/client:secagg_runner",
+ "//fcp/protos/federatedcompute:federated_compute_cc_proto",
+ "//fcp/secagg/shared:cc_proto",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_googleapis//google/rpc:code_cc_proto",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "http_secagg_send_to_server_impl_test",
+ srcs = ["http_secagg_send_to_server_impl_test.cc"],
+ deps = [
+ ":http_secagg_send_to_server_impl",
+ "//fcp/base:simulated_clock",
+ "//fcp/client:test_helpers",
+ "//fcp/client/http/testing:test_helpers",
+ "//fcp/protos/federatedcompute:federated_compute_cc_proto",
+ "//fcp/testing",
+ "@com_google_absl//absl/time",
+ "@com_google_googleapis//google/longrunning:longrunning_cc_proto",
+ "@com_google_googleapis//google/rpc:code_cc_proto",
+ ],
+)
+
+cc_library(
+ name = "protocol_request_helper",
+ srcs = ["protocol_request_helper.cc"],
+ hdrs = ["protocol_request_helper.h"],
+ deps = [
+ ":http_client",
+ ":http_client_util",
+ ":in_memory_request_response",
+ "//fcp/base:clock",
+ "//fcp/base:time_util",
+ "//fcp/base:wall_clock_stopwatch",
+ "//fcp/protos/federatedcompute:federated_compute_cc_proto",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_googleapis//google/longrunning:longrunning_cc_proto",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "protocol_request_helper_test",
+ srcs = ["protocol_request_helper_test.cc"],
+ deps = [
+ ":protocol_request_helper",
+ "//fcp/base:time_util",
+ "//fcp/client:test_helpers",
+ "//fcp/client/http/testing:test_helpers",
+ "//fcp/protos/federatedcompute:federated_compute_cc_proto",
+ "//fcp/testing",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
diff --git a/fcp/client/http/README.md b/fcp/client/http/README.md
new file mode 100644
index 0000000..f84ed8a
--- /dev/null
+++ b/fcp/client/http/README.md
@@ -0,0 +1,6 @@
+# HTTP-related classes and utilities.
+
+This directory hosts classes supporting the use of the HTTP protocol by the the
+rest of the client codebase. E.g. it defines an abstract interface for issuing
+HTTP requests, which can be implemented using different HTTP libraries on
+different platforms.
diff --git a/fcp/client/http/curl/BUILD b/fcp/client/http/curl/BUILD
new file mode 100644
index 0000000..8f8f4d4
--- /dev/null
+++ b/fcp/client/http/curl/BUILD
@@ -0,0 +1,65 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "curl_http_client",
+ srcs = [
+ "curl_api.cc",
+ "curl_header_parser.cc",
+ "curl_http_client.cc",
+ "curl_http_request_handle.cc",
+ "curl_http_response.cc",
+ ],
+ hdrs = [
+ "curl_api.h",
+ "curl_header_parser.h",
+ "curl_http_client.h",
+ "curl_http_request_handle.h",
+ "curl_http_response.h",
+ ],
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//fcp/base",
+ "//fcp/client/http:http_client",
+ "//fcp/client/http:http_client_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@libcurl//:curl",
+ ],
+)
+
+cc_test(
+ name = "curl_header_parser_test",
+ srcs = [
+ "curl_header_parser_test.cc",
+ ],
+ deps = [
+ ":curl_http_client",
+ "//fcp/client/http:http_client_util",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+# TODO(team)
diff --git a/fcp/client/http/curl/curl_api.cc b/fcp/client/http/curl/curl_api.cc
new file mode 100644
index 0000000..7f242fd
--- /dev/null
+++ b/fcp/client/http/curl/curl_api.cc
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/curl/curl_api.h"
+
+#include <memory>
+#include <string>
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp::client::http::curl {
+// CurlEasyHandle
+
+CurlEasyHandle::CurlEasyHandle() : easy_handle_(curl_easy_init()) {
+ FCP_CHECK(easy_handle_ != nullptr);
+}
+
+CurlEasyHandle::~CurlEasyHandle() { curl_easy_cleanup(easy_handle_); }
+
+CURLcode CurlEasyHandle::GetInfo(CURLINFO info, curl_off_t* value) const {
+ curl_off_t data = 0;
+ CURLcode code = curl_easy_getinfo(easy_handle_, info, &data);
+ if (code == CURLE_OK) *value = static_cast<curl_off_t>(data);
+ return code;
+}
+
+std::string CurlEasyHandle::StrError(CURLcode code) {
+ return curl_easy_strerror(code);
+}
+
+CURL* CurlEasyHandle::GetEasyHandle() const { return easy_handle_; }
+
+// CurlMultiHandle
+
+CurlMultiHandle::CurlMultiHandle() : multi_handle_(curl_multi_init()) {
+ FCP_CHECK(multi_handle_ != nullptr);
+}
+
+CurlMultiHandle::~CurlMultiHandle() { curl_multi_cleanup(multi_handle_); }
+
+CURLMsg* CurlMultiHandle::InfoRead(int* msgs_in_queue) {
+ return curl_multi_info_read(multi_handle_, msgs_in_queue);
+}
+
+CURLMcode CurlMultiHandle::AddEasyHandle(CurlEasyHandle* easy_handle) {
+ return curl_multi_add_handle(multi_handle_, easy_handle->GetEasyHandle());
+}
+
+CURLMcode CurlMultiHandle::RemoveEasyHandle(CurlEasyHandle* easy_handle) {
+ return curl_multi_remove_handle(multi_handle_, easy_handle->GetEasyHandle());
+}
+
+CURLMcode CurlMultiHandle::Perform(int* num_running_handles) {
+ return curl_multi_perform(multi_handle_, num_running_handles);
+}
+
+CURLMcode CurlMultiHandle::Poll(curl_waitfd extra_fds[],
+ unsigned int extra_nfds, int timeout_ms,
+ int* numfds) {
+ return curl_multi_poll(multi_handle_, extra_fds, extra_nfds, timeout_ms,
+ numfds);
+}
+
+std::string CurlMultiHandle::StrError(CURLMcode code) {
+ return curl_multi_strerror(code);
+}
+
+// CurlApi
+
+CurlApi::CurlApi() { curl_global_init(CURL_GLOBAL_ALL); }
+
+CurlApi::~CurlApi() { curl_global_cleanup(); }
+
+std::unique_ptr<CurlEasyHandle> CurlApi::CreateEasyHandle() const {
+ absl::MutexLock lock(&mutex_);
+ // make_unique cannot access the private constructor, so we use
+ // an old-fashioned new.
+ return std::unique_ptr<CurlEasyHandle>(new CurlEasyHandle());
+}
+
+std::unique_ptr<CurlMultiHandle> CurlApi::CreateMultiHandle() const {
+ absl::MutexLock lock(&mutex_);
+ // make_unique cannot access the private constructor, so we use
+ // an old-fashioned new.
+ return std::unique_ptr<CurlMultiHandle>(new CurlMultiHandle());
+}
+
+} // namespace fcp::client::http::curl
diff --git a/fcp/client/http/curl/curl_api.h b/fcp/client/http/curl/curl_api.h
new file mode 100644
index 0000000..cebb6cf
--- /dev/null
+++ b/fcp/client/http/curl/curl_api.h
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_CURL_CURL_API_H_
+#define FCP_CLIENT_HTTP_CURL_CURL_API_H_
+
+#include <memory>
+#include <string>
+#include <type_traits>
+
+#include "absl/synchronization/mutex.h"
+#include "curl/curl.h"
+
+namespace fcp::client::http::curl {
+// An RAII wrapper around the libcurl easy handle, which works with one request.
+// The class is not thread-safe.
+class CurlEasyHandle {
+ public:
+ ~CurlEasyHandle();
+ CurlEasyHandle(const CurlEasyHandle&) = delete;
+ CurlEasyHandle& operator=(const CurlEasyHandle&) = delete;
+
+ CURLcode GetInfo(CURLINFO info, curl_off_t* value) const;
+
+ template <typename T, typename = std::enable_if_t<std::is_trivial_v<T>>>
+ CURLcode SetOpt(CURLoption option, T value) {
+ return curl_easy_setopt(easy_handle_, option, value);
+ }
+ CURLcode SetOpt(CURLoption option, const std::string& value) {
+ return SetOpt(option, value.c_str());
+ }
+
+ // Converts the curl code into a human-readable form.
+ ABSL_MUST_USE_RESULT static std::string StrError(CURLcode code);
+
+ // Returns the underlying curl handle.
+ ABSL_MUST_USE_RESULT CURL* GetEasyHandle() const;
+
+ private:
+ friend class CurlApi;
+ CurlEasyHandle();
+ CURL* const easy_handle_;
+};
+
+// An RAII wrapper around the libcurl multi handle, which is a bundle of easy
+// requests that performs them in parallel in one thread. The class is not
+// thread-safe.
+class CurlMultiHandle {
+ public:
+ ~CurlMultiHandle();
+ CurlMultiHandle(const CurlMultiHandle&) = delete;
+ CurlMultiHandle& operator=(const CurlMultiHandle&) = delete;
+
+ // Fetches the next message in the message queue.
+ CURLMsg* InfoRead(int* msgs_in_queue);
+
+ // Add a new request to the bundle.
+ CURLMcode AddEasyHandle(CurlEasyHandle* easy_handle);
+
+ // Remove the requests from the bundle. It will cancel an unfinished request,
+ // but it will not delete the easy handle.
+ CURLMcode RemoveEasyHandle(CurlEasyHandle* easy_handle);
+
+ // Performs all active tasks and returns.
+ CURLMcode Perform(int* num_running_handles);
+
+ // Waits for a new job
+ CURLMcode Poll(curl_waitfd extra_fds[], unsigned int extra_nfds,
+ int timeout_ms, int* numfds);
+
+ // Converts the curl code into a human-readable form.
+ ABSL_MUST_USE_RESULT static std::string StrError(CURLMcode code);
+
+ private:
+ friend class CurlApi;
+ CurlMultiHandle();
+ CURLM* const multi_handle_;
+};
+
+// An RAII wrapper around global initialization for libcurl. It forces the user
+// to create it first, so the initialization can be made, on which handles
+// depend. The class needs to be created only once, and its methods are
+// thread-safe.
+class CurlApi {
+ public:
+ CurlApi();
+ ~CurlApi();
+ CurlApi(const CurlApi&) = delete;
+ CurlApi& operator=(const CurlApi&) = delete;
+
+ ABSL_MUST_USE_RESULT std::unique_ptr<CurlEasyHandle> CreateEasyHandle() const;
+ ABSL_MUST_USE_RESULT std::unique_ptr<CurlMultiHandle> CreateMultiHandle()
+ const;
+
+ private:
+ mutable absl::Mutex mutex_;
+};
+
+} // namespace fcp::client::http::curl
+
+#endif // FCP_CLIENT_HTTP_CURL_CURL_API_H_
diff --git a/fcp/client/http/curl/curl_header_parser.cc b/fcp/client/http/curl/curl_header_parser.cc
new file mode 100644
index 0000000..02a0067
--- /dev/null
+++ b/fcp/client/http/curl/curl_header_parser.cc
@@ -0,0 +1,108 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/curl/curl_header_parser.h"
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/match.h"
+#include "absl/strings/str_split.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/http/http_client_util.h"
+
+namespace fcp::client::http::curl {
+
+CurlHeaderParser::CurlHeaderParser()
+ : status_code_(-1),
+ is_last_header_line_(false),
+ use_curl_encoding_(false) {}
+
+void CurlHeaderParser::ParseHeader(const std::string& header_string) {
+ if (ParseAsStatus(header_string)) {
+ return;
+ }
+ if (ParseAsHeader(header_string)) {
+ return;
+ }
+ if (ParseAsLastLine(header_string)) {
+ return;
+ }
+}
+
+bool CurlHeaderParser::ParseAsStatus(const std::string& header_string) {
+ if (!absl::StartsWith(header_string, "HTTP/")) {
+ return false;
+ }
+
+ std::pair<std::string, std::string> split =
+ absl::StrSplit(header_string, ' ');
+ int status_code;
+ if (!absl::SimpleAtoi(split.second.substr(0, 3), &status_code)) {
+ return false;
+ }
+
+ status_code_ = status_code;
+ // It is required that we store only the final header list. So we keep the
+ // last set of headers
+ header_list_.clear();
+ return true;
+}
+
+bool CurlHeaderParser::ParseAsHeader(const std::string& header_string) {
+ if (!absl::StrContains(header_string, ':')) {
+ return false;
+ }
+
+ std::pair<std::string, std::string> split =
+ absl::StrSplit(header_string, ':');
+ std::string key = split.first;
+ std::string value = std::string(absl::StripAsciiWhitespace(split.second));
+
+ // Removes the "Content-Encoding", "Content-Length", and "Content-Length"
+ // headers from the response when the curl encoding in use because they
+ // reflect in-flight encoded values
+ if (!use_curl_encoding_ ||
+ (!absl::EqualsIgnoreCase(key, kContentEncodingHdr) &&
+ !absl::EqualsIgnoreCase(key, kContentLengthHdr) &&
+ !absl::EqualsIgnoreCase(key, kTransferEncodingHdr))) {
+ header_list_.push_back({key, value});
+ }
+ return true;
+}
+
+bool CurlHeaderParser::ParseAsLastLine(const std::string& header_string) {
+ // In general, it is impossible to tell when curl will reach the last
+ // header because there could another one in some special cases. In
+ // particular, it happens when curl hits a redirect status code (301).
+ // In this case, we need to proceed.
+ if (std::string(header_string) == "\r\n" &&
+ status_code_ != HttpResponseCode::kHttpMovedPermanently) {
+ is_last_header_line_ = true;
+ return true;
+ }
+
+ return false;
+}
+
+void CurlHeaderParser::UseCurlEncoding() { use_curl_encoding_ = true; }
+
+bool CurlHeaderParser::IsLastHeader() const { return is_last_header_line_; }
+
+int CurlHeaderParser::GetStatusCode() const { return status_code_; }
+
+HeaderList CurlHeaderParser::GetHeaderList() const { return header_list_; }
+
+} // namespace fcp::client::http::curl
diff --git a/fcp/client/http/curl/curl_header_parser.h b/fcp/client/http/curl/curl_header_parser.h
new file mode 100644
index 0000000..5d2fe28
--- /dev/null
+++ b/fcp/client/http/curl/curl_header_parser.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_CURL_CURL_HEADER_PARSER_H_
+#define FCP_CLIENT_HTTP_CURL_CURL_HEADER_PARSER_H_
+
+#include <string>
+
+#include "fcp/client/http/http_client.h"
+
+namespace fcp::client::http::curl {
+
+// A custom parser that is needed to call the first callback after all the
+// headers received
+class CurlHeaderParser {
+ public:
+ explicit CurlHeaderParser();
+ ~CurlHeaderParser() = default;
+ CurlHeaderParser(const CurlHeaderParser&) = delete;
+ CurlHeaderParser& operator=(const CurlHeaderParser&) = delete;
+
+ // Parses the next header string
+ void ParseHeader(const std::string& header_string);
+
+ // Removes the "Content-Encoding", "Content-Length", and "Content-Length"
+ // headers from the response when the curl encoding in use because they
+ // reflect in-flight encoded values
+ void UseCurlEncoding();
+ // Indicates that the parser reached the last header
+ ABSL_MUST_USE_RESULT bool IsLastHeader() const;
+ ABSL_MUST_USE_RESULT int GetStatusCode() const;
+ ABSL_MUST_USE_RESULT HeaderList GetHeaderList() const;
+
+ private:
+ // Extracts status codes from HTTP/1.1 and HTTP/2 responses
+ bool ParseAsStatus(const std::string& header_string);
+ // Parses a header into a key-value pair
+ bool ParseAsHeader(const std::string& header_string);
+ // Decides whether it is the last header
+ bool ParseAsLastLine(const std::string& header_string);
+
+ int status_code_;
+ HeaderList header_list_;
+ bool is_last_header_line_;
+ bool use_curl_encoding_;
+};
+
+} // namespace fcp::client::http::curl
+
+#endif // FCP_CLIENT_HTTP_CURL_CURL_HEADER_PARSER_H_
diff --git a/fcp/client/http/curl/curl_header_parser_test.cc b/fcp/client/http/curl/curl_header_parser_test.cc
new file mode 100644
index 0000000..ff22811
--- /dev/null
+++ b/fcp/client/http/curl/curl_header_parser_test.cc
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/http/curl/curl_header_parser.h"
+
+#include "fcp/client/http/http_client_util.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::client::http::curl {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Pair;
+
+TEST(CurlHeaderParserTest, Parse_HTTP_1_1_StatusCode) {
+ CurlHeaderParser parser;
+ EXPECT_THAT(parser.GetStatusCode(), -1);
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+
+ auto line1 = "HTTP/1.1 200 OK\r\n";
+ parser.ParseHeader(line1);
+
+ EXPECT_THAT(parser.GetStatusCode(), 200);
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+
+ auto line2 = "HTTP/1.1 500 Internal Error\r\n";
+ parser.ParseHeader(line2);
+
+ EXPECT_THAT(parser.GetStatusCode(), 500);
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+
+ auto line3 = "HTTP/1.1 123 Some:Custom Message\r\n";
+ parser.ParseHeader(line3);
+
+ EXPECT_THAT(parser.GetStatusCode(), 123);
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+}
+
+TEST(CurlHeaderParserTest, Parse_HTTP_2_StatusCode) {
+ CurlHeaderParser parser;
+ EXPECT_THAT(parser.GetStatusCode(), -1);
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+
+ auto line1 = "HTTP/2 200\r\n";
+ parser.ParseHeader(line1);
+
+ EXPECT_THAT(parser.GetStatusCode(), 200);
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+
+ auto line2 = "HTTP/2 500\r\n";
+ parser.ParseHeader(line2);
+
+ EXPECT_THAT(parser.GetStatusCode(), 500);
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+}
+
+TEST(CurlHeaderParserTest, ParseHeaders) {
+ CurlHeaderParser parser;
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+ EXPECT_THAT(parser.IsLastHeader(), false);
+
+ auto line = "HTTP/2 301\r\n"; // Redirect
+ parser.ParseHeader(line);
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+ EXPECT_THAT(parser.IsLastHeader(), false);
+
+ line = "Content-Encoding: gzip\r\n";
+ parser.ParseHeader(line);
+ EXPECT_THAT(parser.GetHeaderList(),
+ ElementsAre(Pair(kContentEncodingHdr, "gzip")));
+ EXPECT_THAT(parser.IsLastHeader(), false);
+
+ line = "\r\n";
+ parser.ParseHeader(line);
+ EXPECT_THAT(parser.GetHeaderList(),
+ ElementsAre(Pair(kContentEncodingHdr, "gzip")));
+ EXPECT_THAT(parser.IsLastHeader(), false);
+
+ line = "HTTP/2 200\r\n"; // OK
+ parser.ParseHeader(line);
+ EXPECT_THAT(parser.GetHeaderList(), ElementsAre());
+ EXPECT_THAT(parser.IsLastHeader(), false);
+
+ line = "Content-Type: application/octet-stream\r\n";
+ parser.ParseHeader(line);
+ EXPECT_THAT(parser.GetHeaderList(),
+ ElementsAre(Pair("Content-Type", "application/octet-stream")));
+ EXPECT_THAT(parser.IsLastHeader(), false);
+
+ line = "Content-Length: 150\r\n";
+ parser.ParseHeader(line);
+ EXPECT_THAT(parser.GetHeaderList(),
+ ElementsAre(Pair("Content-Type", "application/octet-stream"),
+ Pair(kContentLengthHdr, "150")));
+ EXPECT_THAT(parser.IsLastHeader(), false);
+
+ line = "\r\n";
+ parser.ParseHeader(line);
+ EXPECT_THAT(parser.GetHeaderList(),
+ ElementsAre(Pair("Content-Type", "application/octet-stream"),
+ Pair(kContentLengthHdr, "150")));
+ EXPECT_THAT(parser.IsLastHeader(), true);
+}
+} // namespace
+} // namespace fcp::client::http::curl
diff --git a/fcp/client/http/curl/curl_http_client.cc b/fcp/client/http/curl/curl_http_client.cc
new file mode 100644
index 0000000..3c242bc
--- /dev/null
+++ b/fcp/client/http/curl/curl_http_client.cc
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/http/curl/curl_http_client.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/http/curl/curl_http_request_handle.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/http_client_util.h"
+
+namespace fcp::client::http::curl {
+namespace {
+// Cleans completed requests and calls the required callbacks.
+void ReadCompleteMessages(CurlMultiHandle* multi_handle) {
+ CURLMsg* msg;
+ int messages_in_queue = 0;
+ while ((msg = multi_handle->InfoRead(&messages_in_queue))) {
+ if (msg->msg == CURLMSG_DONE) {
+ FCP_LOG(INFO) << CurlEasyHandle::StrError(msg->data.result);
+ void* user_data;
+ curl_easy_getinfo(msg->easy_handle, CURLINFO_PRIVATE, &user_data);
+ FCP_CHECK(user_data != nullptr);
+
+ auto handle = static_cast<CurlHttpRequestHandle*>(user_data);
+ handle->MarkAsCompleted();
+ handle->RemoveFromMulti(multi_handle);
+ }
+ }
+}
+
+// Processes multiple requests while blocked.
+absl::Status PerformMultiHandlesBlocked(CurlMultiHandle* multi_handle) {
+ int num_running_handles = -1;
+ while (num_running_handles) {
+ CURLMcode code = multi_handle->Perform(&num_running_handles);
+ if (code != CURLM_OK) {
+ FCP_LOG(ERROR) << "MultiPerform failed with code: " << code;
+ return absl::InternalError(
+ absl::StrCat("MultiPerform failed with code: ", code));
+ }
+
+ ReadCompleteMessages(multi_handle);
+
+ if (num_running_handles > 0) {
+ code = multi_handle->Poll(/*extra_fds*/ nullptr,
+ /*extra_nfds*/ 0, /*timeout_ms*/ 1000,
+ /*numfds*/ nullptr);
+ }
+ }
+
+ return absl::OkStatus();
+}
+} // namespace
+
+CurlHttpClient::CurlHttpClient(CurlApi* curl_api, std::string test_cert_path)
+ : curl_api_(curl_api), test_cert_path_(std::move(test_cert_path)) {
+ FCP_CHECK(curl_api_ != nullptr);
+}
+
+std::unique_ptr<HttpRequestHandle> CurlHttpClient::EnqueueRequest(
+ std::unique_ptr<HttpRequest> request) {
+ FCP_LOG(INFO) << "Creating a " << ConvertMethodToString(request->method())
+ << " request to " << request->uri() << " with body "
+ << request->HasBody() << " with headers "
+ << request->extra_headers().size();
+ for (const auto& [key, value] : request->extra_headers()) {
+ FCP_LOG(INFO) << key << ": " << value;
+ }
+
+ return std::make_unique<CurlHttpRequestHandle>(
+ std::move(request), curl_api_->CreateEasyHandle(), test_cert_path_);
+}
+
+absl::Status CurlHttpClient::PerformRequests(
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>> requests) {
+ FCP_LOG(INFO) << "PerformRequests";
+ std::unique_ptr<CurlMultiHandle> multi_handle =
+ curl_api_->CreateMultiHandle();
+ FCP_CHECK(multi_handle != nullptr);
+
+ for (const auto& [request_handle, callback] : requests) {
+ FCP_CHECK(request_handle != nullptr);
+ FCP_CHECK(callback != nullptr);
+
+ auto http_request_handle =
+ static_cast<CurlHttpRequestHandle*>(request_handle);
+ FCP_RETURN_IF_ERROR(
+ http_request_handle->AddToMulti(multi_handle.get(), callback));
+ }
+
+ return PerformMultiHandlesBlocked(multi_handle.get());
+}
+
+} // namespace fcp::client::http::curl
diff --git a/fcp/client/http/curl/curl_http_client.h b/fcp/client/http/curl/curl_http_client.h
new file mode 100644
index 0000000..94acbaf
--- /dev/null
+++ b/fcp/client/http/curl/curl_http_client.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_HTTP_CURL_CURL_HTTP_CLIENT_H_
+#define FCP_CLIENT_HTTP_CURL_CURL_HTTP_CLIENT_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/synchronization/mutex.h"
+#include "fcp/client/http/curl/curl_api.h"
+#include "fcp/client/http/http_client.h"
+
+namespace fcp::client::http::curl {
+
+// A curl-based implementation of the HttpClient interface that uses
+// CurlHttpRequestHandle underneath. The implementation assumes that
+// CurlHttpClient lives longer than CurlHttpRequestHandle; and
+// CurlApi lives longer than CurlHttpClient
+class CurlHttpClient : public HttpClient {
+ public:
+ explicit CurlHttpClient(CurlApi* curl_api, std::string test_cert_path = "");
+ ~CurlHttpClient() override = default;
+ CurlHttpClient(const CurlHttpClient&) = delete;
+ CurlHttpClient& operator=(const CurlHttpClient&) = delete;
+
+ // HttpClient overrides:
+ std::unique_ptr<HttpRequestHandle> EnqueueRequest(
+ std::unique_ptr<HttpRequest> request) override;
+
+ // Performs the given requests while blocked. Results will be returned to each
+ // corresponding `HttpRequestCallback`.
+ absl::Status PerformRequests(
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>> requests)
+ override;
+
+ private:
+ // Owned by the caller
+ const CurlApi* const curl_api_;
+ const std::string test_cert_path_;
+};
+
+} // namespace fcp::client::http::curl
+
+#endif // FCP_CLIENT_HTTP_CURL_CURL_HTTP_CLIENT_H_
diff --git a/fcp/client/http/curl/curl_http_client_test.cc b/fcp/client/http/curl/curl_http_client_test.cc
new file mode 100644
index 0000000..b73f6c3
--- /dev/null
+++ b/fcp/client/http/curl/curl_http_client_test.cc
@@ -0,0 +1,564 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/http/curl/curl_http_client.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/time/clock.h"
+#include "fcp/client/http/in_memory_request_response.h"
+#include "fcp/client/http/testing/http_test_server.h"
+#include "fcp/client/http/testing/test_helpers.h"
+
+namespace fcp::client::http::curl {
+namespace {
+using ::testing::_;
+using ::testing::AllOf;
+using ::testing::Field;
+using ::testing::FieldsAre;
+using ::testing::StrictMock;
+using ::testing::UnorderedElementsAreArray;
+
+void SetUpOnResponseStarted(
+ StrictMock<MockHttpRequestCallback>* request_callback,
+ const std::string& request_uri, HttpRequest::Method method,
+ int response_code, const HeaderList& expected_request_extra_headers,
+ bool has_body, const HeaderList& expected_response_headers) {
+ EXPECT_CALL(*request_callback, OnResponseStarted(_, _))
+ .WillOnce(::testing::Invoke([=, &request_uri](
+ const HttpRequest& request,
+ const HttpResponse& response) {
+ EXPECT_THAT(request.uri(), request_uri);
+ EXPECT_THAT(request.method(), method);
+ EXPECT_THAT(request.extra_headers(),
+ UnorderedElementsAreArray(expected_request_extra_headers));
+ EXPECT_THAT(request.HasBody(), has_body);
+ EXPECT_THAT(response.code(), response_code);
+ EXPECT_THAT(response.headers(),
+ UnorderedElementsAreArray(expected_response_headers));
+ return absl::OkStatus();
+ }));
+}
+
+void SetUpOnResponseBody(StrictMock<MockHttpRequestCallback>* request_callback,
+ const std::string& request_uri,
+ HttpRequest::Method method, int response_code,
+ const HeaderList& expected_request_extra_headers,
+ bool has_body,
+ const HeaderList& expected_response_headers,
+ const std::string& expected_response_body) {
+ EXPECT_CALL(*request_callback, OnResponseBody(_, _, _))
+ .WillOnce(::testing::Invoke([=](const HttpRequest& request,
+ const HttpResponse& response,
+ absl::string_view data) {
+ EXPECT_THAT(request.uri(), request_uri);
+ EXPECT_THAT(request.method(), method);
+ EXPECT_THAT(request.extra_headers(),
+ UnorderedElementsAreArray(expected_request_extra_headers));
+ EXPECT_THAT(request.HasBody(), has_body);
+ EXPECT_THAT(response.code(), response_code);
+ EXPECT_THAT(response.headers(),
+ UnorderedElementsAreArray(expected_response_headers));
+ EXPECT_THAT(data, expected_response_body);
+ return absl::OkStatus();
+ }));
+}
+
+void SetUpOnResponseBody(StrictMock<MockHttpRequestCallback>* request_callback,
+ const std::string& request_uri,
+ HttpRequest::Method method, int response_code,
+ const HeaderList& expected_request_extra_headers,
+ bool has_body,
+ const HeaderList& expected_response_headers,
+ const std::string& expected_response_body,
+ size_t& total_bytes_downloaded) {
+ EXPECT_CALL(*request_callback, OnResponseBody(_, _, _))
+ .WillOnce(::testing::Invoke(
+ [=, &total_bytes_downloaded](const HttpRequest& request,
+ const HttpResponse& response,
+ absl::string_view data) {
+ EXPECT_THAT(request.uri(), request_uri);
+ EXPECT_THAT(request.method(), method);
+ EXPECT_THAT(
+ request.extra_headers(),
+ UnorderedElementsAreArray(expected_request_extra_headers));
+ EXPECT_THAT(request.HasBody(), has_body);
+ EXPECT_THAT(response.code(), response_code);
+ EXPECT_THAT(response.headers(),
+ UnorderedElementsAreArray(expected_response_headers));
+ EXPECT_THAT(data, expected_response_body);
+ total_bytes_downloaded += data.size();
+ return absl::OkStatus();
+ }));
+}
+
+void SetUpOnResponseCompleted(
+ StrictMock<MockHttpRequestCallback>* request_callback,
+ const std::string& request_uri, HttpRequest::Method method,
+ int response_code, const HeaderList& expected_request_extra_headers,
+ bool has_body, const HeaderList& expected_response_headers) {
+ EXPECT_CALL(*request_callback, OnResponseCompleted(_, _))
+ .WillOnce(::testing::Invoke([=, &request_uri](
+ const HttpRequest& request,
+ const HttpResponse& response) {
+ EXPECT_THAT(request.uri(), request_uri);
+ EXPECT_THAT(request.method(), method);
+ EXPECT_THAT(request.extra_headers(),
+ UnorderedElementsAreArray(expected_request_extra_headers));
+ EXPECT_THAT(request.HasBody(), has_body);
+ EXPECT_THAT(response.code(), response_code);
+ EXPECT_THAT(response.headers(),
+ UnorderedElementsAreArray(expected_response_headers));
+ }));
+}
+
+void SetUpGetRequestCallback(
+ StrictMock<MockHttpRequestCallback>* request_callback,
+ const std::string& request_uri,
+ const HeaderList& expected_request_extra_headers,
+ const HeaderList& expected_response_headers,
+ const std::string& expected_response_body, size_t& total_bytes_downloaded) {
+ SetUpOnResponseStarted(request_callback, request_uri,
+ HttpRequest::Method::kGet, /*response_code*/ 200,
+ expected_request_extra_headers,
+ /*has_body*/ false, expected_response_headers);
+ SetUpOnResponseBody(request_callback, request_uri, HttpRequest::Method::kGet,
+ /*response_code*/ 200, expected_request_extra_headers,
+ /*has_body*/ false, expected_response_headers,
+ expected_response_body, total_bytes_downloaded);
+
+ SetUpOnResponseCompleted(request_callback, request_uri,
+ HttpRequest::Method::kGet, /*response_code*/ 200,
+ expected_request_extra_headers,
+ /*has_body*/ false, expected_response_headers);
+}
+
+void SetUpPostRequestCallback(
+ StrictMock<MockHttpRequestCallback>* request_callback,
+ const std::string& request_uri,
+ const HeaderList& expected_request_extra_headers,
+ const HeaderList& expected_response_headers,
+ const std::string& expected_response_body, size_t& total_bytes_downloaded) {
+ SetUpOnResponseStarted(request_callback, request_uri,
+ HttpRequest::Method::kPost, /*response_code*/ 200,
+ expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers);
+
+ SetUpOnResponseBody(request_callback, request_uri, HttpRequest::Method::kPost,
+ /*response_code*/ 200, expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers,
+ expected_response_body, total_bytes_downloaded);
+
+ SetUpOnResponseCompleted(request_callback, request_uri,
+ HttpRequest::Method::kPost, /*response_code*/ 200,
+ expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers);
+}
+
+void SetUpPutRequestCallback(
+ StrictMock<MockHttpRequestCallback>* request_callback,
+ const std::string& request_uri,
+ const HeaderList& expected_request_extra_headers,
+ const HeaderList& expected_response_headers,
+ const std::string& expected_response_body, size_t& total_bytes_downloaded) {
+ SetUpOnResponseStarted(request_callback, request_uri,
+ HttpRequest::Method::kPut, /*response_code*/ 200,
+ expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers);
+
+ SetUpOnResponseBody(request_callback, request_uri, HttpRequest::Method::kPut,
+ /*response_code*/ 200, expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers,
+ expected_response_body, total_bytes_downloaded);
+
+ SetUpOnResponseCompleted(request_callback, request_uri,
+ HttpRequest::Method::kPut, /*response_code*/ 200,
+ expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers);
+}
+
+void PerformTwoRequests(CurlHttpClient* http_client, int port,
+ const std::string& request_uri1,
+ const std::string& request_uri2) {
+ std::string request1_body;
+ std::string request2_body = "test: 123-45";
+
+ auto request1 = InMemoryHttpRequest::Create(
+ request_uri1, HttpRequest::Method::kGet, HeaderList(), request1_body,
+ /*use_compression*/ false);
+ ASSERT_OK(request1);
+
+ auto request2 = InMemoryHttpRequest::Create(
+ request_uri2, HttpRequest::Method::kPost, HeaderList(), request2_body,
+ /*use_compression*/ false);
+ ASSERT_OK(request2);
+
+ auto handle1 = http_client->EnqueueRequest(std::move(request1.value()));
+ auto handle2 = http_client->EnqueueRequest(std::move(request2.value()));
+
+ auto request_callback1 =
+ std::make_unique<StrictMock<MockHttpRequestCallback>>();
+ auto request_callback2 =
+ std::make_unique<StrictMock<MockHttpRequestCallback>>();
+
+ size_t total_bytes_downloaded_handle1 = 0;
+ size_t total_bytes_downloaded_handle2 = 0;
+
+ HeaderList expected_request_extra_headers{{"Content-Length", "12"}};
+ HeaderList expected_response_headers{
+ {"Content-Type", "text/html"},
+ {"Date",
+ absl::FormatTime("%a, %d %b %Y %H", absl::Now(), absl::UTCTimeZone())}};
+
+ auto expected_response_body1 = absl::StrCat(
+ "HTTP Method: GET\n", "Request Uri: /test\n", "Request Headers:\n",
+ "Host: localhost:", port, "\nAccept: */*\n", "Accept-Encoding: gzip\n",
+ "Request Body:\n");
+
+ auto expected_response_body2 = absl::StrCat(
+ "HTTP Method: POST\nRequest Uri: /test\n",
+ "Request Headers:\nHost: localhost:", port,
+ "\nAccept: */*\nAccept-Encoding: gzip\n", "Content-Length: 12\n",
+ "Content-Type: application/x-www-form-urlencoded\n",
+ "Request Body:\ntest: 123-45");
+
+ SetUpGetRequestCallback(request_callback1.get(), request_uri1,
+ /*expected_request_extra_headers*/ {},
+ expected_response_headers, expected_response_body1,
+ total_bytes_downloaded_handle1);
+ SetUpPostRequestCallback(request_callback2.get(), request_uri2,
+ expected_request_extra_headers,
+ expected_response_headers, expected_response_body2,
+ total_bytes_downloaded_handle2);
+
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>> requests{
+ std::make_pair(handle1.get(), request_callback1.get()),
+ std::make_pair(handle2.get(), request_callback2.get())};
+
+ EXPECT_THAT(handle1->TotalSentReceivedBytes(), FieldsAre(0, 0));
+ EXPECT_THAT(handle2->TotalSentReceivedBytes(), FieldsAre(0, 0));
+
+ absl::Status status = http_client->PerformRequests(requests);
+
+ EXPECT_THAT(status, absl::OkStatus());
+ EXPECT_THAT(handle1->TotalSentReceivedBytes(),
+ AllOf(Field(&HttpRequestHandle::SentReceivedBytes::sent_bytes,
+ request1_body.size()),
+ Field(&HttpRequestHandle::SentReceivedBytes::received_bytes,
+ total_bytes_downloaded_handle1)));
+ EXPECT_THAT(handle2->TotalSentReceivedBytes(),
+ AllOf(Field(&HttpRequestHandle::SentReceivedBytes::sent_bytes,
+ request2_body.size()),
+ Field(&HttpRequestHandle::SentReceivedBytes::received_bytes,
+ total_bytes_downloaded_handle2)));
+}
+
+// Runs PerformRequests once in one thread.
+TEST(CurlHttpClientTest, PerformTwoRequestsInParallelInOneThread) {
+ const int port = 4568;
+ const std::string request_uri =
+ absl::StrCat("http://localhost:", port, "/test");
+
+ auto curl_api = std::make_unique<CurlApi>();
+ auto http_client = std::make_unique<CurlHttpClient>(curl_api.get());
+ auto http_server = CreateHttpTestServer("/test", port, /*num_threads*/ 5);
+ EXPECT_THAT(http_server.ok(), true);
+ EXPECT_THAT(http_server.value()->StartAcceptingRequests(), true);
+
+ PerformTwoRequests(http_client.get(), port, request_uri, request_uri);
+
+ curl_api.reset();
+ http_server.value()->Terminate();
+ http_server.value()->WaitForTermination();
+}
+
+TEST(CurlHttpClientTest, PutRequest) {
+ const int port = 4568;
+ const std::string request_uri =
+ absl::StrCat("http://localhost:", port, "/test");
+
+ auto curl_api = std::make_unique<CurlApi>();
+ auto http_client = std::make_unique<CurlHttpClient>(curl_api.get());
+ auto http_server = CreateHttpTestServer("/test", port, /*num_threads*/ 5);
+ EXPECT_THAT(http_server.ok(), true);
+ EXPECT_THAT(http_server.value()->StartAcceptingRequests(), true);
+
+ std::string request_body = "test: 123-45";
+
+ auto request = InMemoryHttpRequest::Create(
+ request_uri, HttpRequest::Method::kPut, HeaderList(), request_body,
+ /*use_compression*/ false);
+ ASSERT_OK(request);
+
+ auto handle = http_client->EnqueueRequest(std::move(request.value()));
+
+ auto request_callback =
+ std::make_unique<StrictMock<MockHttpRequestCallback>>();
+
+ size_t total_bytes_downloaded_handle = 0;
+
+ HeaderList expected_request_extra_headers{{"Content-Length", "12"}};
+ HeaderList expected_response_headers{
+ {"Content-Type", "text/html"},
+ {"Date",
+ absl::FormatTime("%a, %d %b %Y %H", absl::Now(), absl::UTCTimeZone())}};
+
+ auto expected_response_body =
+ absl::StrCat("HTTP Method: PUT\nRequest Uri: /test\n",
+ "Request Headers:\nHost: localhost:", port,
+ "\nAccept: */*\nAccept-Encoding: gzip\n",
+ "Content-Length: 12\n", "Request Body:\ntest: 123-45");
+
+ SetUpPutRequestCallback(request_callback.get(), request_uri,
+ expected_request_extra_headers,
+ expected_response_headers, expected_response_body,
+ total_bytes_downloaded_handle);
+
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>> requests{
+ std::make_pair(handle.get(), request_callback.get())};
+
+ EXPECT_THAT(handle->TotalSentReceivedBytes(), FieldsAre(0, 0));
+
+ absl::Status status = http_client->PerformRequests(requests);
+
+ EXPECT_THAT(status, absl::OkStatus());
+ EXPECT_THAT(handle->TotalSentReceivedBytes(),
+ AllOf(Field(&HttpRequestHandle::SentReceivedBytes::sent_bytes,
+ request_body.size()),
+ Field(&HttpRequestHandle::SentReceivedBytes::received_bytes,
+ total_bytes_downloaded_handle)));
+
+ curl_api.reset();
+ http_server.value()->Terminate();
+ http_server.value()->WaitForTermination();
+}
+
+// Runs PerformRequests in five times in five threads.
+TEST(CurlHttpClientTest, PerformTwoRequestsInParallelFiveTimesInFiveThreads) {
+ const int port = 4568;
+ const std::string request_uri =
+ absl::StrCat("http://localhost:", port, "/test");
+
+ auto curl_api = std::make_unique<CurlApi>();
+ auto http_client = std::make_unique<CurlHttpClient>(curl_api.get());
+ auto http_server = CreateHttpTestServer("/test", port, /*num_threads*/ 10);
+ EXPECT_THAT(http_server.ok(), true);
+ EXPECT_THAT(http_server.value()->StartAcceptingRequests(), true);
+
+ auto thread_pool_scheduler = CreateThreadPoolScheduler(5);
+
+ thread_pool_scheduler->Schedule([&http_client, request_uri]() {
+ PerformTwoRequests(http_client.get(), port, request_uri, request_uri);
+ });
+ thread_pool_scheduler->Schedule([&http_client, request_uri]() {
+ PerformTwoRequests(http_client.get(), port, request_uri, request_uri);
+ });
+ thread_pool_scheduler->Schedule([&http_client, request_uri]() {
+ PerformTwoRequests(http_client.get(), port, request_uri, request_uri);
+ });
+ thread_pool_scheduler->Schedule([&http_client, request_uri]() {
+ PerformTwoRequests(http_client.get(), port, request_uri, request_uri);
+ });
+ thread_pool_scheduler->Schedule([&http_client, request_uri]() {
+ PerformTwoRequests(http_client.get(), port, request_uri, request_uri);
+ });
+
+ thread_pool_scheduler->WaitUntilIdle();
+
+ curl_api.reset();
+ http_server.value()->Terminate();
+ http_server.value()->WaitForTermination();
+}
+
+// Runs PerformRequests with two requests and cancels the second after
+// OnResponseStarted received.
+TEST(CurlHttpClientTest, CancelRequest) {
+ const int port = 4568;
+ const std::string request_uri1 =
+ absl::StrCat("http://localhost:", port, "/test");
+ const std::string request_uri2 =
+ absl::StrCat("http://localhost:", port, "/test");
+
+ auto curl_api = std::make_unique<CurlApi>();
+ auto http_client = std::make_unique<CurlHttpClient>(curl_api.get());
+ auto http_server = CreateHttpTestServer("/test", port, /*num_threads*/ 5);
+ EXPECT_THAT(http_server.ok(), true);
+ EXPECT_THAT(http_server.value()->StartAcceptingRequests(), true);
+
+ std::string request1_body;
+ std::string request2_body = "test: 123-45";
+
+ auto request1 = InMemoryHttpRequest::Create(
+ request_uri1, HttpRequest::Method::kGet, HeaderList(), request1_body,
+ /*use_compression*/ false);
+ ASSERT_OK(request1);
+
+ auto request2 = InMemoryHttpRequest::Create(
+ request_uri2, HttpRequest::Method::kPost, HeaderList(), request2_body,
+ /*use_compression*/ false);
+ ASSERT_OK(request2);
+
+ auto handle1 = http_client->EnqueueRequest(std::move(request1.value()));
+ auto handle2 = http_client->EnqueueRequest(std::move(request2.value()));
+
+ auto request_callback1 =
+ std::make_unique<StrictMock<MockHttpRequestCallback>>();
+ auto request_callback2 =
+ std::make_unique<StrictMock<MockHttpRequestCallback>>();
+
+ size_t total_bytes_downloaded_handle1 = 0;
+
+ HeaderList expected_request_extra_headers{{"Content-Length", "12"}};
+ HeaderList expected_response_headers{
+ {"Content-Type", "text/html"},
+ {"Date",
+ absl::FormatTime("%a, %d %b %Y %H", absl::Now(), absl::UTCTimeZone())}};
+
+ auto expected_response_body = absl::StrCat(
+ "HTTP Method: GET\n", "Request Uri: /test\n", "Request Headers:\n",
+ "Host: localhost:", port, "\nAccept: */*\n", "Accept-Encoding: gzip\n",
+ "Request Body:\n");
+
+ SetUpGetRequestCallback(request_callback1.get(), request_uri1,
+ /*expected_request_extra_headers*/ {},
+ expected_response_headers, expected_response_body,
+ total_bytes_downloaded_handle1);
+
+ EXPECT_CALL(*request_callback2, OnResponseStarted(_, _))
+ .WillOnce(::testing::Invoke(
+ [&handle2](const HttpRequest& request, const HttpResponse& response) {
+ handle2->Cancel();
+ return absl::OkStatus();
+ }));
+
+ auto expected_response_body2 = absl::StrCat(
+ "HTTP Method: POST\nRequest Uri: /test\n",
+ "Request Headers:\nHost: localhost:", port,
+ "\nAccept: */*\nAccept-Encoding: gzip\n", "Content-Length: 12\n",
+ "Content-Type: application/x-www-form-urlencoded\n",
+ "Request Body:\ntest: 123-45");
+
+ SetUpOnResponseBody(
+ request_callback2.get(), request_uri2, HttpRequest::Method::kPost,
+ /*response_code*/ 200, expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers, expected_response_body2);
+
+ EXPECT_CALL(*request_callback2, OnResponseBodyError(_, _, _))
+ .WillOnce(::testing::Invoke([&request_uri2](const HttpRequest& request,
+ const HttpResponse& response,
+ const absl::Status& error) {
+ EXPECT_THAT(request.uri(), request_uri2);
+ EXPECT_THAT(request.method(), HttpRequest::Method::kPost);
+ EXPECT_THAT(request.extra_headers().size(), 1);
+ EXPECT_THAT(request.HasBody(), true);
+ }));
+
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>> requests{
+ std::make_pair(handle1.get(), request_callback1.get()),
+ std::make_pair(handle2.get(), request_callback2.get())};
+
+ auto thread_pool_scheduler = CreateThreadPoolScheduler(2);
+ thread_pool_scheduler->Schedule([&http_client, &requests]() {
+ absl::Status status = http_client->PerformRequests(requests);
+
+ EXPECT_THAT(status, absl::OkStatus());
+ });
+
+ thread_pool_scheduler->WaitUntilIdle();
+
+ EXPECT_THAT(handle1->TotalSentReceivedBytes(),
+ AllOf(Field(&HttpRequestHandle::SentReceivedBytes::sent_bytes,
+ request1_body.size()),
+ Field(&HttpRequestHandle::SentReceivedBytes::received_bytes,
+ total_bytes_downloaded_handle1)));
+
+ curl_api.reset();
+ http_server.value()->Terminate();
+ http_server.value()->WaitForTermination();
+}
+
+// Runs PerformRequests once in one thread.
+TEST(CurlHttpClientTest, TestExtraHeaders) {
+ const int port = 4568;
+ const std::string request_uri =
+ absl::StrCat("http://localhost:", port, "/test");
+
+ auto curl_api = std::make_unique<CurlApi>();
+ auto http_client = std::make_unique<CurlHttpClient>(curl_api.get());
+
+ auto http_server = CreateHttpTestServer("/test", port, /*num_threads*/ 3);
+ EXPECT_THAT(http_server.ok(), true);
+ EXPECT_THAT(http_server.value()->StartAcceptingRequests(), true);
+
+ std::string request_body = "test: 123-45";
+ HeaderList extra_headers = {{"Content-Type", "application/x-protobuf"}};
+
+ auto request = InMemoryHttpRequest::Create(
+ request_uri, HttpRequest::Method::kPost, extra_headers, request_body,
+ /*use_compression*/ false);
+ ASSERT_OK(request);
+
+ auto handle = http_client->EnqueueRequest(std::move(request.value()));
+
+ auto request_callback =
+ std::make_unique<StrictMock<MockHttpRequestCallback>>();
+
+ HeaderList expected_request_extra_headers{
+ {"Content-Length", "12"},
+ {"Content-Type", "application/x-protobuf"},
+ };
+ HeaderList expected_response_headers{
+ {"Content-Type", "text/html"},
+ {"Date",
+ absl::FormatTime("%a, %d %b %Y %H", absl::Now(), absl::UTCTimeZone())}};
+ auto expected_response_body =
+ absl::StrCat("HTTP Method: POST\nRequest Uri: /test\n",
+ "Request Headers:\nHost: localhost:", port,
+ "\nAccept: */*\nAccept-Encoding: gzip\n",
+ "Content-Type: application/x-protobuf\n",
+ "Content-Length: 12\n", "Request Body:\ntest: 123-45");
+
+ SetUpOnResponseStarted(request_callback.get(), request_uri,
+ HttpRequest::Method::kPost, /*response_code*/ 200,
+ expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers);
+
+ SetUpOnResponseBody(
+ request_callback.get(), request_uri, HttpRequest::Method::kPost,
+ /*response_code*/ 200, expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers, expected_response_body);
+
+ SetUpOnResponseCompleted(request_callback.get(), request_uri,
+ HttpRequest::Method::kPost, /*response_code*/ 200,
+ expected_request_extra_headers,
+ /*has_body*/ true, expected_response_headers);
+
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>> requests{
+ std::make_pair(handle.get(), request_callback.get())};
+
+ absl::Status status = http_client->PerformRequests(requests);
+ EXPECT_THAT(status, absl::OkStatus());
+
+ curl_api.reset();
+ http_server.value()->Terminate();
+ http_server.value()->WaitForTermination();
+}
+
+} // namespace
+} // namespace fcp::client::http::curl
diff --git a/fcp/client/http/curl/curl_http_request_handle.cc b/fcp/client/http/curl/curl_http_request_handle.cc
new file mode 100644
index 0000000..874b774
--- /dev/null
+++ b/fcp/client/http/curl/curl_http_request_handle.cc
@@ -0,0 +1,386 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/http/curl/curl_http_request_handle.h"
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "absl/strings/match.h"
+#include "curl/curl.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/http/curl/curl_api.h"
+#include "fcp/client/http/curl/curl_header_parser.h"
+#include "fcp/client/http/curl/curl_http_response.h"
+#include "fcp/client/http/http_client_util.h"
+
+namespace fcp::client::http::curl {
+namespace {
+// A type check for the macro.
+inline CURLcode AsCode(CURLcode code) { return code; }
+/**
+ * Macro which allows to check for a status code and return from the
+ * current method if not OK. Example:
+ *
+ * Status DoSomething() {
+ * CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(...));
+ * }
+ */
+#define CURL_RETURN_IF_ERROR(expr) \
+ do { \
+ CURLcode __code = AsCode(expr); \
+ if (__code != CURLE_OK) { \
+ FCP_LOG(ERROR) << "Easy handle failed with " \
+ << CurlEasyHandle::StrError(__code); \
+ return (__code); \
+ } \
+ } while (false)
+
+// Add a new element to a C-style list.
+curl_slist* AddToCurlHeaderList(curl_slist* header_list, const std::string& key,
+ const std::string& value) {
+ // A null pointer is returned if anything went wrong, otherwise a new
+ // list pointer is returned.
+ curl_slist* tmp =
+ curl_slist_append(header_list, absl::StrCat(key, ": ", value).c_str());
+ FCP_CHECK(tmp != nullptr);
+ return tmp;
+}
+
+} // namespace
+
+size_t CurlHttpRequestHandle::HeaderCallback(char* buffer, size_t size,
+ size_t n_items, void* user_data) {
+ auto self = static_cast<CurlHttpRequestHandle*>(user_data);
+ std::string str_header(static_cast<char*>(buffer), size * n_items);
+
+ self->header_parser_.ParseHeader(str_header);
+ if (!self->header_parser_.IsLastHeader()) {
+ return size * n_items;
+ }
+
+ self->response_ =
+ std::make_unique<CurlHttpResponse>(self->header_parser_.GetStatusCode(),
+ self->header_parser_.GetHeaderList());
+
+ FCP_CHECK(self->callback_ != nullptr);
+ absl::Status status =
+ self->callback_->OnResponseStarted(*self->request_, *self->response_);
+
+ if (!status.ok()) {
+ FCP_LOG(ERROR) << "Called OnResponseStarted. Received status: " << status;
+ self->callback_->OnResponseError(*self->request_, status);
+ }
+
+ return size * n_items;
+}
+
+size_t CurlHttpRequestHandle::DownloadCallback(void* body, size_t size,
+ size_t nmemb, void* user_data) {
+ auto self = static_cast<CurlHttpRequestHandle*>(user_data);
+ absl::string_view str_body(static_cast<char*>(body), size * nmemb);
+
+ absl::Status status = self->callback_->OnResponseBody(
+ *self->request_, *self->response_, str_body);
+
+ if (!status.ok()) {
+ FCP_LOG(ERROR) << "Called OnResponseBody. Received status: " << status;
+ self->callback_->OnResponseBodyError(*self->request_, *self->response_,
+ status);
+ }
+
+ return size * nmemb;
+}
+
+size_t CurlHttpRequestHandle::UploadCallback(char* buffer, size_t size,
+ size_t num, void* user_data) {
+ auto self = static_cast<CurlHttpRequestHandle*>(user_data);
+ size_t buffer_size = size * num;
+
+ absl::StatusOr<int64_t> read_size =
+ self->request_->ReadBody(buffer, buffer_size);
+ if (read_size.ok()) {
+ return read_size.value();
+ } else if (read_size.status().code() == absl::StatusCode::kOutOfRange) {
+ return 0;
+ }
+ return CURL_READFUNC_ABORT;
+}
+
+size_t CurlHttpRequestHandle::ProgressCallback(void* user_data,
+ curl_off_t dltotal,
+ curl_off_t dlnow,
+ curl_off_t ultotal,
+ curl_off_t ulnow) {
+ auto self = static_cast<CurlHttpRequestHandle*>(user_data);
+ absl::MutexLock lock(&self->mutex_);
+ // Abort is any number except zero.
+ return (self->is_cancelled_) ? 1 : 0;
+}
+
+CurlHttpRequestHandle::CurlHttpRequestHandle(
+ std::unique_ptr<HttpRequest> request,
+ std::unique_ptr<CurlEasyHandle> easy_handle,
+ const std::string& test_cert_path)
+ : request_(std::move(request)),
+ response_(nullptr),
+ easy_handle_(std::move(easy_handle)),
+ callback_(nullptr),
+ is_being_performed_(false),
+ is_completed_(false),
+ is_cancelled_(false),
+ header_list_(nullptr) {
+ FCP_CHECK(request_ != nullptr);
+ FCP_CHECK(easy_handle_ != nullptr);
+
+ CURLcode code = InitializeConnection(test_cert_path);
+ if (code != CURLE_OK) {
+ FCP_LOG(ERROR) << "easy_handle initialization failed with code "
+ << CurlEasyHandle::StrError(code);
+ FCP_LOG(ERROR) << error_buffer_;
+ callback_->OnResponseError(*request_, absl::InternalError(error_buffer_));
+ return;
+ }
+}
+
+CurlHttpRequestHandle::~CurlHttpRequestHandle() {
+ curl_slist_free_all(header_list_);
+}
+
+void CurlHttpRequestHandle::Cancel() {
+ absl::MutexLock lock(&mutex_);
+
+ if (callback_ == nullptr || is_cancelled_ || is_completed_) {
+ return;
+ }
+ if (response_ != nullptr) {
+ callback_->OnResponseBodyError(*request_, *response_,
+ absl::CancelledError());
+ } else {
+ callback_->OnResponseError(*request_, absl::CancelledError());
+ }
+ is_cancelled_ = true;
+}
+
+void CurlHttpRequestHandle::MarkAsCompleted() {
+ absl::MutexLock lock(&mutex_);
+
+ FCP_CHECK(callback_ != nullptr);
+ if (!is_cancelled_ && !is_completed_) {
+ if (response_ != nullptr) {
+ callback_->OnResponseCompleted(*request_, *response_);
+ } else {
+ callback_->OnResponseError(*request_,
+ absl::InternalError("response_ is nullptr"));
+ }
+ }
+ is_completed_ = true;
+}
+
+absl::Status CurlHttpRequestHandle::AddToMulti(CurlMultiHandle* multi_handle,
+ HttpRequestCallback* callback) {
+ absl::MutexLock lock(&mutex_);
+
+ FCP_CHECK(callback != nullptr);
+ FCP_CHECK(multi_handle != nullptr);
+
+ if (is_cancelled_) {
+ callback->OnResponseError(*request_, absl::CancelledError());
+ return absl::CancelledError();
+ } else if (is_being_performed_ || is_completed_) {
+ return absl::ResourceExhaustedError(
+ "The handle was previously passed to another PerformRequests call.");
+ }
+
+ is_being_performed_ = true;
+ callback_ = callback;
+
+ CURLMcode code = multi_handle->AddEasyHandle(easy_handle_.get());
+ if (code != CURLM_OK) {
+ FCP_LOG(ERROR) << "AddEasyHandle failed with code " << code;
+ FCP_LOG(ERROR) << error_buffer_;
+ callback_->OnResponseError(*request_, absl::InternalError(error_buffer_));
+ return absl::InternalError(error_buffer_);
+ }
+
+ return absl::OkStatus();
+}
+
+void CurlHttpRequestHandle::RemoveFromMulti(CurlMultiHandle* multi_handle) {
+ absl::MutexLock lock(&mutex_);
+
+ FCP_CHECK(multi_handle != nullptr);
+ CURLMcode code = multi_handle->RemoveEasyHandle(easy_handle_.get());
+ if (code != CURLM_OK) {
+ FCP_LOG(ERROR) << "RemoveEasyHandle failed with code "
+ << CurlMultiHandle::StrError(code);
+ FCP_LOG(ERROR) << error_buffer_;
+ }
+}
+
+HttpRequestHandle::SentReceivedBytes
+CurlHttpRequestHandle::TotalSentReceivedBytes() const {
+ absl::MutexLock lock(&mutex_);
+ curl_off_t total_sent_bytes = 0;
+ CURLcode code =
+ easy_handle_->GetInfo(CURLINFO_SIZE_UPLOAD_T, &total_sent_bytes);
+ if (code != CURLE_OK) {
+ FCP_LOG(ERROR) << "TotalSentBytes failed with code " << code;
+ FCP_LOG(ERROR) << error_buffer_;
+ }
+ curl_off_t total_received_bytes = 0;
+ code = easy_handle_->GetInfo(CURLINFO_SIZE_DOWNLOAD_T, &total_received_bytes);
+ if (code != CURLE_OK) {
+ FCP_LOG(ERROR) << "TotalReceivedBytes failed with code " << code;
+ FCP_LOG(ERROR) << error_buffer_;
+ }
+ return {.sent_bytes = total_sent_bytes,
+ .received_bytes = total_received_bytes};
+}
+
+CURLcode CurlHttpRequestHandle::InitializeConnection(
+ const std::string& test_cert_path) {
+ error_buffer_[0] = 0;
+ // Needed to read an error message.
+ CURL_RETURN_IF_ERROR(
+ easy_handle_->SetOpt(CURLOPT_ERRORBUFFER, error_buffer_));
+
+ // Skip all signal handling because is not thread-safe.
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_NOSIGNAL, 1L));
+
+ CURL_RETURN_IF_ERROR(
+ easy_handle_->SetOpt(CURLOPT_URL, std::string(request_->uri())));
+
+ // Forces curl to follow redirects.
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_FOLLOWLOCATION, 1L));
+
+ // Suppresses headers added by a proxy.
+ CURL_RETURN_IF_ERROR(
+ easy_handle_->SetOpt(CURLOPT_SUPPRESS_CONNECT_HEADERS, 1L));
+
+ // Force curl to verify the ssl connection.
+ CURL_RETURN_IF_ERROR(
+ test_cert_path.empty()
+ ? easy_handle_->SetOpt(CURLOPT_CAINFO, nullptr)
+ : easy_handle_->SetOpt(CURLOPT_CAINFO, test_cert_path));
+
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_SSL_VERIFYPEER, 2L));
+
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_SSL_VERIFYHOST, 1L));
+
+ // Force curl to never timeout.
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_TIMEOUT_MS,
+ std::numeric_limits<int>::max()));
+
+ // Called when a response header received.
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(
+ CURLOPT_HEADERFUNCTION, &CurlHttpRequestHandle::HeaderCallback));
+
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_HEADERDATA, this));
+
+ // Called when a response body received.
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(
+ CURLOPT_WRITEFUNCTION, &CurlHttpRequestHandle::DownloadCallback));
+
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_WRITEDATA, this));
+
+ // Called to send a request body
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(
+ CURLOPT_READFUNCTION, &CurlHttpRequestHandle::UploadCallback));
+
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_READDATA, this));
+
+ // Called periodically. We use it to check whether the request is cancelled.
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(
+ CURLOPT_XFERINFOFUNCTION, &CurlHttpRequestHandle::ProgressCallback));
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_XFERINFODATA, this));
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_NOPROGRESS, 0L));
+
+ // Private storage. Used by a multi-handle.
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_PRIVATE, this));
+
+ switch (request_->method()) {
+ case HttpRequest::Method::kGet:
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_HTTPGET, 1L));
+ break;
+ case HttpRequest::Method::kHead:
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_NOBODY, 1L));
+ break;
+ case HttpRequest::Method::kPost:
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_POST, 1L));
+ // Forces curl to use the callback.
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_POSTFIELDS, nullptr));
+ break;
+ case HttpRequest::Method::kPut:
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_UPLOAD, 1L));
+ break;
+ case HttpRequest::Method::kPatch:
+ case HttpRequest::Method::kDelete:
+ FCP_LOG(ERROR) << "Unsupported request type";
+ return CURLE_UNSUPPORTED_PROTOCOL;
+ }
+
+ return InitializeHeaders(request_->extra_headers(), request_->method());
+}
+
+CURLcode CurlHttpRequestHandle::InitializeHeaders(
+ const HeaderList& extra_headers, HttpRequest::Method method) {
+ // If no "Accept-Encoding" request header is explicitly specified
+ // advertise an "Accept-Encoding: gzip" else leave decoded.
+ std::optional<std::string> accept_encoding =
+ FindHeader(request_->extra_headers(), kAcceptEncodingHdr);
+ if (!accept_encoding.has_value()) {
+ // Libcurl is responsible for the encoding.
+ CURL_RETURN_IF_ERROR(
+ easy_handle_->SetOpt(CURLOPT_ACCEPT_ENCODING, kGzipEncodingHdrValue));
+ header_parser_.UseCurlEncoding();
+ } else {
+ // The caller is responsible for the encoding.
+ CURL_RETURN_IF_ERROR(
+ easy_handle_->SetOpt(CURLOPT_ACCEPT_ENCODING, nullptr));
+ }
+
+ for (auto& [key, value] : extra_headers) {
+ if (absl::EqualsIgnoreCase(key, kAcceptEncodingHdr)) {
+ continue;
+ } else if (absl::EqualsIgnoreCase(key, kContentLengthHdr)) {
+ if (method == HttpRequest::Method::kPost) {
+ // For post less than 2GB
+ CURL_RETURN_IF_ERROR(
+ easy_handle_->SetOpt(CURLOPT_POSTFIELDSIZE, std::stol(value)));
+
+ // Removes the header to prevent libcurl from setting it
+ // to 'Expect: 100-continue' by default, which causes an additional
+ // and unnecessary network round trip.
+ header_list_ = AddToCurlHeaderList(header_list_, kExpectHdr, "");
+ } else if (method == HttpRequest::Method::kPut) {
+ CURL_RETURN_IF_ERROR(
+ easy_handle_->SetOpt(CURLOPT_INFILESIZE, std::stol(value)));
+ header_list_ = AddToCurlHeaderList(header_list_, kExpectHdr, "");
+ }
+ } else {
+ // A user-defined "Expect" header is not supported.
+ FCP_CHECK(!absl::EqualsIgnoreCase(key, kExpectHdr));
+ header_list_ = AddToCurlHeaderList(header_list_, key, value);
+ }
+ }
+
+ CURL_RETURN_IF_ERROR(easy_handle_->SetOpt(CURLOPT_HTTPHEADER, header_list_));
+ return CURLE_OK;
+}
+} // namespace fcp::client::http::curl
diff --git a/fcp/client/http/curl/curl_http_request_handle.h b/fcp/client/http/curl/curl_http_request_handle.h
new file mode 100644
index 0000000..1636e44
--- /dev/null
+++ b/fcp/client/http/curl/curl_http_request_handle.h
@@ -0,0 +1,102 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_HTTP_CURL_CURL_HTTP_REQUEST_HANDLE_H_
+#define FCP_CLIENT_HTTP_CURL_CURL_HTTP_REQUEST_HANDLE_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/synchronization/mutex.h"
+#include "curl/curl.h"
+#include "fcp/client/http/curl/curl_api.h"
+#include "fcp/client/http/curl/curl_header_parser.h"
+#include "fcp/client/http/http_client.h"
+
+namespace fcp::client::http::curl {
+
+// A thread-safe curl-based implementation. Designed to be used with
+// CurlHttpClient.
+class CurlHttpRequestHandle : public HttpRequestHandle {
+ public:
+ // If non-empty, `test_cert_path` specifies the path to the Certificate
+ // Authority (CA) bundle to use instead of the system defaults.
+ CurlHttpRequestHandle(std::unique_ptr<HttpRequest> request,
+ std::unique_ptr<CurlEasyHandle> easy_handle,
+ const std::string& test_cert_path);
+ ~CurlHttpRequestHandle() override;
+ CurlHttpRequestHandle(const CurlHttpRequestHandle&) = delete;
+ CurlHttpRequestHandle& operator=(const CurlHttpRequestHandle&) = delete;
+
+ // Adds this request to the corresponding multi-handle that can execute
+ // multiple requests in parallel. The corresponding callbacks will be
+ // called accordingly.
+ absl::Status AddToMulti(CurlMultiHandle* multi_handle,
+ HttpRequestCallback* callback)
+ ABSL_LOCKS_EXCLUDED(mutex_);
+ // Removes this request from the corresponding multi-handle.
+ void RemoveFromMulti(CurlMultiHandle* multi_handle)
+ ABSL_LOCKS_EXCLUDED(mutex_);
+ // Marks the request as completed which fires the OnComplete callback.
+ void MarkAsCompleted() ABSL_LOCKS_EXCLUDED(mutex_);
+
+ // HttpRequestHandle overrides:
+ ABSL_MUST_USE_RESULT HttpRequestHandle::SentReceivedBytes
+ TotalSentReceivedBytes() const override ABSL_LOCKS_EXCLUDED(mutex_);
+ void Cancel() override ABSL_LOCKS_EXCLUDED(mutex_);
+
+ private:
+ // Initializes the easy_handle_ in the constructor.
+ CURLcode InitializeConnection(const std::string& test_cert_path)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ // Initializes headers from external_headers
+ CURLcode InitializeHeaders(const HeaderList& extra_headers,
+ HttpRequest::Method method)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ // A helper function called _sequentially_ when a response header received.
+ static size_t HeaderCallback(char* buffer, size_t size, size_t n_items,
+ void* user_data) ABSL_NO_THREAD_SAFETY_ANALYSIS;
+ // A helper function called _sequentially_ when a chunk of a body received
+ static size_t DownloadCallback(void* body, size_t size, size_t nmemb,
+ void* user_data)
+ ABSL_NO_THREAD_SAFETY_ANALYSIS;
+ // A helper function called _sequentially_ to send a chunk of a body.
+ static size_t UploadCallback(char* buffer, size_t size, size_t num,
+ void* user_data) ABSL_NO_THREAD_SAFETY_ANALYSIS;
+
+ // Called periodically. Used to cancel the request.
+ static size_t ProgressCallback(void* user_data, curl_off_t dltotal,
+ curl_off_t dlnow, curl_off_t ultotal,
+ curl_off_t ulnow) ABSL_LOCKS_EXCLUDED(mutex_);
+
+ mutable absl::Mutex mutex_;
+ const std::unique_ptr<HttpRequest> request_ ABSL_GUARDED_BY(mutex_);
+ std::unique_ptr<HttpResponse> response_ ABSL_GUARDED_BY(mutex_);
+ const std::unique_ptr<CurlEasyHandle> easy_handle_ ABSL_GUARDED_BY(mutex_);
+ // Used only in the HeaderCallback sequentially.
+ CurlHeaderParser header_parser_{};
+ // Owned by the caller. Initialized in AddToMulti and then read-only.
+ HttpRequestCallback* callback_ ABSL_GUARDED_BY(mutex_);
+ bool is_being_performed_ ABSL_GUARDED_BY(mutex_);
+ bool is_completed_ ABSL_GUARDED_BY(mutex_);
+ bool is_cancelled_ ABSL_GUARDED_BY(mutex_);
+ char error_buffer_[CURL_ERROR_SIZE] ABSL_GUARDED_BY(mutex_){};
+ // Owned by the class.
+ curl_slist* header_list_;
+};
+} // namespace fcp::client::http::curl
+
+#endif // FCP_CLIENT_HTTP_CURL_CURL_HTTP_REQUEST_HANDLE_H_
diff --git a/fcp/client/http/curl/curl_http_response.cc b/fcp/client/http/curl/curl_http_response.cc
new file mode 100644
index 0000000..0607995
--- /dev/null
+++ b/fcp/client/http/curl/curl_http_response.cc
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/http/curl/curl_http_response.h"
+
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp::client::http::curl {
+CurlHttpResponse::CurlHttpResponse(int status_code, HeaderList header_list)
+ : status_code_(status_code), header_list_(std::move(header_list)) {}
+
+int CurlHttpResponse::code() const { return status_code_; }
+
+const HeaderList& CurlHttpResponse::headers() const { return header_list_; }
+
+} // namespace fcp::client::http::curl
diff --git a/fcp/client/http/curl/curl_http_response.h b/fcp/client/http/curl/curl_http_response.h
new file mode 100644
index 0000000..cf92a30
--- /dev/null
+++ b/fcp/client/http/curl/curl_http_response.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_HTTP_CURL_CURL_HTTP_RESPONSE_H_
+#define FCP_CLIENT_HTTP_CURL_CURL_HTTP_RESPONSE_H_
+
+#include <utility>
+
+#include "fcp/client/http/http_client.h"
+
+namespace fcp::client::http::curl {
+// A simple http response. This class is thread-safe.
+class CurlHttpResponse : public HttpResponse {
+ public:
+ CurlHttpResponse(int status_code, HeaderList header_list);
+ ~CurlHttpResponse() override = default;
+
+ // HttpResponse:
+ ABSL_MUST_USE_RESULT int code() const override;
+ ABSL_MUST_USE_RESULT const HeaderList& headers() const override;
+
+ private:
+ const int status_code_;
+ const HeaderList header_list_;
+};
+} // namespace fcp::client::http::curl
+
+#endif // FCP_CLIENT_HTTP_CURL_CURL_HTTP_RESPONSE_H_
diff --git a/fcp/client/http/http_client.h b/fcp/client/http/http_client.h
new file mode 100644
index 0000000..f96e80c
--- /dev/null
+++ b/fcp/client/http/http_client.h
@@ -0,0 +1,468 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_HTTP_CLIENT_H_
+#define FCP_CLIENT_HTTP_HTTP_CLIENT_H_
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/base/attributes.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+using Header = std::pair<std::string, std::string>;
+// This is a vector of pairs and not a map since multiple request headers with
+// the same name are allowed (see RFC2616 section 4.2).
+using HeaderList = std::vector<Header>;
+
+class HttpRequest; // forward declaration
+class HttpRequestCallback; // forward declaration
+class HttpRequestHandle; // forward declaration
+class HttpResponse; // forward declaration
+
+// An interface that allows the callers to make HTTP requests and receive their
+// responses.
+//
+// Platforms will be required to pass an instance of this class to
+// `RunFederatedComputation(...)`, and such instances must remain alive for at
+// least the duration of those calls.
+//
+// Instances of this class must support being called from any thread. In such
+// cases implementations should ideally handle multiple requests in parallel, at
+// least up to an implementation-defined max number of parallel requests (i.e.
+// ideally two `PerformRequests` calls on different threads can be handled in
+// parallel, rather than having the 2nd call block until the 1st call is
+// completed).
+//
+// Besides the requirements documented further below, the following high-level
+// behavior is required of any `HttpClient` implementation:
+// - Underlying protocols:
+// * Implementations must support at least HTTP/1.1 over TLS 1.2.
+// * Implementations are also allowed to serve requests using HTTP/2, QUIC or
+// other newer protocols.
+// * Implementations must support both IPv4 and IPv6 (but they are allowed to
+// fall back to IPv4).
+// - Certificate validation:
+// * Implementations are responsible for TLS certificate validation and for
+// maintaining an up-to-date set of root certificates as well as an
+// up-to-date HTTP/TLS implementation.
+// * Implementations should not include user-added CAs, and should consider
+// restricting the set of CAs further to only those needed for connecting to
+// the expected endpoints.
+// - Cookies:
+// * Implementations must not supply any cookies in requests (beyond those
+// that may be specified in the `HttpRequest::headers()` method).
+// * Implementations must not store any cookies returned by the server.
+// * Instead, they must return any server-specified "Set-Cookie" response
+// header via `HttpResponse::headers()`.
+// - Redirects:
+// * Implementations must follow HTTP redirects responses, up to an
+// implementation-defined maximum.
+// * In such cases the response headers & body returned via the interfaces
+// below should be those of the final response.
+// * See `HttpRequestCallback` docs below for more details.
+// - Caching:
+// * Implementations should not implement a cache as it is expected that
+// naive HTTP-level caching will not be effective (and since a cache may
+// ultimately be implemented over this interface, in the Federated Compute
+// library itself).
+// * If implementations do implement one, however, they are expected to abide
+// by the standard HTTP caching rules (see the `HttpRequest::Method` docs
+// for more details).
+// - Response body decompression & decoding:
+// * If no "Accept-Encoding" request header is explicitly specified in
+// `HttpRequest::headers()`, then implementations must advertise an
+// "Accept-Encoding" request header themselves whose value includes at
+// least "gzip" (additional encodings are allowed to be specified in
+// addition to "gzip"), and must transparently decompress any compressed
+// server responses before returning the data via these interfaces.
+// Implementations are also allowed to advertise/support additional
+// encoding methods.
+// * In such cases where no "Accept-Encoding" header is specified,
+// implementations must remove the "Content-Encoding" and
+// "Content-Length" headers from headers returned via
+// `HttpResponse::headers()` (since those wouldn't reflect the payload
+// delivered via this interface).
+// * However, if an "Accept-Encoding" request header *is* explicitly
+// specified, then implementations must use that header verbatim and they
+// must not decompress the response (even if they natively support the
+// compression method), and they must leave the "Content-Encoding" and
+// "Content-Length" headers intact.
+// * This ensures that the caller of this interface can take full control of
+// the decompression and/or choose to store decompressed payloads on disk
+// if it so chooses.
+// * Implementations must transparently decode server responses served with
+// "Transfer-Encoding: chunked". In such cases they must remove the
+// "Transfer-Encoding" response header.
+// - Request body compression & encoding:
+// * If implementations receive a "Content-Encoding" request header, this
+// means that the request body stream they receive has already been
+// compressed. The implementation must leave the header and request body
+// intact in such cases (i.e. not re-compress it).
+// * If implementations receive a "Content-Length" request header, they must
+// use it verbatim and they should then assume that the request body will
+// be of exactly that size.
+// * If they do not receive such a header then they must use the
+// "Transfer-encoding: chunked" mechanism to transmit the request body
+// (i.e. they shouldn't specify a "Content-Length" header and they should
+// transmit the body in chunks), or use an equivalent method of streaming
+// the data (such as HTTP/2's data streaming).
+class HttpClient {
+ public:
+ virtual ~HttpClient() = default;
+
+ // Enqueues an HTTP request, without starting it yet. To start the request the
+ // `HttpRequestHandle` must be passed to `PerformRequests`. Each
+ // `HttpRequestHandle` must be passed to at most one `PerformRequests` call.
+ //
+ // The `HttpClient` implementation assumes ownership of the `HttpRequest`
+ // object, and the implementation must delete the object when the
+ // `HttpRequestHandle` is deleted.
+ ABSL_MUST_USE_RESULT
+ virtual std::unique_ptr<HttpRequestHandle> EnqueueRequest(
+ std::unique_ptr<HttpRequest> request) = 0;
+
+ // Performs the given requests. Results will be returned to each
+ // corresponding `HttpRequestCallback` while this method is blocked. This
+ // method must block until all requests have finished or have been cancelled,
+ // and until all corresponding request callbacks have returned.
+ //
+ // By decoupling the enqueueing and starting of (groups of) requests,
+ // implementations may be able to handle concurrent requests more optimally
+ // (e.g. by issuing them over a shared HTTP connection). Having separate
+ // per-request `HttpRequestHandle` objects also makes it easier to support
+ // canceling specific requests, releasing resources for specific requests,
+ // accessing stats for specific requests, etc.
+ //
+ // The `HttpRequestHandle` and `HttpRequestCallback` instances must outlive
+ // the call to `PerformRequests`, but may be deleted any time after this call
+ // has returned.
+ //
+ // Returns an `INVALID_ARGUMENT` error if a `HttpRequestHandle` was previously
+ // already passed to another `PerformRequests` call, or if an
+ // `HttpRequestHandle`'s `Cancel` method was already called before being
+ // passed to this call.
+ virtual absl::Status PerformRequests(
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>>
+ requests) = 0;
+};
+
+// An HTTP request for a single resource. Implemented by the caller of
+// `HttpClient`.
+//
+// Once instances are passed to `EnqueueRequest`, their lifetime is managed by
+// the `HttpClient` implementation. Implementations must tie the `HttpRequest`
+// instance lifetime to the lifetime of the `HttpRequestHandle` they return
+// (i.e. they should delete the `HttpRequest` from the `HttpRequestHandle`
+// destructor).
+//
+// Methods of this class may get called from any thread (and subsequent calls
+// are not required to all happen on the same thread).
+class HttpRequest {
+ public:
+ // Note: the request methods imply a set of standard request properties such
+ // as cacheability, safety, and idempotency. See
+ // https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods.
+ //
+ // The caller of `HttpClient` may implement its own caching layer in the
+ // future, so implementations are not expected to cache cacheable requests
+ // (although they are technically allowed to).
+ //
+ // Implementations should not automatically retry requests, even if the
+ // request method implies it is safe or idempotent. The caller of `HttpClient`
+ // will own the responsibility for retrying requests.
+ enum class Method { kHead, kGet, kPost, kPut, kPatch, kDelete };
+
+ // Must not be called until any corresponding `HttpRequestHandle` has been
+ // deleted.
+ virtual ~HttpRequest() = default;
+
+ // The URI to request. Will always have an "https://" scheme (but this may be
+ // extended in the future).
+ virtual absl::string_view uri() const = 0;
+
+ // The HTTP method to use for this request.
+ virtual Method method() const = 0;
+
+ // Extra request headers to include with this request, in addition to any
+ // headers specified by the `HttpClient` implementation.
+ //
+ // See the `HttpClient` comment for the expected behavior w.r.t. a few
+ // specific headers.
+ virtual const HeaderList& extra_headers() const = 0;
+
+ // Returns true if the request has a request body (which can be read using
+ // `ReadBody`). If the request body payload size is known ahead of time, then
+ // the "Content-Length" header will be set in `extra_headers()`. If it isn't
+ // known yet then the `HttpClient` implementation should use the
+ // "Transfer-Encoding: chunked" encoding to transmit the request body to the
+ // server in chunks (or use an equivalent method of streaming the data, e.g.
+ // if the connection uses HTTP/2). See the `HttpClient` comment for more
+ // details.
+ virtual bool HasBody() const = 0;
+
+ // HttpRequests that up to `requested` bytes of the request body be read into
+ // `buffer`, and that the actual amount of bytes read is returned. The caller
+ // retains ownership of the buffer.
+ //
+ // Callees must return at least 1 byte, but may otherwise return less than the
+ // requested amount of data, if more data isn't available yet. Callees should
+ // return `OUT_OF_RANGE` when the end of data has been reached, in which case
+ // `buffer` should not be modified.
+ //
+ // Callees should return data ASAP, as delaying this for too long may cause
+ // the network stream to fall idle/run out of data to transmit.
+ //
+ // May also return other errors, in which case the request will be ended and
+ // `HttpRequestCallback::OnResponseError` will be called with the same error.
+ virtual absl::StatusOr<int64_t> ReadBody(char* buffer, int64_t requested) = 0;
+};
+
+// A handle to a pending `HttpRequest`, allowing a caller of `HttpClient` to
+// access stats for the request or to cancel ongoing requests. Implemented by
+// the `HttpClient` implementer.
+//
+// The lifetimes of instances of this class are owned by the caller of
+// `HttpClient`.
+//
+// Methods of this class may get called from any thread (and subsequent calls
+// are not required to all happen on the same thread).
+class HttpRequestHandle {
+ public:
+ // When this is called, `HttpClient` implementations should delete all their
+ // owned resources as well as the associated `HttpRequest`.
+ virtual ~HttpRequestHandle() = default;
+
+ // The total amount of data sent/received over the network for this request up
+ // to this point. These numbers should reflect as close as possible the amount
+ // of bytes sent "over the wire". This means, for example, that if the data is
+ // compressed or if a `Transfer-Encoding` is used, the numbers should reflect
+ // the compressed and/or encoded size of the data (if the implementation is
+ // able to account for that). Implementations are allowed to account for the
+ // overhead of TLS encoding in these numbers, but are not required to (since
+ // many HTTP libraries also do not provide stats at that level of
+ // granularity).
+ //
+ // If the request was served from a cache then this should reflect only the
+ // actual bytes sent over the network (e.g. 0 if returned from disk directly,
+ // or if a cache validation request was sent only those bytes used by the
+ // validation request/response).
+ //
+ // If the request involved redirects, the numbers returned here should include
+ // the bytes sent/received for those redirects, if the implementation supports
+ // this. Otherwise they are allowed to reflect only the final
+ // request/response's bytes sent/received.
+ //
+ // Implementations should strive to return as up-to-date numbers are possible
+ // from these methods (e.g. ideally the 'sent' number should reflect the
+ // amount of request body data that has been uploaded so far, even if the
+ // upload hasn't completed fully yet; similarly the 'received' number should
+ // reflect the amount of response body data received so far, even if the
+ // response hasn't been fully received yet).
+ //
+ // The numbers returned here are not required to increase monotonically
+ // between each call to the method. E.g. implementations are allowed to return
+ // best-available estimates while the request is still in flight, and then
+ // revise the numbers down to a more accurate number once the request has been
+ // completed.
+ struct SentReceivedBytes {
+ int64_t sent_bytes;
+ int64_t received_bytes;
+ };
+ virtual SentReceivedBytes TotalSentReceivedBytes() const = 0;
+
+ // Used to indicate that the request should be cancelled and that
+ // implementations may release resources associated with this request (e.g.
+ // the socket used by the request).
+ //
+ // Callers are still only allowed to delete this instance once after any
+ // corresponding `PerformRequests()` call has completed, and not before.
+ //
+ // If a `PerformRequests` call is ongoing for this handle, then the
+ // corresponding `HttpRequestCallback` instance may still receive further
+ // method invocations after this call returns (e.g. because an invocation may
+ // already have been in flight).
+ //
+ // If a `PerformRequests` call is ongoing for this handle, and if the
+ // `HttpRequestCallback::OnResponseStarted` method was not called yet, then
+ // the `HttpRequestCallback::OnResponseError` method must be called with
+ // status `CANCELLED`.
+ //
+ // Otherwise, if a `PerformRequests` call is ongoing for this handle, and if
+ // the `HttpRequestCallback::OnResponseCompleted` method was not called yet,
+ // then the `HttpRequestCallback::OnResponseBodyError` method must be called
+ // with status `CANCELLED`.
+ virtual void Cancel() = 0;
+};
+
+// The callback interface that `HttpClient` implementations must use to deliver
+// the response to a `HttpRequest`. Implemented by the caller of `HttpClient`.
+//
+// The lifetimes of instances of this class are owned by the caller of
+// `HttpClient`. Instances must remain alive for at least as long as their
+// corresponding `PerformRequests` call.
+//
+// Methods of this class may get called from any thread (incl. concurrently),
+// but callers of this class must always call the callback methods for a
+// specific `HttpRequest` in the order specified in each method's documentation.
+// Implementations of this class therefore likely should use internal
+// synchronization.
+//
+// For example, a call to `OnResponseBody` for a given `HttpRequest` A will
+// always be preceded by a completed call to `OnResponseStarted` for that same
+// request A. However, callbacks for different `HttpRequest` objects may happen
+// concurrently, so for example, `OnResponseStarted` may be called concurrently
+// for two different requests A and B. This latter scenario means that if the
+// same `HttpRequestCallback` object is used to handle callbacks for both
+// requests, then the object has to handle concurrent calls correctly.
+class HttpRequestCallback {
+ public:
+ virtual ~HttpRequestCallback() = default;
+
+ // Called when the final HTTP response headers have been received (i.e. after
+ // any redirects have been followed but before the response body may have been
+ // received fully) for the given `HttpRequest`. The response data can be
+ // accessed via the given `HttpResponse`, which will remain alive for the
+ // lifetime of the corresponding `HttpRequestHandle`.
+ //
+ // Note that all the data in the `HttpResponse` object should reflect the
+ // last/final response (i.e. it shouldn't reflect any already-followed
+ // redirects).
+ //
+ // If the response has a body then after this method is called
+ // `OnResponseBody` will be called one or more times to deliver the response
+ // body (or `OnResponseBodyError` if an error occurs).
+ //
+ // Note that responses with an HTTP status code other than 200 ("OK") may
+ // still have response bodies, and implementations must deliver these via the
+ // `OnResponseBody` callback, just as they should for a successful response.
+ //
+ // If this method returns an error then the `HttpClient` implementation should
+ // consider the `HttpRequest` canceled. No further methods must be called on
+ // this `HttpRequestCallback` instance for the given `HttpRequest` after in
+ // this case.
+ virtual absl::Status OnResponseStarted(const HttpRequest& request,
+ const HttpResponse& response) = 0;
+
+ // Called when the request encountered an error or timed out, before receiving
+ // the response headers completely. No further methods must be called on this
+ // `HttpRequestCallback` instance for the given `HttpRequest` after this
+ // method is called.
+ //
+ // If the implementation is able to discern that the error may have been
+ // transient, they should return `UNAVAILABLE`.
+ //
+ // If more than the implementation's defined max number of redirects occurred
+ // (without reaching the final response), then implementations should return
+ // `OUT_OF_RANGE` here.
+ //
+ // If the implementation hit an implementation-specific timeout (even though
+ // implementations are discouraged from imposing such timeouts), then this
+ // should be `DEADLINE_EXCEEDED`.
+ //
+ // If the `HttpRequestHandle::Cancel` method was called before
+ // `OnResponseStarted` was called for the given `HttpRequest`, then this
+ // method will be called with a `CANCELLED` status.
+ //
+ // If the request's `HttpRequest::ReadBody` returned an unexpected error,
+ // then method will be called with that error.
+ virtual void OnResponseError(const HttpRequest& request,
+ const absl::Status& error) = 0;
+
+ // Called (possibly multiple times per request) when a block of response data
+ // is available in `data`. This method must only be called after
+ // `OnResponseStarted` was called for the given `HttpRequest`.
+ //
+ // Callees must process the data ASAP, as delaying this for too long may
+ // prevent additional data from arriving on the network stream.
+ //
+ // If this method returns an error then the `HttpClient` implementation should
+ // consider the `HttpRequest` canceled. No further methods must be called on
+ // this `HttpRequestCallback` instance for the given `HttpRequest` after in
+ // this case.
+ virtual absl::Status OnResponseBody(const HttpRequest& request,
+ const HttpResponse& response,
+ absl::string_view data) = 0;
+
+ // Called when the request encountered an error or timed out while receiving
+ // the response body (i.e. after `OnResponseStarted` was called). No further
+ // methods must be called on this `HttpRequestCallback` instance for the given
+ // `HttpRequest` after this method is called.
+ //
+ // If the implementation is able to discern that the error may have been
+ // transient, they should return `UNAVAILABLE`.
+ //
+ // If the implementation hit an implementation-specific timeout (even though
+ // implementations are discouraged from imposing such timeouts), then this
+ // should be `DEADLINE_EXCEEDED`.
+ //
+ // If the `HttpRequestHandle::Cancel` method was called before
+ // `OnResponseCompleted` was called for the given `HttpRequest`, then this
+ // method will be called with a `CANCELLED` status.
+ virtual void OnResponseBodyError(const HttpRequest& request,
+ const HttpResponse& response,
+ const absl::Status& error) = 0;
+
+ // Called when the request has completed successfully (i.e. the response
+ // headers were delivered, and if there was a response body then it was also
+ // delivered successfully). Must not be called if one of the error callbacks
+ // was already called for the given `HttpRequest`, and no further methods must
+ // be called on this `HttpRequestCallback` instance for the given
+ // `HttpRequest` after this method is called.
+ virtual void OnResponseCompleted(const HttpRequest& request,
+ const HttpResponse& response) = 0;
+};
+
+// A response to a given `HttpRequest`. Implemented by the `HttpClient`
+// implementer.
+//
+// The lifetimes of instances of this class are managed by the `HttpClient`
+// implementer. Instances of this class must remain alive for at least long as
+// the corresponding `HttpRequestHandle` is alive.
+//
+// Note that all the data in this object should be for the last/final response.
+// I.e. any responses corresponding to redirects should not be reflected here.
+class HttpResponse {
+ public:
+ virtual ~HttpResponse() = default;
+
+ // The response code returned by the server (e.g. 200).
+ virtual int code() const = 0;
+
+ // The response headers. Implementations are allowed to either coalesce
+ // repeated headers using commas (as per RFC2616 section 4.2), or to return
+ // them as separate entries.
+ //
+ // See `HttpClient` comment for the expected behavior w.r.t. a few specific
+ // headers.
+ virtual const HeaderList& headers() const = 0;
+};
+
+} // namespace http
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_HTTP_HTTP_CLIENT_H_
diff --git a/fcp/client/http/http_client_util.cc b/fcp/client/http/http_client_util.cc
new file mode 100644
index 0000000..3fce9f4
--- /dev/null
+++ b/fcp/client/http/http_client_util.cc
@@ -0,0 +1,246 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/http_client_util.h"
+
+#include <algorithm>
+#include <functional>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
+#include "absl/strings/substitute.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/protos/federatedcompute/common.pb.h"
+// #include "google/rpc/status.pb.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+namespace {
+
+using ::google::internal::federatedcompute::v1::Status;
+
+absl::StatusCode ConvertHttpCodeToStatusCode(int code) {
+ switch (code) {
+ case kHttpBadRequest:
+ return absl::StatusCode::kInvalidArgument;
+ case kHttpForbidden:
+ return absl::StatusCode::kPermissionDenied;
+ case kHttpNotFound:
+ return absl::StatusCode::kNotFound;
+ case kHttpConflict:
+ return absl::StatusCode::kAborted;
+ case kHttpTooManyRequests:
+ return absl::StatusCode::kResourceExhausted;
+ case kHttpClientClosedRequest:
+ return absl::StatusCode::kCancelled;
+ case kHttpGatewayTimeout:
+ return absl::StatusCode::kDeadlineExceeded;
+ case kHttpNotImplemented:
+ return absl::StatusCode::kUnimplemented;
+ case kHttpServiceUnavailable:
+ return absl::StatusCode::kUnavailable;
+ case kHttpUnauthorized:
+ return absl::StatusCode::kUnauthenticated;
+ default: {
+ // Importantly this range ensures that we treat not only "200 OK" as OK,
+ // but also other codes such as "201 Created" etc.
+ if (code >= 200 && code < 300) {
+ return absl::StatusCode::kOk;
+ }
+ if (code >= 400 && code < 500) {
+ return absl::StatusCode::kFailedPrecondition;
+ }
+ if (code >= 500 && code < 600) {
+ return absl::StatusCode::kInternal;
+ }
+ return absl::StatusCode::kUnknown;
+ }
+ }
+}
+
+// Converts a `Status` error code into an `absl::StatusCode`
+// (there is a 1:1 mapping).
+absl::StatusCode ConvertRpcCodeToStatusCode(int code) {
+ switch (code) {
+ case static_cast<int>(absl::StatusCode::kOk):
+ return absl::StatusCode::kOk;
+ case static_cast<int>(absl::StatusCode::kCancelled):
+ return absl::StatusCode::kCancelled;
+ case static_cast<int>(absl::StatusCode::kUnknown):
+ return absl::StatusCode::kUnknown;
+ case static_cast<int>(absl::StatusCode::kInvalidArgument):
+ return absl::StatusCode::kInvalidArgument;
+ case static_cast<int>(absl::StatusCode::kDeadlineExceeded):
+ return absl::StatusCode::kDeadlineExceeded;
+ case static_cast<int>(absl::StatusCode::kNotFound):
+ return absl::StatusCode::kNotFound;
+ case static_cast<int>(absl::StatusCode::kAlreadyExists):
+ return absl::StatusCode::kAlreadyExists;
+ case static_cast<int>(absl::StatusCode::kPermissionDenied):
+ return absl::StatusCode::kPermissionDenied;
+ case static_cast<int>(absl::StatusCode::kResourceExhausted):
+ return absl::StatusCode::kResourceExhausted;
+ case static_cast<int>(absl::StatusCode::kFailedPrecondition):
+ return absl::StatusCode::kFailedPrecondition;
+ case static_cast<int>(absl::StatusCode::kAborted):
+ return absl::StatusCode::kAborted;
+ case static_cast<int>(absl::StatusCode::kOutOfRange):
+ return absl::StatusCode::kOutOfRange;
+ case static_cast<int>(absl::StatusCode::kUnimplemented):
+ return absl::StatusCode::kUnimplemented;
+ case static_cast<int>(absl::StatusCode::kInternal):
+ return absl::StatusCode::kInternal;
+ case static_cast<int>(absl::StatusCode::kUnavailable):
+ return absl::StatusCode::kUnavailable;
+ case static_cast<int>(absl::StatusCode::kDataLoss):
+ return absl::StatusCode::kDataLoss;
+ case static_cast<int>(absl::StatusCode::kUnauthenticated):
+ return absl::StatusCode::kUnauthenticated;
+ default:
+ // This should never be reached, since there should be a 1:1 mapping
+ // between Absl and Google RPC status codes.
+ return absl::StatusCode::kUnknown;
+ }
+}
+
+absl::StatusOr<std::string> PercentEncode(
+ absl::string_view input, std::function<bool(char c)> unencoded_chars) {
+ std::string result;
+ for (unsigned char c : input) {
+ // We limit URIs only to ASCII characters.
+ if (!absl::ascii_isascii(c)) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "Encountered unsupported char during URI encoding: ", c));
+ }
+ // The following characters are *not* percent-encoded.
+ if (unencoded_chars(c)) {
+ result.push_back(c);
+ continue;
+ }
+ // Any other character is percent-encoded.
+ result.append(absl::StrFormat("%%%X", c));
+ }
+ return result;
+}
+
+} // namespace
+
+absl::Status ConvertHttpCodeToStatus(int code) {
+ absl::StatusCode status_code = ConvertHttpCodeToStatusCode(code);
+ if (status_code == absl::StatusCode::kOk) {
+ return absl::OkStatus();
+ }
+ std::string error_message =
+ absl::StrCat("Request returned non-OK response (code: ", code, ")");
+ return absl::Status(status_code, error_message);
+}
+
+absl::Status ConvertRpcStatusToAbslStatus(Status rpc_status) {
+ return absl::Status(ConvertRpcCodeToStatusCode(rpc_status.code()),
+ rpc_status.message());
+}
+
+Status ConvertAbslStatusToRpcStatus(absl::Status status) {
+ Status rpc_status;
+ rpc_status.set_code(static_cast<int>(status.code()));
+ rpc_status.set_message(std::string(status.message()));
+ return rpc_status;
+}
+
+std::string ConvertMethodToString(HttpRequest::Method method) {
+ switch (method) {
+ case HttpRequest::Method::kGet:
+ return "GET";
+ case HttpRequest::Method::kHead:
+ return "HEAD";
+ case HttpRequest::Method::kDelete:
+ return "DELETE";
+ case HttpRequest::Method::kPatch:
+ return "PATCH";
+ case HttpRequest::Method::kPost:
+ return "POST";
+ case HttpRequest::Method::kPut:
+ return "PUT";
+ }
+}
+
+std::optional<std::string> FindHeader(const HeaderList& headers,
+ absl::string_view needle) {
+ // Normalize the needle (since header names are case insensitive, as per RFC
+ // 2616 section 4.2).
+ const std::string normalized_needle = absl::AsciiStrToLower(needle);
+ const auto& header_entry = std::find_if(
+ headers.begin(), headers.end(), [&normalized_needle](const Header& x) {
+ // AsciiStrToLower safely handles non-ASCII data, and comparing
+ // non-ASCII data w/ our needle is safe as well.
+ return absl::AsciiStrToLower(std::get<0>(x)) == normalized_needle;
+ });
+
+ if (header_entry == headers.end()) {
+ return std::nullopt;
+ }
+ return std::get<1>(*header_entry);
+}
+
+absl::StatusOr<std::string> JoinBaseUriWithSuffix(
+ absl::string_view base_uri, absl::string_view uri_suffix) {
+ if (!uri_suffix.empty() && uri_suffix[0] != '/') {
+ return absl::InvalidArgumentError(
+ "uri_suffix be empty or must have a leading '/'");
+ }
+ // Construct the full URI by joining the base URI we should use with the given
+ // suffix, ensuring that there's always a single '/' in between the two parts.
+ return absl::StrCat(absl::StripSuffix(base_uri, "/"), "/",
+ absl::StripPrefix(uri_suffix, "/"));
+}
+
+absl::StatusOr<std::string> EncodeUriSinglePathSegment(
+ absl::string_view input) {
+ return PercentEncode(input, [](char c) {
+ return absl::ascii_isalnum(c) || c == '-' || c == '_' || c == '.' ||
+ c == '~';
+ });
+}
+
+absl::StatusOr<std::string> EncodeUriMultiplePathSegments(
+ absl::string_view input) {
+ return PercentEncode(input, [](char c) {
+ return absl::ascii_isalnum(c) || c == '-' || c == '_' || c == '.' ||
+ c == '~' || c == '/';
+ });
+}
+
+absl::StatusOr<std::string> CreateByteStreamUploadUriSuffix(
+ absl::string_view resource_name) {
+ constexpr absl::string_view pattern = "/upload/v1/media/$0";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_resource_name,
+ EncodeUriMultiplePathSegments(resource_name));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_resource_name);
+}
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/http_client_util.h b/fcp/client/http/http_client_util.h
new file mode 100644
index 0000000..9d73677
--- /dev/null
+++ b/fcp/client/http/http_client_util.h
@@ -0,0 +1,135 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_HTTP_CLIENT_UTIL_H_
+#define FCP_CLIENT_HTTP_HTTP_CLIENT_UTIL_H_
+
+#include <optional>
+#include <string>
+
+// #include "google/rpc/status.pb.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/protos/federatedcompute/common.pb.h"
+
+namespace fcp::client::http {
+
+inline static constexpr char kHttpsScheme[] = "https://";
+inline static constexpr char kLocalhostUri[] = "http://localhost:";
+inline static constexpr char kAcceptEncodingHdr[] = "Accept-Encoding";
+inline static constexpr char kContentLengthHdr[] = "Content-Length";
+inline static constexpr char kContentEncodingHdr[] = "Content-Encoding";
+inline static constexpr char kContentTypeHdr[] = "Content-Type";
+inline static constexpr char kExpectHdr[] = "Expect";
+inline static constexpr char kTransferEncodingHdr[] = "Transfer-Encoding";
+inline static constexpr char kApiKeyHdr[] = "x-goog-api-key";
+// The "Transfer-Encoding" header value when the header is present but indicates
+// that no encoding was actually applied.
+inline static constexpr char kIdentityEncodingHdrValue[] = "identity";
+inline static constexpr char kGzipEncodingHdrValue[] = "gzip";
+inline static constexpr char kProtobufContentType[] = "application/x-protobuf";
+
+// A non-exhaustive enumeration of common HTTP response codes.
+// Note this is purposely *not* an "enum class", to allow easy comparisons
+// against the int codes returned by `HttpResponse`.
+enum HttpResponseCode {
+ kHttpOk = 200,
+ kHttpMovedPermanently = 301,
+ kHttpBadRequest = 400,
+ kHttpUnauthorized = 401,
+ kHttpForbidden = 403,
+ kHttpNotFound = 404,
+ kHttpConflict = 409,
+ kHttpTooManyRequests = 429,
+ kHttpClientClosedRequest = 499,
+ kHttpInternalServerError = 500,
+ kHttpNotImplemented = 501,
+ kHttpServiceUnavailable = 503,
+ kHttpGatewayTimeout = 504,
+};
+
+// Converts an HTTP response code into an `absl::Status` (incl. an error message
+// with the original HTTP code).
+absl::Status ConvertHttpCodeToStatus(int code);
+
+// Converts a `::google::internal::federatedcompute::v1::Status` into an
+// `absl::Status`.
+absl::Status ConvertRpcStatusToAbslStatus(
+ ::google::internal::federatedcompute::v1::Status rpc_status);
+
+// Converts an `absl::Status` into a `google::rpc::Status`.
+::google::internal::federatedcompute::v1::Status ConvertAbslStatusToRpcStatus(
+ absl::Status status);
+
+// Converts the method enum to a string.
+std::string ConvertMethodToString(HttpRequest::Method method);
+
+// Finds the header value for header with name `needle` in a list of headers
+// (incl. normalizing the header names to lowercase before doing any
+// comparisons). Note that this returns the first matching header value (rather
+// than coalescing repeated header values as per RFC2616 section 4.2), so it
+// must only be used for headers for which only a single value is expected.
+// Returns an empty optional if no header value was found.
+std::optional<std::string> FindHeader(const HeaderList& headers,
+ absl::string_view needle);
+
+// Creates a URI out of a base URI and a suffix.
+//
+// The `base_uri` argument is expected to be a valid fully qualified URI on its
+// own (i.e. having non-empty scheme and authority/host segments, and possibly a
+// path segment as well), although this function does not validate this. It may
+// or may not end with a trailing '/'. may or may not
+//
+// The `uri_suffix` argument must always either be empty or start with a
+// leading '/'.
+//
+// Returns a URI formed by joining the two arguments, ensuring there is
+// always a single '/' in between the two parts. E.g. if both `base_uri` ends
+// with a '/' and `uri_suffix` start with a '/', then the two '/' characters
+// will be normalized into a single one. If `uri_suffix` is empty, then the
+// resulting URI will always end in a '/'.
+absl::StatusOr<std::string> JoinBaseUriWithSuffix(absl::string_view base_uri,
+ absl::string_view uri_suffix);
+
+// URI-encodes the input, for use a *single path segment* in a URI. This means
+// that '/' characters *are* escaped.
+//
+// See "exactly one path segment" in the "Path template syntax" section in
+// https://github.com/googleapis/googleapis/blob/master/google/api/http.proto.
+//
+// Note that only ASCII strings are accepted (others will return
+// `INVALID_ARGUMENT`). This is stricter than the http.proto spec requires.
+absl::StatusOr<std::string> EncodeUriSinglePathSegment(absl::string_view input);
+
+// URI-encodes the input, for use as *multiple path segments* in a URI. This
+// means that '/' characters *are not* escaped.
+//
+// See "multiple path segments" in the "Path template syntax" section in
+// https://github.com/googleapis/googleapis/blob/master/google/api/http.proto.
+//
+// Note that only ASCII strings are accepted (others will return
+// `INVALID_ARGUMENT`). This is stricter than the http.proto spec requires.
+absl::StatusOr<std::string> EncodeUriMultiplePathSegments(
+ absl::string_view input);
+
+// Create a ByteStream upload URI suffix based on the resource name.
+// Returns INVALID_ARGUMENT when the resource name cannot be URI-encoded.
+absl::StatusOr<std::string> CreateByteStreamUploadUriSuffix(
+ absl::string_view resource_name);
+} // namespace fcp::client::http
+
+#endif // FCP_CLIENT_HTTP_HTTP_CLIENT_UTIL_H_
diff --git a/fcp/client/http/http_client_util_test.cc b/fcp/client/http/http_client_util_test.cc
new file mode 100644
index 0000000..a1a2d5d
--- /dev/null
+++ b/fcp/client/http/http_client_util_test.cc
@@ -0,0 +1,323 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/http_client_util.h"
+
+#include <optional>
+#include <string>
+
+#include "google/rpc/status.pb.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::client::http {
+namespace {
+
+using ::fcp::IsCode;
+using ::testing::HasSubstr;
+using ::testing::Optional;
+using ::testing::StrEq;
+
+TEST(ConvertHttpCodeToStatusTest, ConvertsKnownCodesCorrectly) {
+ EXPECT_OK(ConvertHttpCodeToStatus(kHttpOk));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpBadRequest),
+ IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpForbidden),
+ IsCode(PERMISSION_DENIED));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpNotFound), IsCode(NOT_FOUND));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpConflict), IsCode(ABORTED));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpTooManyRequests),
+ IsCode(RESOURCE_EXHAUSTED));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpClientClosedRequest),
+ IsCode(CANCELLED));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpGatewayTimeout),
+ IsCode(DEADLINE_EXCEEDED));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpNotImplemented),
+ IsCode(UNIMPLEMENTED));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpServiceUnavailable),
+ IsCode(UNAVAILABLE));
+ EXPECT_THAT(ConvertHttpCodeToStatus(kHttpUnauthorized),
+ IsCode(UNAUTHENTICATED));
+}
+
+TEST(ConvertHttpCodeToStatusTest, ConvertsUnknown200CodesToOk) {
+ EXPECT_OK(ConvertHttpCodeToStatus(201));
+ EXPECT_OK(ConvertHttpCodeToStatus(210));
+ EXPECT_OK(ConvertHttpCodeToStatus(299));
+}
+
+TEST(ConvertHttpCodeToStatusTest, ConvertsUnknown400CodesToFailedPrecondition) {
+ // Note: 400, 401, and 499 are known errors codes that map to other values. We
+ // hence test a few other values in the 400 range that aren't "known".
+ EXPECT_THAT(ConvertHttpCodeToStatus(402), IsCode(FAILED_PRECONDITION));
+ EXPECT_THAT(ConvertHttpCodeToStatus(405), IsCode(FAILED_PRECONDITION));
+ EXPECT_THAT(ConvertHttpCodeToStatus(410), IsCode(FAILED_PRECONDITION));
+ EXPECT_THAT(ConvertHttpCodeToStatus(498), IsCode(FAILED_PRECONDITION));
+}
+
+TEST(ConvertHttpCodeToStatusTest, ConvertsUnknown500CodesToInternal) {
+ // note: 501 is a known error code that maps to other values. We hence test
+ // 502 instead.
+ EXPECT_THAT(ConvertHttpCodeToStatus(500), IsCode(INTERNAL));
+ EXPECT_THAT(ConvertHttpCodeToStatus(502), IsCode(INTERNAL));
+ EXPECT_THAT(ConvertHttpCodeToStatus(510), IsCode(INTERNAL));
+ EXPECT_THAT(ConvertHttpCodeToStatus(599), IsCode(INTERNAL));
+}
+
+TEST(ConvertHttpCodeToStatusTest, ConvertsUnknownAllOtherCodesToUnknown) {
+ EXPECT_THAT(ConvertHttpCodeToStatus(300), IsCode(UNKNOWN));
+ EXPECT_THAT(ConvertHttpCodeToStatus(301), IsCode(UNKNOWN));
+ EXPECT_THAT(ConvertHttpCodeToStatus(310), IsCode(UNKNOWN));
+ EXPECT_THAT(ConvertHttpCodeToStatus(399), IsCode(UNKNOWN));
+ EXPECT_THAT(ConvertHttpCodeToStatus(0), IsCode(UNKNOWN));
+ EXPECT_THAT(ConvertHttpCodeToStatus(1), IsCode(UNKNOWN));
+ EXPECT_THAT(ConvertHttpCodeToStatus(10), IsCode(UNKNOWN));
+ EXPECT_THAT(ConvertHttpCodeToStatus(99), IsCode(UNKNOWN));
+ EXPECT_THAT(ConvertHttpCodeToStatus(600), IsCode(UNKNOWN));
+}
+
+// Test to ensure that when we map an 'unknown' HTTP response to its fallback
+// catch-all StatusCode, we keep the original error code in the error message
+// (to aid in debugging).
+TEST(ConvertHttpCodeToStatusTest, IncludesOriginalErrorCodeInMessage) {
+ EXPECT_THAT(ConvertHttpCodeToStatus(400).message(), HasSubstr("code: 400"));
+ EXPECT_THAT(ConvertHttpCodeToStatus(402).message(), HasSubstr("code: 402"));
+ EXPECT_THAT(ConvertHttpCodeToStatus(502).message(), HasSubstr("code: 502"));
+ EXPECT_THAT(ConvertHttpCodeToStatus(503).message(), HasSubstr("code: 503"));
+ EXPECT_THAT(ConvertHttpCodeToStatus(300).message(), HasSubstr("code: 300"));
+}
+
+void ExpectRpcStatusMatchAbslStatus(absl::StatusCode code) {
+ ::google::rpc::Status rpc_status;
+ *rpc_status.mutable_message() = "the_message";
+ rpc_status.set_code(static_cast<int>(code));
+ absl::Status converted_status = ConvertRpcStatusToAbslStatus(rpc_status);
+ EXPECT_THAT(converted_status, IsCode(code));
+ // OK Status objects always have empty messages.
+ EXPECT_EQ(converted_status.message(),
+ code == absl::StatusCode::kOk ? "" : "the_message");
+}
+
+TEST(ConvertRpcStatusToAbslStatusTest, ValidCodesShouldConvertSuccessfully) {
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kOk);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kCancelled);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kUnknown);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kInvalidArgument);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kDeadlineExceeded);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kNotFound);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kAlreadyExists);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kPermissionDenied);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kResourceExhausted);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kFailedPrecondition);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kAborted);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kOutOfRange);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kUnimplemented);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kInternal);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kUnavailable);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kDataLoss);
+ ExpectRpcStatusMatchAbslStatus(absl::StatusCode::kUnauthenticated);
+}
+
+TEST(ConvertRpcStatusToAbslStatusTest, InvalidCodesShouldConvertToUnknown) {
+ ::google::rpc::Status rpc_status;
+ *rpc_status.mutable_message() = "the_message";
+ rpc_status.set_code(100); // 100 is not a valid status code.
+ absl::Status converted_status = ConvertRpcStatusToAbslStatus(rpc_status);
+ EXPECT_THAT(converted_status, IsCode(UNKNOWN));
+ EXPECT_EQ(converted_status.message(), "the_message");
+}
+
+TEST(FindHeaderTest, DoesNotFindMissingHeader) {
+ EXPECT_EQ(FindHeader({}, "Header-One"), std::nullopt);
+ EXPECT_EQ(FindHeader({{"Header-Two", "bar"}}, "Header-One"), std::nullopt);
+ EXPECT_EQ(FindHeader({{"Header-Two", "bar"}, {"Header-Three", "baz"}},
+ "Header-One"),
+ std::nullopt);
+}
+TEST(FindHeaderTest, EmptyNeedleDoesNotFindAnyHeader) {
+ EXPECT_EQ(FindHeader({}, ""), std::nullopt);
+ EXPECT_EQ(FindHeader({{"Header-Two", "bar"}}, ""), std::nullopt);
+}
+
+TEST(FindHeaderTest, FindsHeaderInList) {
+ EXPECT_THAT(FindHeader({{"Header-One", "foo"},
+ {"Header-Two", "bar"},
+ {"Header-Three", "baz"}},
+ "Header-Two"),
+ Optional(StrEq("bar")));
+
+ EXPECT_THAT(FindHeader({{"Header-Two", "BAR"},
+ {"Header-One", "foo"},
+ {"Header-Three", "baz"}},
+ "Header-Two"),
+ Optional(StrEq("BAR")));
+
+ EXPECT_THAT(FindHeader({{"Header-One", "foo"},
+ {"Header-Three", "baz"},
+ {"Header-Two", "BaR"}},
+ "Header-Two"),
+ Optional(StrEq("BaR")));
+
+ EXPECT_THAT(FindHeader({{"Header-Two", "bar"}}, "Header-Two"),
+ Optional(StrEq("bar")));
+}
+
+TEST(FindHeaderTest, FindsHeaderInListDespiteMixedCase) {
+ // In each of these scenarios, the fact that the list or needle are not the
+ // same case should not affect the result.
+ EXPECT_THAT(FindHeader({{"Header-One", "foo"},
+ {"header-two", "bar"},
+ {"Header-Three", "baz"}},
+ "Header-Two"),
+ Optional(StrEq("bar")));
+
+ EXPECT_THAT(FindHeader({{"Header-One", "foo"},
+ {"Header-Two", "bar"},
+ {"Header-Three", "baz"}},
+ "header-two"),
+ Optional(StrEq("bar")));
+}
+
+TEST(FindHeaderTest, ReturnsFirstMatch) {
+ // In each of these scenarios, the first matching header value should be
+ // returned.
+ EXPECT_THAT(
+ FindHeader(
+ {{"Header-One", "foo"}, {"Header-Two", "bar"}, {"Header-Two", "baz"}},
+ "Header-Two"),
+ Optional(StrEq("bar")));
+
+ EXPECT_THAT(
+ FindHeader(
+ {{"Header-One", "foo"}, {"Header-Two", "bar"}, {"Header-Two", "baz"}},
+ "Header-Two"),
+ Optional(StrEq("bar")));
+
+ EXPECT_THAT(
+ FindHeader(
+ {{"Header-One", "foo"}, {"header-two", "bar"}, {"HEADER-TWO", "baz"}},
+ "HEADER-TWO"),
+ Optional(StrEq("bar")));
+}
+
+TEST(JoinBaseUriWithSuffixTest, ReturnsJoinedUri) {
+ // No trailing slash in base URI.
+ auto result = JoinBaseUriWithSuffix("https://foo", "/bar");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("https://foo/bar"));
+
+ // Trailing slash in base URI.
+ result = JoinBaseUriWithSuffix("https://foo/", "/bar");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("https://foo/bar"));
+
+ // Additional URI components are correctly merged.
+ result = JoinBaseUriWithSuffix("https://foo:123", "/bar/baz");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("https://foo:123/bar/baz"));
+
+ result = JoinBaseUriWithSuffix("https://foo:123/", "/bar/baz");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("https://foo:123/bar/baz"));
+
+ // Empty suffixes should be allowed.
+ result = JoinBaseUriWithSuffix("https://foo", "");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("https://foo/"));
+
+ // Trailing slash in base URI.
+ result = JoinBaseUriWithSuffix("https://foo/", "");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("https://foo/"));
+}
+
+TEST(JoinBaseUriWithSuffixTest, NoLeadingSlashInSuffixFails) {
+ // No leading slash in the URI suffix, should result in error.
+ auto result = JoinBaseUriWithSuffix("https://foo", "bar");
+ EXPECT_THAT(result.status(), IsCode(INVALID_ARGUMENT));
+ result = JoinBaseUriWithSuffix("https://foo/", "bar");
+ EXPECT_THAT(result.status(), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(EncodeUriTest, UnencodedCharsShouldRemainUnencoded) {
+ std::string unencoded_single_path_segment_chars =
+ "-_.~0123456789abcxyzABCXYZ";
+ auto result = EncodeUriSinglePathSegment(unencoded_single_path_segment_chars);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq(unencoded_single_path_segment_chars));
+
+ std::string unencoded_multi_path_segment_chars =
+ "-_.~/01234567899abcxyzABCXYZ";
+ result = EncodeUriMultiplePathSegments(unencoded_multi_path_segment_chars);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq(unencoded_multi_path_segment_chars));
+}
+
+TEST(EncodeUriTest, OtherCharsShouldBeEncoded) {
+ auto result = EncodeUriSinglePathSegment("#?+%/");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("%23%3F%2B%25%2F"));
+
+ // For the "multiple path segments" version the slash should remain unencoded.
+ result = EncodeUriMultiplePathSegments("#?+%/");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("%23%3F%2B%25/"));
+
+ // Non-encodable characters before/in between/after the encodable characters
+ // should remain unencoded.
+ result = EncodeUriSinglePathSegment("abc#?123+%/XYZ");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("abc%23%3F123%2B%25%2FXYZ"));
+
+ result = EncodeUriMultiplePathSegments("abc#?123+%/XYZ");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq("abc%23%3F123%2B%25/XYZ"));
+}
+
+TEST(EncodeUriTest, EmptyStringShouldReturnEmptyString) {
+ auto result = EncodeUriSinglePathSegment("");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq(""));
+
+ // For the "multiple path segments" version the slash should remain unencoded.
+ result = EncodeUriMultiplePathSegments("");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, StrEq(""));
+}
+
+TEST(EncodeUriTest, NonAsciiStringShouldReturnError) {
+ auto result = EncodeUriSinglePathSegment("€");
+ EXPECT_THAT(result, IsCode(INVALID_ARGUMENT));
+
+ // For the "multiple path segments" version the slash should remain unencoded.
+ result = EncodeUriMultiplePathSegments("€");
+ EXPECT_THAT(result, IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CreateByteStreamUriTest, HappyCase) {
+ auto result = CreateByteStreamUploadUriSuffix("my/resource");
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, "/upload/v1/media/my/resource");
+}
+
+TEST(CreateByteStreamUriTest, NonAsciiResourceNameShouldReturnError) {
+ EXPECT_THAT(CreateByteStreamUploadUriSuffix("€"), IsCode(INVALID_ARGUMENT));
+}
+
+} // namespace
+} // namespace fcp::client::http
diff --git a/fcp/client/http/http_federated_protocol.cc b/fcp/client/http/http_federated_protocol.cc
new file mode 100644
index 0000000..1d854fc
--- /dev/null
+++ b/fcp/client/http/http_federated_protocol.cc
@@ -0,0 +1,1428 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/http_federated_protocol.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+// #include "google/longrunning/operations.pb.h"
+#include "google/protobuf/any.pb.h"
+// #include "google/rpc/code.pb.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/random/random.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "absl/time/time.h"
+#include "fcp/base/clock.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/time_util.h"
+#include "fcp/base/wall_clock_stopwatch.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/federated_protocol_util.h"
+#include "fcp/client/fl_runner.pb.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/http_client_util.h"
+#include "fcp/client/http/http_secagg_send_to_server_impl.h"
+#include "fcp/client/http/in_memory_request_response.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/parsing_utils.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/federatedcompute/aggregations.pb.h"
+#include "fcp/protos/federatedcompute/common.pb.h"
+#include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
+#include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
+#include "fcp/protos/federatedcompute/task_assignments.pb.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+namespace {
+
+using ::fcp::client::GenerateRetryWindowFromRetryTime;
+using ::fcp::client::GenerateRetryWindowFromTargetDelay;
+using ::fcp::client::PickRetryTimeFromRange;
+using ::google::internal::federatedcompute::v1::AbortAggregationRequest;
+using ::google::internal::federatedcompute::v1::ClientStats;
+using ::google::internal::federatedcompute::v1::EligibilityEvalTaskRequest;
+using ::google::internal::federatedcompute::v1::EligibilityEvalTaskResponse;
+using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec;
+using ::google::internal::federatedcompute::v1::
+ ReportEligibilityEvalTaskResultRequest;
+using ::google::internal::federatedcompute::v1::ReportTaskResultRequest;
+using ::google::internal::federatedcompute::v1::Resource;
+using ::google::internal::federatedcompute::v1::ResourceCompressionFormat;
+using ::google::internal::federatedcompute::v1::
+ SecureAggregationProtocolExecutionInfo;
+using ::google::internal::federatedcompute::v1::
+ StartAggregationDataUploadRequest;
+using ::google::internal::federatedcompute::v1::
+ StartAggregationDataUploadResponse;
+using ::google::internal::federatedcompute::v1::StartSecureAggregationRequest;
+using ::google::internal::federatedcompute::v1::StartSecureAggregationResponse;
+using ::google::internal::federatedcompute::v1::StartTaskAssignmentRequest;
+using ::google::internal::federatedcompute::v1::StartTaskAssignmentResponse;
+using ::google::internal::federatedcompute::v1::SubmitAggregationResultRequest;
+using ::google::internal::federatedml::v2::TaskEligibilityInfo;
+// using ::google::longrunning::Operation;
+
+using CompressionFormat =
+ ::fcp::client::http::UriOrInlineData::InlineData::CompressionFormat;
+
+// Creates the URI suffix for a RequestEligibilityEvalTask protocol request.
+absl::StatusOr<std::string> CreateRequestEligibilityEvalTaskUriSuffix(
+ absl::string_view population_name) {
+ constexpr absl::string_view kRequestEligibilityEvalTaskUriSuffix =
+ "/v1/eligibilityevaltasks/$0:request";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_population_name,
+ EncodeUriSinglePathSegment(population_name));
+ return absl::Substitute(kRequestEligibilityEvalTaskUriSuffix,
+ encoded_population_name);
+}
+
+// Creates the URI suffix for a ReportEligibilityEvalTaskResult protocol
+// request.
+absl::StatusOr<std::string> CreateReportEligibilityEvalTaskResultUriSuffix(
+ absl::string_view population_name, absl::string_view session_id) {
+ constexpr absl::string_view kReportEligibilityEvalTaskResultUriSuffix =
+ "/v1/populations/$0/eligibilityevaltasks/$1:reportresult";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_population_name,
+ EncodeUriSinglePathSegment(population_name));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_session_id,
+ EncodeUriSinglePathSegment(session_id));
+ return absl::Substitute(kReportEligibilityEvalTaskResultUriSuffix,
+ encoded_population_name, encoded_session_id);
+}
+
+// Creates the URI suffix for a StartTaskAssignment protocol request.
+absl::StatusOr<std::string> CreateStartTaskAssignmentUriSuffix(
+ absl::string_view population_name, absl::string_view session_id) {
+ constexpr absl::string_view kStartTaskAssignmentUriSuffix =
+ "/v1/populations/$0/taskassignments/$1:start";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_population_name,
+ EncodeUriSinglePathSegment(population_name));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_session_id,
+ EncodeUriSinglePathSegment(session_id));
+ return absl::Substitute(kStartTaskAssignmentUriSuffix,
+ encoded_population_name, encoded_session_id);
+}
+
+// Creates he URI suffix for a ReportTaskResult protocol request.
+absl::StatusOr<std::string> CreateReportTaskResultUriSuffix(
+ absl::string_view population_name, absl::string_view session_id) {
+ constexpr absl::string_view pattern =
+ "/v1/populations/$0/taskassignments/$1:reportresult";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_population_name,
+ EncodeUriSinglePathSegment(population_name));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_session_id,
+ EncodeUriSinglePathSegment(session_id));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_population_name, encoded_session_id);
+}
+
+absl::StatusOr<std::string> CreateStartAggregationDataUploadUriSuffix(
+ absl::string_view aggregation_id, absl::string_view client_token) {
+ constexpr absl::string_view pattern =
+ "/v1/aggregations/$0/clients/$1:startdataupload";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
+ EncodeUriSinglePathSegment(aggregation_id));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
+ EncodeUriSinglePathSegment(client_token));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_aggregation_id,
+ encoded_client_token);
+}
+
+absl::StatusOr<std::string> CreateSubmitAggregationResultUriSuffix(
+ absl::string_view aggregation_id, absl::string_view client_token) {
+ constexpr absl::string_view pattern = "/v1/aggregations/$0/clients/$1:submit";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
+ EncodeUriSinglePathSegment(aggregation_id));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
+ EncodeUriSinglePathSegment(client_token));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_aggregation_id,
+ encoded_client_token);
+}
+
+absl::StatusOr<std::string> CreateAbortAggregationUriSuffix(
+ absl::string_view aggregation_id, absl::string_view client_token) {
+ constexpr absl::string_view pattern = "/v1/aggregations/$0/clients/$1:abort";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
+ EncodeUriSinglePathSegment(aggregation_id));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
+ EncodeUriSinglePathSegment(client_token));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_aggregation_id,
+ encoded_client_token);
+}
+
+absl::StatusOr<std::string> CreateStartSecureAggregationUriSuffix(
+ absl::string_view aggregation_id, absl::string_view client_token) {
+ constexpr absl::string_view pattern =
+ "/v1/secureaggregations/$0/clients/$1:start";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
+ EncodeUriSinglePathSegment(aggregation_id));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
+ EncodeUriSinglePathSegment(client_token));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_aggregation_id,
+ encoded_client_token);
+}
+
+// Convert a Resource proto into a UriOrInlineData object. Returns an
+// `INVALID_ARGUMENT` error if the given `Resource` has the `uri` field set to
+// an empty value, or an `UNIMPLEMENTED` error if the `Resource` has an unknown
+// field set.
+absl::StatusOr<UriOrInlineData> ConvertResourceToUriOrInlineData(
+ const Resource& resource) {
+ switch (resource.resource_case()) {
+ case Resource::ResourceCase::kUri:
+ if (resource.uri().empty()) {
+ return absl::InvalidArgumentError(
+ "Resource.uri must be non-empty when set");
+ }
+ return UriOrInlineData::CreateUri(
+ resource.uri(), resource.client_cache_id(),
+ TimeUtil::ConvertProtoToAbslDuration(resource.max_age()));
+ case Resource::ResourceCase::kInlineResource: {
+ CompressionFormat compression_format = CompressionFormat::kUncompressed;
+ if (resource.inline_resource().has_compression_format()) {
+ switch (resource.inline_resource().compression_format()) {
+ case ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP:
+ compression_format = CompressionFormat::kGzip;
+ break;
+ default:
+ return absl::UnimplementedError(
+ "Unknown ResourceCompressionFormat");
+ }
+ }
+ return UriOrInlineData::CreateInlineData(
+ absl::Cord(resource.inline_resource().data()), compression_format);
+ }
+ case Resource::ResourceCase::RESOURCE_NOT_SET:
+ // If neither field is set at all, we'll just act as if we got an empty
+ // inline data field.
+ return UriOrInlineData::CreateInlineData(
+ absl::Cord(), CompressionFormat::kUncompressed);
+ default:
+ return absl::UnimplementedError("Unknown Resource type");
+ }
+}
+
+::google::internal::federatedcompute::v1::Code ConvertPhaseOutcomeToRpcCode(
+ engine::PhaseOutcome phase_outcome) {
+ switch (phase_outcome) {
+ case engine::PhaseOutcome::COMPLETED:
+ return ::google::internal::federatedcompute::v1::Code::OK;
+ case engine::PhaseOutcome::ERROR:
+ return ::google::internal::federatedcompute::v1::Code::INTERNAL;
+ case engine::PhaseOutcome::INTERRUPTED:
+ return ::google::internal::federatedcompute::v1::Code::CANCELLED;
+ default:
+ return ::google::internal::federatedcompute::v1::Code::UNKNOWN;
+ }
+}
+
+absl::StatusOr<ReportTaskResultRequest> CreateReportTaskResultRequest(
+ engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ absl::string_view aggregation_id, absl::string_view task_name) {
+ ReportTaskResultRequest request;
+ request.set_aggregation_id(std::string(aggregation_id));
+ request.set_task_name(std::string(task_name));
+ request.set_computation_status_code(
+ ConvertPhaseOutcomeToRpcCode(phase_outcome));
+ ClientStats* client_stats = request.mutable_client_stats();
+ *client_stats->mutable_computation_execution_duration() =
+ TimeUtil::ConvertAbslToProtoDuration(plan_duration);
+ return request;
+}
+
+// Creates a special InterruptibleRunner which won't check the should_abort
+// function until the timeout duration is passed. This special
+// InterruptibleRunner is used to issue Cancellation requests or Abort requests.
+std::unique_ptr<InterruptibleRunner> CreateDelayedInterruptibleRunner(
+ LogManager* log_manager, std::function<bool()> should_abort,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ absl::Time deadline) {
+ return std::make_unique<InterruptibleRunner>(
+ log_manager,
+ [deadline, should_abort]() {
+ return absl::Now() > deadline && should_abort();
+ },
+ timing_config,
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
+ .interrupt_timeout =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
+}
+} // namespace
+
+HttpFederatedProtocol::HttpFederatedProtocol(
+ Clock* clock, LogManager* log_manager, const Flags* flags,
+ HttpClient* http_client,
+ std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
+ SecAggEventPublisher* secagg_event_publisher,
+ absl::string_view entry_point_uri, absl::string_view api_key,
+ absl::string_view population_name, absl::string_view retry_token,
+ absl::string_view client_version, absl::string_view attestation_measurement,
+ std::function<bool()> should_abort, absl::BitGen bit_gen,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ cache::ResourceCache* resource_cache)
+ : object_state_(ObjectState::kInitialized),
+ clock_(*clock),
+ log_manager_(log_manager),
+ flags_(flags),
+ http_client_(http_client),
+ secagg_runner_factory_(std::move(secagg_runner_factory)),
+ secagg_event_publisher_(secagg_event_publisher),
+ interruptible_runner_(std::make_unique<InterruptibleRunner>(
+ log_manager, should_abort, timing_config,
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
+ .interrupt_timeout =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT})),
+ eligibility_eval_request_creator_(
+ std::make_unique<ProtocolRequestCreator>(
+ entry_point_uri, api_key, HeaderList{},
+ !flags->disable_http_request_body_compression())),
+ protocol_request_helper_(http_client, &bytes_downloaded_,
+ &bytes_uploaded_, network_stopwatch_.get(),
+ clock),
+ api_key_(api_key),
+ population_name_(population_name),
+ retry_token_(retry_token),
+ client_version_(client_version),
+ attestation_measurement_(attestation_measurement),
+ should_abort_(std::move(should_abort)),
+ bit_gen_(std::move(bit_gen)),
+ timing_config_(timing_config),
+ waiting_period_for_cancellation_(
+ absl::Seconds(flags->waiting_period_sec_for_cancellation())),
+ resource_cache_(resource_cache) {
+ // Note that we could cast the provided error codes to absl::StatusCode
+ // values here. However, that means we'd have to handle the case when
+ // invalid integers that don't map to a StatusCode enum are provided in the
+ // flag here. Instead, we cast absl::StatusCodes to int32_t each time we
+ // compare them with the flag-provided list of codes, which means we never
+ // have to worry about invalid flag values (besides the fact that invalid
+ // values will be silently ignored, which could make it harder to realize when
+ // a flag is misconfigured).
+ const std::vector<int32_t>& error_codes =
+ flags->federated_training_permanent_error_codes();
+ federated_training_permanent_error_codes_ =
+ absl::flat_hash_set<int32_t>(error_codes.begin(), error_codes.end());
+}
+
+absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
+HttpFederatedProtocol::EligibilityEvalCheckin(
+ std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback) {
+ FCP_CHECK(object_state_ == ObjectState::kInitialized)
+ << "Invalid call sequence";
+ object_state_ = ObjectState::kEligibilityEvalCheckinFailed;
+
+ // Send the request and parse the response.
+ auto response = HandleEligibilityEvalTaskResponse(
+ PerformEligibilityEvalTaskRequest(), payload_uris_received_callback);
+ // Update the object state to ensure we return the correct retry delay.
+ UpdateObjectStateIfPermanentError(
+ response.status(),
+ ObjectState::kEligibilityEvalCheckinFailedPermanentError);
+ if (response.ok() && std::holds_alternative<EligibilityEvalTask>(*response)) {
+ eligibility_eval_enabled_ = true;
+ }
+ return response;
+}
+
+absl::StatusOr<InMemoryHttpResponse>
+HttpFederatedProtocol::PerformEligibilityEvalTaskRequest() {
+ // Create and serialize the request body. Note that the `population_name`
+ // field is set in the URI instead of in this request proto message.
+ EligibilityEvalTaskRequest request;
+ request.mutable_client_version()->set_version_code(client_version_);
+ request.mutable_attestation_measurement()->set_value(
+ attestation_measurement_);
+
+ request.mutable_resource_capabilities()->add_supported_compression_formats(
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+ request.mutable_eligibility_eval_task_capabilities()
+ ->set_supports_multiple_task_assignment(
+ flags_->http_protocol_supports_multiple_task_assignments());
+
+ FCP_ASSIGN_OR_RETURN(
+ std::string uri_suffix,
+ CreateRequestEligibilityEvalTaskUriSuffix(population_name_));
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_request,
+ eligibility_eval_request_creator_->CreateProtocolRequest(
+ uri_suffix, {}, HttpRequest::Method::kPost,
+ request.SerializeAsString(), /*is_protobuf_encoded=*/true));
+
+ // Issue the request.
+ return protocol_request_helper_.PerformProtocolRequest(
+ std::move(http_request), *interruptible_runner_);
+}
+
+absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
+HttpFederatedProtocol::HandleEligibilityEvalTaskResponse(
+ absl::StatusOr<InMemoryHttpResponse> http_response,
+ std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback) {
+ if (!http_response.ok()) {
+ // If the protocol request failed then forward the error, but add a prefix
+ // to the error message to ensure we can easily distinguish an HTTP error
+ // occurring in response to the protocol request from HTTP errors occurring
+ // during checkpoint/plan resource fetch requests later on.
+ return absl::Status(http_response.status().code(),
+ absl::StrCat("protocol request failed: ",
+ http_response.status().ToString()));
+ }
+
+ EligibilityEvalTaskResponse response_proto;
+ if (!response_proto.ParseFromString(std::string(http_response->body))) {
+ return absl::InvalidArgumentError("Could not parse response_proto");
+ }
+
+ // Upon receiving the server's RetryWindows we immediately choose a concrete
+ // target timestamp to retry at. This ensures that a) clients of this class
+ // don't have to implement the logic to select a timestamp from a min/max
+ // range themselves, b) we tell clients of this class to come back at exactly
+ // a point in time the server intended us to come at (i.e. "now +
+ // server_specified_retry_period", and not a point in time that is partly
+ // determined by how long the remaining protocol interactions (e.g. training
+ // and results upload) will take (i.e. "now +
+ // duration_of_remaining_protocol_interactions +
+ // server_specified_retry_period").
+ retry_times_ = RetryTimes{
+ .retry_time_if_rejected = PickRetryTimeFromRange(
+ response_proto.retry_window_if_rejected().delay_min(),
+ response_proto.retry_window_if_rejected().delay_max(), bit_gen_),
+ .retry_time_if_accepted = PickRetryTimeFromRange(
+ response_proto.retry_window_if_accepted().delay_min(),
+ response_proto.retry_window_if_accepted().delay_max(), bit_gen_)};
+
+ // If the request was rejected then the protocol session has ended and there's
+ // no more work for us to do.
+ if (response_proto.has_rejection_info()) {
+ object_state_ = ObjectState::kEligibilityEvalCheckinRejected;
+ return Rejection{};
+ }
+
+ pre_task_assignment_session_id_ = response_proto.session_id();
+
+ FCP_ASSIGN_OR_RETURN(
+ task_assignment_request_creator_,
+ ProtocolRequestCreator::Create(
+ api_key_, response_proto.task_assignment_forwarding_info(),
+ !flags_->disable_http_request_body_compression()));
+
+ switch (response_proto.result_case()) {
+ case EligibilityEvalTaskResponse::kEligibilityEvalTask: {
+ const auto& task = response_proto.eligibility_eval_task();
+
+ EligibilityEvalTask result{.execution_id = task.execution_id()};
+ payload_uris_received_callback(result);
+
+ // Fetch the task resources, returning any errors that may be encountered
+ // in the process.
+ FCP_ASSIGN_OR_RETURN(
+ result.payloads,
+ FetchTaskResources(
+ {.plan = task.plan(), .checkpoint = task.init_checkpoint()}));
+ if (task.has_population_eligibility_spec() &&
+ flags_->http_protocol_supports_multiple_task_assignments()) {
+ FCP_ASSIGN_OR_RETURN(
+ result.population_eligibility_spec,
+ FetchPopulationEligibilitySpec(task.population_eligibility_spec()));
+ }
+
+ object_state_ = ObjectState::kEligibilityEvalEnabled;
+ return std::move(result);
+ }
+ case EligibilityEvalTaskResponse::kNoEligibilityEvalConfigured: {
+ // Nothing to do...
+ object_state_ = ObjectState::kEligibilityEvalDisabled;
+ return EligibilityEvalDisabled{};
+ }
+ default:
+ return absl::UnimplementedError(
+ "Unrecognized EligibilityEvalCheckinResponse");
+ }
+}
+
+absl::StatusOr<std::unique_ptr<HttpRequest>>
+HttpFederatedProtocol::CreateReportEligibilityEvalTaskResultRequest(
+ absl::Status status) {
+ ReportEligibilityEvalTaskResultRequest request;
+ request.set_status_code(
+ static_cast<::google::internal::federatedcompute::v1::Code>(
+ status.code()));
+ FCP_ASSIGN_OR_RETURN(std::string uri_suffix,
+ CreateReportEligibilityEvalTaskResultUriSuffix(
+ population_name_, pre_task_assignment_session_id_));
+ return eligibility_eval_request_creator_->CreateProtocolRequest(
+ uri_suffix, QueryParams(), HttpRequest::Method::kPost,
+ request.SerializeAsString(),
+ /*is_protobuf_encoded=*/true);
+}
+
+void HttpFederatedProtocol::ReportEligibilityEvalError(
+ absl::Status error_status) {
+ if (!ReportEligibilityEvalErrorInternal(error_status).ok()) {
+ log_manager_->LogDiag(
+ ProdDiagCode::HTTP_REPORT_ELIGIBILITY_EVAL_RESULT_REQUEST_FAILED);
+ }
+}
+
+absl::Status HttpFederatedProtocol::ReportEligibilityEvalErrorInternal(
+ absl::Status error_status) {
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> request,
+ CreateReportEligibilityEvalTaskResultRequest(error_status));
+ return protocol_request_helper_
+ .PerformProtocolRequest(std::move(request), *interruptible_runner_)
+ .status();
+}
+
+absl::StatusOr<FederatedProtocol::CheckinResult> HttpFederatedProtocol::Checkin(
+ const std::optional<TaskEligibilityInfo>& task_eligibility_info,
+ std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
+ // Checkin(...) must follow an earlier call to EligibilityEvalCheckin() that
+ // resulted in a CheckinResultPayload or an EligibilityEvalDisabled result. Or
+ // it must follow a PerformMultipleTaskAssignments(...) regardless of the
+ // outcome for the call.
+ FCP_CHECK(object_state_ == ObjectState::kEligibilityEvalDisabled ||
+ object_state_ == ObjectState::kEligibilityEvalEnabled ||
+ object_state_ == ObjectState::kMultipleTaskAssignmentsAccepted ||
+ object_state_ == ObjectState::kMultipleTaskAssignmentsFailed ||
+ object_state_ ==
+ ObjectState::kMultipleTaskAssignmentsFailedPermanentError ||
+ object_state_ ==
+ ObjectState::kMultipleTaskAssignmentsNoAvailableTask)
+ << "Checkin(...) called despite failed/rejected earlier "
+ "EligibilityEvalCheckin";
+ if (object_state_ == ObjectState::kEligibilityEvalEnabled) {
+ FCP_CHECK(task_eligibility_info.has_value())
+ << "Missing TaskEligibilityInfo despite receiving prior "
+ "EligibilityEvalCheckin payload";
+ } else {
+ FCP_CHECK(!task_eligibility_info.has_value())
+ << "Received TaskEligibilityInfo despite not receiving a prior "
+ "EligibilityEvalCheckin payload";
+ }
+ object_state_ = ObjectState::kCheckinFailed;
+
+ // Send the request and parse the response.
+ auto response = HandleTaskAssignmentOperationResponse(
+ PerformTaskAssignmentAndReportEligibilityEvalResultRequests(
+ task_eligibility_info),
+ payload_uris_received_callback);
+
+ // Update the object state to ensure we return the correct retry delay.
+ UpdateObjectStateIfPermanentError(response.status(),
+ ObjectState::kCheckinFailedPermanentError);
+ return response;
+}
+
+absl::StatusOr<InMemoryHttpResponse> HttpFederatedProtocol::
+ PerformTaskAssignmentAndReportEligibilityEvalResultRequests(
+ const std::optional<TaskEligibilityInfo>& task_eligibility_info) {
+ // Create and serialize the request body. Note that the `population_name`
+ // and `session_id` fields are set in the URI instead of in this request
+ // proto message.
+ StartTaskAssignmentRequest request;
+ request.mutable_client_version()->set_version_code(client_version_);
+
+ if (task_eligibility_info.has_value()) {
+ *request.mutable_task_eligibility_info() = *task_eligibility_info;
+ }
+
+ request.mutable_resource_capabilities()->add_supported_compression_formats(
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+
+ std::vector<std::unique_ptr<HttpRequest>> requests;
+
+ // Construct the URI suffix.
+ FCP_ASSIGN_OR_RETURN(std::string task_assignment_uri_suffix,
+ CreateStartTaskAssignmentUriSuffix(
+ population_name_, pre_task_assignment_session_id_));
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> task_assignment_http_request,
+ task_assignment_request_creator_->CreateProtocolRequest(
+ task_assignment_uri_suffix, {}, HttpRequest::Method::kPost,
+ request.SerializeAsString(), /*is_protobuf_encoded=*/true));
+ requests.push_back(std::move(task_assignment_http_request));
+
+ if (eligibility_eval_enabled_) {
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest>
+ report_eligibility_eval_result_http_request,
+ CreateReportEligibilityEvalTaskResultRequest(absl::OkStatus()));
+ requests.push_back(std::move(report_eligibility_eval_result_http_request));
+ }
+
+ // Issue the request.
+ FCP_ASSIGN_OR_RETURN(
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> responses,
+ protocol_request_helper_.PerformMultipleProtocolRequests(
+ std::move(requests), *interruptible_runner_));
+ // The responses are returned in order. The first one is for the task
+ // assignment request. The second one (optional) is for the report eligibility
+ // eval task result request. We only care about the first one.
+ if (eligibility_eval_enabled_ && !responses[1].ok()) {
+ log_manager_->LogDiag(
+ ProdDiagCode::HTTP_REPORT_ELIGIBILITY_EVAL_RESULT_REQUEST_FAILED);
+ }
+ return responses[0];
+}
+
+absl::StatusOr<FederatedProtocol::CheckinResult>
+HttpFederatedProtocol::HandleTaskAssignmentOperationResponse(
+ absl::StatusOr<InMemoryHttpResponse> http_response,
+ std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
+ // If the initial response was not successful, then return immediately, even
+ // if the result was CANCELLED, since we won't have received an operation name
+ // to issue a CancelOperationRequest with anyway.
+ FCP_RETURN_IF_ERROR(http_response);
+ StartTaskAssignmentResponse response_proto;
+ if (!response_proto.ParseFromString(std::string(http_response->body))) {
+ return absl::InvalidArgumentError(
+ "could not parse StartTaskAssignmentResponse proto");
+ }
+
+ // absl::StatusOr<Operation> initial_operation =
+ // ParseOperationProtoFromHttpResponse(http_response);
+ // if (!initial_operation.ok()) {
+ // return absl::Status(initial_operation.status().code(),
+ // absl::StrCat("protocol request failed: ",
+ // initial_operation.status().ToString()));
+ // }
+ // absl::StatusOr<Operation> response_operation_proto =
+ // protocol_request_helper_.PollOperationResponseUntilDone(
+ // *initial_operation, *task_assignment_request_creator_,
+ // *interruptible_runner_);
+ // if (!response_operation_proto.ok()) {
+ // // If the protocol request failed then issue a cancellation request to
+ // let
+ // // the server know the operation will be abandoned, and forward the
+ // error,
+ // // but add a prefix to the error message to ensure we can easily
+ // // distinguish an HTTP error occurring in response to the protocol
+ // request
+ // // from HTTP errors occurring during checkpoint/plan resource fetch
+ // // requests later on.
+ // FCP_ASSIGN_OR_RETURN(std::string operation_name,
+ // ExtractOperationName(*initial_operation));
+ // // Client interruption
+ // std::unique_ptr<InterruptibleRunner> cancellation_runner =
+ // CreateDelayedInterruptibleRunner(
+ // log_manager_, should_abort_, timing_config_,
+ // absl::Now() + waiting_period_for_cancellation_);
+ // if (!protocol_request_helper_
+ // .CancelOperation(operation_name,
+ // *task_assignment_request_creator_,
+ // *cancellation_runner)
+ // .ok()) {
+ // log_manager_->LogDiag(
+ // ProdDiagCode::HTTP_CANCELLATION_OR_ABORT_REQUEST_FAILED);
+ // }
+ // return absl::Status(
+ // response_operation_proto.status().code(),
+ // absl::StrCat("protocol request failed: ",
+ // response_operation_proto.status().ToString()));
+ // }
+
+ // // The Operation has finished. Check if it resulted in an error, and if
+ // so
+ // // forward it after converting it to an absl::Status error.
+ // if (response_operation_proto->has_error()) {
+ // auto rpc_error =
+ // ConvertRpcStatusToAbslStatus(response_operation_proto->error());
+ // return absl::Status(
+ // rpc_error.code(),
+ // absl::StrCat("Operation ", response_operation_proto->name(),
+ // " contained error: ", rpc_error.ToString()));
+ // }
+
+ // Otherwise, handle the StartTaskAssignmentResponse that should have been
+ // returned by the Operation response proto.
+ return HandleTaskAssignmentInnerResponse(response_proto,
+ payload_uris_received_callback);
+}
+
+absl::StatusOr<FederatedProtocol::CheckinResult>
+HttpFederatedProtocol::HandleTaskAssignmentInnerResponse(
+ const StartTaskAssignmentResponse& response_proto,
+ std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
+ // StartTaskAssignmentResponse response_proto;
+ // if (!operation_response.UnpackTo(&response_proto)) {
+ // return absl::InvalidArgumentError(
+ // "could not parse StartTaskAssignmentResponse proto");
+ // }
+ if (response_proto.has_rejection_info()) {
+ object_state_ = ObjectState::kCheckinRejected;
+ return Rejection{};
+ }
+ if (!response_proto.has_task_assignment()) {
+ return absl::UnimplementedError("Unrecognized StartTaskAssignmentResponse");
+ }
+ const auto& task_assignment = response_proto.task_assignment();
+
+ FCP_ASSIGN_OR_RETURN(
+ default_task_info_.aggregation_request_creator,
+ ProtocolRequestCreator::Create(
+ api_key_, task_assignment.aggregation_data_forwarding_info(),
+ !flags_->disable_http_request_body_compression()));
+
+ TaskAssignment result = {
+ .federated_select_uri_template =
+ task_assignment.federated_select_uri_info().uri_template(),
+ .aggregation_session_id = task_assignment.aggregation_id(),
+ .sec_agg_info = std::nullopt};
+ if (task_assignment.has_secure_aggregation_info()) {
+ result.sec_agg_info =
+ SecAggInfo{.minimum_clients_in_server_visible_aggregate =
+ task_assignment.secure_aggregation_info()
+ .minimum_clients_in_server_visible_aggregate()};
+ }
+
+ payload_uris_received_callback(result);
+
+ // Fetch the task resources, returning any errors that may be encountered in
+ // the process.
+ FCP_ASSIGN_OR_RETURN(
+ result.payloads,
+ FetchTaskResources({.plan = task_assignment.plan(),
+ .checkpoint = task_assignment.init_checkpoint()}));
+
+ object_state_ = ObjectState::kCheckinAccepted;
+ default_task_info_.state = ObjectState::kCheckinAccepted;
+ default_task_info_.session_id = task_assignment.session_id();
+ default_task_info_.aggregation_session_id = task_assignment.aggregation_id();
+ default_task_info_.aggregation_authorization_token =
+ task_assignment.authorization_token();
+ default_task_info_.task_name = task_assignment.task_name();
+
+ return std::move(result);
+}
+
+absl::StatusOr<FederatedProtocol::MultipleTaskAssignments>
+HttpFederatedProtocol::PerformMultipleTaskAssignments(
+ const std::vector<std::string>& task_names) {
+ // PerformMultipleTaskAssignments(...) must follow an earlier call to
+ // EligibilityEvalCheckin() that resulted in a EligibilityEvalTask with
+ // PopulationEligibilitySpec.
+ FCP_CHECK(object_state_ == ObjectState::kEligibilityEvalDisabled ||
+ object_state_ == ObjectState::kEligibilityEvalEnabled)
+ << "PerformMultipleTaskAssignments(...) called despite failed/rejected "
+ "earlier "
+ "EligibilityEvalCheckin";
+ object_state_ = ObjectState::kMultipleTaskAssignmentsFailed;
+ return absl::UnimplementedError(
+ "PerformMultipleTaskAssignments is not yet implemented.");
+}
+
+absl::Status HttpFederatedProtocol::ReportCompleted(
+ ComputationResults results, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) {
+ FCP_LOG(INFO) << "Reporting outcome: " << static_cast<int>(engine::COMPLETED);
+ PerTaskInfo* task_info;
+ if (aggregation_session_id.has_value()) {
+ if (!task_info_map_.contains(aggregation_session_id.value())) {
+ return absl::InvalidArgumentError("Unexpected aggregation_session_id.");
+ }
+ task_info = &task_info_map_[aggregation_session_id.value()];
+ } else {
+ task_info = &default_task_info_;
+ }
+ FCP_CHECK(task_info->state == ObjectState::kCheckinAccepted ||
+ task_info->state == ObjectState::kMultipleTaskAssignmentsAccepted)
+ << "Invalid call sequence";
+ task_info->state = ObjectState::kReportCalled;
+ auto find_secagg_tensor_lambda = [](const auto& item) {
+ return std::holds_alternative<QuantizedTensor>(item.second);
+ };
+ if (std::find_if(results.begin(), results.end(), find_secagg_tensor_lambda) ==
+ results.end()) {
+ return ReportViaSimpleAggregation(std::move(results), plan_duration,
+ *task_info);
+ } else {
+ return ReportViaSecureAggregation(std::move(results), plan_duration,
+ *task_info);
+ }
+}
+
+absl::Status HttpFederatedProtocol::ReportViaSimpleAggregation(
+ ComputationResults results, absl::Duration plan_duration,
+ PerTaskInfo& task_info) {
+ if (results.size() != 1 ||
+ !std::holds_alternative<TFCheckpoint>(results.begin()->second)) {
+ return absl::InternalError(
+ "Simple Aggregation aggregands have unexpected format.");
+ }
+ auto start_upload_status = HandleStartDataAggregationUploadOperationResponse(
+ PerformStartDataUploadRequestAndReportTaskResult(plan_duration,
+ task_info),
+ task_info);
+ if (!start_upload_status.ok()) {
+ task_info.state = ObjectState::kReportFailedPermanentError;
+ return start_upload_status;
+ }
+ auto upload_status = UploadDataViaSimpleAgg(
+ std::get<TFCheckpoint>(std::move(results.begin()->second)), task_info);
+ if (!upload_status.ok()) {
+ task_info.state = ObjectState::kReportFailedPermanentError;
+ if (upload_status.code() != absl::StatusCode::kAborted &&
+ !AbortAggregation(upload_status,
+ "Upload data via simple aggregation failed.",
+ task_info)
+ .ok()) {
+ log_manager_->LogDiag(
+ ProdDiagCode::HTTP_CANCELLATION_OR_ABORT_REQUEST_FAILED);
+ }
+ return upload_status;
+ }
+ return SubmitAggregationResult(task_info);
+}
+
+absl::StatusOr<InMemoryHttpResponse>
+HttpFederatedProtocol::PerformStartDataUploadRequestAndReportTaskResult(
+ absl::Duration plan_duration, PerTaskInfo& task_info) {
+ FCP_ASSIGN_OR_RETURN(
+ ReportTaskResultRequest report_task_result_request,
+ CreateReportTaskResultRequest(
+ engine::PhaseOutcome::COMPLETED, plan_duration,
+ task_info.aggregation_session_id, task_info.task_name));
+ FCP_ASSIGN_OR_RETURN(
+ std::string report_task_result_uri_suffix,
+ CreateReportTaskResultUriSuffix(population_name_, task_info.session_id));
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_report_task_result_request,
+ task_assignment_request_creator_->CreateProtocolRequest(
+ report_task_result_uri_suffix, {}, HttpRequest::Method::kPost,
+ report_task_result_request.SerializeAsString(),
+ /*is_protobuf_encoded=*/true));
+
+ StartAggregationDataUploadRequest start_upload_request;
+ FCP_ASSIGN_OR_RETURN(std::string start_aggregation_data_upload_uri_suffix,
+ CreateStartAggregationDataUploadUriSuffix(
+ task_info.aggregation_session_id,
+ task_info.aggregation_authorization_token));
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_start_aggregation_data_upload_request,
+ task_info.aggregation_request_creator->CreateProtocolRequest(
+ start_aggregation_data_upload_uri_suffix, {},
+ HttpRequest::Method::kPost, start_upload_request.SerializeAsString(),
+ /*is_protobuf_encoded=*/true));
+ FCP_LOG(INFO) << "StartAggregationDataUpload request uri is: "
+ << http_start_aggregation_data_upload_request->uri();
+ FCP_LOG(INFO) << "ReportTaskResult request uri is: "
+ << http_report_task_result_request->uri();
+ std::vector<std::unique_ptr<HttpRequest>> requests;
+ requests.push_back(std::move(http_start_aggregation_data_upload_request));
+ requests.push_back(std::move(http_report_task_result_request));
+ FCP_ASSIGN_OR_RETURN(
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> responses,
+ protocol_request_helper_.PerformMultipleProtocolRequests(
+ std::move(requests), *interruptible_runner_));
+ // We should have two responses, otherwise we have made a developer error.
+ FCP_CHECK(responses.size() == 2);
+ // The responses are returned in order so the first response will be the one
+ // for StartAggregationDataUpload request. We only care about this response,
+ // the ReportTaskResult request is just a best effort to report client metrics
+ // to the server, and we don't want to abort the aggregation even if it
+ // failed.
+ if (!responses[1].ok()) {
+ log_manager_->LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED);
+ }
+ return responses[0];
+}
+
+absl::Status
+HttpFederatedProtocol::HandleStartDataAggregationUploadOperationResponse(
+ absl::StatusOr<InMemoryHttpResponse> http_response,
+ PerTaskInfo& task_info) {
+ // absl::StatusOr<Operation> operation =
+ // ParseOperationProtoFromHttpResponse(http_response);
+ // if (!operation.ok()) {
+ // // If the protocol request failed then forward the error, but add a
+ // prefix
+ // // to the error message to ensure we can easily distinguish an HTTP error
+ // // occurring in response to the protocol request from HTTP errors
+ // // occurring during upload requests later on.
+ // return absl::Status(
+ // operation.status().code(),
+ // absl::StrCat(
+ // "StartAggregationDataUpload request failed during polling: ",
+ // operation.status().ToString()));
+ // }
+ // absl::StatusOr<Operation> response_operation_proto =
+ // protocol_request_helper_.PollOperationResponseUntilDone(
+ // *operation, *task_info.aggregation_request_creator,
+ // *interruptible_runner_);
+ // if (!response_operation_proto.ok()) {
+ // return absl::Status(
+ // response_operation_proto.status().code(),
+ // absl::StrCat("StartAggregationDataUpload request failed: ",
+ // response_operation_proto.status().ToString()));
+ // }
+
+ // // The Operation has finished. Check if it resulted in an error, and if so
+ // // forward it after converting it to an absl::Status error.
+ // if (response_operation_proto->has_error()) {
+ // auto rpc_error =
+ // ConvertRpcStatusToAbslStatus(response_operation_proto->error());
+ // return absl::Status(
+ // rpc_error.code(),
+ // absl::StrCat("Operation ", response_operation_proto->name(),
+ // " contained error: ", rpc_error.ToString()));
+ // }
+
+ // Otherwise, handle the StartDataAggregationUploadResponse that should have
+ // been returned by the Operation response proto.
+ FCP_RETURN_IF_ERROR(http_response);
+ StartAggregationDataUploadResponse response_proto;
+ if (!response_proto.ParseFromString(std::string(http_response->body))) {
+ return absl::InvalidArgumentError(
+ "could not parse StartTaskAssignmentResponse proto");
+ }
+
+ // Note that we reassign `aggregation_request_creator_` because from this
+ // point onwards, subsequent aggregation protocol requests should go to the
+ // endpoint identified in the aggregation_protocol_forwarding_info.
+ FCP_ASSIGN_OR_RETURN(
+ task_info.aggregation_request_creator,
+ ProtocolRequestCreator::Create(
+ api_key_, response_proto.aggregation_protocol_forwarding_info(),
+ !flags_->disable_http_request_body_compression()));
+ auto upload_resource = response_proto.resource();
+ task_info.aggregation_resource_name = upload_resource.resource_name();
+ FCP_ASSIGN_OR_RETURN(
+ task_info.data_upload_request_creator,
+ ProtocolRequestCreator::Create(
+ api_key_, upload_resource.data_upload_forwarding_info(),
+ !flags_->disable_http_request_body_compression()));
+ // TODO(team): Remove the authorization token fallback once
+ // client_token is always populated.
+ task_info.aggregation_client_token =
+ !response_proto.client_token().empty()
+ ? response_proto.client_token()
+ : task_info.aggregation_authorization_token;
+ return absl::OkStatus();
+}
+
+absl::Status HttpFederatedProtocol::UploadDataViaSimpleAgg(
+ std::string tf_checkpoint, PerTaskInfo& task_info) {
+ FCP_LOG(INFO) << "Uploading checkpoint with simple aggregation.";
+ FCP_ASSIGN_OR_RETURN(
+ std::string uri_suffix,
+ CreateByteStreamUploadUriSuffix(task_info.aggregation_resource_name));
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_request,
+ task_info.data_upload_request_creator->CreateProtocolRequest(
+ uri_suffix, {{"upload_protocol", "raw"}}, HttpRequest::Method::kPost,
+ std::move(tf_checkpoint), /*is_protobuf_encoded=*/false));
+ FCP_LOG(INFO) << "ByteStream.Write request URI is: " << http_request->uri();
+ auto http_response = protocol_request_helper_.PerformProtocolRequest(
+ std::move(http_request), *interruptible_runner_);
+ if (!http_response.ok()) {
+ // If the request failed, we'll forward the error status.
+ return absl::Status(http_response.status().code(),
+ absl::StrCat("Data upload failed: ",
+ http_response.status().ToString()));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status HttpFederatedProtocol::SubmitAggregationResult(
+ PerTaskInfo& task_info) {
+ FCP_LOG(INFO) << "Notifying the server that data upload is complete.";
+ FCP_ASSIGN_OR_RETURN(std::string uri_suffix,
+ CreateSubmitAggregationResultUriSuffix(
+ task_info.aggregation_session_id,
+ task_info.aggregation_client_token));
+ SubmitAggregationResultRequest request;
+ request.set_resource_name(task_info.aggregation_resource_name);
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_request,
+ task_info.aggregation_request_creator->CreateProtocolRequest(
+ uri_suffix, {}, HttpRequest::Method::kPost,
+ request.SerializeAsString(), /*is_protobuf_encoded=*/true));
+ FCP_LOG(INFO) << "SubmitAggregationResult request URI is: "
+ << http_request->uri();
+ auto http_response = protocol_request_helper_.PerformProtocolRequest(
+ std::move(http_request), *interruptible_runner_);
+ if (!http_response.ok()) {
+ // If the request failed, we'll forward the error status.
+ return absl::Status(http_response.status().code(),
+ absl::StrCat("SubmitAggregationResult failed: ",
+ http_response.status().ToString()));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status HttpFederatedProtocol::AbortAggregation(
+ absl::Status original_error_status,
+ absl::string_view error_message_for_server, PerTaskInfo& task_info) {
+ FCP_LOG(INFO) << "Aborting aggregation: " << original_error_status;
+ FCP_CHECK(task_info.state == ObjectState::kReportFailedPermanentError)
+ << "Invalid call sequence";
+ FCP_ASSIGN_OR_RETURN(
+ std::string uri_suffix,
+ CreateAbortAggregationUriSuffix(task_info.aggregation_session_id,
+ task_info.aggregation_client_token));
+ // We only provide the server with a simplified error message.
+ absl::Status error_status(original_error_status.code(),
+ error_message_for_server);
+ AbortAggregationRequest request;
+ *request.mutable_status() = ConvertAbslStatusToRpcStatus(error_status);
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_request,
+ task_info.aggregation_request_creator->CreateProtocolRequest(
+ uri_suffix, {}, HttpRequest::Method::kPost,
+ request.SerializeAsString(), /*is_protobuf_encoded=*/true));
+ std::unique_ptr<InterruptibleRunner> cancellation_runner =
+ CreateDelayedInterruptibleRunner(
+ log_manager_, should_abort_, timing_config_,
+ absl::Now() + waiting_period_for_cancellation_);
+ return protocol_request_helper_
+ .PerformProtocolRequest(std::move(http_request), *cancellation_runner)
+ .status();
+}
+
+absl::Status HttpFederatedProtocol::ReportViaSecureAggregation(
+ ComputationResults results, absl::Duration plan_duration,
+ PerTaskInfo& task_info) {
+ FCP_ASSIGN_OR_RETURN(
+ StartSecureAggregationResponse response_proto,
+ StartSecureAggregationAndReportTaskResult(plan_duration, task_info));
+ SecureAggregationProtocolExecutionInfo protocol_execution_info =
+ response_proto.protocol_execution_info();
+ // TODO(team): Remove the authorization token fallback once
+ // client_token is always populated.
+ task_info.aggregation_client_token =
+ !response_proto.client_token().empty()
+ ? response_proto.client_token()
+ : task_info.aggregation_authorization_token;
+
+ // Move checkpoint out of ComputationResults, and put it into a std::optional.
+ std::optional<TFCheckpoint> tf_checkpoint;
+ for (auto& [k, v] : results) {
+ if (std::holds_alternative<TFCheckpoint>(v)) {
+ tf_checkpoint = std::get<TFCheckpoint>(std::move(v));
+ results.erase(k);
+ break;
+ }
+ }
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> server_response_holder;
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
+ HttpSecAggSendToServerImpl::Create(
+ api_key_, &clock_, &protocol_request_helper_,
+ interruptible_runner_.get(),
+ [this](absl::Time deadline) {
+ return CreateDelayedInterruptibleRunner(
+ this->log_manager_, this->should_abort_, this->timing_config_,
+ deadline);
+ },
+ &server_response_holder, task_info.aggregation_session_id,
+ task_info.aggregation_client_token,
+ response_proto.secagg_protocol_forwarding_info(),
+ response_proto.masked_result_resource(),
+ response_proto.nonmasked_result_resource(), std::move(tf_checkpoint),
+ flags_->disable_http_request_body_compression(),
+ waiting_period_for_cancellation_));
+ auto protocol_delegate = std::make_unique<HttpSecAggProtocolDelegate>(
+ response_proto.secure_aggregands(), &server_response_holder);
+ auto secagg_interruptible_runner = std::make_unique<InterruptibleRunner>(
+ log_manager_, should_abort_, timing_config_,
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
+ .interrupt_timeout =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
+ std::unique_ptr<SecAggRunner> secagg_runner =
+ secagg_runner_factory_->CreateSecAggRunner(
+ std::move(send_to_server_impl), std::move(protocol_delegate),
+ secagg_event_publisher_, log_manager_,
+ secagg_interruptible_runner.get(),
+ protocol_execution_info.expected_number_of_clients(),
+ protocol_execution_info
+ .minimum_surviving_clients_for_reconstruction());
+ FCP_RETURN_IF_ERROR(secagg_runner->Run(std::move(results)));
+ return absl::OkStatus();
+}
+
+absl::StatusOr<StartSecureAggregationResponse>
+HttpFederatedProtocol::StartSecureAggregationAndReportTaskResult(
+ absl::Duration plan_duration, PerTaskInfo& task_info) {
+ FCP_ASSIGN_OR_RETURN(std::string start_secure_aggregation_uri_suffix,
+ CreateStartSecureAggregationUriSuffix(
+ task_info.aggregation_session_id,
+ task_info.aggregation_authorization_token));
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> start_secure_aggregation_http_request,
+ task_info.aggregation_request_creator->CreateProtocolRequest(
+ start_secure_aggregation_uri_suffix, QueryParams(),
+ HttpRequest::Method::kPost,
+ StartSecureAggregationRequest::default_instance().SerializeAsString(),
+ /*is_protobuf_encoded=*/true));
+
+ FCP_ASSIGN_OR_RETURN(
+ std::string report_task_result_uri_suffix,
+ CreateReportTaskResultUriSuffix(population_name_, task_info.session_id));
+ FCP_ASSIGN_OR_RETURN(
+ ReportTaskResultRequest report_task_result_request,
+ CreateReportTaskResultRequest(
+ engine::PhaseOutcome::COMPLETED, plan_duration,
+ task_info.aggregation_session_id, task_info.task_name));
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> report_task_result_http_request,
+ task_assignment_request_creator_->CreateProtocolRequest(
+ report_task_result_uri_suffix, QueryParams(),
+ HttpRequest::Method::kPost,
+ report_task_result_request.SerializeAsString(),
+ /*is_protobuf_encoded=*/true));
+
+ std::vector<std::unique_ptr<HttpRequest>> requests;
+ requests.push_back(std::move(start_secure_aggregation_http_request));
+ requests.push_back(std::move(report_task_result_http_request));
+
+ FCP_ASSIGN_OR_RETURN(
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> responses,
+ protocol_request_helper_.PerformMultipleProtocolRequests(
+ std::move(requests), *interruptible_runner_));
+ // We will handle the response for StartSecureAggregation RPC.
+ // The ReportTaskResult RPC is for best efforts only, we will ignore the
+ // response, only log a diagcode if it fails.
+ FCP_CHECK(responses.size() == 2);
+ if (!responses[1].ok()) {
+ log_manager_->LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED);
+ }
+ // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
+ // ParseOperationProtoFromHttpResponse(responses[0]));
+ // FCP_ASSIGN_OR_RETURN(
+ // Operation completed_operation,
+ // protocol_request_helper_.PollOperationResponseUntilDone(
+ // initial_operation, *task_info.aggregation_request_creator,
+ // *interruptible_runner_));
+ // // The Operation has finished. Check if it resulted in an error, and if so
+ // // forward it after converting it to an absl::Status error.
+ // if (completed_operation.has_error()) {
+ // auto rpc_error =
+ // ConvertRpcStatusToAbslStatus(completed_operation.error()); return
+ // absl::Status(
+ // rpc_error.code(),
+ // absl::StrCat("Operation ", completed_operation.name(),
+ // " contained error: ", rpc_error.ToString()));
+ // }
+ StartSecureAggregationResponse response_proto;
+ if (!response_proto.ParseFromString(std::string(responses[0]->body))) {
+ return absl::InvalidArgumentError(
+ "could not parse StartSecureAggregationResponse proto");
+ }
+ return response_proto;
+}
+
+absl::Status HttpFederatedProtocol::ReportNotCompleted(
+ engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) {
+ FCP_LOG(WARNING) << "Reporting outcome: " << static_cast<int>(phase_outcome);
+ PerTaskInfo* task_info;
+ if (aggregation_session_id.has_value()) {
+ if (!task_info_map_.contains(aggregation_session_id.value())) {
+ return absl::InvalidArgumentError("Unexpected aggregation_session_id.");
+ }
+ task_info = &task_info_map_[aggregation_session_id.value()];
+ } else {
+ task_info = &default_task_info_;
+ }
+ FCP_CHECK(task_info->state == ObjectState::kCheckinAccepted ||
+ task_info->state == ObjectState::kMultipleTaskAssignmentsAccepted)
+ << "Invalid call sequence";
+ task_info->state = ObjectState::kReportCalled;
+ FCP_ASSIGN_OR_RETURN(
+ ReportTaskResultRequest request,
+ CreateReportTaskResultRequest(phase_outcome, plan_duration,
+ task_info->aggregation_session_id,
+ task_info->task_name));
+ // Construct the URI suffix.
+ FCP_ASSIGN_OR_RETURN(
+ std::string uri_suffix,
+ CreateReportTaskResultUriSuffix(population_name_, task_info->session_id));
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_request,
+ task_assignment_request_creator_->CreateProtocolRequest(
+ uri_suffix, {}, HttpRequest::Method::kPost,
+ request.SerializeAsString(), /*is_protobuf_encoded=*/true));
+
+ // Issue the request.
+ absl::StatusOr<InMemoryHttpResponse> http_response =
+ protocol_request_helper_.PerformProtocolRequest(std::move(http_request),
+ *interruptible_runner_);
+ if (!http_response.ok()) {
+ // If the request failed, we'll forward the error status.
+ return absl::Status(http_response.status().code(),
+ absl::StrCat("ReportTaskResult request failed: ",
+ http_response.status().ToString()));
+ }
+ return absl::OkStatus();
+}
+
+::google::internal::federatedml::v2::RetryWindow
+HttpFederatedProtocol::GetLatestRetryWindow() {
+ ObjectState state = GetTheLatestStateFromAllTasks();
+ // We explicitly enumerate all possible states here rather than using
+ // "default", to ensure that when new states are added later on, the author
+ // is forced to update this method and consider which is the correct
+ // RetryWindow to return.
+ switch (state) {
+ case ObjectState::kCheckinAccepted:
+ case ObjectState::kMultipleTaskAssignmentsAccepted:
+ case ObjectState::kReportCalled:
+ // If a client makes it past the 'checkin acceptance' stage, we use the
+ // 'accepted' RetryWindow unconditionally (unless a permanent error is
+ // encountered). This includes cases where the checkin is accepted, but
+ // the report request results in a (transient) error.
+ FCP_CHECK(retry_times_.has_value());
+ return GenerateRetryWindowFromRetryTime(
+ retry_times_->retry_time_if_accepted);
+ case ObjectState::kEligibilityEvalCheckinRejected:
+ case ObjectState::kEligibilityEvalDisabled:
+ case ObjectState::kEligibilityEvalEnabled:
+ case ObjectState::kCheckinRejected:
+ case ObjectState::kMultipleTaskAssignmentsNoAvailableTask:
+ case ObjectState::kReportMultipleTaskPartialError:
+ FCP_CHECK(retry_times_.has_value());
+ return GenerateRetryWindowFromRetryTime(
+ retry_times_->retry_time_if_rejected);
+ case ObjectState::kInitialized:
+ case ObjectState::kEligibilityEvalCheckinFailed:
+ case ObjectState::kCheckinFailed:
+ case ObjectState::kMultipleTaskAssignmentsFailed:
+ if (retry_times_.has_value()) {
+ // If we already received a server-provided retry window, then use it.
+ return GenerateRetryWindowFromRetryTime(
+ retry_times_->retry_time_if_rejected);
+ }
+ // Otherwise, we generate a retry window using the flag-provided transient
+ // error retry period.
+ return GenerateRetryWindowFromTargetDelay(
+ absl::Seconds(
+ flags_->federated_training_transient_errors_retry_delay_secs()),
+ // NOLINTBEGIN(whitespace/line_length)
+ flags_
+ ->federated_training_transient_errors_retry_delay_jitter_percent(),
+ // NOLINTEND(whitespace/line_length)
+ bit_gen_);
+ case ObjectState::kEligibilityEvalCheckinFailedPermanentError:
+ case ObjectState::kCheckinFailedPermanentError:
+ case ObjectState::kMultipleTaskAssignmentsFailedPermanentError:
+ case ObjectState::kReportFailedPermanentError:
+ // If we encountered a permanent error during the eligibility eval or
+ // regular checkins, then we use the Flags-configured 'permanent error'
+ // retry period. Note that we do so regardless of whether the server had,
+ // by the time the permanent error was received, already returned a
+ // CheckinRequestAck containing a set of retry windows. See note on error
+ // handling at the top of this file.
+ return GenerateRetryWindowFromTargetDelay(
+ absl::Seconds(
+ flags_->federated_training_permanent_errors_retry_delay_secs()),
+ // NOLINTBEGIN(whitespace/line_length)
+ flags_
+ ->federated_training_permanent_errors_retry_delay_jitter_percent(),
+ // NOLINTEND(whitespace/line_length)
+ bit_gen_);
+ }
+}
+
+absl::StatusOr<FederatedProtocol::PlanAndCheckpointPayloads>
+HttpFederatedProtocol::FetchTaskResources(
+ HttpFederatedProtocol::TaskResources task_resources) {
+ FCP_ASSIGN_OR_RETURN(UriOrInlineData plan_uri_or_data,
+ ConvertResourceToUriOrInlineData(task_resources.plan));
+ FCP_ASSIGN_OR_RETURN(
+ UriOrInlineData checkpoint_uri_or_data,
+ ConvertResourceToUriOrInlineData(task_resources.checkpoint));
+
+ // Fetch the plan and init checkpoint resources if they need to be fetched
+ // (using the inline data instead if available).
+ absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
+ resource_responses;
+ {
+ auto started_stopwatch = network_stopwatch_->Start();
+ resource_responses = FetchResourcesInMemory(
+ *http_client_, *interruptible_runner_,
+ {plan_uri_or_data, checkpoint_uri_or_data}, &bytes_downloaded_,
+ &bytes_uploaded_, resource_cache_);
+ }
+ FCP_RETURN_IF_ERROR(resource_responses);
+ auto& plan_data_response = (*resource_responses)[0];
+ auto& checkpoint_data_response = (*resource_responses)[1];
+
+ // Note: we forward any error during the fetching of the plan/checkpoint
+ // resources resources to the caller, which means that these error codes
+ // will be checked against the set of 'permanent' error codes, just like the
+ // errors in response to the protocol request are.
+ if (!plan_data_response.ok()) {
+ return absl::Status(plan_data_response.status().code(),
+ absl::StrCat("plan fetch failed: ",
+ plan_data_response.status().ToString()));
+ }
+ if (!checkpoint_data_response.ok()) {
+ return absl::Status(
+ checkpoint_data_response.status().code(),
+ absl::StrCat("checkpoint fetch failed: ",
+ checkpoint_data_response.status().ToString()));
+ }
+
+ return PlanAndCheckpointPayloads{plan_data_response->body,
+ checkpoint_data_response->body};
+}
+
+absl::StatusOr<PopulationEligibilitySpec>
+HttpFederatedProtocol::FetchPopulationEligibilitySpec(
+ const Resource& population_eligibility_spec_resource) {
+ FCP_ASSIGN_OR_RETURN(
+ UriOrInlineData population_eligibility_spec_uri_or_data,
+ ConvertResourceToUriOrInlineData(population_eligibility_spec_resource));
+
+ // Fetch the plan and init checkpoint resources if they need to be fetched
+ // (using the inline data instead if available).
+ absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
+ resource_responses;
+ {
+ auto started_stopwatch = network_stopwatch_->Start();
+ resource_responses = FetchResourcesInMemory(
+ *http_client_, *interruptible_runner_,
+ {population_eligibility_spec_uri_or_data}, &bytes_downloaded_,
+ &bytes_uploaded_, resource_cache_);
+ }
+ FCP_RETURN_IF_ERROR(resource_responses);
+ auto& response = (*resource_responses)[0];
+
+ // Note: we forward any error during the fetching of the plan/checkpoint
+ // resources resources to the caller, which means that these error codes
+ // will be checked against the set of 'permanent' error codes, just like the
+ // errors in response to the protocol request are.
+ if (!response.ok()) {
+ return absl::Status(
+ response.status().code(),
+ absl::StrCat("population eligibility spec fetch failed: ",
+ response.status().ToString()));
+ }
+ PopulationEligibilitySpec population_eligibility_spec;
+ if (!ParseFromStringOrCord(population_eligibility_spec, response->body)) {
+ return absl::InvalidArgumentError(
+ "Unable to parse PopulationEligibilitySpec.");
+ }
+ return population_eligibility_spec;
+}
+
+void HttpFederatedProtocol::UpdateObjectStateIfPermanentError(
+ absl::Status status,
+ HttpFederatedProtocol::ObjectState permanent_error_object_state) {
+ if (federated_training_permanent_error_codes_.contains(
+ static_cast<int32_t>(status.code()))) {
+ object_state_ = permanent_error_object_state;
+ }
+}
+
+FederatedProtocol::ObjectState
+HttpFederatedProtocol::GetTheLatestStateFromAllTasks() {
+ // If we didn't have successful check-in or multiple task assignments, we
+ // don't have to check the per task states.
+ if (object_state_ != ObjectState::kCheckinAccepted &&
+ object_state_ != ObjectState::kMultipleTaskAssignmentsAccepted) {
+ return object_state_;
+ }
+ if (!flags_->http_protocol_supports_multiple_task_assignments()) {
+ return default_task_info_.state;
+ }
+
+ int32_t success_cnt = 0;
+ int32_t permanent_failure_cnt = 0;
+ int32_t task_cnt = 0;
+ auto count_func = [&success_cnt, &permanent_failure_cnt](ObjectState state) {
+ if (state == ObjectState::kReportCalled) {
+ success_cnt++;
+ }
+ if (state == ObjectState::kReportFailedPermanentError) {
+ permanent_failure_cnt++;
+ }
+ };
+
+ if (default_task_info_.state != ObjectState::kInitialized) {
+ task_cnt++;
+ count_func(default_task_info_.state);
+ }
+
+ for (const auto& item : task_info_map_) {
+ task_cnt++;
+ count_func(item.second.state);
+ }
+
+ // If none of the tasks succeeds, assume all of them failed with permanent
+ // error and return kReportFailedPermanentError. If all of them succeeds,
+ // return kReportCalled. If only some of the tasks succeed, return
+ // kReportMultipleTaskPartialError.
+ if (permanent_failure_cnt == task_cnt) {
+ return ObjectState::kReportFailedPermanentError;
+ } else if (success_cnt == task_cnt) {
+ return ObjectState::kReportCalled;
+ } else {
+ return ObjectState::kReportMultipleTaskPartialError;
+ }
+}
+
+NetworkStats HttpFederatedProtocol::GetNetworkStats() {
+ return {.bytes_downloaded = bytes_downloaded_,
+ .bytes_uploaded = bytes_uploaded_,
+ .network_duration = network_stopwatch_->GetTotalDuration()};
+}
+
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/http_federated_protocol.h b/fcp/client/http/http_federated_protocol.h
new file mode 100644
index 0000000..ccce8eb
--- /dev/null
+++ b/fcp/client/http/http_federated_protocol.h
@@ -0,0 +1,306 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_HTTP_FEDERATED_PROTOCOL_H_
+#define FCP_CLIENT_HTTP_HTTP_FEDERATED_PROTOCOL_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/random/random.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "fcp/base/clock.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/wall_clock_stopwatch.h"
+#include "fcp/client/cache/resource_cache.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/fl_runner.pb.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/in_memory_request_response.h"
+#include "fcp/client/http/protocol_request_helper.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/secagg_runner.h"
+#include "fcp/client/selector_context.pb.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/federatedcompute/common.pb.h"
+#include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
+#include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
+#include "fcp/protos/federatedcompute/task_assignments.pb.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/secagg/client/secagg_client.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+// Implements a single session of the HTTP-based Federated Compute protocol.
+class HttpFederatedProtocol : public fcp::client::FederatedProtocol {
+ public:
+ HttpFederatedProtocol(
+ Clock* clock, LogManager* log_manager, const Flags* flags,
+ HttpClient* http_client,
+ std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
+ SecAggEventPublisher* secagg_event_publisher,
+ absl::string_view entry_point_uri, absl::string_view api_key,
+ absl::string_view population_name, absl::string_view retry_token,
+ absl::string_view client_version,
+ absl::string_view attestation_measurement,
+ std::function<bool()> should_abort, absl::BitGen bit_gen,
+ const InterruptibleRunner::TimingConfig& timing_config,
+ cache::ResourceCache* resource_cache);
+
+ ~HttpFederatedProtocol() override = default;
+
+ absl::StatusOr<fcp::client::FederatedProtocol::EligibilityEvalCheckinResult>
+ EligibilityEvalCheckin(std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback) override;
+
+ void ReportEligibilityEvalError(absl::Status error_status) override;
+
+ absl::StatusOr<fcp::client::FederatedProtocol::CheckinResult> Checkin(
+ const std::optional<
+ google::internal::federatedml::v2::TaskEligibilityInfo>&
+ task_eligibility_info,
+ std::function<void(const TaskAssignment&)> payload_uris_received_callback)
+ override;
+
+ absl::StatusOr<MultipleTaskAssignments> PerformMultipleTaskAssignments(
+ const std::vector<std::string>& task_names) override;
+
+ absl::Status ReportCompleted(
+ ComputationResults results, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) override;
+
+ absl::Status ReportNotCompleted(
+ engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) override;
+
+ google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow()
+ override;
+
+ NetworkStats GetNetworkStats() override;
+
+ private:
+ // Information for a given task.
+ struct PerTaskInfo {
+ std::unique_ptr<ProtocolRequestCreator> aggregation_request_creator;
+ std::unique_ptr<ProtocolRequestCreator> data_upload_request_creator;
+ std::string session_id;
+ // The identifier of the aggregation session we are participating in.
+ std::string aggregation_session_id;
+ // The token authorizing the client to participate in an aggregation
+ // session.
+ std::string aggregation_authorization_token;
+ // The name identifying the task that was assigned.
+ std::string task_name;
+ // Unique identifier for the client's participation in an aggregation
+ // session.
+ std::string aggregation_client_token;
+ // Resource name for the checkpoint in simple aggregation.
+ std::string aggregation_resource_name;
+ // Each task's state is tracked individually starting from the end of
+ // check-in or multiple task assignments. The states from all of the tasks
+ // will be used collectively to determine which retry window to use.
+ ObjectState state = ObjectState::kInitialized;
+ };
+
+ // Helper function to perform an eligibility eval task request and get its
+ // response.
+ absl::StatusOr<InMemoryHttpResponse> PerformEligibilityEvalTaskRequest();
+
+ // Helper function for handling an eligibility eval task response (incl.
+ // fetching any resources, if necessary).
+ absl::StatusOr<fcp::client::FederatedProtocol::EligibilityEvalCheckinResult>
+ HandleEligibilityEvalTaskResponse(
+ absl::StatusOr<InMemoryHttpResponse> http_response,
+ std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback);
+
+ absl::StatusOr<std::unique_ptr<HttpRequest>>
+ CreateReportEligibilityEvalTaskResultRequest(absl::Status status);
+
+ // Helper function to perform an ReportEligibilityEvalResult request.
+ absl::Status ReportEligibilityEvalErrorInternal(absl::Status error_status);
+
+ // Helper function to perform a task assignment request and get its response.
+ absl::StatusOr<InMemoryHttpResponse>
+ PerformTaskAssignmentAndReportEligibilityEvalResultRequests(
+ const std::optional<
+ ::google::internal::federatedml::v2::TaskEligibilityInfo>&
+ task_eligibility_info);
+
+ // Helper function for handling the 'outer' task assignment response, which
+ // consists of an `Operation` which may or may not need to be polled before a
+ // final 'inner' response is available.
+ absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult>
+ HandleTaskAssignmentOperationResponse(
+ absl::StatusOr<InMemoryHttpResponse> http_response,
+ std::function<void(const TaskAssignment&)>
+ payload_uris_received_callback);
+
+ // Helper function for handling an 'inner' task assignment response (i.e.
+ // after the outer `Operation` has concluded). This includes fetching any
+ // resources, if necessary.
+ absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult>
+ HandleTaskAssignmentInnerResponse(
+ const google::internal::federatedcompute::v1::StartTaskAssignmentResponse&
+ response_proto,
+ std::function<void(const TaskAssignment&)>
+ payload_uris_received_callback);
+
+ // Helper function for reporting result via simple aggregation.
+ absl::Status ReportViaSimpleAggregation(ComputationResults results,
+ absl::Duration plan_duration,
+ PerTaskInfo& task_info);
+ // Helper function to perform a StartDataUploadRequest and a ReportTaskResult
+ // request concurrently.
+ // This method will only return the response from the StartDataUploadRequest.
+ absl::StatusOr<InMemoryHttpResponse>
+ PerformStartDataUploadRequestAndReportTaskResult(absl::Duration plan_duration,
+ PerTaskInfo& task_info);
+
+ // Helper function for handling a longrunning operation returned by a
+ // StartDataAggregationUpload request.
+ absl::Status HandleStartDataAggregationUploadOperationResponse(
+ absl::StatusOr<InMemoryHttpResponse> http_response,
+ PerTaskInfo& task_info);
+
+ // Helper function to perform data upload via simple aggregation.
+ absl::Status UploadDataViaSimpleAgg(std::string tf_checkpoint,
+ PerTaskInfo& task_info);
+
+ // Helper function to perform a SubmitAggregationResult request.
+ absl::Status SubmitAggregationResult(PerTaskInfo& task_info);
+
+ // Helper function to perform an AbortAggregation request.
+ // We only provide the server with a simplified error message.
+ absl::Status AbortAggregation(absl::Status original_error_status,
+ absl::string_view error_message_for_server,
+ PerTaskInfo& task_info);
+
+ // Helper function for reporting via secure aggregation.
+ absl::Status ReportViaSecureAggregation(ComputationResults results,
+ absl::Duration plan_duration,
+ PerTaskInfo& task_info);
+
+ // Helper function to perform a StartSecureAggregationRequest and a
+ // ReportTaskResultRequest.
+ absl::StatusOr<
+ google::internal::federatedcompute::v1::StartSecureAggregationResponse>
+ StartSecureAggregationAndReportTaskResult(absl::Duration plan_duration,
+ PerTaskInfo& task_info);
+
+ struct TaskResources {
+ const ::google::internal::federatedcompute::v1::Resource& plan;
+ const ::google::internal::federatedcompute::v1::Resource& checkpoint;
+ };
+
+ // Helper function for fetching the checkpoint/plan resources for an
+ // eligibility eval task or regular task.
+ absl::StatusOr<PlanAndCheckpointPayloads> FetchTaskResources(
+ TaskResources task_resources);
+
+ // Helper function for fetching the PopulationEligibilitySpec.
+ absl::StatusOr<
+ google::internal::federatedcompute::v1::PopulationEligibilitySpec>
+ FetchPopulationEligibilitySpec(
+ const ::google::internal::federatedcompute::v1::Resource&
+ population_eligibility_spec_resource);
+
+ // Helper that moves to the given object state if the given status represents
+ // a permanent error.
+ void UpdateObjectStateIfPermanentError(
+ absl::Status status, ObjectState permanent_error_object_state);
+
+ ObjectState GetTheLatestStateFromAllTasks();
+
+ // This ObjectState tracks states until the end of check-in or multiple task
+ // assignments. Once a task is assigned, the state is tracked inside the
+ // task_info_map_ for multiple task assignments or default_task_info_ for
+ // single task check-in.
+ ObjectState object_state_;
+ Clock& clock_;
+ LogManager* log_manager_;
+ const Flags* const flags_;
+ HttpClient* const http_client_;
+ std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory_;
+ SecAggEventPublisher* secagg_event_publisher_;
+ std::unique_ptr<InterruptibleRunner> interruptible_runner_;
+ std::unique_ptr<ProtocolRequestCreator> eligibility_eval_request_creator_;
+ std::unique_ptr<ProtocolRequestCreator> task_assignment_request_creator_;
+ std::unique_ptr<WallClockStopwatch> network_stopwatch_ =
+ WallClockStopwatch::Create();
+ ProtocolRequestHelper protocol_request_helper_;
+ const std::string api_key_;
+ const std::string population_name_;
+ const std::string retry_token_;
+ const std::string client_version_;
+ const std::string attestation_measurement_;
+ std::function<bool()> should_abort_;
+ absl::BitGen bit_gen_;
+ const InterruptibleRunner::TimingConfig timing_config_;
+ // The graceful waiting period for cancellation requests before checking
+ // whether the client should be interrupted.
+ const absl::Duration waiting_period_for_cancellation_;
+ // The set of canonical error codes that should be treated as 'permanent'
+ // errors.
+ absl::flat_hash_set<int32_t> federated_training_permanent_error_codes_;
+ int64_t bytes_downloaded_ = 0;
+ int64_t bytes_uploaded_ = 0;
+ // Represents 2 absolute retry timestamps to use when the device is rejected
+ // or accepted. The retry timestamps will have been generated based on the
+ // retry windows specified in the server's EligibilityEvalTaskResponse message
+ // and the time at which that message was received.
+ struct RetryTimes {
+ absl::Time retry_time_if_rejected;
+ absl::Time retry_time_if_accepted;
+ };
+ // Represents the information received via the EligibilityEvalTaskResponse
+ // message. This field will have an absent value until that message has been
+ // received.
+ std::optional<RetryTimes> retry_times_;
+ std::string pre_task_assignment_session_id_;
+
+ // A map of aggregation_session_id to per-task information.
+ // Only tasks from the multiple task assignments will be tracked in this map.
+ absl::flat_hash_map<std::string, PerTaskInfo> task_info_map_;
+ // The task received from the regular check-in will be tracked here.
+ PerTaskInfo default_task_info_;
+
+ // Set this field to true if an eligibility eval task was received from the
+ // server in the EligibilityEvalTaskResponse.
+ bool eligibility_eval_enabled_ = false;
+ // `nullptr` if the feature is disabled.
+ cache::ResourceCache* resource_cache_;
+};
+
+} // namespace http
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_HTTP_HTTP_FEDERATED_PROTOCOL_H_
diff --git a/fcp/client/http/http_federated_protocol_test.cc b/fcp/client/http/http_federated_protocol_test.cc
new file mode 100644
index 0000000..bf34047
--- /dev/null
+++ b/fcp/client/http/http_federated_protocol_test.cc
@@ -0,0 +1,3062 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/http_federated_protocol.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/longrunning/operations.pb.h"
+#include "google/protobuf/any.pb.h"
+#include "google/protobuf/duration.pb.h"
+#include "google/rpc/code.pb.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "absl/random/random.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/blocking_counter.h"
+#include "absl/synchronization/notification.h"
+#include "absl/time/time.h"
+#include "fcp/base/clock.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/base/time_util.h"
+#include "fcp/base/wall_clock_stopwatch.h"
+#include "fcp/client/cache/test_helpers.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/federated_protocol_util.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/http_client_util.h"
+#include "fcp/client/http/in_memory_request_response.h"
+#include "fcp/client/http/testing/test_helpers.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/stats.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/federatedcompute/aggregations.pb.h"
+#include "fcp/protos/federatedcompute/common.pb.h"
+#include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
+#include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
+#include "fcp/protos/federatedcompute/task_assignments.pb.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp::client::http {
+namespace {
+
+using ::fcp::EqualsProto;
+using ::fcp::IsCode;
+using ::fcp::client::http::FakeHttpResponse;
+using ::fcp::client::http::MockableHttpClient;
+using ::fcp::client::http::MockHttpClient;
+using ::fcp::client::http::SimpleHttpRequestMatcher;
+using ::google::internal::federatedcompute::v1::ByteStreamResource;
+using ::google::internal::federatedcompute::v1::ClientStats;
+using ::google::internal::federatedcompute::v1::EligibilityEvalTask;
+using ::google::internal::federatedcompute::v1::EligibilityEvalTaskRequest;
+using ::google::internal::federatedcompute::v1::EligibilityEvalTaskResponse;
+using ::google::internal::federatedcompute::v1::ForwardingInfo;
+using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec;
+using ::google::internal::federatedcompute::v1::
+ ReportEligibilityEvalTaskResultRequest;
+using ::google::internal::federatedcompute::v1::ReportTaskResultRequest;
+using ::google::internal::federatedcompute::v1::ReportTaskResultResponse;
+using ::google::internal::federatedcompute::v1::Resource;
+using ::google::internal::federatedcompute::v1::ResourceCompressionFormat;
+using ::google::internal::federatedcompute::v1::RetryWindow;
+using ::google::internal::federatedcompute::v1::SecureAggregandExecutionInfo;
+using ::google::internal::federatedcompute::v1::
+ StartAggregationDataUploadRequest;
+using ::google::internal::federatedcompute::v1::
+ StartAggregationDataUploadResponse;
+using ::google::internal::federatedcompute::v1::StartSecureAggregationRequest;
+using ::google::internal::federatedcompute::v1::StartSecureAggregationResponse;
+using ::google::internal::federatedcompute::v1::StartTaskAssignmentRequest;
+using ::google::internal::federatedcompute::v1::StartTaskAssignmentResponse;
+using ::google::internal::federatedcompute::v1::SubmitAggregationResultRequest;
+using ::google::internal::federatedcompute::v1::TaskAssignment;
+using ::google::internal::federatedml::v2::TaskEligibilityInfo;
+using ::google::internal::federatedml::v2::TaskWeight;
+using ::google::longrunning::GetOperationRequest;
+using ::google::longrunning::Operation;
+using ::testing::_;
+using ::testing::AllOf;
+using ::testing::ByMove;
+using ::testing::DescribeMatcher;
+using ::testing::DoubleEq;
+using ::testing::DoubleNear;
+using ::testing::Eq;
+using ::testing::ExplainMatchResult;
+using ::testing::Field;
+using ::testing::FieldsAre;
+using ::testing::Ge;
+using ::testing::Gt;
+using ::testing::HasSubstr;
+using ::testing::InSequence;
+using ::testing::IsEmpty;
+using ::testing::Lt;
+using ::testing::MockFunction;
+using ::testing::NiceMock;
+using ::testing::Not;
+using ::testing::Optional;
+using ::testing::Return;
+using ::testing::StrEq;
+using ::testing::StrictMock;
+using ::testing::UnorderedElementsAre;
+using ::testing::VariantWith;
+using ::testing::WithArg;
+
+constexpr char kEntryPointUri[] = "https://initial.uri/";
+constexpr char kTaskAssignmentTargetUri[] = "https://taskassignment.uri/";
+constexpr char kAggregationTargetUri[] = "https://aggregation.uri/";
+constexpr char kSecondStageAggregationTargetUri[] =
+ "https://aggregation.second.uri/";
+constexpr char kByteStreamTargetUri[] = "https://bytestream.uri/";
+constexpr char kApiKey[] = "TEST_APIKEY";
+// Note that we include a '/' character in the population name, which allows us
+// to verify that it is correctly URL-encoded into "%2F".
+constexpr char kPopulationName[] = "TEST/POPULATION";
+constexpr char kEligibilityEvalExecutionId[] = "ELIGIBILITY_EXECUTION_ID";
+// Note that we include a '/' and '#' characters in the population name, which
+// allows us to verify that it is correctly URL-encoded into "%2F" and "%23".
+constexpr char kEligibilityEvalSessionId[] = "ELIGIBILITY/SESSION#ID";
+constexpr char kPlan[] = "CLIENT_ONLY_PLAN";
+constexpr char kInitCheckpoint[] = "INIT_CHECKPOINT";
+constexpr char kRetryToken[] = "OLD_RETRY_TOKEN";
+constexpr char kClientVersion[] = "CLIENT_VERSION";
+constexpr char kAttestationMeasurement[] = "ATTESTATION_MEASUREMENT";
+constexpr char kClientSessionId[] = "CLIENT_SESSION_ID";
+constexpr char kAggregationSessionId[] = "AGGREGATION_SESSION_ID";
+constexpr char kAuthorizationToken[] = "AUTHORIZATION_TOKEN";
+constexpr char kTaskName[] = "TASK_NAME";
+constexpr char kClientToken[] = "CLIENT_TOKEN";
+constexpr char kResourceName[] = "CHECKPOINT_RESOURCE";
+constexpr char kFederatedSelectUriTemplate[] = "https://federated.select";
+constexpr char kOperationName[] = "my_operation";
+
+const int32_t kCancellationWaitingPeriodSec = 1;
+const int32_t kMinimumClientsInServerVisibleAggregate = 2;
+
+MATCHER_P(EligibilityEvalTaskRequestMatcher, matcher,
+ absl::StrCat(negation ? "doesn't parse" : "parses",
+ " as an EligibilityEvalTaskRequest, and that ",
+ DescribeMatcher<EligibilityEvalTaskRequest>(matcher,
+ negation))) {
+ EligibilityEvalTaskRequest request;
+ if (!request.ParseFromString(arg)) {
+ return false;
+ }
+ return ExplainMatchResult(matcher, request, result_listener);
+}
+
+MATCHER_P(
+ ReportEligibilityEvalTaskResultRequestMatcher, matcher,
+ absl::StrCat(negation ? "doesn't parse" : "parses",
+ " as a ReportEligibilityEvalTaskResultRequest, and that ",
+ DescribeMatcher<ReportEligibilityEvalTaskResultRequest>(
+ matcher, negation))) {
+ ReportEligibilityEvalTaskResultRequest request;
+ if (!request.ParseFromString(arg)) {
+ return false;
+ }
+ return ExplainMatchResult(matcher, request, result_listener);
+}
+
+MATCHER_P(StartTaskAssignmentRequestMatcher, matcher,
+ absl::StrCat(negation ? "doesn't parse" : "parses",
+ " as a StartTaskAssignmentRequest, and that ",
+ DescribeMatcher<StartTaskAssignmentRequest>(matcher,
+ negation))) {
+ StartTaskAssignmentRequest request;
+ if (!request.ParseFromString(arg)) {
+ return false;
+ }
+ return ExplainMatchResult(matcher, request, result_listener);
+}
+
+MATCHER_P(GetOperationRequestMatcher, matcher,
+ absl::StrCat(negation ? "doesn't parse" : "parses",
+ " as a GetOperationRequest, and that ",
+ DescribeMatcher<GetOperationRequest>(matcher,
+ negation))) {
+ GetOperationRequest request;
+ if (!request.ParseFromString(arg)) {
+ return false;
+ }
+ return ExplainMatchResult(matcher, request, result_listener);
+}
+
+MATCHER_P(ReportTaskResultRequestMatcher, matcher,
+ absl::StrCat(negation ? "doesn't parse" : "parses",
+ " as a ReportTaskResultRequest, and that ",
+ DescribeMatcher<ReportTaskResultRequest>(matcher,
+ negation))) {
+ ReportTaskResultRequest request;
+ if (!request.ParseFromString(arg)) {
+ return false;
+ }
+ return ExplainMatchResult(matcher, request, result_listener);
+}
+
+constexpr int kTransientErrorsRetryPeriodSecs = 10;
+constexpr double kTransientErrorsRetryDelayJitterPercent = 0.1;
+constexpr double kExpectedTransientErrorsRetryPeriodSecsMin = 9.0;
+constexpr double kExpectedTransientErrorsRetryPeriodSecsMax = 11.0;
+constexpr int kPermanentErrorsRetryPeriodSecs = 100;
+constexpr double kPermanentErrorsRetryDelayJitterPercent = 0.2;
+constexpr double kExpectedPermanentErrorsRetryPeriodSecsMin = 80.0;
+constexpr double kExpectedPermanentErrorsRetryPeriodSecsMax = 120.0;
+
+void ExpectTransientErrorRetryWindow(
+ const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
+ // The calculated retry delay must lie within the expected transient errors
+ // retry delay range.
+ EXPECT_THAT(retry_window.delay_min().seconds() +
+ retry_window.delay_min().nanos() / 1000000000,
+ AllOf(Ge(kExpectedTransientErrorsRetryPeriodSecsMin),
+ Lt(kExpectedTransientErrorsRetryPeriodSecsMax)));
+ EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
+}
+
+void ExpectPermanentErrorRetryWindow(
+ const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
+ // The calculated retry delay must lie within the expected permanent errors
+ // retry delay range.
+ EXPECT_THAT(retry_window.delay_min().seconds() +
+ retry_window.delay_min().nanos() / 1000000000,
+ AllOf(Ge(kExpectedPermanentErrorsRetryPeriodSecsMin),
+ Lt(kExpectedPermanentErrorsRetryPeriodSecsMax)));
+ EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
+}
+
+RetryWindow GetAcceptedRetryWindow() {
+ // Must not overlap with kTransientErrorsRetryPeriodSecs or
+ // kPermanentErrorsRetryPeriodSecs.
+ RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(200L);
+ retry_window.mutable_delay_max()->set_seconds(299L);
+ return retry_window;
+}
+
+void ExpectAcceptedRetryWindow(
+ const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
+ // The calculated retry delay must lie within the expected 'rejected' retry
+ // delay range.
+ EXPECT_THAT(retry_window.delay_min().seconds() +
+ retry_window.delay_min().nanos() / 1000000000,
+ AllOf(Ge(200L), Lt(299L)));
+ EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
+}
+
+RetryWindow GetRejectedRetryWindow() {
+ // Must not overlap with kTransientErrorsRetryPeriodSecs or
+ // kPermanentErrorsRetryPeriodSecs.
+ RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(300L);
+ retry_window.mutable_delay_max()->set_seconds(399L);
+ return retry_window;
+}
+
+void ExpectRejectedRetryWindow(
+ const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
+ // The calculated retry delay must lie within the expected 'rejected' retry
+ // delay range.
+ EXPECT_THAT(retry_window.delay_min().seconds() +
+ retry_window.delay_min().nanos() / 1000000000,
+ AllOf(Ge(300L), Lt(399L)));
+ EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
+}
+
+EligibilityEvalTaskRequest GetExpectedEligibilityEvalTaskRequest(
+ bool supports_multiple_task_assignments = false) {
+ EligibilityEvalTaskRequest request;
+ // Note: we don't expect population_name to be set, since it should be set in
+ // the URI instead.
+ request.mutable_client_version()->set_version_code(kClientVersion);
+ request.mutable_attestation_measurement()->set_value(kAttestationMeasurement);
+ request.mutable_resource_capabilities()
+ ->mutable_supported_compression_formats()
+ ->Add(ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+ request.mutable_eligibility_eval_task_capabilities()
+ ->set_supports_multiple_task_assignment(
+ supports_multiple_task_assignments);
+ return request;
+}
+
+EligibilityEvalTaskResponse GetFakeEnabledEligibilityEvalTaskResponse(
+ const Resource& plan, const Resource& checkpoint,
+ const std::string& execution_id,
+ std::optional<Resource> population_eligibility_spec = std::nullopt,
+ const RetryWindow& accepted_retry_window = GetAcceptedRetryWindow(),
+ const RetryWindow& rejected_retry_window = GetRejectedRetryWindow()) {
+ EligibilityEvalTaskResponse response;
+ response.set_session_id(kEligibilityEvalSessionId);
+ EligibilityEvalTask* eval_task = response.mutable_eligibility_eval_task();
+ *eval_task->mutable_plan() = plan;
+ *eval_task->mutable_init_checkpoint() = checkpoint;
+ if (population_eligibility_spec.has_value()) {
+ *eval_task->mutable_population_eligibility_spec() =
+ population_eligibility_spec.value();
+ }
+ eval_task->set_execution_id(execution_id);
+ ForwardingInfo* forwarding_info =
+ response.mutable_task_assignment_forwarding_info();
+ forwarding_info->set_target_uri_prefix(kTaskAssignmentTargetUri);
+ *response.mutable_retry_window_if_accepted() = accepted_retry_window;
+ *response.mutable_retry_window_if_rejected() = rejected_retry_window;
+ return response;
+}
+
+EligibilityEvalTaskResponse GetFakeDisabledEligibilityEvalTaskResponse() {
+ EligibilityEvalTaskResponse response;
+ response.set_session_id(kEligibilityEvalSessionId);
+ response.mutable_no_eligibility_eval_configured();
+ ForwardingInfo* forwarding_info =
+ response.mutable_task_assignment_forwarding_info();
+ forwarding_info->set_target_uri_prefix(kTaskAssignmentTargetUri);
+ *response.mutable_retry_window_if_accepted() = GetAcceptedRetryWindow();
+ *response.mutable_retry_window_if_rejected() = GetRejectedRetryWindow();
+ return response;
+}
+
+EligibilityEvalTaskResponse GetFakeRejectedEligibilityEvalTaskResponse() {
+ EligibilityEvalTaskResponse response;
+ response.mutable_rejection_info();
+ *response.mutable_retry_window_if_accepted() = GetAcceptedRetryWindow();
+ *response.mutable_retry_window_if_rejected() = GetRejectedRetryWindow();
+ return response;
+}
+
+TaskEligibilityInfo GetFakeTaskEligibilityInfo() {
+ TaskEligibilityInfo eligibility_info;
+ TaskWeight* task_weight = eligibility_info.mutable_task_weights()->Add();
+ task_weight->set_task_name("foo");
+ task_weight->set_weight(567.8);
+ return eligibility_info;
+}
+
+StartTaskAssignmentRequest GetExpectedStartTaskAssignmentRequest(
+ const std::optional<TaskEligibilityInfo>& task_eligibility_info) {
+ // Note: we don't expect population_name or session_id to be set, since they
+ // should be set in the URI instead.
+ StartTaskAssignmentRequest request;
+ request.mutable_client_version()->set_version_code(kClientVersion);
+ if (task_eligibility_info.has_value()) {
+ *request.mutable_task_eligibility_info() = *task_eligibility_info;
+ }
+ request.mutable_resource_capabilities()
+ ->mutable_supported_compression_formats()
+ ->Add(ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+ return request;
+}
+
+StartTaskAssignmentResponse GetFakeRejectedTaskAssignmentResponse() {
+ StartTaskAssignmentResponse response;
+ response.mutable_rejection_info();
+ return response;
+}
+
+StartTaskAssignmentResponse GetFakeTaskAssignmentResponse(
+ const Resource& plan, const Resource& checkpoint,
+ const std::string& federated_select_uri_template,
+ const std::string& aggregation_session_id,
+ int32_t minimum_clients_in_server_visible_aggregate) {
+ StartTaskAssignmentResponse response;
+ TaskAssignment* task_assignment = response.mutable_task_assignment();
+ ForwardingInfo* forwarding_info =
+ task_assignment->mutable_aggregation_data_forwarding_info();
+ forwarding_info->set_target_uri_prefix(kAggregationTargetUri);
+ task_assignment->set_session_id(kClientSessionId);
+ task_assignment->set_aggregation_id(aggregation_session_id);
+ task_assignment->set_authorization_token(kAuthorizationToken);
+ task_assignment->set_task_name(kTaskName);
+ *task_assignment->mutable_plan() = plan;
+ *task_assignment->mutable_init_checkpoint() = checkpoint;
+ task_assignment->mutable_federated_select_uri_info()->set_uri_template(
+ federated_select_uri_template);
+ if (minimum_clients_in_server_visible_aggregate > 0) {
+ task_assignment->mutable_secure_aggregation_info()
+ ->set_minimum_clients_in_server_visible_aggregate(
+ minimum_clients_in_server_visible_aggregate);
+ } else {
+ task_assignment->mutable_aggregation_info();
+ }
+ return response;
+}
+
+ReportTaskResultRequest GetExpectedReportTaskResultRequest(
+ absl::string_view aggregation_id, absl::string_view task_name,
+ ::google::rpc::Code code, absl::Duration train_duration) {
+ ReportTaskResultRequest request;
+ request.set_aggregation_id(std::string(aggregation_id));
+ request.set_task_name(std::string(task_name));
+ request.set_computation_status_code(code);
+ ClientStats client_stats;
+ *client_stats.mutable_computation_execution_duration() =
+ TimeUtil::ConvertAbslToProtoDuration(train_duration);
+ *request.mutable_client_stats() = client_stats;
+ return request;
+}
+
+StartAggregationDataUploadResponse GetFakeStartAggregationDataUploadResponse(
+ absl::string_view aggregation_resource_name,
+ absl::string_view byte_stream_uri_prefix,
+ absl::string_view second_stage_aggregation_uri_prefix) {
+ StartAggregationDataUploadResponse response;
+ ByteStreamResource* resource = response.mutable_resource();
+ *resource->mutable_resource_name() = aggregation_resource_name;
+ ForwardingInfo* data_upload_forwarding_info =
+ resource->mutable_data_upload_forwarding_info();
+ *data_upload_forwarding_info->mutable_target_uri_prefix() =
+ byte_stream_uri_prefix;
+ ForwardingInfo* aggregation_protocol_forwarding_info =
+ response.mutable_aggregation_protocol_forwarding_info();
+ *aggregation_protocol_forwarding_info->mutable_target_uri_prefix() =
+ second_stage_aggregation_uri_prefix;
+ response.set_client_token(kClientToken);
+ return response;
+}
+
+FakeHttpResponse CreateEmptySuccessHttpResponse() {
+ return FakeHttpResponse(200, HeaderList(), "");
+}
+
+class HttpFederatedProtocolTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ EXPECT_CALL(mock_flags_,
+ federated_training_transient_errors_retry_delay_secs)
+ .WillRepeatedly(Return(kTransientErrorsRetryPeriodSecs));
+ EXPECT_CALL(mock_flags_,
+ federated_training_transient_errors_retry_delay_jitter_percent)
+ .WillRepeatedly(Return(kTransientErrorsRetryDelayJitterPercent));
+ EXPECT_CALL(mock_flags_,
+ federated_training_permanent_errors_retry_delay_secs)
+ .WillRepeatedly(Return(kPermanentErrorsRetryPeriodSecs));
+ EXPECT_CALL(mock_flags_,
+ federated_training_permanent_errors_retry_delay_jitter_percent)
+ .WillRepeatedly(Return(kPermanentErrorsRetryDelayJitterPercent));
+ EXPECT_CALL(mock_flags_, federated_training_permanent_error_codes)
+ .WillRepeatedly(Return(std::vector<int32_t>{
+ static_cast<int32_t>(absl::StatusCode::kNotFound),
+ static_cast<int32_t>(absl::StatusCode::kInvalidArgument),
+ static_cast<int32_t>(absl::StatusCode::kUnimplemented)}));
+ // Note that we disable compression in test to make it easier to verify the
+ // request body. The compression logic is tested in
+ // in_memory_request_response_test.cc.
+ EXPECT_CALL(mock_flags_, disable_http_request_body_compression)
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(mock_flags_, waiting_period_sec_for_cancellation)
+ .WillRepeatedly(Return(kCancellationWaitingPeriodSec));
+
+ EXPECT_CALL(mock_flags_, http_protocol_supports_multiple_task_assignments)
+ .WillRepeatedly(Return(false));
+
+ // We only initialize federated_protocol_ in this SetUp method, rather than
+ // in the test's constructor, to ensure that we can set mock flag values
+ // before the HttpFederatedProtocol constructor is called. Using
+ // std::unique_ptr conveniently allows us to assign the field a new value
+ // after construction (which we could not do if the field's type was
+ // HttpFederatedProtocol, since it doesn't have copy or move constructors).
+ federated_protocol_ = std::make_unique<HttpFederatedProtocol>(
+ clock_, &mock_log_manager_, &mock_flags_, &mock_http_client_,
+ absl::WrapUnique(mock_secagg_runner_factory_),
+ &mock_secagg_event_publisher_, kEntryPointUri, kApiKey, kPopulationName,
+ kRetryToken, kClientVersion, kAttestationMeasurement,
+ mock_should_abort_.AsStdFunction(), absl::BitGen(),
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ &mock_resource_cache_);
+ }
+
+ void TearDown() override {
+ // Regardless of the outcome of the test (or the protocol interaction being
+ // tested), network usage must always be reflected in the network stats
+ // methods.
+ HttpRequestHandle::SentReceivedBytes sent_received_bytes =
+ mock_http_client_.TotalSentReceivedBytes();
+
+ NetworkStats network_stats = federated_protocol_->GetNetworkStats();
+ EXPECT_EQ(network_stats.bytes_downloaded,
+ sent_received_bytes.received_bytes);
+ EXPECT_EQ(network_stats.bytes_uploaded, sent_received_bytes.sent_bytes);
+ // If any network traffic occurred, we expect to see some time reflected in
+ // the duration.
+ if (network_stats.bytes_uploaded > 0) {
+ EXPECT_THAT(network_stats.network_duration, Gt(absl::ZeroDuration()));
+ }
+ }
+
+ // This function runs a successful EligibilityEvalCheckin() that results in an
+ // eligibility eval payload being returned by the server (if
+ // `eligibility_eval_enabled` is true), or results in a 'no eligibility eval
+ // configured' response (if `eligibility_eval_enabled` is false). This is a
+ // utility function used by Checkin*() tests that depend on a prior,
+ // successful execution of EligibilityEvalCheckin(). It returns a
+ // absl::Status, which the caller should verify is OK using ASSERT_OK.
+ absl::Status RunSuccessfulEligibilityEvalCheckin(
+ bool eligibility_eval_enabled = true) {
+ EligibilityEvalTaskResponse eval_task_response;
+ if (eligibility_eval_enabled) {
+ // We return a fake response which returns the plan/initial checkpoint
+ // data inline, to keep things simple.
+ std::string expected_plan = kPlan;
+ Resource plan_resource;
+ plan_resource.mutable_inline_resource()->set_data(kPlan);
+ std::string expected_checkpoint = kInitCheckpoint;
+ Resource checkpoint_resource;
+ checkpoint_resource.mutable_inline_resource()->set_data(
+ expected_checkpoint);
+ eval_task_response = GetFakeEnabledEligibilityEvalTaskResponse(
+ plan_resource, checkpoint_resource, kEligibilityEvalExecutionId);
+ } else {
+ eval_task_response = GetFakeDisabledEligibilityEvalTaskResponse();
+ }
+ std::string request_uri =
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ request_uri, HttpRequest::Method::kPost, _,
+ EligibilityEvalTaskRequestMatcher(
+ EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ // The 'EET received' callback should be called, even if the task resource
+ // data was available inline.
+ if (eligibility_eval_enabled) {
+ EXPECT_CALL(mock_eet_received_callback_,
+ Call(FieldsAre(FieldsAre("", ""), kEligibilityEvalExecutionId,
+ Eq(std::nullopt))));
+ }
+
+ return federated_protocol_
+ ->EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction())
+ .status();
+ }
+
+ // This function runs a successful Checkin() that results in a
+ // task assignment payload being returned by the server. This is a
+ // utility function used by Report*() tests that depend on a prior,
+ // successful execution of Checkin(). It returns a
+ // absl::Status, which the caller should verify is OK using ASSERT_OK.
+ absl::Status RunSuccessfulCheckin(bool eligibility_eval_enabled = true) {
+ // We return a fake response which returns the plan/initial checkpoint
+ // data inline, to keep things simple.
+ std::string expected_plan = kPlan;
+ std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ std::string expected_checkpoint = kInitCheckpoint;
+ Resource checkpoint_resource;
+ checkpoint_resource.mutable_inline_resource()->set_data(
+ expected_checkpoint);
+ std::string expected_aggregation_session_id = kAggregationSessionId;
+ StartTaskAssignmentResponse task_assignment_response =
+ GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
+ kFederatedSelectUriTemplate,
+ expected_aggregation_session_id, 0);
+
+ std::string request_uri =
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto";
+ TaskEligibilityInfo expected_eligibility_info =
+ GetFakeTaskEligibilityInfo();
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ request_uri, HttpRequest::Method::kPost, _,
+ StartTaskAssignmentRequestMatcher(
+ EqualsProto(GetExpectedStartTaskAssignmentRequest(
+ expected_eligibility_info))))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName, task_assignment_response)
+ .SerializeAsString())));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
+
+ if (eligibility_eval_enabled) {
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(
+ report_eet_request_uri, absl::OkStatus());
+ }
+
+ return federated_protocol_
+ ->Checkin(expected_eligibility_info,
+ mock_task_received_callback_.AsStdFunction())
+ .status();
+ }
+
+ void ExpectSuccessfulReportEligibilityEvalTaskResultRequest(
+ absl::string_view expected_request_uri, absl::Status eet_status) {
+ ReportEligibilityEvalTaskResultRequest report_eet_request;
+ report_eet_request.set_status_code(
+ static_cast<google::rpc::Code>(eet_status.code()));
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ std::string(expected_request_uri), HttpRequest::Method::kPost, _,
+ ReportEligibilityEvalTaskResultRequestMatcher(
+ EqualsProto(report_eet_request)))))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+ }
+
+ void ExpectSuccessfulReportTaskResultRequest(
+ absl::string_view expected_report_result_uri,
+ absl::string_view aggregation_session_id, absl::string_view task_name,
+ absl::Duration plan_duration) {
+ ReportTaskResultResponse report_task_result_response;
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ std::string(expected_report_result_uri),
+ HttpRequest::Method::kPost, _,
+ ReportTaskResultRequestMatcher(
+ EqualsProto(GetExpectedReportTaskResultRequest(
+ aggregation_session_id, task_name,
+ google::rpc::Code::OK, plan_duration))))))
+ .WillOnce(Return(CreateEmptySuccessHttpResponse()));
+ }
+
+ void ExpectSuccessfulStartAggregationDataUploadRequest(
+ absl::string_view expected_start_data_upload_uri,
+ absl::string_view aggregation_resource_name,
+ absl::string_view byte_stream_uri_prefix,
+ absl::string_view second_stage_aggregation_uri_prefix) {
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ std::string(expected_start_data_upload_uri),
+ HttpRequest::Method::kPost, _,
+ StartAggregationDataUploadRequest().SerializeAsString())))
+ .WillOnce(Return(
+ FakeHttpResponse(200, HeaderList(),
+ pending_operation_response.SerializeAsString())));
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _,
+ GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(
+ kOperationName,
+ GetFakeStartAggregationDataUploadResponse(
+ aggregation_resource_name, byte_stream_uri_prefix,
+ second_stage_aggregation_uri_prefix))
+ .SerializeAsString())));
+ }
+
+ void ExpectSuccessfulByteStreamUploadRequest(
+ absl::string_view byte_stream_upload_uri,
+ absl::string_view checkpoint_str) {
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ std::string(byte_stream_upload_uri), HttpRequest::Method::kPost, _,
+ std::string(checkpoint_str))))
+ .WillOnce(Return(CreateEmptySuccessHttpResponse()));
+ }
+
+ void ExpectSuccessfulSubmitAggregationResultRequest(
+ absl::string_view expected_submit_aggregation_result_uri) {
+ SubmitAggregationResultRequest submit_aggregation_result_request;
+ submit_aggregation_result_request.set_resource_name(kResourceName);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ std::string(expected_submit_aggregation_result_uri),
+ HttpRequest::Method::kPost, _,
+ submit_aggregation_result_request.SerializeAsString())))
+ .WillOnce(Return(CreateEmptySuccessHttpResponse()));
+ }
+
+ void ExpectSuccessfulAbortAggregationRequest(absl::string_view base_uri) {
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ absl::StrCat(base_uri, "/v1/aggregations/",
+ "AGGREGATION_SESSION_ID/clients/"
+ "CLIENT_TOKEN:abort?%24alt=proto"),
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(CreateEmptySuccessHttpResponse()));
+ }
+
+ StrictMock<MockHttpClient> mock_http_client_;
+ StrictMock<MockSecAggRunnerFactory>* mock_secagg_runner_factory_ =
+ new StrictMock<MockSecAggRunnerFactory>();
+ StrictMock<MockSecAggEventPublisher> mock_secagg_event_publisher_;
+ StrictMock<MockLogManager> mock_log_manager_;
+ NiceMock<MockFlags> mock_flags_;
+ NiceMock<MockFunction<bool()>> mock_should_abort_;
+ StrictMock<cache::MockResourceCache> mock_resource_cache_;
+ Clock* clock_ = Clock::RealClock();
+ NiceMock<MockFunction<void(
+ const ::fcp::client::FederatedProtocol::EligibilityEvalTask&)>>
+ mock_eet_received_callback_;
+ NiceMock<MockFunction<void(
+ const ::fcp::client::FederatedProtocol::TaskAssignment&)>>
+ mock_task_received_callback_;
+
+ // The class under test.
+ std::unique_ptr<HttpFederatedProtocol> federated_protocol_;
+};
+
+using HttpFederatedProtocolDeathTest = HttpFederatedProtocolTest;
+
+TEST_F(HttpFederatedProtocolTest,
+ TestTransientErrorRetryWindowDifferentAcrossDifferentInstances) {
+ const ::google::internal::federatedml::v2::RetryWindow& retry_window1 =
+ federated_protocol_->GetLatestRetryWindow();
+ ExpectTransientErrorRetryWindow(retry_window1);
+ federated_protocol_.reset(nullptr);
+ mock_secagg_runner_factory_ = new StrictMock<MockSecAggRunnerFactory>();
+
+ // Create a new HttpFederatedProtocol instance. It should not produce the same
+ // retry window value as the one we just got. This is a simple correctness
+ // check to ensure that the value is at least randomly generated (and that we
+ // don't accidentally use the random number generator incorrectly).
+ federated_protocol_ = std::make_unique<HttpFederatedProtocol>(
+ clock_, &mock_log_manager_, &mock_flags_, &mock_http_client_,
+ absl::WrapUnique(mock_secagg_runner_factory_),
+ &mock_secagg_event_publisher_, kEntryPointUri, kApiKey, kPopulationName,
+ kRetryToken, kClientVersion, kAttestationMeasurement,
+ mock_should_abort_.AsStdFunction(), absl::BitGen(),
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ &mock_resource_cache_);
+
+ const ::google::internal::federatedml::v2::RetryWindow& retry_window2 =
+ federated_protocol_->GetLatestRetryWindow();
+ ExpectTransientErrorRetryWindow(retry_window2);
+
+ EXPECT_THAT(retry_window1, Not(EqualsProto(retry_window2)));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestEligibilityEvalCheckinRequestFailsTransientError) {
+ // Make the HTTP client return a 503 Service Unavailable error when the
+ // EligibilityEvalCheckin(...) code issues the control protocol's HTTP
+ // request. This should result in the error being returned as the result.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
+ EXPECT_THAT(eligibility_checkin_result.status().message(),
+ HasSubstr("protocol request failed"));
+ // The original 503 HTTP response code should be included in the message as
+ // well.
+ EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("503"));
+ // No RetryWindows were received from the server, so we expect to get a
+ // RetryWindow generated based on the transient errors retry delay flag.
+ ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestEligibilityEvalCheckinRequestFailsPermanentError) {
+ // Make the HTTP client return a 404 Not Found error when the
+ // EligibilityEvalCheckin(...) code issues the control protocol's HTTP
+ // request. This should result in the error being returned as the result.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
+ EXPECT_THAT(eligibility_checkin_result.status().message(),
+ HasSubstr("protocol request failed"));
+ // The original 404 HTTP response code should be included in the message as
+ // well.
+ EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("404"));
+ // No RetryWindows were received from the server, so we expect to get a
+ // RetryWindow generated based on the *permanent* errors retry delay flag,
+ // since NOT_FOUND is marked as a permanent error in the flags.
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests the case where we get interrupted while waiting for a response to the
+// protocol request in EligibilityEvalCheckin.
+TEST_F(HttpFederatedProtocolTest,
+ TestEligibilityEvalCheckinRequestInterrupted) {
+ absl::Notification request_issued;
+ absl::Notification request_cancelled;
+
+ // Make HttpClient::PerformRequests() block until the counter is decremented.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce([&request_issued, &request_cancelled](
+ MockableHttpClient::SimpleHttpRequest ignored) {
+ request_issued.Notify();
+ request_cancelled.WaitForNotification();
+ return FakeHttpResponse(503, HeaderList(), "");
+ });
+
+ // Make should_abort return false until we know that the request was issued
+ // (i.e. once InterruptibleRunner has actually started running the code it
+ // was given), and then make it return true, triggering an abort sequence and
+ // unblocking the PerformRequests()() call we caused to block above.
+ EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
+ return request_issued.HasBeenNotified();
+ });
+
+ // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
+ // request complete.
+ mock_http_client_.SetCancellationListener(
+ [&request_cancelled]() { request_cancelled.Notify(); });
+
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(CANCELLED));
+ // No RetryWindows were received from the server, so we expect to get a
+ // RetryWindow generated based on the transient errors retry delay flag.
+ ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest, TestEligibilityEvalCheckinRejection) {
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ EligibilityEvalTaskRequestMatcher(
+ EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ GetFakeRejectedEligibilityEvalTaskResponse().SerializeAsString())));
+
+ // The 'eet received' callback should not be invoked since no EET was given to
+ // the client.
+ EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_OK(eligibility_checkin_result);
+ EXPECT_THAT(*eligibility_checkin_result,
+ VariantWith<FederatedProtocol::Rejection>(_));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest, TestEligibilityEvalCheckinDisabled) {
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ EligibilityEvalTaskRequestMatcher(
+ EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ GetFakeDisabledEligibilityEvalTaskResponse().SerializeAsString())));
+
+ // The 'eet received' callback should not be invoked since no EET was given to
+ // the client.
+ EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_OK(eligibility_checkin_result);
+ EXPECT_THAT(*eligibility_checkin_result,
+ VariantWith<FederatedProtocol::EligibilityEvalDisabled>(_));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest, TestEligibilityEvalCheckinEnabled) {
+ // We return a fake response which requires fetching the plan via HTTP, but
+ // which has the checkpoint data available inline.
+ std::string expected_plan = kPlan;
+ std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ std::string expected_checkpoint = kInitCheckpoint;
+ Resource checkpoint_resource;
+ checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
+ std::string expected_execution_id = kEligibilityEvalExecutionId;
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(
+ plan_resource, checkpoint_resource, expected_execution_id);
+
+ InSequence seq;
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ EligibilityEvalTaskRequestMatcher(
+ EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ // The 'EET received' callback should be called *before* the actual task
+ // resources are fetched.
+ EXPECT_CALL(mock_eet_received_callback_,
+ Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
+ Eq(std::nullopt))));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_OK(eligibility_checkin_result);
+ EXPECT_THAT(
+ *eligibility_checkin_result,
+ VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
+ AllOf(Field(&FederatedProtocol::PlanAndCheckpointPayloads::plan,
+ absl::Cord(expected_plan)),
+ Field(&FederatedProtocol::PlanAndCheckpointPayloads::checkpoint,
+ absl::Cord(expected_checkpoint))),
+ expected_execution_id, Eq(std::nullopt))));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestEligibilityEvalCheckinWithPopulationEligibilitySpec) {
+ EXPECT_CALL(mock_flags_, http_protocol_supports_multiple_task_assignments)
+ .WillRepeatedly(Return(true));
+ // We return a fake response which requires fetching the plan via HTTP,
+ // but which has the checkpoint data available inline.
+ std::string expected_plan = kPlan;
+ std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ std::string expected_checkpoint = kInitCheckpoint;
+ Resource checkpoint_resource;
+ checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
+
+ PopulationEligibilitySpec expected_population_eligibility_spec;
+ auto task_info = expected_population_eligibility_spec.add_task_info();
+ task_info->set_task_name("task_1");
+ task_info->set_task_assignment_mode(
+ PopulationEligibilitySpec::TaskInfo::TASK_ASSIGNMENT_MODE_MULTIPLE);
+ std::string population_eligibility_spec_uri =
+ "https://fake.uri/population_eligibility_spec";
+ Resource population_eligibility_spec;
+ population_eligibility_spec.set_uri(population_eligibility_spec_uri);
+ std::string expected_execution_id = kEligibilityEvalExecutionId;
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(
+ plan_resource, checkpoint_resource, expected_execution_id,
+ population_eligibility_spec);
+
+ InSequence seq;
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ EligibilityEvalTaskRequestMatcher(
+ EqualsProto(GetExpectedEligibilityEvalTaskRequest(
+ /* supports_multiple_task_assignments= */ true))))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ // The 'EET received' callback should be called *before* the actual task
+ // resources are fetched.
+ EXPECT_CALL(mock_eet_received_callback_,
+ Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
+ Eq(std::nullopt))));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
+ EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ population_eligibility_spec_uri,
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ expected_population_eligibility_spec.SerializeAsString())));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_OK(eligibility_checkin_result);
+ EXPECT_THAT(
+ *eligibility_checkin_result,
+ VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
+ AllOf(Field(&FederatedProtocol::PlanAndCheckpointPayloads::plan,
+ absl::Cord(expected_plan)),
+ Field(&FederatedProtocol::PlanAndCheckpointPayloads::checkpoint,
+ absl::Cord(expected_checkpoint))),
+ expected_execution_id,
+ Optional(EqualsProto(expected_population_eligibility_spec)))));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestEligibilityEvalCheckinWithPopulationEligibilitySpecInvalidFormat) {
+ EXPECT_CALL(mock_flags_, http_protocol_supports_multiple_task_assignments)
+ .WillRepeatedly(Return(true));
+ // We return a fake response which requires fetching the plan via HTTP,
+ // but which has the checkpoint data available inline.
+ std::string expected_plan = kPlan;
+ std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ std::string expected_checkpoint = kInitCheckpoint;
+ Resource checkpoint_resource;
+ checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
+
+ Resource population_eligibility_spec;
+ population_eligibility_spec.mutable_inline_resource()->set_data(
+ "Invalid_spec");
+ std::string expected_execution_id = kEligibilityEvalExecutionId;
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(
+ plan_resource, checkpoint_resource, expected_execution_id,
+ population_eligibility_spec);
+
+ InSequence seq;
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ EligibilityEvalTaskRequestMatcher(
+ EqualsProto(GetExpectedEligibilityEvalTaskRequest(
+ /* supports_multiple_task_assignments= */ true))))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ // The 'EET received' callback should be called *before* the actual task
+ // resources are fetched.
+ EXPECT_CALL(mock_eet_received_callback_,
+ Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
+ Eq(std::nullopt))));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_THAT(eligibility_checkin_result, IsCode(INVALID_ARGUMENT));
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestEligibilityEvalCheckinEnabledWithCompression) {
+ std::string expected_plan = kPlan;
+ absl::StatusOr<std::string> compressed_plan =
+ internal::CompressWithGzip(expected_plan);
+ ASSERT_OK(compressed_plan);
+ Resource plan_resource;
+ plan_resource.mutable_inline_resource()->set_data(*compressed_plan);
+ plan_resource.mutable_inline_resource()->set_compression_format(
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+ std::string expected_checkpoint = kInitCheckpoint;
+ absl::StatusOr<std::string> compressed_checkpoint =
+ internal::CompressWithGzip(expected_checkpoint);
+ Resource checkpoint_resource;
+ checkpoint_resource.mutable_inline_resource()->set_data(
+ *compressed_checkpoint);
+ checkpoint_resource.mutable_inline_resource()->set_compression_format(
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+ std::string expected_execution_id = kEligibilityEvalExecutionId;
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(
+ plan_resource, checkpoint_resource, expected_execution_id);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ ASSERT_OK(eligibility_checkin_result);
+ EXPECT_THAT(
+ *eligibility_checkin_result,
+ VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
+ AllOf(Field(&FederatedProtocol::PlanAndCheckpointPayloads::plan,
+ absl::Cord(expected_plan)),
+ Field(&FederatedProtocol::PlanAndCheckpointPayloads::checkpoint,
+ absl::Cord(expected_checkpoint))),
+ expected_execution_id, Eq(std::nullopt))));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Ensures that if the plan resource fails to be downloaded, the error is
+// correctly returned from the EligibilityEvalCheckin(...) method.
+TEST_F(HttpFederatedProtocolTest,
+ TestEligibilityEvalCheckinEnabledPlanDataFetchFailed) {
+ std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ std::string checkpoint_uri = "https://fake.uri/checkpoint";
+ Resource checkpoint_resource;
+ checkpoint_resource.set_uri(checkpoint_uri);
+ std::string expected_execution_id = kEligibilityEvalExecutionId;
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(
+ plan_resource, checkpoint_resource, expected_execution_id);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ checkpoint_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ // Mock a failed plan fetch.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ // The 404 error for the resource request should be reflected in the return
+ // value.
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
+ EXPECT_THAT(eligibility_checkin_result.status().message(),
+ HasSubstr("plan fetch failed"));
+ // The original 404 HTTP response code should be included in the message as
+ // well.
+ EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("404"));
+ // Since the error type is considered a permanent error, we should get a
+ // permanent error retry window.
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Ensures that if the checkpoint resource fails to be downloaded, the error is
+// correctly returned from the EligibilityEvalCheckin(...) method.
+TEST_F(HttpFederatedProtocolTest,
+ TestEligibilityEvalCheckinEnabledCheckpointDataFetchFailed) {
+ std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ std::string checkpoint_uri = "https://fake.uri/checkpoint";
+ Resource checkpoint_resource;
+ checkpoint_resource.set_uri(checkpoint_uri);
+ std::string expected_execution_id = kEligibilityEvalExecutionId;
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(
+ plan_resource, checkpoint_resource, expected_execution_id);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ // Mock a failed checkpoint fetch.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ checkpoint_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ // The 503 error for the resource request should be reflected in the return
+ // value.
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
+ EXPECT_THAT(eligibility_checkin_result.status().message(),
+ HasSubstr("checkpoint fetch failed"));
+ // The original 503 HTTP response code should be included in the message as
+ // well.
+ EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("503"));
+ // RetryWindows were received from the server before the error was received,
+ // and the error is considered 'transient', so we expect to get a rejected
+ // RetryWindow.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportEligibilityEvalTaskResult) {
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ReportEligibilityEvalTaskResultRequest report_eet_request;
+ report_eet_request.set_status_code(
+ static_cast<google::rpc::Code>(absl::StatusCode::kCancelled));
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ report_eet_request_uri, HttpRequest::Method::kPost, _,
+ ReportEligibilityEvalTaskResultRequestMatcher(
+ EqualsProto(report_eet_request)))))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ federated_protocol_->ReportEligibilityEvalError(absl::CancelledError());
+}
+
+// Tests that the protocol correctly sanitizes any invalid values it may have
+// received from the server.
+TEST_F(HttpFederatedProtocolTest,
+ TestNegativeMinMaxRetryDelayValueSanitization) {
+ RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(-1);
+ retry_window.mutable_delay_max()->set_seconds(-2);
+
+ // The above retry window's negative min/max values should be clamped to 0.
+ RetryWindow expected_retry_window;
+ expected_retry_window.mutable_delay_min()->set_seconds(0);
+ expected_retry_window.mutable_delay_max()->set_seconds(0);
+
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(
+ Resource(), Resource(), kEligibilityEvalExecutionId, std::nullopt,
+ retry_window, retry_window);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction()));
+
+ const google::internal::federatedml::v2::RetryWindow& actual_retry_window =
+ federated_protocol_->GetLatestRetryWindow();
+ // The above retry window's invalid max value should be clamped to the min
+ // value (minus some errors introduced by the inaccuracy of double
+ // multiplication).
+ EXPECT_THAT(actual_retry_window.delay_min().seconds() +
+ actual_retry_window.delay_min().nanos() / 1000000000.0,
+ DoubleEq(0));
+ EXPECT_THAT(actual_retry_window.delay_max().seconds() +
+ actual_retry_window.delay_max().nanos() / 1000000000.0,
+ DoubleEq(0));
+}
+
+// Tests that the protocol correctly sanitizes any invalid values it may have
+// received from the server.
+TEST_F(HttpFederatedProtocolTest, TestInvalidMaxRetryDelayValueSanitization) {
+ RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(1234);
+ retry_window.mutable_delay_max()->set_seconds(1233); // less than delay_min
+
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(
+ Resource(), Resource(), kEligibilityEvalExecutionId, std::nullopt,
+ retry_window, retry_window);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction()));
+
+ const google::internal::federatedml::v2::RetryWindow& actual_retry_window =
+ federated_protocol_->GetLatestRetryWindow();
+ // The above retry window's invalid max value should be clamped to the min
+ // value (minus some errors introduced by the inaccuracy of double
+ // multiplication). Note that DoubleEq enforces too precise of bounds, so we
+ // use DoubleNear instead.
+ EXPECT_THAT(actual_retry_window.delay_min().seconds() +
+ actual_retry_window.delay_min().nanos() / 1000000000.0,
+ DoubleNear(1234.0, 0.015));
+ EXPECT_THAT(actual_retry_window.delay_max().seconds() +
+ actual_retry_window.delay_max().nanos() / 1000000000.0,
+ DoubleNear(1234.0, 0.015));
+}
+
+TEST_F(HttpFederatedProtocolDeathTest,
+ TestCheckinAfterFailedEligibilityEvalCheckin) {
+ // Make the HTTP client return a 503 Service Unavailable error when the
+ // EligibilityEvalCheckin(...) code issues the protocol HTTP request.
+ // This should result in the error being returned as the result.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
+
+ // A Checkin(...) request should now fail, because Checkin(...) should only
+ // be a called after a successful EligibilityEvalCheckin(...) request.
+ ASSERT_DEATH(
+ {
+ auto unused = federated_protocol_->Checkin(
+ std::nullopt, mock_task_received_callback_.AsStdFunction());
+ },
+ _);
+}
+
+TEST_F(HttpFederatedProtocolDeathTest,
+ TestCheckinAfterEligibilityEvalCheckinRejection) {
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ EligibilityEvalTaskRequestMatcher(
+ EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ GetFakeRejectedEligibilityEvalTaskResponse().SerializeAsString())));
+
+ ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction()));
+
+ // A Checkin(...) request should now fail, because Checkin(...) should only
+ // be a called after a successful EligibilityEvalCheckin(...) request, with a
+ // non-rejection response.
+ ASSERT_DEATH(
+ {
+ auto unused = federated_protocol_->Checkin(
+ std::nullopt, mock_task_received_callback_.AsStdFunction());
+ },
+ _);
+}
+
+TEST_F(HttpFederatedProtocolDeathTest,
+ TestCheckinWithEligibilityInfoAfterEligibilityEvalCheckinDisabled) {
+ ASSERT_OK(
+ RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/false));
+
+ // A Checkin(...) request with a TaskEligibilityInfo argument should now fail,
+ // because such info should only be passed a successful
+ // EligibilityEvalCheckin(...) request with an eligibility eval task in the
+ // response.
+ ASSERT_DEATH(
+ {
+ auto unused = federated_protocol_->Checkin(
+ TaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+ },
+ _);
+}
+
+TEST_F(HttpFederatedProtocolDeathTest, TestCheckinWithMissingEligibilityInfo) {
+ ASSERT_OK(
+ RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/true));
+
+ // A Checkin(...) request with a missing TaskEligibilityInfo should now fail,
+ // as the protocol requires us to provide one based on the plan includes in
+ // the eligibility eval checkin response payload..
+ ASSERT_DEATH(
+ {
+ auto unused = federated_protocol_->Checkin(
+ std::nullopt, mock_task_received_callback_.AsStdFunction());
+ },
+ _);
+}
+
+TEST_F(HttpFederatedProtocolDeathTest,
+ TestCheckinAfterEligibilityEvalResourceDataFetchFailed) {
+ Resource plan_resource;
+ plan_resource.set_uri("https://fake.uri/plan");
+ Resource checkpoint_resource;
+ checkpoint_resource.set_uri("https://fake.uri/checkpoint");
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(
+ plan_resource, checkpoint_resource, kEligibilityEvalExecutionId);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ // Mock a failed plan/resource fetch.
+ EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ _, HttpRequest::Method::kGet, _, "")))
+ .WillRepeatedly(Return(FakeHttpResponse(503, HeaderList(), "")));
+
+ auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction());
+
+ // A Checkin(...) request should now fail, because Checkin(...) should only
+ // be a called after a successful EligibilityEvalCheckin(...) request, with a
+ // non-rejection response.
+ ASSERT_DEATH(
+ {
+ auto unused = federated_protocol_->Checkin(
+ TaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+ },
+ _);
+}
+
+// Ensures that if the HTTP layer returns an error code that maps to a transient
+// error, it is handled correctly
+TEST_F(HttpFederatedProtocolTest, TestCheckinFailsTransientError) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ // Make the HTTP request return an 503 Service Unavailable error when the
+ // Checkin(...) code tries to send its first request. This should result in
+ // the error being returned as the result.
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
+
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
+ // The original 503 HTTP response code should be included in the message as
+ // well.
+ EXPECT_THAT(checkin_result.status().message(), HasSubstr("503"));
+ // RetryWindows were already received from the server during the eligibility
+ // eval checkin, so we expect to get a 'rejected' retry window.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Ensures that if the HTTP layer returns an error code that maps to a permanent
+// error, it is handled correctly.
+TEST_F(HttpFederatedProtocolTest, TestCheckinFailsPermanentErrorFromHttp) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ // Make the HTTP request return an 404 Not Found error when the Checkin(...)
+ // code tries to send its first request. This should result in the error being
+ // returned as the result.
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
+
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
+ // The original 503 HTTP response code should be included in the message as
+ // well.
+ EXPECT_THAT(checkin_result.status().message(), HasSubstr("404"));
+ // Even though RetryWindows were already received from the server during the
+ // eligibility eval checkin, we expect a RetryWindow generated based on the
+ // *permanent* errors retry delay flag, since NOT_FOUND is marked as a
+ // permanent error in the flags, and permanent errors should always result in
+ // permanent error windows (regardless of whether retry windows were already
+ // received).
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Ensures that if the HTTP layer returns a successful response, but it contains
+// an Operation proto with a permanent error, that it is handled correctly.
+TEST_F(HttpFederatedProtocolTest, TestCheckinFailsPermanentErrorFromOperation) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ // Make the HTTP request return successfully, but make it contain an Operation
+ // proto that itself contains a permanent error. This should result in the
+ // error being returned as the result.
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateErrorOperation(kOperationName, absl::StatusCode::kNotFound,
+ "foo")
+ .SerializeAsString())));
+
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+
+ EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
+ EXPECT_THAT(checkin_result.status().message(),
+ HasSubstr("Operation my_operation contained error"));
+ // The original error message should be included in the message as well.
+ EXPECT_THAT(checkin_result.status().message(), HasSubstr("foo"));
+ // Even though RetryWindows were already received from the server during the
+ // eligibility eval checkin, we expect a RetryWindow generated based on the
+ // *permanent* errors retry delay flag, since NOT_FOUND is marked as a
+ // permanent error in the flags, and permanent errors should always result in
+ // permanent error windows (regardless of whether retry windows were already
+ // received).
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests the case where we get interrupted while waiting for a response to the
+// protocol request in Checkin.
+TEST_F(HttpFederatedProtocolTest, TestCheckinInterrupted) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ absl::Notification request_issued;
+ absl::Notification request_cancelled;
+
+ // Make HttpClient::PerformRequests() block until the counter is decremented.
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce([&request_issued, &request_cancelled](
+ MockableHttpClient::SimpleHttpRequest ignored) {
+ request_issued.Notify();
+ request_cancelled.WaitForNotification();
+ return FakeHttpResponse(503, HeaderList(), "");
+ });
+
+ // Make should_abort return false until we know that the request was issued
+ // (i.e. once InterruptibleRunner has actually started running the code it
+ // was given), and then make it return true, triggering an abort sequence and
+ // unblocking the PerformRequests()() call we caused to block above.
+ EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
+ return request_issued.HasBeenNotified();
+ });
+
+ // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
+ // request complete.
+ mock_http_client_.SetCancellationListener([&request_cancelled]() {
+ if (!request_cancelled.HasBeenNotified()) {
+ request_cancelled.Notify();
+ }
+ });
+
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
+
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+ EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
+ // RetryWindows were already received from the server during the eligibility
+ // eval checkin, so we expect to get a 'rejected' retry window.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests the case where we get interrupted during polling of the long running
+// operation.
+TEST_F(HttpFederatedProtocolTest,
+ TestCheckinInterruptedDuringLongRunningOperation) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ absl::Notification request_issued;
+ absl::Notification request_cancelled;
+
+ Operation pending_operation = CreatePendingOperation("operations/foo#bar");
+ // Make HttpClient::PerformRequests() block until the counter is decremented.
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation.SerializeAsString())));
+
+ // Make should_abort return false until we know that the request was issued
+ // (i.e. once InterruptibleRunner has actually started running the code it
+ // was given), and then make it return true, triggering an abort sequence and
+ // unblocking the PerformRequests()() call we caused to block above.
+ EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
+ return request_issued.HasBeenNotified();
+ });
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://taskassignment.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _,
+ GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
+ .WillRepeatedly([&request_issued, &request_cancelled, pending_operation](
+ MockableHttpClient::SimpleHttpRequest ignored) {
+ if (!request_issued.HasBeenNotified()) {
+ request_issued.Notify();
+ }
+ request_cancelled.WaitForNotification();
+ return FakeHttpResponse(200, HeaderList(),
+ pending_operation.SerializeAsString());
+ });
+
+ // Once the client is cancelled, a CancelOperationRequest should still be sent
+ // out before returning to the caller."
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://taskassignment.uri/v1/operations/"
+ "foo%23bar:cancel?%24alt=proto",
+ HttpRequest::Method::kGet, _,
+ GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
+ // request complete.
+ mock_http_client_.SetCancellationListener(
+ [&request_cancelled]() { request_cancelled.Notify(); });
+
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
+
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+ EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
+ // RetryWindows were already received from the server during the eligibility
+ // eval checkin, so we expect to get a 'rejected' retry window.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests the case where we get interrupted during polling of the long-running
+// operation, and the issued cancellation request timed out.
+TEST_F(HttpFederatedProtocolTest, TestCheckinInterruptedCancellationTimeout) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ absl::Notification request_issued;
+ absl::Notification request_cancelled;
+
+ Operation pending_operation = CreatePendingOperation("operations/foo#bar");
+ // Make HttpClient::PerformRequests() block until the counter is decremented.
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation.SerializeAsString())));
+
+ // Make should_abort return false until we know that the request was issued
+ // (i.e. once InterruptibleRunner has actually started running the code it
+ // was given), and then make it return true, triggering an abort sequence and
+ // unblocking the PerformRequests()() call we caused to block above.
+ EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
+ return request_issued.HasBeenNotified();
+ });
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://taskassignment.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _,
+ GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
+ .WillRepeatedly([&request_issued, &request_cancelled, pending_operation](
+ MockableHttpClient::SimpleHttpRequest ignored) {
+ if (!request_issued.HasBeenNotified()) {
+ request_issued.Notify();
+ }
+ request_cancelled.WaitForNotification();
+ return FakeHttpResponse(200, HeaderList(),
+ pending_operation.SerializeAsString());
+ });
+
+ // Once the client is cancelled, a CancelOperationRequest should still be sent
+ // out before returning to the caller."
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://taskassignment.uri/v1/operations/"
+ "foo%23bar:cancel?%24alt=proto",
+ HttpRequest::Method::kGet, _,
+ GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
+ .WillOnce([](MockableHttpClient::SimpleHttpRequest ignored) {
+ // Sleep for 2 seconds before returning the response.
+ absl::SleepFor(absl::Seconds(2));
+ return FakeHttpResponse(200, HeaderList(), "");
+ });
+
+ // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
+ // request complete.
+ mock_http_client_.SetCancellationListener([&request_cancelled]() {
+ if (!request_cancelled.HasBeenNotified()) {
+ request_cancelled.Notify();
+ }
+ });
+
+ // The Interruption log will be logged twice, one for Get operation, the other
+ // for Cancel operation.
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP))
+ .Times(2);
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::HTTP_CANCELLATION_OR_ABORT_REQUEST_FAILED));
+
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+ EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
+ // RetryWindows were already received from the server during the eligibility
+ // eval checkin, so we expect to get a 'rejected' retry window.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests whether 'rejection' responses to the main Checkin(...) request are
+// handled correctly.
+TEST_F(HttpFederatedProtocolTest, TestCheckinRejectionWithTaskEligibilityInfo) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartTaskAssignmentRequestMatcher(
+ EqualsProto(GetExpectedStartTaskAssignmentRequest(
+ expected_eligibility_info))))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName,
+ GetFakeRejectedTaskAssignmentResponse())
+ .SerializeAsString())));
+
+ // The 'task received' callback should not be invoked since no task was given
+ // to the client.
+ EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
+
+ ASSERT_OK(checkin_result.status());
+ EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests whether we can issue a Checkin() request correctly without passing a
+// TaskEligibilityInfo, in the case that the eligibility eval checkin didn't
+// return any eligibility eval task to run.
+TEST_F(HttpFederatedProtocolTest,
+ TestCheckinRejectionWithoutTaskEligibilityInfo) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(
+ RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/false));
+
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartTaskAssignmentRequestMatcher(EqualsProto(
+ GetExpectedStartTaskAssignmentRequest(std::nullopt))))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName,
+ GetFakeRejectedTaskAssignmentResponse())
+ .SerializeAsString())));
+
+ // The 'task received' callback should not be invoked since no task was given
+ // to the client.
+ EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
+
+ // Issue the regular checkin, without a TaskEligibilityInfo (since we didn't
+ // receive an eligibility eval task to run during eligibility eval checkin).
+ auto checkin_result = federated_protocol_->Checkin(
+ std::nullopt, mock_task_received_callback_.AsStdFunction());
+
+ ASSERT_OK(checkin_result.status());
+ EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Tests whether a successful task assignment response is handled correctly.
+TEST_F(HttpFederatedProtocolTest, TestCheckinTaskAssigned) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
+ // We return a fake response which requires fetching the plan via HTTP, but
+ // which has the checkpoint data available inline.
+ std::string expected_plan = kPlan;
+ std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ std::string expected_checkpoint = kInitCheckpoint;
+ Resource checkpoint_resource;
+ checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
+ std::string expected_federated_select_uri_template =
+ kFederatedSelectUriTemplate;
+ std::string expected_aggregation_session_id = kAggregationSessionId;
+
+ InSequence seq;
+ // Note that in this particular test we check that the CheckinRequest is as
+ // expected (in all prior tests we just use the '_' matcher, because the
+ // request isn't really relevant to the test).
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartTaskAssignmentRequestMatcher(
+ EqualsProto(GetExpectedStartTaskAssignmentRequest(
+ expected_eligibility_info))))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName,
+ GetFakeTaskAssignmentResponse(
+ plan_resource, checkpoint_resource,
+ expected_federated_select_uri_template,
+ expected_aggregation_session_id,
+ kMinimumClientsInServerVisibleAggregate))
+ .SerializeAsString())));
+
+ // The 'task received' callback should be called *before* the actual task
+ // resources are fetched.
+ EXPECT_CALL(
+ mock_task_received_callback_,
+ Call(FieldsAre(FieldsAre("", ""), expected_federated_select_uri_template,
+ expected_aggregation_session_id,
+ Optional(FieldsAre(
+ _, Eq(kMinimumClientsInServerVisibleAggregate))))));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
+
+ ASSERT_OK(checkin_result.status());
+ EXPECT_THAT(
+ *checkin_result,
+ VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
+ FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
+ expected_federated_select_uri_template,
+ expected_aggregation_session_id,
+ Optional(
+ FieldsAre(_, Eq(kMinimumClientsInServerVisibleAggregate))))));
+ // The Checkin call is expected to return the accepted retry window from the
+ // response to the first eligibility eval request.
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Ensures that polling the Operation returned by a StartTaskAssignmentRequest
+// works as expected. This serves mostly as a high-level check. Further
+// polling-specific behavior is tested in more detail in
+// ProtocolRequestHelperTest.
+TEST_F(HttpFederatedProtocolTest,
+ TestCheckinTaskAssignedAfterOperationPolling) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ // Make the initial StartTaskAssignmentRequest return a pending Operation
+ // result. Note that we use a '#' character in the operation name to allow us
+ // to verify that it is properly URL-encoded.
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation_response.SerializeAsString())));
+
+ // Then, after letting the operation get polled twice more, eventually return
+ // a fake response.
+ std::string expected_plan = kPlan;
+ Resource plan_resource;
+ plan_resource.mutable_inline_resource()->set_data(expected_plan);
+ std::string expected_checkpoint = kInitCheckpoint;
+ Resource checkpoint_resource;
+ checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
+ std::string expected_federated_select_uri_template =
+ kFederatedSelectUriTemplate;
+ std::string expected_aggregation_session_id = kAggregationSessionId;
+
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://taskassignment.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _,
+ GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation_response.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation_response.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName,
+ GetFakeTaskAssignmentResponse(
+ plan_resource, checkpoint_resource,
+ expected_federated_select_uri_template,
+ expected_aggregation_session_id, 0))
+ .SerializeAsString())));
+
+ // The 'task received' callback should be called, even if the task resource
+ // data was available inline.
+ EXPECT_CALL(
+ mock_task_received_callback_,
+ Call(FieldsAre(FieldsAre("", ""), expected_federated_select_uri_template,
+ expected_aggregation_session_id, Eq(std::nullopt))));
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+
+ ASSERT_OK(checkin_result.status());
+ EXPECT_THAT(
+ *checkin_result,
+ VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
+ FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
+ expected_federated_select_uri_template,
+ expected_aggregation_session_id, Eq(std::nullopt))));
+ // The Checkin call is expected to return the accepted retry window from the
+ // response to the first eligibility eval request.
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Ensures that if the plan resource fails to be downloaded, the error is
+// correctly returned from the Checkin(...) method.
+TEST_F(HttpFederatedProtocolTest, TestCheckinTaskAssignedPlanDataFetchFailed) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ std::string checkpoint_uri = "https://fake.uri/checkpoint";
+ Resource checkpoint_resource;
+ checkpoint_resource.set_uri(checkpoint_uri);
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(
+ kOperationName,
+ GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
+ kFederatedSelectUriTemplate,
+ kAggregationSessionId, 0))
+ .SerializeAsString())));
+
+ // Mock a failed plan fetch.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ checkpoint_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+
+ // The 404 error for the resource request should be reflected in the return
+ // value.
+ EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
+ EXPECT_THAT(checkin_result.status().message(),
+ HasSubstr("plan fetch failed"));
+ EXPECT_THAT(checkin_result.status().message(), HasSubstr("404"));
+ // The Checkin call is expected to return the permanent error retry window,
+ // since 404 maps to a permanent error.
+ ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+// Ensures that if the checkpoint resource fails to be downloaded, the error is
+// correctly returned from the Checkin(...) method.
+TEST_F(HttpFederatedProtocolTest,
+ TestCheckinTaskAssignedCheckpointDataFetchFailed) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ std::string checkpoint_uri = "https://fake.uri/checkpoint";
+ Resource checkpoint_resource;
+ checkpoint_resource.set_uri(checkpoint_uri);
+
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(
+ kOperationName,
+ GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
+ kFederatedSelectUriTemplate,
+ kAggregationSessionId, 0))
+ .SerializeAsString())));
+
+ // Mock a failed checkpoint fetch.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ checkpoint_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ // Issue the regular checkin.
+ auto checkin_result = federated_protocol_->Checkin(
+ GetFakeTaskEligibilityInfo(),
+ mock_task_received_callback_.AsStdFunction());
+
+ // The 503 error for the resource request should be reflected in the return
+ // value.
+ EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
+ EXPECT_THAT(checkin_result.status().message(),
+ HasSubstr("checkpoint fetch failed"));
+ EXPECT_THAT(checkin_result.status().message(), HasSubstr("503"));
+ // The Checkin call is expected to return the rejected retry window from the
+ // response to the first eligibility eval request.
+ ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportCompletedViaSimpleAggSuccess) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ // Create a fake checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ ExpectSuccessfulStartAggregationDataUploadRequest(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+ ExpectSuccessfulByteStreamUploadRequest(
+ "https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw",
+ checkpoint_str);
+ ExpectSuccessfulSubmitAggregationResultRequest(
+ "https://aggregation.second.uri/v1/aggregations/"
+ "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
+
+ EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt));
+}
+
+// TODO(team): Remove this test once client_token is always populated in
+// StartAggregationDataUploadResponse.
+TEST_F(HttpFederatedProtocolTest,
+ TestReportCompletedViaSimpleAggWithoutClientToken) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ // Create a fake checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+
+ StartAggregationDataUploadResponse start_aggregation_data_upload_response =
+ GetFakeStartAggregationDataUploadResponse(
+ kResourceName, kByteStreamTargetUri,
+ kSecondStageAggregationTargetUri);
+ // Omit the client token from the response.
+ start_aggregation_data_upload_response.clear_client_token();
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName,
+ start_aggregation_data_upload_response)
+ .SerializeAsString())));
+
+ ExpectSuccessfulByteStreamUploadRequest(
+ "https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw",
+ checkpoint_str);
+ // SubmitAggregationResult should reuse the authorization token.
+ ExpectSuccessfulSubmitAggregationResultRequest(
+ "https://aggregation.second.uri/v1/aggregations/"
+ "AGGREGATION_SESSION_ID/clients/AUTHORIZATION_TOKEN:submit?%24alt=proto");
+
+ EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt));
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportCompletedViaSecureAgg) {
+ absl::Duration plan_duration = absl::Minutes(5);
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ StartSecureAggregationResponse start_secure_aggregation_response;
+ start_secure_aggregation_response.set_client_token(kClientToken);
+ auto masked_result_resource =
+ start_secure_aggregation_response.mutable_masked_result_resource();
+ masked_result_resource->set_resource_name("masked_resource");
+ masked_result_resource->mutable_data_upload_forwarding_info()
+ ->set_target_uri_prefix("https://bytestream.uri/");
+
+ auto nonmasked_result_resource =
+ start_secure_aggregation_response.mutable_nonmasked_result_resource();
+ nonmasked_result_resource->set_resource_name("nonmasked_resource");
+ nonmasked_result_resource->mutable_data_upload_forwarding_info()
+ ->set_target_uri_prefix("https://bytestream.uri/");
+
+ start_secure_aggregation_response.mutable_secagg_protocol_forwarding_info()
+ ->set_target_uri_prefix("https://secure.aggregations.uri/");
+ auto protocol_execution_info =
+ start_secure_aggregation_response.mutable_protocol_execution_info();
+ protocol_execution_info->set_minimum_surviving_clients_for_reconstruction(
+ 450);
+ protocol_execution_info->set_expected_number_of_clients(500);
+
+ auto secure_aggregands =
+ start_secure_aggregation_response.mutable_secure_aggregands();
+ SecureAggregandExecutionInfo secure_aggregand_execution_info;
+ secure_aggregand_execution_info.set_modulus(9999);
+ (*secure_aggregands)["secagg_tensor"] = secure_aggregand_execution_info;
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/secureaggregations/"
+ "AGGREGATION_SESSION_ID/clients/"
+ "AUTHORIZATION_TOKEN:start?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartSecureAggregationRequest().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar").SerializeAsString())));
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName, start_secure_aggregation_response)
+ .SerializeAsString())));
+
+ // Create a fake checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ results.emplace("secagg_tensor", QuantizedTensor());
+
+ EXPECT_CALL(*mock_secagg_runner_factory_,
+ CreateSecAggRunner(_, _, _, _, _, 500, 450))
+ .WillOnce(WithArg<0>([&](auto send_to_server_impl) {
+ auto mock_secagg_runner =
+ std::make_unique<StrictMock<MockSecAggRunner>>();
+ EXPECT_CALL(*mock_secagg_runner,
+ Run(UnorderedElementsAre(Pair(
+ "secagg_tensor", VariantWith<QuantizedTensor>(FieldsAre(
+ IsEmpty(), 0, IsEmpty()))))))
+ .WillOnce([=,
+ send_to_server_impl = std::move(send_to_server_impl)] {
+ // SecAggSendToServerBase::Send should use the client token. This
+ // needs to be tested here since `send_to_server_impl` should not
+ // be used outside of Run.
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secure.aggregations.uri/v1/secureaggregations/"
+ "AGGREGATION_SESSION_ID/clients/"
+ "CLIENT_TOKEN:abort?%24alt=proto",
+ _, _, _)))
+ .WillOnce(Return(CreateEmptySuccessHttpResponse()));
+ secagg::ClientToServerWrapperMessage abort_message;
+ abort_message.mutable_abort();
+ send_to_server_impl->Send(&abort_message);
+
+ return absl::OkStatus();
+ });
+ return mock_secagg_runner;
+ }));
+
+ EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt));
+}
+
+// TODO(team): Remove this test once client_token is always populated in
+// StartSecureAggregationResponse.
+TEST_F(HttpFederatedProtocolTest,
+ TestReportCompletedViaSecureAggWithoutClientToken) {
+ absl::Duration plan_duration = absl::Minutes(5);
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ StartSecureAggregationResponse start_secure_aggregation_response;
+ // Don't set client_token.
+ auto masked_result_resource =
+ start_secure_aggregation_response.mutable_masked_result_resource();
+ masked_result_resource->set_resource_name("masked_resource");
+ masked_result_resource->mutable_data_upload_forwarding_info()
+ ->set_target_uri_prefix("https://bytestream.uri/");
+
+ auto nonmasked_result_resource =
+ start_secure_aggregation_response.mutable_nonmasked_result_resource();
+ nonmasked_result_resource->set_resource_name("nonmasked_resource");
+ nonmasked_result_resource->mutable_data_upload_forwarding_info()
+ ->set_target_uri_prefix("https://bytestream.uri/");
+
+ start_secure_aggregation_response.mutable_secagg_protocol_forwarding_info()
+ ->set_target_uri_prefix("https://secure.aggregations.uri/");
+ auto protocol_execution_info =
+ start_secure_aggregation_response.mutable_protocol_execution_info();
+ protocol_execution_info->set_minimum_surviving_clients_for_reconstruction(
+ 450);
+ protocol_execution_info->set_expected_number_of_clients(500);
+
+ auto secure_aggregands =
+ start_secure_aggregation_response.mutable_secure_aggregands();
+ SecureAggregandExecutionInfo secure_aggregand_execution_info;
+ secure_aggregand_execution_info.set_modulus(9999);
+ (*secure_aggregands)["secagg_tensor"] = secure_aggregand_execution_info;
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/secureaggregations/"
+ "AGGREGATION_SESSION_ID/clients/"
+ "AUTHORIZATION_TOKEN:start?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartSecureAggregationRequest().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName, start_secure_aggregation_response)
+ .SerializeAsString())));
+
+ // Create a fake checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ results.emplace("secagg_tensor", QuantizedTensor());
+
+ EXPECT_CALL(*mock_secagg_runner_factory_,
+ CreateSecAggRunner(_, _, _, _, _, _, _))
+ .WillOnce(WithArg<0>([&](auto send_to_server_impl) {
+ auto mock_secagg_runner =
+ std::make_unique<StrictMock<MockSecAggRunner>>();
+ EXPECT_CALL(*mock_secagg_runner, Run(_))
+ .WillOnce([=,
+ send_to_server_impl = std::move(send_to_server_impl)] {
+ // SecAggSendToServerBase::Send should reuse the authorization
+ // token. This needs to be tested here since `send_to_server_impl`
+ // should not be used outside of Run.
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secure.aggregations.uri/v1/secureaggregations/"
+ "AGGREGATION_SESSION_ID/clients/"
+ "AUTHORIZATION_TOKEN:abort?%24alt=proto",
+ _, _, _)))
+ .WillOnce(Return(CreateEmptySuccessHttpResponse()));
+ secagg::ClientToServerWrapperMessage abort_message;
+ abort_message.mutable_abort();
+ send_to_server_impl->Send(&abort_message);
+
+ return absl::OkStatus();
+ });
+ return mock_secagg_runner;
+ }));
+
+ EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestReportCompletedViaSecureAggReportTaskResultFailed) {
+ absl::Duration plan_duration = absl::Minutes(5);
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ StartSecureAggregationResponse start_secure_aggregation_response;
+ start_secure_aggregation_response.set_client_token(kClientToken);
+ auto masked_result_resource =
+ start_secure_aggregation_response.mutable_masked_result_resource();
+ masked_result_resource->set_resource_name("masked_resource");
+ masked_result_resource->mutable_data_upload_forwarding_info()
+ ->set_target_uri_prefix("https://bytestream.uri/");
+
+ auto nonmasked_result_resource =
+ start_secure_aggregation_response.mutable_nonmasked_result_resource();
+ nonmasked_result_resource->set_resource_name("nonmasked_resource");
+ nonmasked_result_resource->mutable_data_upload_forwarding_info()
+ ->set_target_uri_prefix("https://bytestream.uri/");
+
+ start_secure_aggregation_response.mutable_secagg_protocol_forwarding_info()
+ ->set_target_uri_prefix("https://secure.aggregations.uri/");
+ auto protocol_execution_info =
+ start_secure_aggregation_response.mutable_protocol_execution_info();
+ protocol_execution_info->set_minimum_surviving_clients_for_reconstruction(
+ 450);
+ protocol_execution_info->set_expected_number_of_clients(500);
+
+ auto secure_aggregands =
+ start_secure_aggregation_response.mutable_secure_aggregands();
+ SecureAggregandExecutionInfo secure_aggregand_execution_info;
+ secure_aggregand_execution_info.set_modulus(9999);
+ (*secure_aggregands)["secagg_tensor"] = secure_aggregand_execution_info;
+
+ // Mock a failed ReportTaskResult request.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ ReportTaskResultRequestMatcher(
+ EqualsProto(GetExpectedReportTaskResultRequest(
+ kAggregationSessionId, kTaskName,
+ google::rpc::Code::OK, plan_duration))))))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED));
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/secureaggregations/"
+ "AGGREGATION_SESSION_ID/clients/"
+ "AUTHORIZATION_TOKEN:start?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartSecureAggregationRequest().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar").SerializeAsString())));
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName, start_secure_aggregation_response)
+ .SerializeAsString())));
+
+ // Create a fake checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ results.emplace("secagg_tensor", QuantizedTensor());
+
+ MockSecAggRunner* mock_secagg_runner = new StrictMock<MockSecAggRunner>();
+ EXPECT_CALL(*mock_secagg_runner_factory_,
+ CreateSecAggRunner(_, _, _, _, _, 500, 450))
+ .WillOnce(Return(ByMove(absl::WrapUnique(mock_secagg_runner))));
+ EXPECT_CALL(*mock_secagg_runner,
+ Run(UnorderedElementsAre(
+ Pair("secagg_tensor", VariantWith<QuantizedTensor>(FieldsAre(
+ IsEmpty(), 0, IsEmpty()))))))
+ .WillOnce(Return(absl::OkStatus()));
+
+ EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt));
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportCompletedStartSecAggFailed) {
+ absl::Duration plan_duration = absl::Minutes(5);
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/secureaggregations/"
+ "AGGREGATION_SESSION_ID/clients/"
+ "AUTHORIZATION_TOKEN:start?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartSecureAggregationRequest().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
+ "Request failed.")
+ .SerializeAsString())));
+
+ // Create a fake checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ results.emplace("secagg_tensor", QuantizedTensor());
+
+ EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt),
+ IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestReportCompletedStartSecAggFailedImmediately) {
+ absl::Duration plan_duration = absl::Minutes(5);
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/secureaggregations/"
+ "AGGREGATION_SESSION_ID/clients/"
+ "AUTHORIZATION_TOKEN:start?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartSecureAggregationRequest().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(403, HeaderList(), "")));
+
+ // Create a fake checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ results.emplace("secagg_tensor", QuantizedTensor());
+
+ EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt),
+ IsCode(absl::StatusCode::kPermissionDenied));
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportCompletedReportTaskResultFailed) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ // Create a fake checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ // Mock a failed ReportTaskResult request.
+ ReportTaskResultResponse report_task_result_response;
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ ReportTaskResultRequestMatcher(
+ EqualsProto(GetExpectedReportTaskResultRequest(
+ kAggregationSessionId, kTaskName,
+ google::rpc::Code::OK, plan_duration))))))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED));
+
+ ExpectSuccessfulStartAggregationDataUploadRequest(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+ ExpectSuccessfulByteStreamUploadRequest(
+ "https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw",
+ checkpoint_str);
+ ExpectSuccessfulSubmitAggregationResultRequest(
+ "https://aggregation.second.uri/v1/aggregations/"
+ "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
+
+ // Despite the ReportTaskResult request failed, we still consider the overall
+ // ReportCompleted succeeded because the rest of the steps succeeds, and the
+ // ReportTaskResult is a just a metric reporting on a best effort basis.
+ EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestReportCompletedStartAggregationFailedImmediately) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ std::string checkpoint_str;
+ const size_t kTFCheckpointSize = 32;
+ checkpoint_str.resize(kTFCheckpointSize, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartAggregationDataUploadRequest().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
+ absl::Status report_result = federated_protocol_->ReportCompleted(
+ std::move(results), plan_duration, std::nullopt);
+ ASSERT_THAT(report_result, IsCode(absl::StatusCode::kUnavailable));
+ EXPECT_THAT(report_result.message(),
+ HasSubstr("StartAggregationDataUpload request failed"));
+ EXPECT_THAT(report_result.message(), HasSubstr("503"));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestReportCompletedStartAggregationFailedDuringPolling) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ std::string checkpoint_str;
+ const size_t kTFCheckpointSize = 32;
+ checkpoint_str.resize(kTFCheckpointSize, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ StartAggregationDataUploadRequest().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation_response.SerializeAsString())));
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _,
+ GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
+ .WillOnce(Return(FakeHttpResponse(401, HeaderList())));
+ absl::Status report_result = federated_protocol_->ReportCompleted(
+ std::move(results), plan_duration, std::nullopt);
+ ASSERT_THAT(report_result, IsCode(absl::StatusCode::kUnauthenticated));
+ EXPECT_THAT(report_result.message(),
+ HasSubstr("StartAggregationDataUpload request failed"));
+ EXPECT_THAT(report_result.message(), HasSubstr("401"));
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportCompletedUploadFailed) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ std::string checkpoint_str;
+ const size_t kTFCheckpointSize = 32;
+ checkpoint_str.resize(kTFCheckpointSize, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ ExpectSuccessfulStartAggregationDataUploadRequest(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ StrEq("https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw"),
+ HttpRequest::Method::kPost, _, std::string(checkpoint_str))))
+ .WillOnce(Return(FakeHttpResponse(501, HeaderList())));
+ ExpectSuccessfulAbortAggregationRequest("https://aggregation.second.uri");
+ absl::Status report_result = federated_protocol_->ReportCompleted(
+ std::move(results), plan_duration, std::nullopt);
+ ASSERT_THAT(report_result, IsCode(absl::StatusCode::kUnimplemented));
+ EXPECT_THAT(report_result.message(), HasSubstr("Data upload failed"));
+ EXPECT_THAT(report_result.message(), HasSubstr("501"));
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportCompletedUploadAbortedByServer) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ std::string checkpoint_str;
+ const size_t kTFCheckpointSize = 32;
+ checkpoint_str.resize(kTFCheckpointSize, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ ExpectSuccessfulStartAggregationDataUploadRequest(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ StrEq("https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw"),
+ HttpRequest::Method::kPost, _, std::string(checkpoint_str))))
+ .WillOnce(Return(FakeHttpResponse(
+ 409, HeaderList(),
+ CreateErrorOperation(kOperationName, absl::StatusCode::kAborted,
+ "The client update is no longer needed.")
+ .SerializeAsString())));
+ absl::Status report_result = federated_protocol_->ReportCompleted(
+ std::move(results), plan_duration, std::nullopt);
+ ASSERT_THAT(report_result, IsCode(absl::StatusCode::kAborted));
+ EXPECT_THAT(report_result.message(), HasSubstr("Data upload failed"));
+ EXPECT_THAT(report_result.message(), HasSubstr("409"));
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportCompletedUploadInterrupted) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ std::string checkpoint_str;
+ const size_t kTFCheckpointSize = 32;
+ checkpoint_str.resize(kTFCheckpointSize, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ ExpectSuccessfulStartAggregationDataUploadRequest(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+ absl::Notification request_issued;
+ absl::Notification request_cancelled;
+
+ // Make HttpClient::PerformRequests() block until the counter is decremented.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ StrEq("https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw"),
+ HttpRequest::Method::kPost, _, std::string(checkpoint_str))))
+ .WillOnce([&request_issued, &request_cancelled](
+ MockableHttpClient::SimpleHttpRequest ignored) {
+ request_issued.Notify();
+ request_cancelled.WaitForNotification();
+ return FakeHttpResponse(503, HeaderList(), "");
+ });
+ // Make should_abort return false until we know that the request was issued
+ // (i.e. once InterruptibleRunner has actually started running the code it
+ // was given), and then make it return true, triggering an abort sequence and
+ // unblocking the PerformRequests()() call we caused to block above.
+ EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
+ return request_issued.HasBeenNotified();
+ });
+
+ // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
+ // request complete.
+ mock_http_client_.SetCancellationListener([&request_cancelled]() {
+ if (!request_cancelled.HasBeenNotified()) {
+ request_cancelled.Notify();
+ }
+ });
+
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
+ ExpectSuccessfulAbortAggregationRequest("https://aggregation.second.uri");
+ absl::Status report_result = federated_protocol_->ReportCompleted(
+ std::move(results), plan_duration, std::nullopt);
+ ASSERT_THAT(report_result, IsCode(absl::StatusCode::kCancelled));
+ EXPECT_THAT(report_result.message(), HasSubstr("Data upload failed"));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestReportCompletedSubmitAggregationResultFailed) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+
+ std::string checkpoint_str;
+ const size_t kTFCheckpointSize = 32;
+ checkpoint_str.resize(kTFCheckpointSize, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ ExpectSuccessfulStartAggregationDataUploadRequest(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+ ExpectSuccessfulByteStreamUploadRequest(
+ "https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw",
+ checkpoint_str);
+
+ SubmitAggregationResultRequest submit_aggregation_result_request;
+ submit_aggregation_result_request.set_resource_name(kResourceName);
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://aggregation.second.uri/v1/aggregations/"
+ "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ submit_aggregation_result_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(409, HeaderList())));
+ absl::Status report_result = federated_protocol_->ReportCompleted(
+ std::move(results), plan_duration, std::nullopt);
+
+ ASSERT_THAT(report_result, IsCode(absl::StatusCode::kAborted));
+ EXPECT_THAT(report_result.message(),
+ HasSubstr("SubmitAggregationResult failed"));
+ EXPECT_THAT(report_result.message(), HasSubstr("409"));
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportNotCompletedSuccess) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ absl::Duration plan_duration = absl::Minutes(5);
+ ReportTaskResultResponse response;
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ HttpRequest::Method::kPost, _,
+ ReportTaskResultRequestMatcher(
+ EqualsProto(GetExpectedReportTaskResultRequest(
+ kAggregationSessionId, kTaskName,
+ ::google::rpc::Code::INTERNAL, plan_duration))))))
+ .WillOnce(Return(
+ FakeHttpResponse(200, HeaderList(), response.SerializeAsString())));
+
+ ASSERT_OK(federated_protocol_->ReportNotCompleted(
+ engine::PhaseOutcome::ERROR, plan_duration, std::nullopt));
+
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportNotCompletedError) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ ReportTaskResultResponse response;
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
+
+ absl::Status status = federated_protocol_->ReportNotCompleted(
+ engine::PhaseOutcome::ERROR, absl::Minutes(5), std::nullopt);
+ EXPECT_THAT(status, IsCode(UNAVAILABLE));
+ EXPECT_THAT(
+ status.message(),
+ AllOf(HasSubstr("ReportTaskResult request failed:"), HasSubstr("503")));
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest, TestReportNotCompletedPermanentError) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ ReportTaskResultResponse response;
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ HttpRequest::Method::kPost, _, _)))
+ .WillOnce(Return(FakeHttpResponse(404, HeaderList())));
+
+ absl::Status status = federated_protocol_->ReportNotCompleted(
+ engine::PhaseOutcome::ERROR, absl::Minutes(5), std::nullopt);
+ EXPECT_THAT(status, IsCode(NOT_FOUND));
+ EXPECT_THAT(
+ status.message(),
+ AllOf(HasSubstr("ReportTaskResult request failed:"), HasSubstr("404")));
+ ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestClientDecodedResourcesEnabledDeclaresSupport) {
+ EligibilityEvalTaskRequest expected_eligibility_request;
+ expected_eligibility_request.mutable_client_version()->set_version_code(
+ kClientVersion);
+ expected_eligibility_request.mutable_attestation_measurement()->set_value(
+ kAttestationMeasurement);
+ // Make sure gzip support is declared in the eligibility eval checkin request.
+ expected_eligibility_request.mutable_resource_capabilities()
+ ->add_supported_compression_formats(
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+ expected_eligibility_request.mutable_eligibility_eval_task_capabilities()
+ ->set_supports_multiple_task_assignment(false);
+
+ // Issue an eligibility eval checkin so we can validate the field is set.
+ Resource eligibility_plan_resource;
+ eligibility_plan_resource.mutable_inline_resource()->set_data(kPlan);
+ Resource checkpoint_resource;
+ checkpoint_resource.mutable_inline_resource()->set_data(kInitCheckpoint);
+
+ EligibilityEvalTaskResponse eval_task_response =
+ GetFakeEnabledEligibilityEvalTaskResponse(eligibility_plan_resource,
+ checkpoint_resource,
+ kEligibilityEvalExecutionId);
+ const std::string eligibility_request_uri =
+ "https://initial.uri/v1/eligibilityevaltasks/"
+ "TEST%2FPOPULATION:request?%24alt=proto";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ eligibility_request_uri, HttpRequest::Method::kPost, _,
+ EligibilityEvalTaskRequestMatcher(
+ EqualsProto(expected_eligibility_request)))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), eval_task_response.SerializeAsString())));
+
+ ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
+ mock_eet_received_callback_.AsStdFunction()));
+
+ // Now issue a regular checkin and make sure the field is set there too.
+ const std::string plan_uri = "https://fake.uri/plan";
+ Resource plan_resource;
+ plan_resource.set_uri(plan_uri);
+ StartTaskAssignmentResponse task_assignment_response =
+ GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
+ kFederatedSelectUriTemplate,
+ kAggregationSessionId, 0);
+ const std::string request_uri =
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto";
+ TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
+ StartTaskAssignmentRequest expected_request;
+ expected_request.mutable_client_version()->set_version_code(kClientVersion);
+ *expected_request.mutable_task_eligibility_info() = expected_eligibility_info;
+ // Make sure gzip support is declared in the regular checkin request.
+ expected_request.mutable_resource_capabilities()
+ ->add_supported_compression_formats(
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ request_uri, HttpRequest::Method::kPost, _,
+ StartTaskAssignmentRequestMatcher(EqualsProto(expected_request)))))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateDoneOperation(kOperationName, task_assignment_response)
+ .SerializeAsString())));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ plan_uri, HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), kPlan)));
+
+ std::string report_eet_request_uri =
+ "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
+ "eligibilityevaltasks/"
+ "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
+ ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
+ absl::OkStatus());
+
+ ASSERT_OK(federated_protocol_->Checkin(
+ expected_eligibility_info, mock_task_received_callback_.AsStdFunction()));
+}
+
+} // anonymous namespace
+} // namespace fcp::client::http
diff --git a/fcp/client/http/http_resource_metadata.proto b/fcp/client/http/http_resource_metadata.proto
new file mode 100644
index 0000000..cc4f11a
--- /dev/null
+++ b/fcp/client/http/http_resource_metadata.proto
@@ -0,0 +1,29 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.client.http;
+
+message HttpResourceMetadata {
+ // The format in which the stored resource is compressed.
+ ResourceCompressionFormat compression_format = 1;
+}
+
+// Different file formats that may be used to compress resources.
+enum ResourceCompressionFormat {
+ RESOURCE_COMPRESSION_FORMAT_UNSPECIFIED = 0;
+ // Gzip-compressed data.
+ RESOURCE_COMPRESSION_FORMAT_GZIP = 1;
+}
diff --git a/fcp/client/http/http_secagg_send_to_server_impl.cc b/fcp/client/http/http_secagg_send_to_server_impl.cc
new file mode 100644
index 0000000..7afeeb2
--- /dev/null
+++ b/fcp/client/http/http_secagg_send_to_server_impl.cc
@@ -0,0 +1,452 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/http_secagg_send_to_server_impl.h"
+
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+// #include "google/rpc/code.pb.h"
+#include "absl/strings/substitute.h"
+#include "fcp/client/http/http_client_util.h"
+#include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+using ::google::internal::federatedcompute::v1::AbortSecureAggregationRequest;
+using ::google::internal::federatedcompute::v1::AdvertiseKeysRequest;
+using ::google::internal::federatedcompute::v1::AdvertiseKeysResponse;
+using ::google::internal::federatedcompute::v1::ByteStreamResource;
+using ::google::internal::federatedcompute::v1::ForwardingInfo;
+using ::google::internal::federatedcompute::v1::ShareKeysRequest;
+using ::google::internal::federatedcompute::v1::ShareKeysResponse;
+using ::google::internal::federatedcompute::v1::
+ SubmitSecureAggregationResultRequest;
+using ::google::internal::federatedcompute::v1::
+ SubmitSecureAggregationResultResponse;
+using ::google::internal::federatedcompute::v1::UnmaskRequest;
+// using ::google::longrunning::Operation;
+
+namespace {
+absl::StatusOr<std::string> CreateAbortSecureAggregationUriSuffix(
+ absl::string_view aggregation_id, absl::string_view client_token) {
+ constexpr absl::string_view pattern =
+ "/v1/secureaggregations/$0/clients/$1:abort";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
+ EncodeUriSinglePathSegment(aggregation_id));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
+ EncodeUriSinglePathSegment(client_token));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_aggregation_id,
+ encoded_client_token);
+}
+
+absl::StatusOr<std::string> CreateAdvertiseKeysUriSuffix(
+ absl::string_view aggregation_id, absl::string_view client_token) {
+ constexpr absl::string_view pattern =
+ "/v1/secureaggregations/$0/clients/$1:advertisekeys";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
+ EncodeUriSinglePathSegment(aggregation_id));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
+ EncodeUriSinglePathSegment(client_token));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_aggregation_id,
+ encoded_client_token);
+}
+
+absl::StatusOr<std::string> CreateShareKeysUriSuffix(
+ absl::string_view aggregation_id, absl::string_view client_token) {
+ constexpr absl::string_view pattern =
+ "/v1/secureaggregations/$0/clients/$1:sharekeys";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
+ EncodeUriSinglePathSegment(aggregation_id));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
+ EncodeUriSinglePathSegment(client_token));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_aggregation_id,
+ encoded_client_token);
+}
+
+absl::StatusOr<std::string> CreateSubmitSecureAggregationResultUriSuffix(
+ absl::string_view aggregation_id, absl::string_view client_token) {
+ constexpr absl::string_view pattern =
+ "/v1/secureaggregations/$0/clients/$1:submit";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
+ EncodeUriSinglePathSegment(aggregation_id));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
+ EncodeUriSinglePathSegment(client_token));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_aggregation_id,
+ encoded_client_token);
+}
+
+absl::StatusOr<std::string> CreateUnmaskUriSuffix(
+ absl::string_view aggregation_id, absl::string_view client_token) {
+ constexpr absl::string_view pattern =
+ "/v1/secureaggregations/$0/clients/$1:unmask";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
+ EncodeUriSinglePathSegment(aggregation_id));
+ FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
+ EncodeUriSinglePathSegment(client_token));
+ // Construct the URI suffix.
+ return absl::Substitute(pattern, encoded_aggregation_id,
+ encoded_client_token);
+}
+
+} // anonymous namespace
+
+absl::StatusOr<std::unique_ptr<HttpSecAggSendToServerImpl>>
+HttpSecAggSendToServerImpl::Create(
+ absl::string_view api_key, Clock* clock,
+ ProtocolRequestHelper* request_helper,
+ InterruptibleRunner* interruptible_runner,
+ std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)>
+ delayed_interruptible_runner_creator,
+ absl::StatusOr<secagg::ServerToClientWrapperMessage>*
+ server_response_holder,
+ absl::string_view aggregation_id, absl::string_view client_token,
+ const ForwardingInfo& secagg_upload_forwarding_info,
+ const ByteStreamResource& masked_result_resource,
+ const ByteStreamResource& nonmasked_result_resource,
+ std::optional<std::string> tf_checkpoint,
+ bool disable_request_body_compression,
+ absl::Duration waiting_period_for_cancellation) {
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<ProtocolRequestCreator> secagg_request_creator,
+ ProtocolRequestCreator::Create(api_key, secagg_upload_forwarding_info,
+ !disable_request_body_compression));
+ // We don't use request body compression for resource upload.
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<ProtocolRequestCreator>
+ masked_result_upload_request_creator,
+ ProtocolRequestCreator::Create(
+ api_key, masked_result_resource.data_upload_forwarding_info(),
+ /*use_compression=*/false));
+ // We don't use request body compression for resource upload.
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<ProtocolRequestCreator>
+ nonmasked_result_upload_request_creator,
+ ProtocolRequestCreator::Create(
+ api_key, nonmasked_result_resource.data_upload_forwarding_info(),
+ /*use_compression=*/false));
+
+ return absl::WrapUnique(new HttpSecAggSendToServerImpl(
+ api_key, clock, request_helper, interruptible_runner,
+ std::move(delayed_interruptible_runner_creator), server_response_holder,
+ aggregation_id, client_token, masked_result_resource.resource_name(),
+ nonmasked_result_resource.resource_name(),
+ std::move(secagg_request_creator),
+ std::move(masked_result_upload_request_creator),
+ std::move(nonmasked_result_upload_request_creator),
+ std::move(tf_checkpoint), waiting_period_for_cancellation));
+}
+
+// Despite the method name is "Send", this method is doing more. It sends the
+// request, waits for the response and set the response to the response holder
+// for the secagg client to access in the next round of secagg communications.
+//
+// The current SecAgg library is built around the assumption that the underlying
+// network protocol is fully asynchronous and bidirectional. This was true for
+// the gRPC protocol but isn't the case anymore for the HTTP protocol (which has
+// a more traditional request/response structure). Nevertheless, because we
+// still need to support the gRPC protocol the structure of the SecAgg library
+// cannot be changed yet, and this means that currently we need to store away
+// the result and let the secagg client to access on a later time. However, once
+// the gRPC protocol support is removed, we should consider updating the SecAgg
+// library to assume the more traditional request/response structure (e.g. by
+// having SecAggSendToServer::Send return the corresponding response message).
+//
+// TODO(team): Simplify SecAgg library around request/response structure
+// once gRPC support is removed.
+void HttpSecAggSendToServerImpl::Send(
+ secagg::ClientToServerWrapperMessage* message) {
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> server_message;
+ if (message->has_advertise_keys()) {
+ server_response_holder_ =
+ DoR0AdvertiseKeys(std::move(message->advertise_keys()));
+ } else if (message->has_share_keys_response()) {
+ server_response_holder_ =
+ DoR1ShareKeys(std::move(message->share_keys_response()));
+ } else if (message->has_masked_input_response()) {
+ server_response_holder_ = DoR2SubmitSecureAggregationResult(
+ std::move(message->masked_input_response()));
+ } else if (message->has_unmasking_response()) {
+ server_response_holder_ =
+ DoR3Unmask(std::move(message->unmasking_response()));
+ } else if (message->has_abort()) {
+ server_response_holder_ =
+ AbortSecureAggregation(std::move(message->abort()));
+ } else {
+ // When the protocol succeeds, the ClientToServerWrapperMessage will be
+ // empty, and we'll just set the empty server message.
+ server_response_holder_ = secagg::ServerToClientWrapperMessage();
+ }
+}
+
+absl::StatusOr<secagg::ServerToClientWrapperMessage>
+HttpSecAggSendToServerImpl::AbortSecureAggregation(
+ secagg::AbortMessage abort_message) {
+ FCP_ASSIGN_OR_RETURN(
+ std::string uri_suffix,
+ CreateAbortSecureAggregationUriSuffix(aggregation_id_, client_token_));
+
+ AbortSecureAggregationRequest request;
+ ::google::internal::federatedcompute::v1::Status* status =
+ request.mutable_status();
+ status->set_code(13);
+ status->set_message(abort_message.diagnostic_info());
+
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_request,
+ secagg_request_creator_->CreateProtocolRequest(
+ uri_suffix, QueryParams(), HttpRequest::Method::kPost,
+ request.SerializeAsString(),
+ /*is_protobuf_encoded=*/true));
+ std::unique_ptr<InterruptibleRunner> delayed_interruptible_runner =
+ delayed_interruptible_runner_creator_(clock_.Now() +
+ waiting_period_for_cancellation_);
+ FCP_ASSIGN_OR_RETURN(
+ InMemoryHttpResponse response,
+ request_helper_.PerformProtocolRequest(std::move(http_request),
+ *delayed_interruptible_runner));
+
+ secagg::ServerToClientWrapperMessage server_message;
+ server_message.mutable_abort();
+ return server_message;
+}
+
+absl::StatusOr<secagg::ServerToClientWrapperMessage>
+HttpSecAggSendToServerImpl::DoR0AdvertiseKeys(
+ secagg::AdvertiseKeys advertise_keys) {
+ FCP_ASSIGN_OR_RETURN(
+ std::string uri_suffix,
+ CreateAdvertiseKeysUriSuffix(aggregation_id_, client_token_));
+
+ AdvertiseKeysRequest request;
+ *request.mutable_advertise_keys() = advertise_keys;
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_request,
+ secagg_request_creator_->CreateProtocolRequest(
+ uri_suffix, QueryParams(), HttpRequest::Method::kPost,
+ request.SerializeAsString(),
+ /*is_protobuf_encoded=*/true));
+ FCP_ASSIGN_OR_RETURN(InMemoryHttpResponse response,
+ request_helper_.PerformProtocolRequest(
+ std::move(http_request), interruptible_runner_));
+ // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
+ // ParseOperationProtoFromHttpResponse(response));
+
+ // FCP_ASSIGN_OR_RETURN(
+ // Operation completed_operation,
+ // request_helper_.PollOperationResponseUntilDone(
+ // initial_operation, *secagg_request_creator_,
+ // interruptible_runner_));
+
+ // // The Operation has finished. Check if it resulted in an error, and if so
+ // // forward it after converting it to an absl::Status error.
+ // if (completed_operation.has_error()) {
+ // return ConvertRpcStatusToAbslStatus(completed_operation.error());
+ // }
+ AdvertiseKeysResponse response_proto;
+ if (!response_proto.ParseFromString(std::string(response.body))) {
+ return absl::InternalError("could not parse AdvertiseKeysResponse proto");
+ }
+ secagg::ServerToClientWrapperMessage server_message;
+ *server_message.mutable_share_keys_request() =
+ response_proto.share_keys_server_request();
+ return server_message;
+}
+
+absl::StatusOr<secagg::ServerToClientWrapperMessage>
+HttpSecAggSendToServerImpl::DoR1ShareKeys(
+ secagg::ShareKeysResponse share_keys_response) {
+ FCP_ASSIGN_OR_RETURN(
+ std::string uri_suffix,
+ CreateShareKeysUriSuffix(aggregation_id_, client_token_));
+
+ ShareKeysRequest request;
+ *request.mutable_share_keys_client_response() = share_keys_response;
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> http_request,
+ secagg_request_creator_->CreateProtocolRequest(
+ uri_suffix, QueryParams(), HttpRequest::Method::kPost,
+ request.SerializeAsString(),
+ /*is_protobuf_encoded=*/true));
+
+ FCP_ASSIGN_OR_RETURN(InMemoryHttpResponse response,
+ request_helper_.PerformProtocolRequest(
+ std::move(http_request), interruptible_runner_));
+ // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
+ // ParseOperationProtoFromHttpResponse(response));
+
+ // FCP_ASSIGN_OR_RETURN(
+ // Operation completed_operation,
+ // request_helper_.PollOperationResponseUntilDone(
+ // initial_operation, *secagg_request_creator_,
+ // interruptible_runner_));
+
+ // // The Operation has finished. Check if it resulted in an error, and if so
+ // // forward it after converting it to an absl::Status error.
+ // if (completed_operation.has_error()) {
+ // return ConvertRpcStatusToAbslStatus(completed_operation.error());
+ // }
+ ShareKeysResponse response_proto;
+ if (!response_proto.ParseFromString(std::string(response.body))) {
+ return absl::InternalError(
+ "could not parse StartSecureAggregationResponse proto");
+ }
+ secagg::ServerToClientWrapperMessage server_message;
+ *server_message.mutable_masked_input_request() =
+ response_proto.masked_input_collection_server_request();
+ return server_message;
+}
+
+absl::StatusOr<secagg::ServerToClientWrapperMessage>
+HttpSecAggSendToServerImpl::DoR2SubmitSecureAggregationResult(
+ secagg::MaskedInputCollectionResponse masked_input_response) {
+ std::vector<std::unique_ptr<HttpRequest>> requests;
+ FCP_ASSIGN_OR_RETURN(std::string masked_result_upload_uri_suffix,
+ CreateByteStreamUploadUriSuffix(masked_resource_name_));
+
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> masked_input_upload_request,
+ masked_result_upload_request_creator_->CreateProtocolRequest(
+ masked_result_upload_uri_suffix, {{"upload_protocol", "raw"}},
+ HttpRequest::Method::kPost,
+ std::move(masked_input_response).SerializeAsString(),
+ /*is_protobuf_encoded=*/false));
+ requests.push_back(std::move(masked_input_upload_request));
+ bool has_checkpoint = tf_checkpoint_.has_value();
+ if (has_checkpoint) {
+ FCP_ASSIGN_OR_RETURN(
+ std::string nonmasked_result_upload_uri_suffix,
+ CreateByteStreamUploadUriSuffix(nonmasked_resource_name_));
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> nonmasked_input_upload_request,
+ nonmasked_result_upload_request_creator_->CreateProtocolRequest(
+ nonmasked_result_upload_uri_suffix, {{"upload_protocol", "raw"}},
+ HttpRequest::Method::kPost, std::move(tf_checkpoint_).value(),
+ /*is_protobuf_encoded=*/false));
+ requests.push_back(std::move(nonmasked_input_upload_request));
+ }
+ FCP_ASSIGN_OR_RETURN(
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> responses,
+ request_helper_.PerformMultipleProtocolRequests(std::move(requests),
+ interruptible_runner_));
+ for (const auto& response : responses) {
+ if (!response.ok()) {
+ return response.status();
+ }
+ }
+ FCP_ASSIGN_OR_RETURN(std::string submit_result_uri_suffix,
+ CreateSubmitSecureAggregationResultUriSuffix(
+ aggregation_id_, client_token_));
+ SubmitSecureAggregationResultRequest request;
+ *request.mutable_masked_result_resource_name() = masked_resource_name_;
+ if (has_checkpoint) {
+ *request.mutable_nonmasked_result_resource_name() =
+ nonmasked_resource_name_;
+ }
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> submit_result_request,
+ secagg_request_creator_->CreateProtocolRequest(
+ submit_result_uri_suffix, QueryParams(), HttpRequest::Method::kPost,
+ request.SerializeAsString(),
+ /*is_protobuf_encoded=*/true));
+ FCP_ASSIGN_OR_RETURN(
+ InMemoryHttpResponse response,
+ request_helper_.PerformProtocolRequest(std::move(submit_result_request),
+ interruptible_runner_));
+ // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
+ // ParseOperationProtoFromHttpResponse(response));
+ // FCP_ASSIGN_OR_RETURN(
+ // Operation completed_operation,
+ // request_helper_.PollOperationResponseUntilDone(
+ // initial_operation, *secagg_request_creator_,
+ // interruptible_runner_));
+
+ // // The Operation has finished. Check if it resulted in an error, and if so
+ // // forward it after converting it to an absl::Status error.
+ // if (completed_operation.has_error()) {
+ // return ConvertRpcStatusToAbslStatus(completed_operation.error());
+ // }
+ SubmitSecureAggregationResultResponse response_proto;
+ if (!response_proto.ParseFromString(std::string(response.body))) {
+ return absl::InvalidArgumentError(
+ "could not parse SubmitSecureAggregationResultResponse proto");
+ }
+ secagg::ServerToClientWrapperMessage server_message;
+ *server_message.mutable_unmasking_request() =
+ response_proto.unmasking_server_request();
+ return server_message;
+}
+
+absl::StatusOr<secagg::ServerToClientWrapperMessage>
+HttpSecAggSendToServerImpl::DoR3Unmask(
+ secagg::UnmaskingResponse unmasking_response) {
+ FCP_ASSIGN_OR_RETURN(std::string unmask_uri_suffix,
+ CreateUnmaskUriSuffix(aggregation_id_, client_token_));
+ UnmaskRequest request;
+ *request.mutable_unmasking_client_response() = unmasking_response;
+ FCP_ASSIGN_OR_RETURN(
+ std::unique_ptr<HttpRequest> unmask_request,
+ secagg_request_creator_->CreateProtocolRequest(
+ unmask_uri_suffix, QueryParams(), HttpRequest::Method::kPost,
+ request.SerializeAsString(),
+ /*is_protobuf_encoded=*/true));
+ FCP_ASSIGN_OR_RETURN(InMemoryHttpResponse unmask_response,
+ request_helper_.PerformProtocolRequest(
+ std::move(unmask_request), interruptible_runner_));
+ return secagg::ServerToClientWrapperMessage();
+}
+
+// TODO(team): remove GetModulus method, merge it into SecAggRunner.
+absl::StatusOr<uint64_t> HttpSecAggProtocolDelegate::GetModulus(
+ const std::string& key) {
+ if (!secure_aggregands_.contains(key)) {
+ return absl::InternalError(
+ absl::StrCat("Execution not found for aggregand: ", key));
+ }
+ return secure_aggregands_[key].modulus();
+}
+
+absl::StatusOr<secagg::ServerToClientWrapperMessage>
+HttpSecAggProtocolDelegate::ReceiveServerMessage() {
+ return server_response_holder_;
+}
+
+void HttpSecAggProtocolDelegate::Abort() {
+ // Intentional to be blank because we don't have internal states to clear.
+}
+
+size_t HttpSecAggProtocolDelegate::last_received_message_size() {
+ if (server_response_holder_.ok()) {
+ return server_response_holder_->ByteSizeLong();
+ } else {
+ // If the last request failed, return zero.
+ return 0;
+ }
+}
+
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/http_secagg_send_to_server_impl.h b/fcp/client/http/http_secagg_send_to_server_impl.h
new file mode 100644
index 0000000..8fc61a8
--- /dev/null
+++ b/fcp/client/http/http_secagg_send_to_server_impl.h
@@ -0,0 +1,181 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_
+#define FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_
+
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/client/http/protocol_request_helper.h"
+#include "fcp/client/secagg_event_publisher.h"
+#include "fcp/client/secagg_runner.h"
+#include "fcp/protos/federatedcompute/common.pb.h"
+#include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+// Implementation of SecAggSendToServerBase for HTTP federated protocol.
+class HttpSecAggSendToServerImpl : public SecAggSendToServerBase {
+ public:
+ // Create an instance of HttpSecAggSendToServerImpl.
+ // This method returns error status when failed to create
+ // ProtocolRequestCreator based on the input ForwardingInfo or
+ // ByteStreamResources.
+ static absl::StatusOr<std::unique_ptr<HttpSecAggSendToServerImpl>> Create(
+ absl::string_view api_key, Clock* clock,
+ ProtocolRequestHelper* request_helper,
+ InterruptibleRunner* interruptible_runner,
+ std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)>
+ delayed_interruptible_runner_creator,
+ absl::StatusOr<secagg::ServerToClientWrapperMessage>*
+ server_response_holder,
+ absl::string_view aggregation_id, absl::string_view client_token,
+ const google::internal::federatedcompute::v1::ForwardingInfo&
+ secagg_upload_forwarding_info,
+ const google::internal::federatedcompute::v1::ByteStreamResource&
+ masked_result_resource,
+ const google::internal::federatedcompute::v1::ByteStreamResource&
+ nonmasked_result_resource,
+ std::optional<std::string> tf_checkpoint,
+ bool disable_request_body_compression,
+ absl::Duration waiting_period_for_cancellation);
+ ~HttpSecAggSendToServerImpl() override = default;
+
+ // Sends a client to server request based on the
+ // secagg::ClientToServerWrapperMessage, waits for the response, and set it to
+ // the server response holder.
+ void Send(secagg::ClientToServerWrapperMessage* message) override;
+
+ private:
+ HttpSecAggSendToServerImpl(
+ absl::string_view api_key, Clock* clock,
+ ProtocolRequestHelper* request_helper,
+ InterruptibleRunner* interruptible_runner,
+ std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)>
+ delayed_interruptible_runner_creator,
+ absl::StatusOr<secagg::ServerToClientWrapperMessage>*
+ server_response_holder,
+ absl::string_view aggregation_id, absl::string_view client_token,
+ absl::string_view masked_resource_name,
+ absl::string_view nonmasked_resource_name,
+ std::unique_ptr<ProtocolRequestCreator> secagg_request_creator,
+ std::unique_ptr<ProtocolRequestCreator>
+ masked_result_upload_request_creator,
+ std::unique_ptr<ProtocolRequestCreator>
+ nonmasked_result_upload_request_creator,
+ std::optional<std::string> tf_checkpoint,
+ absl::Duration waiting_period_for_cancellation)
+ : api_key_(api_key),
+ clock_(*clock),
+ request_helper_(*request_helper),
+ interruptible_runner_(*interruptible_runner),
+ delayed_interruptible_runner_creator_(
+ delayed_interruptible_runner_creator),
+ server_response_holder_(*server_response_holder),
+ aggregation_id_(std::string(aggregation_id)),
+ client_token_(std::string(client_token)),
+ masked_resource_name_(std::string(masked_resource_name)),
+ nonmasked_resource_name_(std::string(nonmasked_resource_name)),
+ secagg_request_creator_(std::move(secagg_request_creator)),
+ masked_result_upload_request_creator_(
+ std::move(masked_result_upload_request_creator)),
+ nonmasked_result_upload_request_creator_(
+ std::move(nonmasked_result_upload_request_creator)),
+ tf_checkpoint_(std::move(tf_checkpoint)),
+ waiting_period_for_cancellation_(waiting_period_for_cancellation) {}
+
+ // Sends an AbortSecureAggregationRequest.
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> AbortSecureAggregation(
+ secagg::AbortMessage abort_message);
+ // Sends an AdvertiseKeysRequest and waits for the AdvertiseKeysResponse,
+ // polling the corresponding LRO if needed.
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR0AdvertiseKeys(
+ secagg::AdvertiseKeys advertise_keys);
+ // Sends an ShareKeysRequest and waits for the ShareKeysResponse, polling
+ // the corresponding LRO if needed.
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR1ShareKeys(
+ secagg::ShareKeysResponse share_keys_response);
+ // Uploads masked resource and (optional) nonmasked resource. After successful
+ // upload, sends an SubmitSecureAggregationResultRequest and waits for the
+ // SubmitSecureAggregationResultResponse, polling the corresponding LRO if
+ // needed.
+ absl::StatusOr<secagg::ServerToClientWrapperMessage>
+ DoR2SubmitSecureAggregationResult(
+ secagg::MaskedInputCollectionResponse masked_input_response);
+ // Sends an UnmaskRequest and waits for the UnmaskResponse.
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR3Unmask(
+ secagg::UnmaskingResponse unmasking_response);
+ const std::string api_key_;
+ Clock& clock_;
+ ProtocolRequestHelper& request_helper_;
+ InterruptibleRunner& interruptible_runner_;
+ std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)>
+ delayed_interruptible_runner_creator_;
+ absl::StatusOr<secagg::ServerToClientWrapperMessage>& server_response_holder_;
+ std::string aggregation_id_;
+ std::string client_token_;
+ std::string masked_resource_name_;
+ std::string nonmasked_resource_name_;
+ std::unique_ptr<ProtocolRequestCreator> secagg_request_creator_;
+ std::unique_ptr<ProtocolRequestCreator> masked_result_upload_request_creator_;
+ std::unique_ptr<ProtocolRequestCreator>
+ nonmasked_result_upload_request_creator_;
+ std::optional<std::string> tf_checkpoint_;
+ absl::Duration waiting_period_for_cancellation_;
+};
+
+// Implementation of SecAggProtocolDelegate for the HTTP federated protocol.
+class HttpSecAggProtocolDelegate : public SecAggProtocolDelegate {
+ public:
+ HttpSecAggProtocolDelegate(
+ google::protobuf::Map<
+ std::string,
+ google::internal::federatedcompute::v1::SecureAggregandExecutionInfo>
+ secure_aggregands,
+ absl::StatusOr<secagg::ServerToClientWrapperMessage>*
+ server_response_holder)
+ : secure_aggregands_(std::move(secure_aggregands)),
+ server_response_holder_(*server_response_holder) {}
+ // Retrieve the modulus for a given SecAgg vector.
+ absl::StatusOr<uint64_t> GetModulus(const std::string& key) override;
+ // Receive Server message.
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> ReceiveServerMessage()
+ override;
+ // Called when the SecAgg protocol is interrupted.
+ void Abort() override;
+ size_t last_received_message_size() override;
+
+ private:
+ google::protobuf::Map<
+ std::string,
+ google::internal::federatedcompute::v1::SecureAggregandExecutionInfo>
+ secure_aggregands_;
+ absl::StatusOr<secagg::ServerToClientWrapperMessage>& server_response_holder_;
+};
+
+} // namespace http
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_
diff --git a/fcp/client/http/http_secagg_send_to_server_impl_test.cc b/fcp/client/http/http_secagg_send_to_server_impl_test.cc
new file mode 100644
index 0000000..908fa01
--- /dev/null
+++ b/fcp/client/http/http_secagg_send_to_server_impl_test.cc
@@ -0,0 +1,756 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/http_secagg_send_to_server_impl.h"
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "google/longrunning/operations.pb.h"
+#include "google/rpc/code.pb.h"
+#include "absl/time/time.h"
+#include "fcp/base/simulated_clock.h"
+#include "fcp/client/http/testing/test_helpers.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+namespace {
+
+using ::google::internal::federatedcompute::v1::AbortSecureAggregationRequest;
+using ::google::internal::federatedcompute::v1::AbortSecureAggregationResponse;
+using ::google::internal::federatedcompute::v1::AdvertiseKeysRequest;
+using ::google::internal::federatedcompute::v1::AdvertiseKeysResponse;
+using ::google::internal::federatedcompute::v1::ByteStreamResource;
+using ::google::internal::federatedcompute::v1::ForwardingInfo;
+using ::google::internal::federatedcompute::v1::SecureAggregandExecutionInfo;
+using ::google::internal::federatedcompute::v1::ShareKeysRequest;
+using ::google::internal::federatedcompute::v1::ShareKeysResponse;
+using ::google::internal::federatedcompute::v1::
+ SubmitSecureAggregationResultRequest;
+using ::google::internal::federatedcompute::v1::
+ SubmitSecureAggregationResultResponse;
+using ::google::internal::federatedcompute::v1::UnmaskRequest;
+using ::google::internal::federatedcompute::v1::UnmaskResponse;
+using ::google::longrunning::Operation;
+using ::testing::_;
+using ::testing::NiceMock;
+using ::testing::Return;
+using ::testing::StrictMock;
+
+constexpr absl::string_view kAggregationId = "aggregation_id";
+constexpr absl::string_view kClientToken = "client_token";
+constexpr absl::string_view kSecureAggregationTargetUri =
+ "https://secureaggregation.uri/";
+constexpr absl::string_view kByteStreamTargetUri = "https://bytestream.uri/";
+constexpr absl::string_view kMaskedResourceName = "masked_resource";
+constexpr absl::string_view kNonmaskedResourceName = "nonmasked_resource";
+constexpr absl::string_view kOperationName = "my_operation";
+constexpr absl::string_view kApiKey = "API_KEY";
+constexpr absl::Duration kDelayedInterruptibleRunnerDeadline =
+ absl::Seconds(10);
+
+TEST(HttpSecAggProtocolDelegateTest, GetModulus) {
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
+ std::string tensor_key = "tensor_1";
+ google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
+ SecureAggregandExecutionInfo secure_aggregand_execution_info;
+ secure_aggregand_execution_info.set_modulus(12345);
+ secure_aggregands[tensor_key] = secure_aggregand_execution_info;
+ HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
+ auto modulus = delegate.GetModulus(tensor_key);
+ ASSERT_OK(modulus);
+ ASSERT_EQ(*modulus, 12345);
+}
+
+TEST(HttpSecAggProtocolDelegateTest, GetModulusKeyNotFound) {
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
+ google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
+ SecureAggregandExecutionInfo secure_aggregand_execution_info;
+ secure_aggregand_execution_info.set_modulus(12345);
+ secure_aggregands["tensor_1"] = secure_aggregand_execution_info;
+ HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
+ ASSERT_THAT(delegate.GetModulus("do_not_exist"),
+ IsCode(absl::StatusCode::kInternal));
+}
+
+TEST(HttpSecAggProtocolDelegateTest, ReceiveMessageOkResponse) {
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
+ google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
+ HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
+ secagg::ServerToClientWrapperMessage server_response;
+ server_response.mutable_masked_input_request()->add_encrypted_key_shares(
+ "encrypted_key");
+ holder = server_response;
+
+ auto server_message = delegate.ReceiveServerMessage();
+ ASSERT_OK(server_message);
+ ASSERT_THAT(*server_message, EqualsProto(server_response));
+ ASSERT_EQ(delegate.last_received_message_size(),
+ server_response.ByteSizeLong());
+}
+
+TEST(HttpSecAggProtocolDelegateTest, ReceiveMessageErrorResponse) {
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
+ google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
+ HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
+ holder = absl::InternalError("Something is broken.");
+
+ ASSERT_THAT(delegate.ReceiveServerMessage(),
+ IsCode(absl::StatusCode::kInternal));
+ ASSERT_EQ(delegate.last_received_message_size(), 0);
+}
+
+class HttpSecAggSendToServerImplTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ request_helper_ = std::make_unique<ProtocolRequestHelper>(
+ &http_client_, &bytes_downloaded_, &bytes_uploaded_,
+ network_stopwatch_.get(), Clock::RealClock());
+ runner_ = std::make_unique<InterruptibleRunner>(
+ &log_manager_, []() { return false; },
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
+ .interrupt_timeout =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
+ *secagg_upload_forwarding_info_.mutable_target_uri_prefix() =
+ kSecureAggregationTargetUri;
+ *masked_result_resource_.mutable_resource_name() = kMaskedResourceName;
+ ForwardingInfo masked_resource_forwarding_info;
+ *masked_resource_forwarding_info.mutable_target_uri_prefix() =
+ kByteStreamTargetUri;
+ *masked_result_resource_.mutable_data_upload_forwarding_info() =
+ masked_resource_forwarding_info;
+ *nonmasked_result_resource_.mutable_resource_name() =
+ kNonmaskedResourceName;
+ ForwardingInfo nonmasked_resource_forwarding_info;
+ *nonmasked_resource_forwarding_info.mutable_target_uri_prefix() =
+ kByteStreamTargetUri;
+ *nonmasked_result_resource_.mutable_data_upload_forwarding_info() =
+ nonmasked_resource_forwarding_info;
+ }
+
+ std::unique_ptr<InterruptibleRunner> CreateInterruptibleRunner() {
+ return std::make_unique<InterruptibleRunner>(
+ &log_manager_, [this]() { return interrupted_; },
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
+ .interrupt_timeout =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
+ }
+
+ std::unique_ptr<HttpSecAggSendToServerImpl> CreateSecAggSendToServer(
+ std::optional<std::string> tf_checkpoint) {
+ auto send_to_server = HttpSecAggSendToServerImpl::Create(
+ kApiKey, &clock_, request_helper_.get(), runner_.get(),
+ /* delayed_interruptible_runner_creator=*/
+ [this](absl::Time deadline) {
+ // Ensure that the HttpSecAggSendToServerImpl implementation correctly
+ // passes a deadline that matches the 'waiting period' value we
+ // provide below (with a 1s grace period to account for the delay of
+ // executing the actual test; unfortunately the underlying HTTP code
+ // currently still uses absl::Now() directly, so we're forced to deal
+ // with 'real' time...).
+ //
+ // We don't actually use the deadline value though, since that
+ // would only make testing more complicated.
+ EXPECT_GE(deadline, absl::Now() +
+ kDelayedInterruptibleRunnerDeadline -
+ absl::Seconds(1));
+ return CreateInterruptibleRunner();
+ },
+ &server_response_holder_, kAggregationId, kClientToken,
+ secagg_upload_forwarding_info_, masked_result_resource_,
+ nonmasked_result_resource_, tf_checkpoint,
+ /* disable_request_body_compression=*/true,
+ /* waiting_period_for_cancellation=*/
+ kDelayedInterruptibleRunnerDeadline);
+ FCP_CHECK(send_to_server.ok());
+ return std::move(*send_to_server);
+ }
+ bool interrupted_ = false;
+ // We set the simulated clock to "now", since a bunch of the HTTP-related FCP
+ // code currently still uses absl::Now() directly, rather than using a more
+ // testable "Clock" object. This ensures various timestamps we may encounter
+ // are more understandable.
+ SimulatedClock clock_ = SimulatedClock(absl::Now());
+ StrictMock<MockHttpClient> http_client_;
+ NiceMock<MockLogManager> log_manager_;
+ int64_t bytes_downloaded_ = 0;
+ int64_t bytes_uploaded_ = 0;
+ std::unique_ptr<WallClockStopwatch> network_stopwatch_ =
+ WallClockStopwatch::Create();
+ std::unique_ptr<ProtocolRequestHelper> request_helper_;
+ std::unique_ptr<InterruptibleRunner> runner_;
+ absl::StatusOr<secagg::ServerToClientWrapperMessage> server_response_holder_;
+ ForwardingInfo secagg_upload_forwarding_info_;
+ ByteStreamResource masked_result_resource_;
+ ByteStreamResource nonmasked_result_resource_;
+};
+
+TEST_F(HttpSecAggSendToServerImplTest, TestSendR0AdvertiseKeys) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ secagg::ClientToServerWrapperMessage server_message;
+ auto pair_of_keys =
+ server_message.mutable_advertise_keys()->mutable_pair_of_public_keys();
+ pair_of_keys->set_enc_pk("enc_pk");
+ pair_of_keys->set_noise_pk("noise_pk");
+ // Create expected request.
+ AdvertiseKeysRequest expected_request;
+ *expected_request.mutable_advertise_keys() = server_message.advertise_keys();
+
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:advertisekeys?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar").SerializeAsString())));
+ // Create expected response.
+ secagg::ShareKeysRequest share_keys_request;
+ *share_keys_request.add_pairs_of_public_keys()->mutable_noise_pk() =
+ "noise_pk";
+ AdvertiseKeysResponse advertise_keys_response;
+ *advertise_keys_response.mutable_share_keys_server_request() =
+ share_keys_request;
+ Operation complete_operation =
+ CreateDoneOperation(kOperationName, advertise_keys_response);
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), complete_operation.SerializeAsString())));
+ send_to_server->Send(&server_message);
+ ASSERT_OK(server_response_holder_);
+
+ secagg::ServerToClientWrapperMessage expected_message;
+ *expected_message.mutable_share_keys_request() = share_keys_request;
+ EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest,
+ TestSendR0AdvertiseKeysFailedImmediately) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ secagg::ClientToServerWrapperMessage server_message;
+ auto pair_of_keys =
+ server_message.mutable_advertise_keys()->mutable_pair_of_public_keys();
+ pair_of_keys->set_enc_pk("enc_pk");
+ pair_of_keys->set_noise_pk("noise_pk");
+ // Create expected request.
+ AdvertiseKeysRequest expected_request;
+ *expected_request.mutable_advertise_keys() = server_message.advertise_keys();
+
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:advertisekeys?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
+ send_to_server->Send(&server_message);
+ EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest, TestSendR0AdvertiseKeysFailed) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ secagg::ClientToServerWrapperMessage server_message;
+ auto pair_of_keys =
+ server_message.mutable_advertise_keys()->mutable_pair_of_public_keys();
+ pair_of_keys->set_enc_pk("enc_pk");
+ pair_of_keys->set_noise_pk("noise_pk");
+ // Create expected request.
+ AdvertiseKeysRequest expected_request;
+ *expected_request.mutable_advertise_keys() = server_message.advertise_keys();
+
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:advertisekeys?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
+ "Something's wrong")
+ .SerializeAsString())));
+ send_to_server->Send(&server_message);
+ EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest, TestSendR1ShareKeys) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ secagg::ClientToServerWrapperMessage server_message;
+ server_message.mutable_share_keys_response()
+ ->mutable_encrypted_key_shares()
+ ->Add("encrypted_key");
+ // Create expected request
+ ShareKeysRequest expected_request;
+ *expected_request.mutable_share_keys_client_response() =
+ server_message.share_keys_response();
+
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:sharekeys?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar").SerializeAsString())));
+ // Create expected response
+ secagg::MaskedInputCollectionRequest masked_input_collection_request;
+ masked_input_collection_request.add_encrypted_key_shares(
+ "encryoted_key_share");
+ ShareKeysResponse share_keys_response;
+ *share_keys_response.mutable_masked_input_collection_server_request() =
+ masked_input_collection_request;
+ Operation complete_operation =
+ CreateDoneOperation(kOperationName, share_keys_response);
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), complete_operation.SerializeAsString())));
+ send_to_server->Send(&server_message);
+ ASSERT_OK(server_response_holder_);
+
+ secagg::ServerToClientWrapperMessage expected_message;
+ *expected_message.mutable_masked_input_request() =
+ masked_input_collection_request;
+ EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest, TestSendR1ShareKeysFailedImmediatedly) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ secagg::ClientToServerWrapperMessage server_message;
+ server_message.mutable_share_keys_response()
+ ->mutable_encrypted_key_shares()
+ ->Add("encrypted_key");
+ // Create expected request
+ ShareKeysRequest expected_request;
+ *expected_request.mutable_share_keys_client_response() =
+ server_message.share_keys_response();
+
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:sharekeys?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
+ send_to_server->Send(&server_message);
+ EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest, TestSendR1ShareKeysFailed) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ secagg::ClientToServerWrapperMessage server_message;
+ server_message.mutable_share_keys_response()
+ ->mutable_encrypted_key_shares()
+ ->Add("encrypted_key");
+ // Create expected request
+ ShareKeysRequest expected_request;
+ *expected_request.mutable_share_keys_client_response() =
+ server_message.share_keys_response();
+
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:sharekeys?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
+ "Something's wrong")
+ .SerializeAsString())));
+ send_to_server->Send(&server_message);
+ EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest, TestSendR2SubmitResultNoCheckpoint) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ secagg::ClientToServerWrapperMessage server_message;
+ secagg::MaskedInputVector masked_vector;
+ *masked_vector.mutable_encoded_vector() = "encoded_vector";
+ auto vector_map =
+ server_message.mutable_masked_input_response()->mutable_vectors();
+ (*vector_map)["vector_1"] = masked_vector;
+ // Create expected request
+ SubmitSecureAggregationResultRequest expected_request;
+ *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
+
+ // Create expected responses
+ EXPECT_CALL(http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://bytestream.uri/upload/v1/media/"
+ "masked_resource?upload_protocol=raw",
+ HttpRequest::Method::kPost, _,
+ server_message.masked_input_response().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:submit?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar").SerializeAsString())));
+ secagg::UnmaskingRequest unmasking_request;
+ unmasking_request.add_dead_3_client_ids(12345);
+ SubmitSecureAggregationResultResponse submit_secagg_result_response;
+ *submit_secagg_result_response.mutable_unmasking_server_request() =
+ unmasking_request;
+ Operation complete_operation =
+ CreateDoneOperation(kOperationName, submit_secagg_result_response);
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), complete_operation.SerializeAsString())));
+ send_to_server->Send(&server_message);
+ ASSERT_OK(server_response_holder_);
+
+ secagg::ServerToClientWrapperMessage expected_message;
+ *expected_message.mutable_unmasking_request() = unmasking_request;
+ EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest, TestSendR2SubmitResultWithCheckpoint) {
+ std::string tf_checkpoint = "trained.ckpt";
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
+ secagg::ClientToServerWrapperMessage server_message;
+ secagg::MaskedInputVector masked_vector;
+ *masked_vector.mutable_encoded_vector() = "encoded_vector";
+ auto vector_map =
+ server_message.mutable_masked_input_response()->mutable_vectors();
+ (*vector_map)["vector_1"] = masked_vector;
+ // Create expected request
+ SubmitSecureAggregationResultRequest expected_request;
+ *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
+ *expected_request.mutable_nonmasked_result_resource_name() =
+ kNonmaskedResourceName;
+
+ // Create expected responses
+ EXPECT_CALL(http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://bytestream.uri/upload/v1/media/"
+ "masked_resource?upload_protocol=raw",
+ HttpRequest::Method::kPost, _,
+ server_message.masked_input_response().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+ EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://bytestream.uri/upload/v1/media/"
+ "nonmasked_resource?upload_protocol=raw",
+ HttpRequest::Method::kPost, _, tf_checkpoint)))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ secagg::UnmaskingRequest unmasking_request;
+ unmasking_request.add_dead_3_client_ids(12345);
+ SubmitSecureAggregationResultResponse submit_secagg_result_response;
+ *submit_secagg_result_response.mutable_unmasking_server_request() =
+ unmasking_request;
+ Operation complete_operation =
+ CreateDoneOperation(kOperationName, submit_secagg_result_response);
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:submit?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), complete_operation.SerializeAsString())));
+ send_to_server->Send(&server_message);
+ ASSERT_OK(server_response_holder_);
+
+ secagg::ServerToClientWrapperMessage expected_message;
+ *expected_message.mutable_unmasking_request() = unmasking_request;
+ EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest,
+ TestSendR2SubmitResultWithCheckpointUploadFailed) {
+ std::string tf_checkpoint = "trained.ckpt";
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
+ secagg::ClientToServerWrapperMessage server_message;
+ secagg::MaskedInputVector masked_vector;
+ *masked_vector.mutable_encoded_vector() = "encoded_vector";
+ auto vector_map =
+ server_message.mutable_masked_input_response()->mutable_vectors();
+ (*vector_map)["vector_1"] = masked_vector;
+ // Create expected request
+ SubmitSecureAggregationResultRequest expected_request;
+ *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
+ *expected_request.mutable_nonmasked_result_resource_name() =
+ kNonmaskedResourceName;
+
+ // Fail one of the upload
+ EXPECT_CALL(http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://bytestream.uri/upload/v1/media/"
+ "masked_resource?upload_protocol=raw",
+ HttpRequest::Method::kPost, _,
+ server_message.masked_input_response().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
+ EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://bytestream.uri/upload/v1/media/"
+ "nonmasked_resource?upload_protocol=raw",
+ HttpRequest::Method::kPost, _, tf_checkpoint)))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ send_to_server->Send(&server_message);
+ EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest,
+ TestSendR2SubmitResultWithCheckpointSubmitResultFailedImmediately) {
+ std::string tf_checkpoint = "trained.ckpt";
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
+ secagg::ClientToServerWrapperMessage server_message;
+ secagg::MaskedInputVector masked_vector;
+ masked_vector.set_encoded_vector("encoded_vector");
+ auto vector_map =
+ server_message.mutable_masked_input_response()->mutable_vectors();
+ (*vector_map)["vector_1"] = masked_vector;
+ // Create expected request
+ SubmitSecureAggregationResultRequest expected_request;
+ *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
+ *expected_request.mutable_nonmasked_result_resource_name() =
+ kNonmaskedResourceName;
+
+ // Create expected responses
+ EXPECT_CALL(http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://bytestream.uri/upload/v1/media/"
+ "masked_resource?upload_protocol=raw",
+ HttpRequest::Method::kPost, _,
+ server_message.masked_input_response().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+ EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://bytestream.uri/upload/v1/media/"
+ "nonmasked_resource?upload_protocol=raw",
+ HttpRequest::Method::kPost, _, tf_checkpoint)))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ secagg::UnmaskingRequest unmasking_request;
+ unmasking_request.add_dead_3_client_ids(12345);
+ SubmitSecureAggregationResultResponse submit_secagg_result_response;
+ *submit_secagg_result_response.mutable_unmasking_server_request() =
+ unmasking_request;
+ Operation complete_operation =
+ CreateDoneOperation(kOperationName, submit_secagg_result_response);
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:submit?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
+ send_to_server->Send(&server_message);
+ EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest,
+ TestSendR2SubmitResultWithCheckpointSubmitResultFailed) {
+ std::string tf_checkpoint = "trained.ckpt";
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
+ secagg::ClientToServerWrapperMessage server_message;
+ secagg::MaskedInputVector masked_vector;
+ masked_vector.set_encoded_vector("encoded_vector");
+ auto vector_map =
+ server_message.mutable_masked_input_response()->mutable_vectors();
+ (*vector_map)["vector_1"] = masked_vector;
+ // Create expected request
+ SubmitSecureAggregationResultRequest expected_request;
+ *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
+ *expected_request.mutable_nonmasked_result_resource_name() =
+ kNonmaskedResourceName;
+
+ // Create expected responses
+ EXPECT_CALL(http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://bytestream.uri/upload/v1/media/"
+ "masked_resource?upload_protocol=raw",
+ HttpRequest::Method::kPost, _,
+ server_message.masked_input_response().SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+ EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://bytestream.uri/upload/v1/media/"
+ "nonmasked_resource?upload_protocol=raw",
+ HttpRequest::Method::kPost, _, tf_checkpoint)))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
+
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:submit?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar").SerializeAsString())));
+
+ secagg::UnmaskingRequest unmasking_request;
+ unmasking_request.add_dead_3_client_ids(12345);
+ SubmitSecureAggregationResultResponse submit_secagg_result_response;
+ *submit_secagg_result_response.mutable_unmasking_server_request() =
+ unmasking_request;
+ Operation complete_operation =
+ CreateDoneOperation(kOperationName, submit_secagg_result_response);
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, "")))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
+ "Something's wroing.")
+ .SerializeAsString())));
+ send_to_server->Send(&server_message);
+ EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest, TestSendR3Unmask) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ secagg::ClientToServerWrapperMessage server_message;
+ server_message.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares()
+ ->set_noise_sk_share("noise_sk_share");
+ // Create expected request
+ UnmaskRequest expected_request;
+ *expected_request.mutable_unmasking_client_response() =
+ server_message.unmasking_response();
+
+ // Create expected response
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:unmask?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(),
+ UnmaskResponse().SerializeAsString())));
+ send_to_server->Send(&server_message);
+ auto response = server_response_holder_;
+ ASSERT_OK(response);
+ EXPECT_THAT(*response, EqualsProto(secagg::ServerToClientWrapperMessage()));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest, TestSendAbortWithoutInterruption) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ std::string diagnostic_info = "Some computation failed.";
+ secagg::ClientToServerWrapperMessage server_message;
+ server_message.mutable_abort()->set_diagnostic_info(diagnostic_info);
+ // Create expected request
+ AbortSecureAggregationRequest expected_request;
+ google::rpc::Status status;
+ status.set_message(diagnostic_info);
+ status.set_code(google::rpc::INTERNAL);
+ *expected_request.mutable_status() = status;
+
+ // We expect the abort request to actually be issued, because interrupted_ is
+ // set to false, and hence the "InterruptibleRunner" we provided at the top of
+ // the test should let the request go through.
+ EXPECT_CALL(
+ http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
+ "clients/client_token:abort?%24alt=proto",
+ HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ AbortSecureAggregationResponse().SerializeAsString())));
+
+ // Send the request, and verify that sending it succeeded.
+ send_to_server->Send(&server_message);
+ ASSERT_OK(server_response_holder_);
+ secagg::ServerToClientWrapperMessage expected_response;
+ expected_response.mutable_abort();
+ EXPECT_THAT(*server_response_holder_, EqualsProto(expected_response));
+}
+
+TEST_F(HttpSecAggSendToServerImplTest,
+ TestSendAbortShouldBeCancelledIfAlreadyInterruptedForTooLong) {
+ std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
+ CreateSecAggSendToServer(std::optional<std::string>());
+ std::string diagnostic_info = "Some computation failed.";
+ secagg::ClientToServerWrapperMessage server_message;
+ server_message.mutable_abort()->set_diagnostic_info(diagnostic_info);
+ // Create expected request
+ AbortSecureAggregationRequest expected_request;
+ google::rpc::Status status;
+ status.set_message(diagnostic_info);
+ status.set_code(google::rpc::INTERNAL);
+ *expected_request.mutable_status() = status;
+
+ // We do *not* expect any HTTP request to actually be issued, since the
+ // interrupted_ flag is true, and therefore the request should be cancelled
+ // before it is even issued.
+ interrupted_ = true;
+
+ // Send the request, and verify that sending it failed.
+ send_to_server->Send(&server_message);
+ EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kCancelled));
+}
+
+} // anonymous namespace
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/in_memory_request_response.cc b/fcp/client/http/in_memory_request_response.cc
new file mode 100644
index 0000000..bbee11a
--- /dev/null
+++ b/fcp/client/http/in_memory_request_response.cc
@@ -0,0 +1,607 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/in_memory_request_response.h"
+
+#include <cstdint>
+#include <cstring>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/cache/resource_cache.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/http_client_util.h"
+#include "fcp/client/http/http_resource_metadata.pb.h"
+#include "fcp/client/interruptible_runner.h"
+#include "google/protobuf/io/gzip_stream.h"
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+namespace {
+
+// Returns the resource from the cache, or NOT_FOUND if it was not in the cache.
+// If the resource was compressed, it will be decompressed.
+absl::StatusOr<absl::Cord> TryGetResourceFromCache(
+ absl::string_view client_cache_id,
+ const std::optional<absl::Duration>& max_age,
+ cache::ResourceCache& resource_cache) {
+ FCP_ASSIGN_OR_RETURN(
+ cache::ResourceCache::ResourceAndMetadata cached_resource_and_metadata,
+ resource_cache.Get(client_cache_id, max_age));
+ HttpResourceMetadata metadata;
+ if (!cached_resource_and_metadata.metadata.UnpackTo(&metadata)) {
+ return absl::InternalError("Failed to unpack metadata!");
+ }
+ absl::Cord cached_resource = cached_resource_and_metadata.resource;
+ if (metadata.compression_format() ==
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP) {
+ FCP_ASSIGN_OR_RETURN(cached_resource, internal::UncompressWithGzip(
+ std::string(cached_resource)));
+ }
+ return cached_resource;
+}
+
+absl::Status TryPutResourceInCache(absl::string_view client_cache_id,
+ const absl::Cord& response_body,
+ bool response_encoded_with_gzip,
+ absl::Duration max_age,
+ cache::ResourceCache& resource_cache) {
+ // We fetched a resource that has a client_cache_id and was not
+ // loaded from the cache, put it in the cache.
+ HttpResourceMetadata metadata;
+ if (response_encoded_with_gzip) {
+ metadata.set_compression_format(
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+ } else {
+ metadata.set_compression_format(
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_UNSPECIFIED);
+ }
+ google::protobuf::Any metadata_wrapper;
+ metadata_wrapper.PackFrom(metadata);
+ return resource_cache.Put(client_cache_id, response_body, metadata_wrapper,
+ max_age);
+}
+
+} // namespace
+
+using ::google::protobuf::io::ArrayInputStream;
+using ::google::protobuf::io::GzipInputStream;
+using ::google::protobuf::io::GzipOutputStream;
+using ::google::protobuf::io::StringOutputStream;
+
+using CompressionFormat =
+ ::fcp::client::http::UriOrInlineData::InlineData::CompressionFormat;
+
+static constexpr char kOctetStream[] = "application/octet-stream";
+constexpr absl::string_view kClientDecodedGzipSuffix = "+gzip";
+
+absl::StatusOr<std::unique_ptr<HttpRequest>> InMemoryHttpRequest::Create(
+ absl::string_view uri, Method method, HeaderList extra_headers,
+ std::string body, bool use_compression) {
+ // Allow http://localhost:xxxx as an exception to the https-only policy,
+ // so that we can use a local http test server.
+ if (!absl::StartsWithIgnoreCase(uri, kHttpsScheme) &&
+ !absl::StartsWithIgnoreCase(uri, kLocalhostUri)) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Non-HTTPS URIs are not supported: ", uri));
+ }
+ if (use_compression) {
+ FCP_ASSIGN_OR_RETURN(body, internal::CompressWithGzip(body));
+ extra_headers.push_back({kContentEncodingHdr, kGzipEncodingHdrValue});
+ }
+ std::optional<std::string> content_length_hdr =
+ FindHeader(extra_headers, kContentLengthHdr);
+ if (content_length_hdr.has_value()) {
+ return absl::InvalidArgumentError(
+ "Content-Length header should not be provided!");
+ }
+
+ if (!body.empty()) {
+ switch (method) {
+ case HttpRequest::Method::kPost:
+ case HttpRequest::Method::kPatch:
+ case HttpRequest::Method::kPut:
+ case HttpRequest::Method::kDelete:
+ break;
+ default:
+ return absl::InvalidArgumentError(absl::StrCat(
+ "Request method does not allow request body: ", method));
+ }
+ // Add a Content-Length header, but only if there's a request body.
+ extra_headers.push_back({kContentLengthHdr, std::to_string(body.size())});
+ }
+
+ return absl::WrapUnique(new InMemoryHttpRequest(
+ uri, method, std::move(extra_headers), std::move(body)));
+}
+
+absl::StatusOr<int64_t> InMemoryHttpRequest::ReadBody(char* buffer,
+ int64_t requested) {
+ // This method is called from the HttpClient's thread (we don't really care
+ // which one). Hence, we use a mutex to ensure that subsequent calls to this
+ // method see the modifications to cursor_.
+ absl::WriterMutexLock _(&mutex_);
+
+ // Check whether there's any bytes left to read, and indicate the end has been
+ // reached if not.
+ int64_t bytes_left = body_.size() - cursor_;
+ if (bytes_left == 0) {
+ return absl::OutOfRangeError("End of stream reached");
+ }
+ FCP_CHECK(buffer != nullptr);
+ FCP_CHECK(requested > 0);
+ // Calculate how much data we can return, based on the size of `buffer`.
+ int64_t actual_read = bytes_left <= requested ? bytes_left : requested;
+ std::memcpy(buffer, body_.data() + cursor_, actual_read);
+ cursor_ += actual_read;
+ return actual_read;
+}
+
+absl::Status InMemoryHttpRequestCallback::OnResponseStarted(
+ const HttpRequest& request, const HttpResponse& response) {
+ absl::WriterMutexLock _(&mutex_);
+ response_code_ = response.code();
+
+ std::optional<std::string> content_encoding_header =
+ FindHeader(response.headers(), kContentEncodingHdr);
+ if (content_encoding_header.has_value()) {
+ // We don't expect the response body to be "Content-Encoding" encoded,
+ // because the `HttpClient` is supposed to transparently handle the decoding
+ // for us (unless we specified a "Accept-Encoding" header in the request,
+ // which would indicate that we wanted to handle the response decoding).
+ if (!FindHeader(request.extra_headers(), kAcceptEncodingHdr).has_value()) {
+ // Note: technically, we should only receive Content-Encoding values that
+ // match the Accept-Encoding values provided in the request headers. The
+ // check above isn't quite that strict, but that's probably fine (since
+ // such issues should be rare, and can be handled farther up the stack).
+ status_ = absl::InvalidArgumentError(
+ absl::StrCat("Unexpected header: ", kContentEncodingHdr));
+ return status_;
+ }
+ content_encoding_ = *content_encoding_header;
+ }
+
+ content_type_ = FindHeader(response.headers(), kContentTypeHdr).value_or("");
+
+ // Similarly, we should under no circumstances receive a non-identity
+ // Transfer-Encoding header, since the `HttpClient` is unconditionally
+ // required to undo any such encoding for us.
+ std::optional<std::string> transfer_encoding_header =
+ FindHeader(response.headers(), kTransferEncodingHdr);
+ if (transfer_encoding_header.has_value() &&
+ absl::AsciiStrToLower(*transfer_encoding_header) !=
+ kIdentityEncodingHdrValue) {
+ status_ = absl::InvalidArgumentError(
+ absl::StrCat("Unexpected header: ", kTransferEncodingHdr));
+ return status_;
+ }
+
+ // If no Content-Length header is provided, this means that the server either
+ // didn't provide one and is streaming the response, or that the HttpClient
+ // implementation transparently decompressed the data for us and stripped the
+ // Content-Length header (as per the HttpClient contract).
+ std::optional<std::string> content_length_hdr =
+ FindHeader(response.headers(), kContentLengthHdr);
+ if (!content_length_hdr.has_value()) {
+ return absl::OkStatus();
+ }
+
+ // A Content-Length header available. Let's parse it so that we know how much
+ // data to expect.
+ int64_t content_length;
+ // Note that SimpleAtoi safely handles non-ASCII data.
+ if (!absl::SimpleAtoi(*content_length_hdr, &content_length)) {
+ status_ = absl::InvalidArgumentError(
+ "Could not parse Content-Length response header");
+ return status_;
+ }
+ if (content_length < 0) {
+ status_ = absl::OutOfRangeError(absl::StrCat(
+ "Invalid Content-Length response header: ", content_length));
+ return status_;
+ }
+ expected_content_length_ = content_length;
+
+ return absl::OkStatus();
+}
+
+void InMemoryHttpRequestCallback::OnResponseError(const HttpRequest& request,
+ const absl::Status& error) {
+ absl::WriterMutexLock _(&mutex_);
+ status_ = absl::Status(
+ error.code(), absl::StrCat("Error receiving response headers (error: ",
+ error.message(), ")"));
+}
+
+absl::Status InMemoryHttpRequestCallback::OnResponseBody(
+ const HttpRequest& request, const HttpResponse& response,
+ absl::string_view data) {
+ // This runs on a thread chosen by the HttpClient implementation (i.e. it
+ // could be our original thread, or a different one). Ensure that if
+ // subsequent callbacks occur on different threads each thread sees the
+ // previous threads' updates to response_buffer_.
+ absl::WriterMutexLock _(&mutex_);
+
+ // Ensure we're not receiving more data than expected.
+ if (expected_content_length_.has_value() &&
+ response_buffer_.size() + data.size() > *expected_content_length_) {
+ status_ = absl::OutOfRangeError(absl::StrCat(
+ "Too much response body data received (rcvd: ", response_buffer_.size(),
+ ", new: ", data.size(), ", max: ", *expected_content_length_, ")"));
+ return status_;
+ }
+
+ // Copy the data into the target buffer. Note that this means we'll always
+ // store the response body as a number of memory fragments (rather than a
+ // contiguous buffer). However, because HttpClient implementations are
+ // encouraged to return response data in fairly large chunks, we don't expect
+ // this too cause much overhead.
+ response_buffer_.Append(data);
+
+ return absl::OkStatus();
+}
+
+void InMemoryHttpRequestCallback::OnResponseBodyError(
+ const HttpRequest& request, const HttpResponse& response,
+ const absl::Status& error) {
+ absl::WriterMutexLock _(&mutex_);
+ status_ = absl::Status(
+ error.code(),
+ absl::StrCat("Error receiving response body (response code: ",
+ response.code(), ", error: ", error.message(), ")"));
+}
+
+void InMemoryHttpRequestCallback::OnResponseCompleted(
+ const HttpRequest& request, const HttpResponse& response) {
+ // Once the body has been received correctly, turn the response code into a
+ // canonical code.
+ absl::WriterMutexLock _(&mutex_);
+ // Note: the case when too *much* response data is unexpectedly received is
+ // handled in OnResponseBody (while this handles the case of too little data).
+ if (expected_content_length_.has_value() &&
+ response_buffer_.size() != *expected_content_length_) {
+ status_ = absl::InvalidArgumentError(
+ absl::StrCat("Too little response body data received (rcvd: ",
+ response_buffer_.size(),
+ ", expected: ", *expected_content_length_, ")"));
+ return;
+ }
+
+ status_ = ConvertHttpCodeToStatus(*response_code_);
+}
+
+absl::StatusOr<InMemoryHttpResponse> InMemoryHttpRequestCallback::Response()
+ const {
+ absl::ReaderMutexLock _(&mutex_);
+ FCP_RETURN_IF_ERROR(status_);
+ // If status_ is OK, then response_code_ and response_headers_ are guaranteed
+ // to have values.
+
+ return InMemoryHttpResponse{*response_code_, content_encoding_, content_type_,
+ response_buffer_};
+}
+
+absl::StatusOr<InMemoryHttpResponse> PerformRequestInMemory(
+ HttpClient& http_client, InterruptibleRunner& interruptible_runner,
+ std::unique_ptr<http::HttpRequest> request, int64_t* bytes_received_acc,
+ int64_t* bytes_sent_acc) {
+ // Note: we must explicitly instantiate a vector here as opposed to passing an
+ // initializer list to PerformRequestsInMemory, because initializer lists do
+ // not support move-only values.
+ std::vector<std::unique_ptr<http::HttpRequest>> requests;
+ requests.push_back(std::move(request));
+ FCP_ASSIGN_OR_RETURN(
+ auto result, PerformMultipleRequestsInMemory(
+ http_client, interruptible_runner, std::move(requests),
+ bytes_received_acc, bytes_sent_acc));
+ return std::move(result[0]);
+}
+
+absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
+PerformMultipleRequestsInMemory(
+ HttpClient& http_client, InterruptibleRunner& interruptible_runner,
+ std::vector<std::unique_ptr<http::HttpRequest>> requests,
+ int64_t* bytes_received_acc, int64_t* bytes_sent_acc) {
+ // A vector that will own the request handles and callbacks (and will
+ // determine their lifetimes).
+ std::vector<std::pair<std::unique_ptr<HttpRequestHandle>,
+ std::unique_ptr<InMemoryHttpRequestCallback>>>
+ handles_and_callbacks;
+ handles_and_callbacks.reserve(requests.size());
+
+ // An accompanying vector that contains just the raw pointers, for passing to
+ // `HttpClient::PerformRequests`.
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>>
+ handles_and_callbacks_ptrs;
+ handles_and_callbacks_ptrs.reserve(requests.size());
+
+ // Enqueue each request, and create a simple callback for each request which
+ // will simply buffer the response body in-memory and allow us to consume that
+ // buffer once all requests have finished.
+ for (std::unique_ptr<HttpRequest>& request : requests) {
+ std::unique_ptr<HttpRequestHandle> handle =
+ http_client.EnqueueRequest(std::move(request));
+ auto callback = std::make_unique<InMemoryHttpRequestCallback>();
+ handles_and_callbacks_ptrs.push_back({handle.get(), callback.get()});
+ handles_and_callbacks.push_back({std::move(handle), std::move(callback)});
+ }
+
+ // Issue the requests in one call (allowing the HttpClient to issue them
+ // concurrently), in an interruptible fashion.
+ absl::Status result = interruptible_runner.Run(
+ [&http_client, &handles_and_callbacks_ptrs]() {
+ return http_client.PerformRequests(handles_and_callbacks_ptrs);
+ },
+ [&handles_and_callbacks_ptrs] {
+ // If we get aborted then call HttpRequestHandle::Cancel on all handles.
+ // This should result in the PerformRequests call returning early and
+ // InterruptibleRunner::Run returning CANCELLED.
+ for (auto [handle, callback] : handles_and_callbacks_ptrs) {
+ handle->Cancel();
+ }
+ });
+ // Update the network stats *before* we return (just in case a failed
+ // `PerformRequests` call caused some network traffic to have been sent
+ // anyway).
+ for (auto& [handle, callback] : handles_and_callbacks) {
+ HttpRequestHandle::SentReceivedBytes sent_received_bytes =
+ handle->TotalSentReceivedBytes();
+ if (bytes_received_acc != nullptr) {
+ *bytes_received_acc += sent_received_bytes.received_bytes;
+ }
+ if (bytes_sent_acc != nullptr) {
+ *bytes_sent_acc += sent_received_bytes.sent_bytes;
+ }
+ }
+
+ FCP_RETURN_IF_ERROR(result);
+
+ // Gather and return the results.
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> results;
+ results.reserve(handles_and_callbacks.size());
+ for (auto& [handle, callback] : handles_and_callbacks) {
+ results.push_back(callback->Response());
+ }
+ return results;
+}
+
+absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
+FetchResourcesInMemory(HttpClient& http_client,
+ InterruptibleRunner& interruptible_runner,
+ const std::vector<UriOrInlineData>& resources,
+ int64_t* bytes_received_acc, int64_t* bytes_sent_acc,
+ cache::ResourceCache* resource_cache) {
+ // Each resource may have the data already available (by having been included
+ // in a prior response inline), or may need to be fetched.
+
+ // We'll create an 'accessor' for each resource, providing access to that
+ // resource's data by the end of this function (either the fetched data, or
+ // the inline data). Additionally, this struct will contain the
+ // client_cache_id and max_age for the resource if the resource should be put
+ // in the cache. If the resource should not be put in the cache,
+ // client_cache_id will be an empty string.
+ struct AccessorAndCacheMetadata {
+ std::function<absl::StatusOr<InMemoryHttpResponse>()> accessor;
+ std::string client_cache_id;
+ absl::Duration max_age;
+ };
+ std::vector<AccessorAndCacheMetadata> response_accessors;
+
+ // We'll compile HttpRequest instances for those resources that do need to be
+ // fetched, then we'll fire them off all at once, and then we'll gather their
+ // responses once all requests have finished.
+ std::vector<std::unique_ptr<http::HttpRequest>> http_requests;
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> http_responses;
+ bool caching_enabled = resource_cache != nullptr;
+
+ for (const UriOrInlineData& resource : resources) {
+ if (!resource.uri().uri.empty()) {
+ // If the resource has a cache_id, try getting it out of the cache. If any
+ // condition happens outside the happy path, fetch the resource normally.
+ if (caching_enabled && !resource.uri().client_cache_id.empty()) {
+ absl::StatusOr<absl::Cord> cached_resource =
+ TryGetResourceFromCache(resource.uri().client_cache_id,
+ resource.uri().max_age, *resource_cache);
+ if (cached_resource.ok()) {
+ // Resource was successfully fetched from the cache, so we do not set
+ // the client_cache_id or the max_age.
+ response_accessors.push_back({.accessor =
+ [cached_resource]() {
+ return InMemoryHttpResponse{
+ kHttpOk, "", kOctetStream,
+ *cached_resource};
+ },
+ .client_cache_id = "",
+ .max_age = absl::ZeroDuration()});
+ continue;
+ }
+ }
+ // If the resource URI is set, then create a request to fetch the data for
+ // it, and point the accessor at the slot in http_responses where that
+ // request's response will eventually live.
+ FCP_ASSIGN_OR_RETURN(std::unique_ptr<http::HttpRequest> request,
+ InMemoryHttpRequest::Create(
+ resource.uri().uri, HttpRequest::Method::kGet,
+ {}, "", /*use_compression=*/
+ false));
+ http_requests.push_back(std::move(request));
+ int64_t response_index = http_requests.end() - http_requests.begin() - 1;
+ auto response_accessing_fn = [&http_responses, response_index]() {
+ return std::move(http_responses.at(response_index));
+ };
+ if (caching_enabled) {
+ // We didn't load the resource from the cache, so set the
+ // client_cache_id and max_age in the response_accessor.
+ response_accessors.push_back(
+ {.accessor = response_accessing_fn,
+ .client_cache_id = std::string(resource.uri().client_cache_id),
+ .max_age = resource.uri().max_age});
+ } else {
+ response_accessors.push_back({.accessor = response_accessing_fn});
+ }
+ } else {
+ // The data is available inline. Make the accessor just return a "fake"
+ // successful HTTP response (that way the caller can have unified error
+ // handling logic and doesn't have to know whether a resource was truly
+ // fetched via HTTP or not). Because the inline_data field is an
+ // absl::Cord, making a copy of it should be very cheap.
+ response_accessors.push_back({.accessor = [resource]() {
+ std::string content_type(kOctetStream);
+ switch (resource.inline_data().compression_format) {
+ case UriOrInlineData::InlineData::CompressionFormat::kUncompressed:
+ break;
+ case UriOrInlineData::InlineData::CompressionFormat::kGzip:
+ absl::StrAppend(&content_type, kClientDecodedGzipSuffix);
+ break;
+ }
+ return InMemoryHttpResponse{kHttpOk, "", content_type,
+ resource.inline_data().data};
+ }});
+ }
+ }
+
+ // Perform the requests.
+ auto resource_fetch_result = PerformMultipleRequestsInMemory(
+ http_client, interruptible_runner, std::move(http_requests),
+ bytes_received_acc, bytes_sent_acc);
+ // Check whether issuing the requests failed as a whole (generally indicating
+ // a programming error).
+ FCP_RETURN_IF_ERROR(resource_fetch_result);
+ http_responses = std::move(*resource_fetch_result);
+
+ // Compile the result vector by getting each resource's response using the
+ // corresponding accessor.
+ // Note that the order of results returned corresponds to the order of
+ // resources in the vector we originally received.
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> result;
+ result.reserve(response_accessors.size());
+ for (const auto& response_accessor : response_accessors) {
+ absl::StatusOr<InMemoryHttpResponse> response =
+ response_accessor.accessor();
+ if (response.ok()) {
+ bool encoded_with_gzip = absl::EndsWithIgnoreCase(
+ response->content_type, kClientDecodedGzipSuffix);
+ if (!response_accessor.client_cache_id.empty()) {
+ TryPutResourceInCache(response_accessor.client_cache_id,
+ response->body, encoded_with_gzip,
+ response_accessor.max_age, *resource_cache)
+ .IgnoreError();
+ }
+ if (encoded_with_gzip) {
+ std::string response_body_temp(response->body);
+ // We're going to overwrite the response body with the decoded
+ // contents shortly, no need to keep an extra copy of it in memory.
+ response->body.Clear();
+ absl::StatusOr<absl::Cord> decoded_response_body =
+ internal::UncompressWithGzip(response_body_temp);
+ if (!decoded_response_body.ok()) {
+ response = decoded_response_body.status();
+ } else {
+ response->body = *std::move(decoded_response_body);
+ }
+ }
+ }
+ result.push_back(response);
+ }
+ return result;
+}
+
+namespace internal {
+absl::StatusOr<std::string> CompressWithGzip(
+ const std::string& uncompressed_data) {
+ int starting_pos = 0;
+ size_t str_size = uncompressed_data.length();
+ size_t in_size = str_size;
+ std::string output;
+ StringOutputStream string_output_stream(&output);
+ GzipOutputStream::Options options;
+ options.format = GzipOutputStream::GZIP;
+ GzipOutputStream compressed_stream(&string_output_stream, options);
+ void* out;
+ int out_size;
+ while (starting_pos < str_size) {
+ if (!compressed_stream.Next(&out, &out_size) || out_size <= 0) {
+ return absl::InternalError(
+ absl::StrCat("An error has occurred during compression: ",
+ compressed_stream.ZlibErrorMessage()));
+ }
+
+ if (in_size <= out_size) {
+ uncompressed_data.copy(static_cast<char*>(out), in_size, starting_pos);
+ // Ensure that the stream's output buffer is truncated to match the total
+ // amount of data.
+ compressed_stream.BackUp(out_size - static_cast<int>(in_size));
+ break;
+ }
+ uncompressed_data.copy(static_cast<char*>(out), out_size, starting_pos);
+ starting_pos += out_size;
+ in_size -= out_size;
+ }
+
+ if (!compressed_stream.Close()) {
+ return absl::InternalError(absl::StrCat(
+ "Failed to close the stream: ", compressed_stream.ZlibErrorMessage()));
+ }
+ return output;
+}
+
+absl::StatusOr<absl::Cord> UncompressWithGzip(
+ const std::string& compressed_data) {
+ absl::Cord out;
+ const void* buffer;
+ int size;
+ ArrayInputStream sub_stream(compressed_data.data(),
+ static_cast<int>(compressed_data.size()));
+ GzipInputStream input_stream(&sub_stream, GzipInputStream::GZIP);
+
+ while (input_stream.Next(&buffer, &size)) {
+ if (size <= -1) {
+ return absl::InternalError(
+ "Uncompress failed: invalid input size returned by the "
+ "GzipInputStream.");
+ }
+ out.Append(absl::string_view(reinterpret_cast<const char*>(buffer), size));
+ }
+
+ if (input_stream.ZlibErrorMessage() != nullptr) {
+ // Some real error happened during decompression.
+ return absl::InternalError(
+ absl::StrCat("An error has occurred during decompression:",
+ input_stream.ZlibErrorMessage()));
+ }
+
+ return out;
+}
+
+} // namespace internal
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/in_memory_request_response.h b/fcp/client/http/in_memory_request_response.h
new file mode 100644
index 0000000..56dc51c
--- /dev/null
+++ b/fcp/client/http/in_memory_request_response.h
@@ -0,0 +1,251 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_IN_MEMORY_REQUEST_RESPONSE_H_
+#define FCP_CLIENT_HTTP_IN_MEMORY_REQUEST_RESPONSE_H_
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "fcp/client/cache/resource_cache.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/interruptible_runner.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+// Simple `HttpRequest` implementation with an in-memory request body.
+class InMemoryHttpRequest : public HttpRequest {
+ public:
+ // Factory method for creating an instance.
+ //
+ // Callers are recommended to `std::move` the `body` parameter (or to rely on
+ // copy elision), to avoid unnecessary copies of the data.
+ //
+ // Note that a "Content-Length" header will be constructed automatically, and
+ // must not be provided by the caller.
+ //
+ // If "use_compression" is true, the body will be compressed with
+ // gzip. A "Content-Encoding" header will be added, and the "Content-Length"
+ // header will be the compressed length.
+ //
+ // Returns an INVALID_ARGUMENT error if:
+ // - the URI is a non-HTTPS URI,
+ // - the request has a body but the request method doesn't allow it,
+ // - the headers contain a "Content-Length" header.
+ static absl::StatusOr<std::unique_ptr<HttpRequest>> Create(
+ absl::string_view uri, Method method, HeaderList extra_headers,
+ std::string body, bool use_compression);
+
+ absl::string_view uri() const override { return uri_; };
+ Method method() const override { return method_; };
+ const HeaderList& extra_headers() const override { return headers_; }
+ bool HasBody() const override { return !body_.empty(); };
+
+ absl::StatusOr<int64_t> ReadBody(char* buffer, int64_t requested) override;
+
+ private:
+ InMemoryHttpRequest(absl::string_view uri, Method method,
+ HeaderList extra_headers, std::string body)
+ : uri_(uri),
+ method_(method),
+ body_(std::move(body)),
+ headers_(std::move(extra_headers)) {}
+
+ const std::string uri_;
+ const Method method_;
+ const std::string body_;
+ const HeaderList headers_;
+ int64_t cursor_ ABSL_GUARDED_BY(mutex_) = 0;
+ mutable absl::Mutex mutex_;
+};
+
+// Simple container class for holding an HTTP response code, headers, and
+// in-memory request body, as well as metadata for the client-side cache.
+struct InMemoryHttpResponse {
+ int code;
+ // This is empty if no "Content-Encoding" header was present in the response
+ // headers.
+ std::string content_encoding;
+ // This is empty if no "Content-Type" header was present in the response
+ // headers.
+ std::string content_type;
+ absl::Cord body;
+};
+
+// Simple `HttpRequestCallback` implementation that stores the response and its
+// body in an `InMemoryHttpResponse` object for later consumption.
+class InMemoryHttpRequestCallback : public HttpRequestCallback {
+ public:
+ InMemoryHttpRequestCallback() = default;
+
+ absl::Status OnResponseStarted(const HttpRequest& request,
+ const HttpResponse& response) override;
+ void OnResponseError(const HttpRequest& request,
+ const absl::Status& error) override;
+ absl::Status OnResponseBody(const HttpRequest& request,
+ const HttpResponse& response,
+ absl::string_view data) override;
+ void OnResponseBodyError(const HttpRequest& request,
+ const HttpResponse& response,
+ const absl::Status& error) override;
+ void OnResponseCompleted(const HttpRequest& request,
+ const HttpResponse& response) override;
+ absl::StatusOr<InMemoryHttpResponse> Response() const;
+
+ private:
+ absl::Status status_ ABSL_GUARDED_BY(mutex_) =
+ absl::UnavailableError("No response received");
+ std::optional<int> response_code_ ABSL_GUARDED_BY(mutex_);
+ std::string content_encoding_ ABSL_GUARDED_BY(mutex_);
+ std::string content_type_ ABSL_GUARDED_BY(mutex_);
+ std::optional<int64_t> expected_content_length_ ABSL_GUARDED_BY(mutex_);
+ absl::Cord response_buffer_ ABSL_GUARDED_BY(mutex_);
+ mutable absl::Mutex mutex_;
+ std::string client_cache_id_;
+};
+
+// Utility for performing a single HTTP request and returning the results (incl.
+// the response body) via an in-memory object, in an interruptible way.
+//
+// If `bytes_received_acc` and `bytes_sent_acc` are non-null then those
+// accumulators will also be incremented by the amount of data that was
+// received/sent by the request.
+//
+// Returns an error if the request failed.
+absl::StatusOr<InMemoryHttpResponse> PerformRequestInMemory(
+ HttpClient& http_client, InterruptibleRunner& interruptible_runner,
+ std::unique_ptr<http::HttpRequest> request, int64_t* bytes_received_acc,
+ int64_t* bytes_sent_acc);
+
+// Utility for performing multiple HTTP requests and returning the results
+// (incl. the response body) in memory.
+//
+// Returns an error if issuing the joint `PerformRequests` call failed.
+// Otherwise it returns a vector containing the result of each request
+// (in the same order the requests were provided in).
+absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
+PerformMultipleRequestsInMemory(
+ HttpClient& http_client, InterruptibleRunner& interruptible_runner,
+ std::vector<std::unique_ptr<http::HttpRequest>> requests,
+ int64_t* bytes_received_acc, int64_t* bytes_sent_acc);
+
+// Simple class representing a resource for which data is already available
+// in-memory (`inline_data`) or for which data needs to be fetched by an HTTP
+// GET request (via `uri`). Only one field can ever be set to a non-empty value,
+// or both fields may be empty (indicating a zero-length resource for which
+// nothing has to be fetched).
+class UriOrInlineData {
+ public:
+ struct InlineData {
+ enum class CompressionFormat {
+ kUncompressed,
+ kGzip,
+ };
+
+ absl::Cord data;
+ CompressionFormat compression_format = CompressionFormat::kUncompressed;
+ };
+
+ struct Uri {
+ std::string uri;
+ std::string client_cache_id;
+ absl::Duration max_age;
+ };
+
+ // Creates an instance representing a URI from which data has to be fetched.
+ // If the resource represented by the uri should be cached, both
+ // `client_cache_id` and `max_age` must be set, otherwise they may be
+ // empty/zero.
+ static UriOrInlineData CreateUri(std::string uri, std::string client_cache_id,
+ absl::Duration max_age) {
+ return UriOrInlineData({.uri = std::move(uri),
+ .client_cache_id = std::move(client_cache_id),
+ .max_age = max_age},
+ {});
+ }
+ // Creates an instance representing a resource's already-available (or empty)
+ // data.
+ static UriOrInlineData CreateInlineData(
+ absl::Cord inline_data,
+ InlineData::CompressionFormat compression_format) {
+ return UriOrInlineData({}, {std::move(inline_data), compression_format});
+ }
+
+ const Uri& uri() const { return uri_; }
+ const InlineData& inline_data() const { return inline_data_; }
+
+ private:
+ UriOrInlineData(Uri uri, InlineData inline_data)
+ : uri_(std::move(uri)), inline_data_(std::move(inline_data)) {}
+
+ const Uri uri_;
+ const InlineData inline_data_;
+};
+
+// Utility for (potentially) fetching multiple resources at once, each of which
+// either needs to be fetched from a URI using a HTTP GET request, or for which
+// its data is already available, and returning the eventual results (incl. the
+// response body) via in-memory objects, in an interruptible way. If resources
+// do need to be fetched, then a single `HttpClient::PerformRequests` call will
+// be made for all resources at once.
+//
+// This makes it a convenient way for callers to gather the data for a related
+// set of resources (some of which might already have their data available) in
+// one go, and to then access the data for the first resource at index 0, the
+// second resource at index 1, etc., transparently handling the various
+// permutations that are possible (e.g. resource A having data inline but
+// B having to be fetched, or both being inlined, or ...) via a unified access
+// pattern and error handling mechanism.
+//
+// If `bytes_received_acc` and `bytes_sent_acc` are non-null then those
+// accumulators will also be incremented by the aggregate amount of data that
+// was received/sent by the HTTP requests that were issued.
+//
+// Returns an error if issuing the joint `HttpClient::PerformRequests` call
+// failed. Otherwise it returns a vector containing the result for each
+// resource (in the same order the resources were provided in).
+absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
+FetchResourcesInMemory(HttpClient& http_client,
+ InterruptibleRunner& interruptible_runner,
+ const std::vector<UriOrInlineData>& resources,
+ int64_t* bytes_received_acc, int64_t* bytes_sent_acc,
+ cache::ResourceCache* resource_cache);
+
+// Used by the class and in tests only.
+namespace internal {
+absl::StatusOr<std::string> CompressWithGzip(
+ const std::string& uncompressed_data);
+absl::StatusOr<absl::Cord> UncompressWithGzip(
+ const std::string& compressed_data);
+} // namespace internal
+
+}; // namespace http
+}; // namespace client
+}; // namespace fcp
+
+#endif // FCP_CLIENT_HTTP_IN_MEMORY_REQUEST_RESPONSE_H_
diff --git a/fcp/client/http/in_memory_request_response_test.cc b/fcp/client/http/in_memory_request_response_test.cc
new file mode 100644
index 0000000..d06b8b1
--- /dev/null
+++ b/fcp/client/http/in_memory_request_response_test.cc
@@ -0,0 +1,1576 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/in_memory_request_response.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/escaping.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/blocking_counter.h"
+#include "absl/synchronization/notification.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/simulated_clock.h"
+#include "fcp/client/cache/file_backed_resource_cache.h"
+#include "fcp/client/cache/test_helpers.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/http_client_util.h"
+#include "fcp/client/http/http_resource_metadata.pb.h"
+#include "fcp/client/http/testing/test_helpers.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+#include "google/protobuf/io/gzip_stream.h"
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+
+namespace fcp::client::http {
+namespace {
+
+using ::fcp::IsCode;
+using ::fcp::client::http::FakeHttpResponse;
+using ::fcp::client::http::MockableHttpClient;
+using ::fcp::client::http::MockHttpClient;
+using ::testing::_;
+using ::testing::ContainerEq;
+using ::testing::Contains;
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::FieldsAre;
+using ::testing::Ge;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+using ::testing::MockFunction;
+using ::testing::Ne;
+using ::testing::NiceMock;
+using ::testing::Pair;
+using ::testing::Return;
+using ::testing::StrEq;
+using ::testing::StrictMock;
+
+using CompressionFormat =
+ ::fcp::client::http::UriOrInlineData::InlineData::CompressionFormat;
+
+constexpr absl::string_view kOctetStream = "application/octet-stream";
+int64_t kMaxCacheSizeBytes = 10000000;
+
+google::protobuf::Any MetadataForUncompressedResource() {
+ HttpResourceMetadata metadata;
+ google::protobuf::Any any;
+ any.PackFrom(metadata);
+ return any;
+}
+
+google::protobuf::Any MetadataForCompressedResource() {
+ HttpResourceMetadata metadata;
+ metadata.set_compression_format(
+ ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
+ google::protobuf::Any any;
+ any.PackFrom(metadata);
+ return any;
+}
+
+TEST(InMemoryHttpRequestTest, NonHttpsUriFails) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("http://invalid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ EXPECT_THAT(request.status(), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(InMemoryHttpRequestTest, GetWithRequestBodyFails) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create(
+ "https://valid.com", HttpRequest::Method::kGet, {},
+ "non_empty_request_body", /*use_compression=*/false);
+ EXPECT_THAT(request.status(), IsCode(INVALID_ARGUMENT));
+}
+
+// Ensures that providing a Content-Length header results in an error (since it
+// is automatically generated instead).
+TEST(InMemoryHttpRequestTest, ContentLengthHeaderFails) {
+ // Note that we purposely use a mixed-case "Content-length" header name, to
+ // ensure that the check is correctly case-insensitive.
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create(
+ "https://valid.com", HttpRequest::Method::kPost,
+ {{"Content-length", "1234"}}, "non_empty_request_body",
+ /*use_compression=*/false);
+ EXPECT_THAT(request.status(), IsCode(INVALID_ARGUMENT));
+}
+
+TEST(InMemoryHttpRequestTest, ValidGetRequest) {
+ const std::string expected_uri = "https://valid.com";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create(expected_uri, HttpRequest::Method::kGet, {},
+ "", /*use_compression=*/false);
+ ASSERT_OK(request);
+ EXPECT_THAT((*request)->uri(), StrEq(expected_uri));
+ EXPECT_EQ((*request)->method(), HttpRequest::Method::kGet);
+ // Because no request body is present, the Content-Length header shouldn't
+ // have been added.
+ EXPECT_THAT((*request)->extra_headers(), IsEmpty());
+ EXPECT_FALSE((*request)->HasBody());
+}
+
+TEST(InMemoryHttpRequestTest, ValidGetRequestWithHeaders) {
+ const HeaderList expected_headers{{"Foo", "Bar"}, {"Baz", "123"}};
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, expected_headers,
+ "", /*use_compression=*/false);
+ ASSERT_OK(request);
+ EXPECT_THAT((*request)->extra_headers(), ContainerEq(expected_headers));
+}
+
+TEST(InMemoryHttpRequestTest, ValidPostRequestWithoutBody) {
+ const std::string expected_uri = "https://valid.com";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create(expected_uri, HttpRequest::Method::kPost, {},
+ "", /*use_compression=*/false);
+ ASSERT_OK(request);
+ EXPECT_THAT((*request)->uri(), StrEq(expected_uri));
+ EXPECT_EQ((*request)->method(), HttpRequest::Method::kPost);
+ // Because no request body is present, the Content-Length header shouldn't
+ // have been added.
+ EXPECT_THAT((*request)->extra_headers(), IsEmpty());
+ EXPECT_FALSE((*request)->HasBody());
+}
+
+TEST(InMemoryHttpRequestTest, ValidPostRequestWithBody) {
+ const std::string expected_uri = "https://valid.com";
+ const std::string expected_body = "request_body";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create(expected_uri, HttpRequest::Method::kPost, {},
+ expected_body,
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+ EXPECT_THAT((*request)->uri(), StrEq(expected_uri));
+ EXPECT_EQ((*request)->method(), HttpRequest::Method::kPost);
+ EXPECT_THAT((*request)->extra_headers(),
+ ElementsAre(Pair("Content-Length",
+ std::to_string(expected_body.size()))));
+ EXPECT_TRUE((*request)->HasBody());
+}
+
+// Checks that the automatically generated Content-Length header is
+// appropriately added to any specified extra headers (rather than replacing
+// them completely).
+TEST(InMemoryHttpRequestTest, ValidPostRequestWithBodyAndHeaders) {
+ const HeaderList original_headers{{"Foo", "Bar"}, {"Baz", "123"}};
+ const std::string expected_body = "request_body";
+ const HeaderList expected_headers{
+ {"Foo", "Bar"},
+ {"Baz", "123"},
+ {"Content-Length", std::to_string(expected_body.size())}};
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kPost, original_headers,
+ expected_body,
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ EXPECT_THAT((*request)->extra_headers(), ContainerEq(expected_headers));
+}
+
+TEST(InMemoryHttpRequestTest, ReadBodySimple) {
+ const std::string expected_body = "request_body";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kPost, {}, expected_body,
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ std::string actual_body;
+ actual_body.resize(expected_body.size());
+ // Read the body in one go (the "simple" case).
+ auto read_result =
+ (*request)->ReadBody(actual_body.data(), actual_body.size());
+ ASSERT_OK(read_result);
+ EXPECT_THAT(*read_result, actual_body.size());
+ EXPECT_THAT(actual_body, StrEq(expected_body));
+ // Expect the second read to indicate the end of the stream.
+ EXPECT_THAT((*request)->ReadBody(nullptr, 1), IsCode(OUT_OF_RANGE));
+}
+
+TEST(InMemoryHttpRequestTest, ReadBodyChunked) {
+ // This test reads the body in chunks of 3 bytes at a time (rather than all at
+ // once, like previous test). To test some edge cases, we ensure that the
+ // request body's length is not evenly dividable by 3.
+ const std::string expected_body = "12345678";
+ ASSERT_THAT(expected_body.size() % 3, Ne(0));
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kPost, {}, expected_body,
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ std::string actual_body;
+ // Pre-size the buffer with 'X' characters. This will allow us to check
+ // whether each read modifies the expected range of the buffer.
+ actual_body.resize(expected_body.size(), 'X');
+
+ // Read the body 3 bytes at a time.
+ // The first read tests the case where less data is requested than is
+ // available.
+ absl::StatusOr<int64_t> read_result =
+ (*request)->ReadBody(actual_body.data(), 3);
+ ASSERT_OK(read_result);
+ EXPECT_THAT(*read_result, 3);
+ EXPECT_THAT(actual_body, StrEq("123XXXXX"));
+
+ read_result = (*request)->ReadBody(actual_body.data() + 3, 3);
+ ASSERT_OK(read_result);
+ EXPECT_THAT(*read_result, 3);
+ EXPECT_THAT(actual_body, StrEq("123456XX"));
+
+ // The last read should only read 2 bytes. This tests the case where more data
+ // is requested than is available. A correct implementation should not write
+ // more bytes than it has available, which should ensure that no writes will
+ // occur beyond actual_body.data()'s buffer size (which is only 8 bytes long).
+ read_result = (*request)->ReadBody(actual_body.data() + 6, 3);
+ ASSERT_OK(read_result);
+ EXPECT_THAT(*read_result, 2);
+ EXPECT_THAT(actual_body, StrEq(expected_body));
+
+ // Expect the last read to indicate the end of the stream.
+ EXPECT_THAT((*request)->ReadBody(nullptr, 1), IsCode(OUT_OF_RANGE));
+}
+
+TEST(InMemoryHttpRequestTest, RequestWithCompressedBody) {
+ const std::string uncompressed_body =
+ "request_body_AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kPost, {},
+ uncompressed_body,
+ /*use_compression=*/true);
+ ASSERT_OK(request);
+ auto content_encoding_header =
+ FindHeader((*request)->extra_headers(), kContentEncodingHdr);
+ ASSERT_TRUE(content_encoding_header.has_value());
+ ASSERT_EQ(content_encoding_header.value(), kGzipEncodingHdrValue);
+
+ auto content_length_header =
+ FindHeader((*request)->extra_headers(), kContentLengthHdr);
+ ASSERT_TRUE(content_length_header.has_value());
+ int compressed_length = std::stoi(content_length_header.value());
+ ASSERT_GT(compressed_length, 0);
+ ASSERT_LT(compressed_length, uncompressed_body.size());
+
+ std::string actual_body;
+ actual_body.resize(compressed_length);
+ // Read the body in one go (the "simple" case).
+ auto read_result =
+ (*request)->ReadBody(actual_body.data(), actual_body.size());
+ ASSERT_OK(read_result);
+ EXPECT_THAT(*read_result, actual_body.size());
+
+ // Expect the second read to indicate the end of the stream.
+ EXPECT_THAT((*request)->ReadBody(nullptr, 1), IsCode(OUT_OF_RANGE));
+
+ auto recovered_body = internal::UncompressWithGzip(actual_body);
+ ASSERT_OK(recovered_body);
+ EXPECT_EQ(*recovered_body, uncompressed_body);
+}
+
+TEST(InMemoryHttpRequestCallbackTest, ResponseFailsBeforeHeaders) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ InMemoryHttpRequestCallback callback;
+ callback.OnResponseError(**request, absl::UnimplementedError("foobar"));
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ EXPECT_THAT(actual_response, IsCode(UNIMPLEMENTED));
+ EXPECT_THAT(actual_response.status().message(), HasSubstr("foobar"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, ResponseFailsAfterHeaders) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ auto fake_response = FakeHttpResponse(kHttpOk, {});
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ callback.OnResponseBodyError(**request, fake_response,
+ absl::UnimplementedError("foobar"));
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ EXPECT_THAT(actual_response, IsCode(UNIMPLEMENTED));
+ EXPECT_THAT(actual_response.status().message(), HasSubstr("foobar"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, ResponseFailsAfterPartialBody) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ auto fake_response = FakeHttpResponse(kHttpOk, {});
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, "response_"));
+ callback.OnResponseBodyError(**request, fake_response,
+ absl::UnimplementedError("foobar"));
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ EXPECT_THAT(actual_response, IsCode(UNIMPLEMENTED));
+ EXPECT_THAT(actual_response.status().message(), HasSubstr("foobar"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest,
+ TestResponseWithContentLengthAndBodyTooShort) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ auto fake_response = FakeHttpResponse(kHttpOk, {{"Content-Length", "5"}});
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ // Return a partial response (only 4 characters instead of the expected 5
+ // indicated by the Content-Length header).
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, "12"));
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, "34"));
+ callback.OnResponseCompleted(**request, fake_response);
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ EXPECT_THAT(actual_response, IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(actual_response.status().message(),
+ HasSubstr("Too little response body data received"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest,
+ TestResponseWithContentLengthAndBodyTooLong) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ auto fake_response = FakeHttpResponse(kHttpOk, {{"Content-Length", "5"}});
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ // Return a more response data than expected (6 characters instead of the
+ // expected 5 indicated by the Content-Length header).
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, "12"));
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, "34"));
+ absl::Status response_body_result =
+ callback.OnResponseBody(**request, fake_response, "56");
+ EXPECT_THAT(response_body_result, IsCode(OUT_OF_RANGE));
+ EXPECT_THAT(response_body_result.message(),
+ HasSubstr("Too much response body data received"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, ResponseWithContentLengthNegative) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ // Set the Content-Length to an invalid negative value.
+ auto fake_response = FakeHttpResponse(kHttpOk, {{"Content-Length", "-1"}});
+
+ InMemoryHttpRequestCallback callback;
+ absl::Status response_started_result =
+ callback.OnResponseStarted(**request, fake_response);
+ EXPECT_THAT(response_started_result, IsCode(OUT_OF_RANGE));
+ EXPECT_THAT(response_started_result.message(),
+ HasSubstr("Invalid Content-Length response header"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, ResponseWithContentLengthNonInteger) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ // Set the Content-Length to an invalid non-integer, non-ASCII value.
+ auto fake_response =
+ FakeHttpResponse(kHttpOk, {{"Content-Length", "\U0001F600"}});
+
+ InMemoryHttpRequestCallback callback;
+ absl::Status response_started_result =
+ callback.OnResponseStarted(**request, fake_response);
+ EXPECT_THAT(response_started_result, IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(response_started_result.message(),
+ HasSubstr("Could not parse Content-Length response header"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, ResponseWithContentEncodingHeader) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ // Add a Content-Encoding header, which implementations should never provide
+ // to us, unless we advertised an Accept-Encoding header in the request.
+ auto fake_response = FakeHttpResponse(kHttpOk, {{"Content-Encoding", "foo"}});
+
+ InMemoryHttpRequestCallback callback;
+ absl::Status response_started_result =
+ callback.OnResponseStarted(**request, fake_response);
+ EXPECT_THAT(response_started_result, IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(response_started_result.message(),
+ HasSubstr("Unexpected header: Content-Encoding"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, ResponseWithTransferEncodingHeader) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ // Add a Transfer-Encoding header, which implementations should never provide
+ // to us.
+ auto fake_response =
+ FakeHttpResponse(kHttpOk, {{"Transfer-Encoding", "foo"}});
+
+ InMemoryHttpRequestCallback callback;
+ absl::Status response_started_result =
+ callback.OnResponseStarted(**request, fake_response);
+ EXPECT_THAT(response_started_result, IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(response_started_result.message(),
+ HasSubstr("Unexpected header: Transfer-Encoding"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, OkResponseWithoutContentLength) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ // Note that the fake response does not contain a Content-Length header. The
+ // InMemoryHttpRequestCallback should be able to handle this (accumulating
+ // response data until OnResponseCompleted is called).
+ int expected_code = 201; // "201 Created"
+ auto fake_response = FakeHttpResponse(expected_code, {});
+ const std::string expected_body = "response_body";
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ // We return the response body in one go (the "simple" case).
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, expected_body));
+ callback.OnResponseCompleted(**request, fake_response);
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ ASSERT_OK(actual_response);
+ EXPECT_THAT(*actual_response, FieldsAre(expected_code, IsEmpty(), IsEmpty(),
+ StrEq(expected_body)));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, OkResponseWithContentLength) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ int expected_code = kHttpOk;
+ const std::string expected_body = "response_body";
+ auto fake_response = FakeHttpResponse(
+ expected_code,
+ {{"Content-Length", std::to_string(expected_body.size())}});
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, expected_body));
+ callback.OnResponseCompleted(**request, fake_response);
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ ASSERT_OK(actual_response);
+ EXPECT_THAT(*actual_response, FieldsAre(expected_code, IsEmpty(), IsEmpty(),
+ StrEq(expected_body)));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, OkResponseChunkedBody) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ // Note that the fake response does not contain a Content-Length header. The
+ // InMemoryHttpRequestCallback should be able to handle this (accumulating
+ // response data until OnResponseCompleted is called).
+ auto fake_response = FakeHttpResponse(kHttpOk, {});
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ // This test returns the body in chunks of 3 bytes at a time (rather than
+ // all at once, like previous test). To test some edge cases, we ensure that
+ // the request body's length is not evenly dividable by 3.
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, "123"));
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, "456"));
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, "78"));
+ callback.OnResponseCompleted(**request, fake_response);
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ ASSERT_OK(actual_response);
+ EXPECT_THAT(actual_response->body, StrEq("12345678"));
+}
+
+TEST(InMemoryHttpRequestCallbackTest,
+ TestOkResponseWithEmptyBodyWithoutContentLength) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ int expected_code = kHttpOk;
+ auto fake_response = FakeHttpResponse(expected_code, {});
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, ""));
+ callback.OnResponseCompleted(**request, fake_response);
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ ASSERT_OK(actual_response);
+ EXPECT_THAT(*actual_response,
+ FieldsAre(expected_code, IsEmpty(), IsEmpty(), IsEmpty()));
+}
+
+TEST(InMemoryHttpRequestCallbackTest,
+ TestOkResponseWithEmptyBodyWithContentLength) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ int expected_code = kHttpOk;
+ auto fake_response =
+ FakeHttpResponse(expected_code, {{"Content-Length", "0"}});
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, ""));
+ callback.OnResponseCompleted(**request, fake_response);
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ ASSERT_OK(actual_response);
+ EXPECT_THAT(*actual_response,
+ FieldsAre(expected_code, IsEmpty(), IsEmpty(), IsEmpty()));
+}
+
+TEST(InMemoryHttpRequestCallbackTest,
+ TestOkResponsWithAcceptEncodingRequestHeader) {
+ const std::string expected_content_encoding = "foo";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create(
+ "https://valid.com", HttpRequest::Method::kGet,
+ {{"Accept-Encoding", expected_content_encoding}}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ // Note that the fake response contains a Content-Encoding header, which
+ // should be allowed because the request contained an Accept-Encoding header.
+ int expected_code = kHttpOk;
+ auto fake_response = FakeHttpResponse(
+ expected_code, {{"Some-Response-Header", "foo"},
+ {"Content-Encoding", expected_content_encoding}});
+ const std::string expected_body = "response_body";
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ // We return the response body in one go (the "simple" case).
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, expected_body));
+ callback.OnResponseCompleted(**request, fake_response);
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ ASSERT_OK(actual_response);
+ EXPECT_THAT(*actual_response,
+ FieldsAre(expected_code, StrEq(expected_content_encoding),
+ IsEmpty(), StrEq(expected_body)));
+}
+
+TEST(InMemoryHttpRequestCallbackTest, NotFoundResponse) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ // Return an HTTP error response.
+ auto fake_response = FakeHttpResponse(kHttpNotFound, {});
+ const std::string expected_body = "response_body";
+
+ InMemoryHttpRequestCallback callback;
+ ASSERT_OK(callback.OnResponseStarted(**request, fake_response));
+ ASSERT_OK(callback.OnResponseBody(**request, fake_response, expected_body));
+ callback.OnResponseCompleted(**request, fake_response);
+
+ absl::StatusOr<InMemoryHttpResponse> actual_response = callback.Response();
+ EXPECT_THAT(actual_response, IsCode(NOT_FOUND));
+ EXPECT_THAT(actual_response.status().message(), HasSubstr("404"));
+}
+
+class PerformRequestsTest : public ::testing::Test {
+ protected:
+ PerformRequestsTest()
+ : interruptible_runner_(
+ &mock_log_manager_, mock_should_abort_.AsStdFunction(),
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
+ .interrupt_timeout =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT}) {}
+ void SetUp() override {
+ root_cache_dir_ = testing::TempDir();
+ root_files_dir_ = testing::TempDir();
+ }
+
+ void TearDown() override {
+ std::filesystem::remove_all(root_cache_dir_);
+ std::filesystem::remove_all(root_files_dir_);
+ }
+
+ NiceMock<MockLogManager> mock_log_manager_;
+ NiceMock<MockFunction<bool()>> mock_should_abort_;
+ InterruptibleRunner interruptible_runner_;
+ StrictMock<MockHttpClient> mock_http_client_;
+ testing::StrictMock<MockLogManager> log_manager_;
+ SimulatedClock clock_;
+ std::string root_cache_dir_;
+ std::string root_files_dir_;
+};
+
+TEST_F(PerformRequestsTest, PerformRequestInMemoryOk) {
+ std::string expected_request_body = "request_body";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create(
+ "https://valid.com", HttpRequest::Method::kPost,
+ {{"Some-Request-Header", "foo"}}, expected_request_body,
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ int expected_response_code = kHttpOk;
+ std::string expected_response_body = "response_body";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre((*request)->uri(), (*request)->method(),
+ Contains(Header{"Some-Request-Header", "foo"}),
+ StrEq(expected_request_body))))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code,
+ {{"Some-Response-Header", "bar"}},
+ expected_response_body)));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ absl::StatusOr<InMemoryHttpResponse> result = PerformRequestInMemory(
+ mock_http_client_, interruptible_runner_, *std::move(request),
+ // We pass in non-null pointers for the network stats, to ensure they are
+ // correctly updated.
+ &bytes_received, &bytes_sent);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, FieldsAre(expected_response_code, IsEmpty(), IsEmpty(),
+ StrEq(expected_response_body)));
+
+ EXPECT_THAT(bytes_sent, Ne(bytes_received));
+ EXPECT_THAT(bytes_sent, Ge(expected_request_body.size()));
+ EXPECT_THAT(bytes_received, Ge(expected_response_body.size()));
+}
+
+TEST_F(PerformRequestsTest, PerformRequestInMemoryNotFound) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ int expected_response_code = kHttpNotFound;
+ std::string expected_response_body = "response_body";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre((*request)->uri(), (*request)->method(), _, _)))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code, {},
+ expected_response_body)));
+
+ absl::StatusOr<InMemoryHttpResponse> result =
+ PerformRequestInMemory(mock_http_client_, interruptible_runner_,
+ *std::move(request), nullptr, nullptr);
+ EXPECT_THAT(result, IsCode(NOT_FOUND));
+ EXPECT_THAT(result.status().message(), HasSubstr("404"));
+}
+
+TEST_F(PerformRequestsTest, PerformRequestInMemoryEarlyError) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kPost, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre((*request)->uri(), (*request)->method(), _, _)))
+ // Make the call to PerformSingleRequest (and therefore the call to
+ // HttpClient::PerformRequests) fail as a whole (rather than having the
+ // individual request return a failure).
+ .WillOnce(Return(absl::InvalidArgumentError("PerformRequests failed")));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ absl::StatusOr<InMemoryHttpResponse> result =
+ PerformRequestInMemory(mock_http_client_, interruptible_runner_,
+ *std::move(request), &bytes_received, &bytes_sent);
+ EXPECT_THAT(result, IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(result.status().message(), HasSubstr("PerformRequests failed"));
+
+ // We know that MockHttpClient will have updated the 'sent' network stat
+ // before having called into PerformSingleRequest, so that should be
+ // reflected.
+ EXPECT_THAT(bytes_sent, Ge(0));
+ // The 'received' network stat should be 0 OTOH, since issuing the request
+ // failed.
+ EXPECT_EQ(bytes_received, 0);
+}
+
+// Tests the case where the request gets interrupted.
+TEST_F(PerformRequestsTest, PerformRequestInMemoryCancellation) {
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kGet, {}, "",
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+
+ absl::Notification request_issued;
+ // We expect one calls to the cancellation listener.
+ absl::BlockingCounter counter_should_abort(1);
+ // When the HttpClient receives a HttpRequestHandle::Cancel call, we decrement
+ // the counter.
+ mock_http_client_.SetCancellationListener(
+ [&counter_should_abort]() { counter_should_abort.DecrementCount(); });
+
+ // Make HttpClient::PerformRequests() block until the counter is decremented.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre((*request)->uri(), (*request)->method(), _, _)))
+ .WillOnce([&request_issued, &counter_should_abort](
+ MockableHttpClient::SimpleHttpRequest ignored) {
+ request_issued.Notify();
+ counter_should_abort.Wait();
+ return FakeHttpResponse(503, {}, "");
+ });
+ // Make should_abort return false until we know that the request was issued
+ // (i.e. once InterruptibleRunner has actually started running the code it was
+ // given), and then make it return true, triggering an abort sequence and
+ // unblocking the PerformRequests() call we caused to block above.
+ EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
+ return request_issued.HasBeenNotified();
+ });
+
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ // The request should result in a CANCELLED outcome.
+ absl::StatusOr<InMemoryHttpResponse> result =
+ PerformRequestInMemory(mock_http_client_, interruptible_runner_,
+ *std::move(request), &bytes_received, &bytes_sent);
+
+ EXPECT_THAT(result, IsCode(CANCELLED));
+ EXPECT_THAT(result.status().message(),
+ HasSubstr("cancelled after graceful wait"));
+
+ // The network stats should still have been updated though (to reflect the
+ // data sent and received up until the point of interruption).
+ EXPECT_THAT(bytes_sent, Ge(0));
+ EXPECT_THAT(bytes_received, Ge(0));
+}
+
+TEST_F(PerformRequestsTest, PerformTwoRequestsInMemoryOk) {
+ std::string expected_request_body = "request_body";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> request =
+ InMemoryHttpRequest::Create(
+ "https://valid.com", HttpRequest::Method::kPost,
+ {{"Some-Request-Header", "foo"}}, expected_request_body,
+ /*use_compression=*/false);
+ ASSERT_OK(request);
+ std::string another_expected_request_body = "request_body_2";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> another_request =
+ InMemoryHttpRequest::Create("https://valid.com",
+ HttpRequest::Method::kPost,
+ {{"Some-Other-Request-Header", "foo2"}},
+ another_expected_request_body,
+ /*use_compression=*/false);
+ ASSERT_OK(another_request);
+
+ std::string expected_response_body = "response_body";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre((*request)->uri(), (*request)->method(),
+ Contains(Header{"Some-Request-Header", "foo"}),
+ StrEq(expected_request_body))))
+ .WillOnce(Return(FakeHttpResponse(
+ kHttpOk, {{"Some-Response-Header", "bar"}}, expected_response_body)));
+ std::string another_expected_response_body = "another_response_body";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(FieldsAre(
+ (*request)->uri(), (*request)->method(),
+ Contains(Header{"Some-Other-Request-Header", "foo2"}),
+ StrEq(another_expected_request_body))))
+ .WillOnce(Return(
+ FakeHttpResponse(kHttpOk, {{"Some-Other-Response-Header", "bar2"}},
+ another_expected_response_body)));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ std::vector<std::unique_ptr<HttpRequest>> requests;
+ requests.push_back(*std::move(request));
+ requests.push_back(*std::move(another_request));
+ absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>> results =
+ PerformMultipleRequestsInMemory(
+ mock_http_client_, interruptible_runner_, std::move(requests),
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent);
+ ASSERT_OK(results);
+ ASSERT_EQ(results->size(), 2);
+
+ auto first_response = (*results)[0];
+ ASSERT_OK(first_response);
+ EXPECT_THAT(*first_response, FieldsAre(kHttpOk, IsEmpty(), IsEmpty(),
+ StrEq(expected_response_body)));
+ auto second_response = (*results)[1];
+ ASSERT_OK(second_response);
+ EXPECT_THAT(*second_response,
+ FieldsAre(kHttpOk, IsEmpty(), IsEmpty(),
+ StrEq(another_expected_response_body)));
+
+ EXPECT_THAT(bytes_sent, Ne(bytes_received));
+ EXPECT_THAT(bytes_sent, Ge(expected_request_body.size() +
+ another_expected_request_body.size()));
+ EXPECT_THAT(bytes_received, Ge(expected_response_body.size() +
+ another_expected_response_body.size()));
+}
+
+TEST_F(PerformRequestsTest, PerformTwoRequestsWithOneFailedOneSuccess) {
+ std::string success_request_body = "success_request_body";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> success_request =
+ InMemoryHttpRequest::Create(
+ "https://valid.com", HttpRequest::Method::kPost,
+ {{"Some-Request-Header", "foo"}}, success_request_body,
+ /*use_compression=*/false);
+ ASSERT_OK(success_request);
+ std::string failure_request_body = "failure_request_body";
+ absl::StatusOr<std::unique_ptr<HttpRequest>> failure_request =
+ InMemoryHttpRequest::Create(
+ "https://valid.com", HttpRequest::Method::kPost,
+ {{"Some-Other-Request-Header", "foo2"}}, failure_request_body,
+ /*use_compression=*/false);
+ ASSERT_OK(failure_request);
+
+ int ok_response_code = kHttpOk;
+ std::string success_response_body = "response_body";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(FieldsAre(
+ (*success_request)->uri(), (*success_request)->method(),
+ Contains(Header{"Some-Request-Header", "foo"}),
+ StrEq(success_request_body))))
+ .WillOnce(Return(FakeHttpResponse(ok_response_code,
+ {{"Some-Response-Header", "bar"}},
+ success_response_body)));
+ std::string failure_response_body = "failure_response_body";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(FieldsAre(
+ (*failure_request)->uri(), (*failure_request)->method(),
+ Contains(Header{"Some-Other-Request-Header", "foo2"}),
+ StrEq(failure_request_body))))
+ .WillOnce(
+ Return(FakeHttpResponse(kHttpNotFound, {}, failure_response_body)));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ std::vector<std::unique_ptr<HttpRequest>> requests;
+ requests.push_back(*std::move(success_request));
+ requests.push_back(*std::move(failure_request));
+ absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>> results =
+ PerformMultipleRequestsInMemory(
+ mock_http_client_, interruptible_runner_, std::move(requests),
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent);
+ ASSERT_OK(results);
+ ASSERT_EQ(results->size(), 2);
+ auto first_response = (*results)[0];
+ ASSERT_OK(first_response);
+ EXPECT_THAT(*first_response, FieldsAre(ok_response_code, IsEmpty(), IsEmpty(),
+ StrEq(success_response_body)));
+
+ EXPECT_THAT(results->at(1), IsCode(NOT_FOUND));
+ EXPECT_THAT(results->at(1).status().message(), HasSubstr("404"));
+
+ EXPECT_THAT(bytes_sent, Ne(bytes_received));
+ EXPECT_THAT(bytes_sent,
+ Ge(success_request_body.size() + failure_request_body.size()));
+ EXPECT_THAT(bytes_received,
+ Ge(success_response_body.size() + failure_response_body.size()));
+}
+
+// Tests the case where a zero-length vector of UriOrInlineData is passed in. It
+// should result in a zero-length result vector (as opposed to an error or a
+// crash).
+TEST_F(PerformRequestsTest, FetchResourcesInMemoryEmptyInputVector) {
+ auto result = FetchResourcesInMemory(mock_http_client_, interruptible_runner_,
+ {}, nullptr, nullptr,
+ /*resource_cache=*/nullptr);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, IsEmpty());
+}
+
+// Tests the case where both fields of UriOrInlineData are empty. The empty
+// inline_data field should be returned.
+TEST_F(PerformRequestsTest, FetchResourcesInMemoryEmptyUriAndInline) {
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_,
+ {UriOrInlineData::CreateInlineData(absl::Cord(),
+ CompressionFormat::kUncompressed)},
+ nullptr, nullptr,
+ /*resource_cache=*/nullptr);
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0],
+ FieldsAre(kHttpOk, IsEmpty(), kOctetStream, IsEmpty()));
+}
+
+// Tests the case where one of the URIs is invalid. The whole request should
+// result in an error in that case.
+TEST_F(PerformRequestsTest, FetchResourcesInMemoryInvalidUri) {
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_,
+ {UriOrInlineData::CreateUri("https://valid.com", "",
+ absl::ZeroDuration()),
+ UriOrInlineData::CreateUri("http://invalid.com", "",
+ absl::ZeroDuration())},
+ nullptr, nullptr,
+ /*resource_cache=*/nullptr);
+ EXPECT_THAT(result, IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(result.status().message(), HasSubstr("Non-HTTPS"));
+}
+
+// Tests the case where all of the requested resources must be fetched via URI.
+TEST_F(PerformRequestsTest, FetchResourcesInMemoryAllUris) {
+ const std::string uri1 = "https://valid.com/1";
+ const std::string uri2 = "https://valid.com/2";
+ const std::string uri3 = "https://valid.com/3";
+ const std::string uri4 = "https://valid.com/4";
+ auto resource1 = UriOrInlineData::CreateUri(uri1, "", absl::ZeroDuration());
+ auto resource2 = UriOrInlineData::CreateUri(uri2, "", absl::ZeroDuration());
+ auto resource3 = UriOrInlineData::CreateUri(uri3, "", absl::ZeroDuration());
+ auto resource4 = UriOrInlineData::CreateUri(uri4, "", absl::ZeroDuration());
+
+ int expected_response_code1 = kHttpOk;
+ int expected_response_code2 = kHttpNotFound;
+ int expected_response_code3 = kHttpServiceUnavailable;
+ int expected_response_code4 = 204; // "204 No Content"
+ std::string expected_response_body1 = "response_body1";
+ std::string expected_response_body4 = "response_body4";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri1, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code1, {},
+ expected_response_body1)));
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri2, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code2, {}, "")));
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri3, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code3, {}, "")));
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri4, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code4, {},
+ expected_response_body4)));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ auto result =
+ FetchResourcesInMemory(mock_http_client_, interruptible_runner_,
+ {resource1, resource2, resource3, resource4},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent,
+ /*resource_cache=*/nullptr);
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0],
+ FieldsAre(expected_response_code1, IsEmpty(), IsEmpty(),
+ StrEq(expected_response_body1)));
+ EXPECT_THAT((*result)[1], IsCode(NOT_FOUND));
+ EXPECT_THAT((*result)[1].status().message(), HasSubstr("404"));
+ EXPECT_THAT((*result)[2], IsCode(UNAVAILABLE));
+ EXPECT_THAT((*result)[2].status().message(), HasSubstr("503"));
+ ASSERT_OK((*result)[3]);
+ EXPECT_THAT(*(*result)[3],
+ FieldsAre(expected_response_code4, IsEmpty(), IsEmpty(),
+ StrEq(expected_response_body4)));
+
+ EXPECT_THAT(bytes_sent, Ne(bytes_received));
+ EXPECT_THAT(bytes_sent, Ge(0));
+ EXPECT_THAT(bytes_received, Ge(0));
+}
+
+// Tests the case where some of the requested resources have inline data
+// available.
+TEST_F(PerformRequestsTest, FetchResourcesInMemorySomeInlineData) {
+ const std::string uri1 = "https://valid.com/1";
+ const std::string uri3 = "https://valid.com/3";
+ std::string expected_response_body2 = "response_body2";
+ std::string expected_response_body4 = "response_body4";
+ auto resource1 = UriOrInlineData::CreateUri(uri1, "", absl::ZeroDuration());
+ auto resource2 = UriOrInlineData::CreateInlineData(
+ absl::Cord(expected_response_body2), CompressionFormat::kUncompressed);
+ auto resource3 = UriOrInlineData::CreateUri(uri3, "", absl::ZeroDuration());
+ auto resource4 = UriOrInlineData::CreateInlineData(
+ absl::Cord(expected_response_body4), CompressionFormat::kUncompressed);
+
+ int expected_response_code1 = kHttpServiceUnavailable;
+ int expected_response_code3 = 204; // "204 No Content"
+ std::string expected_response_body3 = "response_body3";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri1, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code1, {}, "")));
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri3, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code3, {},
+ expected_response_body3)));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ auto result =
+ FetchResourcesInMemory(mock_http_client_, interruptible_runner_,
+ {resource1, resource2, resource3, resource4},
+ &bytes_received, &bytes_sent,
+ /*resource_cache=*/nullptr);
+ ASSERT_OK(result);
+
+ EXPECT_THAT((*result)[0], IsCode(UNAVAILABLE));
+ EXPECT_THAT((*result)[0].status().message(), HasSubstr("503"));
+ ASSERT_OK((*result)[1]);
+ EXPECT_THAT(*(*result)[1], FieldsAre(kHttpOk, IsEmpty(), kOctetStream,
+ StrEq(expected_response_body2)));
+ ASSERT_OK((*result)[2]);
+ EXPECT_THAT(*(*result)[2],
+ FieldsAre(expected_response_code3, IsEmpty(), IsEmpty(),
+ StrEq(expected_response_body3)));
+ ASSERT_OK((*result)[3]);
+ EXPECT_THAT(*(*result)[3], FieldsAre(kHttpOk, IsEmpty(), kOctetStream,
+ StrEq(expected_response_body4)));
+
+ EXPECT_THAT(bytes_sent, Ne(bytes_received));
+ EXPECT_THAT(bytes_sent, Ge(0));
+ EXPECT_THAT(bytes_received, Ge(0));
+}
+
+// Tests the case where all of the requested resources have inline data
+// available (and hence no HTTP requests are expected to be issued).
+TEST_F(PerformRequestsTest, FetchResourcesInMemoryOnlyInlineData) {
+ std::string expected_response_body1 = "response_body1";
+ std::string expected_response_body2 = "response_body2";
+ auto resource1 = UriOrInlineData::CreateInlineData(
+ absl::Cord(expected_response_body1), CompressionFormat::kUncompressed);
+ auto resource2 = UriOrInlineData::CreateInlineData(
+ absl::Cord(expected_response_body2), CompressionFormat::kUncompressed);
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ auto result = FetchResourcesInMemory(mock_http_client_, interruptible_runner_,
+ {resource1, resource2}, &bytes_received,
+ &bytes_sent,
+ /*resource_cache=*/nullptr);
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0], FieldsAre(kHttpOk, IsEmpty(), kOctetStream,
+ StrEq(expected_response_body1)));
+ ASSERT_OK((*result)[1]);
+ EXPECT_THAT(*(*result)[1], FieldsAre(kHttpOk, IsEmpty(), kOctetStream,
+ StrEq(expected_response_body2)));
+
+ // The network stats should be untouched, since no network requests were
+ // issued.
+ EXPECT_EQ(bytes_sent, 0);
+ EXPECT_EQ(bytes_received, 0);
+}
+
+// Tests the case where the fetches get interrupted.
+TEST_F(PerformRequestsTest, FetchResourcesInMemoryCancellation) {
+ const std::string uri1 = "https://valid.com/1";
+ const std::string uri2 = "https://valid.com/2";
+ auto resource1 = UriOrInlineData::CreateUri(uri1, "", absl::ZeroDuration());
+ auto resource2 = UriOrInlineData::CreateUri(uri2, "", absl::ZeroDuration());
+
+ EXPECT_CALL(mock_http_client_, PerformSingleRequest(FieldsAre(uri1, _, _, _)))
+ .WillOnce(Return(FakeHttpResponse(kHttpOk, {}, "")));
+
+ absl::Notification request_issued;
+ // We expect two calls to the cancellation listener, one for each request.
+ absl::BlockingCounter counter_should_abort(2);
+ // When the HttpClient receives a HttpRequestHandle::Cancel call, we decrement
+ // the counter.
+ mock_http_client_.SetCancellationListener(
+ [&counter_should_abort]() { counter_should_abort.DecrementCount(); });
+
+ // Make HttpClient::PerformRequests() block until the counter is decremented.
+ EXPECT_CALL(mock_http_client_, PerformSingleRequest(FieldsAre(uri2, _, _, _)))
+ .WillOnce([&request_issued, &counter_should_abort](
+ MockableHttpClient::SimpleHttpRequest ignored) {
+ request_issued.Notify();
+ counter_should_abort.Wait();
+ return FakeHttpResponse(503, {}, "");
+ });
+ // Make should_abort return false until we know that the 2nd request was
+ // issued (i.e. once InterruptibleRunner has actually started running the code
+ // it was given), and then make it return true, triggering an abort sequence
+ // and unblocking the PerformRequests() call we caused to block above.
+ EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
+ return request_issued.HasBeenNotified();
+ });
+
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ // The request should result in an overall CANCELLED outcome.
+ auto result = FetchResourcesInMemory(mock_http_client_, interruptible_runner_,
+ {resource1, resource2}, &bytes_received,
+ &bytes_sent,
+ /*resource_cache=*/nullptr);
+ EXPECT_THAT(result, IsCode(CANCELLED));
+ EXPECT_THAT(result.status().message(),
+ HasSubstr("cancelled after graceful wait"));
+
+ // The network stats should still have been updated though (to reflect the
+ // data sent and received up until the point of interruption).
+ EXPECT_THAT(bytes_sent, Ge(0));
+ EXPECT_THAT(bytes_received, Ge(0));
+}
+
+TEST_F(PerformRequestsTest, FetchResourcesInMemoryCompressedResources) {
+ const std::string uri = "https://valid.com/";
+ auto resource1 = UriOrInlineData::CreateUri(uri, "", absl::ZeroDuration());
+
+ int expected_response_code = kHttpOk;
+ std::string content_type = "bytes+gzip";
+ std::string expected_response_body = "response_body: AAAAAAAAAAAAAAAAAAAAA";
+ auto compressed_response_body =
+ internal::CompressWithGzip(expected_response_body);
+ ASSERT_OK(compressed_response_body);
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code,
+ {{kContentTypeHdr, content_type}},
+ *compressed_response_body)));
+
+ auto resource2 = UriOrInlineData::CreateInlineData(
+ absl::Cord(*compressed_response_body), CompressionFormat::kGzip);
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_, {resource1, resource2},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent,
+ /*resource_cache=*/nullptr);
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0],
+ FieldsAre(expected_response_code, IsEmpty(), StrEq(content_type),
+ StrEq(expected_response_body)));
+ EXPECT_THAT(*(*result)[1],
+ FieldsAre(kHttpOk, IsEmpty(), absl::StrCat(kOctetStream, "+gzip"),
+ StrEq(expected_response_body)));
+
+ EXPECT_THAT(bytes_sent, Ne(bytes_received));
+ EXPECT_THAT(bytes_sent, Ge(0));
+ EXPECT_THAT(bytes_received, Ge(2 * compressed_response_body->size()));
+}
+
+TEST_F(PerformRequestsTest,
+ FetchResourcesInMemoryCompressedResourcesFailToDecode) {
+ const std::string uri = "https://valid.com/";
+ auto resource1 = UriOrInlineData::CreateUri(uri, "", absl::ZeroDuration());
+
+ int expected_response_code = kHttpOk;
+ std::string content_type = "not-actually-gzipped+gzip";
+ std::string expected_response_body = "I am not a valid gzipped body ლ(ಠ益ಠლ)";
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code,
+ {{kContentTypeHdr, content_type}},
+ expected_response_body)));
+
+ auto resource2 = UriOrInlineData::CreateInlineData(
+ absl::Cord(expected_response_body), CompressionFormat::kGzip);
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_, {resource1, resource2},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent,
+ /*resource_cache=*/nullptr);
+ // Fetching will succeed
+ ASSERT_OK(result);
+
+ // ...but our responses will have failed to decode.
+ EXPECT_THAT((*result)[0], IsCode(INTERNAL));
+ EXPECT_THAT((*result)[1], IsCode(INTERNAL));
+
+ EXPECT_THAT(bytes_sent, Ne(bytes_received));
+ EXPECT_THAT(bytes_sent, Ge(0));
+ EXPECT_THAT(bytes_received, Ge(2 * expected_response_body.size()));
+}
+
+TEST_F(PerformRequestsTest, FetchResourcesInMemoryCachedResourceOk) {
+ const std::string uri = "https://valid.com/1";
+ const std::string cache_id = "(^˵◕ω◕˵^)";
+ absl::Cord cached_resource("(((*°▽°*)八(*°▽°*)))");
+ absl::Duration max_age = absl::Hours(1);
+ int expected_response_code = kHttpOk;
+ auto resource = UriOrInlineData::CreateUri(uri, cache_id, max_age);
+ auto resource_cache = cache::FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK((*resource_cache)
+ ->Put(cache_id, cached_resource,
+ MetadataForUncompressedResource(), max_age));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_, {resource},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent, resource_cache->get());
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0], FieldsAre(expected_response_code, IsEmpty(),
+ kOctetStream, StrEq(cached_resource)));
+
+ // Fully from the cache!
+ EXPECT_THAT(bytes_sent, Eq(0));
+ EXPECT_THAT(bytes_received, Eq(0));
+}
+
+TEST_F(PerformRequestsTest,
+ FetchResourcesInMemoryCachedResourceOkAndCompressed) {
+ const std::string uri = "https://valid.com/1";
+ const std::string cache_id = "(^˵◕ω◕˵^)";
+ absl::Cord cached_resource("(((*°▽°*)八(*°▽°*)))");
+ absl::Cord compressed_cached_resource(
+ *internal::CompressWithGzip(std::string(cached_resource)));
+ absl::Duration max_age = absl::Hours(1);
+ int expected_response_code = kHttpOk;
+ auto resource = UriOrInlineData::CreateUri(uri, cache_id, max_age);
+ auto resource_cache = cache::FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+ ASSERT_OK((*resource_cache)
+ ->Put(cache_id, compressed_cached_resource,
+ MetadataForCompressedResource(), max_age));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_, {resource},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent, resource_cache->get());
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0], FieldsAre(expected_response_code, IsEmpty(),
+ kOctetStream, StrEq(cached_resource)));
+
+ // Fully from the cache!
+ EXPECT_THAT(bytes_sent, Eq(0));
+ EXPECT_THAT(bytes_received, Eq(0));
+}
+
+TEST_F(PerformRequestsTest, FetchResourcesInMemoryNotCachedButThenPutInCache) {
+ const std::string uri = "https://valid.com/1";
+ const std::string cache_id = "(^˵◕ω◕˵^)";
+ absl::Cord expected_response_body("(((*°▽°*)八(*°▽°*)))");
+ absl::Duration max_age = absl::Hours(1);
+ int expected_response_code = kHttpOk;
+ std::string content_type = "bytes";
+ auto resource = UriOrInlineData::CreateUri(uri, cache_id, max_age);
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code,
+ {{kContentTypeHdr, content_type}},
+ std::string(expected_response_body))));
+
+ auto resource_cache = cache::FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_, {resource},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent, resource_cache->get());
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0],
+ FieldsAre(expected_response_code, IsEmpty(), content_type,
+ StrEq(expected_response_body)));
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ auto stored_resource = (*resource_cache)->Get(cache_id, std::nullopt);
+ ASSERT_OK(stored_resource);
+ EXPECT_THAT(*stored_resource,
+ FieldsAre(StrEq(expected_response_body),
+ EqualsProto(MetadataForUncompressedResource())));
+}
+
+TEST_F(PerformRequestsTest,
+ FetchResourcesInMemoryNotCachedButThenPutInCacheCompressed) {
+ const std::string uri = "https://valid.com/1";
+ const std::string cache_id = "(^˵◕ω◕˵^)";
+ absl::Cord expected_response_body("(((*°▽°*)八(*°▽°*)))");
+ absl::Duration max_age = absl::Hours(1);
+ int expected_response_code = kHttpOk;
+ std::string content_type = "bytes+gzip";
+ absl::Cord compressed_response_body(
+ *internal::CompressWithGzip(std::string(expected_response_body)));
+ auto resource = UriOrInlineData::CreateUri(uri, cache_id, max_age);
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(
+ expected_response_code, {{kContentTypeHdr, content_type}},
+ std::string(compressed_response_body))));
+
+ auto resource_cache = cache::FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_, {resource},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent, resource_cache->get());
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0],
+ FieldsAre(expected_response_code, IsEmpty(), content_type,
+ StrEq(expected_response_body)));
+
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ auto stored_resource = (*resource_cache)->Get(cache_id, std::nullopt);
+ ASSERT_OK(stored_resource);
+ EXPECT_THAT(*stored_resource,
+ FieldsAre(StrEq(compressed_response_body),
+ EqualsProto(MetadataForCompressedResource())));
+}
+
+TEST_F(PerformRequestsTest,
+ FetchResourcesInMemoryNotCachedPutInCacheWithZeroMaxAgeDoesntCrash) {
+ const std::string uri = "https://valid.com/1";
+ const std::string cache_id = "(^˵◕ω◕˵^)";
+ absl::Cord expected_response_body("(((*°▽°*)八(*°▽°*)))");
+ absl::Duration max_age = absl::ZeroDuration();
+ int expected_response_code = kHttpOk;
+ std::string content_type = "bytes";
+ auto resource = UriOrInlineData::CreateUri(uri, cache_id, max_age);
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code,
+ {{kContentTypeHdr, content_type}},
+ std::string(expected_response_body))));
+
+ auto resource_cache = cache::FileBackedResourceCache::Create(
+ root_files_dir_, root_cache_dir_, &log_manager_, &clock_,
+ kMaxCacheSizeBytes);
+ ASSERT_OK(resource_cache);
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_MISS));
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_, {resource},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent, resource_cache->get());
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0],
+ FieldsAre(expected_response_code, IsEmpty(), content_type,
+ StrEq(expected_response_body)));
+ EXPECT_CALL(log_manager_, LogDiag(DebugDiagCode::RESOURCE_CACHE_HIT));
+ auto stored_resource = (*resource_cache)->Get(cache_id, std::nullopt);
+ ASSERT_OK(stored_resource);
+ EXPECT_THAT(*stored_resource,
+ FieldsAre(StrEq(expected_response_body),
+ EqualsProto(MetadataForUncompressedResource())));
+}
+
+TEST_F(PerformRequestsTest,
+ FetchResourcesInMemoryCachedResourceGetReturnsInternal) {
+ const std::string uri = "https://valid.com/1";
+ const std::string cache_id = "(^˵◕ω◕˵^)";
+ absl::Cord expected_response_body("(((*°▽°*)八(*°▽°*)))");
+ absl::Duration max_age = absl::Hours(1);
+ int expected_response_code = kHttpOk;
+ std::string content_type = "bytes";
+ auto resource = UriOrInlineData::CreateUri(uri, cache_id, max_age);
+ StrictMock<cache::MockResourceCache> resource_cache;
+
+ EXPECT_CALL(resource_cache, Get(cache_id, std::make_optional(max_age)))
+ .WillOnce(Return(absl::InternalError("the cache exploded -_-")));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code,
+ {{kContentTypeHdr, content_type}},
+ std::string(expected_response_body))));
+
+ EXPECT_CALL(resource_cache, Put(cache_id, expected_response_body, _, max_age))
+ .WillOnce(Return(absl::OkStatus()));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_, {resource},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent, &resource_cache);
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0],
+ FieldsAre(expected_response_code, IsEmpty(), content_type,
+ StrEq(expected_response_body)));
+}
+
+TEST_F(PerformRequestsTest,
+ FetchResourcesInMemoryPutInCacheReturnsInternalDoesntCrash) {
+ const std::string uri = "https://valid.com/1";
+ const std::string cache_id = "(^˵◕ω◕˵^)";
+ absl::Cord expected_response_body("(((*°▽°*)八(*°▽°*)))");
+ absl::Duration max_age = absl::Hours(1);
+ int expected_response_code = kHttpOk;
+ std::string content_type = "bytes";
+ auto resource = UriOrInlineData::CreateUri(uri, cache_id, max_age);
+ StrictMock<cache::MockResourceCache> resource_cache;
+
+ EXPECT_CALL(resource_cache, Get(cache_id, std::make_optional(max_age)))
+ .WillOnce(Return(absl::NotFoundError("not in the cache sorry!")));
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(
+ FieldsAre(uri, HttpRequest::Method::kGet, HeaderList{}, "")))
+ .WillOnce(Return(FakeHttpResponse(expected_response_code,
+ {{kContentTypeHdr, content_type}},
+ std::string(expected_response_body))));
+
+ EXPECT_CALL(resource_cache,
+ Put(cache_id, expected_response_body,
+ EqualsProto(MetadataForUncompressedResource()), max_age))
+ .WillOnce(Return(absl::InternalError("the cache exploded -_-")));
+
+ int64_t bytes_received = 0;
+ int64_t bytes_sent = 0;
+ auto result = FetchResourcesInMemory(
+ mock_http_client_, interruptible_runner_, {resource},
+ // We pass in non-null pointers for the network
+ // stats, to ensure they are correctly updated.
+ &bytes_received, &bytes_sent, &resource_cache);
+ ASSERT_OK(result);
+
+ ASSERT_OK((*result)[0]);
+ EXPECT_THAT(*(*result)[0],
+ FieldsAre(expected_response_code, IsEmpty(), content_type,
+ StrEq(expected_response_body)));
+}
+
+} // namespace
+} // namespace fcp::client::http
diff --git a/fcp/client/http/java/BUILD b/fcp/client/http/java/BUILD
new file mode 100644
index 0000000..7ca7427
--- /dev/null
+++ b/fcp/client/http/java/BUILD
@@ -0,0 +1,62 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+default_visibility = [
+ "//fcp:internal",
+]
+
+package(
+ default_visibility = default_visibility,
+ licenses = ["notice"], # Apache 2.0
+)
+
+# --------------------------------------------------------------------
+
+# The protos used to pass data across the JNI boundary.
+proto_library(
+ name = "jni_proto",
+ srcs = ["jni.proto"],
+)
+
+java_proto_library(
+ name = "jni_java_proto",
+ deps = [":jni_proto"],
+)
+
+cc_proto_library(
+ name = "jni_cc_proto",
+ deps = [":jni_proto"],
+)
+
+cc_library(
+ name = "java_http_client",
+ srcs = ["java_http_client.cc"],
+ hdrs = ["java_http_client.h"],
+ deps = [
+ ":jni_cc_proto",
+ "//fcp/base",
+ "//fcp/client/http:http_client",
+ "//fcp/client/http:http_client_util",
+ "//fcp/jni:jni_util",
+ "@bazel_tools//tools/jdk:jni",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/cleanup",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googleapis//google/rpc:status_cc_proto",
+ ],
+ alwayslink = 1,
+)
diff --git a/fcp/client/http/java/java_http_client.cc b/fcp/client/http/java/java_http_client.cc
new file mode 100644
index 0000000..82a67f1
--- /dev/null
+++ b/fcp/client/http/java/java_http_client.cc
@@ -0,0 +1,531 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/java/java_http_client.h"
+
+#include <jni.h>
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+// #include "google/rpc/status.pb.h"
+#include "absl/cleanup/cleanup.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/http_client_util.h"
+#include "fcp/client/http/java/jni.pb.h"
+#include "fcp/jni/jni_util.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+namespace java {
+
+using fcp::client::http::HttpRequestCallback;
+using fcp::client::http::HttpRequestHandle;
+using fcp::jni::JavaFieldSig;
+using fcp::jni::JavaMethodSig;
+using fcp::jni::LocalRefDeleter;
+using fcp::jni::ParseProtoFromJByteArray;
+using fcp::jni::ScopedJniEnv;
+using fcp::jni::SerializeProtoToJByteArray;
+
+namespace {
+
+// The Java method signatures for the Java class corresponding to the C++
+// `JavaHttpClient` class.
+struct JavaHttpClientClassDesc {
+ static constexpr JavaMethodSig kEnqueueRequest = {
+ "enqueueRequest",
+ "([B)Lcom/google/fcp/client/http/HttpClientForNative$HttpRequestHandle;"};
+ static constexpr JavaMethodSig kPerformRequests = {"performRequests",
+ "([Ljava/lang/Object;)[B"};
+ static constexpr JavaMethodSig kClose = {"close", "()V"};
+};
+
+// The Java method and field signatures for the Java class corresponding to the
+// C++ `JavaHttpRequestHandle` class.
+struct JavaHttpRequestHandleClassDesc {
+ static constexpr JavaMethodSig kGetTotalSentReceivedBytes = {
+ "getTotalSentReceivedBytes", "()[B"};
+ static constexpr JavaMethodSig kClose = {"close", "()V"};
+ static constexpr JavaFieldSig kNativeHandle = {"nativeHandle", "J"};
+};
+
+JniHttpMethod ConvertHttpClientMethodToProtoMethod(
+ fcp::client::http::HttpRequest::Method method) {
+ switch (method) {
+ case fcp::client::http::HttpRequest::Method::kHead:
+ return JniHttpMethod::HTTP_METHOD_HEAD;
+ case fcp::client::http::HttpRequest::Method::kGet:
+ return JniHttpMethod::HTTP_METHOD_GET;
+ case fcp::client::http::HttpRequest::Method::kPost:
+ return JniHttpMethod::HTTP_METHOD_POST;
+ case fcp::client::http::HttpRequest::Method::kPut:
+ return JniHttpMethod::HTTP_METHOD_PUT;
+ case fcp::client::http::HttpRequest::Method::kPatch:
+ return JniHttpMethod::HTTP_METHOD_PATCH;
+ case fcp::client::http::HttpRequest::Method::kDelete:
+ return JniHttpMethod::HTTP_METHOD_DELETE;
+ default:
+ return JniHttpMethod::HTTP_METHOD_UNKNOWN;
+ }
+}
+
+// Calls JNIEnv::GetMethodID and ensures that its return value is valid.
+jmethodID GetMethodIdOrAbort(JNIEnv& env, jclass clazz, JavaMethodSig method) {
+ jmethodID id = env.GetMethodID(clazz, method.name, method.signature);
+ FCP_CHECK(id != nullptr);
+ return id;
+}
+} // namespace
+
+JavaHttpClient::JavaHttpClient(JavaVM* jvm, jobject java_http_client)
+ : jvm_(jvm) {
+ ScopedJniEnv scoped_env(jvm_);
+ JNIEnv* env = scoped_env.env();
+ jthis_ = env->NewGlobalRef(java_http_client);
+ FCP_CHECK(jthis_ != nullptr);
+ // We get the class from the jobject here instead of looking it up by name in
+ // the classloader because we may be using this class from a non java thread
+ // that has been attached to the jvm, and thus has a classloader with only
+ // "system" classes.
+ jclass java_http_client_class = env->GetObjectClass(java_http_client);
+ FCP_CHECK(java_http_client_class != nullptr);
+ LocalRefDeleter java_http_client_class_deleter(env, java_http_client_class);
+
+ // Look up the method IDs for the Java methods we'll call later on.
+ enqueue_request_id_ = GetMethodIdOrAbort(
+ *env, java_http_client_class, JavaHttpClientClassDesc::kEnqueueRequest);
+ perform_requests_id_ = GetMethodIdOrAbort(
+ *env, java_http_client_class, JavaHttpClientClassDesc::kPerformRequests);
+ close_id_ = GetMethodIdOrAbort(*env, java_http_client_class,
+ JavaHttpClientClassDesc::kClose);
+}
+
+JavaHttpClient::~JavaHttpClient() {
+ ScopedJniEnv scoped_env(jvm_);
+ JNIEnv* env = scoped_env.env();
+
+ // We call the Java close() method when the destructor is invoked. This gives
+ // the Java code a chance to clean things up on its side, if it needs to.
+ env->CallVoidMethod(jthis_, close_id_);
+ FCP_CHECK(!env->ExceptionCheck());
+
+ // Delete the global reference to the Java object.
+ env->DeleteGlobalRef(jthis_);
+}
+
+std::unique_ptr<HttpRequestHandle> JavaHttpClient::EnqueueRequest(
+ std::unique_ptr<fcp::client::http::HttpRequest> request) {
+ // Convert the `HttpRequest`'s info into a proto we can serialize and pass
+ // over the JNI boundary.
+ JniHttpRequest request_proto;
+ request_proto.set_uri(std::string(request->uri()));
+ request_proto.set_method(
+ ConvertHttpClientMethodToProtoMethod(request->method()));
+ for (auto request_header : request->extra_headers()) {
+ JniHttpHeader* header = request_proto.add_extra_headers();
+ header->set_name(request_header.first);
+ header->set_value(request_header.second);
+ }
+ request_proto.set_has_body(request->HasBody());
+
+ // Call into Java to create the Java request handle object that will cooperate
+ // with our C++ `JavaHttpRequestHandle` object.
+ ScopedJniEnv scoped_env(jvm_);
+ JNIEnv* env = scoped_env.env();
+ jbyteArray serialized_request_proto =
+ SerializeProtoToJByteArray(env, request_proto);
+ LocalRefDeleter serialized_request_proto_deleter(env,
+ serialized_request_proto);
+ jobject java_http_request_handle = env->CallObjectMethod(
+ jthis_, enqueue_request_id_, serialized_request_proto);
+ FCP_CHECK(java_http_request_handle != nullptr);
+ FCP_CHECK(!env->ExceptionCheck());
+
+ // Create the C++ `JavaHttpRequestHandle` object (which will 'attach' itself
+ // to the Java object by writing its address to the `nativeHandle` Java
+ // field).
+ auto result = std::make_unique<JavaHttpRequestHandle>(
+ jvm_, java_http_request_handle, std::move(request));
+ return result;
+}
+
+absl::Status JavaHttpClient::PerformRequests(
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>>
+ generic_requests) {
+ // We're about to kick off a group of requests. Each request has a matching
+ // callback, as well as a corresponding Java object. To prepare for the
+ // `performRequests` call into Java:
+ // 1. Create an Object[] array, consisting of each request's corresponding
+ // Java object.
+ // 2. For each request, register the callback that should be used for it.
+ ScopedJniEnv scoped_env(jvm_);
+ JNIEnv* env = scoped_env.env();
+
+ // The object array we'll create will just be of type Object[]. The Java
+ // code/JNI runtime will be in charge of downcasting the individual elements
+ // back to its concrete Java HttpRequestHandle implementation class. This
+ // avoids us having to try and look up the Java class (which can be difficult,
+ // since we may be running on a thread with a ClassLoader containing only Java
+ // system classes).
+ jclass object_class = env->FindClass("java/lang/Object");
+ FCP_CHECK(object_class != nullptr);
+ FCP_CHECK(!env->ExceptionCheck());
+ LocalRefDeleter object_class_deleter(env, object_class);
+
+ // Create the Object[] array.
+ jobjectArray request_handle_array = env->NewObjectArray(
+ static_cast<jsize>(generic_requests.size()), object_class, nullptr);
+ FCP_CHECK(request_handle_array != nullptr);
+ FCP_CHECK(!env->ExceptionCheck());
+ LocalRefDeleter request_handle_array_deleter(env, request_handle_array);
+
+ // Populate the Object[] array with the Java objects corresponding to each
+ // request, and register each callback with the `JavaHttpRequestHandle`.
+ int i = 0;
+ for (const auto& [generic_handle, callback] : generic_requests) {
+ auto request_handle = static_cast<JavaHttpRequestHandle*>(generic_handle);
+ env->SetObjectArrayElement(request_handle_array, i++,
+ request_handle->GetJobject());
+ FCP_CHECK(!env->ExceptionCheck());
+ FCP_RETURN_IF_ERROR(request_handle->SetCallback(callback));
+ }
+
+ // Call the Java `performRequests` method over JNI, passing it the Object[]
+ // array.
+ jbyteArray perform_requests_result =
+ static_cast<jbyteArray>(env->CallObjectMethod(
+ jthis_, perform_requests_id_, request_handle_array));
+ FCP_CHECK(!env->ExceptionCheck());
+ FCP_CHECK(perform_requests_result != nullptr);
+ LocalRefDeleter perform_requests_result_deleter(env, perform_requests_result);
+
+ // Convert the return value from Java to an absl::Status.
+ return ConvertRpcStatusToAbslStatus(
+ ParseProtoFromJByteArray<
+ ::google::internal::federatedcompute::v1::Status>(
+ env, perform_requests_result));
+}
+
+JavaHttpRequestHandle* JavaHttpRequestHandle::FromJlong(jlong ptr) {
+ // If the Java code erroneously calls a JNI callback with a handle that has
+ // already been destroyed, then `ptr` will be 0. We want to catch such bugs
+ // early.
+ FCP_CHECK(ptr != 0)
+ << "cannot call JNI callback before enqueueRequest has been called";
+ return reinterpret_cast<JavaHttpRequestHandle*>(ptr);
+}
+
+JavaHttpRequestHandle::JavaHttpRequestHandle(
+ JavaVM* jvm, jobject java_http_request_handle,
+ std::unique_ptr<fcp::client::http::HttpRequest> request)
+ : jvm_(jvm), request_(std::move(request)) {
+ ScopedJniEnv scoped_env(jvm_);
+ JNIEnv* env = scoped_env.env();
+ jthis_ = env->NewGlobalRef(java_http_request_handle);
+ FCP_CHECK(jthis_ != nullptr);
+
+ // We get the class from the jobject here instead of looking up by name in
+ // the classloader because we may be using this class from a non java thread
+ // that has been attached to the jvm, and thus has a classloader with only
+ // "system" classes.
+ jclass java_http_request_handle_class =
+ env->GetObjectClass(java_http_request_handle);
+ LocalRefDeleter java_http_request_handle_class_deleter(
+ env, java_http_request_handle_class);
+
+ get_total_sent_received_bytes_id_ = GetMethodIdOrAbort(
+ *env, java_http_request_handle_class,
+ JavaHttpRequestHandleClassDesc::kGetTotalSentReceivedBytes);
+
+ close_id_ = GetMethodIdOrAbort(*env, java_http_request_handle_class,
+ JavaHttpRequestHandleClassDesc::kClose);
+
+ native_handle_id_ =
+ env->GetFieldID(java_http_request_handle_class,
+ JavaHttpRequestHandleClassDesc::kNativeHandle.name,
+ JavaHttpRequestHandleClassDesc::kNativeHandle.signature);
+
+ // Register this object's address inside the `nativeHandle` field, so we can
+ // look this object up during later calls back into native.
+ env->SetLongField(jthis_, native_handle_id_, reinterpret_cast<jlong>(this));
+ FCP_CHECK(!env->ExceptionCheck());
+}
+
+JavaHttpRequestHandle::~JavaHttpRequestHandle() {
+ ScopedJniEnv scoped_env(jvm_);
+ JNIEnv* env = scoped_env.env();
+
+ absl::MutexLock locked(&lock_);
+ // We call the Java close() method when the destructor is invoked, to let the
+ // Java code know the request's resources (if any) can be released. The
+ // close() method may not have been invoked yet if the JavaHttpRequestHandle
+ // never ended up being passed to `performRequests`.
+ env->CallVoidMethod(jthis_, close_id_);
+ FCP_CHECK(!env->ExceptionCheck());
+
+ // Unset the native handle (this is an additional safety check, so that if the
+ // Java object erroneously calls back into the native layer again, we will be
+ // able to detect it, rather than us accidentally accessing a destructed
+ // object).
+ env->SetLongField(jthis_, native_handle_id_, 0);
+
+ // Delete the reference to the Java object.
+ env->DeleteGlobalRef(jthis_);
+}
+
+void JavaHttpRequestHandle::Cancel() {
+ {
+ absl::MutexLock locked(&lock_);
+ // We mark the request 'performed'. This way if the handle is subsequently
+ // still erroneously passed to `PerformRequests`, we can detect the error.
+ performed_ = true;
+ }
+ // Note that we release the lock before calling into Java, to ensure that if
+ // the Java call itself calls back into the native layer (e.g. by calling on
+ // of the request handle callbacks), we don't accidentally try to acquire
+ // the same mutex twice.
+
+ ScopedJniEnv scoped_env(jvm_);
+ JNIEnv* env = scoped_env.env();
+
+ // We call the Java close() method to indicate that the request should be
+ // cancelled. If `PerformRequests` wasn't called yet, then this will be a
+ // no-op.
+ env->CallVoidMethod(jthis_, close_id_);
+ FCP_CHECK(!env->ExceptionCheck());
+}
+
+HttpRequestHandle::SentReceivedBytes
+JavaHttpRequestHandle::TotalSentReceivedBytes() const {
+ ScopedJniEnv scoped_env(jvm_);
+ JNIEnv* env = scoped_env.env();
+
+ jbyteArray sent_received_bytes_result = static_cast<jbyteArray>(
+ env->CallObjectMethod(jthis_, get_total_sent_received_bytes_id_));
+ FCP_CHECK(!env->ExceptionCheck());
+ FCP_CHECK(sent_received_bytes_result != nullptr);
+ LocalRefDeleter sent_received_bytes_result_deleter(
+ env, sent_received_bytes_result);
+
+ // Convert the return value from a Java byte[] to the expected proto.
+ auto sent_received_bytes = ParseProtoFromJByteArray<JniHttpSentReceivedBytes>(
+ env, sent_received_bytes_result);
+ return {.sent_bytes = sent_received_bytes.sent_bytes(),
+ .received_bytes = sent_received_bytes.received_bytes()};
+}
+
+fcp::client::http::HttpRequestCallback* JavaHttpRequestHandle::callback()
+ const {
+ // This method acquires and immediately releases the lock to ensure that all
+ // JNI callback invocations observe the effects of any prior JNI callback
+ // invocation, prior to invoking another `HttpRequestCallback` method.
+ //
+ // We don't hold the lock while invoking the actual `HttpRequestCallback`
+ // method though, to ensure that if the `HttpRequestCallback` invocation
+ // ultimately causes another JNI callback to be invoked, we don't attempt to
+ // acquire the lock twice on the same thread.
+ absl::MutexLock _(&lock_);
+ return callback_;
+}
+
+const JavaHttpResponse& JavaHttpRequestHandle::response() const {
+ // We synchronize for the same purpose as in callback().
+ absl::MutexLock _(&lock_);
+ return response_;
+}
+
+absl::Status JavaHttpRequestHandle::SetCallback(
+ fcp::client::http::HttpRequestCallback* callback) {
+ absl::MutexLock locked(&lock_);
+ // If the request was already 'performed' then we should detect that error.
+ if (performed_) {
+ return absl::InvalidArgumentError(
+ "can't perform a request twice, or perform an already-cancelled "
+ "request");
+ }
+ performed_ = true;
+ callback_ = callback;
+ return absl::OkStatus();
+}
+
+jboolean JavaHttpRequestHandle::ReadRequestBody(JNIEnv* env, jbyteArray buffer,
+ jlong requested_bytes,
+ jintArray actual_bytes_read) {
+ // Get a pointer to the output buffer's raw data. Note that this may make a
+ // copy of the Java data, but depending on JVM implementation it may also
+ // return a direct pointer to it, avoiding the copy (on Android, ART will
+ // generally avoid copying if the array is large enough).
+ jbyte* raw_buffer = env->GetByteArrayElements(buffer, nullptr);
+ FCP_CHECK(raw_buffer != nullptr);
+ FCP_CHECK(!env->ExceptionCheck());
+ // Ask the `HttpRequest` to write the request body data into the buffer.
+ absl::StatusOr<int64_t> read_body_result =
+ request_->ReadBody(reinterpret_cast<char*>(raw_buffer), requested_bytes);
+ // Release the raw buffer pointer (we must always do this, even if we hit an
+ // error and didn't write anything to the buffer).
+ //
+ // This ensures that the data in raw_buffer is now visible via the Java buffer
+ // (as noted above, this may result in copying the data into the Java heap,
+ // but if a direct pointer was returned earlier on then this will be a no-op).
+ env->ReleaseByteArrayElements(buffer, raw_buffer, 0);
+ FCP_CHECK(!env->ExceptionCheck());
+
+ // Out of range is expected, and marks the end of the body. Any other error is
+ // unrecoverable, and should result in the OnResponseError being called.
+ if (!read_body_result.ok() &&
+ read_body_result.status().code() != absl::StatusCode::kOutOfRange) {
+ // If we receive an error during the reading of the request body, we
+ // immediately forward that error to the HttpRequestCallback (i.e. the Java
+ // layer will not need to call this callback method anymore). This ensures
+ // that we can forward the original error back to the callback (without
+ // having to convert it to a Java representation and back to a Status
+ // again).
+ callback()->OnResponseError(
+ *request_,
+ absl::Status(read_body_result.status().code(),
+ absl::StrCat("failed to read request body",
+ read_body_result.status().message())));
+ return JNI_FALSE;
+ }
+
+ // Otherwise, if everything went successfully, then we still need to write the
+ // actual amount of data we read (or -1 if we hit the end of the data) to the
+ // `actual_bytes_read` output array (the output array provides a convenient
+ // way to return something in addition to the return value, while still using
+ // only primitive Java types to keep the JNI boilerplate to a minimum).
+ //
+ // Note: we know that casting from int64_t to jint (aka a 32 bit int) should
+ // be safe, since `requested_bytes` is a jint, and the actual bytes read can
+ // never be larger than that number.
+ jint actual_bytes_read_result[] = {
+ static_cast<jint>(read_body_result.value_or(-1))};
+ env->SetIntArrayRegion(actual_bytes_read, 0, 1, actual_bytes_read_result);
+ return JNI_TRUE;
+}
+
+jboolean JavaHttpRequestHandle::OnResponseStarted(JNIEnv* env,
+ jbyteArray response_proto) {
+ // Populate the response_ field based on the serialized response proto. This
+ // will allow us to access it in subsequent callbacks as well.
+ {
+ absl::MutexLock _(&lock_);
+ response_.PopulateFromProto(
+ ParseProtoFromJByteArray<JniHttpResponse>(env, response_proto));
+ }
+ return callback()->OnResponseStarted(*request_, response()).ok() ? JNI_TRUE
+ : JNI_FALSE;
+}
+
+void JavaHttpRequestHandle::OnResponseError(JNIEnv* env,
+ jbyteArray status_proto) {
+ absl::Status status = ConvertRpcStatusToAbslStatus(
+ ParseProtoFromJByteArray<
+ ::google::internal::federatedcompute::v1::Status>(env, status_proto));
+ callback()->OnResponseError(*request_, status);
+}
+
+jboolean JavaHttpRequestHandle::OnResponseBody(JNIEnv* env, jbyteArray buffer,
+ jint bytes_available) {
+ // Get a pointer to the input buffer's raw data. Note that this may make a
+ // copy of the Java data, but depending on JVM implementation it may also
+ // return a direct pointer to it, avoiding the copy (on Android, ART will
+ // generally avoid copying if the array is large enough).
+ jbyte* raw_buffer = env->GetByteArrayElements(buffer, nullptr);
+ FCP_CHECK(raw_buffer != nullptr);
+ FCP_CHECK(!env->ExceptionCheck());
+ absl::string_view buffer_view(reinterpret_cast<char*>(raw_buffer),
+ bytes_available);
+ // Pass the response body data to the HttpRequestCallback.
+ auto result = callback()->OnResponseBody(*request_, response(), buffer_view);
+
+ // JNI_ABORT ensures that we don't copy the bytes in the raw buffer back to
+ // the main buffer (since we know they weren't modified).
+ env->ReleaseByteArrayElements(buffer, raw_buffer, JNI_ABORT);
+
+ return result.ok() ? JNI_TRUE : JNI_FALSE;
+}
+
+void JavaHttpRequestHandle::OnResponseBodyError(JNIEnv* env,
+ jbyteArray status_proto) {
+ absl::Status status = ConvertRpcStatusToAbslStatus(
+ ParseProtoFromJByteArray<
+ ::google::internal::federatedcompute::v1::Status>(env, status_proto));
+ callback()->OnResponseBodyError(*request_, response(), status);
+}
+
+void JavaHttpRequestHandle::OnResponseCompleted() {
+ callback()->OnResponseCompleted(*request_, response());
+}
+
+// JNI functions. These are called from Java. We just forward them to the
+// appropriate JavaHttpRequestHandle instance's member function.
+#define JFUN(METHOD_NAME) \
+ Java_com_google_fcp_client_http_HttpClientForNative_##METHOD_NAME // NOLINT
+
+extern "C" JNIEXPORT jboolean JNICALL JFUN(readRequestBody)(
+ JNIEnv* env, jclass, jlong request_handle_ptr, jbyteArray buffer,
+ jlong requested_bytes, jintArray actual_bytes_read) {
+ return JavaHttpRequestHandle::FromJlong(request_handle_ptr)
+ ->ReadRequestBody(env, buffer, requested_bytes, actual_bytes_read);
+}
+
+extern "C" JNIEXPORT jboolean JNICALL JFUN(onResponseStarted)(
+ JNIEnv* env, jclass, jlong request_handle_ptr, jbyteArray response_proto) {
+ return JavaHttpRequestHandle::FromJlong(request_handle_ptr)
+ ->OnResponseStarted(env, response_proto);
+}
+
+extern "C" JNIEXPORT void JNICALL JFUN(onResponseError)(
+ JNIEnv* env, jclass, jlong request_handle_ptr, jbyteArray status_proto) {
+ return JavaHttpRequestHandle::FromJlong(request_handle_ptr)
+ ->OnResponseError(env, status_proto);
+}
+
+extern "C" JNIEXPORT jboolean JNICALL
+JFUN(onResponseBody)(JNIEnv* env, jclass, jlong request_handle_ptr,
+ jbyteArray buffer, jint bytes_available) {
+ return JavaHttpRequestHandle::FromJlong(request_handle_ptr)
+ ->OnResponseBody(env, buffer, bytes_available);
+}
+
+extern "C" JNIEXPORT void JNICALL JFUN(onResponseBodyError)(
+ JNIEnv* env, jclass, jlong request_handle_ptr, jbyteArray status_proto) {
+ JavaHttpRequestHandle::FromJlong(request_handle_ptr)
+ ->OnResponseBodyError(env, status_proto);
+}
+
+extern "C" JNIEXPORT void JNICALL
+JFUN(onResponseCompleted)(JNIEnv* env, jclass, jlong request_handle_ptr) {
+ JavaHttpRequestHandle::FromJlong(request_handle_ptr)->OnResponseCompleted();
+}
+
+} // namespace java
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/java/java_http_client.h b/fcp/client/http/java/java_http_client.h
new file mode 100644
index 0000000..3187ecd
--- /dev/null
+++ b/fcp/client/http/java/java_http_client.h
@@ -0,0 +1,181 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_JAVA_JAVA_HTTP_CLIENT_H_
+#define FCP_CLIENT_HTTP_JAVA_JAVA_HTTP_CLIENT_H_
+
+#include <jni.h>
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/base/attributes.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/java/jni.pb.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+namespace java {
+
+// An `HttpClient` implementation that performs HTTP requests by calling back
+// into Java over JNI, and lets the Java code issue the actual requests.
+class JavaHttpClient : public HttpClient {
+ public:
+ JavaHttpClient(JavaVM* jvm, jobject java_http_client);
+ ~JavaHttpClient() override;
+
+ ABSL_MUST_USE_RESULT
+ std::unique_ptr<fcp::client::http::HttpRequestHandle> EnqueueRequest(
+ std::unique_ptr<fcp::client::http::HttpRequest> request) override;
+
+ absl::Status PerformRequests(
+ std::vector<std::pair<fcp::client::http::HttpRequestHandle*,
+ fcp::client::http::HttpRequestCallback*>>
+ requests) override;
+
+ private:
+ JavaVM* const jvm_;
+ jobject jthis_;
+ jmethodID enqueue_request_id_;
+ jmethodID perform_requests_id_;
+ jmethodID close_id_;
+};
+
+// An HttpResponse implementation that is based on data provided via the JNI
+// callbacks.
+class JavaHttpResponse : public fcp::client::http::HttpResponse {
+ public:
+ JavaHttpResponse() {}
+
+ // Populates the response data based on the given JNI proto.
+ void PopulateFromProto(const JniHttpResponse& response_proto) {
+ code_ = response_proto.code();
+ headers_.clear();
+ for (const JniHttpHeader& header : response_proto.headers()) {
+ headers_.push_back(Header(header.name(), header.value()));
+ }
+ }
+
+ int code() const override { return code_; }
+ const HeaderList& headers() const override { return headers_; }
+
+ private:
+ int code_ = -1;
+ HeaderList headers_;
+};
+
+// An `HttpRequestHandle` implementation that performs HTTP requests by calling
+// back into Java over JNI, and lets the Java code issue the actual requests.
+class JavaHttpRequestHandle : public fcp::client::http::HttpRequestHandle {
+ public:
+ // Utility for extracting a JavaHttpRequestHandle pointer from a Java jlong.
+ // When a JavaHttpRequestHandle object is constructed, it will 'attach' itself
+ // to the corresponding Java object by writing its address to its
+ // `nativeHandle` field. When Java then calls back into our C++ code, it
+ // passes that address back to us, and this function can then be used to turn
+ // the address into a proper pointer.
+ static JavaHttpRequestHandle* FromJlong(jlong ptr);
+
+ JavaHttpRequestHandle(
+ JavaVM* jvm, jobject java_http_request_handle,
+ std::unique_ptr<fcp::client::http::HttpRequest> request);
+
+ ~JavaHttpRequestHandle() override ABSL_LOCKS_EXCLUDED(lock_);
+
+ // --- HttpRequestHandle methods
+ fcp::client::http::HttpRequestHandle::SentReceivedBytes
+ TotalSentReceivedBytes() const override;
+ void Cancel() override ABSL_LOCKS_EXCLUDED(lock_);
+
+ // --- JNI handling methods
+ jboolean ReadRequestBody(JNIEnv* env, jbyteArray buffer,
+ jlong requested_bytes, jintArray actual_bytes_read)
+ ABSL_LOCKS_EXCLUDED(lock_);
+
+ jboolean OnResponseStarted(JNIEnv* env, jbyteArray response_proto)
+ ABSL_LOCKS_EXCLUDED(lock_);
+ void OnResponseError(JNIEnv* env, jbyteArray status_proto)
+ ABSL_LOCKS_EXCLUDED(lock_);
+
+ jboolean OnResponseBody(JNIEnv* env, jbyteArray buffer, jint bytes_available)
+ ABSL_LOCKS_EXCLUDED(lock_);
+ void OnResponseBodyError(JNIEnv* env, jbyteArray status_proto)
+ ABSL_LOCKS_EXCLUDED(lock_);
+ void OnResponseCompleted() ABSL_LOCKS_EXCLUDED(lock_);
+
+ // --- Internal methods.
+ // Returns the (possibly empty/default-constructed) response object for this
+ // handle. This object is populated with actual data after `OnResponseStarted`
+ // is called.
+ const JavaHttpResponse& response() const ABSL_LOCKS_EXCLUDED(lock_);
+
+ // Returns the `HttpRequestCallback` associated with this handle, or a nullptr
+ // if no callback is associated with it yet (i.e. no `PerformRequests` call
+ // was made yet with this handle).
+ fcp::client::http::HttpRequestCallback* callback() const
+ ABSL_LOCKS_EXCLUDED(lock_);
+
+ // Associates an `HttpRequestCallback` with this handle, or returns an
+ // `INVALID_ARGUMENT` error if one was already associated with it (indicating
+ // the handle is being used with more than one `PerformRequests` call, which
+ // is not allowed).
+ absl::Status SetCallback(fcp::client::http::HttpRequestCallback* callback)
+ ABSL_LOCKS_EXCLUDED(lock_);
+
+ // Returns the JNI reference for the Java handle object that corresponds to
+ // this C++ handle object.
+ jobject GetJobject() { return jthis_; }
+
+ private:
+ JavaVM* const jvm_;
+ jobject jthis_;
+ jmethodID get_total_sent_received_bytes_id_;
+ jmethodID close_id_;
+ jfieldID native_handle_id_;
+
+ const std::unique_ptr<fcp::client::http::HttpRequest> request_;
+
+ // A note on synchronization: most of JavaHttpRequestHandle's methods may be
+ // called from any thread (incl. the JNI callback methods, which are
+ // guaranteed to never be called concurrently from more than one thread, but
+ // which for which subsequent calls may occur on different threads). This
+ // object also has mutable state (e.g. `response_` which only gets populated
+ // with data once the `OnResponseStarted` callback is invoked).
+ //
+ // We use this mutex to ensure that each JNI callback invocation observes the
+ // effects of every prior callback invocation. This means we don't have to
+ // rely on the Java side of the implementation for thread safety (even though
+ // the Java side likely also implements its own synchronization as well).
+ mutable absl::Mutex lock_;
+ JavaHttpResponse response_ ABSL_GUARDED_BY(lock_);
+ fcp::client::http::HttpRequestCallback* callback_ ABSL_GUARDED_BY(lock_) =
+ nullptr;
+ bool performed_ ABSL_GUARDED_BY(lock_) = false;
+};
+
+} // namespace java
+} // namespace http
+} // namespace client
+} // namespace fcp
+#endif // FCP_CLIENT_HTTP_JAVA_JAVA_HTTP_CLIENT_H_
diff --git a/fcp/client/http/java/jni.proto b/fcp/client/http/java/jni.proto
new file mode 100644
index 0000000..39bedd8
--- /dev/null
+++ b/fcp/client/http/java/jni.proto
@@ -0,0 +1,62 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.client.http.java;
+
+option java_package = "com.google.fcp.client.http";
+option java_multiple_files = true;
+option java_outer_classname = "JniProto";
+
+// This file defines protos that are used to serialize data across the JNI
+// boundary, for use by `fcp::client::http::java::JavaHttpClient`.
+
+// Represents a serialized `fcp::client::http::HttpRequest` object.
+message JniHttpRequest {
+ string uri = 1;
+ JniHttpMethod method = 2;
+ repeated JniHttpHeader extra_headers = 3;
+ bool has_body = 4;
+}
+
+// Represents a serialized `fcp::client::http::Header` object.
+message JniHttpHeader {
+ string name = 1;
+ string value = 2;
+}
+
+// Represents a serialized `fcp::client::http::HttpRequest::Method` enum.
+enum JniHttpMethod {
+ HTTP_METHOD_UNKNOWN = 0;
+ HTTP_METHOD_HEAD = 1;
+ HTTP_METHOD_GET = 2;
+ HTTP_METHOD_POST = 3;
+ HTTP_METHOD_PUT = 4;
+ HTTP_METHOD_PATCH = 5;
+ HTTP_METHOD_DELETE = 6;
+}
+
+// Represents a serialized `fcp::client::http::HttpResponse` object.
+message JniHttpResponse {
+ int32 code = 1;
+ repeated JniHttpHeader headers = 2;
+}
+
+// Represents a serialized
+// `fcp::client::http::HttpRequestHandle::SentReceivesBytes` object.
+message JniHttpSentReceivedBytes {
+ int64 sent_bytes = 1;
+ int64 received_bytes = 2;
+}
diff --git a/fcp/client/http/protocol_request_helper.cc b/fcp/client/http/protocol_request_helper.cc
new file mode 100644
index 0000000..142c406
--- /dev/null
+++ b/fcp/client/http/protocol_request_helper.cc
@@ -0,0 +1,377 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/protocol_request_helper.h"
+
+#include "absl/strings/substitute.h"
+#include "fcp/base/time_util.h"
+#include "fcp/client/http/http_client_util.h"
+#include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
+#include "fcp/protos/federatedcompute/task_assignments.pb.h"
+#include "google/protobuf/any.pb.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+// The default interval when polling pending operations.
+const absl::Duration kDefaultLroPollingInterval = absl::Milliseconds(500);
+// The maximum interval when polling pending operations.
+const absl::Duration kMaxLroPollingInterval = absl::Minutes(1);
+
+constexpr absl::string_view kStartTaskAssignmentMetadata =
+ "type.googleapis.com/"
+ "google.internal.federatedcompute.v1.StartTaskAssignmentMetadata"; // NOLINT
+constexpr absl::string_view kAdvertiseKeysMetadata =
+ "type.googleapis.com/"
+ "google.internal.federatedcompute.v1.AdvertiseKeysMetadata"; // NOLINT
+constexpr absl::string_view kShareKeysMetadata =
+ "type.googleapis.com/google.internal.federatedcompute.v1.ShareKeysMetadata";
+constexpr absl::string_view kSubmitSecureAggregationResultMetadata =
+ "type.googleapis.com/"
+ "google.internal.federatedcompute.v1."
+ "SubmitSecureAggregationResultMetadata"; // NOLINT
+
+using ::google::internal::federatedcompute::v1::AdvertiseKeysMetadata;
+using ::google::internal::federatedcompute::v1::ForwardingInfo;
+using ::google::internal::federatedcompute::v1::ShareKeysMetadata;
+using ::google::internal::federatedcompute::v1::StartTaskAssignmentMetadata;
+using ::google::internal::federatedcompute::v1::
+ SubmitSecureAggregationResultMetadata;
+// using ::google::longrunning::Operation;
+
+namespace {
+// A note on error handling:
+//
+// The implementation here makes a distinction between what we call 'transient'
+// and 'permanent' errors. While the exact categorization of transient vs.
+// permanent errors is defined by a flag, the intent is that transient errors
+// are those types of errors that may occur in the regular course of business,
+// e.g. due to an interrupted network connection, a load balancer temporarily
+// rejecting our request etc. Generally, these are expected to be resolvable by
+// merely retrying the request at a slightly later time. Permanent errors are
+// intended to be those that are not expected to be resolvable as quickly or by
+// merely retrying the request. E.g. if a client checks in to the server with a
+// population name that doesn't exist, then the server may return NOT_FOUND, and
+// until the server-side configuration is changed, it will continue returning
+// such an error. Hence, such errors can warrant a longer retry period (to waste
+// less of both the client's and server's resources).
+//
+// The errors also differ in how they interact with the server-specified retry
+// windows that are returned via the EligbilityEvalTaskResponse message.
+// - If a permanent error occurs, then we will always return a retry window
+// based on the target 'permanent errors retry period' flag, regardless of
+// whether we received an EligbilityEvalTaskResponse from the server at an
+// earlier time.
+// - If a transient error occurs, then we will only return a retry window
+// based on the target 'transient errors retry period' flag if the server
+// didn't already return an EligibilityEvalTaskResponse. If it did return such
+// a response, then one of the retry windows in that message will be used
+// instead.
+//
+// Finally, note that for simplicity's sake we generally check whether a
+// permanent error was received at the level of this class's public methods,
+// rather than deeper down in each of our helper methods that actually call
+// directly into the HTTP stack. This keeps our state-managing code simpler, but
+// does mean that if any of our helper methods like
+// PerformEligibilityEvalTaskRequest produce a permanent error code locally
+// (i.e. without it being sent by the server), it will be treated as if the
+// server sent it and the permanent error retry period will be used. We consider
+// this a reasonable tradeoff.
+
+std::string CreateUriSuffixFromPathAndParams(absl::string_view path,
+ const QueryParams& params) {
+ return absl::StrCat(path, "?",
+ absl::StrJoin(params.begin(), params.end(), "&",
+ absl::PairFormatter("=")));
+}
+
+// Creates the URI suffix for a GetOperation protocol request.
+absl::StatusOr<std::string> CreateGetOperationUriSuffix(
+ absl::string_view operation_name) {
+ constexpr absl::string_view kGetOperationUriSuffix = "/v1/$0";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_operation_name,
+ EncodeUriMultiplePathSegments(operation_name));
+ return absl::Substitute(kGetOperationUriSuffix, encoded_operation_name);
+}
+
+// Creates the URI suffix for a CancelOperation protocol request.
+absl::StatusOr<std::string> CreateCancelOperationUriSuffix(
+ absl::string_view operation_name) {
+ constexpr absl::string_view kCancelOperationUriSuffix = "/v1/$0:cancel";
+ FCP_ASSIGN_OR_RETURN(std::string encoded_operation_name,
+ EncodeUriMultiplePathSegments(operation_name));
+ return absl::Substitute(kCancelOperationUriSuffix, encoded_operation_name);
+}
+
+absl::StatusOr<InMemoryHttpResponse> CheckResponseContentEncoding(
+ absl::StatusOr<InMemoryHttpResponse> response) {
+ if (response.ok() && !response->content_encoding.empty()) {
+ // Note that the `HttpClient` API contract ensures that if we don't specify
+ // an Accept-Encoding request header, then the response should be delivered
+ // to us without any Content-Encoding applied to it. Hence, if we somehow do
+ // still see a Content-Encoding response header then the `HttpClient`
+ // implementation isn't adhering to its part of the API contract.
+ return absl::UnavailableError(
+ "HTTP response unexpectedly has a Content-Encoding");
+ }
+ return response;
+}
+
+// Extract polling interval from the operation proto.
+// The returned polling interval will be within the range of [1ms, 1min]. If
+// the polling interval inside the operation proto is outside this range, it'll
+// be clipped to the nearest boundary. If the polling interval is unset, 1ms
+// will be returned.
+// absl::Duration GetPollingInterval(Operation operation) {
+// absl::string_view type_url = operation.metadata().type_url();
+// google::protobuf::Duration polling_interval_proto;
+// if (type_url == kStartTaskAssignmentMetadata) {
+// StartTaskAssignmentMetadata metadata;
+// if (!operation.metadata().UnpackTo(&metadata)) {
+// return kDefaultLroPollingInterval;
+// }
+// polling_interval_proto = metadata.polling_interval();
+// } else if (type_url == kAdvertiseKeysMetadata) {
+// AdvertiseKeysMetadata metadata;
+// if (!operation.metadata().UnpackTo(&metadata)) {
+// return kDefaultLroPollingInterval;
+// }
+// polling_interval_proto = metadata.polling_interval();
+// } else if (type_url == kShareKeysMetadata) {
+// ShareKeysMetadata metadata;
+// if (!operation.metadata().UnpackTo(&metadata)) {
+// return kDefaultLroPollingInterval;
+// }
+// polling_interval_proto = metadata.polling_interval();
+// } else if (type_url == kSubmitSecureAggregationResultMetadata) {
+// SubmitSecureAggregationResultMetadata metadata;
+// if (!operation.metadata().UnpackTo(&metadata)) {
+// return kDefaultLroPollingInterval;
+// }
+// polling_interval_proto = metadata.polling_interval();
+// } else {
+// // Unknown type
+// return kDefaultLroPollingInterval;
+// }
+
+// absl::Duration polling_interval =
+// TimeUtil::ConvertProtoToAbslDuration(polling_interval_proto);
+// if (polling_interval < absl::ZeroDuration()) {
+// return kDefaultLroPollingInterval;
+// } else if (polling_interval > kMaxLroPollingInterval) {
+// return kMaxLroPollingInterval;
+// } else {
+// return polling_interval;
+// }
+// }
+
+} // anonymous namespace
+
+ProtocolRequestCreator::ProtocolRequestCreator(
+ absl::string_view request_base_uri, absl::string_view api_key,
+ HeaderList request_headers, bool use_compression)
+ : next_request_base_uri_(request_base_uri),
+ api_key_(api_key),
+ next_request_headers_(std::move(request_headers)),
+ use_compression_(use_compression) {}
+
+absl::StatusOr<std::unique_ptr<HttpRequest>>
+ProtocolRequestCreator::CreateProtocolRequest(absl::string_view uri_path_suffix,
+ QueryParams params,
+ HttpRequest::Method method,
+ std::string request_body,
+ bool is_protobuf_encoded) const {
+ return CreateHttpRequest(uri_path_suffix, std::move(params), method,
+ std::move(request_body), is_protobuf_encoded,
+ use_compression_);
+}
+
+absl::StatusOr<std::unique_ptr<HttpRequest>>
+ProtocolRequestCreator::CreateGetOperationRequest(
+ absl::string_view operation_name) const {
+ FCP_ASSIGN_OR_RETURN(std::string uri_path_suffix,
+ CreateGetOperationUriSuffix(operation_name));
+ return CreateHttpRequest(uri_path_suffix, {}, HttpRequest::Method::kGet, "",
+ /*is_protobuf_encoded=*/true,
+ /*use_compression=*/false);
+}
+
+absl::StatusOr<std::unique_ptr<HttpRequest>>
+ProtocolRequestCreator::CreateCancelOperationRequest(
+ absl::string_view operation_name) const {
+ FCP_ASSIGN_OR_RETURN(std::string uri_path_suffix,
+ CreateCancelOperationUriSuffix(operation_name));
+ return CreateHttpRequest(uri_path_suffix, {}, HttpRequest::Method::kGet, "",
+ /*is_protobuf_encoded=*/true,
+ /*use_compression=*/false);
+}
+
+absl::StatusOr<std::unique_ptr<HttpRequest>>
+ProtocolRequestCreator::CreateHttpRequest(absl::string_view uri_path_suffix,
+ QueryParams params,
+ HttpRequest::Method method,
+ std::string request_body,
+ bool is_protobuf_encoded,
+ bool use_compression) const {
+ HeaderList request_headers = next_request_headers_;
+ request_headers.push_back({kApiKeyHdr, api_key_});
+ if (is_protobuf_encoded) {
+ if (!request_body.empty()) {
+ request_headers.push_back({kContentTypeHdr, kProtobufContentType});
+ }
+
+ // %24alt is the percent encoded $alt. "$" is prepended to alt to indicate
+ // that "alt" is a system parameter.
+ // https://cloud.google.com/apis/docs/system-parameters#http_mapping
+ params["%24alt"] = "proto";
+ }
+ std::string uri_with_params = std::string(uri_path_suffix);
+ if (!params.empty()) {
+ uri_with_params = CreateUriSuffixFromPathAndParams(uri_path_suffix, params);
+ }
+ FCP_ASSIGN_OR_RETURN(
+ std::string uri,
+ JoinBaseUriWithSuffix(next_request_base_uri_, uri_with_params));
+
+ return InMemoryHttpRequest::Create(uri, method, request_headers,
+ std::move(request_body), use_compression);
+}
+
+absl::StatusOr<std::unique_ptr<ProtocolRequestCreator>>
+ProtocolRequestCreator::Create(absl::string_view api_key,
+ const ForwardingInfo& forwarding_info,
+ bool use_compression) {
+ // Extract the base URI and headers to use for the subsequent request.
+ if (forwarding_info.target_uri_prefix().empty()) {
+ return absl::InvalidArgumentError(
+ "Missing `ForwardingInfo.target_uri_prefix`");
+ }
+ const auto& new_headers = forwarding_info.extra_request_headers();
+ return std::make_unique<ProtocolRequestCreator>(ProtocolRequestCreator(
+ forwarding_info.target_uri_prefix(), api_key,
+ HeaderList(new_headers.begin(), new_headers.end()), use_compression));
+}
+
+ProtocolRequestHelper::ProtocolRequestHelper(
+ HttpClient* http_client, int64_t* bytes_downloaded, int64_t* bytes_uploaded,
+ WallClockStopwatch* network_stopwatch, Clock* clock)
+ : http_client_(*http_client),
+ bytes_downloaded_(*bytes_downloaded),
+ bytes_uploaded_(*bytes_uploaded),
+ network_stopwatch_(*network_stopwatch),
+ clock_(*clock) {}
+
+absl::StatusOr<InMemoryHttpResponse>
+ProtocolRequestHelper::PerformProtocolRequest(
+ std::unique_ptr<HttpRequest> request, InterruptibleRunner& runner) {
+ std::vector<std::unique_ptr<HttpRequest>> requests;
+ requests.push_back(std::move(request));
+ FCP_ASSIGN_OR_RETURN(
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> response,
+ PerformMultipleProtocolRequests(std::move(requests), runner));
+ return std::move(response[0]);
+}
+
+absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
+ProtocolRequestHelper::PerformMultipleProtocolRequests(
+ std::vector<std::unique_ptr<http::HttpRequest>> requests,
+ InterruptibleRunner& runner) {
+ // Check whether issuing the request failed as a whole (generally indicating
+ // a programming error).
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> responses;
+ {
+ auto started_stopwatch = network_stopwatch_.Start();
+ FCP_ASSIGN_OR_RETURN(responses,
+ PerformMultipleRequestsInMemory(
+ http_client_, runner, std::move(requests),
+ &bytes_downloaded_, &bytes_uploaded_));
+ }
+ std::vector<absl::StatusOr<InMemoryHttpResponse>> results;
+ std::transform(responses.begin(), responses.end(),
+ std::back_inserter(results), CheckResponseContentEncoding);
+ return results;
+}
+
+// absl::StatusOr<::google::longrunning::Operation>
+// ProtocolRequestHelper::PollOperationResponseUntilDone(
+// const Operation& initial_operation,
+// const ProtocolRequestCreator& request_creator,
+// InterruptibleRunner& runner) {
+// // There are three cases that lead to this method returning:
+// // - The HTTP response indicates an error.
+// // - The HTTP response cannot be parsed into an Operation proto.
+// // - The response `Operation.done` field is true.
+// //
+// // In all other cases we continue to poll the Operation via a subsequent
+// // GetOperationRequest.
+// Operation response_operation_proto = initial_operation;
+// while (true) {
+// // If the Operation is done then return it.
+// if (response_operation_proto.done()) {
+// return std::move(response_operation_proto);
+// }
+
+// FCP_ASSIGN_OR_RETURN(std::string operation_name,
+// ExtractOperationName(response_operation_proto));
+
+// // Wait for server returned polling interval before sending next request.
+// clock_.Sleep(GetPollingInterval(response_operation_proto));
+// // The response Operation indicates that the result isn't ready yet. Poll
+// // again.
+// FCP_ASSIGN_OR_RETURN(
+// std::unique_ptr<HttpRequest> get_operation_request,
+// request_creator.CreateGetOperationRequest(operation_name));
+// absl::StatusOr<InMemoryHttpResponse> http_response =
+// PerformProtocolRequest(std::move(get_operation_request), runner);
+// FCP_ASSIGN_OR_RETURN(response_operation_proto,
+// ParseOperationProtoFromHttpResponse(http_response));
+// }
+// }
+
+// absl::StatusOr<InMemoryHttpResponse> ProtocolRequestHelper::CancelOperation(
+// absl::string_view operation_name,
+// const ProtocolRequestCreator& request_creator,
+// InterruptibleRunner& runner) {
+// FCP_ASSIGN_OR_RETURN(
+// std::unique_ptr<HttpRequest> cancel_operation_request,
+// request_creator.CreateCancelOperationRequest(operation_name));
+// return PerformProtocolRequest(std::move(cancel_operation_request), runner);
+// }
+
+// absl::StatusOr<Operation> ParseOperationProtoFromHttpResponse(
+// absl::StatusOr<InMemoryHttpResponse> http_response) {
+// // If the HTTP response indicates an error then return that error.
+// FCP_RETURN_IF_ERROR(http_response);
+// Operation response_operation_proto;
+// // Parse the response.
+// if (!response_operation_proto.ParseFromString(
+// std::string(http_response->body))) {
+// return absl::InvalidArgumentError("could not parse Operation proto");
+// }
+// return response_operation_proto;
+// }
+
+// absl::StatusOr<std::string> ExtractOperationName(const Operation& operation)
+// {
+// if (!absl::StartsWith(operation.name(), "operations/")) {
+// return absl::InvalidArgumentError(
+// "Cannot cancel an Operation with an invalid name");
+// }
+// return operation.name();
+// }
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/protocol_request_helper.h b/fcp/client/http/protocol_request_helper.h
new file mode 100644
index 0000000..8dec3e6
--- /dev/null
+++ b/fcp/client/http/protocol_request_helper.h
@@ -0,0 +1,170 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_PROTOCOL_REQUEST_HELPER_H_
+#define FCP_CLIENT_HTTP_PROTOCOL_REQUEST_HELPER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/base/clock.h"
+#include "fcp/base/wall_clock_stopwatch.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/in_memory_request_response.h"
+#include "fcp/protos/federatedcompute/common.pb.h"
+// #include "google/longrunning/operations.pb.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+// Note the uri query parameters should be percent encoded.
+using QueryParams = absl::flat_hash_map<std::string, std::string>;
+
+// A helper for creating HTTP request with base uri, request headers and
+// compression setting.
+class ProtocolRequestCreator {
+ public:
+ ProtocolRequestCreator(absl::string_view request_base_uri,
+ absl::string_view api_key, HeaderList request_headers,
+ bool use_compression);
+
+ // Creates a `ProtocolRequestCreator` based on the forwarding info.
+ // Validates and extracts the base URI and headers to use for the subsequent
+ // request(s).
+ static absl::StatusOr<std::unique_ptr<ProtocolRequestCreator>> Create(
+ absl::string_view api_key,
+ const ::google::internal::federatedcompute::v1::ForwardingInfo&
+ forwarding_info,
+ bool use_compression);
+
+ // Creates an `HttpRequest` with base uri, request headers and compression
+ // setting. The `uri_path_suffix` argument must always either be empty or
+ // start with a leading '/'. The method will return `InvalidArgumentError` if
+ // this isn't the case. The `uri_path_suffix` should not contain any query
+ // parameters, instead, query parameters should be specified in `params`.
+ //
+ // The URI to which the protocol request will be sent will be constructed by
+ // joining `next_request_base_uri_` with `uri_path_suffix` (see
+ // `JoinBaseUriWithSuffix` for details), and any query parameters if `params`
+ // is not empty.
+ //
+ // When `is_protobuf_encoded` is true, `%24alt=proto` will be added to the uri
+ // as a query parameter to indicate that the proto encoded payload is
+ // expected. When the `request_body` is not empty, a `Content-Type` header
+ // will also be added to the request
+ absl::StatusOr<std::unique_ptr<HttpRequest>> CreateProtocolRequest(
+ absl::string_view uri_path_suffix, QueryParams params,
+ HttpRequest::Method method, std::string request_body,
+ bool is_protobuf_encoded) const;
+
+ // Creates an `HttpRequest` for getting the result of a
+ // `google.longrunning.operation`. Note that the request body is empty,
+ // because its only field (`name`) is included in the URI instead. Also note
+ // that the `next_request_headers_` will be attached to this request.
+ absl::StatusOr<std::unique_ptr<HttpRequest>> CreateGetOperationRequest(
+ absl::string_view operation_name) const;
+
+ // Creates an `HttpRequest` for canceling a `google.longrunning.operation`.
+ // Note that the request body is empty, because its only field (`name`) is
+ // included in the URI instead. Also note that the `next_request_headers_`
+ // will be attached to this request.
+ absl::StatusOr<std::unique_ptr<HttpRequest>> CreateCancelOperationRequest(
+ absl::string_view operation_name) const;
+
+ private:
+ absl::StatusOr<std::unique_ptr<HttpRequest>> CreateHttpRequest(
+ absl::string_view uri_path_suffix, QueryParams params,
+ HttpRequest::Method method, std::string request_body,
+ bool is_protobuf_encoded, bool use_compression) const;
+ // The URI to use for the next protocol request. See `ForwardingInfo`.
+ std::string next_request_base_uri_;
+ // The API key used for requests.
+ const std::string api_key_;
+ // The set of headers to attach to the next protocol request. See
+ // `ForwardingInfo`.
+ HeaderList next_request_headers_;
+ const bool use_compression_;
+};
+
+// A helper for issuing protocol requests.
+class ProtocolRequestHelper {
+ public:
+ ProtocolRequestHelper(HttpClient* http_client, int64_t* bytes_downloaded,
+ int64_t* bytes_uploaded,
+ WallClockStopwatch* network_stopwatch, Clock* clock);
+
+ // Performs the given request (handling any interruptions that may occur) and
+ // updates the network stats.
+ absl::StatusOr<InMemoryHttpResponse> PerformProtocolRequest(
+ std::unique_ptr<HttpRequest> request, InterruptibleRunner& runner);
+
+ // Performs the vector of requests (handling any interruptions that may occur)
+ // concurrently and updates the network stats.
+ // The returned vector of responses has the same order of the issued requests.
+ absl::StatusOr<std::vector<absl::StatusOr<InMemoryHttpResponse>>>
+ PerformMultipleProtocolRequests(
+ std::vector<std::unique_ptr<HttpRequest>> requests,
+ InterruptibleRunner& runner);
+
+ // Helper function for handling an HTTP response that contains an `Operation`
+ // proto.
+ //
+ // Takes an HTTP response (which must have been produced by a call to
+ // `PerformRequestInMemory`), parses the proto, and returns it if its
+ // `Operation.done` field is true. If the field is false then this method
+ // keeps polling the Operation via performing requests created by
+ // `CreateGetOperationRequest` until it a response is received where the field
+ // is true, at which point that most recent response is returned. If at any
+ // point an HTTP or response parsing error is encountered, then that error is
+ // returned instead.
+ // absl::StatusOr<::google::longrunning::Operation>
+ // PollOperationResponseUntilDone(
+ // const ::google::longrunning::Operation& initial_operation,
+ // const ProtocolRequestCreator& request_creator,
+ // InterruptibleRunner& runner);
+
+ // // Helper function for cancelling an operation.
+ // absl::StatusOr<InMemoryHttpResponse> CancelOperation(
+ // absl::string_view operation_name,
+ // const ProtocolRequestCreator& request_creator,
+ // InterruptibleRunner& runner);
+
+ private:
+ HttpClient& http_client_;
+ int64_t& bytes_downloaded_;
+ int64_t& bytes_uploaded_;
+ WallClockStopwatch& network_stopwatch_;
+ Clock& clock_;
+};
+
+// Parse a google::longrunning::Operation out of a InMemoryHttpResponse.
+// If the initial http_response is not OK, this method will immediately return
+// with the error status.
+// absl::StatusOr<google::longrunning::Operation>
+// ParseOperationProtoFromHttpResponse(
+// absl::StatusOr<InMemoryHttpResponse> http_response);
+
+// // Extract the operation name from Operation proto.
+// // If the operation name is not started with "operations/", invalid argument
+// // error will be returned.
+// absl::StatusOr<std::string> ExtractOperationName(
+// const google::longrunning::Operation& operation);
+
+} // namespace http
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_HTTP_PROTOCOL_REQUEST_HELPER_H_
diff --git a/fcp/client/http/protocol_request_helper_test.cc b/fcp/client/http/protocol_request_helper_test.cc
new file mode 100644
index 0000000..e6d2834
--- /dev/null
+++ b/fcp/client/http/protocol_request_helper_test.cc
@@ -0,0 +1,762 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/protocol_request_helper.h"
+
+#include "google/protobuf/any.pb.h"
+#include "fcp/base/time_util.h"
+#include "fcp/client/http/testing/test_helpers.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
+#include "fcp/protos/federatedcompute/task_assignments.pb.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+namespace {
+
+using ::google::internal::federatedcompute::v1::AdvertiseKeysMetadata;
+using ::google::internal::federatedcompute::v1::ForwardingInfo;
+using ::google::internal::federatedcompute::v1::ShareKeysMetadata;
+using ::google::internal::federatedcompute::v1::StartTaskAssignmentMetadata;
+using ::google::internal::federatedcompute::v1::
+ SubmitSecureAggregationResultMetadata;
+using ::google::longrunning::Operation;
+using ::google::protobuf::Any;
+using ::testing::_;
+using ::testing::ContainerEq;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+using ::testing::MockFunction;
+using ::testing::NiceMock;
+using ::testing::Return;
+using ::testing::StrictMock;
+using ::testing::UnorderedElementsAre;
+
+constexpr absl::string_view kApiKey = "API_KEY";
+
+class MockClock : public Clock {
+ public:
+ MOCK_METHOD(absl::Time, Now, (), (override));
+ MOCK_METHOD(void, Sleep, (absl::Duration duration), (override));
+
+ protected:
+ MOCK_METHOD(absl::Time, NowLocked, (), (override));
+ MOCK_METHOD(void, ScheduleWakeup, (absl::Time wakeup_time), (override));
+};
+
+void VerifyInMemoryHttpResponse(const InMemoryHttpResponse& response, int code,
+ absl::string_view content_encoding,
+ absl::string_view body) {
+ EXPECT_EQ(response.code, code);
+ EXPECT_EQ(response.content_encoding, content_encoding);
+ EXPECT_EQ(response.body, body);
+}
+
+Operation CreatePendingOperation(const std::string operation_name) {
+ Operation operation;
+ operation.set_done(false);
+ operation.set_name(operation_name);
+ return operation;
+}
+
+Operation CreatePendingOperation(const std::string operation_name,
+ const Any& metadata) {
+ Operation operation;
+ operation.set_done(false);
+ operation.set_name(operation_name);
+ *operation.mutable_metadata() = metadata;
+ return operation;
+}
+
+// Creates a 'done' `Operation`, with the given already-packed-into-`Any`
+// result.
+Operation CreateDoneOperation(const Any& packed_inner_result) {
+ Operation operation;
+ operation.set_done(true);
+ *operation.mutable_response() = packed_inner_result;
+ return operation;
+}
+
+Operation CreateErrorOperation(const absl::StatusCode error_code,
+ const std::string error_message) {
+ Operation operation;
+ operation.set_done(true);
+ operation.mutable_error()->set_code(static_cast<int>(error_code));
+ operation.mutable_error()->set_message(error_message);
+ return operation;
+}
+
+TEST(ProtocolRequestCreatorTest, TestInvalidForwardingInfo) {
+ // If a ForwardingInfo does not have a target_uri_prefix field set then the
+ // ProcessForwardingInfo call should fail.
+ ForwardingInfo forwarding_info;
+ EXPECT_THAT(ProtocolRequestCreator::Create(kApiKey, forwarding_info,
+ /*use_compression=*/false),
+ IsCode(INVALID_ARGUMENT));
+
+ (*forwarding_info.mutable_extra_request_headers())["x-header1"] =
+ "header-value1";
+ EXPECT_THAT(ProtocolRequestCreator::Create(kApiKey, forwarding_info,
+ /*use_compression=*/false),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(ProtocolRequestCreatorTest, CreateProtocolRequestInvalidSuffix) {
+ ProtocolRequestCreator creator("https://initial.uri", kApiKey, HeaderList(),
+ /*use_compression=*/false);
+ std::string uri_suffix = "v1/request";
+ ASSERT_THAT(
+ creator.CreateProtocolRequest(uri_suffix, QueryParams(),
+ HttpRequest::Method::kPost, "request_body",
+ /*is_protobuf_encoded=*/false),
+ IsCode(absl::StatusCode::kInvalidArgument));
+}
+
+TEST(ProtocolRequestCreatorTest, CreateProtocolRequest) {
+ ProtocolRequestCreator creator("https://initial.uri", kApiKey, HeaderList(),
+ /*use_compression=*/false);
+ std::string expected_body = "expected_body";
+ auto request = creator.CreateProtocolRequest(
+ "/v1/request", QueryParams(), HttpRequest::Method::kPost, expected_body,
+ /*is_protobuf_encoded=*/false);
+
+ ASSERT_OK(request);
+ EXPECT_EQ((*request)->uri(), "https://initial.uri/v1/request");
+ EXPECT_EQ((*request)->method(), HttpRequest::Method::kPost);
+ EXPECT_THAT(
+ (*request)->extra_headers(),
+ UnorderedElementsAre(
+ Header{"x-goog-api-key", "API_KEY"},
+ Header{"Content-Length", std::to_string(expected_body.size())}));
+ EXPECT_TRUE((*request)->HasBody());
+ std::string actual_body;
+ actual_body.resize(expected_body.size());
+ ASSERT_OK((*request)->ReadBody(actual_body.data(), expected_body.size()));
+ EXPECT_EQ(actual_body, expected_body);
+}
+
+TEST(ProtocolRequestCreatorTest, CreateProtobufEncodedProtocolRequest) {
+ ProtocolRequestCreator creator("https://initial.uri", kApiKey, HeaderList(),
+ /*use_compression=*/false);
+ std::string expected_body = "expected_body";
+ auto request = creator.CreateProtocolRequest(
+ "/v1/request", QueryParams(), HttpRequest::Method::kPost, expected_body,
+ /*is_protobuf_encoded=*/true);
+
+ ASSERT_OK(request);
+ EXPECT_EQ((*request)->uri(), "https://initial.uri/v1/request?%24alt=proto");
+ EXPECT_EQ((*request)->method(), HttpRequest::Method::kPost);
+ EXPECT_THAT((*request)->extra_headers(),
+ UnorderedElementsAre(
+ Header{"x-goog-api-key", "API_KEY"},
+ Header{"Content-Length", absl::StrCat(expected_body.size())},
+ Header{"Content-Type", "application/x-protobuf"}));
+ EXPECT_TRUE((*request)->HasBody());
+ std::string actual_body;
+ actual_body.resize(expected_body.size());
+ ASSERT_OK((*request)->ReadBody(actual_body.data(), actual_body.size()));
+ EXPECT_EQ(actual_body, expected_body);
+}
+
+TEST(ProtocolRequestCreatorTest, CreateGetOperationRequest) {
+ ProtocolRequestCreator creator("https://initial.uri", kApiKey, HeaderList(),
+ /*use_compression=*/false);
+ std::string operation_name = "my_operation";
+ auto request = creator.CreateGetOperationRequest(operation_name);
+ ASSERT_OK(request);
+ EXPECT_EQ((*request)->uri(),
+ "https://initial.uri/v1/my_operation?%24alt=proto");
+ EXPECT_EQ((*request)->method(), HttpRequest::Method::kGet);
+ EXPECT_THAT((*request)->extra_headers(),
+ UnorderedElementsAre(Header{"x-goog-api-key", "API_KEY"}));
+ EXPECT_FALSE((*request)->HasBody());
+}
+
+TEST(ProtocolRequestCreatorTest, CreateCancelOperationRequest) {
+ ProtocolRequestCreator creator("https://initial.uri", kApiKey, HeaderList(),
+ /*use_compression=*/false);
+ std::string operation_name = "my_operation";
+ auto request = creator.CreateCancelOperationRequest(operation_name);
+ ASSERT_OK(request);
+ EXPECT_EQ((*request)->uri(),
+ "https://initial.uri/v1/my_operation:cancel?%24alt=proto");
+ EXPECT_EQ((*request)->method(), HttpRequest::Method::kGet);
+ EXPECT_THAT((*request)->extra_headers(),
+ UnorderedElementsAre(Header{"x-goog-api-key", "API_KEY"}));
+ EXPECT_FALSE((*request)->HasBody());
+}
+
+class ProtocolRequestHelperTest : public ::testing::Test {
+ public:
+ ProtocolRequestHelperTest()
+ : interruptible_runner_(
+ &mock_log_manager_, mock_should_abort_.AsStdFunction(),
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
+ .interrupt_timeout =
+ ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
+ .interrupted_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
+ .interrupt_timeout_extended = ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT}),
+ initial_request_creator_("https://initial.uri", kApiKey, HeaderList(),
+ /*use_compression=*/false),
+ protocol_request_helper_(&mock_http_client_, &bytes_downloaded_,
+ &bytes_uploaded_, network_stopwatch_.get(),
+ &mock_clock_) {}
+
+ protected:
+ void TearDown() override {
+ // Regardless of the outcome of the test (or the protocol interaction being
+ // tested), network usage must always be reflected in the network stats.
+ HttpRequestHandle::SentReceivedBytes sent_received_bytes =
+ mock_http_client_.TotalSentReceivedBytes();
+ EXPECT_THAT(bytes_downloaded_, sent_received_bytes.received_bytes);
+ EXPECT_THAT(bytes_uploaded_, sent_received_bytes.sent_bytes);
+ }
+
+ StrictMock<MockClock> mock_clock_;
+ StrictMock<MockHttpClient> mock_http_client_;
+
+ NiceMock<MockLogManager> mock_log_manager_;
+ NiceMock<MockFunction<bool()>> mock_should_abort_;
+
+ int64_t bytes_downloaded_ = 0;
+ int64_t bytes_uploaded_ = 0;
+ std::unique_ptr<WallClockStopwatch> network_stopwatch_ =
+ WallClockStopwatch::Create();
+
+ InterruptibleRunner interruptible_runner_;
+ ProtocolRequestCreator initial_request_creator_;
+ // The class under test.
+ ProtocolRequestHelper protocol_request_helper_;
+};
+
+Any GetFakeAnyProto() {
+ Any fake_any;
+ fake_any.set_type_url("the_type_url");
+ *fake_any.mutable_value() = "the_value";
+ return fake_any;
+}
+
+TEST_F(ProtocolRequestHelperTest, TestForwardingInfoIsPassedAlongCorrectly) {
+ // The initial request should use the initial entry point URI and an empty set
+ // of headers.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/suffix1", HttpRequest::Method::kPost,
+ // This request has a response body, so the HttpClient will
+ // add this header automatically.
+ ContainerEq(HeaderList{{"x-goog-api-key", "API_KEY"},
+ {"Content-Length", "5"}}),
+ "body1")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "response1")));
+ auto request_creator = std::make_unique<ProtocolRequestCreator>(
+ "https://initial.uri", kApiKey, HeaderList(),
+ /*use_compression=*/false);
+ auto http_request = request_creator->CreateProtocolRequest(
+ "/suffix1", QueryParams(), HttpRequest::Method::kPost, "body1",
+ /*is_protobuf_encoded=*/false);
+ ASSERT_OK(http_request);
+ auto result = protocol_request_helper_.PerformProtocolRequest(
+ *std::move(http_request), interruptible_runner_);
+ ASSERT_OK(result);
+ VerifyInMemoryHttpResponse(*result, 200, "", "response1");
+
+ {
+ // Process some fake ForwardingInfo.
+ ForwardingInfo forwarding_info1;
+ forwarding_info1.set_target_uri_prefix("https://second.uri/");
+ (*forwarding_info1.mutable_extra_request_headers())["x-header1"] =
+ "header-value1";
+ (*forwarding_info1.mutable_extra_request_headers())["x-header2"] =
+ "header-value2";
+ auto new_request_creator = ProtocolRequestCreator::Create(
+ kApiKey, forwarding_info1, /*use_compression=*/false);
+ ASSERT_OK(new_request_creator);
+ request_creator = std::move(*new_request_creator);
+ }
+
+ // The next series of requests should now use the ForwardingInfo (incl. use
+ // the "https://second.uri/" prefix, and include the headers). Note that we
+ // must use UnorderedElementsAre since the iteration order of the headers in
+ // the `ForwardingInfo` is undefined.
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://second.uri/suffix2", HttpRequest::Method::kGet,
+ UnorderedElementsAre(Header{"x-goog-api-key", "API_KEY"},
+ Header{"x-header1", "header-value1"},
+ Header{"x-header2", "header-value2"}),
+ "")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "response2")));
+ http_request = request_creator->CreateProtocolRequest(
+ "/suffix2", QueryParams(), HttpRequest::Method::kGet, "",
+ /*is_protobuf_encoded=*/false);
+ ASSERT_OK(http_request);
+ result = protocol_request_helper_.PerformProtocolRequest(
+ *std::move(http_request), interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_EQ(result->body, "response2");
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://second.uri/suffix3", HttpRequest::Method::kPut,
+ UnorderedElementsAre(Header{"x-goog-api-key", "API_KEY"},
+ Header{"x-header1", "header-value1"},
+ Header{"x-header2", "header-value2"},
+ // This request has a response body, so
+ // the HttpClient will add this header
+ // automatically.
+ Header{"Content-Length", "5"}),
+ "body3")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "response3")));
+ http_request = request_creator->CreateProtocolRequest(
+ "/suffix3", QueryParams(), HttpRequest::Method::kPut, "body3",
+ /*is_protobuf_encoded=*/false);
+ ASSERT_OK(http_request);
+ result = protocol_request_helper_.PerformProtocolRequest(
+ *std::move(http_request), interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_EQ(result->body, "response3");
+
+ {
+ // Process some more fake ForwardingInfo (without any headers this time).
+ ForwardingInfo forwarding_info2;
+ forwarding_info2.set_target_uri_prefix("https://third.uri");
+ auto new_request_creator = ProtocolRequestCreator::Create(
+ kApiKey, forwarding_info2, /*use_compression=*/false);
+ ASSERT_OK(new_request_creator);
+ request_creator = std::move(*new_request_creator);
+ }
+
+ // The next request should now use the latest ForwardingInfo again (i.e. use
+ // the "https://third.uri/" prefix, and not specify any headers).
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://third.uri/suffix4", HttpRequest::Method::kPost,
+ // This request has a response body, so the HttpClient will
+ // add this header automatically.
+ ContainerEq(HeaderList{
+ {"x-goog-api-key", "API_KEY"},
+ {"Content-Length", "5"},
+ }),
+ "body4")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "response4")));
+ http_request = request_creator->CreateProtocolRequest(
+ "/suffix4", QueryParams(), HttpRequest::Method::kPost, "body4",
+ /*is_protobuf_encoded=*/false);
+ ASSERT_OK(http_request);
+ result = protocol_request_helper_.PerformProtocolRequest(
+ *std::move(http_request), interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_EQ(result->body, "response4");
+}
+
+TEST_F(ProtocolRequestHelperTest, TestPollOperationInvalidOperationName) {
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ CreatePendingOperation("invalid_operation_name"),
+ initial_request_creator_, interruptible_runner_);
+ EXPECT_THAT(result.status(), IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(result.status().message(), HasSubstr("invalid name"));
+}
+
+TEST_F(ProtocolRequestHelperTest, TestPollOperationResponseImmediateSuccess) {
+ Operation expected_response = CreateDoneOperation(GetFakeAnyProto());
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ expected_response, initial_request_creator_, interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_response));
+}
+
+TEST_F(ProtocolRequestHelperTest,
+ TestPollOperationResponseImmediateOperationError) {
+ Operation expected_response =
+ CreateErrorOperation(ALREADY_EXISTS, "some error");
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ expected_response, initial_request_creator_, interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_response));
+}
+
+TEST_F(ProtocolRequestHelperTest,
+ TestPollOperationResponseSuccessAfterPolling) {
+ // Make the initial request return a pending Operation result. Note that we
+ // use a '#' character in the operation name to allow us to verify that it
+ // is properly URL-encoded.
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+
+ // Then, after letting the operation get polled twice more, eventually
+ // return a fake response.
+ Operation expected_response = CreateDoneOperation(GetFakeAnyProto());
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://initial.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, IsEmpty())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation_response.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation_response.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), expected_response.SerializeAsString())));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(500))).Times(3);
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ pending_operation_response, initial_request_creator_,
+ interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_response));
+}
+
+TEST_F(ProtocolRequestHelperTest, TestPollOperationResponseErrorAfterPolling) {
+ // Make the initial request return a pending Operation result.
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+
+ // Then, after letting the operation get polled twice more, eventually
+ // return a fake error.
+ Operation expected_response =
+ CreateErrorOperation(ALREADY_EXISTS, "some error");
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://initial.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, IsEmpty())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation_response.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), pending_operation_response.SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), expected_response.SerializeAsString())));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(500))).Times(3);
+
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ pending_operation_response, initial_request_creator_,
+ interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_response));
+}
+
+TEST_F(ProtocolRequestHelperTest,
+ TestPollOperationResponseDifferentPollingIntervals) {
+ StartTaskAssignmentMetadata metadata;
+ *metadata.mutable_polling_interval() =
+ TimeUtil::ConvertAbslToProtoDuration(absl::Milliseconds(2));
+ Any packed_metadata;
+ ASSERT_TRUE(packed_metadata.PackFrom(metadata));
+ StartTaskAssignmentMetadata metadata_2;
+ *metadata_2.mutable_polling_interval() =
+ TimeUtil::ConvertAbslToProtoDuration(absl::Milliseconds(3));
+ Any packed_metadata_2;
+ ASSERT_TRUE(packed_metadata_2.PackFrom(metadata_2));
+
+ // Make the initial request return a pending Operation result. Note that we
+ // use a '#' character in the operation name to allow us to verify that it
+ // is properly URL-encoded.
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+
+ // Then, after letting the operation get polled twice more, eventually
+ // return a fake response.
+ Operation expected_response = CreateDoneOperation(GetFakeAnyProto());
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://initial.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, IsEmpty())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar", packed_metadata)
+ .SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar", packed_metadata_2)
+ .SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), expected_response.SerializeAsString())));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(500)));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(2)));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(3)));
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ pending_operation_response, initial_request_creator_,
+ interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_response));
+}
+
+TEST_F(ProtocolRequestHelperTest,
+ TestPollOperationResponsePollingIntervalTooHigh) {
+ StartTaskAssignmentMetadata metadata;
+ *metadata.mutable_polling_interval() =
+ TimeUtil::ConvertAbslToProtoDuration(absl::Hours(1));
+ Any packed_metadata;
+ ASSERT_TRUE(packed_metadata.PackFrom(metadata));
+
+ // Make the initial request return a pending Operation result. Note that we
+ // use a '#' character in the operation name to allow us to verify that it
+ // is properly URL-encoded.
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+
+ // Then, after letting the operation get polled twice more, eventually
+ // return a fake response.
+ Operation expected_response = CreateDoneOperation(GetFakeAnyProto());
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://initial.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, IsEmpty())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar", packed_metadata)
+ .SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), expected_response.SerializeAsString())));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(500)));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Minutes(1)));
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ pending_operation_response, initial_request_creator_,
+ interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_response));
+}
+
+TEST_F(ProtocolRequestHelperTest,
+ TestPollOperationResponseAdvertiseKeysMetadata) {
+ AdvertiseKeysMetadata metadata;
+ *metadata.mutable_polling_interval() =
+ TimeUtil::ConvertAbslToProtoDuration(absl::Milliseconds(2));
+ Any packed_metadata;
+ ASSERT_TRUE(packed_metadata.PackFrom(metadata));
+
+ // Make the initial request return a pending Operation result. Note that we
+ // use a '#' character in the operation name to allow us to verify that it
+ // is properly URL-encoded.
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+
+ // Then, after letting the operation get polled twice more, eventually
+ // return a fake response.
+ Operation expected_response = CreateDoneOperation(GetFakeAnyProto());
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://initial.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, IsEmpty())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar", packed_metadata)
+ .SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), expected_response.SerializeAsString())));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(500)));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(2)));
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ pending_operation_response, initial_request_creator_,
+ interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_response));
+}
+
+TEST_F(ProtocolRequestHelperTest, TestPollOperationResponseShareKeysMetadata) {
+ ShareKeysMetadata metadata;
+ *metadata.mutable_polling_interval() =
+ TimeUtil::ConvertAbslToProtoDuration(absl::Milliseconds(2));
+ Any packed_metadata;
+ ASSERT_TRUE(packed_metadata.PackFrom(metadata));
+
+ // Make the initial request return a pending Operation result. Note that we
+ // use a '#' character in the operation name to allow us to verify that it
+ // is properly URL-encoded.
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+
+ // Then, after letting the operation get polled twice more, eventually
+ // return a fake response.
+ Operation expected_response = CreateDoneOperation(GetFakeAnyProto());
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://initial.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, IsEmpty())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar", packed_metadata)
+ .SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), expected_response.SerializeAsString())));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(500)));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(2)));
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ pending_operation_response, initial_request_creator_,
+ interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_response));
+}
+
+TEST_F(ProtocolRequestHelperTest,
+ TestPollOperationResponseSubmitSecureAggregationResultMetadata) {
+ SubmitSecureAggregationResultMetadata metadata;
+ *metadata.mutable_polling_interval() =
+ TimeUtil::ConvertAbslToProtoDuration(absl::Milliseconds(2));
+ Any packed_metadata;
+ ASSERT_TRUE(packed_metadata.PackFrom(metadata));
+
+ // Make the initial request return a pending Operation result. Note that we
+ // use a '#' character in the operation name to allow us to verify that it
+ // is properly URL-encoded.
+ Operation pending_operation_response =
+ CreatePendingOperation("operations/foo#bar");
+
+ // Then, after letting the operation get polled twice more, eventually
+ // return a fake response.
+ Operation expected_response = CreateDoneOperation(GetFakeAnyProto());
+
+ EXPECT_CALL(mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ // Note that the '#' character is encoded as "%23".
+ "https://initial.uri/v1/operations/foo%23bar?%24alt=proto",
+ HttpRequest::Method::kGet, _, IsEmpty())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(),
+ CreatePendingOperation("operations/foo#bar", packed_metadata)
+ .SerializeAsString())))
+ .WillOnce(Return(FakeHttpResponse(
+ 200, HeaderList(), expected_response.SerializeAsString())));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(500)));
+ EXPECT_CALL(mock_clock_, Sleep(absl::Milliseconds(2)));
+ absl::StatusOr<Operation> result =
+ protocol_request_helper_.PollOperationResponseUntilDone(
+ pending_operation_response, initial_request_creator_,
+ interruptible_runner_);
+ ASSERT_OK(result);
+ EXPECT_THAT(*result, EqualsProto(expected_response));
+}
+
+TEST_F(ProtocolRequestHelperTest, PerformMultipleRequestsSuccess) {
+ auto request_a = initial_request_creator_.CreateProtocolRequest(
+ "/v1/request_a", QueryParams(), HttpRequest::Method::kPost, "body1",
+ /*is_protobuf_encoded=*/false);
+ ASSERT_OK(request_a);
+ auto request_b = initial_request_creator_.CreateProtocolRequest(
+ "/v1/request_b", QueryParams(), HttpRequest::Method::kPost, "body2",
+ /*is_protobuf_encoded=*/false);
+ ASSERT_OK(request_b);
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/request_a", HttpRequest::Method::kPost,
+ // This request has a response body, so the HttpClient will
+ // add this header automatically.
+ ContainerEq(HeaderList{{"x-goog-api-key", "API_KEY"},
+ {"Content-Length", "5"}}),
+ "body1")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "response1")));
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/request_b", HttpRequest::Method::kPost,
+ // This request has a response body, so the HttpClient will
+ // add this header automatically.
+ ContainerEq(HeaderList{{"x-goog-api-key", "API_KEY"},
+ {"Content-Length", "5"}}),
+ "body2")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "response2")));
+
+ std::vector<std::unique_ptr<HttpRequest>> requests;
+ requests.push_back(std::move(*request_a));
+ requests.push_back(std::move(*request_b));
+
+ auto result = protocol_request_helper_.PerformMultipleProtocolRequests(
+ std::move(requests), interruptible_runner_);
+ ASSERT_OK(result);
+ auto response_1 = (*result)[0];
+ ASSERT_OK(response_1);
+ VerifyInMemoryHttpResponse(*response_1, 200, "", "response1");
+ auto response_2 = (*result)[1];
+ ASSERT_OK(response_2);
+ VerifyInMemoryHttpResponse(*response_2, 200, "", "response2");
+}
+
+TEST_F(ProtocolRequestHelperTest, PerformMultipleRequestsPartialFail) {
+ std::string uri_suffix_a = "/v1/request_a";
+ auto request_a = initial_request_creator_.CreateProtocolRequest(
+ uri_suffix_a, QueryParams(), HttpRequest::Method::kPost, "body1",
+ /*is_protobuf_encoded=*/false);
+ ASSERT_OK(request_a);
+ std::string uri_suffix_b = "/v1/request_b";
+ auto request_b = initial_request_creator_.CreateProtocolRequest(
+ uri_suffix_b, QueryParams(), HttpRequest::Method::kPost, "body2",
+ /*is_protobuf_encoded=*/false);
+ ASSERT_OK(request_b);
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/request_a", HttpRequest::Method::kPost,
+ // This request has a response body, so the HttpClient will
+ // add this header automatically.
+ ContainerEq(HeaderList{{"x-goog-api-key", "API_KEY"},
+ {"Content-Length", "5"}}),
+ "body1")))
+ .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "response1")));
+ EXPECT_CALL(
+ mock_http_client_,
+ PerformSingleRequest(SimpleHttpRequestMatcher(
+ "https://initial.uri/v1/request_b", HttpRequest::Method::kPost,
+ // This request has a response body, so the HttpClient will
+ // add this header automatically.
+ ContainerEq(HeaderList{{"x-goog-api-key", "API_KEY"},
+ {"Content-Length", "5"}}),
+ "body2")))
+ .WillOnce(
+ Return(FakeHttpResponse(404, HeaderList(), "failure_response")));
+
+ std::vector<std::unique_ptr<HttpRequest>> requests;
+ requests.push_back(std::move(*request_a));
+ requests.push_back(std::move(*request_b));
+
+ auto result = protocol_request_helper_.PerformMultipleProtocolRequests(
+ std::move(requests), interruptible_runner_);
+
+ ASSERT_OK(result);
+ auto response_1 = (*result)[0];
+ ASSERT_OK(response_1);
+ VerifyInMemoryHttpResponse(*response_1, 200, "", "response1");
+ auto response_2 = (*result)[1];
+ ASSERT_THAT(response_2, IsCode(absl::StatusCode::kNotFound));
+}
+} // anonymous namespace
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/testing/BUILD b/fcp/client/http/testing/BUILD
new file mode 100644
index 0000000..401ba31
--- /dev/null
+++ b/fcp/client/http/testing/BUILD
@@ -0,0 +1,40 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# TODO(team)
+
+cc_library(
+ name = "test_helpers",
+ testonly = True,
+ srcs = ["test_helpers.cc"],
+ hdrs = ["test_helpers.h"],
+ deps = [
+ "//fcp/base",
+ "//fcp/client/http:http_client",
+ "//fcp/client/http:http_client_util",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_googleapis//google/longrunning:longrunning_cc_proto",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/client/http/testing/http_test_server.cc b/fcp/client/http/testing/http_test_server.cc
new file mode 100644
index 0000000..081b86f
--- /dev/null
+++ b/fcp/client/http/testing/http_test_server.cc
@@ -0,0 +1,105 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/http/testing/http_test_server.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/status/statusor.h"
+#include "fcp/base/scheduler.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+namespace {
+using ::tensorflow::serving::net_http::EventExecutor;
+using ::tensorflow::serving::net_http::HTTPServerInterface;
+using ::tensorflow::serving::net_http::RequestHandlerOptions;
+using ::tensorflow::serving::net_http::ServerOptions;
+using ::tensorflow::serving::net_http::ServerRequestInterface;
+
+// Echoes the request back to the client
+void EchoHandler(ServerRequestInterface* req) {
+ std::string response;
+
+ absl::StrAppend(&response, "HTTP Method: ", req->http_method(), "\n");
+ absl::StrAppend(&response, "Request Uri: ", req->uri_path(), "\n");
+
+ absl::StrAppend(&response, "Request Headers:\n");
+ for (absl::string_view header : req->request_headers()) {
+ absl::StrAppend(&response, header, ": ", req->GetRequestHeader(header),
+ "\n");
+ }
+
+ absl::StrAppend(&response, "Request Body:\n");
+ // Read the request body
+ int64_t num_bytes;
+ while (true) {
+ auto request_chunk = req->ReadRequestBytes(&num_bytes);
+ if (request_chunk == nullptr) {
+ break;
+ }
+ absl::StrAppend(&response,
+ absl::string_view(request_chunk.get(), num_bytes));
+ }
+
+ req->WriteResponseString(response);
+
+ SetContentTypeHTML(req);
+ req->Reply();
+}
+
+// Non-blocking event executor needed for an event-driven web server
+class ThreadPoolEventExecutor final : public EventExecutor {
+ public:
+ explicit ThreadPoolEventExecutor(int num_threads)
+ : thread_pool_scheduler_(CreateThreadPoolScheduler(num_threads)) {}
+ ~ThreadPoolEventExecutor() override {
+ thread_pool_scheduler_->WaitUntilIdle();
+ }
+
+ void Schedule(std::function<void()> fn) override {
+ thread_pool_scheduler_->Schedule(fn);
+ }
+
+ private:
+ std::unique_ptr<Scheduler> thread_pool_scheduler_;
+};
+} // namespace
+
+absl::StatusOr<std::unique_ptr<HTTPServerInterface>> CreateHttpTestServer(
+ const std::string& uri, int port, int num_threads) {
+ auto options = std::make_unique<ServerOptions>();
+ options->AddPort(port);
+ options->SetExecutor(std::make_unique<ThreadPoolEventExecutor>(num_threads));
+
+ std::unique_ptr<HTTPServerInterface> http_server =
+ CreateEvHTTPServer(std::move(options));
+ if (http_server == nullptr) {
+ return absl::InternalError("Failed to create EvHTTPServer");
+ }
+
+ RequestHandlerOptions handler_options;
+ http_server->RegisterRequestHandler(uri, EchoHandler, handler_options);
+ return http_server;
+}
+
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/testing/http_test_server.h b/fcp/client/http/testing/http_test_server.h
new file mode 100644
index 0000000..d16f5db
--- /dev/null
+++ b/fcp/client/http/testing/http_test_server.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_HTTP_TESTING_HTTP_TEST_SERVER_H_
+#define FCP_CLIENT_HTTP_TESTING_HTTP_TEST_SERVER_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow_serving/util/net_http/server/public/httpserver.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+// Creates a test server where
+// uri: the uri starting with "/"
+// port: the port to listen
+// num_threads: the number of parallel threads the server has
+absl::StatusOr<
+ std::unique_ptr<::tensorflow::serving::net_http::HTTPServerInterface>>
+CreateHttpTestServer(const std::string& uri, int port, int num_threads);
+} // namespace http
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_HTTP_TESTING_HTTP_TEST_SERVER_H_
diff --git a/fcp/client/http/testing/test_helpers.cc b/fcp/client/http/testing/test_helpers.cc
new file mode 100644
index 0000000..f532c56
--- /dev/null
+++ b/fcp/client/http/testing/test_helpers.cc
@@ -0,0 +1,250 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/http/testing/test_helpers.h"
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/http/http_client_util.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+using ::google::longrunning::Operation;
+using ::google::protobuf::Message;
+using ::testing::AllOf;
+using ::testing::Field;
+using ::testing::Matcher;
+
+namespace {
+
+// A simple `HttpRequestHandle` implementation for use with
+// `MockableHttpClient`.
+class SimpleHttpRequestHandle : public HttpRequestHandle {
+ public:
+ SimpleHttpRequestHandle(std::unique_ptr<HttpRequest> request,
+ std::function<void()> cancellation_listener)
+ : request_(std::move(request)),
+ cancellation_listener_(cancellation_listener) {}
+
+ HttpRequestHandle::SentReceivedBytes TotalSentReceivedBytes() const override {
+ return sent_received_bytes_;
+ }
+ void SetSentBytes(int64_t bytes) { sent_received_bytes_.sent_bytes = bytes; }
+ void SetReceivedBytes(int64_t bytes) {
+ sent_received_bytes_.received_bytes = bytes;
+ }
+
+ void Cancel() override { cancellation_listener_(); }
+
+ HttpRequest* request() { return request_.get(); }
+
+ // Marks the handle as having been passed to PerformRequests(...). Returns
+ // true if the handle hand't previously been marked performed.
+ bool MarkPerformed() {
+ bool already_performed = performed_;
+ performed_ = true;
+ return !already_performed;
+ }
+
+ private:
+ const std::unique_ptr<HttpRequest> request_;
+ std::function<void()> cancellation_listener_;
+ bool performed_ = false;
+ HttpRequestHandle::SentReceivedBytes sent_received_bytes_ = {0, 0};
+};
+
+} // namespace
+
+std::unique_ptr<HttpRequestHandle> MockableHttpClient::EnqueueRequest(
+ std::unique_ptr<HttpRequest> request) {
+ return std::make_unique<SimpleHttpRequestHandle>(
+ std::move(request), [this]() { this->cancellation_listener_(); });
+}
+
+absl::Status MockableHttpClient::PerformRequests(
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>> requests) {
+ for (const auto& [generic_handle, callback] : requests) {
+ auto handle = static_cast<SimpleHttpRequestHandle*>(generic_handle);
+ if (!handle->MarkPerformed()) {
+ return absl::InternalError(
+ "MockableHttpClient: handles cannot be used more than once.");
+ }
+
+ HttpRequest* request = handle->request();
+
+ std::string request_body;
+ if (request->HasBody()) {
+ const HeaderList& headers = request->extra_headers();
+ std::optional<std::string> content_length_hdr =
+ FindHeader(headers, kContentLengthHdr);
+ if (!content_length_hdr.has_value()) {
+ return absl::InternalError(
+ "MockableHttpClient only supports requests with known "
+ "Content-Length");
+ }
+ int64_t content_length;
+ if (!absl::SimpleAtoi(*content_length_hdr, &content_length)) {
+ return absl::InternalError(absl::StrCat(
+ "MockableHttpClient: unexpected Content-Length value: ",
+ content_length));
+ }
+ request_body.resize(content_length);
+
+ // Read the data all at once (our buffer should be big enough for it).
+ absl::StatusOr<int64_t> read_result =
+ request->ReadBody(&request_body[0], content_length);
+ if (!read_result.ok()) {
+ return absl::InternalError(
+ absl::StrCat("MockableHttpClient: ReadBody failed: ",
+ read_result.status().ToString()));
+ }
+ if (*read_result != content_length) {
+ return absl::InternalError(
+ absl::StrCat("MockableHttpClient: 1st ReadBody didn't read all the "
+ "data. Actual: ",
+ *read_result, ", expected: ", content_length));
+ }
+
+ // Ensure we've hit the end of the data by checking for OUT_OF_RANGE.
+ absl::Status read_body_result =
+ request->ReadBody(&request_body[0], 1).status();
+ if (read_body_result.code() != absl::StatusCode::kOutOfRange) {
+ return absl::InternalError(
+ absl::StrCat("MockableHttpClient: 2nd ReadBody failed: ",
+ read_body_result.ToString()));
+ }
+ }
+
+ // Forward the request to the PerformSingleRequest method (which
+ // generally will have been mocked using gMock's MOCK_METHOD). This
+ // method will return the response that we should then deliver to the
+ // HttpRequestCallback.
+ SimpleHttpRequest simple_request = {std::string(request->uri()),
+ request->method(),
+ request->extra_headers(), request_body};
+ absl::StatusOr<FakeHttpResponse> response =
+ PerformSingleRequest(simple_request);
+
+ // Mock some 'sent bytes data'. Users of this class shouldn't rely on the
+ // exact value (just as they can't expect to predict how much data a real
+ // `HttpClient` would send).
+ int64_t fake_sent_bytes = request->uri().size() + request_body.size();
+ handle->SetSentBytes(fake_sent_bytes);
+ sent_received_bytes_.sent_bytes += fake_sent_bytes;
+
+ if (!response.ok()) {
+ return absl::Status(
+ response.status().code(),
+ absl::StrCat("MockableHttpClient: PerformSingleRequest failed: ",
+ response.status().ToString()));
+ }
+
+ // Return the response data to the callback's various methods.
+ FCP_LOG(INFO) << "MockableHttpClient: Delivering response headers for: "
+ << request->uri();
+ absl::Status response_started_result =
+ callback->OnResponseStarted(*request, *response);
+ if (!response_started_result.ok()) {
+ return absl::InternalError(
+ absl::StrCat("MockableHttpClient: OnResponseStarted failed: ",
+ response_started_result.ToString()));
+ }
+
+ // Mock some 'received bytes data'. We add 100 bytes to ensure that even
+ // responses with empty response bodies do increase the counter, because
+ // generally headers will always be received.
+ int64_t fake_received_bytes = 100 + response->body().size();
+ handle->SetReceivedBytes(fake_received_bytes);
+ sent_received_bytes_.received_bytes += fake_received_bytes;
+
+ FCP_LOG(INFO) << "MockableHttpClient: Delivering response body for: "
+ << request->uri();
+ absl::Status response_body_result =
+ callback->OnResponseBody(*request, *response, response->body());
+ if (!response_body_result.ok()) {
+ return absl::InternalError(
+ absl::StrCat("MockableHttpClient: OnResponseBody failed: ",
+ response_body_result.ToString()));
+ }
+
+ FCP_LOG(INFO) << "MockableHttpClient: Delivering response completion for: "
+ << request->uri();
+ callback->OnResponseCompleted(*request, *response);
+ }
+ return absl::OkStatus();
+}
+
+Matcher<MockableHttpClient::SimpleHttpRequest> SimpleHttpRequestMatcher(
+ const Matcher<std::string>& uri_matcher,
+ const Matcher<HttpRequest::Method>& method_matcher,
+ const Matcher<HeaderList>& headers_matcher,
+ const Matcher<std::string>& body_matcher) {
+ return AllOf(
+ Field("uri", &MockableHttpClient::SimpleHttpRequest::uri, uri_matcher),
+ Field("method", &MockableHttpClient::SimpleHttpRequest::method,
+ method_matcher),
+ Field("headers", &MockableHttpClient::SimpleHttpRequest::headers,
+ headers_matcher),
+ Field("body", &MockableHttpClient::SimpleHttpRequest::body,
+ body_matcher));
+}
+
+Operation CreatePendingOperation(absl::string_view operation_name) {
+ Operation operation;
+ operation.set_done(false);
+ operation.set_name(std::string(operation_name));
+ return operation;
+}
+
+Operation CreateDoneOperation(absl::string_view operation_name,
+ const Message& inner_result) {
+ Operation operation;
+ operation.set_name(std::string(operation_name));
+ operation.set_done(true);
+ operation.mutable_response()->PackFrom(inner_result);
+ return operation;
+}
+
+Operation CreateErrorOperation(absl::string_view operation_name,
+ const absl::StatusCode error_code,
+ absl::string_view error_message) {
+ Operation operation;
+ operation.set_name(std::string(operation_name));
+ operation.set_done(true);
+ operation.mutable_error()->set_code(static_cast<int>(error_code));
+ operation.mutable_error()->set_message(std::string(error_message));
+ return operation;
+}
+
+} // namespace http
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/http/testing/test_helpers.h b/fcp/client/http/testing/test_helpers.h
new file mode 100644
index 0000000..3c690b8
--- /dev/null
+++ b/fcp/client/http/testing/test_helpers.h
@@ -0,0 +1,173 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_HTTP_TESTING_TEST_HELPERS_H_
+#define FCP_CLIENT_HTTP_TESTING_TEST_HELPERS_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/longrunning/operations.pb.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/base/attributes.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/http/http_client.h"
+
+namespace fcp {
+namespace client {
+namespace http {
+
+// A simple `HttpResponse` implementation consisting of a code, headers (which
+// are returned via the `HttpResponse` interface methods), and an optional
+// in-memory response body which `MockableHttpClient` can use to return the data
+// via the `HttpRequestCallback::OnResponseBody` method.
+class FakeHttpResponse : public HttpResponse {
+ public:
+ FakeHttpResponse(int code, HeaderList headers)
+ : code_(code), headers_(headers), body_("") {}
+ FakeHttpResponse(int code, HeaderList headers, std::string body)
+ : code_(code), headers_(headers), body_(body) {}
+
+ int code() const override { return code_; }
+ const HeaderList& headers() const override { return headers_; }
+ const std::string& body() const { return body_; }
+
+ private:
+ const int code_;
+ const HeaderList headers_;
+ const std::string body_;
+};
+
+// A simplified version of the `HttpClient` interface for use in tests. Enables
+// easy use with gMock via the simplified `PerformSingleRequest` interface.
+class MockableHttpClient : public HttpClient {
+ public:
+ // A simple container holding all important properties of an incoming
+ // `HttpRequest`. This is the parameter type passed to `PerformSingleRequest`,
+ // and because it is a simple struct (as opposed to `HttpRequest`, which uses
+ // a number of methods that are hard to mock such as `HttpRequest::ReadBody`)
+ // it makes it easy to use gMock matchers to match against parts or all of the
+ // request's properties.
+ struct SimpleHttpRequest {
+ const std::string uri;
+ const HttpRequest::Method method;
+ const HeaderList headers;
+ const std::string body;
+ };
+
+ MockableHttpClient() = default;
+
+ ABSL_MUST_USE_RESULT std::unique_ptr<HttpRequestHandle> EnqueueRequest(
+ std::unique_ptr<HttpRequest> request) override;
+
+ absl::Status PerformRequests(
+ std::vector<std::pair<HttpRequestHandle*, HttpRequestCallback*>> requests)
+ override;
+
+ // Implement this method (e.g. using gMock's MOCK_METHOD) for a simple way to
+ // mock a single request. See `MockHttpClient` below.
+ virtual absl::StatusOr<FakeHttpResponse> PerformSingleRequest(
+ SimpleHttpRequest request) = 0;
+
+ // Registers a callback that will be called when any request receives a
+ // `HttpRequestHandle::Cancel` call.
+ virtual void SetCancellationListener(std::function<void()> listener) {
+ cancellation_listener_ = listener;
+ }
+
+ // Returns the (fake) number of bytes that the mock client has sent/received.
+ // This number will match the sum of all `HttpRequestHandle`'s
+ // `TotalSentReceivedBytes()` methods after they were processed by the mock
+ // client.
+ virtual HttpRequestHandle::SentReceivedBytes TotalSentReceivedBytes() {
+ return sent_received_bytes_;
+ }
+
+ private:
+ std::function<void()> cancellation_listener_ = []() {};
+
+ // A running (fake) tally of the number of bytes that have been
+ // downloaded/uploaded so far.
+ HttpRequestHandle::SentReceivedBytes sent_received_bytes_;
+};
+
+// A convenient to use mock HttpClient implementation.
+class MockHttpClient : public MockableHttpClient {
+ public:
+ MockHttpClient() = default;
+
+ MOCK_METHOD(absl::StatusOr<FakeHttpResponse>, PerformSingleRequest,
+ (SimpleHttpRequest request), (override));
+};
+
+::testing::Matcher<MockableHttpClient::SimpleHttpRequest>
+SimpleHttpRequestMatcher(
+ const ::testing::Matcher<std::string>& uri_matcher,
+ const ::testing::Matcher<HttpRequest::Method>& method_matcher,
+ const ::testing::Matcher<HeaderList>& headers_matcher,
+ const ::testing::Matcher<std::string>& body_matcher);
+
+// A mock request callback.
+class MockHttpRequestCallback : public HttpRequestCallback {
+ public:
+ explicit MockHttpRequestCallback() = default;
+ ~MockHttpRequestCallback() override = default;
+ MOCK_METHOD(absl::Status, OnResponseStarted,
+ (const HttpRequest& request, const HttpResponse& response),
+ (override));
+
+ MOCK_METHOD(void, OnResponseError,
+ (const HttpRequest& request, const absl::Status& error),
+ (override));
+
+ MOCK_METHOD(absl::Status, OnResponseBody,
+ (const HttpRequest& request, const HttpResponse& response,
+ absl::string_view data),
+ (override));
+
+ MOCK_METHOD(void, OnResponseBodyError,
+ (const HttpRequest& request, const HttpResponse& response,
+ const absl::Status& error),
+ (override));
+
+ MOCK_METHOD(void, OnResponseCompleted,
+ (const HttpRequest& request, const HttpResponse& response),
+ (override));
+};
+
+// Creates a 'pending' `Operation`.
+::google::longrunning::Operation CreatePendingOperation(
+ absl::string_view operation_name);
+
+// Creates a 'done' `Operation`, packing the given message into an `Any`.
+::google::longrunning::Operation CreateDoneOperation(
+ absl::string_view operation_name, const google::protobuf::Message& inner_result);
+
+// Creates an `Operation` with the specified error information.
+::google::longrunning::Operation CreateErrorOperation(
+ absl::string_view operation_name, const absl::StatusCode error_code,
+ absl::string_view error_message);
+
+} // namespace http
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_HTTP_TESTING_TEST_HELPERS_H_
diff --git a/fcp/client/interruptible_runner.cc b/fcp/client/interruptible_runner.cc
new file mode 100644
index 0000000..9ab8c41
--- /dev/null
+++ b/fcp/client/interruptible_runner.cc
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/interruptible_runner.h"
+
+#include <functional>
+#include <utility>
+
+#include "absl/status/status.h"
+
+namespace fcp {
+namespace client {
+
+absl::Status InterruptibleRunner::Run(std::function<absl::Status()> f,
+ std::function<void()> abort_function) {
+ // Check before even making the call.
+ if (should_abort_()) {
+ return absl::CancelledError("cancelled before posting callable");
+ }
+ fcp::thread::Future<absl::Status> run_future =
+ fcp::thread::ScheduleFuture<absl::Status>(thread_pool_.get(), f);
+ return WaitUntilDone(std::move(run_future), abort_function);
+}
+
+absl::Status InterruptibleRunner::WaitUntilDone(
+ fcp::thread::Future<absl::Status>&& run_future,
+ std::function<void()> abort_function) {
+ // Wait until call is done, checking periodically whether we need to abort.
+ while (true) {
+ if (run_future.Wait(timing_config_.polling_period)) {
+ std::optional<absl::Status> future_result = std::move(run_future).Take();
+ // std::nullopt indicates the underlying promise was abandoned. To my
+ // best knowledge this always indicates a programming error and hence
+ // should result in a crash.
+ FCP_CHECK(future_result != std::nullopt);
+ return future_result.value();
+ }
+
+ if (should_abort_()) {
+ return Abort(std::move(run_future), abort_function);
+ }
+ }
+}
+
+absl::Status InterruptibleRunner::Abort(
+ fcp::thread::Future<absl::Status> run_future,
+ std::function<void()> abort_function) {
+ FCP_LOG(WARNING) << "Aborting run.";
+
+ // Attempt to abort the ongoing call.
+ abort_function();
+
+ // Wait for at most the graceful shutdown period.
+ if (run_future.Wait(timing_config_.graceful_shutdown_period)) {
+ log_manager_->LogDiag(diagnostics_config_.interrupted);
+ FCP_CHECK(std::move(run_future).Take() != std::nullopt);
+ return absl::CancelledError("cancelled after graceful wait");
+ }
+
+ // Runnable failed to abort during the graceful shutdown period. Wait for
+ // (possibly much) longer, because there's nothing much being
+ // gained by returning with TF still running, but resources leak.
+ log_manager_->LogDiag(diagnostics_config_.interrupt_timeout);
+ if (run_future.Wait(timing_config_.extended_shutdown_period)) {
+ log_manager_->LogDiag(diagnostics_config_.interrupted_extended);
+ FCP_CHECK(std::move(run_future).Take() != std::nullopt);
+ return absl::CancelledError("cancelled after extended wait");
+ }
+
+ // If even waiting for the long period didn't help, exit this process.
+ // This is the worst case that will unfortunately happen - we hope the
+ // logs above and below make it to a logging backend, allowing to narrow
+ // the root cause down to particular models or builds; and the exit(0) should
+ // avoid raising a crash dialog when training is running in a background
+ // process. Nevertheless the goal should be to never reach this point.
+
+ log_manager_->LogDiag(diagnostics_config_.interrupt_timeout_extended);
+ exit(0);
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/interruptible_runner.h b/fcp/client/interruptible_runner.h
new file mode 100644
index 0000000..75586ce
--- /dev/null
+++ b/fcp/client/interruptible_runner.h
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_INTERRUPTIBLE_RUNNER_H_
+#define FCP_CLIENT_INTERRUPTIBLE_RUNNER_H_
+
+#include <functional>
+#include <memory>
+
+#include "absl/status/status.h"
+#include "absl/time/time.h"
+#include "fcp/base/future.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/log_manager.h"
+
+namespace fcp {
+namespace client {
+
+// An executor that runs operations in a background thread, polling a callback
+// periodically whether to abort, and aborting the operation if necessary.
+// This uses a single-threaded thread pool. During execution of an operation,
+// should_abort is polled periodically (polling_period), and if it returns true,
+// the abort_function supplied along with the operation is called. The operation
+// is then expected to abort within graceful_shutdown_period. If not, a diag
+// code is logged and we wait for some time longer (extended_shutdown_period),
+// and if the operation still does not finish, the program exits.
+// The destructor blocks until the background thread has become idle.
+class InterruptibleRunner {
+ public:
+ // A struct used to group polling & timeout related parameters.
+ struct TimingConfig {
+ absl::Duration polling_period;
+ absl::Duration graceful_shutdown_period;
+ absl::Duration extended_shutdown_period;
+ };
+
+ // A struct used to group diagnostics related parameters.
+ struct DiagnosticsConfig {
+ ProdDiagCode interrupted;
+ ProdDiagCode interrupt_timeout;
+ ProdDiagCode interrupted_extended;
+ ProdDiagCode interrupt_timeout_extended;
+ };
+
+ InterruptibleRunner(LogManager* log_manager,
+ std::function<bool()> should_abort,
+ const TimingConfig& timing_config,
+ const DiagnosticsConfig& diagnostics_config)
+ : log_manager_(log_manager),
+ should_abort_(should_abort),
+ timing_config_(timing_config),
+ diagnostics_config_(diagnostics_config) {
+ thread_pool_ = fcp::CreateThreadPoolScheduler(1);
+ }
+
+ ~InterruptibleRunner() { thread_pool_->WaitUntilIdle(); }
+
+ // Executes f() on a background. Returns CANCELLED if the background thread
+ // was aborted, or a Status object from the background thread on successful
+ // completion.
+ absl::Status Run(std::function<absl::Status()> f,
+ std::function<void()> abort_function);
+
+ private:
+ absl::Status WaitUntilDone(fcp::thread::Future<absl::Status>&& run_future,
+ std::function<void()> abort_function);
+ absl::Status Abort(fcp::thread::Future<absl::Status> run_future,
+ std::function<void()> abort_function);
+
+ std::unique_ptr<Scheduler> thread_pool_;
+ LogManager* const log_manager_;
+ std::function<bool()> should_abort_;
+ TimingConfig timing_config_;
+ DiagnosticsConfig diagnostics_config_;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_INTERRUPTIBLE_RUNNER_H_
diff --git a/fcp/client/interruptible_runner_test.cc b/fcp/client/interruptible_runner_test.cc
new file mode 100644
index 0000000..9629e93
--- /dev/null
+++ b/fcp/client/interruptible_runner_test.cc
@@ -0,0 +1,258 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/interruptible_runner.h"
+
+#include <functional>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "absl/synchronization/blocking_counter.h"
+#include "absl/time/time.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace {
+
+using ::fcp::client::ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION;
+using ::fcp::client::ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT;
+using ::fcp::client::ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED;
+using ::fcp::client::ProdDiagCode::
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT;
+using ::testing::StrictMock;
+
+static InterruptibleRunner::DiagnosticsConfig getDiagnosticsConfig() {
+ return InterruptibleRunner::DiagnosticsConfig{
+ .interrupted = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
+ .interrupt_timeout = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
+ .interrupted_extended =
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
+ .interrupt_timeout_extended =
+ BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT};
+}
+
+// Tests the case where runnable finishes before the future times out (and we'd
+// call should_abort).
+TEST(InterruptibleRunnerTest, TestNormalNoAbortCheck) {
+ int should_abort_calls = 0;
+ int abort_function_calls = 0;
+ std::function<bool()> should_abort = [&should_abort_calls]() {
+ should_abort_calls++;
+ return false;
+ };
+ std::function<void()> abort_function = [&abort_function_calls]() {
+ abort_function_calls++;
+ };
+
+ InterruptibleRunner interruptibleRunner(
+ /*log_manager=*/nullptr, should_abort,
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::InfiniteDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ getDiagnosticsConfig());
+ absl::Status status = interruptibleRunner.Run(
+ []() { return absl::OkStatus(); }, abort_function);
+ EXPECT_THAT(status, IsCode(OK));
+ EXPECT_EQ(should_abort_calls, 1);
+ EXPECT_EQ(abort_function_calls, 0);
+
+ // Test that the Status returned by the runnable is returned as is.
+ status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
+ abort_function);
+ EXPECT_THAT(status, IsCode(DATA_LOSS));
+}
+
+// Tests the case where should_abort prevents us from even kicking off the run.
+TEST(InterruptibleRunnerTest, TestNormalAbortBeforeRun) {
+ int should_abort_calls = 0;
+ int abort_function_calls = 0;
+ int runnable_calls = 0;
+ std::function<bool()> should_abort = [&should_abort_calls]() {
+ should_abort_calls++;
+ return true;
+ };
+ std::function<void()> abort_function = [&abort_function_calls]() {
+ abort_function_calls++;
+ };
+
+ InterruptibleRunner interruptibleRunner(
+ /*log_manager=*/nullptr, should_abort,
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::InfiniteDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ getDiagnosticsConfig());
+ absl::Status status = interruptibleRunner.Run(
+ [&runnable_calls]() {
+ runnable_calls++;
+ return absl::OkStatus();
+ },
+ abort_function);
+ EXPECT_THAT(status, IsCode(CANCELLED));
+ EXPECT_EQ(abort_function_calls, 0);
+ EXPECT_EQ(runnable_calls, 0);
+}
+
+// Tests the case where the future wait times out once, we call should_abort,
+// which says to continue, and then the future returns.
+TEST(InterruptibleRunnerTest, TestNormalWithAbortCheckButNoAbort) {
+ int should_abort_calls = 0;
+ int abort_function_calls = 0;
+ absl::BlockingCounter counter_should_abort(1);
+ absl::BlockingCounter counter_did_abort(1);
+ std::function<bool()> should_abort =
+ [&should_abort_calls, &counter_should_abort, &counter_did_abort]() {
+ should_abort_calls++;
+ if (should_abort_calls == 2) {
+ counter_should_abort.DecrementCount();
+ counter_did_abort.Wait();
+ }
+ return false;
+ };
+ std::function<void()> abort_function = [&abort_function_calls]() {
+ abort_function_calls++;
+ };
+
+ InterruptibleRunner interruptibleRunner(
+ nullptr, should_abort,
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ getDiagnosticsConfig());
+ absl::Status status = interruptibleRunner.Run(
+ [&counter_should_abort, &counter_did_abort]() {
+ // Block until should_abort has been called.
+ counter_should_abort.Wait();
+ // Tell should_abort to return false.
+ counter_did_abort.DecrementCount();
+ return absl::OkStatus();
+ },
+ abort_function);
+ EXPECT_THAT(status, IsCode(OK));
+ EXPECT_GE(should_abort_calls, 2);
+ EXPECT_EQ(abort_function_calls, 0);
+
+ status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
+ abort_function);
+ EXPECT_THAT(status, IsCode(DATA_LOSS));
+}
+
+// Tests the case where the runnable gets aborted and behaves nicely (aborts
+// within the grace period).
+TEST(InterruptibleRunnerTest, TestAbortInGracePeriod) {
+ StrictMock<MockLogManager> log_manager;
+ int should_abort_calls = 0;
+ int abort_function_calls = 0;
+ absl::BlockingCounter counter_should_abort(1);
+ absl::BlockingCounter counter_did_abort(1);
+
+ std::function<bool()> should_abort = [&should_abort_calls]() {
+ should_abort_calls++;
+ return should_abort_calls >= 2;
+ };
+ std::function<void()> abort_function =
+ [&abort_function_calls, &counter_should_abort, &counter_did_abort]() {
+ abort_function_calls++;
+ // Signal runnable to abort.
+ counter_should_abort.DecrementCount();
+ // Wait for runnable to have aborted.
+ counter_did_abort.Wait();
+ };
+
+ InterruptibleRunner interruptibleRunner(
+ &log_manager, should_abort,
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::InfiniteDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ getDiagnosticsConfig());
+ // Tests that abort works.
+ EXPECT_CALL(log_manager, LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION))
+ .Times(testing::Exactly(1));
+ absl::Status status = interruptibleRunner.Run(
+ [&counter_should_abort, &counter_did_abort]() {
+ counter_should_abort.Wait();
+ counter_did_abort.DecrementCount();
+ return absl::OkStatus();
+ },
+ abort_function);
+ EXPECT_THAT(status, IsCode(CANCELLED));
+ EXPECT_EQ(should_abort_calls, 2);
+ EXPECT_EQ(abort_function_calls, 1);
+}
+
+// Tests the case where abort does not happen within the grace period.
+// This is achieved by only letting the runnable finish once the grace period
+// wait fails and a timeout diag code is logged, by taking an action on the
+// LogManager mock.
+TEST(InterruptibleRunnerTest, TestAbortInExtendedGracePeriod) {
+ StrictMock<MockLogManager> log_manager;
+ int should_abort_calls = 0;
+ int abort_function_calls = 0;
+
+ absl::BlockingCounter counter_should_abort(1);
+ absl::BlockingCounter counter_did_abort(1);
+
+ std::function<bool()> should_abort = [&should_abort_calls]() {
+ should_abort_calls++;
+ return should_abort_calls >= 2;
+ };
+ std::function<void()> abort_function = [&abort_function_calls]() {
+ abort_function_calls++;
+ };
+
+ InterruptibleRunner interruptibleRunner(
+ &log_manager, should_abort,
+ InterruptibleRunner::TimingConfig{
+ .polling_period = absl::ZeroDuration(),
+ .graceful_shutdown_period = absl::ZeroDuration(),
+ .extended_shutdown_period = absl::InfiniteDuration()},
+ getDiagnosticsConfig());
+ EXPECT_CALL(log_manager,
+ LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT))
+ .WillOnce(
+ [&counter_should_abort, &counter_did_abort](ProdDiagCode ignored) {
+ counter_should_abort.DecrementCount();
+ counter_did_abort.Wait();
+ return absl::OkStatus();
+ });
+ EXPECT_CALL(
+ log_manager,
+ LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED))
+ .Times(testing::Exactly(1));
+ absl::Status status = interruptibleRunner.Run(
+ [&counter_should_abort, &counter_did_abort]() {
+ counter_should_abort.Wait();
+ counter_did_abort.DecrementCount();
+ return absl::OkStatus();
+ },
+ abort_function);
+
+ EXPECT_THAT(status, IsCode(CANCELLED));
+ EXPECT_EQ(should_abort_calls, 2);
+ EXPECT_EQ(abort_function_calls, 1);
+}
+
+} // namespace
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/lc_runner.cc b/fcp/client/lc_runner.cc
new file mode 100644
index 0000000..7439b74
--- /dev/null
+++ b/fcp/client/lc_runner.cc
@@ -0,0 +1,362 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/lc_runner.h"
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/engine/plan_engine_helpers.h"
+
+#ifdef FCP_CLIENT_SUPPORT_TFMOBILE
+#include "fcp/client/engine/simple_plan_engine.h"
+#endif
+
+#include "fcp/client/engine/tflite_plan_engine.h"
+#include "fcp/client/opstats/opstats_example_store.h"
+#include "fcp/client/phase_logger_impl.h"
+#include "fcp/client/selector_context.pb.h"
+#include "fcp/protos/plan.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace fcp {
+namespace client {
+
+using ::fcp::client::opstats::OpStatsLogger;
+using ::google::internal::federated::plan::ClientOnlyPlan;
+using ::google::internal::federated::plan::LocalComputeIORouter;
+
+using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
+using TfMobileInputs = std::vector<std::pair<std::string, tensorflow::Tensor>>;
+
+namespace {
+#ifdef FCP_CLIENT_SUPPORT_TFMOBILE
+absl::StatusOr<std::unique_ptr<TfMobileInputs>>
+ConstructInputsForTensorflowSpecPlan(
+ const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
+ const std::string& output_dir_uri,
+ const absl::flat_hash_map<std::string, std::string>& input_resources) {
+ auto inputs = std::make_unique<
+ std::vector<std::pair<std::string, tensorflow::Tensor>>>();
+ if (local_compute.has_multiple_input_resources()) {
+ if (!input_dir_uri.empty()) {
+ return absl::InvalidArgumentError(
+ "Both input dir and input resources are provided.");
+ }
+ auto input_resource_tensor_name_map =
+ local_compute.multiple_input_resources()
+ .input_resource_tensor_name_map();
+ for (const auto& resource : input_resources) {
+ tensorflow::Tensor resource_tensor(tensorflow::DT_STRING, {});
+ resource_tensor.scalar<tensorflow::tstring>()() = resource.second;
+ if (!input_resource_tensor_name_map.contains(resource.first)) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("User provided input resource:", resource.first,
+ " is missing in LocalComputeIORouter."));
+ }
+ std::string tensor_name = input_resource_tensor_name_map[resource.first];
+ inputs->push_back({tensor_name, resource_tensor});
+ }
+ } else {
+ tensorflow::Tensor input_dirpath(tensorflow::DT_STRING, {});
+ input_dirpath.scalar<tensorflow::tstring>()() = input_dir_uri;
+ inputs->push_back({local_compute.input_dir_tensor_name(), input_dirpath});
+ }
+ tensorflow::Tensor output_dirpath(tensorflow::DT_STRING, {});
+ output_dirpath.scalar<tensorflow::tstring>()() = output_dir_uri;
+ inputs->push_back({local_compute.output_dir_tensor_name(), output_dirpath});
+ return inputs;
+}
+#endif
+
+absl::StatusOr<std::unique_ptr<TfLiteInputs>> ConstructInputsForTFLitePlan(
+ const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
+ const std::string& output_dir_uri,
+ const absl::flat_hash_map<std::string, std::string>& input_resources) {
+ auto inputs = std::make_unique<TfLiteInputs>();
+ if (local_compute.has_multiple_input_resources()) {
+ if (!input_dir_uri.empty()) {
+ return absl::InvalidArgumentError(
+ "Both input dir and input resources are provided.");
+ }
+ auto input_resource_tensor_name_map =
+ local_compute.multiple_input_resources()
+ .input_resource_tensor_name_map();
+ for (const auto& resource : input_resources) {
+ if (!input_resource_tensor_name_map.contains(resource.first)) {
+ // If the user provided more input resources than required in the
+ // LocalComputeIORouter, we simply continue without throwing an error.
+ // In this way, the user could update their scheduling logic separately
+ // from their local computation definitions.
+ continue;
+ }
+ std::string tensor_name = input_resource_tensor_name_map[resource.first];
+ (*inputs)[tensor_name] = resource.second;
+ }
+ } else {
+ (*inputs)[local_compute.input_dir_tensor_name()] = input_dir_uri;
+ }
+ (*inputs)[local_compute.output_dir_tensor_name()] = output_dir_uri;
+ return inputs;
+}
+
+void LogComputationOutcome(engine::PlanResult plan_result,
+ PhaseLogger& phase_logger,
+ absl::Time run_plan_start_time,
+ absl::Time reference_time) {
+ switch (plan_result.outcome) {
+ case engine::PlanOutcome::kSuccess:
+ phase_logger.LogComputationCompleted(plan_result.example_stats,
+ NetworkStats(), run_plan_start_time,
+ reference_time);
+ break;
+ case engine::PlanOutcome::kInterrupted:
+ phase_logger.LogComputationInterrupted(
+ plan_result.original_status, plan_result.example_stats,
+ NetworkStats(), run_plan_start_time, reference_time);
+ break;
+ case engine::PlanOutcome::kInvalidArgument:
+ phase_logger.LogComputationInvalidArgument(
+ plan_result.original_status, plan_result.example_stats,
+ NetworkStats(), run_plan_start_time);
+ break;
+ case engine::PlanOutcome::kTensorflowError:
+ phase_logger.LogComputationTensorflowError(
+ std::move(plan_result.original_status), plan_result.example_stats,
+ NetworkStats(), run_plan_start_time, reference_time);
+ break;
+ case engine::PlanOutcome::kExampleIteratorError:
+ phase_logger.LogComputationExampleIteratorError(
+ plan_result.original_status, plan_result.example_stats,
+ NetworkStats(), run_plan_start_time);
+ break;
+ }
+}
+
+// Creates an ExampleIteratorFactory that routes queries to the
+// SimpleTaskEnvironment::CreateExampleIterator() method.
+std::unique_ptr<engine::ExampleIteratorFactory>
+CreateSimpleTaskEnvironmentIteratorFactory(
+ SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
+ return std::make_unique<engine::FunctionalExampleIteratorFactory>(
+ /*can_handle_func=*/
+ [](const google::internal::federated::plan::ExampleSelector&) {
+ // The SimpleTaskEnvironment-based ExampleIteratorFactory should
+ // be the catch-all factory that is able to handle all queries
+ // that no other ExampleIteratorFactory is able to handle.
+ return true;
+ },
+ /*create_iterator_func=*/
+ [task_env, selector_context](
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) {
+ return task_env->CreateExampleIterator(example_selector,
+ selector_context);
+ },
+ /*should_collect_stats=*/true);
+}
+
+absl::Status RunPlanWithTensorflowSpec(
+ PhaseLogger& phase_logger,
+ std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
+ std::function<bool()> should_abort, LogManager* log_manager,
+ OpStatsLogger* opstats_logger, const Flags* flags,
+ const ClientOnlyPlan& client_plan, const std::string& input_dir_uri,
+ const std::string& output_dir_uri,
+ const absl::flat_hash_map<std::string, std::string>& input_resources,
+ const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
+ const absl::Time run_plan_start_time, const absl::Time reference_time) {
+ // Check that this is a TensorflowSpec-based plan for local computation.
+ if (!client_plan.phase().has_tensorflow_spec()) {
+ absl::Status error_status =
+ absl::InvalidArgumentError("Plan without TensorflowSpec");
+ phase_logger.LogComputationInvalidArgument(
+ error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
+ return error_status;
+ }
+ if (!client_plan.phase().has_local_compute()) {
+ absl::Status error_status =
+ absl::InvalidArgumentError("Invalid TensorflowSpec-based plan");
+ phase_logger.LogComputationInvalidArgument(
+ error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
+ return error_status;
+ }
+
+ // Run plan
+ std::vector<std::string> output_names_unused;
+
+ if (!client_plan.tflite_graph().empty()) {
+ log_manager->LogDiag(
+ ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED);
+ }
+
+ if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
+ auto inputs = ConstructInputsForTFLitePlan(
+ client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
+ input_resources);
+ if (!inputs.ok()) {
+ phase_logger.LogComputationInvalidArgument(
+ inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
+ return inputs.status();
+ }
+ engine::TfLitePlanEngine plan_engine(example_iterator_factories,
+ should_abort, log_manager,
+ opstats_logger, flags, &timing_config);
+ engine::PlanResult plan_result = plan_engine.RunPlan(
+ client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
+ std::move(*inputs), output_names_unused);
+ engine::PlanOutcome outcome = plan_result.outcome;
+ LogComputationOutcome(std::move(plan_result), phase_logger,
+ run_plan_start_time, reference_time);
+ return ConvertPlanOutcomeToStatus(outcome);
+ }
+
+#ifdef FCP_CLIENT_SUPPORT_TFMOBILE
+ // Construct input tensors based on the values in the LocalComputeIORouter
+ // message.
+ auto inputs = ConstructInputsForTensorflowSpecPlan(
+ client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
+ input_resources);
+ if (!inputs.ok()) {
+ phase_logger.LogComputationInvalidArgument(
+ inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
+ return inputs.status();
+ }
+ engine::SimplePlanEngine plan_engine(
+ example_iterator_factories, should_abort, log_manager, opstats_logger,
+ &timing_config, flags->support_constant_tf_inputs());
+ engine::PlanResult plan_result = plan_engine.RunPlan(
+ client_plan.phase().tensorflow_spec(), client_plan.graph(),
+ client_plan.tensorflow_config_proto(), std::move(*inputs),
+ output_names_unused);
+ engine::PlanOutcome outcome = plan_result.outcome;
+ LogComputationOutcome(std::move(plan_result), phase_logger,
+ run_plan_start_time, reference_time);
+ return ConvertPlanOutcomeToStatus(outcome);
+#else
+ return absl::InternalError("No plan engine enabled");
+#endif
+}
+} // anonymous namespace
+
+absl::Status RunLocalComputation(
+ SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
+ LogManager* log_manager, const Flags* flags,
+ const std::string& session_name, const std::string& plan_uri,
+ const std::string& input_dir_uri, const std::string& output_dir_uri,
+ const absl::flat_hash_map<std::string, std::string>& input_resources) {
+ auto opstats_logger = engine::CreateOpStatsLogger(
+ env_deps->GetBaseDir(), flags, log_manager, session_name,
+ /*population_name=*/"");
+ SelectorContext selector_context;
+ selector_context.mutable_computation_properties()->set_session_name(
+ session_name);
+ LocalComputation computation = LocalComputation();
+ computation.set_input_dir(input_dir_uri);
+ computation.set_output_dir(output_dir_uri);
+ computation.mutable_input_resource_map()->insert(input_resources.begin(),
+ input_resources.end());
+ *selector_context.mutable_computation_properties()->mutable_local_compute() =
+ computation;
+ PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
+ log_manager, flags);
+ return RunLocalComputation(phase_logger, env_deps, log_manager,
+ opstats_logger.get(), flags, plan_uri,
+ input_dir_uri, output_dir_uri, input_resources,
+ selector_context);
+}
+
+absl::Status RunLocalComputation(
+ PhaseLogger& phase_logger, SimpleTaskEnvironment* env_deps,
+ LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
+ const std::string& plan_uri, const std::string& input_dir_uri,
+ const std::string& output_dir_uri,
+ const absl::flat_hash_map<std::string, std::string>& input_resources,
+ const SelectorContext& selector_context) {
+ absl::Time reference_time = absl::Now();
+ absl::Duration polling_period =
+ absl::Milliseconds(flags->condition_polling_period_millis());
+ std::function<bool()> should_abort = [env_deps, polling_period]() {
+ return env_deps->ShouldAbort(absl::Now(), polling_period);
+ };
+ // Check if the device conditions allow running a local computation.
+ if (should_abort()) {
+ std::string message =
+ "Device conditions not satisfied, aborting local computation";
+ FCP_LOG(INFO) << message;
+ phase_logger.LogTaskNotStarted(message);
+ return absl::CancelledError("");
+ }
+ // Local compute plans can use example iterators from the
+ // SimpleTaskEnvironment and those reading the OpStats DB.
+ opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
+ opstats_logger, log_manager,
+ flags->opstats_last_successful_contribution_criteria());
+ std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
+ CreateSimpleTaskEnvironmentIteratorFactory(env_deps, selector_context);
+ std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
+ &opstats_example_iterator_factory, env_example_iterator_factory.get()};
+
+ fcp::client::InterruptibleRunner::TimingConfig timing_config = {
+ .polling_period = polling_period,
+ .graceful_shutdown_period = absl::Milliseconds(
+ flags->tf_execution_teardown_grace_period_millis()),
+ .extended_shutdown_period = absl::Milliseconds(
+ flags->tf_execution_teardown_extended_period_millis()),
+ };
+
+ absl::Time run_plan_start_time = absl::Now();
+ phase_logger.LogComputationStarted();
+
+ absl::StatusOr<std::string> plan_str = fcp::ReadFileToString(plan_uri);
+ if (!plan_str.ok()) {
+ phase_logger.LogComputationIOError(plan_str.status(), ExampleStats(),
+ NetworkStats(), run_plan_start_time);
+ return plan_str.status();
+ }
+
+ ClientOnlyPlan plan;
+ if (!plan.ParseFromString(*plan_str)) {
+ absl::Status error_status =
+ absl::InvalidArgumentError("could not parse received plan");
+ phase_logger.LogComputationInvalidArgument(
+ error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
+ return error_status;
+ }
+
+ std::vector<std::string> output_names;
+ std::vector<tensorflow::Tensor> output_tensors;
+ return RunPlanWithTensorflowSpec(
+ phase_logger, example_iterator_factories, should_abort, log_manager,
+ opstats_logger, flags, plan, input_dir_uri, output_dir_uri,
+ input_resources, timing_config, run_plan_start_time, reference_time);
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/lc_runner.h b/fcp/client/lc_runner.h
new file mode 100644
index 0000000..2e2993c
--- /dev/null
+++ b/fcp/client/lc_runner.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_LC_RUNNER_H_
+#define FCP_CLIENT_LC_RUNNER_H_
+
+#include <string>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/phase_logger.h"
+#include "fcp/client/simple_task_environment.h"
+
+namespace fcp {
+namespace client {
+
+// Prod entry point for running a local computation. Concurrent calls, with
+// the same SimpleTaskEnvironment::GetBaseDir(), are not supported.
+// If the training conditions are not met, return CANCELLED status.
+// If the plan cannot be parsed, return INVALID_ARGUMENT status.
+// If the plan does not contain tensorSpec, return INVALID_ARGUMENT status.
+// If the plan does not contain LocalComputeIORouter, return INVALID_ARGUMENT
+// status.
+// If the plan contains ClientExecutions, return INVALID_ARGUMENT status.
+// If the plan expects input tensors other than dataset token, input dir and
+// output dir, return INVALID_ARGUMENT status.
+// If Tensorflow completes, return OK status.
+// If Tensorflow was interrupted, return CANCELLED status.
+absl::Status RunLocalComputation(
+ SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
+ LogManager* log_manager, const Flags* flags,
+ const std::string& session_name, const std::string& plan_uri,
+ const std::string& input_dir_uri, const std::string& output_dir_uri,
+ const absl::flat_hash_map<std::string, std::string>& input_resources);
+
+// This is exposed for use in tests that require a mocked OpStatsLogger.
+// Otherwise, this is used internally by the other RunLocalComputation
+// method once the OpStatsLogger object has been created.
+absl::Status RunLocalComputation(
+ PhaseLogger& phase_logger, SimpleTaskEnvironment* env_deps,
+ LogManager* log_manager,
+ ::fcp::client::opstats::OpStatsLogger* opstats_logger, const Flags* flags,
+ const std::string& plan_uri, const std::string& input_dir_uri,
+ const std::string& output_dir_uri,
+ const absl::flat_hash_map<std::string, std::string>& input_resources,
+ const SelectorContext& selector_context);
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_LC_RUNNER_H_
diff --git a/fcp/client/log_manager.h b/fcp/client/log_manager.h
new file mode 100644
index 0000000..0e7d2ca
--- /dev/null
+++ b/fcp/client/log_manager.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_LOG_MANAGER_H_
+#define FCP_CLIENT_LOG_MANAGER_H_
+
+#include <cstdint>
+#include <string>
+
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/histogram_counters.pb.h"
+
+namespace fcp {
+namespace client {
+
+// An interface used to log "diag codes" - numeric enum values representing some
+// state of the code, e.g. a specific source code location being reached, or a
+// certain condition being met - to a monitoring backend in the cloud.
+class LogManager {
+ public:
+ virtual ~LogManager() = default;
+
+ // These functions log the given diag code.
+ virtual void LogDiag(ProdDiagCode diagCode) = 0;
+ virtual void LogDiag(DebugDiagCode diagCode) = 0;
+ // This function logs the given value to a long histogram identified by
+ // histogram_counter, annotated with the indexes and data_source_type.
+ virtual void LogToLongHistogram(HistogramCounters histogram_counter,
+ int execution_index, int epoch_index,
+ engine::DataSourceType data_source_type,
+ int64_t value) = 0;
+ void LogToLongHistogram(HistogramCounters histogram_counter, int64_t value) {
+ return LogToLongHistogram(histogram_counter, /*execution_index=*/0,
+ /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, value);
+ }
+ // After calling this function, all subsequently published histogram events
+ // will be annotated with the specified model_identifier. This value is
+ // typically provided by the federated server.
+ //
+ // Note that this method may be called multiple times with different values,
+ // if over the course of a training session multiple models are executed.
+ virtual void SetModelIdentifier(const std::string& model_identifier) = 0;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_LOG_MANAGER_H_
diff --git a/fcp/client/opstats/BUILD b/fcp/client/opstats/BUILD
new file mode 100644
index 0000000..31c3d37
--- /dev/null
+++ b/fcp/client/opstats/BUILD
@@ -0,0 +1,171 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "opstats_db",
+ hdrs = ["opstats_db.h"],
+ deps = [
+ "//fcp/protos:opstats_cc_proto",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ ],
+)
+
+cc_library(
+ name = "opstats_logger",
+ hdrs = ["opstats_logger.h"],
+ deps = [
+ ":opstats_db",
+ "//fcp/client:interfaces",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:opstats_cc_proto",
+ ],
+)
+
+cc_library(
+ name = "opstats_logger_impl",
+ srcs = ["opstats_logger_impl.cc"],
+ hdrs = ["opstats_logger_impl.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":opstats_db",
+ ":opstats_logger",
+ "//fcp/base:time_util",
+ "//fcp/client:interfaces",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:opstats_cc_proto",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "opstats_logger_impl_test",
+ srcs = ["opstats_logger_impl_test.cc"],
+ deps = [
+ ":opstats_logger_impl",
+ ":pds_backed_opstats_db",
+ "//fcp/base",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:histogram_counters_cc_proto",
+ "//fcp/client:test_helpers",
+ "//fcp/protos:opstats_cc_proto",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "opstats_example_store",
+ srcs = ["opstats_example_store.cc"],
+ hdrs = ["opstats_example_store.h"],
+ deps = [
+ ":opstats_db",
+ ":opstats_logger",
+ ":opstats_utils",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:interfaces",
+ "//fcp/client:simple_task_environment",
+ "//fcp/client/engine:example_iterator_factory",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:opstats_cc_proto",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_protobuf//:protobuf",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "opstats_example_store_test",
+ srcs = ["opstats_example_store_test.cc"],
+ deps = [
+ ":opstats_example_store",
+ "//fcp/client:test_helpers",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "opstats_utils",
+ srcs = ["opstats_utils.cc"],
+ hdrs = ["opstats_utils.h"],
+ deps = [
+ ":opstats_db",
+ "//fcp/base",
+ "//fcp/protos:opstats_cc_proto",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "opstats_utils_test",
+ srcs = ["opstats_utils_test.cc"],
+ deps = [
+ ":opstats_utils",
+ "//fcp/client:test_helpers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "pds_backed_opstats_db",
+ srcs = ["pds_backed_opstats_db.cc"],
+ hdrs = ["pds_backed_opstats_db.h"],
+ deps = [
+ ":opstats_db",
+ "//fcp/base",
+ "//fcp/client:diag_codes_cc_proto",
+ "//fcp/client:interfaces",
+ "//fcp/protos:opstats_cc_proto",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@com_google_protobuf//:protobuf",
+ "@protodatastore_cpp//protostore:file-storage",
+ "@protodatastore_cpp//protostore:proto-data-store",
+ ],
+)
+
+cc_test(
+ name = "pds_backed_opstats_db_test",
+ srcs = ["pds_backed_opstats_db_test.cc"],
+ deps = [
+ ":pds_backed_opstats_db",
+ "//fcp/client:test_helpers",
+ "//fcp/protos:opstats_cc_proto",
+ "//fcp/testing",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
diff --git a/fcp/client/opstats/opstats_db.h b/fcp/client/opstats/opstats_db.h
new file mode 100644
index 0000000..ee832b8
--- /dev/null
+++ b/fcp/client/opstats/opstats_db.h
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_OPSTATS_OPSTATS_DB_H_
+#define FCP_CLIENT_OPSTATS_OPSTATS_DB_H_
+
+#include <functional>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/protos/opstats.pb.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+
+// Base no-op class for the OpStats database that always contains an empty
+// OpStatsSequence and performs no file i/o.
+class OpStatsDb {
+ public:
+ virtual ~OpStatsDb() = default;
+ // The returned OpStatsSequence message should contain the operational stats
+ // for all runs. The operational stats for each run is wrapped inside a
+ // OperationalStats message, and the OperationalStats messages are ordered
+ // sequentially (first run to last run) within OpStatsSequence.
+ virtual absl::StatusOr<OpStatsSequence> Read() { return OpStatsSequence(); }
+
+ // OpStatsDb has a Transform method instead of a Write method because
+ // OpStatsSequence message already contains the operational stats for every
+ // run, and the user only need to update the existing OpStatsSequence message
+ // to add/remove/update data. In addition, by having a Transform method allows
+ // the implementations to perform atomic read-update-write operations.
+ virtual absl::Status Transform(std::function<void(OpStatsSequence&)> func) {
+ return absl::OkStatus();
+ }
+};
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_OPSTATS_OPSTATS_DB_H_
diff --git a/fcp/client/opstats/opstats_example_store.cc b/fcp/client/opstats/opstats_example_store.cc
new file mode 100644
index 0000000..1927698
--- /dev/null
+++ b/fcp/client/opstats/opstats_example_store.cc
@@ -0,0 +1,254 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/opstats/opstats_example_store.h"
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "google/protobuf/any.pb.h"
+#include "google/protobuf/util/time_util.h"
+#include "absl/status/status.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/opstats/opstats_utils.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/opstats.pb.h"
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+
+using ::google::internal::federated::plan::ExampleSelector;
+using ::google::protobuf::util::TimeUtil;
+
+namespace {
+
+absl::Time GetLastUpdatedTime(const OperationalStats& op_stats) {
+ if (op_stats.events().empty()) {
+ return absl::InfinitePast();
+ } else {
+ return absl::FromUnixMillis(TimeUtil::TimestampToMilliseconds(
+ op_stats.events().rbegin()->timestamp()));
+ }
+}
+
+tensorflow::Feature CreateFeatureFromString(const std::string& str) {
+ tensorflow::Feature feature;
+ feature.mutable_bytes_list()->add_value(str);
+ return feature;
+}
+
+tensorflow::Feature CreateFeatureFromInt(int64_t value) {
+ tensorflow::Feature feature;
+ feature.mutable_int64_list()->add_value(value);
+ return feature;
+}
+
+tensorflow::Feature CreateFeatureFromStringVector(
+ const std::vector<std::string>& values) {
+ tensorflow::Feature feature;
+ auto* bytes_list = feature.mutable_bytes_list();
+ for (const auto& value : values) {
+ bytes_list->add_value(value);
+ }
+ return feature;
+}
+
+tensorflow::Feature CreateFeatureFromIntVector(
+ const std::vector<int64_t>& values) {
+ tensorflow::Feature feature;
+ auto* int64_list = feature.mutable_int64_list();
+ for (const auto& value : values) {
+ int64_list->add_value(value);
+ }
+ return feature;
+}
+
+std::string CreateExample(const OperationalStats& op_stats,
+ int64_t earliest_trustworthy_time) {
+ tensorflow::Example example;
+ auto* feature_map = example.mutable_features()->mutable_feature();
+ (*feature_map)[kPopulationName] =
+ CreateFeatureFromString(op_stats.population_name());
+ (*feature_map)[kSessionName] =
+ CreateFeatureFromString(op_stats.session_name());
+ (*feature_map)[kTaskName] = CreateFeatureFromString(op_stats.task_name());
+
+ // Create events related features.
+ std::vector<int64_t> event_types;
+ std::vector<int64_t> event_time_millis;
+ for (const auto& event : op_stats.events()) {
+ event_types.push_back(event.event_type());
+ event_time_millis.push_back(
+ TimeUtil::TimestampToMilliseconds(event.timestamp()));
+ }
+ (*feature_map)[kEventsEventType] = CreateFeatureFromIntVector(event_types);
+ (*feature_map)[kEventsTimestampMillis] =
+ CreateFeatureFromIntVector(event_time_millis);
+
+ // Create external dataset stats related features.
+ std::vector<std::string> uris;
+ std::vector<int64_t> num_examples_read;
+ std::vector<int64_t> num_bytes_read;
+ for (const auto& stats : op_stats.dataset_stats()) {
+ uris.push_back(stats.first);
+ num_examples_read.push_back(stats.second.num_examples_read());
+ num_bytes_read.push_back(stats.second.num_bytes_read());
+ }
+ (*feature_map)[kDatasetStatsUri] = CreateFeatureFromStringVector(uris);
+ (*feature_map)[kDatasetStatsNumExamplesRead] =
+ CreateFeatureFromIntVector(num_examples_read);
+ (*feature_map)[kDatasetStatsNumBytesRead] =
+ CreateFeatureFromIntVector(num_bytes_read);
+
+ (*feature_map)[kErrorMessage] =
+ CreateFeatureFromString(op_stats.error_message());
+
+ // Create RetryWindow related features.
+ (*feature_map)[kRetryWindowDelayMinMillis] = CreateFeatureFromInt(
+ TimeUtil::DurationToMilliseconds(op_stats.retry_window().delay_min()));
+ (*feature_map)[kRetryWindowDelayMaxMillis] = CreateFeatureFromInt(
+ TimeUtil::DurationToMilliseconds(op_stats.retry_window().delay_max()));
+
+ (*feature_map)[kChunkingLayerBytesDownloaded] =
+ CreateFeatureFromInt(op_stats.chunking_layer_bytes_downloaded());
+ (*feature_map)[kChunkingLayerBytesUploaded] =
+ CreateFeatureFromInt(op_stats.chunking_layer_bytes_uploaded());
+ (*feature_map)[kNetworkDuration] = CreateFeatureFromInt(
+ TimeUtil::DurationToMilliseconds(op_stats.network_duration()));
+
+ (*feature_map)[kEarliestTrustWorthyTimeMillis] =
+ CreateFeatureFromInt(earliest_trustworthy_time);
+
+ return example.SerializeAsString();
+}
+
+class OpStatsExampleIterator : public fcp::client::ExampleIterator {
+ public:
+ explicit OpStatsExampleIterator(std::vector<OperationalStats> op_stats,
+ int64_t earliest_trustworthy_time)
+ : next_(0),
+ data_(std::move(op_stats)),
+ earliest_trustworthy_time_millis_(earliest_trustworthy_time) {}
+ absl::StatusOr<std::string> Next() override {
+ if (next_ < 0 || next_ >= data_.size()) {
+ return absl::OutOfRangeError("The iterator is out of range.");
+ }
+ return CreateExample(data_[next_++], earliest_trustworthy_time_millis_);
+ }
+
+ void Close() override {
+ next_ = 0;
+ data_.clear();
+ }
+
+ private:
+ // The index for the next OperationalStats to be used.
+ int next_;
+ std::vector<OperationalStats> data_;
+ const int64_t earliest_trustworthy_time_millis_;
+};
+
+} // anonymous namespace
+
+bool OpStatsExampleIteratorFactory::CanHandle(
+ const ExampleSelector& example_selector) {
+ return example_selector.collection_uri() == opstats::kOpStatsCollectionUri;
+}
+
+absl::StatusOr<std::unique_ptr<fcp::client::ExampleIterator>>
+OpStatsExampleIteratorFactory::CreateExampleIterator(
+ const ExampleSelector& example_selector) {
+ if (example_selector.collection_uri() != kOpStatsCollectionUri) {
+ log_manager_->LogDiag(ProdDiagCode::OPSTATS_INCORRECT_COLLECTION_URI);
+ return absl::InvalidArgumentError(absl::StrCat(
+ "The collection uri is ", example_selector.collection_uri(),
+ ", which is not the expected uri: ", kOpStatsCollectionUri));
+ }
+ if (!op_stats_logger_->IsOpStatsEnabled()) {
+ log_manager_->LogDiag(
+ ProdDiagCode::OPSTATS_EXAMPLE_STORE_REQUESTED_NOT_ENABLED);
+ return absl::InvalidArgumentError("OpStats example store is not enabled.");
+ }
+
+ absl::Time lower_bound_time = absl::InfinitePast();
+ absl::Time upper_bound_time = absl::InfiniteFuture();
+ bool last_successful_contribution = false;
+ if (example_selector.has_criteria()) {
+ OpStatsSelectionCriteria criteria;
+ if (!example_selector.criteria().UnpackTo(&criteria)) {
+ log_manager_->LogDiag(ProdDiagCode::OPSTATS_INVALID_SELECTION_CRITERIA);
+ return absl::InvalidArgumentError("Unable to parse selection criteria.");
+ }
+
+ if (criteria.has_start_time()) {
+ lower_bound_time = absl::FromUnixMillis(
+ TimeUtil::TimestampToMilliseconds(criteria.start_time()));
+ }
+ if (criteria.has_end_time()) {
+ upper_bound_time = absl::FromUnixMillis(
+ TimeUtil::TimestampToMilliseconds(criteria.end_time()));
+ }
+ if (lower_bound_time > upper_bound_time) {
+ log_manager_->LogDiag(ProdDiagCode::OPSTATS_INVALID_SELECTION_CRITERIA);
+ return absl::InvalidArgumentError(
+ "Invalid selection criteria: start_time is after end_time.");
+ }
+ last_successful_contribution = criteria.last_successful_contribution();
+ }
+
+ FCP_ASSIGN_OR_RETURN(OpStatsSequence data,
+ op_stats_logger_->GetOpStatsDb()->Read());
+ std::vector<OperationalStats> selected_data;
+ if (last_successful_contribution) {
+ if (opstats_last_successful_contribution_criteria_) {
+ // Selector specified last_successful_contribution, and the feature is
+ // enabled. Create a last_successful_contribution iterator.
+ std::optional<OperationalStats> last_successful_contribution_entry =
+ GetLastSuccessfulContribution(data,
+ op_stats_logger_->GetCurrentTaskName());
+ if (last_successful_contribution_entry.has_value()) {
+ selected_data.push_back(*last_successful_contribution_entry);
+ }
+ } else {
+ return absl::InvalidArgumentError(
+ "OpStats selection criteria has last_successful_contribution enabled "
+ "but feature not enabled in the runtime!");
+ }
+ } else {
+ for (auto it = data.opstats().rbegin(); it != data.opstats().rend(); ++it) {
+ absl::Time last_update_time = GetLastUpdatedTime(*it);
+ if (last_update_time >= lower_bound_time &&
+ last_update_time <= upper_bound_time) {
+ selected_data.push_back(*it);
+ }
+ }
+ }
+ return std::make_unique<OpStatsExampleIterator>(
+ std::move(selected_data),
+ TimeUtil::TimestampToMilliseconds(data.earliest_trustworthy_time()));
+}
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/opstats/opstats_example_store.h b/fcp/client/opstats/opstats_example_store.h
new file mode 100644
index 0000000..9dbd650
--- /dev/null
+++ b/fcp/client/opstats/opstats_example_store.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_OPSTATS_OPSTATS_EXAMPLE_STORE_H_
+#define FCP_CLIENT_OPSTATS_OPSTATS_EXAMPLE_STORE_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_db.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+inline static constexpr char kOpStatsCollectionUri[] = "internal:/opstats";
+inline static constexpr char kPopulationName[] = "population_name";
+inline static constexpr char kSessionName[] = "session_name";
+inline static constexpr char kTaskName[] = "task_name";
+inline static constexpr char kEventsEventType[] = "events-event_type";
+inline static constexpr char kEventsTimestampMillis[] = "events-timestamp";
+inline static constexpr char kDatasetStatsUri[] = "dataset_stats-uri";
+inline static constexpr char kDatasetStatsNumExamplesRead[] =
+ "dataset_stats-num_examples_read";
+inline static constexpr char kDatasetStatsNumBytesRead[] =
+ "dataset_stats-num_bytes_read";
+inline static constexpr char kErrorMessage[] = "error_message";
+inline static constexpr char kRetryWindowDelayMinMillis[] =
+ "retry_window-delay_min";
+inline static constexpr char kRetryWindowDelayMaxMillis[] =
+ "retry_window-delay_max";
+inline static constexpr char kChunkingLayerBytesDownloaded[] =
+ "chunking_layer_bytes_downloaded";
+inline static constexpr char kChunkingLayerBytesUploaded[] =
+ "chunking_layer_bytes_uploaded";
+inline static constexpr char kNetworkDuration[] = "network_duration";
+inline static constexpr char kEarliestTrustWorthyTimeMillis[] =
+ "earliest_trustworthy_time";
+
+class OpStatsExampleIteratorFactory
+ : public fcp::client::engine::ExampleIteratorFactory {
+ public:
+ OpStatsExampleIteratorFactory(
+ OpStatsLogger* op_stats_logger, LogManager* log_manager,
+ bool opstats_last_successful_contribution_criteria)
+ : op_stats_logger_(op_stats_logger),
+ log_manager_(log_manager),
+ opstats_last_successful_contribution_criteria_(
+ opstats_last_successful_contribution_criteria) {}
+
+ bool CanHandle(const google::internal::federated::plan::ExampleSelector&
+ example_selector) override;
+
+ bool ShouldCollectStats() override { return true; }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) override;
+
+ private:
+ OpStatsLogger* op_stats_logger_;
+ LogManager* log_manager_;
+ bool opstats_last_successful_contribution_criteria_;
+};
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_OPSTATS_OPSTATS_EXAMPLE_STORE_H_
diff --git a/fcp/client/opstats/opstats_example_store_test.cc b/fcp/client/opstats/opstats_example_store_test.cc
new file mode 100644
index 0000000..a8c4e23
--- /dev/null
+++ b/fcp/client/opstats/opstats_example_store_test.cc
@@ -0,0 +1,601 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/opstats/opstats_example_store.h"
+
+#include <string>
+#include <utility>
+
+#include "google/protobuf/any.pb.h"
+#include "google/protobuf/util/time_util.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/testing/testing.h"
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+namespace {
+
+using ::google::internal::federated::plan::ExampleSelector;
+using ::google::internal::federatedml::v2::RetryWindow;
+using ::google::protobuf::util::TimeUtil;
+using ::testing::Return;
+
+constexpr char kTestTaskName[] = "stefans_really_cool_task";
+
+class OpStatsExampleStoreTest : public testing::Test {
+ public:
+ OpStatsExampleStoreTest() {
+ EXPECT_CALL(mock_opstats_logger_, IsOpStatsEnabled())
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(mock_opstats_logger_, GetOpStatsDb())
+ .WillRepeatedly(Return(&mock_db_));
+ EXPECT_CALL(mock_opstats_logger_, GetCurrentTaskName())
+ .WillRepeatedly(Return(kTestTaskName));
+ }
+
+ protected:
+ static OperationalStats::Event CreateEvent(
+ OperationalStats::Event::EventKind event_kind, int64_t event_time_ms) {
+ OperationalStats::Event event;
+ event.set_event_type(event_kind);
+ *event.mutable_timestamp() =
+ TimeUtil::MillisecondsToTimestamp(event_time_ms);
+ return event;
+ }
+
+ static OperationalStats::DatasetStats CreateDatasetStats(
+ int64_t num_examples_read, int64_t num_bytes_read) {
+ OperationalStats::DatasetStats stats;
+ stats.set_num_bytes_read(num_bytes_read);
+ stats.set_num_examples_read(num_examples_read);
+ return stats;
+ }
+
+ testing::StrictMock<MockOpStatsLogger> mock_opstats_logger_;
+ testing::StrictMock<MockOpStatsDb> mock_db_;
+ testing::StrictMock<MockLogManager> mock_log_manager_;
+ OpStatsExampleIteratorFactory iterator_factory_ =
+ OpStatsExampleIteratorFactory(
+ &mock_opstats_logger_, &mock_log_manager_,
+ /*opstats_last_successful_contribution_criteria=*/false);
+};
+
+TEST_F(OpStatsExampleStoreTest, TestInvalidCollectionUrl) {
+ ExampleSelector selector;
+ selector.set_collection_uri("INVALID");
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::OPSTATS_INCORRECT_COLLECTION_URI));
+
+ EXPECT_FALSE(iterator_factory_.CanHandle(selector));
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> status_or =
+ iterator_factory_.CreateExampleIterator(selector);
+ EXPECT_THAT(status_or.status(), IsCode(absl::StatusCode::kInvalidArgument));
+}
+
+TEST_F(OpStatsExampleStoreTest, TestMalformedCriteria) {
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ selector.mutable_criteria()->set_value("NOT_A_PROTO");
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::OPSTATS_INVALID_SELECTION_CRITERIA));
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> status_or =
+ iterator_factory_.CreateExampleIterator(selector);
+ EXPECT_THAT(status_or.status(), IsCode(absl::StatusCode::kInvalidArgument));
+}
+
+TEST_F(OpStatsExampleStoreTest, TestInvalidCriteria) {
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ OpStatsSelectionCriteria criteria;
+ *criteria.mutable_start_time() = TimeUtil::MillisecondsToTimestamp(2000L);
+ *criteria.mutable_end_time() = TimeUtil::MillisecondsToTimestamp(1000L);
+ selector.mutable_criteria()->PackFrom(criteria);
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::OPSTATS_INVALID_SELECTION_CRITERIA));
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> status_or =
+ iterator_factory_.CreateExampleIterator(selector);
+ EXPECT_THAT(status_or.status(), IsCode(absl::StatusCode::kInvalidArgument));
+}
+
+TEST_F(OpStatsExampleStoreTest, TestReadFromDbFailed) {
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ EXPECT_CALL(mock_db_, Read())
+ .WillOnce(Return(absl::InternalError("Something's wrong.")));
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> status_or =
+ iterator_factory_.CreateExampleIterator(selector);
+ EXPECT_THAT(status_or.status(), IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_F(OpStatsExampleStoreTest, Success) {
+ // Prepare some data
+ OpStatsSequence opstats_sequence;
+
+ OperationalStats* stats_first = opstats_sequence.add_opstats();
+ std::string session_first = "session_first";
+ std::string population_first = "population_first";
+ stats_first->set_session_name(session_first);
+ stats_first->set_population_name(population_first);
+
+ OperationalStats* stats_last = opstats_sequence.add_opstats();
+ std::string session_last = "session_last";
+ std::string population_last = "population_last";
+ stats_last->set_session_name(session_last);
+ stats_last->set_population_name(population_last);
+
+ EXPECT_CALL(mock_db_, Read()).WillOnce(Return(opstats_sequence));
+
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator_or =
+ iterator_factory_.CreateExampleIterator(selector);
+ ASSERT_TRUE(iterator_or.ok());
+ std::unique_ptr<ExampleIterator> iterator = std::move(iterator_or.value());
+ absl::StatusOr<std::string> example_or = iterator->Next();
+ ASSERT_TRUE(example_or.ok());
+ tensorflow::Example example_last;
+ ASSERT_TRUE(example_last.ParseFromString(example_or.value()));
+ example_or = iterator->Next();
+ ASSERT_TRUE(example_or.ok());
+ tensorflow::Example example_first;
+ ASSERT_TRUE(example_first.ParseFromString(example_or.value()));
+
+ // Check if the examples contain the expected data. Opstats examples are
+ // returned in last in, first out order.
+ std::set<std::string> actual_session_names;
+ actual_session_names.insert(ExtractSingleString(example_last, kSessionName));
+ actual_session_names.insert(ExtractSingleString(example_first, kSessionName));
+ std::set<std::string> expected_session_names = {session_last, session_first};
+ EXPECT_EQ(actual_session_names, expected_session_names);
+
+ std::set<std::string> actual_population_names;
+ actual_population_names.insert(
+ ExtractSingleString(example_last, kPopulationName));
+ actual_population_names.insert(
+ ExtractSingleString(example_first, kPopulationName));
+ std::set<std::string> expected_population_names = {population_last,
+ population_first};
+ EXPECT_EQ(actual_population_names, expected_population_names);
+
+ // We should have arrived at the end of the iterator.
+ example_or = iterator->Next();
+ EXPECT_THAT(example_or.status(), IsCode(absl::StatusCode::kOutOfRange));
+
+ // Subsequent Next() calls should all return OUT_OF_RANGE.
+ example_or = iterator->Next();
+ EXPECT_THAT(example_or.status(), IsCode(absl::StatusCode::kOutOfRange));
+
+ // Close() should work without exceptions.
+ iterator->Close();
+}
+
+TEST_F(OpStatsExampleStoreTest, EmptyData) {
+ EXPECT_CALL(mock_db_, Read())
+ .WillOnce(Return(OpStatsSequence::default_instance()));
+
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator_or =
+ iterator_factory_.CreateExampleIterator(selector);
+ ASSERT_TRUE(iterator_or.ok());
+ std::unique_ptr<ExampleIterator> iterator = std::move(iterator_or.value());
+ absl::StatusOr<std::string> status_or = iterator->Next();
+ EXPECT_THAT(status_or.status(), IsCode(absl::StatusCode::kOutOfRange));
+}
+
+TEST_F(OpStatsExampleStoreTest, DataIsFilteredBySelectionCriteria) {
+ OperationalStats included;
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 900L));
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED, 1000L));
+
+ OperationalStats excluded_early;
+ excluded_early.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 500L));
+ excluded_early.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED, 700L));
+
+ OperationalStats excluded_late;
+ excluded_late.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 800L));
+ excluded_late.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED, 2001L));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(excluded_early);
+ *opstats_sequence.add_opstats() = std::move(included);
+ *opstats_sequence.add_opstats() = std::move(excluded_late);
+ EXPECT_CALL(mock_db_, Read()).WillOnce(Return(opstats_sequence));
+
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ OpStatsSelectionCriteria criteria;
+ *criteria.mutable_start_time() = TimeUtil::MillisecondsToTimestamp(1000L);
+ *criteria.mutable_end_time() = TimeUtil::MillisecondsToTimestamp(2000L);
+ selector.mutable_criteria()->PackFrom(criteria);
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator_or =
+ iterator_factory_.CreateExampleIterator(selector);
+
+ ASSERT_TRUE(iterator_or.ok());
+ std::unique_ptr<ExampleIterator> iterator = std::move(iterator_or.value());
+ absl::StatusOr<std::string> example_or = iterator->Next();
+ ASSERT_TRUE(example_or.ok());
+ tensorflow::Example example;
+ example.ParseFromString(example_or.value());
+ auto event_type_list = ExtractRepeatedInt64(example, kEventsEventType);
+ ASSERT_EQ(event_type_list.at(0),
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED);
+ ASSERT_EQ(event_type_list.at(1),
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED);
+ auto event_time_ms_list =
+ ExtractRepeatedInt64(example, kEventsTimestampMillis);
+ ASSERT_EQ(event_time_ms_list.at(0), 900);
+ ASSERT_EQ(event_time_ms_list.at(1), 1000);
+
+ // We expect the iterator reaches the end because there's only 1 example.
+ example_or = iterator->Next();
+ EXPECT_THAT(example_or.status(), IsCode(absl::StatusCode::kOutOfRange));
+}
+
+TEST_F(OpStatsExampleStoreTest, SelectionCriteriaOnlyContainsBeginTime) {
+ OperationalStats included;
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 900L));
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED, 1000L));
+
+ OperationalStats excluded_early;
+ excluded_early.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 500L));
+ excluded_early.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED, 700L));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(excluded_early);
+ *opstats_sequence.add_opstats() = std::move(included);
+ EXPECT_CALL(mock_db_, Read()).WillOnce(Return(opstats_sequence));
+
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ OpStatsSelectionCriteria criteria;
+ *criteria.mutable_start_time() = TimeUtil::MillisecondsToTimestamp(1000L);
+ selector.mutable_criteria()->PackFrom(criteria);
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator_or =
+ iterator_factory_.CreateExampleIterator(selector);
+
+ ASSERT_TRUE(iterator_or.ok());
+ std::unique_ptr<ExampleIterator> iterator = std::move(iterator_or.value());
+ absl::StatusOr<std::string> example_or = iterator->Next();
+ ASSERT_TRUE(example_or.ok());
+ tensorflow::Example example;
+ example.ParseFromString(example_or.value());
+ auto event_type_list = ExtractRepeatedInt64(example, kEventsEventType);
+ ASSERT_EQ(event_type_list.at(0),
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED);
+ ASSERT_EQ(event_type_list.at(1),
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED);
+ auto event_time_ms_list =
+ ExtractRepeatedInt64(example, kEventsTimestampMillis);
+ ASSERT_EQ(event_time_ms_list.at(0), 900);
+ ASSERT_EQ(event_time_ms_list.at(1), 1000);
+
+ // We expect the iterator reaches the end because there's only 1 example.
+ example_or = iterator->Next();
+ EXPECT_THAT(example_or.status(), IsCode(absl::StatusCode::kOutOfRange));
+}
+
+TEST_F(OpStatsExampleStoreTest, SelectionCriteriaOnlyContainsEndTime) {
+ OperationalStats included;
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 900L));
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED, 1000L));
+
+ OperationalStats excluded_late;
+ excluded_late.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 800L));
+ excluded_late.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED, 2001L));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(included);
+ *opstats_sequence.add_opstats() = std::move(excluded_late);
+ EXPECT_CALL(mock_db_, Read()).WillOnce(Return(opstats_sequence));
+
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ OpStatsSelectionCriteria criteria;
+ *criteria.mutable_start_time() = TimeUtil::MillisecondsToTimestamp(1000L);
+ *criteria.mutable_end_time() = TimeUtil::MillisecondsToTimestamp(2000L);
+ selector.mutable_criteria()->PackFrom(criteria);
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator_or =
+ iterator_factory_.CreateExampleIterator(selector);
+
+ ASSERT_TRUE(iterator_or.ok());
+ std::unique_ptr<ExampleIterator> iterator = std::move(iterator_or.value());
+ absl::StatusOr<std::string> example_or = iterator->Next();
+ ASSERT_TRUE(example_or.ok());
+ tensorflow::Example example;
+ example.ParseFromString(example_or.value());
+ auto event_type_list = ExtractRepeatedInt64(example, kEventsEventType);
+ ASSERT_EQ(event_type_list.at(0),
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED);
+ ASSERT_EQ(event_type_list.at(1),
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED);
+ auto event_time_ms_list =
+ ExtractRepeatedInt64(example, kEventsTimestampMillis);
+ ASSERT_EQ(event_time_ms_list.at(0), 900);
+ ASSERT_EQ(event_time_ms_list.at(1), 1000);
+
+ // We expect the iterator reaches the end because there's only 1 example.
+ example_or = iterator->Next();
+ EXPECT_THAT(example_or.status(), IsCode(absl::StatusCode::kOutOfRange));
+}
+
+TEST_F(OpStatsExampleStoreTest,
+ SelectionCriteriaLastSuccessfulContributionEnabledAndExists) {
+ OpStatsExampleIteratorFactory iterator_factory =
+ OpStatsExampleIteratorFactory(
+ &mock_opstats_logger_, &mock_log_manager_,
+ /*opstats_last_successful_contribution_criteria=*/true);
+ OperationalStats included;
+ included.set_task_name(kTestTaskName);
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 900L));
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED, 1000L));
+
+ OperationalStats last_successful_contribution;
+ last_successful_contribution.set_task_name(kTestTaskName);
+ last_successful_contribution.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 1200L));
+ last_successful_contribution.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED, 2001L));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(included);
+ *opstats_sequence.add_opstats() = std::move(last_successful_contribution);
+ EXPECT_CALL(mock_db_, Read()).WillOnce(Return(opstats_sequence));
+
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ OpStatsSelectionCriteria criteria;
+ criteria.set_last_successful_contribution(true);
+ selector.mutable_criteria()->PackFrom(criteria);
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator_or =
+ iterator_factory.CreateExampleIterator(selector);
+
+ EXPECT_OK(iterator_or);
+ std::unique_ptr<ExampleIterator> iterator = std::move(iterator_or.value());
+ absl::StatusOr<std::string> example_or = iterator->Next();
+ EXPECT_OK(example_or);
+ tensorflow::Example example;
+ example.ParseFromString(example_or.value());
+ auto event_type_list = ExtractRepeatedInt64(example, kEventsEventType);
+ ASSERT_EQ(event_type_list.at(0),
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED);
+ ASSERT_EQ(event_type_list.at(1),
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED);
+ auto event_time_ms_list =
+ ExtractRepeatedInt64(example, kEventsTimestampMillis);
+ ASSERT_EQ(event_time_ms_list.at(0), 1200L);
+ ASSERT_EQ(event_time_ms_list.at(1), 2001L);
+
+ // We expect the iterator reaches the end because there's only 1 example.
+ example_or = iterator->Next();
+ EXPECT_THAT(example_or.status(), IsCode(absl::StatusCode::kOutOfRange));
+}
+
+TEST_F(OpStatsExampleStoreTest,
+ SelectionCriteriaLastSuccessfulContributionEnabledAndDoesNotExist) {
+ OpStatsExampleIteratorFactory iterator_factory =
+ OpStatsExampleIteratorFactory(
+ &mock_opstats_logger_, &mock_log_manager_,
+ /*opstats_last_successful_contribution_criteria=*/true);
+ OperationalStats non_matching;
+ non_matching.set_task_name("non_matching_task_name");
+ non_matching.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 900L));
+ non_matching.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED, 1000L));
+
+ OperationalStats matching_but_no_upload;
+ matching_but_no_upload.set_task_name(kTestTaskName);
+ matching_but_no_upload.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 1200L));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(non_matching);
+ *opstats_sequence.add_opstats() = std::move(matching_but_no_upload);
+ EXPECT_CALL(mock_db_, Read()).WillOnce(Return(opstats_sequence));
+
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ OpStatsSelectionCriteria criteria;
+ criteria.set_last_successful_contribution(true);
+ selector.mutable_criteria()->PackFrom(criteria);
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator_or =
+ iterator_factory.CreateExampleIterator(selector);
+
+ EXPECT_OK(iterator_or);
+ std::unique_ptr<ExampleIterator> iterator = std::move(iterator_or.value());
+ absl::StatusOr<std::string> example_or = iterator->Next();
+ EXPECT_THAT(example_or.status(), IsCode(absl::StatusCode::kOutOfRange));
+}
+
+TEST_F(OpStatsExampleStoreTest,
+ SelectionCriteriaLastSuccessfulContributionDisabled) {
+ // disable the feature but put in some matching entries.
+ OpStatsExampleIteratorFactory iterator_factory =
+ OpStatsExampleIteratorFactory(
+ &mock_opstats_logger_, &mock_log_manager_,
+ /*opstats_last_successful_contribution_criteria=*/false);
+
+ OperationalStats included;
+ included.set_task_name(kTestTaskName);
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 900L));
+ included.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED, 1000L));
+
+ OperationalStats last_successful_contribution;
+ last_successful_contribution.set_task_name(kTestTaskName);
+ last_successful_contribution.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED, 1200L));
+ last_successful_contribution.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED, 2001L));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(included);
+ *opstats_sequence.add_opstats() = std::move(last_successful_contribution);
+ EXPECT_CALL(mock_db_, Read()).WillOnce(Return(opstats_sequence));
+
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ OpStatsSelectionCriteria criteria;
+ criteria.set_last_successful_contribution(true);
+ selector.mutable_criteria()->PackFrom(criteria);
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator_or =
+ iterator_factory.CreateExampleIterator(selector);
+ // Enabling last successful contribution in the criteria when it's not enabled
+ // in the client returns INVALID_ARGUMENT.
+ EXPECT_THAT(iterator_or.status(), IsCode(absl::StatusCode::kInvalidArgument));
+}
+
+TEST_F(OpStatsExampleStoreTest, FullSerialization) {
+ OperationalStats stats;
+ // Set singular fields
+ std::string session = "session";
+ std::string population = "population";
+ std::string task_name = "task";
+ std::string error = "error";
+ int64_t chunking_layer_bytes_downloaded = 200;
+ int64_t chunking_layer_bytes_uploaded = 600;
+ int64_t network_duration_ms = 700;
+ stats.set_session_name(session);
+ stats.set_population_name(population);
+ stats.set_task_name(task_name);
+ stats.set_error_message(error);
+ stats.set_chunking_layer_bytes_downloaded(chunking_layer_bytes_downloaded);
+ stats.set_chunking_layer_bytes_uploaded(chunking_layer_bytes_uploaded);
+ *stats.mutable_network_duration() =
+ TimeUtil::MillisecondsToDuration(network_duration_ms);
+
+ // Set two events
+ OperationalStats::Event::EventKind event_kind_a =
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED;
+ int64_t event_time_ms_a = 1000;
+ OperationalStats::Event::EventKind event_kind_b =
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED;
+ int64_t event_time_ms_b = 1500;
+ stats.mutable_events()->Add(CreateEvent(event_kind_a, event_time_ms_a));
+ stats.mutable_events()->Add(CreateEvent(event_kind_b, event_time_ms_b));
+
+ // Set two dataset stats
+ std::string uri_a = "app:/train";
+ int64_t num_examples_a = 10;
+ int64_t example_bytes_a = 1000;
+ std::string uri_b = "app:/test";
+ int64_t num_examples_b = 5;
+ int64_t example_bytes_b = 500;
+ (*stats.mutable_dataset_stats())[uri_a] =
+ CreateDatasetStats(num_examples_a, example_bytes_a);
+ (*stats.mutable_dataset_stats())[uri_b] =
+ CreateDatasetStats(num_examples_b, example_bytes_b);
+
+ // Set retry window
+ int64_t min_delay_ms = 5000;
+ int64_t max_delay_ms = 9000;
+ RetryWindow retry;
+ retry.set_retry_token("token");
+ *retry.mutable_delay_min() = TimeUtil::MillisecondsToDuration(min_delay_ms);
+ *retry.mutable_delay_max() = TimeUtil::MillisecondsToDuration(max_delay_ms);
+ *stats.mutable_retry_window() = retry;
+
+ OpStatsSequence opstats_sequence;
+ ::google::protobuf::Timestamp currentTime = TimeUtil::GetCurrentTime();
+ *opstats_sequence.mutable_earliest_trustworthy_time() = currentTime;
+ *opstats_sequence.add_opstats() = std::move(stats);
+ EXPECT_CALL(mock_db_, Read()).WillOnce(Return(opstats_sequence));
+
+ ExampleSelector selector;
+ selector.set_collection_uri(kOpStatsCollectionUri);
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> iterator_or =
+ iterator_factory_.CreateExampleIterator(selector);
+
+ ASSERT_TRUE(iterator_or.ok());
+ std::unique_ptr<ExampleIterator> iterator = std::move(iterator_or.value());
+ absl::StatusOr<std::string> example_or = iterator->Next();
+ ASSERT_TRUE(example_or.ok());
+ tensorflow::Example example;
+ example.ParseFromString(example_or.value());
+
+ // Verify the example contains all the correct information.
+ // Singular fields
+ ASSERT_EQ(ExtractSingleString(example, kSessionName), session);
+ ASSERT_EQ(ExtractSingleString(example, kPopulationName), population);
+ ASSERT_EQ(ExtractSingleString(example, kTaskName), task_name);
+ ASSERT_EQ(ExtractSingleString(example, kErrorMessage), error);
+ ASSERT_EQ(ExtractSingleInt64(example, kChunkingLayerBytesDownloaded),
+ chunking_layer_bytes_downloaded);
+ ASSERT_EQ(ExtractSingleInt64(example, kChunkingLayerBytesUploaded),
+ chunking_layer_bytes_uploaded);
+ ASSERT_EQ(ExtractSingleInt64(example, kNetworkDuration), network_duration_ms);
+ ASSERT_EQ(ExtractSingleInt64(example, kEarliestTrustWorthyTimeMillis),
+ TimeUtil::TimestampToMilliseconds(currentTime));
+
+ // Events
+ auto event_types = ExtractRepeatedInt64(example, kEventsEventType);
+ ASSERT_EQ(event_types.at(0), event_kind_a);
+ ASSERT_EQ(event_types.at(1), event_kind_b);
+ auto event_times = ExtractRepeatedInt64(example, kEventsTimestampMillis);
+ ASSERT_EQ(event_times.at(0), event_time_ms_a);
+ ASSERT_EQ(event_times.at(1), event_time_ms_b);
+
+ // Dataset stats
+ auto dataset_urls = ExtractRepeatedString(example, kDatasetStatsUri);
+ // The order of the dataset stats doesn't matter, but should be consistent
+ // across the individual features.
+ int index_a = dataset_urls.at(1) == uri_a;
+ ASSERT_EQ(dataset_urls.at(index_a), uri_a);
+ ASSERT_EQ(dataset_urls.at(1 - index_a), uri_b);
+ auto example_counts =
+ ExtractRepeatedInt64(example, kDatasetStatsNumExamplesRead);
+ ASSERT_EQ(example_counts.at(index_a), num_examples_a);
+ ASSERT_EQ(example_counts.at(1 - index_a), num_examples_b);
+ auto example_bytes = ExtractRepeatedInt64(example, kDatasetStatsNumBytesRead);
+ ASSERT_EQ(example_bytes.at(index_a), example_bytes_a);
+ ASSERT_EQ(example_bytes.at(1 - index_a), example_bytes_b);
+
+ // RetryWindow
+ ASSERT_EQ(ExtractSingleInt64(example, kRetryWindowDelayMinMillis),
+ min_delay_ms);
+ ASSERT_EQ(ExtractSingleInt64(example, kRetryWindowDelayMaxMillis),
+ max_delay_ms);
+}
+
+} // anonymous namespace
+} // namespace opstats
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/opstats/opstats_logger.h b/fcp/client/opstats/opstats_logger.h
new file mode 100644
index 0000000..7b66c22
--- /dev/null
+++ b/fcp/client/opstats/opstats_logger.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_OPSTATS_OPSTATS_LOGGER_H_
+#define FCP_CLIENT_OPSTATS_OPSTATS_LOGGER_H_
+
+#include <memory>
+#include <string>
+
+#include "fcp/client/opstats/opstats_db.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/opstats.pb.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+
+// Base no-op class for the OpStats logger.
+class OpStatsLogger {
+ public:
+ OpStatsLogger() = default;
+
+ explicit OpStatsLogger(bool opstats_enabled)
+ : opstats_enabled_(opstats_enabled),
+ db_(std::make_unique<OpStatsDb>()),
+ init_status_(absl::OkStatus()) {}
+
+ OpStatsLogger(bool opstats_enabled, absl::Status init_status)
+ : opstats_enabled_(opstats_enabled),
+ db_(std::make_unique<OpStatsDb>()),
+ init_status_(init_status) {}
+
+ virtual ~OpStatsLogger() = default;
+
+ // Log a checkin accepted event and the corresponding task name.
+ virtual void AddEventAndSetTaskName(
+ const std::string& task_name, OperationalStats::Event::EventKind event) {}
+
+ // Log an event.
+ virtual void AddEvent(OperationalStats::Event::EventKind event) {}
+
+ // Log an event and corresponding error message.
+ virtual void AddEventWithErrorMessage(
+ OperationalStats::Event::EventKind event,
+ const std::string& error_message) {}
+
+ // Log info associated with a dataset created for a given collection. If this
+ // is called multiple times for the same collection, the example counts and
+ // sizes should be aggregated.
+ virtual void UpdateDatasetStats(const std::string& collection_uri,
+ int additional_example_count,
+ int64_t additional_example_size_bytes) {}
+
+ // Log network stats, replacing any old stats for the run.
+ virtual void SetNetworkStats(const NetworkStats& network_stats) {}
+
+ // Log the retry window, replacing any old retry window. Ignore any retry
+ // token in the retry window message.
+ virtual void SetRetryWindow(
+ google::internal::federatedml::v2::RetryWindow retry_window) {}
+
+ // Get the underlying opstats database.
+ virtual OpStatsDb* GetOpStatsDb() { return db_.get(); }
+
+ // Whether opstats is enabled.
+ virtual bool IsOpStatsEnabled() const { return opstats_enabled_; }
+
+ // Syncs all logged events to storage.
+ virtual absl::Status CommitToStorage() { return absl::OkStatus(); }
+
+ // Returns a status holding an initialization error if OpStats was enabled but
+ // failed to initialize.
+ absl::Status GetInitStatus() { return init_status_; }
+
+ // Returns the task name of the currently executing task. Only returns a valid
+ // task name if called after `AddEventAndSetTaskName` is called.
+ virtual std::string GetCurrentTaskName() { return ""; }
+
+ private:
+ bool opstats_enabled_;
+ std::unique_ptr<OpStatsDb> db_;
+ // If there was an error initializing the OpStats logger such that the no-op
+ // impl was returned instead, this will hold the status detailing the error.
+ absl::Status init_status_;
+};
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_OPSTATS_OPSTATS_LOGGER_H_
diff --git a/fcp/client/opstats/opstats_logger_impl.cc b/fcp/client/opstats/opstats_logger_impl.cc
new file mode 100644
index 0000000..3fc1e6b
--- /dev/null
+++ b/fcp/client/opstats/opstats_logger_impl.cc
@@ -0,0 +1,147 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/opstats/opstats_logger_impl.h"
+
+#include <string>
+#include <utility>
+
+#include "google/protobuf/util/time_util.h"
+#include "fcp/base/time_util.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_db.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/opstats.pb.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+
+using ::google::internal::federatedml::v2::RetryWindow;
+
+OpStatsLoggerImpl::OpStatsLoggerImpl(std::unique_ptr<OpStatsDb> db,
+ LogManager* log_manager,
+ const Flags* flags,
+ const std::string& session_name,
+ const std::string& population_name)
+ : db_(std::move(db)), log_manager_(log_manager) {
+ log_manager_->LogDiag(DebugDiagCode::TRAINING_OPSTATS_ENABLED);
+ log_manager_->LogDiag(ProdDiagCode::OPSTATS_DB_COMMIT_EXPECTED);
+
+ // Setup the OperationalStats message for the new run.
+ stats_.set_session_name(session_name);
+ stats_.set_population_name(population_name);
+}
+
+OpStatsLoggerImpl::~OpStatsLoggerImpl() {
+ // We're in the dtor, we don't care about what CommitToStorage returns.
+ auto status = CommitToStorage();
+}
+
+void OpStatsLoggerImpl::AddEventAndSetTaskName(
+ const std::string& task_name, OperationalStats::Event::EventKind event) {
+ absl::MutexLock lock(&mutex_);
+ AddNewEventToStats(event);
+ stats_.set_task_name(task_name);
+}
+
+void OpStatsLoggerImpl::AddEvent(OperationalStats::Event::EventKind event) {
+ absl::MutexLock lock(&mutex_);
+ AddNewEventToStats(event);
+}
+
+void OpStatsLoggerImpl::AddEventWithErrorMessage(
+ OperationalStats::Event::EventKind event,
+ const std::string& error_message) {
+ absl::MutexLock lock(&mutex_);
+ AddNewEventToStats(event);
+ // Don't replace an existing error message.
+ if (stats_.error_message().empty()) {
+ stats_.set_error_message(error_message);
+ }
+}
+
+void OpStatsLoggerImpl::UpdateDatasetStats(
+ const std::string& collection_uri, int additional_example_count,
+ int64_t additional_example_size_bytes) {
+ absl::MutexLock lock(&mutex_);
+ auto& dataset_stats = (*stats_.mutable_dataset_stats())[collection_uri];
+ dataset_stats.set_num_examples_read(dataset_stats.num_examples_read() +
+ additional_example_count);
+ dataset_stats.set_num_bytes_read(dataset_stats.num_bytes_read() +
+ additional_example_size_bytes);
+}
+
+void OpStatsLoggerImpl::SetNetworkStats(const NetworkStats& network_stats) {
+ absl::MutexLock lock(&mutex_);
+ stats_.set_chunking_layer_bytes_downloaded(network_stats.bytes_downloaded);
+ stats_.set_chunking_layer_bytes_uploaded(network_stats.bytes_uploaded);
+ *stats_.mutable_network_duration() =
+ TimeUtil::ConvertAbslToProtoDuration(network_stats.network_duration);
+}
+
+void OpStatsLoggerImpl::SetRetryWindow(RetryWindow retry_window) {
+ absl::MutexLock lock(&mutex_);
+ retry_window.clear_retry_token();
+ *stats_.mutable_retry_window() = std::move(retry_window);
+}
+
+void OpStatsLoggerImpl::AddNewEventToStats(
+ OperationalStats::Event::EventKind kind) {
+ auto new_event = stats_.add_events();
+ new_event->set_event_type(kind);
+ *new_event->mutable_timestamp() = google::protobuf::util::TimeUtil::GetCurrentTime();
+}
+
+absl::Status OpStatsLoggerImpl::CommitToStorage() {
+ absl::MutexLock lock(&mutex_);
+ log_manager_->LogDiag(ProdDiagCode::OPSTATS_DB_COMMIT_ATTEMPTED);
+ const absl::Time before_commit_time = absl::Now();
+ auto status = already_committed_
+ ? db_->Transform([stats = &stats_](OpStatsSequence& data) {
+ // Check if opstats on disk somehow got cleared between
+ // the first commit and now, and handle appropriately.
+ // This can happen e.g. if the ttl for the opstats db
+ // is incorrectly configured to have a very low ttl,
+ // causing the entire history to be lost as part of the
+ // update.
+ if (data.opstats_size() == 0) {
+ *data.add_opstats() = *stats;
+ } else {
+ *data.mutable_opstats(data.opstats_size() - 1) =
+ *stats;
+ }
+ })
+ : db_->Transform([stats = &stats_](OpStatsSequence& data) {
+ *data.add_opstats() = *stats;
+ });
+ const absl::Time after_commit_time = absl::Now();
+ log_manager_->LogToLongHistogram(
+ HistogramCounters::TRAINING_OPSTATS_COMMIT_LATENCY,
+ absl::ToInt64Milliseconds(after_commit_time - before_commit_time));
+ already_committed_ = true;
+ return status;
+}
+
+std::string OpStatsLoggerImpl::GetCurrentTaskName() {
+ absl::MutexLock lock(&mutex_);
+ return stats_.task_name();
+}
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/opstats/opstats_logger_impl.h b/fcp/client/opstats/opstats_logger_impl.h
new file mode 100644
index 0000000..581406f
--- /dev/null
+++ b/fcp/client/opstats/opstats_logger_impl.h
@@ -0,0 +1,112 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_OPSTATS_OPSTATS_LOGGER_IMPL_H_
+#define FCP_CLIENT_OPSTATS_OPSTATS_LOGGER_IMPL_H_
+
+#include <string>
+
+#include "absl/synchronization/mutex.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_db.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/opstats.pb.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+
+// An implementation of OpStatsLogger backed by a database.
+class OpStatsLoggerImpl : public OpStatsLogger {
+ public:
+ // Creates a logger backed by an actual database. Populates the internal
+ // message with the provided session and population names.
+ OpStatsLoggerImpl(std::unique_ptr<OpStatsDb> db, LogManager* log_manager,
+ const Flags* flags, const std::string& session_name,
+ const std::string& population_name);
+
+ // Commits the cumulative message to the db.
+ ~OpStatsLoggerImpl() override;
+
+ // Adds an event and the given task name to the cumulative internal message,
+ // in a single transaction.
+ void AddEventAndSetTaskName(const std::string& task_name,
+ OperationalStats::Event::EventKind event)
+ ABSL_LOCKS_EXCLUDED(mutex_) override;
+
+ // Adds an event to the cumulative internal message.
+ void AddEvent(OperationalStats::Event::EventKind event)
+ ABSL_LOCKS_EXCLUDED(mutex_) override;
+
+ // Adds an event and corresponding error message to the cumulative internal
+ // message.
+ void AddEventWithErrorMessage(OperationalStats::Event::EventKind event,
+ const std::string& error_message) override;
+
+ // Updates info associated with a dataset created for a given collection in
+ // the cumulative internal message. If this is called multiple times for the
+ // same collection, the example counts and sizes will be aggregated in the
+ // underlying submessage.
+ void UpdateDatasetStats(const std::string& collection_uri,
+ int additional_example_count,
+ int64_t additional_example_size_bytes)
+ ABSL_LOCKS_EXCLUDED(mutex_) override;
+
+ // Adds network stats, replacing any old stats for the run, to the cumulative
+ // internal message.
+ void SetNetworkStats(const NetworkStats& network_stats)
+ ABSL_LOCKS_EXCLUDED(mutex_) override;
+
+ // Sets the retry window, replacing any old retry window, in the cumulative
+ // internal message. Any retry token in the retry window message is dropped.
+ void SetRetryWindow(
+ google::internal::federatedml::v2::RetryWindow retry_window)
+ ABSL_LOCKS_EXCLUDED(mutex_) override;
+
+ // Get the underlying opstats database.
+ OpStatsDb* GetOpStatsDb() override { return db_.get(); }
+
+ // Whether opstats is enabled. An instance of this class should only ever be
+ // created when opstats is enabled.
+ bool IsOpStatsEnabled() const override { return true; }
+
+ // Syncs all logged events to storage.
+ absl::Status CommitToStorage() override;
+
+ // Returns the task name of the currently executing task. Only returns a valid
+ // task name if called after `AddEventAndSetTaskName` is called.
+ std::string GetCurrentTaskName() override;
+
+ private:
+ // Helper for adding a new event of the specified kind to the cumulative
+ // message being stored in this class.
+ void AddNewEventToStats(OperationalStats::Event::EventKind kind)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Cumulative message storing information about this run.
+ OperationalStats stats_ ABSL_GUARDED_BY(mutex_);
+ bool already_committed_ ABSL_GUARDED_BY(mutex_) = false;
+ std::unique_ptr<OpStatsDb> db_;
+ LogManager* log_manager_;
+ absl::Mutex mutex_;
+};
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_OPSTATS_OPSTATS_LOGGER_IMPL_H_
diff --git a/fcp/client/opstats/opstats_logger_impl_test.cc b/fcp/client/opstats/opstats_logger_impl_test.cc
new file mode 100644
index 0000000..05f09d2
--- /dev/null
+++ b/fcp/client/opstats/opstats_logger_impl_test.cc
@@ -0,0 +1,574 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/opstats/opstats_logger_impl.h"
+
+#include <filesystem>
+#include <string>
+#include <utility>
+
+#include "google/protobuf/util/time_util.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/histogram_counters.pb.h"
+#include "fcp/client/opstats/pds_backed_opstats_db.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/protos/opstats.pb.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+namespace {
+
+using google::internal::federatedml::v2::RetryWindow;
+using google::protobuf::Timestamp;
+using google::protobuf::util::TimeUtil;
+using testing::Ge;
+using testing::Return;
+using testing::StrictMock;
+
+constexpr char kSessionName[] = "SESSION";
+constexpr char kPopulationName[] = "POPULATION";
+constexpr char kTaskName[] = "TASK";
+
+class OpStatsLoggerImplTest : public testing::Test {
+ protected:
+ void SetUp() override {
+ ON_CALL(mock_flags_, enable_opstats()).WillByDefault(Return(true));
+ ON_CALL(mock_flags_, opstats_ttl_days()).WillByDefault(Return(1));
+ ON_CALL(mock_flags_, opstats_db_size_limit_bytes())
+ .WillByDefault(Return(1 * 1024 * 1024));
+ base_dir_ = testing::TempDir();
+ }
+
+ void TearDown() override {
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ ASSERT_OK(db);
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogToLongHistogram(OPSTATS_DB_SIZE_BYTES, /*execution_index=*/0,
+ /*epoch_index=*/0, engine::DataSourceType::DATASET,
+ /*value=*/0));
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/0));
+ EXPECT_THAT((*db)->Transform([](OpStatsSequence& data) { data.Clear(); }),
+ IsOk());
+ }
+
+ std::unique_ptr<OpStatsLogger> CreateOpStatsLoggerImpl(
+ const std::string& session_name, const std::string& population_name) {
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ FCP_CHECK(db.ok());
+ return std::make_unique<OpStatsLoggerImpl>(std::move(*db),
+ &mock_log_manager_, &mock_flags_,
+ session_name, population_name);
+ }
+
+ // Checks that the expected and actual protos are equivalent, ignoring the
+ // timestamps in the actual proto, which must also be increasing.
+ void CheckEqualProtosAndIncreasingTimestamps(const Timestamp& start_time,
+ const OpStatsSequence& expected,
+ OpStatsSequence actual) {
+ auto previous_timestamp = start_time;
+ for (auto& opstats : *actual.mutable_opstats()) {
+ for (auto& event : *opstats.mutable_events()) {
+ EXPECT_GE(event.timestamp(), previous_timestamp);
+ previous_timestamp = event.timestamp();
+ // Remove the timestamp
+ event.clear_timestamp();
+ }
+ }
+ actual.clear_earliest_trustworthy_time();
+ EXPECT_THAT(actual, EqualsProto(expected));
+ }
+
+ void ExpectOpstatsEnabledEvents(int num_opstats_loggers) {
+ ExpectOpstatsEnabledEvents(num_opstats_loggers, num_opstats_loggers);
+ }
+
+ void ExpectOpstatsEnabledEvents(int num_opstats_loggers,
+ int num_opstats_commits) {
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(DebugDiagCode::TRAINING_OPSTATS_ENABLED))
+ .Times(num_opstats_loggers);
+ // Logged when the class is initialized.
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::OPSTATS_DB_COMMIT_EXPECTED))
+ .Times(num_opstats_loggers);
+ EXPECT_CALL(mock_log_manager_,
+ LogDiag(ProdDiagCode::OPSTATS_DB_COMMIT_ATTEMPTED))
+ .Times(num_opstats_commits);
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogToLongHistogram(TRAINING_OPSTATS_COMMIT_LATENCY,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Ge(0)))
+ .Times(num_opstats_commits);
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogToLongHistogram(OPSTATS_DB_SIZE_BYTES, /*execution_index=*/0,
+ /*epoch_index=*/0, engine::DataSourceType::DATASET,
+ /*value=*/Ge(0)))
+ .Times(num_opstats_commits);
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogToLongHistogram(OPSTATS_DB_NUM_ENTRIES, /*execution_index=*/0,
+ /*epoch_index=*/0, engine::DataSourceType::DATASET,
+ /*value=*/Ge(0)))
+ .Times(num_opstats_commits);
+ }
+
+ RetryWindow CreateRetryWindow(const std::string& retry_token,
+ int64_t delay_min_seconds,
+ int64_t delay_max_seconds) {
+ RetryWindow retry_window;
+ retry_window.set_retry_token(retry_token);
+ retry_window.mutable_delay_min()->set_seconds(delay_min_seconds);
+ retry_window.mutable_delay_max()->set_seconds(delay_max_seconds);
+ return retry_window;
+ }
+
+ std::string base_dir_;
+ MockFlags mock_flags_;
+ StrictMock<MockLogManager> mock_log_manager_;
+};
+
+TEST_F(OpStatsLoggerImplTest, SetTaskName) {
+ auto start_time = TimeUtil::GetCurrentTime();
+ ExpectOpstatsEnabledEvents(/*num_opstats_loggers=*/3);
+
+ auto opstats_logger = CreateOpStatsLoggerImpl(kSessionName, kPopulationName);
+ opstats_logger->AddEventAndSetTaskName(
+ kTaskName, OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+
+ opstats_logger.reset();
+
+ auto opstats_logger_no_population =
+ CreateOpStatsLoggerImpl(kSessionName,
+ /*population_name=*/"");
+ opstats_logger_no_population->AddEventAndSetTaskName(
+ kTaskName, OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+
+ opstats_logger_no_population.reset();
+
+ auto opstats_logger_no_session =
+ CreateOpStatsLoggerImpl(/*session_name=*/"", kPopulationName);
+ opstats_logger_no_session->AddEventAndSetTaskName(
+ kTaskName, OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+
+ opstats_logger_no_session.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ ASSERT_OK(db);
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ OpStatsSequence expected;
+ // Add the first run
+ auto new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->set_population_name(kPopulationName);
+ new_opstats->set_task_name(kTaskName);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+ // Add the second run
+ new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->set_task_name(kTaskName);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+ // Add the third run
+ new_opstats = expected.add_opstats();
+ new_opstats->set_population_name(kPopulationName);
+ new_opstats->set_task_name(kTaskName);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+
+ CheckEqualProtosAndIncreasingTimestamps(start_time, expected, *data);
+}
+
+TEST_F(OpStatsLoggerImplTest, NewRunAfterCorruption) {
+ auto start_time = TimeUtil::GetCurrentTime();
+ ExpectOpstatsEnabledEvents(/*num_opstats_loggers=*/2);
+
+ auto opstats_logger = CreateOpStatsLoggerImpl(kSessionName, kPopulationName);
+ opstats_logger->AddEventAndSetTaskName(
+ kTaskName, OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+ opstats_logger.reset();
+
+ // Make the db file corrupt
+ {
+ std::filesystem::path db_path(base_dir_);
+ db_path /= PdsBackedOpStatsDb::kParentDir;
+ db_path /= PdsBackedOpStatsDb::kDbFileName;
+ protostore::FileStorage file_storage;
+ std::unique_ptr<protostore::OutputStream> ostream =
+ file_storage.OpenForWrite(db_path).value();
+ EXPECT_THAT(ostream->Append("not a proto"), IsOk());
+ EXPECT_THAT(ostream->Close(), IsOk());
+ }
+
+ EXPECT_CALL(mock_log_manager_, LogDiag(ProdDiagCode::OPSTATS_READ_FAILED));
+ auto opstats_logger_no_population =
+ CreateOpStatsLoggerImpl(kSessionName,
+ /*population_name=*/"");
+ opstats_logger_no_population->AddEventAndSetTaskName(
+ kTaskName, OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+
+ opstats_logger_no_population.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ ASSERT_OK(db);
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ // Expect only the second run to be represented in the db.
+ OpStatsSequence expected;
+ auto new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->set_task_name(kTaskName);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+ CheckEqualProtosAndIncreasingTimestamps(start_time, expected, *data);
+}
+
+TEST_F(OpStatsLoggerImplTest, AddEvent) {
+ auto start_time = TimeUtil::GetCurrentTime();
+ ExpectOpstatsEnabledEvents(/*num_opstats_loggers=*/2);
+
+ auto opstats_logger = CreateOpStatsLoggerImpl(kSessionName, kPopulationName);
+ opstats_logger->AddEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED);
+ opstats_logger.reset();
+
+ auto opstats_logger_no_population =
+ CreateOpStatsLoggerImpl(kSessionName,
+ /*population_name=*/"");
+ opstats_logger_no_population->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED);
+ opstats_logger_no_population->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED);
+ opstats_logger_no_population.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ ASSERT_OK(db);
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ OpStatsSequence expected;
+ // Add the first run
+ auto new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->set_population_name(kPopulationName);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED);
+ // Add the second run
+ new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED);
+
+ CheckEqualProtosAndIncreasingTimestamps(start_time, expected, *data);
+}
+
+TEST_F(OpStatsLoggerImplTest, AddEventAfterTtl) {
+ auto start_time = TimeUtil::GetCurrentTime();
+ ExpectOpstatsEnabledEvents(/*num_opstats_loggers=*/2);
+
+ // Set the ttl to 0 so that previous data will be wiped out each time the
+ // logger tries to commit new data.
+ EXPECT_CALL(mock_flags_, opstats_ttl_days()).WillRepeatedly(Return(0));
+ auto opstats_logger = CreateOpStatsLoggerImpl(kSessionName, kPopulationName);
+ opstats_logger->AddEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED);
+ opstats_logger.reset();
+
+ auto opstats_logger_no_population =
+ CreateOpStatsLoggerImpl(kSessionName,
+ /*population_name=*/"");
+ opstats_logger_no_population->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED);
+ opstats_logger_no_population->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED);
+ opstats_logger_no_population.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ // Expect the db to contain only data associated with the second run. The
+ // second run should be complete, however.
+ OpStatsSequence expected;
+ auto new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED);
+
+ CheckEqualProtosAndIncreasingTimestamps(start_time, expected, *data);
+}
+
+TEST_F(OpStatsLoggerImplTest, AddEventWithErrorMessage) {
+ auto start_time = TimeUtil::GetCurrentTime();
+ ExpectOpstatsEnabledEvents(/*num_opstats_loggers=*/1);
+
+ auto opstats_logger = CreateOpStatsLoggerImpl(kSessionName, kPopulationName);
+ opstats_logger->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_ERROR_IO, "first error");
+ opstats_logger->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_ERROR_TENSORFLOW, "second error");
+ opstats_logger.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ OpStatsSequence expected;
+ auto new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->set_population_name(kPopulationName);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ERROR_IO);
+ new_opstats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ERROR_TENSORFLOW);
+ new_opstats->set_error_message("first error");
+
+ CheckEqualProtosAndIncreasingTimestamps(start_time, expected, *data);
+}
+
+TEST_F(OpStatsLoggerImplTest, UpdateDatasetStats) {
+ ExpectOpstatsEnabledEvents(/*num_opstats_loggers=*/1);
+
+ auto opstats_logger = CreateOpStatsLoggerImpl(kSessionName, kPopulationName);
+ const std::string kCollectionUri = "app:/collection_uri";
+ const std::string kCollectionUriOther = "app:/collection_uri_other";
+ opstats_logger->UpdateDatasetStats(kCollectionUri,
+ /*additional_example_count=*/100,
+ /*additional_example_size_bytes=*/1000);
+ opstats_logger->UpdateDatasetStats(kCollectionUriOther,
+ /*additional_example_count=*/200,
+ /*additional_example_size_bytes=*/2000);
+ opstats_logger->UpdateDatasetStats(kCollectionUri,
+ /*additional_example_count=*/300,
+ /*additional_example_size_bytes=*/3000);
+ opstats_logger.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ OpStatsSequence expected;
+ auto new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->set_population_name(kPopulationName);
+ OperationalStats::DatasetStats dataset_stats;
+ dataset_stats.set_num_examples_read(400);
+ dataset_stats.set_num_bytes_read(4000);
+ (*new_opstats->mutable_dataset_stats())[kCollectionUri] =
+ std::move(dataset_stats);
+ OperationalStats::DatasetStats dataset_stats_other;
+ dataset_stats_other.set_num_examples_read(200);
+ dataset_stats_other.set_num_bytes_read(2000);
+ (*new_opstats->mutable_dataset_stats())[kCollectionUriOther] =
+ std::move(dataset_stats_other);
+
+ (*data).clear_earliest_trustworthy_time();
+ EXPECT_THAT(*data, EqualsProto(expected));
+}
+
+TEST_F(OpStatsLoggerImplTest, SetNetworkStats) {
+ ExpectOpstatsEnabledEvents(/*num_opstats_loggers=*/1);
+
+ auto opstats_logger = CreateOpStatsLoggerImpl(kSessionName, kPopulationName);
+ opstats_logger->SetNetworkStats(
+ {.bytes_downloaded = 102,
+ .bytes_uploaded = 103,
+ .network_duration = absl::Milliseconds(104)});
+ opstats_logger->SetNetworkStats(
+ {.bytes_downloaded = 202,
+ .bytes_uploaded = 203,
+ .network_duration = absl::Milliseconds(204)});
+ opstats_logger.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ ASSERT_OK(db);
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ OpStatsSequence expected;
+ auto new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->set_population_name(kPopulationName);
+ // The bytes_downloaded/bytes_uploaded fields should not be set anymore
+ new_opstats->set_chunking_layer_bytes_downloaded(202);
+ new_opstats->set_chunking_layer_bytes_uploaded(203);
+ // The new network_duration field should be set now.
+ new_opstats->mutable_network_duration()->set_nanos(
+ static_cast<int32_t>(absl::ToInt64Nanoseconds(absl::Milliseconds(204))));
+
+ (*data).clear_earliest_trustworthy_time();
+ EXPECT_THAT(*data, EqualsProto(expected));
+}
+
+TEST_F(OpStatsLoggerImplTest, SetRetryWindow) {
+ ExpectOpstatsEnabledEvents(/*num_opstats_loggers=*/1);
+
+ auto opstats_logger = CreateOpStatsLoggerImpl(kSessionName, kPopulationName);
+ opstats_logger->SetRetryWindow(CreateRetryWindow("retry_token", 100, 200));
+ opstats_logger->SetRetryWindow(CreateRetryWindow("retry_token", 300, 400));
+ opstats_logger.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ ASSERT_OK(db);
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ OpStatsSequence expected;
+ auto new_opstats = expected.add_opstats();
+ new_opstats->set_session_name(kSessionName);
+ new_opstats->set_population_name(kPopulationName);
+ *new_opstats->mutable_retry_window() =
+ CreateRetryWindow(/*retry_token=*/"", 300, 400);
+
+ (*data).clear_earliest_trustworthy_time();
+ EXPECT_THAT(*data, EqualsProto(expected));
+}
+
+TEST_F(OpStatsLoggerImplTest, AddEventCommitAddMoreEvents) {
+ auto start_time = TimeUtil::GetCurrentTime();
+ ExpectOpstatsEnabledEvents(
+ /*num_opstats_loggers=*/2, /*num_opstats_commits=*/4);
+
+ auto opstats_logger = CreateOpStatsLoggerImpl(kSessionName, kPopulationName);
+ opstats_logger->AddEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED);
+ opstats_logger.reset();
+
+ auto opstats_logger_no_population =
+ CreateOpStatsLoggerImpl(kSessionName,
+ /*population_name=*/"");
+ opstats_logger_no_population->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED);
+ ASSERT_OK(opstats_logger_no_population->CommitToStorage());
+ opstats_logger_no_population->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED);
+ ASSERT_OK(opstats_logger_no_population->CommitToStorage());
+ opstats_logger_no_population->AddEvent(
+ OperationalStats::Event::EVENT_KIND_TRAIN_NOT_STARTED);
+ opstats_logger_no_population.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ ASSERT_OK(db);
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ OpStatsSequence expected;
+ // Add the first run
+ auto second_run = expected.add_opstats();
+ second_run->set_session_name(kSessionName);
+ second_run->set_population_name(kPopulationName);
+ second_run->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED);
+ // Add the second run
+ second_run = expected.add_opstats();
+ second_run->set_session_name(kSessionName);
+ second_run->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED);
+ second_run->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED);
+ second_run->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_TRAIN_NOT_STARTED);
+
+ CheckEqualProtosAndIncreasingTimestamps(start_time, expected, *data);
+}
+
+TEST_F(OpStatsLoggerImplTest, MisconfiguredTtlMultipleCommit) {
+ auto start_time = TimeUtil::GetCurrentTime();
+ ExpectOpstatsEnabledEvents(/*num_opstats_loggers=*/1,
+ /*num_opstats_commits*/ 3);
+ auto db_zero_ttl = PdsBackedOpStatsDb::Create(
+ base_dir_, absl::ZeroDuration(), mock_log_manager_,
+ mock_flags_.opstats_db_size_limit_bytes())
+ .value();
+ auto opstats_logger = std::make_unique<OpStatsLoggerImpl>(
+ std::move(db_zero_ttl), &mock_log_manager_, &mock_flags_, kSessionName,
+ kPopulationName);
+
+ opstats_logger->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED);
+ ASSERT_OK(opstats_logger->CommitToStorage());
+ opstats_logger->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED);
+ ASSERT_OK(opstats_logger->CommitToStorage());
+ opstats_logger->AddEvent(
+ OperationalStats::Event::EVENT_KIND_TRAIN_NOT_STARTED);
+ opstats_logger.reset();
+
+ auto db = PdsBackedOpStatsDb::Create(
+ base_dir_, mock_flags_.opstats_ttl_days() * absl::Hours(24),
+ mock_log_manager_, mock_flags_.opstats_db_size_limit_bytes());
+ ASSERT_OK(db);
+ auto data = (*db)->Read();
+ ASSERT_OK(data);
+
+ // Even though we had corruption in the middle of the run, it should be ok
+ // because we committed the entire history successfully at the end.
+ OpStatsSequence expected;
+ auto expected_stats = expected.add_opstats();
+ expected_stats->set_population_name(kPopulationName);
+ expected_stats->set_session_name(kSessionName);
+ expected_stats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED);
+ expected_stats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED);
+ expected_stats->add_events()->set_event_type(
+ OperationalStats::Event::EVENT_KIND_TRAIN_NOT_STARTED);
+
+ CheckEqualProtosAndIncreasingTimestamps(start_time, expected, *data);
+}
+
+} // anonymous namespace
+} // namespace opstats
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/opstats/opstats_utils.cc b/fcp/client/opstats/opstats_utils.cc
new file mode 100644
index 0000000..a782197
--- /dev/null
+++ b/fcp/client/opstats/opstats_utils.cc
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/opstats/opstats_utils.h"
+
+#include <algorithm>
+#include <optional>
+#include <string>
+
+#include "google/protobuf/timestamp.pb.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/opstats/opstats_db.h"
+#include "fcp/protos/opstats.pb.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+
+std::optional<google::protobuf::Timestamp> GetLastSuccessfulContributionTime(
+ OpStatsSequence& data, const std::string& task_name) {
+ std::optional<OperationalStats> last_successful_entry =
+ GetLastSuccessfulContribution(data, task_name);
+ if (!last_successful_entry.has_value()) {
+ return std::nullopt;
+ }
+
+ auto upload_started = std::find_if(
+ last_successful_entry->events().begin(),
+ last_successful_entry->events().end(),
+ [](const OperationalStats::Event& event) {
+ return event.event_type() ==
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED;
+ });
+ if (upload_started == last_successful_entry->events().end()) {
+ // For last_successful_entry to have a value, it must have had an
+ // EVENT_KIND_RESULT_UPLOAD_STARTED event, so we should never reach this.
+ return std::nullopt;
+ }
+
+ return upload_started->timestamp();
+}
+
+std::optional<OperationalStats> GetLastSuccessfulContribution(
+ OpStatsSequence& data, const std::string& task_name) {
+ for (auto it = data.opstats().rbegin(); it != data.opstats().rend(); ++it) {
+ const OperationalStats& opstats_entry = *it;
+ bool upload_started = false;
+ bool upload_aborted = false;
+ if (opstats_entry.task_name() != task_name) {
+ continue;
+ }
+ for (const auto& event : opstats_entry.events()) {
+ if (event.event_type() ==
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED) {
+ upload_started = true;
+ }
+ if (event.event_type() ==
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_SERVER_ABORTED) {
+ upload_aborted = true;
+ }
+ }
+ if (upload_started && !upload_aborted) {
+ return opstats_entry;
+ }
+ }
+ return std::nullopt;
+}
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/opstats/opstats_utils.h b/fcp/client/opstats/opstats_utils.h
new file mode 100644
index 0000000..1e22c22
--- /dev/null
+++ b/fcp/client/opstats/opstats_utils.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_OPSTATS_OPSTATS_UTILS_H_
+#define FCP_CLIENT_OPSTATS_OPSTATS_UTILS_H_
+
+#include <optional>
+#include <string>
+
+#include "google/protobuf/timestamp.pb.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/opstats/opstats_db.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+
+// Returns an optional containing an OperationalStats of the last time the
+// runtime successfully contributed to a task with the given task name,
+// otherwise returns an empty optional.
+std::optional<OperationalStats> GetLastSuccessfulContribution(
+ OpStatsSequence& data, const std::string& task_name);
+
+// Returns an optional containing the timestamp of the last time the runtime
+// successfully contributed to a task with the given task name, otherwise
+// returns an empty optional.
+std::optional<google::protobuf::Timestamp> GetLastSuccessfulContributionTime(
+ OpStatsSequence& data, const std::string& task_name);
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_OPSTATS_OPSTATS_UTILS_H_
diff --git a/fcp/client/opstats/opstats_utils_test.cc b/fcp/client/opstats/opstats_utils_test.cc
new file mode 100644
index 0000000..636d11f
--- /dev/null
+++ b/fcp/client/opstats/opstats_utils_test.cc
@@ -0,0 +1,154 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/opstats/opstats_utils.h"
+
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/client/test_helpers.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+namespace {
+
+constexpr char kTaskName[] = "task";
+OperationalStats::Event::EventKind kUploadStartedEvent =
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED;
+OperationalStats::Event::EventKind kUploadServerAbortedEvent =
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_SERVER_ABORTED;
+
+OperationalStats::Event CreateEvent(
+ OperationalStats::Event::EventKind event_kind, int64_t event_time_seconds) {
+ OperationalStats::Event event;
+ event.set_event_type(event_kind);
+ google::protobuf::Timestamp t;
+ t.set_seconds(event_time_seconds);
+ *event.mutable_timestamp() = t;
+ return event;
+}
+
+TEST(OpStatsUtils,
+ GetLastSuccessfulContributionTimeReturnsUploadStartedTimestamp) {
+ OperationalStats stats;
+ stats.set_task_name(kTaskName);
+
+ int64_t upload_started_time_sec = 1000;
+ stats.mutable_events()->Add(
+ CreateEvent(kUploadStartedEvent, upload_started_time_sec));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(stats);
+
+ auto last_time =
+ GetLastSuccessfulContributionTime(opstats_sequence, kTaskName);
+ EXPECT_EQ(last_time->seconds(), upload_started_time_sec);
+}
+
+TEST(OpStatsUtils,
+ GetLastSuccessfulContributionTimeReturnNotFoundForUnknownTask) {
+ OperationalStats stats;
+ stats.set_task_name(kTaskName);
+
+ int64_t upload_started_time_sec = 1000;
+ stats.mutable_events()->Add(
+ CreateEvent(kUploadStartedEvent, upload_started_time_sec));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(stats);
+ EXPECT_FALSE(
+ GetLastSuccessfulContributionTime(opstats_sequence, "task_name_not_found")
+ .has_value());
+}
+
+TEST(OpStatsUtils,
+ GetLastSuccessfulContributionTimeReturnsMostRecentUploadStartedTimestamp) {
+ OpStatsSequence opstats_sequence;
+
+ OperationalStats old_stats;
+ old_stats.set_task_name(kTaskName);
+ old_stats.mutable_events()->Add(CreateEvent(kUploadStartedEvent, 1000));
+ *opstats_sequence.add_opstats() = std::move(old_stats);
+
+ OperationalStats new_stats;
+ new_stats.set_task_name(kTaskName);
+ int64_t new_upload_started_sec = 2000;
+ new_stats.mutable_events()->Add(
+ CreateEvent(kUploadStartedEvent, new_upload_started_sec));
+ *opstats_sequence.add_opstats() = std::move(new_stats);
+
+ auto last_time =
+ GetLastSuccessfulContributionTime(opstats_sequence, kTaskName);
+ EXPECT_EQ(last_time->seconds(), new_upload_started_sec);
+}
+
+TEST(OpStatsUtils,
+ GetLastSuccessfulContributionTimeReturnNotFoundIfAbortedByServer) {
+ OperationalStats stats;
+ stats.set_task_name(kTaskName);
+ stats.mutable_events()->Add(CreateEvent(kUploadStartedEvent, 1000));
+ stats.mutable_events()->Add(CreateEvent(kUploadServerAbortedEvent, 1001));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(stats);
+
+ EXPECT_FALSE(GetLastSuccessfulContributionTime(opstats_sequence, kTaskName)
+ .has_value());
+}
+
+TEST(OpStatsUtils,
+ GetLastSuccessfulContributionTimeReturnOlderIfNewerAbortedByServer) {
+ OpStatsSequence opstats_sequence;
+
+ OperationalStats old_stats;
+ old_stats.set_task_name(kTaskName);
+ int64_t expected_time_sec = 1000;
+ old_stats.mutable_events()->Add(
+ CreateEvent(kUploadStartedEvent, expected_time_sec));
+ *opstats_sequence.add_opstats() = std::move(old_stats);
+
+ OperationalStats new_stats;
+ new_stats.set_task_name(kTaskName);
+ new_stats.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED, 2000));
+ new_stats.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_SERVER_ABORTED, 2001));
+ *opstats_sequence.add_opstats() = std::move(new_stats);
+
+ auto last_time =
+ GetLastSuccessfulContributionTime(opstats_sequence, kTaskName);
+ EXPECT_EQ(last_time->seconds(), expected_time_sec);
+}
+
+TEST(OpStatsUtils,
+ GetLastSuccessfulContributionTimeReturnNotFoundIfNoUploadStarted) {
+ OperationalStats stats;
+ stats.set_task_name(kTaskName);
+ stats.mutable_events()->Add(
+ CreateEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED, 1000));
+
+ OpStatsSequence opstats_sequence;
+ *opstats_sequence.add_opstats() = std::move(stats);
+ EXPECT_FALSE(GetLastSuccessfulContributionTime(opstats_sequence, kTaskName)
+ .has_value());
+}
+
+} // namespace
+} // namespace opstats
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/opstats/pds_backed_opstats_db.cc b/fcp/client/opstats/pds_backed_opstats_db.cc
new file mode 100644
index 0000000..12113e8
--- /dev/null
+++ b/fcp/client/opstats/pds_backed_opstats_db.cc
@@ -0,0 +1,297 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/opstats/pds_backed_opstats_db.h"
+
+#include <fcntl.h>
+#include <sys/file.h>
+
+#include <filesystem>
+#include <functional>
+#include <string>
+#include <utility>
+
+#include "google/protobuf/util/time_util.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/diag_codes.pb.h"
+#include "fcp/client/log_manager.h"
+#include "protostore/file-storage.h"
+#include "protostore/proto-data-store.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+namespace {
+
+using ::google::protobuf::util::TimeUtil;
+
+ABSL_CONST_INIT absl::Mutex file_lock_mutex(absl::kConstInit);
+
+absl::flat_hash_set<std::string>* GetFilesInUseSet() {
+ // Create the heap allocated static set only once, never call d'tor.
+ // See: go/totw/110
+ static absl::flat_hash_set<std::string>* files_in_use =
+ new absl::flat_hash_set<std::string>();
+ return files_in_use;
+}
+
+absl::StatusOr<int> AcquireFileLock(const std::string& db_path,
+ LogManager& log_manager) {
+ absl::WriterMutexLock lock(&file_lock_mutex);
+ // If the underlying file is already in the hash set, it means another
+ // instance of OpStatsDb is using it, and we'll return an error.
+ absl::flat_hash_set<std::string>* files_in_use = GetFilesInUseSet();
+ if (!files_in_use->insert(db_path).second) {
+ log_manager.LogDiag(ProdDiagCode::OPSTATS_MULTIPLE_DB_INSTANCE_DETECTED);
+ return absl::InternalError(
+ "Another instance is already using the underlying database file.");
+ }
+ // Create a new file descriptor.
+ // Create the file if it doesn't exist, set permission to 0644.
+ int fd = open(db_path.c_str(), O_CREAT | O_RDWR,
+ S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
+ if (fd < 0) {
+ files_in_use->erase(db_path);
+ log_manager.LogDiag(ProdDiagCode::OPSTATS_FAILED_TO_OPEN_FILE);
+ return absl::InternalError(absl::StrCat("Failed to open file: ", db_path));
+ }
+ // Acquire exclusive lock on the file in a non-blocking mode.
+ // flock(2) applies lock on the file object in the open file table, so it can
+ // apply lock across different processes. Within a process, flock doesn't
+ // necessarily guarantee synchronization across multiple threads.
+ // See:https://man7.org/linux/man-pages/man2/flock.2.html
+ if (flock(fd, LOCK_EX | LOCK_NB) < 0) {
+ files_in_use->erase(db_path);
+ close(fd);
+ log_manager.LogDiag(ProdDiagCode::OPSTATS_MULTIPLE_DB_INSTANCE_DETECTED);
+ return absl::InternalError(
+ "Failed to acquire file lock on the underlying database file.");
+ }
+ return fd;
+}
+
+void ReleaseFileLock(const std::string& db_path, int fd) {
+ absl::WriterMutexLock lock(&file_lock_mutex);
+ GetFilesInUseSet()->erase(db_path);
+ FCP_CHECK(fd >= 0);
+ // File lock is released when the descriptor is closed.
+ close(fd);
+}
+
+std::unique_ptr<OpStatsSequence> CreateEmptyData() {
+ auto empty_data = std::make_unique<OpStatsSequence>();
+ *(empty_data->mutable_earliest_trustworthy_time()) =
+ google::protobuf::util::TimeUtil::GetCurrentTime();
+ return empty_data;
+}
+
+// Returns the data in the db, or an error from the read operation.
+absl::StatusOr<OpStatsSequence> ReadInternal(
+ protostore::ProtoDataStore<OpStatsSequence>& db, LogManager& log_manager) {
+ absl::StatusOr<const OpStatsSequence*> data = db.Read();
+ if (data.ok()) {
+ return *data.value();
+ } else {
+ log_manager.LogDiag(ProdDiagCode::OPSTATS_READ_FAILED);
+ return absl::InternalError(
+ absl::StrCat("Failed to read from database, with error message: ",
+ data.status().message()));
+ }
+}
+
+// Overwrites the db to contain an empty OpStatsSequence message.
+absl::Status ResetInternal(protostore::ProtoDataStore<OpStatsSequence>& db,
+ LogManager& log_manager) {
+ absl::Status reset_status = db.Write(CreateEmptyData());
+ if (!reset_status.ok()) {
+ log_manager.LogDiag(ProdDiagCode::OPSTATS_RESET_FAILED);
+ return absl::InternalError(
+ absl::StrCat("Failed to reset the database, with error message: ",
+ reset_status.code()));
+ }
+ return absl::OkStatus();
+}
+
+absl::Time GetLastUpdateTime(const OperationalStats& operational_stats) {
+ if (operational_stats.events().empty()) {
+ return absl::InfinitePast();
+ }
+ return absl::FromUnixSeconds(TimeUtil::TimestampToSeconds(
+ operational_stats.events().rbegin()->timestamp()));
+}
+
+// If there's data, use the timestamp of the first event as the earliest
+// trustworthy time; otherwise, the current time will be used.
+::google::protobuf::Timestamp GetEarliestTrustWorthyTime(
+ const google::protobuf::RepeatedPtrField<OperationalStats>& op_stats) {
+ ::google::protobuf::Timestamp timestamp = TimeUtil::GetCurrentTime();
+ for (const auto& stat : op_stats) {
+ if (!stat.events().empty()) {
+ timestamp = stat.events().begin()->timestamp();
+ break;
+ }
+ }
+ return timestamp;
+}
+
+void RemoveOutdatedData(OpStatsSequence& data, absl::Duration ttl) {
+ absl::Time earliest_accepted_time = absl::Now() - ttl;
+ auto* op_stats = data.mutable_opstats();
+ int64_t original_num_entries = op_stats->size();
+ op_stats->erase(
+ std::remove_if(op_stats->begin(), op_stats->end(),
+ [earliest_accepted_time](const OperationalStats& data) {
+ return GetLastUpdateTime(data) < earliest_accepted_time;
+ }),
+ op_stats->end());
+ int64_t num_entries_after_purging = op_stats->size();
+ if (num_entries_after_purging < original_num_entries) {
+ *(data.mutable_earliest_trustworthy_time()) =
+ TimeUtil::MillisecondsToTimestamp(
+ absl::ToUnixMillis(earliest_accepted_time));
+ }
+}
+
+void PruneOldDataUntilBelowSizeLimit(OpStatsSequence& data,
+ const int64_t max_size_bytes,
+ LogManager& log_manager) {
+ int64_t current_size = data.ByteSizeLong();
+ auto& op_stats = *(data.mutable_opstats());
+ if (current_size > max_size_bytes) {
+ int64_t num_pruned_entries = 0;
+ auto it = op_stats.begin();
+ absl::Time earliest_event_time = absl::InfinitePast();
+ // The OperationalStats are sorted by time from earliest to latest, so we'll
+ // remove from the start.
+ while (current_size > max_size_bytes && it != op_stats.end()) {
+ if (earliest_event_time == absl::InfinitePast()) {
+ earliest_event_time = GetLastUpdateTime(*it);
+ }
+ num_pruned_entries++;
+ // Note that the size of an OperationalStats is smaller than the size
+ // impact it has on the OpStatsSequence. We are being conservative here.
+ current_size -= it->ByteSizeLong();
+ it++;
+ }
+ op_stats.erase(op_stats.begin(), it);
+ *data.mutable_earliest_trustworthy_time() =
+ GetEarliestTrustWorthyTime(op_stats);
+ log_manager.LogToLongHistogram(
+ HistogramCounters::OPSTATS_NUM_PRUNED_ENTRIES, num_pruned_entries);
+ log_manager.LogToLongHistogram(
+ HistogramCounters::OPSTATS_OLDEST_PRUNED_ENTRY_TENURE_HOURS,
+ absl::ToInt64Hours(absl::Now() - earliest_event_time));
+ }
+ log_manager.LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ current_size);
+ log_manager.LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ op_stats.size());
+}
+
+} // anonymous namespace
+
+absl::StatusOr<std::unique_ptr<OpStatsDb>> PdsBackedOpStatsDb::Create(
+ const std::string& base_dir, absl::Duration ttl, LogManager& log_manager,
+ int64_t max_size_bytes) {
+ std::filesystem::path path(base_dir);
+ if (!path.is_absolute()) {
+ log_manager.LogDiag(ProdDiagCode::OPSTATS_INVALID_FILE_PATH);
+ return absl::InvalidArgumentError(
+ absl::StrCat("The provided path: ", base_dir,
+ " is invalid. The path must start with \"/\""));
+ }
+ path /= kParentDir;
+ std::error_code error;
+ std::filesystem::create_directories(path, error);
+ if (error.value() != 0) {
+ log_manager.LogDiag(ProdDiagCode::OPSTATS_PARENT_DIR_CREATION_FAILED);
+ return absl::InternalError(
+ absl::StrCat("Failed to create directory ", path.generic_string()));
+ }
+ path /= kDbFileName;
+ std::function<void()> lock_releaser;
+ auto file_storage = std::make_unique<protostore::FileStorage>();
+ std::unique_ptr<protostore::ProtoDataStore<OpStatsSequence>> pds;
+ std::string db_path = path.generic_string();
+ FCP_ASSIGN_OR_RETURN(int fd, AcquireFileLock(db_path, log_manager));
+ lock_releaser = [db_path, fd]() { ReleaseFileLock(db_path, fd); };
+ pds = std::make_unique<protostore::ProtoDataStore<OpStatsSequence>>(
+ *file_storage, db_path);
+ absl::StatusOr<int64_t> file_size = file_storage->GetFileSize(path);
+ if (!file_size.ok()) {
+ lock_releaser();
+ return file_size.status();
+ }
+ // If the size of the underlying file is zero, it means this is the first
+ // time we create the database.
+ bool should_initiate = file_size.value() == 0;
+
+ // If this is the first time we create the OpStatsDb, we want to create an
+ // empty database.
+ if (should_initiate) {
+ absl::Status write_status = pds->Write(CreateEmptyData());
+ if (!write_status.ok()) {
+ lock_releaser();
+ return write_status;
+ }
+ }
+ return absl::WrapUnique(
+ new PdsBackedOpStatsDb(std::move(pds), std::move(file_storage), ttl,
+ log_manager, max_size_bytes, lock_releaser));
+}
+
+PdsBackedOpStatsDb::~PdsBackedOpStatsDb() { lock_releaser_(); }
+
+absl::StatusOr<OpStatsSequence> PdsBackedOpStatsDb::Read() {
+ absl::WriterMutexLock lock(&mutex_);
+ auto data_or = ReadInternal(*db_, log_manager_);
+ if (!data_or.ok()) {
+ // Try resetting after a failed read.
+ auto reset_status = ResetInternal(*db_, log_manager_);
+ }
+ return data_or;
+}
+
+absl::Status PdsBackedOpStatsDb::Transform(
+ std::function<void(OpStatsSequence&)> func) {
+ absl::WriterMutexLock lock(&mutex_);
+ OpStatsSequence data;
+ auto data_or = ReadInternal(*db_, log_manager_);
+ if (!data_or.ok()) {
+ // Try resetting after a failed read.
+ FCP_RETURN_IF_ERROR(ResetInternal(*db_, log_manager_));
+ } else {
+ data = std::move(data_or).value();
+ RemoveOutdatedData(data, ttl_);
+ }
+ func(data);
+ PruneOldDataUntilBelowSizeLimit(data, max_size_bytes_, log_manager_);
+ if (!data.has_earliest_trustworthy_time()) {
+ *data.mutable_earliest_trustworthy_time() =
+ GetEarliestTrustWorthyTime(data.opstats());
+ }
+ absl::Status status =
+ db_->Write(std::make_unique<OpStatsSequence>(std::move(data)));
+ if (!status.ok()) {
+ log_manager_.LogDiag(ProdDiagCode::OPSTATS_WRITE_FAILED);
+ }
+ return status;
+}
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/opstats/pds_backed_opstats_db.h b/fcp/client/opstats/pds_backed_opstats_db.h
new file mode 100644
index 0000000..0665080
--- /dev/null
+++ b/fcp/client/opstats/pds_backed_opstats_db.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_OPSTATS_PDS_BACKED_OPSTATS_DB_H_
+#define FCP_CLIENT_OPSTATS_PDS_BACKED_OPSTATS_DB_H_
+
+#include <functional>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_db.h"
+#include "fcp/protos/opstats.pb.h"
+#include "protostore/file-storage.h"
+#include "protostore/proto-data-store.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+
+// An implementation of OpStatsDb based on protodatastore cpp.
+class PdsBackedOpStatsDb : public OpStatsDb {
+ public:
+ static constexpr char kParentDir[] = "fcp/opstats";
+ static constexpr char kDbFileName[] = "opstats.pb";
+
+ // Factory method to create PdsBackedOpStatsDb. The provided path is the
+ // absolute path for the base directory for storing files. OpStatsDb will
+ // attempt to create subdirectories and file, so the directory must grant
+ // read/write access. The ttl is the duration that an OperationalStats message
+ // is kept since its last update time.
+ static absl::StatusOr<std::unique_ptr<OpStatsDb>> Create(
+ const std::string& base_dir, absl::Duration ttl, LogManager& log_manager,
+ int64_t max_size_bytes);
+
+ ~PdsBackedOpStatsDb() override;
+
+ // Returns the data in the db, or an error from the read operation. If the
+ // read fails, will try to reset the db to be empty. The returned data is not
+ // necessarily restricted according to the ttl.
+ absl::StatusOr<OpStatsSequence> Read() override ABSL_LOCKS_EXCLUDED(mutex_);
+
+ // Modifies the data in the db based on the supplied transformation function
+ // and ttl restrictions. If there is an error fetching the existing data, the
+ // db is reset. No transformation is applied if the reset fails.
+ absl::Status Transform(std::function<void(OpStatsSequence&)> func) override
+ ABSL_LOCKS_EXCLUDED(mutex_);
+
+ private:
+ PdsBackedOpStatsDb(
+ std::unique_ptr<protostore::ProtoDataStore<OpStatsSequence>> db,
+ std::unique_ptr<protostore::FileStorage> file_storage, absl::Duration ttl,
+ LogManager& log_manager, int64_t max_size_bytes,
+ std::function<void()> lock_releaser)
+ : ttl_(std::move(ttl)),
+ db_(std::move(db)),
+ storage_(std::move(file_storage)),
+ log_manager_(log_manager),
+ max_size_bytes_(max_size_bytes),
+ lock_releaser_(lock_releaser) {}
+
+ const absl::Duration ttl_;
+ std::unique_ptr<protostore::ProtoDataStore<OpStatsSequence>> db_
+ ABSL_GUARDED_BY(mutex_);
+ std::unique_ptr<protostore::FileStorage> storage_;
+ LogManager& log_manager_;
+ const int64_t max_size_bytes_;
+ std::function<void()> lock_releaser_;
+ absl::Mutex mutex_;
+};
+
+} // namespace opstats
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_OPSTATS_PDS_BACKED_OPSTATS_DB_H_
diff --git a/fcp/client/opstats/pds_backed_opstats_db_test.cc b/fcp/client/opstats/pds_backed_opstats_db_test.cc
new file mode 100644
index 0000000..9768152
--- /dev/null
+++ b/fcp/client/opstats/pds_backed_opstats_db_test.cc
@@ -0,0 +1,519 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/client/opstats/pds_backed_opstats_db.h"
+
+#include <filesystem>
+#include <functional>
+#include <string>
+#include <thread> // NOLINT(build/c++11)
+#include <utility>
+
+#include "google/protobuf/util/time_util.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/statusor.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/protos/opstats.pb.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace opstats {
+namespace {
+
+using ::google::protobuf::util::TimeUtil;
+using ::testing::Ge;
+using ::testing::Gt;
+
+const absl::Duration ttl = absl::Hours(24);
+const absl::Time benchmark_time = absl::Now();
+const int64_t benchmark_time_sec = absl::ToUnixSeconds(absl::Now());
+const int64_t size_limit = 1 * 1024 * 1024;
+
+class BasePdsBackedOpStatsDbTest {
+ protected:
+ void SetUpBaseDir() { base_dir_ = testing::TempDir(); }
+
+ void TearDownBaseDir() {
+ std::filesystem::remove(std::filesystem::path(base_dir_) /
+ PdsBackedOpStatsDb::kParentDir /
+ PdsBackedOpStatsDb::kDbFileName);
+ }
+
+ static OperationalStats_Event CreateEvent(
+ OperationalStats::Event::EventKind kind, int64_t time_sec) {
+ OperationalStats_Event event;
+ event.set_event_type(kind);
+ *event.mutable_timestamp() = TimeUtil::SecondsToTimestamp(time_sec);
+ return event;
+ }
+
+ static OperationalStats CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EventKind kind, int64_t time_sec) {
+ OperationalStats op_stats;
+ op_stats.mutable_events()->Add(CreateEvent(kind, time_sec));
+ return op_stats;
+ }
+
+ std::string base_dir_;
+ testing::StrictMock<MockLogManager> log_manager_;
+ absl::Mutex mu_;
+};
+
+class PdsBackedOpStatsDbTest : public BasePdsBackedOpStatsDbTest,
+ public testing::Test {
+ void SetUp() override { SetUpBaseDir(); }
+
+ void TearDown() override { TearDownBaseDir(); }
+};
+
+TEST_F(PdsBackedOpStatsDbTest, FailToCreateParentDirectory) {
+ EXPECT_CALL(log_manager_,
+ LogDiag(ProdDiagCode::OPSTATS_PARENT_DIR_CREATION_FAILED));
+ ASSERT_THAT(
+ PdsBackedOpStatsDb::Create("/proc/0", ttl, log_manager_, size_limit),
+ IsCode(INTERNAL));
+}
+
+TEST_F(PdsBackedOpStatsDbTest, InvalidRelativePath) {
+ EXPECT_CALL(log_manager_, LogDiag(ProdDiagCode::OPSTATS_INVALID_FILE_PATH));
+ ASSERT_THAT(PdsBackedOpStatsDb::Create("relative/opstats", ttl, log_manager_,
+ size_limit),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST_F(PdsBackedOpStatsDbTest, AddOpStats) {
+ auto db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit);
+ ASSERT_THAT(db, IsOk());
+ OperationalStats op_stats = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED, benchmark_time_sec);
+ auto func = [op_stats](OpStatsSequence& data) {
+ *data.add_opstats() = op_stats;
+ };
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Gt(0)));
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/1));
+ ASSERT_OK((*db)->Transform(func));
+ OpStatsSequence expected;
+ *expected.add_opstats() = op_stats;
+ absl::StatusOr<OpStatsSequence> data = (*db)->Read();
+ ASSERT_THAT(data, IsOk());
+ ASSERT_TRUE(data->has_earliest_trustworthy_time());
+ data->clear_earliest_trustworthy_time();
+ EXPECT_THAT(*data, EqualsProto(expected));
+}
+
+TEST_F(PdsBackedOpStatsDbTest, MutateOpStats) {
+ auto db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit);
+ ASSERT_THAT(db, IsOk());
+ auto initialCommit = [](OpStatsSequence& data) {
+ *data.add_opstats() = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED,
+ benchmark_time_sec);
+ };
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Gt(0)))
+ .Times(2);
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/1))
+ .Times(2);
+ ASSERT_OK((*db)->Transform(initialCommit));
+ auto mutate = [](OpStatsSequence& data) {
+ data.mutable_opstats(0)->mutable_events()->Add(
+ CreateEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED,
+ benchmark_time_sec));
+ };
+ ASSERT_OK((*db)->Transform(mutate));
+ OperationalStats expected_op_stats;
+ expected_op_stats.mutable_events()->Add(CreateEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED, benchmark_time_sec));
+ expected_op_stats.mutable_events()->Add(
+ CreateEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED,
+ benchmark_time_sec));
+ OpStatsSequence expected;
+ *expected.add_opstats() = expected_op_stats;
+ absl::StatusOr<OpStatsSequence> data = (*db)->Read();
+ ASSERT_THAT(data, IsOk());
+ ASSERT_TRUE(data->has_earliest_trustworthy_time());
+ data->clear_earliest_trustworthy_time();
+ EXPECT_THAT(*data, EqualsProto(expected));
+}
+
+TEST_F(PdsBackedOpStatsDbTest, LastUpdateTimeIsCorrectlyUsed) {
+ auto db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit);
+ ASSERT_THAT(db, IsOk());
+ OperationalStats op_stats;
+ op_stats.mutable_events()->Add(
+ CreateEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED,
+ absl::ToUnixSeconds(benchmark_time - absl::Hours(48))));
+ op_stats.mutable_events()->Add(
+ CreateEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED,
+ absl::ToUnixSeconds(benchmark_time - absl::Hours(12))));
+ auto initialCommit = [op_stats](OpStatsSequence& data) {
+ *data.add_opstats() = op_stats;
+ };
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Gt(0)))
+ .Times(2);
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/1))
+ .Times(2);
+ ASSERT_OK((*db)->Transform(initialCommit));
+
+ // We do a second unity commit to trigger the ttl cleanup.
+ auto unityCommit = [](OpStatsSequence& data) {};
+ ASSERT_OK((*db)->Transform(unityCommit));
+ OpStatsSequence expected;
+ *expected.add_opstats() = op_stats;
+ absl::StatusOr<OpStatsSequence> data = (*db)->Read();
+ ASSERT_THAT(data, IsOk());
+ ASSERT_TRUE(data->has_earliest_trustworthy_time());
+ data->clear_earliest_trustworthy_time();
+ EXPECT_THAT(*data, EqualsProto(expected));
+}
+
+TEST_F(PdsBackedOpStatsDbTest, NoEventsOpStatsGotRemoved) {
+ auto db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit);
+ ASSERT_THAT(db, IsOk());
+ OperationalStats op_stats;
+ op_stats.set_population_name("population");
+ auto initialCommit = [op_stats](OpStatsSequence& data) {
+ *data.add_opstats() = op_stats;
+ };
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Ge(0)))
+ .Times(2);
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/1));
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/0));
+ ASSERT_OK((*db)->Transform(initialCommit));
+
+ // We do a second unity commit to trigger the ttl cleanup.
+ auto unityCommit = [](OpStatsSequence& data) {};
+ ASSERT_OK((*db)->Transform(unityCommit));
+ absl::StatusOr<OpStatsSequence> data = (*db)->Read();
+ ASSERT_THAT(data, IsOk());
+ ASSERT_TRUE(data->has_earliest_trustworthy_time());
+ data->clear_earliest_trustworthy_time();
+ EXPECT_THAT(*data, EqualsProto(OpStatsSequence::default_instance()));
+}
+
+TEST_F(PdsBackedOpStatsDbTest, TwoInstanceOnTwoThreadsAccessSameFile) {
+ EXPECT_CALL(log_manager_,
+ LogDiag(ProdDiagCode::OPSTATS_MULTIPLE_DB_INSTANCE_DETECTED));
+ std::vector<absl::StatusOr<std::unique_ptr<OpStatsDb>>> results;
+ std::function<void()> init = [&]() {
+ absl::WriterMutexLock lock(&mu_);
+ results.push_back(
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit));
+ };
+ std::thread first_thread(init);
+ std::thread second_thread(init);
+ first_thread.join();
+ second_thread.join();
+ std::set<absl::StatusCode> expected{absl::StatusCode::kOk,
+ absl::StatusCode::kInternal};
+ std::set<absl::StatusCode> status_codes;
+ for (const auto& result : results) {
+ status_codes.insert(result.status().code());
+ }
+ ASSERT_EQ(status_codes, expected);
+}
+
+TEST_F(PdsBackedOpStatsDbTest, TwoInstanceOnTwoThreadsAccessDifferentFile) {
+ std::vector<absl::StatusOr<std::unique_ptr<OpStatsDb>>> results;
+ std::function<void(std::string)> init = [&](std::string thread_id) {
+ absl::WriterMutexLock lock(&mu_);
+ results.push_back(
+ PdsBackedOpStatsDb::Create(absl::StrCat(base_dir_, "/", thread_id), ttl,
+ log_manager_, size_limit));
+ };
+ std::thread first_thread(init, "1");
+ std::thread second_thread(init, "2");
+ first_thread.join();
+ second_thread.join();
+ for (const auto& result : results) {
+ ASSERT_OK(result.status());
+ }
+}
+
+TEST_F(PdsBackedOpStatsDbTest, BackfillEarliestTrustWorthyTime) {
+ OperationalStats first_op_stats = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_TRAIN_NOT_STARTED,
+ benchmark_time_sec);
+ OperationalStats second_op_stats = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_TRAIN_NOT_STARTED,
+ benchmark_time_sec);
+ {
+ absl::StatusOr<std::unique_ptr<OpStatsDb>> db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit);
+ ASSERT_OK(db);
+ auto add = [first_op_stats, second_op_stats](OpStatsSequence& data) {
+ *data.add_opstats() = first_op_stats;
+ *data.add_opstats() = second_op_stats;
+ };
+ auto remove_earliest_trustworthy_time = [](OpStatsSequence& data) {
+ data.clear_earliest_trustworthy_time();
+ };
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Gt(0)))
+ .Times(2);
+ EXPECT_CALL(log_manager_, LogToLongHistogram(
+ HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/2))
+ .Times(2);
+ ASSERT_OK((*db)->Transform(add));
+ ASSERT_OK((*db)->Transform(remove_earliest_trustworthy_time));
+ }
+
+ absl::StatusOr<std::unique_ptr<OpStatsDb>> db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit);
+ ASSERT_OK(db);
+ OperationalStats third_op_stats = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED,
+ benchmark_time_sec + 10);
+ auto add_another = [third_op_stats](OpStatsSequence& data) {
+ *data.add_opstats() = third_op_stats;
+ };
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Gt(0)));
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/3));
+ ASSERT_OK((*db)->Transform(add_another));
+ absl::StatusOr<OpStatsSequence> data = (*db)->Read();
+ ASSERT_OK(data);
+ OpStatsSequence expected;
+ *expected.mutable_earliest_trustworthy_time() =
+ TimeUtil::SecondsToTimestamp(benchmark_time_sec);
+ *expected.add_opstats() = first_op_stats;
+ *expected.add_opstats() = second_op_stats;
+ *expected.add_opstats() = third_op_stats;
+ EXPECT_THAT((*data), EqualsProto(expected));
+}
+
+TEST_F(PdsBackedOpStatsDbTest, ReadEmpty) {
+ ::google::protobuf::Timestamp before_creation_time =
+ TimeUtil::GetCurrentTime();
+ auto db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit);
+ ::google::protobuf::Timestamp after_creation_time =
+ TimeUtil::GetCurrentTime();
+ ASSERT_THAT(db, IsOk());
+ absl::StatusOr<OpStatsSequence> data = (*db)->Read();
+ ASSERT_THAT(data, IsOk());
+ EXPECT_TRUE(data->opstats().empty());
+ EXPECT_TRUE(data->earliest_trustworthy_time() >= before_creation_time);
+ EXPECT_TRUE(data->earliest_trustworthy_time() <= after_creation_time);
+}
+
+TEST_F(PdsBackedOpStatsDbTest, RemoveOpstatsDueToTtl) {
+ auto db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit);
+ ASSERT_THAT(db, IsOk());
+ OperationalStats op_stats_remove = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED,
+ absl::ToUnixSeconds(benchmark_time - absl::Hours(25)));
+ OperationalStats op_stats_keep = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED,
+ absl::ToUnixSeconds(benchmark_time - absl::Hours(23)));
+ auto initialCommit = [op_stats_remove, op_stats_keep](OpStatsSequence& data) {
+ *data.add_opstats() = op_stats_remove;
+ *data.add_opstats() = op_stats_keep;
+ };
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Gt(0)))
+ .Times(2);
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/2));
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/1));
+ ASSERT_OK((*db)->Transform(initialCommit));
+
+ // We do a second unity commit to trigger the ttl cleanup.
+ auto unityCommit = [](OpStatsSequence& data) {};
+ ASSERT_OK((*db)->Transform(unityCommit));
+
+ absl::StatusOr<OpStatsSequence> data = (*db)->Read();
+ ASSERT_THAT(data, IsOk());
+ ASSERT_EQ(data->opstats().size(), 1);
+ ASSERT_THAT(data->opstats()[0], EqualsProto(op_stats_keep));
+ // The TTL is 24 hours, the timestamp should be set to the time when the db
+ // got purged - 24 hours. It should be smaller than the kept
+ // OperationalStats, but larger than benchmark time - 24 hours.
+ google::protobuf::Timestamp lower_bound = TimeUtil::SecondsToTimestamp(
+ absl::ToUnixSeconds(benchmark_time - absl::Hours(24)));
+ google::protobuf::Timestamp upper_bound = TimeUtil::SecondsToTimestamp(
+ absl::ToUnixSeconds(benchmark_time - absl::Hours(23)));
+ EXPECT_TRUE(data->earliest_trustworthy_time() >= lower_bound);
+ EXPECT_TRUE(data->earliest_trustworthy_time() <= upper_bound);
+}
+
+TEST_F(PdsBackedOpStatsDbTest, CorruptedFile) {
+ {
+ std::unique_ptr<OpStatsDb> db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit)
+ .value();
+ auto func = [](OpStatsSequence& data) {
+ *data.add_opstats() = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED,
+ benchmark_time_sec);
+ };
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Gt(0)));
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/1));
+ ASSERT_OK(db->Transform(func));
+ }
+
+ {
+ std::filesystem::path db_path(base_dir_);
+ db_path /= PdsBackedOpStatsDb::kParentDir;
+ db_path /= PdsBackedOpStatsDb::kDbFileName;
+ protostore::FileStorage file_storage;
+ std::unique_ptr<protostore::OutputStream> ostream =
+ file_storage.OpenForWrite(db_path).value();
+ ASSERT_OK(ostream->Append("not a proto"));
+ ASSERT_OK(ostream->Close());
+ }
+
+ std::unique_ptr<OpStatsDb> db =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, size_limit)
+ .value();
+ EXPECT_CALL(log_manager_, LogDiag(ProdDiagCode::OPSTATS_READ_FAILED));
+ ::google::protobuf::Timestamp before_read_time = TimeUtil::GetCurrentTime();
+ ASSERT_THAT(db->Read(), IsCode(INTERNAL));
+ ::google::protobuf::Timestamp after_read_time = TimeUtil::GetCurrentTime();
+
+ // Second read should succeed, and return empty data.
+ absl::StatusOr<OpStatsSequence> data = db->Read();
+ ASSERT_THAT(data, IsOk());
+ EXPECT_TRUE(data->opstats().empty());
+ EXPECT_TRUE(data->earliest_trustworthy_time() >= before_read_time);
+ EXPECT_TRUE(data->earliest_trustworthy_time() <= after_read_time);
+}
+
+TEST_F(PdsBackedOpStatsDbTest, OpStatsRemovedDueToSizeLimit) {
+ // Set size limit to 18, which allow a single OperationalStats with a single
+ // event (12 bytes for OperationalStats, 14 bytes when it is wrapped inside
+ // an OpStatsSequence). If record_earliest_trustworthy_time is true, we'll
+ // increase the size limit to 30 bytes to accommodate the timestamp.
+ int64_t max_size_bytes = 30;
+ absl::StatusOr<std::unique_ptr<OpStatsDb>> db_status =
+ PdsBackedOpStatsDb::Create(base_dir_, ttl, log_manager_, max_size_bytes);
+ ASSERT_THAT(db_status, IsOk());
+ std::unique_ptr<OpStatsDb> db = std::move(db_status.value());
+ EXPECT_CALL(
+ log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_SIZE_BYTES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Gt(0)))
+ .Times(2);
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_DB_NUM_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/1))
+ .Times(2);
+ OperationalStats op_stats = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED, benchmark_time_sec);
+ auto initial_commit = [op_stats](OpStatsSequence& data) {
+ *data.add_opstats() = op_stats;
+ };
+ ASSERT_OK(db->Transform(initial_commit));
+
+ // Add the second event, which will pushes the database size over the limit.
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(HistogramCounters::OPSTATS_NUM_PRUNED_ENTRIES,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/1));
+ EXPECT_CALL(log_manager_,
+ LogToLongHistogram(
+ HistogramCounters::OPSTATS_OLDEST_PRUNED_ENTRY_TENURE_HOURS,
+ /*execution_index=*/0, /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, /*value=*/Ge(0)));
+ OperationalStats another_op_stats = CreateOperationalStatsWithSingleEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED,
+ benchmark_time_sec + 5);
+ auto add = [another_op_stats](OpStatsSequence& data) {
+ *data.add_opstats() = another_op_stats;
+ };
+ ASSERT_OK(db->Transform(add));
+
+ // Verify the first event doesn't exist in the database.
+ OpStatsSequence expected;
+ *expected.add_opstats() = another_op_stats;
+ *expected.mutable_earliest_trustworthy_time() =
+ TimeUtil::SecondsToTimestamp(benchmark_time_sec + 5);
+
+ absl::StatusOr<OpStatsSequence> data = db->Read();
+ ASSERT_THAT(data, IsOk());
+ EXPECT_THAT(*data, EqualsProto(expected));
+}
+
+} // anonymous namespace
+} // namespace opstats
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/parsing_utils.h b/fcp/client/parsing_utils.h
new file mode 100644
index 0000000..c2acb40
--- /dev/null
+++ b/fcp/client/parsing_utils.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_PARSING_UTILS_H_
+#define FCP_CLIENT_PARSING_UTILS_H_
+
+#include <string>
+#include <variant>
+
+#include "absl/strings/cord.h"
+
+namespace fcp {
+namespace client {
+
+// Parses a proto from either an std::string or an absl::Cord. This allows the
+// proto data to be provided in either format.
+template <typename MessageT>
+bool ParseFromStringOrCord(MessageT& proto,
+ std::variant<std::string, absl::Cord> data) {
+ if (std::holds_alternative<std::string>(data)) {
+ return proto.ParseFromString(std::get<std::string>(data));
+ } else {
+ return proto.ParseFromString(std::string(std::get<absl::Cord>(data)));
+ }
+}
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_PARSING_UTILS_H_
diff --git a/fcp/client/phase_logger.h b/fcp/client/phase_logger.h
new file mode 100644
index 0000000..70125bb
--- /dev/null
+++ b/fcp/client/phase_logger.h
@@ -0,0 +1,222 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_PHASE_LOGGER_H_
+#define FCP_CLIENT_PHASE_LOGGER_H_
+
+#include "absl/strings/string_view.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/stats.h"
+#include "fcp/protos/federated_api.pb.h"
+
+namespace fcp {
+namespace client {
+
+class PhaseLogger {
+ public:
+ virtual ~PhaseLogger() = default;
+ virtual void UpdateRetryWindowAndNetworkStats(
+ const ::google::internal::federatedml::v2::RetryWindow& retry_window,
+ const NetworkStats& network_stats) = 0;
+ virtual void SetModelIdentifier(absl::string_view model_identifier) = 0;
+
+ // Called when a run was started but immediately aborted.
+ virtual void LogTaskNotStarted(absl::string_view error_message) = 0;
+ // Called when a run was started but the runtime failed to initialize a
+ // noncritical component, and execution continue.
+ virtual void LogNonfatalInitializationError(absl::Status error_status) = 0;
+ // Called when a run was started but the runtime failed to initialize a
+ // component, and execution was halted.
+ virtual void LogFatalInitializationError(absl::Status error_status) = 0;
+
+ // Eligibility eval check-in phase.
+ // Called when an eligibility eval check-in starts.
+ virtual void LogEligibilityEvalCheckinStarted() = 0;
+ // Called when an IO error is encountered during eligibility eval check-in.
+ virtual void LogEligibilityEvalCheckinIOError(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) = 0;
+ // Called when an invalid payload is received from the eligibility eval
+ // check-in result.
+ virtual void LogEligibilityEvalCheckinInvalidPayloadError(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) = 0;
+ // Called when the eligibility eval check-in is interrupted by the client.
+ virtual void LogEligibilityEvalCheckinClientInterrupted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) = 0;
+ // Called when the eligibility eval check-in is aborted by the server.
+ virtual void LogEligibilityEvalCheckinServerAborted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) = 0;
+ // Called when eligibility eval is not configured.
+ virtual void LogEligibilityEvalNotConfigured(
+ const NetworkStats& network_stats, absl::Time time_before_checkin) = 0;
+ // Called when eligibility eval check-in request is turned away by the server.
+ virtual void LogEligibilityEvalCheckinTurnedAway(
+ const NetworkStats& network_stats, absl::Time time_before_checkin) = 0;
+ virtual void LogEligibilityEvalCheckinPlanUriReceived(
+ const NetworkStats& network_stats, absl::Time time_before_checkin) = 0;
+ // Called when a valid eligibility eval plan is received.
+ virtual void LogEligibilityEvalCheckinCompleted(
+ const NetworkStats& network_stats, absl::Time time_before_checkin,
+ absl::Time time_before_plan_download) = 0;
+
+ // Eligibility eval computation phase.
+ // Called when the eligibility eval computation starts.
+ virtual void LogEligibilityEvalComputationStarted() = 0;
+ // Called when the input parameters for the eligibility eval task are invalid.
+ virtual void LogEligibilityEvalComputationInvalidArgument(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time) = 0;
+ // Called when an example store error happened during eligibility eval
+ // computation.
+ virtual void LogEligibilityEvalComputationExampleIteratorError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time) = 0;
+ // Called when a tensorflow error happened during eligibiliity eval
+ // computation.
+ virtual void LogEligibilityEvalComputationTensorflowError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time) = 0;
+ // Called when the eligibility eval computation is interrupted.
+ virtual void LogEligibilityEvalComputationInterrupted(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time) = 0;
+ // Called when the eligibility eval computation is completed.
+ virtual void LogEligibilityEvalComputationCompleted(
+ const ExampleStats& example_stats, absl::Time run_plan_start_time,
+ absl::Time reference_time) = 0;
+
+ // Check-in phase.
+ // Called when a regular check-in starts.
+ virtual void LogCheckinStarted() = 0;
+ // Called when an IO error occurred during check-in.
+ virtual void LogCheckinIOError(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) = 0;
+ // Called when an invalid payload is received from the check-in result.
+ virtual void LogCheckinInvalidPayload(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) = 0;
+ // Called when check-in is interrupted by the client.
+ virtual void LogCheckinClientInterrupted(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) = 0;
+ // Called when check-in is aborted by the server.
+ virtual void LogCheckinServerAborted(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) = 0;
+ // Called when the client's check-in request is turned away by the server.
+ virtual void LogCheckinTurnedAway(const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) = 0;
+ virtual void LogCheckinPlanUriReceived(absl::string_view task_name,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin) = 0;
+ // Called when check-in is completed.
+ virtual void LogCheckinCompleted(absl::string_view task_name,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time time_before_plan_download,
+ absl::Time reference_time) = 0;
+
+ // Computation phase.
+ // Called when computation started.
+ virtual void LogComputationStarted() = 0;
+ // Called when the input parameters are invalid.
+ virtual void LogComputationInvalidArgument(
+ absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats, absl::Time run_plan_start_time) = 0;
+ // Called when an example store error occurred during computation.
+ virtual void LogComputationExampleIteratorError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats, absl::Time run_plan_start_time) = 0;
+ // Called when an IO error happened during computation
+ virtual void LogComputationIOError(absl::Status error_status,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time) = 0;
+ // Called when a tensorflow error happened during computation.
+ virtual void LogComputationTensorflowError(absl::Status error_status,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time,
+ absl::Time reference_time) = 0;
+ // Called when computation is interrupted.
+ virtual void LogComputationInterrupted(absl::Status error_status,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time,
+ absl::Time reference_time) = 0;
+ // Called when computation is completed.
+ virtual void LogComputationCompleted(const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time,
+ absl::Time reference_time) = 0;
+
+ // Result upload phase. Result upload only happens when all the previous
+ // phases succeed.
+ // Called when result upload started.
+ virtual absl::Status LogResultUploadStarted() = 0;
+ // Called when an IO error occurred during result upload.
+ virtual void LogResultUploadIOError(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_result_upload,
+ absl::Time reference_time) = 0;
+ // Called when the result upload is interrupted by the client.
+ virtual void LogResultUploadClientInterrupted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time) = 0;
+ // Called when the result upload is aborted by the server.
+ virtual void LogResultUploadServerAborted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time) = 0;
+ // Called when result upload is completed.
+ virtual void LogResultUploadCompleted(const NetworkStats& network_stats,
+ absl::Time time_before_result_upload,
+ absl::Time reference_time) = 0;
+
+ // Failure upload phase. Failure upload only happens when any of the previous
+ // phases failed.
+ // Called when failure upload starts.
+ virtual absl::Status LogFailureUploadStarted() = 0;
+ // Called when an IO error occurred during failure upload.
+ virtual void LogFailureUploadIOError(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time) = 0;
+ // Called when the failure upload is interrupted by the client.
+ virtual void LogFailureUploadClientInterrupted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload, absl::Time reference_time) = 0;
+ // Called when the failure upload is aborted by the server.
+ virtual void LogFailureUploadServerAborted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload, absl::Time reference_time) = 0;
+ // Called when the failure upload is completed.
+ virtual void LogFailureUploadCompleted(const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time) = 0;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_PHASE_LOGGER_H_
diff --git a/fcp/client/phase_logger_impl.cc b/fcp/client/phase_logger_impl.cc
new file mode 100644
index 0000000..ecd7e61
--- /dev/null
+++ b/fcp/client/phase_logger_impl.cc
@@ -0,0 +1,638 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/phase_logger_impl.h"
+
+#include <string>
+
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace client {
+namespace {
+constexpr absl::string_view kInitializationErrorPrefix =
+ "Error during initialization: ";
+constexpr absl::string_view kEligibilityCheckinErrorPrefix =
+ "Error during eligibility check-in: ";
+constexpr absl::string_view kEligibilityComputationErrorPrefix =
+ "Error during eligibility eval computation: ";
+constexpr absl::string_view kCheckinErrorPrefix = "Error during check-in: ";
+constexpr absl::string_view kComputationErrorPrefix =
+ "Error during computation: ";
+constexpr absl::string_view kResultUploadErrorPrefix =
+ "Error reporting results: ";
+constexpr absl::string_view kFailureUploadErrorPrefix =
+ "Error reporting computation failure: ";
+} // anonymous namespace
+
+using ::fcp::client::opstats::OperationalStats;
+using ::google::internal::federatedml::v2::RetryWindow;
+
+void PhaseLoggerImpl::UpdateRetryWindowAndNetworkStats(
+ const RetryWindow& retry_window, const NetworkStats& network_stats) {
+ opstats_logger_->SetRetryWindow(retry_window);
+
+ // Update the network stats.
+ opstats_logger_->SetNetworkStats(network_stats);
+}
+
+void PhaseLoggerImpl::SetModelIdentifier(absl::string_view model_identifier) {
+ event_publisher_->SetModelIdentifier(std::string(model_identifier));
+ log_manager_->SetModelIdentifier(std::string(model_identifier));
+}
+
+void PhaseLoggerImpl::LogTaskNotStarted(absl::string_view error_message) {
+ event_publisher_->PublishTaskNotStarted(error_message);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_TRAIN_NOT_STARTED,
+ std::string(error_message));
+}
+
+void PhaseLoggerImpl::LogNonfatalInitializationError(
+ absl::Status error_status) {
+ std::string error_message = GetErrorMessage(
+ error_status, kInitializationErrorPrefix, /* keep_error_message= */ true);
+ event_publisher_->PublishNonfatalInitializationError(error_message);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_INITIALIZATION_ERROR_FATAL,
+ error_message);
+}
+
+void PhaseLoggerImpl::LogFatalInitializationError(absl::Status error_status) {
+ std::string error_message = GetErrorMessage(
+ error_status, kInitializationErrorPrefix, /* keep_error_message= */ true);
+ event_publisher_->PublishFatalInitializationError(error_message);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_INITIALIZATION_ERROR_NONFATAL,
+ error_message);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalCheckinStarted() {
+ event_publisher_->PublishEligibilityEvalCheckin();
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalCheckinIOError(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) {
+ std::string error_message =
+ GetErrorMessage(error_status, kEligibilityCheckinErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishEligibilityEvalCheckinIoError(
+ error_message, network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_ERROR_IO,
+ error_message);
+ LogEligibilityEvalCheckinLatency(time_before_checkin);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalCheckinClientInterrupted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) {
+ std::string error_message =
+ GetErrorMessage(error_status, kEligibilityCheckinErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishEligibilityEvalCheckinClientInterrupted(
+ error_message, network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_CHECKIN_CLIENT_INTERRUPTED,
+ error_message);
+ LogEligibilityEvalCheckinLatency(time_before_checkin);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalCheckinServerAborted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) {
+ std::string error_message =
+ GetErrorMessage(error_status, kEligibilityCheckinErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishEligibilityEvalCheckinServerAborted(
+ error_message, network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_SERVER_ABORTED,
+ error_message);
+ LogEligibilityEvalCheckinLatency(time_before_checkin);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalNotConfigured(
+ const NetworkStats& network_stats, absl::Time time_before_checkin) {
+ event_publisher_->PublishEligibilityEvalNotConfigured(
+ network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_DISABLED);
+ LogEligibilityEvalCheckinLatency(time_before_checkin);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalCheckinTurnedAway(
+ const NetworkStats& network_stats, absl::Time time_before_checkin) {
+ event_publisher_->PublishEligibilityEvalRejected(
+ network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED);
+ LogEligibilityEvalCheckinLatency(time_before_checkin);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalCheckinInvalidPayloadError(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) {
+ log_manager_->LogDiag(
+ ProdDiagCode::
+ BACKGROUND_TRAINING_ELIGIBILITY_EVAL_FAILED_CANNOT_PARSE_PLAN);
+ event_publisher_->PublishEligibilityEvalCheckinErrorInvalidPayload(
+ error_message, network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_CHECKIN_ERROR_INVALID_PAYLOAD,
+ std::string(error_message));
+ LogEligibilityEvalCheckinLatency(time_before_checkin);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalCheckinPlanUriReceived(
+ const NetworkStats& network_stats, absl::Time time_before_checkin) {
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_PLAN_URI_RECEIVED);
+ event_publisher_->PublishEligibilityEvalPlanUriReceived(
+ network_stats, absl::Now() - time_before_checkin);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalCheckinCompleted(
+ const NetworkStats& network_stats, absl::Time time_before_checkin,
+ absl::Time time_before_plan_download) {
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_ENABLED);
+ absl::Time before_time = time_before_plan_download;
+ event_publisher_->PublishEligibilityEvalPlanReceived(
+ network_stats, absl::Now() - before_time);
+
+ // The 'EligibilityEvalCheckinLatency' should cover the whole period from
+ // eligibility eval checkin to completion (and not just the period from EET
+ // plan URIs being received to completion).
+ LogEligibilityEvalCheckinLatency(time_before_checkin);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalComputationStarted() {
+ event_publisher_->PublishEligibilityEvalComputationStarted();
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_COMPUTATION_STARTED);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalComputationInvalidArgument(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kEligibilityComputationErrorPrefix,
+ /* keep_error_message= */ true);
+ log_manager_->LogDiag(
+ ProdDiagCode::BACKGROUND_TRAINING_FAILED_PLAN_FAILS_SANITY_CHECK);
+ event_publisher_->PublishEligibilityEvalComputationInvalidArgument(
+ error_message, example_stats, absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_INVALID_ARGUMENT,
+ error_message);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalComputationExampleIteratorError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kEligibilityComputationErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishEligibilityEvalComputationExampleIteratorError(
+ error_message, example_stats, absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_EXAMPLE_ITERATOR,
+ error_message);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalComputationTensorflowError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kEligibilityComputationErrorPrefix,
+ log_tensorflow_error_messages_);
+ event_publisher_->PublishEligibilityEvalComputationTensorflowError(
+ error_message, example_stats, absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_TENSORFLOW,
+ error_message);
+ LogEligibilityEvalComputationLatency(run_plan_start_time, reference_time);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalComputationInterrupted(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kEligibilityComputationErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishEligibilityEvalComputationInterrupted(
+ error_message, example_stats, absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_CLIENT_INTERRUPTED,
+ error_message);
+ LogEligibilityEvalComputationLatency(run_plan_start_time, reference_time);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalComputationCompleted(
+ const ExampleStats& example_stats, absl::Time run_plan_start_time,
+ absl::Time reference_time) {
+ event_publisher_->PublishEligibilityEvalComputationCompleted(
+ example_stats, absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_COMPUTATION_FINISHED);
+ log_manager_->LogToLongHistogram(
+ HistogramCounters::TRAINING_OVERALL_EXAMPLE_SIZE,
+ example_stats.example_size_bytes);
+ log_manager_->LogToLongHistogram(
+ HistogramCounters::TRAINING_OVERALL_EXAMPLE_COUNT,
+ example_stats.example_count);
+ LogEligibilityEvalComputationLatency(run_plan_start_time, reference_time);
+}
+
+void PhaseLoggerImpl::LogCheckinStarted() {
+ // Log that we are about to check in with the server.
+ event_publisher_->PublishCheckin();
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED);
+}
+
+void PhaseLoggerImpl::LogCheckinIOError(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) {
+ std::string error_message = GetErrorMessage(error_status, kCheckinErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishCheckinIoError(error_message, network_stats,
+ absl::Now() - time_before_checkin);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_ERROR_IO, error_message);
+ LogCheckinLatency(time_before_checkin, reference_time);
+}
+
+void PhaseLoggerImpl::LogCheckinClientInterrupted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time) {
+ std::string error_message = GetErrorMessage(error_status, kCheckinErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishCheckinClientInterrupted(
+ error_message, network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_CLIENT_INTERRUPTED,
+ error_message);
+ LogCheckinLatency(time_before_checkin, reference_time);
+}
+
+void PhaseLoggerImpl::LogCheckinServerAborted(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) {
+ std::string error_message = GetErrorMessage(error_status, kCheckinErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishCheckinServerAborted(
+ error_message, network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_SERVER_ABORTED,
+ error_message);
+ LogCheckinLatency(time_before_checkin, reference_time);
+}
+
+void PhaseLoggerImpl::LogCheckinTurnedAway(const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) {
+ event_publisher_->PublishRejected(network_stats,
+ absl::Now() - time_before_checkin);
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_REJECTED);
+ LogCheckinLatency(time_before_checkin, reference_time);
+}
+
+void PhaseLoggerImpl::LogCheckinInvalidPayload(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time) {
+ log_manager_->LogDiag(
+ ProdDiagCode::BACKGROUND_TRAINING_FAILED_CANNOT_PARSE_PLAN);
+ event_publisher_->PublishCheckinInvalidPayload(
+ error_message, network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_ERROR_INVALID_PAYLOAD,
+ std::string(error_message));
+ LogCheckinLatency(time_before_checkin, reference_time);
+}
+
+void PhaseLoggerImpl::LogCheckinPlanUriReceived(
+ absl::string_view task_name, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) {
+ event_publisher_->PublishCheckinPlanUriReceived(
+ network_stats, absl::Now() - time_before_checkin);
+ opstats_logger_->AddEventAndSetTaskName(
+ std::string(task_name),
+ OperationalStats::Event::EVENT_KIND_CHECKIN_PLAN_URI_RECEIVED);
+}
+
+void PhaseLoggerImpl::LogCheckinCompleted(absl::string_view task_name,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time time_before_plan_download,
+ absl::Time reference_time) {
+ absl::Duration duration = absl::Now() - time_before_plan_download;
+ event_publisher_->PublishCheckinFinishedV2(network_stats, duration);
+ // We already have set the task name when LogCheckinPlanUriReceived was
+ // called, so we only have to add the event.
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED);
+ // The 'EligibilityEvalCheckinLatency' should cover the whole period from
+ // eligibility eval checkin to completion (and not just the period from EET
+ // plan URIs being received to completion).
+ LogCheckinLatency(time_before_checkin, reference_time);
+}
+
+void PhaseLoggerImpl::LogComputationStarted() {
+ event_publisher_->PublishComputationStarted();
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED);
+}
+
+void PhaseLoggerImpl::LogComputationInvalidArgument(
+ absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats, absl::Time run_plan_start_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kComputationErrorPrefix,
+ /* keep_error_message= */ true);
+ log_manager_->LogDiag(
+ ProdDiagCode::BACKGROUND_TRAINING_FAILED_PLAN_FAILS_SANITY_CHECK);
+ event_publisher_->PublishComputationInvalidArgument(
+ error_message, example_stats, network_stats,
+ absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_ERROR_INVALID_ARGUMENT,
+ error_message);
+}
+
+void PhaseLoggerImpl::LogComputationIOError(absl::Status error_status,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kComputationErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishComputationIOError(
+ error_message, example_stats, network_stats,
+ absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_ERROR_IO, error_message);
+}
+
+void PhaseLoggerImpl::LogComputationExampleIteratorError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats, absl::Time run_plan_start_time) {
+ std::string error_message = GetErrorMessage(
+ error_status, kComputationErrorPrefix, /* keep_error_message= */ true);
+ event_publisher_->PublishComputationExampleIteratorError(
+ error_message, example_stats, network_stats,
+ absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_ERROR_EXAMPLE_ITERATOR,
+ error_message);
+}
+
+void PhaseLoggerImpl::LogComputationTensorflowError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats, absl::Time run_plan_start_time,
+ absl::Time reference_time) {
+ std::string error_message = GetErrorMessage(
+ error_status, kComputationErrorPrefix, log_tensorflow_error_messages_);
+ event_publisher_->PublishComputationTensorflowError(
+ error_message, example_stats, network_stats,
+ absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_ERROR_TENSORFLOW,
+ error_message);
+ LogComputationLatency(run_plan_start_time, reference_time);
+}
+
+void PhaseLoggerImpl::LogComputationInterrupted(
+ absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats, absl::Time run_plan_start_time,
+ absl::Time reference_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kComputationErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishComputationInterrupted(
+ error_message, example_stats, network_stats,
+ absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_CLIENT_INTERRUPTED,
+ error_message);
+ LogComputationLatency(run_plan_start_time, reference_time);
+}
+
+void PhaseLoggerImpl::LogComputationCompleted(const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time,
+ absl::Time reference_time) {
+ event_publisher_->PublishComputationCompleted(
+ example_stats, network_stats, absl::Now() - run_plan_start_time);
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED);
+ log_manager_->LogToLongHistogram(
+ HistogramCounters::TRAINING_OVERALL_EXAMPLE_SIZE,
+ example_stats.example_size_bytes);
+ log_manager_->LogToLongHistogram(
+ HistogramCounters::TRAINING_OVERALL_EXAMPLE_COUNT,
+ example_stats.example_count);
+ LogComputationLatency(run_plan_start_time, reference_time);
+}
+
+absl::Status PhaseLoggerImpl::LogResultUploadStarted() {
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED);
+ // Commit the run data accumulated thus far to Opstats and fail if
+ // something goes wrong.
+ FCP_RETURN_IF_ERROR(opstats_logger_->CommitToStorage());
+ event_publisher_->PublishResultUploadStarted();
+ return absl::OkStatus();
+}
+
+void PhaseLoggerImpl::LogResultUploadIOError(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kResultUploadErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishResultUploadIOError(
+ error_message, network_stats, absl::Now() - time_before_result_upload);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_ERROR_IO,
+ error_message);
+ LogReportLatency(time_before_result_upload, reference_time);
+}
+
+void PhaseLoggerImpl::LogResultUploadClientInterrupted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kResultUploadErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishResultUploadClientInterrupted(
+ error_message, network_stats, absl::Now() - time_before_result_upload);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_CLIENT_INTERRUPTED,
+ error_message);
+ LogReportLatency(time_before_result_upload, reference_time);
+}
+
+void PhaseLoggerImpl::LogResultUploadServerAborted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kResultUploadErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishResultUploadServerAborted(
+ error_message, network_stats, absl::Now() - time_before_result_upload);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_SERVER_ABORTED,
+ error_message);
+ LogReportLatency(time_before_result_upload, reference_time);
+}
+
+void PhaseLoggerImpl::LogResultUploadCompleted(
+ const NetworkStats& network_stats, absl::Time time_before_result_upload,
+ absl::Time reference_time) {
+ event_publisher_->PublishResultUploadCompleted(
+ network_stats, absl::Now() - time_before_result_upload);
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_FINISHED);
+ LogReportLatency(time_before_result_upload, reference_time);
+}
+
+absl::Status PhaseLoggerImpl::LogFailureUploadStarted() {
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_STARTED);
+ // Commit the run data accumulated thus far to Opstats and fail if
+ // something goes wrong.
+ FCP_RETURN_IF_ERROR(opstats_logger_->CommitToStorage());
+ event_publisher_->PublishFailureUploadStarted();
+ return absl::OkStatus();
+}
+
+void PhaseLoggerImpl::LogFailureUploadIOError(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload, absl::Time reference_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kFailureUploadErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishFailureUploadIOError(
+ error_message, network_stats, absl::Now() - time_before_failure_upload);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_ERROR_IO,
+ error_message);
+ LogReportLatency(time_before_failure_upload, reference_time);
+}
+
+void PhaseLoggerImpl::LogFailureUploadClientInterrupted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload, absl::Time reference_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kFailureUploadErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishFailureUploadClientInterrupted(
+ error_message, network_stats, absl::Now() - time_before_failure_upload);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_CLIENT_INTERRUPTED,
+ error_message);
+ LogReportLatency(time_before_failure_upload, reference_time);
+}
+
+void PhaseLoggerImpl::LogFailureUploadServerAborted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload, absl::Time reference_time) {
+ std::string error_message =
+ GetErrorMessage(error_status, kFailureUploadErrorPrefix,
+ /* keep_error_message= */ true);
+ event_publisher_->PublishFailureUploadServerAborted(
+ error_message, network_stats, absl::Now() - time_before_failure_upload);
+ opstats_logger_->AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_SERVER_ABORTED,
+ error_message);
+ LogReportLatency(time_before_failure_upload, reference_time);
+}
+
+void PhaseLoggerImpl::LogFailureUploadCompleted(
+ const NetworkStats& network_stats, absl::Time time_before_failure_upload,
+ absl::Time reference_time) {
+ event_publisher_->PublishFailureUploadCompleted(
+ network_stats, absl::Now() - time_before_failure_upload);
+ opstats_logger_->AddEvent(
+ OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_FINISHED);
+ LogReportLatency(time_before_failure_upload, reference_time);
+}
+
+void PhaseLoggerImpl::LogTimeSince(HistogramCounters histogram_counter,
+ absl::Time reference_time) {
+ absl::Duration duration = absl::Now() - reference_time;
+ log_manager_->LogToLongHistogram(histogram_counter,
+ absl::ToInt64Milliseconds(duration));
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalCheckinLatency(
+ absl::Time time_before_checkin) {
+ LogTimeSince(HistogramCounters::TRAINING_FL_ELIGIBILITY_EVAL_CHECKIN_LATENCY,
+ time_before_checkin);
+}
+
+void PhaseLoggerImpl::LogEligibilityEvalComputationLatency(
+ absl::Time run_plan_start_time, absl::Time reference_time) {
+ LogTimeSince(HistogramCounters::TRAINING_RUN_PHASE_LATENCY,
+ run_plan_start_time);
+ LogTimeSince(HistogramCounters::TRAINING_RUN_PHASE_END_TIME, reference_time);
+}
+
+void PhaseLoggerImpl::LogCheckinLatency(absl::Time time_before_checkin,
+ absl::Time reference_time) {
+ LogTimeSince(HistogramCounters::TRAINING_FL_CHECKIN_LATENCY,
+ time_before_checkin);
+ LogTimeSince(HistogramCounters::TRAINING_FL_CHECKIN_END_TIME, reference_time);
+}
+
+void PhaseLoggerImpl::LogComputationLatency(absl::Time run_plan_start_time,
+ absl::Time reference_time) {
+ LogTimeSince(HistogramCounters::TRAINING_RUN_PHASE_LATENCY,
+ run_plan_start_time);
+ LogTimeSince(HistogramCounters::TRAINING_RUN_PHASE_END_TIME, reference_time);
+}
+
+void PhaseLoggerImpl::LogReportLatency(absl::Time time_before_report,
+ absl::Time reference_time) {
+ LogTimeSince(HistogramCounters::TRAINING_FL_REPORT_RESULTS_LATENCY,
+ time_before_report);
+ LogTimeSince(HistogramCounters::TRAINING_FL_REPORT_RESULTS_END_TIME,
+ reference_time);
+}
+
+std::string PhaseLoggerImpl::GetErrorMessage(absl::Status error_status,
+ absl::string_view error_prefix,
+ bool keep_error_message) {
+ return absl::StrCat(error_prefix, "code: ", error_status.code(), ", error: ",
+ keep_error_message ? error_status.message() : "");
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/phase_logger_impl.h b/fcp/client/phase_logger_impl.h
new file mode 100644
index 0000000..3cff41c
--- /dev/null
+++ b/fcp/client/phase_logger_impl.h
@@ -0,0 +1,214 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_PHASE_LOGGER_IMPL_H_
+#define FCP_CLIENT_PHASE_LOGGER_IMPL_H_
+
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/phase_logger.h"
+#include "fcp/protos/federated_api.pb.h"
+
+namespace fcp {
+namespace client {
+
+class PhaseLoggerImpl : public PhaseLogger {
+ public:
+ PhaseLoggerImpl(EventPublisher* event_publisher,
+ opstats::OpStatsLogger* opstats_logger,
+ LogManager* log_manager, const Flags* flags)
+ : event_publisher_(event_publisher),
+ opstats_logger_(opstats_logger),
+ log_manager_(log_manager),
+ log_tensorflow_error_messages_(flags->log_tensorflow_error_messages()) {
+ }
+
+ void UpdateRetryWindowAndNetworkStats(
+ const ::google::internal::federatedml::v2::RetryWindow& retry_window,
+ const NetworkStats& network_stats) override;
+ void SetModelIdentifier(absl::string_view model_identifier) override;
+ void LogTaskNotStarted(absl::string_view error_message) override;
+ void LogNonfatalInitializationError(absl::Status error_status) override;
+ void LogFatalInitializationError(absl::Status error_status) override;
+
+ // Eligibility eval check-in phase.
+ void LogEligibilityEvalCheckinStarted() override;
+ void LogEligibilityEvalCheckinIOError(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) override;
+ void LogEligibilityEvalCheckinInvalidPayloadError(
+ absl::string_view error_message, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) override;
+ void LogEligibilityEvalCheckinClientInterrupted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) override;
+ void LogEligibilityEvalCheckinServerAborted(
+ absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin) override;
+ void LogEligibilityEvalNotConfigured(const NetworkStats& network_stats,
+ absl::Time time_before_checkin) override;
+ void LogEligibilityEvalCheckinTurnedAway(
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin) override;
+ void LogEligibilityEvalCheckinPlanUriReceived(
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin) override;
+ void LogEligibilityEvalCheckinCompleted(
+ const NetworkStats& network_stats, absl::Time time_before_checkin,
+ absl::Time time_before_plan_download) override;
+
+ // Eligibility eval computation phase.
+ void LogEligibilityEvalComputationStarted() override;
+ void LogEligibilityEvalComputationInvalidArgument(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time) override;
+ void LogEligibilityEvalComputationExampleIteratorError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time) override;
+ void LogEligibilityEvalComputationTensorflowError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time) override;
+ void LogEligibilityEvalComputationInterrupted(
+ absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time) override;
+ void LogEligibilityEvalComputationCompleted(
+ const ExampleStats& example_stats, absl::Time run_plan_start_time,
+ absl::Time reference_time) override;
+
+ // Check-in phase.
+ void LogCheckinStarted() override;
+ void LogCheckinIOError(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) override;
+ void LogCheckinInvalidPayload(absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) override;
+ void LogCheckinClientInterrupted(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) override;
+ void LogCheckinServerAborted(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) override;
+ void LogCheckinTurnedAway(const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time reference_time) override;
+ void LogCheckinPlanUriReceived(absl::string_view task_name,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin) override;
+ void LogCheckinCompleted(absl::string_view task_name,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time time_before_plan_download,
+ absl::Time reference_time) override;
+
+ // Computation phase.
+ void LogComputationStarted() override;
+ void LogComputationInvalidArgument(absl::Status error_status,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time) override;
+ void LogComputationExampleIteratorError(
+ absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time) override;
+ void LogComputationIOError(absl::Status error_status,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time) override;
+ void LogComputationTensorflowError(absl::Status error_status,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time,
+ absl::Time reference_time) override;
+ void LogComputationInterrupted(absl::Status error_status,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time,
+ absl::Time reference_time) override;
+ void LogComputationCompleted(const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time,
+ absl::Time reference_time) override;
+
+ absl::Status LogResultUploadStarted() override;
+ void LogResultUploadIOError(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_result_upload,
+ absl::Time reference_time) override;
+ void LogResultUploadClientInterrupted(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_result_upload,
+ absl::Time reference_time) override;
+ void LogResultUploadServerAborted(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_result_upload,
+ absl::Time reference_time) override;
+ void LogResultUploadCompleted(const NetworkStats& network_stats,
+ absl::Time time_before_result_upload,
+ absl::Time reference_time) override;
+
+ // Failure upload phase.
+ absl::Status LogFailureUploadStarted() override;
+ void LogFailureUploadIOError(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time) override;
+ void LogFailureUploadClientInterrupted(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time) override;
+ void LogFailureUploadServerAborted(absl::Status error_status,
+ const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time) override;
+ void LogFailureUploadCompleted(const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time) override;
+
+ private:
+ void LogTimeSince(HistogramCounters histogram_counter,
+ absl::Time reference_time);
+ void LogEligibilityEvalCheckinLatency(absl::Time time_before_checkin);
+ void LogEligibilityEvalComputationLatency(absl::Time run_plan_start_time,
+ absl::Time reference_time);
+ void LogCheckinLatency(absl::Time time_before_checkin,
+ absl::Time reference_time);
+ void LogComputationLatency(absl::Time run_plan_start_time,
+ absl::Time reference_time);
+ void LogReportLatency(absl::Time time_before_report,
+ absl::Time reference_time);
+ std::string GetErrorMessage(absl::Status error_status,
+ absl::string_view error_prefix,
+ bool keep_error_message);
+
+ EventPublisher* event_publisher_;
+ opstats::OpStatsLogger* opstats_logger_;
+ LogManager* log_manager_;
+ const bool log_tensorflow_error_messages_;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_PHASE_LOGGER_IMPL_H_
diff --git a/fcp/client/phase_logger_impl_test.cc b/fcp/client/phase_logger_impl_test.cc
new file mode 100644
index 0000000..f4712d0
--- /dev/null
+++ b/fcp/client/phase_logger_impl_test.cc
@@ -0,0 +1,916 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/phase_logger_impl.h"
+
+#include <string>
+#include <tuple>
+
+#include "google/protobuf/util/time_util.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace {
+
+using ::fcp::client::opstats::OperationalStats;
+using ::google::internal::federatedml::v2::RetryWindow;
+using ::google::protobuf::util::TimeUtil;
+using ::testing::_;
+using ::testing::AllOf;
+using ::testing::Ge;
+using ::testing::InSequence;
+using ::testing::Lt;
+using ::testing::Return;
+using ::testing::StrictMock;
+
+const int64_t kChunkingLayerBytesReceived = 100;
+const int64_t kChunkingLayerBytesSent = 50;
+const int kTotalExampleCount = 10;
+const int64_t kTotalExampleSizeBytes = 1000;
+
+// Parameterize tests with
+// 1) whether log tf error message;
+// 2) whether separate "plan URI received" events should be logged.
+class PhaseLoggerImplTest : public testing::TestWithParam<bool> {
+ protected:
+ void SetUp() override {
+ log_tensorflow_error_messages_ = GetParam();
+ EXPECT_CALL(mock_flags_, log_tensorflow_error_messages())
+ .WillRepeatedly(Return(log_tensorflow_error_messages_));
+
+ phase_logger_ = std::make_unique<PhaseLoggerImpl>(
+ &mock_event_publisher_, &mock_opstats_logger_, &mock_log_manager_,
+ &mock_flags_);
+ }
+
+ void VerifyCounterLogged(HistogramCounters counter,
+ const testing::Matcher<int64_t>& matcher) {
+ EXPECT_CALL(mock_log_manager_,
+ LogToLongHistogram(counter, /*execution_index=*/0,
+ /*epoch_index=*/0,
+ engine::DataSourceType::DATASET, matcher));
+ }
+
+ StrictMock<MockLogManager> mock_log_manager_;
+ StrictMock<MockEventPublisher> mock_event_publisher_;
+ StrictMock<MockOpStatsLogger> mock_opstats_logger_;
+ MockFlags mock_flags_;
+ bool log_tensorflow_error_messages_ = false;
+ std::unique_ptr<PhaseLoggerImpl> phase_logger_;
+ NetworkStats network_stats_ = {
+ .bytes_downloaded = kChunkingLayerBytesReceived,
+ .bytes_uploaded = kChunkingLayerBytesSent};
+ ExampleStats example_stats_ = {.example_count = kTotalExampleCount,
+ .example_size_bytes = kTotalExampleSizeBytes};
+};
+
+std::string GenerateTestName(
+ const testing::TestParamInfo<PhaseLoggerImplTest::ParamType>& info) {
+ std::string name =
+ absl::StrCat(info.param ? "Log_tf_error" : "No_tf_error", "__");
+ return name;
+}
+
+INSTANTIATE_TEST_SUITE_P(OldVsNewBehavior, PhaseLoggerImplTest, testing::Bool(),
+ GenerateTestName);
+
+TEST_P(PhaseLoggerImplTest, UpdateRetryWindowAndNetworkStats) {
+ RetryWindow retry_window;
+ *retry_window.mutable_retry_token() = "retry_token";
+ *retry_window.mutable_delay_max() = TimeUtil::HoursToDuration(48);
+ *retry_window.mutable_delay_min() = TimeUtil::HoursToDuration(4);
+
+ InSequence seq;
+ EXPECT_CALL(mock_opstats_logger_, SetRetryWindow(EqualsProto(retry_window)));
+ EXPECT_CALL(mock_opstats_logger_,
+ SetNetworkStats(testing::Eq(network_stats_)));
+
+ phase_logger_->UpdateRetryWindowAndNetworkStats(retry_window, network_stats_);
+}
+
+TEST_P(PhaseLoggerImplTest, SetModelIdentifier) {
+ std::string model_identifier = "model_identifier";
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_, SetModelIdentifier(model_identifier));
+ EXPECT_CALL(mock_log_manager_, SetModelIdentifier(model_identifier));
+
+ phase_logger_->SetModelIdentifier(model_identifier);
+}
+
+TEST_P(PhaseLoggerImplTest, LogTaskNotStarted) {
+ std::string error_message = "Client is not ready for training.";
+ EXPECT_CALL(mock_event_publisher_, PublishTaskNotStarted(error_message));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_TRAIN_NOT_STARTED,
+ error_message));
+ phase_logger_->LogTaskNotStarted(error_message);
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalCheckinStarted) {
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_, PublishEligibilityEvalCheckin());
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED));
+ phase_logger_->LogEligibilityEvalCheckinStarted();
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalCheckinIOError) {
+ std::string error_message = "Network error";
+ std::string expected_error_message = absl::StrCat(
+ "Error during eligibility check-in: code: 3, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalCheckinIoError(expected_error_message,
+ network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_CHECKIN_ERROR_IO,
+ expected_error_message));
+
+ VerifyCounterLogged(
+ HistogramCounters::TRAINING_FL_ELIGIBILITY_EVAL_CHECKIN_LATENCY, Ge(0));
+ phase_logger_->LogEligibilityEvalCheckinIOError(
+ absl::InvalidArgumentError(error_message), network_stats_, absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalCheckinClientInterrupted) {
+ std::string error_message = "Client is not idle";
+ std::string expected_error_message = absl::StrCat(
+ "Error during eligibility check-in: code: 1, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalCheckinClientInterrupted(
+ expected_error_message, network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_CHECKIN_CLIENT_INTERRUPTED,
+ expected_error_message));
+ VerifyCounterLogged(
+ HistogramCounters::TRAINING_FL_ELIGIBILITY_EVAL_CHECKIN_LATENCY, Ge(0));
+
+ phase_logger_->LogEligibilityEvalCheckinClientInterrupted(
+ absl::CancelledError(error_message), network_stats_, absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalCheckinServerAborted) {
+ std::string error_message = "Connection aborted by the server";
+ std::string expected_error_message = absl::StrCat(
+ "Error during eligibility check-in: code: 10, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalCheckinServerAborted(expected_error_message,
+ network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_CHECKIN_SERVER_ABORTED,
+ expected_error_message));
+ VerifyCounterLogged(
+ HistogramCounters::TRAINING_FL_ELIGIBILITY_EVAL_CHECKIN_LATENCY, Ge(0));
+
+ phase_logger_->LogEligibilityEvalCheckinServerAborted(
+ absl::AbortedError(error_message), network_stats_, absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalNotConfigured) {
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalNotConfigured(network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_ELIGIBILITY_DISABLED));
+ VerifyCounterLogged(
+ HistogramCounters::TRAINING_FL_ELIGIBILITY_EVAL_CHECKIN_LATENCY, Ge(0));
+
+ phase_logger_->LogEligibilityEvalNotConfigured(network_stats_, absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalCheckinTurnedAway) {
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalRejected(network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_ELIGIBILITY_REJECTED));
+ VerifyCounterLogged(
+ HistogramCounters::TRAINING_FL_ELIGIBILITY_EVAL_CHECKIN_LATENCY, Ge(0));
+
+ phase_logger_->LogEligibilityEvalCheckinTurnedAway(network_stats_,
+ absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalCheckinInvalidPayloadError) {
+ std::string error_message = "Cannot parse eligibility eval plan";
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::
+ BACKGROUND_TRAINING_ELIGIBILITY_EVAL_FAILED_CANNOT_PARSE_PLAN));
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalCheckinErrorInvalidPayload(
+ error_message, network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_CHECKIN_ERROR_INVALID_PAYLOAD,
+ error_message));
+ VerifyCounterLogged(
+ HistogramCounters::TRAINING_FL_ELIGIBILITY_EVAL_CHECKIN_LATENCY, Ge(0));
+
+ phase_logger_->LogEligibilityEvalCheckinInvalidPayloadError(
+ error_message, network_stats_, absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalCheckinPlanUriReceived) {
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_PLAN_URI_RECEIVED));
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalPlanUriReceived(
+ network_stats_,
+ AllOf(Ge(absl::Minutes(1)),
+ Lt(absl::Minutes(1) + absl::Milliseconds(10)))));
+
+ phase_logger_->LogEligibilityEvalCheckinPlanUriReceived(
+ network_stats_, absl::Now() - absl::Minutes(1));
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalCheckinCompleted) {
+ NetworkStats network_stats{.bytes_downloaded = 100,
+ .bytes_uploaded = 200,
+ .network_duration = absl::Seconds(40)};
+
+ absl::Duration expected_duration = absl::Minutes(1);
+
+ InSequence seq;
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_ELIGIBILITY_ENABLED));
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalPlanReceived(
+ network_stats,
+ AllOf(Ge(expected_duration),
+ Lt(expected_duration + absl::Milliseconds(10)))));
+ // The counter should always log the full duration, from before the start of
+ // the checkin (regardless of the enable_plan_uri_received_logs_ flag).
+ VerifyCounterLogged(
+ HistogramCounters::TRAINING_FL_ELIGIBILITY_EVAL_CHECKIN_LATENCY,
+ absl::ToInt64Milliseconds(absl::Minutes(2)));
+
+ phase_logger_->LogEligibilityEvalCheckinCompleted(
+ network_stats,
+ /*time_before_checkin=*/absl::Now() - absl::Minutes(2),
+ /*time_before_plan_downloaded=*/absl::Now() - absl::Minutes(1));
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalComputationStarted) {
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalComputationStarted());
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(
+ OperationalStats::Event::EVENT_KIND_ELIGIBILITY_COMPUTATION_STARTED));
+
+ phase_logger_->LogEligibilityEvalComputationStarted();
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalComputationInvalidArgument) {
+ std::string error_message = "Invalid plan.";
+ std::string expected_error_message = absl::StrCat(
+ "Error during eligibility eval computation: code: 3, error: ",
+ error_message);
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::BACKGROUND_TRAINING_FAILED_PLAN_FAILS_SANITY_CHECK));
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalComputationInvalidArgument(
+ expected_error_message, example_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_INVALID_ARGUMENT,
+ expected_error_message));
+ phase_logger_->LogEligibilityEvalComputationInvalidArgument(
+ absl::InvalidArgumentError(error_message), example_stats_, absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalComputationExampleIteratorError) {
+ std::string original_message = "Failed to create example iterator";
+ absl::Status error_status = absl::InvalidArgumentError(original_message);
+ std::string expected_error_message = absl::StrCat(
+ "Error during eligibility eval computation: code: 3, error: ",
+ original_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalComputationExampleIteratorError(
+ expected_error_message, example_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_EXAMPLE_ITERATOR,
+ expected_error_message));
+
+ phase_logger_->LogEligibilityEvalComputationExampleIteratorError(
+ error_status, example_stats_, absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalComputationTensorflowError) {
+ std::string original_message = "Missing kernel for op NotExist";
+ absl::Status error_status = absl::InvalidArgumentError(original_message);
+ std::string expected_error_message = absl::StrCat(
+ "Error during eligibility eval computation: code: 3, error: ");
+ if (log_tensorflow_error_messages_) {
+ absl::StrAppend(&expected_error_message, original_message);
+ }
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalComputationTensorflowError(
+ expected_error_message, example_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_TENSORFLOW,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_END_TIME, Ge(0));
+ phase_logger_->LogEligibilityEvalComputationTensorflowError(
+ error_status, example_stats_, absl::Now() - absl::Minutes(2),
+ absl::Now() - absl::Minutes(5));
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalComputationInterrupted) {
+ std::string error_message = "Client is no longer idle";
+ std::string expected_error_message = absl::StrCat(
+ "Error during eligibility eval computation: code: 1, error: ",
+ error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalComputationInterrupted(
+ expected_error_message, example_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_CLIENT_INTERRUPTED,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_END_TIME, Ge(0));
+
+ phase_logger_->LogEligibilityEvalComputationInterrupted(
+ absl::CancelledError(error_message), example_stats_,
+ absl::Now() - absl::Minutes(2), absl::Now() - absl::Minutes(5));
+}
+
+TEST_P(PhaseLoggerImplTest, LogEligibilityEvalComputationCompleted) {
+ absl::Time run_plan_start_time = absl::Now() - absl::Minutes(8);
+ absl::Time reference_time = absl::Now() - absl::Minutes(9);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishEligibilityEvalComputationCompleted(example_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_FINISHED));
+ VerifyCounterLogged(HistogramCounters::TRAINING_OVERALL_EXAMPLE_SIZE,
+ kTotalExampleSizeBytes);
+ VerifyCounterLogged(HistogramCounters::TRAINING_OVERALL_EXAMPLE_COUNT,
+ kTotalExampleCount);
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_END_TIME, Ge(0));
+
+ phase_logger_->LogEligibilityEvalComputationCompleted(
+ example_stats_, run_plan_start_time, reference_time);
+}
+
+TEST_P(PhaseLoggerImplTest, LogCheckinStarted) {
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_, PublishCheckin());
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_STARTED));
+ phase_logger_->LogCheckinStarted();
+}
+
+TEST_P(PhaseLoggerImplTest, LogCheckinIOError) {
+ std::string error_message = "IO error";
+ std::string expected_error_message =
+ absl::StrCat("Error during check-in: code: 14, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishCheckinIoError(expected_error_message, network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_ERROR_IO,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_END_TIME, Ge(0));
+
+ phase_logger_->LogCheckinIOError(
+ absl::UnavailableError(error_message), network_stats_,
+ absl::Now() - absl::Minutes(2), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogCheckinClientInterrupted) {
+ std::string error_message = "The client is no longer idle";
+ std::string expected_error_message =
+ absl::StrCat("Error during check-in: code: 1, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishCheckinClientInterrupted(expected_error_message,
+ network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_CLIENT_INTERRUPTED,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_END_TIME, Ge(0));
+
+ phase_logger_->LogCheckinClientInterrupted(
+ absl::CancelledError(error_message), network_stats_,
+ absl::Now() - absl::Minutes(2), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogCheckinServerAborted) {
+ std::string error_message = "The request is aborted by the server";
+ std::string expected_error_message =
+ absl::StrCat("Error during check-in: code: 10, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(
+ mock_event_publisher_,
+ PublishCheckinServerAborted(expected_error_message, network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_SERVER_ABORTED,
+ expected_error_message));
+
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_END_TIME, Ge(0));
+
+ phase_logger_->LogCheckinServerAborted(
+ absl::AbortedError(error_message), network_stats_,
+ absl::Now() - absl::Minutes(2), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogCheckinTurnedAway) {
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_, PublishRejected(network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_REJECTED));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_END_TIME, Ge(0));
+
+ phase_logger_->LogCheckinTurnedAway(network_stats_,
+ absl::Now() - absl::Minutes(2),
+ absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogCheckinInvalidPayload) {
+ std::string error_message = "Cannot parse plan";
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(ProdDiagCode::BACKGROUND_TRAINING_FAILED_CANNOT_PARSE_PLAN));
+ EXPECT_CALL(mock_event_publisher_,
+ PublishCheckinInvalidPayload(error_message, network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_CHECKIN_ERROR_INVALID_PAYLOAD,
+ error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_END_TIME, Ge(0));
+
+ phase_logger_->LogCheckinInvalidPayload(error_message, network_stats_,
+ absl::Now() - absl::Minutes(2),
+ absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogCheckinPlanUriReceived) {
+ std::string task_name = "my_task";
+ EXPECT_CALL(mock_event_publisher_,
+ PublishCheckinPlanUriReceived(
+ network_stats_,
+ AllOf(Ge(absl::Minutes(1)),
+ Lt(absl::Minutes(1) + absl::Milliseconds(10)))));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventAndSetTaskName(
+ task_name,
+ OperationalStats::Event::EVENT_KIND_CHECKIN_PLAN_URI_RECEIVED));
+
+ phase_logger_->LogCheckinPlanUriReceived(task_name, network_stats_,
+ absl::Now() - absl::Minutes(1));
+}
+
+TEST_P(PhaseLoggerImplTest, LogCheckinCompleted) {
+ NetworkStats network_stats{.bytes_downloaded = 100,
+ .bytes_uploaded = 200,
+ .network_duration = absl::Seconds(40)};
+
+ absl::Duration expected_duration = absl::Minutes(1);
+
+ std::string task_name = "my_task";
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishCheckinFinishedV2(
+ network_stats,
+ AllOf(Ge(expected_duration),
+ Lt(expected_duration + absl::Milliseconds(10)))));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_CHECKIN_ACCEPTED));
+ // The counter should always log the full duration, from before the start of
+ // the checkin.
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_LATENCY,
+ AllOf(Ge(absl::ToInt64Milliseconds(absl::Minutes(2))),
+ Lt(absl::ToInt64Milliseconds(
+ absl::Minutes(2) + absl::Milliseconds(100)))));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_CHECKIN_END_TIME,
+ AllOf(Ge(absl::ToInt64Milliseconds(absl::Minutes(8))),
+ Lt(absl::ToInt64Milliseconds(
+ absl::Minutes(8) + absl::Milliseconds(100)))));
+
+ phase_logger_->LogCheckinCompleted(
+ task_name, network_stats,
+ /*time_before_checkin=*/absl::Now() - absl::Minutes(2),
+ /*time_before_plan_download=*/absl::Now() - absl::Minutes(1),
+ /*reference_time=*/absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogComputationStarted) {
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_, PublishComputationStarted());
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_COMPUTATION_STARTED));
+ phase_logger_->LogComputationStarted();
+}
+
+TEST_P(PhaseLoggerImplTest, LogComputationInvalidArgument) {
+ std::string error_message = "Unexpected input tensor";
+ std::string expected_error_message =
+ absl::StrCat("Error during computation: code: 3, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(
+ mock_log_manager_,
+ LogDiag(
+ ProdDiagCode::BACKGROUND_TRAINING_FAILED_PLAN_FAILS_SANITY_CHECK));
+ EXPECT_CALL(mock_event_publisher_,
+ PublishComputationInvalidArgument(
+ expected_error_message, example_stats_, network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_COMPUTATION_ERROR_INVALID_ARGUMENT,
+ expected_error_message));
+ phase_logger_->LogComputationInvalidArgument(
+ absl::InvalidArgumentError(error_message), example_stats_, network_stats_,
+ absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogComputationIOError) {
+ std::string error_message = "IO error";
+ std::string expected_error_message =
+ absl::StrCat("Error during computation: code: 3, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishComputationIOError(expected_error_message, example_stats_,
+ network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_ERROR_IO,
+ expected_error_message));
+ phase_logger_->LogComputationIOError(
+ absl::InvalidArgumentError(error_message), example_stats_, network_stats_,
+ absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogComputationExampleIteratorError) {
+ std::string original_message = "Cannot create example iterator";
+ absl::Status error_status = absl::InvalidArgumentError(original_message);
+ std::string expected_error_message = absl::StrCat(
+ "Error during computation: code: 3, error: ", original_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishComputationExampleIteratorError(
+ expected_error_message, example_stats_, network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::
+ EVENT_KIND_COMPUTATION_ERROR_EXAMPLE_ITERATOR,
+ expected_error_message));
+ phase_logger_->LogComputationExampleIteratorError(
+ error_status, example_stats_, network_stats_, absl::Now());
+}
+
+TEST_P(PhaseLoggerImplTest, LogComputationTensorflowError) {
+ std::string original_message = "Missing op kernel NotExist";
+ absl::Status error_status = absl::InvalidArgumentError(original_message);
+ std::string expected_error_message =
+ absl::StrCat("Error during computation: code: 3, error: ");
+ if (log_tensorflow_error_messages_) {
+ absl::StrAppend(&expected_error_message, original_message);
+ }
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishComputationTensorflowError(
+ expected_error_message, example_stats_, network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_ERROR_TENSORFLOW,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_END_TIME, Ge(0));
+ phase_logger_->LogComputationTensorflowError(
+ error_status, example_stats_, network_stats_,
+ absl::Now() - absl::Minutes(6), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogComputationInterrupted) {
+ absl::Time run_plan_start_time = absl::Now() - absl::Minutes(6);
+ absl::Time reference_time = absl::Now() - absl::Minutes(8);
+ std::string error_message = "Client is no longer idle.";
+ std::string expected_error_message =
+ absl::StrCat("Error during computation: code: 1, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishComputationInterrupted(expected_error_message,
+ example_stats_, network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_COMPUTATION_CLIENT_INTERRUPTED,
+ expected_error_message));
+
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_LATENCY, Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_END_TIME, Ge(0));
+
+ phase_logger_->LogComputationInterrupted(absl::CancelledError(error_message),
+ example_stats_, network_stats_,
+ run_plan_start_time, reference_time);
+}
+
+TEST_P(PhaseLoggerImplTest, LogComputationCompleted) {
+ absl::Time run_plan_start_time = absl::Now() - absl::Minutes(6);
+ absl::Time reference_time = absl::Now() - absl::Minutes(8);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishComputationCompleted(example_stats_, network_stats_,
+ Ge(absl::Minutes(6))));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_COMPUTATION_FINISHED));
+ VerifyCounterLogged(HistogramCounters::TRAINING_OVERALL_EXAMPLE_SIZE,
+ kTotalExampleSizeBytes);
+ VerifyCounterLogged(HistogramCounters::TRAINING_OVERALL_EXAMPLE_COUNT,
+ kTotalExampleCount);
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_LATENCY,
+ Ge(absl::ToInt64Milliseconds(absl::Minutes(6))));
+ VerifyCounterLogged(HistogramCounters::TRAINING_RUN_PHASE_END_TIME,
+ Ge(absl::ToInt64Milliseconds(absl::Minutes(8))));
+
+ phase_logger_->LogComputationCompleted(example_stats_, network_stats_,
+ run_plan_start_time, reference_time);
+}
+
+TEST_P(PhaseLoggerImplTest, LogResultUploadStartedOpStatsDbCommitSucceeds) {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED));
+ EXPECT_CALL(mock_opstats_logger_, CommitToStorage)
+ .WillOnce(Return(absl::OkStatus()));
+ EXPECT_CALL(mock_event_publisher_, PublishResultUploadStarted());
+
+ ASSERT_OK(phase_logger_->LogResultUploadStarted());
+}
+
+TEST_P(PhaseLoggerImplTest, LogResultUploadStartedOpStatsDbCommitFails) {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_STARTED));
+ EXPECT_CALL(mock_opstats_logger_, CommitToStorage)
+ .WillOnce(Return(absl::InternalError("")));
+
+ ASSERT_THAT(phase_logger_->LogResultUploadStarted(),
+ IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_P(PhaseLoggerImplTest, LogResultUploadIOError) {
+ std::string error_message = "Network IO";
+ std::string expected_error_message =
+ absl::StrCat("Error reporting results: code: 14, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(
+ mock_event_publisher_,
+ PublishResultUploadIOError(expected_error_message, network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_ERROR_IO,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_LATENCY,
+ Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_END_TIME,
+ Ge(0));
+ phase_logger_->LogResultUploadIOError(
+ absl::UnavailableError(error_message), network_stats_,
+ absl::Now() - absl::Minutes(1), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogResultUploadClientInterrupted) {
+ std::string error_message = "Client is no longer idle";
+ std::string expected_error_message =
+ absl::StrCat("Error reporting results: code: 1, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishResultUploadClientInterrupted(expected_error_message,
+ network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_CLIENT_INTERRUPTED,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_LATENCY,
+ Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_END_TIME,
+ Ge(0));
+
+ phase_logger_->LogResultUploadClientInterrupted(
+ absl::CancelledError(error_message), network_stats_,
+ absl::Now() - absl::Minutes(1), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogResultUploadServerAborted) {
+ std::string error_message = "Request is aborted by the server";
+ std::string expected_error_message =
+ absl::StrCat("Error reporting results: code: 10, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishResultUploadServerAborted(expected_error_message,
+ network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_SERVER_ABORTED,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_LATENCY,
+ Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_END_TIME,
+ Ge(0));
+
+ phase_logger_->LogResultUploadServerAborted(
+ absl::AbortedError(error_message), network_stats_,
+ absl::Now() - absl::Minutes(1), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogResultUploadCompleted) {
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishResultUploadCompleted(network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_RESULT_UPLOAD_FINISHED));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_LATENCY,
+ Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_END_TIME,
+ Ge(0));
+
+ phase_logger_->LogResultUploadCompleted(network_stats_,
+ absl::Now() - absl::Minutes(1),
+ absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogFailureUploadStartedOpstatsDbCommitSucceeds) {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_STARTED));
+ EXPECT_CALL(mock_opstats_logger_, CommitToStorage())
+ .WillOnce(Return(absl::OkStatus()));
+ EXPECT_CALL(mock_event_publisher_, PublishFailureUploadStarted());
+ ASSERT_OK(phase_logger_->LogFailureUploadStarted());
+}
+
+TEST_P(PhaseLoggerImplTest, LogFailureUploadStartedOpstatsDbCommitFails) {
+ InSequence seq;
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_STARTED));
+ EXPECT_CALL(mock_opstats_logger_, CommitToStorage())
+ .WillOnce(Return(absl::InternalError("")));
+ ASSERT_THAT(phase_logger_->LogFailureUploadStarted(),
+ IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_P(PhaseLoggerImplTest, LogFailureUploadIOError) {
+ std::string error_message = "Network error.";
+ std::string expected_error_message = absl::StrCat(
+ "Error reporting computation failure: code: 14, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(
+ mock_event_publisher_,
+ PublishFailureUploadIOError(expected_error_message, network_stats_, _));
+ EXPECT_CALL(mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_ERROR_IO,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_LATENCY,
+ Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_END_TIME,
+ Ge(0));
+ phase_logger_->LogFailureUploadIOError(
+ absl::UnavailableError(error_message), network_stats_,
+ absl::Now() - absl::Minutes(1), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogFailureUploadClientInterrupted) {
+ std::string error_message = "The client is no longer idle";
+ std::string expected_error_message = absl::StrCat(
+ "Error reporting computation failure: code: 1, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishFailureUploadClientInterrupted(expected_error_message,
+ network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_CLIENT_INTERRUPTED,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_LATENCY,
+ Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_END_TIME,
+ Ge(0));
+ phase_logger_->LogFailureUploadClientInterrupted(
+ absl::CancelledError(error_message), network_stats_,
+ absl::Now() - absl::Minutes(1), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogFailureUploadServerAborted) {
+ std::string error_message = "Request is aborted by the server.";
+ std::string expected_error_message = absl::StrCat(
+ "Error reporting computation failure: code: 10, error: ", error_message);
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishFailureUploadServerAborted(expected_error_message,
+ network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEventWithErrorMessage(
+ OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_SERVER_ABORTED,
+ expected_error_message));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_LATENCY,
+ Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_END_TIME,
+ Ge(0));
+ phase_logger_->LogFailureUploadServerAborted(
+ absl::AbortedError(error_message), network_stats_,
+ absl::Now() - absl::Minutes(1), absl::Now() - absl::Minutes(8));
+}
+
+TEST_P(PhaseLoggerImplTest, LogFailureUploadCompleted) {
+ InSequence seq;
+ EXPECT_CALL(mock_event_publisher_,
+ PublishFailureUploadCompleted(network_stats_, _));
+ EXPECT_CALL(
+ mock_opstats_logger_,
+ AddEvent(OperationalStats::Event::EVENT_KIND_FAILURE_UPLOAD_FINISHED));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_LATENCY,
+ Ge(0));
+ VerifyCounterLogged(HistogramCounters::TRAINING_FL_REPORT_RESULTS_END_TIME,
+ Ge(0));
+ phase_logger_->LogFailureUploadCompleted(network_stats_,
+ absl::Now() - absl::Minutes(1),
+ absl::Now() - absl::Minutes(8));
+}
+
+} // namespace
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/secagg_event_publisher.h b/fcp/client/secagg_event_publisher.h
new file mode 100644
index 0000000..6bdc78d
--- /dev/null
+++ b/fcp/client/secagg_event_publisher.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_SECAGG_EVENT_PUBLISHER_H_
+#define FCP_CLIENT_SECAGG_EVENT_PUBLISHER_H_
+
+#include <cstdint>
+#include <string>
+
+namespace fcp::secagg {
+enum class ClientState : int;
+} // namespace fcp::secagg
+
+namespace fcp {
+namespace client {
+
+// An interface for publishing events that occur during the secure
+// aggregation protocol. All methods in here either succeed with OK, or fail
+// with INVALID_ARGUMENT.
+class SecAggEventPublisher {
+ public:
+ virtual ~SecAggEventPublisher() = default;
+
+ // Publishes that the protocol has left the prior state and entered the
+ // given state, along with the size of the last message sent.
+ virtual void PublishStateTransition(::fcp::secagg::ClientState state,
+ size_t last_sent_message_size,
+ size_t last_received_message_size) = 0;
+ // Publishes a top-level SecAgg client error.
+ virtual void PublishError() = 0;
+ // Publishes a SecAgg client abort.
+ virtual void PublishAbort(bool client_initiated,
+ const std::string& error_message) = 0;
+ // After calling this function, all subsequently published events will be
+ // annotated with the specified execution logging ID, which is set during
+ // protocol execution.
+ virtual void set_execution_session_id(int64_t execution_session_id) = 0;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_SECAGG_EVENT_PUBLISHER_H_
diff --git a/fcp/client/secagg_runner.cc b/fcp/client/secagg_runner.cc
new file mode 100644
index 0000000..bb96718
--- /dev/null
+++ b/fcp/client/secagg_runner.cc
@@ -0,0 +1,224 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/secagg_runner.h"
+
+#include <memory>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/crypto_rand_prng.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+
+namespace fcp {
+namespace client {
+
+using ::fcp::secagg::ClientState;
+
+// Implementation of StateTransitionListenerInterface.
+class SecAggStateTransitionListenerImpl
+ : public secagg::StateTransitionListenerInterface {
+ public:
+ SecAggStateTransitionListenerImpl(
+ SecAggEventPublisher& secagg_event_publisher, LogManager& log_manager,
+ SecAggSendToServerBase& secagg_send_to_server_impl,
+ SecAggProtocolDelegate& secagg_protocol_delegate);
+ void Transition(secagg::ClientState new_state) override;
+
+ void Started(secagg::ClientState state) override;
+
+ void Stopped(secagg::ClientState state) override;
+
+ void set_execution_session_id(int64_t execution_session_id) override;
+
+ private:
+ SecAggEventPublisher& secagg_event_publisher_;
+ LogManager& log_manager_;
+ SecAggSendToServerBase& secagg_send_to_server_;
+ SecAggProtocolDelegate& secagg_protocol_delegate_;
+ secagg::ClientState state_ = secagg::ClientState::INITIAL;
+};
+
+SecAggStateTransitionListenerImpl::SecAggStateTransitionListenerImpl(
+ SecAggEventPublisher& secagg_event_publisher, LogManager& log_manager,
+ SecAggSendToServerBase& secagg_send_to_server_impl,
+ SecAggProtocolDelegate& secagg_protocol_delegate)
+ : secagg_event_publisher_(secagg_event_publisher),
+ log_manager_(log_manager),
+ secagg_send_to_server_(secagg_send_to_server_impl),
+ secagg_protocol_delegate_(secagg_protocol_delegate) {}
+
+void SecAggStateTransitionListenerImpl::Transition(ClientState new_state) {
+ FCP_LOG(INFO) << "Transitioning from state: " << static_cast<int>(state_)
+ << " to state: " << static_cast<int>(new_state);
+ state_ = new_state;
+ if (state_ == ClientState::ABORTED) {
+ log_manager_.LogDiag(ProdDiagCode::SECAGG_CLIENT_NATIVE_ERROR_GENERIC);
+ }
+ secagg_event_publisher_.PublishStateTransition(
+ new_state, secagg_send_to_server_.last_sent_message_size(),
+ secagg_protocol_delegate_.last_received_message_size());
+}
+
+void SecAggStateTransitionListenerImpl::Started(ClientState state) {
+ // TODO(team): Implement this.
+}
+
+void SecAggStateTransitionListenerImpl::Stopped(ClientState state) {
+ // TODO(team): Implement this.
+}
+
+void SecAggStateTransitionListenerImpl::set_execution_session_id(
+ int64_t execution_session_id) {
+ secagg_event_publisher_.set_execution_session_id(execution_session_id);
+}
+
+SecAggRunnerImpl::SecAggRunnerImpl(
+ std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
+ std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
+ SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
+ InterruptibleRunner* interruptible_runner,
+ int64_t expected_number_of_clients,
+ int64_t minimum_surviving_clients_for_reconstruction)
+ : send_to_server_impl_(std::move(send_to_server_impl)),
+ protocol_delegate_(std::move(protocol_delegate)),
+ secagg_event_publisher_(*secagg_event_publisher),
+ log_manager_(*log_manager),
+ interruptible_runner_(*interruptible_runner),
+ expected_number_of_clients_(expected_number_of_clients),
+ minimum_surviving_clients_for_reconstruction_(
+ minimum_surviving_clients_for_reconstruction) {}
+
+absl::Status SecAggRunnerImpl::Run(ComputationResults results) {
+ auto secagg_state_transition_listener =
+ std::make_unique<SecAggStateTransitionListenerImpl>(
+ secagg_event_publisher_, log_manager_, *send_to_server_impl_,
+ *protocol_delegate_);
+ auto input_map = std::make_unique<secagg::SecAggVectorMap>();
+ std::vector<secagg::InputVectorSpecification> input_vector_specification;
+ for (auto& [k, v] : results) {
+ if (std::holds_alternative<QuantizedTensor>(v)) {
+ FCP_ASSIGN_OR_RETURN(uint64_t modulus, protocol_delegate_->GetModulus(k));
+ // Note: std::move is used below to ensure that each QuantizedTensor
+ // is consumed when converted to SecAggVector and that we don't
+ // continue having both in memory for longer than needed.
+ auto vector = std::get<QuantizedTensor>(std::move(v));
+ if (modulus <= 1 || modulus > secagg::SecAggVector::kMaxModulus) {
+ return absl::InternalError(
+ absl::StrCat("Invalid SecAgg modulus configuration: ", modulus));
+ }
+ if (vector.values.empty())
+ return absl::InternalError(
+ absl::StrCat("Zero sized vector found: ", k));
+ int64_t flattened_length = 1;
+ for (const auto& size : vector.dimensions) flattened_length *= size;
+ auto data_length = vector.values.size();
+ if (flattened_length != data_length)
+ return absl::InternalError(
+ absl::StrCat("Flattened length: ", flattened_length,
+ " does not match vector size: ", data_length));
+ for (const auto& value : vector.values) {
+ if (value >= modulus) {
+ return absl::InternalError(absl::StrCat(
+ "The input SecAgg vector doesn't have the appropriate "
+ "modulus: element with value ",
+ value, " found, max value allowed ", (modulus - 1ULL)));
+ }
+ }
+ input_vector_specification.emplace_back(k, flattened_length, modulus);
+ input_map->try_emplace(
+ k, absl::MakeConstSpan(vector.values.data(), data_length), modulus);
+ }
+ }
+ secagg_client_ = std::make_unique<secagg::SecAggClient>(
+ expected_number_of_clients_,
+ minimum_surviving_clients_for_reconstruction_,
+ std::move(input_vector_specification),
+ std::make_unique<secagg::CryptoRandPrng>(),
+ std::move(send_to_server_impl_),
+ std::move(secagg_state_transition_listener),
+ std::make_unique<secagg::AesCtrPrngFactory>());
+
+ FCP_RETURN_IF_ERROR(interruptible_runner_.Run(
+ [this, &input_map]() -> absl::Status {
+ FCP_RETURN_IF_ERROR(secagg_client_->Start());
+ FCP_RETURN_IF_ERROR(secagg_client_->SetInput(std::move(input_map)));
+ while (!secagg_client_->IsCompletedSuccessfully()) {
+ absl::StatusOr<secagg::ServerToClientWrapperMessage>
+ server_to_client_wrapper_message =
+ this->protocol_delegate_->ReceiveServerMessage();
+ if (!server_to_client_wrapper_message.ok()) {
+ return absl::Status(
+ server_to_client_wrapper_message.status().code(),
+ absl::StrCat(
+ "Error during SecAgg receive: ",
+ server_to_client_wrapper_message.status().message()));
+ }
+ auto result =
+ secagg_client_->ReceiveMessage(*server_to_client_wrapper_message);
+ if (!result.ok()) {
+ this->secagg_event_publisher_.PublishError();
+ return absl::Status(result.status().code(),
+ absl::StrCat("Error receiving SecAgg message: ",
+ result.status().message()));
+ }
+ if (secagg_client_->IsAborted()) {
+ std::string error_message = "error message not found.";
+ if (secagg_client_->ErrorMessage().ok())
+ error_message = secagg_client_->ErrorMessage().value();
+ this->secagg_event_publisher_.PublishAbort(false, error_message);
+ return absl::CancelledError("SecAgg aborted: " + error_message);
+ }
+ }
+ return absl::OkStatus();
+ },
+ [this]() {
+ AbortInternal();
+ this->protocol_delegate_->Abort();
+ }));
+ return absl::OkStatus();
+}
+
+void SecAggRunnerImpl::AbortInternal() {
+ log_manager_.LogDiag(ProdDiagCode::SECAGG_CLIENT_NATIVE_ERROR_GENERIC);
+ auto abort_message = "Client-initiated abort.";
+ auto result = secagg_client_->Abort(abort_message);
+ if (!result.ok()) {
+ FCP_LOG(ERROR) << "Could not initiate client abort, code: " << result.code()
+ << " message: " << result.message();
+ }
+ // Note: the implementation assumes that secagg_event_publisher
+ // cannot hang indefinitely, i.e. does not need its own interruption
+ // trigger.
+ secagg_event_publisher_.PublishAbort(true, abort_message);
+}
+
+std::unique_ptr<SecAggRunner> SecAggRunnerFactoryImpl::CreateSecAggRunner(
+ std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
+ std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
+ SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
+ InterruptibleRunner* interruptible_runner,
+ int64_t expected_number_of_clients,
+ int64_t minimum_surviving_clients_for_reconstruction) {
+ return std::make_unique<SecAggRunnerImpl>(
+ std::move(send_to_server_impl), std::move(protocol_delegate),
+ secagg_event_publisher, log_manager, interruptible_runner,
+ expected_number_of_clients, minimum_surviving_clients_for_reconstruction);
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/secagg_runner.h b/fcp/client/secagg_runner.h
new file mode 100644
index 0000000..074b2b7
--- /dev/null
+++ b/fcp/client/secagg_runner.h
@@ -0,0 +1,120 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_SECAGG_RUNNER_H_
+#define FCP_CLIENT_SECAGG_RUNNER_H_
+
+#include <memory>
+#include <string>
+
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/interruptible_runner.h"
+#include "fcp/client/secagg_event_publisher.h"
+#include "fcp/secagg/client/secagg_client.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace client {
+
+// Base SecAggSendToServer class which provides message size and network
+// bandwidth usage metrics. When the child class inherit from this class, it's
+// up to the child class to record metrics correctly.
+class SecAggSendToServerBase : public secagg::SendToServerInterface {
+ public:
+ size_t last_sent_message_size() const { return last_sent_message_size_; }
+
+ protected:
+ size_t last_sent_message_size_ = 0;
+};
+
+// A delegate class which handles server to client communication protocol
+// specific details (HTTP vs gRPC etc).
+class SecAggProtocolDelegate {
+ public:
+ virtual ~SecAggProtocolDelegate() = default;
+ // Retrieve the modulus for a given SecAgg vector.
+ virtual absl::StatusOr<uint64_t> GetModulus(const std::string& key) = 0;
+ // Receive Server message.
+ virtual absl::StatusOr<secagg::ServerToClientWrapperMessage>
+ ReceiveServerMessage() = 0;
+ // Called when the SecAgg protocol is interrupted.
+ virtual void Abort() = 0;
+ virtual size_t last_received_message_size() = 0;
+};
+
+// A helper class which runs the secure aggregation protocol.
+class SecAggRunner {
+ public:
+ virtual ~SecAggRunner() = default;
+ virtual absl::Status Run(ComputationResults results) = 0;
+};
+
+// Implementation of SecAggRunner.
+class SecAggRunnerImpl : public SecAggRunner {
+ public:
+ SecAggRunnerImpl(std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
+ std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
+ SecAggEventPublisher* secagg_event_publisher,
+ LogManager* log_manager,
+ InterruptibleRunner* interruptible_runner,
+ int64_t expected_number_of_clients,
+ int64_t minimum_surviving_clients_for_reconstruction);
+ // Run the secure aggregation protocol.
+ // SecAggProtocolDelegate and SecAggSendToServerBase will only be invoked from
+ // a single thread.
+ absl::Status Run(ComputationResults results) override;
+
+ private:
+ void AbortInternal();
+
+ std::unique_ptr<SecAggSendToServerBase> send_to_server_impl_;
+ std::unique_ptr<SecAggProtocolDelegate> protocol_delegate_;
+ std::unique_ptr<secagg::SecAggClient> secagg_client_;
+ SecAggEventPublisher& secagg_event_publisher_;
+ LogManager& log_manager_;
+ InterruptibleRunner& interruptible_runner_;
+ const int64_t expected_number_of_clients_;
+ const int64_t minimum_surviving_clients_for_reconstruction_;
+};
+
+// A factory interface for SecAggRunner.
+class SecAggRunnerFactory {
+ public:
+ virtual ~SecAggRunnerFactory() = default;
+ virtual std::unique_ptr<SecAggRunner> CreateSecAggRunner(
+ std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
+ std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
+ SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
+ InterruptibleRunner* interruptible_runner,
+ int64_t expected_number_of_clients,
+ int64_t minimum_surviving_clients_for_reconstruction) = 0;
+};
+
+class SecAggRunnerFactoryImpl : public SecAggRunnerFactory {
+ std::unique_ptr<SecAggRunner> CreateSecAggRunner(
+ std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
+ std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
+ SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
+ InterruptibleRunner* interruptible_runner,
+ int64_t expected_number_of_clients,
+ int64_t minimum_surviving_clients_for_reconstruction) override;
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_SECAGG_RUNNER_H_
diff --git a/fcp/client/selector_context.proto b/fcp/client/selector_context.proto
new file mode 100644
index 0000000..c2f6ab3
--- /dev/null
+++ b/fcp/client/selector_context.proto
@@ -0,0 +1,127 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.client;
+
+import "google/protobuf/timestamp.proto";
+
+option java_package = "com.google.intelligence.fcp.client";
+option java_multiple_files = true;
+
+message SelectorContext {
+ QueryTimeComputationProperties computation_properties = 1;
+}
+
+// Properties about the computation exposed to the iterator.
+message QueryTimeComputationProperties {
+ // Session name, if applicable.
+ string session_name = 1;
+
+ // Different kinds of computation types.
+ oneof computation_type {
+ // Local computation type.
+ LocalComputation local_compute = 2;
+
+ // EligibilityEval computation type.
+ EligibilityEvalComputation eligibility_eval = 3;
+
+ // Federated computation type.
+ FederatedComputation federated = 4;
+ }
+
+ // A unique ID identifying the computation run.
+ int64 run_id = 5;
+
+ // Additional context data.
+ bytes context_data = 6;
+
+ enum ExampleIteratorOutputFormat {
+ // The specific serialization format is left unspecified and up to the
+ // ExampleStore implementation and TensorFlow-based tasks.
+ // In most cases, data is encoded in binary-serialized `tf.train.Example`
+ // protos.
+ EXAMPLE_ITERATOR_OUTPUT_FORMAT_UNSPECIFIED = 0;
+
+ // Data encoded in binary-serialized `fcp.client.ExampleQueryResult` protos.
+ EXAMPLE_QUERY_RESULT = 1;
+ }
+
+ // Expected output format from the example iterator.
+ ExampleIteratorOutputFormat example_iterator_output_format = 7;
+}
+
+// On-device, local computation only. No aggregation.
+message LocalComputation {
+ // The absolute path to the input directory.
+ string input_dir = 1;
+ // The absolute path to the output directory.
+ string output_dir = 2;
+ // The map of input resources where the key is the name of the resource, and
+ // the value is the absolute paths to the resource.
+ map<string, string> input_resource_map = 3;
+}
+
+// ElgibilityEval computation, no aggregation.
+message EligibilityEvalComputation {
+ // Population name.
+ string population_name = 1;
+}
+
+// Federated computation with server aggregation.
+message FederatedComputation {
+ // Population name.
+ string population_name = 1;
+ // Name of the task that was executed.
+ string task_name = 2;
+
+ // Identity representing the computation e.g. its plan hash.
+ bytes computation_id = 5;
+
+ // Details about previous executions for the currently executing task.
+ HistoricalContext historical_context = 6;
+
+ // Types of server aggregation.
+ oneof aggregation_type {
+ // Simple aggregation. At least one value is aggregated with simple
+ // aggregation. This includes the mixed case where some values are
+ // aggregated with simple aggregation while others are aggregated with
+ // secure aggregation.
+ SimpleAggregation simple_aggregation = 3;
+
+ // Secure aggregation. All values are aggregated with secure aggregation.
+ SecureAggregation secure_aggregation = 4;
+ }
+}
+
+// Simple aggregation.
+message SimpleAggregation {}
+
+// Secure aggregation.
+message SecureAggregation {
+ // The minimum number of clients' values that must be aggregated together
+ // before the server can gain access to the aggregate,
+ // even transiently (e.g. in RAM).
+ // This isn't needed by Secure Aggregation protocol on the client side but
+ // shared by the server with clients for transparency and/or policy reasons.
+ // See `federated_api.proto`.
+ int32 minimum_clients_in_server_visible_aggregate = 1;
+}
+
+// Details about previous executions for the currently executing task.
+message HistoricalContext {
+ // Timestamp of when this task was last successfully contributed to.
+ google.protobuf.Timestamp last_successful_contribution_time = 1;
+}
diff --git a/fcp/client/simple_task_environment.cc b/fcp/client/simple_task_environment.cc
new file mode 100644
index 0000000..55b304e
--- /dev/null
+++ b/fcp/client/simple_task_environment.cc
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/simple_task_environment.h"
+
+namespace fcp {
+namespace client {
+
+bool SimpleTaskEnvironment::ShouldAbort(
+ absl::Time current_time, absl::Duration condition_polling_period) {
+ if (current_time - last_training_conditions_fetch_timestamp_ <
+ condition_polling_period) {
+ return false;
+ } else {
+ last_training_conditions_fetch_timestamp_ = current_time;
+ }
+ return !TrainingConditionsSatisfied();
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/simple_task_environment.h b/fcp/client/simple_task_environment.h
new file mode 100644
index 0000000..ca866e8
--- /dev/null
+++ b/fcp/client/simple_task_environment.h
@@ -0,0 +1,102 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_SIMPLE_TASK_ENVIRONMENT_H_
+#define FCP_CLIENT_SIMPLE_TASK_ENVIRONMENT_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/selector_context.pb.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp {
+namespace client {
+
+// An interface used by the plan engine to query for serialized examples. Not
+// required to be thread-safe.
+class ExampleIterator {
+ public:
+ virtual ~ExampleIterator() = default;
+
+ // Returns a serialized example as std::string on success; or, on error:
+ // - CANCELLED if the call got interrupted.
+ // - INVALID_ARGUMENT if some other error occurred, e.g. I/O.
+ // - OUT_OF_RANGE if the end of the iterator was reached.
+ virtual absl::StatusOr<std::string> Next() = 0;
+
+ // Close the iterator to release associated resources.
+ virtual void Close() = 0;
+};
+
+// A simplified task environment that acts as a standalone task environment for
+// TFF-based plans and a delegated task environment for earlier plan types.
+class SimpleTaskEnvironment {
+ public:
+ virtual ~SimpleTaskEnvironment() = default;
+
+ // Returns the path of the directory that will be used to store persistent
+ // files that should not be deleted, such as Opstats.
+ virtual std::string GetBaseDir() = 0;
+
+ // Returns the path of the directory that may be used to store
+ // temporary/cached files. The federated compute runtime will use this
+ // directory to cache data for re-use across multiple invocations, as well as
+ // for creating temporary files that are deleted at the end of each
+ // invocation. Implementers of this interface may also delete files in this
+ // directory (for example, in low storage situations) without adverse effects
+ // to the runtime. GetCacheDir() may return the same path as GetBaseDir().
+ virtual std::string GetCacheDir() = 0;
+
+ // TODO(team): factor out native implementations of this and delete.
+ virtual absl::StatusOr<std::unique_ptr<ExampleIterator>>
+ CreateExampleIterator(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) = 0;
+
+ virtual absl::StatusOr<std::unique_ptr<ExampleIterator>>
+ CreateExampleIterator(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector,
+ const SelectorContext& selector_context) {
+ return CreateExampleIterator(example_selector);
+ }
+
+ // Creates an HttpClient. May return a nullptr if HTTP is not supported
+ // (although support for HTTP will become mandatory in the future).
+ virtual std::unique_ptr<fcp::client::http::HttpClient> CreateHttpClient() {
+ return nullptr;
+ }
+
+ // Checks whether the caller should abort computation. If less than
+ // condition_polling_period time has elapsed since the last call this function
+ // made to TrainingConditionsSatisfied, returns false.
+ bool ShouldAbort(absl::Time current_time,
+ absl::Duration condition_polling_period);
+
+ private:
+ virtual bool TrainingConditionsSatisfied() = 0;
+
+ absl::Time last_training_conditions_fetch_timestamp_ = absl::InfinitePast();
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_SIMPLE_TASK_ENVIRONMENT_H_
diff --git a/fcp/client/simple_task_environment_test.cc b/fcp/client/simple_task_environment_test.cc
new file mode 100644
index 0000000..3ac990a
--- /dev/null
+++ b/fcp/client/simple_task_environment_test.cc
@@ -0,0 +1,84 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/client/simple_task_environment.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "fcp/client/test_helpers.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace client {
+namespace {
+
+using ::testing::Return;
+using ::testing::StrictMock;
+
+TEST(SimpleTaskEnvironmentTest, TestShouldAbort) {
+ StrictMock<MockSimpleTaskEnvironment> mock_task_env;
+ EXPECT_CALL(mock_task_env, TrainingConditionsSatisfied())
+ .WillOnce(Return(false));
+ bool result = mock_task_env.ShouldAbort(
+ /*current_time=*/absl::Now(),
+ /*condition_polling_period=*/absl::ZeroDuration());
+ EXPECT_TRUE(result);
+}
+
+// Assert that with a zero condition_polling_period, no throttling in
+// ShouldAbort takes place, by emulating two calls at the same time.
+// Both calls should return the mock's value.
+TEST(SimpleTaskEnvironmentTest, TestShouldAbortNoThrottling) {
+ StrictMock<MockSimpleTaskEnvironment> mock_task_env;
+ absl::Time now = absl::Now();
+ EXPECT_CALL(mock_task_env, TrainingConditionsSatisfied())
+ .WillRepeatedly(Return(false));
+ bool result = mock_task_env.ShouldAbort(
+ /*current_time=*/now,
+ /*condition_polling_period=*/absl::ZeroDuration());
+ EXPECT_TRUE(result);
+ result = mock_task_env.ShouldAbort(
+ /*current_time=*/now,
+ /*condition_polling_period=*/absl::ZeroDuration());
+ EXPECT_TRUE(result);
+}
+
+// Verify ShouldAbort throttling for non-zero polling periods.
+TEST(SimpleTaskEnvironmentTest, TestShouldAbortThrottling) {
+ StrictMock<MockSimpleTaskEnvironment> mock_task_env;
+ EXPECT_CALL(mock_task_env, TrainingConditionsSatisfied())
+ .WillRepeatedly(Return(false));
+ absl::Time now = absl::Now();
+ // First call should be non-throttled (since it assumes last call happened at
+ // UnixEpoch. Second call after 1s will be throttled because polling period is
+ // 1.5s; third call (after 2s) will be non-throttled again.
+ bool result = mock_task_env.ShouldAbort(
+ /*current_time=*/now,
+ /*condition_polling_period=*/absl::Milliseconds(1500));
+ EXPECT_TRUE(result);
+ result = mock_task_env.ShouldAbort(
+ /*current_time=*/now + absl::Seconds(1),
+ /*condition_polling_period=*/absl::Milliseconds(1500));
+ EXPECT_FALSE(result);
+ result = mock_task_env.ShouldAbort(
+ /*current_time=*/now + absl::Seconds(2),
+ /*condition_polling_period=*/absl::Milliseconds(1500));
+ EXPECT_TRUE(result);
+}
+
+} // anonymous namespace
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/stats.h b/fcp/client/stats.h
new file mode 100644
index 0000000..572bbdd
--- /dev/null
+++ b/fcp/client/stats.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_STATS_H_
+#define FCP_CLIENT_STATS_H_
+
+#include <cstdint>
+
+#include "absl/time/time.h"
+
+namespace fcp {
+namespace client {
+
+struct NetworkStats {
+ // The estimated number of bytes downloaded from the network ("over the
+ // wire").
+ int64_t bytes_downloaded = 0;
+ // The estimated number of bytes uploaded to the network ("over the
+ // wire").
+ int64_t bytes_uploaded = 0;
+ // The best estimate of the duration of wall clock time spent waiting for
+ // network requests to finish (but, for example, excluding any idle time spent
+ // waiting between issuing polling requests).
+ absl::Duration network_duration = absl::ZeroDuration();
+
+ // Returns the difference between two sets of network stats.
+ NetworkStats operator-(const NetworkStats& other) const {
+ return {.bytes_downloaded = bytes_downloaded - other.bytes_downloaded,
+ .bytes_uploaded = bytes_uploaded - other.bytes_uploaded,
+ .network_duration = network_duration - other.network_duration};
+ }
+
+ NetworkStats operator+(const NetworkStats& other) const {
+ return {.bytes_downloaded = bytes_downloaded + other.bytes_downloaded,
+ .bytes_uploaded = bytes_uploaded + other.bytes_uploaded,
+ .network_duration = network_duration + other.network_duration};
+ }
+};
+
+inline bool operator==(const NetworkStats& s1, const NetworkStats& s2) {
+ return
+
+ s1.bytes_downloaded == s2.bytes_downloaded &&
+ s1.bytes_uploaded == s2.bytes_uploaded &&
+ s1.network_duration == s2.network_duration;
+}
+
+struct ExampleStats {
+ int example_count = 0;
+ int64_t example_size_bytes = 0;
+};
+
+inline bool operator==(const ExampleStats& s1, const ExampleStats& s2) {
+ return s1.example_count == s2.example_count &&
+ s1.example_size_bytes == s2.example_size_bytes;
+}
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_STATS_H_
diff --git a/fcp/client/test_helpers.cc b/fcp/client/test_helpers.cc
new file mode 100644
index 0000000..1d232a0
--- /dev/null
+++ b/fcp/client/test_helpers.cc
@@ -0,0 +1,175 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/client/test_helpers.h"
+
+#include <android-base/file.h>
+#include <fcntl.h>
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+
+namespace fcp {
+namespace client {
+
+using ::google::internal::federated::plan::Dataset;
+
+namespace {
+bool LoadFileAsString_(std::string path, std::string* msg) {
+ std::ifstream checkpoint_istream(path);
+ if (!checkpoint_istream) {
+ return false;
+ }
+ std::stringstream checkpoint_stream;
+ checkpoint_stream << checkpoint_istream.rdbuf();
+ *msg = checkpoint_stream.str();
+ return true;
+}
+
+bool LoadMessageLiteFromFile_(std::string path,
+ google::protobuf::MessageLite* msg) {
+ std::string data;
+ if (!LoadFileAsString_(path, &data)) {
+ return false;
+ }
+ if (!msg->ParseFromString(data)) {
+ return false;
+ }
+ return true;
+}
+} // namespace
+
+SimpleExampleIterator::SimpleExampleIterator(
+ std::vector<const char*> examples) {
+ FCP_LOG(INFO) << "***** create example iterator examples";
+ for (const auto& example : examples) {
+ examples_.push_back(std::string(example));
+ }
+ FCP_CHECK(!examples_.empty()) << "No data was loaded";
+}
+
+SimpleExampleIterator::SimpleExampleIterator(Dataset dataset) {
+ FCP_LOG(INFO) << "***** create example iterator dataset";
+ for (const Dataset::ClientDataset& client_dataset : dataset.client_data()) {
+ FCP_CHECK(client_dataset.selected_example_size() == 0)
+ << "This constructor can only be used for Dataset protos with unnamed "
+ "example data.";
+ for (const auto& example : client_dataset.example()) {
+ FCP_LOG(INFO) << "***** create example iterator";
+ examples_.push_back(example);
+ }
+ }
+ FCP_CHECK(!examples_.empty()) << "No data was loaded";
+}
+
+SimpleExampleIterator::SimpleExampleIterator(Dataset dataset,
+ absl::string_view collection_uri) {
+ FCP_LOG(INFO) << "***** create example iterator dataset uri";
+
+ for (const Dataset::ClientDataset& client_dataset : dataset.client_data()) {
+ FCP_CHECK(client_dataset.selected_example_size() > 0)
+ << "This constructor can only be used for Dataset protos with named "
+ "example data.";
+ for (const Dataset::ClientDataset::SelectedExample& selected_example :
+ client_dataset.selected_example()) {
+ // Only use those examples whose `ExampleSelector` matches the
+ // `collection_uri` argument. Note that the `ExampleSelector`'s selection
+ // criteria is ignored/not taken into account here.
+ if (selected_example.selector().collection_uri() != collection_uri) {
+ continue;
+ }
+ for (const auto& example : selected_example.example()) {
+ examples_.push_back(example);
+ }
+ }
+ }
+ FCP_CHECK(!examples_.empty()) << "No data was loaded for " << collection_uri;
+}
+
+absl::StatusOr<std::string> SimpleExampleIterator::Next() {
+ if (index_ < examples_.size()) {
+ FCP_LOG(INFO) << "***** return next example " << examples_[index_];
+ return examples_[index_++];
+ }
+ return absl::OutOfRangeError("");
+}
+
+absl::StatusOr<ComputationArtifacts> LoadFlArtifacts() {
+ FCP_LOG(INFO) << "***** LoadFlArtifacts";
+ std::string artifact_path_prefix =
+ absl::StrCat(android::base::GetExecutableDirectory(), "/fcp/testdata");
+ ComputationArtifacts result;
+ result.plan_filepath =
+ absl::StrCat(artifact_path_prefix, "/federation_client_only_plan.pb");
+ std::string plan;
+ // if (!LoadFileAsString_(result.plan_filepath, &plan)) {
+ // return absl::InternalError("Failed to load ClientOnlyPlan as string");
+ // }
+ // // Load the plan data from the file.
+ if (!LoadMessageLiteFromFile_(result.plan_filepath, &result.plan)) {
+ return absl::InternalError("Failed to load ClientOnlyPlan");
+ }
+
+ // Load dataset
+ auto dataset_filepath =
+ absl::StrCat(artifact_path_prefix, "/federation_proxy_train_examples.pb");
+ if (!LoadMessageLiteFromFile_(dataset_filepath, &result.dataset)) {
+ return absl::InternalError("Failed to load example Dataset");
+ }
+
+ result.checkpoint_filepath = absl::StrCat(
+ artifact_path_prefix, "/federation_test_checkpoint.client.ckp");
+ // Load the checkpoint data from the file.
+ if (!LoadFileAsString_(result.checkpoint_filepath, &result.checkpoint)) {
+ return absl::InternalError("Failed to load checkpoint");
+ }
+
+ auto federated_select_slices_filepath = absl::StrCat(
+ artifact_path_prefix, "/federation_test_select_checkpoints.pb");
+ // Load the federated select slices data.
+ if (!LoadMessageLiteFromFile_(federated_select_slices_filepath,
+ &result.federated_select_slices)) {
+ return absl::InternalError("Failed to load federated select slices");
+ }
+ return result;
+}
+
+std::string ExtractSingleString(const tensorflow::Example& example,
+ const char key[]) {
+ return example.features().feature().at(key).bytes_list().value().at(0);
+}
+
+google::protobuf::RepeatedPtrField<std::string> ExtractRepeatedString(
+ const tensorflow::Example& example, const char key[]) {
+ return example.features().feature().at(key).bytes_list().value();
+}
+
+int64_t ExtractSingleInt64(const tensorflow::Example& example,
+ const char key[]) {
+ return example.features().feature().at(key).int64_list().value().at(0);
+}
+
+google::protobuf::RepeatedField<int64_t> ExtractRepeatedInt64(
+ const tensorflow::Example& example, const char key[]) {
+ return example.features().feature().at(key).int64_list().value();
+}
+
+} // namespace client
+} // namespace fcp
diff --git a/fcp/client/test_helpers.h b/fcp/client/test_helpers.h
new file mode 100644
index 0000000..30022cf
--- /dev/null
+++ b/fcp/client/test_helpers.h
@@ -0,0 +1,861 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_TEST_HELPERS_H_
+#define FCP_CLIENT_TEST_HELPERS_H_
+
+#include <functional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/federated_select.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_db.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/phase_logger.h"
+#include "fcp/client/secagg_event_publisher.h"
+#include "fcp/client/secagg_runner.h"
+#include "fcp/client/simple_task_environment.h"
+#include "gmock/gmock.h"
+#include "google/protobuf/duration.pb.h"
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
+
+namespace fcp {
+namespace client {
+
+class MockSecAggEventPublisher : public SecAggEventPublisher {
+ public:
+ MOCK_METHOD(void, PublishStateTransition,
+ (::fcp::secagg::ClientState state, size_t last_sent_message_size,
+ size_t last_received_message_size),
+ (override));
+ MOCK_METHOD(void, PublishError, (), (override));
+ MOCK_METHOD(void, PublishAbort,
+ (bool client_initiated, const std::string& error_message),
+ (override));
+ MOCK_METHOD(void, set_execution_session_id, (int64_t execution_session_id),
+ (override));
+};
+
+class MockEventPublisher : public EventPublisher {
+ public:
+ MOCK_METHOD(void, PublishEligibilityEvalCheckin, (), (override));
+ MOCK_METHOD(void, PublishEligibilityEvalPlanUriReceived,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalPlanReceived,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalNotConfigured,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalRejected,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckin, (), (override));
+ MOCK_METHOD(void, PublishCheckinFinished,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishRejected, (), (override));
+ MOCK_METHOD(void, PublishReportStarted, (int64_t report_size_bytes),
+ (override));
+ MOCK_METHOD(void, PublishReportFinished,
+ (const NetworkStats& network_stats,
+ absl::Duration report_duration),
+ (override));
+ MOCK_METHOD(void, PublishPlanExecutionStarted, (), (override));
+ MOCK_METHOD(void, PublishTensorFlowError,
+ (int example_count, absl::string_view error_message), (override));
+ MOCK_METHOD(void, PublishIoError, (absl::string_view error_message),
+ (override));
+ MOCK_METHOD(void, PublishExampleSelectorError,
+ (int example_count, absl::string_view error_message), (override));
+ MOCK_METHOD(void, PublishInterruption,
+ (const ExampleStats& example_stats, absl::Time start_time),
+ (override));
+ MOCK_METHOD(void, PublishPlanCompleted,
+ (const ExampleStats& example_stats, absl::Time start_time),
+ (override));
+ MOCK_METHOD(void, SetModelIdentifier, (const std::string& model_identifier),
+ (override));
+ MOCK_METHOD(void, PublishTaskNotStarted, (absl::string_view error_message),
+ (override));
+ MOCK_METHOD(void, PublishNonfatalInitializationError,
+ (absl::string_view error_message), (override));
+ MOCK_METHOD(void, PublishFatalInitializationError,
+ (absl::string_view error_message), (override));
+ MOCK_METHOD(void, PublishEligibilityEvalCheckinIoError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalCheckinClientInterrupted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalCheckinServerAborted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalCheckinErrorInvalidPayload,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationStarted, (), (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationInvalidArgument,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationExampleIteratorError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationTensorflowError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationInterrupted,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationCompleted,
+ (const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinIoError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinClientInterrupted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinServerAborted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinInvalidPayload,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishRejected,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinPlanUriReceived,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinFinishedV2,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationStarted, (), (override));
+ MOCK_METHOD(void, PublishComputationInvalidArgument,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationIOError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationExampleIteratorError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationTensorflowError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationInterrupted,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationCompleted,
+ (const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishResultUploadStarted, (), (override));
+ MOCK_METHOD(void, PublishResultUploadIOError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishResultUploadClientInterrupted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishResultUploadServerAborted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishResultUploadCompleted,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishFailureUploadStarted, (), (override));
+ MOCK_METHOD(void, PublishFailureUploadIOError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishFailureUploadClientInterrupted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishFailureUploadServerAborted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishFailureUploadCompleted,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+
+ SecAggEventPublisher* secagg_event_publisher() override {
+ return &secagg_event_publisher_;
+ }
+
+ private:
+ ::testing::NiceMock<MockSecAggEventPublisher> secagg_event_publisher_;
+};
+
+// A mock FederatedProtocol implementation, which keeps track of the stages in
+// the protocol and returns a different set of network stats and RetryWindow for
+// each stage, making it easier to write accurate assertions in unit tests.
+class MockFederatedProtocol : public FederatedProtocol {
+ public:
+ constexpr static NetworkStats
+ kPostEligibilityCheckinPlanUriReceivedNetworkStats = {
+ .bytes_downloaded = 280,
+ .bytes_uploaded = 380,
+ .network_duration = absl::Milliseconds(25)};
+ constexpr static NetworkStats kPostEligibilityCheckinNetworkStats = {
+ .bytes_downloaded = 300,
+ .bytes_uploaded = 400,
+ .network_duration = absl::Milliseconds(50)};
+ constexpr static NetworkStats kPostReportEligibilityEvalErrorNetworkStats = {
+ .bytes_downloaded = 400,
+ .bytes_uploaded = 500,
+ .network_duration = absl::Milliseconds(150)};
+ constexpr static NetworkStats kPostCheckinPlanUriReceivedNetworkStats = {
+ .bytes_downloaded = 2970,
+ .bytes_uploaded = 3970,
+ .network_duration = absl::Milliseconds(225)};
+ constexpr static NetworkStats kPostCheckinNetworkStats = {
+ .bytes_downloaded = 3000,
+ .bytes_uploaded = 4000,
+ .network_duration = absl::Milliseconds(250)};
+ constexpr static NetworkStats kPostReportCompletedNetworkStats = {
+ .bytes_downloaded = 30000,
+ .bytes_uploaded = 40000,
+ .network_duration = absl::Milliseconds(350)};
+ constexpr static NetworkStats kPostReportNotCompletedNetworkStats = {
+ .bytes_downloaded = 29999,
+ .bytes_uploaded = 39999,
+ .network_duration = absl::Milliseconds(450)};
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetInitialRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(0L);
+ retry_window.mutable_delay_max()->set_seconds(1L);
+ *retry_window.mutable_retry_token() = "INITIAL";
+ return retry_window;
+ }
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetPostEligibilityCheckinRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(100L);
+ retry_window.mutable_delay_max()->set_seconds(101L);
+ *retry_window.mutable_retry_token() = "POST_ELIGIBILITY";
+ return retry_window;
+ }
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetPostCheckinRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(200L);
+ retry_window.mutable_delay_max()->set_seconds(201L);
+ *retry_window.mutable_retry_token() = "POST_CHECKIN";
+ return retry_window;
+ }
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetPostReportCompletedRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(300L);
+ retry_window.mutable_delay_max()->set_seconds(301L);
+ *retry_window.mutable_retry_token() = "POST_REPORT_COMPLETED";
+ return retry_window;
+ }
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetPostReportNotCompletedRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(400L);
+ retry_window.mutable_delay_max()->set_seconds(401L);
+ *retry_window.mutable_retry_token() = "POST_REPORT_NOT_COMPLETED";
+ return retry_window;
+ }
+
+ explicit MockFederatedProtocol() {}
+
+ // We override the real FederatedProtocol methods so that we can intercept the
+ // progression of protocol stages, and expose dedicate gMock-overridable
+ // methods for use in tests.
+ absl::StatusOr<EligibilityEvalCheckinResult> EligibilityEvalCheckin(
+ std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback) final {
+ absl::StatusOr<EligibilityEvalCheckinResult> result =
+ MockEligibilityEvalCheckin();
+ if (result.ok() &&
+ std::holds_alternative<FederatedProtocol::EligibilityEvalTask>(
+ *result)) {
+ network_stats_ = kPostEligibilityCheckinPlanUriReceivedNetworkStats;
+ payload_uris_received_callback(
+ std::get<FederatedProtocol::EligibilityEvalTask>(*result));
+ }
+ network_stats_ = kPostEligibilityCheckinNetworkStats;
+ retry_window_ = GetPostEligibilityCheckinRetryWindow();
+ return result;
+ };
+ MOCK_METHOD(absl::StatusOr<EligibilityEvalCheckinResult>,
+ MockEligibilityEvalCheckin, ());
+
+ void ReportEligibilityEvalError(absl::Status error_status) final {
+ network_stats_ = kPostReportEligibilityEvalErrorNetworkStats;
+ retry_window_ = GetPostEligibilityCheckinRetryWindow();
+ MockReportEligibilityEvalError(error_status);
+ }
+ MOCK_METHOD(void, MockReportEligibilityEvalError,
+ (absl::Status error_status));
+
+ absl::StatusOr<CheckinResult> Checkin(
+ const std::optional<
+ ::google::internal::federatedml::v2::TaskEligibilityInfo>&
+ task_eligibility_info,
+ std::function<void(const FederatedProtocol::TaskAssignment&)>
+ payload_uris_received_callback) final {
+ absl::StatusOr<CheckinResult> result = MockCheckin(task_eligibility_info);
+ if (result.ok() &&
+ std::holds_alternative<FederatedProtocol::TaskAssignment>(*result)) {
+ network_stats_ = kPostCheckinPlanUriReceivedNetworkStats;
+ payload_uris_received_callback(
+ std::get<FederatedProtocol::TaskAssignment>(*result));
+ }
+ retry_window_ = GetPostCheckinRetryWindow();
+ network_stats_ = kPostCheckinNetworkStats;
+ return result;
+ };
+ MOCK_METHOD(absl::StatusOr<CheckinResult>, MockCheckin,
+ (const std::optional<
+ ::google::internal::federatedml::v2::TaskEligibilityInfo>&
+ task_eligibility_info));
+
+ absl::StatusOr<MultipleTaskAssignments> PerformMultipleTaskAssignments(
+ const std::vector<std::string>& task_names) final {
+ absl::StatusOr<MultipleTaskAssignments> result =
+ MockPerformMultipleTaskAssignments(task_names);
+ retry_window_ = GetPostCheckinRetryWindow();
+ network_stats_ = kPostCheckinPlanUriReceivedNetworkStats;
+ return result;
+ };
+
+ MOCK_METHOD(absl::StatusOr<MultipleTaskAssignments>,
+ MockPerformMultipleTaskAssignments,
+ (const std::vector<std::string>& task_names));
+
+ absl::Status ReportCompleted(
+ ComputationResults results, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) final {
+ network_stats_ = kPostReportCompletedNetworkStats;
+ retry_window_ = GetPostReportCompletedRetryWindow();
+ return MockReportCompleted(std::move(results), plan_duration,
+ aggregation_session_id);
+ };
+ MOCK_METHOD(absl::Status, MockReportCompleted,
+ (ComputationResults results, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id));
+
+ absl::Status ReportNotCompleted(
+ engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) final {
+ network_stats_ = kPostReportNotCompletedNetworkStats;
+ retry_window_ = GetPostReportNotCompletedRetryWindow();
+ return MockReportNotCompleted(phase_outcome, plan_duration,
+ aggregation_session_id);
+ };
+ MOCK_METHOD(absl::Status, MockReportNotCompleted,
+ (engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id));
+
+ ::google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow()
+ final {
+ return retry_window_;
+ }
+
+ NetworkStats GetNetworkStats() final { return network_stats_; }
+
+ private:
+ NetworkStats network_stats_;
+ ::google::internal::federatedml::v2::RetryWindow retry_window_ =
+ GetInitialRetryWindow();
+};
+
+class MockLogManager : public LogManager {
+ public:
+ MOCK_METHOD(void, LogDiag, (ProdDiagCode), (override));
+ MOCK_METHOD(void, LogDiag, (DebugDiagCode), (override));
+ MOCK_METHOD(void, LogToLongHistogram,
+ (fcp::client::HistogramCounters, int, int,
+ fcp::client::engine::DataSourceType, int64_t),
+ (override));
+ MOCK_METHOD(void, SetModelIdentifier, (const std::string&), (override));
+};
+
+class MockOpStatsLogger : public ::fcp::client::opstats::OpStatsLogger {
+ public:
+ MOCK_METHOD(
+ void, AddEventAndSetTaskName,
+ (const std::string& task_name,
+ ::fcp::client::opstats::OperationalStats::Event::EventKind event),
+ (override));
+ MOCK_METHOD(
+ void, AddEvent,
+ (::fcp::client::opstats::OperationalStats::Event::EventKind event),
+ (override));
+ MOCK_METHOD(void, AddEventWithErrorMessage,
+ (::fcp::client::opstats::OperationalStats::Event::EventKind event,
+ const std::string& error_message),
+ (override));
+ MOCK_METHOD(void, UpdateDatasetStats,
+ (const std::string& collection_uri, int additional_example_count,
+ int64_t additional_example_size_bytes),
+ (override));
+ MOCK_METHOD(void, SetNetworkStats, (const NetworkStats& network_stats),
+ (override));
+ MOCK_METHOD(void, SetRetryWindow,
+ (google::internal::federatedml::v2::RetryWindow retry_window),
+ (override));
+ MOCK_METHOD(::fcp::client::opstats::OpStatsDb*, GetOpStatsDb, (), (override));
+ MOCK_METHOD(bool, IsOpStatsEnabled, (), (const override));
+ MOCK_METHOD(absl::Status, CommitToStorage, (), (override));
+ MOCK_METHOD(std::string, GetCurrentTaskName, (), (override));
+};
+
+class MockSimpleTaskEnvironment : public SimpleTaskEnvironment {
+ public:
+ MOCK_METHOD(std::string, GetBaseDir, (), (override));
+ MOCK_METHOD(std::string, GetCacheDir, (), (override));
+ MOCK_METHOD((absl::StatusOr<std::unique_ptr<ExampleIterator>>),
+ CreateExampleIterator,
+ (const google::internal::federated::plan::ExampleSelector&
+ example_selector),
+ (override));
+ MOCK_METHOD((absl::StatusOr<std::unique_ptr<ExampleIterator>>),
+ CreateExampleIterator,
+ (const google::internal::federated::plan::ExampleSelector&
+ example_selector,
+ const SelectorContext& selector_context),
+ (override));
+ MOCK_METHOD(std::unique_ptr<fcp::client::http::HttpClient>, CreateHttpClient,
+ (), (override));
+ MOCK_METHOD(bool, TrainingConditionsSatisfied, (), (override));
+};
+
+class MockExampleIterator : public ExampleIterator {
+ public:
+ MOCK_METHOD(absl::StatusOr<std::string>, Next, (), (override));
+ MOCK_METHOD(void, Close, (), (override));
+};
+
+// An iterator that passes through each example in the dataset once.
+class SimpleExampleIterator : public ExampleIterator {
+ public:
+ // Uses the given bytes as the examples to return.
+ explicit SimpleExampleIterator(std::vector<const char*> examples);
+ // Passes through each of the examples in the `Dataset.client_data.example`
+ // field.
+ explicit SimpleExampleIterator(
+ google::internal::federated::plan::Dataset dataset);
+ // Passes through each of the examples in the
+ // `Dataset.client_data.selected_example.example` field, whose example
+ // collection URI matches the provided `collection_uri`.
+ SimpleExampleIterator(google::internal::federated::plan::Dataset dataset,
+ absl::string_view collection_uri);
+ absl::StatusOr<std::string> Next() override;
+ void Close() override {}
+
+ protected:
+ std::vector<std::string> examples_;
+ int index_ = 0;
+};
+
+struct ComputationArtifacts {
+ // The path to the file containing the plan data.
+ std::string plan_filepath;
+ // The already-parsed plan data.
+ google::internal::federated::plan::ClientOnlyPlan plan;
+ // The test dataset.
+ google::internal::federated::plan::Dataset dataset;
+ // The path to the file containing the initial checkpoint data (not set for
+ // local compute task artifacts).
+ std::string checkpoint_filepath;
+ // The initial checkpoint data, as a string (not set for local compute task
+ // artifacts).
+ std::string checkpoint;
+ // The Federated Select slice data (not set for local compute task artifacts).
+ google::internal::federated::plan::SlicesTestDataset federated_select_slices;
+};
+
+absl::StatusOr<ComputationArtifacts> LoadFlArtifacts();
+
+class MockFlags : public Flags {
+ public:
+ MOCK_METHOD(int64_t, condition_polling_period_millis, (), (const, override));
+ MOCK_METHOD(int64_t, tf_execution_teardown_grace_period_millis, (),
+ (const, override));
+ MOCK_METHOD(int64_t, tf_execution_teardown_extended_period_millis, (),
+ (const, override));
+ MOCK_METHOD(int64_t, grpc_channel_deadline_seconds, (), (const, override));
+ MOCK_METHOD(bool, log_tensorflow_error_messages, (), (const, override));
+ MOCK_METHOD(bool, enable_opstats, (), (const, override));
+ MOCK_METHOD(int64_t, opstats_ttl_days, (), (const, override));
+ MOCK_METHOD(int64_t, opstats_db_size_limit_bytes, (), (const, override));
+ MOCK_METHOD(int64_t, federated_training_transient_errors_retry_delay_secs, (),
+ (const, override));
+ MOCK_METHOD(float,
+ federated_training_transient_errors_retry_delay_jitter_percent,
+ (), (const, override));
+ MOCK_METHOD(int64_t, federated_training_permanent_errors_retry_delay_secs, (),
+ (const, override));
+ MOCK_METHOD(float,
+ federated_training_permanent_errors_retry_delay_jitter_percent,
+ (), (const, override));
+ MOCK_METHOD(std::vector<int32_t>, federated_training_permanent_error_codes,
+ (), (const, override));
+ MOCK_METHOD(bool, use_tflite_training, (), (const, override));
+ MOCK_METHOD(bool, enable_grpc_with_http_resource_support, (),
+ (const, override));
+ MOCK_METHOD(bool, enable_grpc_with_eligibility_eval_http_resource_support, (),
+ (const, override));
+ MOCK_METHOD(bool, ensure_dynamic_tensors_are_released, (), (const, override));
+ MOCK_METHOD(int32_t, large_tensor_threshold_for_dynamic_allocation, (),
+ (const, override));
+ MOCK_METHOD(bool, disable_http_request_body_compression, (),
+ (const, override));
+ MOCK_METHOD(bool, use_http_federated_compute_protocol, (), (const, override));
+ MOCK_METHOD(bool, enable_computation_id, (), (const, override));
+ MOCK_METHOD(int32_t, waiting_period_sec_for_cancellation, (),
+ (const, override));
+ MOCK_METHOD(bool, enable_federated_select, (), (const, override));
+ MOCK_METHOD(int32_t, num_threads_for_tflite, (), (const, override));
+ MOCK_METHOD(bool, disable_tflite_delegate_clustering, (), (const, override));
+ MOCK_METHOD(bool, enable_example_query_plan_engine, (), (const, override));
+ MOCK_METHOD(bool, support_constant_tf_inputs, (), (const, override));
+ MOCK_METHOD(bool, http_protocol_supports_multiple_task_assignments, (),
+ (const, override));
+};
+
+// Helper methods for extracting opstats fields from TF examples.
+std::string ExtractSingleString(const tensorflow::Example& example,
+ const char key[]);
+google::protobuf::RepeatedPtrField<std::string> ExtractRepeatedString(
+ const tensorflow::Example& example, const char key[]);
+int64_t ExtractSingleInt64(const tensorflow::Example& example,
+ const char key[]);
+google::protobuf::RepeatedField<int64_t> ExtractRepeatedInt64(
+ const tensorflow::Example& example, const char key[]);
+
+class MockOpStatsDb : public ::fcp::client::opstats::OpStatsDb {
+ public:
+ MOCK_METHOD(absl::StatusOr<::fcp::client::opstats::OpStatsSequence>, Read, (),
+ (override));
+ MOCK_METHOD(absl::Status, Transform,
+ (std::function<void(::fcp::client::opstats::OpStatsSequence&)>),
+ (override));
+};
+
+class MockPhaseLogger : public PhaseLogger {
+ public:
+ MOCK_METHOD(
+ void, UpdateRetryWindowAndNetworkStats,
+ (const ::google::internal::federatedml::v2::RetryWindow& retry_window,
+ const NetworkStats& network_stats),
+ (override));
+ MOCK_METHOD(void, SetModelIdentifier, (absl::string_view model_identifier),
+ (override));
+ MOCK_METHOD(void, LogTaskNotStarted, (absl::string_view error_message),
+ (override));
+ MOCK_METHOD(void, LogNonfatalInitializationError, (absl::Status error_status),
+ (override));
+ MOCK_METHOD(void, LogFatalInitializationError, (absl::Status error_status),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinStarted, (), (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinIOError,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinInvalidPayloadError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinClientInterrupted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinServerAborted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalNotConfigured,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinTurnedAway,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinPlanUriReceived,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinCompleted,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time time_before_plan_download),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationStarted, (), (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationInvalidArgument,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationExampleIteratorError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationTensorflowError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationInterrupted,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationCompleted,
+ (const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinStarted, (), (override));
+ MOCK_METHOD(void, LogCheckinIOError,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinInvalidPayload,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinClientInterrupted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinServerAborted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinTurnedAway,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinPlanUriReceived,
+ (absl::string_view task_name, const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogCheckinCompleted,
+ (absl::string_view task_name, const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time time_before_plan_download, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogComputationStarted, (), (override));
+ MOCK_METHOD(void, LogComputationInvalidArgument,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogComputationExampleIteratorError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogComputationIOError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogComputationTensorflowError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogComputationInterrupted,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogComputationCompleted,
+ (const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(absl::Status, LogResultUploadStarted, (), (override));
+ MOCK_METHOD(void, LogResultUploadIOError,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogResultUploadClientInterrupted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogResultUploadServerAborted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogResultUploadCompleted,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(absl::Status, LogFailureUploadStarted, (), (override));
+ MOCK_METHOD(void, LogFailureUploadIOError,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogFailureUploadClientInterrupted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogFailureUploadServerAborted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogFailureUploadCompleted,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+};
+
+class MockFederatedSelectManager : public FederatedSelectManager {
+ public:
+ MOCK_METHOD(std::unique_ptr<engine::ExampleIteratorFactory>,
+ CreateExampleIteratorFactoryForUriTemplate,
+ (absl::string_view uri_template), (override));
+
+ MOCK_METHOD(NetworkStats, GetNetworkStats, (), (override));
+};
+
+class MockFederatedSelectExampleIteratorFactory
+ : public FederatedSelectExampleIteratorFactory {
+ public:
+ MOCK_METHOD(absl::StatusOr<std::unique_ptr<ExampleIterator>>,
+ CreateExampleIterator,
+ (const ::google::internal::federated::plan::ExampleSelector&
+ example_selector),
+ (override));
+};
+
+class MockSecAggRunnerFactory : public SecAggRunnerFactory {
+ public:
+ MOCK_METHOD(std::unique_ptr<SecAggRunner>, CreateSecAggRunner,
+ (std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
+ std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
+ SecAggEventPublisher* secagg_event_publisher,
+ LogManager* log_manager,
+ InterruptibleRunner* interruptible_runner,
+ int64_t expected_number_of_clients,
+ int64_t minimum_surviving_clients_for_reconstruction),
+ (override));
+};
+
+class MockSecAggRunner : public SecAggRunner {
+ public:
+ MOCK_METHOD(absl::Status, Run, (ComputationResults results), (override));
+};
+
+class MockSecAggSendToServerBase : public SecAggSendToServerBase {
+ MOCK_METHOD(void, Send, (secagg::ClientToServerWrapperMessage * message),
+ (override));
+};
+
+class MockSecAggProtocolDelegate : public SecAggProtocolDelegate {
+ public:
+ MOCK_METHOD(absl::StatusOr<uint64_t>, GetModulus, (const std::string& key),
+ (override));
+ MOCK_METHOD(absl::StatusOr<secagg::ServerToClientWrapperMessage>,
+ ReceiveServerMessage, (), (override));
+ MOCK_METHOD(void, Abort, (), (override));
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_TEST_HELPERS_H_
diff --git a/fcp/client/testing/BUILD b/fcp/client/testing/BUILD
new file mode 100644
index 0000000..18f53fe
--- /dev/null
+++ b/fcp/client/testing/BUILD
@@ -0,0 +1,51 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "utils",
+ testonly = True,
+ hdrs = ["utils.h"],
+ deps = [
+ "//fcp/base",
+ "//fcp/client:interfaces",
+ "//fcp/client:simple_task_environment",
+ "//fcp/client/engine:engine_cc_proto",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/protos:plan_proto_cc",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "utils_test",
+ size = "small",
+ srcs = ["utils_test.cc"],
+ deps = [
+ ":utils",
+ "//fcp/protos:plan_cc_proto",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/client/testing/utils.h b/fcp/client/testing/utils.h
new file mode 100644
index 0000000..736d30b
--- /dev/null
+++ b/fcp/client/testing/utils.h
@@ -0,0 +1,127 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_CLIENT_TESTING_UTILS_H_
+#define FCP_CLIENT_TESTING_UTILS_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/repeated_field.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/client/engine/engine.pb.h"
+#include "fcp/client/files.h"
+#include "fcp/client/simple_task_environment.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp::client::testing {
+
+using google::internal::federated::plan::Dataset;
+using google::internal::federated::plan::ExampleSelector;
+using google::internal::federated::plan::Plan;
+using google::internal::federatedml::v2::RetryWindow;
+
+inline std::string MakeTestFileName(absl::string_view dir,
+ absl::string_view prefix,
+ absl::string_view suffix) {
+ return ConcatPath(StripTrailingPathSeparator(dir),
+ absl::StrCat(prefix, suffix));
+}
+
+// Basic implementation of ExampleIterator for testing purposes.
+// It iterates over examples from a given dataset.
+class TestExampleIterator : public ExampleIterator {
+ public:
+ explicit TestExampleIterator(const Dataset::ClientDataset* dataset)
+ : next_example_(dataset->example().begin()),
+ end_(dataset->example().end()) {}
+
+ absl::StatusOr<std::string> Next() override {
+ if (next_example_ == end_) {
+ return absl::OutOfRangeError("");
+ }
+ return *(next_example_++);
+ }
+
+ void Close() override {}
+
+ private:
+ google::protobuf::RepeatedPtrField<std::string>::const_iterator next_example_;
+ google::protobuf::RepeatedPtrField<std::string>::const_iterator end_;
+};
+
+// Implementation of TaskEnvironment, the interface by which the client plan
+// engine interacts with the environment, that allows tests to provide a dataset
+// as input and consume the output checkpoint.
+class TestTaskEnvironment : public SimpleTaskEnvironment {
+ public:
+ explicit TestTaskEnvironment(const Dataset::ClientDataset* dataset,
+ const std::string& base_dir)
+ : dataset_(dataset), base_dir_(base_dir) {}
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const google::internal::federated::plan::ExampleSelector&
+ example_selector) override {
+ SelectorContext unused;
+ return CreateExampleIterator(example_selector, unused);
+ }
+
+ absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
+ const ExampleSelector& example_selector,
+ const SelectorContext& selector_context) override {
+ std::unique_ptr<ExampleIterator> iter =
+ std::make_unique<TestExampleIterator>(dataset_);
+ return std::move(iter);
+ }
+
+ std::string GetBaseDir() override { return base_dir_; }
+
+ std::string GetCacheDir() override { return base_dir_; }
+
+ private:
+ bool TrainingConditionsSatisfied() override { return true; }
+
+ const Dataset::ClientDataset* dataset_;
+ std::string base_dir_;
+ std::string checkpoint_file_;
+};
+
+// Implementation of client file API that creates files in a temporary test
+// directory.
+class TestFiles : public Files {
+ public:
+ explicit TestFiles(absl::string_view test_dir) : test_dir_(test_dir) {}
+ absl::StatusOr<std::string> CreateTempFile(
+ const std::string& prefix, const std::string& suffix) override {
+ return MakeTestFileName(test_dir_, prefix, suffix);
+ }
+
+ private:
+ std::string test_dir_;
+};
+
+} // namespace fcp::client::testing
+
+#endif // FCP_CLIENT_TESTING_UTILS_H_
diff --git a/fcp/client/testing/utils_test.cc b/fcp/client/testing/utils_test.cc
new file mode 100644
index 0000000..90f8484
--- /dev/null
+++ b/fcp/client/testing/utils_test.cc
@@ -0,0 +1,89 @@
+#include "fcp/client/testing/utils.h"
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/status/statusor.h"
+#include "fcp/protos/plan.pb.h"
+
+namespace fcp::client::testing {
+namespace {
+
+using google::internal::federated::plan::Dataset;
+
+// The fixture for testing class Foo.
+class TestExampleIteratorTest : public ::testing::Test {};
+
+TEST_F(TestExampleIteratorTest, NextCallsForEmptyDatset) {
+ Dataset::ClientDataset client_dataset;
+
+ TestExampleIterator iterator(&client_dataset);
+
+ absl::StatusOr<std::string> element;
+
+ element = iterator.Next();
+ EXPECT_FALSE(element.ok());
+
+ element = iterator.Next();
+ EXPECT_FALSE(element.ok());
+
+ element = iterator.Next();
+ EXPECT_FALSE(element.ok());
+}
+
+TEST_F(TestExampleIteratorTest, NextCallsForSingleElementDataset) {
+ Dataset::ClientDataset client_dataset;
+ client_dataset.add_example("abc");
+
+ TestExampleIterator iterator(&client_dataset);
+
+ absl::StatusOr<std::string> element;
+
+ element = iterator.Next();
+ ASSERT_TRUE(element.ok());
+ EXPECT_EQ(element.value(), "abc");
+
+ element = iterator.Next();
+ EXPECT_FALSE(element.ok());
+
+ element = iterator.Next();
+ EXPECT_FALSE(element.ok());
+}
+
+// Tests that the Foo::Bar() method does Abc.
+TEST_F(TestExampleIteratorTest, NextCallsForThreeElementDataset) {
+ Dataset::ClientDataset client_dataset;
+ client_dataset.add_example("a");
+ client_dataset.add_example("b");
+ client_dataset.add_example("c");
+
+ TestExampleIterator iterator(&client_dataset);
+
+ absl::StatusOr<std::string> element;
+
+ element = iterator.Next();
+ ASSERT_TRUE(element.ok());
+ EXPECT_EQ(element.value(), "a");
+
+ element = iterator.Next();
+ ASSERT_TRUE(element.ok());
+ EXPECT_EQ(element.value(), "b");
+
+ element = iterator.Next();
+ ASSERT_TRUE(element.ok());
+ EXPECT_EQ(element.value(), "c");
+
+ element = iterator.Next();
+ EXPECT_FALSE(element.ok());
+
+ element = iterator.Next();
+ EXPECT_FALSE(element.ok());
+}
+
+} // namespace
+} // namespace fcp::client::testing
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/fcp/config.bzl b/fcp/config.bzl
new file mode 100644
index 0000000..39acdab
--- /dev/null
+++ b/fcp/config.bzl
@@ -0,0 +1,33 @@
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# cc_* rules should include this list in copts. If additional cc_*-wide
+# customization appears, we might want to switch to macros.
+
+"""This is the definition site for things we want to keep consistent, like copts."""
+
+FCP_COPTS = [
+]
+
+FCP_BAREMETAL_COPTS = FCP_COPTS + [
+ "-DFCP_BAREMETAL",
+ "-nostdlib",
+ "-fno-exceptions",
+ "-ffreestanding",
+ "-Wno-unused-parameter",
+]
+
+FCP_NANOLIBC_COPTS = FCP_BAREMETAL_COPTS + [
+ "-DFCP_NANOLIBC",
+]
diff --git a/fcp/demo/BUILD b/fcp/demo/BUILD
new file mode 100644
index 0000000..c59b9fa
--- /dev/null
+++ b/fcp/demo/BUILD
@@ -0,0 +1,338 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+# The public interface for this package, providing the various Federated Program platform
+# components needed to run a Federated Program using Federated Compute clients.
+py_library(
+ name = "demo",
+ srcs = ["__init__.py"],
+ srcs_version = "PY3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":federated_computation",
+ ":federated_context",
+ ":federated_data_source",
+ ],
+)
+
+py_library(
+ name = "aggregations",
+ srcs = ["aggregations.py"],
+ data = ["@pybind11_abseil//pybind11_abseil:status.so"],
+ srcs_version = "PY3",
+ deps = [
+ ":http_actions",
+ ":media",
+ "//fcp/aggregation/protocol:configuration_py_pb2",
+ "//fcp/aggregation/protocol:py_pb2",
+ "//fcp/aggregation/protocol/python:aggregation_protocol",
+ "//fcp/aggregation/tensorflow/python:aggregation_protocols",
+ "//fcp/protos:plan_py_pb2",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ "@com_google_googleapis//google/rpc:code_py_proto",
+ "@googleapis_for_longrunning//google/longrunning:longrunning_py_proto",
+ ],
+)
+
+py_test(
+ name = "aggregations_test",
+ srcs = ["aggregations_test.py"],
+ data = ["@pybind11_abseil//pybind11_abseil:status.so"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":aggregations",
+ ":http_actions",
+ ":media",
+ ":test_utils",
+ "//fcp/aggregation/protocol:py_pb2",
+ "//fcp/aggregation/protocol/python:aggregation_protocol",
+ "//fcp/aggregation/tensorflow/python:aggregation_protocols",
+ "//fcp/protos:plan_py_pb2",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ ],
+)
+
+py_library(
+ name = "checkpoint_tensor_reference",
+ srcs = ["checkpoint_tensor_reference.py"],
+ srcs_version = "PY3",
+)
+
+py_test(
+ name = "checkpoint_tensor_reference_test",
+ srcs = ["checkpoint_tensor_reference_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":checkpoint_tensor_reference",
+ ":test_utils",
+ ],
+)
+
+py_library(
+ name = "eligibility_eval_tasks",
+ srcs = ["eligibility_eval_tasks.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":http_actions",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ "@com_google_googleapis//google/rpc:code_py_proto",
+ ],
+)
+
+py_test(
+ name = "eligibility_eval_tasks_test",
+ srcs = ["eligibility_eval_tasks_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":eligibility_eval_tasks",
+ ":http_actions",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ "@com_google_googleapis//google/rpc:code_py_proto",
+ ],
+)
+
+py_library(
+ name = "federated_computation",
+ srcs = ["federated_computation.py"],
+ srcs_version = "PY3",
+)
+
+py_test(
+ name = "federated_computation_test",
+ srcs = ["federated_computation_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [":federated_computation"],
+)
+
+py_library(
+ name = "federated_context",
+ srcs = ["federated_context.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":checkpoint_tensor_reference",
+ ":federated_computation",
+ ":federated_data_source",
+ ":server",
+ "//fcp/artifact_building:artifact_constants",
+ "//fcp/artifact_building:checkpoint_utils",
+ "//fcp/artifact_building:data_spec",
+ "//fcp/artifact_building:federated_compute_plan_builder",
+ "//fcp/artifact_building:plan_utils",
+ "//fcp/artifact_building:variable_helpers",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_test(
+ name = "federated_context_test",
+ size = "medium",
+ srcs = ["federated_context_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":federated_computation",
+ ":federated_context",
+ ":federated_data_source",
+ ":server",
+ ":test_utils",
+ "//fcp/artifact_building:artifact_constants",
+ "//fcp/artifact_building:federated_compute_plan_builder",
+ "//fcp/artifact_building:plan_utils",
+ "//fcp/artifact_building:variable_helpers",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_library(
+ name = "federated_data_source",
+ srcs = ["federated_data_source.py"],
+ srcs_version = "PY3",
+ deps = [
+ "//fcp/protos:plan_py_pb2",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ ],
+)
+
+py_test(
+ name = "federated_data_source_test",
+ srcs = ["federated_data_source_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":federated_data_source",
+ "//fcp/protos:plan_py_pb2",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ ],
+)
+
+py_test(
+ name = "federated_program_test",
+ size = "medium",
+ srcs = ["federated_program_test.py"],
+ data = ["//fcp/client:client_runner_main"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":demo",
+ "//fcp/client:client_runner_example_data_py_pb2",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+py_library(
+ name = "http_actions",
+ srcs = ["http_actions.py"],
+ srcs_version = "PY3",
+ deps = ["@com_google_googleapis//google/api:annotations_py_proto"],
+)
+
+py_test(
+ name = "http_actions_test",
+ size = "medium",
+ srcs = ["http_actions_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":http_actions",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ ],
+)
+
+py_library(
+ name = "media",
+ srcs = ["media.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":http_actions",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ ],
+)
+
+py_test(
+ name = "media_test",
+ srcs = ["media_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":http_actions",
+ ":media",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ ],
+)
+
+py_library(
+ name = "plan_utils",
+ srcs = ["plan_utils.py"],
+ srcs_version = "PY3",
+ deps = [
+ "//fcp/protos:plan_py_pb2",
+ "//fcp/tensorflow:serve_slices_py",
+ ],
+)
+
+py_test(
+ name = "plan_utils_test",
+ size = "medium",
+ srcs = ["plan_utils_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":plan_utils",
+ ":test_utils",
+ "//fcp/protos:plan_py_pb2",
+ "//fcp/tensorflow:serve_slices_py",
+ ],
+)
+
+py_library(
+ name = "server",
+ srcs = ["server.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":aggregations",
+ ":eligibility_eval_tasks",
+ ":http_actions",
+ ":media",
+ ":plan_utils",
+ ":task_assignments",
+ "//fcp/protos:plan_py_pb2",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ ],
+)
+
+py_test(
+ name = "server_test",
+ size = "medium",
+ srcs = ["server_test.py"],
+ data = ["//fcp/client:client_runner_main"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":plan_utils",
+ ":server",
+ ":test_utils",
+ "//fcp/protos:plan_py_pb2",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ "//fcp/tensorflow:external_dataset_py",
+ "@googleapis_for_longrunning//google/longrunning:longrunning_py_proto",
+ ],
+)
+
+py_library(
+ name = "task_assignments",
+ srcs = ["task_assignments.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":aggregations",
+ ":http_actions",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ "@com_google_googleapis//google/rpc:code_py_proto",
+ "@googleapis_for_longrunning//google/longrunning:longrunning_py_proto",
+ ],
+)
+
+py_test(
+ name = "task_assignments_test",
+ srcs = ["task_assignments_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":aggregations",
+ ":http_actions",
+ ":task_assignments",
+ "//fcp/protos/federatedcompute:federated_compute_py_pb2",
+ "@com_google_googleapis//google/rpc:code_py_proto",
+ ],
+)
+
+py_library(
+ name = "test_utils",
+ testonly = True,
+ srcs = ["test_utils.py"],
+)
+
+py_test(
+ name = "test_utils_test",
+ size = "medium",
+ srcs = ["test_utils_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [":test_utils"],
+)
diff --git a/fcp/demo/README.md b/fcp/demo/README.md
new file mode 100644
index 0000000..49cc687
--- /dev/null
+++ b/fcp/demo/README.md
@@ -0,0 +1,242 @@
+# Cross-Device Federated Computations Demo
+
+This directory contains an example
+[Federated Program platform](https://github.com/tensorflow/federated/blob/main/tensorflow_federated/python/program/README.md#platform-specific-components)
+implementation that's compatible with the Federated Compute client.
+
+The code in this directory prioritizes understandability over production
+scalability because many of the frameworks used to create robust servers are
+dependent on the intended deployment environment. Comments throughout the code
+and documentation call out where changes should be made for a production
+implementation. Unless otherwise noted, the libraries in other directories are
+production quality.
+
+See
+[Towards Federated Learning at Scale: System Design](https://arxiv.org/abs/1902.01046)
+(TFLaS) for additional information on scaling Federated Learning.
+
+## Example Usage
+
+> 💡 See `federated_program_test.py` for a working example of configuring and
+> running a Federated Program using this package.
+
+The following example program is based on the example in the
+[TFF documentation](https://github.com/tensorflow/federated/blob/main/tensorflow_federated/python/program/README.md#program):
+
+```python
+from fcp import demo
+from fcp.protos import plan_pb2
+
+# Parameters set by the customer.
+_OUTPUT_DIR = flags.DEFINE_string('output_dir', None, 'The output path.')
+_POPULATION_NAME = flags.DEFINE_string(
+ 'population_name', None, 'The identifier for the client population.')
+_COLLECTION_URI = flags.DEFINE_string(
+ 'collection_uri', None,
+ 'A URI identifying the example collection to read from.')
+
+
+def main():
+ # Parameters set by the program.
+ total_rounds = 10
+ num_clients = 3
+
+ # Configure the platform-specific components.
+ with demo.FederatedContext(
+ _POPULATION_NAME.value,
+ base_context=tff.framework.get_context_stack().current) as context:
+ data_source = demo.FederatedDataSource(
+ _POPULATION_NAME.value,
+ plan_pb2.ExampleSelector(collection_uri=_COLLECTION_URI.value))
+
+ # Configure the platform-agnostic components.
+ summary_dir = os.path.join(_OUTPUT_DIR.value, 'summary')
+ output_managers = [
+ tff.program.LoggingReleaseManager(),
+ tensorboard_manager = tff.program.TensorBoardReleaseManager(summary_dir),
+ ]
+ program_state_dir = os.path.join(..., 'program_state')
+ program_state_manager = tff.program.FileProgramStateManager(
+ program_state_dir)
+
+ # Define the computations.
+ initialize = ...
+ train = ...
+
+ # Execute the computations using program logic.
+ tff.framework.set_default_context(context)
+ train_federated_model(
+ initialize=initialize,
+ train=train,
+ data_source=data_source,
+ total_rounds=total_rounds,
+ num_clients=num_clients,
+ output_managers=output_managers,
+ program_state_manager=program_state_manager)
+```
+
+## Code Structure
+
+```mermaid
+flowchart
+ client(Client)
+
+ subgraph FP[Federated Program Process]
+ federated_program(Federated Program)
+ style federated_program color:#333,fill:#bbb,stroke:#666,stroke-width:3px;
+
+ subgraph Server[In-Process Server]
+ server(server.py)
+ http_actions(http_actions.py)
+ plan_utils(plan_utils.py)
+
+ subgraph Handlers[HTTP Handlers]
+ aggregations(aggregations.py)
+ eligibility_eval_tasks(eligibility_eval_tasks.py)
+ media(media.py)
+ task_assignments(task_assignments.py)
+ end
+ end
+
+ subgraph FP_Platform[Federated Program Platform]
+ federated_context(federated_context.py)
+ federated_computation(federated_computation.py)
+ federated_data_source(federated_data_source.py)
+ checkpoint_tensor_reference(checkpoint_tensor_reference.py)
+ end
+ end
+
+ client & server --> Handlers
+ server --> http_actions & plan_utils
+ Handlers --> http_actions
+ federated_program --> federated_context & federated_computation & federated_data_source
+ federated_context --> checkpoint_tensor_reference & server
+```
+
+### Client
+
+The [Federated Computations Client](../client)
+library is used by applications running on end-user devices to run
+server-defined computations over on-device data and report back results (such as
+updated model weights) to be aggregated by the server.
+
+> 💡 See `federated_program_test.py` for command-line flags that should be used
+> when running `//fcp/client:client_runner_main`.
+
+> ⚠️ The client requires TLS when connecting to any host other than `localhost`.
+> The server's public and private keys will need to be provided to the
+> `demo.FederatedContext` constructor, and the corresponding CA certificate will
+> need to be passed to the client library (e.g., via `--test_cert` for
+> `client_runner_main`).
+
+### Federated Program Platform
+
+The demo Federated Computations platform is a
+[Federated Program platform](https://github.com/tensorflow/federated/blob/main/tensorflow_federated/python/program/README.md#platform-specific-components)
+implementation that allows TFF computations to be run using Federated
+Computations Clients.
+
+A production implementation could reuse much of this code as-is, though
+`federated_context.py` would need to be updated to communicate with remote
+server(s) instead of an in-process server.
+
+#### `federated_context.py`
+
+Contains a
+[`tff.program.FederatedContext`](https://www.tensorflow.org/federated/api_docs/python/tff/program/FederatedContext)
+implementation for running computations on the demo Federated Computations
+platform.
+
+This module uses libraries in
+[`fcp/artifact_building`](../artifact_building) to
+convert TFF computations to the format expected by the
+[in-process server](#in-process-server) and [client](#client).
+
+#### `federated_computation.py`
+
+Contains a
+[`tff.Computation`](https://www.tensorflow.org/federated/api_docs/python/tff/Computation)
+subclass for computations that will be run by the demo Federated Computations
+platform.
+
+#### `federated_data_source.py`
+
+Contains a
+[`tff.program.FederatedDataSource`](https://www.tensorflow.org/federated/api_docs/python/tff/program/FederatedDataSource)
+implementation for representing on-device data sources.
+
+#### `checkpoint_tensor_reference.py`
+
+Contains a
+[`tff.program.MaterializableValueReference`](https://www.tensorflow.org/federated/api_docs/python/tff/program/MaterializableValueReference)
+implementation that reads values from a TensorFlow checkpoint.
+
+### In-Process Server
+
+An in-process HTTP(S) server that implements the
+[Federated Compute protocol](../protos/federatedcompute).
+This server is responsible for selecting which clients will contribute to each
+computation invocation (**task**), broadcasting computations and state to
+clients, aggregating the results of on-device computation, and incorporating
+that aggregate information back into the model or metrics.
+
+In a production implementation, each Federated Compute protocol service would
+likely be handled by a separate replicated microservice, not a Python module.
+
+#### `server.py`
+
+Provides the interface for setting up and stopping the in-process HTTP(S) server
+and running computations provided by the `FederatedContext`. This module is
+responsible for notifying the various Federated Compute protocol service
+implementations when a new task has been added and then managing the lifecycle
+of that task.
+
+#### `eligibility_eval_tasks.py`
+
+Contains handlers for the Federated Compute protocol's
+[EligibilityEvalTasks](../protos/federatedcompute/eligibility_eval_tasks.proto)
+service. This service is responsible for serving optional pre-task-assignment
+computations that determines to which tasks each client is eligible to
+contribute. The demo platform does not currently support configuring Eligibility
+Eval tasks; clients are considered to be eligible for all tasks.
+
+#### `task_assignments.py`
+
+Contains handlers for the Federated Compute protocol's
+[TaskAssignments](../protos/federatedcompute/task_assignments.proto)
+service. This service is responsible for either assigning each client to a
+task -- or rejecting the client.
+
+#### `aggregations.py`
+
+Contains handlers for the Federated Compute protocol's
+[Aggregations](../protos/federatedcompute/aggregations.proto)
+service. This service is responsible for aggregating client-reported data using
+the
+[simple Aggregation Protocol](../aggregation/protocol/simple_aggregation)
+library.
+
+Note that the demo platform does not currently contain an implementation of the
+[SecureAggregations](../protos/federatedcompute/secure_aggregations.proto)
+service.
+
+#### `media.py`
+
+Contains handlers for HTTP uploads and downloads using `PUT` and `GET` requests.
+
+A production implementation will likely replace this module with a
+deployment-environment-specific download service; a custom upload service
+implementation may be needed since it should not persistently store
+client-uploaded data.
+
+#### `http_actions.py`
+
+Contains helper functions for converting proto-based handlers into HTTP
+handlers. This conversion mimics the Cloud Endpoints
+[HTTP to gRPC transcoding](https://cloud.google.com/endpoints/docs/grpc/transcoding).
+
+#### `plan_utils.py`
+
+Contains helper functions for constructing the TensorFlow graph and input
+checkpoint used by the client and running TensorFlow-based post-processing on
+aggregated results.
diff --git a/fcp/demo/__init__.py b/fcp/demo/__init__.py
new file mode 100644
index 0000000..c52f3e3
--- /dev/null
+++ b/fcp/demo/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Demo Federated Program platform using Federated Compute."""
+
+import sys
+
+from fcp.demo.federated_computation import FederatedComputation
+from fcp.demo.federated_context import FederatedContext
+from fcp.demo.federated_data_source import FederatedDataSource
+
+__version__ = '0.0.1'
+
+if sys.version_info[0] < 3 or sys.version_info[1] < 9:
+ raise Exception('Python version 3.9 or later is required')
diff --git a/fcp/demo/aggregations.py b/fcp/demo/aggregations.py
new file mode 100644
index 0000000..63ec901
--- /dev/null
+++ b/fcp/demo/aggregations.py
@@ -0,0 +1,554 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Action handlers for the Aggregations service."""
+
+import asyncio
+from collections.abc import Callable, Sequence
+import contextlib
+import dataclasses
+import enum
+import functools
+import http
+import queue
+import threading
+from typing import Optional
+import uuid
+
+from absl import logging
+
+from google.longrunning import operations_pb2
+from google.rpc import code_pb2
+from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2
+from fcp.aggregation.protocol import configuration_pb2
+from fcp.aggregation.protocol.python import aggregation_protocol
+from fcp.aggregation.tensorflow.python import aggregation_protocols
+from fcp.demo import http_actions
+from fcp.demo import media
+from fcp.protos import plan_pb2
+from fcp.protos.federatedcompute import aggregations_pb2
+from fcp.protos.federatedcompute import common_pb2
+from pybind11_abseil import status as absl_status
+
+
+class AggregationStatus(enum.Enum):
+ COMPLETED = 1
+ PENDING = 2
+ FAILED = 3
+ ABORTED = 4
+
+
+@dataclasses.dataclass
+class SessionStatus:
+ """The status of an aggregation session."""
+ # The current state of the aggregation session.
+ status: AggregationStatus = AggregationStatus.PENDING
+ # Number of clients that successfully started and completed the aggregation
+ # upload protocol.
+ num_clients_completed: int = 0
+ # Number of clients that started the aggregation upload protocol but failed
+ # to complete (e.g dropped out in the middle of the protocol).
+ num_clients_failed: int = 0
+ # Number of clients that started the aggregation upload protocol but have not
+ # yet finished (either successfully or not).
+ num_clients_pending: int = 0
+ # Number of clients that started the aggregation protocol but were aborted by
+ # the server before they could complete (e.g., if progress on the session was
+ # no longer needed).
+ num_clients_aborted: int = 0
+ # Number of inputs that were successfully aggregated and included in the
+ # final aggregate. Note that even if a client successfully completes the
+ # protocol (i.e., it is included in num_clients_completed), it is not
+ # guaranteed that the uploaded report is included in the final aggregate yet.
+ num_inputs_aggregated_and_included: int = 0
+ # Number of inputs that were received by the server and are pending (i.e.,
+ # the inputs have not been included in the final aggregate yet).
+ num_inputs_aggregated_and_pending: int = 0
+ # Number of inputs that were received by the server but discarded.
+ num_inputs_discarded: int = 0
+
+
+@dataclasses.dataclass(frozen=True)
+class AggregationRequirements:
+ # The minimum number of clients required before a result can be released
+ # outside this service. Note that aggregation does not automatically stop if
+ # minimum_clients_in_server_published_aggregate is met. It is up to callers
+ # to stop aggregation when they want.
+ minimum_clients_in_server_published_aggregate: int
+ # The Plan to execute.
+ plan: plan_pb2.Plan
+
+
+@dataclasses.dataclass
+class _ActiveClientData:
+ """Information about an active client."""
+ # The client's identifier in the aggregation protocol.
+ client_id: int
+ # Queue receiving the final status of the client connection (if closed by the
+ # aggregation protocol). At most one value will be written.
+ close_status: queue.SimpleQueue[absl_status.Status]
+ # The name of the resource to which the client should write its update.
+ resource_name: str
+
+
+@dataclasses.dataclass(eq=False)
+class _WaitData:
+ """Information about a pending wait operation."""
+ # The condition under which the wait should complete.
+ num_inputs_aggregated_and_included: Optional[int]
+ # The loop the caller is waiting on.
+ loop: asyncio.AbstractEventLoop = dataclasses.field(
+ default_factory=asyncio.get_running_loop)
+ # The future to which the SessionStatus will be written once the wait is over.
+ status_future: asyncio.Future[SessionStatus] = dataclasses.field(
+ default_factory=asyncio.Future)
+
+
+class _AggregationProtocolCallback(
+ aggregation_protocol.AggregationProtocol.Callback):
+ """AggregationProtocol.Callback that writes events to queues."""
+
+ def __init__(self, on_abort: Callable[[], None]):
+ """Constructs a new _AggregationProtocolCallback..
+
+ Args:
+ on_abort: A callback invoked if/when Abort is called.
+ """
+ super().__init__()
+ # When a client is accepted after calling AggregationProtocol.AddClients,
+ # this queue receives the new client's id as well as a queue that will
+ # provide the diagnostic status when the client is closed. (The status
+ # queue is being used as a future and will only receive one element.)
+ self.accepted_clients: queue.SimpleQueue[tuple[
+ int, queue.SimpleQueue[absl_status.Status]]] = queue.SimpleQueue()
+ # A queue receiving the final result of the aggregation session: either the
+ # aggregated tensors or a failure status. This queue is being used as a
+ # future and will only receive one element.
+ self.result: queue.SimpleQueue[bytes | absl_status.Status] = (
+ queue.SimpleQueue())
+
+ self._on_abort = on_abort
+ self._client_results_lock = threading.Lock()
+ # A map from client id to the queue for each client's close status.
+ self._client_results: dict[int, queue.SimpleQueue[absl_status.Status]] = {}
+
+ def OnAcceptClients(self, start_client_id: int, num_clients: int,
+ message: apm_pb2.AcceptanceMessage) -> None:
+ with self._client_results_lock:
+ for client_id in range(start_client_id, start_client_id + num_clients):
+ q = queue.SimpleQueue()
+ self._client_results[client_id] = q
+ self.accepted_clients.put((client_id, q))
+
+ def OnSendServerMessage(self, client_id: int,
+ message: apm_pb2.ServerMessage) -> None:
+ raise NotImplementedError()
+
+ def OnCloseClient(self, client_id: int,
+ diagnostic_status: absl_status.Status) -> None:
+ with self._client_results_lock:
+ self._client_results.pop(client_id).put(diagnostic_status)
+
+ def OnComplete(self, result: bytes) -> None:
+ self.result.put(result)
+
+ def OnAbort(self, diagnostic_status: absl_status.Status) -> None:
+ self.result.put(diagnostic_status)
+ self._on_abort()
+
+
+@dataclasses.dataclass(eq=False)
+class _AggregationSessionState:
+ """Internal state for an aggregation session."""
+ # The session's aggregation requirements.
+ requirements: AggregationRequirements
+ # The AggregationProtocol.Callback object receiving protocol events.
+ callback: _AggregationProtocolCallback
+ # The protocol performing the aggregation. Service._sessions_lock should not
+ # be held while AggregationProtocol methods are invoked -- both because
+ # methods may be slow and because callbacks may also need to acquire the lock.
+ agg_protocol: aggregation_protocol.AggregationProtocol
+ # The current status of the session.
+ status: AggregationStatus = AggregationStatus.PENDING
+ # Unredeemed client authorization tokens.
+ authorization_tokens: set[str] = dataclasses.field(default_factory=set)
+ # Information about active clients, keyed by authorization token
+ active_clients: dict[str, _ActiveClientData] = dataclasses.field(
+ default_factory=dict)
+ # Information for in-progress wait calls on this session.
+ pending_waits: set[_WaitData] = dataclasses.field(default_factory=set)
+
+
+class Service:
+ """Implements the Aggregations service."""
+
+ def __init__(self, forwarding_info: Callable[[], common_pb2.ForwardingInfo],
+ media_service: media.Service):
+ self._forwarding_info = forwarding_info
+ self._media_service = media_service
+ self._sessions: dict[str, _AggregationSessionState] = {}
+ self._sessions_lock = threading.Lock()
+
+ def create_session(self,
+ aggregation_requirements: AggregationRequirements) -> str:
+ """Creates a new aggregation session and returns its id."""
+ session_id = str(uuid.uuid4())
+ callback = _AggregationProtocolCallback(
+ functools.partial(self._handle_protocol_abort, session_id))
+ if (len(aggregation_requirements.plan.phase) != 1 or
+ not aggregation_requirements.plan.phase[0].HasField('server_phase_v2')):
+ raise ValueError('Plan must contain exactly one server_phase_v2.')
+
+ # NOTE: For simplicity, this implementation only creates a single,
+ # in-process aggregation shard. In a production implementation, there should
+ # be multiple shards running on separate servers to enable high rates of
+ # client contributions. Utilities for combining results from separate shards
+ # are still in development as of Jan 2023.
+ agg_protocol = aggregation_protocols.create_simple_aggregation_protocol(
+ configuration_pb2.Configuration(aggregation_configs=[
+ self._translate_server_aggregation_config(aggregation_config)
+ for aggregation_config in
+ aggregation_requirements.plan.phase[0].server_phase_v2.aggregations
+ ]), callback)
+ agg_protocol.Start(0)
+
+ with self._sessions_lock:
+ self._sessions[session_id] = _AggregationSessionState(
+ requirements=aggregation_requirements,
+ callback=callback,
+ agg_protocol=agg_protocol)
+ return session_id
+
+ def complete_session(
+ self, session_id: str) -> tuple[SessionStatus, Optional[bytes]]:
+ """Completes the aggregation session and returns its results."""
+ with self._sessions_lock:
+ state = self._sessions.pop(session_id)
+
+ try:
+ # Only complete the AggregationProtocol if it's still pending. The most
+ # likely alternative is that it's ABORTED due to an error generated by the
+ # protocol itself.
+ status = self._get_session_status(state)
+ if status.status != AggregationStatus.PENDING:
+ return self._get_session_status(state), None
+
+ # Ensure privacy requirements have been met.
+ if (state.agg_protocol.GetStatus().num_inputs_aggregated_and_included <
+ state.requirements.minimum_clients_in_server_published_aggregate):
+ state.agg_protocol.Abort()
+ raise ValueError(
+ 'minimum_clients_in_server_published_aggregate has not been met.')
+
+ state.agg_protocol.Complete()
+ result = state.callback.result.get(timeout=1)
+ if isinstance(result, absl_status.Status):
+ raise absl_status.StatusNotOk(result)
+ state.status = AggregationStatus.COMPLETED
+ return self._get_session_status(state), result
+ except (ValueError, absl_status.StatusNotOk, queue.Empty) as e:
+ logging.warning('Failed to complete aggregation session: %s', e)
+ state.status = AggregationStatus.FAILED
+ return self._get_session_status(state), None
+ finally:
+ self._cleanup_session(state)
+
+ def abort_session(self, session_id: str) -> SessionStatus:
+ """Aborts/cancels an aggregation session."""
+ with self._sessions_lock:
+ state = self._sessions.pop(session_id)
+
+ # Only abort the AggregationProtocol if it's still pending. The most likely
+ # alternative is that it's ABORTED due to an error generated by the protocol
+ # itself.
+ if state.status == AggregationStatus.PENDING:
+ state.status = AggregationStatus.ABORTED
+ state.agg_protocol.Abort()
+
+ self._cleanup_session(state)
+ return self._get_session_status(state)
+
+ def get_session_status(self, session_id: str) -> SessionStatus:
+ """Returns the status of an aggregation session."""
+ with self._sessions_lock:
+ return self._get_session_status(self._sessions[session_id])
+
+ async def wait(
+ self,
+ session_id: str,
+ num_inputs_aggregated_and_included: Optional[int] = None
+ ) -> SessionStatus:
+ """Blocks until all conditions are satisfied or the aggregation fails."""
+ with self._sessions_lock:
+ state = self._sessions[session_id]
+ # Check if any of the conditions are already satisfied.
+ status = self._get_session_status(state)
+ if (num_inputs_aggregated_and_included is None or
+ num_inputs_aggregated_and_included <=
+ status.num_inputs_aggregated_and_included):
+ return status
+
+ wait_data = _WaitData(num_inputs_aggregated_and_included)
+ state.pending_waits.add(wait_data)
+ return await wait_data.status_future
+
+ def pre_authorize_clients(self, session_id: str,
+ num_tokens: int) -> Sequence[str]:
+ """Generates tokens authorizing clients to contribute to the session."""
+ tokens = set(str(uuid.uuid4()) for _ in range(num_tokens))
+ with self._sessions_lock:
+ self._sessions[session_id].authorization_tokens |= tokens
+ return list(tokens)
+
+ def _translate_intrinsic_arg(
+ self, intrinsic_arg: plan_pb2.ServerAggregationConfig.IntrinsicArg
+ ) -> configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg:
+ """Transform an aggregation intrinsic arg for the aggregation service."""
+ if intrinsic_arg.HasField('input_tensor'):
+ return configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg(
+ input_tensor=intrinsic_arg.input_tensor)
+ elif intrinsic_arg.HasField('state_tensor'):
+ raise ValueError(
+ 'Non-client intrinsic args are not supported in this demo.'
+ )
+ else:
+ raise AssertionError(
+ 'Cases should have exhausted all possible types of intrinsic args.')
+
+ def _translate_server_aggregation_config(
+ self, plan_aggregation_config: plan_pb2.ServerAggregationConfig
+ ) -> configuration_pb2.Configuration.ServerAggregationConfig:
+ """Transform the aggregation config for use by the aggregation service."""
+ if plan_aggregation_config.inner_aggregations:
+ raise AssertionError('Nested intrinsic structrues are not supported yet.')
+ return configuration_pb2.Configuration.ServerAggregationConfig(
+ intrinsic_uri=plan_aggregation_config.intrinsic_uri,
+ intrinsic_args=[
+ self._translate_intrinsic_arg(intrinsic_arg)
+ for intrinsic_arg in plan_aggregation_config.intrinsic_args
+ ],
+ output_tensors=plan_aggregation_config.output_tensors)
+
+ def _get_session_status(self,
+ state: _AggregationSessionState) -> SessionStatus:
+ """Returns the SessionStatus for an _AggregationSessionState object."""
+ status = state.agg_protocol.GetStatus()
+ return SessionStatus(
+ status=state.status,
+ num_clients_completed=status.num_clients_completed,
+ num_clients_failed=status.num_clients_failed,
+ num_clients_pending=status.num_clients_pending,
+ num_clients_aborted=status.num_clients_aborted,
+ num_inputs_aggregated_and_included=(
+ status.num_inputs_aggregated_and_included),
+ num_inputs_aggregated_and_pending=(
+ status.num_inputs_aggregated_and_pending),
+ num_inputs_discarded=status.num_inputs_discarded)
+
+ def _get_http_status(self, code: absl_status.StatusCode) -> http.HTTPStatus:
+ """Returns the HTTPStatus code for an absl StatusCode."""
+ if (code == absl_status.StatusCode.INVALID_ARGUMENT or
+ code == absl_status.StatusCode.FAILED_PRECONDITION):
+ return http.HTTPStatus.BAD_REQUEST
+ elif code == absl_status.StatusCode.NOT_FOUND:
+ return http.HTTPStatus.NOT_FOUND
+ else:
+ return http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+ def _cleanup_session(self, state: _AggregationSessionState) -> None:
+ """Cleans up the session and releases any resources.
+
+ Args:
+ state: The session state to clean up.
+ """
+ state.authorization_tokens.clear()
+ for client_data in state.active_clients.values():
+ self._media_service.finalize_upload(client_data.resource_name)
+ state.active_clients.clear()
+ # Anyone waiting on the session should be notified that it's finished.
+ if state.pending_waits:
+ status = self._get_session_status(state)
+ for data in state.pending_waits:
+ data.loop.call_soon_threadsafe(
+ functools.partial(data.status_future.set_result, status))
+ state.pending_waits.clear()
+
+ def _handle_protocol_abort(self, session_id: str) -> None:
+ """Notifies waiting clients when the protocol is aborted."""
+ with self._sessions_lock:
+ with contextlib.suppress(KeyError):
+ state = self._sessions[session_id]
+ state.status = AggregationStatus.FAILED
+ # Anyone waiting on the session should be notified it's been aborted.
+ if state.pending_waits:
+ status = self._get_session_status(state)
+ for data in state.pending_waits:
+ data.loop.call_soon_threadsafe(
+ functools.partial(data.status_future.set_result, status))
+ state.pending_waits.clear()
+
+ @http_actions.proto_action(
+ service='google.internal.federatedcompute.v1.Aggregations',
+ method='StartAggregationDataUpload')
+ def start_aggregation_data_upload(
+ self, request: aggregations_pb2.StartAggregationDataUploadRequest
+ ) -> operations_pb2.Operation:
+ """Handles a StartAggregationDataUpload request."""
+ with self._sessions_lock:
+ try:
+ state = self._sessions[request.aggregation_id]
+ except KeyError as e:
+ raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
+ try:
+ state.authorization_tokens.remove(request.authorization_token)
+ except KeyError as e:
+ raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e
+
+ state.agg_protocol.AddClients(1)
+ client_token = str(uuid.uuid4())
+ client_id, close_status = state.callback.accepted_clients.get(timeout=1)
+ upload_name = self._media_service.register_upload()
+
+ with self._sessions_lock:
+ state.active_clients[client_token] = _ActiveClientData(
+ client_id, close_status, upload_name)
+
+ forwarding_info = self._forwarding_info()
+ response = aggregations_pb2.StartAggregationDataUploadResponse(
+ aggregation_protocol_forwarding_info=forwarding_info,
+ resource=common_pb2.ByteStreamResource(
+ data_upload_forwarding_info=forwarding_info,
+ resource_name=upload_name),
+ client_token=client_token)
+
+ op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True)
+ op.metadata.Pack(aggregations_pb2.StartAggregationDataUploadMetadata())
+ op.response.Pack(response)
+ return op
+
+ @http_actions.proto_action(
+ service='google.internal.federatedcompute.v1.Aggregations',
+ method='SubmitAggregationResult')
+ def submit_aggregation_result(
+ self, request: aggregations_pb2.SubmitAggregationResultRequest
+ ) -> aggregations_pb2.SubmitAggregationResultResponse:
+ """Handles a SubmitAggregationResult request."""
+ with self._sessions_lock:
+ try:
+ state = self._sessions[request.aggregation_id]
+ except KeyError as e:
+ raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
+ try:
+ client_data = state.active_clients.pop(request.client_token)
+ except KeyError as e:
+ raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e
+
+ # Ensure the client is using the resource name provided when they called
+ # StartAggregationDataUpload.
+ if request.resource_name != client_data.resource_name:
+ raise http_actions.HttpError(http.HTTPStatus.BAD_REQUEST)
+
+ # The aggregation protocol may have already closed the connect (e.g., if
+ # an error occurred). If so, clean up the upload and return the error.
+ if not client_data.close_status.empty():
+ with contextlib.suppress(KeyError):
+ self._media_service.finalize_upload(request.resource_name)
+ raise http_actions.HttpError(
+ self._get_http_status(client_data.close_status.get().code()))
+
+ # Finalize the upload.
+ try:
+ update = self._media_service.finalize_upload(request.resource_name)
+ if update is None:
+ raise absl_status.StatusNotOk(
+ absl_status.invalid_argument_error(
+ 'Aggregation result never uploaded'))
+ except (KeyError, absl_status.StatusNotOk) as e:
+ if isinstance(e, KeyError):
+ e = absl_status.StatusNotOk(
+ absl_status.internal_error('Failed to finalize upload'))
+ state.agg_protocol.CloseClient(client_data.client_id, e.status)
+ # Since we're initiating the close, it's also necessary to notify the
+ # _AggregationProtocolCallback so it can clean up resources.
+ state.callback.OnCloseClient(client_data.client_id, e.status)
+ raise http_actions.HttpError(self._get_http_status(
+ e.status.code())) from e
+
+ client_message = apm_pb2.ClientMessage(
+ simple_aggregation=apm_pb2.ClientMessage.SimpleAggregation(
+ input=apm_pb2.ClientResource(inline_bytes=update)))
+ try:
+ state.agg_protocol.ReceiveClientMessage(client_data.client_id,
+ client_message)
+ except absl_status.StatusNotOk as e:
+ # ReceiveClientInput should only fail if the AggregationProtocol is in a
+ # bad state -- likely leading to it being aborted.
+ logging.warning('Failed to receive client input: %s', e)
+ raise http_actions.HttpError(http.HTTPStatus.INTERNAL_SERVER_ERROR) from e
+
+ # Wait for the client input to be processed.
+ close_status = client_data.close_status.get()
+ if not close_status.ok():
+ raise http_actions.HttpError(self._get_http_status(close_status.code()))
+
+ # Check for any newly-satisfied pending wait operations.
+ with self._sessions_lock:
+ if state.pending_waits:
+ completed_waits = set()
+ status = self._get_session_status(state)
+ for data in state.pending_waits:
+ if (data.num_inputs_aggregated_and_included is not None and
+ status.num_inputs_aggregated_and_included >=
+ data.num_inputs_aggregated_and_included):
+ data.loop.call_soon_threadsafe(
+ functools.partial(data.status_future.set_result, status))
+ completed_waits.add(data)
+ state.pending_waits -= completed_waits
+ return aggregations_pb2.SubmitAggregationResultResponse()
+
+ @http_actions.proto_action(
+ service='google.internal.federatedcompute.v1.Aggregations',
+ method='AbortAggregation')
+ def abort_aggregation(
+ self, request: aggregations_pb2.AbortAggregationRequest
+ ) -> aggregations_pb2.AbortAggregationResponse:
+ """Handles an AbortAggregation request."""
+ with self._sessions_lock:
+ try:
+ state = self._sessions[request.aggregation_id]
+ except KeyError as e:
+ raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
+ try:
+ client_data = state.active_clients.pop(request.client_token)
+ except KeyError as e:
+ raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e
+
+ # Attempt to finalize the in-progress upload to free up resources.
+ with contextlib.suppress(KeyError):
+ self._media_service.finalize_upload(client_data.resource_name)
+
+ # Notify the aggregation protocol that the client has left.
+ if request.status.code == code_pb2.Code.OK:
+ status = absl_status.Status.OkStatus()
+ else:
+ status = absl_status.BuildStatusNotOk(
+ absl_status.StatusCodeFromInt(request.status.code),
+ request.status.message)
+ state.agg_protocol.CloseClient(client_data.client_id, status)
+ # Since we're initiating the close, it's also necessary to notify the
+ # _AggregationProtocolCallback so it can clean up resources.
+ state.callback.OnCloseClient(client_data.client_id, status)
+
+ logging.debug('[%s] AbortAggregation: %s', request.aggregation_id,
+ request.status)
+ return aggregations_pb2.AbortAggregationResponse()
diff --git a/fcp/demo/aggregations_test.py b/fcp/demo/aggregations_test.py
new file mode 100644
index 0000000..69ba5e3
--- /dev/null
+++ b/fcp/demo/aggregations_test.py
@@ -0,0 +1,783 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for aggregations."""
+
+import asyncio
+import http
+import unittest
+from unittest import mock
+
+from absl.testing import absltest
+import tensorflow as tf
+
+from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2
+from fcp.aggregation.protocol.python import aggregation_protocol
+from fcp.aggregation.tensorflow.python import aggregation_protocols
+from fcp.demo import aggregations
+from fcp.demo import http_actions
+from fcp.demo import media
+from fcp.demo import test_utils
+from fcp.protos import plan_pb2
+from fcp.protos.federatedcompute import aggregations_pb2
+from fcp.protos.federatedcompute import common_pb2
+from pybind11_abseil import status as absl_status
+
+INPUT_TENSOR = 'in'
+OUTPUT_TENSOR = 'out'
+AGGREGATION_REQUIREMENTS = aggregations.AggregationRequirements(
+ minimum_clients_in_server_published_aggregate=3,
+ plan=plan_pb2.Plan(phase=[
+ plan_pb2.Plan.Phase(
+ server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[
+ plan_pb2.ServerAggregationConfig(
+ intrinsic_uri='federated_sum',
+ intrinsic_args=[
+ plan_pb2.ServerAggregationConfig.IntrinsicArg(
+ input_tensor=tf.TensorSpec((
+ ), tf.int32, INPUT_TENSOR).experimental_as_proto())
+ ],
+ output_tensors=[
+ tf.TensorSpec((
+ ), tf.int32, OUTPUT_TENSOR).experimental_as_proto(),
+ ]),
+ ])),
+ ]))
+FORWARDING_INFO = common_pb2.ForwardingInfo(
+ target_uri_prefix='https://forwarding.example/')
+
+
+class NotOkStatus:
+ """Matcher for a not-ok Status."""
+
+ def __eq__(self, other) -> bool:
+ return isinstance(other, absl_status.Status) and not other.ok()
+
+
+class AggregationsTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.mock_media_service = self.enter_context(
+ mock.patch.object(media, 'Service', autospec=True))
+ self.mock_media_service.register_upload.return_value = 'upload-id'
+ self.mock_media_service.finalize_upload.return_value = (
+ test_utils.create_checkpoint({INPUT_TENSOR: 0}))
+
+ def test_pre_authorize_clients(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ tokens = service.pre_authorize_clients(session_id, 3)
+ self.assertLen(tokens, 3)
+ # The tokens should all be unique.
+ self.assertLen(set(tokens), 3)
+
+ def test_pre_authorize_clients_with_missing_session_id(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ with self.assertRaises(KeyError):
+ service.pre_authorize_clients('does-not-exist', 1)
+
+ def test_pre_authorize_clients_with_bad_count(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ self.assertEmpty(service.pre_authorize_clients(session_id, 0))
+ self.assertEmpty(service.pre_authorize_clients(session_id, -2))
+
+ def test_create_session(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING))
+
+ def test_complete_session(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+
+ # Upload results from the client.
+ num_clients = (
+ AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate)
+ for i in range(num_clients):
+ tokens = service.pre_authorize_clients(session_id, 1)
+
+ self.mock_media_service.register_upload.return_value = f'upload-{i}'
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+
+ self.mock_media_service.finalize_upload.return_value = (
+ test_utils.create_checkpoint({INPUT_TENSOR: i}))
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token,
+ resource_name=start_upload_response.resource.resource_name))
+
+ # Now that all clients have contributed, the aggregation session can be
+ # completed.
+ status, aggregate = service.complete_session(session_id)
+ self.assertEqual(
+ status,
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.COMPLETED,
+ num_clients_completed=num_clients,
+ num_inputs_aggregated_and_included=num_clients))
+ self.assertEqual(
+ test_utils.read_tensor_from_checkpoint(aggregate,
+ OUTPUT_TENSOR, tf.int32),
+ sum(range(num_clients)))
+
+ # get_session_status should no longer return results.
+ with self.assertRaises(KeyError):
+ service.get_session_status(session_id)
+
+ @mock.patch.object(
+ aggregation_protocols,
+ 'create_simple_aggregation_protocol',
+ autospec=True)
+ def test_complete_session_fails(self, mock_create_simple_agg_protocol):
+ # Use a mock since it's not easy to cause
+ # SimpleAggregationProtocol::Complete to fail.
+ mock_agg_protocol = mock.create_autospec(
+ aggregation_protocol.AggregationProtocol, instance=True)
+ mock_create_simple_agg_protocol.return_value = mock_agg_protocol
+
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+
+ required_clients = (
+ AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate)
+ agg_status = apm_pb2.StatusMessage(
+ num_inputs_aggregated_and_included=required_clients)
+ mock_agg_protocol.GetStatus.side_effect = lambda: agg_status
+
+ def on_complete():
+ agg_status.num_inputs_discarded = (
+ agg_status.num_inputs_aggregated_and_included)
+ agg_status.num_inputs_aggregated_and_included = 0
+ raise absl_status.StatusNotOk(absl_status.unknown_error('message'))
+
+ mock_agg_protocol.Complete.side_effect = on_complete
+
+ status, aggregate = service.complete_session(session_id)
+ self.assertEqual(
+ status,
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.FAILED,
+ num_inputs_discarded=required_clients))
+ self.assertIsNone(aggregate)
+ mock_agg_protocol.Complete.assert_called_once()
+ mock_agg_protocol.Abort.assert_not_called()
+
+ @mock.patch.object(
+ aggregation_protocols,
+ 'create_simple_aggregation_protocol',
+ autospec=True)
+ def test_complete_session_aborts(self, mock_create_simple_agg_protocol):
+ # Use a mock since it's not easy to cause
+ # SimpleAggregationProtocol::Complete to trigger a protocol abort.
+ mock_agg_protocol = mock.create_autospec(
+ aggregation_protocol.AggregationProtocol, instance=True)
+ mock_create_simple_agg_protocol.return_value = mock_agg_protocol
+
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+
+ required_clients = (
+ AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate)
+ agg_status = apm_pb2.StatusMessage(
+ num_inputs_aggregated_and_included=required_clients)
+ mock_agg_protocol.GetStatus.side_effect = lambda: agg_status
+
+ def on_complete():
+ agg_status.num_inputs_discarded = (
+ agg_status.num_inputs_aggregated_and_included)
+ agg_status.num_inputs_aggregated_and_included = 0
+ callback = mock_create_simple_agg_protocol.call_args.args[1]
+ callback.OnAbort(absl_status.unknown_error('message'))
+
+ mock_agg_protocol.Complete.side_effect = on_complete
+
+ status, aggregate = service.complete_session(session_id)
+ self.assertEqual(
+ status,
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.FAILED,
+ num_inputs_discarded=required_clients))
+ self.assertIsNone(aggregate)
+ mock_agg_protocol.Complete.assert_called_once()
+ mock_agg_protocol.Abort.assert_not_called()
+
+ def test_complete_session_without_enough_inputs(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(
+ aggregations.AggregationRequirements(
+ minimum_clients_in_server_published_aggregate=3,
+ plan=AGGREGATION_REQUIREMENTS.plan))
+ tokens = service.pre_authorize_clients(session_id, 2)
+
+ # Upload results for one client.
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token,
+ resource_name=start_upload_response.resource.resource_name))
+
+ # Complete the session before there are 2 completed clients.
+ status, aggregate = service.complete_session(session_id)
+ self.assertEqual(
+ status,
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.FAILED,
+ num_clients_completed=1,
+ num_inputs_discarded=1))
+ self.assertIsNone(aggregate)
+
+ def test_complete_session_with_missing_session_id(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ with self.assertRaises(KeyError):
+ service.complete_session('does-not-exist')
+
+ def test_abort_session_with_no_uploads(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ self.assertEqual(
+ service.abort_session(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.ABORTED))
+
+ def test_abort_session_with_uploads(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ tokens = service.pre_authorize_clients(session_id, 3)
+
+ # Upload results for one client.
+ self.mock_media_service.register_upload.return_value = 'upload1'
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token,
+ resource_name=start_upload_response.resource.resource_name))
+
+ # Start a partial upload from a second client.
+ self.mock_media_service.register_upload.return_value = 'upload2'
+ service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[1]))
+
+ # Abort the session. The pending client should be treated as failed.
+ self.assertEqual(
+ service.abort_session(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.ABORTED,
+ num_clients_completed=1,
+ num_clients_aborted=1,
+ num_inputs_discarded=1))
+ # The registered upload for the second client should have been finalized.
+ self.mock_media_service.finalize_upload.assert_called_with('upload2')
+ # get_session_status should no longer return results.
+ with self.assertRaises(KeyError):
+ service.get_session_status(session_id)
+
+ def test_abort_session_with_missing_session_id(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ with self.assertRaises(KeyError):
+ service.abort_session('does-not-exist')
+
+ def test_get_session_status_with_missing_session_id(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ with self.assertRaises(KeyError):
+ service.get_session_status('does-not-exist')
+
+ async def test_wait(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ task = asyncio.create_task(
+ service.wait(session_id, num_inputs_aggregated_and_included=1))
+ # The awaitable should not be done yet.
+ await asyncio.wait([task], timeout=0.1)
+ self.assertFalse(task.done())
+
+ # Upload results for one client.
+ tokens = service.pre_authorize_clients(session_id, 1)
+ self.mock_media_service.register_upload.return_value = 'upload'
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token,
+ resource_name=start_upload_response.resource.resource_name))
+
+ # The awaitable should now return.
+ await asyncio.wait([task], timeout=1)
+ self.assertTrue(task.done())
+ self.assertEqual(
+ task.result(),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING,
+ num_clients_completed=1,
+ num_inputs_aggregated_and_included=1))
+
+ async def test_wait_already_satisfied(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+
+ # Upload results for one client.
+ tokens = service.pre_authorize_clients(session_id, 1)
+ self.mock_media_service.register_upload.return_value = 'upload'
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token,
+ resource_name=start_upload_response.resource.resource_name))
+
+ # Since a client has already reported, the condition should already be
+ # satisfied.
+ task = asyncio.create_task(
+ service.wait(session_id, num_inputs_aggregated_and_included=1))
+ await asyncio.wait([task], timeout=1)
+ self.assertTrue(task.done())
+ self.assertEqual(
+ task.result(),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING,
+ num_clients_completed=1,
+ num_inputs_aggregated_and_included=1))
+
+ async def test_wait_with_abort(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ task = asyncio.create_task(
+ service.wait(session_id, num_inputs_aggregated_and_included=1))
+ # The awaitable should not be done yet.
+ await asyncio.wait([task], timeout=0.1)
+ self.assertFalse(task.done())
+
+ # The awaitable should return once the session is aborted.
+ status = service.abort_session(session_id)
+ await asyncio.wait([task], timeout=1)
+ self.assertTrue(task.done())
+ self.assertEqual(task.result(), status)
+
+ @mock.patch.object(
+ aggregation_protocols,
+ 'create_simple_aggregation_protocol',
+ autospec=True)
+ async def test_wait_with_protocol_abort(self,
+ mock_create_simple_agg_protocol):
+ # Use a mock since it's not easy to cause the AggregationProtocol to abort.
+ mock_agg_protocol = mock.create_autospec(
+ aggregation_protocol.AggregationProtocol, instance=True)
+ mock_create_simple_agg_protocol.return_value = mock_agg_protocol
+ mock_agg_protocol.GetStatus.return_value = apm_pb2.StatusMessage(
+ num_clients_aborted=1234)
+
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ task = asyncio.create_task(
+ service.wait(session_id, num_inputs_aggregated_and_included=1))
+ # The awaitable should not be done yet.
+ await asyncio.wait([task], timeout=0.1)
+ self.assertFalse(task.done())
+
+ # The awaitable should return once the AggregationProtocol aborts.
+ callback = mock_create_simple_agg_protocol.call_args.args[1]
+ callback.OnAbort(absl_status.unknown_error('message'))
+ await asyncio.wait([task], timeout=1)
+ self.assertTrue(task.done())
+ self.assertEqual(
+ task.result(),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.FAILED,
+ num_clients_aborted=1234))
+
+ async def test_wait_with_complete(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(
+ aggregations.AggregationRequirements(
+ minimum_clients_in_server_published_aggregate=0,
+ plan=AGGREGATION_REQUIREMENTS.plan))
+ task = asyncio.create_task(
+ service.wait(session_id, num_inputs_aggregated_and_included=1))
+ # The awaitable should not be done yet.
+ await asyncio.wait([task], timeout=0.1)
+ self.assertFalse(task.done())
+
+ # The awaitable should return once the session is completed.
+ status, _ = service.complete_session(session_id)
+ await asyncio.wait([task], timeout=1)
+ self.assertTrue(task.done())
+ self.assertEqual(task.result(), status)
+
+ async def test_wait_without_condition(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ task = asyncio.create_task(service.wait(session_id))
+ # If there are no conditions, the wait should be trivially satisfied.
+ await asyncio.wait([task], timeout=1)
+ self.assertTrue(task.done())
+ self.assertEqual(task.result(), service.get_session_status(session_id))
+
+ async def test_wait_with_missing_session_id(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ task = asyncio.create_task(service.wait('does-not-exist'))
+ await asyncio.wait([task], timeout=1)
+ self.assertTrue(task.done())
+ self.assertIsInstance(task.exception(), KeyError)
+
+ def test_start_aggregation_data_upload(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ tokens = service.pre_authorize_clients(session_id, 1)
+ self.mock_media_service.register_upload.return_value = 'upload'
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertNotEmpty(operation.name)
+ self.assertTrue(operation.done)
+
+ metadata = aggregations_pb2.StartAggregationDataUploadMetadata()
+ operation.metadata.Unpack(metadata)
+ self.assertEqual(metadata,
+ aggregations_pb2.StartAggregationDataUploadMetadata())
+
+ response = aggregations_pb2.StartAggregationDataUploadResponse()
+ operation.response.Unpack(response)
+ # The client token should be set and different from the authorization token.
+ self.assertNotEmpty(response.client_token)
+ self.assertNotEqual(response.client_token, tokens[0])
+ self.assertEqual(
+ response,
+ aggregations_pb2.StartAggregationDataUploadResponse(
+ aggregation_protocol_forwarding_info=FORWARDING_INFO,
+ resource=common_pb2.ByteStreamResource(
+ data_upload_forwarding_info=FORWARDING_INFO,
+ resource_name='upload'),
+ client_token=response.client_token))
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING,
+ num_clients_pending=1))
+
+ def test_start_aggregagation_data_upload_with_missing_session_id(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ tokens = service.pre_authorize_clients(session_id, 1)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id='does-not-exist', authorization_token=tokens[0]))
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING))
+
+ def test_start_aggregagation_data_upload_with_invalid_token(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token='does-not-exist'))
+ self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING))
+
+ def test_start_aggregagation_data_upload_twice(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ tokens = service.pre_authorize_clients(session_id, 1)
+ service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
+
+ def test_submit_aggregation_result(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+
+ # Upload results from the client.
+ tokens = service.pre_authorize_clients(session_id, 1)
+
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+
+ submit_response = service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token,
+ resource_name=start_upload_response.resource.resource_name))
+ self.assertEqual(submit_response,
+ aggregations_pb2.SubmitAggregationResultResponse())
+ self.mock_media_service.finalize_upload.assert_called_with(
+ start_upload_response.resource.resource_name)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING,
+ num_clients_completed=1,
+ num_inputs_aggregated_and_included=1))
+
+ def test_submit_aggregation_result_with_invalid_client_input(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+
+ tokens = service.pre_authorize_clients(session_id, 1)
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+
+ self.mock_media_service.finalize_upload.return_value = b'invalid'
+ with self.assertRaises(http_actions.HttpError):
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token,
+ resource_name=start_upload_response.resource.resource_name))
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING,
+ num_clients_failed=1))
+
+ def test_submit_aggregation_result_with_missing_session_id(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id='does-not-exist',
+ client_token='client-token',
+ resource_name='upload-id'))
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING))
+
+ def test_submit_aggregation_result_with_invalid_token(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token='does-not-exist',
+ resource_name='upload-id'))
+ self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING))
+
+ def test_submit_aggregation_result_with_finalize_upload_error(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ tokens = service.pre_authorize_clients(session_id, 1)
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+
+ # If the resource_name doesn't correspond to a registered upload,
+ # finalize_upload will raise a KeyError.
+ self.mock_media_service.finalize_upload.side_effect = KeyError()
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token,
+ resource_name=start_upload_response.resource.resource_name))
+ self.assertEqual(cm.exception.code, http.HTTPStatus.INTERNAL_SERVER_ERROR)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING,
+ num_clients_failed=1))
+
+ def test_submit_aggregation_result_with_unuploaded_resource(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ tokens = service.pre_authorize_clients(session_id, 1)
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+
+ # If the resource_name is valid but no resource was uploaded,
+ # finalize_resource will return None.
+ self.mock_media_service.finalize_upload.return_value = None
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.submit_aggregation_result(
+ aggregations_pb2.SubmitAggregationResultRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token,
+ resource_name=start_upload_response.resource.resource_name))
+ self.assertEqual(cm.exception.code, http.HTTPStatus.BAD_REQUEST)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING,
+ num_clients_failed=1))
+
+ def test_abort_aggregation(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ tokens = service.pre_authorize_clients(session_id, 1)
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+ self.assertEqual(
+ service.abort_aggregation(
+ aggregations_pb2.AbortAggregationRequest(
+ aggregation_id=session_id,
+ client_token=start_upload_response.client_token)),
+ aggregations_pb2.AbortAggregationResponse())
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING,
+ num_clients_completed=1,
+ num_inputs_discarded=1))
+
+ def test_abort_aggregation_with_missing_session_id(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ tokens = service.pre_authorize_clients(session_id, 1)
+ operation = service.start_aggregation_data_upload(
+ aggregations_pb2.StartAggregationDataUploadRequest(
+ aggregation_id=session_id, authorization_token=tokens[0]))
+ self.assertTrue(operation.done)
+ start_upload_response = (
+ aggregations_pb2.StartAggregationDataUploadResponse())
+ operation.response.Unpack(start_upload_response)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.abort_aggregation(
+ aggregations_pb2.AbortAggregationRequest(
+ aggregation_id='does-not-exist',
+ client_token=start_upload_response.client_token))
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING,
+ num_clients_pending=1))
+
+ def test_abort_aggregation_with_invalid_token(self):
+ service = aggregations.Service(lambda: FORWARDING_INFO,
+ self.mock_media_service)
+ session_id = service.create_session(AGGREGATION_REQUIREMENTS)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.abort_aggregation(
+ aggregations_pb2.AbortAggregationRequest(
+ aggregation_id=session_id, client_token='does-not-exist'))
+ self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
+ self.assertEqual(
+ service.get_session_status(session_id),
+ aggregations.SessionStatus(
+ status=aggregations.AggregationStatus.PENDING))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/checkpoint_tensor_reference.py b/fcp/demo/checkpoint_tensor_reference.py
new file mode 100644
index 0000000..372fcfc
--- /dev/null
+++ b/fcp/demo/checkpoint_tensor_reference.py
@@ -0,0 +1,66 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MaterializableValueReference that reads from a TensorFlow checkpoint."""
+
+from typing import Any, Optional
+import uuid
+
+import tensorflow as tf
+import tensorflow_federated as tff
+
+
+class CheckpointTensorReference(tff.program.MaterializableValueReference):
+ """A reference to a tensor in a TF checkpoint file."""
+
+ def __init__(self, tensor_name: str, dtype: tf.DType, shape: Any,
+ checkpoint_future: tff.async_utils.SharedAwaitable):
+ """Constructs a new CheckpointTensorReference object.
+
+ Args:
+ tensor_name: The name of the tensor in the TF checkpoint.
+ dtype: The type of the tensor.
+ shape: The shape of the tensor, expressed as a value convertible to
+ `tf.TensorShape`.
+ checkpoint_future: A `tff.async_utils.SharedAwaitable` that resolves to
+ the TF checkpoint bytes once they're available.
+ """
+ self._tensor_name = tensor_name
+ self._type_signature = tff.TensorType(dtype, shape)
+ self._checkpoint_future = checkpoint_future
+ self._tensor: Optional[tf.Tensor] = None
+
+ @property
+ def type_signature(self) -> tff.Type:
+ return self._type_signature
+
+ async def get_value(self) -> tff.program.MaterializedValue:
+ if self._tensor is None:
+ checkpoint = await self._checkpoint_future
+ # Write to a file in TensorFlow's RamFileSystem to avoid disk I/O.
+ tmpfile = f'ram://{uuid.uuid4()}.ckpt'
+ with tf.io.gfile.GFile(tmpfile, 'wb') as f:
+ f.write(checkpoint)
+ try:
+ self._tensor = tf.raw_ops.RestoreV2(
+ prefix=tmpfile,
+ tensor_names=[self._tensor_name],
+ shape_and_slices=[''],
+ dtypes=[self._type_signature.dtype])[0]
+ finally:
+ tf.io.gfile.remove(tmpfile)
+
+ try:
+ return self._tensor.numpy()
+ except AttributeError as e:
+ raise ValueError('get_value is only supported in eager mode.') from e
diff --git a/fcp/demo/checkpoint_tensor_reference_test.py b/fcp/demo/checkpoint_tensor_reference_test.py
new file mode 100644
index 0000000..706065f
--- /dev/null
+++ b/fcp/demo/checkpoint_tensor_reference_test.py
@@ -0,0 +1,88 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for checkpoint_tensor_reference."""
+
+import unittest
+
+from absl.testing import absltest
+import numpy
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.demo import checkpoint_tensor_reference as ctr
+from fcp.demo import test_utils
+
+TENSOR_NAME = 'test'
+DTYPE = tf.int32
+SHAPE = (2, 3)
+TEST_VALUE = tf.zeros(SHAPE, DTYPE).numpy()
+
+
+async def get_test_checkpoint():
+ return test_utils.create_checkpoint({TENSOR_NAME: TEST_VALUE})
+
+
+class CheckpointTensorReferenceTest(absltest.TestCase,
+ unittest.IsolatedAsyncioTestCase):
+
+ def test_type_signature(self):
+ ref = ctr.CheckpointTensorReference(
+ TENSOR_NAME, DTYPE, SHAPE,
+ tff.async_utils.SharedAwaitable(get_test_checkpoint()))
+ self.assertEqual(ref.type_signature, tff.TensorType(DTYPE, SHAPE))
+
+ async def test_get_value(self):
+
+ async def get_checkpoint():
+ return test_utils.create_checkpoint({TENSOR_NAME: TEST_VALUE})
+
+ ref = ctr.CheckpointTensorReference(
+ TENSOR_NAME, DTYPE, SHAPE,
+ tff.async_utils.SharedAwaitable(get_checkpoint()))
+ self.assertTrue(numpy.array_equiv(await ref.get_value(), TEST_VALUE))
+
+ async def test_get_value_in_graph_mode(self):
+ with tf.compat.v1.Graph().as_default():
+ ref = ctr.CheckpointTensorReference(
+ TENSOR_NAME, DTYPE, SHAPE,
+ tff.async_utils.SharedAwaitable(get_test_checkpoint()))
+ with self.assertRaisesRegex(ValueError,
+ 'get_value is only supported in eager mode'):
+ await ref.get_value()
+
+ async def test_get_value_not_found(self):
+
+ async def get_not_found_checkpoint():
+ return test_utils.create_checkpoint({'other': TEST_VALUE})
+
+ ref = ctr.CheckpointTensorReference(
+ TENSOR_NAME, DTYPE, SHAPE,
+ tff.async_utils.SharedAwaitable(get_not_found_checkpoint()))
+ with self.assertRaises(tf.errors.NotFoundError):
+ await ref.get_value()
+
+ async def test_get_value_with_invalid_checkpoint(self):
+
+ async def get_invalid_checkpoint():
+ return b'invalid'
+
+ ref = ctr.CheckpointTensorReference(
+ TENSOR_NAME, DTYPE, SHAPE,
+ tff.async_utils.SharedAwaitable(get_invalid_checkpoint()))
+ with self.assertRaises(tf.errors.DataLossError):
+ await ref.get_value()
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/eligibility_eval_tasks.py b/fcp/demo/eligibility_eval_tasks.py
new file mode 100644
index 0000000..596e881
--- /dev/null
+++ b/fcp/demo/eligibility_eval_tasks.py
@@ -0,0 +1,138 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Action handlers for the EligibilityEvalTasks service.
+
+Eligibility Eval tasks are not currently supported by this demo implementation.
+"""
+
+import dataclasses
+import datetime
+import http
+import threading
+from typing import Callable
+import uuid
+
+from absl import logging
+
+from google.rpc import code_pb2
+from fcp.demo import http_actions
+from fcp.protos.federatedcompute import common_pb2
+from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
+
+_TaskAssignmentMode = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
+)
+
+
+@dataclasses.dataclass(frozen=True)
+class _Task:
+ task_name: str
+ task_assignment_mode: _TaskAssignmentMode
+
+
+class Service:
+ """Implements the EligibilityEvalTasks service."""
+
+ def __init__(self, population_name: str,
+ forwarding_info: Callable[[], common_pb2.ForwardingInfo]):
+ self._population_name = population_name
+ self._forwarding_info = forwarding_info
+ self._tasks: dict[str, _Task] = {}
+ self._tasks_lock = threading.Lock()
+
+ def add_task(self, task_name: str, task_assignment_mode: _TaskAssignmentMode):
+ """Informs the service that a new task has been added to the system."""
+ with self._tasks_lock:
+ self._tasks[task_name] = _Task(task_name, task_assignment_mode)
+
+ def remove_task(self, task_name: str):
+ """Informs the service that a task has been removed from the system."""
+ with self._tasks_lock:
+ del self._tasks[task_name]
+
+ @property
+ def _population_eligibility_spec(
+ self,
+ ) -> eligibility_eval_tasks_pb2.PopulationEligibilitySpec:
+ with self._tasks_lock:
+ return eligibility_eval_tasks_pb2.PopulationEligibilitySpec(
+ task_info=[
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo(
+ task_name=task.task_name,
+ task_assignment_mode=task.task_assignment_mode,
+ )
+ for task in self._tasks.values()
+ ]
+ )
+
+ @http_actions.proto_action(
+ service='google.internal.federatedcompute.v1.EligibilityEvalTasks',
+ method='RequestEligibilityEvalTask')
+ def request_eligibility_eval_task(
+ self, request: eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest
+ ) -> eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse:
+ """Handles a RequestEligibilityEvalTask request."""
+ if request.population_name != self._population_name:
+ raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
+ # NOTE: A production implementation should use
+ # `request.attestation_measurement` to verify the device is valid, e.g.
+ # using the SafetyNet Attestation API.
+ session_id = str(uuid.uuid4())
+ logging.debug('[%s] RequestEligibilityEvalTask', session_id)
+
+ # NOTE: A production implementation should vary the retry windows based on
+ # the population size and other factors, as described in TFLaS §2.3.
+ retry_window = common_pb2.RetryWindow()
+ retry_window.delay_min.FromTimedelta(datetime.timedelta(seconds=10))
+ retry_window.delay_max.FromTimedelta(datetime.timedelta(seconds=30))
+
+ response = eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse(
+ task_assignment_forwarding_info=self._forwarding_info(),
+ session_id=str(uuid.uuid4()),
+ retry_window_if_accepted=retry_window,
+ retry_window_if_rejected=retry_window,
+ )
+
+ # This implementation does not support Eligibility Eval tasks. However, the
+ # EligibilityEvalTask response is also used to provide the
+ # PopulationEligibilitySpec to clients, so the service returns an
+ # EligibilityEvalTask instead of NoEligibilityEvalConfigured if the client
+ # supports multiple task assignment.
+ capabilities = request.eligibility_eval_task_capabilities
+ if capabilities.supports_multiple_task_assignment:
+ spec_resource = response.eligibility_eval_task.population_eligibility_spec
+ spec_resource.inline_resource.data = (
+ self._population_eligibility_spec.SerializeToString()
+ )
+ else:
+ response.no_eligibility_eval_configured.SetInParent()
+ return response
+
+ @http_actions.proto_action(
+ service='google.internal.federatedcompute.v1.EligibilityEvalTasks',
+ method='ReportEligibilityEvalTaskResult')
+ def report_eligibility_eval_task_result(
+ self,
+ request: eligibility_eval_tasks_pb2.ReportEligibilityEvalTaskResultRequest
+ ) -> eligibility_eval_tasks_pb2.ReportEligibilityEvalTaskResultResponse:
+ """Handles a ReportEligibilityEvalTaskResult request."""
+ if request.population_name != self._population_name:
+ raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
+ # NOTE: A production implementation should collect and report metrics. In
+ # this implementation, we simply log the result.
+ logging.log(
+ logging.DEBUG if request.status_code == code_pb2.OK else logging.WARN,
+ '[%s] ReportEligibilityEvalTaskResult: %s', request.session_id,
+ code_pb2.Code.Name(request.status_code))
+ return eligibility_eval_tasks_pb2.ReportEligibilityEvalTaskResultResponse()
diff --git a/fcp/demo/eligibility_eval_tasks_test.py b/fcp/demo/eligibility_eval_tasks_test.py
new file mode 100644
index 0000000..ca6fbd1
--- /dev/null
+++ b/fcp/demo/eligibility_eval_tasks_test.py
@@ -0,0 +1,169 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for eligibility_eval_tasks."""
+
+import datetime
+import http
+from unittest import mock
+import uuid
+
+from absl.testing import absltest
+
+from google.rpc import code_pb2
+from fcp.demo import eligibility_eval_tasks
+from fcp.demo import http_actions
+from fcp.protos.federatedcompute import common_pb2
+from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
+
+_TaskAssignmentMode = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
+)
+
+POPULATION_NAME = 'test/population'
+FORWARDING_INFO = common_pb2.ForwardingInfo(
+ target_uri_prefix='https://forwarding.example/')
+
+
+class EligibilityEvalTasksTest(absltest.TestCase):
+
+ @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
+ def test_request_eligibility_eval_task(self, mock_uuid):
+ service = eligibility_eval_tasks.Service(POPULATION_NAME,
+ lambda: FORWARDING_INFO)
+ service.add_task('task1', _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE)
+ request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest(
+ population_name=POPULATION_NAME)
+ retry_window = common_pb2.RetryWindow()
+ retry_window.delay_min.FromTimedelta(datetime.timedelta(seconds=10))
+ retry_window.delay_max.FromTimedelta(datetime.timedelta(seconds=30))
+ self.assertEqual(
+ service.request_eligibility_eval_task(request),
+ eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse(
+ session_id=str(mock_uuid.return_value),
+ task_assignment_forwarding_info=FORWARDING_INFO,
+ no_eligibility_eval_configured=(
+ eligibility_eval_tasks_pb2.NoEligibilityEvalConfigured()),
+ retry_window_if_accepted=retry_window,
+ retry_window_if_rejected=retry_window))
+
+ def test_request_eligibility_eval_task_with_multiple_assignment(self):
+ service = eligibility_eval_tasks.Service(
+ POPULATION_NAME, lambda: FORWARDING_INFO
+ )
+ service.add_task('task1', _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE)
+ service.add_task('task2', _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE)
+
+ request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest(
+ population_name=POPULATION_NAME,
+ eligibility_eval_task_capabilities=(
+ eligibility_eval_tasks_pb2.EligibilityEvalTaskCapabilities(
+ supports_multiple_task_assignment=True
+ )
+ ),
+ )
+ response = service.request_eligibility_eval_task(request)
+ self.assertTrue(response.HasField('eligibility_eval_task'))
+ spec_resource = response.eligibility_eval_task.population_eligibility_spec
+ population_eligibility_spec = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.FromString(
+ spec_resource.inline_resource.data
+ )
+ )
+ self.assertEqual(
+ population_eligibility_spec,
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec(
+ task_info=[
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo(
+ task_name='task1',
+ task_assignment_mode=(
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE
+ ),
+ ),
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo(
+ task_name='task2',
+ task_assignment_mode=(
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE
+ ),
+ ),
+ ],
+ ),
+ )
+
+ def test_request_eligibility_eval_task_with_wrong_population(self):
+ service = eligibility_eval_tasks.Service(POPULATION_NAME,
+ lambda: FORWARDING_INFO)
+ request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest(
+ population_name='other/population')
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.request_eligibility_eval_task(request)
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+
+ def test_report_eligibility_eval_task_result(self):
+ service = eligibility_eval_tasks.Service(POPULATION_NAME,
+ lambda: FORWARDING_INFO)
+ request = eligibility_eval_tasks_pb2.ReportEligibilityEvalTaskResultRequest(
+ population_name=POPULATION_NAME,
+ session_id='session-id',
+ status_code=code_pb2.ABORTED)
+ self.assertEqual(
+ service.report_eligibility_eval_task_result(request),
+ eligibility_eval_tasks_pb2.ReportEligibilityEvalTaskResultResponse())
+
+ def test_report_eligibility_eval_task_result_with_wrong_population(self):
+ service = eligibility_eval_tasks.Service(POPULATION_NAME,
+ lambda: FORWARDING_INFO)
+ request = eligibility_eval_tasks_pb2.ReportEligibilityEvalTaskResultRequest(
+ population_name='other/population',
+ session_id='session-id',
+ status_code=code_pb2.ABORTED)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.report_eligibility_eval_task_result(request)
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+
+ def test_remove_task(self):
+ service = eligibility_eval_tasks.Service(
+ POPULATION_NAME, lambda: FORWARDING_INFO
+ )
+ service.add_task('task', _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE)
+ service.remove_task('task')
+ request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest(
+ population_name=POPULATION_NAME,
+ eligibility_eval_task_capabilities=(
+ eligibility_eval_tasks_pb2.EligibilityEvalTaskCapabilities(
+ supports_multiple_task_assignment=True
+ )
+ ),
+ )
+ response = service.request_eligibility_eval_task(request)
+ spec_resource = response.eligibility_eval_task.population_eligibility_spec
+ population_eligibility_spec = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.FromString(
+ spec_resource.inline_resource.data
+ )
+ )
+ self.assertEqual(
+ population_eligibility_spec,
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec(),
+ )
+
+ def test_remove_missing_task(self):
+ service = eligibility_eval_tasks.Service(
+ POPULATION_NAME, lambda: FORWARDING_INFO
+ )
+ with self.assertRaises(KeyError):
+ service.remove_task('does-not-exist')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/federated_computation.py b/fcp/demo/federated_computation.py
new file mode 100644
index 0000000..a8b84b6
--- /dev/null
+++ b/fcp/demo/federated_computation.py
@@ -0,0 +1,79 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""tff.Computation subclass for the demo Federated Computation platform."""
+
+import functools
+import re
+
+import tensorflow_federated as tff
+
+COMPUTATION_NAME_REGEX = re.compile(r'\w+(/\w+)*')
+
+
+class FederatedComputation(tff.Computation):
+ """A tff.Computation that should be run in a tff.program.FederatedContext."""
+
+ def __init__(self, comp: tff.Computation, *, name: str):
+ """Constructs a new FederatedComputation object.
+
+ Args:
+ comp: The MapReduceForm- and DistributeAggregateForm- compatible
+ computation that will be run.
+ name: A unique name for the computation.
+ """
+ tff.backends.mapreduce.check_computation_compatible_with_map_reduce_form(
+ comp
+ ) # pytype: disable=wrong-arg-types
+ if not COMPUTATION_NAME_REGEX.fullmatch(name):
+ raise ValueError(f'name must match "{COMPUTATION_NAME_REGEX.pattern}".')
+ self._comp = comp
+ self._name = name
+
+ @functools.cached_property
+ def map_reduce_form(self) -> tff.backends.mapreduce.MapReduceForm:
+ """The underlying MapReduceForm representation."""
+ return tff.backends.mapreduce.get_map_reduce_form_for_computation( # pytype: disable=wrong-arg-types
+ self._comp
+ )
+
+ @functools.cached_property
+ def distribute_aggregate_form(
+ self,
+ ) -> tff.backends.mapreduce.DistributeAggregateForm:
+ """The underlying DistributeAggregateForm representation."""
+ return tff.backends.mapreduce.get_distribute_aggregate_form_for_computation( # pytype: disable=wrong-arg-types
+ self._comp
+ )
+
+ @property
+ def wrapped_computation(self) -> tff.Computation:
+ """The underlying tff.Computation."""
+ return self._comp
+
+ @property
+ def name(self) -> str:
+ """The name of the computation."""
+ return self._name
+
+ @property
+ def type_signature(self) -> tff.Type:
+ return self._comp.type_signature
+
+ def __call__(self, *args, **kwargs) ->...:
+ arg = tff.structure.Struct([(None, arg) for arg in args] +
+ list(kwargs.items()))
+ return tff.framework.get_context_stack().current.invoke(self, arg)
+
+ def __hash__(self) -> int:
+ return hash((self._comp, self._name))
diff --git a/fcp/demo/federated_computation_test.py b/fcp/demo/federated_computation_test.py
new file mode 100644
index 0000000..a802845
--- /dev/null
+++ b/fcp/demo/federated_computation_test.py
@@ -0,0 +1,158 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for federated_computation."""
+
+from unittest import mock
+
+from absl.testing import absltest
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.demo import federated_computation as fc
+
+
+@tff.tf_computation(tf.int32, tf.int32)
+def add_values(x, y):
+ return x + y
+
+
+@tff.federated_computation(
+ tff.type_at_server(tf.int32),
+ tff.type_at_clients(tff.SequenceType(tf.string)))
+def count_clients(state, client_data):
+ """Example TFF computation that counts clients."""
+ del client_data
+ client_value = tff.federated_value(1, tff.CLIENTS)
+ aggregated_count = tff.federated_sum(client_value)
+ metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER)
+ return tff.federated_map(add_values, (state, aggregated_count)), metrics
+
+
+@tff.federated_computation(
+ tff.type_at_server(tf.int32),
+ tff.type_at_clients(tff.SequenceType(tf.string)))
+def count_examples(state, client_data):
+ """Example TFF computation that counts client examples."""
+
+ @tff.tf_computation
+ def client_work(client_data):
+ return client_data.reduce(0, lambda x, _: x + 1)
+
+ client_counts = tff.federated_map(client_work, client_data)
+ aggregated_count = tff.federated_sum(client_counts)
+ metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER)
+ return tff.federated_map(add_values, (state, aggregated_count)), metrics
+
+
+class FederatedComputationTest(absltest.TestCase):
+
+ def test_invalid_name(self):
+ with self.assertRaisesRegex(ValueError, r'name must match ".+"'):
+ fc.FederatedComputation(count_clients, name='^invalid^')
+
+ def test_incompatible_computation(self):
+ # This function doesn't have the return value structure required for MRF.
+ @tff.federated_computation(tff.type_at_server(tf.int32))
+ def add_one(value):
+ return value + tff.federated_value(1, tff.SERVER)
+
+ with self.assertRaises(TypeError):
+ fc.FederatedComputation(add_one, name='comp')
+
+ @tff.test.with_context(
+ tff.backends.test.create_sync_test_cpp_execution_context
+ )
+ def test_map_reduce_form(self):
+ comp1 = fc.FederatedComputation(count_clients, name='comp1')
+ comp2 = fc.FederatedComputation(count_examples, name='comp2')
+ self.assertNotEqual(comp1.map_reduce_form, comp2.map_reduce_form)
+
+ # While we treat the MRF contents as an implementation detail, we can verify
+ # the invocation results of the corresponding computation.
+ # comp1 should return the number of clients.
+ self.assertEqual(
+ tff.backends.mapreduce.get_computation_for_map_reduce_form(
+ comp1.map_reduce_form
+ )(0, [['', '']] * 3),
+ (3, ()),
+ )
+ # comp2 should return the number of examples across all clients.
+ self.assertEqual(
+ tff.backends.mapreduce.get_computation_for_map_reduce_form(
+ comp2.map_reduce_form)(0, [['', '']] * 3), (6, ()))
+
+ @tff.test.with_context(
+ tff.backends.native.create_sync_local_cpp_execution_context
+ )
+ def test_distribute_aggregate_form(self):
+ comp1 = fc.FederatedComputation(count_clients, name='comp1')
+ comp2 = fc.FederatedComputation(count_examples, name='comp2')
+ self.assertNotEqual(
+ comp1.distribute_aggregate_form, comp2.distribute_aggregate_form
+ )
+
+ # While we treat the DAF contents as an implementation detail, we can verify
+ # the invocation results of the corresponding computation.
+ # comp1 should return the number of clients.
+ self.assertEqual(
+ tff.backends.mapreduce.get_computation_for_distribute_aggregate_form(
+ comp1.distribute_aggregate_form
+ )(0, [['', '']] * 3),
+ (3, ()),
+ )
+ # comp2 should return the number of examples across all clients.
+ self.assertEqual(
+ tff.backends.mapreduce.get_computation_for_distribute_aggregate_form(
+ comp2.distribute_aggregate_form
+ )(0, [['', '']] * 3),
+ (6, ()),
+ )
+
+ def test_wrapped_computation(self):
+ comp = fc.FederatedComputation(count_clients, name='comp')
+ self.assertEqual(comp.wrapped_computation, count_clients)
+
+ def test_name(self):
+ comp = fc.FederatedComputation(count_clients, name='comp')
+ self.assertEqual(comp.name, 'comp')
+
+ def test_type_signature(self):
+ comp = fc.FederatedComputation(count_clients, name='comp')
+ self.assertEqual(comp.type_signature, count_clients.type_signature)
+
+ def test_call(self):
+ comp = fc.FederatedComputation(count_clients, name='comp')
+ ctx = mock.create_autospec(tff.program.FederatedContext, instance=True)
+ ctx.invoke.return_value = 1234
+ with tff.framework.get_context_stack().install(ctx):
+ self.assertEqual(comp(1, 2, 3, kw1='a', kw2='b'), 1234)
+ ctx.invoke.assert_called_once_with(
+ comp,
+ tff.structure.Struct([(None, 1), (None, 2), (None, 3), ('kw1', 'a'),
+ ('kw2', 'b')]))
+
+ def test_hash(self):
+ comp = fc.FederatedComputation(count_clients, name='comp')
+ # Equivalent objects should have equal hashes.
+ self.assertEqual(
+ hash(comp), hash(fc.FederatedComputation(count_clients, name='comp')))
+ # Different computations or names should produce different hashes.
+ self.assertNotEqual(
+ hash(comp), hash(fc.FederatedComputation(count_clients, name='other')))
+ self.assertNotEqual(
+ hash(comp), hash(fc.FederatedComputation(count_examples, name='comp')))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/federated_context.py b/fcp/demo/federated_context.py
new file mode 100644
index 0000000..e509907
--- /dev/null
+++ b/fcp/demo/federated_context.py
@@ -0,0 +1,314 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TFF FederatedContext subclass for the demo Federated Computation platform."""
+
+from collections.abc import Awaitable
+import socket
+import ssl
+import threading
+from typing import Any, Optional, Union
+import uuid
+
+from absl import logging
+import attr
+import numpy as np
+import tensorflow as tf
+import tensorflow_federated as tff
+import tree
+
+from fcp.artifact_building import artifact_constants
+from fcp.artifact_building import checkpoint_utils
+from fcp.artifact_building import data_spec
+from fcp.artifact_building import federated_compute_plan_builder
+from fcp.artifact_building import plan_utils
+from fcp.artifact_building import variable_helpers
+from fcp.demo import checkpoint_tensor_reference
+from fcp.demo import federated_computation
+from fcp.demo import federated_data_source
+from fcp.demo import server
+from fcp.protos import plan_pb2
+
+
+class FederatedContext(tff.program.FederatedContext):
+ """A FederatedContext for use with the demo platform."""
+
+ def __init__(self,
+ population_name: str,
+ *,
+ base_context: Optional[tff.framework.SyncContext] = None,
+ host: str = 'localhost',
+ port: int = 0,
+ certfile: Optional[str] = None,
+ keyfile: Optional[str] = None,
+ address_family: Optional[socket.AddressFamily] = None):
+ """Initializes a `FederatedContext`.
+
+ Args:
+ population_name: The name of the population to execute computations on.
+ base_context: The context used to run non-federated TFF computations
+ (i.e., computations with a type other than FederatedComputation).
+ host: The hostname the server should bind to.
+ port: The port the server should listen on.
+ certfile: The path to the certificate to use for https.
+ keyfile: The path to the certificate's private key (if separate).
+ address_family: An override for the HTTP server's address family.
+ """
+ # NOTE: The demo server only supports a single population, which must be
+ # specified at startup. An implementation that supports multiple populations
+ # should only use the population name from the PopulationDataSource.
+ if not federated_data_source.POPULATION_NAME_REGEX.fullmatch(
+ population_name):
+ raise ValueError(
+ 'population_name must match '
+ f'"{federated_data_source.POPULATION_NAME_REGEX.pattern}".')
+ self._population_name = population_name
+ self._base_context = base_context
+ self._server = server.InProcessServer(
+ population_name=population_name,
+ host=host,
+ port=port,
+ address_family=address_family)
+ if certfile is not None:
+ context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+ context.load_cert_chain(certfile, keyfile)
+ self._server.socket = context.wrap_socket(
+ self._server.socket, server_side=True)
+ self._server_thread = threading.Thread(
+ target=self._server.serve_forever, daemon=True)
+ self._cached_comps: dict[tuple[tff.Computation, int], plan_pb2.Plan] = {}
+
+ @property
+ def server_port(self) -> int:
+ """The port on which the Federated Compute server is running."""
+ return self._server.server_port
+
+ def __enter__(self):
+ self._server_thread.start()
+ logging.log(logging.INFO, 'Federated Compute server running on %s:%s',
+ self._server.server_name, self._server.server_port)
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ self._server.shutdown()
+ self._server_thread.join()
+ logging.log(logging.INFO, 'Federated Compute server stopped')
+
+ def invoke(self, comp: tff.Computation, arg: Any) -> Any:
+ """Invokes a computation.
+
+ Args:
+ comp: The computation being invoked.
+ arg: The arguments of the call encoded in a computation-specific way. For
+ FederatedComputations, this should be a `(state, config)` tuple, where
+ the state is a possibly nested structure and the configuration is
+ provided by a FederatedDataSource.
+
+ Returns:
+ A value reference structure representing the result of the computation.
+ """
+ # Pass other computation types to the next FederatedContext.
+ if not isinstance(comp, federated_computation.FederatedComputation):
+ if not self._base_context:
+ raise TypeError('computation must be a FederatedComputation if no '
+ 'base_context was provided.')
+ return self._base_context.invoke(comp, arg)
+
+ state, config = self._parse_arg(arg)
+ if config.population_name != self._population_name:
+ raise ValueError('FederatedDataSource and FederatedContext '
+ 'population_names must match.')
+
+ # Since building the plan can be slow, we cache the result to speed up
+ # subsequent invocations.
+ cache_key = (comp.wrapped_computation, id(config.example_selector))
+ try:
+ plan = self._cached_comps[cache_key]
+ except KeyError:
+ plan = federated_compute_plan_builder.build_plan(
+ comp.map_reduce_form,
+ comp.distribute_aggregate_form,
+ self._get_nested_data_spec(config.example_selector),
+ grappler_config=tf.compat.v1.ConfigProto(),
+ generate_server_phase_v2=True,
+ )
+ # Add the TF Lite flatbuffer to the plan. If the conversion fails, the
+ # flatbuffer will be silently omitted and the client will use the
+ # TensorFlow graph in `plan.client_graph_bytes` instead.
+ # NOTE: If conversion failures should not be silent, pass
+ # `forgive_tflite_conversion_failure=False`.
+ plan = plan_utils.generate_and_add_flat_buffer_to_plan(plan)
+ self._cached_comps[cache_key] = plan
+
+ checkpoint_future = self._run_computation(comp.name, config, plan,
+ comp.type_signature.parameter[0],
+ state)
+ result_value_ref = self._create_tensor_reference_struct(
+ comp.type_signature.result, checkpoint_future)
+ return tff.types.type_to_py_container(result_value_ref,
+ comp.type_signature.result)
+
+ def _is_state_structure_of_allowed_types(
+ self,
+ structure: Union[
+ tff.structure.Struct,
+ tf.Tensor,
+ tff.program.MaterializableValue,
+ ],
+ ) -> bool:
+ """Checks if each node in `structure` is an allowed type for `state`."""
+ if isinstance(structure, tff.structure.Struct):
+ structure = tff.structure.flatten(structure)
+ else:
+ structure = tree.flatten(structure)
+ for item in structure:
+ if not (
+ tf.is_tensor(item)
+ or isinstance(
+ item,
+ (
+ np.ndarray,
+ np.number,
+ int,
+ float,
+ str,
+ bytes,
+ tff.program.MaterializableValueReference,
+ ),
+ )
+ ):
+ return False
+ return True
+
+ def _parse_arg(
+ self, arg: tff.structure.Struct
+ ) -> tuple[Union[tff.structure.Struct, tf.Tensor,
+ tff.program.MaterializableValueReference],
+ federated_data_source.DataSelectionConfig]:
+ """Parses and validates the invoke arguments."""
+ if len(arg) != 2:
+ raise ValueError(f'The argument structure is unsupported: {arg}.')
+
+ state, config = arg
+ if attr.has(type(state)):
+ state = tff.structure.from_container(state, recursive=True)
+ if not self._is_state_structure_of_allowed_types(state):
+ raise TypeError(
+ 'arg[0] must be a value or structure of values of '
+ '`MaterializableValueReference`s, `tf.Tensor`s, '
+ '`np.ndarray`s, `np.number`s, or Python scalars. Got: '
+ f'{tf.nest.map_structure(type, state)!r})'
+ )
+
+ # Code below assumes single values are always `tf.Tensor`s.
+ if isinstance(state, (int, float, str, bytes, np.ndarray, np.number)):
+ state = tf.convert_to_tensor(state)
+
+ if not isinstance(config, federated_data_source.DataSelectionConfig):
+ raise TypeError('arg[1] must be the result of '
+ 'FederatedDataSource.iterator().select().')
+ return state, config
+
+ def _get_nested_data_spec(self, example_selector) -> data_spec.NestedDataSpec:
+ """Converts a NestedExampleSelector to a NestedDataSpec."""
+ if isinstance(example_selector, dict):
+ return {
+ k: self._get_nested_data_spec(v) for k, v in example_selector.items()
+ }
+ return data_spec.DataSpec(example_selector)
+
+ async def _run_computation(
+ self, name: str, config: federated_data_source.DataSelectionConfig,
+ plan: plan_pb2.Plan, input_type: tff.Type,
+ input_state: Union[tff.structure.Struct, tf.Tensor,
+ tff.program.MaterializableValueReference]
+ ) -> bytes:
+ """Prepares and runs a computation using the demo server."""
+ input_checkpoint = self._state_to_checkpoint(
+ input_type, await self._resolve_value_references(input_state))
+ try:
+ logging.log(logging.INFO, 'Started running %s', name)
+ return await self._server.run_computation(
+ name,
+ plan,
+ input_checkpoint,
+ config.task_assignment_mode,
+ config.num_clients,
+ )
+ finally:
+ logging.log(logging.INFO, 'Finished running %s', name)
+
+ async def _resolve_value_references(
+ self, structure: Union[tff.structure.Struct, tf.Tensor,
+ tff.program.MaterializableValueReference]
+ ) -> Union[tff.structure.Struct, tf.Tensor]:
+ """Dereferences any MaterializableValueReferences in a struct."""
+ if isinstance(structure, tff.program.MaterializableValueReference):
+ return await structure.get_value() # pytype: disable=bad-return-type # numpy-scalars
+ elif tf.is_tensor(structure):
+ return structure
+ elif isinstance(structure, tff.structure.Struct):
+ s = [
+ self._resolve_value_references(x)
+ for x in tff.structure.flatten(structure)
+ ]
+ return tff.structure.pack_sequence_as(structure, s)
+ else:
+ raise ValueError(
+ 'arg[1] must be a struct, Tensor, or MaterializableValueReference.')
+
+ def _state_to_checkpoint(
+ self, state_type: tff.Type, state: Union[tff.structure.Struct,
+ tf.Tensor]) -> bytes:
+ """Converts computation input state to a checkpoint file.
+
+ The checkpoint file format is used to pass the state to
+ InProcessServer.run_computation.
+
+ Args:
+ state_type: The TFF type of the state structure.
+ state: A Tensor or TFF structure with input state for a computation.
+
+ Returns:
+ The state encoded as a checkpoint file.
+ """
+ var_names = variable_helpers.variable_names_from_type(
+ state_type, name=artifact_constants.SERVER_STATE_VAR_PREFIX)
+
+ # Write to a file in TensorFlow's RamFileSystem to avoid disk I/O.
+ tmpfile = f'ram://{uuid.uuid4()}.ckpt'
+ checkpoint_utils.save_tff_structure_to_checkpoint(state, var_names, tmpfile)
+ try:
+ with tf.io.gfile.GFile(tmpfile, 'rb') as f:
+ return f.read()
+ finally:
+ tf.io.gfile.remove(tmpfile)
+
+ def _create_tensor_reference_struct(
+ self, result_type: tff.Type,
+ checkpoint_future: Awaitable[bytes]) -> tff.structure.Struct:
+ """Creates the CheckpointTensorReference struct for a result type."""
+ shared_checkpoint_future = tff.async_utils.SharedAwaitable(
+ checkpoint_future)
+ tensor_specs = checkpoint_utils.tff_type_to_tensor_spec_list(result_type)
+ var_names = (
+ variable_helpers.variable_names_from_type(
+ result_type[0], name=artifact_constants.SERVER_STATE_VAR_PREFIX) +
+ variable_helpers.variable_names_from_type(
+ result_type[1], name=artifact_constants.SERVER_METRICS_VAR_PREFIX))
+ tensor_refs = [
+ checkpoint_tensor_reference.CheckpointTensorReference(
+ var_name, spec.dtype, spec.shape, shared_checkpoint_future)
+ for var_name, spec in zip(var_names, tensor_specs)
+ ]
+ return checkpoint_utils.pack_tff_value(result_type, tensor_refs)
diff --git a/fcp/demo/federated_context_test.py b/fcp/demo/federated_context_test.py
new file mode 100644
index 0000000..4d5a7ac
--- /dev/null
+++ b/fcp/demo/federated_context_test.py
@@ -0,0 +1,438 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for federated_context."""
+
+import http
+import http.client
+import socket
+import threading
+import unittest
+from unittest import mock
+
+from absl.testing import absltest
+import attr
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.artifact_building import artifact_constants
+from fcp.artifact_building import federated_compute_plan_builder
+from fcp.artifact_building import plan_utils
+from fcp.artifact_building import variable_helpers
+from fcp.demo import federated_computation
+from fcp.demo import federated_context
+from fcp.demo import federated_data_source
+from fcp.demo import server
+from fcp.demo import test_utils
+from fcp.protos import plan_pb2
+
+ADDRESS_FAMILY = socket.AddressFamily.AF_INET
+POPULATION_NAME = 'test/population'
+DATA_SOURCE = federated_data_source.FederatedDataSource(
+ POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/test'))
+
+
+@tff.tf_computation(tf.int32)
+def add_one(x):
+ return x + 1
+
+
+@tff.federated_computation(
+ tff.type_at_server(tf.int32),
+ tff.type_at_clients(tff.SequenceType(tf.string)))
+def count_clients(state, client_data):
+ """Example TFF computation that counts clients."""
+ del client_data
+ num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
+ non_state = tff.federated_value((), tff.SERVER)
+ return state + num_clients, non_state
+
+
+@tff.federated_computation(
+ tff.type_at_server(tff.StructType([('foo', tf.int32), ('bar', tf.int32)])),
+ tff.type_at_clients(tff.SequenceType(tf.string)),
+)
+def irregular_arrays(state, client_data):
+ """Example TFF computation that returns irregular data."""
+ del client_data
+ num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
+ non_state = tff.federated_value(1, tff.SERVER)
+ return state, non_state + num_clients
+
+
+@attr.s(eq=False, frozen=True, slots=True)
+class TestClass:
+ """An attrs class."""
+
+ field_one = attr.ib()
+ field_two = attr.ib()
+
+
+@tff.tf_computation
+def init():
+ return TestClass(field_one=1, field_two=2)
+
+
+attrs_type = init.type_signature.result
+
+
+@tff.federated_computation(
+ tff.type_at_server(attrs_type),
+ tff.type_at_clients(tff.SequenceType(tf.string)),
+)
+def attrs_computation(state, client_data):
+ """Example TFF computation that returns an attrs class."""
+ del client_data
+ num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
+ non_state = tff.federated_value(1, tff.SERVER)
+ return state, non_state + num_clients
+
+
+def build_result_checkpoint(state: int) -> bytes:
+ """Helper function to build a result checkpoint for `count_clients`."""
+ var_names = variable_helpers.variable_names_from_type(
+ count_clients.type_signature.result[0],
+ name=artifact_constants.SERVER_STATE_VAR_PREFIX)
+ return test_utils.create_checkpoint({var_names[0]: state})
+
+
+class FederatedContextTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
+
+ def test_invalid_population_name(self):
+ with self.assertRaisesRegex(ValueError, 'population_name must match ".+"'):
+ federated_context.FederatedContext(
+ '^^invalid^^', address_family=ADDRESS_FAMILY)
+
+ @mock.patch.object(server.InProcessServer, 'shutdown', autospec=True)
+ @mock.patch.object(server.InProcessServer, 'serve_forever', autospec=True)
+ def test_context_management(self, serve_forever, shutdown):
+ started = threading.Event()
+ serve_forever.side_effect = lambda *args, **kwargs: started.set()
+
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)
+ self.assertFalse(started.is_set())
+ shutdown.assert_not_called()
+ with ctx:
+ self.assertTrue(started.wait(0.5))
+ shutdown.assert_not_called()
+ shutdown.assert_called_once()
+
+ def test_http(self):
+ with federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY) as ctx:
+ conn = http.client.HTTPConnection('localhost', port=ctx.server_port)
+ conn.request('GET', '/does-not-exist')
+ self.assertEqual(conn.getresponse().status, http.HTTPStatus.NOT_FOUND)
+
+ def test_invoke_non_federated_with_base_context(self):
+ base_context = tff.backends.native.create_sync_local_cpp_execution_context()
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME,
+ address_family=ADDRESS_FAMILY,
+ base_context=base_context)
+ with tff.framework.get_context_stack().install(ctx):
+ self.assertEqual(add_one(3), 4)
+
+ def test_invoke_non_federated_without_base_context(self):
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)
+ with tff.framework.get_context_stack().install(ctx):
+ with self.assertRaisesRegex(TypeError,
+ 'computation must be a FederatedComputation'):
+ add_one(3)
+
+ def test_invoke_with_invalid_state_type(self):
+ comp = federated_computation.FederatedComputation(count_clients, name='x')
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)
+ with tff.framework.get_context_stack().install(ctx):
+ with self.assertRaisesRegex(
+ TypeError, r'arg\[0\] must be a value or structure of values'
+ ):
+ comp(plan_pb2.Plan(), DATA_SOURCE.iterator().select(1))
+
+ def test_invoke_with_invalid_data_source_type(self):
+ comp = federated_computation.FederatedComputation(count_clients, name='x')
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)
+ with tff.framework.get_context_stack().install(ctx):
+ with self.assertRaisesRegex(
+ TypeError, r'arg\[1\] must be the result of '
+ r'FederatedDataSource.iterator\(\).select\(\)'):
+ comp(0, plan_pb2.Plan())
+
+ def test_invoke_succeeds_with_structure_state_type(self):
+ comp = federated_computation.FederatedComputation(
+ irregular_arrays, name='x'
+ )
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY
+ )
+ with tff.framework.get_context_stack().install(ctx):
+ state = {'foo': (3, 1), 'bar': (4, 5, 6)}
+ comp(state, DATA_SOURCE.iterator().select(1))
+
+ def test_invoke_succeeds_with_attrs_state_type(self):
+ comp = federated_computation.FederatedComputation(
+ attrs_computation, name='x'
+ )
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY
+ )
+ with tff.framework.get_context_stack().install(ctx):
+ state = TestClass(field_one=1, field_two=2)
+ comp(state, DATA_SOURCE.iterator().select(1))
+
+ def test_invoke_with_mismatched_population_names(self):
+ comp = federated_computation.FederatedComputation(count_clients, name='x')
+ ds = federated_data_source.FederatedDataSource('other/name',
+ DATA_SOURCE.example_selector)
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)
+ with tff.framework.get_context_stack().install(ctx):
+ with self.assertRaisesRegex(
+ ValueError, 'FederatedDataSource and FederatedContext '
+ 'population_names must match'):
+ comp(0, ds.iterator().select(1))
+
+ @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True)
+ async def test_invoke_success(self, run_computation):
+ run_computation.return_value = build_result_checkpoint(7)
+
+ comp = federated_computation.FederatedComputation(count_clients, name='x')
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)
+ release_manager = tff.program.MemoryReleaseManager()
+ with tff.framework.get_context_stack().install(ctx):
+ state, _ = comp(3, DATA_SOURCE.iterator().select(10))
+ await release_manager.release(
+ state, tff.type_at_server(tf.int32), key='result')
+
+ self.assertEqual(release_manager.values()['result'][0], 7)
+
+ run_computation.assert_called_once_with(
+ mock.ANY,
+ comp.name,
+ mock.ANY,
+ mock.ANY,
+ DATA_SOURCE.task_assignment_mode,
+ 10,
+ )
+ plan = run_computation.call_args.args[2]
+ self.assertIsInstance(plan, plan_pb2.Plan)
+ self.assertNotEmpty(plan.client_tflite_graph_bytes)
+ input_var_names = variable_helpers.variable_names_from_type(
+ count_clients.type_signature.parameter[0],
+ name=artifact_constants.SERVER_STATE_VAR_PREFIX)
+ self.assertLen(input_var_names, 1)
+ self.assertEqual(
+ test_utils.read_tensor_from_checkpoint(
+ run_computation.call_args.args[3], input_var_names[0], tf.int32), 3)
+
+ @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True)
+ async def test_invoke_with_value_reference(self, run_computation):
+ run_computation.side_effect = [
+ build_result_checkpoint(1234),
+ build_result_checkpoint(5678)
+ ]
+
+ comp = federated_computation.FederatedComputation(count_clients, name='x')
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)
+ release_manager = tff.program.MemoryReleaseManager()
+ with tff.framework.get_context_stack().install(ctx):
+ state, _ = comp(3, DATA_SOURCE.iterator().select(10))
+ state, _ = comp(state, DATA_SOURCE.iterator().select(10))
+ await release_manager.release(
+ state, tff.type_at_server(tf.int32), key='result')
+
+ self.assertEqual(release_manager.values()['result'][0], 5678)
+
+ input_var_names = variable_helpers.variable_names_from_type(
+ count_clients.type_signature.parameter[0],
+ name=artifact_constants.SERVER_STATE_VAR_PREFIX)
+ self.assertLen(input_var_names, 1)
+ # The second invocation should be passed the value returned by the first
+ # invocation.
+ self.assertEqual(run_computation.call_count, 2)
+ self.assertEqual(
+ test_utils.read_tensor_from_checkpoint(
+ run_computation.call_args.args[3], input_var_names[0], tf.int32),
+ 1234)
+
+ async def test_invoke_without_input_state(self):
+ comp = federated_computation.FederatedComputation(count_clients, name='x')
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)
+ with tff.framework.get_context_stack().install(ctx):
+ with self.assertRaisesRegex(
+ TypeError, r'arg\[0\] must be a value or structure of values'
+ ):
+ comp(None, DATA_SOURCE.iterator().select(1))
+
+ @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True)
+ async def test_invoke_with_run_computation_error(self, run_computation):
+ run_computation.side_effect = ValueError('message')
+
+ comp = federated_computation.FederatedComputation(count_clients, name='x')
+ ctx = federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)
+ release_manager = tff.program.MemoryReleaseManager()
+ with tff.framework.get_context_stack().install(ctx):
+ state, _ = comp(0, DATA_SOURCE.iterator().select(10))
+ with self.assertRaisesRegex(ValueError, 'message'):
+ await release_manager.release(
+ state, tff.type_at_server(tf.int32), key='result')
+
+
+class FederatedContextPlanCachingTest(absltest.TestCase,
+ unittest.IsolatedAsyncioTestCase):
+
+ async def asyncSetUp(self):
+ await super().asyncSetUp()
+
+ @tff.federated_computation(
+ tff.type_at_server(tf.int32),
+ tff.type_at_clients(tff.SequenceType(tf.string)))
+ def identity(state, client_data):
+ del client_data
+ return state, tff.federated_value((), tff.SERVER)
+
+ self.count_clients_comp1 = federated_computation.FederatedComputation(
+ count_clients, name='count_clients1')
+ self.count_clients_comp2 = federated_computation.FederatedComputation(
+ count_clients, name='count_clients2')
+ self.identity_comp = federated_computation.FederatedComputation(
+ identity, name='identity')
+
+ self.data_source1 = federated_data_source.FederatedDataSource(
+ POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/1'))
+ self.data_source2 = federated_data_source.FederatedDataSource(
+ POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/2'))
+
+ self.run_computation = self.enter_context(
+ mock.patch.object(
+ server.InProcessServer, 'run_computation', autospec=True))
+ self.run_computation.return_value = build_result_checkpoint(0)
+ self.build_plan = self.enter_context(
+ mock.patch.object(
+ federated_compute_plan_builder, 'build_plan', autospec=True))
+ self.build_plan.return_value = plan_pb2.Plan()
+ self.generate_and_add_flat_buffer_to_plan = self.enter_context(
+ mock.patch.object(
+ plan_utils, 'generate_and_add_flat_buffer_to_plan', autospec=True))
+ self.generate_and_add_flat_buffer_to_plan.side_effect = lambda plan: plan
+ self.enter_context(tff.framework.get_context_stack().install(
+ federated_context.FederatedContext(
+ POPULATION_NAME, address_family=ADDRESS_FAMILY)))
+ self.release_manager = tff.program.MemoryReleaseManager()
+
+ # Run (and therefore cache) count_clients_comp1 with data_source1.
+ await self.release_manager.release(
+ self.count_clients_comp1(0,
+ self.data_source1.iterator().select(1)),
+ self.count_clients_comp1.type_signature.result,
+ key='result')
+ self.build_plan.assert_called_once()
+ self.assertEqual(self.build_plan.call_args.args[0],
+ self.count_clients_comp1.map_reduce_form)
+ self.assertEqual(
+ self.build_plan.call_args.args[1],
+ self.count_clients_comp1.distribute_aggregate_form,
+ )
+ self.assertEqual(
+ self.build_plan.call_args.args[2].example_selector_proto,
+ self.data_source1.example_selector,
+ )
+ self.run_computation.assert_called_once()
+ self.build_plan.reset_mock()
+ self.run_computation.reset_mock()
+
+ async def test_reuse_with_repeat_computation(self):
+ await self.release_manager.release(
+ self.count_clients_comp1(0,
+ self.data_source1.iterator().select(1)),
+ self.count_clients_comp1.type_signature.result,
+ key='result')
+ self.build_plan.assert_not_called()
+ self.run_computation.assert_called_once()
+
+ async def test_reuse_with_changed_num_clients(self):
+ await self.release_manager.release(
+ self.count_clients_comp1(0,
+ self.data_source1.iterator().select(10)),
+ self.count_clients_comp1.type_signature.result,
+ key='result')
+ self.build_plan.assert_not_called()
+ self.run_computation.assert_called_once()
+
+ async def test_reuse_with_changed_initial_state(self):
+ await self.release_manager.release(
+ self.count_clients_comp1(3,
+ self.data_source1.iterator().select(1)),
+ self.count_clients_comp1.type_signature.result,
+ key='result')
+ self.build_plan.assert_not_called()
+ self.run_computation.assert_called_once()
+
+ async def test_reuse_with_equivalent_map_reduce_form(self):
+ await self.release_manager.release(
+ self.count_clients_comp2(0,
+ self.data_source1.iterator().select(1)),
+ self.count_clients_comp2.type_signature.result,
+ key='result')
+ self.build_plan.assert_not_called()
+ self.run_computation.assert_called_once()
+
+ async def test_rebuild_with_different_computation(self):
+ await self.release_manager.release(
+ self.identity_comp(0,
+ self.data_source1.iterator().select(1)),
+ self.identity_comp.type_signature.result,
+ key='result')
+ self.build_plan.assert_called_once()
+ self.assertEqual(self.build_plan.call_args.args[0],
+ self.identity_comp.map_reduce_form)
+ self.assertEqual(
+ self.build_plan.call_args.args[1],
+ self.identity_comp.distribute_aggregate_form,
+ )
+ self.assertEqual(
+ self.build_plan.call_args.args[2].example_selector_proto,
+ self.data_source1.example_selector,
+ )
+ self.run_computation.assert_called_once()
+
+ async def test_rebuild_with_different_data_source(self):
+ await self.release_manager.release(
+ self.count_clients_comp1(0,
+ self.data_source2.iterator().select(1)),
+ self.count_clients_comp1.type_signature.result,
+ key='result')
+ self.build_plan.assert_called_once()
+ self.assertEqual(self.build_plan.call_args.args[0],
+ self.count_clients_comp1.map_reduce_form)
+ self.assertEqual(
+ self.build_plan.call_args.args[1],
+ self.count_clients_comp1.distribute_aggregate_form,
+ )
+ self.assertEqual(
+ self.build_plan.call_args.args[2].example_selector_proto,
+ self.data_source2.example_selector,
+ )
+ self.run_computation.assert_called_once()
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/federated_data_source.py b/fcp/demo/federated_data_source.py
new file mode 100644
index 0000000..71d9cff
--- /dev/null
+++ b/fcp/demo/federated_data_source.py
@@ -0,0 +1,141 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TFF FederatedDataSource for the demo Federated Computation platform."""
+
+import dataclasses
+import functools
+import re
+from typing import Optional, Union
+
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.protos import plan_pb2
+from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
+
+POPULATION_NAME_REGEX = re.compile(r'\w+(/\w+)*')
+
+_NestedExampleSelector = Union[plan_pb2.ExampleSelector,
+ dict[str, '_NestedExampleSelector']]
+_TaskAssignmentMode = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
+)
+
+
+@dataclasses.dataclass
+class DataSelectionConfig:
+ population_name: str
+ example_selector: _NestedExampleSelector
+ task_assignment_mode: _TaskAssignmentMode
+ num_clients: int
+
+
+class FederatedDataSource(tff.program.FederatedDataSource):
+ """A FederatedDataSource for use with the demo platform.
+
+ A FederatedDataSource represents a population of client devices and the set of
+ on-device data over which computations should be invoked.
+ """
+
+ _FEDERATED_TYPE = tff.FederatedType(tff.SequenceType(tf.string), tff.CLIENTS)
+
+ def __init__(
+ self,
+ population_name: str,
+ example_selector: _NestedExampleSelector,
+ task_assignment_mode: _TaskAssignmentMode = (
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE
+ ),
+ ):
+ """Constructs a new FederatedDataSource object.
+
+ Args:
+ population_name: The name of the population to execute computations on.
+ example_selector: A `plan_pb2.ExampleSelector` or a structure of
+ ExampleSelectors indicating CLIENTS-placed data to execute over.
+ task_assignment_mode: The TaskAssignmentMode to use for this computation.
+ """
+ if not POPULATION_NAME_REGEX.fullmatch(population_name):
+ raise ValueError(
+ f'population_name must match "{POPULATION_NAME_REGEX.pattern}".')
+ self._population_name = population_name
+ self._example_selector = example_selector
+ self._task_assignment_mode = task_assignment_mode
+
+ @property
+ def population_name(self) -> str:
+ """The name of the population from which examples will be retrieved."""
+ return self._population_name
+
+ @property
+ def example_selector(self) -> _NestedExampleSelector:
+ """The NestedExampleSelector used to obtain the examples."""
+ return self._example_selector
+
+ @property
+ def task_assignment_mode(self) -> _TaskAssignmentMode:
+ """The TaskAssignmentMode to use for this computation."""
+ return self._task_assignment_mode
+
+ @functools.cached_property
+ def federated_type(self) -> tff.FederatedType:
+
+ def get_struct_type(value):
+ if isinstance(value, dict):
+ return tff.StructType([
+ (k, get_struct_type(v)) for k, v in value.items()
+ ])
+ # ExternalDataset always returns a sequence of tf.strings, which should be
+ # serialized `tf.train.Example` protos.
+ return tff.SequenceType(tf.string)
+
+ return tff.FederatedType(
+ get_struct_type(self._example_selector), tff.CLIENTS)
+
+ @functools.cached_property
+ def capabilities(self) -> list[tff.program.Capability]:
+ return [tff.program.Capability.SUPPORTS_REUSE]
+
+ def iterator(self) -> tff.program.FederatedDataSourceIterator:
+ return _FederatedDataSourceIterator(self)
+
+
+class _FederatedDataSourceIterator(tff.program.FederatedDataSourceIterator):
+ """A `FederatedDataSourceIterator` for use with the demo platform."""
+
+ def __init__(self, data_source: FederatedDataSource):
+ self._data_source = data_source
+
+ @classmethod
+ def from_bytes(cls, data: bytes) -> '_FederatedDataSourceIterator':
+ """Deserializes the object from bytes."""
+ raise NotImplementedError
+
+ def to_bytes(self) -> bytes:
+ """Serializes the object to bytes."""
+ raise NotImplementedError
+
+ @property
+ def federated_type(self):
+ return self._data_source.federated_type
+
+ def select(self, num_clients: Optional[int] = None) -> DataSelectionConfig:
+ if num_clients is None or num_clients <= 0:
+ raise ValueError('num_clients must be positive.')
+ return DataSelectionConfig(
+ self._data_source.population_name,
+ self._data_source.example_selector,
+ self._data_source.task_assignment_mode,
+ num_clients,
+ )
diff --git a/fcp/demo/federated_data_source_test.py b/fcp/demo/federated_data_source_test.py
new file mode 100644
index 0000000..dbeb7ea
--- /dev/null
+++ b/fcp/demo/federated_data_source_test.py
@@ -0,0 +1,128 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for federated_data_source."""
+
+from absl.testing import absltest
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp.demo import federated_data_source as fds
+from fcp.protos import plan_pb2
+from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
+
+_TaskAssignmentMode = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
+)
+
+POPULATION_NAME = 'test/name'
+EXAMPLE_SELECTOR = plan_pb2.ExampleSelector(collection_uri='app://test')
+
+
+class FederatedDataSourceTest(absltest.TestCase):
+
+ def test_invalid_population_name(self):
+ with self.assertRaisesRegex(ValueError, r'population_name must match ".+"'):
+ fds.FederatedDataSource('^invalid^', EXAMPLE_SELECTOR)
+
+ def test_population_name(self):
+ ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
+ self.assertEqual(ds.population_name, POPULATION_NAME)
+
+ def test_example_selector(self):
+ ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
+ self.assertEqual(ds.example_selector, EXAMPLE_SELECTOR)
+
+ def test_default_task_assignment_mode(self):
+ ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
+ self.assertEqual(
+ ds.task_assignment_mode, _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE
+ )
+
+ def test_task_assignment_mode(self):
+ ds = fds.FederatedDataSource(
+ POPULATION_NAME,
+ EXAMPLE_SELECTOR,
+ task_assignment_mode=_TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
+ )
+ self.assertEqual(
+ ds.task_assignment_mode,
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
+ )
+
+ def test_federated_type(self):
+ ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
+ self.assertEqual(
+ ds.federated_type,
+ tff.FederatedType(tff.SequenceType(tf.string), tff.CLIENTS))
+
+ def test_federated_type_nested(self):
+ nested_example_selector = {
+ 'a': EXAMPLE_SELECTOR,
+ 'b': EXAMPLE_SELECTOR,
+ 'c': {
+ '1': EXAMPLE_SELECTOR,
+ '2': EXAMPLE_SELECTOR
+ },
+ }
+ ds = fds.FederatedDataSource(POPULATION_NAME, nested_example_selector)
+ self.assertEqual(
+ ds.federated_type,
+ tff.FederatedType(
+ tff.StructType([
+ ('a', tff.SequenceType(tf.string)),
+ ('b', tff.SequenceType(tf.string)),
+ ('c',
+ tff.StructType([
+ ('1', tff.SequenceType(tf.string)),
+ ('2', tff.SequenceType(tf.string)),
+ ])),
+ ]), tff.CLIENTS))
+
+ def test_capabilities(self):
+ ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
+ self.assertListEqual(ds.capabilities,
+ [tff.program.Capability.SUPPORTS_REUSE])
+
+ def test_iterator_federated_type(self):
+ ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
+ self.assertEqual(ds.iterator().federated_type, ds.federated_type)
+
+ def test_iterator_select(self):
+ ds = fds.FederatedDataSource(
+ POPULATION_NAME,
+ EXAMPLE_SELECTOR,
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
+ )
+ self.assertEqual(
+ ds.iterator().select(10),
+ fds.DataSelectionConfig(
+ POPULATION_NAME,
+ EXAMPLE_SELECTOR,
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
+ 10,
+ ),
+ )
+
+ def test_iterator_select_with_invalid_num_clients(self):
+ ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
+ with self.assertRaisesRegex(ValueError, 'num_clients must be positive'):
+ ds.iterator().select(num_clients=None)
+ with self.assertRaisesRegex(ValueError, 'num_clients must be positive'):
+ ds.iterator().select(num_clients=-5)
+ with self.assertRaisesRegex(ValueError, 'num_clients must be positive'):
+ ds.iterator().select(num_clients=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/federated_program_test.py b/fcp/demo/federated_program_test.py
new file mode 100644
index 0000000..206d40a
--- /dev/null
+++ b/fcp/demo/federated_program_test.py
@@ -0,0 +1,172 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""End-to-end test running a simple Federated Program."""
+
+import asyncio
+import os
+import tempfile
+import unittest
+
+from absl import flags
+from absl.testing import absltest
+import tensorflow as tf
+import tensorflow_federated as tff
+
+from fcp import demo
+from fcp.client import client_runner_example_data_pb2
+from fcp.protos import plan_pb2
+
+POPULATION_NAME = 'test/population'
+COLLECTION_URI = 'app:/example'
+
+
+@tff.federated_computation()
+def initialize() -> tff.Value:
+ """Returns the initial state."""
+ return tff.federated_value(0, tff.SERVER)
+
+
+@tff.federated_computation(
+ tff.type_at_server(tf.int32),
+ tff.type_at_clients(tff.SequenceType(tf.string)))
+def sum_counts(state, client_data):
+ """Sums the value of all 'count' features across all clients."""
+
+ @tf.function
+ def reduce_counts(s: tf.int32, example: tf.string) -> tf.int32:
+ features = {'count': tf.io.FixedLenFeature((), tf.int64)}
+ count = tf.io.parse_example(example, features=features)['count']
+ return s + tf.cast(count, tf.int32)
+
+ @tff.tf_computation
+ def client_work(client_data):
+ return client_data.reduce(0, reduce_counts)
+
+ client_counts = tff.federated_map(client_work, client_data)
+ aggregated_count = tff.federated_sum(client_counts)
+
+ num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
+ metrics = tff.federated_zip((num_clients,))
+ return state + aggregated_count, metrics
+
+
+async def program_logic(init: tff.Computation, comp: tff.Computation,
+ data_source: tff.program.FederatedDataSource,
+ total_rounds: int, number_of_clients: int,
+ release_manager: tff.program.ReleaseManager) -> None:
+ """Initializes and runs a computation, releasing metrics and final state."""
+ tff.program.check_in_federated_context()
+ data_iterator = data_source.iterator()
+ state = init()
+ for i in range(total_rounds):
+ cohort_config = data_iterator.select(number_of_clients)
+ state, metrics = comp(state, cohort_config)
+ await release_manager.release(
+ metrics, comp.type_signature.result[1], key=f'metrics/{i}')
+ await release_manager.release(
+ state, comp.type_signature.result[0], key='result')
+
+
+async def run_client(population_name: str, server_url: str, num_rounds: int,
+ collection_uri: str,
+ examples: list[tf.train.Example]) -> int:
+ """Runs a client and returns its return code."""
+ client_runner = os.path.join(
+ flags.FLAGS.test_srcdir,
+ 'com_google_fcp',
+ 'fcp',
+ 'client',
+ 'client_runner_main')
+
+ example_data = client_runner_example_data_pb2.ClientRunnerExampleData(
+ examples_by_collection_uri={
+ collection_uri:
+ client_runner_example_data_pb2.ClientRunnerExampleData
+ .ExampleList(examples=[e.SerializeToString() for e in examples])
+ })
+
+ # Unfortunately, since there's no convenient way to tell when the server has
+ # actually started serving the computation, we cannot delay starting the
+ # client until the server's ready to assign it a task. This isn't an issue in
+ # a production setting, where there's a steady stream of clients connecting,
+ # but it is a problem in this unit test, where each client only connects to
+ # the server a fixed number of times. To work around this, we give the server
+ # a little extra time to become ready; this delay doesn't significantly slow
+ # down the test since there are many other time-consuming steps.
+ await asyncio.sleep(1)
+
+ with tempfile.NamedTemporaryFile() as tmpfile:
+ tmpfile.write(example_data.SerializeToString())
+ tmpfile.flush()
+ subprocess = await asyncio.create_subprocess_exec(
+ client_runner, f'--population={population_name}',
+ f'--server={server_url}', f'--example_data_path={tmpfile.name}',
+ f'--num_rounds={num_rounds}', '--sleep_after_round_secs=1',
+ '--use_http_federated_compute_protocol', '--use_tflite_training')
+ return await subprocess.wait()
+
+
+def create_examples(counts: list[int]) -> list[tf.train.Example]:
+ """Creates a list of tf.train.Example with the provided 'count' features."""
+ examples = []
+ for count in counts:
+ example = tf.train.Example()
+ example.features.feature['count'].int64_list.value.append(count)
+ examples.append(example)
+ return examples
+
+
+class FederatedProgramTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
+
+ async def test_multiple_rounds(self):
+ data_source = demo.FederatedDataSource(
+ POPULATION_NAME,
+ plan_pb2.ExampleSelector(collection_uri=COLLECTION_URI))
+ comp = demo.FederatedComputation(sum_counts, name='sum_counts')
+ release_manager = tff.program.MemoryReleaseManager()
+ num_rounds = 2
+ client_counts = [
+ [0, 3, 5, 1],
+ [2, 4],
+ ]
+
+ base_context = tff.backends.native.create_sync_local_cpp_execution_context()
+
+ with demo.FederatedContext(
+ POPULATION_NAME,
+ base_context=base_context) as ctx:
+ clients = [
+ run_client(POPULATION_NAME, f'http://localhost:{ctx.server_port}',
+ num_rounds, COLLECTION_URI, create_examples(counts))
+ for counts in client_counts
+ ]
+ with tff.framework.get_context_stack().install(ctx):
+ program = program_logic(initialize, comp, data_source, num_rounds,
+ len(client_counts), release_manager)
+ return_codes = (await asyncio.gather(program, *clients))[1:]
+ # All clients should complete successfully.
+ self.assertListEqual(return_codes, [0] * len(client_counts))
+
+ self.assertSequenceEqual(release_manager.values()['result'],
+ (num_rounds * sum([sum(l) for l in client_counts]),
+ tff.type_at_server(tf.int32)))
+ for i in range(num_rounds):
+ self.assertSequenceEqual(
+ release_manager.values()[f'metrics/{i}'],
+ ((len(client_counts),),
+ tff.type_at_server(tff.StructWithPythonType([tf.int32], tuple))))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/http_actions.py b/fcp/demo/http_actions.py
new file mode 100644
index 0000000..503da70
--- /dev/null
+++ b/fcp/demo/http_actions.py
@@ -0,0 +1,295 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for creating proto service and HTTP action handlers.
+
+The `@proto_action` function annotates a method as implementing a proto service
+method. The annotated method should have the type
+`Callable[[RequestMessage], ResponseMessage]`. The decorator will take care of
+transcoding to/from a HTTP request, similar to
+https://cloud.google.com/endpoints/docs/grpc/transcoding. The transcoding only
+supports proto-over-http ('?alt=proto').
+
+The `@http_action` function annotates a method as implementing a HTTP action at
+some request path. The annotated method will receive the request body, and
+should return a `HttpResponse`.
+
+The `create_handler` function merges one or more objects with decorated methods
+into a single request handler that's compatible with `http.server`.
+"""
+
+import collections
+import dataclasses
+import enum
+import gzip
+import http
+import http.server
+import re
+from typing import Any, Callable, Mapping, Match, Pattern, Type, TypeVar
+import urllib.parse
+import zlib
+
+from absl import logging
+
+from google.api import annotations_pb2
+from google.protobuf import descriptor_pool
+from google.protobuf import message
+from google.protobuf import message_factory
+
+_CallableT = TypeVar('_CallableT', bound=Callable)
+
+_HTTP_ACTION_ATTR = '_http_action_data'
+_FACTORY = message_factory.MessageFactory(descriptor_pool.Default())
+
+
+@dataclasses.dataclass(frozen=True)
+class HttpError(Exception):
+ """An Exception specifying the HTTP error to return."""
+ code: http.HTTPStatus
+
+
+@dataclasses.dataclass(frozen=True)
+class HttpResponse:
+ """Information for a successful HTTP response."""
+ body: bytes
+ headers: Mapping[str, str] = dataclasses.field(default_factory=lambda: {})
+
+
+def proto_action(*,
+ service=str,
+ method=str) -> Callable[[_CallableT], _CallableT]:
+ """Decorator annotating a method as handling a proto service method.
+
+ The `google.api.http` annotation on the method will determine what requests
+ will be handled by the decorated function. Only a subset of methods and path
+ patterns are currently supported.
+
+ The decorated method will be called with the request message; it should return
+ a response message or or throw an `HttpError`.
+
+ Args:
+ service: The full name of the proto service.
+ method: The name of the method.
+
+ Returns:
+ An annotated function.
+ """
+ try:
+ desc = _FACTORY.pool.FindServiceByName(service).FindMethodByName(method)
+ except KeyError as e:
+ raise ValueError(f'Unable to find /{service}.{method}.') from e
+
+ rule = desc.GetOptions().Extensions[annotations_pb2.http]
+ pattern_kind = rule.WhichOneof('pattern')
+ try:
+ http_method = _HttpMethod[pattern_kind.upper()]
+ except KeyError as e:
+ raise ValueError(
+ f'The google.api.http annotation on /{service}.{method} is invalid '
+ 'or unsupported.') from e
+ path = _convert_pattern(getattr(rule, pattern_kind), alt_proto=True)
+
+ def handler(match: Match[str], body: bytes,
+ fn: Callable[[message.Message], message.Message]) -> HttpResponse:
+ request = _FACTORY.GetPrototype(desc.input_type)()
+ if rule.body == '*':
+ try:
+ request.ParseFromString(body)
+ except message.DecodeError as e:
+ raise HttpError(code=http.HTTPStatus.BAD_REQUEST) from e
+ elif rule.body:
+ setattr(request, rule.body, body)
+ # Set any fields from the request path.
+ for prop, value in match.groupdict().items():
+ try:
+ unescaped = urllib.parse.unquote(value)
+ except UnicodeError as e:
+ raise HttpError(code=http.HTTPStatus.BAD_REQUEST) from e
+ setattr(request, prop, unescaped)
+
+ response_body = fn(request).SerializeToString()
+ return HttpResponse(
+ body=response_body,
+ headers={
+ 'Content-Length': len(response_body),
+ 'Content-Type': 'application/x-protobuf'
+ })
+
+ def annotate_method(func: _CallableT) -> _CallableT:
+ setattr(func, _HTTP_ACTION_ATTR,
+ _HttpActionData(method=http_method, path=path, handler=handler))
+ return func
+
+ return annotate_method
+
+
+def http_action(*, method: str,
+ pattern: str) -> Callable[[_CallableT], _CallableT]:
+ """Decorator annotating a method as an HTTP action handler.
+
+ Request matching the method and pattern will be handled by the decorated
+ method. The pattern may contain bracket-enclosed keywords (e.g.,
+ '/data/{path}'), which will be matched against the request and passed
+ to the decorated function as keyword arguments.
+
+ The decorated method will be called with the request body (if any) and any
+ keyword args from the path pattern; it should return a `HttpResponse` or throw
+ an `HttpError`.
+
+ Args:
+ method: The type of HTTP method ('GET' or 'POST').
+ pattern: The url pattern to match.
+
+ Returns:
+ An annotated function.
+ """
+ try:
+ http_method = _HttpMethod[method.upper()]
+ except KeyError as e:
+ raise ValueError(f'unsupported HTTP method `{method}`') from e
+ path = _convert_pattern(pattern)
+
+ def handler(match: Match[str], body: bytes,
+ fn: Callable[[bytes], HttpResponse]) -> HttpResponse:
+ try:
+ args = {k: urllib.parse.unquote(v) for k, v in match.groupdict().items()}
+ except UnicodeError as e:
+ raise HttpError(code=http.HTTPStatus.BAD_REQUEST) from e
+ return fn(body, **args)
+
+ def annotate_method(func: _CallableT) -> _CallableT:
+ setattr(func, _HTTP_ACTION_ATTR,
+ _HttpActionData(method=http_method, path=path, handler=handler))
+ return func
+
+ return annotate_method
+
+
+def create_handler(*services: Any) -> Type[http.server.BaseHTTPRequestHandler]:
+ """Builds a BaseHTTPRequestHandler that delegates to decorated methods.
+
+ The returned BaseHTTPRequestHandler class will route requests to decorated
+ methods of the provided services, or return 404 if the request path does not
+ match any action handlers. If the request path matches multiple registered
+ action handlers, it's unspecified which will be invoked.
+
+ Args:
+ *services: A list of objects with methods decorated with `@proto_action` or
+ `@http_action`.
+
+ Returns:
+ A BaseHTTPRequestHandler subclass.
+ """
+
+ # Collect all handlers, keyed by HTTP method.
+ handlers = collections.defaultdict(lambda: [])
+ for service in services:
+ for attr_name in dir(service):
+ attr = getattr(service, attr_name)
+ if not callable(attr):
+ continue
+ data = getattr(attr, _HTTP_ACTION_ATTR, None)
+ if isinstance(data, _HttpActionData):
+ handlers[data.method].append((data, attr))
+
+ format_handlers = lambda h: ''.join([f'\n * {e[0].path.pattern}' for e in h])
+ logging.debug(
+ 'Creating HTTP request handler for path patterns:\nGET:%s\nPOST:%s',
+ format_handlers(handlers[_HttpMethod.GET]),
+ format_handlers(handlers[_HttpMethod.POST]))
+
+ class RequestHandler(http.server.BaseHTTPRequestHandler):
+ """Handler that delegates to `handlers`."""
+
+ def do_GET(self) -> None: # pylint:disable=invalid-name (override)
+ self._handle_request(_HttpMethod.GET, read_body=False)
+
+ def do_POST(self) -> None: # pylint:disable=invalid-name (override)
+ self._handle_request(_HttpMethod.POST)
+
+ def _handle_request(self,
+ method: _HttpMethod,
+ read_body: bool = True) -> None:
+ """Reads and delegates an incoming request to a registered handler."""
+ for data, fn in handlers[method]:
+ match = data.path.fullmatch(self.path)
+ if match is None:
+ continue
+
+ try:
+ body = self._read_body() if read_body else b''
+ response = data.handler(match, body, fn)
+ except HttpError as e:
+ logging.debug('%s error: %s', self.path, e)
+ return self.send_error(e.code)
+ return self._send_response(response)
+
+ # If no handler matched the path, return an error.
+ self.send_error(http.HTTPStatus.NOT_FOUND)
+
+ def _read_body(self) -> bytes:
+ """Reads the body of the request."""
+ body = self.rfile.read(int(self.headers['Content-Length']))
+ if self.headers['Content-Encoding'] == 'gzip':
+ try:
+ body = gzip.decompress(body)
+ except (gzip.BadGzipFile, zlib.error) as e:
+ raise HttpError(http.HTTPStatus.BAD_REQUEST) from e
+ elif self.headers['Content-Encoding']:
+ logging.warning('Unsupported content encoding %s',
+ self.headers['Content-Encoding'])
+ raise HttpError(http.HTTPStatus.BAD_REQUEST)
+ return body
+
+ def _send_response(self, response: HttpResponse) -> None:
+ """Sends a successful response message."""
+ self.send_response(http.HTTPStatus.OK)
+ for keyword, value in response.headers.items():
+ self.send_header(keyword, value)
+ self.end_headers()
+ self.wfile.write(response.body)
+
+ return RequestHandler
+
+
+class _HttpMethod(enum.Enum):
+ GET = 1
+ POST = 2
+
+
+@dataclasses.dataclass(frozen=True)
+class _HttpActionData:
+ """Data tracked for HTTP actions.
+
+ Attributes:
+ method: The name of the HTTP method to handle.
+ path: Requests matching this pattern will be handled.
+ handler: The handler function, which receives the path match, request body,
+ and decorated function.
+ """
+ method: _HttpMethod
+ path: Pattern[str]
+ handler: Callable[[Match[str], bytes, Callable[..., Any]], HttpResponse]
+
+
+def _convert_pattern(pattern: str, alt_proto=False) -> Pattern[str]:
+ """Converts a Google API pattern to a regexp with named groups."""
+ # Subfields are not supported and will generate a regexp compilation error.
+ pattern_regexp = re.sub(r'\\\{(.+?)\\\}', r'(?P<\1>[^/?]*)',
+ re.escape(pattern))
+ if alt_proto:
+ pattern_regexp += r'\?%24alt=proto'
+ try:
+ return re.compile(pattern_regexp)
+ except re.error as e:
+ raise ValueError(f'unable to convert `{pattern}` to a regexp') from e
diff --git a/fcp/demo/http_actions_test.py b/fcp/demo/http_actions_test.py
new file mode 100644
index 0000000..f74a775
--- /dev/null
+++ b/fcp/demo/http_actions_test.py
@@ -0,0 +1,216 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for http_actions."""
+
+import gzip
+import http
+import http.client
+import http.server
+import socket
+import threading
+from unittest import mock
+
+from absl.testing import absltest
+
+from fcp.demo import http_actions
+from fcp.protos.federatedcompute import common_pb2
+from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
+
+
+class TestService:
+
+ def __init__(self):
+ self.proto_action = mock.Mock()
+ self.get_action = mock.Mock()
+ self.post_action = mock.Mock()
+
+ @http_actions.proto_action(
+ service='google.internal.federatedcompute.v1.EligibilityEvalTasks',
+ method='RequestEligibilityEvalTask')
+ def handle_proto_action(self, *args, **kwargs):
+ return self.proto_action(*args, **kwargs)
+
+ @http_actions.http_action(method='get', pattern='/get/{arg1}/{arg2}')
+ def handle_get_action(self, *args, **kwargs):
+ return self.get_action(*args, **kwargs)
+
+ @http_actions.http_action(method='post', pattern='/post/{arg1}/{arg2}')
+ def handle_post_action(self, *args, **kwargs):
+ return self.post_action(*args, **kwargs)
+
+
+class TestHttpServer(http.server.HTTPServer):
+ pass
+
+
+class HttpActionsTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.service = TestService()
+ handler = http_actions.create_handler(self.service)
+ self._httpd = TestHttpServer(('localhost', 0), handler)
+ self._server_thread = threading.Thread(
+ target=self._httpd.serve_forever, daemon=True)
+ self._server_thread.start()
+ self.conn = http.client.HTTPConnection(
+ self._httpd.server_name, port=self._httpd.server_port)
+
+ def tearDown(self):
+ self._httpd.shutdown()
+ self._server_thread.join()
+ self._httpd.server_close()
+ super().tearDown()
+
+ def test_not_found(self):
+ self.conn.request('GET', '/no-match')
+ self.assertEqual(self.conn.getresponse().status, http.HTTPStatus.NOT_FOUND)
+
+ def test_proto_success(self):
+ expected_response = (
+ eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse(
+ session_id='test'))
+ self.service.proto_action.return_value = expected_response
+
+ request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest(
+ client_version=common_pb2.ClientVersion(version_code='test123'))
+ self.conn.request(
+ 'POST',
+ '/v1/eligibilityevaltasks/test%2Fpopulation:request?%24alt=proto',
+ request.SerializeToString())
+ response = self.conn.getresponse()
+ self.assertEqual(response.status, http.HTTPStatus.OK)
+ response_proto = (
+ eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse.FromString(
+ response.read()))
+ self.assertEqual(response_proto, expected_response)
+ # `population_name` should be set from the URL.
+ request.population_name = 'test/population'
+ self.service.proto_action.assert_called_once_with(request)
+
+ def test_proto_error(self):
+ self.service.proto_action.side_effect = http_actions.HttpError(
+ code=http.HTTPStatus.UNAUTHORIZED)
+
+ self.conn.request(
+ 'POST',
+ '/v1/eligibilityevaltasks/test%2Fpopulation:request?%24alt=proto', b'')
+ response = self.conn.getresponse()
+ self.assertEqual(response.status, http.HTTPStatus.UNAUTHORIZED)
+
+ def test_proto_with_invalid_payload(self):
+ self.conn.request(
+ 'POST',
+ '/v1/eligibilityevaltasks/test%2Fpopulation:request?%24alt=proto',
+ b'invalid')
+ response = self.conn.getresponse()
+ self.assertEqual(response.status, http.HTTPStatus.BAD_REQUEST)
+
+ def test_proto_with_gzip_encoding(self):
+ self.service.proto_action.return_value = (
+ eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse())
+
+ request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest(
+ client_version=common_pb2.ClientVersion(version_code='test123'))
+ self.conn.request('POST',
+ '/v1/eligibilityevaltasks/test:request?%24alt=proto',
+ gzip.compress(request.SerializeToString()),
+ {'Content-Encoding': 'gzip'})
+ self.assertEqual(self.conn.getresponse().status, http.HTTPStatus.OK)
+ request.population_name = 'test'
+ self.service.proto_action.assert_called_once_with(request)
+
+ def test_proto_with_invalid_gzip_encoding(self):
+ self.conn.request('POST',
+ '/v1/eligibilityevaltasks/test:request?%24alt=proto',
+ b'invalid', {'Content-Encoding': 'gzip'})
+ response = self.conn.getresponse()
+ self.assertEqual(response.status, http.HTTPStatus.BAD_REQUEST)
+
+ def test_proto_with_unsupport_encoding(self):
+ self.conn.request('POST',
+ '/v1/eligibilityevaltasks/test:request?%24alt=proto', b'',
+ {'Content-Encoding': 'compress'})
+ self.assertEqual(self.conn.getresponse().status,
+ http.HTTPStatus.BAD_REQUEST)
+ self.service.proto_action.assert_not_called()
+
+ def test_get_success(self):
+ self.service.get_action.return_value = http_actions.HttpResponse(
+ body=b'body',
+ headers={
+ 'Content-Length': 4,
+ 'Content-Type': 'application/x-test',
+ })
+
+ self.conn.request('GET', '/get/foo/bar')
+ response = self.conn.getresponse()
+ self.assertEqual(response.status, http.HTTPStatus.OK)
+ self.assertEqual(response.headers['Content-Length'], '4')
+ self.assertEqual(response.headers['Content-Type'], 'application/x-test')
+ self.assertEqual(response.read(), b'body')
+ self.service.get_action.assert_called_once_with(b'', arg1='foo', arg2='bar')
+
+ def test_get_error(self):
+ self.service.get_action.side_effect = http_actions.HttpError(
+ code=http.HTTPStatus.UNAUTHORIZED)
+
+ self.conn.request('GET', '/get/foo/bar')
+ self.assertEqual(self.conn.getresponse().status,
+ http.HTTPStatus.UNAUTHORIZED)
+
+ def test_post_success(self):
+ self.service.post_action.return_value = http_actions.HttpResponse(
+ body=b'body',
+ headers={
+ 'Content-Length': 4,
+ 'Content-Type': 'application/x-test',
+ })
+
+ self.conn.request('POST', '/post/foo/bar', b'request-body')
+ response = self.conn.getresponse()
+ self.assertEqual(response.status, http.HTTPStatus.OK)
+ self.assertEqual(response.headers['Content-Length'], '4')
+ self.assertEqual(response.headers['Content-Type'], 'application/x-test')
+ self.assertEqual(response.read(), b'body')
+ self.service.post_action.assert_called_once_with(
+ b'request-body', arg1='foo', arg2='bar')
+
+ def test_post_error(self):
+ self.service.post_action.side_effect = http_actions.HttpError(
+ code=http.HTTPStatus.UNAUTHORIZED)
+
+ self.conn.request('POST', '/post/foo/bar', b'request-body')
+ self.assertEqual(self.conn.getresponse().status,
+ http.HTTPStatus.UNAUTHORIZED)
+
+ def test_post_with_gzip_encoding(self):
+ self.service.post_action.return_value = http_actions.HttpResponse(body=b'')
+
+ self.conn.request('POST', '/post/foo/bar', gzip.compress(b'request-body'),
+ {'Content-Encoding': 'gzip'})
+ self.assertEqual(self.conn.getresponse().status, http.HTTPStatus.OK)
+ self.service.post_action.assert_called_once_with(
+ b'request-body', arg1='foo', arg2='bar')
+
+ def test_post_with_unsupport_encoding(self):
+ self.conn.request('POST', '/post/foo/bar', b'',
+ {'Content-Encoding': 'compress'})
+ self.assertEqual(self.conn.getresponse().status,
+ http.HTTPStatus.BAD_REQUEST)
+ self.service.post_action.assert_not_called()
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/media.py b/fcp/demo/media.py
new file mode 100644
index 0000000..e930339
--- /dev/null
+++ b/fcp/demo/media.py
@@ -0,0 +1,135 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Action handlers for file upload and download.
+
+In a production system, download would likely be handled by an external service;
+it's important that uploads are not handled separately to help ensure that
+unaggregated client data is only held ephemerally.
+"""
+
+import contextlib
+import http
+import threading
+from typing import Callable, Iterator, Optional
+import uuid
+
+from fcp.demo import http_actions
+from fcp.protos.federatedcompute import common_pb2
+
+
+class DownloadGroup:
+ """A group of downloadable files."""
+
+ def __init__(self, prefix: str, add_fn: Callable[[str, bytes, str], None]):
+ self._prefix = prefix
+ self._add_fn = add_fn
+
+ @property
+ def prefix(self) -> str:
+ """The path prefix for all files in this group."""
+ return self._prefix
+
+ def add(self,
+ name: str,
+ data: bytes,
+ content_type: str = 'application/octet-stream') -> str:
+ """Adds a file to the group.
+
+ Args:
+ name: The name of the new file.
+ data: The bytes to make available.
+ content_type: The content type to include in the response.
+
+ Returns:
+ The full path to the new file.
+
+ Raises:
+ KeyError if a file with that name has already been registered.
+ """
+ self._add_fn(name, data, content_type)
+ return self._prefix + name
+
+
+class Service:
+ """Implements a service for uploading and downloading data over HTTP."""
+
+ def __init__(self, forwarding_info: Callable[[], common_pb2.ForwardingInfo]):
+ self._forwarding_info = forwarding_info
+ self._lock = threading.Lock()
+ self._downloads: dict[str, dict[str, http_actions.HttpResponse]] = {}
+ self._uploads: dict[str, Optional[bytes]] = {}
+
+ @contextlib.contextmanager
+ def create_download_group(self) -> Iterator[DownloadGroup]:
+ """Creates a new group of downloadable files.
+
+ Files can be be added to this group using `DownloadGroup.add`. All files in
+ the group will be unregistered when the ContextManager goes out of scope.
+
+ Yields:
+ The download group to which files should be added.
+ """
+ group = str(uuid.uuid4())
+
+ def add_file(name: str, data: bytes, content_type: str) -> None:
+ with self._lock:
+ if name in self._downloads[group]:
+ raise KeyError(f'{name} already exists')
+ self._downloads[group][name] = http_actions.HttpResponse(
+ body=data,
+ headers={
+ 'Content-Length': len(data),
+ 'Content-Type': content_type,
+ })
+
+ with self._lock:
+ self._downloads[group] = {}
+ try:
+ yield DownloadGroup(
+ f'{self._forwarding_info().target_uri_prefix}data/{group}/', add_file)
+ finally:
+ with self._lock:
+ del self._downloads[group]
+
+ def register_upload(self) -> str:
+ """Registers a path for single-use upload, returning the resource name."""
+ name = str(uuid.uuid4())
+ with self._lock:
+ self._uploads[name] = None
+ return name
+
+ def finalize_upload(self, name: str) -> Optional[bytes]:
+ """Returns the data from an upload, if any."""
+ with self._lock:
+ return self._uploads.pop(name)
+
+ @http_actions.http_action(method='GET', pattern='/data/{group}/{name}')
+ def download(self, body: bytes, group: str,
+ name: str) -> http_actions.HttpResponse:
+ """Handles a download request."""
+ del body
+ try:
+ with self._lock:
+ return self._downloads[group][name]
+ except KeyError as e:
+ raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
+
+ @http_actions.http_action(
+ method='POST', pattern='/upload/v1/media/{name}?upload_protocol=raw')
+ def upload(self, body: bytes, name: str) -> http_actions.HttpResponse:
+ with self._lock:
+ if name not in self._uploads or self._uploads[name] is not None:
+ raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED)
+ self._uploads[name] = body
+ return http_actions.HttpResponse(b'')
diff --git a/fcp/demo/media_test.py b/fcp/demo/media_test.py
new file mode 100644
index 0000000..7f8415c
--- /dev/null
+++ b/fcp/demo/media_test.py
@@ -0,0 +1,188 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for media."""
+
+import http
+from unittest import mock
+import uuid
+
+from absl.testing import absltest
+
+from fcp.demo import http_actions
+from fcp.demo import media
+from fcp.protos.federatedcompute import common_pb2
+
+
+class MediaTest(absltest.TestCase):
+
+ @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
+ def test_create_download_group(self, mock_uuid):
+ forwarding_info = common_pb2.ForwardingInfo(
+ target_uri_prefix='https://media.example/')
+ service = media.Service(lambda: forwarding_info)
+ with service.create_download_group() as group:
+ self.assertEqual(group.prefix,
+ f'https://media.example/data/{mock_uuid.return_value}/')
+ name = 'file-name'
+ self.assertEqual(group.add(name, b'data'), group.prefix + name)
+
+ def test_download(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ with service.create_download_group() as group:
+ data = b'data'
+ url = group.add('name', data)
+ self.assertEqual(
+ service.download(b'',
+ *url.split('/')[-2:]),
+ http_actions.HttpResponse(
+ body=data,
+ headers={
+ 'Content-Length': len(data),
+ 'Content-Type': 'application/octet-stream',
+ }))
+
+ def test_download_with_content_type(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ with service.create_download_group() as group:
+ data = b'data'
+ content_type = 'application/x-test'
+ url = group.add('name', data, content_type=content_type)
+ self.assertEqual(
+ service.download(b'',
+ *url.split('/')[-2:]),
+ http_actions.HttpResponse(
+ body=data,
+ headers={
+ 'Content-Length': len(data),
+ 'Content-Type': content_type,
+ }))
+
+ def test_download_multiple_files(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ with service.create_download_group() as group:
+ data1 = b'data1'
+ data2 = b'data2'
+ url1 = group.add('file1', data1)
+ url2 = group.add('file2', data2)
+ self.assertEqual(service.download(b'', *url1.split('/')[-2:]).body, data1)
+ self.assertEqual(service.download(b'', *url2.split('/')[-2:]).body, data2)
+
+ def test_download_multiple_groups(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ with service.create_download_group() as group1, (
+ service.create_download_group()) as group2:
+ self.assertNotEqual(group1.prefix, group2.prefix)
+ data1 = b'data1'
+ data2 = b'data2'
+ url1 = group1.add('name', data1)
+ url2 = group2.add('name', data2)
+ self.assertEqual(service.download(b'', *url1.split('/')[-2:]).body, data1)
+ self.assertEqual(service.download(b'', *url2.split('/')[-2:]).body, data2)
+
+ def test_download_unregistered_group(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.download(b'', 'does-not-exist', 'does-not-exist')
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+
+ def test_download_unregistered_file(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ with service.create_download_group() as group:
+ url = group.add('name', b'data')
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.download(b'', url.split('/')[-2], 'does-not-exist')
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+
+ def test_download_no_longer_registered(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ with service.create_download_group() as group:
+ url = group.add('name', b'data')
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.download(b'', *url.split('/')[-2:])
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+
+ def test_register_duplicate_download(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ with service.create_download_group() as group:
+ data1 = b'data'
+ url = group.add('name', data1)
+ with self.assertRaises(KeyError):
+ group.add('name', b'data2')
+
+ # The original file should still be downloadable.
+ self.assertEqual(service.download(b'', *url.split('/')[-2:]).body, data1)
+
+ def test_upload(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ name = service.register_upload()
+ data = b'data'
+ self.assertEqual(
+ service.upload(data, name), http_actions.HttpResponse(body=b''))
+ self.assertEqual(service.finalize_upload(name), data)
+
+ def test_upload_without_data(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ name = service.register_upload()
+ self.assertIsNone(service.finalize_upload(name))
+
+ def test_upload_multiple_times(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ name = service.register_upload()
+
+ data = b'data1'
+ self.assertEqual(
+ service.upload(data, name), http_actions.HttpResponse(body=b''))
+
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.upload(b'data2', name)
+ self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
+
+ self.assertEqual(service.finalize_upload(name), data)
+
+ def test_upload_multiple(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ name1 = service.register_upload()
+ name2 = service.register_upload()
+
+ # Order shouldn't matter.
+ service.upload(b'data2', name2)
+ service.upload(b'data1', name1)
+
+ self.assertEqual(service.finalize_upload(name1), b'data1')
+ self.assertEqual(service.finalize_upload(name2), b'data2')
+
+ def test_upload_unregistered(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.upload(b'data', 'does-not-exist')
+ self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
+
+ with self.assertRaises(KeyError):
+ service.finalize_upload('does-not-exist')
+
+ def test_upload_no_longer_registered(self):
+ service = media.Service(common_pb2.ForwardingInfo)
+ name = service.register_upload()
+ self.assertIsNone(service.finalize_upload(name))
+
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.upload(b'data', name)
+ self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
+
+ with self.assertRaises(KeyError):
+ service.finalize_upload(name)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/plan_utils.py b/fcp/demo/plan_utils.py
new file mode 100644
index 0000000..409d5fa
--- /dev/null
+++ b/fcp/demo/plan_utils.py
@@ -0,0 +1,203 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for working with Plan protos and TensorFlow.
+
+See the field comments in plan.proto for more information about each operation
+and when it should be run.
+"""
+
+import functools
+import tempfile
+from typing import Any, Optional
+import uuid
+
+import tensorflow as tf
+
+from google.protobuf import message
+from fcp.protos import plan_pb2
+from fcp.tensorflow import serve_slices as serve_slices_registry
+
+
+class Session:
+ """A session for performing L2 Plan operations.
+
+ This class only supports loading a single intermediate update.
+ """
+
+ def __init__(self, plan: plan_pb2.Plan, checkpoint: bytes):
+ if len(plan.phase) != 1:
+ raise ValueError('plan must contain exactly 1 phase.')
+ if not plan.phase[0].HasField('server_phase'):
+ raise ValueError('plan.phase[0] is missing server_phase.')
+
+ graph_def = tf.compat.v1.GraphDef()
+ try:
+ plan.server_graph_bytes.Unpack(graph_def)
+ except message.DecodeError as e:
+ raise ValueError('Unable to parse server graph.') from e
+
+ graph = tf.Graph()
+ with graph.as_default():
+ tf.import_graph_def(graph_def, name='')
+ self._session = tf.compat.v1.Session(graph=graph)
+ self._plan = plan
+ self._restore_state(plan.server_savepoint, checkpoint)
+ self._maybe_run(plan.phase[0].server_phase.phase_init_op)
+
+ serve_slices_calls = []
+
+ def record_serve_slices_call(*args):
+ served_at_id = str(uuid.uuid4())
+ serve_slices_calls.append((served_at_id, args))
+ return served_at_id
+
+ with serve_slices_registry.register_serve_slices_callback(
+ record_serve_slices_call
+ ) as token:
+ self._client_checkpoint = self._save_state(
+ plan.phase[0].server_phase.write_client_init, session_token=token
+ )
+ self._slices = {
+ k: self._build_slices(*args) for k, args in serve_slices_calls
+ }
+
+ def __enter__(self) -> 'Session':
+ self._session.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb) -> None:
+ self._session.__exit__(exc_type, exc_value, tb)
+
+ def close(self) -> None:
+ """Closes the session, releasing resources."""
+ self._session.close()
+
+ def _maybe_run(
+ self, op: str, feed_dict: Optional[dict[str, Any]] = None
+ ) -> None:
+ """Runs an operation if it's non-empty."""
+ if op:
+ self._session.run(op, feed_dict=feed_dict)
+
+ def _restore_state(self, checkpoint_op: plan_pb2.CheckpointOp,
+ checkpoint: bytes) -> None:
+ """Restores state from a TensorFlow checkpoint."""
+ self._maybe_run(checkpoint_op.before_restore_op)
+ if checkpoint_op.HasField('saver_def'):
+ with tempfile.NamedTemporaryFile('wb') as tmpfile:
+ tmpfile.write(checkpoint)
+ tmpfile.flush()
+ self._session.run(
+ checkpoint_op.saver_def.restore_op_name,
+ {checkpoint_op.saver_def.filename_tensor_name: tmpfile.name})
+ self._maybe_run(checkpoint_op.after_restore_op)
+
+ def _save_state(
+ self,
+ checkpoint_op: plan_pb2.CheckpointOp,
+ session_token: Optional[bytes] = None,
+ ) -> bytes:
+ """Saves state to a TensorFlow checkpoint."""
+ before_and_after_inputs = {}
+ if session_token and checkpoint_op.session_token_tensor_name:
+ before_and_after_inputs[checkpoint_op.session_token_tensor_name] = (
+ session_token
+ )
+
+ self._maybe_run(
+ checkpoint_op.before_save_op, feed_dict=before_and_after_inputs
+ )
+ result = b''
+ if checkpoint_op.HasField('saver_def'):
+ with tempfile.NamedTemporaryFile() as tmpfile:
+ save_tensor_inputs = before_and_after_inputs.copy()
+ save_tensor_inputs[checkpoint_op.saver_def.filename_tensor_name] = (
+ tmpfile.name
+ )
+ self._session.run(
+ checkpoint_op.saver_def.save_tensor_name,
+ feed_dict=save_tensor_inputs,
+ )
+ # TensorFlow overwrites (via move) the output file, so the data can't be
+ # read from the filehandle. Deletion still works properly, though.
+ with open(tmpfile.name, 'rb') as f:
+ result = f.read()
+ self._maybe_run(
+ checkpoint_op.after_save_op, feed_dict=before_and_after_inputs
+ )
+ return result
+
+ def _build_slices(
+ self,
+ callback_token: bytes,
+ server_val: list[Any],
+ max_key: int,
+ select_fn_initialize_op: str,
+ select_fn_server_val_input_tensor_names: list[str],
+ select_fn_key_input_tensor_name: str,
+ select_fn_filename_input_tensor_name: str,
+ select_fn_target_tensor_name: str,
+ ):
+ """Builds the slices for a ServeSlices call."""
+ del callback_token
+ slices: list[bytes] = []
+ for i in range(0, max_key + 1):
+ self._maybe_run(select_fn_initialize_op)
+ with tempfile.NamedTemporaryFile() as tmpfile:
+ feed_dict = dict(
+ zip(select_fn_server_val_input_tensor_names, server_val)
+ )
+ feed_dict[select_fn_key_input_tensor_name] = i
+ feed_dict[select_fn_filename_input_tensor_name] = tmpfile.name
+ self._session.run(select_fn_target_tensor_name, feed_dict=feed_dict)
+ # TensorFlow overwrites (via move) the output file, so the data can't be
+ # read from the filehandle. Deletion still works properly, though.
+ with open(tmpfile.name, 'rb') as f:
+ slices.append(f.read())
+ return slices
+
+ @functools.cached_property
+ def client_plan(self) -> bytes:
+ """The serialized ClientOnlyPlan corresponding to the Plan proto."""
+ client_only_plan = plan_pb2.ClientOnlyPlan(
+ phase=self._plan.phase[0].client_phase,
+ graph=self._plan.client_graph_bytes.value,
+ tflite_graph=self._plan.client_tflite_graph_bytes)
+ if self._plan.HasField('tensorflow_config_proto'):
+ client_only_plan.tensorflow_config_proto.CopyFrom(
+ self._plan.tensorflow_config_proto)
+ return client_only_plan.SerializeToString()
+
+ @property
+ def client_checkpoint(self) -> bytes:
+ """The initial checkpoint for use by clients."""
+ return self._client_checkpoint
+
+ def finalize(self, update: bytes) -> bytes:
+ """Loads an intermediate update and return the final result."""
+ self._restore_state(
+ self._plan.phase[0].server_phase.read_intermediate_update, update)
+ self._maybe_run(self._plan.phase[0].server_phase
+ .intermediate_aggregate_into_accumulators_op)
+ # write_accumulators and metrics are not needed by Federated Program
+ # computations because all results are included in the server savepoint.
+ self._maybe_run(
+ self._plan.phase[0].server_phase.apply_aggregrated_updates_op)
+ return self._save_state(self._plan.server_savepoint)
+
+ @property
+ def slices(self) -> dict[str, list[bytes]]:
+ """The Federated Select slices, keyed by served_at_id."""
+ # Return a copy to prevent mutations.
+ return self._slices.copy()
diff --git a/fcp/demo/plan_utils_test.py b/fcp/demo/plan_utils_test.py
new file mode 100644
index 0000000..a3a83da
--- /dev/null
+++ b/fcp/demo/plan_utils_test.py
@@ -0,0 +1,350 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for plan_utils."""
+
+import functools
+import tempfile
+from typing import Any, Optional
+
+from absl.testing import absltest
+import tensorflow as tf
+
+from fcp.demo import plan_utils
+from fcp.demo import test_utils
+from fcp.protos import plan_pb2
+from fcp.tensorflow import serve_slices
+
+DEFAULT_INITIAL_CHECKPOINT = b'initial'
+CHECKPOINT_TENSOR_NAME = 'checkpoint'
+INTERMEDIATE_TENSOR_NAME = 'intermediate_value'
+FINAL_TENSOR_NAME = 'final_value'
+NUM_SLICES = 3
+
+
+def create_plan(log_file: Optional[str] = None) -> plan_pb2.Plan:
+ """Creates a test Plan that sums inputs."""
+
+ def log_op(name: str) -> tf.Operation:
+ """Helper function to log op invocations to a file."""
+ if log_file:
+ return tf.print(name, output_stream=f'file://{log_file}')
+ return tf.raw_ops.NoOp()
+
+ def create_checkpoint_op(
+ name: str,
+ filename_op: Any,
+ save_op: Any = None,
+ restore_op: Any = None,
+ session_token_tensor_name: Optional[str] = None,
+ ) -> plan_pb2.CheckpointOp:
+ before_restore = log_op(f'{name}/before_restore')
+ after_restore = log_op(f'{name}/after_restore')
+ before_save = log_op(f'{name}/before_save')
+ after_save = log_op(f'{name}/after_save')
+ with tf.control_dependencies(
+ [save_op if save_op is not None else tf.raw_ops.NoOp()]):
+ save_op = log_op(f'{name}/save')
+ with tf.control_dependencies(
+ [restore_op if restore_op is not None else tf.raw_ops.NoOp()]):
+ restore_op = log_op(f'{name}/restore')
+ return plan_pb2.CheckpointOp(
+ saver_def=tf.compat.v1.train.SaverDef(
+ filename_tensor_name=filename_op.name,
+ restore_op_name=restore_op.name,
+ save_tensor_name=save_op.name,
+ version=tf.compat.v1.train.SaverDef.V1,
+ ),
+ before_restore_op=before_restore.name,
+ after_restore_op=after_restore.name,
+ before_save_op=before_save.name,
+ after_save_op=after_save.name,
+ session_token_tensor_name=session_token_tensor_name,
+ )
+
+ with tf.compat.v1.Graph().as_default() as client_graph:
+ tf.constant(0)
+
+ with tf.compat.v1.Graph().as_default() as server_graph:
+ # Initialization:
+ last_client_update = tf.Variable(0, dtype=tf.int32)
+ intermediate_acc = tf.Variable(0, dtype=tf.int32)
+ last_intermediate_update = tf.Variable(0, dtype=tf.int32)
+ final_acc = tf.Variable(0, dtype=tf.int32)
+ with tf.control_dependencies([
+ last_client_update.initializer, intermediate_acc.initializer,
+ last_intermediate_update.initializer, final_acc.initializer
+ ]):
+ phase_init_op = log_op('phase_init')
+
+ # Ops for Federated Select:
+ select_fn_initialize_op = log_op('slices/initialize')
+ select_fn_server_vals = [
+ tf.constant(1234),
+ tf.constant('asdf'),
+ tf.constant([1, 2, 3]),
+ ]
+ select_fn_server_val_inputs = [
+ tf.compat.v1.placeholder(v.dtype) for v in select_fn_server_vals
+ ]
+ select_fn_key_input = tf.compat.v1.placeholder(tf.int32, shape=())
+ select_fn_filename_input = tf.compat.v1.placeholder(tf.string, shape=())
+ assertions = [
+ tf.debugging.assert_equal(placeholder, constant)
+ for placeholder, constant in zip(
+ select_fn_server_val_inputs, select_fn_server_vals
+ )
+ ]
+ with tf.control_dependencies([log_op('slices/save_slice')] + assertions):
+ select_fn_save_op = tf.io.write_file(
+ select_fn_filename_input, tf.strings.as_string(select_fn_key_input)
+ )
+ # Some tests disable passing the callback token; set `served_at_id` to '-'
+ # in that case.
+ callback_token = tf.compat.v1.placeholder_with_default('', shape=())
+ served_at_id = tf.cond(
+ tf.equal(callback_token, ''),
+ lambda: '-',
+ functools.partial(
+ serve_slices.serve_slices,
+ callback_token=callback_token,
+ server_val=select_fn_server_vals,
+ max_key=NUM_SLICES - 1,
+ select_fn_initialize_op=select_fn_initialize_op.name,
+ select_fn_server_val_input_tensor_names=[
+ v.name for v in select_fn_server_val_inputs
+ ],
+ select_fn_key_input_tensor_name=select_fn_key_input.name,
+ select_fn_filename_input_tensor_name=select_fn_filename_input.name,
+ select_fn_target_tensor_name=select_fn_save_op.name,
+ ),
+ )
+
+ # Ops for L2 Aggregation:
+ client_checkpoint_data = tf.Variable(
+ DEFAULT_INITIAL_CHECKPOINT, dtype=tf.string)
+
+ write_client_init_filename = tf.compat.v1.placeholder(tf.string, shape=())
+ client_checkpoint_data_value = tf.cond(
+ tf.compat.v1.is_variable_initialized(client_checkpoint_data),
+ client_checkpoint_data.read_value,
+ lambda: client_checkpoint_data.initial_value,
+ )
+ write_client_init_op = create_checkpoint_op(
+ 'write_client_init',
+ write_client_init_filename,
+ save_op=tf.io.write_file(
+ write_client_init_filename,
+ tf.strings.join(
+ [client_checkpoint_data_value, served_at_id], separator=' '
+ ),
+ ),
+ session_token_tensor_name=callback_token.name,
+ )
+
+ read_intermediate_update_filename = tf.compat.v1.placeholder(
+ tf.string, shape=())
+ read_intermediate_update_op = create_checkpoint_op(
+ 'read_intermediate_update',
+ read_intermediate_update_filename,
+ restore_op=last_intermediate_update.assign(
+ tf.raw_ops.Restore(
+ file_pattern=read_intermediate_update_filename,
+ tensor_name=INTERMEDIATE_TENSOR_NAME,
+ dt=tf.int32)))
+
+ with tf.control_dependencies([log_op('apply_aggregated_updates')]):
+ apply_aggregated_updates_op = final_acc.assign_add(
+ last_intermediate_update)
+
+ server_savepoint_filename = tf.compat.v1.placeholder(tf.string, shape=())
+ server_savepoint_op = create_checkpoint_op(
+ 'server_savepoint',
+ server_savepoint_filename,
+ save_op=tf.raw_ops.Save(
+ filename=server_savepoint_filename,
+ tensor_names=[FINAL_TENSOR_NAME],
+ data=[final_acc]),
+ restore_op=client_checkpoint_data.assign(
+ tf.raw_ops.Restore(
+ file_pattern=server_savepoint_filename,
+ tensor_name=CHECKPOINT_TENSOR_NAME,
+ dt=tf.string)))
+
+ config_proto = tf.compat.v1.ConfigProto(operation_timeout_in_ms=1234)
+
+ plan = plan_pb2.Plan(
+ phase=[
+ plan_pb2.Plan.Phase(
+ client_phase=plan_pb2.ClientPhase(name='ClientPhase'),
+ server_phase=plan_pb2.ServerPhase(
+ phase_init_op=phase_init_op.name,
+ write_client_init=write_client_init_op,
+ read_intermediate_update=read_intermediate_update_op,
+ apply_aggregrated_updates_op=(
+ apply_aggregated_updates_op.name)))
+ ],
+ server_savepoint=server_savepoint_op,
+ client_tflite_graph_bytes=b'tflite-graph',
+ version=1)
+ plan.client_graph_bytes.Pack(client_graph.as_graph_def())
+ plan.server_graph_bytes.Pack(server_graph.as_graph_def())
+ plan.tensorflow_config_proto.Pack(config_proto)
+ return plan
+
+
+def create_checkpoint(tensor_name=b'test'):
+ """Creates a test initial checkpoint."""
+ return test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: tensor_name})
+
+
+class PlanUtilsTest(absltest.TestCase):
+
+ def test_session_enter_exit(self):
+ self.assertIsNone(tf.compat.v1.get_default_session())
+ with plan_utils.Session(create_plan(), create_checkpoint()):
+ self.assertIsNotNone(tf.compat.v1.get_default_session())
+ self.assertIsNone(tf.compat.v1.get_default_session())
+
+ def test_session_without_phase(self):
+ plan = create_plan()
+ plan.ClearField('phase')
+ with self.assertRaises(ValueError):
+ plan_utils.Session(plan, create_checkpoint())
+
+ def test_session_without_server_phase(self):
+ plan = create_plan()
+ plan.phase[0].ClearField('server_phase')
+ with self.assertRaises(ValueError):
+ plan_utils.Session(plan, create_checkpoint())
+
+ def test_session_with_multiple_phases(self):
+ plan = create_plan()
+ plan.phase.append(plan.phase[0])
+ with self.assertRaises(ValueError):
+ plan_utils.Session(plan, create_checkpoint())
+
+ def test_session_client_plan(self):
+ plan = create_plan()
+ with plan_utils.Session(plan, create_checkpoint()) as session:
+ self.assertEqual(
+ plan_pb2.ClientOnlyPlan.FromString(session.client_plan),
+ plan_pb2.ClientOnlyPlan(
+ phase=plan.phase[0].client_phase,
+ graph=plan.client_graph_bytes.value,
+ tflite_graph=plan.client_tflite_graph_bytes,
+ tensorflow_config_proto=plan.tensorflow_config_proto))
+
+ def test_session_client_plan_without_tensorflow_config(self):
+ plan = create_plan()
+ plan.ClearField('tensorflow_config_proto')
+ with plan_utils.Session(plan, create_checkpoint()) as session:
+ self.assertEqual(
+ plan_pb2.ClientOnlyPlan.FromString(session.client_plan),
+ plan_pb2.ClientOnlyPlan(
+ phase=plan.phase[0].client_phase,
+ graph=plan.client_graph_bytes.value,
+ tflite_graph=plan.client_tflite_graph_bytes))
+
+ def test_session_client_plan_without_tflite_graph(self):
+ plan = create_plan()
+ plan.ClearField('client_tflite_graph_bytes')
+ with plan_utils.Session(plan, create_checkpoint()) as session:
+ self.assertEqual(
+ plan_pb2.ClientOnlyPlan.FromString(session.client_plan),
+ plan_pb2.ClientOnlyPlan(
+ phase=plan.phase[0].client_phase,
+ graph=plan.client_graph_bytes.value,
+ tensorflow_config_proto=plan.tensorflow_config_proto))
+
+ def test_session_client_checkpoint(self):
+ expected = b'test-client-checkpoint'
+ with plan_utils.Session(
+ create_plan(),
+ test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: expected
+ })) as session:
+ self.assertEqual(
+ session.client_checkpoint,
+ expected + b' ' + next(iter(session.slices)).encode(),
+ )
+
+ def test_session_client_checkpoint_without_server_savepoint(self):
+ plan = create_plan()
+ # If server_savepoint isn't set, the checkpoint shouldn't be loaded.
+ plan.ClearField('server_savepoint')
+ with plan_utils.Session(plan, create_checkpoint()) as session:
+ self.assertStartsWith(
+ session.client_checkpoint, DEFAULT_INITIAL_CHECKPOINT + b' '
+ )
+
+ def test_session_finalize(self):
+ with tempfile.NamedTemporaryFile('r') as tmpfile:
+ with plan_utils.Session(create_plan(tmpfile.name),
+ create_checkpoint()) as session:
+ checkpoint = session.finalize(
+ test_utils.create_checkpoint({INTERMEDIATE_TENSOR_NAME: 3}))
+ self.assertSequenceEqual(
+ tmpfile.read().splitlines(),
+ [
+ 'server_savepoint/before_restore',
+ 'server_savepoint/restore',
+ 'server_savepoint/after_restore',
+ 'phase_init',
+ 'write_client_init/before_save',
+ 'write_client_init/save',
+ 'write_client_init/after_save',
+ ]
+ + ['slices/initialize', 'slices/save_slice'] * NUM_SLICES
+ + [
+ 'read_intermediate_update/before_restore',
+ 'read_intermediate_update/restore',
+ 'read_intermediate_update/after_restore',
+ 'apply_aggregated_updates',
+ 'server_savepoint/before_save',
+ 'server_savepoint/save',
+ 'server_savepoint/after_save',
+ ],
+ )
+
+ result = test_utils.read_tensor_from_checkpoint(checkpoint,
+ FINAL_TENSOR_NAME, tf.int32)
+ # The value should be propagated from the intermediate aggregate.
+ self.assertEqual(result, 3)
+
+ def test_session_with_tensorflow_error(self):
+ plan = create_plan()
+ plan.phase[0].server_phase.phase_init_op = 'does-not-exist'
+ with self.assertRaises(ValueError):
+ plan_utils.Session(plan, create_checkpoint())
+
+ def test_session_slices(self):
+ with plan_utils.Session(create_plan(), create_checkpoint()) as session:
+ # The served_at_id should match the value in the client checkpoint.
+ served_at_id = session.client_checkpoint.split(b' ')[1].decode()
+ self.assertSameElements(session.slices.keys(), [served_at_id])
+ self.assertListEqual(
+ session.slices[served_at_id],
+ [str(i).encode() for i in range(NUM_SLICES)],
+ )
+
+ def test_session_without_slices(self):
+ plan = create_plan()
+ plan.phase[0].server_phase.write_client_init.ClearField(
+ 'session_token_tensor_name'
+ )
+ with plan_utils.Session(plan, create_checkpoint()) as session:
+ self.assertEmpty(session.slices)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/server.py b/fcp/demo/server.py
new file mode 100644
index 0000000..dcd9627
--- /dev/null
+++ b/fcp/demo/server.py
@@ -0,0 +1,164 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""An in-process federated compute server."""
+
+import contextlib
+import gzip
+import http.server
+import socket
+import socketserver
+import ssl
+from typing import Optional
+
+from absl import logging
+
+from fcp.demo import aggregations
+from fcp.demo import eligibility_eval_tasks
+from fcp.demo import http_actions
+from fcp.demo import media
+from fcp.demo import plan_utils
+from fcp.demo import task_assignments
+from fcp.protos import plan_pb2
+from fcp.protos.federatedcompute import common_pb2
+from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
+
+_TaskAssignmentMode = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
+)
+
+# Template for file name for federated select slices. See
+# `FederatedSelectUriInfo.uri_template` for the meaning of the "{served_at_id}"
+# and "{key_base10}" substrings.
+_FEDERATED_SELECT_NAME_TEMPLATE = '{served_at_id}_{key_base10}'
+
+# Content type used for serialized and compressed Plan messages.
+_PLAN_CONTENT_TYPE = 'application/x-protobuf+gzip'
+
+# Content type used for serialzied and compressed TensorFlow checkpoints.
+_CHECKPOINT_CONTENT_TYPE = 'application/octet-stream+gzip'
+
+
+class InProcessServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
+ """An in-process HTTP server implementing the Federated Compute protocol."""
+
+ def __init__(self,
+ *,
+ population_name: str,
+ host: str,
+ port: int,
+ address_family: Optional[socket.AddressFamily] = None):
+ self._media_service = media.Service(self._get_forwarding_info)
+ self._aggregations_service = aggregations.Service(self._get_forwarding_info,
+ self._media_service)
+ self._task_assignments_service = task_assignments.Service(
+ population_name, self._get_forwarding_info, self._aggregations_service)
+ self._eligibility_eval_tasks_service = eligibility_eval_tasks.Service(
+ population_name, self._get_forwarding_info
+ )
+ handler = http_actions.create_handler(
+ self._media_service,
+ self._aggregations_service,
+ self._task_assignments_service,
+ self._eligibility_eval_tasks_service,
+ )
+ if address_family is not None:
+ self.address_family = address_family
+ http.server.HTTPServer.__init__(self, (host, port), handler)
+
+ async def run_computation(
+ self,
+ task_name: str,
+ plan: plan_pb2.Plan,
+ server_checkpoint: bytes,
+ task_assignment_mode: _TaskAssignmentMode,
+ number_of_clients: int,
+ ) -> bytes:
+ """Runs a computation, returning the resulting checkpoint.
+
+ If there's already a computation in progress, the new computation will
+ not start until the previous one has completed (either successfully or not).
+
+ Args:
+ task_name: The name of the task.
+ plan: The Plan proto containing the client and server computations.
+ server_checkpoint: The starting server checkpoint.
+ task_assignment_mode: The task assignment mode to use for the computation.
+ number_of_clients: The minimum number of clients to include.
+
+ Returns:
+ A TensorFlow checkpoint containing the aggregated results.
+ """
+ requirements = aggregations.AggregationRequirements(
+ minimum_clients_in_server_published_aggregate=number_of_clients,
+ plan=plan)
+ session_id = self._aggregations_service.create_session(requirements)
+ with contextlib.ExitStack() as stack:
+ stack.callback(
+ lambda: self._aggregations_service.abort_session(session_id))
+ with plan_utils.Session(plan, server_checkpoint) as session:
+ with self._media_service.create_download_group() as group:
+ plan_url = group.add(
+ 'plan',
+ gzip.compress(session.client_plan),
+ content_type=_PLAN_CONTENT_TYPE,
+ )
+ checkpoint_url = group.add(
+ 'checkpoint',
+ gzip.compress(session.client_checkpoint),
+ content_type=_CHECKPOINT_CONTENT_TYPE,
+ )
+ for served_at_id, slices in session.slices.items():
+ for i, slice_data in enumerate(slices):
+ group.add(
+ _FEDERATED_SELECT_NAME_TEMPLATE.format(
+ served_at_id=served_at_id, key_base10=str(i)
+ ),
+ gzip.compress(slice_data),
+ content_type=_CHECKPOINT_CONTENT_TYPE,
+ )
+ self._eligibility_eval_tasks_service.add_task(
+ task_name, task_assignment_mode
+ )
+ self._task_assignments_service.add_task(
+ task_name,
+ task_assignment_mode,
+ session_id,
+ common_pb2.Resource(uri=plan_url),
+ common_pb2.Resource(uri=checkpoint_url),
+ group.prefix + _FEDERATED_SELECT_NAME_TEMPLATE,
+ )
+ try:
+ status = await self._aggregations_service.wait(
+ session_id,
+ num_inputs_aggregated_and_included=number_of_clients)
+ if status.status != aggregations.AggregationStatus.PENDING:
+ raise ValueError('Aggregation failed.')
+ finally:
+ self._task_assignments_service.remove_task(session_id)
+ self._eligibility_eval_tasks_service.remove_task(task_name)
+
+ stack.pop_all()
+ status, intermedia_update = (
+ self._aggregations_service.complete_session(session_id))
+ if (status.status != aggregations.AggregationStatus.COMPLETED or
+ intermedia_update is None):
+ raise ValueError('Aggregation failed.')
+ logging.debug('%s aggregation complete: %s', task_name, status)
+ return session.finalize(intermedia_update)
+
+ def _get_forwarding_info(self) -> common_pb2.ForwardingInfo:
+ protocol = 'https' if isinstance(self.socket, ssl.SSLSocket) else 'http'
+ return common_pb2.ForwardingInfo(
+ target_uri_prefix=(
+ f'{protocol}://{self.server_name}:{self.server_port}/'))
diff --git a/fcp/demo/server_test.py b/fcp/demo/server_test.py
new file mode 100644
index 0000000..b268f7e
--- /dev/null
+++ b/fcp/demo/server_test.py
@@ -0,0 +1,284 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for server."""
+
+import asyncio
+import gzip
+import http
+import http.client
+import os
+import threading
+import unittest
+from unittest import mock
+import urllib.parse
+import urllib.request
+
+from absl import flags
+from absl import logging
+from absl.testing import absltest
+import tensorflow as tf
+
+from google.longrunning import operations_pb2
+from fcp.demo import plan_utils
+from fcp.demo import server
+from fcp.demo import test_utils
+from fcp.protos import plan_pb2
+from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
+from fcp.protos.federatedcompute import task_assignments_pb2
+from fcp.tensorflow import external_dataset
+
+_TaskAssignmentMode = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
+)
+
+POPULATION_NAME = 'test/population'
+CAP_TENSOR_NAME = 'cap'
+COUNT_TENSOR_NAME = 'count'
+TEST_SLICES = {
+ 'id1': [b'1-1', b'1-2', b'1-3'],
+ 'id2': [b'2-1', b'2-2'],
+}
+
+
+def create_plan() -> plan_pb2.Plan:
+ """Creates a test plan that counts examples, with a per-client cap."""
+
+ with tf.compat.v1.Graph().as_default() as client_graph:
+ dataset_token = tf.compat.v1.placeholder(tf.string, shape=())
+ input_filepath = tf.compat.v1.placeholder(tf.string, shape=())
+ output_filepath = tf.compat.v1.placeholder(tf.string, shape=())
+ ds = external_dataset.ExternalDataset(token=dataset_token, selector=b'')
+ cap = tf.raw_ops.Restore(
+ file_pattern=input_filepath, tensor_name=CAP_TENSOR_NAME, dt=tf.int32)
+ count = ds.take(tf.cast(cap, dtype=tf.int64)).reduce(0, lambda x, _: x + 1)
+ target_node = tf.raw_ops.Save(
+ filename=output_filepath,
+ tensor_names=[COUNT_TENSOR_NAME],
+ data=[count])
+
+ with tf.compat.v1.Graph().as_default() as server_graph:
+ filename = tf.compat.v1.placeholder(tf.string, shape=())
+ contribution_cap = tf.Variable(0, dtype=tf.int32)
+ count = tf.Variable(0, dtype=tf.int32)
+ load_initial_count = count.assign(
+ tf.raw_ops.Restore(
+ file_pattern=filename, tensor_name=COUNT_TENSOR_NAME, dt=tf.int32),
+ read_value=False)
+ load_contribution_cap = contribution_cap.assign(
+ tf.raw_ops.Restore(
+ file_pattern=filename, tensor_name=CAP_TENSOR_NAME, dt=tf.int32),
+ read_value=False)
+ with tf.control_dependencies([load_initial_count, load_contribution_cap]):
+ restore_server_savepoint = tf.no_op()
+ write_client_init = tf.raw_ops.Save(
+ filename=filename,
+ tensor_names=[CAP_TENSOR_NAME],
+ data=[contribution_cap])
+
+ read_intermediate_update = count.assign_add(
+ tf.raw_ops.Restore(
+ file_pattern=filename, tensor_name=COUNT_TENSOR_NAME, dt=tf.int32))
+ save_count = tf.raw_ops.Save(
+ filename=filename, tensor_names=[COUNT_TENSOR_NAME], data=[count])
+
+ plan = plan_pb2.Plan(
+ phase=[
+ plan_pb2.Plan.Phase(
+ client_phase=plan_pb2.ClientPhase(
+ tensorflow_spec=plan_pb2.TensorflowSpec(
+ dataset_token_tensor_name=dataset_token.op.name,
+ input_tensor_specs=[
+ tf.TensorSpec.from_tensor(
+ input_filepath).experimental_as_proto(),
+ tf.TensorSpec.from_tensor(
+ output_filepath).experimental_as_proto(),
+ ],
+ target_node_names=[target_node.name]),
+ federated_compute=plan_pb2.FederatedComputeIORouter(
+ input_filepath_tensor_name=input_filepath.op.name,
+ output_filepath_tensor_name=output_filepath.op.name)),
+ server_phase=plan_pb2.ServerPhase(
+ write_client_init=plan_pb2.CheckpointOp(
+ saver_def=tf.compat.v1.train.SaverDef(
+ filename_tensor_name=filename.name,
+ save_tensor_name=write_client_init.name)),
+ read_intermediate_update=plan_pb2.CheckpointOp(
+ saver_def=tf.compat.v1.train.SaverDef(
+ filename_tensor_name=filename.name,
+ restore_op_name=read_intermediate_update.name))),
+ server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[
+ plan_pb2.ServerAggregationConfig(
+ intrinsic_uri='federated_sum',
+ intrinsic_args=[
+ plan_pb2.ServerAggregationConfig.IntrinsicArg(
+ input_tensor=tf.TensorSpec(
+ (), tf.int32,
+ COUNT_TENSOR_NAME).experimental_as_proto())
+ ],
+ output_tensors=[
+ tf.TensorSpec((), tf.int32, COUNT_TENSOR_NAME)
+ .experimental_as_proto()
+ ])
+ ]))
+ ],
+ server_savepoint=plan_pb2.CheckpointOp(
+ saver_def=tf.compat.v1.train.SaverDef(
+ filename_tensor_name=filename.name,
+ save_tensor_name=save_count.name,
+ restore_op_name=restore_server_savepoint.name)),
+ version=1)
+ plan.client_graph_bytes.Pack(client_graph.as_graph_def())
+ plan.server_graph_bytes.Pack(server_graph.as_graph_def())
+ return plan
+
+
+class ServerTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.server = server.InProcessServer( # pytype: disable=wrong-arg-types
+ population_name=POPULATION_NAME,
+ host='localhost',
+ port=0)
+ self._server_thread = threading.Thread(target=self.server.serve_forever)
+ self._server_thread.start()
+ self.conn = http.client.HTTPConnection(
+ self.server.server_name, port=self.server.server_port)
+
+ def tearDown(self):
+ self.server.shutdown()
+ self._server_thread.join()
+ self.server.server_close()
+ super().tearDown()
+
+ async def wait_for_task(self) -> task_assignments_pb2.TaskAssignment:
+ """Polls the server until a task is being served."""
+ pop = urllib.parse.quote(POPULATION_NAME, safe='')
+ url = f'/v1/populations/{pop}/taskassignments/test:start?%24alt=proto'
+ request = task_assignments_pb2.StartTaskAssignmentRequest()
+ while True:
+ self.conn.request('POST', url, request.SerializeToString())
+ http_response = self.conn.getresponse()
+ if http_response.status == http.HTTPStatus.OK:
+ op = operations_pb2.Operation.FromString(http_response.read())
+ response = task_assignments_pb2.StartTaskAssignmentResponse()
+ op.response.Unpack(response)
+ if response.HasField('task_assignment'):
+ logging.info('wait_for_task received assignment to %s',
+ response.task_assignment.task_name)
+ return response.task_assignment
+ await asyncio.sleep(0.5)
+
+ async def test_run_computation(self):
+ initial_count = 100
+ cap = 10
+ examples_per_client = [1, 5, 15]
+ checkpoint = test_utils.create_checkpoint({
+ CAP_TENSOR_NAME: cap,
+ COUNT_TENSOR_NAME: initial_count,
+ })
+ run_computation_task = asyncio.create_task(
+ self.server.run_computation(
+ 'task/name',
+ create_plan(),
+ checkpoint,
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
+ len(examples_per_client),
+ )
+ )
+
+ # Wait for task assignment to return a task.
+ wait_task = asyncio.create_task(self.wait_for_task())
+ await asyncio.wait([run_computation_task, wait_task],
+ timeout=10,
+ return_when=asyncio.FIRST_COMPLETED)
+ self.assertTrue(wait_task.done())
+ # `run_computation` should not be done since no clients have reported.
+ self.assertFalse(run_computation_task.done())
+
+ client_runner = os.path.join(
+ flags.FLAGS.test_srcdir,
+ 'com_google_fcp',
+ 'fcp',
+ 'client',
+ 'client_runner_main')
+ server_url = f'http://{self.server.server_name}:{self.server.server_port}/'
+ clients = []
+ for num_examples in examples_per_client:
+ subprocess = asyncio.create_subprocess_exec(
+ client_runner, f'--server={server_url}',
+ f'--population={POPULATION_NAME}',
+ f'--num_empty_examples={num_examples}', '--sleep_after_round_secs=0',
+ '--use_http_federated_compute_protocol')
+ clients.append(asyncio.create_task((await subprocess).wait()))
+
+ # Wait for the computation to complete.
+ await asyncio.wait([run_computation_task] + clients, timeout=10)
+ self.assertTrue(run_computation_task.done())
+ for client in clients:
+ self.assertTrue(client.done())
+ self.assertEqual(client.result(), 0)
+
+ # Verify the sum in the checkpoint.
+ result = test_utils.read_tensor_from_checkpoint(
+ run_computation_task.result(), COUNT_TENSOR_NAME, tf.int32)
+ self.assertEqual(
+ result, initial_count + sum([min(n, cap) for n in examples_per_client]))
+
+ @mock.patch.object(
+ plan_utils.Session,
+ 'slices',
+ new=property(lambda unused_self: TEST_SLICES),
+ )
+ async def test_federated_select(self):
+ checkpoint = test_utils.create_checkpoint({
+ CAP_TENSOR_NAME: 100,
+ COUNT_TENSOR_NAME: 0,
+ })
+ run_computation_task = asyncio.create_task(
+ self.server.run_computation(
+ 'task/name',
+ create_plan(),
+ checkpoint,
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
+ 1,
+ )
+ )
+
+ # Wait for task assignment to return a task.
+ wait_task = asyncio.create_task(self.wait_for_task())
+ await asyncio.wait(
+ [run_computation_task, wait_task],
+ timeout=10,
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ self.assertTrue(wait_task.done())
+ uri_template = wait_task.result().federated_select_uri_info.uri_template
+ self.assertNotEmpty(uri_template)
+
+ # Check the contents of the slices.
+ for served_at_id, slices in TEST_SLICES.items():
+ for i, slice_data in enumerate(slices):
+ with urllib.request.urlopen(
+ uri_template.format(served_at_id=served_at_id, key_base10=str(i))
+ ) as response:
+ self.assertEqual(
+ response.getheader('Content-Type'),
+ 'application/octet-stream+gzip',
+ )
+ self.assertEqual(gzip.decompress(response.read()), slice_data)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/task_assignments.py b/fcp/demo/task_assignments.py
new file mode 100644
index 0000000..46d67a6
--- /dev/null
+++ b/fcp/demo/task_assignments.py
@@ -0,0 +1,230 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Action handlers for the TaskAssignments service."""
+
+import collections
+import dataclasses
+import http
+import threading
+from typing import Callable, Optional
+import uuid
+
+from absl import logging
+
+from google.longrunning import operations_pb2
+from google.rpc import code_pb2
+from google.protobuf import text_format
+from fcp.demo import aggregations
+from fcp.demo import http_actions
+from fcp.protos.federatedcompute import common_pb2
+from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
+from fcp.protos.federatedcompute import task_assignments_pb2
+
+_TaskAssignmentMode = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
+)
+
+
+@dataclasses.dataclass(frozen=True)
+class _Task:
+ task_name: str
+ aggregation_session_id: str
+ init_checkpoint: common_pb2.Resource
+ plan: common_pb2.Resource
+ federated_select_uri_template: str
+
+
+class Service:
+ """Implements the TaskAssignments service."""
+
+ def __init__(self, population_name: str,
+ forwarding_info: Callable[[], common_pb2.ForwardingInfo],
+ aggregations_service: aggregations.Service):
+ self._population_name = population_name
+ self._forwarding_info = forwarding_info
+ self._aggregations_service = aggregations_service
+ self._single_assignment_tasks = collections.deque()
+ self._multiple_assignment_tasks: list[_Task] = []
+ self._tasks_lock = threading.Lock()
+
+ def add_task(
+ self,
+ task_name: str,
+ task_assignment_mode: _TaskAssignmentMode,
+ aggregation_session_id: str,
+ plan: common_pb2.Resource,
+ init_checkpoint: common_pb2.Resource,
+ federated_select_uri_template: str,
+ ):
+ """Adds a new task to the service."""
+ task = _Task(
+ task_name=task_name,
+ aggregation_session_id=aggregation_session_id,
+ init_checkpoint=init_checkpoint,
+ plan=plan,
+ federated_select_uri_template=federated_select_uri_template,
+ )
+ if task_assignment_mode == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE:
+ with self._tasks_lock:
+ self._single_assignment_tasks.append(task)
+ elif (
+ task_assignment_mode
+ == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE
+ ):
+ with self._tasks_lock:
+ self._multiple_assignment_tasks.append(task)
+ else:
+ raise ValueError(f'Unsupport TaskAssignmentMode {task_assignment_mode}.')
+
+ def remove_task(self, aggregation_session_id: str):
+ """Removes a task from the service."""
+ with self._tasks_lock:
+ for task in self._single_assignment_tasks:
+ if task.aggregation_session_id == aggregation_session_id:
+ self._single_assignment_tasks.remove(task)
+ return
+ for task in self._multiple_assignment_tasks:
+ if task.aggregation_session_id == aggregation_session_id:
+ self._multiple_assignment_tasks.remove(task)
+ return
+ raise KeyError(aggregation_session_id)
+
+ @property
+ def _current_task(self) -> Optional[_Task]:
+ with self._tasks_lock:
+ return (
+ self._single_assignment_tasks[0]
+ if self._single_assignment_tasks
+ else None
+ )
+
+ @http_actions.proto_action(
+ service='google.internal.federatedcompute.v1.TaskAssignments',
+ method='StartTaskAssignment')
+ def start_task_assignment(
+ self, request: task_assignments_pb2.StartTaskAssignmentRequest
+ ) -> operations_pb2.Operation:
+ """Handles a StartTaskAssignment request."""
+ if request.population_name != self._population_name:
+ raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
+
+ # NOTE: A production implementation should consider whether the current task
+ # supports `request.client_version` before assigning the client. Given that
+ # all clients may not be eligible for all tasks, consider more sophisticated
+ # assignment than a FIFO queue.
+ task = self._current_task
+ if task:
+ logging.debug('[%s] StartTaskAssignment: assigned %s', request.session_id,
+ task.task_name)
+ # NOTE: If a production implementation of the Aggregations service cannot
+ # always pre-authorize clients (e.g., due to rate-limiting incoming
+ # clients), this code should either retry the operation or return a
+ # non-permanent error to the client (e.g., UNAVAILABLE).
+ authorization_token = self._aggregations_service.pre_authorize_clients(
+ task.aggregation_session_id, num_tokens=1)[0]
+ response = task_assignments_pb2.StartTaskAssignmentResponse(
+ task_assignment=task_assignments_pb2.TaskAssignment(
+ aggregation_data_forwarding_info=self._forwarding_info(),
+ aggregation_info=(
+ task_assignments_pb2.TaskAssignment.AggregationInfo()
+ ),
+ session_id=request.session_id,
+ aggregation_id=task.aggregation_session_id,
+ authorization_token=authorization_token,
+ task_name=task.task_name,
+ init_checkpoint=task.init_checkpoint,
+ plan=task.plan,
+ federated_select_uri_info=(
+ task_assignments_pb2.FederatedSelectUriInfo(
+ uri_template=task.federated_select_uri_template
+ )
+ ),
+ )
+ )
+ else:
+ # NOTE: Instead of immediately rejecting clients, a production
+ # implementation may keep around some number of clients to be assigned to
+ # queued tasks or even future rounds of the current task (depending on how
+ # quickly rounds complete).
+ logging.debug('[%s] StartTaskAssignment: rejected', request.session_id)
+ response = task_assignments_pb2.StartTaskAssignmentResponse(
+ rejection_info=common_pb2.RejectionInfo())
+
+ # If task assignment took significant time, we return a longrunning
+ # Operation; since this implementation makes assignment decisions right
+ # away, we can return an already-completed operation.
+ op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True)
+ op.metadata.Pack(task_assignments_pb2.StartTaskAssignmentMetadata())
+ op.response.Pack(response)
+ return op
+
+ @http_actions.proto_action(
+ service='google.internal.federatedcompute.v1.TaskAssignments',
+ method='PerformMultipleTaskAssignments')
+ def perform_multiple_task_assignments(
+ self, request: task_assignments_pb2.PerformMultipleTaskAssignmentsRequest
+ ) -> task_assignments_pb2.PerformMultipleTaskAssignmentsResponse:
+ """Handles a PerformMultipleTaskAssignments request."""
+ if request.population_name != self._population_name:
+ raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
+
+ task_assignments = []
+ with self._tasks_lock:
+ for task in self._multiple_assignment_tasks:
+ if task.task_name not in request.task_names:
+ continue
+
+ # NOTE: A production implementation should consider whether the task
+ # supports `request.client_version` before assigning the client.
+
+ authorization_token = self._aggregations_service.pre_authorize_clients(
+ task.aggregation_session_id, num_tokens=1)[0]
+ task_assignments.append(
+ task_assignments_pb2.TaskAssignment(
+ aggregation_data_forwarding_info=self._forwarding_info(),
+ aggregation_info=(
+ task_assignments_pb2.TaskAssignment.AggregationInfo()
+ ),
+ session_id=request.session_id,
+ aggregation_id=task.aggregation_session_id,
+ authorization_token=authorization_token,
+ task_name=task.task_name,
+ init_checkpoint=task.init_checkpoint,
+ plan=task.plan,
+ federated_select_uri_info=(
+ task_assignments_pb2.FederatedSelectUriInfo(
+ uri_template=task.federated_select_uri_template
+ )
+ ),
+ )
+ )
+
+ return task_assignments_pb2.PerformMultipleTaskAssignmentsResponse(
+ task_assignments=task_assignments)
+
+ @http_actions.proto_action(
+ service='google.internal.federatedcompute.v1.TaskAssignments',
+ method='ReportTaskResult')
+ def report_task_result(
+ self, request: task_assignments_pb2.ReportTaskResultRequest
+ ) -> task_assignments_pb2.ReportTaskResultResponse:
+ """Handles a ReportTaskResult request."""
+ if request.population_name != self._population_name:
+ raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
+ logging.log(
+ (logging.DEBUG if request.computation_status_code == code_pb2.OK else
+ logging.WARN), '[%s] ReportTaskResult: %s (%s)', request.session_id,
+ code_pb2.Code.Name(request.computation_status_code),
+ text_format.MessageToString(request.client_stats, as_one_line=True))
+ return task_assignments_pb2.ReportTaskResultResponse()
diff --git a/fcp/demo/task_assignments_test.py b/fcp/demo/task_assignments_test.py
new file mode 100644
index 0000000..4eeab3f
--- /dev/null
+++ b/fcp/demo/task_assignments_test.py
@@ -0,0 +1,453 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for task_assignments."""
+
+import http
+from unittest import mock
+import uuid
+
+from absl.testing import absltest
+
+from google.rpc import code_pb2
+from fcp.demo import aggregations
+from fcp.demo import http_actions
+from fcp.demo import task_assignments
+from fcp.protos.federatedcompute import common_pb2
+from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
+from fcp.protos.federatedcompute import task_assignments_pb2
+
+_TaskAssignmentMode = (
+ eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
+)
+
+POPULATION_NAME = 'test/population'
+FORWARDING_INFO = common_pb2.ForwardingInfo(
+ target_uri_prefix='https://forwarding.example/')
+
+
+class TaskAssignmentsTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.mock_aggregations_service = self.enter_context(
+ mock.patch.object(aggregations, 'Service', autospec=True))
+ self.mock_aggregations_service.pre_authorize_clients.return_value = ['']
+
+ def test_start_task_assignment_with_wrong_population(self):
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+ request = task_assignments_pb2.StartTaskAssignmentRequest(
+ population_name='other/population', session_id='session-id')
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.start_task_assignment(request)
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+
+ @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
+ def test_start_task_assignment_with_no_tasks(self, mock_uuid):
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+ request = task_assignments_pb2.StartTaskAssignmentRequest(
+ population_name=POPULATION_NAME, session_id='session-id')
+ operation = service.start_task_assignment(request)
+ self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}')
+ self.assertTrue(operation.done)
+
+ metadata = task_assignments_pb2.StartTaskAssignmentMetadata()
+ operation.metadata.Unpack(metadata)
+ self.assertEqual(metadata,
+ task_assignments_pb2.StartTaskAssignmentMetadata())
+
+ response = task_assignments_pb2.StartTaskAssignmentResponse()
+ operation.response.Unpack(response)
+ self.assertEqual(
+ response,
+ task_assignments_pb2.StartTaskAssignmentResponse(
+ rejection_info=common_pb2.RejectionInfo()
+ ),
+ )
+
+ def test_start_task_assignment_with_multiple_assignment_task(self):
+ service = task_assignments.Service(
+ POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
+ )
+ service.add_task(
+ 'task',
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
+ 'aggregation-session',
+ common_pb2.Resource(uri='https://task.example/plan'),
+ common_pb2.Resource(uri='https://task.example/checkpoint'),
+ 'https://task.example/{key_base10}',
+ )
+
+ request = task_assignments_pb2.StartTaskAssignmentRequest(
+ population_name=POPULATION_NAME, session_id='session-id'
+ )
+ operation = service.start_task_assignment(request)
+ self.assertTrue(operation.done)
+
+ response = task_assignments_pb2.StartTaskAssignmentResponse()
+ operation.response.Unpack(response)
+ self.assertEqual(
+ response,
+ task_assignments_pb2.StartTaskAssignmentResponse(
+ rejection_info=common_pb2.RejectionInfo()))
+
+ @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
+ def test_start_task_assignment_with_one_task(self, mock_uuid):
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+
+ self.mock_aggregations_service.pre_authorize_clients.return_value = [
+ 'token'
+ ]
+
+ task_plan = common_pb2.Resource(uri='https://task.example/plan')
+ task_checkpoint = common_pb2.Resource(uri='https://task.example/checkpoint')
+ task_federated_select_uri_template = 'https://task.example/{key_base10}'
+ service.add_task(
+ 'task',
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
+ 'aggregation-session',
+ task_plan,
+ task_checkpoint,
+ task_federated_select_uri_template,
+ )
+
+ request = task_assignments_pb2.StartTaskAssignmentRequest(
+ population_name=POPULATION_NAME, session_id='session-id')
+ operation = service.start_task_assignment(request)
+ self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}')
+ self.assertTrue(operation.done)
+
+ metadata = task_assignments_pb2.StartTaskAssignmentMetadata()
+ operation.metadata.Unpack(metadata)
+ self.assertEqual(metadata,
+ task_assignments_pb2.StartTaskAssignmentMetadata())
+
+ response = task_assignments_pb2.StartTaskAssignmentResponse()
+ operation.response.Unpack(response)
+ self.assertEqual(
+ response,
+ task_assignments_pb2.StartTaskAssignmentResponse(
+ task_assignment=task_assignments_pb2.TaskAssignment(
+ aggregation_data_forwarding_info=FORWARDING_INFO,
+ aggregation_info=(
+ task_assignments_pb2.TaskAssignment.AggregationInfo()
+ ),
+ session_id=request.session_id,
+ aggregation_id='aggregation-session',
+ authorization_token='token',
+ task_name='task',
+ plan=task_plan,
+ init_checkpoint=task_checkpoint,
+ federated_select_uri_info=(
+ task_assignments_pb2.FederatedSelectUriInfo(
+ uri_template=task_federated_select_uri_template
+ )
+ ),
+ )
+ ),
+ )
+
+ self.mock_aggregations_service.pre_authorize_clients.assert_called_once_with(
+ 'aggregation-session', num_tokens=1)
+
+ def test_start_task_assignment_with_multiple_tasks(self):
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+
+ self.mock_aggregations_service.pre_authorize_clients.return_value = [
+ 'token'
+ ]
+
+ task1_plan = common_pb2.Resource(uri='https://task1.example/plan')
+ task1_checkpoint = common_pb2.Resource(
+ uri='https://task1.example/checkpoint')
+ task1_federated_select_uri_template = 'https://task1.example/{key_base10}'
+ service.add_task(
+ 'task1',
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
+ 'aggregation-session1',
+ task1_plan,
+ task1_checkpoint,
+ task1_federated_select_uri_template,
+ )
+ task2_plan = common_pb2.Resource(uri='https://task2.example/plan')
+ task2_checkpoint = common_pb2.Resource(
+ uri='https://task2.example/checkpoint')
+ task2_federated_select_uri_template = 'https://task2.example/{key_base10}'
+ service.add_task(
+ 'task2',
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
+ 'aggregation-session2',
+ task2_plan,
+ task2_checkpoint,
+ task2_federated_select_uri_template,
+ )
+
+ request = task_assignments_pb2.StartTaskAssignmentRequest(
+ population_name=POPULATION_NAME, session_id='session-id')
+
+ # Initially, task1 should be used.
+ operation = service.start_task_assignment(request)
+ response = task_assignments_pb2.StartTaskAssignmentResponse()
+ operation.response.Unpack(response)
+ self.assertEqual(
+ response,
+ task_assignments_pb2.StartTaskAssignmentResponse(
+ task_assignment=task_assignments_pb2.TaskAssignment(
+ aggregation_data_forwarding_info=FORWARDING_INFO,
+ aggregation_info=(
+ task_assignments_pb2.TaskAssignment.AggregationInfo()
+ ),
+ session_id=request.session_id,
+ aggregation_id='aggregation-session1',
+ authorization_token='token',
+ task_name='task1',
+ plan=task1_plan,
+ init_checkpoint=task1_checkpoint,
+ federated_select_uri_info=(
+ task_assignments_pb2.FederatedSelectUriInfo(
+ uri_template=task1_federated_select_uri_template
+ )
+ ),
+ )
+ ),
+ )
+ self.mock_aggregations_service.pre_authorize_clients.assert_called_with(
+ 'aggregation-session1', num_tokens=1)
+
+ # After task1 is removed, task2 should be used.
+ service.remove_task('aggregation-session1')
+ operation = service.start_task_assignment(request)
+ response = task_assignments_pb2.StartTaskAssignmentResponse()
+ operation.response.Unpack(response)
+ self.assertEqual(
+ response,
+ task_assignments_pb2.StartTaskAssignmentResponse(
+ task_assignment=task_assignments_pb2.TaskAssignment(
+ aggregation_data_forwarding_info=FORWARDING_INFO,
+ aggregation_info=(
+ task_assignments_pb2.TaskAssignment.AggregationInfo()
+ ),
+ session_id=request.session_id,
+ aggregation_id='aggregation-session2',
+ authorization_token='token',
+ task_name='task2',
+ plan=task2_plan,
+ init_checkpoint=task2_checkpoint,
+ federated_select_uri_info=(
+ task_assignments_pb2.FederatedSelectUriInfo(
+ uri_template=task2_federated_select_uri_template
+ )
+ ),
+ )
+ ),
+ )
+ self.mock_aggregations_service.pre_authorize_clients.assert_called_with(
+ 'aggregation-session2', num_tokens=1)
+
+ # After task2 is removed, the client should be rejected.
+ service.remove_task('aggregation-session2')
+ operation = service.start_task_assignment(request)
+ response = task_assignments_pb2.StartTaskAssignmentResponse()
+ operation.response.Unpack(response)
+ self.assertEqual(
+ response,
+ task_assignments_pb2.StartTaskAssignmentResponse(
+ rejection_info=common_pb2.RejectionInfo()))
+
+ def test_perform_multiple_task_assignments_with_wrong_population(self):
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+ request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
+ population_name='other/population',
+ session_id='session-id',
+ task_names=['task1', 'task2', 'task3'])
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.perform_multiple_task_assignments(request)
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+
+ def test_perform_multiple_task_assignments_without_tasks(self):
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+
+ request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
+ population_name=POPULATION_NAME,
+ session_id='session-id',
+ task_names=['task1', 'task2', 'task3'])
+ self.assertEqual(
+ service.perform_multiple_task_assignments(request),
+ task_assignments_pb2.PerformMultipleTaskAssignmentsResponse())
+
+ def test_perform_multiple_task_assignments_with_multiple_tasks(self):
+ self.mock_aggregations_service.pre_authorize_clients.side_effect = (
+ lambda session_id, num_tokens=1: [f'token-for-{session_id}'])
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+
+ task1_plan = common_pb2.Resource(uri='https://task1.example/plan')
+ task1_checkpoint = common_pb2.Resource(
+ uri='https://task1.example/checkpoint')
+ task1_federated_select_uri_template = 'https://task1.example/{key_base10}'
+ service.add_task(
+ 'task1',
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
+ 'aggregation-session1',
+ task1_plan,
+ task1_checkpoint,
+ task1_federated_select_uri_template,
+ )
+ task2_plan = common_pb2.Resource(uri='https://task2.example/plan')
+ task2_checkpoint = common_pb2.Resource(
+ uri='https://task2.example/checkpoint')
+ task2_federated_select_uri_template = 'https://task2.example/{key_base10}'
+ service.add_task(
+ 'task2',
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
+ 'aggregation-session2',
+ task2_plan,
+ task2_checkpoint,
+ task2_federated_select_uri_template,
+ )
+ # Tasks using other TaskAssignmentModes should be skipped.
+ task3_plan = common_pb2.Resource(uri='https://task3.example/plan')
+ task3_checkpoint = common_pb2.Resource(
+ uri='https://task3.example/checkpoint'
+ )
+ task3_federated_select_uri_template = 'https://task3.example/{key_base10}'
+ service.add_task(
+ 'task3',
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
+ 'aggregation-session3',
+ task3_plan,
+ task3_checkpoint,
+ task3_federated_select_uri_template,
+ )
+
+ request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
+ population_name=POPULATION_NAME,
+ session_id='session-id',
+ task_names=['task1', 'task2', 'task3'])
+ self.assertCountEqual(
+ service.perform_multiple_task_assignments(request).task_assignments,
+ [
+ task_assignments_pb2.TaskAssignment(
+ aggregation_data_forwarding_info=FORWARDING_INFO,
+ aggregation_info=(
+ task_assignments_pb2.TaskAssignment.AggregationInfo()
+ ),
+ session_id=request.session_id,
+ aggregation_id='aggregation-session1',
+ authorization_token='token-for-aggregation-session1',
+ task_name='task1',
+ plan=task1_plan,
+ init_checkpoint=task1_checkpoint,
+ federated_select_uri_info=(
+ task_assignments_pb2.FederatedSelectUriInfo(
+ uri_template=task1_federated_select_uri_template
+ )
+ ),
+ ),
+ task_assignments_pb2.TaskAssignment(
+ aggregation_data_forwarding_info=FORWARDING_INFO,
+ aggregation_info=(
+ task_assignments_pb2.TaskAssignment.AggregationInfo()
+ ),
+ session_id=request.session_id,
+ aggregation_id='aggregation-session2',
+ authorization_token='token-for-aggregation-session2',
+ task_name='task2',
+ plan=task2_plan,
+ init_checkpoint=task2_checkpoint,
+ federated_select_uri_info=(
+ task_assignments_pb2.FederatedSelectUriInfo(
+ uri_template=task2_federated_select_uri_template
+ )
+ ),
+ ),
+ # 'task3' should be omitted since there isn't a corresponding task.
+ ],
+ )
+
+ def test_add_task_with_invalid_task_assignment_mode(self):
+ service = task_assignments.Service(
+ POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
+ )
+ with self.assertRaises(ValueError):
+ service.add_task(
+ 'task',
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_UNSPECIFIED,
+ 'aggregation-session',
+ common_pb2.Resource(uri='https://task.example/plan'),
+ common_pb2.Resource(uri='https://task.example/checkpoint'),
+ 'https://task.example/{key_base10}',
+ )
+
+ def test_remove_multiple_assignment_task(self):
+ service = task_assignments.Service(
+ POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
+ )
+ service.add_task(
+ 'task',
+ _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
+ 'aggregation-session',
+ common_pb2.Resource(uri='https://task.example/plan'),
+ common_pb2.Resource(uri='https://task.example/checkpoint'),
+ 'https://task.example/{key_base10}',
+ )
+ service.remove_task('aggregation-session')
+
+ request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
+ population_name=POPULATION_NAME,
+ session_id='session-id',
+ task_names=['task'],
+ )
+ self.assertEmpty(
+ service.perform_multiple_task_assignments(request).task_assignments
+ )
+
+ def test_remove_missing_task(self):
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+ with self.assertRaises(KeyError):
+ service.remove_task('does-not-exist')
+
+ def test_report_task_result(self):
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+ request = task_assignments_pb2.ReportTaskResultRequest(
+ population_name=POPULATION_NAME,
+ session_id='session-id',
+ aggregation_id='aggregation-id',
+ computation_status_code=code_pb2.ABORTED)
+ self.assertEqual(
+ service.report_task_result(request),
+ task_assignments_pb2.ReportTaskResultResponse())
+
+ def test_report_task_result_with_wrong_population(self):
+ service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
+ self.mock_aggregations_service)
+ request = task_assignments_pb2.ReportTaskResultRequest(
+ population_name='other/population',
+ session_id='session-id',
+ aggregation_id='aggregation-id',
+ computation_status_code=code_pb2.ABORTED)
+ with self.assertRaises(http_actions.HttpError) as cm:
+ service.report_task_result(request)
+ self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/demo/test_utils.py b/fcp/demo/test_utils.py
new file mode 100644
index 0000000..641da1b
--- /dev/null
+++ b/fcp/demo/test_utils.py
@@ -0,0 +1,44 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Helper functions for writing tests."""
+
+import tempfile
+from typing import Any, Mapping
+
+import tensorflow as tf
+
+
+def create_checkpoint(data: Mapping[str, Any]) -> bytes:
+ """Creates a TensorFlow checkpoint."""
+ with tempfile.NamedTemporaryFile() as tmpfile:
+ with tf.compat.v1.Session() as session:
+ session.run(
+ tf.raw_ops.Save(
+ filename=tmpfile.name,
+ tensor_names=list(data.keys()),
+ data=list(data.values())))
+ with open(tmpfile.name, 'rb') as f:
+ return f.read()
+
+
+def read_tensor_from_checkpoint(checkpoint: bytes, tensor_name: str,
+ dt: tf.DType) -> Any:
+ """Reads a single tensor from a checkpoint."""
+ with tempfile.NamedTemporaryFile('wb') as tmpfile:
+ tmpfile.write(checkpoint)
+ tmpfile.flush()
+ with tf.compat.v1.Session() as session:
+ return session.run(
+ tf.raw_ops.Restore(
+ file_pattern=tmpfile.name, tensor_name=tensor_name, dt=dt))
diff --git a/fcp/demo/test_utils_test.py b/fcp/demo/test_utils_test.py
new file mode 100644
index 0000000..85cbd95
--- /dev/null
+++ b/fcp/demo/test_utils_test.py
@@ -0,0 +1,52 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for test_utils."""
+
+from absl.testing import absltest
+
+import tensorflow as tf
+
+from fcp.demo import test_utils
+
+
+class TestUtilsTest(absltest.TestCase):
+
+ def test_create_checkpoint(self):
+ checkpoint = test_utils.create_checkpoint({
+ 'int': 3,
+ 'str': 'test',
+ 'list': [1, 2, 3],
+ })
+ self.assertEqual(
+ test_utils.read_tensor_from_checkpoint(checkpoint, 'int', tf.int32), 3)
+ self.assertEqual(
+ test_utils.read_tensor_from_checkpoint(checkpoint, 'str', tf.string),
+ b'test')
+ self.assertListEqual(
+ test_utils.read_tensor_from_checkpoint(checkpoint, 'list',
+ tf.int32).tolist(), [1, 2, 3])
+
+ def test_read_from_checkpoint_not_found(self):
+ checkpoint = test_utils.create_checkpoint({'int': 3})
+ with self.assertRaises(Exception):
+ test_utils.read_tensor_from_checkpoint(checkpoint, 'str', tf.string)
+
+ def test_read_from_checkpoint_wrong_type(self):
+ checkpoint = test_utils.create_checkpoint({'int': 3})
+ with self.assertRaises(Exception):
+ test_utils.read_tensor_from_checkpoint(checkpoint, 'int', tf.string)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/dictionary/BUILD b/fcp/dictionary/BUILD
new file mode 100644
index 0000000..d803c51
--- /dev/null
+++ b/fcp/dictionary/BUILD
@@ -0,0 +1,67 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library")
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+proto_library(
+ name = "dictionary_proto",
+ srcs = ["dictionary.proto"],
+)
+
+cc_proto_library(
+ name = "dictionary_cc_proto",
+ deps = [":dictionary_proto"],
+)
+
+py_proto_library(
+ name = "dictionary_py_pb2",
+ deps = [":dictionary_proto"],
+)
+
+cc_library(
+ name = "dictionary_lib",
+ srcs = ["dictionary.cc"],
+ hdrs = ["dictionary.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":dictionary_cc_proto",
+ "//fcp/base",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "dictionary_test",
+ srcs = ["dictionary_test.cc"],
+ visibility = ["//visibility:private"],
+ deps = [
+ ":dictionary_cc_proto",
+ ":dictionary_lib",
+ "//fcp/base",
+ "//fcp/testing:parse_text_proto",
+ "@com_google_absl//absl/status",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/dictionary/dictionary.cc b/fcp/dictionary/dictionary.cc
new file mode 100644
index 0000000..1ba374e
--- /dev/null
+++ b/fcp/dictionary/dictionary.cc
@@ -0,0 +1,184 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/dictionary/dictionary.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/dictionary/dictionary.pb.h"
+#include "absl/container/node_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+
+namespace fcp {
+namespace dictionary {
+
+// Bidirectional map defined as hash_map from strings to int32_t paired with
+// a vector of those keys for reverse lookup.
+typedef std::pair<absl::node_hash_map<std::string, int32_t>,
+ std::vector<std::string>>
+ HashVectorBimap;
+
+namespace {
+
+// Map a string to an ID, using a bidirectional map (an std::pair containing
+// two data structures for string -> int and for int -> string lookups).
+int32_t MapLookup(const HashVectorBimap& bimap, const std::string& tag) {
+ auto map_idx = bimap.first.find(tag);
+ return map_idx == bimap.first.end() ? Dictionary::kNotFound : map_idx->second;
+}
+// Lookup a token given its ID.
+std::string MapReverseLookup(const HashVectorBimap& bimap, int32_t id) {
+ if (id < 0 || id >= bimap.second.size()) {
+ return "";
+ }
+ return bimap.second[id];
+}
+
+// Return the size of an stl-like data structure.
+int32_t GetSize(const HashVectorBimap& bimap) {
+ return static_cast<int32_t>(bimap.first.size());
+}
+
+int32_t GetMaxSpecialId(const DictionaryDescription::SpecialIds& special_ids) {
+ int32_t max_special_id = -1;
+ max_special_id = std::max(max_special_id, special_ids.bos());
+ max_special_id = std::max(max_special_id, special_ids.eos());
+ max_special_id = std::max(max_special_id, special_ids.unk());
+ return max_special_id;
+}
+
+// Dictionary implementation powered by templated utility functions above.
+template <typename Bimap>
+class DictionaryImpl : public Dictionary {
+ public:
+ DictionaryImpl(
+ std::unique_ptr<Bimap> bimap,
+ const DictionaryDescription::SpecialIds& special_ids,
+ const DictionaryDescription::OutputBlocklistIds& output_blocklist_ids)
+ : bimap_(std::move(bimap)),
+ special_ids_(special_ids),
+ max_special_id_(GetMaxSpecialId(special_ids)) {
+ // Validate special ids.
+ FCP_CHECK(special_ids.has_bos() == (special_ids.bos() >= 0));
+ FCP_CHECK(special_ids.has_eos() == (special_ids.eos() >= 0));
+ FCP_CHECK(special_ids.has_unk() == (special_ids.unk() >= 0));
+
+ // Token numbering starts at max(special_ids) + 1.
+ output_blocklist_ids_.reserve(max_special_id_ + 1 +
+ output_blocklist_ids.id_size());
+ for (int32_t id = 0; id <= max_special_id_; ++id) {
+ output_blocklist_ids_.push_back(id);
+ }
+ for (int32_t id : output_blocklist_ids.id()) {
+ output_blocklist_ids_.push_back(id);
+ }
+ }
+
+ int32_t Size() const override {
+ return GetSize(*bimap_) + max_special_id_ + 1;
+ }
+
+ int32_t TokenToId(const std::string& tag) const override {
+ int32_t id = MapLookup(*bimap_, tag);
+ if (id == kNotFound) {
+ return special_ids_.unk();
+ } else {
+ return id + max_special_id_ + 1;
+ }
+ }
+
+ std::string IdToToken(int32_t id) const override {
+ return MapReverseLookup(*bimap_, id - (max_special_id_ + 1));
+ }
+
+ bool IsSpecialId(int32_t token_id) const override {
+ return token_id <= max_special_id_;
+ }
+
+ const std::vector<int32_t>& GetSortedOutputBlocklistIds() const override {
+ return output_blocklist_ids_;
+ }
+
+ const DictionaryDescription::SpecialIds& GetSpecialIds() const override {
+ return special_ids_;
+ }
+
+ private:
+ const std::unique_ptr<Bimap> bimap_;
+ const DictionaryDescription::SpecialIds special_ids_;
+ int32_t max_special_id_;
+ std::vector<int32_t> output_blocklist_ids_;
+};
+
+absl::Status IsOutputBlocklistIdsSortedAndUnique(
+ const DictionaryDescription& description) {
+ // All blocklist ids must be greater than max_special_id.
+ const int32_t max_special_id = GetMaxSpecialId(description.special_ids());
+
+ // Make sure output blocklist IDs are sorted in ascending order and unique.
+ if (description.has_output_blocklist_ids()) {
+ for (int i = 0; i < description.output_blocklist_ids().id_size(); i++) {
+ if (description.output_blocklist_ids().id(i) <= max_special_id) {
+ return absl::InvalidArgumentError(
+ "output_blocklist_ids should not overlap with special ids");
+ }
+ if (!(i == 0 || description.output_blocklist_ids().id(i) >
+ description.output_blocklist_ids().id(i - 1))) {
+ return absl::InvalidArgumentError(
+ "output_blocklist_ids not unique or sorted");
+ }
+ }
+ }
+ return absl::OkStatus();
+}
+
+} // anonymous namespace
+
+absl::StatusOr<std::unique_ptr<Dictionary>> Dictionary::Create(
+ const DictionaryDescription& description) {
+ if (!description.has_vocabulary()) {
+ return absl::InvalidArgumentError(
+ "Cannot create a dictionary that does not have vocabulary set");
+ }
+ // Make sure output blocklist IDs are sorted in ascending order and unique.
+ FCP_RETURN_IF_ERROR(IsOutputBlocklistIdsSortedAndUnique(description));
+
+ if (description.vocabulary().has_index()) {
+ auto bimap = std::make_unique<HashVectorBimap>();
+ int i = 0;
+ bimap->second.reserve(description.vocabulary().index().token_size());
+ for (const std::string& token : description.vocabulary().index().token()) {
+ FCP_CHECK(!token.empty());
+ bimap->first[token] = i++;
+ bimap->second.push_back(token);
+ }
+ return std::unique_ptr<Dictionary>(new DictionaryImpl<HashVectorBimap>(
+ std::move(bimap), description.special_ids(),
+ description.output_blocklist_ids()));
+ } else {
+ return absl::InvalidArgumentError(
+ "Invalid DictionaryDescription: no vocabulary specified.");
+ }
+}
+} // namespace dictionary
+} // namespace fcp
diff --git a/fcp/dictionary/dictionary.h b/fcp/dictionary/dictionary.h
new file mode 100644
index 0000000..6be61ad
--- /dev/null
+++ b/fcp/dictionary/dictionary.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_DICTIONARY_DICTIONARY_H_
+#define FCP_DICTIONARY_DICTIONARY_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "fcp/dictionary/dictionary.pb.h"
+
+namespace fcp {
+namespace dictionary {
+
+// Interface for mapping tokens (usually words) to indices.
+class Dictionary {
+ public:
+ virtual ~Dictionary() {}
+
+ // Returns the number of elements in the dictionary.
+ virtual int32_t Size() const = 0;
+
+ // Returns the index of token in the dictionary or kNotFound if not found.
+ virtual int32_t TokenToId(const std::string& token) const = 0;
+
+ // Maps an ID to a string if the ID represents a valid token.
+ // Returns "" on error.
+ virtual std::string IdToToken(int32_t id) const = 0;
+
+ // Returns true if the given id is set via DictionaryDescription.SpecialIds.
+ virtual bool IsSpecialId(int32_t id) const = 0;
+
+ // Returns a sorted (ascending) list of ids to filter from the predictions.
+ // Can be used for e.g. punctuation. Includes special ids.
+ virtual const std::vector<int32_t>& GetSortedOutputBlocklistIds() const = 0;
+
+ // Returns the special ids used in this dictionary.
+ virtual const DictionaryDescription::SpecialIds& GetSpecialIds() const = 0;
+
+ // Id returned when an element is not found. This is distinct from the id
+ // of the unknown_token (if one is configured).
+ static constexpr int32_t kNotFound = -1;
+
+ //
+ // Static constructors
+ //
+
+ // Creates a dictionary from a self-describing DictionaryDescription proto.
+ static absl::StatusOr<std::unique_ptr<Dictionary>> Create(
+ const DictionaryDescription& description);
+};
+
+} // namespace dictionary
+} // namespace fcp
+
+#endif // FCP_DICTIONARY_DICTIONARY_H_
diff --git a/fcp/dictionary/dictionary.proto b/fcp/dictionary/dictionary.proto
new file mode 100644
index 0000000..b126d49
--- /dev/null
+++ b/fcp/dictionary/dictionary.proto
@@ -0,0 +1,61 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+syntax = "proto2";
+
+package fcp.dictionary;
+
+// Describes a mapping of strings to (usually consecutive) integer ids.
+message DictionaryDescription {
+ // Vocabulary ids with special meaning.
+ message SpecialIds {
+ // If set and non-negative, id used for an unknown token.
+ optional int32 unk = 1 [default = -1];
+
+ // If set and non-negative, id used for the beginning of a sequence.
+ optional int32 bos = 2 [default = -1];
+
+ // If set and non-negative, id used for the end of a sequence.
+ optional int32 eos = 3 [default = -1];
+ }
+
+ // Vocabulary ids that should be filtered from the predictions (e.g.,
+ // punctuation, bad words etc.).
+ message OutputBlocklistIds {
+ repeated int32 id = 1 [packed = true];
+ }
+
+ // Optional persistent storage format for the token to id map.
+ message Vocabulary {
+ message TokenIndex {
+ repeated string token = 1;
+ }
+
+ reserved 1;
+
+ oneof vocabulary {
+ // Repeated strings stored in-order (index begins at 0).
+ TokenIndex index = 2;
+ }
+ }
+
+ optional SpecialIds special_ids = 1;
+
+ optional Vocabulary vocabulary = 2;
+
+ optional OutputBlocklistIds output_blocklist_ids = 3;
+
+ reserved 4;
+}
diff --git a/fcp/dictionary/dictionary_test.cc b/fcp/dictionary/dictionary_test.cc
new file mode 100644
index 0000000..62e1f39
--- /dev/null
+++ b/fcp/dictionary/dictionary_test.cc
@@ -0,0 +1,114 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/dictionary/dictionary.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/dictionary/dictionary.pb.h"
+#include "fcp/testing/parse_text_proto.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+
+namespace fcp {
+namespace dictionary {
+
+using ::testing::ElementsAre;
+
+class DictionaryTest : public ::testing::Test {};
+
+TEST_F(DictionaryTest, TestMapDictionaryLookup) {
+ std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(PARSE_TEXT_PROTO(
+ "vocabulary: < index: < token: 'a' token: 'b' token: 'c' > >"));
+
+ EXPECT_EQ(0, dictionary->TokenToId("a"));
+ EXPECT_EQ(1, dictionary->TokenToId("b"));
+ EXPECT_EQ(2, dictionary->TokenToId("c"));
+ EXPECT_EQ(Dictionary::kNotFound, dictionary->TokenToId("d"));
+}
+
+TEST_F(DictionaryTest, TestMapDictionaryLookupWithUnk) {
+ std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
+ PARSE_TEXT_PROTO("special_ids: < unk: 0 bos: 1 > "
+ "vocabulary: < index: <"
+ " token: 'a' token: 'b' token: 'c' > >"));
+ EXPECT_EQ(2, dictionary->TokenToId("a"));
+ EXPECT_EQ(3, dictionary->TokenToId("b"));
+ EXPECT_EQ(4, dictionary->TokenToId("c"));
+ EXPECT_EQ(0, dictionary->TokenToId("d"));
+ EXPECT_EQ(0, dictionary->TokenToId("e"));
+ EXPECT_EQ(0, dictionary->TokenToId("<UNK>"));
+ EXPECT_EQ(0, dictionary->TokenToId("<BOS>"));
+}
+
+TEST_F(DictionaryTest, TestMapDictionaryLookupWithSpecialTokenHoles) {
+ std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
+ PARSE_TEXT_PROTO("special_ids: < unk: 1 bos: 4 > "
+ "vocabulary: < index: <"
+ " token: 'a' token: 'b' token: 'c' > >"));
+
+ // Make sure dictionary doesn't use the "holes" in IDs - 0, 2 and 3 - for
+ // tokens, but starts numbering tokens with max(special_ids) + 1.
+ EXPECT_EQ(5, dictionary->TokenToId("a"));
+ EXPECT_EQ(6, dictionary->TokenToId("b"));
+ EXPECT_EQ(7, dictionary->TokenToId("c"));
+ EXPECT_EQ(1, dictionary->TokenToId("d"));
+ EXPECT_EQ(1, dictionary->TokenToId("e"));
+ EXPECT_EQ(1, dictionary->TokenToId("<UNK>"));
+ EXPECT_EQ(1, dictionary->TokenToId("<BOS>"));
+}
+
+TEST_F(DictionaryTest, TestMapDictionaryReverseLookup) {
+ std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(PARSE_TEXT_PROTO(
+ "vocabulary: < index: < token: 'a' token: 'b' token: 'c' > >"));
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("a")), "a");
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("b")), "b");
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("c")), "c");
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("d")), "");
+ EXPECT_EQ(dictionary->IdToToken(0xDEADBEEF), "");
+ EXPECT_EQ(dictionary->IdToToken(1337), "");
+}
+
+TEST_F(DictionaryTest, TestMapDictionaryReverseLookupWithUnk) {
+ std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
+ PARSE_TEXT_PROTO("special_ids: < unk: 0 bos: 1 > "
+ "vocabulary: < index: <"
+ " token: 'a' token: 'b' token: 'c' > >"));
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("a")), "a");
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("b")), "b");
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("c")), "c");
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("d")), "");
+ EXPECT_EQ(dictionary->IdToToken(0xDEADBEEF), "");
+ EXPECT_EQ(dictionary->IdToToken(1337), "");
+}
+
+TEST_F(DictionaryTest, TestMapDictionaryReverseLookupWithSpecialTokenHoles) {
+ std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
+ PARSE_TEXT_PROTO("special_ids: < unk: 1 bos: 4 > "
+ "vocabulary: < index: <"
+ " token: 'a' token: 'b' token: 'c' > >"));
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("a")), "a");
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("b")), "b");
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("c")), "c");
+ EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("d")), "");
+ EXPECT_EQ(dictionary->IdToToken(0xDEADBEEF), "");
+ EXPECT_EQ(dictionary->IdToToken(1337), "");
+}
+} // namespace dictionary
+} // namespace fcp
diff --git a/fcp/java_src/main/java/com/google/fcp/client/BUILD b/fcp/java_src/main/java/com/google/fcp/client/BUILD
new file mode 100644
index 0000000..23a7362
--- /dev/null
+++ b/fcp/java_src/main/java/com/google/fcp/client/BUILD
@@ -0,0 +1,27 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+java_library(
+ name = "client",
+ srcs = [
+ "CallFromNativeWrapper.java",
+ ],
+)
diff --git a/fcp/java_src/main/java/com/google/fcp/client/CallFromNativeWrapper.java b/fcp/java_src/main/java/com/google/fcp/client/CallFromNativeWrapper.java
new file mode 100644
index 0000000..aa1ee52
--- /dev/null
+++ b/fcp/java_src/main/java/com/google/fcp/client/CallFromNativeWrapper.java
@@ -0,0 +1,76 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.fcp.client;
+
+import java.lang.Thread.UncaughtExceptionHandler;
+import java.util.concurrent.Callable;
+
+/**
+ * Utility class to wrap java method calls that originate from the native layer over JNI, which
+ * ensures that any uncaught {@link Throwable} is passed to the given {@link
+ * UncaughtExceptionHandler} (which is expected to generate a crash report and terminate the
+ * process), as opposed to being passed back to the native layer.
+ */
+public class CallFromNativeWrapper {
+
+ private final UncaughtExceptionHandler uncaughtExceptionHandler;
+
+ /** A {@link Callable} that does not throw checked exceptions. */
+ public interface NativeToJavaCallable<T> extends Callable<T> {
+ @Override
+ T call();
+ }
+
+ public CallFromNativeWrapper(UncaughtExceptionHandler uncaughtExceptionHandler) {
+ this.uncaughtExceptionHandler = uncaughtExceptionHandler;
+ }
+
+ /**
+ * Wraps a java method call from native code on an arbitrary thread (i.e. one created by
+ * TensorFlow). If a {@link Throwable} is thrown the exception will be passed to the {@code
+ * uncaughtExceptionHandler}.
+ */
+ public <T> T wrapCall(NativeToJavaCallable<T> callable) {
+ try {
+ return callable.call();
+ } catch (Throwable t) {
+ uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t);
+ // The uncaught exception handler generally will have killed us by here.
+ //
+ // On Android, the system should see our thread crash and kill the process before reaching
+ // here. Just in case we make it this far, we wrap the exception in a runtime exception and
+ // let it pass to the native layer (which will generally then abort the process upon detecting
+ // the exception).
+ throw new CallFromNativeRuntimeException(t);
+ }
+ }
+
+ public void wrapVoidCall(Runnable runnable) {
+ wrapCall(
+ () -> {
+ runnable.run();
+ return null;
+ });
+ }
+
+ /**
+ * A {@link RuntimeException} signifying there was an unchecked exception when calling from native
+ * to java.
+ */
+ public static class CallFromNativeRuntimeException extends RuntimeException {
+ CallFromNativeRuntimeException(Throwable t) {
+ super(t);
+ }
+ }
+}
diff --git a/fcp/java_src/main/java/com/google/fcp/client/http/BUILD b/fcp/java_src/main/java/com/google/fcp/client/http/BUILD
new file mode 100644
index 0000000..0031765
--- /dev/null
+++ b/fcp/java_src/main/java/com/google/fcp/client/http/BUILD
@@ -0,0 +1,48 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+java_library(
+ name = "jni_interface",
+ srcs = [
+ "HttpClientForNative.java",
+ ],
+ deps = [
+ "@fcp_maven//:com_google_errorprone_error_prone_annotations",
+ ],
+)
+
+java_library(
+ name = "http",
+ srcs = [
+ "HttpClientForNativeImpl.java",
+ "HttpRequestHandleImpl.java",
+ ],
+ deps = [
+ ":jni_interface",
+ "//fcp/client/http/java:java_http_client",
+ "//fcp/client/http/java:jni_java_proto",
+ "//fcp/java_src/main/java/com/google/fcp/client",
+ "@com_google_googleapis//google/rpc:rpc_java_proto",
+ "@com_google_protobuf//:protobuf_java",
+ "@fcp_maven//:com_google_code_findbugs_jsr305",
+ "@fcp_maven//:com_google_guava_guava",
+ ],
+)
diff --git a/fcp/java_src/main/java/com/google/fcp/client/http/HttpClientForNative.java b/fcp/java_src/main/java/com/google/fcp/client/http/HttpClientForNative.java
new file mode 100644
index 0000000..ab40d96
--- /dev/null
+++ b/fcp/java_src/main/java/com/google/fcp/client/http/HttpClientForNative.java
@@ -0,0 +1,233 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.fcp.client.http;
+
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+import java.io.Closeable;
+
+/**
+ * A base class for building a Java/JNI-based implementation of the C++ {@code HttpClient}
+ * interface.
+ *
+ * <p>This class is defined in conjunction with the {@code java_http_client.cc/h} C++ code that
+ * invokes it via JNI.
+ *
+ * <p>A note on thread safety:
+ *
+ * <ol>
+ * <li>Incoming calls from the native layer can generally come from any thread, and hence
+ * implementations of these classes must be thread safe.
+ * <li>Outgoing calls to the native layer (e.g. {@link #readRequestBody}, {@link
+ * #onResponseStarted}, etc.) may also be made from any thread, but for a single {@link
+ * HttpRequestHandle} there must never be any concurrent outgoing calls from more than one
+ * thread (hence they are {@code @GuardedBy("this")}).
+ * <li>Outgoing calls to the native layer must only be made once an {@link #performRequests} has
+ * been called on a given {@link HttpRequestHandle}, and not before.
+ * </ol>
+ */
+public abstract class HttpClientForNative implements Closeable {
+ /**
+ * A base class for building a Java/JNI-based implementation of the C++ {@code HttpRequestHandle}
+ * and related interfaces.
+ *
+ * <p>This class is defined in conjunction with the {@code java_http_client.cc/h} C++ code that
+ * invokes it via JNI.
+ */
+ public abstract static class HttpRequestHandle implements Closeable {
+ /**
+ * Called by the native layer to get the request's latest total sent/received bytes stats. May
+ * be called multiple times, and from any thread.
+ *
+ * <p>See C++'s {@code HttpRequestHandle::TotalSentReceivedBytes}.
+ *
+ * @return a serialized {@link JniHttpSentReceivedBytes} proto.
+ */
+ public abstract byte[] getTotalSentReceivedBytes();
+
+ /**
+ * Called by the native layer when the request isn't needed anymore. May be called multiple
+ * times, and from any thread.
+ */
+ @Override
+ public abstract void close();
+
+ /**
+ * Reads up to {@code requestedBytes} of request body data into {@code buffer}, via the native
+ * layer. If the end of the data is reached, then -1 will be placed in the mutable
+ * single-element {@code actualBytesRead} array (this corresponds to C++'s {@code
+ * HttpRequest::ReadBody} returning {@code OUT_OF_RANGE}). Otherwise, at least 1 byte of data
+ * will have been read, and the actual amount of bytes that were read will be placed in the
+ * {@code actualBytesRead} array.
+ *
+ * <p>If the return value is false, then {@link HttpClientForNative} implementation must not
+ * call {@link #onResponseError} anymore, as the native layer will already have called the
+ * corresponding C++ callback.
+ *
+ * <p>See C++'s {@code HttpRequest::ReadBody}.
+ *
+ * <p>Must only be called <strong>after</strong> {@link #performRequests} is called on this
+ * handle. Only one of the callback methods on this handle may be called at any given time (but
+ * they may be called from any thread).
+ *
+ * @return true if the read succeeded (incl. if the end of data was reached), false if the read
+ * failed (in which case the request should be aborted without calling any more callback
+ * methods).
+ */
+ // Note: can be overridden in unit tests, to intercept/mock out calls to the native layer.
+ @GuardedBy("this")
+ protected boolean readRequestBody(byte[] buffer, long requestedBytes, int[] actualBytesRead) {
+ return HttpClientForNative.readRequestBody(
+ nativeHandle, buffer, requestedBytes, actualBytesRead);
+ }
+
+ /**
+ * Signals to the native layer that the response headers (provided as a serialized {@link
+ * JniHttpResponse}) have been received.
+ *
+ * <p>See C++'s {@code HttpRequestCallback::OnResponseStarted}.
+ *
+ * <p>Must only be called <strong>after</strong> {@link #performRequests} is called on this
+ * handle. Only one of the callback methods on this handle may be called at any given time (but
+ * they may be called from any thread).
+ *
+ * @return true if the response headers were successfully processed, false if not (in which case
+ * the request should be aborted without calling any more callback methods).
+ */
+ // Note: can be overridden in unit tests, to intercept/mock out calls to the native layer.
+ @GuardedBy("this")
+ protected boolean onResponseStarted(byte[] responseProto) {
+ return HttpClientForNative.onResponseStarted(nativeHandle, responseProto);
+ }
+
+ /**
+ * Signals to the native layer that an error (provided as a serialized {@link
+ * com.google.rpc.Status} proto) occurred before the response headers were received.
+ *
+ * <p>See C++'s {@code HttpRequestCallback::OnResponseError}.
+ *
+ * <p>Must only be called <strong>after</strong> {@link #performRequests} is called on this
+ * handle. Only one of the callback methods on this handle may be called at any given time (but
+ * they may be called from any thread).
+ */
+ // Note: can be overridden in unit tests, to intercept/mock out calls to the native layer.
+ @GuardedBy("this")
+ protected void onResponseError(byte[] statusProto) {
+ HttpClientForNative.onResponseError(nativeHandle, statusProto);
+ }
+
+ /**
+ * Provides {@code bytesAvailable} bytes of the response body to the native layer, via {@code
+ * data}.
+ *
+ * <p>See C++'s {@code HttpRequestCallback::OnResponseBody}.
+ *
+ * <p>Must only be called <strong>after</strong> {@link #performRequests} is called on this
+ * handle. Only one of the callback methods on this handle may be called at any given time (but
+ * they may be called from any thread).
+ *
+ * @return true if the data was successfully processed, or false if not (in which case the
+ * request should be aborted without calling any more callback methods).
+ */
+ // Note: can be overridden in unit tests, to intercept/mock out calls to the native layer.
+ @GuardedBy("this")
+ protected boolean onResponseBody(byte[] data, int bytesAvailable) {
+ return HttpClientForNative.onResponseBody(nativeHandle, data, bytesAvailable);
+ }
+
+ /**
+ * Signals to the native layer that an error (provided as a serialized {@link
+ * com.google.rpc.Status} proto) occurred while reading the response body.
+ *
+ * <p>See C++'s {@code HttpRequestCallback::OnResponseBodyError}.
+ *
+ * <p>Must only be called <strong>after</strong> {@link #performRequests} is called on this
+ * handle. Only one of the callback methods on this handle may be called at any given time (but
+ * they may be called from any thread).
+ */
+ // Note: can be overridden in unit tests, to intercept/mock out calls to the native layer.
+ @GuardedBy("this")
+ protected void onResponseBodyError(byte[] statusProto) {
+ HttpClientForNative.onResponseBodyError(nativeHandle, statusProto);
+ }
+
+ /**
+ * Signals to the native layer that the request completed successfully.
+ *
+ * <p>See C++'s {@code HttpRequestCallback::OnResponseBodyCompleted}.
+ *
+ * <p>Must only be called <strong>after</strong> {@link #performRequests} is called on this
+ * handle. Only one of the callback methods on this handle may be called at any given time (but
+ * they may be called from any thread).
+ */
+ // Note: can be overridden in unit tests, to intercept/mock out calls to the native layer.
+ @GuardedBy("this")
+ protected void onResponseCompleted() {
+ HttpClientForNative.onResponseCompleted(nativeHandle);
+ }
+
+ /**
+ * A field that native code uses to associate a native pointer with this object. This field must
+ * never be modified by Java code.
+ */
+ // Note: this field is volatile to ensure that if it is read from a different thread than the
+ // one that wrote to it earlier, the second thread will see the updated value.
+ private volatile long nativeHandle = 0;
+ }
+
+ /**
+ * Creates an {@link HttpRequestHandle} for use with {@link #performRequests}.
+ *
+ * <p>May be called from any thread.
+ *
+ * @param requestProto a serialized {@link JniHttpRequest} proto.
+ */
+ public abstract HttpRequestHandle enqueueRequest(byte[] requestProto);
+
+ /**
+ * Performs the requests corresponding to the given objects, which must be {@link
+ * HttpRequestHandle} instances previously returned by {@link #enqueueRequest}.
+ *
+ * <p>May be called from any thread.
+ *
+ * @return a serialized {@link com.google.rpc.Status} proto indicating success or failure.
+ */
+ // NOTE: The parameter type is an 'Object[]' array, because this makes it easier for the native
+ // code calling this over JNI to construct the array (it can simply look up the 'Object') class.
+ // The Java implementation is expected to downcast the objects in the array to its RequestHandle
+ // implementation class.
+ public abstract byte[] performRequests(Object[] requests);
+
+ /**
+ * Called by native when the client is no longer used and all resources can be released. May be
+ * called multiple times, and from any thread.
+ */
+ @Override
+ public abstract void close();
+
+ // The actual native callback methods, which the HttpRequestHandle class provides wrappers for.
+ // See that class's docs for more info.
+ private static native boolean readRequestBody(
+ long nativeRequestHandle, byte[] buffer, long requestedBytes, int[] actualBytesRead);
+
+ private static native boolean onResponseStarted(long nativeRequestHandle, byte[] responseProto);
+
+ private static native void onResponseError(long nativeRequestHandle, byte[] statusProto);
+
+ private static native boolean onResponseBody(
+ long nativeRequestHandle, byte[] data, int bytesAvailable);
+
+ private static native void onResponseBodyError(long nativeRequestHandle, byte[] statusProto);
+
+ private static native void onResponseCompleted(long nativeRequestHandle);
+}
diff --git a/fcp/java_src/main/java/com/google/fcp/client/http/HttpClientForNativeImpl.java b/fcp/java_src/main/java/com/google/fcp/client/http/HttpClientForNativeImpl.java
new file mode 100644
index 0000000..a235d7f
--- /dev/null
+++ b/fcp/java_src/main/java/com/google/fcp/client/http/HttpClientForNativeImpl.java
@@ -0,0 +1,114 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.fcp.client.http;
+
+import com.google.fcp.client.CallFromNativeWrapper;
+import com.google.protobuf.ExtensionRegistryLite;
+import com.google.protobuf.InvalidProtocolBufferException;
+import com.google.rpc.Code;
+import com.google.rpc.Status;
+import java.net.HttpURLConnection;
+import java.util.ArrayList;
+
+/**
+ * An implementation of {@link HttpClientForNativeImpl} that uses {@link HttpURLConnection} for
+ * issuing network requests.
+ */
+public final class HttpClientForNativeImpl extends HttpClientForNative {
+
+ /** Used to bubble up unexpected errors and exceptions. */
+ static final class UncheckedHttpClientForNativeException extends RuntimeException {
+ UncheckedHttpClientForNativeException(String message) {
+ super(message);
+ }
+
+ UncheckedHttpClientForNativeException(String message, Throwable cause) {
+ super(message, cause);
+ }
+ }
+
+ /** A factory for creating an {@link HttpRequestHandleImpl} for a given {@link JniHttpRequest}. */
+ public interface HttpRequestHandleImplFactory {
+ /**
+ * Creates a new request handle, which must be an instance of {@link HttpRequestHandleImpl} or
+ * one of its subclasses. This indirection is used to provide a different subclass in unit
+ * tests.
+ */
+ HttpRequestHandleImpl create(JniHttpRequest request);
+ }
+
+ private final CallFromNativeWrapper callFromNativeWrapper;
+ private final HttpRequestHandleImplFactory requestHandleFactory;
+
+ /**
+ * Creates a new instance, configured with the provided parameters.
+ *
+ * @param callFromNativeWrapper the wrapper to use for all calls that arrive over JNI, to ensure
+ * uncaught exceptions are handled correctly.
+ * @param requestHandleFactory the factory to use to create new {@link HttpRequestHandleImpl} for
+ * a given {@link JniHttpRequest}.
+ */
+ public HttpClientForNativeImpl(
+ CallFromNativeWrapper callFromNativeWrapper,
+ HttpRequestHandleImplFactory requestHandleFactory) {
+ this.callFromNativeWrapper = callFromNativeWrapper;
+ this.requestHandleFactory = requestHandleFactory;
+ }
+
+ @Override
+ public HttpRequestHandleImpl enqueueRequest(byte[] requestProto) {
+ return callFromNativeWrapper.wrapCall(
+ () -> {
+ // Parse the request given to us over JNI.
+ JniHttpRequest request;
+ try {
+ request =
+ JniHttpRequest.parseFrom(requestProto, ExtensionRegistryLite.getEmptyRegistry());
+ } catch (InvalidProtocolBufferException e) {
+ // If parsing failed then the native code did something horribly wrong, just let the
+ // exception bubble up to the unchecked exception handler.
+ throw new UncheckedHttpClientForNativeException("invalid JniHttpRequest", e);
+ }
+ return requestHandleFactory.create(request);
+ });
+ }
+
+ @Override
+ public byte[] performRequests(Object[] requestsParam) {
+ return callFromNativeWrapper.wrapCall(
+ () -> {
+ ArrayList<HttpRequestHandleImpl> handles = new ArrayList<>(requestsParam.length);
+ for (Object requestHandle : requestsParam) {
+ // Note: if this cast fails, then it means that the native layer has somehow passed us a
+ // different object than we returned from enqueueRequest, which would indicate a bug. In
+ // those cases we just let the exception bubble up to create a crash report.
+ HttpRequestHandleImpl handle = (HttpRequestHandleImpl) requestHandle;
+ handles.add(handle);
+ // Handle each request on the ExecutorService (i.e. on background threads).
+ handle.performRequest();
+ }
+ // Wait for each request to finish.
+ for (HttpRequestHandleImpl handle : handles) {
+ handle.waitForRequestCompletion();
+ }
+
+ return Status.newBuilder().setCode(Code.OK_VALUE).build().toByteArray();
+ });
+ }
+
+ @Override
+ public void close() {
+ // Nothing to do here.
+ }
+}
diff --git a/fcp/java_src/main/java/com/google/fcp/client/http/HttpRequestHandleImpl.java b/fcp/java_src/main/java/com/google/fcp/client/http/HttpRequestHandleImpl.java
new file mode 100644
index 0000000..bf12598
--- /dev/null
+++ b/fcp/java_src/main/java/com/google/fcp/client/http/HttpRequestHandleImpl.java
@@ -0,0 +1,1052 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.fcp.client.http;
+
+import static com.google.common.base.Strings.nullToEmpty;
+
+import com.google.common.base.Ascii;
+import com.google.common.io.CountingInputStream;
+import com.google.fcp.client.CallFromNativeWrapper;
+import com.google.fcp.client.http.HttpClientForNative.HttpRequestHandle;
+import com.google.fcp.client.http.HttpClientForNativeImpl.UncheckedHttpClientForNativeException;
+import com.google.rpc.Code;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.CookieHandler;
+import java.net.HttpURLConnection;
+import java.net.ProtocolException;
+import java.net.SocketTimeoutException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.zip.GZIPInputStream;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+
+/**
+ * An implementation of {@link HttpRequestHandle} that uses {@link HttpURLConnection} (the
+ * implementation of which is provided via the {@link HttpURLConnectionFactory} indirection) for
+ * issuing network requests.
+ *
+ * <p>Note: this class is non-final to allow the native callback methods defined in {@link
+ * HttpRequestHandle} to be overridden in unit tests. This is class is otherwise not meant to be
+ * extended and should hence be considered effectively 'final'.
+ */
+public class HttpRequestHandleImpl extends HttpRequestHandle {
+
+ /**
+ * A factory for creating an {@link HttpURLConnection} for a given URI.
+ *
+ * <p>To use the system's default {@link HttpURLConnection} implementation one can simply use
+ * {@code new URL(uri).openConnection()} as the factory implementation.
+ */
+ public interface HttpURLConnectionFactory {
+ HttpURLConnection createUrlConnection(String uri) throws IOException;
+ }
+
+ // The Content-Length header name.
+ private static final String CONTENT_LENGTH_HEADER = "Content-Length";
+ // The Accept-Encoding header name.
+ private static final String ACCEPT_ENCODING_HEADER = "Accept-Encoding";
+ // The Content-Encoding response header name.
+ private static final String CONTENT_ENCODING_HEADER = "Content-Encoding";
+ // The Accept-Encoding and Content-Encoding value to indicate gzip-based compression.
+ private static final String GZIP_ENCODING = "gzip";
+ // The Transfer-Encoding header name.
+ private static final String TRANSFER_ENCODING_HEADER = "Transfer-Encoding";
+ // The Transfer-Encoding value indicating "chunked" encoding.
+ private static final String CHUNKED_TRANSFER_ENCODING = "chunked";
+
+ /** Used to indicate that the request was invalid in some way. */
+ private static final class InvalidHttpRequestException extends Exception {
+ private InvalidHttpRequestException(String message) {
+ super(message);
+ }
+
+ private InvalidHttpRequestException(String message, Throwable cause) {
+ super(message, cause);
+ }
+ }
+
+ /**
+ * Used to indicate that the request was cancelled or encountered an unrecoverable error in the
+ * middle of an operation. The request should be aborted without invoking any further callbacks to
+ * the native layer.
+ */
+ private static final class AbortRequestException extends Exception {}
+
+ private enum State {
+ /**
+ * The state when this object is created, but before it has been passed to {@link
+ * HttpClientForNative#performRequests}.
+ */
+ NOT_STARTED,
+ /**
+ * The state before any response headers have been received. Errors should go to the {@link
+ * #onResponseError} callback.
+ */
+ BEFORE_RESPONSE_HEADERS,
+ /**
+ * The state after any response headers have been received. Errors should go to the {@link
+ * #onResponseBodyError} callback.
+ */
+ AFTER_RESPONSE_HEADERS,
+ /**
+ * The state after the request was finished (either successfully, with an error, or via
+ * cancellation), and no more callbacks should be invoked.
+ */
+ CLOSED
+ }
+
+ private final JniHttpRequest request;
+ private final CallFromNativeWrapper callFromNativeWrapper;
+ private final ExecutorService executorService;
+ private final HttpURLConnectionFactory urlConnectionFactory;
+ private final int connectTimeoutMs;
+ private final int readTimeoutMs;
+ private final int requestBodyChunkSizeBytes;
+ private final int responseBodyChunkSizeBytes;
+ private final int responseBodyGzipBufferSizeBytes;
+ private final boolean callDisconnectWhenCancelled;
+ private final boolean supportAcceptEncodingHeader;
+ private final double estimatedHttp2HeaderCompressionRatio;
+
+ // Until we have an actual connection, this is a no-op.
+ @GuardedBy("this")
+ private Runnable disconnectRunnable = () -> {};
+
+ @GuardedBy("this")
+ private State state = State.NOT_STARTED;
+
+ @GuardedBy("this")
+ @Nullable
+ private Future<?> ongoingWork;
+
+ // These are "volatile" and not synchronized so that they can be read easily from any thread even
+ // if the lock is currently held. They're only incremented from a single thread, so their being
+ // volatile is sufficient to safely increment/update them.
+ private volatile long sentHeaderBytes = 0;
+ private volatile long sentBodyBytes = 0;
+ private volatile long receivedHeaderBytes = 0;
+ private volatile long receivedBodyBytes = 0;
+ private volatile boolean requestUsedHttp2Heuristic = false;
+
+ /**
+ * Creates a new handle representing a single request. See {@link HttpClientForNativeImpl} for a
+ * description of the parameters.
+ *
+ * @param request the {@link JniHttpRequest} the handle is being created for.
+ * @param callFromNativeWrapper the wrapper to use for all calls that arrive over JNI, to ensure
+ * uncaught exceptions are handled correctly.
+ * @param executorService the {@link ExecutorService} to use for background work.
+ * @param urlConnectionFactory the factory to use to instance new {@link HttpURLConnection}s.
+ * @param connectTimeoutMs the value to use with {@link HttpURLConnection#setConnectTimeout(int)}.
+ * If this is -1 then {@code setConnectTimeout} will not be called at all.
+ * @param readTimeoutMs the value to use with {@link HttpURLConnection#setReadTimeout(int)}. If
+ * this is -1 then {@code setReadTimeout} will not be called at all.
+ * <p>If {@code getInputStream().read(...)} or other methods like {@code getResponseCode()}
+ * take longer than this amount of time, they will throw a {@link
+ * java.net.SocketTimeoutException} and request will fail. Setting it to -1 will result in an
+ * infinite timeout being used.
+ * <p>Note that this only affects the reads of the response body, and does not affect the
+ * writes of the request body.
+ * @param requestBodyChunkSizeBytes the value to use with {@link
+ * HttpURLConnection#setChunkedStreamingMode(int)}, when chunked transfer encoding is used to
+ * upload request bodies. This also determines the amount of request body data we'll read from
+ * the native layer before pushing it onto the network's {@link java.io.OutputStream}.
+ * @param responseBodyChunkSizeBytes determines the amount of response body data we'll try to read
+ * from the network's {@link java.io.InputStream} (or from the {@link
+ * java.util.zip.GZIPInputStream} wrapping the network's {@code InputStream}) before pushing
+ * it to the native layer.
+ * @param responseBodyGzipBufferSizeBytes determines the amount of response body data the {@link
+ * java.util.zip.GZIPInputStream} wrapping the network's {@link java.io.InputStream} will try
+ * to read before starting another round of decompression (in case we receive a compressed
+ * response body that we need to decompress on the fly).
+ * @param callDisconnectWhenCancelled whether to call {@link HttpURLConnection#disconnect()} (from
+ * a different thread than the request is being run on) when a request gets cancelled. See
+ * note in {@link HttpRequestHandleImpl#close()}.
+ * @param supportAcceptEncodingHeader whether to set the "Accept-Encoding" request header by
+ * default. Some {@link HttpURLConnection} implementations don't allow setting it, and this
+ * flag allows turning that behavior off. When this setting is false, the assumption is that
+ * the implementation at the very least sets "Accept-Encoding: gzip" (as required by the C++
+ * `HttpClient` contract).
+ * @param estimatedHttp2HeaderCompressionRatio the compression ratio to account for in the
+ * calculation of sent/received bytes estimates for the header data, in case HTTP/2 is used
+ * for the request. HTTP/2 supports HPACK, and hence counting the header data in uncompressed
+ * form likely results in over-estimates. This only affects requests that are determined to
+ * have used HTTP/2, which is based on the somewhat fragile heuristic of whether {@link
+ * HttpURLConnection#getResponseMessage()} is empty (since HTTP/2 does not support status line
+ * 'reason phrases').
+ */
+ public HttpRequestHandleImpl(
+ JniHttpRequest request,
+ CallFromNativeWrapper callFromNativeWrapper,
+ ExecutorService executorService,
+ HttpURLConnectionFactory urlConnectionFactory,
+ int connectTimeoutMs,
+ int readTimeoutMs,
+ int requestBodyChunkSizeBytes,
+ int responseBodyChunkSizeBytes,
+ int responseBodyGzipBufferSizeBytes,
+ boolean callDisconnectWhenCancelled,
+ boolean supportAcceptEncodingHeader,
+ double estimatedHttp2HeaderCompressionRatio) {
+ this.request = request;
+ this.callFromNativeWrapper = callFromNativeWrapper;
+ this.executorService = executorService;
+ this.urlConnectionFactory = urlConnectionFactory;
+ this.connectTimeoutMs = connectTimeoutMs;
+ this.readTimeoutMs = readTimeoutMs;
+ this.requestBodyChunkSizeBytes = requestBodyChunkSizeBytes;
+ this.responseBodyChunkSizeBytes = responseBodyChunkSizeBytes;
+ this.responseBodyGzipBufferSizeBytes = responseBodyGzipBufferSizeBytes;
+ this.callDisconnectWhenCancelled = callDisconnectWhenCancelled;
+ this.supportAcceptEncodingHeader = supportAcceptEncodingHeader;
+ this.estimatedHttp2HeaderCompressionRatio = estimatedHttp2HeaderCompressionRatio;
+ }
+
+ @Override
+ public final void close() {
+ // This method is called when the request should be cancelled and/or is otherwise not
+ // needed anymore. It may be called from any thread.
+ callFromNativeWrapper.wrapVoidCall(
+ () -> {
+ synchronized (this) {
+ // If the request was already closed, then this means that the request was either
+ // already interrupted before, or that the request completed successfully. In both
+ // cases there's nothing left to do for us.
+ if (state == State.CLOSED) {
+ return;
+ }
+ // Otherwise, this indicates that the request is being *cancelled* while it was still
+ // running.
+
+ // We mark the connection closed, to prevent any further callbacks to the native layer
+ // from being issued. We do this *before* invoking the callback, just in case our
+ // invoking the callback causes this close() method to be invoked again by the native
+ // layer (we wouldn't want to enter an infinite loop)
+ State oldState = state;
+ state = State.CLOSED;
+ // We signal the closure/cancellation to the native layer right away, using the
+ // appropriate callback for the state we were in.
+ doError(Code.CANCELLED, "request cancelled via close()", oldState);
+ // We unblock the blocked thread on which HttpClientForNativeImpl#performRequests was
+ // called (although that thread may be blocked on other, still-pending requests).
+ if (ongoingWork != null) {
+ ongoingWork.cancel(/* mayInterruptIfRunning=*/ true);
+ }
+
+ // Note that HttpURLConnection isn't documented to be thread safe, and hence it isn't
+ // 100% clear that calling its #disconnect() method from a different thread (as we are
+ // about to do here) will correctly either. However, it seems to be the only way to
+ // interrupt an ongoing request when it is blocked writing to or reading from the
+ // network socket.
+ //
+ // At least on Android the OkHttp-based implementation does seem to be thread safe (it
+ // uses OkHttp's HttpEngine.cancel() method, which is thread safe). The JDK
+ // implementation seems to not be thread safe (but behaves well enough?). The
+ // callDisconnectWhenCancelled parameter can be used to control this behavior.
+ if (callDisconnectWhenCancelled) {
+ disconnectRunnable.run();
+ }
+
+ // Handling cancellations/closures this way ensures that the native code is unblocked
+ // even before the network requests have been fully aborted. Any still-pending HTTP
+ // connections will be cleaned up in their corresponding background threads.
+ }
+ });
+ }
+
+ @Override
+ public byte[] getTotalSentReceivedBytes() {
+ double headerCompressionRatio =
+ requestUsedHttp2Heuristic ? estimatedHttp2HeaderCompressionRatio : 1.0;
+ // Note that this estimate of sent/received bytes is not necessarily monotonically increasing:
+ // - We'll initially estimate the amount of received response body bytes based on the bytes we
+ // observe in the response InputStream (which may count the uncompressed response bytes). This
+ // will account, as best as possible, for how much has data been received so far (incl. in
+ // case the request gets cancelled mid-flight), although it may be an over-estimate due to not
+ // accounting for response body compression (depending on the HttpURLConnection
+ // implementation, e.g. in case of Cronet's).
+ // - Once the request has completed successfully, we'll estimate the received response body
+ // bytes based on the Content-Length response header, if there was one. This gives us a chance
+ // to revise our estimate down to a more accurate value, if the HttpURLConnection
+ // implementation exposes the original Content-Length header to us (e.g. in the case of
+ // Cronet).
+ // - Once we know from the response headers that the request used HTTP/2, we'll apply the header
+ // compression ratio. But before we know that, we don't apply it.
+ //
+ // Note that the estimates we provide here also won't take into account various other sources of
+ // network usage: the bytes transmitted to establish TLS channels, request/responses for
+ // followed HTTP redirects, HTTP/1.1-to-HTTP/2 upgrades etc.
+ return JniHttpSentReceivedBytes.newBuilder()
+ .setSentBytes((long) (sentHeaderBytes * headerCompressionRatio) + sentBodyBytes)
+ .setReceivedBytes((long) (receivedHeaderBytes * headerCompressionRatio) + receivedBodyBytes)
+ .build()
+ .toByteArray();
+ }
+
+ final synchronized void performRequest() {
+ if (state != State.NOT_STARTED) {
+ throw new IllegalStateException("must not call perform() more than once");
+ }
+ state = State.BEFORE_RESPONSE_HEADERS;
+ ongoingWork = executorService.submit(this::runRequestToCompletion);
+ }
+
+ final void waitForRequestCompletion() {
+ // Get a copy of the Future, if it is set. Then call .get() without holding the lock.
+ Future<?> localOngoingWork;
+ synchronized (this) {
+ if (ongoingWork == null) {
+ throw new IllegalStateException("must not call waitForCompletion() before perform()");
+ }
+ localOngoingWork = ongoingWork;
+ }
+ try {
+ localOngoingWork.get();
+ } catch (ExecutionException e) {
+ // This shouldn't happen, since the run(...) method shouldn't throw any exceptions. If one
+ // does get thrown, it is a RuntimeException or Error, in which case we'll just let it bubble
+ // up to the uncaught exception handler.
+ throw new UncheckedHttpClientForNativeException("unexpected exception", e);
+ } catch (InterruptedException e) {
+ // This shouldn't happen, since no one should be interrupting the calling thread.
+ throw new UncheckedHttpClientForNativeException("unexpected interruption", e);
+ } catch (CancellationException e) {
+ // Do nothing. This will happen when a request gets cancelled in the middle of execution, but
+ // in those cases there's nothing left for us to do, and we should just gracefully return.
+ // This will allow #performRequests(...) to be unblocked, while the background thread may
+ // still be cleaning up some resources.
+ }
+ }
+
+ /** Convenience method for checking for the closed state, in a synchronized fashion. */
+ private synchronized boolean isClosed() {
+ return state == State.CLOSED;
+ }
+
+ /**
+ * Convenience method for checking for the closed state in a synchronized fashion, throwing an
+ * {@link AbortRequestException} if the request is closed.
+ */
+ private synchronized void checkClosed() throws AbortRequestException {
+ if (state == State.CLOSED) {
+ throw new AbortRequestException();
+ }
+ }
+
+ /**
+ * Calls either the {@link #onResponseError} or {@link #onResponseBodyError} callback, including
+ * the originating Java exception description in the status message. Which callback is used
+ * depends on the current {@link #state}.
+ */
+ private synchronized void doError(String message, Exception e) {
+ // We mark the state as CLOSED, since no more callbacks should be invoked after signaling an
+ // error. We do this before issuing the callback to the native layer, to ensure that if that
+ // call results in another call to the Java layer, we don't emit any callbacks anymore.
+ State oldState = state;
+ state = State.CLOSED;
+ Code code = Code.UNAVAILABLE;
+ if (e instanceof SocketTimeoutException) {
+ code = Code.DEADLINE_EXCEEDED;
+ } else if (e instanceof InvalidHttpRequestException) {
+ code = Code.INVALID_ARGUMENT;
+ }
+ doError(code, String.format("%s (%s)", message, e), oldState);
+ }
+
+ @GuardedBy("this")
+ private void doError(Code code, String message, State state) {
+ byte[] error =
+ Status.newBuilder().setCode(code.getNumber()).setMessage(message).build().toByteArray();
+ switch (state) {
+ case BEFORE_RESPONSE_HEADERS:
+ onResponseError(error);
+ break;
+ case AFTER_RESPONSE_HEADERS:
+ onResponseBodyError(error);
+ break;
+ case NOT_STARTED:
+ case CLOSED:
+ // If the request had already been closed, or if it hadn't been passed to {@link
+ // HttpClientForNative#performRequests} yet, then we shouldn't issue any (further)
+ // callbacks.
+ break;
+ }
+ }
+
+ /** Calls the {@link #readRequestBody} callback, but only if the request isn't closed yet. */
+ private synchronized void doReadRequestBody(
+ byte[] buffer, long requestedBytes, int[] actualBytesRead) throws AbortRequestException {
+ // If the request has already been closed, then we shouldn't issue any further callbacks.
+ checkClosed();
+ checkCallToNativeResult(readRequestBody(buffer, requestedBytes, actualBytesRead));
+ }
+
+ /** Calls the {@link #onResponseStarted} callback, but only if the request isn't closed yet. */
+ private synchronized void doOnResponseStarted(byte[] responseProto) throws AbortRequestException {
+ // Ensure that we call the onResponseStarted callback *and* update the object state as a
+ // single atomic transaction, so that any errors/cancellations occurring before or after
+ // this block result in the correct error callback being called.
+
+ // If the request has already been closed, then we shouldn't issue any further callbacks.
+ checkClosed();
+ // After this point, any errors we signal to the native layer should go through
+ // 'onResponseBodyError', so we update the object state. We do this before invoking the
+ // callback, to ensure that if our call into the native layer causes a call back into Java
+ // that then triggers an error callback, we invoke the right one.
+ state = State.AFTER_RESPONSE_HEADERS;
+ checkCallToNativeResult(onResponseStarted(responseProto));
+ }
+
+ /** Calls the {@link #onResponseBody} callback, but only if the request isn't closed yet. */
+ private synchronized void doOnResponseBody(byte[] buffer, int bytesAvailable)
+ throws AbortRequestException {
+ // If the request has already been closed, then we shouldn't issue any further callbacks.
+ checkClosed();
+ checkCallToNativeResult(onResponseBody(buffer, bytesAvailable));
+ }
+
+ /** Calls the {@link #onResponseCompleted} callback, but only if the request isn't closed yet. */
+ private synchronized void doOnResponseCompleted(long originalContentLengthHeader) {
+ // If the request has already been closed, then we shouldn't issue any further callbacks.
+ if (state == State.CLOSED) {
+ return;
+ }
+ // If we did receive a Content-Length header, then once we've fully completed the request, we
+ // can use it to estimate the total received bytes (and it will be the most accurate estimate
+ // available to us).
+ //
+ // E.g. the Cronet HttpURLConnection implementation will return the original Content-Length
+ // header, even though it decompresses any response body Content-Encoding for us and doesn't let
+ // use see the original compressed bytes.
+ //
+ // If there was no Content-Length header at all, then we must go by our own calculation of the
+ // number of received bytes (i.e. based on the bytes we observed in the response InputStream).
+ if (originalContentLengthHeader > -1) {
+ receivedBodyBytes = originalContentLengthHeader;
+ }
+ // If the request hadn't already been closed, it should be considered closed now (since we're
+ // about to call the final callback).
+ state = State.CLOSED;
+ onResponseCompleted();
+ }
+
+ /**
+ * Transitions to the CLOSED {@link #state} and throws an AbortRequestException, if the given
+ * result from a call to the native layer is false.
+ */
+ @GuardedBy("this")
+ private void checkCallToNativeResult(boolean result) throws AbortRequestException {
+ if (!result) {
+ // If any call to the native layer fails, then we shouldn't invoke any more callbacks.
+ state = State.CLOSED;
+ throw new AbortRequestException();
+ }
+ }
+
+ private void runRequestToCompletion() {
+ // If we're already closed by the time the background thread started executing this method,
+ // there's nothing left to do for us.
+ if (isClosed()) {
+ return;
+ }
+
+ // Create the HttpURLConnection instance (this usually doesn't do any real work yet, even
+ // though it is declared to throw IOException).
+ HttpURLConnection connection;
+ try {
+ connection = urlConnectionFactory.createUrlConnection(request.getUri());
+ } catch (IOException e) {
+ doError("failure during connection creation", e);
+ return;
+ }
+
+ // Register a runnable that will allow us to cancel an ongoing request from a different
+ // thread.
+ synchronized (this) {
+ disconnectRunnable = connection::disconnect;
+ }
+
+ // From this point on we should call connection.disconnect() at the end of this method
+ // invocation, *except* when the request reaches a successful end (see comment below).
+ boolean doDisconnect = true;
+ try {
+ // Set and validate connection parameters (timeouts, HTTP method, request body, etc.).
+ String acceptEncodingHeader = findRequestHeader(ACCEPT_ENCODING_HEADER);
+ long requestContentLength;
+ try {
+ requestContentLength = parseContentLengthHeader(findRequestHeader(CONTENT_LENGTH_HEADER));
+ configureConnection(connection, requestContentLength, acceptEncodingHeader);
+ } catch (InvalidHttpRequestException e) {
+ doError("invalid request", e);
+ return;
+ }
+
+ // If there is a request body then start sending it. This is usually when the actual network
+ // connection is first established (subject to the #getRequestConnectTimeoutMs).
+ if (request.getHasBody()) {
+ try {
+ sendRequestBody(connection, requestContentLength);
+ } catch (IOException e) {
+ doError("failure during request body send", e);
+ return;
+ }
+ }
+
+ // Check one more time, before waiting on the response headers, if the request has already
+ // been cancelled (to avoid starting any blocking network IO we can't easily interrupt).
+ checkClosed();
+
+ // If there was no request body, then this will establish the connection (subject to the
+ // #getRequestConnectTimeoutMs). If there was a request body, this will be a noop.
+ try {
+ connection.connect();
+ } catch (IOException e) {
+ doError("failure during connect", e);
+ return;
+ }
+
+ // Wait for the request headers to be received (subject to #getRequestReadTimeOutMs).
+ ResponseHeadersWithMetadata response;
+ try {
+ response = receiveResponseHeaders(connection, acceptEncodingHeader);
+ } catch (IOException e) {
+ doError("failure during response header receive", e);
+ return;
+ }
+ doOnResponseStarted(response.responseProto.toByteArray());
+
+ try {
+ receiveResponseBody(connection, response.shouldDecodeGzip);
+ } catch (IOException e) {
+ doError("failure during response body receive", e);
+ return;
+ }
+ // Note that we purposely don't call connection.disconnect() once we reach this point, since
+ // we will have gracefully finished the request (e.g. by having readall of its response data),
+ // and this means that the underlying socket/connection may be reused for other connections to
+ // the same endpoint. Calling connection.disconnect() would prevent such connection reuse,
+ // which can be detrimental to the overall throughput. The underlying HttpURLConnection
+ // implementation will eventually reap the socket if doesn't end up being reused within a set
+ // amount of time.
+ doDisconnect = false;
+ doOnResponseCompleted(response.originalContentLengthHeader);
+ } catch (AbortRequestException e) {
+ // Nothing left for us to do.
+ } finally {
+ if (doDisconnect) {
+ connection.disconnect();
+ }
+ // At this point we will either have reached the end of the request successfully (in which
+ // case doOnResponseCompleted will have updated the object state to CLOSED), or we will have
+ // hit a AbortRequestException (in which case the state will already have been set to CLOSED),
+ // or we will have signaled an error (which will have set the state to CLOSED as well).
+ // Hence we don't have to modify the object state here anymore.
+ }
+ }
+
+ /** Returns the HTTP request method we should use, as a string. */
+ private String getRequestMethod() {
+ switch (request.getMethod()) {
+ case HTTP_METHOD_HEAD:
+ return "HEAD";
+ case HTTP_METHOD_GET:
+ return "GET";
+ case HTTP_METHOD_POST:
+ return "POST";
+ case HTTP_METHOD_PUT:
+ return "PUT";
+ case HTTP_METHOD_PATCH:
+ return "PATCH";
+ case HTTP_METHOD_DELETE:
+ return "DELETE";
+ default:
+ // This shouldn't happen, as it would indicate a bug in either this code or the native C++
+ // code calling us.
+ throw new UncheckedHttpClientForNativeException(
+ String.format("unexpected method: %s", request.getMethod().getNumber()));
+ }
+ }
+
+ /**
+ * Finds the given header (case-insensitively) and returns its value, if there is one. Otherwise
+ * returns null.
+ */
+ @Nullable
+ private String findRequestHeader(String name) {
+ for (JniHttpHeader header : request.getExtraHeadersList()) {
+ if (Ascii.equalsIgnoreCase(name, header.getName())) {
+ return header.getValue();
+ }
+ }
+ return null;
+ }
+
+ /**
+ * Tries to parse a "Content-Length" header value and returns it as a long, if it isn't null.
+ * Otherwise returns -1.
+ */
+ private long parseContentLengthHeader(String contentLengthHeader)
+ throws InvalidHttpRequestException {
+ if (contentLengthHeader == null) {
+ return -1;
+ }
+ try {
+ return Long.parseLong(contentLengthHeader);
+ } catch (NumberFormatException e) {
+ throw new InvalidHttpRequestException(
+ String.format("invalid Content-Length request header value: %s", contentLengthHeader), e);
+ }
+ }
+
+ /**
+ * Configures the {@link HttpURLConnection} object before it is used to establish the actual
+ * network connection.
+ */
+ @SuppressWarnings("NonAtomicVolatileUpdate")
+ private void configureConnection(
+ HttpURLConnection connection,
+ long requestContentLength,
+ @Nullable String acceptEncodingHeader)
+ throws InvalidHttpRequestException {
+ String requestMethod = getRequestMethod();
+ try {
+ connection.setRequestMethod(requestMethod);
+ } catch (ProtocolException e) {
+ // This should never happen, as we take care to only call this method with appropriate
+ // parameters.
+ throw new UncheckedHttpClientForNativeException("unexpected ProtocolException", e);
+ }
+ for (JniHttpHeader header : request.getExtraHeadersList()) {
+ // Note that we use addRequestProperty rather than setRequestProperty, to ensure that
+ // request headers that occur multiple times are properly specified (rather than just the
+ // last value being specified).
+ connection.addRequestProperty(header.getName(), header.getValue());
+ }
+ // The C++ `HttpClient` contract requires us to set the Accept-Encoding header, if there isn't
+ // one provided by the native layer. Note that on Android the HttpURLConnection implementation
+ // does this by default, but the JDK's implementation does not. Note that by setting this header
+ // we must also handle the response InputStream data correctly (by inflating it, if the
+ // Content-Encoding indicates the data is compressed).
+ // Some HttpURLConnection implementations (such as Cronet's) don't allow setting this header,
+ // and print out a warning if you do. The supportAcceptEncodingHeader allows turning this
+ // behavior off (thereby avoiding the warning being logged).
+ if (supportAcceptEncodingHeader && acceptEncodingHeader == null) {
+ connection.setRequestProperty(ACCEPT_ENCODING_HEADER, GZIP_ENCODING);
+ } else if (!supportAcceptEncodingHeader && acceptEncodingHeader != null) {
+ throw new InvalidHttpRequestException("cannot support Accept-Encoding header");
+ }
+
+ if (connectTimeoutMs >= 0) {
+ connection.setConnectTimeout(connectTimeoutMs);
+ }
+ if (readTimeoutMs >= 0) {
+ connection.setReadTimeout(readTimeoutMs);
+ }
+
+ connection.setDoInput(true);
+ if (request.getHasBody()) {
+ connection.setDoOutput(true);
+ if (requestContentLength >= 0) {
+ // If the Content-Length header is set then we don't have to use Transfer-Encoding, since
+ // we know the size of the request body ahead of time.
+ connection.setFixedLengthStreamingMode(requestContentLength);
+ } else {
+ // If we don't know the size of the request body ahead of time, we should turn on
+ // "Transfer-Encoding: chunked" using the following method.
+ connection.setChunkedStreamingMode(requestBodyChunkSizeBytes);
+ }
+ } else if (requestContentLength > 0) {
+ // If getHasBody() is false but a non-zero Content-Length header is set, then something went
+ // wrong in the native layer.
+ throw new InvalidHttpRequestException("Content-Length > 0 but no request body available");
+ }
+
+ // As per the interface contract in C++'s http_client.h, we should not use any caches.
+ connection.setUseCaches(false);
+ // As per the interface contract in C++'s http_client.h, we should follow redirects.
+ connection.setInstanceFollowRedirects(true);
+
+ // Ensure that no system-wide CookieHandler was installed, since we must not store any cookies.
+ if (CookieHandler.getDefault() != null) {
+ throw new IllegalStateException("must not set a CookieHandler");
+ }
+
+ // Count the request headers as part of the sent bytes. We do this before we actually open the
+ // connection, so that if the connection fails to be established we still account for the
+ // possibly already-transmitted data.
+ //
+ // Note that if the implementation uses HTTP2 with HPACK header compression this could lead to
+ // an overestimation of the total bytes sent. The estimatedHttp2HeaderCompressionRatio parameter
+ // can be used to account for this heuristically.
+ //
+ // If HTTP/2 is used, then some of our estimates will also be overestimates since we assume that
+ // headers are terminated by \r\n lines etc., while HTTP/2 generally represents headers more
+ // compactly. To avoid complicating things too much, we don't account for that.
+ //
+ // Aside from not accounting for HTTP/2 and header compression, some request headers may also be
+ // set by the HttpUrlConnection implementation which we cannot observe here, and hence we won't
+ // be counting those either. Hence, this number could be both an over or under-estimate, and
+ // should really be considered a best-effort estimate.
+ //
+ // Note that while it might seem we could use getRequestProperties() to get the actual request
+ // headers (incl. implementation-specified ones), this isn't actually the case for most
+ // HttpURLConnection implementations (and some implementations don't return anything from
+ // getRequestProperties(), even if we've already called addRequestProperty()).
+ // First, account for the HTTP request status line.
+ sentHeaderBytes +=
+ requestMethod.length()
+ + " ".length()
+ + request.getUri().length()
+ + " HTTP/1.1\r\n".length();
+ // Then account for each header we know is will be included.
+ for (JniHttpHeader header : request.getExtraHeadersList()) {
+ // Each entry should count the lengths of the header name + header value (rather than only
+ // counting the header name length once), since duplicated headers are likely to be sent in
+ // separate header lines (rather than being coalesced into a single header line by the
+ // HttpURLConnection implementation).
+ sentHeaderBytes +=
+ header.getName().length() + ": ".length() + header.getValue().length() + "\r\n".length();
+ }
+ // Account for the \r\n characters at the end of the request headers.
+ sentHeaderBytes += "\r\n".length();
+ }
+
+ /**
+ * Sends the request body (received from the native layer via the JNI callbacks) to the server
+ * after establishing a connection (blocking until all request body data has been written to the
+ * network, or an error occurs).
+ *
+ * @param connection the HttpURLConnection to send the request body for.
+ * @param requestContentLength the length of the request body if it is known ahead of time, or -1
+ * if the request body's length is not known ahead of time.
+ */
+ @SuppressWarnings("NonAtomicVolatileUpdate")
+ private void sendRequestBody(HttpURLConnection connection, long requestContentLength)
+ throws IOException, AbortRequestException {
+ // Check one more time, before issuing the request, if it's already been cancelled (to avoid
+ // starting any blocking network IO we can't easily interrupt).
+ checkClosed();
+
+ // Note that we don't wrap the OutputStream in a BufferedOutputStream, since we already write
+ // data to the unbuffered OutputStream in fairly large chunks at a time, so adding another
+ // buffering layer in between isn't helpful.
+ //
+ // The call to getOutputStream or OutputStream.write() is what will establish the actual
+ // network connection.
+ try (OutputStream outputStream = connection.getOutputStream()) {
+ // Allocate a buffer for reading the request body data into via JNI.
+ byte[] buffer = new byte[calculateRequestBodyBufferSize(requestContentLength)];
+ // Allocate an array for the native layer to write the number of actually read bytes into.
+ // Because arrays are mutable, this effectively serves as an 'output parameter', allowing the
+ // native code to return this bit of information in addition to its primary success/failure
+ // return value.
+ int[] actualBytesRead = new int[1];
+ while (true) {
+ // Read data from native. This may be very fast, but may also block on disk IO and/or
+ // on-the-fly payload compression.
+ doReadRequestBody(buffer, buffer.length, actualBytesRead);
+ // The native layer signals the end of the request body data by returning -1 as the
+ // "actually read bytes" value (this corresponds to C++'s `HttpRequest::ReadBody` returning
+ // `OUT_OF_RANGE`).
+ if (actualBytesRead[0] == -1) {
+ // End of data reached (successfully).
+ break;
+ }
+ // Otherwise, the native layer is required to have read at least 1 byte into our buffer at
+ // this point (and hence actualBytesRead[0] will be >= 1).
+
+ // Account for the data we're about to send in our 'sent bytes' stats. We do this before we
+ // write it to the output stream (so that this over rather than under-estimates the number,
+ // in case we get interrupted mid-write).
+ sentBodyBytes += actualBytesRead[0];
+
+ // Write the data from the native layer to the network socket.
+ outputStream.write(buffer, 0, actualBytesRead[0]);
+
+ // Before trying to read another chunk of data, make sure that the request hasn't been
+ // aborted yet.
+ checkClosed();
+ }
+ // Flush the stream before we close it, for good measure.
+ outputStream.flush();
+ }
+ // We're done uploading.
+ }
+
+ private int calculateRequestBodyBufferSize(long requestContentLength) {
+ // If the request body size is known ahead of time, and is smaller than the chunk size we
+ // otherwise would use, then we allocate a buffer of just the exact size we need. If the
+ // request body size is unknown or too large, then we use a set chunk buffer size to read one
+ // chunk at a time.
+ if (requestContentLength > 0 && requestContentLength < requestBodyChunkSizeBytes) {
+ // This cast from long to int is safe, because we know requestContentLength is smaller than
+ // the int bufferSize at this point.
+ return (int) requestContentLength;
+ }
+ return requestBodyChunkSizeBytes;
+ }
+
+ private static final class ResponseHeadersWithMetadata {
+ private final JniHttpResponse responseProto;
+ private final boolean shouldDecodeGzip;
+ private final long originalContentLengthHeader;
+
+ ResponseHeadersWithMetadata(
+ JniHttpResponse responseProto, boolean shouldDecodeGzip, long originalContentLengthHeader) {
+ this.responseProto = responseProto;
+ this.shouldDecodeGzip = shouldDecodeGzip;
+ this.originalContentLengthHeader = originalContentLengthHeader;
+ }
+ }
+
+ /**
+ * Receives the response headers from the server (blocking until that data is available, or an
+ * error occurs), and passes it to the native layer via the JNI callbacks.
+ */
+ @SuppressWarnings("NonAtomicVolatileUpdate")
+ private ResponseHeadersWithMetadata receiveResponseHeaders(
+ HttpURLConnection connection, String originalAcceptEncodingHeader) throws IOException {
+ // This call will block until the response headers are received (or throw if an error occurred
+ // before headers were received, or if no response header data is received before
+ // #getRequestReadTimeOutMs).
+ int responseCode = connection.getResponseCode();
+
+ // If the original headers we received from the native layer did not include an Accept-Encoding
+ // header, then *if we specified an "Accept-Encoding" header ourselves and subsequently received
+ // an encoded response body* we should a) remove the Content-Encoding header (since they refer
+ // to the encoded data, not the decoded data we will return to the native layer), and b) decode
+ // the response body data before returning it to the native layer. Note that if we did receive
+ // an "Accept-Encoding" header (even if it specified "gzip"), we must not auto-decode the
+ // response body and we should also leave the headers alone.
+ boolean shouldDecodeGzip = false;
+ if (supportAcceptEncodingHeader && originalAcceptEncodingHeader == null) {
+ // We need to strip the headers, if the body is encoded. Determine if it is encoded first.
+ for (Map.Entry<String, List<String>> header : connection.getHeaderFields().entrySet()) {
+ List<String> headerValues = header.getValue();
+ if (Ascii.equalsIgnoreCase(CONTENT_ENCODING_HEADER, nullToEmpty(header.getKey()))
+ && !headerValues.isEmpty()
+ && Ascii.equalsIgnoreCase(GZIP_ENCODING, nullToEmpty(headerValues.get(0)))) {
+ shouldDecodeGzip = true;
+ break;
+ }
+ }
+ }
+
+ JniHttpResponse.Builder response = JniHttpResponse.newBuilder();
+ response.setCode(responseCode);
+
+ // Account for the response status line in the 'received bytes' stats.
+ String responseMessage = connection.getResponseMessage();
+ // Note that responseMessage could be null or empty if an HTTP/2 implementation is used (since
+ // HTTP/2 doesn't have 'reason phrases' in the status line anymore, only codes).
+ responseMessage = nullToEmpty(responseMessage);
+ receivedHeaderBytes += "HTTP/1.1 XXX ".length() + responseMessage.length() + "\r\n".length();
+ // Add two bytes to account for the \r\n at the end of the response headers.
+ receivedHeaderBytes += "\r\n".length();
+
+ // If the response message was empty, then we assume that the request used HTTP/2. This is a
+ // flawed heuristic, but the best we have available.
+ requestUsedHttp2Heuristic = responseMessage.isEmpty();
+
+ // Now let's process the response headers.
+ long receivedContentLength = -1;
+ for (Map.Entry<String, List<String>> header : connection.getHeaderFields().entrySet()) {
+ // First, let's account for the received headers in our 'received bytes' stats. See note about
+ // counting bytes for request headers above, which applies similarly to response
+ // headers.
+ //
+ // Note that for some HttpURLConnection implementations the HTTP response status line may be
+ // included in the getHeadersField() result under the null header key, while others don't
+ // include it at all. We just skip counting the status line from getHeaderFields() sinec we
+ // already accounted for it above.
+ if (header.getKey() == null) {
+ continue;
+ }
+ // Count the bytes for all the headers (including accounting for the colon, space, and
+ // newlines that would've been sent over the wire).
+ for (String headerValue : header.getValue()) {
+ receivedHeaderBytes +=
+ header.getKey() == null ? 0 : (header.getKey().length() + ": ".length());
+ receivedHeaderBytes += headerValue == null ? 0 : headerValue.length();
+ // Account for the \r\n chars at the end of the header.
+ receivedHeaderBytes += "\r\n".length();
+ }
+
+ // Now let's skip headers we shouldn't return to the C++ layer.
+ //
+ // The HttpURLConnection implementation generally unchunks response bodies that used
+ // "Transfer-Encoding: chunked". However, while Android's implementation also then removes the
+ // "Transfer-Encoding" header, the JDK implementation does not. Since the HttpClient contract
+ // requires us to remove that header, we explicitly filter it out here.
+ //
+ // Finally, if the response will automatically be gzip-decoded by us, then we must redact any
+ // Content-Encoding header too.
+ if ((Ascii.equalsIgnoreCase(TRANSFER_ENCODING_HEADER, header.getKey())
+ && header.getValue().size() == 1
+ && Ascii.equalsIgnoreCase(
+ CHUNKED_TRANSFER_ENCODING, nullToEmpty(header.getValue().get(0))))
+ || (shouldDecodeGzip
+ && Ascii.equalsIgnoreCase(CONTENT_ENCODING_HEADER, header.getKey()))) {
+ continue;
+ }
+ // Also, the "Content-Length" value returned by HttpURLConnection may or may not correspond to
+ // the response body data we will see via + " - " + receivedBodyBytesgetInputStream() (e.g.
+ // it may reflect the length of
+ // the previously compressed data, even if the data is already decompressed for us when we
+ // read it from the InputStream). Hence, we ignore it as well. We do so even though the C++
+ // `HttpClient` asks us to leave it unredacted, because its value cannot be interpreted
+ // consistently. However, if the "Content-Length" header *is* available, then we do use it to
+ // estimate the network bytes we've received (but only once the request has completed
+ // successfully).
+ if (Ascii.equalsIgnoreCase(CONTENT_LENGTH_HEADER, header.getKey())) {
+ if (header.getValue().size() == 1) {
+ try {
+ receivedContentLength = Long.parseLong(header.getValue().get(0));
+ } catch (NumberFormatException e) {
+ // ignore
+ }
+ }
+ continue;
+ }
+
+ // Pass the remaining headers to the C++ layer.
+ for (String headerValue : header.getValue()) {
+ response.addHeaders(
+ JniHttpHeader.newBuilder().setName(header.getKey()).setValue(headerValue));
+ }
+ }
+
+ // If we receive a positive cache hit (i.e. HTTP_NOT_MODIFIED), then the response will not have
+ // a body even though the "Content-Encoding" header may still be set. In such cases we shouldn't
+ // try pass the InputStream to a GZIPInputStream (in the receiveResponseBody function below),
+ // since GZIPInputStream would crash on the 0-byte stream. Note that while we disable any
+ // HttpURLConnection-level cache explicitly in this file, it's still possible that the native
+ // layer itself implements a cache, which could result in us receiving HTTP_NOT_MODIFIED
+ // responses after all, and we should handle those correctly.
+ shouldDecodeGzip =
+ shouldDecodeGzip && connection.getResponseCode() != HttpURLConnection.HTTP_NOT_MODIFIED;
+ return new ResponseHeadersWithMetadata(
+ response.build(), shouldDecodeGzip, receivedContentLength);
+ }
+
+ /**
+ * Receives the response body from the server and passes it to the native layer via the JNI
+ * callbacks (blocking until all response body data has been received, or an error occurs).
+ */
+ @SuppressWarnings("NonAtomicVolatileUpdate")
+ private void receiveResponseBody(HttpURLConnection connection, boolean shouldDecodeGzip)
+ throws IOException, AbortRequestException {
+ // Check one more time, before blocking on the InputStream, if it request has already been
+ // cancelled (to avoid starting any blocking network IO we can't easily interrupt).
+ checkClosed();
+
+ try (CountingInputStream networkStream = getResponseBodyStream(connection);
+ InputStream inputStream = getDecodedResponseBodyStream(networkStream, shouldDecodeGzip)) {
+ long networkReceivedBytes = 0;
+ // Allocate a buffer for reading the response body data into memory and passing it to JNI.
+ int bufferSize = responseBodyChunkSizeBytes;
+ byte[] buffer = new byte[bufferSize];
+ // This outer loop runs until we reach the end of the response body stream (or hit an
+ // error).
+ int actualBytesRead = -1;
+ do {
+ int cursor = 0;
+ // Read data from the network stream (or from the decompressing input stream wrapping the
+ // network stream), filling up the buffer that we will pass to the native layer. It's likely
+ // that each read returns less data than we request. Hence, this inner loop runs until our
+ // buffer is full, the end of the data is reached, or we hit an error.
+ while (cursor < buffer.length) {
+ actualBytesRead = inputStream.read(buffer, cursor, buffer.length - cursor);
+
+ // Update the number of received bytes (at the network level, as best as we can measure).
+ // We must do this before we break out of the loop.
+ //
+ // Note that for some implementations like Cronet's, this would count uncompressed bytes
+ // even if the original response was compressed using a Content-Encoding. Hence, this
+ // would be an over-estimate of actual network data usage. We will, however, try to
+ // provide a more accurate value once the request is completed successfully, if a
+ // Content-Length response header was available. See doOnResponseCompleted.
+ long newNetworkReceivedBytes = networkStream.getCount();
+ receivedBodyBytes += (newNetworkReceivedBytes - networkReceivedBytes);
+ networkReceivedBytes = newNetworkReceivedBytes;
+
+ if (actualBytesRead == -1) {
+ // End of data reached (successfully). Break out of inner loop.
+ break;
+ }
+ // Some data was read.
+ cursor += actualBytesRead;
+ }
+ // If our buffer is still empty, then we must've hit the end of the data right away. No need
+ // to call back into the native layer anymore.
+ if (cursor == 0) {
+ break;
+ }
+ // If our buffer now has some data in it, we must pass it to the native layer via the JNI
+ // callback. This may be very fast, but may also block on disk IO and/or on-the-fly
+ // payload decompression.
+ doOnResponseBody(buffer, cursor);
+
+ // Before trying to read another chunk of data, make sure that the request hasn't been
+ // aborted yet.
+ checkClosed();
+ } while (actualBytesRead != -1);
+ }
+ // We're done downloading. The InputStream will be closed, letting the network layer reclaim
+ // the socket and possibly return it to a connection pool for later reuse (as long as we don't
+ // call #disconnect() on it, which would prevent the socket from being reused).
+ }
+
+ /** Returns the {@link java.io.InputStream} that will return the response body data. */
+ private static CountingInputStream getResponseBodyStream(HttpURLConnection connection)
+ throws IOException {
+ // If the response was an error, then we need to call getErrorStream() to get the response
+ // body. Otherwise we need to use getInputStream().
+ //
+ // Note that we don't wrap the InputStream in a BufferedInputStream, since we already read data
+ // from the unbuffered InputStream in large chunks at a time, so adding another buffering layer
+ // in between isn't helpful.
+ InputStream errorStream = connection.getErrorStream();
+ if (errorStream == null) {
+ return new CountingInputStream(connection.getInputStream());
+ }
+ return new CountingInputStream(errorStream);
+ }
+
+ /**
+ * Returns an {@link java.io.InputStream} that, if we should automatically decode/decompress the
+ * response body, will do so.
+ *
+ * <p>Note that if we should not automatically decode the response body, then this will simply
+ * return {@code inputStream}.
+ */
+ private InputStream getDecodedResponseBodyStream(
+ InputStream inputStream, boolean shouldDecodeGzip) throws IOException {
+ if (shouldDecodeGzip) {
+ // Note that GZIPInputStream's default internal buffer size is quite small (512 bytes). We
+ // therefore specify a buffer size explicitly, to ensure that we read in large enough chunks
+ // from the network stream (which in turn can improve overall throughput).
+ return new GZIPInputStream(inputStream, responseBodyGzipBufferSizeBytes);
+ }
+ return inputStream;
+ }
+}
diff --git a/fcp/java_src/test/java/com/google/fcp/client/http/BUILD b/fcp/java_src/test/java/com/google/fcp/client/http/BUILD
new file mode 100644
index 0000000..cb57f4d
--- /dev/null
+++ b/fcp/java_src/test/java/com/google/fcp/client/http/BUILD
@@ -0,0 +1,33 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+java_test(
+ name = "HttpClientForNativeImplTest",
+ size = "small",
+ srcs = ["HttpClientForNativeImplTest.java"],
+ deps = [
+ "//fcp/client/http/java:jni_java_proto",
+ "//fcp/java_src/main/java/com/google/fcp/client",
+ "//fcp/java_src/main/java/com/google/fcp/client/http",
+ "@com_google_googleapis//google/rpc:rpc_java_proto",
+ "@com_google_protobuf//:protobuf_java",
+ "@fcp_maven//:com_google_guava_guava",
+ "@fcp_maven//:com_google_truth_truth",
+ "@fcp_maven//:junit_junit",
+ "@fcp_maven//:org_mockito_mockito_core",
+ ],
+)
diff --git a/fcp/java_src/test/java/com/google/fcp/client/http/HttpClientForNativeImplTest.java b/fcp/java_src/test/java/com/google/fcp/client/http/HttpClientForNativeImplTest.java
new file mode 100644
index 0000000..f7b45b9
--- /dev/null
+++ b/fcp/java_src/test/java/com/google/fcp/client/http/HttpClientForNativeImplTest.java
@@ -0,0 +1,1723 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.fcp.client.http;
+
+import static com.google.common.truth.Truth.assertThat;
+import static com.google.common.truth.Truth.assertWithMessage;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.fcp.client.CallFromNativeWrapper;
+import com.google.fcp.client.CallFromNativeWrapper.CallFromNativeRuntimeException;
+import com.google.fcp.client.http.HttpRequestHandleImpl.HttpURLConnectionFactory;
+import com.google.protobuf.ExtensionRegistryLite;
+import com.google.protobuf.InvalidProtocolBufferException;
+import com.google.rpc.Code;
+import com.google.rpc.Status;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.net.HttpURLConnection;
+import java.net.SocketTimeoutException;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.zip.GZIPOutputStream;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.InOrder;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+
+/**
+ * Unit tests for {@link HttpClientForNativeImpl}.
+ *
+ * <p>This test doesn't actually call into any native code/JNI (instead the JNI callback methods are
+ * faked out and replaced with Java-only code, which makes the code a lot easier to unit test). This
+ * test also does <strong>not</strong> exercise all of the concurrency-related edge cases that could
+ * arise (since these are difficult to conclusively test in general).
+ */
+@RunWith(JUnit4.class)
+public final class HttpClientForNativeImplTest {
+
+ private static final int DEFAULT_TEST_CHUNK_BUFFER_SIZE = 5;
+ private static final double ESTIMATED_HTTP2_HEADER_COMPRESSION_RATIO = 0.5;
+ // We use an executor with real background threads, just to exercise a bit more of the code and
+ // possibly spot any concurrency issues. The use of background threads is conveniently hidden
+ // behind the performRequests interface anyway.
+ private static final ExecutorService TEST_EXECUTOR_SERVICE = Executors.newFixedThreadPool(2);
+ // Do nothing in the UncaughtExceptionHandler, letting the exception bubble up instead.
+ private static final CallFromNativeWrapper TEST_CALL_FROM_NATIVE_WRAPPER =
+ new CallFromNativeWrapper((t, e) -> {});
+
+ /**
+ * A fake {@link HttpRequestHandleImpl} implementation which never actually calls into the native
+ * layer over JNI, and instead uses a fake pure Java implementation that emulates how the native
+ * layer would behave. This makes unit testing the Java layer possible.
+ */
+ static class TestHttpRequestHandleImpl extends HttpRequestHandleImpl {
+ TestHttpRequestHandleImpl(
+ JniHttpRequest request,
+ HttpURLConnectionFactory urlConnectionFactory,
+ boolean supportAcceptEncodingHeader,
+ boolean disableTimeouts) {
+ super(
+ request,
+ TEST_CALL_FROM_NATIVE_WRAPPER,
+ TEST_EXECUTOR_SERVICE,
+ urlConnectionFactory,
+ /*connectTimeoutMs=*/ disableTimeouts ? -1 : 123,
+ /*readTimeoutMs=*/ disableTimeouts ? -1 : 456,
+ // Force the implementation to read 5 bytes at a time, to exercise the chunking logic.
+ /*requestBodyChunkSizeBytes=*/ DEFAULT_TEST_CHUNK_BUFFER_SIZE,
+ /*responseBodyChunkSizeBytes=*/ DEFAULT_TEST_CHUNK_BUFFER_SIZE,
+ /*responseBodyGzipBufferSizeBytes=*/ DEFAULT_TEST_CHUNK_BUFFER_SIZE,
+ /*callDisconnectWhenCancelled=*/ true,
+ /*supportAcceptEncodingHeader=*/ supportAcceptEncodingHeader,
+ /*estimatedHttp2HeaderCompressionRatio=*/ ESTIMATED_HTTP2_HEADER_COMPRESSION_RATIO);
+ }
+
+ // There should be no need for us to synchronize around these mutable fields, since the
+ // implementation itself should already implement the necessary synchronization to ensure that
+ // only one JNI callback method is called a time.
+ ByteArrayInputStream fakeRequestBody = null;
+ boolean readRequestBodyResult = true;
+ boolean onResponseStartedResult = true;
+ boolean onResponseBodyResult = true;
+
+ JniHttpResponse responseProto = null;
+ Status responseError = null;
+ Status responseBodyError = null;
+ ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
+ boolean completedSuccessfully = false;
+
+ @Override
+ protected boolean readRequestBody(byte[] buffer, long requestedBytes, int[] actualBytesRead) {
+ if (!readRequestBodyResult) {
+ return false;
+ }
+ int cursor;
+ // Always return up to two bytes only. That way we ensure the implementation properly handles
+ // the case when it gets less than data back than requested.
+ for (cursor = 0; cursor < Long.min(2, requestedBytes); cursor++) {
+ int newByte = fakeRequestBody.read();
+ if (newByte == -1) {
+ break;
+ }
+ buffer[cursor] = (byte) newByte;
+ }
+ actualBytesRead[0] = cursor == 0 ? -1 : cursor;
+ return true;
+ }
+
+ @Override
+ protected boolean onResponseStarted(byte[] responseProto) {
+ if (!onResponseStartedResult) {
+ return false;
+ }
+ try {
+ this.responseProto =
+ JniHttpResponse.parseFrom(responseProto, ExtensionRegistryLite.getEmptyRegistry());
+ } catch (InvalidProtocolBufferException e) {
+ throw new AssertionError("invalid responseProto", e);
+ }
+ return true;
+ }
+
+ @Override
+ protected void onResponseError(byte[] statusProto) {
+ try {
+ responseError = Status.parseFrom(statusProto, ExtensionRegistryLite.getEmptyRegistry());
+ } catch (InvalidProtocolBufferException e) {
+ throw new AssertionError("invalid statusProto", e);
+ }
+ }
+
+ @Override
+ protected boolean onResponseBody(byte[] data, int bytesAvailable) {
+ if (!onResponseBodyResult) {
+ return false;
+ }
+ responseBody.write(data, 0, bytesAvailable);
+ return true;
+ }
+
+ @Override
+ protected void onResponseBodyError(byte[] statusProto) {
+ try {
+ responseBodyError = Status.parseFrom(statusProto, ExtensionRegistryLite.getEmptyRegistry());
+ } catch (InvalidProtocolBufferException e) {
+ throw new AssertionError("invalid statusProto", e);
+ }
+ }
+
+ @Override
+ protected void onResponseCompleted() {
+ completedSuccessfully = true;
+ }
+
+ /**
+ * Checks that the request succeeded, based on which native callback methods were/were not
+ * invoked.
+ */
+ void assertSuccessfulCompletion() {
+ assertWithMessage("onResponseError was called").that(responseError).isNull();
+ assertWithMessage("onResponseBodyError was called").that(responseBodyError).isNull();
+ assertWithMessage("onResponseStarted was not called").that(responseProto).isNotNull();
+ assertWithMessage("onResponseCompleted was not called").that(completedSuccessfully).isTrue();
+ }
+ }
+
+ @Rule public final MockitoRule mockito = MockitoJUnit.rule();
+
+ @Mock HttpURLConnectionFactory urlConnectionFactory;
+
+ HttpClientForNativeImpl httpClient;
+
+ @Before
+ public void setUp() throws Exception {
+ httpClient =
+ new HttpClientForNativeImpl(
+ TEST_CALL_FROM_NATIVE_WRAPPER,
+ (request) ->
+ new TestHttpRequestHandleImpl(
+ request,
+ urlConnectionFactory,
+ /*supportAcceptEncodingHeader=*/ true,
+ /*disableTimeouts=*/ false));
+ }
+
+ @Test
+ public void testSingleRequestWithoutRequestBodySucceeds() throws Exception {
+ doTestSingleRequestWithoutRequestBodySucceeds(
+ /*supportAcceptEncodingHeader=*/ true, /*expectTimeoutsToBeSet=*/ true);
+ }
+
+ @Test
+ public void testSingleRequestWithoutRequestBodyAndDisableAcceptEncodingHeaderSupportSucceeds()
+ throws Exception {
+ httpClient =
+ new HttpClientForNativeImpl(
+ TEST_CALL_FROM_NATIVE_WRAPPER,
+ (request) ->
+ new TestHttpRequestHandleImpl(
+ request,
+ urlConnectionFactory,
+ /*supportAcceptEncodingHeader=*/ false,
+ /*disableTimeouts=*/ false));
+ doTestSingleRequestWithoutRequestBodySucceeds(
+ /*supportAcceptEncodingHeader=*/ false, /*expectTimeoutsToBeSet=*/ true);
+ }
+
+ @Test
+ public void testSingleRequestWithoutRequestBodyAndDisableTimeoutsSucceeds() throws Exception {
+ httpClient =
+ new HttpClientForNativeImpl(
+ TEST_CALL_FROM_NATIVE_WRAPPER,
+ (request) ->
+ new TestHttpRequestHandleImpl(
+ request,
+ urlConnectionFactory,
+ /*supportAcceptEncodingHeader=*/ false,
+ /*disableTimeouts=*/ true));
+ doTestSingleRequestWithoutRequestBodySucceeds(
+ /*supportAcceptEncodingHeader=*/ false, /*expectTimeoutsToBeSet=*/ false);
+ }
+
+ private void doTestSingleRequestWithoutRequestBodySucceeds(
+ boolean supportAcceptEncodingHeader, boolean expectTimeoutsToBeSet) throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header1")
+ .setValue("Foo")
+ .build())
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header2")
+ .setValue("Bar")
+ .build())
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection);
+
+ int expectedResponseCode = 200;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+ // Create a fake set of response headers. We use a LinkedHashMap rather than the less verbose
+ // ImmutableMap.of utility to allow us to add an entry for the HTTP status line, which in
+ // HttpURLConnection has a null key, which ImmutableMap disallows (to check that it gets
+ // properly handled/filtered out before passing on to JNI). Because the map still has a defined
+ // iteration order, we can still easily compare the whole response proto in one go (since we
+ // know the order the header fields will be in).
+ LinkedHashMap<String, List<String>> headerFields = new LinkedHashMap<>();
+ headerFields.put("Response-Header1", ImmutableList.of("Bar", "Baz"));
+ headerFields.put("Response-Header2", ImmutableList.of("Barbaz"));
+ // And add a Content-Length and 'null' header (to check whether they are correctly redacted &
+ // ignored.
+ headerFields.put("Content-Length", ImmutableList.of("9999")); // Should be ignored.
+ headerFields.put(null, ImmutableList.of("200 OK")); // Should be ignored.
+ when(mockConnection.getHeaderFields()).thenReturn(headerFields);
+
+ // Fake some response body data.
+ String expectedResponseBody = "test_response_body";
+ when(mockConnection.getInputStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody.getBytes(UTF_8)));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Verify the results..
+ requestHandle.assertSuccessfulCompletion();
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(expectedResponseCode)
+ // The Content-Length and 'null' headers should have been redacted.
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Baz").build())
+ .addHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Response-Header2")
+ .setValue("Barbaz")
+ .build())
+ .build());
+
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody);
+
+ // Verify various important request properties.
+ verify(mockConnection).setRequestMethod("GET");
+ InOrder requestHeadersOrder = inOrder(mockConnection);
+ requestHeadersOrder.verify(mockConnection).addRequestProperty("Request-Header1", "Foo");
+ requestHeadersOrder.verify(mockConnection).addRequestProperty("Request-Header2", "Bar");
+ verify(mockConnection, supportAcceptEncodingHeader ? times(1) : never())
+ .setRequestProperty("Accept-Encoding", "gzip");
+ verify(mockConnection, expectTimeoutsToBeSet ? times(1) : never()).setConnectTimeout(123);
+ verify(mockConnection, expectTimeoutsToBeSet ? times(1) : never()).setReadTimeout(456);
+ verify(mockConnection, never()).setDoOutput(anyBoolean());
+ verify(mockConnection, never()).getOutputStream();
+ verify(mockConnection).setDoInput(true);
+ verify(mockConnection).setUseCaches(false);
+ verify(mockConnection).setInstanceFollowRedirects(true);
+ }
+
+ @Test
+ public void testSingleRequestWithRequestBodySucceeds() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_POST)
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header1")
+ .setValue("Foo")
+ .build())
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header1")
+ .setValue("Foobar")
+ .build())
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header2")
+ .setValue("Bar")
+ .build())
+ .setHasBody(true)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection);
+
+ // Gather request body data sent to the HttpURLConnection in a stream we can inspect later on.
+ ByteArrayOutputStream actualRequestBody = new ByteArrayOutputStream();
+ when(mockConnection.getOutputStream()).thenReturn(actualRequestBody);
+
+ // Fake some request body data.
+ String expectedRequestBody = "test_request_body";
+ requestHandle.fakeRequestBody = new ByteArrayInputStream(expectedRequestBody.getBytes(UTF_8));
+
+ int expectedResponseCode = 200;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+ LinkedHashMap<String, List<String>> headerFields = new LinkedHashMap<>();
+ headerFields.put("Response-Header1", ImmutableList.of("Bar", "Baz"));
+ headerFields.put("Response-Header2", ImmutableList.of("Barbaz"));
+ headerFields.put(null, ImmutableList.of("HTTP/1.1 200 OK")); // Should be ignored.
+ when(mockConnection.getHeaderFields()).thenReturn(headerFields);
+
+ // Add the response message ("OK"), so that it gets included in the received bytes stats.
+ when(mockConnection.getResponseMessage()).thenReturn("OK");
+
+ // Fake some response body data.
+ String expectedResponseBody = "test_response_body";
+ when(mockConnection.getInputStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody.getBytes(UTF_8)));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Verify the results.
+ requestHandle.assertSuccessfulCompletion();
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(expectedResponseCode)
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Baz").build())
+ .addHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Response-Header2")
+ .setValue("Barbaz")
+ .build())
+ .build());
+
+ assertThat(actualRequestBody.toString(UTF_8.name())).isEqualTo(expectedRequestBody);
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody);
+
+ // Verify the network stats are accurate (they should count the request headers, URL, request
+ // method, request body, response headers and response body).
+ assertThat(
+ JniHttpSentReceivedBytes.parseFrom(
+ requestHandle.getTotalSentReceivedBytes(),
+ ExtensionRegistryLite.getEmptyRegistry()))
+ .isEqualTo(
+ JniHttpSentReceivedBytes.newBuilder()
+ .setSentBytes(
+ ("POST https://foo.com HTTP/1.1\r\n"
+ + "Request-Header1: Foo\r\n"
+ + "Request-Header1: Foobar\r\n"
+ + "Request-Header2: Bar\r\n"
+ + "\r\n")
+ .length()
+ + expectedRequestBody.length())
+ .setReceivedBytes(
+ ("HTTP/1.1 200 OK\r\n"
+ + "Response-Header1: Bar\r\n"
+ + "Response-Header1: Baz\r\n"
+ + "Response-Header2: Barbaz\r\n"
+ + "\r\n")
+ .length()
+ + requestHandle.responseBody.size())
+ .build());
+
+ // Verify various important request properties.
+ verify(mockConnection).setRequestMethod("POST");
+ InOrder requestHeadersOrder = inOrder(mockConnection);
+ requestHeadersOrder.verify(mockConnection).addRequestProperty("Request-Header1", "Foo");
+ requestHeadersOrder.verify(mockConnection).addRequestProperty("Request-Header1", "Foobar");
+ requestHeadersOrder.verify(mockConnection).addRequestProperty("Request-Header2", "Bar");
+ verify(mockConnection).setConnectTimeout(123);
+ verify(mockConnection).setReadTimeout(456);
+ verify(mockConnection).setDoOutput(true);
+ // Since the request body content length wasn't known ahead of time, the
+ // 'Transfer-Encoding: chunked' streaming mode should've been enabled.
+ verify(mockConnection).setChunkedStreamingMode(5);
+ verify(mockConnection, never()).setFixedLengthStreamingMode(anyInt());
+ verify(mockConnection).setDoInput(true);
+ verify(mockConnection).setUseCaches(false);
+ verify(mockConnection).setInstanceFollowRedirects(true);
+ }
+
+ /**
+ * Tests whether a single request with a <strong>known-ahead-of-time</strong> request body content
+ * length is processed correctly.
+ */
+ @Test
+ public void testSingleRequestWithKnownRequestContentLengthSucceeds() throws Exception {
+ String expectedRequestBody = "another_test_request_body";
+ String requestBodyLength = "25"; // the length of the above string.
+ long requestBodyLengthLong = 25L;
+ assertThat(expectedRequestBody).hasLength((int) requestBodyLengthLong);
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_PUT)
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header1")
+ .setValue("Foo")
+ .build())
+ .addExtraHeaders(
+ // Add a Content-Length request header, which should result in 'fixed
+ // length'
+ // request body streaming mode.
+ JniHttpHeader.newBuilder()
+ // We purposely use a mixed-case header name to ensure header matching
+ // is
+ // case insensitive.
+ .setName("Content-length")
+ .setValue(requestBodyLength))
+ .setHasBody(true)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection);
+
+ // Gather request body data sent to the HttpURLConnection in a stream we can inspect later on.
+ ByteArrayOutputStream actualRequestBody = new ByteArrayOutputStream();
+ when(mockConnection.getOutputStream()).thenReturn(actualRequestBody);
+
+ // Fake some request body data.
+ requestHandle.fakeRequestBody = new ByteArrayInputStream(expectedRequestBody.getBytes(UTF_8));
+
+ int expectedResponseCode = 201;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+
+ // Fake some response body data.
+ String expectedResponseBody = "another_test_response_body";
+ when(mockConnection.getInputStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody.getBytes(UTF_8)));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Verify the results..
+ requestHandle.assertSuccessfulCompletion();
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(JniHttpResponse.newBuilder().setCode(expectedResponseCode).build());
+
+ assertThat(actualRequestBody.toString(UTF_8.name())).isEqualTo(expectedRequestBody);
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody);
+
+ verify(mockConnection).setRequestMethod("PUT");
+ verify(mockConnection).setDoOutput(true);
+ InOrder requestHeadersOrder = inOrder(mockConnection);
+ requestHeadersOrder.verify(mockConnection).addRequestProperty("Request-Header1", "Foo");
+ requestHeadersOrder
+ .verify(mockConnection)
+ .addRequestProperty("Content-length", requestBodyLength);
+ // Since the request body content length *was* known ahead of time, the fixed length streaming
+ // mode should have been enabled.
+ verify(mockConnection).setFixedLengthStreamingMode(requestBodyLengthLong);
+ verify(mockConnection, never()).setChunkedStreamingMode(anyInt());
+ }
+
+ /**
+ * Tests whether a single request with a request body that is smaller than our read buffer size is
+ * processed correctly.
+ */
+ @Test
+ public void testSingleRequestWithKnownRequestContentLengthThatFitsInSingleBufferSucceeds()
+ throws Exception {
+ String expectedRequestBody = "1234";
+ String requestBodyLength =
+ "4"; // the length of the above string, which is smaller than the buffer size.
+ long requestBodyLengthLong = 4L;
+ assertThat(expectedRequestBody).hasLength((int) requestBodyLengthLong);
+ assertThat(requestBodyLengthLong).isLessThan(DEFAULT_TEST_CHUNK_BUFFER_SIZE);
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_PUT)
+ .addExtraHeaders(
+ // Add a Content-Length request header, which should result in 'fixed
+ // length'
+ // request body streaming mode.
+ JniHttpHeader.newBuilder()
+ // We purposely use a mixed-case header name to ensure header matching
+ // is
+ // case insensitive.
+ .setName("content-Length")
+ .setValue(requestBodyLength))
+ .setHasBody(true)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection);
+
+ // Gather request body data sent to the HttpURLConnection in a stream we can inspect later on.
+ ByteArrayOutputStream actualRequestBody = new ByteArrayOutputStream();
+ when(mockConnection.getOutputStream()).thenReturn(actualRequestBody);
+
+ // Fake some request body data.
+ requestHandle.fakeRequestBody = new ByteArrayInputStream(expectedRequestBody.getBytes(UTF_8));
+
+ int expectedResponseCode = 503;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+
+ // Fake some response body data (via the error stream this time).
+ String expectedResponseBody = "abc";
+ when(mockConnection.getErrorStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody.getBytes(UTF_8)));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Verify the results..
+ requestHandle.assertSuccessfulCompletion();
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(JniHttpResponse.newBuilder().setCode(expectedResponseCode).build());
+
+ assertThat(actualRequestBody.toString(UTF_8.name())).isEqualTo(expectedRequestBody);
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody);
+
+ verify(mockConnection).setRequestMethod("PUT");
+ verify(mockConnection).setDoOutput(true);
+ verify(mockConnection).addRequestProperty("content-Length", requestBodyLength);
+ // Since the request body content length *was* known ahead of time, the fixed length streaming
+ // mode should have been enabled.
+ verify(mockConnection).setFixedLengthStreamingMode(requestBodyLengthLong);
+ verify(mockConnection, never()).setChunkedStreamingMode(anyInt());
+ }
+
+ /** Tests whether issuing multiple concurrent requests is handled correctly. */
+ @Test
+ public void testMultipleRequestsWithRequestBodiesSucceeds() throws Exception {
+ TestHttpRequestHandleImpl requestHandle1 =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_POST)
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header1")
+ .setValue("Foo")
+ .build())
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header2")
+ .setValue("Bar")
+ .build())
+ .setHasBody(true)
+ .build()
+ .toByteArray());
+
+ TestHttpRequestHandleImpl requestHandle2 =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo2.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_PATCH)
+ .setHasBody(true)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection1 = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection1);
+ HttpURLConnection mockConnection2 = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo2.com")).thenReturn(mockConnection2);
+
+ // Gather request body data sent to the HttpURLConnection in a stream we can inspect later on.
+ ByteArrayOutputStream actualRequestBody1 = new ByteArrayOutputStream();
+ when(mockConnection1.getOutputStream()).thenReturn(actualRequestBody1);
+ ByteArrayOutputStream actualRequestBody2 = new ByteArrayOutputStream();
+ when(mockConnection2.getOutputStream()).thenReturn(actualRequestBody2);
+
+ // Fake some request body data.
+ String expectedRequestBody1 = "test_request_body1";
+ requestHandle1.fakeRequestBody = new ByteArrayInputStream(expectedRequestBody1.getBytes(UTF_8));
+ String expectedRequestBody2 = "another_request_body2";
+ requestHandle2.fakeRequestBody = new ByteArrayInputStream(expectedRequestBody2.getBytes(UTF_8));
+
+ int expectedResponseCode1 = 200;
+ int expectedResponseCode2 = 300;
+ when(mockConnection1.getResponseCode()).thenReturn(expectedResponseCode1);
+ when(mockConnection2.getResponseCode()).thenReturn(expectedResponseCode2);
+ when(mockConnection1.getHeaderFields()).thenReturn(ImmutableMap.of());
+ when(mockConnection2.getHeaderFields())
+ .thenReturn(
+ ImmutableMap.of(
+ "Response-Header1",
+ ImmutableList.of("Bar"),
+ "Response-Header2",
+ ImmutableList.of("Barbaz")));
+
+ // Fake some response body data.
+ String expectedResponseBody1 = "test_response_body";
+ when(mockConnection1.getInputStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody1.getBytes(UTF_8)));
+
+ String expectedResponseBody2 = "test_response_body";
+ when(mockConnection2.getInputStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody2.getBytes(UTF_8)));
+
+ // Run both requests (we provide them in opposite order to how they were created, just to try
+ // to exercise more edge conditions).
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle2, requestHandle1});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle1.close();
+ requestHandle2.close();
+
+ // Verify the results..
+ requestHandle1.assertSuccessfulCompletion();
+ assertThat(requestHandle1.responseProto)
+ .isEqualTo(JniHttpResponse.newBuilder().setCode(expectedResponseCode1).build());
+
+ requestHandle2.assertSuccessfulCompletion();
+ assertThat(requestHandle2.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(expectedResponseCode2)
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .addHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Response-Header2")
+ .setValue("Barbaz")
+ .build())
+ .build());
+
+ assertThat(actualRequestBody1.toString(UTF_8.name())).isEqualTo(expectedRequestBody1);
+ assertThat(requestHandle1.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody1);
+ assertThat(actualRequestBody2.toString(UTF_8.name())).isEqualTo(expectedRequestBody2);
+ assertThat(requestHandle2.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody2);
+
+ // Verify various important request properties.
+ verify(mockConnection1).setRequestMethod("POST");
+ verify(mockConnection2).setRequestMethod("PATCH");
+ InOrder requestHeadersOrder = inOrder(mockConnection1);
+ requestHeadersOrder.verify(mockConnection1).addRequestProperty("Request-Header1", "Foo");
+ requestHeadersOrder.verify(mockConnection1).addRequestProperty("Request-Header2", "Bar");
+ verify(mockConnection2, never()).addRequestProperty(any(), any());
+ }
+
+ @Test
+ public void testGzipResponseBodyDecompressionSucceeds() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header1")
+ .setValue("Foo")
+ .build())
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection);
+
+ int expectedResponseCode = 200;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+ when(mockConnection.getResponseMessage()).thenReturn("OK");
+
+ // Fake some response body data.
+ String expectedResponseBody = "test_response_body";
+ ByteArrayOutputStream compressedResponseBody = new ByteArrayOutputStream();
+ GZIPOutputStream compressedResponseBodyGzipStream =
+ new GZIPOutputStream(compressedResponseBody);
+ compressedResponseBodyGzipStream.write(expectedResponseBody.getBytes(UTF_8));
+ compressedResponseBodyGzipStream.finish();
+ when(mockConnection.getInputStream())
+ .thenReturn(new ByteArrayInputStream(compressedResponseBody.toByteArray()));
+ // And add Content-Encoding and Transfer-Encoding headers (to check whether they are correctly
+ // redacted).
+ when(mockConnection.getHeaderFields())
+ .thenReturn(
+ ImmutableMap.of(
+ "Response-Header1",
+ ImmutableList.of("Bar"),
+ "Content-Encoding",
+ ImmutableList.of("gzip"),
+ "Transfer-Encoding",
+ ImmutableList.of("chunked")));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Verify the results..
+ requestHandle.assertSuccessfulCompletion();
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(expectedResponseCode)
+ // The Content-Encoding and Transfer-Encoding headers should have been redacted.
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .build());
+
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody);
+
+ // Verify the network stats are accurate (they should count the request headers, URL, request
+ // method, request body, response headers and *compressed* response body, since decompression
+ // was performed by us and hence we were able to observe and count the compressed bytes).
+ assertThat(
+ JniHttpSentReceivedBytes.parseFrom(
+ requestHandle.getTotalSentReceivedBytes(),
+ ExtensionRegistryLite.getEmptyRegistry()))
+ .isEqualTo(
+ JniHttpSentReceivedBytes.newBuilder()
+ .setSentBytes(
+ ("GET https://foo.com HTTP/1.1\r\n" + "Request-Header1: Foo\r\n" + "\r\n")
+ .length())
+ .setReceivedBytes(
+ ("HTTP/1.1 200 OK\r\n"
+ + "Response-Header1: Bar\r\n"
+ + "Content-Encoding: gzip\r\n"
+ + "Transfer-Encoding: chunked\r\n"
+ + "\r\n")
+ .length()
+ + compressedResponseBody.size())
+ .build());
+
+ // Verify various important request properties.
+ verify(mockConnection).setRequestMethod("GET");
+ verify(mockConnection).addRequestProperty("Request-Header1", "Foo");
+ verify(mockConnection).setRequestProperty("Accept-Encoding", "gzip");
+ }
+
+ @Test
+ public void testGzipResponseBodyWithAcceptEncodingRequestHeaderShouldNotAutoDecompress()
+ throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder()
+ .setName("Request-Header1")
+ .setValue("Foo")
+ .build())
+ .addExtraHeaders(
+ // We purposely use mixed-case, to ensure case-insensitive matching is used.
+ JniHttpHeader.newBuilder()
+ .setName("Accept-encoding")
+ .setValue("gzip,foobar")
+ .build())
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection);
+
+ int expectedResponseCode = 200;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+
+ // Fake some response body data.
+ String expectedResponseBody = "i_should_not_be_decompressed";
+ when(mockConnection.getInputStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody.getBytes(UTF_8)));
+ // And add Content-Encoding and Content-Length headers (to check whether the first header is
+ // correctly left *un*redacted, and the second is still redacted).
+ when(mockConnection.getHeaderFields())
+ .thenReturn(
+ ImmutableMap.of(
+ "Response-Header1",
+ ImmutableList.of("Bar"),
+ "Content-Encoding",
+ ImmutableList.of("gzip"),
+ "Content-Length",
+ ImmutableList.of("9999")));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Verify the results..
+ requestHandle.assertSuccessfulCompletion();
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(expectedResponseCode)
+ // The Content-Length header should have been redacted.
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Content-Encoding").setValue("gzip").build())
+ .build());
+
+ // The response body should have been returned without trying to decompress it.
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody);
+
+ // Verify various important request properties.
+ verify(mockConnection).setRequestMethod("GET");
+ verify(mockConnection).addRequestProperty("Request-Header1", "Foo");
+ // The Accept-Encoding header provided by the native layer should have been used, verbatim.
+ verify(mockConnection).addRequestProperty("Accept-encoding", "gzip,foobar");
+ verify(mockConnection, never()).setRequestProperty(eq("Accept-Encoding"), any());
+ }
+
+ @Test
+ public void testChunkedTransferEncodingResponseHeaderShouldBeRemoved() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection);
+
+ int expectedResponseCode = 200;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+
+ // Fake some response body data.
+ String expectedResponseBody = "another_test_response_body";
+ when(mockConnection.getInputStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody.getBytes(UTF_8)));
+
+ // And make the response headers include a "Transfer-Encoding: chunked" header, simulating the
+ // case when HttpClientForNativeImpl is used with the JDK, which will un-chunk response data but
+ // which will not remove the Transfer-Encoding header afterwards (contrary to Android's
+ // HttpURLConnection implementation which *does* remove the header in this case).
+ when(mockConnection.getHeaderFields())
+ .thenReturn(
+ ImmutableMap.of(
+ "Response-Header1",
+ ImmutableList.of("Bar"),
+ "Transfer-Encoding",
+ ImmutableList.of("chunked")));
+ // Make the response body length *not* be known ahead of time (in accordance with the "chunked"
+ // transfer encoding having been used.
+ when(mockConnection.getContentLength()).thenReturn(-1);
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Verify the results.
+ requestHandle.assertSuccessfulCompletion();
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(expectedResponseCode)
+ // The Transfer-Encoding header should have been redacted.
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .build());
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody);
+ }
+
+ @Test
+ public void testContentLengthResponseHeaderShouldDetermineReceivedBytesEstimate()
+ throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection);
+
+ int expectedResponseCode = 200;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+ when(mockConnection.getResponseMessage()).thenReturn("OK");
+
+ // Fake some response body data.
+ String expectedResponseBody = "another_test_response_body";
+ when(mockConnection.getInputStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody.getBytes(UTF_8)));
+
+ // And make the response headers include a "Content-Length" header. The header should be ignored
+ // for the most part, *but* it should be used to produce the final estimated 'received bytes'
+ // statistic, if the request completes successfully.
+ int expectedContentLength = 5;
+ when(mockConnection.getHeaderFields())
+ .thenReturn(
+ ImmutableMap.of(
+ "Response-Header1",
+ ImmutableList.of("Bar"),
+ // Simulate a Content-Length header that has value that is smaller than the length
+ // of the response body we actually observe (e.g. a Cronet-based implementation has
+ // decompressed the content for us, but still told us the original length).
+ "Content-Length",
+ ImmutableList.of(Integer.toString(expectedContentLength))));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Verify the results.
+ requestHandle.assertSuccessfulCompletion();
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(expectedResponseCode)
+ // The Content-Length header should have been redacted.
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .build());
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody);
+
+ // Verify the network stats are accurate (they should count the request headers, URL, request
+ // method, request body, response headers and the *content length* rather than the observed
+ // response body).
+ assertThat(
+ JniHttpSentReceivedBytes.parseFrom(
+ requestHandle.getTotalSentReceivedBytes(),
+ ExtensionRegistryLite.getEmptyRegistry()))
+ .isEqualTo(
+ JniHttpSentReceivedBytes.newBuilder()
+ .setSentBytes("GET https://foo.com HTTP/1.1\r\n\r\n".length())
+ .setReceivedBytes(
+ ("HTTP/1.1 200 OK\r\n"
+ + "Response-Header1: Bar\r\n"
+ + "Content-Length: 5\r\n"
+ + "\r\n")
+ .length()
+ + expectedContentLength)
+ .build());
+ }
+
+ @Test
+ public void testHttp2RequestsShouldUseEstimatedHeaderCompressionRatio() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection("https://foo.com")).thenReturn(mockConnection);
+
+ int expectedResponseCode = 200;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+ // Return an empty response message, which is the heuristic that indicates HTTP/2 was likely
+ // used to service the request.
+ when(mockConnection.getResponseMessage()).thenReturn("");
+
+ // Fake some response body data.
+ String expectedResponseBody = "another_test_response_body";
+ when(mockConnection.getInputStream())
+ .thenReturn(new ByteArrayInputStream(expectedResponseBody.getBytes(UTF_8)));
+
+ // And make the response headers include a "Content-Length" header. The header should be ignored
+ // for the most part, *but* it should be used to produce the final estimated 'received bytes'
+ // statistic, if the request completes successfully.
+ when(mockConnection.getHeaderFields())
+ .thenReturn(ImmutableMap.of("Response-Header1", ImmutableList.of("Bar")));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Verify the results.
+ requestHandle.assertSuccessfulCompletion();
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(expectedResponseCode)
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .build());
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isEqualTo(expectedResponseBody);
+
+ // Verify the network stats are accurate (they should count the request headers, URL, request
+ // method, request body, response headers and the response body). Since HTTP/2 was used
+ // (according to the heuristic), the request/response headers should have a compression factor
+ // applied to them.
+ assertThat(
+ JniHttpSentReceivedBytes.parseFrom(
+ requestHandle.getTotalSentReceivedBytes(),
+ ExtensionRegistryLite.getEmptyRegistry()))
+ .isEqualTo(
+ JniHttpSentReceivedBytes.newBuilder()
+ // Even though HTTP/2 was used, our sent/received bytes estimates hardcode an
+ // assumption that HTTP/1.1-style status lines and CRLF-terminated headers were sent
+ // received (and then simply applies a compression factor over the length of those
+ // strings).
+ .setSentBytes(
+ (long)
+ ("GET https://foo.com HTTP/1.1\r\n\r\n".length()
+ * ESTIMATED_HTTP2_HEADER_COMPRESSION_RATIO))
+ .setReceivedBytes(
+ (long)
+ (("HTTP/1.1 200 \r\n" + "Response-Header1: Bar\r\n" + "\r\n").length()
+ * ESTIMATED_HTTP2_HEADER_COMPRESSION_RATIO)
+ + expectedResponseBody.length())
+ .build());
+ }
+
+ @Test
+ public void testPerformOnClosedRequestShouldThrow() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .build()
+ .toByteArray());
+
+ // Close the request before we issue it.
+ requestHandle.close();
+
+ // Since performRequests wasn't called yet, no callbacks should've been invoked as a result of
+ // the call to close().
+ assertThat(requestHandle.responseError).isNull();
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+
+ // Try to perform the request, it should fail.
+ CallFromNativeRuntimeException thrown =
+ assertThrows(
+ CallFromNativeRuntimeException.class,
+ () -> httpClient.performRequests(new Object[] {requestHandle}));
+ assertThat(thrown).hasCauseThat().isInstanceOf(IllegalStateException.class);
+ }
+
+ @Test
+ public void testRequestWithAcceptEncodingHeaderIfNotSupportedShouldResultInError()
+ throws Exception {
+ // Disable support for the Accept-Encoding header.
+ httpClient =
+ new HttpClientForNativeImpl(
+ TEST_CALL_FROM_NATIVE_WRAPPER,
+ (request) ->
+ new TestHttpRequestHandleImpl(
+ request,
+ urlConnectionFactory,
+ /*supportAcceptEncodingHeader=*/ false,
+ /*disableTimeouts=*/ false));
+
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder().setName("Content-Length").setValue("1"))
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder().setName("Accept-Encoding").setValue("gzip"))
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ assertThat(requestHandle.responseError).isNotNull();
+ assertThat(requestHandle.responseError.getCode()).isEqualTo(Code.INVALID_ARGUMENT_VALUE);
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ @Test
+ public void testNoBodyButHasRequestContentLengthShouldResultInError() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .addExtraHeaders(
+ JniHttpHeader.newBuilder().setName("Content-Length").setValue("1").build())
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ assertThat(requestHandle.responseError).isNotNull();
+ assertThat(requestHandle.responseError.getCode()).isEqualTo(Code.INVALID_ARGUMENT_VALUE);
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ /**
+ * If something about the network OutputStream throws an exception during request body upload,
+ * then we should return an error to native.
+ */
+ @Test
+ public void testSendRequestBodyExceptionShouldResultInResponseError() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_POST)
+ .setHasBody(true)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ when(mockConnection.getOutputStream()).thenThrow(new IOException("my error"));
+
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ assertThat(requestHandle.responseError).isNotNull();
+ assertThat(requestHandle.responseError.getCode()).isEqualTo(Code.UNAVAILABLE_VALUE);
+ assertThat(requestHandle.responseError.getMessage()).contains("IOException");
+ assertThat(requestHandle.responseError.getMessage()).contains("my error");
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ /**
+ * If the request got cancelled during request body upload, then we should return a CANCELLED
+ * error to native.
+ */
+ @Test
+ public void testCancellationDuringSendRequestBodyShouldResultInResponseError() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_POST)
+ .setHasBody(true)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ when(mockConnection.getOutputStream())
+ .thenAnswer(
+ invocation -> {
+ // Trigger the request cancellationF
+ requestHandle.close();
+ return new ByteArrayOutputStream();
+ });
+
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ assertThat(requestHandle.responseError).isNotNull();
+ assertThat(requestHandle.responseError.getCode()).isEqualTo(Code.CANCELLED_VALUE);
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ /**
+ * If something fails when reading request body data from JNI, then we should *not* call any more
+ * JNI callbacks again, since the native layer will already have handled the error.
+ */
+ @Test
+ public void testReadRequestBodyFromNativeFailureShouldNotCallJNICallback() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_POST)
+ .setHasBody(true)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ when(mockConnection.getOutputStream()).thenReturn(new ByteArrayOutputStream());
+ // Make the fake readRequestBody JNI method return an error.
+ requestHandle.readRequestBodyResult = false;
+
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ // No callbacks should have been invoked.
+ assertThat(requestHandle.responseError).isNull();
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ /** If establishing the connections fails, then we should return an error to native. */
+ @Test
+ public void testConnectExceptionShouldResultInResponseError() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ doThrow(new IOException("my error")).when(mockConnection).connect();
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseError).isNotNull();
+ assertThat(requestHandle.responseError.getCode()).isEqualTo(Code.UNAVAILABLE_VALUE);
+ assertThat(requestHandle.responseError.getMessage()).contains("IOException");
+ assertThat(requestHandle.responseError.getMessage()).contains("my error");
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+
+ // Verify that the request headers are counted in the network stats, since we can't really know
+ // whether the connect() method failed before any network connection was established, or whether
+ // it failed after we did already send our request onto the wire (each HttpURLConnection
+ // implementation can have slightly different behavior in this regard).
+ assertThat(
+ JniHttpSentReceivedBytes.parseFrom(
+ requestHandle.getTotalSentReceivedBytes(),
+ ExtensionRegistryLite.getEmptyRegistry()))
+ .isEqualTo(
+ JniHttpSentReceivedBytes.newBuilder()
+ .setSentBytes("GET https://foo.com HTTP/1.1\r\n\r\n".length())
+ .setReceivedBytes(0)
+ .build());
+ }
+
+ /**
+ * If something about the network InputStream throws an exception during response headers
+ * receiving, then we should return an error to native.
+ */
+ @Test
+ public void testReceiveResponseHeadersExceptionShouldResultInResponseError() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ when(mockConnection.getResponseCode()).thenThrow(new IOException("my error"));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseError).isNotNull();
+ assertThat(requestHandle.responseError.getCode()).isEqualTo(Code.UNAVAILABLE_VALUE);
+ assertThat(requestHandle.responseError.getMessage()).contains("IOException");
+ assertThat(requestHandle.responseError.getMessage()).contains("my error");
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ /**
+ * If something about the network InputStream throws a {@link java.net.SocketTimeoutException}
+ * during response headers receiving, then we should return a specific DEADLINE_EXCEEDED error to
+ * native.
+ */
+ @Test
+ public void testReceiveResponseHeadersTimeoutExceptionShouldResultInResponseError()
+ throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ when(mockConnection.getResponseCode()).thenThrow(new SocketTimeoutException("my error"));
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseError.getCode()).isEqualTo(Code.DEADLINE_EXCEEDED_VALUE);
+ assertThat(requestHandle.responseError.getMessage()).contains("SocketTimeoutException");
+ assertThat(requestHandle.responseError.getMessage()).contains("my error");
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ /**
+ * If the request gets cancelled during response headers receiving, then we should return a
+ * CANCELLED error to native.
+ */
+ @Test
+ public void testCancellationDuringReceiveResponseHeadersShouldResultInResponseError()
+ throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ when(mockConnection.getResponseCode())
+ .thenAnswer(
+ invocation -> {
+ // Trigger a cancellation of the request.
+ requestHandle.close();
+ return 200;
+ });
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseError.getCode()).isEqualTo(Code.CANCELLED_VALUE);
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ /**
+ * If something fails when writing response header data to JNI, then we should *not* call any more
+ * JNI callbacks again, since the native layer will already have handled the error.
+ */
+ @Test
+ public void testWriteResponseHeadersToNativeFailureShouldNotCallJNICallback() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ int expectedResponseCode = 300;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+ // Make the onResponseStarted JNI method fail when it receives the data.
+ requestHandle.onResponseStartedResult = false;
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // No callbacks should have been invoked.
+ assertThat(requestHandle.responseError).isNull();
+ assertThat(requestHandle.responseProto).isNull();
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ /**
+ * If something about the network InputStream throws an exception during response body download,
+ * then we should return an error to native.
+ */
+ @Test
+ public void testReceiveResponseBodyExceptionShouldResultInResponseBodyError() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ int expectedResponseCode = 300;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+ when(mockConnection.getHeaderFields())
+ .thenReturn(ImmutableMap.of("Response-Header1", ImmutableList.of("Bar")));
+
+ // Make the response body input stream throw an exception.
+ when(mockConnection.getInputStream()).thenThrow(new IOException("my error"));
+ when(mockConnection.getContentLength()).thenReturn(-1);
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Despite having hit an IOException during the response body download, we should still first
+ // have passed the response headers to native.
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(300)
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .build());
+ assertThat(requestHandle.responseError).isNull();
+ assertThat(requestHandle.responseBodyError).isNotNull();
+ assertThat(requestHandle.responseBodyError.getCode()).isEqualTo(Code.UNAVAILABLE_VALUE);
+ assertThat(requestHandle.responseBodyError.getMessage()).contains("IOException");
+ assertThat(requestHandle.responseBodyError.getMessage()).contains("my error");
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ /**
+ * If something fails when writing response body data to JNI, then we should *not* call any more
+ * JNI callbacks again, since the native layer will already have handled the error.
+ */
+ @Test
+ public void testWriteResponseBodyToNativeFailureShouldNotCallJNICallback() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ int expectedResponseCode = 300;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+ when(mockConnection.getHeaderFields())
+ .thenReturn(ImmutableMap.of("Response-Header1", ImmutableList.of("Bar")));
+
+ // Make the response body contain some data.
+ when(mockConnection.getInputStream())
+ .thenReturn(new ByteArrayInputStream("test_response".getBytes(UTF_8)));
+ when(mockConnection.getContentLength()).thenReturn(-1);
+ // But make the onResponseBody JNI method fail when it receives the data.
+ requestHandle.onResponseBodyResult = false;
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Despite having hit an IOException during the response body download, we should still first
+ // have passed the response headers to native.
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(300)
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .build());
+ // No callbacks should have been invoked after we received the response headers.
+ assertThat(requestHandle.responseError).isNull();
+ assertThat(requestHandle.responseBodyError).isNull();
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+ }
+
+ @Test
+ public void testCancellationDuringReceiveResponseBodyShouldResultInError() throws Exception {
+ TestHttpRequestHandleImpl requestHandle =
+ (TestHttpRequestHandleImpl)
+ httpClient.enqueueRequest(
+ JniHttpRequest.newBuilder()
+ .setUri("https://foo.com")
+ .setMethod(JniHttpMethod.HTTP_METHOD_GET)
+ .setHasBody(false)
+ .build()
+ .toByteArray());
+
+ HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+ when(urlConnectionFactory.createUrlConnection(any())).thenReturn(mockConnection);
+
+ int expectedResponseCode = 300;
+ when(mockConnection.getResponseCode()).thenReturn(expectedResponseCode);
+ when(mockConnection.getResponseMessage()).thenReturn("Multiple Choices");
+ when(mockConnection.getHeaderFields())
+ .thenReturn(
+ ImmutableMap.of(
+ "Response-Header1", ImmutableList.of("Bar"),
+ // The Content-Length header should be ignored, and should *not* be used to estimate
+ // the 'received bytes', since the request will not complete successfully.
+ "Content-Length", ImmutableList.of("9999")));
+
+ // Make the response body contain some data. But when the data gets read, the request gets
+ // cancelled.
+ String fakeResponseBody = "test_response";
+ when(mockConnection.getInputStream())
+ .thenAnswer(
+ invocation -> {
+ requestHandle.close();
+ return new ByteArrayInputStream(fakeResponseBody.getBytes(UTF_8));
+ });
+ when(mockConnection.getContentLength()).thenReturn(-1);
+
+ // Run the request.
+ byte[] result = httpClient.performRequests(new Object[] {requestHandle});
+ assertThat(Status.parseFrom(result, ExtensionRegistryLite.getEmptyRegistry()).getCode())
+ .isEqualTo(Code.OK_VALUE);
+
+ requestHandle.close();
+
+ // Despite having hit a cancellation during the response body download, we should still first
+ // have passed the response headers to native.
+ assertThat(requestHandle.responseProto)
+ .isEqualTo(
+ JniHttpResponse.newBuilder()
+ .setCode(300)
+ .addHeaders(
+ JniHttpHeader.newBuilder().setName("Response-Header1").setValue("Bar").build())
+ .build());
+ assertThat(requestHandle.responseError).isNull();
+ // The response body should not have been read to completion, since the request got cancelled
+ // in the middle of the read.
+ assertThat(requestHandle.responseBody.toString(UTF_8.name())).isNotEqualTo(fakeResponseBody);
+ assertThat(requestHandle.responseBodyError).isNotNull();
+ assertThat(requestHandle.responseBodyError.getCode()).isEqualTo(Code.CANCELLED_VALUE);
+ assertThat(requestHandle.completedSuccessfully).isFalse();
+
+ // Verify the network stats are accurate (they should count the request headers, URL, request
+ // method, request body, response headers and the *content length* rather than the observed
+ // response body).
+ assertThat(
+ JniHttpSentReceivedBytes.parseFrom(
+ requestHandle.getTotalSentReceivedBytes(),
+ ExtensionRegistryLite.getEmptyRegistry()))
+ .isEqualTo(
+ JniHttpSentReceivedBytes.newBuilder()
+ .setSentBytes("GET https://foo.com HTTP/1.1\r\n\r\n".length())
+ // The Content-Length response header value should not be taken into account in the
+ // estimated 'received bytes' stat, since the request did not succeed. Instead,
+ // by the time the HttpRequestHandleImpl#close() method is called we will be in the
+ // process of having read a single buffer's worth of response body data, and hence
+ // that's the amount of response body data that should be accounted for. This
+ // ensures that we try as best as possible to only count bytes we actually received
+ // up until the point of cancellation.
+ .setReceivedBytes(
+ ("HTTP/1.1 300 Multiple Choices\r\n"
+ + "Response-Header1: Bar\r\n"
+ + "Content-Length: 9999\r\n"
+ + "\r\n")
+ .length()
+ + DEFAULT_TEST_CHUNK_BUFFER_SIZE)
+ .build());
+ }
+}
diff --git a/fcp/jni/BUILD b/fcp/jni/BUILD
new file mode 100644
index 0000000..737b499
--- /dev/null
+++ b/fcp/jni/BUILD
@@ -0,0 +1,33 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+default_visibility = [
+ "//fcp:internal",
+]
+
+package(
+ default_visibility = default_visibility,
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "jni_util",
+ hdrs = ["jni_util.h"],
+ deps = [
+ "//fcp/base",
+ "@bazel_tools//tools/jdk:jni",
+ "@com_google_absl//absl/cleanup",
+ "@com_google_absl//absl/container:fixed_array",
+ ],
+)
diff --git a/fcp/jni/jni_util.h b/fcp/jni/jni_util.h
new file mode 100644
index 0000000..fada6f1
--- /dev/null
+++ b/fcp/jni/jni_util.h
@@ -0,0 +1,178 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_JNI_JNI_UTIL_H_
+#define FCP_JNI_JNI_UTIL_H_
+
+#include <jni.h>
+
+#include "absl/cleanup/cleanup.h"
+#include "absl/container/fixed_array.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace jni {
+
+// Creates a JNIEnv via the passed JavaVM*, attaching the current thread if it
+// is not already. If an attach was needed, detaches when this is destroyed.
+//
+// ScopedJniEnv must not be shared among threads and destructs on the same
+// thread.
+class ScopedJniEnv final {
+ public:
+ explicit ScopedJniEnv(JavaVM* jvm)
+ : jvm_(jvm), env_(nullptr), is_attached_(false) {
+ // We don't make any assumptions about the state of the current thread, and
+ // we want to leave it in the state we received it with respect to the
+ // JavaVm. So we only attach and detach when needed, and we always delete
+ // local references.
+ jint error = jvm_->GetEnv(reinterpret_cast<void**>(&env_), JNI_VERSION_1_2);
+ if (error != JNI_OK) {
+ error = AttachCurrentThread(jvm_, &env_);
+ FCP_CHECK(error == JNI_OK);
+ is_attached_ = true;
+ }
+ }
+
+ virtual ~ScopedJniEnv() {
+ if (is_attached_) {
+ (void)jvm_->DetachCurrentThread();
+ }
+ }
+
+ JNIEnv* env() { return env_; }
+
+ private:
+ template <typename JNIEnvArgType>
+ static jint AttachCurrentThreadImpl(JavaVM* vm,
+ jint (JavaVM::*fn)(JNIEnvArgType, void*),
+ JNIEnv** env) {
+ static_assert(std::is_same_v<JNIEnvArgType, void**> ||
+ std::is_same_v<JNIEnvArgType, JNIEnv**>);
+ return (vm->*fn)(reinterpret_cast<JNIEnvArgType>(env), nullptr);
+ }
+
+ static jint AttachCurrentThread(JavaVM* vm, JNIEnv** env) {
+ // The NDK and JDK versions of jni.h disagree on the signatures for the
+ // JavaVM::AttachCurrentThread member function (the former uses 'JavaVM*'
+ // and the latter uses 'void**'). To avoid causing linker errors when the
+ // JDK's jni.h is accidentally put on the include path during an Android
+ // build, we use the indirection below when calling the function. It's not
+ // sufficient to #ifdef around __ANDROID__, because whatever is including
+ // this header file might put the JDK jni.h version on the include path.
+ return AttachCurrentThreadImpl(vm, &JavaVM::AttachCurrentThread, env);
+ }
+
+ ScopedJniEnv(const ScopedJniEnv&) = delete;
+ void operator=(const ScopedJniEnv&) = delete;
+
+ JavaVM* jvm_;
+ JNIEnv* env_;
+ bool is_attached_;
+};
+
+// Parses a proto from a Java byte array.
+//
+// If any JNI calls fail, or if the parsing of the proto fails, then this
+// FCP_CHECK-fails.
+//
+// This method does not call `JNIEnv::DeleteLocalRef` on the given `jbyteArray`.
+//
+// This is meant to be used as a convenient way to use serialized protobufs as
+// part of a JNI API contract, since in such cases we can safely assume that the
+// input argument will always be a valid proto (and anything else would be a
+// programmer error).
+template <typename MessageT>
+static MessageT ParseProtoFromJByteArray(JNIEnv* env, jbyteArray byte_array) {
+ MessageT out_message;
+
+ jsize length = env->GetArrayLength(byte_array);
+ FCP_CHECK(!env->ExceptionCheck());
+
+ if (length == 0) {
+ return std::move(out_message);
+ }
+ // This will make a copy of the data into buffer, but generally the proto data
+ // will small enough that this shouldn't matter.
+ absl::FixedArray<jbyte> buffer(length);
+ env->GetByteArrayRegion(byte_array, 0, length, buffer.data());
+ FCP_CHECK(!env->ExceptionCheck());
+
+ FCP_CHECK(out_message.ParseFromArray(buffer.data(), length));
+
+ return std::move(out_message);
+}
+
+// Serializes a proto to a `jbyteArray`.
+//
+// The caller must call `JNIEnv::DeleteLocalRef` on the returned `jbyteArray`
+// once it is done with it.
+//
+// If any JNI calls fail, then this FCP_CHECK-fails.
+template <typename MessageT>
+static jbyteArray SerializeProtoToJByteArray(JNIEnv* env,
+ const MessageT& proto) {
+ int length = static_cast<int>(proto.ByteSizeLong());
+
+ jbyteArray byte_array = env->NewByteArray(length);
+ FCP_CHECK(byte_array != nullptr);
+ FCP_CHECK(!env->ExceptionCheck());
+
+ // This serializes into a buffer and then copies that buffer to the Java byte
+ // array. The proto data is generally small enough that this extra copy
+ // shouldn't matter.
+ absl::FixedArray<jbyte> buffer(length);
+ proto.SerializeToArray(buffer.data(), length);
+
+ env->SetByteArrayRegion(byte_array, 0, length, buffer.data());
+ FCP_CHECK(!env->ExceptionCheck());
+
+ return byte_array;
+}
+
+// Describes the method name and JNI method signature of a Java callback.
+struct JavaMethodSig {
+ char const* name;
+ char const* signature;
+};
+// Describes the field name and JNI type signature of a Java field.
+struct JavaFieldSig {
+ char const* name;
+ char const* signature;
+};
+
+// A utility for ensuring that a local JNI reference is deleted once the object
+// goes out of scope. This class is only intended to be used inside a function
+// body (and not to be returned or passed as an argument).
+class LocalRefDeleter {
+ public:
+ LocalRefDeleter(JNIEnv* env, jobject local_ref)
+ : env_(env), local_ref_(local_ref) {}
+ // Prevent copies & moves, to make it harder to accidentally have this object
+ // be passed as a parameter or return type.
+ LocalRefDeleter(LocalRefDeleter& other) = delete;
+ LocalRefDeleter(LocalRefDeleter&& other) = delete;
+ ~LocalRefDeleter() { env_->DeleteLocalRef(local_ref_); }
+
+ private:
+ JNIEnv* env_;
+ jobject local_ref_;
+};
+
+} // namespace jni
+} // namespace fcp
+
+#endif // FCP_JNI_JNI_UTIL_H_
diff --git a/fcp/patches/BUILD b/fcp/patches/BUILD
new file mode 100644
index 0000000..82bab3f
--- /dev/null
+++ b/fcp/patches/BUILD
@@ -0,0 +1 @@
+# This empty BUILD file is required to make Bazel treat this directory as a package.
diff --git a/fcp/patches/googleapis_longrunning.patch b/fcp/patches/googleapis_longrunning.patch
new file mode 100644
index 0000000..122ff28
--- /dev/null
+++ b/fcp/patches/googleapis_longrunning.patch
@@ -0,0 +1,16 @@
+--- google/longrunning/BUILD.bazel
++++ google/longrunning/BUILD.bazel
+@@ -209,3 +209,13 @@
+ ":longrunning_php_proto",
+ ],
+ )
++
++##############################################################################
++# Python
++##############################################################################
++load("@com_google_googleapis_imports//:imports.bzl", "py_proto_library")
++
++py_proto_library(
++ name = "longrunning_py_proto",
++ deps = [":operations_proto"],
++)
diff --git a/fcp/patches/googletest.patch b/fcp/patches/googletest.patch
new file mode 100644
index 0000000..c937e92
--- /dev/null
+++ b/fcp/patches/googletest.patch
@@ -0,0 +1,55 @@
+diff -Naur googletest-5a509dbd2e5a6c694116e329c5a20dc190653724/BUILD.bazel googletest.new/BUILD.bazel
+--- BUILD.bazel
++++ BUILD.bazel
+@@ -159,11 +159,14 @@
+ cc_library(
+ name = "gtest_main",
+ srcs = ["googlemock/src/gmock_main.cc"],
+ features = select({
+ ":windows": ["windows_export_all_symbols"],
+ "//conditions:default": [],
+ }),
+- deps = [":gtest"],
++ deps = select({
++ ":has_absl": ["@com_google_absl//absl/flags:parse"],
++ "//conditions:default": [],
++ }) + [":gtest"],
+ )
+
+ # The following rules build samples of how to use gTest.
+diff -Naur googletest-5a509dbd2e5a6c694116e329c5a20dc190653724/googlemock/src/gmock_main.cc googletest.new/googlemock/src/gmock_main.cc
+--- googlemock/src/gmock_main.cc
++++ googlemock/src/gmock_main.cc
+@@ -32,6 +32,9 @@
+
+ #include <iostream>
+
++#if GTEST_HAS_ABSL
++#include "absl/flags/parse.h"
++#endif // GTEST_HAS_ABSL
+ #include "gmock/gmock.h"
+ #include "gtest/gtest.h"
+
+@@ -70,6 +73,9 @@
+ // also responsible for initializing Google Test. Therefore there's
+ // no need for calling testing::InitGoogleTest() separately.
+ testing::InitGoogleMock(&argc, argv);
++#if GTEST_HAS_ABSL
++ absl::ParseCommandLine(argc, argv);
++#endif // GTEST_HAS_ABSL
+ return RUN_ALL_TESTS();
+ }
+ #endif
+diff -Naur googletest-5a509dbd2e5a6c694116e329c5a20dc190653724/googletest/src/gtest_main.cc googletest.new/googletest/src/gtest_main.cc
+--- googletest/src/gtest_main.cc
++++ googletest/src/gtest_main.cc
+@@ -50,6 +50,9 @@
+ GTEST_API_ int main(int argc, char **argv) {
+ printf("Running main() from %s\n", __FILE__);
+ testing::InitGoogleTest(&argc, argv);
++#if GTEST_HAS_ABSL
++ absl::ParseCommandLine(argc, argv);
++#endif // GTEST_HAS_ABSL
+ return RUN_ALL_TESTS();
+ }
+ #endif
diff --git a/fcp/patches/tensorflow_googleapis_proto_rules.patch b/fcp/patches/tensorflow_googleapis_proto_rules.patch
new file mode 100644
index 0000000..70903d0
--- /dev/null
+++ b/fcp/patches/tensorflow_googleapis_proto_rules.patch
@@ -0,0 +1,11 @@
+--- third_party/googleapis/repository_rules.bzl
++++ third_party/googleapis/repository_rules.patched.bzl
+@@ -34,6 +34,8 @@
+ switched_rules_by_language(
+ name = "com_google_googleapis_imports",
+ cc = True,
+ grpc = True,
++ java = True,
++ python = True,
+ rules_override = {
+ "cc_proto_library": [
diff --git a/fcp/patches/tensorflow_llvm_url.patch b/fcp/patches/tensorflow_llvm_url.patch
new file mode 100644
index 0000000..52386a1
--- /dev/null
+++ b/fcp/patches/tensorflow_llvm_url.patch
@@ -0,0 +1,23 @@
+diff --git third_party/llvm/workspace.bzl third_party/llvm/workspace.bzl
+index 038e0ee5fe5..4693f5cfadc 100644
+--- third_party/llvm/workspace.bzl
++++ third_party/llvm/workspace.bzl
+@@ -5,15 +5,15 @@ load("//third_party:repo.bzl", "tf_http_archive")
+ def repo(name):
+ """Imports LLVM."""
+ LLVM_COMMIT = "10939d1d580b9d3c9c2f3539c6bdb39f408179c0"
+- LLVM_SHA256 = "4adce5ef34c2062be0d7c5eb2a11606fa70690342e7e93327457ee2b6ad7ac72"
++ LLVM_SHA256 = "8f5201fc907d4faeb6f05eaa61a24504d8685eb92e5fec143bf981783711363f"
+
+ tf_http_archive(
+ name = name,
+ sha256 = LLVM_SHA256,
+- strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT),
++ strip_prefix = "llvm-llvm-project-{commit_partial}".format(commit_partial = LLVM_COMMIT[:7]),
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
+- "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
++ "https://api.github.com/repos/llvm/llvm-project/tarball/{commit}".format(commit = LLVM_COMMIT),
+ ],
+ build_file = "//third_party/llvm:llvm.BUILD",
+ patch_file = [
diff --git a/fcp/patches/tensorflow_pybind11_osx.patch b/fcp/patches/tensorflow_pybind11_osx.patch
new file mode 100644
index 0000000..4b02f65
--- /dev/null
+++ b/fcp/patches/tensorflow_pybind11_osx.patch
@@ -0,0 +1,11 @@
+--- third_party/pybind11.BUILD
++++ third_party/pybind11.BUILD
+@@ -23,3 +23,8 @@
+ "@org_tensorflow//third_party/python_runtime:headers",
+ ],
+ )
++
++config_setting(
++ name = "osx",
++ constraint_values = ["@platforms//os:osx"],
++)
diff --git a/fcp/patches/tensorflow_serving.patch b/fcp/patches/tensorflow_serving.patch
new file mode 100644
index 0000000..9808ef8
--- /dev/null
+++ b/fcp/patches/tensorflow_serving.patch
@@ -0,0 +1,25 @@
+diff --git a/tensorflow_serving/util/net_http/server/public/BUILD b/tensorflow_serving/util/net_http/server/public/BUILD
+index e7f96d98..2ae0530a 100644
+--- tensorflow_serving/util/net_http/server/public/BUILD
++++ tensorflow_serving/util/net_http/server/public/BUILD
+@@ -34,6 +34,7 @@ cc_library(
+ hdrs = [
+ "httpserver.h",
+ ],
++ visibility = ["//visibility:public"],
+ deps = [
+ ":http_server_api",
+ "//tensorflow_serving/util/net_http/server/internal:evhttp_server",
+diff --git a/tensorflow_serving/workspace.bzl b/tensorflow_serving/workspace.bzl
+index 08c3cc28..0803cdf3 100644
+--- tensorflow_serving/workspace.bzl
++++ tensorflow_serving/workspace.bzl
+@@ -31,7 +31,7 @@ def tf_serving_workspace():
+ url = "https://github.com/libevent/libevent/archive/release-2.1.8-stable.zip",
+ sha256 = "70158101eab7ed44fd9cc34e7f247b3cae91a8e4490745d9d6eb7edc184e4d96",
+ strip_prefix = "libevent-release-2.1.8-stable",
+- build_file = "@//third_party/libevent:BUILD",
++ build_file = "@//third_party:event.BUILD.bzl",
+ )
+
+ # ===== ICU dependency =====
diff --git a/fcp/patches/tensorflow_tf_custom_op_py_library.patch b/fcp/patches/tensorflow_tf_custom_op_py_library.patch
new file mode 100644
index 0000000..50cfb63
--- /dev/null
+++ b/fcp/patches/tensorflow_tf_custom_op_py_library.patch
@@ -0,0 +1,10 @@
+--- tensorflow/tensorflow.bzl
++++ tensorflow/tensorflow.bzl
+@@ -1320,7 +1320,6 @@
+ srcs_version = "PY3",
+ visibility = visibility,
+ deps = [
+- clean_dep("//tensorflow/python:framework_for_generated_wrappers_v2"),
+ ],
+ # Instruct build_cleaner to try to avoid using this rule; typically ops
+ # creators will provide their own tf_custom_op_py_library based target
diff --git a/fcp/patches/tensorflow_zlib.patch b/fcp/patches/tensorflow_zlib.patch
new file mode 100644
index 0000000..c510bee
--- /dev/null
+++ b/fcp/patches/tensorflow_zlib.patch
@@ -0,0 +1,11 @@
+--- third_party/zlib.BUILD
++++ third_party/zlib.BUILD
+@@ -31,7 +31,7 @@
+ "zutil.c",
+ "zutil.h",
+ ],
+- hdrs = ["zlib.h"],
++ hdrs = ["zconf.h", "zlib.h"],
+ copts = select({
+ "@org_tensorflow//tensorflow:windows": [],
+ "//conditions:default": [
diff --git a/fcp/protocol/BUILD b/fcp/protocol/BUILD
new file mode 100644
index 0000000..aec7ff0
--- /dev/null
+++ b/fcp/protocol/BUILD
@@ -0,0 +1,55 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "grpc_chunked_bidi_stream",
+ hdrs = ["grpc_chunked_bidi_stream.h"],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/base",
+ "//fcp/protos:cc_grpc",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/status",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "grpc_chunked_bidi_stream_test",
+ timeout = "long",
+ srcs = ["grpc_chunked_bidi_stream_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":grpc_chunked_bidi_stream",
+ "//fcp/base",
+ "//fcp/client:fake_server",
+ "//fcp/client:grpc_bidi_stream",
+ "//fcp/protos:federated_api_cc_proto",
+ "//fcp/testing",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/protocol/grpc_chunked_bidi_stream.h b/fcp/protocol/grpc_chunked_bidi_stream.h
new file mode 100644
index 0000000..8485126
--- /dev/null
+++ b/fcp/protocol/grpc_chunked_bidi_stream.h
@@ -0,0 +1,484 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_PROTOCOL_GRPC_CHUNKED_BIDI_STREAM_H_
+#define FCP_PROTOCOL_GRPC_CHUNKED_BIDI_STREAM_H_
+
+#include <stddef.h>
+
+#include <algorithm>
+#include <deque>
+#include <memory>
+#include <string>
+
+#include "absl/base/attributes.h"
+#include "absl/status/status.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/protos/federated_api.grpc.pb.h"
+#include "grpcpp/impl/codegen/call_op_set.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
+#include "google/protobuf/io/gzip_stream.h"
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+
+namespace fcp {
+namespace client {
+
+/**
+ * A class which implements the chunking protocol for the federated learning
+ * API.
+ *
+ * Can be used by both client and server.
+ *
+ * @tparam Outgoing The type of the outgoing protocol buffer message.
+ * @tparam Incoming The type of the incoming protocol buffer message.
+ */
+template <typename Outgoing, typename Incoming>
+class GrpcChunkedBidiStream {
+ public:
+ struct GrpcChunkedBidiStreamOptions {
+ int32_t chunk_size_for_upload = -1;
+ int32_t max_pending_chunks = -1;
+ google::internal::federatedml::v2::CompressionLevel compression_level{};
+ };
+ GrpcChunkedBidiStream(
+ grpc::internal::WriterInterface<Outgoing>* writer_interface,
+ grpc::internal::ReaderInterface<Incoming>* reader_interface);
+ GrpcChunkedBidiStream(
+ grpc::internal::WriterInterface<Outgoing>* writer_interface,
+ grpc::internal::ReaderInterface<Incoming>* reader_interface,
+ GrpcChunkedBidiStreamOptions options);
+ virtual ~GrpcChunkedBidiStream() = default;
+
+ // GrpcChunkedBidiStream is neither copyable nor movable.
+ GrpcChunkedBidiStream(const GrpcChunkedBidiStream&) = delete;
+ GrpcChunkedBidiStream& operator=(const GrpcChunkedBidiStream&) = delete;
+
+ ABSL_MUST_USE_RESULT absl::Status Send(Outgoing* message);
+ ABSL_MUST_USE_RESULT absl::Status Receive(Incoming* message);
+ void Close();
+ int64_t ChunkingLayerBytesSent();
+ int64_t ChunkingLayerBytesReceived();
+
+ private:
+ ABSL_MUST_USE_RESULT absl::Status TryDecorateCheckinRequest(
+ Outgoing* message);
+ ABSL_MUST_USE_RESULT absl::Status ChunkMessage(const Outgoing& message);
+ ABSL_MUST_USE_RESULT absl::Status TrySendPending();
+ ABSL_MUST_USE_RESULT absl::Status TrySend(const Outgoing& message);
+ ABSL_MUST_USE_RESULT absl::Status SendAck(int32_t chunk_index);
+ ABSL_MUST_USE_RESULT absl::Status SendRaw(const Outgoing& message,
+ bool disable_compression = false);
+ ABSL_MUST_USE_RESULT absl::Status TrySnoopCheckinResponse(Incoming* message);
+ ABSL_MUST_USE_RESULT absl::Status TryAssemblePending(Incoming* message,
+ bool* message_assembled);
+ ABSL_MUST_USE_RESULT absl::Status AssemblePending(Incoming* message,
+ bool* message_assembled);
+ ABSL_MUST_USE_RESULT absl::Status ReceiveRaw(Incoming* message);
+
+ grpc::internal::WriterInterface<Outgoing>* writer_interface_;
+ grpc::internal::ReaderInterface<Incoming>* reader_interface_;
+
+ struct {
+ int32_t uncompressed_size = -1;
+ google::internal::federatedml::v2::CompressionLevel compression_level{};
+ int32_t blob_size_bytes = -1;
+ std::deque<std::string> deque;
+ std::string composite;
+ int64_t total_bytes_downloaded = 0;
+ } incoming_;
+
+ struct {
+ int32_t chunk_size_for_upload = 0;
+ int32_t max_pending_chunks = 0;
+ int32_t pending_chunks = 0;
+ google::internal::federatedml::v2::CompressionLevel compression_level{};
+ std::deque<std::unique_ptr<Outgoing>> deque;
+ int64_t total_bytes_uploaded = 0;
+
+ google::internal::federatedml::v2::ChunkedTransferMessage* Add() {
+ deque.push_back(std::make_unique<Outgoing>());
+ return deque.back()->mutable_chunked_transfer();
+ }
+ } outgoing_;
+};
+
+#define COMMON_USING_DIRECTIVES \
+ using google::internal::federatedml::v2::ChunkedTransferMessage; \
+ using google::internal::federatedml::v2::ClientStreamMessage; \
+ using google::internal::federatedml::v2::CompressionLevel; \
+ using google::internal::federatedml::v2::ServerStreamMessage; \
+ using google::protobuf::io::ArrayInputStream; \
+ using google::protobuf::io::StringOutputStream; \
+ using google::protobuf::io::GzipInputStream; \
+ using google::protobuf::io::GzipOutputStream; \
+ using google::protobuf::io::ZeroCopyOutputStream;
+
+template <typename Outgoing, typename Incoming>
+GrpcChunkedBidiStream<Outgoing, Incoming>::GrpcChunkedBidiStream(
+ grpc::internal::WriterInterface<Outgoing>* writer_interface,
+ grpc::internal::ReaderInterface<Incoming>* reader_interface)
+ : GrpcChunkedBidiStream(writer_interface, reader_interface,
+ GrpcChunkedBidiStreamOptions()) {}
+
+template <typename Outgoing, typename Incoming>
+GrpcChunkedBidiStream<Outgoing, Incoming>::GrpcChunkedBidiStream(
+ grpc::internal::WriterInterface<Outgoing>* writer_interface,
+ grpc::internal::ReaderInterface<Incoming>* reader_interface,
+ GrpcChunkedBidiStreamOptions options)
+ : writer_interface_(writer_interface), reader_interface_(reader_interface) {
+ outgoing_.chunk_size_for_upload = options.chunk_size_for_upload;
+ outgoing_.max_pending_chunks = options.max_pending_chunks;
+ outgoing_.compression_level = options.compression_level;
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::Send(
+ Outgoing* message) {
+ COMMON_USING_DIRECTIVES;
+ FCP_RETURN_IF_ERROR(TryDecorateCheckinRequest(message));
+ switch (message->kind_case()) {
+ case Outgoing::KindCase::kChunkedTransfer:
+ Close();
+ return absl::InvalidArgumentError(
+ absl::StrCat("Message is pre-chunked: ", message->DebugString()));
+ default:
+ break;
+ }
+
+ return TrySend(*message);
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::Receive(
+ Incoming* message) {
+ COMMON_USING_DIRECTIVES;
+ Status status;
+ bool message_assembled = false;
+
+ do {
+ FCP_RETURN_IF_ERROR(status = ReceiveRaw(message));
+ switch (message->kind_case()) {
+ case Incoming::KindCase::kChunkedTransfer:
+ if (message->chunked_transfer().kind_case() ==
+ ChunkedTransferMessage::kAck) {
+ --outgoing_.pending_chunks;
+ FCP_RETURN_IF_ERROR(status = TrySendPending());
+ } else {
+ FCP_RETURN_IF_ERROR(
+ status = TryAssemblePending(message, &message_assembled));
+ }
+ break;
+ default:
+ if (incoming_.uncompressed_size != -1)
+ return absl::InvalidArgumentError("Chunk reassembly in progress.");
+ message_assembled = true;
+ break;
+ }
+ } while (!message_assembled);
+
+ FCP_RETURN_IF_ERROR(status = TrySnoopCheckinResponse(message));
+ return status;
+}
+
+template <>
+inline absl::Status
+GrpcChunkedBidiStream<google::internal::federatedml::v2::ClientStreamMessage,
+ google::internal::federatedml::v2::ServerStreamMessage>::
+ TryDecorateCheckinRequest(
+ google::internal::federatedml::v2::ClientStreamMessage* message) {
+ COMMON_USING_DIRECTIVES;
+ if (message->kind_case() !=
+ ClientStreamMessage::kEligibilityEvalCheckinRequest &&
+ message->kind_case() != ClientStreamMessage::kCheckinRequest)
+ return absl::OkStatus();
+ // Both an EligibilityEvalCheckinRequest or a CheckinRequest message need to
+ // specify a ProtocolOptionsRequest message.
+ auto options = (message->has_eligibility_eval_checkin_request()
+ ? message->mutable_eligibility_eval_checkin_request()
+ ->mutable_protocol_options_request()
+ : message->mutable_checkin_request()
+ ->mutable_protocol_options_request());
+ options->set_supports_chunked_blob_transfer(true);
+ options->add_supported_compression_levels(CompressionLevel::UNCOMPRESSED);
+ options->add_supported_compression_levels(CompressionLevel::ZLIB_DEFAULT);
+ options->add_supported_compression_levels(
+ CompressionLevel::ZLIB_BEST_COMPRESSION);
+ options->add_supported_compression_levels(CompressionLevel::ZLIB_BEST_SPEED);
+ return absl::OkStatus();
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status
+GrpcChunkedBidiStream<Outgoing, Incoming>::TryDecorateCheckinRequest(
+ Outgoing*) {
+ return absl::OkStatus();
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::ChunkMessage(
+ const Outgoing& message) {
+ COMMON_USING_DIRECTIVES;
+
+ auto start = outgoing_.Add()->mutable_start();
+ start->set_compression_level(outgoing_.compression_level);
+
+ // TODO(team): Replace with a more efficient serialization mechanism.
+ std::string output;
+ if (outgoing_.compression_level == CompressionLevel::UNCOMPRESSED) {
+ if (!message.AppendToString(&output))
+ return absl::InternalError("Could not append to string.");
+ } else {
+ StringOutputStream string_output_stream(&output);
+ GzipOutputStream::Options options;
+ options.format = GzipOutputStream::ZLIB;
+ switch (outgoing_.compression_level) {
+ case CompressionLevel::ZLIB_DEFAULT:
+ options.compression_level = Z_DEFAULT_COMPRESSION;
+ break;
+ case CompressionLevel::ZLIB_BEST_COMPRESSION:
+ options.compression_level = Z_BEST_COMPRESSION;
+ break;
+ case CompressionLevel::ZLIB_BEST_SPEED:
+ options.compression_level = Z_BEST_SPEED;
+ break;
+ default:
+ Close();
+ return absl::InternalError("Unsupported compression level.");
+ }
+ GzipOutputStream compressed_stream(&string_output_stream, options);
+ if (!message.SerializeToZeroCopyStream(&compressed_stream) ||
+ !compressed_stream.Close())
+ return absl::InvalidArgumentError(
+ absl::StrCat("Failed to serialize message: ",
+ compressed_stream.ZlibErrorMessage()));
+ }
+
+ auto blob_size_bytes = static_cast<int32_t>(output.size());
+ int32_t chunk_index = 0;
+ if (!blob_size_bytes) blob_size_bytes = 1; // Force one empty packet.
+ for (size_t offset = 0; offset < blob_size_bytes;
+ offset += std::min(blob_size_bytes, outgoing_.chunk_size_for_upload),
+ ++chunk_index) {
+ auto data = outgoing_.Add()->mutable_data();
+ data->set_chunk_index(chunk_index);
+ data->set_chunk_bytes(output.substr(
+ offset, static_cast<size_t>(outgoing_.chunk_size_for_upload)));
+ }
+
+ start->set_uncompressed_size(static_cast<int32_t>(message.ByteSizeLong()));
+ start->set_blob_size_bytes(blob_size_bytes);
+
+ auto end = outgoing_.Add()->mutable_end();
+ end->set_chunk_count(chunk_index);
+ return absl::OkStatus();
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TrySendPending() {
+ COMMON_USING_DIRECTIVES;
+ auto status = absl::OkStatus();
+ while (!outgoing_.deque.empty() &&
+ outgoing_.pending_chunks < outgoing_.max_pending_chunks) {
+ auto& front = outgoing_.deque.front();
+ FCP_RETURN_IF_ERROR(status =
+ SendRaw(*front, outgoing_.compression_level > 0));
+ if (front->chunked_transfer().kind_case() == ChunkedTransferMessage::kData)
+ ++outgoing_.pending_chunks;
+ outgoing_.deque.pop_front();
+ }
+ return status;
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TrySend(
+ const Outgoing& message) {
+ COMMON_USING_DIRECTIVES;
+ if (outgoing_.chunk_size_for_upload <= 0 || outgoing_.max_pending_chunks <= 0)
+ return SendRaw(message); // No chunking.
+ absl::Status status;
+ if (!(status = ChunkMessage(message)).ok()) {
+ Close();
+ return status;
+ }
+ return TrySendPending();
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::SendAck(
+ int32_t chunk_index) {
+ Outgoing ack;
+ ack.mutable_chunked_transfer()->mutable_ack()->set_chunk_index(chunk_index);
+ return SendRaw(ack);
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::SendRaw(
+ const Outgoing& message, bool disable_compression) {
+ if (!writer_interface_)
+ return absl::FailedPreconditionError("Send on closed stream.");
+ grpc::WriteOptions write_options;
+ if (disable_compression) write_options.set_no_compression();
+ outgoing_.total_bytes_uploaded += message.ByteSizeLong();
+ if (!writer_interface_->Write(message, write_options)) {
+ Close();
+ return absl::AbortedError("End of stream.");
+ }
+ return absl::OkStatus();
+}
+
+// If this class is used on the client side, we need to break the abstraction
+// that messages are opaque in order to read the chunking parameters sent by the
+// server to determine how to carry out the remainder of the protocol.
+// Inspect the checkin response to record these chunking options.
+template <>
+inline absl::Status
+GrpcChunkedBidiStream<google::internal::federatedml::v2::ClientStreamMessage,
+ google::internal::federatedml::v2::ServerStreamMessage>::
+ TrySnoopCheckinResponse(
+ google::internal::federatedml::v2::ServerStreamMessage* message) {
+ COMMON_USING_DIRECTIVES;
+ if (message->kind_case() !=
+ ServerStreamMessage::kEligibilityEvalCheckinResponse &&
+ message->kind_case() != ServerStreamMessage::kCheckinResponse)
+ return absl::OkStatus();
+ if (incoming_.uncompressed_size != -1)
+ return absl::InvalidArgumentError("Chunk reassembly in progress.");
+ // We adopt any new protocol options we may receive, even if we previously
+ // received some options already. I.e. a ProtocolOptionsResponse received in a
+ // CheckinResponse will overwrite any ProtocolOptionsResponse that was
+ // previously received in a EligibilityEvalCheckinResponse.
+ // OTOH, we also don't require that every EligibilityEvalCheckinResponse or
+ // CheckinResponse message actually has a ProtocolOptionsResponse message set
+ // (e.g. CheckinResponse may not have a ProtocolOptionsResponse if one was
+ // already returned inside a prior EligibilityEvalCheckinResponse).
+ if (message->eligibility_eval_checkin_response()
+ .has_protocol_options_response() ||
+ message->checkin_response().has_protocol_options_response()) {
+ auto options =
+ (message->has_eligibility_eval_checkin_response()
+ ? message->eligibility_eval_checkin_response()
+ .protocol_options_response()
+ : message->checkin_response().protocol_options_response());
+ outgoing_.chunk_size_for_upload = options.chunk_size_for_upload();
+ outgoing_.max_pending_chunks = options.max_pending_chunks();
+ outgoing_.compression_level = options.compression_level();
+ }
+ return absl::OkStatus();
+}
+
+// If this class is being used by the server, this is a no-op as the server
+// determines the chunking options.
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TrySnoopCheckinResponse(
+ Incoming*) {
+ return absl::OkStatus();
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::TryAssemblePending(
+ Incoming* message, bool* message_assembled) {
+ COMMON_USING_DIRECTIVES;
+ *message_assembled = false;
+ auto chunk = message->chunked_transfer();
+ switch (chunk.kind_case()) {
+ case ChunkedTransferMessage::kStart:
+ if (!incoming_.deque.empty() || incoming_.uncompressed_size != -1)
+ return absl::InternalError("Unexpected Start.");
+ incoming_.uncompressed_size = chunk.start().uncompressed_size();
+ incoming_.compression_level = chunk.start().compression_level();
+ incoming_.blob_size_bytes = chunk.start().blob_size_bytes();
+ break;
+ case ChunkedTransferMessage::kData:
+ if (chunk.data().chunk_index() != incoming_.deque.size())
+ return absl::InternalError("Unexpected Data.");
+ incoming_.deque.emplace_back(chunk.data().chunk_bytes());
+ incoming_.composite.append(incoming_.deque.back());
+ return SendAck(static_cast<int32_t>(incoming_.deque.size() - 1));
+ case ChunkedTransferMessage::kEnd:
+ if (incoming_.deque.empty() ||
+ chunk.end().chunk_count() != incoming_.deque.size())
+ return absl::InternalError("Unexpected End.");
+ return AssemblePending(message, message_assembled);
+ case ChunkedTransferMessage::kAck:
+ return absl::InternalError("Unexpected Ack.");
+ default:
+ return absl::InternalError(
+ absl::StrCat("Unexpected message subtype: ",
+ message->chunked_transfer().kind_case()));
+ }
+
+ return absl::OkStatus();
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::AssemblePending(
+ Incoming* message, bool* message_assembled) {
+ COMMON_USING_DIRECTIVES;
+ // TODO(team): Replace with a more efficient deserialization mechanism.
+ if (incoming_.compression_level == CompressionLevel::UNCOMPRESSED) {
+ if (!message->ParseFromString(incoming_.composite))
+ return absl::InternalError(absl::StrCat("Could not parse from string. ",
+ incoming_.composite.size()));
+ } else {
+ ArrayInputStream string_input_stream(
+ incoming_.composite.c_str(),
+ static_cast<int>(incoming_.composite.size()));
+ GzipInputStream compressed_stream(&string_input_stream);
+ if (!message->ParseFromZeroCopyStream(&compressed_stream))
+ return absl::InternalError("Could not parse proto from input stream.");
+ }
+ *message_assembled = true;
+ incoming_.uncompressed_size = -1;
+ incoming_.blob_size_bytes = -1;
+ incoming_.deque.clear();
+ incoming_.composite.clear();
+ return absl::OkStatus();
+}
+
+template <typename Outgoing, typename Incoming>
+absl::Status GrpcChunkedBidiStream<Outgoing, Incoming>::ReceiveRaw(
+ Incoming* message) {
+ if (!reader_interface_)
+ return absl::FailedPreconditionError("Receive on closed stream.");
+ if (!reader_interface_->Read(message)) {
+ Close();
+ return absl::AbortedError("End of stream.");
+ }
+ incoming_.total_bytes_downloaded += message->ByteSizeLong();
+ return absl::OkStatus();
+}
+
+template <typename Outgoing, typename Incoming>
+void GrpcChunkedBidiStream<Outgoing, Incoming>::Close() {
+ writer_interface_ = nullptr;
+ reader_interface_ = nullptr;
+}
+
+template <typename Outgoing, typename Incoming>
+int64_t
+GrpcChunkedBidiStream<Outgoing, Incoming>::ChunkingLayerBytesReceived() {
+ return incoming_.total_bytes_downloaded;
+}
+
+template <typename Outgoing, typename Incoming>
+int64_t GrpcChunkedBidiStream<Outgoing, Incoming>::ChunkingLayerBytesSent() {
+ return outgoing_.total_bytes_uploaded;
+}
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_PROTOCOL_GRPC_CHUNKED_BIDI_STREAM_H_
diff --git a/fcp/protocol/grpc_chunked_bidi_stream_test.cc b/fcp/protocol/grpc_chunked_bidi_stream_test.cc
new file mode 100644
index 0000000..e3dd968
--- /dev/null
+++ b/fcp/protocol/grpc_chunked_bidi_stream_test.cc
@@ -0,0 +1,330 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/protocol/grpc_chunked_bidi_stream.h"
+
+#include <cctype>
+#include <string>
+#include <tuple>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/fake_server.h"
+#include "fcp/client/grpc_bidi_stream.h"
+#include "fcp/protos/federated_api.pb.h"
+#include "fcp/testing/testing.h"
+#include "grpcpp/security/server_credentials.h"
+#include "grpcpp/server.h"
+#include "grpcpp/server_builder.h"
+
+namespace fcp {
+namespace client {
+namespace test {
+namespace {
+
+using google::internal::federatedml::v2::ClientStreamMessage;
+using google::internal::federatedml::v2::CompressionLevel;
+using google::internal::federatedml::v2::ServerStreamMessage;
+using ::testing::Gt;
+using ::testing::Le;
+using ::testing::Not;
+
+std::string SimpleSelfVerifyingString(size_t n) {
+ std::string str;
+ str.reserve(n);
+ for (auto i = 0; i < n; ++i) str.push_back(static_cast<char>(i % 128));
+ return str;
+}
+
+bool VerifyString(const std::string str) {
+ auto n = str.length();
+ for (auto i = 0; i < n; ++i) {
+ if (str[i] != static_cast<char>(i % 128)) return false;
+ }
+ return true;
+}
+
+Status PerformInitialCheckin(GrpcBidiStream* stream) {
+ ClientStreamMessage request_;
+ ServerStreamMessage reply;
+ Status status;
+
+ auto options =
+ request_.mutable_checkin_request()->mutable_protocol_options_request();
+ options->set_supports_chunked_blob_transfer(true);
+ options->add_supported_compression_levels(CompressionLevel::UNCOMPRESSED);
+ options->add_supported_compression_levels(CompressionLevel::ZLIB_DEFAULT);
+ options->add_supported_compression_levels(
+ CompressionLevel::ZLIB_BEST_COMPRESSION);
+ options->add_supported_compression_levels(CompressionLevel::ZLIB_BEST_SPEED);
+
+ EXPECT_THAT((status = stream->Send(&request_)), IsOk());
+ EXPECT_THAT((status = stream->Receive(&reply)), IsOk());
+ EXPECT_TRUE(reply.has_checkin_response()) << reply.DebugString();
+
+ return status;
+}
+
+using ChunkingParameters =
+ std::tuple<int32_t, /* chunk_size_for_upload */
+ int32_t, /* max_pending_chunks */
+ google::internal::federatedml::v2::CompressionLevel,
+ int32_t, /* request size */
+ size_t, /* request count */
+ int32_t, /* reply size */
+ size_t /* replies per request */
+ >;
+
+class ByteVerifyingFakeServer : public FakeServer {
+ public:
+ explicit ByteVerifyingFakeServer(const ChunkingParameters& params)
+ : FakeServer(std::get<0>(params), std::get<1>(params),
+ std::get<2>(params)),
+ reply_size_(std::get<5>(params)),
+ replies_per_request_(std::get<6>(params)) {}
+
+ Status Handle(const ClientStreamMessage& request,
+ ServerStreamMessage* first_reply,
+ GrpcChunkedBidiStream<ServerStreamMessage, ClientStreamMessage>*
+ stream) override {
+ Status status;
+ if (request.has_checkin_request()) {
+ EXPECT_THAT((status = stream->Send(first_reply)), IsOk());
+ return status;
+ }
+ EXPECT_TRUE(
+ VerifyString(request.report_request().report().update_checkpoint()));
+ ServerStreamMessage reply;
+ reply.mutable_report_response()->mutable_retry_window()->set_retry_token(
+ SimpleSelfVerifyingString(reply_size_));
+ for (auto i = 0; i < replies_per_request_; ++i)
+ EXPECT_THAT((status = stream->Send(&reply)), IsOk());
+ return status;
+ }
+
+ private:
+ int32_t reply_size_;
+ size_t replies_per_request_;
+};
+
+class GrpcChunkedMessageStreamTest
+ : public ::testing::TestWithParam<ChunkingParameters> {
+ public:
+ GrpcChunkedMessageStreamTest() : server_impl_(GetParam()) {
+ auto params = GetParam();
+ request_size_ = std::get<3>(params);
+ request_count_ = std::get<4>(params);
+ reply_size_ = std::get<5>(params);
+ replies_per_request_ = std::get<6>(params);
+
+ grpc::ServerBuilder builder;
+ builder.AddListeningPort("dns:///localhost:0",
+ grpc::InsecureServerCredentials(), &port_);
+ builder.RegisterService(&server_impl_);
+ grpc_server_ = builder.BuildAndStart();
+ client_stream_ =
+ std::make_unique<GrpcBidiStream>(addr_uri(), "none", "",
+ /*grpc_channel_deadline_seconds=*/600);
+ EXPECT_THAT(PerformInitialCheckin(client_stream_.get()), IsOk());
+
+ request_.mutable_report_request()->mutable_report()->set_update_checkpoint(
+ SimpleSelfVerifyingString(request_size_));
+ }
+
+ std::string addr_uri() { return absl::StrCat(kAddrUri, ":", port_); }
+
+ int32_t request_size_;
+ size_t request_count_;
+ int32_t reply_size_;
+ size_t replies_per_request_;
+
+ static constexpr char kAddrUri[] = "dns:///localhost";
+ ByteVerifyingFakeServer server_impl_;
+ int port_ = -1;
+ std::unique_ptr<grpc::Server> grpc_server_;
+ std::unique_ptr<GrpcBidiStream> client_stream_;
+
+ ClientStreamMessage request_;
+ ServerStreamMessage reply_;
+};
+
+TEST_P(GrpcChunkedMessageStreamTest, RequestReply) {
+ for (size_t i = 0; i < request_count_; ++i) {
+ EXPECT_THAT(client_stream_->Send(&request_), IsOk());
+ for (size_t i = 0; i < replies_per_request_; ++i) {
+ EXPECT_THAT(client_stream_->Receive(&reply_), IsOk());
+ EXPECT_TRUE(
+ VerifyString(reply_.report_response().retry_window().retry_token()));
+ }
+ }
+ client_stream_->Close();
+ EXPECT_THAT(client_stream_->Receive(&reply_), Not(IsOk()));
+}
+
+TEST_P(GrpcChunkedMessageStreamTest, RequestReplyChunkingLayerBandwidth) {
+ int64_t bytes_sent_so_far = client_stream_->ChunkingLayerBytesSent();
+ int64_t bytes_received_so_far = client_stream_->ChunkingLayerBytesReceived();
+ for (size_t i = 0; i < request_count_; ++i) {
+ EXPECT_THAT(client_stream_->Send(&request_), IsOk());
+ int64_t request_message_size = request_.ByteSizeLong();
+ // Sends may be deferred if flow control has paused the stream; in this
+ // case, they will not be recorded in statistics until they are sent as part
+ // of the next Receive(). Therefore, we assert sizes after the receives.
+
+ for (size_t i = 0; i < replies_per_request_; ++i) {
+ EXPECT_THAT(client_stream_->Receive(&reply_), IsOk());
+ int64_t bytes_received_delta =
+ client_stream_->ChunkingLayerBytesReceived() - bytes_received_so_far;
+ EXPECT_THAT(bytes_received_delta, Gt(0));
+ int64_t receive_message_size = reply_.ByteSizeLong();
+ // Small messages may actually be expanded due to compression overhead.
+ if (receive_message_size > 64) {
+ EXPECT_THAT(bytes_received_delta, Le(receive_message_size));
+ }
+ bytes_received_so_far += bytes_received_delta;
+ }
+
+ int64_t bytes_sent_delta =
+ client_stream_->ChunkingLayerBytesSent() - bytes_sent_so_far;
+ EXPECT_THAT(client_stream_->ChunkingLayerBytesSent(), Gt(0));
+ EXPECT_THAT(bytes_sent_delta, Gt(0));
+ // Small messages may actually be expanded due to compression overhead.
+ if (request_message_size > 64) {
+ EXPECT_THAT(bytes_sent_delta, Le(request_message_size));
+ }
+ bytes_sent_so_far += bytes_sent_delta;
+ }
+ client_stream_->Close();
+ EXPECT_THAT(client_stream_->Receive(&reply_), Not(IsOk()));
+}
+
+// #define GRPC_CHUNKED_EXPENSIVE_COMPRESSED_TESTS
+#if defined(GRPC_CHUNKED_EXPENSIVE_COMPRESSED_TESTS)
+// Ideally we would generate a covering array rather than a Cartesian product.
+INSTANTIATE_TEST_SUITE_P(
+ CartesianProductExpensive, GrpcChunkedMessageStreamTest,
+ testing::Combine(
+ /* chunk_size_for_upload */
+ testing::ValuesIn({0, 1, 129}),
+ /* max_pending_chunks */
+ testing::ValuesIn({0, 1, 129}),
+ /* compression_level */
+ testing::ValuesIn({CompressionLevel::ZLIB_DEFAULT}),
+ /* request size */
+ testing::ValuesIn({0, 1, 129}),
+ /* request count */
+ testing::ValuesIn({1ul, 129ul}),
+ /* reply size */
+ testing::ValuesIn({0, 1, 129}),
+ /* replies per request */
+ testing::ValuesIn({1ul, 129ul})),
+ [](const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>&
+ info) {
+ // clang-format off
+ std::string name = absl::StrCat(
+ std::get<0>(info.param), "csfu" "_",
+ std::get<1>(info.param), "mpc", "_",
+ std::get<2>(info.param), "cl", "_",
+ std::get<3>(info.param), "rqs", "_",
+ std::get<4>(info.param), "rqc", "_",
+ std::get<5>(info.param), "rps", "_",
+ std::get<6>(info.param), "rppr");
+ absl::c_replace_if(
+ name, [](char c) { return !std::isalnum(c); }, '_');
+ // clang-format on
+ return name;
+ });
+#endif
+
+// #define GRPC_CHUNKED_EXPENSIVE_UNCOMPRESSED_TESTS
+#if defined(GRPC_CHUNKED_EXPENSIVE_UNCOMPRESSED_TESTS)
+// Ideally we would generate a covering array rather than a Cartesian product.
+INSTANTIATE_TEST_SUITE_P(
+ CartesianProductUncompressed, GrpcChunkedMessageStreamTest,
+ testing::Combine(
+ /* chunk_size_for_upload */
+ testing::ValuesIn({0, 1, 129}),
+ /* max_pending_chunks */
+ testing::ValuesIn({0, 1, 129}),
+ /* compression_level */
+ testing::ValuesIn({CompressionLevel::UNCOMPRESSED}),
+ /* request size */
+ testing::ValuesIn({0, 1, 129}),
+ /* request count */
+ testing::ValuesIn({1ul, 129ul}),
+ /* reply size */
+ testing::ValuesIn({0, 1, 129}),
+ /* replies per request */
+ testing::ValuesIn({1ul, 129ul})),
+ [](const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>&
+ info) {
+ // clang-format off
+ std::string name = absl::StrCat(
+ std::get<0>(info.param), "csfu" "_",
+ std::get<1>(info.param), "mpc", "_",
+ std::get<2>(info.param), "cl", "_",
+ std::get<3>(info.param), "rqs", "_",
+ std::get<4>(info.param), "rqc", "_",
+ std::get<5>(info.param), "rps", "_",
+ std::get<6>(info.param), "rppr");
+ absl::c_replace_if(
+ name, [](char c) { return !std::isalnum(c); }, '_');
+ // clang-format on
+ return name;
+ });
+#endif
+
+INSTANTIATE_TEST_SUITE_P(
+ CartesianProductLargeChunks, GrpcChunkedMessageStreamTest,
+ testing::Combine(
+ /* chunk_size_for_upload */
+ testing::ValuesIn({8192}),
+ /* max_pending_chunks */
+ testing::ValuesIn({2}),
+ /* compression_level */
+ testing::ValuesIn({CompressionLevel::ZLIB_BEST_SPEED}),
+ /* request size */
+ testing::ValuesIn({1024 * 1024 * 10}),
+ /* request count */
+ testing::ValuesIn({2ul}),
+ /* reply size */
+ testing::ValuesIn({1024 * 1024 * 10}),
+ /* replies per request */
+ testing::ValuesIn({2ul})),
+ [](const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>&
+ info) {
+ // clang-format off
+ std::string name = absl::StrCat(
+ std::get<0>(info.param), "csfu" "_",
+ std::get<1>(info.param), "mpc", "_",
+ std::get<2>(info.param), "cl", "_",
+ std::get<3>(info.param), "rqs", "_",
+ std::get<4>(info.param), "rqc", "_",
+ std::get<5>(info.param), "rps", "_",
+ std::get<6>(info.param), "rppr");
+ absl::c_replace_if(
+ name, [](char c) { return !std::isalnum(c); }, '_');
+ // clang-format on
+ return name;
+ });
+
+} // namespace
+} // namespace test
+} // namespace client
+} // namespace fcp
diff --git a/fcp/protos/BUILD b/fcp/protos/BUILD
new file mode 100644
index 0000000..8d4952e
--- /dev/null
+++ b/fcp/protos/BUILD
@@ -0,0 +1,126 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
+load("@org_tensorflow//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
+load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library")
+
+default_visibility = [
+ "//visibility:public",
+]
+
+package(
+ default_visibility = default_visibility,
+ licenses = ["notice"], # Apache 2.0
+)
+
+# --------------------------------------------------------------------
+# federated_api.proto
+
+# The api protos.
+proto_library(
+ name = "federated_api_proto",
+ srcs = ["federated_api.proto"],
+ deps = [
+ "//fcp/secagg/shared:proto",
+ "@com_google_googleapis//google/rpc:code_proto",
+ "@com_google_protobuf//:any_proto",
+ "@com_google_protobuf//:duration_proto",
+ ],
+)
+
+py_proto_library(
+ name = "federated_api_py_pb2",
+ deps = [":federated_api_proto"],
+)
+
+java_proto_library(
+ name = "federated_api_java_proto",
+ deps = [":federated_api_proto"],
+)
+
+cc_proto_library(
+ name = "federated_api_cc_proto",
+ deps = [":federated_api_proto"],
+)
+
+cc_grpc_library(
+ name = "cc_grpc",
+ srcs = [":federated_api_proto"],
+ generate_mocks = True,
+ grpc_only = True,
+ deps = [":federated_api_cc_proto"],
+)
+
+# --------------------------------------------------------------------
+# plan.proto
+
+# Using tf_proto_library to get dependencies to TF protos built correctly.
+tf_proto_library(
+ name = "plan_proto",
+ srcs = ["plan.proto"],
+ protodeps = [
+ "@org_tensorflow//tensorflow/core:protos_all",
+ ],
+ visibility = default_visibility,
+)
+
+alias(
+ name = "plan_py_pb2",
+ actual = "plan_proto_py",
+)
+
+java_proto_library(
+ name = "plan_java_proto",
+ deps = [":plan_proto"],
+)
+
+# Allowing to refer to the cc library generated by the rule above in usual way:
+alias(
+ name = "plan_cc_proto",
+ actual = "plan_proto_cc",
+ visibility = default_visibility + [
+ ],
+)
+
+# --------------------------------------------------------------------
+# opstats.proto
+
+proto_library(
+ name = "opstats_proto",
+ srcs = ["opstats.proto"],
+ deps = [
+ ":federated_api_proto",
+ "@com_google_protobuf//:duration_proto",
+ "@com_google_protobuf//:timestamp_proto",
+ ],
+)
+
+cc_proto_library(
+ name = "opstats_cc_proto",
+ deps = [":opstats_proto"],
+)
+
+# --------------------------------------------------------------------
+# task_eligibility_context.proto
+
+proto_library(
+ name = "task_eligibility_context_proto",
+ srcs = ["task_eligibility_context.proto"],
+)
+
+java_proto_library(
+ name = "task_eligibility_context_java_proto",
+ deps = [":task_eligibility_context_proto"],
+)
diff --git a/fcp/protos/federated_api.proto b/fcp/protos/federated_api.proto
new file mode 100644
index 0000000..f862748
--- /dev/null
+++ b/fcp/protos/federated_api.proto
@@ -0,0 +1,809 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package google.internal.federatedml.v2;
+
+import "google/protobuf/any.proto";
+import "google/protobuf/duration.proto";
+import "fcp/secagg/shared/secagg_messages.proto";
+
+option java_package = "com.google.internal.federatedml.v2";
+option java_multiple_files = true;
+option java_outer_classname = "FederatedProto";
+
+// `FederatedTrainingApi` provides the protocol between a device and
+// cloud servers for federated computations.
+//
+// This API works as follows:
+//
+// 1. Clients are associated with *populations*. A population is
+// uniquely identified by an opaque string which indicates the
+// application together with additional information like a specific
+// flavor within this application.
+//
+// 2. Periodically, clients issue an `EligibilityEvalCheckinRequest` to indicate
+// they are interested in participating in a computation, followed by a
+// `CheckinRequest` to actually be assigned one of the available
+// computations they are compatible with. These requests specify the
+// population for which the requests are issued.
+//
+// 3. The server decides whether the client shall participate in a computation,
+// and if so returns the data (a computation description and a checkpoint)
+// needed for the client to start executing the computation.
+//
+// 4. If the client is selected, it performs local execution. When done, the
+// client issues a `ReportRequest`, specifying both population and phase id,
+// and the updated checkpoint.
+service FederatedTrainingApi {
+ // Initiates client-server communication session (bidirectional stream).
+ rpc Session(stream ClientStreamMessage) returns (stream ServerStreamMessage) {
+ }
+}
+
+// Message sent from a client to a server in a bidirectional stream.
+// The message is sent either directly or (for large messages when chunking
+// is supported) split into chunk encapsulated into smaller instances of
+// ClientStreamMessage.
+message ClientStreamMessage {
+ // Different kinds of messages.
+ oneof kind {
+ // Checkin request.
+ CheckinRequest checkin_request = 1;
+
+ // Report request.
+ ReportRequest report_request = 2;
+
+ // Transfer of ClientStreamMessage in multiple chunks
+ ChunkedTransferMessage chunked_transfer = 6;
+
+ // Eligibility evaluation checkin request.
+ EligibilityEvalCheckinRequest eligibility_eval_checkin_request = 7;
+ }
+
+ // Secure Aggregation messages. These form a parallel stream of
+ // messages, and so are outside the 'kind' oneof.
+ fcp.secagg.ClientToServerWrapperMessage secure_aggregation_client_message = 4;
+
+ repeated google.protobuf.Any serialized_side_channel_event = 5;
+
+ // Internal identifier of the connection to the client. This value isn't set
+ // by the device (ignored if set), but used by the server when forwarding
+ // requests internally within the system.
+ string internal_client_id = 3 [deprecated = true];
+}
+
+// Message sent from a server to a client in a bidirectional stream.
+// The message is sent either directly or (for large messages when chunking
+// is supported) split into chunks encapsulated into smaller instances of
+// ServerStreamMessage.
+message ServerStreamMessage {
+ // Different kinds of messages.
+ oneof kind {
+ // Checkin response.
+ CheckinResponse checkin_response = 1;
+
+ // Report response.
+ ReportResponse report_response = 2;
+
+ // Transfer of ServerStreamMessage in multiple chunks
+ ChunkedTransferMessage chunked_transfer = 4;
+
+ // Ack of a CheckinRequest or an EligibilityEvalCheckinRequest. See {@link
+ // CheckinRequest} for details.
+ CheckinRequestAck checkin_request_ack = 5;
+
+ // Eligibility evaluation checkin response.
+ EligibilityEvalCheckinResponse eligibility_eval_checkin_response = 6;
+
+ // This is expected to be extended in upcoming versions of the protocol
+ }
+
+ // Secure Aggregation messages. These form a parallel stream of
+ // messages, and so are outside the 'kind' oneof.
+ fcp.secagg.ServerToClientWrapperMessage secure_aggregation_server_message = 3;
+}
+
+// Supported levels of compression for chunked blob transfer.
+enum CompressionLevel {
+ // Compression disabled
+ UNCOMPRESSED = 0;
+ // zlib compression with default settings; this level uses gzip-6
+ // (currently).
+ // 'Z_DEFAULT_COMPRESSION requests a default compromise between speed and
+ // compression (currently equivalent to level 6).'
+ // Source - // http://www.zlib.net/manual.html
+ ZLIB_DEFAULT = 1;
+ // zlib compression optimized for most compression; this level uses gzip-9.
+ // '9 gives best compression' Source - http://www.zlib.net/manual.html
+ ZLIB_BEST_COMPRESSION = 2;
+ // zlib compression optimized for speed; this level uses gzip-1.
+ // '1 gives best speed' Source - http://www.zlib.net/manual.html
+ ZLIB_BEST_SPEED = 3;
+}
+
+// Supported compressed file formats for HTTP downloads.
+enum HttpCompressionFormat {
+ HTTP_COMPRESSION_FORMAT_UNSPECIFIED = 0;
+ // Gzip-compressed data. If data is compressed in this way, then the
+ // "Content-Type" HTTP response header will have a "+gzip" suffix.
+ HTTP_COMPRESSION_FORMAT_GZIP = 1;
+}
+
+// A request, sent by the device to check if it should participate
+// in the current phase.
+message CheckinRequest {
+ // The name of the population this client belongs to.
+ string population_name = 1;
+
+ // Optional. Retry token (opaque to the client) passed by the server when last
+ // participated in the training or rejected. If clients have such a token
+ // available, they should provide it. If not, things are still expected to
+ // work, but providing this gives server better control on organizing
+ // participation.
+ //
+ // Note that an `EligibilityEvalCheckinRequest` and its subsequent
+ // `CheckinRequest` request will both use the same value for this field,
+ // since both requests are considered part of the same logical protocol
+ // session.
+ string retry_token = 2;
+
+ reserved 3;
+
+ // The attestation measurement providing evidence of integrity for this
+ // client. The measurement is bound to the population_name and retry_token
+ // values in this CheckinRequest.
+ //
+ // Note that an `EligibilityEvalCheckinRequest` and its subsequent
+ // `CheckinRequest` request will both use the same value for this field,
+ // since both requests are considered part of the same logical protocol
+ // session.
+ string attestation_measurement = 4;
+
+ // Protocol options supported by the client.
+ ProtocolOptionsRequest protocol_options_request = 5;
+
+ string client_version = 6 ;
+
+ // The client computes this message using the plan returned by a previous
+ // `EligibilityEvalCheckinResponse`.
+ //
+ // If this field is set, it describes to the server which tasks the client is
+ // (in)eligible. The server must take this information into account when
+ // deciding task to serve in response to this request.
+ //
+ // If this field is unset, it may indicate that the client previously received
+ // an `EligibilityEvalCheckinResponse` without an
+ // `EligibilityEvalPayload` message (i.e. the population did not
+ // have an eligibility-computing task configured at the time of the request).
+ // It may also indicate a client for which the eligibility-computing task
+ // feature has been disabled, or an old client that does not support this
+ // feature yet.
+ //
+ // If this field is unset but the population has an eligibility-computing task
+ // configured, then the server must reject this client, since the server has
+ // no way to determine which tasks the client is (in)eligible for.
+ //
+ // If this field is unset and the population does not have an
+ // eligibility-computing task configured, then the server may serve this
+ // client any task.
+ //
+ TaskEligibilityInfo task_eligibility_info = 7;
+}
+
+// Describes to the server which tasks a client is eligible for.
+message TaskEligibilityInfo {
+ // A semantic version describing how the set of eligibility descriptors should
+ // be interpreted. This fields enables assigning different semantics for how
+ // the server should interpret the descriptors, without having to change the
+ // wire format (e.g. different ways of interpreting `TaskWeight.weight`).
+ int64 version = 1;
+
+ // A list of task weights, which the server may use when assigning the client
+ // a task in response to the current request.
+ //
+ // If none of the `TaskWeight` messages match a given task, then the client
+ // must be considered ineligible for that task, and the server must not serve
+ // the client that task.
+ //
+ // Therefore, if a `TaskEligibilityInfo` message is provided but this field is
+ // empty then the client should be considered ineligible for all tasks in the
+ // population (although in practice the client will simply close the
+ // connection in that case, rather than issue a `CheckinRequest` with such an
+ // empty list of weights).
+ repeated TaskWeight task_weights = 2;
+}
+
+// Describes a weight that should be assigned to a specific task.
+message TaskWeight {
+ // Name of the task this weight applies to.
+ string task_name = 1;
+
+ // The weight that should be applied to the specified task.
+ //
+ // Must be >0.
+ //
+ // This weight may (or may not) be used by the server to implement some form
+ // of task or client prioritization.
+ float weight = 2;
+}
+
+// Response to the checkin request, sent to the device.
+message CheckinResponse {
+ // One of two outcomes, depending on server's decision on participation of the
+ // client.
+ oneof checkin_result {
+ // If the client joined the phase with this call, information how
+ // to proceed.
+ AcceptanceInfo acceptance_info = 1;
+
+ // If the client was not accepted, information how to proceed.
+ RejectionInfo rejection_info = 2;
+ }
+
+ // Instructions from server to the client how to execute protocol.
+ // While, conceptually, chunked transfer is a symmetric protocol with respect
+ // to peers (both server and client act like senders and receivers), the
+ // protocol handshake and configuration part is intentionally skewed towards
+ // the server driving the decisions on how chunking are performed on the wire,
+ // so we have a centralized way of controlling the feature.
+ //
+ // Note that if a client receives more than one `ProtocolOptionsResponse` over
+ // the life of a protocol session (e.g. in
+ // `EligibilityEvalCheckinResponse` as well as `CheckinResponse`)
+ // then the client will use the most recently-received value for further
+ // communications with the server.
+ ProtocolOptionsResponse protocol_options_response = 4;
+
+ reserved 3;
+}
+
+// Acknowledgement for a `CheckinRequest` or an `EligibilityEvalCheckinRequest`
+// from the client. This happens almost instantenously for all clients (that
+// request an ack using ProtocolRequestOptions#should_ack_checkin) as soon as
+// they issue either request, and happens *before* either a `CheckinResponse`
+// or a `EligibilityEvalCheckinResponse` is returned to the client.
+message CheckinRequestAck {
+ // Retry window to use for the next checkin attempt if this attempt ends up
+ // being subsequently accepted by the server, as in the client received a
+ // CheckinResponse with an AcceptanceInfo.
+ RetryWindow retry_window_if_accepted = 1;
+
+ // Retry window to use if this checkin attempt is not accepted by the server,
+ // as in the client doesn't receive a CheckinResponse with an AcceptanceInfo.
+ RetryWindow retry_window_if_rejected = 2;
+}
+
+// A request, sent by the device to request the eligibility-computing plan for
+// the population. This plan is run by the client to generate a
+// `TaskEligibilityInfo` proto result, which is then included with a subsequent
+// `CheckinRequest` (within the same protocol session) to inform the server
+// which tasks the client is eligible for.
+//
+// The use of an `EligibilityEvalCheckinRequest` is optional (i.e. clients
+// may simply issue a `CheckinRequest` without a preceding
+// `EligibilityEvalCheckinRequest`, in which case the
+// `CheckinRequest.task_eligibility_info` field will be left unset).
+message EligibilityEvalCheckinRequest {
+ // The name of the population this client belongs to.
+ string population_name = 1;
+
+ // Optional. This field has the same semantics as
+ // `CheckinRequest.retry_token`, see that field for details.
+ string retry_token = 2;
+
+ // This field has the same semantics as
+ // `CheckinRequest.attestation_measurement`.
+ // See that field for details.
+ string attestation_measurement = 4;
+
+ // Protocol options supported by the client.
+ ProtocolOptionsRequest protocol_options_request = 5;
+
+ // This field has the same semantics as `CheckinRequest.client_version`. See
+ // that field for details.
+ string client_version = 6 ;
+
+ // The client's capabilities when downloading and running Eligibility Eval
+ // tasks.
+ EligibilityEvalTaskCapabilities eligibility_eval_task_capabilities = 7;
+}
+
+// The client's capabilities for determining task eligibility.
+message EligibilityEvalTaskCapabilities {
+ // Whether the client supports multiple task assignment
+ // (/TaskAssignments.PerformMultipleTaskAssignments). If false, the client
+ // will not be provided information about tasks that require multiple task
+ // assignment.
+ bool supports_multiple_task_assignment = 1;
+}
+
+// Response to the `EligibilityEvalCheckinRequest`, sent to the
+// device.
+message EligibilityEvalCheckinResponse {
+ // Each response will contain one of the following results.
+ oneof checkin_result {
+ // If the population has an eligibility-computing plan configured, and if
+ // the client is compatible with that plan, then this field will be set,
+ // containing the plan's payload. The client should run the plan and include
+ // its `TaskEligibilityInfo` result in the subsequent `CheckinRequest`.
+ EligibilityEvalPayload eligibility_eval_payload = 1;
+
+ // If the population does not have an eligibility-computing plan configured,
+ // then this field will be set. The client should continue by issuing a
+ // `CheckinRequest` without the `task_eligibility_info` field set.
+ NoEligibilityEvalConfigured no_eligibility_eval_configured = 2;
+
+ // If the population has an eligibility-computing plan configured, but the
+ // client is incompatible with that plan, then this field will be set.
+ RejectionInfo rejection_info = 3;
+ }
+
+ // This field has the same semantics as
+ // `CheckinResponse.protocol_options_response`. See that field for details.
+ ProtocolOptionsResponse protocol_options_response = 4;
+}
+
+// Contains the eligibility evaluation plan payload.
+message EligibilityEvalPayload {
+ oneof init_checkpoint_type {
+ // A blob representing the checkpoint to start execution from.
+ bytes init_checkpoint = 1;
+
+ // A URI and other metadata of the checkpoint to start execution from.
+ UriResource init_checkpoint_resource = 4;
+ }
+
+ oneof plan_type {
+ // A blob representing the plan to be used for execution.
+ bytes plan = 2;
+
+ // A URI and other metadata of the plan to be used for execution.
+ UriResource plan_resource = 5;
+ }
+
+ oneof population_eligibility_spec_type {
+ // A serialized PopulationEligibilitySpec describing the eligibility
+ // criteria for tasks in the population.
+ bytes population_eligibility_spec = 6;
+
+ // A URI and other metadata of the population eligibility spec to be used.
+ UriResource population_eligibility_spec_resource = 7;
+ }
+
+ // The opaque id of the eligibility evaluation plan payload the client is
+ // being given. This is a string generated by the server and used by the
+ // client for logging purposes. This id MUST NOT contain any information that
+ // could be used to identify a specific device.
+ // Also see the similar `AcceptanceInfo.execution_phase_id`.
+ string execution_id = 3;
+}
+
+// Currently-empty message describing the case where a population does not have
+// an eligibility-computing plan configured.
+message NoEligibilityEvalConfigured {}
+
+// Per-aggregand dynamic configuration information for side channel protocols.
+message SideChannelExecutionInfo {
+ // Dynamic configuration for SecureAggregation side channels.
+ message SecureAggregandExecutionInfo {
+ // Bitwidth for secure-aggregation. This must be wide enough to
+ // encode the sum of all client inputs, given the current number of clients
+ // participating in the protocol.
+ //
+ // This field is deprecated; use modulus instead.
+ int32 output_bitwidth = 1 [deprecated = true];
+
+ // Modulus for secure aggregation.
+ //
+ // The secure aggregation protocol will compute the sum modulo this modulus.
+ //
+ // To achieve equivalence with non-modular summation, the modulus must be
+ // larger than the sum of all client inputs, given the number of clients
+ // participating in the aggregation shard.
+ //
+ // If modulus is missing but output_bitwidth is specified, modulus will
+ // will be taken to be 2**output_bitwidth (for backwards compatibility).
+ uint64 modulus = 2;
+ }
+
+ // What type of side channel is used.
+ oneof type {
+ // Dynamic configuration for Secure Aggregation side channels.
+ SecureAggregandExecutionInfo secure_aggregand = 1;
+ }
+}
+
+// Per-protocol options information for side channel protocols.
+message SideChannelProtocolOptionsRequest {
+ // Options for SecureAggregation side channels.
+ message SecureAggregation {
+ // Protocol versions available.
+ repeated fcp.secagg.ClientVariant client_variant = 2;
+ }
+
+ // Protocol options for Secure Aggregation side channels.
+ SecureAggregation secure_aggregation = 1;
+}
+
+// Negotiated settings for side channel protocols. Side channel options may not
+// be set for channels which will not be used in the ReportRequest.
+message SideChannelProtocolOptionsResponse {
+ // Server-directed secure aggregation options to apply.
+ message SecureAggregation {
+ // The client variant to use.
+ fcp.secagg.ClientVariant client_variant = 1;
+ }
+
+ // SecureAggregation protocol options. Only set if a SecureAggregation plan
+ // will be used.
+ SecureAggregation secure_aggregation = 1;
+}
+
+// Per-protocol dynamic configuration information for side channel protocols.
+message SideChannelProtocolExecutionInfo {
+ // Dynamic configuration for SecureAggregation side channels.
+ message SecureAggregationProtocolExecutionInfo {
+ // Number of clients that a client may exchange data with while running
+ // Secure Aggregation protocol. In the case of a full graph SecAgg protocol
+ // this is a total number of clients that started the protocol.
+ // In the case of subgraph SecAgg protocol this is a number of neighbours
+ // that each client has.
+ int32 expected_number_of_clients = 1;
+
+ // Secure Aggregation client completion threshold. This is a parameter
+ // communicated by the server side of Secure Aggregation protocol to each
+ // client to establish Shamir sharing of secrets.
+ // Additionally, at least `minimum_surviving_clients_for_reconstruction` out
+ // of the initial `expected_number_of_clients` must 'survive' in order for
+ // the protocol to continue on the client side; otherwise the client will
+ // abort its connection.
+ int32 minimum_surviving_clients_for_reconstruction = 2;
+
+ // The minimum number of clients' values that must be aggregated together
+ // before the server can gain access to the aggregate,
+ // even transiently (e.g. in RAM).
+ // This isn't needed by Secure Aggregation protocol on the client side but
+ // shared by the server with clients for transparency or policy reasons.
+ int32 minimum_clients_in_server_visible_aggregate = 3;
+ }
+
+ // Dynamic configuration for Secure Aggregation side channels.
+ SecureAggregationProtocolExecutionInfo secure_aggregation = 1;
+}
+
+// When client (device) supports HTTP download of resources, this message
+// is used to carry information about each individual downloadable resource.
+message UriResource {
+ // Resource URI e.g. fully qualified URL.
+ string uri = 1;
+
+ // Stable identifier for this resource, used by the client cache
+ // implementation. If this field is not set, the client should not attempt to
+ // cache the resource referenced by `uri`.
+ string client_cache_id = 2;
+
+ // The maximum duration for how long the resource should be cached by the
+ // client. Not set if `client_cache_id` is not set.
+ google.protobuf.Duration max_age = 3;
+}
+
+// When client (device) is accepted for the current phase, this
+// data structure carries information necessary to begin training.
+message AcceptanceInfo {
+ // The opaque id of the phase the client has joined. This is a string
+ // generated by the server and used by the client for logging purposes.
+ // This id MUST NOT contain any information that could be used to identify
+ // a specific device.
+ string execution_phase_id = 1;
+
+ // The name identifying the task for which the client was accepted.
+ string task_name = 10;
+
+ oneof init_checkpoint_type {
+ // A blob representing the checkpoint to start execution from.
+ bytes init_checkpoint = 2;
+
+ // A URI and other metadata of the checkpoint to start execution from.
+ UriResource init_checkpoint_resource = 8;
+ }
+
+ // Note: Plan fields below should be unset when included in a JoinResponse
+ // from the aggregator to the selector, and then set by the selector before
+ // sending on to the client.
+ oneof plan_type {
+ // A blob representing the plan to be used for execution.
+ bytes plan = 3;
+
+ // A URI and other metadata of the plan to be used for execution.
+ UriResource plan_resource = 9;
+ }
+
+ // Per-aggregand dynamic configuration information for side channel protocols.
+ // The keys in this map are the names of side channels, aligning with
+ // CheckpointOp.side_channel_tensors in plan.proto.
+ map<string, SideChannelExecutionInfo> side_channels = 4;
+
+ // Per-protocol dynamic configuration information for side channel protocols.
+ // This configuration applies to all aggregands configured to use each
+ // protocol.
+ SideChannelProtocolExecutionInfo side_channel_protocol_execution_info = 5;
+
+ reserved 6, 7;
+
+ // Info for how to generate URIs for fetching slices at runtime.
+ FederatedSelectUriInfo federated_select_uri_info = 11;
+
+ reserved 12, 13;
+}
+
+// Info for how to generate URIs for fetching slices that the task might request
+// to be downloaded at runtime.
+//
+// When one or more slices are requested by the task, the template specified
+// here should be used to form a URI from which the client can download the
+// slice data, by replacing the "{served_at_id}" and "{key_base10}" substrings
+// with the `google.internal.federated.plan.SlicesSelector.served_at_id` and the
+// base-10 representation of the `SlicesSelector.keys` value. The client must
+// not perform any URI escaping to the values that the substrings are replaced
+// with.
+message FederatedSelectUriInfo {
+ // The URI template to use for fetching slices.
+ //
+ // This template must always start with "https://".
+ //
+ // This template must contain the following substrings: "{served_at_id}" and
+ // "{key_base10}", as per the above documentation.
+ string uri_template = 1;
+}
+
+// This is sent when client (device) is rejected for the participation.
+message RejectionInfo {
+ // Optional. A suggestion to the client when to retry next connection to the
+ // service
+ // Deprecated in favor of `CheckinRequestAck`. If a client supports
+ // `CheckinRequestAck` then this value is ignored.
+ RetryWindow retry_window = 4 [deprecated = true];
+
+ reserved 1, 2, 3;
+}
+
+// This is sent by the client after the client finishes local (on-device)
+// training.
+//
+// If secure aggregation side channel is used, this must accompany the
+// secure aggregation commit message in the same ClientStreamMessage.
+message ReportRequest {
+ // The name of the population this client belongs to.
+ string population_name = 1;
+
+ // The id of the execution phase this client participates in.
+ string execution_phase_id = 2;
+
+ // The report.
+ Report report = 3;
+}
+
+// This is sent by the server as the final message in the reporting protocol.
+message ReportResponse {
+ // Optional. A suggestion to the client when to retry next connection to the
+ // service
+ // Deprecated in favor of `CheckinRequestAck`. If a client supports
+ // `CheckinRequestAck` then this value is ignored.
+ RetryWindow retry_window = 1 [deprecated = true];
+}
+
+// A report with results of local (on-device) training.
+message Report {
+ // A blob representing the updated checkpoint, if any. The content
+ // is dependent of the execution method.
+ bytes update_checkpoint = 1;
+
+ // Status code reported by client.
+ // Code.OK indicates that client execution completed successfully and produced
+ // report. Any other code indicates unsuccessful execution and train events
+ // below might contain detailed diagnostic information.
+ int32 status_code = 5;
+
+ reserved 3;
+
+ // A serialized ClientExecutionStats field about stats produced during a
+ // client side execution of the plan.
+ repeated google.protobuf.Any serialized_train_event = 4;
+
+ reserved 2;
+
+ reserved 6;
+}
+
+// This message is used to report duration to the server.
+message ClientExecutionStats {
+ // The time spent on running the plan (includes I/O such as reading examples,
+ // but does not include time spent on the network for retrieving the plan
+ // or uploading results).
+ google.protobuf.Duration duration = 2;
+
+ reserved 1;
+}
+
+// A suggestion to the client when to retry the connection to the service next
+// time
+message RetryWindow {
+ // Optional. If set, the server offers the client to call back in
+ // an interval [delay_min .. delay_max].
+ // The client must pass this token back to the server to identify he is
+ // retrying. If this is not set, the client can retry for another phase at a
+ // time of his choosing.
+ string retry_token = 1;
+
+ // Required (if retry_token is set).
+ // The suggested minimal duration after which the client should
+ // retry. If the client retries earlier, it is likely he will be rejected
+ // again.
+ google.protobuf.Duration delay_min = 2;
+
+ // Required. The suggested maximal duration after which the client should
+ // retry, provided scheduling conditions allow. The client is supposed to make
+ // a best effort to callback in the min..max window, and should avoid
+ // calling before min. If he calls after max, the likelihood to be rejected
+ // again is higher.
+ google.protobuf.Duration delay_max = 3;
+}
+
+// Intermediate Representation for checkpoints after side channels and
+// quantization have been applied. This is the input and post-aggregation
+// output of the transport-and-aggregate protocols.
+message Checkpoint {
+ // An aggregand is a (flattened) collection of checkpoint variables
+ // that will be aggregated using the same transport-and-aggregate protocol.
+ message Aggregand {
+ repeated uint64 values = 1 [packed = true];
+ int32 bitwidth = 2;
+ }
+
+ // Checkpoint variables are partitioned into multiple aggregands,
+ // each of which can use a different transport-and-aggregate protocol.
+ //
+ map<string, Aggregand> aggregands = 1;
+}
+
+// Protocol options sent from client to server inside
+// `EligibilityEvalCheckinRequest` and `CheckinRequest`.
+message ProtocolOptionsRequest {
+ // True if client supports chunked blob transfer protocol.
+ bool supports_chunked_blob_transfer = 1;
+
+ // Chunked blob transfer compression levels supported by this client.
+ repeated CompressionLevel supported_compression_levels = 2;
+
+ // Per-protocol configuration option information for side channel protocols.
+ // These options apply to all aggregands configured to use each protocol.
+ SideChannelProtocolOptionsRequest side_channels = 3;
+
+ // When this is set in a {@link CheckinRequest} or {@link
+ // EligibilityEvalCheckinRequest} message, the server should ack using
+ // a {@link CheckinRequestAck} sent back to the client.
+ //
+ // Note that if a client previously issued a
+ // `EligibilityEvalCheckinRequest` (for which this field is always set
+ // to true), then this field will not be set to true for the subsequent
+ // `CheckinRequest`.
+ bool should_ack_checkin = 4;
+
+ // True if client supports download of resources via HTTP.
+ bool supports_http_download = 5;
+
+ // True if client supports download of Eligibility Eval resources via HTTP.
+ bool supports_eligibility_eval_http_download = 6;
+
+ // HTTP download compression formats supported by this client. All clients
+ // that support HTTP downloads are assumed to support uncompressed payloads.
+ repeated HttpCompressionFormat supported_http_compression_formats = 7;
+}
+
+// Protocol options sent from server to client inside
+// `EligibilityEvalCheckinResponse` and `CheckinResponse`.
+message ProtocolOptionsResponse {
+ // Tells client what chunk size to use for uploading data, default 8192.
+ int32 chunk_size_for_upload = 1;
+
+ // Tells client how many chunks to send ahead of receiving ack, default: 2
+ int32 max_pending_chunks = 2;
+
+ // Indicates desired compression level
+ CompressionLevel compression_level = 4;
+
+ // Negotiated side channel protocol options; side channel options may not
+ // be set for channels which will not be used in the ReportRequest.
+ SideChannelProtocolOptionsResponse side_channels = 5;
+}
+
+// Allows to transmit large ClientStreamMessage or ServerStreamMessage
+// (depending on a direction of a transfer) as a stream of multiple small chunks
+// with optional compression flow control.
+message ChunkedTransferMessage {
+ // Supported types of compression scheme. Deprecated, replaced by
+ // CompressionLevel. Do not add new values.
+ enum CompressionType {
+ // Compression disabled
+ UNCOMPRESSED = 0;
+ }
+
+ // Initiation of the chunked transfer. Transmitting party starts sending
+ // chunks (Data messages) right after sending Start.
+ message Start {
+ // Uncompressed size of the blob, in bytes.
+ int32 uncompressed_size = 2;
+
+ // Level of compression.
+ CompressionLevel compression_level = 3;
+
+ // Size of the blob transferred over the wire, in bytes.
+ // This field may not be set by older clients, so readers should check the
+ // uncompressed_size field if this value is zero.
+ int32 blob_size_bytes = 4;
+
+ reserved 1;
+ }
+
+ // Carries a chunk of data. Receiving party assembles all the chunks to get
+ // a serialized form of a large ClientStreamMessage or ServerStreamMessage
+ // depending on a direction of a transfer.
+ message Data {
+ // 0-based index of the chunk
+ // All chunk messages must be ordered, the index is included for diagnostics
+ // purposes.
+ int32 chunk_index = 1;
+
+ // Next chunk of the blob.
+ bytes chunk_bytes = 2;
+ }
+
+ // Acknowledgement of receiving of Data message.
+ // Must be sent in a response of a successful receiving of a chunk. It is used
+ // for the purposes of the flow control: transmitting party limits number of
+ // of chunks speculatively sent
+ message Ack {
+ // 0-based index of received chunk
+ // All ack messages must be ordered, the index is included for diagnostic
+ // purposes.
+ int32 chunk_index = 1;
+ }
+
+ // Completion of the chunked transfer.
+ message End {
+ // Total number of chunks transferred.
+ int32 chunk_count = 1;
+
+ // TODO(team): Add checksum.
+ }
+
+ // Kind of chunked message.
+ oneof kind {
+ // Start message, sent by transmitter
+ Start start = 1;
+
+ // Data message (a single chunk), sent by transmitter.
+ Data data = 2;
+
+ // Ack of receiving of Data message, sent by receiver.
+ Ack ack = 3;
+
+ // End of the transmission, sent by transmitter.
+ End end = 4;
+ }
+}
diff --git a/fcp/protos/federatedcompute/BUILD b/fcp/protos/federatedcompute/BUILD
new file mode 100644
index 0000000..6d38c4b
--- /dev/null
+++ b/fcp/protos/federatedcompute/BUILD
@@ -0,0 +1,60 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library")
+
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# --------------------------------------------------------------------
+
+# The api protos.
+proto_library(
+ name = "federated_compute_proto",
+ srcs = [
+ "aggregations.proto",
+ "common.proto",
+ "eligibility_eval_tasks.proto",
+ "secure_aggregations.proto",
+ "task_assignments.proto",
+ ],
+ deps = [
+ "//fcp/protos:federated_api_proto",
+ "//fcp/secagg/shared:proto",
+ "@com_google_googleapis//google/api:annotations_proto",
+ "@com_google_googleapis//google/longrunning:operations_proto",
+ "@com_google_googleapis//google/rpc:code_proto",
+ "@com_google_googleapis//google/rpc:status_proto",
+ "@com_google_protobuf//:duration_proto",
+ ],
+)
+
+java_proto_library(
+ name = "federated_compute_java_proto",
+ deps = [":federated_compute_proto"],
+)
+
+cc_proto_library(
+ name = "federated_compute_cc_proto",
+ deps = [":federated_compute_proto"],
+)
+
+py_proto_library(
+ name = "federated_compute_py_pb2",
+ deps = [":federated_compute_proto"],
+)
diff --git a/fcp/protos/federatedcompute/aggregations.proto b/fcp/protos/federatedcompute/aggregations.proto
new file mode 100644
index 0000000..f309ba8
--- /dev/null
+++ b/fcp/protos/federatedcompute/aggregations.proto
@@ -0,0 +1,144 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package google.internal.federatedcompute.v1;
+
+import "fcp/protos/federatedcompute/common.proto";
+
+option java_package = "com.google.internal.federatedcompute.v1";
+option java_multiple_files = true;
+
+service Aggregations {
+ // A request sent by the client after completing local (on-device) task
+ // execution to notify the server that it has Aggregation data to upload. The
+ // server responds with the location at which to upload the data. If a
+ // client's result is no longer needed (e.g. the reporting goal was already
+ // reached for the task), the server will respond with an ABORTED error in the
+ // operation status.
+ rpc StartAggregationDataUpload(StartAggregationDataUploadRequest)
+ returns (StartAggregationDataUploadResponse) {
+ }
+
+ // A request sent by the client indicating the successful completion of the
+ // client's aggregation session. If a client's result is not needed for the
+ // aggregation (e.g. the reporting goal was already reached for the task), the
+ // server will respond with an ABORTED error.
+ //
+ // Clients should use the `ForwardingInfo` from the
+ // `StartAggregationDataUploadResponse.aggregation_protocol_forwarding_info`
+ // response field to construct the URI for this request.
+ rpc SubmitAggregationResult(SubmitAggregationResultRequest)
+ returns (SubmitAggregationResultResponse) {
+ }
+
+ // A request sent by the client indicating the client's aggregation session
+ // should be aborted.
+ //
+ // Clients must only call this if they've previously called
+ // `StartAggregationDataUpload`.
+ //
+ // Clients should not call this if one of the requests returned an Aborted
+ // status.
+ //
+ // If clients have already received a `StartAggregationDataUploadResponse`
+ // they should use the `ForwardingInfo` from the
+ // `StartAggregationDataUploadResponse.aggregation_protocol_forwarding_info`
+ // response field to construct the URI for this request. Otherwise, clients
+ // should use the same `ForwardingInfo` as was used to construct the
+ // `StartAggregationDataUpload` request URI.
+ rpc AbortAggregation(AbortAggregationRequest)
+ returns (AbortAggregationResponse) {
+ }
+}
+
+message StartAggregationDataUploadRequest {
+ // The id of the aggregation session this client participates in. This value
+ // was returned by the server when the client was assigned a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string aggregation_id = 1
+ ;
+
+ // The authorization token returned by the server when the client was assigned
+ // a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string authorization_token = 2
+ ;
+}
+
+message StartAggregationDataUploadMetadata {}
+
+message StartAggregationDataUploadResponse {
+ // Information to construct the URI to use for continuing the aggregation
+ // protocol after the data is uploaded.
+ ForwardingInfo aggregation_protocol_forwarding_info = 1;
+
+ // Information about where to upload aggregation result data.
+ ByteStreamResource resource = 2;
+
+ // Unique token that the client must include in the subsequent protocol
+ // requests.
+ string client_token = 3;
+}
+
+message SubmitAggregationResultRequest {
+ // The id of the aggregation session this client participates in. This value
+ // was returned by the server when the client was assigned a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string aggregation_id = 1
+ ;
+
+ // The client token returned by the server when the client was assigned a
+ // task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string client_token = 2
+ ;
+
+ // Name of the resource to which the aggregration result was uploaded.
+ string resource_name = 3;
+}
+
+message SubmitAggregationResultResponse {}
+
+message AbortAggregationRequest {
+ // The id of the aggregation session this client participates in. This value
+ // was returned by the server when the client was assigned a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string aggregation_id = 1
+ ;
+
+ // The client token returned by the server when the client was assigned a
+ // task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string client_token = 2
+ ;
+
+ // Status code and optional message for why the aggregation was aborted.
+ Status status = 3;
+}
+
+message AbortAggregationResponse {}
diff --git a/fcp/protos/federatedcompute/common.proto b/fcp/protos/federatedcompute/common.proto
new file mode 100644
index 0000000..e67bec9
--- /dev/null
+++ b/fcp/protos/federatedcompute/common.proto
@@ -0,0 +1,313 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package google.internal.federatedcompute.v1;
+
+import "google/protobuf/duration.proto";
+
+option java_package = "com.google.internal.federatedcompute.v1";
+option java_multiple_files = true;
+option java_outer_classname = "FederatedComputeApi";
+
+// Information that tells the client where to send the request for the next
+// protocol phase (the immediately following phase only, not any additional
+// subsequent phases). For example, this may point to the frontend to which
+// a StartTaskAssignmentRequest should be sent, but it should not then be used
+// for uploading aggregation results. A ForwardingInfo will always be returned
+// to the client unless the client was not selected to continue with the
+// protocol.
+message ForwardingInfo {
+ // A URI prefix for the next service to send the request for the next protocol
+ // phase to.
+ //
+ // The URI prefix must always start with "https://".
+ //
+ // The URI prefix may end with a trailing '/', but is not required to. During
+ // the construction of the next protocol request, a slash will always be
+ // inserted by the client between this prefix and the request's URI suffix.
+ //
+ // For example, if some protocol response's ForwardingInfo contains the prefix
+ // "https://foo.bar.com" or "https://foo.bar.com/", and if the subsequent
+ // protocol request's URI suffix is "/baz", then the subsequent request's full
+ // URI would be "https://foo.bar.com/baz".
+ string target_uri_prefix = 1;
+ // Request headers that should be included with the next request for the next
+ // protocol phase. Note that these headers should only be applied to protocol
+ // requests (incl. requests to the long running `Operations` service), but not
+ // to any `Resource` fetch requests.
+ map<string, string> extra_request_headers = 2;
+}
+
+// The attestation measurement providing evidence of integrity for a client.
+message AttestationMeasurement {
+ string value = 1;
+}
+
+message ClientVersion {
+ // Version code identifying the client release.
+ string version_code = 1;
+}
+
+message Resource {
+ // A resource can either be downloaded via a URI, or has its data inlined in
+ // in this message itself.
+ oneof resource {
+ // The URI the resource can be downloaded from. Note that
+ // `ForwardingInfo.target_uri_prefix` field generally don't apply to these
+ // URIs.
+ string uri = 1;
+
+ // The inlined data for the resource. This will eventually replace `data`.
+ InlineResource inline_resource = 3;
+ }
+
+ message InlineResource {
+ // The inlined data for the resource.
+ bytes data = 1
+ ;
+
+ // The compression used for the inlined data, or unset if the data is
+ // uncompressed.
+ optional ResourceCompressionFormat compression_format = 2;
+ }
+
+ // Stable identifier for this resource, used by the client cache
+ // implementation. If this field is not set, the client should not attempt to
+ // cache the resource referenced by `uri`. Not set for inline_resources.
+ string client_cache_id = 4;
+
+ // The maximum duration for how long the resource should be cached by the
+ // client. Not set if `client_cache_id` is not set.
+ google.protobuf.Duration max_age = 5;
+
+ reserved 2;
+}
+
+// The client's capabilities for processing Resource messages, such as the
+// compressed file formats supported.
+message ResourceCapabilities {
+ // Compression formats supported for resources downloaded via `Resource.uri`.
+ // All clients are assumed to support uncompressed payloads.
+ repeated ResourceCompressionFormat supported_compression_formats = 1;
+}
+
+// Different file formats that may be used to compress resources.
+enum ResourceCompressionFormat {
+ RESOURCE_COMPRESSION_FORMAT_UNSPECIFIED = 0;
+ // Gzip-compressed data. If data is compressed in this way, then the
+ // "Content-Type" HTTP response header will have a "+gzip" suffix.
+ RESOURCE_COMPRESSION_FORMAT_GZIP = 1;
+}
+
+// Currently empty message which is sent when client (device) is rejected for
+// participation and is not assigned a task.
+message RejectionInfo {}
+
+// A suggestion to the client when to retry the connection to the service next
+// time
+message RetryWindow {
+ // The suggested minimal duration after which the client should
+ // retry. If the client retries earlier, it is likely it will be rejected
+ // again.
+ google.protobuf.Duration delay_min = 1;
+
+ // Required. The suggested maximal duration after which the client should
+ // retry, provided scheduling conditions allow. The client is supposed to make
+ // a best effort to callback in the min..max window, and should avoid
+ // calling before min. If the client calls after max, the likelihood to be
+ // rejected again is higher.
+ google.protobuf.Duration delay_max = 2;
+}
+
+// Information about where to upload data (e.g. aggregation results, client
+// stats).
+message ByteStreamResource {
+ // Information to construct the URI to use for uploading the data.
+ ForwardingInfo data_upload_forwarding_info = 1;
+ // Resource name to which the data should be uploaded.
+ // Clients should use this field as well as the
+ // `ForwardingInfo.target_uri_prefix` to create the upload URL:
+ // {target_uri_prefix}/upload/v1/media/{resource_name} (where
+ // `{resource_name}` should be encoded as a multipath segment, as described
+ // in
+ // https://github.com/googleapis/googleapis/blob/master/google/api/http.proto).
+ string resource_name = 2;
+}
+
+// Copied from //google/rpc/status.proto.
+message Status {
+ // The status code, which should be an enum value of [google.rpc.Code][].
+ int32 code = 1;
+
+ string message = 2;
+}
+
+enum Code {
+ // Not an error; returned on success.
+ //
+ // HTTP Mapping: 200 OK
+ OK = 0;
+
+ // The operation was cancelled, typically by the caller.
+ //
+ // HTTP Mapping: 499 Client Closed Request
+ CANCELLED = 1;
+
+ // Unknown error. For example, this error may be returned when
+ // a `Status` value received from another address space belongs to
+ // an error space that is not known in this address space. Also
+ // errors raised by APIs that do not return enough error information
+ // may be converted to this error.
+ //
+ // HTTP Mapping: 500 Internal Server Error
+ UNKNOWN = 2;
+
+ // The client specified an invalid argument. Note that this differs
+ // from `FAILED_PRECONDITION`. `INVALID_ARGUMENT` indicates arguments
+ // that are problematic regardless of the state of the system
+ // (e.g., a malformed file name).
+ //
+ // HTTP Mapping: 400 Bad Request
+ INVALID_ARGUMENT = 3;
+
+ // The deadline expired before the operation could complete. For operations
+ // that change the state of the system, this error may be returned
+ // even if the operation has completed successfully. For example, a
+ // successful response from a server could have been delayed long
+ // enough for the deadline to expire.
+ //
+ // HTTP Mapping: 504 Gateway Timeout
+ DEADLINE_EXCEEDED = 4;
+
+ // Some requested entity (e.g., file or directory) was not found.
+ //
+ // Note to server developers: if a request is denied for an entire class
+ // of users, such as gradual feature rollout or undocumented allowlist,
+ // `NOT_FOUND` may be used. If a request is denied for some users within
+ // a class of users, such as user-based access control, `PERMISSION_DENIED`
+ // must be used.
+ //
+ // HTTP Mapping: 404 Not Found
+ NOT_FOUND = 5;
+
+ // The entity that a client attempted to create (e.g., file or directory)
+ // already exists.
+ //
+ // HTTP Mapping: 409 Conflict
+ ALREADY_EXISTS = 6;
+
+ // The caller does not have permission to execute the specified
+ // operation. `PERMISSION_DENIED` must not be used for rejections
+ // caused by exhausting some resource (use `RESOURCE_EXHAUSTED`
+ // instead for those errors). `PERMISSION_DENIED` must not be
+ // used if the caller can not be identified (use `UNAUTHENTICATED`
+ // instead for those errors). This error code does not imply the
+ // request is valid or the requested entity exists or satisfies
+ // other pre-conditions.
+ //
+ // HTTP Mapping: 403 Forbidden
+ PERMISSION_DENIED = 7;
+
+ // The request does not have valid authentication credentials for the
+ // operation.
+ //
+ // HTTP Mapping: 401 Unauthorized
+ UNAUTHENTICATED = 16;
+
+ // Some resource has been exhausted, perhaps a per-user quota, or
+ // perhaps the entire file system is out of space.
+ //
+ // HTTP Mapping: 429 Too Many Requests
+ RESOURCE_EXHAUSTED = 8;
+
+ // The operation was rejected because the system is not in a state
+ // required for the operation's execution. For example, the directory
+ // to be deleted is non-empty, an rmdir operation is applied to
+ // a non-directory, etc.
+ //
+ // Service implementors can use the following guidelines to decide
+ // between `FAILED_PRECONDITION`, `ABORTED`, and `UNAVAILABLE`:
+ // (a) Use `UNAVAILABLE` if the client can retry just the failing call.
+ // (b) Use `ABORTED` if the client should retry at a higher level. For
+ // example, when a client-specified test-and-set fails, indicating the
+ // client should restart a read-modify-write sequence.
+ // (c) Use `FAILED_PRECONDITION` if the client should not retry until
+ // the system state has been explicitly fixed. For example, if an "rmdir"
+ // fails because the directory is non-empty, `FAILED_PRECONDITION`
+ // should be returned since the client should not retry unless
+ // the files are deleted from the directory.
+ //
+ // HTTP Mapping: 400 Bad Request
+ FAILED_PRECONDITION = 9;
+
+ // The operation was aborted, typically due to a concurrency issue such as
+ // a sequencer check failure or transaction abort.
+ //
+ // See the guidelines above for deciding between `FAILED_PRECONDITION`,
+ // `ABORTED`, and `UNAVAILABLE`.
+ //
+ // HTTP Mapping: 409 Conflict
+ ABORTED = 10;
+
+ // The operation was attempted past the valid range. E.g., seeking or
+ // reading past end-of-file.
+ //
+ // Unlike `INVALID_ARGUMENT`, this error indicates a problem that may
+ // be fixed if the system state changes. For example, a 32-bit file
+ // system will generate `INVALID_ARGUMENT` if asked to read at an
+ // offset that is not in the range [0,2^32-1], but it will generate
+ // `OUT_OF_RANGE` if asked to read from an offset past the current
+ // file size.
+ //
+ // There is a fair bit of overlap between `FAILED_PRECONDITION` and
+ // `OUT_OF_RANGE`. We recommend using `OUT_OF_RANGE` (the more specific
+ // error) when it applies so that callers who are iterating through
+ // a space can easily look for an `OUT_OF_RANGE` error to detect when
+ // they are done.
+ //
+ // HTTP Mapping: 400 Bad Request
+ OUT_OF_RANGE = 11;
+
+ // The operation is not implemented or is not supported/enabled in this
+ // service.
+ //
+ // HTTP Mapping: 501 Not Implemented
+ UNIMPLEMENTED = 12;
+
+ // Internal errors. This means that some invariants expected by the
+ // underlying system have been broken. This error code is reserved
+ // for serious errors.
+ //
+ // HTTP Mapping: 500 Internal Server Error
+ INTERNAL = 13;
+
+ // The service is currently unavailable. This is most likely a
+ // transient condition, which can be corrected by retrying with
+ // a backoff. Note that it is not always safe to retry
+ // non-idempotent operations.
+ //
+ // See the guidelines above for deciding between `FAILED_PRECONDITION`,
+ // `ABORTED`, and `UNAVAILABLE`.
+ //
+ // HTTP Mapping: 503 Service Unavailable
+ UNAVAILABLE = 14;
+
+ // Unrecoverable data loss or corruption.
+ //
+ // HTTP Mapping: 500 Internal Server Error
+ DATA_LOSS = 15;
+}
+
diff --git a/fcp/protos/federatedcompute/eligibility_eval_tasks.proto b/fcp/protos/federatedcompute/eligibility_eval_tasks.proto
new file mode 100644
index 0000000..49dc8d8
--- /dev/null
+++ b/fcp/protos/federatedcompute/eligibility_eval_tasks.proto
@@ -0,0 +1,208 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package google.internal.federatedcompute.v1;
+
+import "fcp/protos/federatedcompute/common.proto";
+
+option java_package = "com.google.internal.federatedcompute.v1";
+option java_multiple_files = true;
+
+service EligibilityEvalTasks {
+ // A request, sent by the device to request the eligibility-computing task for
+ // the population. This task is run by the client to generate a
+ // `TaskEligibilityInfo` proto result, which is then included with a
+ // subsequent `StartTaskAssignmentRequest` to inform the server which tasks
+ // the client is eligible for.
+ //
+ // Returns NOT_FOUND if the population does not exist.
+ rpc RequestEligibilityEvalTask(EligibilityEvalTaskRequest)
+ returns (EligibilityEvalTaskResponse) {
+ }
+
+ // A request sent by the device to report the result of running the
+ // EligibilityEval task provided by `EligibilityEvalTaskResponse`.
+ //
+ // A result with a status code other than Code.OK indicates client session
+ // termination. The client may not send any future requests with the given
+ // session_id.
+ //
+ // Clients should use the same `ForwardingInfo` as used in the
+ // `RequestEligibilityEvalTask` request to construct the URI for this request.
+ rpc ReportEligibilityEvalTaskResult(ReportEligibilityEvalTaskResultRequest)
+ returns (ReportEligibilityEvalTaskResultResponse) {
+ }
+}
+
+message EligibilityEvalTaskRequest {
+ // The name of the population this client belongs to.
+ //
+ // Note that http clients set this value in the request URL instead of the
+ // request body.
+ string population_name = 1
+ ;
+
+ // The attestation measurement providing evidence of integrity for this
+ // client. The measurement is bound to the population_name value in this
+ // request.
+ //
+ // Note that the subsequent `StartTaskAssignmentRequest` will use the same
+ // value for this field, since it is considered part of the same logical
+ // protocol session as this request.
+ AttestationMeasurement attestation_measurement = 2;
+
+ ClientVersion client_version = 3;
+
+ // The client's capabilities when downloading and processing resources.
+ ResourceCapabilities resource_capabilities = 4;
+
+ // The client's capabilities when downloading and running Eligibility Eval
+ // tasks.
+ EligibilityEvalTaskCapabilities eligibility_eval_task_capabilities = 5;
+}
+
+// The client's capabilities for determining task eligibility.
+message EligibilityEvalTaskCapabilities {
+ // Whether the client supports multiple task assignment
+ // (/TaskAssignments.PerformMultipleTaskAssignments). If false, the client
+ // will not be provided information about tasks that require multiple task
+ // assignment.
+ bool supports_multiple_task_assignment = 1;
+}
+
+message EligibilityEvalTaskResponse {
+ // Information to construct the URI to use when calling StartTaskAssignment.
+ // This will not be populated if the result below contains a RejectionInfo
+ // message since the client should not call StartTaskAssignment in that case.
+ //
+ // Note that this forwarding info does not apply to
+ // `ReportEligibilityEvalTaskResult` which should instead be sent to the same
+ // endpoint as `RequestEligibilityEvalTask`.
+ ForwardingInfo task_assignment_forwarding_info = 1;
+
+ // Unique identifier for the protocol session. This field will not be set if
+ // the result below contains a RejectionInfo.
+ string session_id = 2;
+
+ oneof result {
+ // If the population has an eligibility-computing task configured, and if
+ // the client is compatible with that task, then this field will be set,
+ // containing the task's information. The client should run the task and
+ // include its `TaskEligibilityInfo` result in the subsequent
+ // `StartTaskAssignmentRequest`.
+ EligibilityEvalTask eligibility_eval_task = 3;
+
+ // If the population does not have an eligibility-computing task configured,
+ // then this field will be set. The client should continue by issuing a
+ // `StartTaskAssignmentRequest` without the `task_eligibility_info` field
+ // set.
+ NoEligibilityEvalConfigured no_eligibility_eval_configured = 4;
+
+ // If the population has an eligibility-computing task configured, but the
+ // client is incompatible with that task or if the server is unable to
+ // service the request at the moment, then this field will be set.
+ RejectionInfo rejection_info = 5;
+ }
+
+ // Retry window to use for the next RequestEligibilityEvalTask attempt if
+ // the following StartTaskAssignment attempt ends up being subsequently
+ // accepted by the server, as in the client received a
+ // StartTaskAssignmentResponse with a TaskAssignment. This will not be set if
+ // the result above contains a RejectionInfo.
+ RetryWindow retry_window_if_accepted = 6;
+
+ // Retry window to use if this request was rejected or if the following
+ // StartTaskAssignment attempt is not accepted by the server, as in the client
+ // receives a StartTaskAssignmentResponse without a TaskAssignment.
+ RetryWindow retry_window_if_rejected = 7;
+}
+
+message EligibilityEvalTask {
+ // The checkpoint from which to start execution (if any).
+ // Optional: This field and `plan` may both be unset if the client supports
+ // multiple task assignment but the population does not have an Eligibility
+ // Eval task configured.
+ Resource init_checkpoint = 1;
+
+ // The task to be used for execution.
+ // Optional: This field and `init_checkpoint` may both be unset if the client
+ // supports multiple task assignment but the population does not have an
+ // Eligibility Eval task configured.
+ Resource plan = 2;
+
+ // A serialized PopulationEligibilitySpec describing the eligibility criteria
+ // for tasks in the population.
+ Resource population_eligibility_spec = 4;
+
+ // The opaque id of the eligibility evaluation task payload the client is
+ // being given. This is a string generated by the server and used by the
+ // client for logging purposes. This id MUST NOT contain any information that
+ // could be used to identify a specific device.
+ // Also see the similar `TaskAssignment.execution_phase_id`.
+ // Optional: If `plan` is absent, this field may also be absent.
+ string execution_id = 3;
+}
+
+// Currently-empty message describing the case where a population does not have
+// an eligibility-computing task configured.
+message NoEligibilityEvalConfigured {}
+
+// Provides the information needed to determine eligibility for tasks in a
+// population.
+message PopulationEligibilitySpec {
+ // Eligibility-related information about each task in the population.
+ repeated TaskInfo task_info = 1;
+
+ message TaskInfo {
+ // The name of the task.
+ string task_name = 1;
+
+ // The TaskAssignments method to use for the task.
+ TaskAssignmentMode task_assignment_mode = 2;
+
+ enum TaskAssignmentMode {
+ TASK_ASSIGNMENT_MODE_UNSPECIFIED = 0;
+ // Task assignment uses /TaskAssignments.StartTaskAssignment.
+ TASK_ASSIGNMENT_MODE_SINGLE = 1;
+ // Task assignment uses /TaskAssignments.PerformMultipleTaskAssignments.
+ TASK_ASSIGNMENT_MODE_MULTIPLE = 2;
+ }
+ }
+}
+
+message ReportEligibilityEvalTaskResultRequest {
+ // The name of the population this client belongs to.
+ //
+ // Note that http clients set this value in the request URL instead of the
+ // request body.
+ string population_name = 1
+ ;
+
+ // The session id returned by the server.
+ //
+ // Note that http clients set this value in the request URL instead of the
+ // request body.
+ string session_id = 2
+ ;
+
+ // Status code reported by client.
+ // Code.OK indicates that client execution completed successfully. Any other
+ // code indicates unsuccessful execution and termination of the protocol
+ // session.
+ int32 status_code = 3;
+}
+
+message ReportEligibilityEvalTaskResultResponse {}
diff --git a/fcp/protos/federatedcompute/secure_aggregations.proto b/fcp/protos/federatedcompute/secure_aggregations.proto
new file mode 100644
index 0000000..a221e80
--- /dev/null
+++ b/fcp/protos/federatedcompute/secure_aggregations.proto
@@ -0,0 +1,343 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package google.internal.federatedcompute.v1;
+
+import "fcp/protos/federatedcompute/common.proto";
+import "fcp/secagg/shared/secagg_messages.proto";
+import "google/protobuf/duration.proto";
+
+option java_package = "com.google.internal.federatedcompute.v1";
+option java_multiple_files = true;
+
+service SecureAggregations {
+ // A request sent by the client after completing local (on-device) task
+ // computation to notify the server that the client has Aggregation data to
+ // upload via the Secure Aggregation protocol. If a client's result is no
+ // longer needed (e.g. the reporting goal was already reached for the task),
+ // the server will respond with an ABORTED error in the operation status. The
+ // operation is completed successfully when the Secure Aggregation protocol is
+ // ready to begin.
+ rpc StartSecureAggregation(StartSecureAggregationRequest)
+ returns (StartSecureAggregationResponse) {}
+
+ // A request sent by the client indicating the client is ending its
+ // participation in the Secure Aggregation protocol.
+ //
+ // Clients must only call this if they've previously called
+ // `StartSecureAggregation`.
+ //
+ // If clients have already received a `StartSecureAggregationResponse`
+ // they should use the `ForwardingInfo` from the
+ // `StartSecureAggregationResponse.secagg_protocol_forwarding_info`
+ // response field to construct the URI for this request. Otherwise, clients
+ // should use the same `ForwardingInfo` as was used to construct the
+ // `StartSecureAggregation` request URI.
+ rpc AbortSecureAggregation(AbortSecureAggregationRequest)
+ returns (AbortSecureAggregationResponse) {
+ }
+
+ // A request sent by the client to advertise its pair of public keys. The
+ // server responds with a list of the (pairs of) public keys of all other
+ // participating clients.
+ //
+ // Clients should use the `ForwardingInfo` from the
+ // `StartSecureAggregationResponse.secagg_protocol_forwarding_info`
+ // response field to construct the URI for this request.
+ //
+ // If the returned operation is not complete, clients should poll for status
+ // at the rate specified in the AdvertiseKeysMetadata.
+ rpc AdvertiseKeys(AdvertiseKeysRequest)
+ returns (AdvertiseKeysResponse) {
+ }
+
+ // A request sent by the client to secret-share its
+ // own noise_sk and prf_sk with all the other clients (encrypting shares for
+ // client j with their own enc_pk). The server responds with the client's
+ // shares of the keys of each other client that sent a ShareKeysRequest.
+ //
+ // Clients should use the `ForwardingInfo` from the
+ // `StartSecureAggregationResponse.secagg_protocol_forwarding_info`
+ // response field to construct the URI for this request.
+ //
+ // If the returned operation is not complete, clients should poll for status
+ // at the rate specified in the ShareKeysMetadata.
+ rpc ShareKeys(ShareKeysRequest) returns (ShareKeysResponse) {
+ }
+
+ // A request sent by the client indicating the successful upload of the
+ // client's masked and unmasked results. The server responds with a list of
+ // clients that did not successfully upload their results (and therefore are
+ // considered dead).
+ //
+ // Clients should use the `ForwardingInfo` from the
+ // `StartSecureAggregationResponse.secagg_protocol_forwarding_info`
+ // response field to construct the URI for this request.
+ //
+ // If the returned operation is not complete, clients should poll for status
+ // at the rate specified in the SubmitSecureAggregationResultMetadata.
+ rpc SubmitSecureAggregationResult(SubmitSecureAggregationResultRequest)
+ returns (SubmitSecureAggregationResultResponse) {
+ }
+
+ // A request sent by the client containing information for each other client
+ // j. For each client j, The client provides either the share of noise_sk (if
+ // client j is dead) OR share of prf_sk (if client j is still alive).
+ //
+ // If a client's aggregation result is no longer needed for the aggregation
+ // (e.g. the reporting goal was already reached for the task), the server will
+ // respond with an ABORTED error.
+ //
+ // Clients should use the `ForwardingInfo` from the
+ // `StartSecureAggregationResponse.secagg_protocol_forwarding_info`
+ // response field to construct the URI for this request.
+ rpc Unmask(UnmaskRequest) returns (UnmaskResponse) {
+ }
+}
+
+// --- StartSecureAggregation ---
+message StartSecureAggregationRequest {
+ // The id of the aggregation session this client participates in. This value
+ // was returned by the server when the client was assigned a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string aggregation_id = 1
+ ;
+
+ // The authorization token returned by the server when the client was assigned
+ // a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string authorization_token = 2
+ ;
+}
+
+message StartSecureAggregationMetadata {}
+
+message StartSecureAggregationResponse {
+ // Information to construct the URI to use for continuing the secure
+ // aggregation protocol.
+ ForwardingInfo secagg_protocol_forwarding_info = 1;
+
+ // Per-aggregand information.
+ map<string, SecureAggregandExecutionInfo> secure_aggregands = 2;
+
+ // Protocol execution information.
+ SecureAggregationProtocolExecutionInfo protocol_execution_info = 3;
+
+ // Information about where to upload masked result.
+ ByteStreamResource masked_result_resource = 4;
+
+ // Information about where to upload unmasked result (e.g. metrics).
+ ByteStreamResource nonmasked_result_resource = 5;
+
+ // Unique token that the client must include in the subsequent protocol
+ // requests.
+ string client_token = 6;
+}
+
+// Per-aggregand configuration information.
+message SecureAggregandExecutionInfo {
+ // Modulus for secure aggregation.
+ //
+ // The secure aggregation protocol will compute the sum modulo this modulus.
+ //
+ // To achieve equivalence with non-modular summation, the modulus must be
+ // larger than the sum of all client inputs, given the number of clients
+ // participating in the aggregation shard.
+ uint64 modulus = 1;
+}
+
+// Dynamic configuration of the Secure Aggregation protocol.
+message SecureAggregationProtocolExecutionInfo {
+ // Number of clients that a client may exchange data with while running
+ // Secure Aggregation protocol. In the case of a full graph SecAgg protocol
+ // this is a total number of clients that started the protocol.
+ // In the case of subgraph SecAgg protocol this is a number of neighbours
+ // that each client has.
+ int32 expected_number_of_clients = 1;
+
+ // Secure Aggregation client completion threshold. This is a parameter
+ // communicated by the server side of Secure Aggregation protocol to each
+ // client to establish Shamir sharing of secrets.
+ // Additionally, at least `minimum_surviving_clients_for_reconstruction` out
+ // of the initial `expected_number_of_clients` must 'survive' in order for
+ // the protocol to continue on the client side; otherwise the client will
+ // abort its connection.
+ int32 minimum_surviving_clients_for_reconstruction = 2;
+}
+
+// --- AbortSecureAggregation ---
+message AbortSecureAggregationRequest {
+ // The id of the aggregation session this client participates in. This value
+ // was returned by the server when the client was assigned a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string aggregation_id = 1
+ ;
+
+ // The client token returned by the server when the client was assigned a
+ // task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string client_token = 2
+ ;
+
+ // Status code and optional message for why the secure aggregation protocol
+ // was aborted.
+ Status status = 3;
+}
+
+message AbortSecureAggregationResponse {}
+
+// --- AdvertiseKeys ---
+message AdvertiseKeysRequest {
+ // The id of the aggregation session this client participates in. This value
+ // was returned by the server when the client was assigned a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string aggregation_id = 1
+ ;
+
+ // The client token returned by the server when the client was assigned a
+ // task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string client_token = 2
+ ;
+
+ // A pair of public keys for this client.
+ fcp.secagg.AdvertiseKeys advertise_keys = 3;
+}
+
+message AdvertiseKeysMetadata {
+ // The suggested duration between instances of polling the AdvertiseKeys LRO.
+ google.protobuf.Duration polling_interval = 1;
+}
+
+message AdvertiseKeysResponse {
+ // Information from the server so that the client can participate in the
+ // ShareKeys protocol stage. Contains a list of pairs of public keys, as well
+ // as the logging ID for the SecAgg execution.
+ fcp.secagg.ShareKeysRequest share_keys_server_request = 1;
+}
+
+// --- ShareKeys ---
+message ShareKeysRequest {
+ // The id of the aggregation session this client participates in. This value
+ // was returned by the server when the client was assigned a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string aggregation_id = 1
+ ;
+
+ // The client token returned by the server when the client was assigned a
+ // task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string client_token = 2
+ ;
+
+ // Information about the client's participation in the ShareKeys protocol
+ // stage. Contains a list of encrypted pairs of key shares (one for each other
+ // client).
+ fcp.secagg.ShareKeysResponse share_keys_client_response = 3;
+}
+
+message ShareKeysMetadata {
+ // The suggested duration between instances of polling the ShareKeys LRO.
+ google.protobuf.Duration polling_interval = 1;
+}
+
+message ShareKeysResponse {
+ // Information from the server so that the client can submit its masked
+ // result. Contains a list of shares of other clients' keys encrypted and
+ // intended for the client who receives this message.
+ fcp.secagg.MaskedInputCollectionRequest
+ masked_input_collection_server_request = 1;
+}
+
+// --- SubmitSecureAggregationResult ---
+message SubmitSecureAggregationResultRequest {
+ // The id of the aggregation session this client participates in. This value
+ // was returned by the server when the client was assigned a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string aggregation_id = 1
+ ;
+
+ // The client token returned by the server when the client was assigned a
+ // task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string client_token = 2
+ ;
+
+ // Name of the resource to which the masked result was uploaded. The blob
+ // uploaded at masked_input_resource name must contain a serialized
+ // fcp.secagg.MaskedInputCollectionResponse message.
+ string masked_result_resource_name = 3;
+
+ // Name of the resource to which the nonmasked result was uploaded.
+ string nonmasked_result_resource_name = 4;
+}
+
+message SubmitSecureAggregationResultMetadata {
+ // The suggested duration between instances of polling the
+ // SubmitSecureAggregationResult LRO.
+ google.protobuf.Duration polling_interval = 1;
+}
+
+message SubmitSecureAggregationResultResponse {
+ // Information from the server so that the client can participate in the
+ // Unmasking protocol stage.
+ fcp.secagg.UnmaskingRequest unmasking_server_request = 1;
+}
+
+// --- Unmask ---
+message UnmaskRequest {
+ // The id of the aggregation session this client participates in. This value
+ // was returned by the server when the client was assigned a task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string aggregation_id = 1
+ ;
+
+ // The client token returned by the server when the client was assigned a
+ // task.
+ //
+ // Note that HTTP clients set this value in the request URL instead of the
+ // request body.
+ string client_token = 2
+ ;
+
+ // Information about the client's participation in the Unmasking protocol
+ // stage.
+ fcp.secagg.UnmaskingResponse unmasking_client_response = 3;
+}
+
+message UnmaskResponse {}
diff --git a/fcp/protos/federatedcompute/task_assignments.proto b/fcp/protos/federatedcompute/task_assignments.proto
new file mode 100644
index 0000000..1aabf45
--- /dev/null
+++ b/fcp/protos/federatedcompute/task_assignments.proto
@@ -0,0 +1,282 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package google.internal.federatedcompute.v1;
+
+import "google/protobuf/duration.proto";
+import "fcp/protos/federated_api.proto";
+import "fcp/protos/federatedcompute/common.proto";
+
+option java_package = "com.google.internal.federatedcompute.v1";
+option java_multiple_files = true;
+
+service TaskAssignments {
+ // A request sent by the device to check if it should participate in the
+ // current task.
+ //
+ // Clients should use the same `ForwardingInfo` (incl. the headers) as used in
+ // this request to construct the URI to poll the operation service to check
+ // for task assignment completion.
+ //
+ // When the task assignment is complete, the result of the operation will
+ // either contain an error or the resulting `StartTaskAssignmentResponse` in
+ // the response.
+ //
+ // If the client no longer needs a task assignment because it is interrupted
+ // or drops out or an error occurs during polling the long running operation,
+ // the client should make a best effort to call CancelOperation.
+ //
+ // If the returned operation is not complete, clients should poll for status
+ // at the rate specified in the StartTaskAssignmentMetadata.
+ rpc StartTaskAssignment(StartTaskAssignmentRequest)
+ returns (StartTaskAssignmentResponse) {}
+
+ // A request sent by the device to participate in multiple tasks
+ // simultaneously.
+ //
+ // Unlike StartTaskAssignment, which returns at most one task assignment of
+ // the server's choice, this RPC attempts to return assignments for *all*
+ // tasks requested by the client.
+ rpc PerformMultipleTaskAssignments(PerformMultipleTaskAssignmentsRequest)
+ returns (PerformMultipleTaskAssignmentsResponse) {}
+
+ // A request sent by the device to report the result of running the task
+ // provided by `StartTaskAssignmentResponse`.
+ //
+ // Clients should use the same `ForwardingInfo` as used in the
+ // `StartTaskAssignment` request to construct the URI for this request.
+ //
+ // A result with a status code other than Code.OK indicates client session
+ // termination. The client may not send any future requests with the given
+ // session_id.
+ rpc ReportTaskResult(ReportTaskResultRequest)
+ returns (ReportTaskResultResponse) {}
+}
+
+message StartTaskAssignmentRequest {
+ // The name of the population this client belongs to.
+ //
+ // Note that http clients set this value in the request URL instead of the
+ // request body.
+ string population_name = 1
+ ;
+
+ // The session id returned by the server.
+ //
+ // Note that http clients set this value in the request URL instead of the
+ // request body.
+ string session_id = 2
+ ;
+
+ reserved 3;
+
+ ClientVersion client_version = 4;
+
+ // The client computes this message using the task returned by a previous
+ // `RequestEligibilityEvalTask` call.
+ //
+ // If this field is set, it describes to the server which tasks the client is
+ // (in)eligible. The server must take this information into account when
+ // deciding task to serve in response to this request.
+ //
+ // If this field is unset, it may indicate that the client previously received
+ // an `EligibilityEvalTask` without an `EligibilityEvalPayload` message (i.e.
+ // the population did not have an eligibility-computing task configured at the
+ // time of the request). It may also indicate a client for which the
+ // eligibility-computing task feature has been disabled, or an old client that
+ // does not support this feature yet.
+ //
+ // If this field is unset but the population has an eligibility-computing task
+ // configured, then the server must reject this client, since the server has
+ // no way to determine which tasks the client is (in)eligible for.
+ //
+ // If this field is unset and the population does not have an
+ // eligibility-computing task configured, then the server may serve this
+ // client any task.
+ //
+ google.internal.federatedml.v2.TaskEligibilityInfo task_eligibility_info = 5;
+
+ // The client's capabilities when downloading and processing resources.
+ ResourceCapabilities resource_capabilities = 6;
+}
+
+message StartTaskAssignmentMetadata {
+ // The suggested duration between instances of polling the StartTaskAssignment
+ // LRO.
+ google.protobuf.Duration polling_interval = 1;
+}
+
+message StartTaskAssignmentResponse {
+ // One of two outcomes, depending on server's decision on participation of the
+ // client.
+ oneof result {
+ // If the client joined the task with this call, information on how to
+ // proceed.
+ TaskAssignment task_assignment = 1;
+
+ // If the client was not accepted, information how to proceed.
+ RejectionInfo rejection_info = 2;
+ }
+}
+
+message PerformMultipleTaskAssignmentsRequest {
+ // The name of the population this client belongs to.
+ //
+ // Note that http clients set this value in the request URL instead of the
+ // request body.
+ string population_name = 1
+ ;
+
+ // The session id returned by the server.
+ string session_id = 2
+ ;
+
+ // The client's version information.
+ ClientVersion client_version = 3;
+
+ // The client's capabilities when downloading and processing resources.
+ ResourceCapabilities resource_capabilities = 4;
+
+ // The list of tasks for which the client would like TaskAssignments. These
+ // names are produced by running the population's Eligibility Eval task (see
+ // /EligibilityEvalTasks.RequestEligibilityEvalTask).
+ repeated string task_names = 5
+ ;
+}
+
+message PerformMultipleTaskAssignmentsResponse {
+ // The TaskAssignments requested by the client. The server may omit requested
+ // TaskAssignments, e.g. for any tasks that no longer exist or are not
+ // currently accepting client contributions; these cases should be infrequent.
+ repeated TaskAssignment task_assignments = 1;
+}
+
+// When client (device) is accepted for the current task, this data structure
+// carries information necessary to begin task execution.
+message TaskAssignment {
+ // Information to construct the URI to use for providing resulting aggregation
+ // data after task execution.
+ ForwardingInfo aggregation_data_forwarding_info = 1;
+
+ message AggregationInfo {}
+
+ message SecureAggregationInfo {
+ // The minimum number of clients' values that must be aggregated together
+ // before the server can gain access to the aggregate,
+ // even transiently (e.g. in RAM).
+ // This isn't needed by Secure Aggregation protocol on the client side but
+ // shared by the server with clients for transparency or policy reasons.
+ int32 minimum_clients_in_server_visible_aggregate = 1;
+ }
+
+ oneof aggregation_type {
+ // If set, indicates that the Aggregations service (see
+ // `aggregations.proto`) should be used to perform aggregation.
+ AggregationInfo aggregation_info = 9;
+
+ // If set, indicates that the SecureAggregations service (see
+ // `secure_aggregations.proto`) should be used to perform aggregation.
+ SecureAggregationInfo secure_aggregation_info = 10;
+ }
+
+ // Unique identifier for the client's protocol session.
+ string session_id = 5;
+
+ // The opaque id of the aggregation session the client has joined. This is a
+ // string generated by the server and MUST NOT contain any information that
+ // could be used to identify a specific device.
+ string aggregation_id = 2;
+
+ // Unique identifier for the client's participation in an aggregation session.
+ string authorization_token = 6;
+
+ // The name identifying the task that was assigned.
+ string task_name = 7;
+
+ // The checkpoint from which to start execution.
+ Resource init_checkpoint = 3;
+
+ // The plan to be used for execution.
+ Resource plan = 4;
+
+ // Info for how to generate URIs for fetching slices at runtime.
+ FederatedSelectUriInfo federated_select_uri_info = 8;
+}
+
+// Info for how to generate URIs for fetching slices that the task might request
+// to be downloaded at runtime.
+//
+// When one or more slices are requested by the task, the template specified
+// here should be used to form a URI from which the client can download the
+// slice data, by replacing the "{served_at_id}" and "{key_base10}" substrings
+// with the `google.internal.federated.plan.SlicesSelector.served_at_id` and the
+// base-10 representation of the `SlicesSelector.keys` value. The client must
+// not perform any URI escaping to the values that the substrings are replaced
+// with.
+message FederatedSelectUriInfo {
+ // The URI template to use for fetching slices.
+ //
+ // This template must always start with "https://".
+ //
+ // This template must contain the following substrings: "{served_at_id}" and
+ // "{key_base10}", as per the above documentation.
+ string uri_template = 1;
+}
+
+message ReportTaskResultRequest {
+ // The name of the population this client belongs to.
+ //
+ // Note that http clients set this value in the request URL instead of the
+ // request body.
+ string population_name = 1
+ ;
+
+ // The session id returned by the server.
+ //
+ // Note that http clients set this value in the request URL instead of the
+ // request body.
+ string session_id = 2
+ ;
+
+ // The opaque id of the aggregation session the client has joined. This is a
+ // string generated by the server and MUST NOT contain any information that
+ // could be used to identify a specific device.
+ string aggregation_id = 5
+ ;
+
+ // The name identifying the task that was assigned.
+ string task_name = 6;
+
+ // Computation status code reported by client.
+ // Code.OK indicates that the client computation completed successfully. Any
+ // other code indicates unsuccessful computation execution and termination of
+ // the protocol session.
+ int32 computation_status_code = 3;
+
+ // Stats produced during the client-side execution of the task.
+ ClientStats client_stats = 4;
+}
+
+// This message is used to report client stats and execution duration to the
+// server.
+message ClientStats {
+ // The time spent on running the task (includes I/O such as reading examples,
+ // but does not include time spent on the network for retrieving the task
+ // or uploading results).
+ google.protobuf.Duration computation_execution_duration = 1;
+}
+
+message ReportTaskResultResponse {}
diff --git a/fcp/protos/opstats.proto b/fcp/protos/opstats.proto
new file mode 100644
index 0000000..54c5041
--- /dev/null
+++ b/fcp/protos/opstats.proto
@@ -0,0 +1,346 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.client.opstats;
+
+import "google/protobuf/duration.proto";
+import "google/protobuf/timestamp.proto";
+import "fcp/protos/federated_api.proto";
+
+// Operational stats per run.
+message OperationalStats {
+ // Population name.
+ string population_name = 1;
+
+ // Session name, if applicable.
+ string session_name = 2;
+
+ // Name of the task that was executed.
+ string task_name = 3;
+
+ // Timestamped training stages and error types.
+ message Event {
+ // Key training stages and error types.
+ enum EventKind {
+ EVENT_KIND_UNRECOGNIZED = 0;
+
+ // An eligibility task checkin attempt started. This does not
+ // indicate whether the eligibility checkin request was actually sent.
+ EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED = 1;
+
+ // An eligibility task checkin response indicated that the client was
+ // rejected because the client was incompatible with the population's
+ // eligibility task plan.
+ EVENT_KIND_ELIGIBILITY_REJECTED = 2;
+
+ // An eligibility task checkin response indicated that eligibility task
+ // plans are not configured.
+ EVENT_KIND_ELIGIBILITY_DISABLED = 3;
+
+ // An eligibility task checkin response return an eligibility task plan
+ // URI, but the client hasn't downloaded the plan and checkpoint yet. Also
+ // logged when the plan/checkpoint resources were actually supplied inline
+ // in the protocol response message and no actual HTTP fetch needs to
+ // happen anymore. This ensures that this event can always be compared
+ // against EVENT_KIND_ELIGIBILITY_ENABLED.
+ EVENT_KIND_ELIGIBILITY_PLAN_URI_RECEIVED = 48;
+
+ // An eligibility task checkin response returned an eligibility task plan,
+ // and the received plan was parseable.
+ EVENT_KIND_ELIGIBILITY_ENABLED = 4;
+
+ // A plan execution started for an eligibility task.
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_STARTED = 5;
+
+ // A plan execution completed successfully for an eligibility task.
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_FINISHED = 6;
+
+ // A checkin attempt started. This does not indicate whether the checkin
+ // request was actually sent.
+ EVENT_KIND_CHECKIN_STARTED = 7;
+
+ // A checkin response indicated that the client was rejected.
+ EVENT_KIND_CHECKIN_REJECTED = 8;
+
+ // A checkin response indicated that the client was accepted for a task,
+ // but the client hasn't downloaded the plan and checkpoint yet. Also
+ // logged when the plan/checkpoint resources were actually supplied inline
+ // in the protocol response message and no actual HTTP fetch needs to
+ // happen anymore. This ensures that this event can always be compared
+ // against EVENT_KIND_CHECKIN_ACCEPTED.
+ EVENT_KIND_CHECKIN_PLAN_URI_RECEIVED = 49;
+
+ // A checkin response indicated that the client was accepted for a task,
+ // and the received plan was parseable.
+ EVENT_KIND_CHECKIN_ACCEPTED = 9;
+
+ // A plan execution started for a normal task.
+ EVENT_KIND_COMPUTATION_STARTED = 10;
+
+ // A plan execution completed successfully for a normal task.
+ EVENT_KIND_COMPUTATION_FINISHED = 11;
+
+ // An upload attempt started. This does not indicate whether the upload
+ // was actually sent.
+ // Deprecated: split into EVENT_KIND_RESULT_UPLOAD_STARTED and
+ // EVENT_KIND_FAILURE_UPLOAD_STARTED.
+ EVENT_KIND_UPLOAD_STARTED = 12 [deprecated = true];
+
+ // An upload response indicated that the server successfully received the
+ // client's upload. This does not guarantee that the client's results are
+ // included in a round update.
+ // Deprecated: split into EVENT_KIND_RESULT_UPLOAD_FINISHED and
+ // EVENT_KIND_FAILURE_UPLOAD_FINISHED.
+ EVENT_KIND_UPLOAD_FINISHED = 13 [deprecated = true];
+
+ // The client interrupted training due to unmet training conditions. This
+ // may occur during checkin, training, or upload.
+ // Deprecated: split into EVENT_KIND_{phase}_INTERRUPTED, where phase is
+ // one of ELIGIBILITY_CHECKIN, ELIGIBILITY_COMPUTATION, CHECKIN,
+ // COMPUTATION, RESULT_UPLOAD, FAILURE_UPLOAD.
+ EVENT_KIND_CLIENT_INTERRUPTED = 14 [deprecated = true];
+
+ // The server aborted the client's connection. This may occur during
+ // checkin or upload.
+ // Deprecated: split into EVENT_KIND_{phase}_SERVER_ABORTED, where phase
+ // is one of ELIGIBILITY_CHECKIN, CHECKIN, RESULT_UPLOAD, FAILURE_UPLOAD.
+ EVENT_KIND_SERVER_ABORTED = 15 [deprecated = true];
+
+ // An error occurred that was related to local storage access,
+ // communication with the server, or an invalid plan.
+ // Deprecated: split into EVENT_KIND_{phase}_ERROR_IO,
+ // EVENT_KIND_{phase}_ERROR_INVALID_ARGUMENT and
+ // EVENT_KIND_{phase}_ERROR_INVALID_PAYLOAD, where phase is one of
+ // ELIGIBILITY_CHECKIN, CHECKIN, RESULT_UPLOAD, FAILURE_UPLOAD,
+ // ELIGIBILITY_COMPUTATION, or COMPUTATION.
+ EVENT_KIND_ERROR_IO = 16 [deprecated = true];
+
+ // The TensorFlow library reported an error.
+ // Deprecated: split into EVENT_KIND_{phase}_ERROR_TENSORFLOW, where phase
+ // is one of ELIGIBILITY_COMPUTATION, COMPUTATION.
+ EVENT_KIND_ERROR_TENSORFLOW = 17 [deprecated = true];
+
+ // An error occurred when processing the example selector.
+ // Deprecated: split into EVENT_KIND_{phase}_ERROR_EXAMPLE_ITERATOR, where
+ // phase is one of ELIGIBILITY_EVAL_COMPUTATION, COMPUTATION.
+ EVENT_KIND_ERROR_EXAMPLE_SELECTOR = 18 [deprecated = true];
+
+ // Indicates that training was scheduled but did not start due to runtime
+ // checks (e.g. insufficient battery levels).
+ EVENT_KIND_TRAIN_NOT_STARTED = 19;
+
+ // Client issued an eligibility eval checkin request, but an IO error was
+ // encountered.
+ // Always preceded by EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED.
+ EVENT_KIND_ELIGIBILITY_CHECKIN_ERROR_IO = 20;
+
+ // Client issued an eligibility eval checkin request, but an invalid
+ // payload was received.
+ // Always preceded by EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED.
+ EVENT_KIND_ELIGIBILITY_CHECKIN_ERROR_INVALID_PAYLOAD = 21;
+
+ // Client issued an eligibility eval checkin request, but got interrupted
+ // on the client. Always preceded by
+ // EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED.
+ EVENT_KIND_ELIGIBILITY_CHECKIN_CLIENT_INTERRUPTED = 22;
+
+ // Client issued an eligibility eval checkin request, but server aborted.
+ // Always preceded by EVENT_KIND_ELIGIBILITY_CHECKIN_STARTED.
+ EVENT_KIND_ELIGIBILITY_CHECKIN_SERVER_ABORTED = 23;
+
+ // Client issued a regular checkin request, but got an IO error.
+ // Always preceded by EVENT_KIND_CHECKIN_STARTED.
+ EVENT_KIND_CHECKIN_ERROR_IO = 24;
+
+ // Client issued a regular checkin request, but the server returned an
+ // invalid payload.
+ // Always preceded by EVENT_KIND_CHECKIN_STARTED.
+ EVENT_KIND_CHECKIN_ERROR_INVALID_PAYLOAD = 25;
+
+ // Client issued a regular checin request, but got interrupted on the
+ // client. Always preceded by EVENT_KIND_CHECKIN_STARTED.
+ EVENT_KIND_CHECKIN_CLIENT_INTERRUPTED = 26;
+
+ // Client issued a regular checin request, but got aborted by the server.
+ // Always preceded by EVENT_KIND_CHECKIN_STARTED.
+ EVENT_KIND_CHECKIN_SERVER_ABORTED = 27;
+
+ // Client encountered a TensorFlow error during eligibility eval task
+ // computation.
+ // Always preceded by EVENT_KIND_ELIGIBILITY_COMPUTATION_STARTED.
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_TENSORFLOW = 28;
+
+ // Reading from disk failed during eligibility eval task computation.
+ // Always preceded by EVENT_KIND_ELIGIBILITY_COMPUTATION_STARTED.
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_IO = 29;
+
+ // Input parameters are invalid for eligibility eval task computation.
+ // Always preceded by EVENT_KIND_ELIGIBILITY_COMPUTATION_STARTED.
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_INVALID_ARGUMENT = 30;
+
+ // Client encountered an example selector error during eligibility eval
+ // task computation. Always preceded by
+ // EVENT_KIND_ELIGIBILITY_COMPUTATION_STARTED.
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_ERROR_EXAMPLE_ITERATOR = 31;
+
+ // Eligibility eval computation was interrupted by the client.
+ // Always preceded by EVENT_KIND_ELIGIBILITY_COMPUTATION_STARTED.
+ EVENT_KIND_ELIGIBILITY_COMPUTATION_CLIENT_INTERRUPTED = 32;
+
+ // A TensorFlow error was encountered during computation, or the output
+ // from the computation was missing or of an unexpected type. Always
+ // preceded by EVENT_KIND_COMPUTATION_STARTED.
+ EVENT_KIND_COMPUTATION_ERROR_TENSORFLOW = 33;
+
+ // Reading from disk failed during computation.
+ // Always preceded by EVENT_KIND_COMPUTATION_STARTED.
+ EVENT_KIND_COMPUTATION_ERROR_IO = 34;
+
+ // Input parameters are invalid for the given computation.
+ // Always preceded by EVENT_KIND_COMPUTATION_STARTED.
+ EVENT_KIND_COMPUTATION_ERROR_INVALID_ARGUMENT = 35;
+
+ // An error occurred when processing the example selector.
+ // Always preceded by EVENT_KIND_COMPUTATION_STARTED.
+ EVENT_KIND_COMPUTATION_ERROR_EXAMPLE_ITERATOR = 36;
+
+ // Client got interrupted during computation.
+ // Always preceded by EVENT_KIND_COMPUTATION_STARTED.
+ EVENT_KIND_COMPUTATION_CLIENT_INTERRUPTED = 37;
+
+ // Client starts to upload successfully computed results.
+ EVENT_KIND_RESULT_UPLOAD_STARTED = 38;
+
+ // An error occurred during upload.
+ // Always preceded by EVENT_KIND_RESULT_UPLOAD_STARTED.
+ EVENT_KIND_RESULT_UPLOAD_ERROR_IO = 39;
+
+ // Upload was interrupted by the client.
+ // Always preceded by EVENT_KIND_RESULT_UPLOAD_STARTED.
+ EVENT_KIND_RESULT_UPLOAD_CLIENT_INTERRUPTED = 40;
+
+ // Upload was aborted by the server.
+ // Always preceded by EVENT_KIND_RESULT_UPLOAD_STARTED.
+ EVENT_KIND_RESULT_UPLOAD_SERVER_ABORTED = 41;
+
+ // Client uploaded training results to the server
+ // Always preceded by EVENT_KIND_RESULT_UPLOAD_STARTED.
+ EVENT_KIND_RESULT_UPLOAD_FINISHED = 42;
+
+ // Client starts to upload failure report.
+ EVENT_KIND_FAILURE_UPLOAD_STARTED = 43;
+
+ // An error occurred during upload.
+ // Always preceded by EVENT_KIND_FAILURE_UPLOAD_STARTED.
+ EVENT_KIND_FAILURE_UPLOAD_ERROR_IO = 44;
+
+ // Upload was interrupted.
+ // Always preceded by EVENT_KIND_FAILURE_UPLOAD_STARTED.
+ EVENT_KIND_FAILURE_UPLOAD_CLIENT_INTERRUPTED = 45;
+
+ // Upload was interrupted.
+ // Always preceded by EVENT_KIND_FAILURE_UPLOAD_STARTED.
+ EVENT_KIND_FAILURE_UPLOAD_SERVER_ABORTED = 46;
+
+ // Client uploaded failure report to the server
+ // Always preceded by EVENT_KIND_FAILURE_UPLOAD_STARTED.
+ EVENT_KIND_FAILURE_UPLOAD_FINISHED = 47;
+
+ // Client failed to initialize a component, but execution was not halted.
+ EVENT_KIND_INITIALIZATION_ERROR_NONFATAL = 50;
+
+ // Client failed to initialize a component, and execution was halted.
+ EVENT_KIND_INITIALIZATION_ERROR_FATAL = 51;
+ }
+
+ EventKind event_type = 1;
+
+ // Event time.
+ google.protobuf.Timestamp timestamp = 2;
+ }
+
+ // History of key training stages and errors encountered during a run. The
+ // events are stored in sequential order, with the earliest event first.
+ repeated Event events = 4;
+
+ // Stats about the examples read from a given collection, potentially
+ // aggregated over multiple iterators.
+ message DatasetStats {
+ // Total number of examples read.
+ int64 num_examples_read = 1;
+
+ // Total number of bytes read.
+ int64 num_bytes_read = 2;
+ }
+
+ // Map of dataset stats keyed on the collection URI.
+ map<string, DatasetStats> dataset_stats = 5;
+
+ // If this execution failed with an error, the message of that error.
+ string error_message = 6;
+
+ // The retry window returned by the fl runner.
+ google.internal.federatedml.v2.RetryWindow retry_window = 7;
+
+ // The number of bytes downloaded (payload size via the chunking layer, which
+ // may be compressed) from the server while executing the task thus far.
+ int64 chunking_layer_bytes_downloaded = 10;
+
+ // The number of bytes uploaded (payload size via the chunking layer, which
+ // may be compressed) from the server while executing the task thus far.
+ int64 chunking_layer_bytes_uploaded = 11;
+
+ // The duration of time spent waiting on the network (but excluding idle time
+ // like the time between polling the server).
+ google.protobuf.Duration network_duration = 12;
+
+ reserved 8, 9;
+}
+
+// Top level op stats message.
+message OpStatsSequence {
+ // The OperationalStats messages are stored in sequential order, with the
+ // earliest message first.
+ repeated OperationalStats opstats = 1;
+ // A timestamp that marks when we can start to trust the data in the
+ // OpStatsDb. Any event happens before this time is missing or removed.
+ google.protobuf.Timestamp earliest_trustworthy_time = 2;
+}
+
+// Selection criteria for op stats data.
+// If this selection criteria not set, all data will be used.
+// If start_time is not set but end_time is set, all examples up to end_time
+// will be used.
+// If end_time is not set, all examples after start_time will be used.
+// If neither start_time nor end_time are set, all examples will be used.
+// If both start_time and end_time are set, the examples within the time range
+// will be used.
+// If last_successful_contribution is set, start_time and end_time are ignored,
+// and opstats returns a single example containing the entry of the last
+// successful contribution (if it exists) of the runtime to the current task. If
+// there are no previous successful contributions, returns an empty iterator.
+message OpStatsSelectionCriteria {
+ // The lower bound (inclusive) of the last updated time for a OperationalStats
+ // message.
+ google.protobuf.Timestamp start_time = 1;
+ // The upper bound (inclusive) of the last updated time for a OperationalStats
+ // message.
+ google.protobuf.Timestamp end_time = 2;
+ // If set, returns the entry of the last successful contribution to the
+ // current task, or no entries if there was no successful contribution.
+ // `start_time` and `end_time are ignored.
+ bool last_successful_contribution = 3;
+}
diff --git a/fcp/protos/plan.proto b/fcp/protos/plan.proto
new file mode 100644
index 0000000..5c606c9
--- /dev/null
+++ b/fcp/protos/plan.proto
@@ -0,0 +1,1380 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package google.internal.federated.plan;
+
+import "google/protobuf/any.proto";
+import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+import "tensorflow/core/protobuf/saver.proto";
+import "tensorflow/core/protobuf/struct.proto";
+
+option java_package = "com.google.internal.federated.plan";
+option java_multiple_files = true;
+option java_outer_classname = "PlanProto";
+
+// Primitives
+// ===========
+
+// Represents an operation to save or restore from a checkpoint. Some
+// instances of this message may only be used either for restore or for
+// save, others for both directions. This is documented together with
+// their usage.
+//
+// This op has four essential uses:
+// 1. read and apply a checkpoint.
+// 2. write a checkpoint.
+// 3. read and apply from an aggregated side channel.
+// 4. write to a side channel (grouped with write a checkpoint).
+// We should consider splitting this into four separate messages.
+message CheckpointOp {
+ // An optional standard saver def. If not provided, only the
+ // op(s) below will be executed. This must be a version 1 SaverDef.
+ tensorflow.SaverDef saver_def = 1;
+
+ // An optional operation to run before the saver_def is executed for
+ // restore.
+ string before_restore_op = 2;
+
+ // An optional operation to run after the saver_def has been
+ // executed for restore. If side_channel_tensors are provided, then
+ // they should be provided in a feed_dict to this op.
+ string after_restore_op = 3;
+
+ // An optional operation to run before the saver_def will be
+ // executed for save.
+ string before_save_op = 4;
+
+ // An optional operation to run after the saver_def has been
+ // executed for save. If there are side_channel_tensors, this op
+ // should be run after the side_channel_tensors have been fetched.
+ string after_save_op = 5;
+
+ // In addition to being saved and restored from a checkpoint, one can
+ // also save and restore via a side channel. The keys in this map are
+ // the names of the tensors transmitted by the side channel. These (key)
+ // tensors should be read off just before saving a SaveDef and used
+ // by the code that handles the side channel. Any variables provided this
+ // way should NOT be saved in the SaveDef.
+ //
+ // For restoring, the variables that are provided by the side channel
+ // are restored differently than those for a checkpoint. For those from
+ // the side channel, these should be restored by calling the before_restore_op
+ // with a feed dict whose keys are the restore_names in the SideChannel and
+ // whose values are the values to be restored.
+ map<string, SideChannel> side_channel_tensors = 6;
+
+ // An optional name of a tensor in to which a unique token for the current
+ // session should be written.
+ //
+ // This session identifier allows TensorFlow ops such as `ServeSlices` or
+ // `ExternalDataset` to refer to callbacks and other session-global objects
+ // registered before running the session.
+ string session_token_tensor_name = 7;
+}
+
+message SideChannel {
+ // A side channel whose variables are processed via SecureAggregation.
+ // This side channel implements aggregation via sum over a set of
+ // clients, so the restored tensor will be a sum of multiple clients
+ // inputs into the side channel. Hence this will restore during the
+ // read_aggregate_update restore, not the per-client read_update restore.
+ message SecureAggregand {
+ message Dimension {
+ int64 size = 1;
+ }
+
+ // Dimensions of the aggregand. This is used by the secure aggregation
+ // protocol in its early rounds, not as redundant info which could be
+ // obtained by reading the dimensions of the tensor itself.
+ repeated Dimension dimension = 3;
+
+ // The data type anticipated by the server-side graph.
+ tensorflow.DataType dtype = 4;
+
+ // SecureAggregation will compute sum modulo this modulus.
+ message FixedModulus {
+ uint64 modulus = 1;
+ }
+
+ // SecureAggregation will for each shard compute sum modulo m with m at
+ // least (1 + shard_size * (base_modulus - 1)), then aggregate
+ // shard results with non-modular addition. Here, shard_size is the number
+ // of clients in the shard.
+ //
+ // Note that the modulus for each shard will be greater than the largest
+ // possible (non-modular) sum of the inputs to that shard. That is,
+ // assuming each client has input on range [0, base_modulus), the result
+ // will be identical to non-modular addition (i.e. federated_sum).
+ //
+ // While any m >= (1 + shard_size * (base_modulus - 1)), the current
+ // implementation takes
+ // m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the
+ // smallest possible value of m that is also a power of 2. This choice is
+ // made because (a) it uses the same number of bits per vector entry as
+ // valid smaller m, using the current on-the-wire encoding scheme, and (b)
+ // it enables the underlying mask-generation PRNG to run in its most
+ // computationally efficient mode, which can be up to 2x faster.
+ message ModulusTimesShardSize {
+ uint64 base_modulus = 1;
+ }
+
+ oneof modulus_scheme {
+ // Bitwidth of the aggregand.
+ //
+ // This is the bitwidth of an input value (i.e. the bitwidth that
+ // quantization should target). The Secure Aggregation bitwidth (i.e.,
+ // the bitwidth of the *sum* of the input values) will be a function of
+ // this bitwidth and the number of participating clients, as negotiated
+ // with the server when the protocol is initiated.
+ //
+ // Deprecated; prefer fixed_modulus instead.
+ int32 quantized_input_bitwidth = 2 [deprecated = true];
+
+ FixedModulus fixed_modulus = 5;
+ ModulusTimesShardSize modulus_times_shard_size = 6;
+ }
+
+ reserved 1;
+ }
+
+ // What type of side channel is used.
+ oneof type {
+ SecureAggregand secure_aggregand = 1;
+ }
+
+ // When restoring the name of the tensor to restore to. This is the name
+ // (key) supplied in the feed_dict in the before_restore_op in order to
+ // restore the tensor provided by the side channel (which will be the
+ // value in the feed_dict).
+ string restore_name = 2;
+}
+
+// Container for a metric used by the internal toolkit.
+message Metric {
+ // Name of an Op to run to read the value.
+ string variable_name = 1;
+
+ // A human-readable name for the statistic. Metric names are usually
+ // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
+ // Must be 7-bit ASCII and under 122 characters.
+ string stat_name = 2;
+
+ // The human-readable name of another metric by which this metric should be
+ // normalized, if any. If empty, this Metric should be aggregated with simple
+ // summation. If not empty, the Metric is aggregated according to
+ // weighted_metric_sum = sum_i (metric_i * weight_i)
+ // weight_sum = sum_i weight_i
+ // average_metric_value = weighted_metric_sum / weight_sum
+ string weight_name = 3;
+}
+
+// Controls the format of output metrics users receive. Represents instructions
+// for how metrics are to be output to users, controlling the end format of
+// the metric users receive.
+message OutputMetric {
+ // Metric name.
+ string name = 1;
+
+ oneof value_source {
+ // A metric representing one stat with aggregation type sum.
+ SumOptions sum = 2;
+
+ // A metric representing a ratio between metrics with aggregation
+ // type average.
+ AverageOptions average = 3;
+
+ // A metric that is not aggregated by the MetricReportAggregator or
+ // metrics_loader. This includes metrics like 'num_server_updates' that are
+ // aggregated in TensorFlow.
+ NoneOptions none = 4;
+
+ // A metric representing one stat with aggregation type only sample.
+ // Samples at most 101 clients' values.
+ OnlySampleOptions only_sample = 5;
+ }
+ // Iff True, the metric will be plotted in the default view of the
+ // task level Colab automatically.
+ oneof visualization_info {
+ bool auto_plot = 6 [deprecated = true];
+ VisualizationSpec plot_spec = 7;
+ }
+}
+
+message VisualizationSpec {
+ // Different allowable plot types.
+ enum VisualizationType {
+ NONE = 0;
+ DEFAULT_PLOT_FOR_TASK_TYPE = 1;
+ LINE_PLOT = 2;
+ LINE_PLOT_WITH_PERCENTILES = 3;
+ HISTOGRAM = 4;
+ }
+
+ // Defines the plot type to provide downstream.
+ VisualizationType plot_type = 1;
+
+ // The x-axis which to provide for the given metric. Must be the name of a
+ // metric or counter. Recommended x_axis options are source_round, round,
+ // or time.
+ string x_axis = 2;
+
+ // Iff True, metric will be displayed on a population level dashboard.
+ bool plot_on_population_dashboard = 3;
+}
+
+// A metric representing one stat with aggregation type sum.
+message SumOptions {
+ // Name for corresponding Metric stat_name field.
+ string stat_name = 1;
+
+ // Iff True, a cumulative sum over rounds will be provided in addition to a
+ // sum per round for the value metric.
+ bool include_cumulative_sum = 2;
+
+ // Iff True, sample of at most 101 clients' values.
+ // Used to calculate quantiles in downstream visualization pipeline.
+ bool include_client_samples = 3;
+}
+
+// A metric representing a ratio between metrics with aggregation type average.
+// Represents: numerator stat / denominator stat.
+message AverageOptions {
+ // Numerator stat name pointing to corresponding Metric stat_name.
+ string numerator_stat_name = 1;
+
+ // Denominator stat name pointing to corresponding Metric stat_name.
+ string denominator_stat_name = 2;
+
+ // Name for corresponding Metric stat_name that is the ratio of the
+ // numerator stat / denominator stat.
+ string average_stat_name = 3;
+
+ // Iff True, sample of at most 101 client's values.
+ // Used to calculate quantiles in downstream visualization pipeline.
+ bool include_client_samples = 4;
+}
+
+// A metric representing one stat with aggregation type none.
+message NoneOptions {
+ // Name for corresponding Metric stat_name field.
+ string stat_name = 1;
+}
+
+// A metric representing one stat with aggregation type only sample.
+message OnlySampleOptions {
+ // Name for corresponding Metric stat_name field.
+ string stat_name = 1;
+}
+
+// Represents a data set. This is used for testing.
+message Dataset {
+ // Represents the data set for one client.
+ message ClientDataset {
+ // A string identifying the client.
+ string client_id = 1;
+
+ // A list of serialized tf.Example protos.
+ repeated bytes example = 2;
+
+ // Represents a dataset whose examples are selected by an ExampleSelector.
+ message SelectedExample {
+ ExampleSelector selector = 1;
+ repeated bytes example = 2;
+ }
+
+ // A list of (selector, dataset) pairs. Used in testing some *TFF-based
+ // tasks* that require multiple datasets as client input, e.g., a TFF-based
+ // personalization eval task requires each client to provide at least two
+ // datasets: one for train, and the other for test.
+ repeated SelectedExample selected_example = 3;
+ }
+
+ // A list of client data.
+ repeated ClientDataset client_data = 1;
+}
+
+// Represents predicates over metrics - i.e., expectations. This is used in
+// training/eval tests to encode metric names and values expected to be reported
+// by a client execution.
+message MetricTestPredicates {
+ // The value must lie in [lower_bound; upper_bound]. Can also be used for
+ // approximate matching (lower == value - epsilon; upper = value + epsilon).
+ message Interval {
+ double lower_bound = 1;
+ double upper_bound = 2;
+ }
+
+ // The value must be a real value as long as the value of the weight_name
+ // metric is non-zero. If the weight metric is zero, then it is acceptable for
+ // the value to be non-real.
+ message RealIfNonzeroWeight {
+ string weight_name = 1;
+ }
+
+ message MetricCriterion {
+ // Name of the metric.
+ string name = 1;
+
+ // FL training round this metric is expected to appear in.
+ int32 training_round_index = 2;
+
+ // If none of the following is set, no matching is performed; but the
+ // metric is still expected to be present (with whatever value).
+ oneof Criterion {
+ // The reported metric must be < lt.
+ float lt = 3;
+ // The reported metric must be > gt.
+ float gt = 4;
+ // The reported metric must be <= le.
+ float le = 5;
+ // The reported metric must be >= ge.
+ float ge = 6;
+ // The reported metric must be == eq.
+ float eq = 7;
+ // The reported metric must lie in the interval.
+ Interval interval = 8;
+ // The reported metric is not NaN or +/- infinity.
+ bool real = 9;
+ // The reported metric is real (i.e., not NaN or +/- infinity) if the
+ // value of an associated weight is not 0.
+ RealIfNonzeroWeight real_if_nonzero_weight = 10;
+ }
+ }
+
+ repeated MetricCriterion metric_criterion = 1;
+
+ reserved 2;
+}
+
+// Client Phase
+// ============
+
+// A `TensorflowSpec` that is executed on the client in a single `tf.Session`.
+// In federated optimization, this will correspond to one `ServerPhase`.
+message ClientPhase {
+ // A short CamelCase name for the ClientPhase.
+ string name = 2;
+
+ // Minimum number of clients in aggregation.
+ // In secure aggregation mode this is used to configure the protocol instance
+ // in a way that server can't learn aggregated values with number of
+ // participants lower than this number.
+ // Without secure aggregation server still respects this parameter,
+ // ensuring that aggregated values never leave server RAM unless they include
+ // data from (at least) specified number of participants.
+ int32 minimum_number_of_participants = 3;
+
+ // If populated, `io_router` must be specified.
+ oneof spec {
+ // A functional interface for the TensorFlow logic the client should
+ // perform.
+ TensorflowSpec tensorflow_spec = 4 [lazy = true];
+ // Spec for client plans that issue example queries and send the query
+ // results directly to an aggregator with no or little additional
+ // processing.
+ ExampleQuerySpec example_query_spec = 9 [lazy = true];
+ }
+
+ // The specification of the inputs coming either from customer apps
+ // (Local Compute) or the federated protocol (Federated Compute).
+ oneof io_router {
+ FederatedComputeIORouter federated_compute = 5 [lazy = true];
+ LocalComputeIORouter local_compute = 6 [lazy = true];
+ FederatedComputeEligibilityIORouter federated_compute_eligibility = 7
+ [lazy = true];
+ FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true];
+ }
+
+ reserved 1;
+}
+
+// TensorflowSpec message describes a single call into TensorFlow, including the
+// expected input tensors that must be fed when making that call, which
+// output tensors to be fetched, and any operations that have no output but must
+// be run. The TensorFlow session will then use the input tensors to do some
+// computation, generally reading from one or more datasets, and provide some
+// outputs.
+//
+// Conceptually, client or server code uses this proto along with an IORouter
+// to build maps of names to input tensors, vectors of output tensor names,
+// and vectors of target nodes:
+//
+// CreateTensorflowArguments(
+// TensorflowSpec& spec,
+// IORouter& io_router,
+// const vector<pair<string, Tensor>>* input_tensors,
+// const vector<string>* output_tensor_names,
+// const vector<string>* target_node_names);
+//
+// Where `input_tensor`, `output_tensor_names` and `target_node_names`
+// correspond to the arguments of TensorFlow C++ API for
+// `tensorflow::Session:Run()`, and the client executes only a single
+// invocation.
+//
+// Note: the execution engine never sees any concepts related to the federated
+// protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors
+// in, tensors out" interface. New aggregation methods can be added without
+// having to modify the execution engine / TensorflowSpec message, instead they
+// should modify the IORouter messages.
+//
+// Note: both `input_tensor_specs` and `output_tensor_specs` are full
+// `tensorflow.TensorSpecProto` messages, though TensorFlow technically
+// only requires the names to feed the values into the session. The additional
+// dtypes/shape information must always be included in case the runtime
+// executing this TensorflowSpec wants to perform additional, optional static
+// assertions. The runtimes however are free to ignore the dtype/shapes and only
+// rely on the names if so desired.
+//
+// Assertions:
+// - all names in `input_tensor_specs`, `output_tensor_specs`, and
+// `target_node_names` must appear in the serialized GraphDef where
+// the TF execution will be invoked.
+// - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise
+// there is nothing to execute in the graph.
+message TensorflowSpec {
+ // The name of a tensor into which a unique token for the current session
+ // should be written. The corresponding tensor is a scalar string tensor and
+ // is separate from `input_tensors` as there is only one.
+ //
+ // A session token allows TensorFlow ops such as `ServeSlices` or
+ // `ExternalDataset` to refer to callbacks and other session-global objects
+ // registered before running the session. In the `ExternalDataset` case, a
+ // single dataset_token is valid for multiple `tf.data.Dataset` objects as
+ // the token can be thought of as a handle to a dataset factory.
+ string dataset_token_tensor_name = 1;
+
+ // TensorSpecs of inputs which will be passed to TF.
+ //
+ // Corresponds to the `feed_dict` parameter of `tf.Session.run()` in
+ // TensorFlow's Python API, excluding the dataset_token listed above.
+ //
+ // Assertions:
+ // - All the tensor names designated as inputs in the corresponding IORouter
+ // must be listed (otherwise the IORouter input work is unused).
+ // - All placeholders in the TF graph must be listed here, with the
+ // exception of the dataset_token which is explicitly set above (otherwise
+ // TensorFlow will fail to execute).
+ repeated tensorflow.TensorSpecProto input_tensor_specs = 2;
+
+ // TensorSpecs that should be fetched from TF after execution.
+ //
+ // Corresponds to the `fetches` parameter of `tf.Session.run()` in
+ // TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++
+ // API.
+ //
+ // Assertions:
+ // - The set of tensor names here must strictly match the tensor names
+ // designated as outputs in the corresponding IORouter (if any exist).
+ repeated tensorflow.TensorSpecProto output_tensor_specs = 3;
+
+ // Node names in the graph that should be executed, but the output not
+ // returned.
+ //
+ // Corresponds to the `fetches` parameter of `tf.Session.run()` in
+ // TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++
+ // API.
+ //
+ // This is intended for use with operations that do not produce tensors, but
+ // nonetheless are required to run (e.g. serializing checkpoints).
+ repeated string target_node_names = 4;
+
+ // Map of Tensor names to constant inputs.
+ // Note: tensors specified via this message should not be included in
+ // input_tensor_specs.
+ map<string, tensorflow.TensorProto> constant_inputs = 5;
+}
+
+// ExampleQuerySpec message describes client execution that issues example
+// queries and sends the query results directly to an aggregator with no or
+// little additional processing.
+// This message describes one or more example store queries that perform the
+// client side analytics computation in C++. The corresponding output vectors
+// will be converted into the expected federated protocol output format.
+// This must be used in conjunction with the `FederatedExampleQueryIORouter`.
+message ExampleQuerySpec {
+ message OutputVectorSpec {
+ // The output vector name.
+ string vector_name = 1;
+
+ // Supported data types for the vector of information.
+ enum DataType {
+ UNSPECIFIED = 0;
+ INT32 = 1;
+ INT64 = 2;
+ BOOL = 3;
+ FLOAT = 4;
+ DOUBLE = 5;
+ BYTES = 6;
+ STRING = 7;
+ }
+
+ // The data type for each entry in the vector.
+ DataType data_type = 2;
+ }
+
+ message ExampleQuery {
+ // The `ExampleSelector` to issue the query with.
+ ExampleSelector example_selector = 1;
+
+ // Indicates that the query returns vector data and must return a single
+ // ExampleQueryResult result containing a VectorData entry matching each
+ // OutputVectorSpec.vector_name.
+ //
+ // If the query instead returns no result, then it will be treated as is if
+ // an error was returned. In that case, or if the query explicitly returns
+ // an error, then the client will abort its session.
+ //
+ // The keys in the map are the names the vectors should be aggregated under,
+ // and must match the keys in FederatedExampleQueryIORouter.aggregations.
+ map<string, OutputVectorSpec> output_vector_specs = 2;
+ }
+
+ // The queries to run.
+ repeated ExampleQuery example_queries = 1;
+}
+
+// The input and output router for Federated Compute plans.
+//
+// This proto is the glue between the federated protocol and the TensorFlow
+// execution engine. This message describes how to prepare data coming from the
+// incoming `CheckinResponse` (defined in
+// fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what
+// to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on
+// the server).
+//
+// TODO(team) we could replace `input_checkpoint_file_tensor_name` with
+// an `input_tensors` field, which would then be a tensor that contains the
+// input TensorProtos directly and skipping disk I/O, rather than referring to a
+// checkpoint file path.
+message FederatedComputeIORouter {
+ // ===========================================================================
+ // Inputs
+ // ===========================================================================
+ // The name of the scalar string tensor that is fed the file path to the
+ // initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint).
+ //
+ // The federated protocol code would copy the `CheckinResponse`'s initial
+ // checkpoint to a temporary file and then pass that file path through this
+ // tensor.
+ //
+ // Ops may be added to the client graph that take this tensor as input and
+ // reads the path.
+ //
+ // This field is optional. It may be omitted if the client graph does not use
+ // an initial checkpoint.
+ string input_filepath_tensor_name = 1;
+
+ // The name of the scalar string tensor that is fed the file path to which
+ // client work should serialize the bytes to send back to the server.
+ //
+ // The federated protocol code generates a temporary file and passes the file
+ // path through this tensor.
+ //
+ // Ops may be be added to the client graph that use this tensor as an argument
+ // to write files (e.g. writing checkpoints to disk).
+ //
+ // This field is optional. It must be omitted if the client graph does not
+ // generate any output files (e.g. when all output tensors of `TensorflowSpec`
+ // use Secure Aggregation). If this field is not set, then the `ReportRequest`
+ // message in the federated protocol will not have the
+ // `Report.update_checkpoint` field set. This absence of a value here can be
+ // used to validate that the plan only uses Secure Aggregation.
+ //
+ // Conversely, if this field is set and executing the associated
+ // TensorflowSpec does not write to the path is indication of an internal
+ // framework error. The runtime should notify the caller that the computation
+ // was setup incorrectly.
+ string output_filepath_tensor_name = 2;
+
+ // ===========================================================================
+ // Outputs
+ // ===========================================================================
+ // Describes which output tensors should be aggregated using an aggregation
+ // protocol, and the configuration for those protocols.
+ //
+ // Assertions:
+ // - All keys must exist in the associated `TensorflowSpec` as
+ // `output_tensor_specs.name` values.
+ map<string, AggregationConfig> aggregations = 3;
+}
+
+// The input and output router for client plans that do not use TensorFlow.
+//
+// This proto is the glue between the federated protocol and the example query
+// execution engine, describing how the query results should ultimately be
+// aggregated.
+message FederatedExampleQueryIORouter {
+ // Describes how each output vector should be aggregated using an aggregation
+ // protocol, and the configuration for those protocols.
+ // Keys must match the keys in ExampleQuerySpec.output_vector_specs.
+ // Note that currently only the TFV1CheckpointAggregation config is supported.
+ map<string, AggregationConfig> aggregations = 1;
+}
+
+// The specification for how to aggregate the associated tensor across clients
+// on the server.
+message AggregationConfig {
+ oneof protocol_config {
+ // Indicates that the given output tensor should be processed using Secure
+ // Aggregation, using the specified config options.
+ SecureAggregationConfig secure_aggregation = 2;
+
+ // Note: in the future we could add a `SimpleAggregationConfig` to add
+ // support for simple aggregation without writing to an intermediate
+ // checkpoint file first.
+
+ // Indicates that the given output tensor or vector (e.g. as produced by an
+ // ExampleQuerySpec) should be placed in an output TF v1 checkpoint.
+ //
+ // Currently only ExampleQuerySpec output vectors are supported by this
+ // aggregation type (i.e. it cannot be used with TensorflowSpec output
+ // tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
+ // its corresponding data type.
+ TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3;
+ }
+}
+
+// Parameters for the SecAgg protocol (go/secagg).
+//
+// Currently only the server uses the SecAgg parameters, so we only use this
+// message to signify usage of SecAgg.
+message SecureAggregationConfig {}
+
+// Parameters for the TFV1 Checkpoint Aggregation protocol.
+//
+// Currently only ExampleQuerySpec output vectors are supported by this
+// aggregation type (i.e. it cannot be used with TensorflowSpec output
+// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
+// its corresponding data type.
+message TFV1CheckpointAggregation {}
+
+// The input and output router for eligibility-computing plans. These plans
+// compute which other plans a client is eligible to run, and are returned by
+// clients via a `EligibilityEvalCheckinResponse` (defined in
+// fcp/protos/federated_api.proto).
+message FederatedComputeEligibilityIORouter {
+ // The name of the scalar string tensor that is fed the file path to the
+ // initial checkpoint (e.g. as provided via
+ // `EligibilityEvalPayload.init_checkpoint`).
+ //
+ // For more detail see the
+ // `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same
+ // semantics.
+ //
+ // This field is optional. It may be omitted if the client graph does not use
+ // an initial checkpoint.
+ //
+ // This tensor name must exist in the associated
+ // `TensorflowSpec.input_tensor_specs` list.
+ string input_filepath_tensor_name = 1;
+
+ // Name of the output tensor (a string scalar) containing the serialized
+ // `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The
+ // client code will parse this proto and place it in the
+ // `task_eligibility_info` field of the subsequent `CheckinRequest`.
+ //
+ // This tensor name must exist in the associated
+ // `TensorflowSpec.output_tensor_specs` list.
+ string task_eligibility_info_tensor_name = 2;
+}
+
+// The input and output router for Local Compute plans.
+//
+// This proto is the glue between the customers app and the TensorFlow
+// execution engine. This message describes how to prepare data coming from the
+// customer app (e.g. the input directory the app setup), and the temporary,
+// scratch output directory that will be notified to the customer app upon
+// completion of `TensorflowSpec`.
+message LocalComputeIORouter {
+ // ===========================================================================
+ // Inputs
+ // ===========================================================================
+ // The name of the placeholder tensor representing the input resource path(s).
+ // It can be a single input directory or file path (in this case the
+ // `input_dir_tensor_name` is populated) or multiple input resources
+ // represented as a map from names to input directories or file paths (in this
+ // case the `multiple_input_resources` is populated).
+ //
+ // In the multiple input resources case, the placeholder tensors are
+ // represented as a map: the keys are the input resource names defined by the
+ // users when constructing the `LocalComputation` Python object, and the
+ // values are the corresponding placeholder tensor names created by the local
+ // computation plan builder.
+ //
+ // Apps will have the ability to create contracts between their Android code
+ // and `LocalComputation` toolkit code to place files inside the input
+ // resource paths with known names (Android code) and create graphs with ops
+ // to read from these paths (file names can be specified in toolkit code).
+ oneof input_resource {
+ string input_dir_tensor_name = 1;
+ // Directly using the `map` field is not allowed in `oneof`, so we have to
+ // wrap it in a new message.
+ MultipleInputResources multiple_input_resources = 3;
+ }
+
+ // Scalar string tensor name that will contain the output directory path.
+ //
+ // The provided directory should be considered temporary scratch that will be
+ // deleted, not persisted. It is the responsibility of the calling app to
+ // move the desired files to a permanent location once the client returns this
+ // directory back to the calling app.
+ string output_dir_tensor_name = 2;
+
+ // ===========================================================================
+ // Outputs
+ // ===========================================================================
+ // NOTE: LocalCompute has no outputs other than what the client graph writes
+ // to `output_dir` specified above.
+}
+
+// Describes the multiple input resources in `LocalComputeIORouter`.
+message MultipleInputResources {
+ // The keys are the input resource names (defined by the users when
+ // constructing the `LocalComputation` Python object), and the values are the
+ // corresponding placeholder tensor names created by the local computation
+ // plan builder.
+ map<string, string> input_resource_tensor_name_map = 1;
+}
+
+// Describes a queue to which input is fed.
+message AsyncInputFeed {
+ // The op for enqueuing an example input.
+ string enqueue_op = 1;
+
+ // The input placeholders for the enqueue op.
+ repeated string enqueue_params = 2;
+
+ // The op for closing the input queue.
+ string close_op = 3;
+
+ // Whether the work that should be fed asynchronously is the data itself
+ // or a description of where that data lives.
+ bool feed_values_are_data = 4;
+}
+
+message DatasetInput {
+ // Initializer of iterator corresponding to tf.data.Dataset object which
+ // handles the input data. Stores name of an op in the graph.
+ string initializer = 1;
+
+ // Placeholders necessary to initialize the dataset.
+ DatasetInputPlaceholders placeholders = 2;
+
+ // Batch size to be used in tf.data.Dataset.
+ int32 batch_size = 3;
+}
+
+message DatasetInputPlaceholders {
+ // Name of placeholder corresponding to filename(s) of SSTable(s) to read data
+ // from.
+ string filename = 1;
+
+ // Name of placeholder corresponding to key_prefix initializing the
+ // SSTableDataset. Note the value fed should be unique user id, not a prefix.
+ string key_prefix = 2;
+
+ // Name of placeholder corresponding to number of rounds the local training
+ // should be run for.
+ string num_epochs = 3;
+
+ // Name of placeholder corresponding to batch size.
+ string batch_size = 4;
+}
+
+// Specifies an example selection procedure.
+message ExampleSelector {
+ // Selection criteria following a contract agreed upon between client and
+ // model designers.
+ google.protobuf.Any criteria = 1;
+
+ // A URI identifying the example collection to read from. Format should adhere
+ // to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments
+ // should adhere to the following rules:
+ // - The scheme ${COLLECTION} should be one of:
+ // - "app" for app-hosted example
+ // - "simulation" for collections not connected to an app (e.g., if used
+ // purely for simulation)
+ // - The authority ${APP_NAME} identifies the owner of the example
+ // collection and should be either the app's package name, or be left empty
+ // (which means "the current app package name").
+ // - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with
+ // a forward slash ("/").
+ // - The query and fragment are currently not used, but they may become used
+ // for something in the future. To keep open that possibility they must
+ // currently be left empty.
+ //
+ // Example: "app://com.google.some.app/someCollection/name"
+ // identifies the collection "/someCollection/name" owned and hosted by the
+ // app with package name "com.google.some.app".
+ //
+ // Example: "app:/someCollection/name" or "app:///someCollection/name"
+ // both identify the collection "/someCollection/name" owned and hosted by the
+ // app associated with the training job in which this URI appears.
+ //
+ // The path will not be interpreted by the runtime, and will be passed to the
+ // example collection implementation for interpretation. Thus, in the case of
+ // app-hosted example stores, the path segment's interpretation is a contract
+ // between the app's example store developers, and the app's model designers.
+ //
+ // If an `app://` URI is set, then the `TrainerOptions` collection name must
+ // not be set.
+ string collection_uri = 2;
+
+ // Resumption token following a contract agreed upon between client and
+ // model designers.
+ google.protobuf.Any resumption_token = 3;
+}
+
+// Selector for slices to fetch as part of a `federated_select` operation.
+message SlicesSelector {
+ // The string ID under which the slices are served.
+ //
+ // This value must have been returned by a previous call to the `serve_slices`
+ // op run during the `write_client_init` operation.
+ string served_at_id = 1;
+
+ // The indices of slices to fetch.
+ repeated int32 keys = 2;
+}
+
+// Represents slice data to be served as part of a `federated_select` operation.
+// This is used for testing.
+message SlicesTestDataset {
+ // The test data to use. The keys map to the `SlicesSelector.served_at_id`
+ // field. E.g. test slice data for a slice with `served_at_id`="foo" and
+ // `keys`=2 would be store in `dataset["foo"].slice_data[2]`.
+ map<string, SlicesTestData> dataset = 1;
+}
+message SlicesTestData {
+ // The test slice data to serve. Each entry's index corresponds to the slice
+ // key it is the test data for.
+ repeated bytes slice_data = 2;
+}
+
+// Server Phase V2
+// ===============
+
+// Represents a server phase with three distinct components: pre-broadcast,
+// aggregation, and post-aggregation.
+//
+// The pre-broadcast and post-aggregation components are described with
+// the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec
+// messages, respectively. These messages in combination with the server
+// IORouter messages specify how to set up a single TF sess.run call for each
+// component.
+//
+// The pre-broadcast logic is obtained by transforming the server_prepare TFF
+// computation in the DistributeAggregateForm. It takes the server state as
+// input, and it generates the checkpoint to broadcast to the clients and
+// potentially an intermediate server state. The intermediate server state may
+// be used by the aggregation and post-aggregation logic.
+//
+// The aggregation logic represents the aggregation of client results at the
+// server and is described using a list of ServerAggregationConfig messages.
+// Each ServerAggregationConfig message describes a single aggregation operation
+// on a set of input/output tensors. The input tensors may represent parts of
+// either the client results or the intermediate server state. These messages
+// are obtained by transforming the client_to_server_aggregation TFF computation
+// in the DistributeAggregateForm.
+//
+// The post-aggregation logic is obtained by transforming the server_result TFF
+// computation in the DistributeAggregateForm. It takes the intermediate server
+// state and the aggregated client results as input, and it generates the new
+// server state and potentially other server-side output.
+//
+// Note that while a ServerPhaseV2 message can be generated for all types of
+// intrinsics, it is currently only compatible with the ClientPhase message if
+// the aggregations being used are exclusively federated_sum (not SecAgg). If
+// this compatibility requirement is satisfied, it is also valid to run the
+// aggregation portion of this ServerPhaseV2 message alongside the pre- and
+// post-aggregation logic from the original ServerPhase message. Ultimately,
+// we expect the full ServerPhaseV2 message to be run and the ServerPhase
+// message to be deprecated.
+message ServerPhaseV2 {
+ // A short CamelCase name for the ServerPhaseV2.
+ string name = 1;
+
+ // A functional interface for the TensorFlow logic the server should perform
+ // prior to the server-to-client broadcast. This should be used with the
+ // TensorFlow graph defined in server_graph_prepare_bytes.
+ TensorflowSpec tensorflow_spec_prepare = 3;
+
+ // The specification of inputs needed by the server_prepare TF logic.
+ oneof server_prepare_io_router {
+ ServerPrepareIORouter prepare_router = 4;
+ }
+
+ // A list of client-to-server aggregations to perform.
+ repeated ServerAggregationConfig aggregations = 2;
+
+ // A functional interface for the TensorFlow logic the server should perform
+ // post-aggregation. This should be used with the TensorFlow graph defined
+ // in server_graph_result_bytes.
+ TensorflowSpec tensorflow_spec_result = 5;
+
+ // The specification of inputs and outputs needed by the server_result TF
+ // logic.
+ oneof server_result_io_router {
+ ServerResultIORouter result_router = 6;
+ }
+}
+
+// Routing for server_prepare graph
+message ServerPrepareIORouter {
+ // The name of the scalar string tensor in the server_prepare TF graph that
+ // is fed the filepath to the initial server state checkpoint. The
+ // server_prepare logic reads from this filepath.
+ string prepare_server_state_input_filepath_tensor_name = 1;
+
+ // The name of the scalar string tensor in the server_prepare TF graph that
+ // is fed the filepath where the client checkpoint should be stored. The
+ // server_prepare logic writes to this filepath.
+ string prepare_output_filepath_tensor_name = 2;
+
+ // The name of the scalar string tensor in the server_prepare TF graph that
+ // is fed the filepath where the intermediate state checkpoint should be
+ // stored. The server_prepare logic writes to this filepath. The intermediate
+ // state checkpoint will be consumed by both the logic used to set parameters
+ // for aggregation and the post-aggregation logic.
+ string prepare_intermediate_state_output_filepath_tensor_name = 3;
+}
+
+// Routing for server_result graph
+message ServerResultIORouter {
+ // The name of the scalar string tensor in the server_result TF graph that is
+ // fed the filepath to the intermediate state checkpoint. The server_result
+ // logic reads from this filepath.
+ string result_intermediate_state_input_filepath_tensor_name = 1;
+
+ // The name of the scalar string tensor in the server_result TF graph that is
+ // fed the filepath to the aggregated client result checkpoint. The
+ // server_result logic reads from this filepath.
+ string result_aggregate_result_input_filepath_tensor_name = 2;
+
+ // The name of the scalar string tensor in the server_result TF graph that is
+ // fed the filepath where the updated server state should be stored. The
+ // server_result logic writes to this filepath.
+ string result_server_state_output_filepath_tensor_name = 3;
+}
+
+// Represents a single aggregation operation, combining one or more input
+// tensors from a collection of clients into one or more output tensors on the
+// server.
+message ServerAggregationConfig {
+ // The uri of the aggregation intrinsic (e.g. 'federated_sum').
+ string intrinsic_uri = 1;
+
+ // Describes an argument to the aggregation operation.
+ message IntrinsicArg {
+ oneof arg {
+ // Refers to a tensor within the checkpoint provided by each client.
+ tensorflow.TensorSpecProto input_tensor = 2;
+
+ // Refers to a tensor within the intermediate server state checkpoint.
+ tensorflow.TensorSpecProto state_tensor = 3;
+ }
+ }
+
+ // List of arguments for the aggregation operation. The arguments can be
+ // dependent on client data (in which case they must be retrieved from
+ // clients) or they can be independent of client data (in which case they
+ // can be configured server-side). For now we assume all client-independent
+ // arguments are constants. The arguments must be in the order expected by
+ // the server.
+ repeated IntrinsicArg intrinsic_args = 4;
+
+ // List of server-side outputs produced by the aggregation operation.
+ repeated tensorflow.TensorSpecProto output_tensors = 5;
+
+ // List of inner aggregation intrinsics. This can be used to delegate parts
+ // of the aggregation logic (e.g. a groupby intrinsic may want to delegate
+ // a sum operation to a sum intrinsic).
+ repeated ServerAggregationConfig inner_aggregations = 6;
+}
+
+// Server Phase
+// ============
+
+// Represents a server phase which implements TF-based aggregation of multiple
+// client updates.
+//
+// There are two different modes of aggregation that are described
+// by the values in this message. The first is aggregation that is
+// coming from coordinated sets of clients. This includes aggregation
+// done via checkpoints from clients or aggregation done over a set
+// of clients by a process like secure aggregation. The results of
+// this first aggregation are saved to intermediate aggregation
+// checkpoints. The second aggregation then comes from taking
+// these intermediate checkpoints and aggregating over them.
+//
+// These two different modes of aggregation are done on different
+// servers, the first in the 'L1' servers and the second in the
+// 'L2' servers, so we use this nomenclature to describe these
+// phases below.
+//
+// The ServerPhase message is currently in the process of being replaced by the
+// ServerPhaseV2 message as we switch the plan building pipeline to use
+// DistributeAggregateForm instead of MapReduceForm. During the migration
+// process, we may generate both messages and use components from either
+// message during execution.
+//
+message ServerPhase {
+ // A short CamelCase name for the ServerPhase.
+ string name = 8;
+
+ // ===========================================================================
+ // L1 "Intermediate" Aggregation.
+ //
+ // This is the initial aggregation that creates partial aggregates from client
+ // results. L1 Aggregation may be run on many different instances.
+ //
+ // Pre-condition:
+ // The execution environment has loaded the graph from `server_graph_bytes`.
+
+ // 1. Initialize the phase.
+ //
+ // Operation to run before the first aggregation happens.
+ // For instance, clears the accumulators so that a new aggregation can begin.
+ string phase_init_op = 1;
+
+ // 2. For each client in set of clients:
+ // a. Restore variables from the client checkpoint.
+ //
+ // Loads a checkpoint from a single client written via
+ // `FederatedComputeIORouter.output_filepath_tensor_name`. This is done once
+ // for every client checkpoint in a round.
+ CheckpointOp read_update = 3;
+ // b. Aggregate the data coming from the client checkpoint.
+ //
+ // An operation that aggregates the data from read_update.
+ // Generally this will add to accumulators and it may leverage internal data
+ // inside the graph to adjust the weights of the Tensors.
+ //
+ // Executed once for each `read_update`, to (for example) update accumulator
+ // variables using the values loaded during `read_update`.
+ string aggregate_into_accumulators_op = 4;
+
+ // 3. After all clients have been aggregated, possibly restore
+ // variables that have been aggregated via a separate process.
+ //
+ // Optionally restores variables where aggregation is done across
+ // an entire round of client data updates. In contrast to `read_update`,
+ // which restores once per client, this occurs after all clients
+ // in a round have been processed. This allows, for example, side
+ // channels where aggregation is done by a separate process (such
+ // as in secure aggregation), in which the side channel aggregated
+ // tensor is passed to the `before_restore_op` which ensure the
+ // variables are restored properly. The `after_restore_op` will then
+ // be responsible for performing the accumulation.
+ //
+ // Note that in current use this should not have a SaverDef, but
+ // should only be used for side channels.
+ CheckpointOp read_aggregated_update = 10;
+
+ // 4. Write the aggregated variables to an intermediate checkpoint.
+ //
+ // We require that `aggregate_into_accumulators_op` is associative and
+ // commutative, so that the aggregates can be computed across
+ // multiple TensorFlow sessions.
+ // As an example, say we are computing the sum of 5 client updates:
+ // A = X1 + X2 + X3 + X4 + X5
+ // We can always do this in one session by calling `read_update`j and
+ // `aggregate_into_accumulators_op` once for each client checkpoint.
+ //
+ // Alternatively, we could compute:
+ // A1 = X1 + X2 in one TensorFlow session, and
+ // A2 = X3 + X4 + X5 in a different session.
+ // Each of these sessions can then write their accumulator state
+ // with the `write_intermediate_update` CheckpointOp, and a yet another third
+ // session can then call `read_intermediate_update` and
+ // `aggregate_into_accumulators_op` on each of these checkpoints to compute:
+ // A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5).
+ CheckpointOp write_intermediate_update = 7;
+ // End L1 "Intermediate" Aggregation.
+ // ===========================================================================
+
+ // ===========================================================================
+ // L2 Aggregation and Coordinator.
+ //
+ // This aggregates intermediate checkpoints from L1 Aggregation and performs
+ // the finalizing of the update. Unlike L1 there will only be one instance
+ // that does this aggregation.
+
+ // Pre-condition:
+ // The execution environment has loaded the graph from `server_graph_bytes`
+ // and restored the global model using `server_savepoint` from the parent
+ // `Plan` message.
+
+ // 1. Initialize the phase.
+ //
+ // This currently re-uses the `phase_init_op` from L1 aggregation above.
+
+ // 2. Write a checkpoint that can be sent to the client.
+ //
+ // Generates a checkpoint to be sent to the client, to be read by
+ // `FederatedComputeIORouter.input_filepath_tensor_name`.
+
+ CheckpointOp write_client_init = 2;
+
+ // 3. For each intermediate checkpoint:
+ // a. Restore variables from the intermediate checkpoint.
+ //
+ // The corresponding read checkpoint op to the write_intermediate_update.
+ // This is used instead of read_update for intermediate checkpoints because
+ // the format of these updates may be different than those used in updates
+ // from clients (which may, for example, be compressed).
+ CheckpointOp read_intermediate_update = 9;
+ // b. Aggregate the data coming from the intermediate checkpoint.
+ //
+ // An operation that aggregates the data from `read_intermediate_update`.
+ // Generally this will add to accumulators and it may leverage internal data
+ // inside the graph to adjust the weights of the Tensors.
+ string intermediate_aggregate_into_accumulators_op = 11;
+
+ // 4. Write the aggregated intermediate variables to a checkpoint.
+ //
+ // This is used for downstream, cross-round aggregation of metrics.
+ // These variables will be read back into a session with
+ // read_intermediate_update.
+ //
+ // Tasks which do not use FL metrics may unset the CheckpointOp.saver_def
+ // to disable writing accumulator checkpoints.
+ CheckpointOp write_accumulators = 12;
+
+ // 5. Finalize the round.
+ //
+ // This can include:
+ // - Applying the update aggregated from the intermediate checkpoints to the
+ // global model and other updates to cross-round state variables.
+ // - Computing final round metric values (e.g. the `report` of a
+ // `tff.federated_aggregate`).
+ string apply_aggregrated_updates_op = 5;
+
+ // 5. Fetch the server aggregated metrics.
+ //
+ // A list of names of metric variables to fetch from the TensorFlow session.
+ repeated Metric metrics = 6;
+
+ // 6. Serialize the updated server state (e.g. the coefficients of the global
+ // model in FL) using `server_savepoint` in the parent `Plan` message.
+
+ // End L2 Aggregation.
+ // ===========================================================================
+}
+
+// Represents the server phase in an eligibility computation.
+//
+// This phase produces a checkpoint to be sent to clients. This checkpoint is
+// then used as an input to the clients' task eligibility computations.
+// This phase *does not include any aggregation.*
+message ServerEligibilityComputationPhase {
+ // A short CamelCase name for the ServerEligibilityComputationPhase.
+ string name = 1;
+
+ // The names of the TensorFlow nodes to run in order to produce output.
+ repeated string target_node_names = 2;
+
+ // The specification of inputs and outputs to the TensorFlow graph.
+ oneof server_eligibility_io_router {
+ TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true];
+ }
+}
+
+// Represents the inputs and outputs of a `ServerEligibilityComputationPhase`
+// which takes a single `TaskEligibilityContext` as input.
+message TEContextServerEligibilityIORouter {
+ // The name of the scalar string tensor that must be fed a serialized
+ // `TaskEligibilityContext`.
+ string context_proto_input_tensor_name = 1;
+
+ // The name of the scalar string tensor that must be fed the path to which
+ // the server graph should write the checkpoint file to be sent to the client.
+ string output_filepath_tensor_name = 2;
+}
+
+// Plan
+// =====
+
+// Represents the overall plan for performing federated optimization or
+// personalization, as handed over to the production system. This will
+// typically be split down into individual pieces for different production
+// parts, e.g. server and client side.
+// NEXT_TAG: 15
+message Plan {
+ reserved 1, 3, 5;
+
+ // The actual type of the server_*_graph_bytes fields below is expected to be
+ // tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form
+ // for two reasons.
+ // 1) We may use execution engines other than TensorFlow.
+ // 2) We wish to avoid the cost of deserialized and re-serializing large
+ // graphs, in the Federated Learning service.
+
+ // While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes,
+ // server_graph_prepare_bytes, and server_graph_result_bytes may all be set.
+ // If we're using a MapReduceForm-based server implementation, only
+ // server_graph_bytes will be used. If we're using a DistributeAggregateForm-
+ // based server implementation, only server_graph_prepare_bytes and
+ // server_graph_result_bytes will be used.
+
+ // Optional. The TensorFlow graph used for all server processing described by
+ // ServerPhase. For personalization, this will not be set.
+ google.protobuf.Any server_graph_bytes = 7;
+
+ // Optional. The TensorFlow graph used for all server processing described by
+ // ServerPhaseV2.tensorflow_spec_prepare.
+ google.protobuf.Any server_graph_prepare_bytes = 13;
+
+ // Optional. The TensorFlow graph used for all server processing described by
+ // ServerPhaseV2.tensorflow_spec_result.
+ google.protobuf.Any server_graph_result_bytes = 14;
+
+ // A savepoint to sync the server checkpoint with a persistent
+ // storage system. The storage initially holds a seeded checkpoint
+ // which can subsequently read and updated by this savepoint.
+ // Optional-- not present in eligibility computation plans (those with a
+ // ServerEligibilityComputationPhase). This is used in conjunction with
+ // ServerPhase only.
+ CheckpointOp server_savepoint = 2;
+
+ // Required. The TensorFlow graph that describes the TensorFlow logic a client
+ // should perform. It should be consistent with the `TensorflowSpec` field in
+ // the `client_phase`. The actual type is expected to be tensorflow.GraphDef.
+ // The TensorFlow graph is stored in serialized form for two reasons.
+ // 1) We may use execution engines other than TensorFlow.
+ // 2) We wish to avoid the cost of deserialized and re-serializing large
+ // graphs, in the Federated Learning service.
+ google.protobuf.Any client_graph_bytes = 8;
+
+ // Optional. The FlatBuffer used for TFLite training.
+ // It contains the same model information as the client_graph_bytes, but with
+ // a different format.
+ bytes client_tflite_graph_bytes = 12;
+
+ // A pair of client phase and server phase which are processed in
+ // sync. The server execution defines how the results of a client
+ // phase are aggregated, and how the checkpoints for clients are
+ // generated.
+ message Phase {
+ // Required. The client phase.
+ ClientPhase client_phase = 1;
+
+ // Optional. Server phase for TF-based aggregation; not provided for
+ // personalization or eligibility tasks.
+ ServerPhase server_phase = 2;
+
+ // Optional. Server phase for native aggregation; only provided for tasks
+ // that have enabled the corresponding flag.
+ ServerPhaseV2 server_phase_v2 = 4;
+
+ // Optional. Only provided for eligibility tasks.
+ ServerEligibilityComputationPhase server_eligibility_phase = 3;
+ }
+
+ // A pair of client and server computations to run.
+ repeated Phase phase = 4;
+
+ // Metrics that are persistent across different phases. This
+ // includes, for example, counters that track how much work of
+ // different kinds has been done.
+ repeated Metric metrics = 6;
+
+ // Describes how metrics in both the client and server phases should be
+ // aggregated.
+ repeated OutputMetric output_metrics = 10;
+
+ // Version of the plan:
+ // version == 0 - Old plan without version field, containing b/65131070
+ // version >= 1 - plan supports multi-shard aggregation mode (L1/L2)
+ int32 version = 9;
+
+ // A TensorFlow ConfigProto packed in an Any.
+ //
+ // If this field is unset, if the Any proto is set but empty, or if the Any
+ // proto is populated with an empty ConfigProto (i.e. its `type_url` field is
+ // set, but the `value` field is empty) then the client implementation may
+ // choose a set of configuration parameters to provide to TensorFlow by
+ // default.
+ //
+ // In all other cases this field must contain a valid packed ConfigProto
+ // (invalid values will result in an error at execution time), and in this
+ // case the client will not provide any other configuration parameters by
+ // default.
+ google.protobuf.Any tensorflow_config_proto = 11;
+}
+
+// Represents a client part of the plan of federated optimization.
+// This also used to describe a client-only plan for standalone on-device
+// training, known as personalization.
+// NEXT_TAG: 6
+message ClientOnlyPlan {
+ reserved 3;
+
+ // The graph to use for training, in binary form.
+ bytes graph = 1;
+
+ // Optional. The flatbuffer used for TFLite training.
+ // Whether "graph" or "tflite_graph" is used for training is up to the client
+ // code to allow for a flag-controlled a/b rollout.
+ bytes tflite_graph = 5;
+
+ // The client phase to execute.
+ ClientPhase phase = 2;
+
+ // A TensorFlow ConfigProto.
+ google.protobuf.Any tensorflow_config_proto = 4;
+}
+
+// Represents the cross round aggregation portion for user defined measurements.
+// This is used by tools that process / analyze accumulator checkpoints
+// after a round of computation, to achieve aggregation beyond a round.
+message CrossRoundAggregationExecution {
+ // Operation to run before reading accumulator checkpoint.
+ string init_op = 1;
+
+ // Reads accumulator checkpoint.
+ CheckpointOp read_aggregated_update = 2;
+
+ // Operation to merge loaded checkpoint into accumulator.
+ string merge_op = 3;
+
+ // Reads and writes the final aggregated accumulator vars.
+ CheckpointOp read_write_final_accumulators = 6;
+
+ // Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable`
+ // to the user defined name of the signal.
+ repeated Measurement measurements = 4;
+
+ // The `tf.Graph` used for aggregating accumulator checkpoints when
+ // loading metrics.
+ google.protobuf.Any cross_round_aggregation_graph_bytes = 5;
+}
+
+message Measurement {
+ // Name of a TensorFlow op to run to read/fetch the value of this measurement.
+ string read_op_name = 1;
+
+ // A human-readable name for the measurement. Names are usually
+ // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
+ string name = 2;
+
+ reserved 3;
+
+ // A serialized `tff.Type` for the measurement.
+ bytes tff_type = 4;
+}
diff --git a/fcp/protos/task_eligibility_context.proto b/fcp/protos/task_eligibility_context.proto
new file mode 100644
index 0000000..9a76520
--- /dev/null
+++ b/fcp/protos/task_eligibility_context.proto
@@ -0,0 +1,51 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package google.internal.federated.plan;
+
+option java_package = "com.google.internal.federated.plan";
+option java_multiple_files = true;
+option java_outer_classname = "TaskEligibilityContextProto";
+
+// Context provided to the server task eligibility computation.
+//
+// This context is provided to the server and is used to produce the checkpoint
+// that will be sent to the clients to aid the clients in computing
+// `TaskEligibilityInfo`.
+message TaskEligibilityContext {
+ // A list of information for each task currently being considered.
+ repeated SingleTaskEligibilityContext tasks = 1;
+}
+
+// Per-task context provided to the server eligibility computation.
+message SingleTaskEligibilityContext {
+ // Name of the task.
+ string task_name = 1;
+
+ // Information about a policy that should be applied to the task to determine
+ // if it's eligible to be run. For example, a "didnt_run_recently" policy
+ // could instruct the server task eligibility computation that the task should
+ // not be run again by a client that ran it recently, where "recently" is
+ // implementation defined.
+ message EligibilityPolicy {
+ // The name of the policy. The set of possible values and their
+ // interpretation is implementation defined.
+ string name = 1;
+ }
+
+ // The list of eligibility policies that should be applied to the task.
+ repeated EligibilityPolicy policies = 2;
+}
diff --git a/fcp/secagg/client/BUILD b/fcp/secagg/client/BUILD
new file mode 100644
index 0000000..b2e22ee
--- /dev/null
+++ b/fcp/secagg/client/BUILD
@@ -0,0 +1,113 @@
+# Description:
+# SecAgg client-specific components.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp:internal"],
+)
+
+cc_library(
+ name = "state_transition_listener",
+ hdrs = ["state_transition_listener_interface.h"],
+)
+
+cc_library(
+ name = "client",
+ srcs = [
+ "secagg_client.cc",
+ "secagg_client_aborted_state.cc",
+ "secagg_client_alive_base_state.cc",
+ "secagg_client_completed_state.cc",
+ "secagg_client_r0_advertise_keys_input_not_set_state.cc",
+ "secagg_client_r0_advertise_keys_input_set_state.cc",
+ "secagg_client_r1_share_keys_base_state.cc",
+ "secagg_client_r1_share_keys_input_not_set_state.cc",
+ "secagg_client_r1_share_keys_input_set_state.cc",
+ "secagg_client_r2_masked_input_coll_base_state.cc",
+ "secagg_client_r2_masked_input_coll_input_not_set_state.cc",
+ "secagg_client_r2_masked_input_coll_input_set_state.cc",
+ "secagg_client_r2_masked_input_coll_waiting_for_input_state.cc",
+ "secagg_client_r3_unmasking_state.cc",
+ "secagg_client_state.cc",
+ ],
+ hdrs = [
+ "other_client_state.h",
+ "secagg_client.h",
+ "secagg_client_aborted_state.h",
+ "secagg_client_alive_base_state.h",
+ "secagg_client_completed_state.h",
+ "secagg_client_r0_advertise_keys_input_not_set_state.h",
+ "secagg_client_r0_advertise_keys_input_set_state.h",
+ "secagg_client_r1_share_keys_base_state.h",
+ "secagg_client_r1_share_keys_input_not_set_state.h",
+ "secagg_client_r1_share_keys_input_set_state.h",
+ "secagg_client_r2_masked_input_coll_base_state.h",
+ "secagg_client_r2_masked_input_coll_input_not_set_state.h",
+ "secagg_client_r2_masked_input_coll_input_set_state.h",
+ "secagg_client_r2_masked_input_coll_waiting_for_input_state.h",
+ "secagg_client_r3_unmasking_state.h",
+ "secagg_client_state.h",
+ "send_to_server_interface.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":state_transition_listener",
+ "//fcp/base",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/shared:cc_proto",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_test(
+ name = "client-test",
+ size = "small",
+ srcs = [
+ "secagg_client_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":client",
+ ":state_transition_listener",
+ "//fcp/base",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/shared:cc_proto",
+ "//fcp/secagg/testing:client_mocks",
+ "//fcp/secagg/testing:common_mocks",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "state-test",
+ size = "small",
+ srcs = [
+ "secagg_client_aborted_state_test.cc",
+ "secagg_client_completed_state_test.cc",
+ "secagg_client_r0_advertise_keys_input_not_set_state_test.cc",
+ "secagg_client_r0_advertise_keys_input_set_state_test.cc",
+ "secagg_client_r1_share_keys_input_not_set_state_test.cc",
+ "secagg_client_r1_share_keys_input_set_state_test.cc",
+ "secagg_client_r2_masked_input_coll_input_not_set_state_test.cc",
+ "secagg_client_r2_masked_input_coll_input_set_state_test.cc",
+ "secagg_client_r2_masked_input_coll_waiting_for_input_state_test.cc",
+ "secagg_client_r3_unmasking_state_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":client",
+ ":state_transition_listener",
+ "//fcp/base",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/shared:cc_proto",
+ "//fcp/secagg/testing:client_mocks",
+ "//fcp/secagg/testing:common_mocks",
+ "//fcp/testing",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/secagg/client/other_client_state.h b/fcp/secagg/client/other_client_state.h
new file mode 100644
index 0000000..86c7c70
--- /dev/null
+++ b/fcp/secagg/client/other_client_state.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_OTHER_CLIENT_STATE_H_
+#define FCP_SECAGG_CLIENT_OTHER_CLIENT_STATE_H_
+
+namespace fcp {
+namespace secagg {
+
+// Used by descendants of {@link SecAggClientState} to track the state of other
+// clients, from the perspective of this client.
+
+enum class OtherClientState {
+ kAlive,
+ kDeadAtRound1,
+ kDeadAtRound2,
+ kDeadAtRound3
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_OTHER_CLIENT_STATE_H_
diff --git a/fcp/secagg/client/secagg_client.cc b/fcp/secagg/client/secagg_client.cc
new file mode 100644
index 0000000..7032867
--- /dev/null
+++ b/fcp/secagg/client/secagg_client.cc
@@ -0,0 +1,129 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client.h"
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/async_abort.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/prng.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClient::SecAggClient(
+ int max_neighbors_expected,
+ int minimum_surviving_neighbors_for_reconstruction,
+ std::vector<InputVectorSpecification> input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ std::atomic<std::string*>* abort_signal_for_test)
+ : mu_(),
+ abort_signal_(nullptr),
+ async_abort_(abort_signal_for_test ? abort_signal_for_test
+ : &abort_signal_),
+ state_(std::make_unique<SecAggClientR0AdvertiseKeysInputNotSetState>(
+ max_neighbors_expected,
+ minimum_surviving_neighbors_for_reconstruction,
+ std::make_unique<std::vector<InputVectorSpecification> >(
+ std::move(input_vector_specs)),
+ std::move(prng), std::move(sender), std::move(transition_listener),
+ std::move(prng_factory), &async_abort_)) {}
+
+Status SecAggClient::Start() {
+ absl::WriterMutexLock _(&mu_);
+ auto state_or_error = state_->Start();
+ if (state_or_error.ok()) {
+ state_ = std::move(state_or_error.value());
+ }
+ return state_or_error.status();
+}
+
+Status SecAggClient::Abort() { return Abort("unknown reason"); }
+
+Status SecAggClient::Abort(const std::string& reason) {
+ async_abort_.Abort(reason);
+ absl::WriterMutexLock _(&mu_);
+ if (state_->IsAborted() || state_->IsCompletedSuccessfully())
+ return FCP_STATUS(OK);
+
+ auto state_or_error = state_->Abort(reason);
+ if (state_or_error.ok()) {
+ state_ = std::move(state_or_error.value());
+ }
+ return state_or_error.status();
+}
+
+Status SecAggClient::SetInput(std::unique_ptr<SecAggVectorMap> input_map) {
+ absl::WriterMutexLock _(&mu_);
+ auto state_or_error = state_->SetInput(std::move(input_map));
+ if (state_or_error.ok()) {
+ state_ = std::move(state_or_error.value());
+ }
+ return state_or_error.status();
+}
+
+StatusOr<bool> SecAggClient::ReceiveMessage(
+ const ServerToClientWrapperMessage& incoming) {
+ absl::WriterMutexLock _(&mu_);
+ auto state_or_error = state_->HandleMessage(incoming);
+ if (state_or_error.ok()) {
+ state_ = std::move(state_or_error.value());
+ // Return true iff neither aborted nor completed.
+ return !(state_->IsAborted() || state_->IsCompletedSuccessfully());
+ } else {
+ return state_or_error.status();
+ }
+}
+
+StatusOr<std::string> SecAggClient::ErrorMessage() const {
+ absl::ReaderMutexLock _(&mu_);
+ return state_->ErrorMessage();
+}
+
+bool SecAggClient::IsAborted() const {
+ absl::ReaderMutexLock _(&mu_);
+ return state_->IsAborted();
+}
+
+bool SecAggClient::IsCompletedSuccessfully() const {
+ absl::ReaderMutexLock _(&mu_);
+ return state_->IsCompletedSuccessfully();
+}
+
+std::string SecAggClient::State() const {
+ absl::ReaderMutexLock _(&mu_);
+ return state_->StateName();
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client.h b/fcp/secagg/client/secagg_client.h
new file mode 100644
index 0000000..654b6ff
--- /dev/null
+++ b/fcp/secagg/client/secagg_client.h
@@ -0,0 +1,171 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/async_abort.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/prng.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+
+// Represents a client for the secure aggregation protocol. Each instance of
+// this class performs just *one* session of the protocol.
+//
+// To create a new instance, use the public constructor. The Start method can be
+// used to produce the first message of the protocol, the SetInput method sets
+// the input for this client and the ReceiveMessage method is used to process
+// incoming messages from the server.
+//
+// The class is thread-safe, but will deadlock if accessed reentrantly from
+// the SendToServerInterface callback.
+//
+// Functions are marked virtual for mockability. Additional virtual attributes
+// should be added as needed by tests.
+
+class SecAggClient {
+ public:
+ // Creates a new instance of the client.
+ //
+ // max_neighbors_expected is the upper bound on the total number of neighbors
+ // this client may interact with. If the server tries to start a protocol
+ // session with more than this many neighbors, this client will abort.
+ //
+ // minimum_surviving_neighbors_for_reconstruction is the threshold lower bound
+ // on the number of neighbors participating. If there are ever fewer than this
+ // number of remaining neighbors in the protocol, this client will abort.
+ //
+ // input_vector_specs must contain one InputVectorSpecification for each input
+ // vector which the protocol will aggregate. This may optionally be moved
+ // from using std::move(caller_input_vector_specs).
+ //
+ // prng should always be an instance of CryptoRandPrng, except as needed for
+ // testing purposes. The client will consume prng, taking ownership of it.
+ //
+ // sender is used by the client to send messages to the server. The client
+ // will consume sender, taking ownership of it.
+ //
+ // transition_listener is used to trigger state transition events, used for
+ // logging.
+ //
+ // prng_factory is a pointer to an instance of a subclass of AesPrngFactory.
+ // The type of prng_factory must be consistent with the one used on the
+ // server.
+ //
+ // async_abort_for_test, optionally, allows the caller to reset the abort
+ // signal. This is used to exhaustively test all abort paths, and should not
+ // be used in production; specifically, if this paramter is not nullptr,
+ // Abort() will no longer abort a state-in-progress; it will only abort across
+ // state transitions.
+ SecAggClient(
+ int max_neighbors_expected,
+ int minimum_surviving_neighbors_for_reconstruction,
+ std::vector<InputVectorSpecification> input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ std::atomic<std::string*>* abort_signal_for_test = nullptr);
+ virtual ~SecAggClient() = default;
+
+ // Disallow copy and move.
+ SecAggClient(const SecAggClient&) = delete;
+ SecAggClient& operator=(const SecAggClient&) = delete;
+
+ // Initiates the protocol by computing its first message and sending it to
+ // the server. This method should only be called once. The output will be OK
+ // unless it is called more than once.
+ virtual Status Start();
+
+ // Makes this client abort the protocol and sends a message to notify the
+ // server. All the state is erased. A new instance of SecAggClient will have
+ // to be created to restart the protocol.
+ //
+ // The status will be OK unless the protocol was already completed or aborted.
+ Status Abort();
+
+ // Makes this client abort the protocol and sends a message to notify the
+ // server. All the state is erased. A new instance of SecAggClient will have
+ // to be created to restart the protocol.
+ //
+ // The specified reason for aborting will be sent to the server and logged.
+ //
+ // The status will be OK unless the protocol was already completed or aborted.
+ Status Abort(const std::string& reason);
+
+ // Sets the input of this client for this protocol session. This method should
+ // only be called once.
+ //
+ // If the input does not match the format laid out in input_vector_specs,
+ // this will return INVALID_ARGUMENT. If SetInput has already been called or
+ // if the client is in an aborted or completed state, this will return
+ // FAILED_PRECONDITION. Otherwise returns OK.
+ Status SetInput(std::unique_ptr<SecAggVectorMap> input_map);
+
+ // Returns a string uniquely describing the current state of the client's FSM.
+ ABSL_MUST_USE_RESULT std::string State() const;
+
+ // Returns true if the client has aborted the protocol, false else.
+ ABSL_MUST_USE_RESULT bool IsAborted() const;
+
+ // Returns true if the client has successfully completed the protocol,
+ // false else.
+ ABSL_MUST_USE_RESULT bool IsCompletedSuccessfully() const;
+
+ // Returns a string describing the reason that the client aborted.
+ // If the client has not actually aborted, returns an error Status with code
+ // PRECONDITION_FAILED.
+ ABSL_MUST_USE_RESULT StatusOr<std::string> ErrorMessage() const;
+
+ // Used to process an incoming message from the server. This method uses the
+ // SendToServerInterface passed to the constructor to send the response
+ // directly to the server.
+ //
+ // The output will be true if the client is still active, or false if the
+ // client is now in a terminal state. The output will be a failure status if
+ // the client did not process the message because it was in a terminal state,
+ // or because the message was the wrong type.
+ StatusOr<bool> ReceiveMessage(const ServerToClientWrapperMessage& incoming);
+
+ private:
+ mutable absl::Mutex mu_;
+
+ std::atomic<std::string*> abort_signal_;
+ AsyncAbort async_abort_;
+
+ // The internal State object, containing details about this client's current
+ // state.
+ std::unique_ptr<SecAggClientState> state_ ABSL_GUARDED_BY(mu_);
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_H_
diff --git a/fcp/secagg/client/secagg_client_aborted_state.cc b/fcp/secagg/client/secagg_client_aborted_state.cc
new file mode 100644
index 0000000..0019467
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_aborted_state.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientAbortedState::SecAggClientAbortedState(
+ const std::string& reason, std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener)
+ : SecAggClientState(std::move(sender), std::move(transition_listener),
+ ClientState::ABORTED),
+ reason_(reason) {
+ FCP_LOG(WARNING) << "Aborting for reason: " << reason_;
+}
+
+SecAggClientAbortedState::~SecAggClientAbortedState() = default;
+
+bool SecAggClientAbortedState::IsAborted() const { return true; }
+
+StatusOr<std::string> SecAggClientAbortedState::ErrorMessage() const {
+ return reason_;
+}
+
+std::string SecAggClientAbortedState::StateName() const { return "ABORTED"; }
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_aborted_state.h b/fcp/secagg/client/secagg_client_aborted_state.h
new file mode 100644
index 0000000..1967e29
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_aborted_state.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_ABORTED_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_ABORTED_STATE_H_
+
+#include <memory>
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the abort state for a client. There are no transitions
+// out of this state. A new SecAggClient object will be needed to start a new
+// run of the protocol.
+
+class SecAggClientAbortedState : public SecAggClientState {
+ public:
+ SecAggClientAbortedState(
+ const std::string& reason, std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener);
+
+ ~SecAggClientAbortedState() override;
+
+ // Returns true from this state.
+ bool IsAborted() const override;
+
+ // Returns the error message with which the client aborted.
+ StatusOr<std::string> ErrorMessage() const override;
+
+ // Returns the name of this state, "ABORTED".
+ std::string StateName() const override;
+
+ private:
+ const std::string reason_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_ABORTED_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_aborted_state_test.cc b/fcp/secagg/client/secagg_client_aborted_state_test.cc
new file mode 100644
index 0000000..454af20
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_aborted_state_test.cc
@@ -0,0 +1,123 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+
+#include <string>
+#include <unordered_map>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::StrEq;
+
+TEST(SecAggClientAbortedStateTest, IsAbortedReturnsTrue) {
+ std::string test_reason = "test reason";
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientAbortedState aborted_state(
+ test_reason, std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<StateTransitionListenerInterface>{transition_listener});
+ EXPECT_THAT(aborted_state.IsAborted(), Eq(true));
+}
+
+TEST(SecAggClientAbortedStateTest, IsCompletedSuccessfullyReturnsFalse) {
+ std::string test_reason = "test reason";
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientAbortedState aborted_state(
+ test_reason, std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<StateTransitionListenerInterface>{transition_listener});
+ EXPECT_THAT(aborted_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+TEST(SecAggClientAbortedStateTest, ErrorMessageReturnsSelectedMessage) {
+ std::string test_reason = "test reason";
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientAbortedState aborted_state(
+ test_reason, std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<StateTransitionListenerInterface>{transition_listener});
+ ASSERT_THAT(aborted_state.ErrorMessage().ok(), Eq(true));
+ EXPECT_THAT(aborted_state.ErrorMessage().value(), StrEq(test_reason));
+}
+
+TEST(SecAggClientAbortedStateTest, StartRaisesErrorStatus) {
+ std::string test_reason = "test reason";
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientAbortedState aborted_state(
+ test_reason, std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<StateTransitionListenerInterface>{transition_listener});
+ EXPECT_THAT(aborted_state.Start().ok(), Eq(false));
+}
+
+TEST(SecAggClientAbortedStateTest, HandleMessageRaisesErrorStatus) {
+ std::string test_reason = "test reason";
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientAbortedState aborted_state(
+ test_reason, std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<StateTransitionListenerInterface>{transition_listener});
+ EXPECT_THAT(
+ aborted_state
+ .HandleMessage(ServerToClientWrapperMessage::default_instance())
+ .ok(),
+ Eq(false));
+}
+
+TEST(SecAggClientAbortedStateTest, SetInputRaisesErrorStatus) {
+ std::string test_reason = "test reason";
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientAbortedState aborted_state(
+ test_reason, std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<StateTransitionListenerInterface>{transition_listener});
+ EXPECT_THAT(aborted_state.SetInput(std::make_unique<SecAggVectorMap>()).ok(),
+ Eq(false));
+}
+
+TEST(SecAggClientAbortedStateTest, AbortRaisesErrorStatus) {
+ std::string test_reason = "test reason";
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientAbortedState aborted_state(
+ test_reason, std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<StateTransitionListenerInterface>{transition_listener});
+ EXPECT_THAT(aborted_state.Abort(test_reason).ok(), Eq(false));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_alive_base_state.cc b/fcp/secagg/client/secagg_client_alive_base_state.cc
new file mode 100644
index 0000000..a383f03
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_alive_base_state.cc
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, softwar
+ * e
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_alive_base_state.h"
+
+#include <string>
+#include <utility>
+
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientAliveBaseState::SecAggClientAliveBaseState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ ClientState state, AsyncAbort* async_abort)
+ : SecAggClientState(std::move(sender), std::move(transition_listener),
+ state),
+ async_abort_(async_abort) {}
+
+StatusOr<std::unique_ptr<SecAggClientState> > SecAggClientAliveBaseState::Abort(
+ const std::string& reason) {
+ return AbortAndNotifyServer("Abort upon external request for reason <" +
+ reason + ">.");
+}
+
+std::unique_ptr<SecAggClientState>
+SecAggClientAliveBaseState::AbortAndNotifyServer(const std::string& reason) {
+ ClientToServerWrapperMessage message_to_server;
+ message_to_server.mutable_abort()->set_diagnostic_info(reason);
+ sender_->Send(&message_to_server);
+ return std::make_unique<SecAggClientAbortedState>(
+ reason, std::move(sender_), std::move(transition_listener_));
+}
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_alive_base_state.h b/fcp/secagg/client/secagg_client_alive_base_state.h
new file mode 100644
index 0000000..89579de
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_alive_base_state.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_ALIVE_BASE_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_ALIVE_BASE_STATE_H_
+
+#include <memory>
+#include <string>
+
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/async_abort.h"
+
+namespace fcp {
+namespace secagg {
+
+// Abstract base class containing code used by all SecAggClientStates where the
+// client is still alive and online, i.e. non-terminal states.
+
+class SecAggClientAliveBaseState : public SecAggClientState {
+ public:
+ ~SecAggClientAliveBaseState() override = default;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > Abort(
+ const std::string& reason) override;
+
+ protected:
+ // SecAggClientAliveBaseState should never be instantiated directly.
+ explicit SecAggClientAliveBaseState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ ClientState state, AsyncAbort* async_abort = nullptr);
+
+ // Method to be used internally by child SecAggClient*State classes, called
+ // when an abort is required by the protocol. Sends an abort message to the
+ // server, then constructs and returns an abort state.
+ std::unique_ptr<SecAggClientState> AbortAndNotifyServer(
+ const std::string& reason);
+
+ AsyncAbort* async_abort_; // Owned by state owner.
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_ALIVE_BASE_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_completed_state.cc b/fcp/secagg/client/secagg_client_completed_state.cc
new file mode 100644
index 0000000..cdc12cb
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_completed_state.cc
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientCompletedState::SecAggClientCompletedState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener)
+ : SecAggClientState(std::move(sender), std::move(transition_listener),
+ ClientState::COMPLETED) {}
+
+SecAggClientCompletedState::~SecAggClientCompletedState() = default;
+
+bool SecAggClientCompletedState::IsCompletedSuccessfully() const {
+ return true;
+}
+
+std::string SecAggClientCompletedState::StateName() const {
+ return "COMPLETED";
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_completed_state.h b/fcp/secagg/client/secagg_client_completed_state.h
new file mode 100644
index 0000000..04e14d9
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_completed_state.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_COMPLETED_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_COMPLETED_STATE_H_
+
+#include <memory>
+#include <string>
+
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the Completed state for a client, meaning the
+// client has either sent its final message or received a message indicating
+// success from the server.
+//
+// There are no transitions out of this state.
+
+class SecAggClientCompletedState : public SecAggClientState {
+ public:
+ // As a terminal state, this State does not need to store any specific
+ // information except the sender (to ensure it does not go out of scope
+ // unexpectedly).
+ explicit SecAggClientCompletedState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener);
+
+ ~SecAggClientCompletedState() override;
+
+ // Returns true from this state.
+ bool IsCompletedSuccessfully() const override;
+
+ // Returns the name of this state, "COMPLETED".
+ std::string StateName() const override;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_COMPLETED_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_completed_state_test.cc b/fcp/secagg/client/secagg_client_completed_state_test.cc
new file mode 100644
index 0000000..8a9785f
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_completed_state_test.cc
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+
+#include <string>
+#include <unordered_map>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+
+TEST(SecAggClientCompletedStateTest, IsAbortedReturnsFalse) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientCompletedState completed_state(
+ std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<MockStateTransitionListener>{transition_listener});
+ EXPECT_THAT(completed_state.IsAborted(), Eq(false));
+}
+
+TEST(SecAggClientCompletedStateTest, IsCompletedSuccessfullyReturnsTrue) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientCompletedState completed_state(
+ std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<MockStateTransitionListener>{transition_listener});
+ EXPECT_THAT(completed_state.IsCompletedSuccessfully(), Eq(true));
+}
+
+TEST(SecAggClientCompletedStateTest, ErrorMessageRaisesErrorStatus) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientCompletedState completed_state(
+ std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<MockStateTransitionListener>{transition_listener});
+ EXPECT_THAT(completed_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecAggClientCompletedStateTest, StartRaisesErrorStatus) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientCompletedState completed_state(
+ std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<MockStateTransitionListener>{transition_listener});
+ EXPECT_THAT(completed_state.Start().ok(), Eq(false));
+}
+
+TEST(SecAggClientCompletedStateTest, HandleMessageRaisesErrorStatus) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientCompletedState completed_state(
+ std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<MockStateTransitionListener>{transition_listener});
+ EXPECT_THAT(
+ completed_state
+ .HandleMessage(ServerToClientWrapperMessage::default_instance())
+ .ok(),
+ Eq(false));
+}
+
+TEST(SecAggClientCompletedStateTest, SetInputRaisesErrorStatus) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientCompletedState completed_state(
+ std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<MockStateTransitionListener>{transition_listener});
+ EXPECT_THAT(
+ completed_state.SetInput(std::make_unique<SecAggVectorMap>()).ok(),
+ Eq(false));
+}
+
+TEST(SecAggClientCompletedStateTest, AbortRaisesErrorStatus) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientCompletedState completed_state(
+ std::unique_ptr<SendToServerInterface>{sender},
+ std::unique_ptr<MockStateTransitionListener>{transition_listener});
+ EXPECT_THAT(completed_state.Abort("incorrect abort").ok(), Eq(false));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.cc b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.cc
new file mode 100644
index 0000000..bdd33d1
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.cc
@@ -0,0 +1,125 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.h"
+#include "fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientR0AdvertiseKeysInputNotSetState::
+ SecAggClientR0AdvertiseKeysInputNotSetState(
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ std::unique_ptr<AesPrngFactory> prng_factory, AsyncAbort* async_abort)
+ : SecAggClientAliveBaseState(std::move(sender),
+ std::move(transition_listener),
+ ClientState::R0_ADVERTISE_KEYS, async_abort),
+ max_neighbors_expected_(max_neighbors_expected),
+ minimum_surviving_neighbors_for_reconstruction_(
+ minimum_surviving_neighbors_for_reconstruction),
+ input_vector_specs_(std::move(input_vector_specs)),
+ prng_(std::move(prng)),
+ prng_factory_(std::move(prng_factory)) {}
+
+SecAggClientR0AdvertiseKeysInputNotSetState::
+ ~SecAggClientR0AdvertiseKeysInputNotSetState() = default;
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR0AdvertiseKeysInputNotSetState::Start() {
+ auto enc_key_agreement = EcdhKeyAgreement::CreateFromRandomKeys().value();
+ auto prng_key_agreement = EcdhKeyAgreement::CreateFromRandomKeys().value();
+
+ ClientToServerWrapperMessage message;
+ PairOfPublicKeys* public_keys =
+ message.mutable_advertise_keys()->mutable_pair_of_public_keys();
+ public_keys->set_enc_pk(enc_key_agreement->PublicKey().AsString());
+ public_keys->set_noise_pk(prng_key_agreement->PublicKey().AsString());
+
+ sender_->Send(&message);
+ return {std::make_unique<SecAggClientR1ShareKeysInputNotSetState>(
+ max_neighbors_expected_, minimum_surviving_neighbors_for_reconstruction_,
+ std::move(enc_key_agreement), std::move(input_vector_specs_),
+ std::move(prng_), std::move(prng_key_agreement), std::move(sender_),
+ std::move(transition_listener_), std::move(prng_factory_), async_abort_)};
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR0AdvertiseKeysInputNotSetState::HandleMessage(
+ const ServerToClientWrapperMessage& message) {
+ // Handle abort messages only.
+ if (message.has_abort()) {
+ if (message.abort().early_success()) {
+ return {std::make_unique<SecAggClientCompletedState>(
+ std::move(sender_), std::move(transition_listener_))};
+ } else {
+ return {std::make_unique<SecAggClientAbortedState>(
+ "Aborting because of abort message from the server.",
+ std::move(sender_), std::move(transition_listener_))};
+ }
+ } else {
+ // Returns an error indicating that the message is of invalid type.
+ return SecAggClientState::HandleMessage(message);
+ }
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR0AdvertiseKeysInputNotSetState::SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map) {
+ if (!ValidateInput(*input_map, *input_vector_specs_)) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "The input to SetInput does not match the "
+ "InputVectorSpecification.";
+ }
+
+ return {std::make_unique<SecAggClientR0AdvertiseKeysInputSetState>(
+ max_neighbors_expected_, minimum_surviving_neighbors_for_reconstruction_,
+ std::move(input_map), std::move(input_vector_specs_), std::move(prng_),
+ std::move(sender_), std::move(transition_listener_),
+ std::move(prng_factory_), async_abort_)};
+}
+
+std::string SecAggClientR0AdvertiseKeysInputNotSetState::StateName() const {
+ return "R0_ADVERTISE_KEYS_INPUT_NOT_SET";
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.h b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.h
new file mode 100644
index 0000000..faa2554
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R0_ADVERTISE_KEYS_INPUT_NOT_SET_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R0_ADVERTISE_KEYS_INPUT_NOT_SET_STATE_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_alive_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/prng.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the client's Round 0: Advertise Keys state with the
+// input not yet set. This state should transition to either the
+// Round 1: Share Keys (Input Not Set) state or the
+// Round 0: Advertise Keys (Input Set) states. It can also transition directly
+// to the Completed or Aborted states.
+
+class SecAggClientR0AdvertiseKeysInputNotSetState
+ : public SecAggClientAliveBaseState {
+ public:
+ SecAggClientR0AdvertiseKeysInputNotSetState(
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ AsyncAbort* async_abort = nullptr);
+
+ ~SecAggClientR0AdvertiseKeysInputNotSetState() override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > Start() override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > HandleMessage(
+ const ServerToClientWrapperMessage& message) override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map) override;
+
+ // Returns the name of this state, "R0_ADVERTISE_KEYS_INPUT_NOT_SET".
+ std::string StateName() const override;
+
+ private:
+ const uint32_t max_neighbors_expected_;
+ const uint32_t minimum_surviving_neighbors_for_reconstruction_;
+ std::unique_ptr<std::vector<InputVectorSpecification> > input_vector_specs_;
+ std::unique_ptr<SecurePrng> prng_;
+ std::unique_ptr<AesPrngFactory> prng_factory_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R0_ADVERTISE_KEYS_INPUT_NOT_SET_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state_test.cc b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state_test.cc
new file mode 100644
index 0000000..351a9c6
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state_test.cc
@@ -0,0 +1,404 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r0_advertise_keys_input_not_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::Pointee;
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest, IsAbortedReturnsFalse) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r0_state.IsAborted(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ IsCompletedSuccessfullyReturnsFalse) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r0_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+// A gMock matcher that just checks that the message is a valid AdvertiseKeys
+// message, with the right fields set to the right lengths.
+MATCHER(IsValidAdvertiseKeysMessage, "") {
+ return (arg->advertise_keys().pair_of_public_keys().enc_pk().size() ==
+ EcdhPublicKey::kSize) &&
+ (arg->advertise_keys().pair_of_public_keys().noise_pk().size() ==
+ EcdhPublicKey::kSize);
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ StartSendsCorrectMessageAndTransitionsState) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(IsValidAdvertiseKeysMessage())).Times(1);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state = r0_state.Start();
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(),
+ Eq("R1_SHARE_KEYS_INPUT_NOT_SET"));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ SetInputTransitionsToInputSetStateWithoutNotifyingServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.SetInput(std::move(input_map));
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(),
+ Eq("R0_ADVERTISE_KEYS_INPUT_SET"));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfVectorIsWrongSize) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector has too many elements.
+ input_map->insert(
+ std::make_pair("test", SecAggVector({5, 8, 22, 30, 7}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputVectorIsTooLargeForBitWidth) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector's bit_width does not match the specified modulus of 32.
+ input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 64));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputVectorHasWrongName) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector has the wrong name.
+ input_map->insert(
+ std::make_pair("incorret", SecAggVector({5, 8, 22, 30}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputHasTooManyVectors) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
+ // This vector is extra.
+ input_map->emplace("test2", SecAggVector({4, 7, 21, 29}, 32));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputHasTooFewVectors) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ // Expects two vectors.
+ input_vector_specs->push_back(InputVectorSpecification("test2", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
+ // Missing second vector.
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ ErrorMessageRaisesErrorStatus) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r0_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ AbortReturnsValidAbortStateAndNotifiesServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::string error_string =
+ "Abort upon external request for reason <Abort reason>.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.Abort("Abort reason");
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ AbortFailureMessageCausesAbortWithoutNotifyingServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(),
+ Eq("Aborting because of abort message from the server."));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ EarlySuccessMessageCausesTransitionToCompletedState) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(true);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("COMPLETED"));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputNotSetStateTest,
+ HandleNonAbortMessageRaisesError) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputNotSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+
+ EXPECT_THAT(r0_state.HandleMessage(message).ok(), Eq(false));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.cc b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.cc
new file mode 100644
index 0000000..b8c4168
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.cc
@@ -0,0 +1,112 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientR0AdvertiseKeysInputSetState::
+ SecAggClientR0AdvertiseKeysInputSetState(
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ std::unique_ptr<SecAggVectorMap> input_map,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ std::unique_ptr<AesPrngFactory> prng_factory, AsyncAbort* async_abort)
+ : SecAggClientAliveBaseState(std::move(sender),
+ std::move(transition_listener),
+ ClientState::R0_ADVERTISE_KEYS, async_abort),
+ max_neighbors_expected_(max_neighbors_expected),
+ minimum_surviving_neighbors_for_reconstruction_(
+ minimum_surviving_neighbors_for_reconstruction),
+ input_map_(std::move(input_map)),
+ input_vector_specs_(std::move(input_vector_specs)),
+ prng_(std::move(prng)),
+ prng_factory_(std::move(prng_factory)) {}
+
+SecAggClientR0AdvertiseKeysInputSetState::
+ ~SecAggClientR0AdvertiseKeysInputSetState() = default;
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR0AdvertiseKeysInputSetState::Start() {
+ auto enc_key_agreement = EcdhKeyAgreement::CreateFromRandomKeys().value();
+ auto prng_key_agreement = EcdhKeyAgreement::CreateFromRandomKeys().value();
+
+ ClientToServerWrapperMessage message;
+ PairOfPublicKeys* public_keys =
+ message.mutable_advertise_keys()->mutable_pair_of_public_keys();
+ public_keys->set_enc_pk(enc_key_agreement->PublicKey().AsString());
+ public_keys->set_noise_pk(prng_key_agreement->PublicKey().AsString());
+
+ sender_->Send(&message);
+ return {std::make_unique<SecAggClientR1ShareKeysInputSetState>(
+ max_neighbors_expected_, minimum_surviving_neighbors_for_reconstruction_,
+ std::move(enc_key_agreement), std::move(input_map_),
+ std::move(input_vector_specs_), std::move(prng_),
+ std::move(prng_key_agreement), std::move(sender_),
+ std::move(transition_listener_), std::move(prng_factory_), async_abort_)};
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR0AdvertiseKeysInputSetState::HandleMessage(
+ const ServerToClientWrapperMessage& message) {
+ // Handle abort messages only.
+ if (message.has_abort()) {
+ if (message.abort().early_success()) {
+ return {std::make_unique<SecAggClientCompletedState>(
+ std::move(sender_), std::move(transition_listener_))};
+ } else {
+ return {std::make_unique<SecAggClientAbortedState>(
+ "Aborting because of abort message from the server.",
+ std::move(sender_), std::move(transition_listener_))};
+ }
+ } else {
+ // Returns an error indicating that the message is of invalid type.
+ return SecAggClientState::HandleMessage(message);
+ }
+}
+
+std::string SecAggClientR0AdvertiseKeysInputSetState::StateName() const {
+ return "R0_ADVERTISE_KEYS_INPUT_SET";
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.h b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.h
new file mode 100644
index 0000000..67d3321
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R0_ADVERTISE_KEYS_INPUT_SET_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R0_ADVERTISE_KEYS_INPUT_SET_STATE_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_alive_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/prng.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the client's Round 0: Advertise Keys state with the
+// input already set. This state should transition to the Round 1: Share Keys
+// (Input Set) state, but can also transition directly to the Completed or
+// Aborted states.
+
+class SecAggClientR0AdvertiseKeysInputSetState
+ : public SecAggClientAliveBaseState {
+ public:
+ SecAggClientR0AdvertiseKeysInputSetState(
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ std::unique_ptr<SecAggVectorMap> input_map,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ AsyncAbort* async_abort = nullptr);
+
+ ~SecAggClientR0AdvertiseKeysInputSetState() override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > Start() override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > HandleMessage(
+ const ServerToClientWrapperMessage& message) override;
+
+ // Returns the name of this state, "R0_ADVERTISE_KEYS_INPUT_SET".
+ std::string StateName() const override;
+
+ private:
+ const uint32_t max_neighbors_expected_;
+ const uint32_t minimum_surviving_neighbors_for_reconstruction_;
+ std::unique_ptr<SecAggVectorMap> input_map_;
+ std::unique_ptr<std::vector<InputVectorSpecification> > input_vector_specs_;
+ std::unique_ptr<SecurePrng> prng_;
+ std::unique_ptr<AesPrngFactory> prng_factory_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R0_ADVERTISE_KEYS_INPUT_SET_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state_test.cc b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state_test.cc
new file mode 100644
index 0000000..391b86c
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state_test.cc
@@ -0,0 +1,290 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r0_advertise_keys_input_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::Pointee;
+
+TEST(SecaggClientR0AdvertiseKeysInputSetStateTest, IsAbortedReturnsFalse) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r0_state.IsAborted(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputSetStateTest,
+ IsCompletedSuccessfullyReturnsFalse) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r0_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+// A gMock matcher that just checks that the message is a valid AdvertiseKeys
+// message, with the right fields set to the right lengths.
+MATCHER(IsValidAdvertiseKeysMessage, "") {
+ return (arg->advertise_keys().pair_of_public_keys().enc_pk().size() ==
+ EcdhPublicKey::kSize) &&
+ (arg->advertise_keys().pair_of_public_keys().noise_pk().size() ==
+ EcdhPublicKey::kSize);
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputSetStateTest,
+ StartSendsCorrectMessageAndTransitionsState) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(IsValidAdvertiseKeysMessage())).Times(1);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state = r0_state.Start();
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("R1_SHARE_KEYS_INPUT_SET"));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputSetStateTest, SetInputRaisesErrorStatus) {
+ std::vector<uint64_t> vec = {2, 4, 6, 8};
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector(vec, 32));
+ auto copy_of_input_map = std::make_unique<SecAggVectorMap>();
+ copy_of_input_map->emplace("test", SecAggVector(vec, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r0_state.SetInput(std::move(copy_of_input_map)).ok(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputSetStateTest,
+ ErrorMessageRaisesErrorStatus) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r0_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputSetStateTest,
+ AbortReturnsValidAbortStateAndNotifiesServer) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::string error_string =
+ "Abort upon external request for reason <Abort reason>.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.Abort("Abort reason");
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputSetStateTest,
+ AbortFailureMessageCausesAbortWithoutNotifyingServer) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(),
+ Eq("Aborting because of abort message from the server."));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputSetStateTest,
+ EarlySuccessMessageCausesTransitionToCompletedState) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(true);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r0_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("COMPLETED"));
+}
+
+TEST(SecaggClientR0AdvertiseKeysInputSetStateTest,
+ HandleNonAbortMessageRaisesError) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR0AdvertiseKeysInputSetState r0_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+
+ EXPECT_THAT(r0_state.HandleMessage(message).ok(), Eq(false));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r1_share_keys_base_state.cc b/fcp/secagg/client/secagg_client_r1_share_keys_base_state.cc
new file mode 100644
index 0000000..4eb15e0
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r1_share_keys_base_state.cc
@@ -0,0 +1,202 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r1_share_keys_base_state.h"
+
+#include <cstdint>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_alive_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_gcm_encryption.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/prng.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientR1ShareKeysBaseState::SecAggClientR1ShareKeysBaseState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ AsyncAbort* async_abort)
+ : SecAggClientAliveBaseState(std::move(sender),
+ std::move(transition_listener),
+ ClientState::R1_SHARE_KEYS, async_abort) {}
+
+void SecAggClientR1ShareKeysBaseState::SetUpShares(
+ int threshold, int n, const Key& agreement_key, const Key& self_prng_key,
+ std::vector<ShamirShare>* self_prng_key_shares,
+ std::vector<ShamirShare>* pairwise_prng_key_shares) {
+ // This could be made into an assertion, but that would complicate the tests
+ // that call this method to get a "preview" of the shares.
+ if (pairwise_prng_key_shares->empty() && self_prng_key_shares->empty()) {
+ ShamirSecretSharing sharer;
+ *pairwise_prng_key_shares = sharer.Share(threshold, n, agreement_key);
+ *self_prng_key_shares = sharer.Share(threshold, n, self_prng_key);
+ }
+}
+
+bool SecAggClientR1ShareKeysBaseState::HandleShareKeysRequest(
+ const ShareKeysRequest& request, const EcdhKeyAgreement& enc_key_agreement,
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ const EcdhKeyAgreement& prng_key_agreement, const AesKey& self_prng_key,
+ SecurePrng* prng, uint32_t* client_id, std::string* error_message,
+ uint32_t* number_of_alive_clients, uint32_t* number_of_clients,
+ std::vector<AesKey>* other_client_enc_keys,
+ std::vector<AesKey>* other_client_prng_keys,
+ std::vector<OtherClientState>* other_client_states,
+ std::vector<ShamirShare>* self_prng_key_shares,
+ std::vector<ShamirShare>* pairwise_prng_key_shares, SessionId* session_id) {
+ transition_listener_->set_execution_session_id(
+ request.sec_agg_execution_logging_id());
+ if (request.pairs_of_public_keys().size() <
+ static_cast<int>(minimum_surviving_neighbors_for_reconstruction)) {
+ *error_message =
+ "The ShareKeysRequest received does not contain enough participants.";
+ return false;
+ } else if (request.pairs_of_public_keys().size() >
+ static_cast<int>(max_neighbors_expected)) {
+ *error_message =
+ "The ShareKeysRequest received contains too many participants.";
+ return false;
+ }
+
+ *number_of_alive_clients = request.pairs_of_public_keys().size();
+ *number_of_clients = request.pairs_of_public_keys().size();
+ bool client_id_set = false;
+
+ SetUpShares(minimum_surviving_neighbors_for_reconstruction,
+ *number_of_clients, prng_key_agreement.PrivateKey(),
+ self_prng_key, self_prng_key_shares, pairwise_prng_key_shares);
+
+ if (request.session_id().size() != kSha256Length) {
+ *error_message =
+ "Session ID is absent in ShareKeysRequest or has an unexpected length.";
+ return false;
+ }
+ session_id->data = request.session_id();
+
+ other_client_states->resize(*number_of_clients, OtherClientState::kAlive);
+ other_client_enc_keys->reserve(*number_of_clients);
+ other_client_prng_keys->reserve(*number_of_clients);
+
+ EcdhPublicKey self_enc_public_key = enc_key_agreement.PublicKey();
+ EcdhPublicKey self_prng_public_key = prng_key_agreement.PublicKey();
+
+ for (uint32_t i = 0; i < *number_of_clients; ++i) {
+ if (async_abort_ && async_abort_->Signalled()) {
+ *error_message = async_abort_->Message();
+ return false;
+ }
+ const PairOfPublicKeys& keys = request.pairs_of_public_keys(i);
+ if (keys.enc_pk().empty() || keys.noise_pk().empty()) {
+ // This is an aborted client, or it sent invalid keys.
+ other_client_states->at(i) = OtherClientState::kDeadAtRound1;
+ --(*number_of_alive_clients);
+ other_client_enc_keys->push_back(AesKey());
+ other_client_prng_keys->push_back(AesKey());
+ } else if (keys.enc_pk().size() != EcdhPublicKey::kSize ||
+ keys.noise_pk().size() != EcdhPublicKey::kSize) {
+ // The server forwarded an invalid public key.
+ *error_message = "Invalid public key in request from server.";
+ return false;
+ } else {
+ EcdhPublicKey enc_pk(
+ reinterpret_cast<const uint8_t*>(keys.enc_pk().data()));
+ EcdhPublicKey prng_pk(
+ reinterpret_cast<const uint8_t*>(keys.noise_pk().data()));
+ if (enc_pk == self_enc_public_key && prng_pk == self_prng_public_key) {
+ // This is this client.
+ if (client_id_set) {
+ *error_message =
+ "Found this client's keys in the ShareKeysRequest twice somehow.";
+ return false;
+ }
+ *client_id = i;
+ client_id_set = true;
+ // Add empty entries for own id.
+ other_client_enc_keys->push_back(AesKey());
+ other_client_prng_keys->push_back(AesKey());
+ } else {
+ auto shared_enc_key = enc_key_agreement.ComputeSharedSecret(enc_pk);
+ auto shared_prng_key = prng_key_agreement.ComputeSharedSecret(prng_pk);
+ if (!shared_enc_key.ok() || !shared_prng_key.ok()) {
+ // The server forwarded an invalid public key.
+ *error_message = "Invalid public key in request from server.";
+ return false;
+ }
+ other_client_enc_keys->push_back(shared_enc_key.value());
+ other_client_prng_keys->push_back(shared_prng_key.value());
+ }
+ }
+ }
+
+ if (*number_of_alive_clients <
+ minimum_surviving_neighbors_for_reconstruction) {
+ *error_message =
+ "There are not enough clients to complete this protocol session. "
+ "Aborting.";
+ return false;
+ }
+ if (!client_id_set) {
+ *error_message =
+ "The ShareKeysRequest sent by the server doesn't contain this client's "
+ "public keys.";
+ return false;
+ }
+ *error_message = "";
+ return true;
+}
+
+bool SecAggClientR1ShareKeysBaseState::EncryptAndSendResponse(
+ const std::vector<AesKey>& other_client_enc_keys,
+ const std::vector<ShamirShare>& pairwise_prng_key_shares,
+ const std::vector<ShamirShare>& self_prng_key_shares,
+ SendToServerInterface* sender) {
+ ClientToServerWrapperMessage message;
+ ShareKeysResponse* response = message.mutable_share_keys_response();
+ AesGcmEncryption encryptor;
+
+ for (uint32_t i = 0; i < other_client_enc_keys.size(); ++i) {
+ if (async_abort_ && async_abort_->Signalled()) return false;
+ if (other_client_enc_keys[i].size() == 0) {
+ // Add a blank for dropped-out clients and for this client.
+ response->add_encrypted_key_shares("");
+ } else {
+ PairOfKeyShares key_shares_pair;
+ key_shares_pair.set_noise_sk_share(pairwise_prng_key_shares[i].data);
+ key_shares_pair.set_prf_sk_share(self_prng_key_shares[i].data);
+ std::string serialized_pair = key_shares_pair.SerializeAsString();
+ response->add_encrypted_key_shares(
+ encryptor.Encrypt(other_client_enc_keys[i], serialized_pair));
+ }
+ }
+
+ sender->Send(&message);
+ return true;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r1_share_keys_base_state.h b/fcp/secagg/client/secagg_client_r1_share_keys_base_state.h
new file mode 100644
index 0000000..3ecc2eb
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r1_share_keys_base_state.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R1_SHARE_KEYS_BASE_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R1_SHARE_KEYS_BASE_STATE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_alive_base_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/prng.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+// This is an abstract class which is the parent of two possible state classes
+// representing the states that the client may be in at Round 1: Share Keys.
+// It should never be instantiated directly, but contains code that will be used
+// by both concrete Round 1 client classes.
+
+class SecAggClientR1ShareKeysBaseState : public SecAggClientAliveBaseState {
+ public:
+ ~SecAggClientR1ShareKeysBaseState() override = default;
+
+ protected:
+ // SecAggClientR1ShareKeysBaseState should never be instantiated directly.
+ explicit SecAggClientR1ShareKeysBaseState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ AsyncAbort* async_abort = nullptr);
+
+ void SetUpShares(int threshold, int n, const Key& agreement_key,
+ const Key& self_prng_key,
+ std::vector<ShamirShare>* self_prng_key_shares,
+ std::vector<ShamirShare>* pairwise_prng_key_shares);
+
+ // Handles the logic associated with receiving a ShareKeysRequest. Uses the
+ // ECDH public keys of other clients to compute shared secrets with other
+ // clients, and shares its own private keys to send to the server.
+ //
+ // The arguments following prng are outputs. The vectors should be empty prior
+ // to calling this method.
+ //
+ // The output will be false if an error was detected; this error will be
+ // stored in *error_message. If the protocol should proceed, the output will
+ // be true and *error_message will be an empty string.
+ bool HandleShareKeysRequest(
+ const ShareKeysRequest& request,
+ const EcdhKeyAgreement& enc_key_agreement,
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ const EcdhKeyAgreement& prng_key_agreement, const AesKey& self_prng_key,
+ SecurePrng* prng, uint32_t* client_id, std::string* error_message,
+ uint32_t* number_of_alive_clients, uint32_t* number_of_clients,
+ std::vector<AesKey>* other_client_enc_keys,
+ std::vector<AesKey>* other_client_prng_keys,
+ std::vector<OtherClientState>* other_client_states,
+ std::vector<ShamirShare>* self_prng_key_shares,
+ std::vector<ShamirShare>* pairwise_prng_key_shares,
+ SessionId* session_id);
+
+ // Individually encrypts each pair of key shares with the agreed-upon key for
+ // the client that share is for, and then sends the encrypted keys to the
+ // server. Dropped-out clients and this client are represented by empty
+ // strings. Returns true if successful, false if aborted by client.
+ bool EncryptAndSendResponse(
+ const std::vector<AesKey>& other_client_enc_keys,
+ const std::vector<ShamirShare>& pairwise_prng_key_shares,
+ const std::vector<ShamirShare>& self_prng_key_shares,
+ SendToServerInterface* sender);
+};
+
+} // namespace secagg
+} // namespace fcp
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R1_SHARE_KEYS_BASE_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.cc b/fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.cc
new file mode 100644
index 0000000..3bca400
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.cc
@@ -0,0 +1,154 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+SecAggClientR1ShareKeysInputNotSetState::
+ SecAggClientR1ShareKeysInputNotSetState(
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ std::unique_ptr<EcdhKeyAgreement> enc_key_agreement,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<EcdhKeyAgreement> prng_key_agreement,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ std::unique_ptr<AesPrngFactory> prng_factory, AsyncAbort* async_abort)
+ : SecAggClientR1ShareKeysBaseState(
+ std::move(sender), std::move(transition_listener), async_abort),
+ max_neighbors_expected_(max_neighbors_expected),
+ minimum_surviving_neighbors_for_reconstruction_(
+ minimum_surviving_neighbors_for_reconstruction),
+ enc_key_agreement_(std::move(enc_key_agreement)),
+ input_vector_specs_(std::move(input_vector_specs)),
+ prng_(std::move(prng)),
+ prng_key_agreement_(std::move(prng_key_agreement)),
+ prng_factory_(std::move(prng_factory)) {}
+
+SecAggClientR1ShareKeysInputNotSetState::
+ ~SecAggClientR1ShareKeysInputNotSetState() = default;
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR1ShareKeysInputNotSetState::HandleMessage(
+ const ServerToClientWrapperMessage& message) {
+ // Handle abort messages or share keys requests only.
+ if (message.has_abort()) {
+ if (message.abort().early_success()) {
+ return {std::make_unique<SecAggClientCompletedState>(
+ std::move(sender_), std::move(transition_listener_))};
+ } else {
+ return {std::make_unique<SecAggClientAbortedState>(
+ "Aborting because of abort message from the server.",
+ std::move(sender_), std::move(transition_listener_))};
+ }
+ } else if (!message.has_share_keys_request()) {
+ // Returns an error indicating that the message is of invalid type.
+ return SecAggClientState::HandleMessage(message);
+ }
+ uint32_t client_id;
+ uint32_t number_of_alive_clients;
+ uint32_t number_of_clients;
+ std::string error_message;
+ auto other_client_enc_keys = std::make_unique<std::vector<AesKey> >();
+ auto other_client_prng_keys = std::make_unique<std::vector<AesKey> >();
+ auto other_client_states = std::make_unique<std::vector<OtherClientState> >();
+ auto own_self_key_share = std::make_unique<ShamirShare>();
+ auto session_id = std::make_unique<SessionId>();
+
+ uint8_t self_prng_key_buffer[AesKey::kSize];
+ for (int i = 0; i < AesKey::kSize; ++i) {
+ self_prng_key_buffer[i] = prng_->Rand8();
+ }
+ auto self_prng_key = std::make_unique<AesKey>(self_prng_key_buffer);
+
+ bool success = HandleShareKeysRequest(
+ message.share_keys_request(), *enc_key_agreement_,
+ max_neighbors_expected_, minimum_surviving_neighbors_for_reconstruction_,
+ *prng_key_agreement_, *self_prng_key, prng_.get(), &client_id,
+ &error_message, &number_of_alive_clients, &number_of_clients,
+ other_client_enc_keys.get(), other_client_prng_keys.get(),
+ other_client_states.get(), &self_prng_key_shares_,
+ &pairwise_prng_key_shares_, session_id.get());
+
+ if (!success) {
+ return AbortAndNotifyServer(error_message);
+ }
+
+ if (!EncryptAndSendResponse(*other_client_enc_keys, pairwise_prng_key_shares_,
+ self_prng_key_shares_, sender_.get())) {
+ return AbortAndNotifyServer(async_abort_->Message());
+ }
+
+ *own_self_key_share = self_prng_key_shares_[client_id];
+ return {std::make_unique<SecAggClientR2MaskedInputCollInputNotSetState>(
+ client_id, minimum_surviving_neighbors_for_reconstruction_,
+ number_of_alive_clients, number_of_clients,
+ std::move(input_vector_specs_), std::move(other_client_states),
+ std::move(other_client_enc_keys), std::move(other_client_prng_keys),
+ std::move(own_self_key_share), std::move(self_prng_key),
+ std::move(sender_), std::move(transition_listener_),
+ std::move(session_id), std::move(prng_factory_), async_abort_)};
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR1ShareKeysInputNotSetState::SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map) {
+ if (!ValidateInput(*input_map, *input_vector_specs_)) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "The input to SetInput does not match the "
+ "InputVectorSpecification.";
+ }
+
+ return {std::make_unique<SecAggClientR1ShareKeysInputSetState>(
+ max_neighbors_expected_, minimum_surviving_neighbors_for_reconstruction_,
+ std::move(enc_key_agreement_), std::move(input_map),
+ std::move(input_vector_specs_), std::move(prng_),
+ std::move(prng_key_agreement_), std::move(sender_),
+ std::move(transition_listener_), std::move(prng_factory_), async_abort_)};
+}
+
+std::string SecAggClientR1ShareKeysInputNotSetState::StateName() const {
+ return "R1_SHARE_KEYS_INPUT_NOT_SET";
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.h b/fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.h
new file mode 100644
index 0000000..c235d06
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R1_SHARE_KEYS_INPUT_NOT_SET_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R1_SHARE_KEYS_INPUT_NOT_SET_STATE_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_r1_share_keys_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/prng.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the client's Round 1: Share Keys state with the input
+// not set yet. This state should transition either to the Round 2: Masked Input
+// Collection (Input Not Set) state, or the Round 1: Share Keys (Input Set)
+// state. It can also transition directly to the Completed or Aborted states.
+
+class SecAggClientR1ShareKeysInputNotSetState
+ : public SecAggClientR1ShareKeysBaseState {
+ public:
+ SecAggClientR1ShareKeysInputNotSetState(
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ std::unique_ptr<EcdhKeyAgreement> enc_key_agreement,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<EcdhKeyAgreement> prng_key_agreement,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ AsyncAbort* async_abort = nullptr);
+
+ ~SecAggClientR1ShareKeysInputNotSetState() override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > HandleMessage(
+ const ServerToClientWrapperMessage& message) override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map) override;
+
+ // Returns the name of this state, "R1_SHARE_KEYS_INPUT_NOT_SET".
+ ABSL_MUST_USE_RESULT std::string StateName() const override;
+
+ private:
+ friend class SecAggClientR1ShareKeysInputNotSetStateTest_ShareKeysRequestIsHandledCorrectlyWithDeadClient_Test; // NOLINT
+
+ const uint32_t max_neighbors_expected_;
+ const uint32_t minimum_surviving_neighbors_for_reconstruction_;
+ std::unique_ptr<EcdhKeyAgreement> enc_key_agreement_;
+ std::unique_ptr<std::vector<InputVectorSpecification> > input_vector_specs_;
+ std::unique_ptr<SecurePrng> prng_;
+ std::unique_ptr<EcdhKeyAgreement> prng_key_agreement_;
+ std::unique_ptr<AesPrngFactory> prng_factory_;
+ std::vector<ShamirShare> self_prng_key_shares_;
+ std::vector<ShamirShare> pairwise_prng_key_shares_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R1_SHARE_KEYS_INPUT_NOT_SET_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state_test.cc b/fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state_test.cc
new file mode 100644
index 0000000..6c09e33
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state_test.cc
@@ -0,0 +1,1004 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r1_share_keys_input_not_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/aes_gcm_encryption.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+
+using ::testing::Eq;
+using ::testing::Pointee;
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest, IsAbortedReturnsFalse) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r1_state.IsAborted(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ IsCompletedSuccessfullyReturnsFalse) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r1_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest, StartRaisesErrorStatus) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r1_state.Start().ok(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ SetInputTransitionsToInputSetStateWithoutNotifyingServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({5, 8, 22, 30}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.SetInput(std::move(input_map));
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("R1_SHARE_KEYS_INPUT_SET"));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfVectorIsWrongSize) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector has too many elements.
+ input_map->insert(
+ std::make_pair("test", SecAggVector({5, 8, 22, 30, 7}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputVectorIsTooLargeForBitWidth) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector's bit_width does not match the specified modulus of 32.
+ input_map->insert(std::make_pair("test", SecAggVector({5, 8, 22, 30}, 64)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputVectorHasWrongName) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector has the wrong name.
+ input_map->insert(
+ std::make_pair("incorret", SecAggVector({5, 8, 22, 30}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputHasTooManyVectors) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({5, 8, 22, 30}, 32)));
+ // This vector is extra.
+ input_map->insert(std::make_pair("test2", SecAggVector({4, 7, 21, 29}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputHasTooFewVectors) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ // Expects two vectors.
+ input_vector_specs->push_back(InputVectorSpecification("test2", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({5, 8, 22, 30}, 32)));
+ // Missing second vector.
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ErrorMessageRaisesErrorStatus) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r1_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ AbortReturnsValidAbortStateAndNotifiesServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::string error_string =
+ "Abort upon external request for reason <Abort reason>.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.Abort("Abort reason");
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ AbortFailureMessageCausesAbortWithoutNotifyingServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(),
+ Eq("Aborting because of abort message from the server."));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ EarlySuccessMessageCausesTransitionToCompletedState) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(true);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("COMPLETED"));
+}
+
+// A gMock matcher to see if the message sent by the client contains shares that
+// reconstruct to the right private keys.
+
+// pairwise_key is the expected pairwise PRNG key.
+// self_key is the expected self PRNG key.
+// enc_keys must be a std::vector<AesKey>, the vector of encryption keys.
+// Assumes 3-of-4 secret sharing.
+MATCHER_P3(ReconstructsCorrectly, pairwise_key, self_key, enc_keys, "") {
+ AesGcmEncryption decryptor;
+ std::vector<ShamirShare> pairwise_shares;
+ std::vector<ShamirShare> self_shares;
+ for (int i = 0; i < enc_keys.size(); ++i) {
+ // Blank shares must be blank in both places
+ if (arg->share_keys_response().encrypted_key_shares(i).empty()) {
+ pairwise_shares.push_back({""});
+ self_shares.push_back({""});
+ continue;
+ }
+ auto decrypted = decryptor.Decrypt(
+ enc_keys[i], arg->share_keys_response().encrypted_key_shares(i));
+ if (!decrypted.ok()) {
+ return false;
+ }
+ PairOfKeyShares key_shares;
+ if (!key_shares.ParseFromString(decrypted.value())) {
+ return false;
+ }
+ pairwise_shares.push_back({key_shares.noise_sk_share()});
+ self_shares.push_back({key_shares.prf_sk_share()});
+ }
+ // Reconstruct keys to see if they match
+ ShamirSecretSharing reconstructor;
+ std::string reconstructed_pairwise_key_string =
+ reconstructor.Reconstruct(3, pairwise_shares, EcdhPrivateKey::kSize)
+ .value();
+ std::string reconstructed_self_key_string =
+ reconstructor.Reconstruct(3, self_shares, AesKey::kSize).value();
+ EcdhPrivateKey reconstructed_pairwise_key(reinterpret_cast<const uint8_t*>(
+ reconstructed_pairwise_key_string.c_str()));
+ AesKey reconstructed_self_key(
+ reinterpret_cast<const uint8_t*>(reconstructed_self_key_string.c_str()));
+ return pairwise_key == reconstructed_pairwise_key &&
+ self_key == reconstructed_self_key;
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ShareKeysRequestIsHandledCorrectlyWhenNoClientsDie) {
+ // In this test, the client under test is id 1, and there are 4 clients, all
+ // alive.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ // Make a copy of the encryption keys for testing
+ std::vector<AesKey> enc_keys;
+ auto enc_key_agreement =
+ EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(2))
+ .value();
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 4; ++i) {
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+
+ if (i == 1) {
+ enc_keys.push_back(AesKey());
+ } else {
+ enc_keys.push_back(
+ enc_key_agreement->ComputeSharedSecret(ecdh_keys.GetPublicKey(2 * i))
+ .value());
+ }
+ }
+
+ // Make a copy of the self PRNG key.
+ FakePrng prng;
+ uint8_t self_prng_key_buffer[AesKey::kSize];
+ for (int i = 0; i < AesKey::kSize; ++i) {
+ self_prng_key_buffer[i] = prng.Rand8();
+ }
+ auto self_prng_key = AesKey(self_prng_key_buffer);
+
+ EXPECT_CALL(*sender, Send(ReconstructsCorrectly(ecdh_keys.GetPrivateKey(3),
+ self_prng_key, enc_keys)))
+ .Times(1);
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(),
+ Eq("R2_MASKED_INPUT_COLL_INPUT_NOT_SET"));
+}
+
+// A gMock matcher to see if the message sent by the client contains shares that
+// reconstruct to the right private keys.
+//
+// This version also takes into account the client's own shares, to ensure that
+// it works even when exactly threshold clients survive.
+//
+// pairwise_key is the expected pairwise PRNG key, an AesKey.
+// self_key is the expected self PRNG key, an AesKey.
+// own_pairwise_key_share is client 1's ShamirShare of its own pairwise PRNG
+// key.
+// own_self_key_share is client 1's ShamirShare of its own pairwise PRNG
+// key.
+// enc_keys must be a std::vector<AesKey>, the vector of encryption keys.
+// Assumes 3-of-4 secret sharing.
+MATCHER_P5(ReconstructsCorrectlyWithOwnKeys, pairwise_key, self_key,
+ own_pairwise_key_share, own_self_key_share, enc_keys, "") {
+ AesGcmEncryption decryptor;
+ std::vector<ShamirShare> pairwise_shares;
+ std::vector<ShamirShare> self_shares;
+ for (int i = 0; i < enc_keys.size(); ++i) {
+ // Blank shares must be blank in both places
+ if (arg->share_keys_response().encrypted_key_shares(i).empty()) {
+ if (i == 1) {
+ pairwise_shares.push_back(own_pairwise_key_share);
+ self_shares.push_back(own_self_key_share);
+ } else {
+ pairwise_shares.push_back({""});
+ self_shares.push_back({""});
+ }
+ continue;
+ }
+ auto decrypted = decryptor.Decrypt(
+ enc_keys[i], arg->share_keys_response().encrypted_key_shares(i));
+ if (!decrypted.ok()) {
+ return false;
+ }
+ PairOfKeyShares key_shares;
+ if (!key_shares.ParseFromString(decrypted.value())) {
+ return false;
+ }
+ pairwise_shares.push_back({key_shares.noise_sk_share()});
+ self_shares.push_back({key_shares.prf_sk_share()});
+ }
+ // Reconstruct keys to see if they match
+ ShamirSecretSharing reconstructor;
+ std::string reconstructed_pairwise_key_string =
+ reconstructor.Reconstruct(3, pairwise_shares, EcdhPrivateKey::kSize)
+ .value();
+ std::string reconstructed_self_key_string =
+ reconstructor.Reconstruct(3, self_shares, AesKey::kSize).value();
+ EcdhPrivateKey reconstructed_pairwise_key(reinterpret_cast<const uint8_t*>(
+ reconstructed_pairwise_key_string.c_str()));
+ AesKey reconstructed_self_key(
+ reinterpret_cast<const uint8_t*>(reconstructed_self_key_string.c_str()));
+ EXPECT_THAT(reconstructed_pairwise_key_string,
+ Eq(std::string(reinterpret_cast<const char*>(pairwise_key.data()),
+ EcdhPrivateKey::kSize)));
+ return pairwise_key == reconstructed_pairwise_key;
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ShareKeysRequestIsHandledCorrectlyWithDeadClient) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // Client 3 has died in this round.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ // Make a copy of the encryption keys for testing
+ std::vector<AesKey> enc_keys;
+ auto enc_key_agreement =
+ EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(2))
+ .value();
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 3; ++i) { // exclude client 3
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+
+ if (i == 1) {
+ enc_keys.push_back(AesKey());
+ } else {
+ enc_keys.push_back(
+ enc_key_agreement->ComputeSharedSecret(ecdh_keys.GetPublicKey(2 * i))
+ .value());
+ }
+ }
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ enc_keys.push_back(AesKey());
+
+ // Make a copy of the self PRNG key.
+ FakePrng prng;
+ uint8_t self_prng_key_buffer[AesKey::kSize];
+ for (int i = 0; i < AesKey::kSize; ++i) {
+ self_prng_key_buffer[i] = prng.Rand8();
+ }
+ auto self_prng_key = AesKey(self_prng_key_buffer);
+
+ // Make a copy of the secret shares, in order to get this client's own shares.
+ std::vector<ShamirShare> self_key_shares, pairwise_key_shares;
+ r1_state.SetUpShares(3, 4, ecdh_keys.GetPrivateKey(3), self_prng_key,
+ &r1_state.self_prng_key_shares_,
+ &r1_state.pairwise_prng_key_shares_);
+ ShamirShare own_pairwise_key_share = r1_state.pairwise_prng_key_shares_.at(1);
+ ShamirShare own_self_key_share = r1_state.self_prng_key_shares_.at(1);
+
+ EXPECT_CALL(*sender,
+ Send(ReconstructsCorrectlyWithOwnKeys(
+ ecdh_keys.GetPrivateKey(3), self_prng_key,
+ own_pairwise_key_share, own_self_key_share, enc_keys)))
+ .Times(1);
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(),
+ Eq("R2_MASKED_INPUT_COLL_INPUT_NOT_SET"));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ShareKeysRequestCausesAbortIfTooManyDeadClients) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // Clients 2 and 3 died, and we need 3 clients to continue, so we should
+ // abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 2; ++i) { // exclude clients 2 and 3.
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+
+ std::string error_string =
+ "There are not enough clients to complete this protocol session. "
+ "Aborting.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ShareKeysRequestCausesAbortIfServerSendsWrongSizeKey) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // One of client 3's keys is a string of the wrong length, so this should
+ // cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 3; ++i) { // handle client 3 separately
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+ PairOfPublicKeys* bad_keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ bad_keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(6));
+ bad_keypair->set_noise_pk("there's no way this is a valid key");
+
+ std::string error_string = "Invalid public key in request from server.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ShareKeysRequestCausesAbortIfServerSendsInvalidKey) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // One of client 3's keys is a not a valid ECDH key, so this should cause an
+ // abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 3; ++i) { // handle client 3 separately
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+ PairOfPublicKeys* bad_keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ bad_keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(6));
+ bad_keypair->set_noise_pk("Right size, but not an ECDH point");
+
+ std::string error_string = "Invalid public key in request from server.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ShareKeysRequestCausesAbortIfServerSendsTooManyKeys) {
+ // In this test, the client under test is id 1, and it expects there to be no
+ // more than 3 clients. However, the server sends 4 keypairs. This should
+ // cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 3, // max_neighbors_expected
+ 2, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 4; ++i) {
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+
+ std::string error_string =
+ "The ShareKeysRequest received contains too many participants.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ShareKeysRequestCausesAbortIfServerSendsTooFewKeys) {
+ // In this test, the client under test is id 1, and the threshold is 3
+ // clients. However, the server sends only 2 keys. This should cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 2; ++i) { // exclude clients 2 and 3
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+
+ std::string error_string =
+ "The ShareKeysRequest received does not contain enough participants.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ShareKeysRequestCausesAbortIfServerOmitsClientsKey) {
+ // In this test, the client under test is not represented at all in the
+ // server's message. This should cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 4; ++i) {
+ if (i == 1) {
+ continue; // skipping this client
+ }
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+
+ std::string error_string =
+ "The ShareKeysRequest sent by the server doesn't contain this client's "
+ "public keys.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputNotSetStateTest,
+ ShareKeysRequestCausesAbortIfServerDuplicatesClientsKey) {
+ // In this test, the client under test is included twice in the server's
+ // message. This should cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputNotSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_vector_specs), std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 3; ++i) { // handle client 3 separately
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+ PairOfPublicKeys* bad_keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ bad_keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2));
+ bad_keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(3));
+
+ std::string error_string =
+ "Found this client's keys in the ShareKeysRequest twice somehow.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.cc b/fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.cc
new file mode 100644
index 0000000..2cd0823
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.cc
@@ -0,0 +1,133 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.h"
+
+#include <cstdint>
+#include <string>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+SecAggClientR1ShareKeysInputSetState::SecAggClientR1ShareKeysInputSetState(
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ std::unique_ptr<EcdhKeyAgreement> enc_key_agreement,
+ std::unique_ptr<SecAggVectorMap> input_map,
+ std::unique_ptr<std::vector<InputVectorSpecification> > input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<EcdhKeyAgreement> prng_key_agreement,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ std::unique_ptr<AesPrngFactory> prng_factory, AsyncAbort* async_abort)
+ : SecAggClientR1ShareKeysBaseState(
+ std::move(sender), std::move(transition_listener), async_abort),
+ max_neighbors_expected_(max_neighbors_expected),
+ minimum_surviving_neighbors_for_reconstruction_(
+ minimum_surviving_neighbors_for_reconstruction),
+ enc_key_agreement_(std::move(enc_key_agreement)),
+ input_map_(std::move(input_map)),
+ input_vector_specs_(std::move(input_vector_specs)),
+ prng_(std::move(prng)),
+ prng_key_agreement_(std::move(prng_key_agreement)),
+ prng_factory_(std::move(prng_factory)) {}
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR1ShareKeysInputSetState::HandleMessage(
+ const ServerToClientWrapperMessage& message) {
+ // Handle abort messages or share keys requests only.
+ if (message.has_abort()) {
+ if (message.abort().early_success()) {
+ return {std::make_unique<SecAggClientCompletedState>(
+ std::move(sender_), std::move(transition_listener_))};
+ } else {
+ return {std::make_unique<SecAggClientAbortedState>(
+ "Aborting because of abort message from the server.",
+ std::move(sender_), std::move(transition_listener_))};
+ }
+ } else if (!message.has_share_keys_request()) {
+ // Returns an error indicating that the message is of invalid type.
+ return SecAggClientState::HandleMessage(message);
+ }
+ uint32_t client_id;
+ uint32_t number_of_alive_clients;
+ uint32_t number_of_clients;
+ std::string error_message;
+ auto other_client_enc_keys = std::make_unique<std::vector<AesKey> >();
+ auto other_client_prng_keys = std::make_unique<std::vector<AesKey> >();
+ auto other_client_states = std::make_unique<std::vector<OtherClientState> >();
+ auto own_self_key_share = std::make_unique<ShamirShare>();
+ auto session_id = std::make_unique<SessionId>();
+
+ uint8_t self_prng_key_buffer[AesKey::kSize];
+ for (uint8_t& i : self_prng_key_buffer) {
+ i = prng_->Rand8();
+ }
+ auto self_prng_key = std::make_unique<AesKey>(self_prng_key_buffer);
+
+ bool success = HandleShareKeysRequest(
+ message.share_keys_request(), *enc_key_agreement_,
+ max_neighbors_expected_, minimum_surviving_neighbors_for_reconstruction_,
+ *prng_key_agreement_, *self_prng_key, prng_.get(), &client_id,
+ &error_message, &number_of_alive_clients, &number_of_clients,
+ other_client_enc_keys.get(), other_client_prng_keys.get(),
+ other_client_states.get(), &self_prng_key_shares_,
+ &pairwise_prng_key_shares_, session_id.get());
+
+ if (!success) {
+ return AbortAndNotifyServer(error_message);
+ }
+
+ if (async_abort_ && async_abort_->Signalled()) {
+ return AbortAndNotifyServer(async_abort_->Message());
+ }
+
+ if (!EncryptAndSendResponse(*other_client_enc_keys, pairwise_prng_key_shares_,
+ self_prng_key_shares_, sender_.get())) {
+ return AbortAndNotifyServer(async_abort_->Message());
+ }
+
+ *own_self_key_share = self_prng_key_shares_[client_id];
+ return {std::make_unique<SecAggClientR2MaskedInputCollInputSetState>(
+ client_id, minimum_surviving_neighbors_for_reconstruction_,
+ number_of_alive_clients, number_of_clients, std::move(input_map_),
+ std::move(input_vector_specs_), std::move(other_client_states),
+ std::move(other_client_enc_keys), std::move(other_client_prng_keys),
+ std::move(own_self_key_share), std::move(self_prng_key),
+ std::move(sender_), std::move(transition_listener_),
+ std::move(session_id), std::move(prng_factory_), async_abort_)};
+}
+
+std::string SecAggClientR1ShareKeysInputSetState::StateName() const {
+ return "R1_SHARE_KEYS_INPUT_SET";
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.h b/fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.h
new file mode 100644
index 0000000..ebaf6c4
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R1_SHARE_KEYS_INPUT_SET_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R1_SHARE_KEYS_INPUT_SET_STATE_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_r1_share_keys_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/prng.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the client's Round 1: Share Keys state with the input
+// already set. This state should transition to the Round 2: Masked Input
+// Collection (Input Set) state, but can also transition directly to the
+// Completed or Aborted states.
+
+class SecAggClientR1ShareKeysInputSetState
+ : public SecAggClientR1ShareKeysBaseState {
+ public:
+ SecAggClientR1ShareKeysInputSetState(
+ uint32_t max_neighbors_expected,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ std::unique_ptr<EcdhKeyAgreement> enc_key_agreement,
+ std::unique_ptr<SecAggVectorMap> input_map,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<SecurePrng> prng,
+ std::unique_ptr<EcdhKeyAgreement> prng_key_agreement,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ AsyncAbort* async_abort = nullptr);
+
+ ~SecAggClientR1ShareKeysInputSetState() override = default;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > HandleMessage(
+ const ServerToClientWrapperMessage& message) override;
+
+ // Returns the name of this state, "R1_SHARE_KEYS_INPUT_SET".
+ std::string StateName() const override;
+
+ private:
+ friend class SecAggClientR1ShareKeysInputSetStateTest_ShareKeysRequestIsHandledCorrectlyWithDeadClient_Test; // NOLINT
+
+ const uint32_t max_neighbors_expected_;
+ const uint32_t minimum_surviving_neighbors_for_reconstruction_;
+ std::unique_ptr<EcdhKeyAgreement> enc_key_agreement_;
+ std::unique_ptr<SecAggVectorMap> input_map_;
+ std::unique_ptr<std::vector<InputVectorSpecification> > input_vector_specs_;
+ std::unique_ptr<SecurePrng> prng_;
+ std::unique_ptr<EcdhKeyAgreement> prng_key_agreement_;
+ std::unique_ptr<AesPrngFactory> prng_factory_;
+ std::vector<ShamirShare> self_prng_key_shares_;
+ std::vector<ShamirShare> pairwise_prng_key_shares_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R1_SHARE_KEYS_INPUT_SET_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r1_share_keys_input_set_state_test.cc b/fcp/secagg/client/secagg_client_r1_share_keys_input_set_state_test.cc
new file mode 100644
index 0000000..5943672
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r1_share_keys_input_set_state_test.cc
@@ -0,0 +1,901 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r1_share_keys_input_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/aes_gcm_encryption.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+
+using ::testing::Eq;
+using ::testing::Pointee;
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest, IsAbortedReturnsFalse) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r1_state.IsAborted(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ IsCompletedSuccessfullyReturnsFalse) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r1_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest, StartRaisesErrorStatus) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r1_state.Start().ok(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest, SetInputRaisesErrorStatus) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r1_state.SetInput(std::make_unique<SecAggVectorMap>()).ok(),
+ Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest, ErrorMessageRaisesErrorStatus) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r1_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ AbortReturnsValidAbortStateAndNotifiesServer) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::string error_string =
+ "Abort upon external request for reason <Abort reason>.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.Abort("Abort reason");
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ AbortFailureMessageCausesAbortWithoutNotifyingServer) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(),
+ Eq("Aborting because of abort message from the server."));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ EarlySuccessMessageCausesTransitionToCompletedState) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromRandomKeys().value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(true);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("COMPLETED"));
+}
+
+// A gMock matcher to see if the message sent by the client contains shares that
+// reconstruct to the right private keys.
+
+// pairwise_key is the expected pairwise PRNG key.
+// self_key is the expected self PRNG key.
+// enc_keys must be a std::vector<AesKey>, the vector of encryption keys.
+// Assumes 3-of-4 secret sharing.
+MATCHER_P3(ReconstructsCorrectly, pairwise_key, self_key, enc_keys, "") {
+ AesGcmEncryption decryptor;
+ std::vector<ShamirShare> pairwise_shares;
+ std::vector<ShamirShare> self_shares;
+ for (int i = 0; i < enc_keys.size(); ++i) {
+ // Blank shares must be blank in both places
+ if (arg->share_keys_response().encrypted_key_shares(i).empty()) {
+ pairwise_shares.push_back({""});
+ self_shares.push_back({""});
+ continue;
+ }
+ auto decrypted = decryptor.Decrypt(
+ enc_keys[i], arg->share_keys_response().encrypted_key_shares(i));
+ if (!decrypted.ok()) {
+ return false;
+ }
+ PairOfKeyShares key_shares;
+ if (!key_shares.ParseFromString(decrypted.value())) {
+ return false;
+ }
+ pairwise_shares.push_back({key_shares.noise_sk_share()});
+ self_shares.push_back({key_shares.prf_sk_share()});
+ }
+ // Reconstruct keys to see if they match
+ ShamirSecretSharing reconstructor;
+ std::string reconstructed_pairwise_key_string =
+ reconstructor.Reconstruct(3, pairwise_shares, EcdhPrivateKey::kSize)
+ .value();
+ std::string reconstructed_self_key_string =
+ reconstructor.Reconstruct(3, self_shares, AesKey::kSize).value();
+ EcdhPrivateKey reconstructed_pairwise_key(reinterpret_cast<const uint8_t*>(
+ reconstructed_pairwise_key_string.c_str()));
+ AesKey reconstructed_self_key(
+ reinterpret_cast<const uint8_t*>(reconstructed_self_key_string.c_str()));
+ return pairwise_key == reconstructed_pairwise_key &&
+ self_key == reconstructed_self_key;
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ ShareKeysRequestIsHandledCorrectlyWhenNoClientsDie) {
+ // In this test, the client under test is id 1, and there are 4 clients, all
+ // alive.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ // Make a copy of the encryption keys for testing
+ std::vector<AesKey> enc_keys;
+ auto enc_key_agreement =
+ EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(2))
+ .value();
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 4; ++i) {
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+
+ if (i == 1) {
+ enc_keys.push_back(AesKey());
+ } else {
+ enc_keys.push_back(
+ enc_key_agreement->ComputeSharedSecret(ecdh_keys.GetPublicKey(2 * i))
+ .value());
+ }
+ }
+
+ // Make a copy of the self PRNG key.
+ FakePrng prng;
+ uint8_t self_prng_key_buffer[AesKey::kSize];
+ for (int i = 0; i < AesKey::kSize; ++i) {
+ self_prng_key_buffer[i] = prng.Rand8();
+ }
+ auto self_prng_key = AesKey(self_prng_key_buffer);
+
+ EXPECT_CALL(*sender, Send(ReconstructsCorrectly(ecdh_keys.GetPrivateKey(3),
+ self_prng_key, enc_keys)))
+ .Times(1);
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(),
+ Eq("R2_MASKED_INPUT_COLL_INPUT_SET"));
+}
+
+// A gMock matcher to see if the message sent by the client contains shares that
+// reconstruct to the right private keys.
+//
+// This version also takes into account the client's own shares, to ensure that
+// it works even when exactly threshold clients survive.
+//
+// pairwise_key is the expected pairwise PRNG key, an AesKey.
+// self_key is the expected self PRNG key, an AesKey.
+// own_pairwise_key_share is client 1's ShamirShare of its own pairwise PRNG
+// key.
+// own_self_key_share is client 1's ShamirShare of its own pairwise PRNG
+// key.
+// enc_keys must be a std::vector<AesKey>, the vector of encryption keys.
+// Assumes 3-of-4 secret sharing.
+MATCHER_P5(ReconstructsCorrectlyWithOwnKeys, pairwise_key, self_key,
+ own_pairwise_key_share, own_self_key_share, enc_keys, "") {
+ AesGcmEncryption decryptor;
+ std::vector<ShamirShare> pairwise_shares;
+ std::vector<ShamirShare> self_shares;
+ for (int i = 0; i < enc_keys.size(); ++i) {
+ // Blank shares must be blank in both places
+ if (arg->share_keys_response().encrypted_key_shares(i).empty()) {
+ if (i == 1) {
+ pairwise_shares.push_back(own_pairwise_key_share);
+ self_shares.push_back(own_self_key_share);
+ } else {
+ pairwise_shares.push_back({""});
+ self_shares.push_back({""});
+ }
+ continue;
+ }
+ auto decrypted = decryptor.Decrypt(
+ enc_keys[i], arg->share_keys_response().encrypted_key_shares(i));
+ if (!decrypted.ok()) {
+ return false;
+ }
+ PairOfKeyShares key_shares;
+ if (!key_shares.ParseFromString(decrypted.value())) {
+ return false;
+ }
+ pairwise_shares.push_back({key_shares.noise_sk_share()});
+ self_shares.push_back({key_shares.prf_sk_share()});
+ }
+ // Reconstruct keys to see if they match
+ ShamirSecretSharing reconstructor;
+ std::string reconstructed_pairwise_key_string =
+ reconstructor.Reconstruct(3, pairwise_shares, EcdhPrivateKey::kSize)
+ .value();
+ std::string reconstructed_self_key_string =
+ reconstructor.Reconstruct(3, self_shares, AesKey::kSize).value();
+ EcdhPrivateKey reconstructed_pairwise_key(reinterpret_cast<const uint8_t*>(
+ reconstructed_pairwise_key_string.c_str()));
+ AesKey reconstructed_self_key(
+ reinterpret_cast<const uint8_t*>(reconstructed_self_key_string.c_str()));
+ EXPECT_THAT(reconstructed_pairwise_key_string,
+ Eq(std::string(reinterpret_cast<const char*>(pairwise_key.data()),
+ EcdhPrivateKey::kSize)));
+ return pairwise_key == reconstructed_pairwise_key;
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ ShareKeysRequestIsHandledCorrectlyWithDeadClient) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // Client 3 has died in this round.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ // Make a copy of the encryption keys for testing
+ std::vector<AesKey> enc_keys;
+ auto enc_key_agreement =
+ EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(2))
+ .value();
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 3; ++i) { // exclude client 3
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+
+ if (i == 1) {
+ enc_keys.push_back(AesKey());
+ } else {
+ enc_keys.push_back(
+ enc_key_agreement->ComputeSharedSecret(ecdh_keys.GetPublicKey(2 * i))
+ .value());
+ }
+ }
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ enc_keys.push_back(AesKey());
+
+ // Make a copy of the self PRNG key.
+ FakePrng prng;
+ uint8_t self_prng_key_buffer[AesKey::kSize];
+ for (int i = 0; i < AesKey::kSize; ++i) {
+ self_prng_key_buffer[i] = prng.Rand8();
+ }
+ auto self_prng_key = AesKey(self_prng_key_buffer);
+
+ r1_state.SetUpShares(3, 4, ecdh_keys.GetPrivateKey(3), self_prng_key,
+ &r1_state.self_prng_key_shares_,
+ &r1_state.pairwise_prng_key_shares_);
+ ShamirShare own_pairwise_key_share = r1_state.pairwise_prng_key_shares_.at(1);
+ ShamirShare own_self_key_share = r1_state.self_prng_key_shares_.at(1);
+
+ EXPECT_CALL(*sender,
+ Send(ReconstructsCorrectlyWithOwnKeys(
+ ecdh_keys.GetPrivateKey(3), self_prng_key,
+ own_pairwise_key_share, own_self_key_share, enc_keys)))
+ .Times(1);
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(),
+ Eq("R2_MASKED_INPUT_COLL_INPUT_SET"));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ ShareKeysRequestCausesAbortIfTooManyDeadClients) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // Clients 2 and 3 died, and we need 3 clients to continue, so we should
+ // abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 2; ++i) { // exclude clients 2 and 3.
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+
+ std::string error_string =
+ "There are not enough clients to complete this protocol session. "
+ "Aborting.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ ShareKeysRequestCausesAbortIfServerSendsWrongSizeKey) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // One of client 3's keys is a string of the wrong length, so this should
+ // cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 3; ++i) { // handle client 3 separately
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+ PairOfPublicKeys* bad_keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ bad_keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(6));
+ bad_keypair->set_noise_pk("there's no way this is a valid key");
+
+ std::string error_string = "Invalid public key in request from server.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ ShareKeysRequestCausesAbortIfServerSendsInvalidKey) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // One of client 3's keys is a not a valid ECDH key, so this should cause an
+ // abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 3; ++i) { // handle client 3 separately
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+ PairOfPublicKeys* bad_keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ bad_keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(6));
+ bad_keypair->set_noise_pk("Right size, but not an ECDH point");
+
+ std::string error_string = "Invalid public key in request from server.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ ShareKeysRequestCausesAbortIfServerSendsTooManyKeys) {
+ // In this test, the client under test is id 1, and it expects there to be no
+ // more than 3 clients. However, the server sends 4 keypairs. This should
+ // cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 3, // max_neighbors_expected
+ 2, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 4; ++i) {
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+
+ std::string error_string =
+ "The ShareKeysRequest received contains too many participants.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ ShareKeysRequestCausesAbortIfServerSendsTooFewKeys) {
+ // In this test, the client under test is id 1, and the threshold is 3
+ // clients. However, the server sends only 2 keys. This should cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 2; ++i) { // exclude clients 2 and 3
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+
+ std::string error_string =
+ "The ShareKeysRequest received does not contain enough participants.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ ShareKeysRequestCausesAbortIfServerOmitsClientsKey) {
+ // In this test, the client under test is not represented at all in the
+ // server's message. This should cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 4; ++i) {
+ if (i == 1) {
+ continue; // skipping this client
+ }
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+
+ std::string error_string =
+ "The ShareKeysRequest sent by the server doesn't contain this client's "
+ "public keys.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR1ShareKeysInputSetStateTest,
+ ShareKeysRequestCausesAbortIfServerDuplicatesClientsKey) {
+ // In this test, the client under test is included twice in the server's
+ // message. This should cause an abort.
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR1ShareKeysInputSetState r1_state(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(2),
+ ecdh_keys.GetPublicKey(2))
+ .value()),
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<FakePrng>(),
+ std::move(EcdhKeyAgreement::CreateFromKeypair(ecdh_keys.GetPrivateKey(3),
+ ecdh_keys.GetPublicKey(3))
+ .value()),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ ServerToClientWrapperMessage message;
+ for (int i = 0; i < 3; ++i) { // handle client 3 separately
+ PairOfPublicKeys* keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+ PairOfPublicKeys* bad_keypair =
+ message.mutable_share_keys_request()->add_pairs_of_public_keys();
+ bad_keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2));
+ bad_keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(3));
+
+ std::string error_string =
+ "Found this client's keys in the ShareKeysRequest twice somehow.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(message.share_keys_request()).data);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r1_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.cc b/fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.cc
new file mode 100644
index 0000000..25b3950
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.cc
@@ -0,0 +1,218 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/container/node_hash_map.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_alive_base_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/shared/aes_gcm_encryption.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/map_of_masks.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientR2MaskedInputCollBaseState::SecAggClientR2MaskedInputCollBaseState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ AsyncAbort* async_abort)
+ : SecAggClientAliveBaseState(std::move(sender),
+ std::move(transition_listener),
+ ClientState::R2_MASKED_INPUT, async_abort) {}
+
+SecAggClientR2MaskedInputCollBaseState::
+ ~SecAggClientR2MaskedInputCollBaseState() = default;
+
+std::unique_ptr<SecAggVectorMap>
+SecAggClientR2MaskedInputCollBaseState::HandleMaskedInputCollectionRequest(
+ const MaskedInputCollectionRequest& request, uint32_t client_id,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_clients,
+ const std::vector<AesKey>& other_client_enc_keys,
+ const std::vector<AesKey>& other_client_prng_keys,
+ const ShamirShare& own_self_key_share, const AesKey& self_prng_key,
+ const SessionId& session_id, const AesPrngFactory& prng_factory,
+ uint32_t* number_of_alive_clients,
+ std::vector<OtherClientState>* other_client_states,
+ std::vector<ShamirShare>* pairwise_key_shares,
+ std::vector<ShamirShare>* self_key_shares, std::string* error_message) {
+ if (request.encrypted_key_shares_size() !=
+ static_cast<int>(number_of_clients)) {
+ *error_message =
+ "The number of encrypted shares sent by the server does not match "
+ "the number of clients.";
+ return nullptr;
+ }
+
+ // Parse the request, decrypt and store the key shares from other clients.
+ AesGcmEncryption decryptor;
+ std::string plaintext;
+
+ for (int i = 0; i < static_cast<int>(number_of_clients); ++i) {
+ if (async_abort_ && async_abort_->Signalled()) {
+ *error_message = async_abort_->Message();
+ return nullptr;
+ }
+ if (i == static_cast<int>(client_id)) {
+ // this client
+ pairwise_key_shares->push_back({""}); // this will never be needed
+ self_key_shares->push_back(own_self_key_share);
+ } else if ((*other_client_states)[i] != OtherClientState::kAlive) {
+ if (request.encrypted_key_shares(i).length() > 0) {
+ // A client who was considered aborted sent key shares.
+ *error_message =
+ "Received encrypted key shares from an aborted client.";
+ return nullptr;
+ } else {
+ pairwise_key_shares->push_back({""});
+ self_key_shares->push_back({""});
+ }
+ } else if (request.encrypted_key_shares(i).length() == 0) {
+ // A client who was considered alive dropped out. Mark it as dead.
+ (*other_client_states)[i] = OtherClientState::kDeadAtRound2;
+ pairwise_key_shares->push_back({""});
+ self_key_shares->push_back({""});
+ --(*number_of_alive_clients);
+ } else {
+ // A living client sent encrypted key shares, so we decrypt and store
+ // them.
+ auto decrypted = decryptor.Decrypt(other_client_enc_keys[i],
+ request.encrypted_key_shares(i));
+ if (!decrypted.ok()) {
+ *error_message = "Authentication of encrypted data failed.";
+ return nullptr;
+ } else {
+ plaintext = decrypted.value();
+ }
+
+ PairOfKeyShares pairwise_and_self_key_shares;
+ if (!pairwise_and_self_key_shares.ParseFromString(plaintext)) {
+ *error_message = "Unable to parse decrypted pair of key shares.";
+ return nullptr;
+ }
+ pairwise_key_shares->push_back(
+ {pairwise_and_self_key_shares.noise_sk_share()});
+ self_key_shares->push_back({pairwise_and_self_key_shares.prf_sk_share()});
+ }
+ }
+
+ if (*number_of_alive_clients <
+ minimum_surviving_neighbors_for_reconstruction) {
+ *error_message =
+ "There are not enough clients to complete this protocol session. "
+ "Aborting.";
+ return nullptr;
+ }
+
+ // Compute the map of masks using the other clients' keys.
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+
+ prng_keys_to_add.push_back(self_prng_key);
+
+ for (int i = 0; i < static_cast<int>(number_of_clients); ++i) {
+ if (async_abort_ && async_abort_->Signalled()) {
+ *error_message = async_abort_->Message();
+ return nullptr;
+ }
+ if (i == static_cast<int>(client_id) ||
+ (*other_client_states)[i] != OtherClientState::kAlive) {
+ continue;
+ } else if (i < static_cast<int>(client_id)) {
+ prng_keys_to_add.push_back(other_client_prng_keys[i]);
+ } else {
+ prng_keys_to_subtract.push_back(other_client_prng_keys[i]);
+ }
+ }
+
+ std::unique_ptr<SecAggVectorMap> map =
+ MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ session_id, prng_factory, async_abort_);
+ if (!map) {
+ *error_message = async_abort_->Message();
+ return nullptr;
+ }
+ return map;
+}
+
+// TODO(team): Add two SecAggVector values more efficiently, without
+// having to unpack both vectors and convert the result back into the
+// packed form.
+SecAggVector AddSecAggVectors(SecAggVector v1, SecAggVector v2) {
+ FCP_CHECK(v1.modulus() == v2.modulus());
+ uint64_t modulus = v1.modulus();
+
+ // The code below moves v1 and v2 to temp instances to "consume" and destroy
+ // the original vectors as soon as possible in order to minimize the number of
+ // concurrent copies of the data in memory.
+ std::vector<uint64_t> vec1 = SecAggVector(std::move(v1)).GetAsUint64Vector();
+
+ {
+ // Keep vec2 scoped so that it is destroyed as soon as it is no longer used
+ // and before creating the SecAggVector instance below.
+ std::vector<uint64_t> vec2 =
+ SecAggVector(std::move(v2)).GetAsUint64Vector();
+
+ // Add the two vectors in place assigning the values back into vec1.
+ FCP_CHECK(vec1.size() == vec2.size());
+ for (int i = 0; i < static_cast<int>(vec1.size()); ++i) {
+ vec1[i] = ((vec1[i] + vec2[i]) % modulus);
+ }
+ }
+
+ return SecAggVector(vec1, modulus);
+}
+
+void SecAggClientR2MaskedInputCollBaseState::SendMaskedInput(
+ std::unique_ptr<SecAggVectorMap> input_map,
+ std::unique_ptr<SecAggVectorMap> map_of_masks) {
+ ClientToServerWrapperMessage to_send;
+ for (auto& pair : *input_map) {
+ // SetInput should already have guaranteed these
+ FCP_CHECK(map_of_masks->find(pair.first) != map_of_masks->end());
+ SecAggVector& mask = map_of_masks->at(pair.first);
+ SecAggVector sum =
+ AddSecAggVectors(std::move(pair.second), std::move(mask));
+ MaskedInputVector sum_vec_proto;
+ sum_vec_proto.set_encoded_vector(std::move(sum).TakePackedBytes());
+ (*to_send.mutable_masked_input_response()->mutable_vectors())[pair.first] =
+ std::move(sum_vec_proto);
+ }
+ sender_->Send(&to_send);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h b/fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h
new file mode 100644
index 0000000..9fefd35
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_BASE_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_BASE_STATE_H_
+
+#include <cstdint>
+#include <memory>
+#include <set>
+#include <string>
+
+#include "absl/container/node_hash_map.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_alive_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/map_of_masks.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+// This is an abstract class which is the parent of three possible state classes
+// representing the states that the client may be in at Round 2: Masked Input
+// Collection. It should never be instantiated directly, but defines code that
+// will be used by multiple concrete Round 2 classes.
+
+class SecAggClientR2MaskedInputCollBaseState
+ : public SecAggClientAliveBaseState {
+ public:
+ ~SecAggClientR2MaskedInputCollBaseState() override;
+
+ protected:
+ // SecAggClientR2MaskedInputCollBaseState should never be instantiated
+ // directly.
+ explicit SecAggClientR2MaskedInputCollBaseState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ AsyncAbort* async_abort = nullptr);
+
+ // Handles the logic associated with receiving a MaskedInputCollectionRequest.
+ // Adds the recovered key shares to pairwise_key_shares and self_key_shares.
+ // Updates the other_client_states and number_of_alive_clients based on
+ // dropouts recorded in the request.
+ //
+ // The return value is computed map of masks if everything succeeed.
+ // If there was a failure, the return value is nullptr, and error_message is
+ // set to a non-empty string.
+ std::unique_ptr<SecAggVectorMap> HandleMaskedInputCollectionRequest(
+ const MaskedInputCollectionRequest& request, uint32_t client_id,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_clients,
+ const std::vector<AesKey>& other_client_enc_keys,
+ const std::vector<AesKey>& other_client_prng_keys,
+ const ShamirShare& own_self_key_share, const AesKey& self_prng_key,
+ const SessionId& session_id, const AesPrngFactory& prng_factory,
+ uint32_t* number_of_alive_clients,
+ std::vector<OtherClientState>* other_client_states,
+ std::vector<ShamirShare>* pairwise_key_shares,
+ std::vector<ShamirShare>* self_key_shares, std::string* error_message);
+
+ // Consumes a map of masks to the input map and sends the result of adding
+ // the two to the server.
+ void SendMaskedInput(std::unique_ptr<SecAggVectorMap> input_map,
+ std::unique_ptr<SecAggVectorMap> map_of_masks);
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_BASE_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.cc b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.cc
new file mode 100644
index 0000000..a3accf2
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.cc
@@ -0,0 +1,155 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientR2MaskedInputCollInputNotSetState::
+ SecAggClientR2MaskedInputCollInputNotSetState(
+ uint32_t client_id,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_alive_neighbors, uint32_t number_of_neighbors,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states,
+ std::unique_ptr<std::vector<AesKey> > other_client_enc_keys,
+ std::unique_ptr<std::vector<AesKey> > other_client_prng_keys,
+ std::unique_ptr<ShamirShare> own_self_key_share,
+ std::unique_ptr<AesKey> self_prng_key,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ std::unique_ptr<SessionId> session_id,
+ std::unique_ptr<AesPrngFactory> prng_factory, AsyncAbort* async_abort)
+ : SecAggClientR2MaskedInputCollBaseState(
+ std::move(sender), std::move(transition_listener), async_abort),
+ client_id_(client_id),
+ minimum_surviving_neighbors_for_reconstruction_(
+ minimum_surviving_neighbors_for_reconstruction),
+ number_of_alive_neighbors_(number_of_alive_neighbors),
+ number_of_neighbors_(number_of_neighbors),
+ input_vector_specs_(std::move(input_vector_specs)),
+ other_client_states_(std::move(other_client_states)),
+ other_client_enc_keys_(std::move(other_client_enc_keys)),
+ other_client_prng_keys_(std::move(other_client_prng_keys)),
+ own_self_key_share_(std::move(own_self_key_share)),
+ self_prng_key_(std::move(self_prng_key)),
+ session_id_(std::move(session_id)),
+ prng_factory_(std::move(prng_factory)) {
+ FCP_CHECK(client_id_ >= 0)
+ << "Client id must not be negative but was " << client_id_;
+}
+
+SecAggClientR2MaskedInputCollInputNotSetState::
+ ~SecAggClientR2MaskedInputCollInputNotSetState() = default;
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR2MaskedInputCollInputNotSetState::HandleMessage(
+ const ServerToClientWrapperMessage& message) {
+ // Handle abort messages or masked input requests only.
+ if (message.has_abort()) {
+ if (message.abort().early_success()) {
+ return {std::make_unique<SecAggClientCompletedState>(
+ std::move(sender_), std::move(transition_listener_))};
+ } else {
+ return {std::make_unique<SecAggClientAbortedState>(
+ "Aborting because of abort message from the server.",
+ std::move(sender_), std::move(transition_listener_))};
+ }
+ } else if (!message.has_masked_input_request()) {
+ // Returns an error indicating that the message is of invalid type.
+ return SecAggClientState::HandleMessage(message);
+ }
+
+ const MaskedInputCollectionRequest& request = message.masked_input_request();
+ std::string error_message;
+ auto pairwise_key_shares = std::make_unique<std::vector<ShamirShare> >();
+ auto self_key_shares = std::make_unique<std::vector<ShamirShare> >();
+
+ std::unique_ptr<SecAggVectorMap> map_of_masks =
+ HandleMaskedInputCollectionRequest(
+ request, client_id_, *input_vector_specs_,
+ minimum_surviving_neighbors_for_reconstruction_, number_of_neighbors_,
+ *other_client_enc_keys_, *other_client_prng_keys_,
+ *own_self_key_share_, *self_prng_key_, *session_id_, *prng_factory_,
+ &number_of_alive_neighbors_, other_client_states_.get(),
+ pairwise_key_shares.get(), self_key_shares.get(), &error_message);
+
+ if (!map_of_masks) {
+ return AbortAndNotifyServer(error_message);
+ }
+
+ return {std::make_unique<SecAggClientR2MaskedInputCollWaitingForInputState>(
+ client_id_, minimum_surviving_neighbors_for_reconstruction_,
+ number_of_alive_neighbors_, number_of_neighbors_,
+ std::move(input_vector_specs_), std::move(map_of_masks),
+ std::move(other_client_states_), std::move(pairwise_key_shares),
+ std::move(self_key_shares), std::move(sender_),
+ std::move(transition_listener_), async_abort_)};
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR2MaskedInputCollInputNotSetState::SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map) {
+ if (!ValidateInput(*input_map, *input_vector_specs_)) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "The input to SetInput does not match the "
+ "InputVectorSpecification.";
+ }
+
+ return {std::make_unique<SecAggClientR2MaskedInputCollInputSetState>(
+ client_id_, minimum_surviving_neighbors_for_reconstruction_,
+ number_of_alive_neighbors_, number_of_neighbors_, std::move(input_map),
+ std::move(input_vector_specs_), std::move(other_client_states_),
+ std::move(other_client_enc_keys_), std::move(other_client_prng_keys_),
+ std::move(own_self_key_share_), std::move(self_prng_key_),
+ std::move(sender_), std::move(transition_listener_),
+ std::move(session_id_), std::move(prng_factory_), async_abort_)};
+}
+
+std::string SecAggClientR2MaskedInputCollInputNotSetState::StateName() const {
+ return "R2_MASKED_INPUT_COLL_INPUT_NOT_SET";
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.h b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.h
new file mode 100644
index 0000000..7ecf6bf
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_INPUT_NOT_SET_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_INPUT_NOT_SET_STATE_H_
+
+#include <cstdint>
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the client's Round 2: Masked Input Collection state
+// with the input not yet set, and the message from the server not yet received.
+// This state should transition to one of the other two Round 2 states:
+// "Input Set" if SetInput is called first, or "Waiting For Input" if the
+// server's message is received first. It can also transition directly to the
+// Completed or Aborted states.
+
+class SecAggClientR2MaskedInputCollInputNotSetState
+ : public SecAggClientR2MaskedInputCollBaseState {
+ public:
+ SecAggClientR2MaskedInputCollInputNotSetState(
+ uint32_t client_id,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_alive_neighbors, uint32_t number_of_neighbors,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states,
+ std::unique_ptr<std::vector<AesKey> > other_client_enc_keys,
+ std::unique_ptr<std::vector<AesKey> > other_client_prng_keys,
+ std::unique_ptr<ShamirShare> own_self_key_share,
+ std::unique_ptr<AesKey> self_prng_key,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ std::unique_ptr<SessionId> session_id,
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ AsyncAbort* async_abort = nullptr);
+
+ ~SecAggClientR2MaskedInputCollInputNotSetState() override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > HandleMessage(
+ const ServerToClientWrapperMessage& message) override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map) override;
+
+ // Returns the name of this state, "R2_MASKED_INPUT_COLL_INPUT_NOT_SET".
+ std::string StateName() const override;
+
+ private:
+ const uint32_t client_id_;
+ const uint32_t minimum_surviving_neighbors_for_reconstruction_;
+ uint32_t number_of_alive_neighbors_;
+ const uint32_t number_of_neighbors_;
+ std::unique_ptr<std::vector<InputVectorSpecification> > input_vector_specs_;
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states_;
+ std::unique_ptr<std::vector<AesKey> > other_client_enc_keys_;
+ std::unique_ptr<std::vector<AesKey> > other_client_prng_keys_;
+ std::unique_ptr<ShamirShare> own_self_key_share_;
+ std::unique_ptr<AesKey> self_prng_key_;
+ std::unique_ptr<SessionId> session_id_;
+ std::unique_ptr<AesPrngFactory> prng_factory_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_INPUT_NOT_SET_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state_test.cc b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state_test.cc
new file mode 100644
index 0000000..3e8e8b4
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state_test.cc
@@ -0,0 +1,735 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/aes_gcm_encryption.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::Pointee;
+
+// For testing purposes, make an AesKey out of a string.
+AesKey MakeAesKey(const std::string& key) {
+ EXPECT_THAT(key.size(), Eq(AesKey::kSize));
+ return AesKey(reinterpret_cast<const uint8_t*>(key.c_str()));
+}
+
+// Default test session_id.
+SessionId session_id = {"session id number, 32 bytes long."};
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest, IsAbortedReturnsFalse) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r2_state.IsAborted(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ IsCompletedSuccessfullyReturnsFalse) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r2_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ StartRaisesErrorStatus) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r2_state.Start().ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ ErrorMessageRaisesErrorStatus) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r2_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ AbortReturnsValidAbortStateAndNotifiesServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::string error_string =
+ "Abort upon external request for reason <Abort reason>.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.Abort("Abort reason");
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ AbortFailureMessageCausesAbortWithoutNotifyingServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(),
+ Eq("Aborting because of abort message from the server."));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ EarlySuccessMessageCausesTransitionToCompletedState) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(true);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("COMPLETED"));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ SetInputTransitionsToInputSetStateWithoutNotifyingServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({5, 8, 22, 30}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(),
+ Eq("R2_MASKED_INPUT_COLL_INPUT_SET"));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfVectorIsWrongSize) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector has too many elements.
+ input_map->insert(
+ std::make_pair("test", SecAggVector({5, 8, 22, 30, 7}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputVectorIsTooLargeForBitWidth) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector's bit_width does not match the specified modulus of 32.
+ input_map->insert(std::make_pair("test", SecAggVector({5, 8, 22, 30}, 64)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputVectorHasWrongName) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector has the wrong name.
+ input_map->insert(
+ std::make_pair("incorret", SecAggVector({5, 8, 22, 30}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputHasTooManyVectors) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({5, 8, 22, 30}, 32)));
+ // This vector is extra.
+ input_map->insert(std::make_pair("test2", SecAggVector({4, 7, 21, 29}, 32)));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ SetInputRaisesErrorStatusIfInputHasTooFewVectors) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ // Expects two vectors.
+ input_vector_specs->push_back(InputVectorSpecification("test2", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({5, 8, 22, 30}, 32)));
+ // Missing second vector.
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ MaskedInputCollectionRequestIsHandledCorrectlyWhenNoClientsDie) {
+ // In this test, the client under test is id 1, and there are 4 clients, all
+ // alive.
+ SecAggVectorMap input_map;
+ input_map.insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ std::vector<AesKey> enc_keys = {
+ MakeAesKey("other client encryption key 0000"),
+ MakeAesKey("other client encryption key 1111"),
+ MakeAesKey("other client encryption key 2222"),
+ MakeAesKey("other client encryption key 3333")};
+ std::vector<AesKey> other_client_prng_keys = {
+ MakeAesKey("other client pairwise prng key 0"), AesKey(),
+ MakeAesKey("other client pairwise prng key 2"),
+ MakeAesKey("other client pairwise prng key 3")};
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::make_unique<std::vector<InputVectorSpecification> >(
+ input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 4, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(enc_keys),
+ std::make_unique<std::vector<AesKey> >(other_client_prng_keys),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ // These are strings because they're for putting into protocol buffers.
+ std::vector<std::string> expected_self_key_shares = {
+ "shared self prng key for client #000", "",
+ "shared self prng key for client #222",
+ "shared self prng key for client #333"};
+ std::vector<std::string> expected_pairwise_key_shares = {
+ "shared pairwise prng key for client0", "",
+ "shared pairwise prng key for client2",
+ "shared pairwise prng key for client3"};
+
+ ServerToClientWrapperMessage message;
+ AesGcmEncryption encryptor;
+ for (int i = 0; i < 4; ++i) {
+ PairOfKeyShares key_shares_pair;
+ key_shares_pair.set_noise_sk_share(expected_pairwise_key_shares[i]);
+ key_shares_pair.set_prf_sk_share(expected_self_key_shares[i]);
+ std::string serialized_pair = key_shares_pair.SerializeAsString();
+ std::string ciphertext =
+ encryptor.Encrypt(enc_keys[i], serialized_pair);
+ message.mutable_masked_input_request()->add_encrypted_key_shares(
+ ciphertext);
+ }
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(),
+ Eq("R2_MASKED_INPUT_COLL_WAITING_FOR_INPUT"));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ MaskedInputCollectionRequestIsHandledCorrectlyWithDeadClient) {
+ // In this test, the client under test is id 1, and there are 4 clients, all
+ // alive.
+ // Client 3 has died in this round.
+ SecAggVectorMap input_map;
+ input_map.insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ std::vector<AesKey> enc_keys = {
+ MakeAesKey("other client encryption key 0000"),
+ MakeAesKey("other client encryption key 1111"),
+ MakeAesKey("other client encryption key 2222"),
+ MakeAesKey("other client encryption key 3333")};
+ std::vector<AesKey> other_client_prng_keys = {
+ MakeAesKey("other client pairwise prng key 0"), AesKey(),
+ MakeAesKey("other client pairwise prng key 2"),
+ MakeAesKey("other client pairwise prng key 3")};
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::make_unique<std::vector<InputVectorSpecification> >(
+ input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 4, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(enc_keys),
+ std::make_unique<std::vector<AesKey> >(other_client_prng_keys),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::vector<std::string> expected_self_key_shares = {
+ "shared self prng key for client #000", "",
+ "shared self prng key for client #222",
+ "shared self prng key for client #333"};
+ std::vector<std::string> expected_pairwise_key_shares = {
+ "shared pairwise prng key for client0", "",
+ "shared pairwise prng key for client2",
+ "shared pairwise prng key for client3"};
+
+ ServerToClientWrapperMessage message;
+ AesGcmEncryption encryptor;
+ for (int i = 0; i < 3; ++i) { // Exclude client 3.
+ PairOfKeyShares key_shares_pair;
+ key_shares_pair.set_noise_sk_share(expected_pairwise_key_shares[i]);
+ key_shares_pair.set_prf_sk_share(expected_self_key_shares[i]);
+ std::string serialized_pair = key_shares_pair.SerializeAsString();
+ std::string ciphertext =
+ encryptor.Encrypt(enc_keys[i], serialized_pair);
+ message.mutable_masked_input_request()->add_encrypted_key_shares(
+ ciphertext);
+ }
+ message.mutable_masked_input_request()->add_encrypted_key_shares("");
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(),
+ Eq("R2_MASKED_INPUT_COLL_WAITING_FOR_INPUT"));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputNotSetStateTest,
+ MaskedInputCollectionRequestCausesAbortIfTooManyDeadClients) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // Clients 2 and 3 died, and we need 3 clients to continue, so we should
+ // abort.
+ SecAggVectorMap input_map;
+ input_map.insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ std::vector<AesKey> enc_keys = {
+ MakeAesKey("other client encryption key 0000"),
+ MakeAesKey("other client encryption key 1111"),
+ MakeAesKey("other client encryption key 2222"),
+ MakeAesKey("other client encryption key 3333")};
+ std::vector<AesKey> other_client_prng_keys = {
+ MakeAesKey("other client pairwise prng key 0"), AesKey(),
+ MakeAesKey("other client pairwise prng key 2"),
+ MakeAesKey("other client pairwise prng key 3")};
+ SecAggClientR2MaskedInputCollInputNotSetState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::make_unique<std::vector<InputVectorSpecification> >(
+ input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 4, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(enc_keys),
+ std::make_unique<std::vector<AesKey> >(other_client_prng_keys),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::vector<std::string> expected_self_key_shares = {
+ "shared self prng key for client #000", "",
+ "shared self prng key for client #222",
+ "shared self prng key for client #333"};
+ std::vector<std::string> expected_pairwise_key_shares = {
+ "shared pairwise prng key for client0", "",
+ "shared pairwise prng key for client2",
+ "shared pairwise prng key for client3"};
+
+ ServerToClientWrapperMessage message;
+ AesGcmEncryption encryptor;
+ for (int i = 0; i < 2; ++i) { // Exclude clients 3 & 4
+ PairOfKeyShares key_shares_pair;
+ key_shares_pair.set_noise_sk_share(expected_pairwise_key_shares[i]);
+ key_shares_pair.set_prf_sk_share(expected_self_key_shares[i]);
+ std::string serialized_pair = key_shares_pair.SerializeAsString();
+ std::string ciphertext =
+ encryptor.Encrypt(enc_keys[i], serialized_pair);
+ message.mutable_masked_input_request()->add_encrypted_key_shares(
+ ciphertext);
+ }
+ message.mutable_masked_input_request()->add_encrypted_key_shares("");
+ message.mutable_masked_input_request()->add_encrypted_key_shares("");
+
+ std::string error_string =
+ "There are not enough clients to complete this protocol session. "
+ "Aborting.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.cc b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.cc
new file mode 100644
index 0000000..73cdec0
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.cc
@@ -0,0 +1,137 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
+#include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientR2MaskedInputCollInputSetState::
+ SecAggClientR2MaskedInputCollInputSetState(
+ uint32_t client_id,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_alive_neighbors, uint32_t number_of_neighbors,
+ std::unique_ptr<SecAggVectorMap> input_map,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states,
+ std::unique_ptr<std::vector<AesKey> > other_client_enc_keys,
+ std::unique_ptr<std::vector<AesKey> > other_client_prng_keys,
+ std::unique_ptr<ShamirShare> own_self_key_share,
+ std::unique_ptr<AesKey> self_prng_key,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ std::unique_ptr<SessionId> session_id,
+ std::unique_ptr<AesPrngFactory> prng_factory, AsyncAbort* async_abort)
+ : SecAggClientR2MaskedInputCollBaseState(
+ std::move(sender), std::move(transition_listener), async_abort),
+ client_id_(client_id),
+ minimum_surviving_neighbors_for_reconstruction_(
+ minimum_surviving_neighbors_for_reconstruction),
+ number_of_alive_neighbors_(number_of_alive_neighbors),
+ number_of_neighbors_(number_of_neighbors),
+ input_map_(std::move(input_map)),
+ input_vector_specs_(std::move(input_vector_specs)),
+ other_client_states_(std::move(other_client_states)),
+ other_client_enc_keys_(std::move(other_client_enc_keys)),
+ other_client_prng_keys_(std::move(other_client_prng_keys)),
+ own_self_key_share_(std::move(own_self_key_share)),
+ self_prng_key_(std::move(self_prng_key)),
+ session_id_(std::move(session_id)),
+ prng_factory_(std::move(prng_factory)) {
+ FCP_CHECK(client_id_ >= 0)
+ << "Client id must not be negative but was " << client_id_;
+}
+
+SecAggClientR2MaskedInputCollInputSetState::
+ ~SecAggClientR2MaskedInputCollInputSetState() = default;
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR2MaskedInputCollInputSetState::HandleMessage(
+ const ServerToClientWrapperMessage& message) {
+ // Handle abort messages or masked input requests only.
+ if (message.has_abort()) {
+ if (message.abort().early_success()) {
+ return {std::make_unique<SecAggClientCompletedState>(
+ std::move(sender_), std::move(transition_listener_))};
+ } else {
+ return {std::make_unique<SecAggClientAbortedState>(
+ "Aborting because of abort message from the server.",
+ std::move(sender_), std::move(transition_listener_))};
+ }
+ } else if (!message.has_masked_input_request()) {
+ // Returns an error indicating that the message is of invalid type.
+ return SecAggClientState::HandleMessage(message);
+ }
+
+ const MaskedInputCollectionRequest& request = message.masked_input_request();
+ std::string error_message;
+ auto pairwise_key_shares = std::make_unique<std::vector<ShamirShare> >();
+ auto self_key_shares = std::make_unique<std::vector<ShamirShare> >();
+
+ std::unique_ptr<SecAggVectorMap> map_of_masks =
+ HandleMaskedInputCollectionRequest(
+ request, client_id_, *input_vector_specs_,
+ minimum_surviving_neighbors_for_reconstruction_, number_of_neighbors_,
+ *other_client_enc_keys_, *other_client_prng_keys_,
+ *own_self_key_share_, *self_prng_key_, *session_id_, *prng_factory_,
+ &number_of_alive_neighbors_, other_client_states_.get(),
+ pairwise_key_shares.get(), self_key_shares.get(), &error_message);
+
+ if (!map_of_masks) {
+ return AbortAndNotifyServer(error_message);
+ }
+
+ SendMaskedInput(std::move(input_map_), std::move(map_of_masks));
+
+ return {std::make_unique<SecAggClientR3UnmaskingState>(
+ client_id_, number_of_alive_neighbors_,
+ minimum_surviving_neighbors_for_reconstruction_, number_of_neighbors_,
+ std::move(other_client_states_), std::move(pairwise_key_shares),
+ std::move(self_key_shares), std::move(sender_),
+ std::move(transition_listener_), async_abort_)};
+}
+
+std::string SecAggClientR2MaskedInputCollInputSetState::StateName() const {
+ return "R2_MASKED_INPUT_COLL_INPUT_SET";
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h
new file mode 100644
index 0000000..93860ca
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_INPUT_SET_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_INPUT_SET_STATE_H_
+
+#include <cstdint>
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the client's Round 2: Masked Input Collection state
+// with the input already set. This state should transition to the
+// Round 3: Unmasking state, but can also transition directly to the Completed
+// or Aborted states.
+
+class SecAggClientR2MaskedInputCollInputSetState
+ : public SecAggClientR2MaskedInputCollBaseState {
+ public:
+ SecAggClientR2MaskedInputCollInputSetState(
+ uint32_t client_id,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_alive_neighbors, uint32_t number_of_neighbors,
+ std::unique_ptr<SecAggVectorMap> input_map,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states,
+ std::unique_ptr<std::vector<AesKey> > other_client_enc_keys,
+ std::unique_ptr<std::vector<AesKey> > other_client_prng_keys,
+ std::unique_ptr<ShamirShare> own_self_key_share,
+ std::unique_ptr<AesKey> self_prng_key,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ std::unique_ptr<SessionId> session_id,
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ AsyncAbort* async_abort = nullptr);
+
+ ~SecAggClientR2MaskedInputCollInputSetState() override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > HandleMessage(
+ const ServerToClientWrapperMessage& message) override;
+
+ // Returns the name of this state, "R2_MASKED_INPUT_COLL_INPUT_SET".
+ std::string StateName() const override;
+
+ private:
+ const uint32_t client_id_;
+ const uint32_t minimum_surviving_neighbors_for_reconstruction_;
+ uint32_t number_of_alive_neighbors_;
+ const uint32_t number_of_neighbors_;
+ std::unique_ptr<SecAggVectorMap> input_map_;
+ std::unique_ptr<std::vector<InputVectorSpecification> > input_vector_specs_;
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states_;
+ std::unique_ptr<std::vector<AesKey> > other_client_enc_keys_;
+ std::unique_ptr<std::vector<AesKey> > other_client_prng_keys_;
+ std::unique_ptr<ShamirShare> own_self_key_share_;
+ std::unique_ptr<AesKey> self_prng_key_;
+ std::unique_ptr<SessionId> session_id_;
+ std::unique_ptr<AesPrngFactory> prng_factory_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_INPUT_SET_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state_test.cc b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state_test.cc
new file mode 100644
index 0000000..c59e917
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state_test.cc
@@ -0,0 +1,612 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
+#include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/aes_gcm_encryption.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/map_of_masks.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::Pointee;
+
+// For testing purposes, make an AesKey out of a string.
+AesKey MakeAesKey(const std::string& key) {
+ EXPECT_THAT(key.size(), Eq(AesKey::kSize));
+ return AesKey(reinterpret_cast<const uint8_t*>(key.c_str()));
+}
+
+// Default test session_id.
+SessionId session_id = {"session id number, 32 bytes long."};
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest, IsAbortedReturnsFalse) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r2_state.IsAborted(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest,
+ IsCompletedSuccessfullyReturnsFalse) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r2_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest, StartRaisesErrorStatus) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r2_state.Start().ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest,
+ SetInputRaisesErrorStatus) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r2_state.SetInput(std::make_unique<SecAggVectorMap>()).ok(),
+ Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest,
+ ErrorMessageRaisesErrorStatus) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(r2_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest,
+ AbortReturnsValidAbortStateAndNotifiesServer) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::string error_string =
+ "Abort upon external request for reason <Abort reason>.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.Abort("Abort reason");
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest,
+ AbortFailureMessageCausesAbortWithoutNotifyingServer) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(),
+ Eq("Aborting because of abort message from the server."));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest,
+ EarlySuccessMessageCausesTransitionToCompletedState) {
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_map), std::move(input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("test 32 byte AES encryption key.")),
+ std::make_unique<std::vector<AesKey> >(
+ 6, MakeAesKey("other test 32 byte AES prng key.")),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(true);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("COMPLETED"));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest,
+ MaskedInputCollectionRequestIsHandledCorrectlyWhenNoClientsDie) {
+ // In this test, the client under test is id 1, and there are 4 clients, all
+ // alive.
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({2, 4, 6, 8}, 32));
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ std::vector<AesKey> enc_keys = {
+ MakeAesKey("other client encryption key 0000"),
+ MakeAesKey("other client encryption key 1111"),
+ MakeAesKey("other client encryption key 2222"),
+ MakeAesKey("other client encryption key 3333")};
+ std::vector<AesKey> other_client_prng_keys = {
+ MakeAesKey("other client pairwise prng key 0"), AesKey(),
+ MakeAesKey("other client pairwise prng key 2"),
+ MakeAesKey("other client pairwise prng key 3")};
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::move(input_map),
+ std::make_unique<std::vector<InputVectorSpecification> >(
+ input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 4, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(enc_keys),
+ std::make_unique<std::vector<AesKey> >(other_client_prng_keys),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ // These are strings because they're for putting into protocol buffers.
+ std::vector<std::string> expected_self_key_shares = {
+ "shared self prng key for client #000", "",
+ "shared self prng key for client #222",
+ "shared self prng key for client #333"};
+ std::vector<std::string> expected_pairwise_key_shares = {
+ "shared pairwise prng key for client0", "",
+ "shared pairwise prng key for client2",
+ "shared pairwise prng key for client3"};
+
+ ServerToClientWrapperMessage message;
+ AesGcmEncryption encryptor;
+ for (int i = 0; i < 4; ++i) {
+ PairOfKeyShares key_shares_pair;
+ key_shares_pair.set_noise_sk_share(expected_pairwise_key_shares[i]);
+ key_shares_pair.set_prf_sk_share(expected_self_key_shares[i]);
+ std::string serialized_pair = key_shares_pair.SerializeAsString();
+ std::string ciphertext =
+ encryptor.Encrypt(enc_keys[i], serialized_pair);
+ message.mutable_masked_input_request()->add_encrypted_key_shares(
+ ciphertext);
+ }
+
+ std::vector<AesKey> prng_keys_to_add = {
+ MakeAesKey("test 32 byte AES self prng key. "),
+ other_client_prng_keys[0]};
+ std::vector<AesKey> prng_keys_to_subtract = {other_client_prng_keys[2],
+ other_client_prng_keys[3]};
+
+ auto map_of_masks =
+ MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ session_id, AesCtrPrngFactory());
+
+ std::vector<uint64_t> mask_vec = map_of_masks->at("test").GetAsUint64Vector();
+ std::vector<uint64_t> input_vec = {2, 4, 6, 8};
+ std::vector<uint64_t> sum_vec;
+ uint64_t bit_width_bound = 32;
+
+ for (int i = 0; i < 4; ++i) {
+ sum_vec.push_back((mask_vec[i] + input_vec[i]) % bit_width_bound);
+ }
+ MaskedInputVector sum_vec_proto;
+ sum_vec_proto.set_encoded_vector(
+ SecAggVector(sum_vec, 32).GetAsPackedBytes());
+ ClientToServerWrapperMessage expected_message;
+ (*expected_message.mutable_masked_input_response()
+ ->mutable_vectors())["test"] = sum_vec_proto;
+
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("R3_UNMASKING"));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest,
+ MaskedInputCollectionRequestIsHandledCorrectlyWithDeadClient) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // Client 3 has died in this round.
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ std::vector<AesKey> enc_keys = {
+ MakeAesKey("other client encryption key 0000"),
+ MakeAesKey("other client encryption key 1111"),
+ MakeAesKey("other client encryption key 2222"),
+ MakeAesKey("other client encryption key 3333")};
+ std::vector<AesKey> other_client_prng_keys = {
+ MakeAesKey("other client pairwise prng key 0"), AesKey(),
+ MakeAesKey("other client pairwise prng key 2"),
+ MakeAesKey("other client pairwise prng key 3")};
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::move(input_map),
+ std::make_unique<std::vector<InputVectorSpecification> >(
+ input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 4, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(enc_keys),
+ std::make_unique<std::vector<AesKey> >(other_client_prng_keys),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::vector<std::string> expected_self_key_shares = {
+ "shared self prng key for client #000", "",
+ "shared self prng key for client #222",
+ "shared self prng key for client #333"};
+ std::vector<std::string> expected_pairwise_key_shares = {
+ "shared pairwise prng key for client0", "",
+ "shared pairwise prng key for client2",
+ "shared pairwise prng key for client3"};
+
+ ServerToClientWrapperMessage message;
+ AesGcmEncryption encryptor;
+ for (int i = 0; i < 3; ++i) { // Exclude client 3.
+ PairOfKeyShares key_shares_pair;
+ key_shares_pair.set_noise_sk_share(expected_pairwise_key_shares[i]);
+ key_shares_pair.set_prf_sk_share(expected_self_key_shares[i]);
+ std::string serialized_pair = key_shares_pair.SerializeAsString();
+ std::string ciphertext =
+ encryptor.Encrypt(enc_keys[i], serialized_pair);
+ message.mutable_masked_input_request()->add_encrypted_key_shares(
+ ciphertext);
+ }
+ message.mutable_masked_input_request()->add_encrypted_key_shares("");
+
+ std::vector<AesKey> prng_keys_to_add = {
+ MakeAesKey("test 32 byte AES self prng key. "),
+ other_client_prng_keys[0]};
+ std::vector<AesKey> prng_keys_to_subtract = {other_client_prng_keys[2]};
+
+ auto map_of_masks =
+ MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ session_id, AesCtrPrngFactory());
+
+ std::vector<uint64_t> mask_vec = map_of_masks->at("test").GetAsUint64Vector();
+ std::vector<uint64_t> input_vec = {2, 4, 6, 8};
+ std::vector<uint64_t> sum_vec;
+ uint64_t bit_width_bound = 32;
+
+ for (int i = 0; i < 4; ++i) {
+ sum_vec.push_back((mask_vec[i] + input_vec[i]) % bit_width_bound);
+ }
+ MaskedInputVector sum_vec_proto;
+ sum_vec_proto.set_encoded_vector(
+ SecAggVector(sum_vec, 32).GetAsPackedBytes());
+ ClientToServerWrapperMessage expected_message;
+ (*expected_message.mutable_masked_input_response()
+ ->mutable_vectors())["test"] = sum_vec_proto;
+
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("R3_UNMASKING"));
+}
+
+TEST(SecAggClientR2MaskedInputCollInputSetStateTest,
+ MaskedInputCollectionRequestCausesAbortIfTooManyDeadClients) {
+ // In this test, the client under test is id 1, and there are 4 clients.
+ // Clients 3 and 4 died, and we need 3 clients to continue, so we should
+ // abort.
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->insert(std::make_pair("test", SecAggVector({2, 4, 6, 8}, 32)));
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ std::vector<AesKey> enc_keys = {
+ MakeAesKey("other client encryption key 0000"),
+ MakeAesKey("other client encryption key 1111"),
+ MakeAesKey("other client encryption key 2222"),
+ MakeAesKey("other client encryption key 3333")};
+ std::vector<AesKey> other_client_prng_keys = {
+ MakeAesKey("other client pairwise prng key 0"), AesKey(),
+ MakeAesKey("other client pairwise prng key 2"),
+ MakeAesKey("other client pairwise prng key 3")};
+ SecAggClientR2MaskedInputCollInputSetState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::move(input_map),
+ std::make_unique<std::vector<InputVectorSpecification> >(
+ input_vector_specs),
+ std::make_unique<std::vector<OtherClientState> >(
+ 4, OtherClientState::kAlive),
+ std::make_unique<std::vector<AesKey> >(enc_keys),
+ std::make_unique<std::vector<AesKey> >(other_client_prng_keys),
+ std::make_unique<ShamirShare>(),
+ std::make_unique<AesKey>(MakeAesKey("test 32 byte AES self prng key. ")),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<SessionId>(session_id),
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::vector<std::string> expected_self_key_shares = {
+ "shared self prng key for client #000", "",
+ "shared self prng key for client #222",
+ "shared self prng key for client #333"};
+ std::vector<std::string> expected_pairwise_key_shares = {
+ "shared pairwise prng key for client0", "",
+ "shared pairwise prng key for client2",
+ "shared pairwise prng key for client3"};
+
+ ServerToClientWrapperMessage message;
+ AesGcmEncryption encryptor;
+ for (int i = 0; i < 2; ++i) { // Exclude clients 3 & 4
+ PairOfKeyShares key_shares_pair;
+ key_shares_pair.set_noise_sk_share(expected_pairwise_key_shares[i]);
+ key_shares_pair.set_prf_sk_share(expected_self_key_shares[i]);
+ std::string serialized_pair = key_shares_pair.SerializeAsString();
+ std::string ciphertext =
+ encryptor.Encrypt(enc_keys[i], serialized_pair);
+ message.mutable_masked_input_request()->add_encrypted_key_shares(
+ ciphertext);
+ }
+ message.mutable_masked_input_request()->add_encrypted_key_shares("");
+ message.mutable_masked_input_request()->add_encrypted_key_shares("");
+
+ std::string error_string =
+ "There are not enough clients to complete this protocol session. "
+ "Aborting.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(message);
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.cc b/fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.cc
new file mode 100644
index 0000000..6103fc6
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.cc
@@ -0,0 +1,122 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/container/node_hash_map.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
+#include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientR2MaskedInputCollWaitingForInputState::
+ SecAggClientR2MaskedInputCollWaitingForInputState(
+ uint32_t client_id,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_alive_neighbors, uint32_t number_of_neighbors,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<SecAggVectorMap> map_of_masks,
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states,
+ std::unique_ptr<std::vector<ShamirShare> > pairwise_key_shares,
+ std::unique_ptr<std::vector<ShamirShare> > self_key_shares,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ AsyncAbort* async_abort)
+ : SecAggClientR2MaskedInputCollBaseState(
+ std::move(sender), std::move(transition_listener), async_abort),
+ client_id_(client_id),
+ minimum_surviving_neighbors_for_reconstruction_(
+ minimum_surviving_neighbors_for_reconstruction),
+ number_of_alive_neighbors_(number_of_alive_neighbors),
+ number_of_neighbors_(number_of_neighbors),
+ input_vector_specs_(std::move(input_vector_specs)),
+ map_of_masks_(std::move(map_of_masks)),
+ other_client_states_(std::move(other_client_states)),
+ pairwise_key_shares_(std::move(pairwise_key_shares)),
+ self_key_shares_(std::move(self_key_shares)) {
+ FCP_CHECK(client_id_ >= 0)
+ << "Client id must not be negative but was " << client_id_;
+}
+
+SecAggClientR2MaskedInputCollWaitingForInputState::
+ ~SecAggClientR2MaskedInputCollWaitingForInputState() = default;
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR2MaskedInputCollWaitingForInputState::HandleMessage(
+ const ServerToClientWrapperMessage& message) {
+ // Handle abort messages only.
+ if (message.has_abort()) {
+ if (message.abort().early_success()) {
+ return {std::make_unique<SecAggClientCompletedState>(
+ std::move(sender_), std::move(transition_listener_))};
+ } else {
+ return {std::make_unique<SecAggClientAbortedState>(
+ "Aborting because of abort message from the server.",
+ std::move(sender_), std::move(transition_listener_))};
+ }
+ } else {
+ // Returns an error indicating that the message is of invalid type.
+ return SecAggClientState::HandleMessage(message);
+ }
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR2MaskedInputCollWaitingForInputState::SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map) {
+ // Only need to do 3 things: Validate input, send message to server, and
+ // return the new state.
+ if (!ValidateInput(*input_map, *input_vector_specs_)) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "The input to SetInput does not match the "
+ "InputVectorSpecification.";
+ }
+
+ SendMaskedInput(std::move(input_map), std::move(map_of_masks_));
+
+ return {std::make_unique<SecAggClientR3UnmaskingState>(
+ client_id_, number_of_alive_neighbors_,
+ minimum_surviving_neighbors_for_reconstruction_, number_of_neighbors_,
+ std::move(other_client_states_), std::move(pairwise_key_shares_),
+ std::move(self_key_shares_), std::move(sender_),
+ std::move(transition_listener_), async_abort_)};
+}
+
+std::string SecAggClientR2MaskedInputCollWaitingForInputState::StateName()
+ const {
+ return "R2_MASKED_INPUT_COLL_WAITING_FOR_INPUT";
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.h b/fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.h
new file mode 100644
index 0000000..c167d60
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.h
@@ -0,0 +1,92 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_WAITING_FOR_INPUT_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_WAITING_FOR_INPUT_STATE_H_
+
+#include <cstdint>
+#include <memory>
+#include <set>
+#include <string>
+
+#include "absl/container/node_hash_map.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the client's Round 2: Masked Input Collection state
+// where the client has already received a message from the server, but is
+// waiting for the external protocol to set the client's input before replying.
+
+// This state should transition to the Round 3: Unmasking state, but can also
+// transition directly to the Completed or Aborted states.
+
+class SecAggClientR2MaskedInputCollWaitingForInputState
+ : public SecAggClientR2MaskedInputCollBaseState {
+ public:
+ SecAggClientR2MaskedInputCollWaitingForInputState(
+ uint32_t client_id,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_alive_neighbors, uint32_t number_of_neighbors,
+ std::unique_ptr<std::vector<InputVectorSpecification> >
+ input_vector_specs,
+ std::unique_ptr<SecAggVectorMap> map_of_masks,
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states,
+ std::unique_ptr<std::vector<ShamirShare> > pairwise_key_shares,
+ std::unique_ptr<std::vector<ShamirShare> > self_key_shares,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ AsyncAbort* async_abort = nullptr);
+
+ ~SecAggClientR2MaskedInputCollWaitingForInputState() override;
+
+ // This state handles only abort/early success messages. All others raise an
+ // error status.
+ StatusOr<std::unique_ptr<SecAggClientState> > HandleMessage(
+ const ServerToClientWrapperMessage& message) override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map) override;
+
+ // Returns the name of this state, "R2_MASKED_INPUT_COLL_WAITING_FOR_INPUT".
+ ABSL_MUST_USE_RESULT std::string StateName() const override;
+
+ private:
+ const uint32_t client_id_;
+ const uint32_t minimum_surviving_neighbors_for_reconstruction_;
+ uint32_t number_of_alive_neighbors_;
+ const uint32_t number_of_neighbors_;
+ std::unique_ptr<std::vector<InputVectorSpecification> > input_vector_specs_;
+ std::unique_ptr<SecAggVectorMap> map_of_masks_;
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states_;
+ std::unique_ptr<std::vector<ShamirShare> > pairwise_key_shares_;
+ std::unique_ptr<std::vector<ShamirShare> > self_key_shares_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R2_MASKED_INPUT_COLL_WAITING_FOR_INPUT_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state_test.cc b/fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state_test.cc
new file mode 100644
index 0000000..4c686e8
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state_test.cc
@@ -0,0 +1,482 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
+#include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/map_of_masks.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::Pointee;
+
+static const ShamirShare test_pairwise_share = {"test pairwise share"};
+static const ShamirShare test_self_share = {"test self share"};
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ IsAbortedReturnsFalse) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({1, 2, 3, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ EXPECT_THAT(r2_state.IsAborted(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ IsCompletedSuccessfullyReturnsFalse) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({1, 2, 3, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ EXPECT_THAT(r2_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ StartRaisesErrorStatus) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({1, 2, 3, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ EXPECT_THAT(r2_state.Start().ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ ErrorMessageRaisesErrorStatus) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({1, 2, 3, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ EXPECT_THAT(r2_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ AbortReturnsValidAbortStateAndNotifiesServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({1, 2, 3, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ std::string error_string =
+ "Abort upon external request for reason <Abort reason>.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.Abort("Abort reason");
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ AbortFailureMessageCausesAbortWithoutNotifyingServer) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({1, 2, 3, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(),
+ Eq("Aborting because of abort message from the server."));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ EarlySuccessMessageCausesTransitionToCompletedState) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({1, 2, 3, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_alive_neighbors
+ 6, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(true);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("COMPLETED"));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ SetInputCausesClientResponseAndRound3Transition) {
+ // In this test, the client under test is id 1, and there are 4 clients, all
+ // alive.
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({28, 8, 10, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
+
+ std::vector<uint64_t> sum_vec = {1, 16, 0, 2};
+
+ MaskedInputVector sum_vec_proto;
+ sum_vec_proto.set_encoded_vector(
+ SecAggVector(sum_vec, 32).GetAsPackedBytes());
+ ClientToServerWrapperMessage expected_message;
+ (*expected_message.mutable_masked_input_response()
+ ->mutable_vectors())["test"] = sum_vec_proto;
+
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), Eq("R3_UNMASKING"));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ SetInputRaisesErrorStatusIfInputVectorIsWrongSize) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({28, 8, 10, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector has too many elements.
+ input_map->emplace("test", SecAggVector({5, 8, 22, 30, 7}, 32));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ SetInputRaisesErrorStatusIfInputVectorIsTooLargeForBitWidth) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({28, 8, 10, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector's bit_width does not match the specified modulus of 32.
+ input_map->emplace("test", SecAggVector({5, 8, 22, 40}, 64));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ SetInputRaisesErrorStatusIfInputVectorHasWrongName) {
+ // In this test, the client under test is id 1, and there are 4 clients, all
+ // alive.
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({28, 8, 10, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ // This vector has the wrong name.
+ input_map->emplace("incorrect", SecAggVector({5, 8, 22, 30}, 32));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ SetInputRaisesErrorStatusIfInputHasTooManyVectors) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({28, 8, 10, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
+ // This vector is extra.
+ input_map->emplace("test2", SecAggVector({4, 7, 21, 29}, 32));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+TEST(SecAggClientR2MaskedInputCollWaitingForInputStateTest,
+ SetInputRaisesErrorStatusIfInputHasTooFewVectors) {
+ auto input_vector_specs =
+ std::make_unique<std::vector<InputVectorSpecification> >();
+ input_vector_specs->push_back(InputVectorSpecification("test", 4, 32));
+ // Expects two vectors.
+ input_vector_specs->push_back(InputVectorSpecification("test2", 4, 32));
+ auto map_of_masks = std::make_unique<SecAggVectorMap>();
+ map_of_masks->emplace("test", SecAggVector({28, 8, 10, 4}, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR2MaskedInputCollWaitingForInputState r2_state(
+ 1, // client_id
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ 4, // number_of_alive_neighbors
+ 4, // number_of_neighbors
+ std::move(input_vector_specs), std::move(map_of_masks),
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
+ // Missing second vector.
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r2_state.SetInput(std::move(input_map));
+ EXPECT_THAT(new_state.ok(), Eq(false));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r3_unmasking_state.cc b/fcp/secagg/client/secagg_client_r3_unmasking_state.cc
new file mode 100644
index 0000000..e461b21
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r3_unmasking_state.cc
@@ -0,0 +1,168 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_alive_base_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientR3UnmaskingState::SecAggClientR3UnmaskingState(
+ uint32_t client_id, uint32_t number_of_alive_neighbors,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_neighbors,
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states,
+ std::unique_ptr<std::vector<ShamirShare> > pairwise_key_shares,
+ std::unique_ptr<std::vector<ShamirShare> > self_key_shares,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ AsyncAbort* async_abort)
+ : SecAggClientAliveBaseState(std::move(sender),
+ std::move(transition_listener),
+ ClientState::R3_UNMASKING, async_abort),
+ client_id_(client_id),
+ number_of_alive_neighbors_(number_of_alive_neighbors),
+ minimum_surviving_neighbors_for_reconstruction_(
+ minimum_surviving_neighbors_for_reconstruction),
+ number_of_neighbors_(number_of_neighbors),
+ other_client_states_(std::move(other_client_states)),
+ pairwise_key_shares_(std::move(pairwise_key_shares)),
+ self_key_shares_(std::move(self_key_shares)) {
+ FCP_CHECK(client_id_ >= 0)
+ << "Client id must not be negative but was " << client_id_;
+}
+
+SecAggClientR3UnmaskingState::~SecAggClientR3UnmaskingState() = default;
+
+StatusOr<std::unique_ptr<SecAggClientState> >
+SecAggClientR3UnmaskingState::HandleMessage(
+ const ServerToClientWrapperMessage& message) {
+ // Handle abort messages or unmasking requests only.
+ if (message.has_abort()) {
+ if (message.abort().early_success()) {
+ return {std::make_unique<SecAggClientCompletedState>(
+ std::move(sender_), std::move(transition_listener_))};
+ } else {
+ return {std::make_unique<SecAggClientAbortedState>(
+ "Aborting because of abort message from the server.",
+ std::move(sender_), std::move(transition_listener_))};
+ }
+ } else if (!message.has_unmasking_request()) {
+ // Returns an error indicating that the message is of invalid type.
+ return SecAggClientState::HandleMessage(message);
+ }
+ if (async_abort_ && async_abort_->Signalled())
+ return AbortAndNotifyServer(async_abort_->Message());
+
+ const UnmaskingRequest& request = message.unmasking_request();
+ std::set<uint32_t> dead_at_round_3_client_ids;
+
+ // Parse incoming request and mark dead clients as dead.
+ for (uint32_t i : request.dead_3_client_ids()) {
+ // TODO(team): Remove this once backwards compatibility not needed.
+ uint32_t id = i - 1;
+ if (id == client_id_) {
+ return AbortAndNotifyServer(
+ "The received UnmaskingRequest states this client has aborted, but "
+ "this client had not yet aborted.");
+ } else if (id >= number_of_neighbors_) {
+ return AbortAndNotifyServer(
+ "The received UnmaskingRequest contains a client id that does not "
+ "correspond to any client.");
+ }
+ switch ((*other_client_states_)[id]) {
+ case OtherClientState::kAlive:
+ (*other_client_states_)[id] = OtherClientState::kDeadAtRound3;
+ --number_of_alive_neighbors_;
+ break;
+ case OtherClientState::kDeadAtRound3:
+ return AbortAndNotifyServer(
+ "The received UnmaskingRequest repeated a client more than once "
+ "as a dead client.");
+ break;
+ case OtherClientState::kDeadAtRound1:
+ case OtherClientState::kDeadAtRound2:
+ default:
+ return AbortAndNotifyServer(
+ "The received UnmaskingRequest considers a client dead in round 3 "
+ "that was already considered dead.");
+ break;
+ }
+ }
+
+ if (number_of_alive_neighbors_ <
+ minimum_surviving_neighbors_for_reconstruction_) {
+ return AbortAndNotifyServer(
+ "Not enough clients survived. The server should not have sent this "
+ "UnmaskingRequest.");
+ }
+
+ /*
+ * Construct a response for the server by choosing the appropriate shares for
+ * each client (i.e. the pairwise share if the client died at round 3, the
+ * self share if the client is alive, or no shares at all if the client died
+ * at or before round 2.
+ */
+ ClientToServerWrapperMessage message_to_server;
+ UnmaskingResponse* unmasking_response =
+ message_to_server.mutable_unmasking_response();
+ for (uint32_t i = 0; i < number_of_neighbors_; ++i) {
+ if (async_abort_ && async_abort_->Signalled())
+ return AbortAndNotifyServer(async_abort_->Message());
+ switch ((*other_client_states_)[i]) {
+ case OtherClientState::kAlive:
+ unmasking_response->add_noise_or_prf_key_shares()->set_prf_sk_share(
+ (*self_key_shares_)[i].data);
+ break;
+ case OtherClientState::kDeadAtRound3:
+ unmasking_response->add_noise_or_prf_key_shares()->set_noise_sk_share(
+ (*pairwise_key_shares_)[i].data);
+ break;
+ case OtherClientState::kDeadAtRound1:
+ case OtherClientState::kDeadAtRound2:
+ default:
+ unmasking_response->add_noise_or_prf_key_shares();
+ break;
+ }
+ }
+
+ // Send this final message to the server, then enter Completed state.
+ sender_->Send(&message_to_server);
+ return {std::make_unique<SecAggClientCompletedState>(
+ std::move(sender_), std::move(transition_listener_))};
+}
+
+std::string SecAggClientR3UnmaskingState::StateName() const {
+ return "R3_UNMASKING";
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_r3_unmasking_state.h b/fcp/secagg/client/secagg_client_r3_unmasking_state.h
new file mode 100644
index 0000000..2ba848a
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r3_unmasking_state.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_R3_UNMASKING_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_R3_UNMASKING_STATE_H_
+
+#include <cstdint>
+#include <memory>
+#include <set>
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/other_client_state.h"
+#include "fcp/secagg/client/secagg_client_alive_base_state.h"
+#include "fcp/secagg/client/secagg_client_state.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents the client's Round 3: Unmasking state. This state
+// should transition to the Completed state, but can also transition to the
+// Aborted state.
+
+class SecAggClientR3UnmaskingState : public SecAggClientAliveBaseState {
+ public:
+ SecAggClientR3UnmaskingState(
+ uint32_t client_id, uint32_t number_of_alive_neighbors,
+ uint32_t minimum_surviving_neighbors_for_reconstruction,
+ uint32_t number_of_neighbors,
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states,
+ std::unique_ptr<std::vector<ShamirShare> > pairwise_key_shares,
+ std::unique_ptr<std::vector<ShamirShare> > self_key_shares,
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+
+ AsyncAbort* async_abort = nullptr);
+
+ ~SecAggClientR3UnmaskingState() override;
+
+ StatusOr<std::unique_ptr<SecAggClientState> > HandleMessage(
+ const ServerToClientWrapperMessage& message) override;
+
+ // Returns the name of this state, "R3_UNMASKING".
+ std::string StateName() const override;
+
+ private:
+ const uint32_t client_id_;
+ uint32_t number_of_alive_neighbors_;
+ const uint32_t minimum_surviving_neighbors_for_reconstruction_;
+ const uint32_t number_of_neighbors_;
+ std::unique_ptr<std::vector<OtherClientState> > other_client_states_;
+ std::unique_ptr<std::vector<ShamirShare> > pairwise_key_shares_;
+ std::unique_ptr<std::vector<ShamirShare> > self_key_shares_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_R3_UNMASKING_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_r3_unmasking_state_test.cc b/fcp/secagg/client/secagg_client_r3_unmasking_state_test.cc
new file mode 100644
index 0000000..2369890
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_r3_unmasking_state_test.cc
@@ -0,0 +1,480 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/secagg_client_aborted_state.h"
+#include "fcp/secagg/client/secagg_client_completed_state.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::Pointee;
+using ::testing::StrEq;
+
+static const ShamirShare test_pairwise_share = {"test pairwise share"};
+static const ShamirShare test_self_share = {"test self share"};
+
+TEST(SecAggClientR3UnmaskingStateTest, IsAbortedReturnsFalse) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR3UnmaskingState r3_state(
+ 0, 4, 4, 4, std::make_unique<std::vector<OtherClientState> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+ EXPECT_THAT(r3_state.IsAborted(), Eq(false));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest, IsCompletedSuccessfullyReturnsFalse) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR3UnmaskingState r3_state(
+ 0, 4, 4, 4, std::make_unique<std::vector<OtherClientState> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+ EXPECT_THAT(r3_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest, StartRaisesErrorStatus) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR3UnmaskingState r3_state(
+ 0, 4, 4, 4, std::make_unique<std::vector<OtherClientState> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+ EXPECT_THAT(r3_state.Start().ok(), Eq(false));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest, SetInputRaisesErrorStatus) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR3UnmaskingState r3_state(
+ 0, 4, 4, 4, std::make_unique<std::vector<OtherClientState> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+ EXPECT_THAT(r3_state.SetInput(std::make_unique<SecAggVectorMap>()).ok(),
+ Eq(false));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest, ErrorMessageRaisesErrorStatus) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR3UnmaskingState r3_state(
+ 0, 4, 4, 4, std::make_unique<std::vector<OtherClientState> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+ EXPECT_THAT(r3_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ AbortReturnsValidAbortStateAndNotifiesServer) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ std::string error_string =
+ "Abort upon external request for reason <Abort reason>.";
+
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ SecAggClientR3UnmaskingState r3_state(
+ 0, 4, 4, 4, std::make_unique<std::vector<OtherClientState> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::make_unique<std::vector<ShamirShare> >(4),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.Abort("Abort reason");
+ ASSERT_THAT(new_state.ok(), Eq(true));
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), StrEq(error_string));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ AbortFailureMessageCausesAbortWithoutNotifyingServer) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR3UnmaskingState r3_state(
+ 1, // client_id
+ 6, // number_of_alive_neighbors
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_neighbors
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(),
+ StrEq("Aborting because of abort message from the server."));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ EarlySuccessMessageCausesTransitionToCompletedState) {
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR3UnmaskingState r3_state(
+ 1, // client_id
+ 6, // number_of_alive_neighbors
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_neighbors
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ EXPECT_CALL(*sender, Send(::testing::_)).Times(0);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(true);
+
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.HandleMessage(abort_message);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("COMPLETED"));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ UnmaskingRequestIsCorrectlyHandledWhenNoClientsDie) {
+ // In this test, this is client id 1. There are 6 clients, and none of them
+ // drop out.
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClientR3UnmaskingState r3_state(
+ 1, // client_id
+ 6, // number_of_alive_neighbors
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_neighbors
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ ClientToServerWrapperMessage expected_message;
+ for (int i = 0; i < 6; i++) {
+ expected_message.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares()
+ ->set_prf_sk_share("test self share");
+ }
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ ServerToClientWrapperMessage request;
+ request.mutable_unmasking_request();
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.HandleMessage(request);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("COMPLETED"));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ UnmaskingRequestIsCorrectlyHandledWhenFewClientsDie) {
+ // In this test, this is client id 1. Client 3 already died at round 2, and
+ // client 5 dies in round 3.
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ std::vector<OtherClientState> other_clients_states{
+ OtherClientState::kAlive, OtherClientState::kAlive,
+ OtherClientState::kAlive, OtherClientState::kDeadAtRound2,
+ OtherClientState::kAlive, OtherClientState::kAlive};
+
+ SecAggClientR3UnmaskingState r3_state(
+ 1, // client_id
+ 5, // number_of_alive_neighbors
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_neighbors
+ std::make_unique<std::vector<OtherClientState> >(other_clients_states),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ ClientToServerWrapperMessage expected_message;
+ for (int i = 0; i < 6; i++) {
+ if (i == 3) {
+ expected_message.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ } else if (i == 5) {
+ expected_message.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares()
+ ->set_noise_sk_share("test pairwise share");
+ } else {
+ expected_message.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares()
+ ->set_prf_sk_share("test self share");
+ }
+ }
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ ServerToClientWrapperMessage request;
+ // TODO(team): 6 -> 5 below, once backwards compatibility not needed.
+ request.mutable_unmasking_request()->add_dead_3_client_ids(6);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.HandleMessage(request);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("COMPLETED"));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ UnmaskingRequestCausesAbortWhenTooManyClientsDie) {
+ // In this test, this is client id 1. Client 3 already died at round 2, and
+ // clients 4 and 5 die in round 3. This should cause a transition to an abort
+ // state and an abort message to be sent to the server.
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ std::vector<OtherClientState> other_clients_states{
+ OtherClientState::kAlive, OtherClientState::kAlive,
+ OtherClientState::kAlive, OtherClientState::kDeadAtRound2,
+ OtherClientState::kAlive, OtherClientState::kAlive};
+
+ SecAggClientR3UnmaskingState r3_state(
+ 1, // client_id
+ 5, // number_of_alive_neighbors
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_neighbors
+ std::make_unique<std::vector<OtherClientState> >(other_clients_states),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ std::string error_string =
+ "Not enough clients survived. The server should not have sent this "
+ "UnmaskingRequest.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ ServerToClientWrapperMessage request;
+ // TODO(team): 5 -> 4 below, once backwards compatibility not needed.
+ request.mutable_unmasking_request()->add_dead_3_client_ids(5);
+ // TODO(team): 6 -> 5 below, once backwards compatibility not needed.
+ request.mutable_unmasking_request()->add_dead_3_client_ids(6);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.HandleMessage(request);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), StrEq(error_string));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ UnmaskingRequestCausesAbortIfServerListsThisClientAsDead) {
+ // In this test, this is client id 1, but the server lists client 1 as dead.
+ // This should cause a transition to an abort state and an abort message to be
+ // sent to the server.
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClientR3UnmaskingState r3_state(
+ 1, // client_id
+ 6, // number_of_alive_neighbors
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_neighbors
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ std::string error_string =
+ "The received UnmaskingRequest states this client has aborted, but this "
+ "client had not yet aborted.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ ServerToClientWrapperMessage request;
+ // TODO(team): 2 -> 1 below, once backwards compatibility not needed.
+ request.mutable_unmasking_request()->add_dead_3_client_ids(2);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.HandleMessage(request);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), StrEq(error_string));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ UnmaskingRequestCausesAbortIfServerListsNonexistentClientAsDead) {
+ // In this test, there are 6 clients (labeled 0-5), but the server lists
+ // client 6 as dead. This should cause a transition to an abort state and an
+ // abort message to be sent to the server.
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClientR3UnmaskingState r3_state(
+ 1, // client_id
+ 6, // number_of_alive_neighbors
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_neighbors
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ std::string error_string =
+ "The received UnmaskingRequest contains a client id that does not "
+ "correspond to any client.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ ServerToClientWrapperMessage request;
+ // TODO(team): 7 -> 6 below, once backwards compatibility not needed.
+ request.mutable_unmasking_request()->add_dead_3_client_ids(7);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.HandleMessage(request);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), StrEq(error_string));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ UnmaskingRequestCausesAbortIfServerListsClientThatAlreadyDied) {
+ // In this test, client 3 died at round 1, but the server lists client 3 as
+ // dead at round 3. This should cause a transition to an abort state and an
+ // abort message to be sent to the server.
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ std::vector<OtherClientState> other_clients_states{
+ OtherClientState::kAlive, OtherClientState::kAlive,
+ OtherClientState::kAlive, OtherClientState::kDeadAtRound1,
+ OtherClientState::kAlive, OtherClientState::kAlive};
+
+ SecAggClientR3UnmaskingState r3_state(
+ 1, // client_id
+ 5, // number_of_alive_neighbors
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_neighbors
+ std::make_unique<std::vector<OtherClientState> >(other_clients_states),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ std::string error_string =
+ "The received UnmaskingRequest considers a client dead in round 3 "
+ "that was already considered dead.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ ServerToClientWrapperMessage request;
+ // TODO(team): 4 -> 3 below, once backwards compatibility not needed.
+ request.mutable_unmasking_request()->add_dead_3_client_ids(4);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.HandleMessage(request);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), StrEq(error_string));
+}
+
+TEST(SecAggClientR3UnmaskingStateTest,
+ UnmaskingRequestCausesAbortIfServerListsSameClientTwice) {
+ // In this test, the server lists client 5 as dead at round 3 twice. This
+ // should cause a transition to an abort state and an abort message to be sent
+ // to the server.
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClientR3UnmaskingState r3_state(
+ 1, // client_id
+ 6, // number_of_alive_neighbors
+ 4, // minimum_surviving_neighbors_for_reconstruction
+ 6, // number_of_neighbors
+ std::make_unique<std::vector<OtherClientState> >(
+ 6, OtherClientState::kAlive),
+ std::make_unique<std::vector<ShamirShare> >(6, test_pairwise_share),
+ std::make_unique<std::vector<ShamirShare> >(6, test_self_share),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener));
+
+ std::string error_string =
+ "The received UnmaskingRequest repeated a client more than once as a "
+ "dead client.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ ServerToClientWrapperMessage request;
+ // TODO(team): 6 -> 5 below, once backwards compatibility not needed.
+ request.mutable_unmasking_request()->add_dead_3_client_ids(6);
+ request.mutable_unmasking_request()->add_dead_3_client_ids(6);
+ StatusOr<std::unique_ptr<SecAggClientState> > new_state =
+ r3_state.HandleMessage(request);
+ ASSERT_TRUE(new_state.ok());
+ EXPECT_THAT(new_state.value()->StateName(), StrEq("ABORTED"));
+ EXPECT_THAT(new_state.value()->ErrorMessage().value(), StrEq(error_string));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_state.cc b/fcp/secagg/client/secagg_client_state.cc
new file mode 100644
index 0000000..310e803
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_state.cc
@@ -0,0 +1,112 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client_state.h"
+
+#include <string>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+// The methods implemented here should be overridden by state classes from
+// which these transitions are valid, and inherited by state classes from which
+// they are invalid. For example, only round 0 classes should override the Start
+// method.
+//
+// Classes that return booleans should only be overridden by state classes for
+// which they will return true.
+
+namespace fcp {
+namespace secagg {
+
+SecAggClientState::SecAggClientState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ ClientState state)
+ : sender_(std::move(sender)),
+ transition_listener_(std::move(transition_listener)),
+ state_(state) {
+ transition_listener_->Transition(state_);
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> > SecAggClientState::Start() {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "An illegal start transition was attempted from state "
+ << StateName();
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> > SecAggClientState::HandleMessage(
+ const ServerToClientWrapperMessage& message) {
+ if (message.message_content_case() ==
+ ServerToClientWrapperMessage::MESSAGE_CONTENT_NOT_SET) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Client received a message of unknown type but was in state "
+ << StateName();
+ } else {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Client received a message of type "
+ << message.message_content_case() << " but was in state "
+ << StateName();
+ }
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> > SecAggClientState::SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "An illegal input transition was attempted from state "
+ << StateName();
+}
+
+StatusOr<std::unique_ptr<SecAggClientState> > SecAggClientState::Abort(
+ const std::string& reason) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The client was already in terminal state " << StateName()
+ << " but received an abort with message: " << reason;
+}
+
+bool SecAggClientState::IsAborted() const { return false; }
+
+bool SecAggClientState::IsCompletedSuccessfully() const { return false; }
+
+StatusOr<std::string> SecAggClientState::ErrorMessage() const {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Error message requested, but client is in state " << StateName();
+}
+
+bool SecAggClientState::ValidateInput(
+ const SecAggVectorMap& input_map,
+ const std::vector<InputVectorSpecification>& input_vector_specs) {
+ if (input_map.size() != input_vector_specs.size()) {
+ return false;
+ }
+ for (const auto& vector_spec : input_vector_specs) {
+ auto input_vec = input_map.find(vector_spec.name());
+ if (input_vec == input_map.end() ||
+ input_vec->second.modulus() != vector_spec.modulus() ||
+ input_vec->second.num_elements() != vector_spec.length()) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/secagg_client_state.h b/fcp/secagg/client/secagg_client_state.h
new file mode 100644
index 0000000..1a1515a
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_state.h
@@ -0,0 +1,108 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SECAGG_CLIENT_STATE_H_
+#define FCP_SECAGG_CLIENT_SECAGG_CLIENT_STATE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+
+// This is an abstract class which is the parent of the other SecAggClient*State
+// classes. It should not be instantiated directly. Default versions of all the
+// methods declared here are provided for use by states which do not expect, and
+// therefore do not implement, those methods.
+
+class SecAggClientState {
+ public:
+ // Initiates the protocol by computing its first message and sending it to the
+ // server. If called in a valid state, returns the new State. Otherwise,
+ // returns an error Status with code PRECONDITION_FAILED.
+ virtual StatusOr<std::unique_ptr<SecAggClientState> > Start();
+
+ // Handles the received message in a way consistent with the current state.
+ // If called from a state expecting a message, returns the new State. If the
+ // message was of the right type but had invalid contents, the new State will
+ // be a SecAggClientAbortState.
+ // If the state was not expecting a message of this type at all, returns an
+ // error Status with code PRECONDITION_FAILED.
+ virtual StatusOr<std::unique_ptr<SecAggClientState> > HandleMessage(
+ const ServerToClientWrapperMessage& message);
+
+ // Sets the input of this client for this protocol session. If successful,
+ // returns the new state. If the input does not match the specification,
+ // returns an error Status with code INVALID_ARGUMENT.
+ // If the client's state was not ready for an input to be set, returns an
+ // error Status with code PRECONDITION_FAILED.
+ virtual StatusOr<std::unique_ptr<SecAggClientState> > SetInput(
+ std::unique_ptr<SecAggVectorMap> input_map);
+
+ // Aborts the protocol for the specified reason. Returns the new state. If the
+ // protocol was already aborted or completed, instead returns an error Status
+ // with code PRECONDITION_FAILED.
+ virtual StatusOr<std::unique_ptr<SecAggClientState> > Abort(
+ const std::string& reason);
+
+ // Returns true if the current state is Abort, false else.
+ ABSL_MUST_USE_RESULT virtual bool IsAborted() const;
+
+ // Returns true if the current state is ProtocolCompleted, false else.
+ ABSL_MUST_USE_RESULT virtual bool IsCompletedSuccessfully() const;
+
+ // Returns the error message, if the current state is an abort state. If not,
+ // returns an error Status with code PRECONDITION_FAILED.
+ ABSL_MUST_USE_RESULT virtual StatusOr<std::string> ErrorMessage() const;
+
+ // Returns the name of the current state, as a string.
+ ABSL_MUST_USE_RESULT virtual std::string StateName() const = 0;
+
+ virtual ~SecAggClientState() = default;
+
+ protected:
+ // The object that sends messages to the server.
+ std::unique_ptr<SendToServerInterface> sender_;
+ // A listener for state transitions.
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener_;
+ // State type.
+ ClientState state_;
+
+ // SecAggClientState should never be instantiated directly.
+ SecAggClientState(
+ std::unique_ptr<SendToServerInterface> sender,
+ std::unique_ptr<StateTransitionListenerInterface> transition_listener,
+ ClientState state);
+
+ // Validates an input map by returning true if all SecAggVectors match their
+ // corresponding InputVectorSpecifications, and false otherwise.
+ bool ValidateInput(
+ const SecAggVectorMap& input_map,
+ const std::vector<InputVectorSpecification>& input_vector_specs);
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SECAGG_CLIENT_STATE_H_
diff --git a/fcp/secagg/client/secagg_client_test.cc b/fcp/secagg/client/secagg_client_test.cc
new file mode 100644
index 0000000..4956212
--- /dev/null
+++ b/fcp/secagg/client/secagg_client_test.cc
@@ -0,0 +1,262 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/client/secagg_client.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/mock_send_to_server_interface.h"
+#include "fcp/secagg/testing/mock_state_transition_listener.h"
+#include "fcp/testing/testing.h"
+
+// All of the actual client functionality is contained within the
+// SecAggClient*State classes. This class only tests very basic functionality
+// of the containing SecAggClient class.
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::_;
+using ::testing::Eq;
+using ::testing::Pointee;
+
+TEST(SecAggClientTest, ConstructedWithCorrectState) {
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClient client(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ input_vector_specs, std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(client.IsAborted(), Eq(false));
+ EXPECT_THAT(client.IsCompletedSuccessfully(), Eq(false));
+ EXPECT_THAT(client.State(), Eq("R0_ADVERTISE_KEYS_INPUT_NOT_SET"));
+}
+
+TEST(SecAggClientTest, StartCausesStateTransition) {
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+ SecAggClient client(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ input_vector_specs, std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+
+ std::make_unique<AesCtrPrngFactory>());
+
+ // Message correctness is checked in the tests for the Round 0 classes.
+ EXPECT_CALL(*sender, Send(::testing::_));
+ Status result = client.Start();
+
+ EXPECT_THAT(result.code(), Eq(OK));
+ EXPECT_THAT(client.IsAborted(), Eq(false));
+ EXPECT_THAT(client.IsCompletedSuccessfully(), Eq(false));
+ EXPECT_THAT(client.State(), Eq("R1_SHARE_KEYS_INPUT_NOT_SET"));
+}
+
+TEST(SecAggClientTest, ReceiveMessageReturnValuesAreCorrect) {
+ // The actual behavior of the client upon receipt of messages is tested in the
+ // state class test files; here we just check that ReceiveMessage returns
+ // values correctly.
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClient client(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ input_vector_specs, std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+
+ std::make_unique<AesCtrPrngFactory>());
+
+ // Get the client into a state where it can receive a message.
+ ClientToServerWrapperMessage round_0_client_message;
+ EXPECT_CALL(*sender, Send(_))
+ .WillOnce(::testing::SaveArgPointee<0>(&round_0_client_message));
+ EXPECT_THAT(client.Start(), IsOk());
+
+ ServerToClientWrapperMessage round_1_message;
+ EcdhPregeneratedTestKeys ecdh_keys;
+ for (int i = 0; i < 4; ++i) {
+ PairOfPublicKeys* keypair = round_1_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys();
+ if (i == 1) {
+ *keypair = round_0_client_message.advertise_keys().pair_of_public_keys();
+ } else {
+ keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
+ keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
+ }
+ }
+ round_1_message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(round_1_message.share_keys_request()).data);
+
+ EXPECT_CALL(*sender, Send(_));
+
+ // A valid message from the server should return true if it can continue.
+ StatusOr<bool> result = client.ReceiveMessage(round_1_message);
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(result.value(), Eq(true));
+
+ // An abort message from the server should return false.
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ result = client.ReceiveMessage(abort_message);
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(result.value(), Eq(false));
+
+ // Any other message after abort should raise an error.
+ result = client.ReceiveMessage(abort_message);
+ EXPECT_THAT(result.ok(), Eq(false));
+}
+
+TEST(SecAggClientTest, AbortMovesToCorrectStateAndSendsMessageToServer) {
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClient client(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ input_vector_specs, std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::string error_string =
+ "Abort upon external request for reason <Abort reason>.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ Status result = client.Abort("Abort reason");
+ EXPECT_THAT(result.code(), Eq(OK));
+ EXPECT_THAT(client.State(), Eq("ABORTED"));
+ EXPECT_THAT(client.ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientTest,
+ AbortWithNoMessageMovesToCorrectStateAndSendsMessageToServer) {
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClient client(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ input_vector_specs, std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+
+ std::make_unique<AesCtrPrngFactory>());
+
+ std::string error_string =
+ "Abort upon external request for reason <unknown reason>.";
+ ClientToServerWrapperMessage expected_message;
+ expected_message.mutable_abort()->set_diagnostic_info(error_string);
+ EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
+
+ Status result = client.Abort();
+ EXPECT_THAT(result.code(), Eq(OK));
+ EXPECT_THAT(client.State(), Eq("ABORTED"));
+ EXPECT_THAT(client.ErrorMessage().value(), Eq(error_string));
+}
+
+TEST(SecAggClientTest, ErrorMessageRaisesErrorStatusIfNotAborted) {
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClient client(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ input_vector_specs, std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+
+ std::make_unique<AesCtrPrngFactory>());
+
+ EXPECT_THAT(client.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecAggClientTest, SetInputChangesStateOnlyOnce) {
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
+ MockSendToServerInterface* sender = new MockSendToServerInterface();
+ MockStateTransitionListener* transition_listener =
+ new MockStateTransitionListener();
+
+ SecAggClient client(
+ 4, // max_neighbors_expected
+ 3, // minimum_surviving_neighbors_for_reconstruction
+ input_vector_specs, std::make_unique<FakePrng>(),
+ std::unique_ptr<SendToServerInterface>(sender),
+ std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
+
+ std::make_unique<AesCtrPrngFactory>());
+
+ auto input_map = std::make_unique<SecAggVectorMap>();
+ input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
+
+ Status result = client.SetInput(std::move(input_map));
+ EXPECT_THAT(result.code(), Eq(OK));
+
+ auto input_map2 = std::make_unique<SecAggVectorMap>();
+ input_map2->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
+ result = client.SetInput(std::move(input_map));
+ EXPECT_THAT(result.code(), Eq(FAILED_PRECONDITION));
+ EXPECT_THAT(client.State(), Eq("R0_ADVERTISE_KEYS_INPUT_SET"));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/client/send_to_server_interface.h b/fcp/secagg/client/send_to_server_interface.h
new file mode 100644
index 0000000..e9cf935
--- /dev/null
+++ b/fcp/secagg/client/send_to_server_interface.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_SEND_TO_SERVER_INTERFACE_H_
+#define FCP_SECAGG_CLIENT_SEND_TO_SERVER_INTERFACE_H_
+
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// Used to provide a SecAggClient with a private and authenticated channel with
+// the server, which can be used to send protocol buffer messages.
+
+class SendToServerInterface {
+ public:
+ // Note: For efficiency, contents may be Swap()'d to default values. In other
+ // words, consider message to have been "moved from" when Send returns.
+ virtual void Send(ClientToServerWrapperMessage* message) = 0;
+
+ virtual ~SendToServerInterface() = default;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_SEND_TO_SERVER_INTERFACE_H_
diff --git a/fcp/secagg/client/state_transition_listener_interface.h b/fcp/secagg/client/state_transition_listener_interface.h
new file mode 100644
index 0000000..143638a
--- /dev/null
+++ b/fcp/secagg/client/state_transition_listener_interface.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_CLIENT_STATE_TRANSITION_LISTENER_INTERFACE_H_
+#define FCP_SECAGG_CLIENT_STATE_TRANSITION_LISTENER_INTERFACE_H_
+
+#include <cstdint>
+
+namespace fcp {
+namespace secagg {
+
+enum class ClientState : int {
+ INITIAL = 0,
+ R0_ADVERTISE_KEYS = 1,
+ R1_SHARE_KEYS = 2,
+ R2_MASKED_INPUT = 3,
+ R3_UNMASKING = 4,
+ COMPLETED = 5,
+ ABORTED = 6,
+};
+
+// Listens for state transition messages.
+//
+// The expected call pattern in the successful case is the following:
+// - Transition(R0_ADVERTISE_KEYS)
+// - Started(R0_ADVERTISE_KEYS)
+// - Stopped(R0_ADVERTISE_KEYS)
+// - Transition(R1_SHARE_KEYS)
+// - Started(R1_SHARE_KEYS)
+// - Stopped(R1_SHARE_KEYS)
+// - Transition(R2_MASKED_INPUT)
+// ...
+// - Transition(COMPLETED)
+//
+// It is also possible to have more than one pair of Started and Stopped calls
+// for any given state.
+//
+// If the protocol gets aborted at any point, Transition(ABORTED) would be
+// called and any remaining Started and Stopped calls would be skipped.
+class StateTransitionListenerInterface {
+ public:
+ // Called on transition to a new state.
+ virtual void Transition(ClientState new_state) = 0;
+ // Called just before a state starts computation, excluding any idle or
+ // waiting time.
+ virtual void Started(ClientState state) = 0;
+ // Called just after a state stops computation, excluding any idle or
+ // waiting time, or sending messages to to the server.
+ virtual void Stopped(ClientState state) = 0;
+ virtual void set_execution_session_id(int64_t execution_session_id) = 0;
+
+ virtual ~StateTransitionListenerInterface() = default;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_CLIENT_STATE_TRANSITION_LISTENER_INTERFACE_H_
diff --git a/fcp/secagg/server/BUILD b/fcp/secagg/server/BUILD
new file mode 100644
index 0000000..e72ec05
--- /dev/null
+++ b/fcp/secagg/server/BUILD
@@ -0,0 +1,360 @@
+# Description:
+# SecAgg server-specific components.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+load("//fcp/tracing:build_defs.bzl", "tracing_schema_cc_library")
+
+package(
+ default_visibility = ["//fcp:internal"],
+)
+
+tracing_schema_cc_library(
+ name = "tracing_schema",
+ srcs = ["tracing_schema.fbs"],
+)
+
+cc_library(
+ name = "experiments_interface",
+ hdrs = ["experiments_interface.h"],
+ copts = FCP_COPTS,
+ deps = ["@com_google_absl//absl/strings"],
+)
+
+cc_library(
+ name = "experiments_names",
+ hdrs = ["experiments_names.h"],
+ copts = FCP_COPTS,
+)
+
+cc_library(
+ name = "secagg_server_protocol_impl",
+ srcs = ["secagg_server_protocol_impl.cc"],
+ hdrs = ["secagg_server_protocol_impl.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":experiments_interface",
+ ":secagg_scheduler",
+ ":secagg_server_metrics_listener",
+ ":secret_sharing_graph",
+ ":send_to_clients_interface",
+ ":server_cc_proto",
+ ":tracing_schema",
+ "//fcp/base",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/shared:cc_proto",
+ "//fcp/tracing",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "secagg_server_metrics_listener",
+ hdrs = [
+ "secagg_server_metrics_listener.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":server_cc_proto",
+ "//fcp/secagg/shared:cc_proto",
+ ],
+)
+
+cc_library(
+ name = "send_to_clients_interface",
+ hdrs = [
+ "send_to_clients_interface.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/secagg/shared:cc_proto",
+ ],
+)
+
+cc_library(
+ name = "server",
+ srcs = [
+ "secagg_server.cc",
+ "secagg_server_aborted_state.cc",
+ "secagg_server_completed_state.cc",
+ "secagg_server_prng_running_state.cc",
+ "secagg_server_r0_advertise_keys_state.cc",
+ "secagg_server_r1_share_keys_state.cc",
+ "secagg_server_r2_masked_input_coll_state.cc",
+ "secagg_server_r3_unmasking_state.cc",
+ "secagg_server_state.cc",
+ "secagg_trace_utility.cc",
+ ],
+ hdrs = [
+ "secagg_server.h",
+ "secagg_server_aborted_state.h",
+ "secagg_server_completed_state.h",
+ "secagg_server_prng_running_state.h",
+ "secagg_server_r0_advertise_keys_state.h",
+ "secagg_server_r1_share_keys_state.h",
+ "secagg_server_r2_masked_input_coll_state.h",
+ "secagg_server_r3_unmasking_state.h",
+ "secagg_server_state.h",
+ "secagg_trace_utility.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":experiments_interface",
+ ":experiments_names",
+ ":graph_parameter_finder",
+ ":secagg_scheduler",
+ ":secagg_server_metrics_listener",
+ ":secagg_server_protocol_impl",
+ ":secret_sharing_graph",
+ ":secret_sharing_graph_factory",
+ ":send_to_clients_interface",
+ ":server_cc_proto",
+ ":tracing_schema",
+ "//fcp/base",
+ "//fcp/base:scheduler",
+ "//fcp/secagg/server/aes",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/shared:cc_proto",
+ "//fcp/tracing",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:node_hash_set",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "secret_sharing_graph_factory",
+ hdrs = ["secret_sharing_graph_factory.h"],
+ deps = [
+ ":secret_sharing_complete_graph",
+ ":secret_sharing_graph",
+ ":secret_sharing_harary_graph",
+ "//fcp/base",
+ ],
+)
+
+cc_library(
+ name = "secret_sharing_graph",
+ hdrs = ["secret_sharing_graph.h"],
+ deps = ["//fcp/base"],
+)
+
+cc_library(
+ name = "secret_sharing_complete_graph",
+ hdrs = ["secret_sharing_complete_graph.h"],
+ deps = [
+ ":secret_sharing_graph",
+ "//fcp/base",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "secret_sharing_harary_graph",
+ srcs = [
+ "secret_sharing_harary_graph.cc",
+ "ssl_bit_gen.cc",
+ ],
+ hdrs = [
+ "secret_sharing_harary_graph.h",
+ "ssl_bit_gen.h",
+ ],
+ deps = [
+ ":secret_sharing_graph",
+ "//fcp/base",
+ "@boringssl//:crypto",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "graph_parameter_finder",
+ srcs = ["graph_parameter_finder.cc"],
+ hdrs = [
+ "graph_parameter_finder.h",
+ ],
+ deps = [
+ ":distribution_utilities",
+ ":server_cc_proto",
+ "//fcp/base",
+ ],
+)
+
+cc_library(
+ name = "distribution_utilities",
+ srcs = ["distribution_utilities.cc"],
+ hdrs = [
+ "distribution_utilities.h",
+ ],
+ deps = ["//fcp/base"],
+)
+
+cc_test(
+ name = "distribution_utilities_test",
+ srcs = ["distribution_utilities_test.cc"],
+ deps = [
+ ":distribution_utilities",
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "server-test",
+ size = "small",
+ srcs = [
+ "secagg_server_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":server",
+ ":server_cc_proto",
+ ":tracing_schema",
+ "//fcp/base",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/shared:cc_proto",
+ "//fcp/secagg/testing:common_mocks",
+ "//fcp/secagg/testing/server:experiments",
+ "//fcp/secagg/testing/server:server_mocks",
+ "//fcp/testing",
+ "//fcp/tracing:test_tracing_recorder",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "state-test",
+ size = "small",
+ srcs = [
+ "secagg_server_aborted_state_test.cc",
+ "secagg_server_completed_state_test.cc",
+ "secagg_server_prng_running_state_test.cc",
+ "secagg_server_r0_advertise_keys_state_test.cc",
+ "secagg_server_r1_share_keys_state_test.cc",
+ "secagg_server_r2_masked_input_coll_state_test.cc",
+ "secagg_server_r3_unmasking_state_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":experiments_interface",
+ ":experiments_names",
+ ":secagg_scheduler",
+ ":secret_sharing_graph_factory",
+ ":send_to_clients_interface",
+ ":server",
+ ":server_cc_proto",
+ "//fcp/base",
+ "//fcp/base:scheduler",
+ "//fcp/secagg/server/aes",
+ "//fcp/secagg/shared",
+ "//fcp/secagg/shared:cc_proto",
+ "//fcp/secagg/testing",
+ "//fcp/secagg/testing:common_mocks",
+ "//fcp/secagg/testing/server:async_runner",
+ "//fcp/secagg/testing/server:experiments",
+ "//fcp/secagg/testing/server:server_mocks",
+ "//fcp/testing",
+ "//fcp/tracing:test_tracing_recorder",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/container:node_hash_set",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "secret_sharing_harary_graph_test",
+ srcs = ["secret_sharing_harary_graph_test.cc"],
+ deps = [
+ ":secret_sharing_graph",
+ ":secret_sharing_graph_factory",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "secret_sharing_complete_graph_test",
+ srcs = ["secret_sharing_complete_graph_test.cc"],
+ deps = [
+ ":secret_sharing_graph",
+ ":secret_sharing_graph_factory",
+ "//fcp/testing",
+ "@com_google_absl//absl/status",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "graph_parameter_finder_test",
+ srcs = ["graph_parameter_finder_test.cc"],
+ deps = [
+ ":graph_parameter_finder",
+ ":server_cc_proto",
+ "//fcp/testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "secagg_scheduler",
+ srcs = [
+ "secagg_scheduler.cc",
+ ],
+ hdrs = [
+ "secagg_scheduler.h",
+ ],
+ deps = [
+ "//fcp/base",
+ "//fcp/base:clock",
+ "//fcp/base:reentrancy_guard",
+ "//fcp/base:scheduler",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_test(
+ name = "secagg_scheduler_test",
+ srcs = ["secagg_scheduler_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":secagg_scheduler",
+ "//fcp/base",
+ "//fcp/base:scheduler",
+ "//fcp/base:simulated_clock",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+proto_library(
+ name = "server_proto",
+ srcs = [
+ "secagg_server_enums.proto",
+ "secagg_server_messages.proto",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+java_proto_library(
+ name = "server_java_proto",
+ deps = [":server_proto"],
+)
+
+cc_proto_library(
+ name = "server_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":server_proto"],
+)
diff --git a/fcp/secagg/server/aes/BUILD b/fcp/secagg/server/aes/BUILD
new file mode 100644
index 0000000..099434e
--- /dev/null
+++ b/fcp/secagg/server/aes/BUILD
@@ -0,0 +1,32 @@
+# Description:
+# AES SecAgg server protocol implementation.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp:internal"],
+)
+
+cc_library(
+ name = "aes",
+ srcs = [
+ "aes_secagg_server_protocol_impl.cc",
+ ],
+ hdrs = [
+ "aes_secagg_server_protocol_impl.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/base",
+ "//fcp/secagg/server:experiments_names",
+ "//fcp/secagg/server:secagg_scheduler",
+ "//fcp/secagg/server:secagg_server_protocol_impl",
+ "//fcp/secagg/server:server_cc_proto",
+ "//fcp/secagg/server:tracing_schema",
+ "//fcp/secagg/shared",
+ "//fcp/tracing",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/status",
+ ],
+)
diff --git a/fcp/secagg/server/aes/aes_secagg_server_protocol_impl.cc b/fcp/secagg/server/aes/aes_secagg_server_protocol_impl.cc
new file mode 100644
index 0000000..50c28e1
--- /dev/null
+++ b/fcp/secagg/server/aes/aes_secagg_server_protocol_impl.cc
@@ -0,0 +1,223 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <functional>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "absl/status/status.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/experiments_names.h"
+#include "fcp/secagg/server/secagg_scheduler.h"
+#include "fcp/secagg/shared/map_of_masks.h"
+#include "fcp/secagg/shared/math.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace {
+
+std::unique_ptr<fcp::secagg::SecAggUnpackedVectorMap> AddReduce(
+ std::vector<std::unique_ptr<fcp::secagg::SecAggVectorMap>> vector_of_maps) {
+ FCP_CHECK(!vector_of_maps.empty());
+ // Initialize result
+ auto result = std::make_unique<fcp::secagg::SecAggUnpackedVectorMap>(
+ *vector_of_maps[0]);
+ // Reduce vector of maps
+ for (int i = 1; i < vector_of_maps.size(); ++i) {
+ result->Add(*vector_of_maps[i]);
+ }
+ return result;
+}
+
+// Initializes a SecAggUnpackedVectorMap object according to a provided input
+// vector specification
+std::unique_ptr<fcp::secagg::SecAggUnpackedVectorMap> InitializeVectorMap(
+ const std::vector<fcp::secagg::InputVectorSpecification>&
+ input_vector_specs) {
+ auto vector_map = std::make_unique<fcp::secagg::SecAggUnpackedVectorMap>();
+ for (const fcp::secagg::InputVectorSpecification& vector_spec :
+ input_vector_specs) {
+ vector_map->emplace(vector_spec.name(),
+ fcp::secagg::SecAggUnpackedVector(
+ vector_spec.length(), vector_spec.modulus()));
+ }
+ return vector_map;
+}
+
+} // namespace
+
+namespace fcp {
+namespace secagg {
+
+// The number of keys included in a single PRNG job.
+static constexpr int kPrngBatchSize = 32;
+
+std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>>
+AesSecAggServerProtocolImpl::SetupMaskedInputCollection() {
+ if (!experiments()->IsEnabled(kSecAggAsyncRound2Experiment)) {
+ // Prepare the sum of masked input vectors with all zeroes.
+ masked_input_ = InitializeVectorMap(input_vector_specs());
+ } else {
+ auto initial_value = InitializeVectorMap(input_vector_specs());
+ masked_input_accumulator_ =
+ scheduler()->CreateAccumulator<SecAggUnpackedVectorMap>(
+ std::move(initial_value), SecAggUnpackedVectorMap::AddMaps);
+ }
+ return masked_input_accumulator_;
+}
+
+std::vector<std::unique_ptr<SecAggVectorMap>>
+AesSecAggServerProtocolImpl::TakeMaskedInputQueue() {
+ absl::MutexLock lock(&mutex_);
+ return std::move(masked_input_queue_);
+}
+
+Status AesSecAggServerProtocolImpl::HandleMaskedInputCollectionResponse(
+ std::unique_ptr<MaskedInputCollectionResponse> masked_input_response) {
+ FCP_CHECK(masked_input_response);
+ // Make sure the received vectors match the specification.
+ if (masked_input_response->vectors().size() != input_vector_specs().size()) {
+ return ::absl::InvalidArgumentError(
+ "Masked input does not match input vector specification - "
+ "wrong number of vectors.");
+ }
+ auto& input_vectors = *masked_input_response->mutable_vectors();
+ auto checked_masked_vectors = std::make_unique<SecAggVectorMap>();
+ for (const InputVectorSpecification& vector_spec : input_vector_specs()) {
+ auto masked_vector = input_vectors.find(vector_spec.name());
+ if (masked_vector == input_vectors.end()) {
+ return ::absl::InvalidArgumentError(
+ "Masked input does not match input vector specification - wrong "
+ "vector names.");
+ }
+ // TODO(team): This does not appear to be properly covered by unit
+ // tests.
+ int bit_width = SecAggVector::GetBitWidth(vector_spec.modulus());
+ if (masked_vector->second.encoded_vector().size() !=
+ DivideRoundUp(vector_spec.length() * bit_width, 8)) {
+ return ::absl::InvalidArgumentError(
+ "Masked input does not match input vector specification - vector is "
+ "wrong size.");
+ }
+ checked_masked_vectors->emplace(
+ vector_spec.name(),
+ SecAggVector(std::move(*masked_vector->second.mutable_encoded_vector()),
+ vector_spec.modulus(), vector_spec.length()));
+ }
+
+ if (experiments()->IsEnabled(kSecAggAsyncRound2Experiment)) {
+ // If async processing is enabled we queue the client message. Moreover, if
+ // the queue we found was empty this means that it has been taken by an
+ // asynchronous aggregation task. In that case, we schedule an aggregation
+ // task to process the queue that we just initiated, which will happen
+ // eventually.
+ size_t is_queue_empty;
+ {
+ absl::MutexLock lock(&mutex_);
+ is_queue_empty = masked_input_queue_.empty();
+ masked_input_queue_.emplace_back(std::move(checked_masked_vectors));
+ }
+ if (is_queue_empty) {
+ // TODO(team): Abort should handle the situation where `this` has
+ // been destructed while the schedule task is still not running, and
+ // message_queue_ can't be moved.
+ Trace<Round2AsyncWorkScheduled>();
+ masked_input_accumulator_->Schedule([&] {
+ auto queue = TakeMaskedInputQueue();
+ Trace<Round2MessageQueueTaken>(queue.size());
+ return AddReduce(std::move(queue));
+ });
+ }
+ } else {
+ // Sequential processing
+ FCP_CHECK(masked_input_);
+ masked_input_->Add(*checked_masked_vectors);
+ }
+
+ return ::absl::OkStatus();
+}
+
+void AesSecAggServerProtocolImpl::FinalizeMaskedInputCollection() {
+ if (experiments()->IsEnabled(kSecAggAsyncRound2Experiment)) {
+ FCP_CHECK(masked_input_accumulator_->IsIdle());
+ masked_input_ = masked_input_accumulator_->GetResultAndCancel();
+ }
+}
+
+CancellationToken AesSecAggServerProtocolImpl::StartPrng(
+ const PrngWorkItems& work_items,
+ std::function<void(Status)> done_callback) {
+ FCP_CHECK(done_callback);
+ FCP_CHECK(masked_input_);
+ auto generators =
+ std::vector<std::function<std::unique_ptr<SecAggUnpackedVectorMap>()>>();
+
+ // Break the keys to add or subtract into vectors of size kPrngBatchSize (or
+ // less for the last one) and schedule them as tasks.
+ for (auto it = work_items.prng_keys_to_add.begin();
+ it < work_items.prng_keys_to_add.end(); it += kPrngBatchSize) {
+ std::vector<AesKey> batch_prng_keys_to_add;
+ std::copy(it,
+ std::min(it + kPrngBatchSize, work_items.prng_keys_to_add.end()),
+ std::back_inserter(batch_prng_keys_to_add));
+ generators.emplace_back([=]() {
+ return UnpackedMapOfMasks(batch_prng_keys_to_add, std::vector<AesKey>(),
+ input_vector_specs(), session_id(),
+ *prng_factory());
+ });
+ }
+
+ for (auto it = work_items.prng_keys_to_subtract.begin();
+ it < work_items.prng_keys_to_subtract.end(); it += kPrngBatchSize) {
+ std::vector<AesKey> batch_prng_keys_to_subtract;
+ std::copy(
+ it,
+ std::min(it + kPrngBatchSize, work_items.prng_keys_to_subtract.end()),
+ std::back_inserter(batch_prng_keys_to_subtract));
+ generators.emplace_back([=]() {
+ return UnpackedMapOfMasks(
+ std::vector<AesKey>(), batch_prng_keys_to_subtract,
+ input_vector_specs(), session_id(), *prng_factory());
+ });
+ }
+
+ auto accumulator = scheduler()->CreateAccumulator<SecAggUnpackedVectorMap>(
+ std::move(masked_input_), SecAggUnpackedVectorMap::AddMaps);
+ for (const auto& generator : generators) {
+ accumulator->Schedule(generator);
+ }
+ accumulator->SetAsyncObserver([=, accumulator = accumulator.get()]() {
+ auto unpacked_map = accumulator->GetResultAndCancel();
+ auto packed_map = std::make_unique<SecAggVectorMap>();
+ for (auto& entry : *unpacked_map) {
+ uint64_t modulus = entry.second.modulus();
+ packed_map->emplace(entry.first,
+ SecAggVector(std::move(entry.second), modulus));
+ }
+ SetResult(std::move(packed_map));
+ done_callback(absl::OkStatus());
+ });
+ return accumulator;
+}
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h b/fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h
new file mode 100644
index 0000000..feaf7fd
--- /dev/null
+++ b/fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_AES_AES_SECAGG_SERVER_PROTOCOL_IMPL_H_
+#define FCP_SECAGG_SERVER_AES_AES_SECAGG_SERVER_PROTOCOL_IMPL_H_
+
+#include <functional>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "fcp/secagg/server/secagg_scheduler.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/tracing_schema.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/tracing/tracing_span.h"
+
+namespace fcp {
+namespace secagg {
+
+class AesSecAggServerProtocolImpl
+ : public SecAggServerProtocolImpl,
+ public std::enable_shared_from_this<AesSecAggServerProtocolImpl> {
+ public:
+ AesSecAggServerProtocolImpl(
+ std::unique_ptr<SecretSharingGraph> graph,
+ int minimum_number_of_clients_to_proceed,
+ std::vector<InputVectorSpecification> input_vector_specs,
+ std::unique_ptr<SecAggServerMetricsListener> metrics,
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ SendToClientsInterface* sender,
+ std::unique_ptr<SecAggScheduler> scheduler,
+ std::vector<ClientStatus> client_statuses, ServerVariant server_variant,
+ std::unique_ptr<ExperimentsInterface> experiments = nullptr)
+ : SecAggServerProtocolImpl(
+ std::move(graph), minimum_number_of_clients_to_proceed,
+ std::move(metrics), std::move(prng_factory), sender,
+ std::move(scheduler), std::move(client_statuses),
+ std::move(experiments)),
+ server_variant_(server_variant),
+ input_vector_specs_(std::move(input_vector_specs)) {}
+
+ ServerVariant server_variant() const override { return server_variant_; }
+
+ // Returns one InputVectorSpecification for each input vector which the
+ // protocol will aggregate.
+ inline const std::vector<InputVectorSpecification>& input_vector_specs()
+ const {
+ return input_vector_specs_;
+ }
+
+ Status InitializeShareKeysRequest(ShareKeysRequest* request) const override {
+ return ::absl::OkStatus();
+ }
+
+ // TODO(team): Remove this method. This field must be set from
+ // inside the protocol implementation.
+ void set_masked_input(std::unique_ptr<SecAggUnpackedVectorMap> masked_input) {
+ masked_input_ = std::move(masked_input);
+ }
+
+ // Takes out ownership the accumulated queue of masked inputs and empties
+ // the current queue.
+ std::vector<std::unique_ptr<SecAggVectorMap>> TakeMaskedInputQueue();
+
+ std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>>
+ SetupMaskedInputCollection() override;
+
+ void FinalizeMaskedInputCollection() override;
+
+ Status HandleMaskedInputCollectionResponse(
+ std::unique_ptr<MaskedInputCollectionResponse> masked_input_response)
+ override;
+
+ CancellationToken StartPrng(
+ const PrngWorkItems& work_items,
+ std::function<void(Status)> done_callback) override;
+
+ private:
+ std::unique_ptr<SecAggUnpackedVectorMap> masked_input_;
+ // Protects masked_input_queue_.
+ absl::Mutex mutex_;
+ std::vector<std::unique_ptr<SecAggVectorMap>> masked_input_queue_
+ ABSL_GUARDED_BY(mutex_);
+ std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>>
+ masked_input_accumulator_;
+ ServerVariant server_variant_;
+ std::vector<InputVectorSpecification> input_vector_specs_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_AES_AES_SECAGG_SERVER_PROTOCOL_IMPL_H_
diff --git a/fcp/secagg/server/distribution_utilities.cc b/fcp/secagg/server/distribution_utilities.cc
new file mode 100644
index 0000000..2b46457
--- /dev/null
+++ b/fcp/secagg/server/distribution_utilities.cc
@@ -0,0 +1,140 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/distribution_utilities.h"
+
+#include <cmath>
+#include <iostream>
+#include <memory>
+
+namespace fcp {
+namespace secagg {
+
+StatusOr<std::unique_ptr<HypergeometricDistribution>>
+HypergeometricDistribution::Create(int total, int marked, int sampled) {
+ if (total < 0) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The population should be at least zero. Value provided = "
+ << total;
+ }
+ if (marked < 0) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The marked population should have size at least zero. Value "
+ "provided = "
+ << marked;
+ }
+ if (sampled < 0) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The sample size should be at least zero. Value provided = "
+ << sampled;
+ }
+ if (marked > total) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The marked population " << marked
+ << " should not exceed the total population " << total;
+ }
+ if (sampled > total) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The sample size " << sampled
+ << " should not exceed the total population " << total;
+ }
+ return std::unique_ptr<HypergeometricDistribution>(
+ new HypergeometricDistribution(total, marked, sampled));
+}
+
+double HypergeometricDistribution::PMF(double x) { return PMFImpl(x, marked_); }
+
+double HypergeometricDistribution::PMFImpl(double x, int counted) {
+ if (x < 0 || x > sampled_ || x > counted) return 0;
+ if (total_ + x < counted + sampled_) return 0;
+ double lpmf = std::lgamma(sampled_ + 1) + std::lgamma(counted + 1) +
+ std::lgamma(total_ - counted + 1) +
+ std::lgamma(total_ - sampled_ + 1) - std::lgamma(x + 1) -
+ std::lgamma(sampled_ - x + 1) - std::lgamma(counted - x + 1) -
+ std::lgamma(total_ + 1) -
+ std::lgamma(total_ - sampled_ - counted + x + 1);
+ return std::exp(lpmf);
+}
+
+double HypergeometricDistribution::CDF(double x) {
+ x = std::floor(x);
+ double mean = marked_ * static_cast<double>(sampled_) / total_;
+ if (x > mean) {
+ return 1 - CDFImpl(sampled_ - x - 1, total_ - marked_);
+ } else {
+ return CDFImpl(x, marked_);
+ }
+}
+
+double HypergeometricDistribution::CDFImpl(double x, int counted) {
+ double current_pmf = PMFImpl(x, counted);
+ double result = 0;
+ while (current_pmf > result * 1e-16) {
+ result += current_pmf;
+ current_pmf *= x;
+ current_pmf *= total_ - counted - sampled_ + x;
+ current_pmf /= counted - x + 1;
+ current_pmf /= sampled_ - x + 1;
+ --x;
+ }
+ return result;
+}
+
+double HypergeometricDistribution::FindQuantile(double quantile,
+ bool complement) {
+ if (quantile > 0.5) {
+ quantile = 1 - quantile;
+ complement = !complement;
+ }
+ if (complement) {
+ return sampled_ - FindQuantileImpl(quantile, total_ - marked_) - 1;
+ } else {
+ return FindQuantileImpl(quantile, marked_);
+ }
+}
+
+double HypergeometricDistribution::FindQuantileImpl(double quantile,
+ int counted) {
+ double basic_bound = counted + sampled_ - total_ - 1;
+ // An inverted tail bound gives a lower bound on the result
+ double fancy_bound =
+ sampled_ * (static_cast<double>(counted) / total_ -
+ std::sqrt(-std::log(quantile) / (2 * sampled_)));
+ double result = -1;
+ if (basic_bound > result) result = basic_bound;
+ if (fancy_bound > result) result = fancy_bound;
+ result = std::floor(result);
+
+ double current_cdf = CDFImpl(result, counted);
+ double current_pmf = PMFImpl(result, counted);
+ while (current_cdf < quantile && result < sampled_) {
+ if (current_pmf > 0) {
+ current_pmf /= result + 1;
+ current_pmf /= total_ - counted - sampled_ + result + 1;
+ current_pmf *= counted - result;
+ current_pmf *= sampled_ - result;
+ } else {
+ current_pmf = PMFImpl(result + 1, counted);
+ }
+ current_cdf += current_pmf;
+ ++result;
+ }
+ --result;
+ return result;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/distribution_utilities.h b/fcp/secagg/server/distribution_utilities.h
new file mode 100644
index 0000000..ca128af
--- /dev/null
+++ b/fcp/secagg/server/distribution_utilities.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_DISTRIBUTION_UTILITIES_H_
+#define FCP_SECAGG_SERVER_DISTRIBUTION_UTILITIES_H_
+
+#include <memory>
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace secagg {
+
+// Represents a Hypergeometric distribution with parameters fixed at creation of
+// the object. Allows to query certain distribution functions.
+class HypergeometricDistribution {
+ public:
+ static StatusOr<std::unique_ptr<HypergeometricDistribution>> Create(
+ int total, int marked, int sampled);
+
+ // Evaluates the probability mass funciton of the random variable at x.
+ double PMF(double x);
+
+ // Evaluates the cumulative distribution function of the random variable at x.
+ double CDF(double x);
+
+ // Finds the value whose cdf is quantile rounded outwards to an integer.
+ // Setting complement to true is equivalent to setting quantile = 1 - quantile
+ // but can avoid numerical error in the extreme upper tail.
+ double FindQuantile(double quantile, bool complement = false);
+
+ private:
+ const int total_;
+ const int marked_;
+ const int sampled_;
+
+ HypergeometricDistribution(int total, int marked, int sampled)
+ : total_(total), marked_(marked), sampled_(sampled) {}
+
+ double PMFImpl(double x, int counted);
+
+ double CDFImpl(double x, int counted);
+
+ double FindQuantileImpl(double quantile, int counted);
+};
+
+} // namespace secagg
+} // namespace fcp
+#endif // FCP_SECAGG_SERVER_DISTRIBUTION_UTILITIES_H_
diff --git a/fcp/secagg/server/distribution_utilities_test.cc b/fcp/secagg/server/distribution_utilities_test.cc
new file mode 100644
index 0000000..5a1fa25
--- /dev/null
+++ b/fcp/secagg/server/distribution_utilities_test.cc
@@ -0,0 +1,166 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/distribution_utilities.h"
+
+#include <memory>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+struct HypergeometricCDForPMFInstance {
+ const double x;
+ const int total;
+ const int marked;
+ const int sampled;
+ const double probability;
+};
+
+struct HypergeometricQuantileInstance {
+ const double probability;
+ const int total;
+ const int marked;
+ const int sampled;
+ const int lower;
+ const int upper;
+};
+
+class HypergeometricPMF
+ : public ::testing::TestWithParam<HypergeometricCDForPMFInstance> {};
+
+class HypergeometricCDF
+ : public ::testing::TestWithParam<HypergeometricCDForPMFInstance> {};
+
+class HypergeometricQuantile
+ : public ::testing::TestWithParam<HypergeometricQuantileInstance> {};
+
+TEST(HypergeometricDistributionCreate, RejectsInvalidInputs) {
+ ASSERT_FALSE(HypergeometricDistribution::Create(10, 11, 5).ok());
+ ASSERT_FALSE(HypergeometricDistribution::Create(10, 5, 11).ok());
+ ASSERT_FALSE(HypergeometricDistribution::Create(10, -1, 5).ok());
+ ASSERT_FALSE(HypergeometricDistribution::Create(10, 5, -1).ok());
+ ASSERT_FALSE(HypergeometricDistribution::Create(-10, 5, 5).ok());
+ ASSERT_FALSE(HypergeometricDistribution::Create(-10, -5, -5).ok());
+}
+
+TEST_P(HypergeometricPMF, ReturnsPrecomputedValues) {
+ const HypergeometricCDForPMFInstance& test_params = GetParam();
+ FCP_LOG(INFO) << "Testing hypergeometric pmf with x = " << test_params.x
+ << " total = " << test_params.total
+ << " marked = " << test_params.marked
+ << " sampled = " << test_params.sampled << ".";
+ auto p = HypergeometricDistribution::Create(
+ test_params.total, test_params.marked, test_params.sampled);
+ ASSERT_THAT(p, IsOk());
+ double result = p.value()->PMF(test_params.x);
+ double relative_error =
+ abs(result - test_params.probability) / (test_params.probability + 1e-30);
+ EXPECT_LT(relative_error, 1e-9);
+ FCP_LOG(INFO) << "result = " << result
+ << " expected_result = " << test_params.probability
+ << " relative_error" << relative_error;
+}
+
+INSTANTIATE_TEST_SUITE_P(HypergeometricPMFTests, HypergeometricPMF,
+ ::testing::ValuesIn<HypergeometricCDForPMFInstance>(
+ {{-5, 9, 3, 3, 0.0},
+ {17, 9, 3, 3, 0.0},
+ {0, 10, 0, 5, 1.0},
+ {3, 10, 10, 5, 0.0},
+ {4, 15, 6, 12, 0.2967032967032967},
+ {38, 98, 63, 17, 0.0},
+ {2, 187, 105, 43, 5.423847289689941e-16},
+ {40, 980, 392, 103, 0.08225792329713294},
+ {89, 1489, 312, 370, 0.014089199026838601},
+ {100000, 1000000, 200000, 500000,
+ 0.0019947087839501726}}));
+
+TEST_P(HypergeometricCDF, ReturnsPrecomputedValues) {
+ const HypergeometricCDForPMFInstance& test_params = GetParam();
+ FCP_LOG(INFO) << "Testing hypergeometric cdf with x = " << test_params.x
+ << " total = " << test_params.total
+ << " marked = " << test_params.marked
+ << " sampled = " << test_params.sampled << ".";
+ auto p = HypergeometricDistribution::Create(
+ test_params.total, test_params.marked, test_params.sampled);
+ ASSERT_THAT(p, IsOk());
+ double result = p.value()->CDF(test_params.x);
+ double relative_error =
+ abs(result - test_params.probability) / (test_params.probability + 1e-30);
+ EXPECT_LT(relative_error, 1e-9);
+ FCP_LOG(INFO) << "result = " << result
+ << " expected_result = " << test_params.probability
+ << " relative_error" << relative_error;
+}
+
+INSTANTIATE_TEST_SUITE_P(HypergeometricCDFTests, HypergeometricCDF,
+ ::testing::ValuesIn<HypergeometricCDForPMFInstance>(
+ {{-5, 9, 3, 3, 0.0},
+ {17, 9, 3, 3, 1.0},
+ {0, 10, 0, 5, 1.0},
+ {3, 10, 10, 5, 0.0},
+ {4.5, 15, 6, 12, 0.34065934065934067},
+ {38, 98, 63, 17, 1.0},
+ {2, 187, 105, 43, 5.526570670097338e-16},
+ {40, 980, 392, 103, 0.4430562850817352},
+ {89, 1489, 312, 370, 0.9599670222722507},
+ {100000, 1000000, 200000, 500000,
+ 0.5009973543919738}}));
+
+TEST_P(HypergeometricQuantile, ReturnsPrecomputedValues) {
+ const HypergeometricQuantileInstance& test_params = GetParam();
+ FCP_LOG(INFO) << "Testing hypergeometric quantile with probability = "
+ << test_params.probability << " total = " << test_params.total
+ << " marked = " << test_params.marked
+ << " sampled = " << test_params.sampled << ".";
+ auto p = HypergeometricDistribution::Create(
+ test_params.total, test_params.marked, test_params.sampled);
+ ASSERT_THAT(p, IsOk());
+ double result_lower = p.value()->FindQuantile(test_params.probability);
+ EXPECT_GE(result_lower, test_params.lower);
+ EXPECT_LE(result_lower, test_params.lower + 1);
+ FCP_LOG(INFO) << "Lower result = " << result_lower
+ << " which should be between " << test_params.lower << " and "
+ << test_params.lower + 1 << ".";
+ double result_upper = p.value()->FindQuantile(test_params.probability, true);
+ EXPECT_LE(result_upper, test_params.upper);
+ EXPECT_GE(result_upper, test_params.upper - 1);
+ FCP_LOG(INFO) << "Upper result = " << result_upper
+ << " which should be between " << test_params.upper - 1
+ << " and " << test_params.upper << ".";
+}
+
+INSTANTIATE_TEST_SUITE_P(HypergeometricQuantileTests, HypergeometricQuantile,
+ ::testing::ValuesIn<HypergeometricQuantileInstance>(
+ {{0.5, 10, 0, 5, -1, 0},
+ {0.2, 10, 10, 5, 4, 5},
+ {0.97, 15, 6, 12, 5, 3},
+ {0.0001, 98, 63, 17, 3, 17},
+ {1e-05, 187, 105, 43, 11, 36},
+ {3e-08, 980, 392, 103, 16, 67},
+ {1.1e-09, 1489, 312, 370, 38, 119},
+ {1e-18, 1000000, 200000, 500000, 98248,
+ 101751}}));
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/experiments_interface.h b/fcp/secagg/server/experiments_interface.h
new file mode 100644
index 0000000..48a455d
--- /dev/null
+++ b/fcp/secagg/server/experiments_interface.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_EXPERIMENTS_INTERFACE_H_
+#define FCP_SECAGG_SERVER_EXPERIMENTS_INTERFACE_H_
+
+#include "absl/strings/string_view.h"
+
+namespace fcp {
+namespace secagg {
+
+// Used to query named flags for experimental features.
+class ExperimentsInterface {
+ public:
+ // Returns true if the specified experiment is enabled;
+ // otherwise returns false.
+ virtual bool IsEnabled(absl::string_view experiment_name) = 0;
+
+ virtual ~ExperimentsInterface() = default;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_EXPERIMENTS_INTERFACE_H_
diff --git a/fcp/secagg/server/experiments_names.h b/fcp/secagg/server/experiments_names.h
new file mode 100644
index 0000000..eb35bdc
--- /dev/null
+++ b/fcp/secagg/server/experiments_names.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_EXPERIMENTS_NAMES_H_
+#define FCP_SECAGG_SERVER_EXPERIMENTS_NAMES_H_
+
+namespace fcp {
+namespace secagg {
+
+// Names of predefined experiments
+static constexpr char kFullgraphSecAggExperiment[] = "FULLGRAPH_SECAGG";
+static constexpr char kForceSubgraphSecAggExperiment[] =
+ "FORCE_SUBGRAPH_SECAGG_FOR_TEST";
+static constexpr char kSubgraphSecAggCuriousServerExperiment[] =
+ "SUBGRAPH_SECAGG_CURIOUS_SERVER";
+static constexpr char kSecAggAsyncRound2Experiment[] = "secagg_async_round_2";
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_EXPERIMENTS_NAMES_H_
diff --git a/fcp/secagg/server/graph_parameter_finder.cc b/fcp/secagg/server/graph_parameter_finder.cc
new file mode 100644
index 0000000..45d8b6a
--- /dev/null
+++ b/fcp/secagg/server/graph_parameter_finder.cc
@@ -0,0 +1,335 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/graph_parameter_finder.h"
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+#include <optional>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/distribution_utilities.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+class HararyGraphParameterFinder {
+ public:
+ // Checks that parameters have valid values and returns an instance of the
+ // class
+ static StatusOr<std::unique_ptr<HararyGraphParameterFinder>> Create(
+ int number_of_clients, double adversarial_rate, double dropout_rate,
+ AdversaryClass adversary_class) {
+ if (number_of_clients <= 0) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The number of clients should greater than zero. Value "
+ "provided = "
+ << number_of_clients;
+ }
+ if (number_of_clients > kMaxNumberOfClients) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The valid number of clients is upper bounded by 1M. There is "
+ "no "
+ "fundamental reason for that, and this "
+ "parameter finder should work for that setting. Just add the "
+ "corresponding tests.";
+ }
+ if (adversarial_rate < 0 || adversarial_rate >= 1) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The adversarial rate should be in [0,1). Value provided = "
+ << adversarial_rate;
+ }
+ if (dropout_rate < 0 || dropout_rate >= 1) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The dropout rate should be in [0,1). Value provided = "
+ << dropout_rate;
+ }
+ FCP_CHECK(adversary_class == AdversaryClass::CURIOUS_SERVER ||
+ adversary_class == AdversaryClass::SEMI_MALICIOUS_SERVER ||
+ adversary_class == AdversaryClass::NONE)
+ << "CURIOUS_SERVER, SEMI_MALICIOUS_SERVER, and NONE are the only "
+ "supported "
+ "adversary classes.";
+ if (adversary_class == AdversaryClass::NONE && adversarial_rate > 0) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The no-adversary setting expects that adversarial_rate = 0. "
+ "Value provided = "
+ << adversarial_rate;
+ }
+ if ((adversary_class == AdversaryClass::CURIOUS_SERVER ||
+ adversary_class == AdversaryClass::NONE) &&
+ adversarial_rate + dropout_rate > .9) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "In the semi-honest and no-adversary settings, "
+ "adversarial_rate + dropout_rate "
+ "<= 0.9 must hold for the instance to be feasible. Values "
+ "provided = "
+ << adversarial_rate << " and " << dropout_rate;
+ }
+ if (adversary_class == AdversaryClass::SEMI_MALICIOUS_SERVER &&
+ adversarial_rate + 2 * dropout_rate > .9) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "In the semi-malicious setting, adversarial_rate + "
+ "2*dropout_rate <= 0.9 must hold for the instance to be "
+ "feasible. Values provided = "
+ << adversarial_rate << " and " << dropout_rate;
+ }
+ return absl::WrapUnique(new HararyGraphParameterFinder(
+ number_of_clients, adversarial_rate, dropout_rate, adversary_class));
+ }
+
+ // Returns the degree of a Harary graph and threshold of a Shamir secret
+ // sharing scheme that result in an instance of subgraph-secagg with
+ // statistical security [kSecurityParameter] and failure probability less
+ // that 2**(-[kCorrectnessParameter]), assuming [number_of_clients_]
+ // participants and a fraction of [adversarial_rate_] (resp. [dropout_rate_])
+ // adversarial clients (resp. dropouts).
+ StatusOr<HararyGraphParameters> ComputeDegreeAndThreshold() {
+ for (int number_of_neighbors = 2;
+ number_of_neighbors < number_of_clients_ - 1;
+ number_of_neighbors += 2) {
+ auto threshold = CheckNumberOfNeighbors(number_of_neighbors);
+ if (threshold.has_value()) {
+ HararyGraphParameters params = {number_of_clients_, number_of_neighbors,
+ threshold.value()};
+ return params;
+ }
+ }
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Parameters consistent with the provided aggregation "
+ "requirements were not found. Unless the number of clients is "
+ "small (<500) and adversarial rate plus dropout rate are large "
+ "(close to 1), this is an unlikely error that should be "
+ "investigated as a bug.";
+ }
+
+ private:
+ static constexpr int kMaxNumberOfClients = 1000000;
+
+ // Statistical security parameter. Parameters found by
+ // HararyGraphParameterFinder guarantee statistical security with probability
+ // > 1-2^{-40}
+ static constexpr double kSecurityParameter = 40;
+ // Statistical correctness parameter. Parameters found by
+ // HararyGraphParameterFinder guarantee a failure probability < 2^{-20}
+ static constexpr double kCorrectnessParameter = 20;
+
+ // Adding kSmall stops the floating point rounding error being magnified in
+ // an insecure direction by the rounding. Being small it is unlikely to
+ // introduce an error the other way either but that would only slightly hit
+ // performance anyway.
+ static constexpr double kSmall = 1e-14;
+
+ int number_of_clients_;
+ double adversarial_rate_;
+ double dropout_rate_;
+ AdversaryClass adversary_class_;
+
+ // Returns a bound on the probability p of a random harary graph of
+ // number_of_clients_ nodes and degree number_of_neighbors being disconnected
+ // after adversarial and dropout nodes are removed.
+ double LogProbOfDisconnectedRandomHarary(int number_of_neighbors) {
+ // Note that any set of removed nodes disconnecting a Harary graph needs to
+ // include number_of_neighbors/2 successive clients in two disjoint places
+ // (in the ring forming the Harary graph). In our setting that means that
+ // these clients need to be adversarial or dropouts. There are
+ // number_of_clients_ ways to pick the first break and
+ // number_of_clients_-number_of_neighbors-1 ways to pick the other, and this
+ // double counts the pairs so dividing by two gives the number of gap pairs.
+ // For each pair this probability is given by the ratio of the given
+ // factorials. Multiplying by number_of_gap_pairs bounds the total
+ // probability by a union bound.
+ double log_number_of_gap_pairs =
+ std::log(number_of_clients_) +
+ std::log(number_of_clients_ - number_of_neighbors - 1) - std::log(2);
+ int max_adversarial_clients = static_cast<int>(
+ std::floor((adversarial_rate_ + kSmall) * number_of_clients_));
+ int max_dropout_clients = static_cast<int>(
+ std::floor((dropout_rate_ + kSmall) * number_of_clients_));
+ int max_bad_clients = max_adversarial_clients + max_dropout_clients;
+
+ if (number_of_neighbors > max_bad_clients) return -HUGE_VAL;
+ double ret = log_number_of_gap_pairs + std::lgamma(max_bad_clients + 1) +
+ std::lgamma(number_of_clients_ - number_of_neighbors + 1) -
+ std::lgamma(number_of_clients_ + 1) -
+ std::lgamma(max_bad_clients - number_of_neighbors + 1);
+ return ret;
+ }
+
+ // Checks if degree number_of_neighbors_ results in a secure and correct
+ // protocol, and returns an appropriate threshold if so.
+ std::optional<int> CheckNumberOfNeighbors(int number_of_neighbors) {
+ // We split the security parameter evenly across the two bad events
+ constexpr double kSecurityParameterPerEvent = kSecurityParameter + 1;
+ const double kLogProbSecurityParameterPerEvent =
+ -kSecurityParameterPerEvent * std::log(2);
+ // We first check that the graph of honest surviving nodes is connected with
+ // large enough probability
+ if (LogProbOfDisconnectedRandomHarary(number_of_neighbors) >=
+ kLogProbSecurityParameterPerEvent) {
+ // The probability of the graph getting disconnected is not small enough
+ return std::nullopt;
+ }
+
+ // We now check find threshold t such that (a) for every client i, the
+ // number of adversarial neighbors of i is greater than t-1 with small
+ // enough probability, and (b) that the number of surviving nodes of a
+ // client is greater than t-1 with large enough probability.
+ int upper_bound_adversarial_clients = static_cast<int>(
+ std::floor((adversarial_rate_ + kSmall) * number_of_clients_));
+ int lower_bound_surviving_clients = static_cast<int>(
+ std::ceil((1 - dropout_rate_ - kSmall) * number_of_clients_));
+ // Distribution of the number of adversarial neighbors of a client
+ auto num_adversarial_neighbors = HypergeometricDistribution::Create(
+ number_of_clients_, upper_bound_adversarial_clients,
+ number_of_neighbors);
+ // Distribution of the number of dropout neighbors of a client
+ auto num_surviving_neighbors = HypergeometricDistribution::Create(
+ number_of_clients_, lower_bound_surviving_clients, number_of_neighbors);
+
+ // t1 is such that Pr[# of adversarial neighbors of a client > t1] <=
+ // 2^{-security_parameter_per_event} / number_of_clients_
+
+ // t2 is such that Pr[# of adversarial neighbors of a client <= t2] >=
+ // 2^{-correctness_parameter_} / number_of_clients_
+
+ // The result of the below quantile functions is an integer result rounded
+ // outwards i.e. away from the median of the distribution.
+ auto t1 = num_adversarial_neighbors.value()->FindQuantile(
+ std::pow(2, -kSecurityParameterPerEvent) / number_of_clients_, true);
+ auto t2 = num_surviving_neighbors.value()->FindQuantile(
+ std::pow(2, -kCorrectnessParameter) / number_of_clients_);
+ if (num_surviving_neighbors.value()->CDF(t2) <
+ std::pow(2, -kCorrectnessParameter) / number_of_clients_) {
+ t2++;
+ }
+
+ // In the semihonest case, the returned threshold must satisfy that t \in
+ // (t1, t2]. To save computation when reconstructing shamir shares we simply
+ // choose t1+1.
+
+ // In the semi-malicious case t should be such that Pr[2t -
+ // number_of_neighbors] < 2^{-security_parameter_ + 1} / number_of_clients_,
+ // and thus we set t1 + 1 to be at least t1/2 + number_of_neighbors/2 + 1/2,
+ // so that also in this case (t1, t2] defines the range of acceptable values
+ // for t
+
+ if (adversary_class_ == AdversaryClass::SEMI_MALICIOUS_SERVER) {
+ t1 = std::ceil((t1 + number_of_neighbors - 1.25) / 2);
+ }
+ if (t2 <= t1) {
+ return std::nullopt;
+ }
+ // The Shamir secret sharing implementation requires that threshold is >= 2
+ return std::max(t1 + 1, 2.);
+ }
+
+ HararyGraphParameterFinder(int number_of_clients, double adversarial_rate,
+ double dropout_rate,
+ AdversaryClass adversary_class)
+ : number_of_clients_(number_of_clients),
+ adversarial_rate_(adversarial_rate),
+ dropout_rate_(dropout_rate),
+ adversary_class_(adversary_class) {}
+};
+
+StatusOr<HararyGraphParameters> ComputeHararyGraphParameters(
+ int number_of_clients, SecureAggregationRequirements threat_model) {
+ FCP_ASSIGN_OR_RETURN(
+ auto pf, HararyGraphParameterFinder::Create(
+ number_of_clients, threat_model.adversarial_client_rate(),
+ threat_model.estimated_dropout_rate(),
+ threat_model.adversary_class()));
+ return pf->ComputeDegreeAndThreshold();
+}
+
+Status CheckFullGraphParameters(int number_of_clients, int threshold,
+ SecureAggregationRequirements threat_model) {
+ if (number_of_clients <= 0) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The number of clients should greater than zero. Value "
+ "provided = "
+ << number_of_clients;
+ }
+ if (threshold <= 1 || threshold > number_of_clients) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The threshold should be > 1 and <= number_of_clients. Value "
+ "provided = "
+ << threshold;
+ }
+ if (threat_model.adversarial_client_rate() < 0 ||
+ threat_model.adversarial_client_rate() >= 1) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The adversarial rate should be in [0,1). Value provided = "
+ << threat_model.adversarial_client_rate();
+ }
+ if (threat_model.estimated_dropout_rate() < 0 ||
+ threat_model.estimated_dropout_rate() >= 1) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The dropout rate should be in [0,1). Value provided = "
+ << threat_model.estimated_dropout_rate();
+ }
+ FCP_CHECK(threat_model.adversary_class() == AdversaryClass::CURIOUS_SERVER ||
+ threat_model.adversary_class() ==
+ AdversaryClass::SEMI_MALICIOUS_SERVER ||
+ threat_model.adversary_class() == AdversaryClass::NONE)
+ << "CURIOUS_SERVER, SEMI_MALICIOUS_SERVER, and NONE are the only "
+ "supported "
+ "adversary classes.";
+ if (threshold < std::ceil((1 - threat_model.estimated_dropout_rate()) *
+ number_of_clients)) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The threshold should be at least ceil(1 - dropout_rate * "
+ "number_of_clients). Values provided are "
+ << threshold << ", "
+ << std::floor((1 - threat_model.estimated_dropout_rate()) *
+ number_of_clients);
+ }
+ if (threat_model.adversary_class() == AdversaryClass::CURIOUS_SERVER &&
+ threshold <= std::ceil(threat_model.adversarial_client_rate() *
+ number_of_clients)) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "In the full-graph variant and CURIOUS_SERVER threat model, the "
+ "threshold should be at least ceil(adversarial_client_rate * "
+ "number_of_clients). Values provided are "
+ << threshold << ", "
+ << std::ceil(threat_model.adversarial_client_rate() *
+ number_of_clients);
+ } else if (threat_model.adversary_class() ==
+ AdversaryClass::SEMI_MALICIOUS_SERVER &&
+ threshold <= std::ceil((number_of_clients +
+ threat_model.adversarial_client_rate() *
+ number_of_clients) /
+ 2)) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "In the full-graph variant and SEMI_MALICIOUS_SERVER threat "
+ "model, the threshold should be at least "
+ "ceil((total_number_of_clients "
+ "+ adversarial_client_rate * number_of_clients) / 2). "
+ "Values provided are "
+ << threshold << ", "
+ << (number_of_clients +
+ std::ceil(threat_model.adversarial_client_rate() *
+ number_of_clients) /
+ 2);
+ }
+ return FCP_STATUS(OK);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/graph_parameter_finder.h b/fcp/secagg/server/graph_parameter_finder.h
new file mode 100644
index 0000000..4438a71
--- /dev/null
+++ b/fcp/secagg/server/graph_parameter_finder.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_GRAPH_PARAMETER_FINDER_H_
+#define FCP_SECAGG_SERVER_GRAPH_PARAMETER_FINDER_H_
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// Represents the parameters that define a SecretSharingHararyGraph: its
+// size, its degree (as Harary graphs are regular), and the threshold for
+// reconstruction
+struct HararyGraphParameters {
+ int number_of_nodes;
+ int degree;
+ int threshold;
+};
+
+// Returns the HararyGraphParameters that result in an instance of
+// subgraph-secagg with statistical security [kSecurityParameter] and failure
+// probability less that 2**(-[kCorrectnessParameter]), assuming
+// [number_of_clients_] participants and the threat model (adversarial rate,
+// dropout rate, and adversary class) defined in [threat_model].
+StatusOr<HararyGraphParameters> ComputeHararyGraphParameters(
+ int number_of_clients, SecureAggregationRequirements threat_model);
+
+// Check if the provided threshold [threshold] results in a secure protocol with
+// [number_of_clients] clients and the parameters and adversary specified in
+// [threat_model]
+Status CheckFullGraphParameters(int number_of_clients, int threshold,
+ SecureAggregationRequirements threat_model);
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_GRAPH_PARAMETER_FINDER_H_
diff --git a/fcp/secagg/server/graph_parameter_finder_test.cc b/fcp/secagg/server/graph_parameter_finder_test.cc
new file mode 100644
index 0000000..3c180bb
--- /dev/null
+++ b/fcp/secagg/server/graph_parameter_finder_test.cc
@@ -0,0 +1,572 @@
+#ifndef THIRD_PARTY_FCP_SECAGG_SERVER_GRAPH_PARAMETER_FINDER_TEST_CC_
+#define THIRD_PARTY_FCP_SECAGG_SERVER_GRAPH_PARAMETER_FINDER_TEST_CC_
+
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/graph_parameter_finder.h"
+
+#include <algorithm>
+#include <iostream>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_messages.pb.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+struct HararyGraphParameterFinderParams {
+ const std::string test_name;
+ const int kNumClients;
+ const double kAdversarialRate;
+ const double kDropoutRate;
+ const AdversaryClass kAdversaryClass;
+ const int kExpectedDegree;
+ const int kExpectedThreshold;
+};
+
+class HararyGraphParameterFinderTest_Feasible
+ : public ::testing::TestWithParam<HararyGraphParameterFinderParams> {};
+
+TEST_P(HararyGraphParameterFinderTest_Feasible,
+ ComputesParametersThatMatchPrecomputedValues) {
+ // This test computes parameters for feasible instances. It checks that the
+ // obtained degree (number of neighbors) and threshold match a precomputed
+ // value.
+ const HararyGraphParameterFinderParams& test_params = GetParam();
+ SecureAggregationRequirements threat_model;
+ int security_parameter = 40;
+ int correctness_parameter = 20;
+ FCP_LOG(INFO)
+ << "Running HararyGraphParameterFinder on instance with num. clients = "
+ << test_params.kNumClients
+ << ", adversarial rate = " << test_params.kAdversarialRate
+ << ", dropout rate = " << test_params.kDropoutRate
+ << ", security parameter = " << security_parameter
+ << ", correctness parameter = " << correctness_parameter
+ << ", adversary class = "
+ << (test_params.kAdversaryClass == AdversaryClass::CURIOUS_SERVER
+ ? "CURIOUS_SERVER"
+ : "SEMI_MALICIOUS_SERVER");
+ threat_model.set_adversarial_client_rate(test_params.kAdversarialRate);
+ threat_model.set_estimated_dropout_rate(test_params.kDropoutRate);
+ threat_model.set_adversary_class(test_params.kAdversaryClass);
+ auto computed_params =
+ ComputeHararyGraphParameters(test_params.kNumClients, threat_model);
+ EXPECT_EQ(computed_params.ok(), true);
+ int degree = computed_params.value().degree;
+ int threshold = computed_params.value().threshold;
+ FCP_LOG(INFO) << "Secure parameters were found: degree = " << degree
+ << ", threshold = " << threshold;
+ int expected_degree = test_params.kExpectedDegree;
+ int expected_threshold = test_params.kExpectedThreshold;
+ EXPECT_EQ(degree, expected_degree);
+ FCP_LOG(INFO) << "degree = " << degree
+ << " expected_degree = " << expected_degree;
+ EXPECT_LE(threshold, expected_threshold);
+ FCP_LOG(INFO) << "threshold = " << threshold
+ << " expected_threshold = " << expected_threshold;
+}
+
+TEST_P(HararyGraphParameterFinderTest_Feasible,
+ ComputesParametersWithinExpectedRange) {
+ // This test computes parameters for feasible instances. It checks that the
+ // obtained degree (number of neighbors) is in between the analytical lower
+ // and upper bounds.
+ const HararyGraphParameterFinderParams& test_params = GetParam();
+ SecureAggregationRequirements threat_model;
+ int security_parameter = 40;
+ int correctness_parameter = 20;
+ FCP_LOG(INFO) << "Running HararyGraphParameterFinder on instance with num. "
+ "clients = "
+ << test_params.kNumClients
+ << ", adversarial rate = " << test_params.kAdversarialRate
+ << ", dropout rate = " << test_params.kDropoutRate
+ << ", security parameter = " << security_parameter
+ << ", correctness parameter = " << correctness_parameter
+ << ", adversary class = "
+ << (test_params.kAdversaryClass ==
+ AdversaryClass::CURIOUS_SERVER
+ ? "CURIOUS_SERVER"
+ : "SEMI_MALICIOUS_SERVER");
+ threat_model.set_adversarial_client_rate(test_params.kAdversarialRate);
+ threat_model.set_estimated_dropout_rate(test_params.kDropoutRate);
+ threat_model.set_adversary_class(test_params.kAdversaryClass);
+ auto computed_params =
+ ComputeHararyGraphParameters(test_params.kNumClients, threat_model);
+ EXPECT_EQ(computed_params.ok(), true);
+ int degree = computed_params.value().degree;
+ int threshold = computed_params.value().threshold;
+ FCP_LOG(INFO) << "Secure parameters were found: degree = " << degree
+ << ", threshold = " << threshold;
+
+ bool unconstrained_instance =
+ test_params.kAdversarialRate == 0 && test_params.kDropoutRate == 0;
+ bool small_instance = test_params.kNumClients < 20;
+ // The degree lower bound this enforces doesn't fit with the security
+ // guarantee that the rest of this code is designed to provide. Clearing up
+ // what the security guarantee should be is b/260400215 and this test should
+ // be cleared up in addressing that. Until then it is switched off for small
+ // variables where it behaves at odds with the model used elsewhere.
+ double degree_lower_bound =
+ unconstrained_instance || small_instance
+ ? 1
+ : log(test_params.kNumClients) + security_parameter * log(2) / 5.;
+
+ double beta = static_cast<double>(threshold) / degree;
+ double alpha = test_params.kNumClients / (test_params.kNumClients - 1);
+ double a = log(test_params.kNumClients) + correctness_parameter * log(2);
+ double b = 2 * pow(alpha * (1 - test_params.kDropoutRate) - beta, 2);
+ double c =
+ std::min(2 * pow(beta - alpha * test_params.kAdversarialRate, 2),
+ -log(test_params.kAdversarialRate + test_params.kDropoutRate));
+ double degree_upper_bound =
+ unconstrained_instance ? 3 : std::max(degree_lower_bound / c + 1, a / b);
+ // We increase the upper bound slightly for the semi-malicious variant
+ if (test_params.kAdversaryClass == AdversaryClass::SEMI_MALICIOUS_SERVER) {
+ degree_upper_bound += degree_upper_bound * 1. / 5;
+ }
+ EXPECT_GT(degree, degree_lower_bound);
+ EXPECT_LT(degree, degree_upper_bound);
+ EXPECT_GE(degree, threshold);
+ EXPECT_GT(threshold, 0);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ HararyGraphParameterFinderTests, HararyGraphParameterFinderTest_Feasible,
+ testing::ValuesIn<HararyGraphParameterFinderParams>({
+ // adversarial_rate = 0.45, dropout_rate = 0.45, adversary_class =
+ // semihonest, for number_of_clients in {10^i | i\in {2,3,4,5,6}}
+ {"100_clients__security_40__correctness_20__adversaryrate_045__"
+ "dropoutrate_045__adversary_class__semihonest",
+ 100, 0.45, 0.45, AdversaryClass::CURIOUS_SERVER, 92, 46},
+ {"1000_clients__security_40__correctness_20__adversaryrate_045__"
+ "dropoutrate_045__adversary_class__semihonest",
+ 1000, 0.45, 0.45, AdversaryClass::CURIOUS_SERVER, 822, 417},
+ {"10000_clients__security_40__correctness_20__adversaryrate_045__"
+ "dropoutrate_045__adversary_class__semihonest",
+ 10000, 0.45, 0.45, AdversaryClass::CURIOUS_SERVER, 3494, 1771},
+ {"100000_clients__security_40__correctness_20__adversaryrate_045__"
+ "dropoutrate_045__adversary_class__semihonest",
+ 100000, 0.45, 0.45, AdversaryClass::CURIOUS_SERVER, 5508, 2788},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_045__"
+ "dropoutrate_045__adversary_class__semihonest",
+ 1000000, 0.45, 0.45, AdversaryClass::CURIOUS_SERVER, 6240, 3156},
+ // adversarial_rate = 0.33, dropout_rate = 0.33, adversary_class =
+ // semihonest, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 10, 0.33, 0.33, AdversaryClass::CURIOUS_SERVER, 8, 4},
+ {"100_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 100, 0.33, 0.33, AdversaryClass::CURIOUS_SERVER, 68, 34},
+ {"1000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 1000, 0.33, 0.33, AdversaryClass::CURIOUS_SERVER, 286, 150},
+ {"10000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 10000, 0.33, 0.33, AdversaryClass::CURIOUS_SERVER, 422, 221},
+ {"100000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 100000, 0.33, 0.33, AdversaryClass::CURIOUS_SERVER, 480, 251},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 1000000, 0.33, 0.33, AdversaryClass::CURIOUS_SERVER, 518, 270},
+ // adversarial_rate = 0.3, dropout_rate = 0.3, adversary_class =
+ // semimalicious, for number_of_clients in {10^i | i\in {2,3,4,5,6}}
+ {"100_clients__security_40__correctness_20__adversaryrate_03__"
+ "dropoutrate_03__adversary_class__semimalicious",
+ 100, 0.3, 0.3, AdversaryClass::SEMI_MALICIOUS_SERVER, 92, 62},
+ {"1000_clients__security_40__correctness_20__adversaryrate_03__"
+ "dropoutrate_03__adversary_class__semimalicious",
+ 1000, 0.3, 0.3, AdversaryClass::SEMI_MALICIOUS_SERVER, 872, 584},
+ {"10000_clients__security_40__correctness_20__adversaryrate_03__"
+ "dropoutrate_03__adversary_class__semimalicious",
+ 10000, 0.3, 0.3, AdversaryClass::SEMI_MALICIOUS_SERVER, 4822, 3230},
+ {"100000_clients__security_40__correctness_20__adversaryrate_03__"
+ "dropoutrate_03__adversary_class__semimalicious",
+ 100000, 0.3, 0.3, AdversaryClass::SEMI_MALICIOUS_SERVER, 9388, 6286},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_03__"
+ "dropoutrate_03__adversary_class__semimalicious",
+ 1000000, 0.3, 0.3, AdversaryClass::SEMI_MALICIOUS_SERVER, 11136, 7454},
+ // adversarial_rate = 0.05, dropout_rate = 0.33, adversary_class =
+ // semihonest, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 10, 0.05, 0.33, AdversaryClass::CURIOUS_SERVER, 4, 2},
+ {"100_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 100, 0.05, 0.33, AdversaryClass::CURIOUS_SERVER, 28, 6},
+ {"1000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 1000, 0.05, 0.33, AdversaryClass::CURIOUS_SERVER, 72, 24},
+ {"10000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 10000, 0.05, 0.33, AdversaryClass::CURIOUS_SERVER, 86, 29},
+ {"100000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 100000, 0.05, 0.33, AdversaryClass::CURIOUS_SERVER, 96, 32},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 1000000, 0.05, 0.33, AdversaryClass::CURIOUS_SERVER, 102, 34},
+ // adversarial_rate = 0.05, dropout_rate = 0.33, adversary_class =
+ // semimalicious, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 10, 0.05, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 8, 5},
+ {"100_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 100, 0.05, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 72, 39},
+ {"1000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 1000, 0.05, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 398, 223},
+ {"10000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 10000, 0.05, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 746, 420},
+ {"100000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 100000, 0.05, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 882, 496},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 1000000, 0.05, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 970, 545},
+ // adversarial_rate = 0.33, dropout_rate = 0.05, adversary_class =
+ // semihonest, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 10, 0.33, 0.05, AdversaryClass::CURIOUS_SERVER, 4, 4},
+ {"100_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 100, 0.33, 0.05, AdversaryClass::CURIOUS_SERVER, 34, 29},
+ {"1000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 1000, 0.33, 0.05, AdversaryClass::CURIOUS_SERVER, 78, 60},
+ {"10000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 10000, 0.33, 0.05, AdversaryClass::CURIOUS_SERVER, 92, 70},
+ {"100000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 100000, 0.33, 0.05, AdversaryClass::CURIOUS_SERVER, 102, 77},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 1000000, 0.33, 0.05, AdversaryClass::CURIOUS_SERVER, 110, 83},
+ // adversarial_rate = 0.33, dropout_rate = 0.05, adversary_class =
+ // semimalicious, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 10, 0.33, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 4, 4},
+ {"100_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 100, 0.33, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 42, 37},
+ {"1000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 1000, 0.33, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 132, 109},
+ {"10000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 10000, 0.33, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 178, 146},
+ {"100000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 100000, 0.33, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 196, 160},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 1000000, 0.33, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 212, 173},
+ // adversarial_rate = 0.05, dropout_rate = 0.05, adversary_class =
+ // semihonest, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 10, 0.05, 0.05, AdversaryClass::CURIOUS_SERVER, 2, 2},
+ {"100_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 100, 0.05, 0.05, AdversaryClass::CURIOUS_SERVER, 12, 6},
+ {"1000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 1000, 0.05, 0.05, AdversaryClass::CURIOUS_SERVER, 30, 17},
+ {"10000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 10000, 0.05, 0.05, AdversaryClass::CURIOUS_SERVER, 34, 20},
+ {"100000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 100000, 0.05, 0.05, AdversaryClass::CURIOUS_SERVER, 36, 21},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semihonest",
+ 1000000, 0.05, 0.05, AdversaryClass::CURIOUS_SERVER, 38, 22},
+ // adversarial_rate = 0.05, dropout_rate = 0.05, adversary_class =
+ // semimalicious, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 10, 0.05, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 2, 2},
+ {"100_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 100, 0.05, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 16, 11},
+ {"1000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 1000, 0.05, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 52, 37},
+ {"10000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 10000, 0.05, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 62, 44},
+ {"100000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 100000, 0.05, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 68, 48},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_005__"
+ "dropoutrate_005__adversary_class__semimalicious",
+ 1000000, 0.05, 0.05, AdversaryClass::SEMI_MALICIOUS_SERVER, 74, 52},
+ // adversarial_rate = 0.33, dropout_rate = 0.0, adversary_class =
+ // semihonest, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 10, 0.33, 0.0, AdversaryClass::CURIOUS_SERVER, 4, 4},
+ {"100_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 100, 0.33, 0.0, AdversaryClass::CURIOUS_SERVER, 26, 25},
+ {"1000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 1000, 0.33, 0.0, AdversaryClass::CURIOUS_SERVER, 38, 36},
+ {"10000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 10000, 0.33, 0.0, AdversaryClass::CURIOUS_SERVER, 42, 40},
+ {"100000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 100000, 0.33, 0.0, AdversaryClass::CURIOUS_SERVER, 46, 44},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 1000000, 0.33, 0.0, AdversaryClass::CURIOUS_SERVER, 50, 47},
+ // adversarial_rate = 0.33, dropout_rate = 0.0, adversary_class =
+ // semimalicious, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 10, 0.33, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 4, 4},
+ {"100_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 100, 0.33, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 26, 26},
+ {"1000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 1000, 0.33, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 38, 37},
+ {"10000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 10000, 0.33, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 42, 41},
+ {"100000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 100000, 0.33, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 46, 45},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 1000000, 0.33, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 50, 49},
+ // adversarial_rate = 0.0, dropout_rate = 0.33, adversary_class =
+ // semihonest, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 10, 0.0, 0.33, AdversaryClass::CURIOUS_SERVER, 4, 2},
+ {"100_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 100, 0.0, 0.33, AdversaryClass::CURIOUS_SERVER, 26, 2},
+ {"1000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 1000, 0.0, 0.33, AdversaryClass::CURIOUS_SERVER, 38, 2},
+ {"10000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 10000, 0.0, 0.33, AdversaryClass::CURIOUS_SERVER, 42, 2},
+ {"100000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 100000, 0.0, 0.33, AdversaryClass::CURIOUS_SERVER, 46, 2},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semihonest",
+ 1000000, 0.0, 0.33, AdversaryClass::CURIOUS_SERVER, 50, 2},
+ // adversarial_rate = 0.0, dropout_rate = 0.33, adversary_class =
+ // none, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__none",
+ 10, 0.0, 0.33, AdversaryClass::NONE, 4, 2},
+ {"100_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__none",
+ 100, 0.0, 0.33, AdversaryClass::NONE, 26, 2},
+ {"1000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__none",
+ 1000, 0.0, 0.33, AdversaryClass::NONE, 38, 2},
+ {"10000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__none",
+ 10000, 0.0, 0.33, AdversaryClass::NONE, 42, 2},
+ {"100000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__none",
+ 100000, 0.0, 0.33, AdversaryClass::NONE, 46, 2},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__none",
+ 1000000, 0.0, 0.33, AdversaryClass::NONE, 50, 2},
+ // adversarial_rate = 0.0, dropout_rate = 0.33, adversary_class =
+ // semimalicious, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 10, 0.0, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 8, 5},
+ {"100_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 100, 0.0, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 68, 35},
+ {"1000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 1000, 0.0, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 228, 115},
+ {"10000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 10000, 0.0, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 326, 164},
+ {"100000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 100000, 0.0, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 372, 187},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 1000000, 0.0, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 410, 206},
+ // adversarial_rate = 0.0, dropout_rate = 0.0, adversary_class =
+ // semihonest, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 10, 0.0, 0.0, AdversaryClass::CURIOUS_SERVER, 2, 2},
+ {"100_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 100, 0.0, 0.0, AdversaryClass::CURIOUS_SERVER, 2, 2},
+ {"1000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 1000, 0.0, 0.0, AdversaryClass::CURIOUS_SERVER, 2, 2},
+ {"10000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 10000, 0.0, 0.0, AdversaryClass::CURIOUS_SERVER, 2, 2},
+ {"100000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 100000, 0.0, 0.0, AdversaryClass::CURIOUS_SERVER, 2, 2},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semihonest",
+ 1000000, 0.0, 0.0, AdversaryClass::CURIOUS_SERVER, 2, 2},
+ // adversarial_rate = 0.0, dropout_rate = 0.0, adversary_class =
+ // none, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"1000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__none",
+ 1000, 0.0, 0.0, AdversaryClass::NONE, 2, 2},
+ {"10000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__none",
+ 10000, 0.0, 0.0, AdversaryClass::NONE, 2, 2},
+ {"100000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__none",
+ 100000, 0.0, 0.0, AdversaryClass::NONE, 2, 2},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__none",
+ 1000000, 0.0, 0.0, AdversaryClass::NONE, 2, 2},
+ // adversarial_rate = 0.0, dropout_rate = 0.0, adversary_class =
+ // semimalicious, for number_of_clients in {10^i | i\in {1,2,3,4,5,6}}
+ {"10_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 10, 0.0, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 2, 2},
+ {"100_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 100, 0.0, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 2, 2},
+ {"1000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 1000, 0.0, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 2, 2},
+ {"10000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 10000, 0.0, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 2, 2},
+ {"100000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 100000, 0.0, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 2, 2},
+ {"1000000_clients__security_40__correctness_20__adversaryrate_00__"
+ "dropoutrate_00__adversary_class__semimalicious",
+ 1000000, 0.0, 0.0, AdversaryClass::SEMI_MALICIOUS_SERVER, 2, 2},
+ }),
+ [](const ::testing::TestParamInfo<
+ HararyGraphParameterFinderTest_Feasible::ParamType>& info) {
+ return info.param.test_name;
+ });
+
+class HararyGraphParameterFinderTest_InvalidOrUnfeasible
+ : public ::testing::TestWithParam<HararyGraphParameterFinderParams> {};
+
+TEST_P(HararyGraphParameterFinderTest_InvalidOrUnfeasible,
+ FailsOnIncorrectParameters) {
+ // This test tries to compute parameters for invalid (parameters with
+ // incorrect values) or unfeasible (combinations of valid parameter values
+ // that make the problem unsolvable) instances.
+ const HararyGraphParameterFinderParams& test_params = GetParam();
+ SecureAggregationRequirements threat_model;
+ threat_model.set_adversarial_client_rate(test_params.kAdversarialRate);
+ threat_model.set_estimated_dropout_rate(test_params.kDropoutRate);
+ threat_model.set_adversary_class(test_params.kAdversaryClass);
+ auto computed_params =
+ ComputeHararyGraphParameters(test_params.kNumClients, threat_model);
+ EXPECT_EQ(computed_params.ok(), false);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ HararyGraphParameterFinderTests,
+ HararyGraphParameterFinderTest_InvalidOrUnfeasible,
+ ::testing::ValuesIn<HararyGraphParameterFinderParams>({
+ {"0_clients__security_40__correctness_20__adversaryrate_01__"
+ "dropoutrate_01__adversary_class__semihonest",
+ 0, 0.1, 0.1, AdversaryClass::CURIOUS_SERVER, 0, 0},
+ {"1000_clients__security_40__correctness_20__adversaryrate_1__"
+ "dropoutrate_1__adversary_class__semihonest",
+ 1000, 0.1, 1., AdversaryClass::CURIOUS_SERVER, 0, 0},
+ {"1000_clients__security_40__correctness_20__adversaryrate_01__"
+ "dropoutrate_minus1__adversary_class__semihonest",
+ 1000, 0.1, -1., AdversaryClass::CURIOUS_SERVER, 0, 0},
+ // For semi_honest/honest-but-curious adversary, we need that
+ // adversary_rate + dropout_rate < 1 for the instance to be feasible
+ {"1000_clients__security_40__correctness_20__adversaryrate_05__"
+ "dropoutrate_05__adversary_class__semihonest",
+ 1000, 0.5, 0.5, AdversaryClass::CURIOUS_SERVER, 0, 0},
+ // For semi_malicious adversary, we need that adversary_rate +
+ // 2*dropout_rate < 1 for the instance to be feasible
+ {"1000_clients__security_40__correctness_20__adversaryrate_05__"
+ "dropoutrate_05__adversary_class__semimalicious",
+ 1000, 0.5, 0.5, AdversaryClass::SEMI_MALICIOUS_SERVER, 0, 0},
+ {"1000_clients__security_40__correctness_20__adversaryrate_033__"
+ "dropoutrate_033__adversary_class__semimalicious",
+ 1000, 0.33, 0.33, AdversaryClass::SEMI_MALICIOUS_SERVER, 0, 0},
+ // For the no-adversary setting we expect adversary_rate == 0
+ {"1000_clients__security_40__correctness_20__adversaryrate_05__"
+ "dropoutrate_05__adversary_class__none",
+ 1000, 0.5, 0.5, AdversaryClass::NONE, 0, 0},
+ }),
+ [](const ::testing::TestParamInfo<
+ HararyGraphParameterFinderTest_InvalidOrUnfeasible::ParamType>& info) {
+ return info.param.test_name;
+ });
+
+TEST(FullGraphCeckParamsTest, ReturnsTrueOnValidThresholds) {
+ SecureAggregationRequirements threat_model;
+ threat_model.set_adversarial_client_rate(.05);
+ threat_model.set_estimated_dropout_rate(.3);
+ threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
+ int num_clients = 60;
+ for (int t = 42; t < num_clients; t++) {
+ EXPECT_THAT(CheckFullGraphParameters(num_clients, t, threat_model).ok(),
+ true)
+ << t;
+ }
+}
+
+TEST(FullGraphCeckParamsTest, ReturnsFalseOnInvalidThresholds) {
+ SecureAggregationRequirements threat_model;
+ threat_model.set_adversarial_client_rate(.05);
+ threat_model.set_estimated_dropout_rate(.3);
+ threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
+ int num_clients = 60;
+ for (int t = 0; t < 42; t++) {
+ EXPECT_THAT(CheckFullGraphParameters(num_clients, t, threat_model).ok(),
+ false);
+ }
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
+
+#endif // THIRD_PARTY_FCP_SECAGG_SERVER_GRAPH_PARAMETER_FINDER_TEST_CC_
diff --git a/fcp/secagg/server/secagg_scheduler.cc b/fcp/secagg/server/secagg_scheduler.cc
new file mode 100644
index 0000000..eb5db06
--- /dev/null
+++ b/fcp/secagg/server/secagg_scheduler.cc
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_scheduler.h"
+
+#include <functional>
+
+namespace fcp {
+namespace secagg {
+
+void SecAggScheduler::WaitUntilIdle() {
+ parallel_scheduler_->WaitUntilIdle();
+ sequential_scheduler_->WaitUntilIdle();
+}
+
+void SecAggScheduler::RunSequential(std::function<void()> function) {
+ sequential_scheduler_->Schedule(function);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_scheduler.h b/fcp/secagg/server/secagg_scheduler.h
new file mode 100644
index 0000000..6d0d004
--- /dev/null
+++ b/fcp/secagg/server/secagg_scheduler.h
@@ -0,0 +1,346 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_
+#define FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_
+
+#include <atomic>
+#include <functional>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "fcp/base/clock.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/reentrancy_guard.h"
+#include "fcp/base/scheduler.h"
+
+namespace fcp {
+namespace secagg {
+
+// Simple callback waiter that runs the function on Wakeup.
+class CallbackWaiter : public Clock::Waiter {
+ public:
+ explicit CallbackWaiter(std::function<void()> callback)
+ : callback_(std::move(callback)) {}
+
+ void WakeUp() override { callback_(); }
+
+ private:
+ std::function<void()> callback_;
+};
+
+// Provides Cancellation mechanism for SevAggScheduler.
+class CancellationImpl {
+ public:
+ virtual ~CancellationImpl() = default;
+
+ // Calling Cancel results in skipping the remaining, still pending
+ // ParallelGenerateSequentialReduce. The call blocks waiting for any
+ // currently active ongoing tasks to complete. Calling Cancel for the second
+ // time has no additional effect.
+ virtual void Cancel() = 0;
+};
+
+using CancellationToken = std::shared_ptr<CancellationImpl>;
+
+template <typename T>
+class Accumulator : public CancellationImpl,
+ public std::enable_shared_from_this<Accumulator<T>> {
+ public:
+ Accumulator(
+ std::unique_ptr<T> initial_value,
+ std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func,
+ Scheduler* parallel_scheduler, Scheduler* sequential_scheduler,
+ Clock* clock)
+ : parallel_scheduler_(parallel_scheduler),
+ sequential_scheduler_(sequential_scheduler),
+ accumulated_value_(std::move(initial_value)),
+ accumulator_func_(accumulator_func),
+ clock_(clock) {}
+
+ inline static std::function<void()> GetParallelScheduleFunc(
+ std::shared_ptr<Accumulator<T>> accumulator,
+ std::function<std::unique_ptr<T>()> generator) {
+ return [accumulator, generator] {
+ // Increment active count if the accumulator is not canceled, otherwise
+ // return without scheduling the task. By active count we mean the total
+ // number of scheduled tasks, both parallel and sequential. To cancel an
+ // accumulator, we wait until that this count is 0.
+ if (!accumulator->MaybeIncrementActiveCount()) {
+ return;
+ }
+ auto partial = generator();
+ FCP_CHECK(partial);
+ // Decrement the count for the parallel task that was just run as
+ // generator().
+ accumulator->DecrementActiveCount();
+ // Schedule sequential part of the generator, only if accumulator is not
+ // cancelled, otherwise return without scheduling it.
+ if (accumulator->IsCancelled()) {
+ return;
+ }
+ accumulator->RunSequential(
+ [=, partial = std::shared_ptr<T>(partial.release())] {
+ ReentrancyGuard guard;
+ FCP_CHECK_STATUS(guard.Check(accumulator->in_sequential_call()));
+ // mark that a task will be
+ // scheduled, if the accumulator is
+ // not canceled.
+ if (!accumulator->MaybeIncrementActiveCount()) {
+ return;
+ }
+ auto new_value = accumulator->accumulator_func_(
+ *accumulator->accumulated_value_, *partial);
+ FCP_CHECK(new_value);
+ accumulator->accumulated_value_ = std::move(new_value);
+ // At this point the sequantial task has been run, and we (i)
+ // decrement both active and remaining counts and possibly reset the
+ // unobserved work flag, (ii) get the callback, which might be
+ // empty, and (iii) call it if that is not the case.
+ auto callback = accumulator->UpdateCountsAndGetCallback();
+ if (callback) {
+ callback();
+ }
+ });
+ };
+ }
+
+ // Schedule a parallel generator that includes a delay. The result of the
+ // generator is fed to the accumulator_func
+ void Schedule(std::function<std::unique_ptr<T>()> generator,
+ absl::Duration delay) {
+ // IncrementRemainingCount() keeps track of the number of async tasks
+ // scheduled, and sets a flag when the count goes from 0 to 1, corresponding
+ // to a starting batch of unobserved work.
+ auto shared_this = this->shared_from_this();
+ shared_this->IncrementRemainingCount();
+ clock_->WakeupWithDeadline(
+ clock_->Now() + delay,
+ std::make_shared<CallbackWaiter>([shared_this, generator] {
+ shared_this->RunParallel(
+ Accumulator<T>::GetParallelScheduleFunc(shared_this, generator));
+ }));
+ }
+
+ // Schedule a parallel generator. The result of the generator is fed to the
+ // accumulator_func
+ void Schedule(std::function<std::unique_ptr<T>()> generator) {
+ // IncrementRemainingCount() keeps track of the number of async tasks
+ // scheduled, and sets a flag when the count goes from 0 to 1, corresponding
+ // to a starting batch of unobserved work.
+ auto shared_this = this->shared_from_this();
+ shared_this->IncrementRemainingCount();
+ RunParallel([shared_this, generator] {
+ shared_this->GetParallelScheduleFunc(shared_this, generator)();
+ });
+ }
+
+ // Returns true if the accumulator doesn't have any remaining tasks,
+ // even if their results have not been observed by a callback.
+ bool IsIdle() {
+ absl::MutexLock lock(&mutex_);
+ return remaining_sequential_tasks_count_ == 0;
+ }
+
+ // Returns false if no async work has happened since last time this function
+ // was called, or the first time it is called. Otherwise it returns true and
+ // schedules a callback to be called once the scheduler is idle.
+ bool SetAsyncObserver(std::function<void()> async_callback) {
+ bool idle;
+ {
+ absl::MutexLock lock(&mutex_);
+ if (!has_unobserved_work_) {
+ return false;
+ }
+ idle = (remaining_sequential_tasks_count_ == 0);
+ if (idle) {
+ // The flag is set to false, and the callback is run as soon as we leave
+ // the mutex's scope.
+ has_unobserved_work_ = false;
+ } else {
+ // The callbak is scheduled for later, as there is ongoing work.
+ async_callback_ = async_callback;
+ }
+ }
+ if (idle) {
+ auto shared_this = this->shared_from_this();
+ RunSequential([async_callback, shared_this] { async_callback(); });
+ }
+ return true;
+ }
+
+ // Updates the active and remaining task counts, and returns the callback to
+ // be executed, or nullptr if there's pending async work.
+ inline std::function<void()> UpdateCountsAndGetCallback() {
+ absl::MutexLock lock(&mutex_);
+ if (--active_count_ == 0 && is_cancelled_) {
+ inactive_cv_.SignalAll();
+ }
+ --remaining_sequential_tasks_count_;
+ if (remaining_sequential_tasks_count_ == 0 && async_callback_) {
+ has_unobserved_work_ = false;
+ auto callback = async_callback_;
+ async_callback_ = nullptr;
+ return callback;
+ } else {
+ return nullptr;
+ }
+ }
+
+ // Take the accumulated result and abort any further work. This method can
+ // only be called when the accumulator is idle
+ std::unique_ptr<T> GetResultAndCancel() {
+ absl::MutexLock lock(&mutex_);
+ FCP_CHECK(active_count_ == 0);
+ is_cancelled_ = true;
+ return std::move(accumulated_value_);
+ }
+
+ // CancellationImpl implementation
+ void Cancel() override {
+ mutex_.Lock();
+ is_cancelled_ = true;
+ while (active_count_ > 0) {
+ inactive_cv_.Wait(&mutex_);
+ }
+ mutex_.Unlock();
+ }
+
+ bool IsCancelled() {
+ absl::MutexLock lock(&mutex_);
+ return is_cancelled_;
+ }
+
+ bool MaybeIncrementActiveCount() {
+ absl::MutexLock lock(&mutex_);
+ if (is_cancelled_) {
+ return false;
+ }
+ active_count_++;
+ return true;
+ }
+
+ size_t DecrementActiveCount() {
+ absl::MutexLock lock(&mutex_);
+ FCP_CHECK(active_count_ > 0);
+ if (--active_count_ == 0 && is_cancelled_) {
+ inactive_cv_.SignalAll();
+ }
+ return active_count_;
+ }
+
+ void IncrementRemainingCount() {
+ absl::MutexLock lock(&mutex_);
+ has_unobserved_work_ |= (remaining_sequential_tasks_count_ == 0);
+ remaining_sequential_tasks_count_++;
+ }
+
+ std::atomic<bool>* in_sequential_call() { return &in_sequential_call_; }
+
+ void inline RunParallel(std::function<void()> function) {
+ parallel_scheduler_->Schedule(function);
+ }
+
+ void inline RunSequential(std::function<void()> function) {
+ sequential_scheduler_->Schedule(function);
+ }
+
+ private:
+ // Scheduler for sequential and parallel tasks, received from the
+ // SecAggScheduler instatiating this class
+ Scheduler* parallel_scheduler_;
+ Scheduler* sequential_scheduler_;
+
+ // Callback to be executed the next time that the sequential scheduler
+ // becomes idle.
+ std::function<void()> async_callback_ ABSL_GUARDED_BY(mutex_) =
+ std::function<void()>();
+ // Accumulated value - accessed by sequential tasks only.
+ std::unique_ptr<T> accumulated_value_;
+ // Accumulation function - accessed by sequential tasks only.
+ std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func_;
+ // Clock used for scheduling delays in parallel tasks
+ Clock* clock_;
+ // Remaining number of sequential tasks to be executed - accessed by
+ // sequential tasks only.
+ size_t remaining_sequential_tasks_count_ ABSL_GUARDED_BY(mutex_) = 0;
+ bool has_unobserved_work_ ABSL_GUARDED_BY(mutex_) = false;
+
+ // Number of active calls to either callback function.
+ size_t active_count_ ABSL_GUARDED_BY(mutex_) = 0;
+ // This is set to true when the run is aborted.
+ bool is_cancelled_ ABSL_GUARDED_BY(mutex_) = false;
+ // Protects active_count_ and cancelled_.
+ absl::Mutex mutex_;
+ // Used to notify cancellation about reaching inactive state;
+ absl::CondVar inactive_cv_;
+ // This is used by ReentrancyGuard to ensure that Sequential tasks are
+ // indeed sequential.
+ std::atomic<bool> in_sequential_call_ = false;
+};
+
+// Implementation of ParallelGenerateSequentialReduce based on fcp::Scheduler.
+// Takes two Schedulers, one which is responsible for parallel execution and
+// another for serial execution. Additionally, takes a clock that can be used to
+// induce delay in task executions.
+class SecAggScheduler {
+ public:
+ SecAggScheduler(Scheduler* parallel_scheduler,
+ Scheduler* sequential_scheduler,
+ Clock* clock = Clock::RealClock())
+ : parallel_scheduler_(parallel_scheduler),
+ sequential_scheduler_(sequential_scheduler),
+ clock_(clock) {}
+
+ // SecAggScheduler is neither copyable nor movable.
+ SecAggScheduler(const SecAggScheduler&) = delete;
+ SecAggScheduler& operator=(const SecAggScheduler&) = delete;
+
+ virtual ~SecAggScheduler() = default;
+
+ // Schedule a callback to be invoked on the sequential scheduler.
+ inline void ScheduleCallback(std::function<void()> callback) {
+ RunSequential(callback);
+ }
+
+ template <typename T>
+ std::shared_ptr<Accumulator<T>> CreateAccumulator(
+ std::unique_ptr<T> initial_value,
+ std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func) {
+ return std::make_shared<Accumulator<T>>(
+ std::move(initial_value), accumulator_func, parallel_scheduler_,
+ sequential_scheduler_, clock_);
+ }
+
+ void WaitUntilIdle();
+
+ protected:
+ // Virtual for testing
+ virtual void RunSequential(std::function<void()> function);
+
+ private:
+ Scheduler* parallel_scheduler_;
+ Scheduler* sequential_scheduler_;
+ Clock* clock_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_
diff --git a/fcp/secagg/server/secagg_scheduler_test.cc b/fcp/secagg/server/secagg_scheduler_test.cc
new file mode 100644
index 0000000..5bd4e87
--- /dev/null
+++ b/fcp/secagg/server/secagg_scheduler_test.cc
@@ -0,0 +1,287 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_scheduler.h"
+
+#include <atomic>
+#include <functional>
+#include <memory>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/base/simulated_clock.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::_;
+using ::testing::Eq;
+using ::testing::IsFalse;
+using ::testing::Lt;
+using ::testing::StrictMock;
+using ::testing::Test;
+
+class MockScheduler : public Scheduler {
+ public:
+ MOCK_METHOD(void, Schedule, (std::function<void()>), (override));
+ MOCK_METHOD(void, WaitUntilIdle, ());
+};
+
+// Wrap int in a struct to keep Clang-tidy happy.
+struct Integer {
+ Integer() : value(0) {}
+ explicit Integer(int v) : value(v) {}
+ int value;
+};
+
+std::vector<std::function<std::unique_ptr<Integer>()>> IntGenerators(int n) {
+ std::vector<std::function<std::unique_ptr<Integer>()>> generators;
+ for (int i = 1; i <= n; ++i) {
+ generators.emplace_back([i]() { return std::make_unique<Integer>(i); });
+ }
+ return generators;
+}
+
+constexpr auto multiply_accumulator = [](const Integer& l, const Integer& r) {
+ return std::make_unique<Integer>(l.value * r.value);
+};
+constexpr auto call_fn = [](const std::function<void()>& f) { f(); };
+
+TEST(SecAggSchedulerTest, ScheduleCallback) {
+ StrictMock<MockScheduler> parallel_scheduler;
+ StrictMock<MockScheduler> sequential_scheduler;
+
+ EXPECT_CALL(parallel_scheduler, Schedule(_)).Times(0);
+ EXPECT_CALL(sequential_scheduler, Schedule(_)).WillOnce(call_fn);
+
+ SecAggScheduler runner(&parallel_scheduler, &sequential_scheduler);
+
+ int r = 0;
+ runner.ScheduleCallback([&r]() { r = 5; });
+ EXPECT_THAT(r, Eq(5));
+}
+
+TEST(SecAggSchedulerTest, SingleCall) {
+ StrictMock<MockScheduler> parallel_scheduler;
+ StrictMock<MockScheduler> sequential_scheduler;
+
+ EXPECT_CALL(parallel_scheduler, Schedule(_)).Times(6).WillRepeatedly(call_fn);
+ EXPECT_CALL(sequential_scheduler, Schedule(_))
+ .Times(7)
+ .WillRepeatedly(call_fn);
+
+ // Technically unsafe, but we know the pointers will be valid as long as
+ // runner is alive.
+ SecAggScheduler runner(&parallel_scheduler, &sequential_scheduler);
+
+ std::vector<std::function<std::unique_ptr<Integer>()>> generators =
+ IntGenerators(6);
+
+ Integer result;
+ auto accumulator = runner.CreateAccumulator<Integer>(
+ std::make_unique<Integer>(1), multiply_accumulator);
+ for (const auto& generator : generators) {
+ accumulator->Schedule(generator);
+ }
+ accumulator->SetAsyncObserver(
+ [&]() { result = *(accumulator->GetResultAndCancel()); });
+ EXPECT_THAT(result.value, Eq(720)); // 6! = 720
+}
+
+TEST(SecAggSchedulerTest, SingleCallWithDelay) {
+ StrictMock<MockScheduler> parallel_scheduler;
+ StrictMock<MockScheduler> sequential_scheduler;
+ SimulatedClock clock;
+
+ EXPECT_CALL(parallel_scheduler, Schedule(_)).Times(6).WillRepeatedly(call_fn);
+ EXPECT_CALL(sequential_scheduler, Schedule(_))
+ .Times(6)
+ .WillRepeatedly(call_fn);
+
+ SecAggScheduler runner(&parallel_scheduler, &sequential_scheduler, &clock);
+
+ std::vector<std::function<std::unique_ptr<Integer>()>> generators =
+ IntGenerators(6);
+
+ Integer result;
+ auto accumulator = runner.CreateAccumulator<Integer>(
+ std::make_unique<Integer>(1), multiply_accumulator);
+ for (const auto& generator : generators) {
+ accumulator->Schedule(generator, absl::Seconds(5));
+ }
+ accumulator->SetAsyncObserver(
+ [&]() { result = *(accumulator->GetResultAndCancel()); });
+
+ // Generators are still delayed.
+ EXPECT_THAT(result.value, Eq(0));
+
+ // Advance time by one second.
+ clock.AdvanceTime(absl::Seconds(1));
+ // Generators are still delayed.
+ EXPECT_THAT(result.value, Eq(0));
+
+ // Advance time by another 4 seconds.
+ clock.AdvanceTime(absl::Seconds(4));
+ EXPECT_THAT(result.value, Eq(720)); // 6! = 720
+}
+
+TEST(SecAggSchedulerTest, TwoCalls) {
+ StrictMock<MockScheduler> parallel_scheduler;
+ StrictMock<MockScheduler> sequential_scheduler;
+
+ EXPECT_CALL(parallel_scheduler, Schedule(_)).WillRepeatedly(call_fn);
+ EXPECT_CALL(sequential_scheduler, Schedule(_)).WillRepeatedly(call_fn);
+
+ // Technically unsafe, but we know the pointers will be valid as long as
+ // runner is alive.
+ SecAggScheduler runner(&parallel_scheduler, &sequential_scheduler);
+
+ // First call
+ std::vector<std::function<std::unique_ptr<Integer>()>> generators =
+ IntGenerators(6);
+
+ Integer result;
+ auto accumulator = runner.CreateAccumulator<Integer>(
+ std::make_unique<Integer>(1), multiply_accumulator);
+ for (const auto& generator : generators) {
+ accumulator->Schedule(generator);
+ }
+ accumulator->SetAsyncObserver(
+ [&]() { result = *(accumulator->GetResultAndCancel()); });
+
+ EXPECT_THAT(result.value, Eq(720)); // 6! = 720
+
+ // Second call
+ std::vector<std::function<std::unique_ptr<Integer>()>> generators2 =
+ IntGenerators(4);
+ auto accumulator2 = runner.CreateAccumulator<Integer>(
+ std::make_unique<Integer>(1), multiply_accumulator);
+
+ for (const auto& generator : generators2) {
+ accumulator2->Schedule(generator);
+ }
+ accumulator2->SetAsyncObserver(
+ [&]() { result = *(accumulator2->GetResultAndCancel()); });
+ EXPECT_THAT(result.value, Eq(24)); // 4! = 24
+}
+
+TEST(SecAggSchedulerAbortTest, Abort) {
+ auto parallel_scheduler = fcp::CreateThreadPoolScheduler(4);
+ auto sequential_scheduler = fcp::CreateThreadPoolScheduler(1);
+
+ absl::Notification signal_abort;
+ std::atomic<int> callback_counter = 0;
+
+ std::vector<std::function<std::unique_ptr<Integer>()>> generators;
+ for (int i = 1; i <= 100; ++i) {
+ generators.emplace_back([&, i]() {
+ callback_counter++;
+ // Signal abort when running 10th parallel task
+ if (i == 10) {
+ signal_abort.Notify();
+ }
+ absl::SleepFor(absl::Milliseconds(1));
+ return std::make_unique<Integer>(i);
+ });
+ }
+
+ auto accumulator_func = [&](const Integer& l, const Integer& r) {
+ callback_counter++;
+ return std::make_unique<Integer>(l.value * r.value);
+ };
+
+ SecAggScheduler runner(parallel_scheduler.get(), sequential_scheduler.get());
+ bool final_callback_called = false;
+ auto accumulator = runner.CreateAccumulator<Integer>(
+ std::make_unique<Integer>(1), accumulator_func);
+ for (const auto& generator : generators) {
+ accumulator->Schedule(generator);
+ }
+ accumulator->SetAsyncObserver([&]() { final_callback_called = true; });
+
+ signal_abort.WaitForNotification();
+ accumulator->Cancel();
+
+ int count_after_abort = callback_counter.load();
+ FCP_LOG(INFO) << "count_after_abort = " << count_after_abort;
+
+ // Wait for all scheduled tasks to finish
+ runner.WaitUntilIdle();
+
+ // The final number of callbacks should not change since returning from
+ // Abort.
+ int final_count = callback_counter.load();
+ EXPECT_THAT(final_count, Eq(count_after_abort));
+ EXPECT_THAT(final_count, Lt(generators.size()));
+ EXPECT_THAT(final_callback_called, IsFalse());
+}
+
+// Tests that three batches of async work result in three calls to the callback,
+// which can be overriden in between calls.
+TEST(SecAggSchedulerTest, ThreeCallbackCalls) {
+ auto parallel_scheduler = fcp::CreateThreadPoolScheduler(4);
+ auto sequential_scheduler = fcp::CreateThreadPoolScheduler(1);
+
+ SecAggScheduler runner(parallel_scheduler.get(), sequential_scheduler.get());
+
+ std::vector<std::function<std::unique_ptr<Integer>()>> generators =
+ IntGenerators(3);
+
+ auto accumulator = runner.CreateAccumulator<Integer>(
+ std::make_unique<Integer>(1), multiply_accumulator);
+ for (const auto& generator : generators) {
+ accumulator->Schedule(generator);
+ }
+ int callback_counter = 0;
+ accumulator->SetAsyncObserver([&]() { callback_counter++; });
+ runner.WaitUntilIdle();
+ EXPECT_THAT(callback_counter, Eq(1));
+ for (const auto& generator : generators) {
+ accumulator->Schedule(generator);
+ }
+ runner.WaitUntilIdle();
+ // The callback was not re-scheduled, so the second call to Schedule didn't
+ // trigger it. This results in unobserved work.
+ EXPECT_THAT(callback_counter, Eq(1));
+ bool has_work = accumulator->SetAsyncObserver([&]() { callback_counter++; });
+ runner.WaitUntilIdle();
+ EXPECT_TRUE(has_work);
+ EXPECT_THAT(callback_counter, Eq(2));
+ // The accumulator should be idle and without unobserved work at this point.
+ has_work = accumulator->SetAsyncObserver([&]() { callback_counter++; });
+ EXPECT_FALSE(has_work);
+ Integer result;
+ for (const auto& generator : generators) {
+ accumulator->Schedule(generator);
+ }
+ accumulator->SetAsyncObserver(
+ [&]() { result = *(accumulator->GetResultAndCancel()); });
+ runner.WaitUntilIdle();
+ // The last call to SetAsyncObserver overwrittes the previous callback.
+ EXPECT_THAT(callback_counter, Eq(2));
+ EXPECT_THAT(result.value, Eq(216)); // 6^3 = 216
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server.cc b/fcp/secagg/server/secagg_server.cc
new file mode 100644
index 0000000..a9013fb
--- /dev/null
+++ b/fcp/secagg/server/secagg_server.cc
@@ -0,0 +1,369 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/node_hash_set.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/experiments_names.h"
+#include "fcp/secagg/server/graph_parameter_finder.h"
+#include "fcp/secagg/server/secagg_scheduler.h"
+#include "fcp/secagg/server/secagg_server_aborted_state.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_messages.pb.h"
+#include "fcp/secagg/server/secagg_server_metrics_listener.h"
+#include "fcp/secagg/server/secagg_server_r0_advertise_keys_state.h"
+#include "fcp/secagg/server/secagg_server_state.h"
+#include "fcp/secagg/server/secagg_trace_utility.h"
+#include "fcp/secagg/server/secret_sharing_graph.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/secagg/server/send_to_clients_interface.h"
+#include "fcp/secagg/server/tracing_schema.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/tracing/tracing_span.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggServer::SecAggServer(std::unique_ptr<SecAggServerProtocolImpl> impl) {
+ state_ = std::make_unique<SecAggServerR0AdvertiseKeysState>(std::move(impl));
+
+ // Start the span for the current state. The rest of the state span
+ // transitioning is done in TransitionState.
+ state_span_ = std::make_unique<UnscopedTracingSpan<SecureAggServerState>>(
+ span_.Ref(), TracingState(state_->State()));
+}
+
+StatusOr<std::unique_ptr<SecAggServer>> SecAggServer::Create(
+ int minimum_number_of_clients_to_proceed, int total_number_of_clients,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ SendToClientsInterface* sender,
+ std::unique_ptr<SecAggServerMetricsListener> metrics,
+ std::unique_ptr<SecAggScheduler> prng_runner,
+ std::unique_ptr<ExperimentsInterface> experiments,
+ const SecureAggregationRequirements& threat_model) {
+ TracingSpan<CreateSecAggServer> span;
+ SecretSharingGraphFactory factory;
+ std::unique_ptr<SecretSharingGraph> secret_sharing_graph;
+ ServerVariant server_variant = ServerVariant::UNKNOWN_VERSION;
+
+ bool is_fullgraph_protocol_variant =
+ experiments->IsEnabled(kFullgraphSecAggExperiment);
+ int degree, threshold;
+ // We first compute parameters degree and threshold for the subgraph variant,
+ // unless the kFullgraphSecAggExperiment is enabled, and then set
+ // is_subgraph_protocol_variant to false if the parameter finding procedure
+ // fails. In that case we resort to classical full-graph secagg.
+ // This will happen for very small values of total_number_of_clients (e.g. <
+ // 65), i.e. cohort sizes where subgraph-secagg does not give much advantage.
+ if (!is_fullgraph_protocol_variant) {
+ if (experiments->IsEnabled(kForceSubgraphSecAggExperiment)) {
+ // In kForceSubgraphSecAggExperiment (which is only for testing
+ // purposes) we fix the degree in the Harary graph to be half the number
+ // of clients (rounding to the next odd number to account for self-edges
+ // as above) and degree to be half of the degree (or 2 whatever is
+ // larger). This means that, for example in a simple test with 5
+ // clients, each client shares keys with 2 other clients and the
+ // threshold is one.
+ degree = total_number_of_clients / 2;
+ if (degree % 2 == 0) {
+ degree += 1;
+ }
+ threshold = std::max(2, degree / 2);
+
+ } else {
+ // kSubgraphSecAggCuriousServerExperiment sets the threat model to
+ // CURIOUS_SERVER in subgraph-secagg executions.
+ // This experiment was introduced as part of go/subgraph-secagg-rollout
+ // and is temporary (see b/191179307).
+ StatusOr<fcp::secagg::HararyGraphParameters>
+ computed_params_status_or_value;
+ if (experiments->IsEnabled(kSubgraphSecAggCuriousServerExperiment)) {
+ SecureAggregationRequirements alternate_threat_model = threat_model;
+ alternate_threat_model.set_adversary_class(
+ AdversaryClass::CURIOUS_SERVER);
+ computed_params_status_or_value = ComputeHararyGraphParameters(
+ total_number_of_clients, alternate_threat_model);
+ } else {
+ computed_params_status_or_value =
+ ComputeHararyGraphParameters(total_number_of_clients, threat_model);
+ }
+ if (computed_params_status_or_value.ok()) {
+ // We add 1 to the computed degree to account for a self-edge in the
+ // SecretSharingHararyGraph graph
+ degree = computed_params_status_or_value->degree + 1;
+ threshold = computed_params_status_or_value->threshold;
+ } else {
+ is_fullgraph_protocol_variant = true;
+ }
+ }
+ }
+
+ // In both the FullGraph and SubGraph variants, the protocol only successfully
+ // completes and returns a sum if no more than
+ // floor(total_number_of_clients * threat_model.estimated_dropout_rate())
+ // clients dropout before the end of the protocol execution. This ensure that
+ // at least ceil(total_number_of_clients *(1. -
+ // threat_model.estimated_dropout_rate() -
+ // threat_model.adversarial_client_rate)) values from honest clients are
+ // included in the final sum.
+ // The protocol allows to make that threshold larger by providing a larger
+ // value of minimum_number_of_clients_to_proceed to the create function, but
+ // never lower.
+ minimum_number_of_clients_to_proceed =
+ std::max(minimum_number_of_clients_to_proceed,
+ static_cast<int>(
+ std::ceil(total_number_of_clients *
+ (1. - threat_model.estimated_dropout_rate()))));
+ if (is_fullgraph_protocol_variant) {
+ // We're instantiating full-graph secagg, either because that was
+ // the intent of the caller (by setting kFullgraphSecAggExperiment), or
+ // because ComputeHararyGraphParameters returned and error.
+ FCP_RETURN_IF_ERROR(CheckFullGraphParameters(
+ total_number_of_clients, minimum_number_of_clients_to_proceed,
+ threat_model));
+ secret_sharing_graph = factory.CreateCompleteGraph(
+ total_number_of_clients, minimum_number_of_clients_to_proceed);
+ server_variant = ServerVariant::NATIVE_V1;
+ Trace<FullGraphServerParameters>(
+ total_number_of_clients, minimum_number_of_clients_to_proceed,
+ experiments->IsEnabled(kSecAggAsyncRound2Experiment));
+ } else {
+ secret_sharing_graph =
+ factory.CreateHararyGraph(total_number_of_clients, degree, threshold);
+ server_variant = ServerVariant::NATIVE_SUBGRAPH;
+ Trace<SubGraphServerParameters>(
+ total_number_of_clients, degree, threshold,
+ minimum_number_of_clients_to_proceed,
+ experiments->IsEnabled(kSecAggAsyncRound2Experiment));
+ }
+
+ return absl::WrapUnique(
+ new SecAggServer(std::make_unique<AesSecAggServerProtocolImpl>(
+ std::move(secret_sharing_graph), minimum_number_of_clients_to_proceed,
+ input_vector_specs, std::move(metrics),
+ std::make_unique<AesCtrPrngFactory>(), sender, std::move(prng_runner),
+ std::vector<ClientStatus>(total_number_of_clients,
+ ClientStatus::READY_TO_START),
+ server_variant, std::move(experiments))));
+}
+
+Status SecAggServer::Abort() {
+ const std::string reason = "Abort upon external request.";
+ TracingSpan<AbortSecAggServer> span(state_span_->Ref(), reason);
+ FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
+ TransitionState(state_->Abort(reason, SecAggServerOutcome::EXTERNAL_REQUEST));
+ return FCP_STATUS(OK);
+}
+
+Status SecAggServer::Abort(const std::string& reason,
+ SecAggServerOutcome outcome) {
+ const std::string formatted_reason =
+ absl::StrCat("Abort upon external request for reason <", reason, ">.");
+ TracingSpan<AbortSecAggServer> span(state_span_->Ref(), formatted_reason);
+ FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
+ TransitionState(state_->Abort(formatted_reason, outcome));
+ return FCP_STATUS(OK);
+}
+
+std::string MakeClientAbortMessage(ClientAbortReason reason) {
+ return absl::StrCat("The protocol is closing client with ClientAbortReason <",
+ ClientAbortReason_Name(reason), ">.");
+}
+
+Status SecAggServer::AbortClient(uint32_t client_id, ClientAbortReason reason) {
+ TracingSpan<AbortSecAggClient> span(
+ state_span_->Ref(), client_id,
+ ClientAbortReason_descriptor()->FindValueByNumber(reason)->name());
+ FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
+ FCP_RETURN_IF_ERROR(ValidateClientId(client_id));
+ // By default, put all AbortClient calls in the same bucket (with some
+ // exceptions below).
+ ClientDropReason client_drop_reason =
+ ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT;
+ bool notify_client = false;
+ bool log_metrics = true;
+ std::string message;
+ // Handle all specific abortClient cases
+ switch (reason) {
+ case ClientAbortReason::INVALID_MESSAGE:
+ notify_client = true;
+ message = MakeClientAbortMessage(reason);
+ break;
+ case ClientAbortReason::CONNECTION_DROPPED:
+ client_drop_reason = ClientDropReason::CONNECTION_CLOSED;
+ break;
+ default:
+ log_metrics = false;
+ message = MakeClientAbortMessage(reason);
+ break;
+ }
+
+ state_->AbortClient(client_id, message, client_drop_reason, notify_client,
+ log_metrics);
+ return FCP_STATUS(OK);
+}
+
+Status SecAggServer::ProceedToNextRound() {
+ TracingSpan<ProceedToNextSecAggRound> span(state_span_->Ref());
+ FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
+ StatusOr<std::unique_ptr<SecAggServerState>> status_or_next_state =
+ state_->ProceedToNextRound();
+ if (status_or_next_state.ok()) {
+ TransitionState(std::move(status_or_next_state.value()));
+ }
+ return status_or_next_state.status();
+}
+
+StatusOr<bool> SecAggServer::ReceiveMessage(
+ uint32_t client_id, std::unique_ptr<ClientToServerWrapperMessage> message) {
+ TracingSpan<ReceiveSecAggMessage> span(state_span_->Ref(), client_id);
+ FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
+ FCP_RETURN_IF_ERROR(ValidateClientId(client_id));
+ FCP_RETURN_IF_ERROR(state_->HandleMessage(client_id, std::move(message)));
+ return ReadyForNextRound();
+}
+
+bool SecAggServer::SetAsyncCallback(std::function<void()> async_callback) {
+ return state_->SetAsyncCallback(async_callback);
+}
+
+void SecAggServer::TransitionState(
+ std::unique_ptr<SecAggServerState> new_state) {
+ // Reset state_span_ before creating a new unscoped span for the next state
+ // to ensure old span is destructed before the new one is created.
+ state_span_.reset();
+ state_ = std::move(new_state);
+ state_span_ = std::make_unique<UnscopedTracingSpan<SecureAggServerState>>(
+ span_.Ref(), TracingState(state_->State()));
+ state_->EnterState();
+}
+
+absl::flat_hash_set<uint32_t> SecAggServer::AbortedClientIds() const {
+ return state_->AbortedClientIds();
+}
+
+StatusOr<std::string> SecAggServer::ErrorMessage() const {
+ return state_->ErrorMessage();
+}
+
+bool SecAggServer::IsAborted() const { return state_->IsAborted(); }
+
+bool SecAggServer::IsCompletedSuccessfully() const {
+ return state_->IsCompletedSuccessfully();
+}
+
+bool SecAggServer::IsNumberOfIncludedInputsCommitted() const {
+ return state_->IsNumberOfIncludedInputsCommitted();
+}
+
+StatusOr<int> SecAggServer::MinimumMessagesNeededForNextRound() const {
+ FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
+ return state_->MinimumMessagesNeededForNextRound();
+}
+
+int SecAggServer::NumberOfAliveClients() const {
+ return state_->NumberOfAliveClients();
+}
+
+int SecAggServer::NumberOfClientsFailedAfterSendingMaskedInput() const {
+ return state_->NumberOfClientsFailedAfterSendingMaskedInput();
+}
+
+int SecAggServer::NumberOfClientsFailedBeforeSendingMaskedInput() const {
+ return state_->NumberOfClientsFailedBeforeSendingMaskedInput();
+}
+
+int SecAggServer::NumberOfClientsTerminatedWithoutUnmasking() const {
+ return state_->NumberOfClientsTerminatedWithoutUnmasking();
+}
+
+int SecAggServer::NumberOfIncludedInputs() const {
+ return state_->NumberOfIncludedInputs();
+}
+
+int SecAggServer::NumberOfPendingClients() const {
+ return state_->NumberOfPendingClients();
+}
+
+StatusOr<int> SecAggServer::NumberOfClientsReadyForNextRound() const {
+ FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
+ return state_->NumberOfClientsReadyForNextRound();
+}
+
+StatusOr<int> SecAggServer::NumberOfMessagesReceivedInThisRound() const {
+ FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
+ return state_->NumberOfMessagesReceivedInThisRound();
+}
+
+StatusOr<bool> SecAggServer::ReadyForNextRound() const {
+ FCP_RETURN_IF_ERROR(ErrorIfAbortedOrCompleted());
+ return state_->ReadyForNextRound();
+}
+
+StatusOr<std::unique_ptr<SecAggVectorMap>> SecAggServer::Result() {
+ return state_->Result();
+}
+
+int SecAggServer::NumberOfNeighbors() const {
+ return state_->number_of_neighbors();
+}
+
+int SecAggServer::MinimumSurvivingNeighborsForReconstruction() const {
+ return state_->minimum_surviving_neighbors_for_reconstruction();
+}
+
+SecAggServerStateKind SecAggServer::State() const { return state_->State(); }
+
+Status SecAggServer::ValidateClientId(uint32_t client_id) const {
+ if (client_id >= state_->total_number_of_clients()) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Client Id " << client_id
+ << " is outside of the expected bounds - 0 to "
+ << state_->total_number_of_clients();
+ }
+ return FCP_STATUS(OK);
+}
+
+Status SecAggServer::ErrorIfAbortedOrCompleted() const {
+ if (state_->IsAborted()) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The server has already aborted. The request cannot be "
+ "satisfied.";
+ }
+ if (state_->IsCompletedSuccessfully()) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The server has already completed the protocol. "
+ << "Call getOutput() to retrieve the output.";
+ }
+ return FCP_STATUS(OK);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server.h b/fcp/secagg/server/secagg_server.h
new file mode 100644
index 0000000..28ec113
--- /dev/null
+++ b/fcp/secagg/server/secagg_server.h
@@ -0,0 +1,344 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/node_hash_set.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/secagg/server/experiments_interface.h"
+#include "fcp/secagg/server/experiments_names.h"
+#include "fcp/secagg/server/secagg_scheduler.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_messages.pb.h"
+#include "fcp/secagg/server/secagg_server_metrics_listener.h"
+#include "fcp/secagg/server/secagg_server_state.h"
+#include "fcp/secagg/server/secret_sharing_graph.h"
+#include "fcp/secagg/server/tracing_schema.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/tracing/tracing_span.h"
+
+namespace fcp {
+namespace secagg {
+
+// Represents a server for the Secure Aggregation protocol. Each instance of
+// this class performs just *one* session of the protocol.
+//
+// To create a new instance, use the public constructor. Once constructed, the
+// server is ready to receive messages from clients with the ReceiveMessage
+// method.
+//
+// When enough messages have been received (i.e. when ReceiveMessage or
+// ReadyForNextRound return true) or any time after that, proceed to the next
+// round by calling ProceedToNextRound.
+//
+// After all client interaction is done, the server needs to do some
+// multi-threaded computation using the supplied Scheduler. Call StartPrng to
+// begin this computation.
+//
+// When the computation is complete, call Result to get the final result.
+//
+// This class is not thread-safe.
+
+class SecAggServer {
+ public:
+ // Constructs a new instance of the Secure Aggregation server.
+ //
+ // minimum_number_of_clients_to_proceed is the threshold lower bound on the
+ // total number of clients expected to complete the protocol. If there are
+ // ever fewer than this many clients still alive in the protocol, the server
+ // will abort (causing all clients to abort as well).
+ //
+ // total_number_of_clients is the number of clients selected to be in the
+ // cohort for this instance of Secure Aggregation.
+ //
+ // input_vector_specs must contain one InputVectorSpecification for each input
+ // vector which the protocol will aggregate.
+ //
+ // sender is used by the server to send messages to clients. The server will
+ // consume this object, taking ownership of it.
+ //
+ // sender may be called on a different thread than the thread used to call
+ // into SecAggServer, specifically in the PrngRunning state.
+ //
+ // prng_factory is a pointer to an instance of a subclass of AesPrngFactory.
+ // If this client will be communicating with the (C++) version of SecAggClient
+ // in this package, then the server and all clients should use
+ // AesCtrPrngFactory.
+ //
+ // metrics will be called over the course of the protocol to record message
+ // sizes and events. If it is null, no metrics will be recorded.
+ //
+ // threat_model includes the assumed maximum adversarial, maximum dropout
+ // rate, and adversary class.
+ //
+ //
+ // The protocol successfully
+ // completes and returns a sum if and only if no more than
+ // floor(total_number_of_clients * threat_model.estimated_dropout_rate())
+ // clients dropout before the end of the protocol execution. This ensure that
+ // at least ceil(total_number_of_clients
+ // *(1. - threat_model.estimated_dropout_rate() -
+ // threat_model.adversarial_client_rate)) values from honest clients are
+ // included in the final sum.
+ // The protocol allows to make that threshold larger by providing a larger
+ // value of minimum_number_of_clients_to_proceed, but
+ // never lower (if the provided minimum_number_of_clients_to_proceed is
+ // smaller than ceil(total_number_of_clients *(1. -
+ // threat_model.estimated_dropout_rate())), the protocol defaults to the
+ // latter value.
+ static StatusOr<std::unique_ptr<SecAggServer>> Create(
+ int minimum_number_of_clients_to_proceed, int total_number_of_clients,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ SendToClientsInterface* sender,
+ std::unique_ptr<SecAggServerMetricsListener> metrics,
+ std::unique_ptr<SecAggScheduler> prng_runner,
+ std::unique_ptr<ExperimentsInterface> experiments,
+ const SecureAggregationRequirements& threat_model);
+
+ ////////////////////////////// PROTOCOL METHODS //////////////////////////////
+
+ // Makes the server abort the protocol, sending a message to all still-alive
+ // clients that the protocol has been aborted. Most of the state will be
+ // erased except for some diagnostic information. A new instance of
+ // SecAggServer will be needed to restart the protocol.
+ //
+ // If a reason string is provided, it will be stored by the server and sent to
+ // the clients as diagnostic information.
+ // An optional outcome can be provided for diagnostic purposes to be recorded
+ // via SecAggServerMetricsListener. By default, EXTERNAL_REQUEST outcome is
+ // assumed.
+ //
+ // The status will be OK unless the protocol was already completed or aborted.
+ Status Abort();
+ Status Abort(const std::string& reason, SecAggServerOutcome outcome);
+
+ // Abort the specified client for the given reason.
+ //
+ // If the server is in a terminal state, returns a FAILED_PRECONDITION status.
+ Status AbortClient(uint32_t client_id, ClientAbortReason reason);
+
+ // Proceeds to the next round, doing necessary computation and sending
+ // messages to clients as appropriate.
+ //
+ // If the server is not ready to proceed, this method will do nothing and
+ // return an UNAVAILABLE status. If the server is already in a terminal state,
+ // this method will do nothing and return a FAILED_PRECONDITION status.
+ //
+ // If the server is ready to proceed, but not all clients have yet sent in
+ // responses, any client that hasn't yet sent a response will be aborted (and
+ // a message informing them of this will be sent).
+ //
+ // After proceeding to the next round, the server is ready to receive more
+ // messages from clients in rounds 1, 2, and 3. In the PrngRunning round, it
+ // is instead ready to have StartPrng called.
+ //
+ // Returns OK as long as the server has actually executed the transition to
+ // the next state.
+ Status ProceedToNextRound();
+
+ // Processes a message that has been received from a client with the given
+ // client_id.
+ //
+ // The boolean returned indicates whether the server is ready to proceed to
+ // the next round. This will be true when a number of clients equal to the
+ // minimum_number_of_clients_to_proceed threshold have sent in valid messages
+ // (and not subsequently aborted), including this one.
+ //
+ // If the message is invalid, the client who sent it will be aborted, and a
+ // message will be sent to them notifying them of the fact. A client may also
+ // send the server a message that it wishes to abort (in which case no further
+ // message to it is sent). This may cause a server that was previously ready
+ // for the next round to no longer be ready, or it may cause the server to
+ // abort if not enough clients remain alive.
+ //
+ // Returns a FAILED_PRECONDITION status if the server is in a terminal state
+ // or the PRNG_RUNNING state.
+ //
+ // Returns an ABORTED status to signify that the server has aborted after
+ // receiving this message. (This will cause all surviving clients to be
+ // notified as well.)
+ StatusOr<bool> ReceiveMessage(
+ uint32_t client_id,
+ std::unique_ptr<ClientToServerWrapperMessage> message);
+ // Sets up a callback to be invoked when any background asynchronous work
+ // has been done. The callback is guaranteed to invoked via the server's
+ // callback scheduler.
+ //
+ // Returns true if asynchronous processing is supported in the current
+ // server state and the callback has been setup successfully. Returns false
+ // if asynchronous processing isn't supported in the current server state or
+ // if no further asynchronous processing is possible. The callback argument
+ // is ignored in that case.
+ bool SetAsyncCallback(std::function<void()> async_callback);
+
+ /////////////////////////////// STATUS METHODS ///////////////////////////////
+
+ // Returns the set of clients that aborted the protocol. Can be used by the
+ // caller to close the relevant RPC connections or just start ignoring
+ // incoming messages from those clients for performance reasons.
+ absl::flat_hash_set<uint32_t> AbortedClientIds() const;
+
+ // Returns a string describing the reason that the protocol was aborted.
+ // If the protocol has not actually been aborted, returns an error Status
+ // with code PRECONDITION_FAILED.
+ StatusOr<std::string> ErrorMessage() const;
+
+ // Returns true if the protocol has been aborted, false else.
+ bool IsAborted() const;
+
+ // Returns true if the protocol has been successfully completed, false else.
+ // The Result method can be called exactly when this method returns true.
+ bool IsCompletedSuccessfully() const;
+
+ // Whether the set of inputs that will be included in the final aggregation
+ // has been fixed.
+ //
+ // If true, the value of NumberOfIncludedInputs will be fixed for the
+ // remainder of the protocol.
+ bool IsNumberOfIncludedInputsCommitted() const;
+
+ // Indicates the minimum number of valid messages needed to be able to
+ // successfully move to the next round.
+ //
+ // Note that this value is not guaranteed to be monotonically decreasing.
+ // Client failures can cause this value to increase.
+ //
+ // Calling this in a terminal state results in an error.
+ StatusOr<int> MinimumMessagesNeededForNextRound() const;
+
+ // Indicates the total number of clients that the server expects to receive
+ // a response from in this round (i.e. the ones that have not aborted). In
+ // the COMPLETED state, this returns the number of clients that survived to
+ // the final protocol message.
+ int NumberOfAliveClients() const;
+
+ // Number of clients that failed after submitting their masked input. These
+ // clients' inputs will be included in the aggregate value, even though
+ // these clients did not complete the protocol.
+ int NumberOfClientsFailedAfterSendingMaskedInput() const;
+
+ // Number of clients that failed before submitting their masked input. These
+ // clients' inputs won't be included in the aggregate value, even if the
+ // protocol succeeds.
+ int NumberOfClientsFailedBeforeSendingMaskedInput() const;
+
+ // Number of clients that submitted a masked value, but didn't report their
+ // unmasking values fast enough to have them used in the final unmasking
+ // process. These clients' inputs will be included in the aggregate value.
+ int NumberOfClientsTerminatedWithoutUnmasking() const;
+
+ // Returns the number of inputs that will appear in the final sum, if the
+ // protocol completes.
+ //
+ // Once IsNumberOfIncludedInputsCommitted is true, this value will be fixed
+ // for the remainder of the protocol.
+ //
+ // This will be 0 if the server is aborted. This will also be 0 if the
+ // server is in an early state, prior to receiving masked inputs. It is
+ // incremented only when the server receives a masked input from a client.
+ int NumberOfIncludedInputs() const;
+
+ // Returns the number of live clients that have not yet submitted the
+ // expected response for the current round. In terminal states, this will be
+ // 0.
+ int NumberOfPendingClients() const;
+
+ // Returns the number of clients that would still be alive if
+ // ProceedToNextRound were called immediately after. This value may be less
+ // than NumberOfMessagesReceivedInThisRound if a client fails after sending
+ // a message in this round.
+ //
+ // Note that this value is not guaranteed to be monotonically increasing,
+ // even within a round. Client failures can cause this value to decrease.
+ //
+ // Calling this in a terminal state results in an error.
+ StatusOr<int> NumberOfClientsReadyForNextRound() const;
+
+ // Returns the number of valid messages received by clients this round.
+ // Unlike NumberOfClientsReadyForNextRound, this number is monotonically
+ // increasing until ProceedToNextRound is called, or the server aborts.
+ //
+ // Calling this in a terminal state results in an error.
+ StatusOr<int> NumberOfMessagesReceivedInThisRound() const;
+
+ // Returns a boolean indicating if the server has received enough messages
+ // from clients (who have not subsequently aborted) to proceed to the next
+ // round. ProceedToNextRound will do nothing unless this returns true.
+ //
+ // Even after this method returns true, the server will remain in the
+ // current round until ProceedToNextRound is called.
+ //
+ // Calling this in a terminal state results in an error.
+ StatusOr<bool> ReadyForNextRound() const;
+
+ // Transfers ownership of the result of the protocol to the caller. Requires
+ // the server to be in a completed state; returns UNAVAILABLE otherwise.
+ // Can be called only once; any consequitive calls result in an error.
+ StatusOr<std::unique_ptr<SecAggVectorMap>> Result();
+
+ // Returns the number of neighbors of each client.
+ int NumberOfNeighbors() const;
+
+ // Returns the minimum number of neighbors of a client that must not
+ // drop-out for that client's contribution to be included in the sum. This
+ // corresponds to the threshold in the shamir secret sharing of self and
+ // pairwise masks.
+ int MinimumSurvivingNeighborsForReconstruction() const;
+
+ // Returns a value uniquely describing the current state of the client's
+ // FSM.
+ SecAggServerStateKind State() const;
+
+ private:
+ // Constructs a new instance of the Secure Aggregation server.
+ explicit SecAggServer(std::unique_ptr<SecAggServerProtocolImpl> impl);
+
+ // This causes the server to transition into a new state, and call the
+ // callback if one is provided.
+ void TransitionState(std::unique_ptr<SecAggServerState> new_state);
+
+ // Validates if the client_id is within the expected bounds.
+ Status ValidateClientId(uint32_t client_id) const;
+
+ // Returns an error if the server is in Aborted or Completed state.
+ Status ErrorIfAbortedOrCompleted() const;
+
+ // The internal state object, containing details about the server's current
+ // state.
+ std::unique_ptr<SecAggServerState> state_;
+
+ // Tracing span for this session of SecAggServer. This is bound to the
+ // lifetime of SecAggServer i.e. from the time the object is created till it
+ // is destroyed.
+ UnscopedTracingSpan<SecureAggServerSession> span_;
+
+ // Holds pointer to a tracing span corresponding to the current active
+ // SecAggServerState.
+ std::unique_ptr<UnscopedTracingSpan<SecureAggServerState>> state_span_;
+};
+
+} // namespace secagg
+} // namespace fcp
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_H_
diff --git a/fcp/secagg/server/secagg_server_aborted_state.cc b/fcp/secagg/server/secagg_server_aborted_state.cc
new file mode 100644
index 0000000..b0c21eb
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_aborted_state.cc
@@ -0,0 +1,51 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_aborted_state.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+namespace fcp {
+namespace secagg {
+
+SecAggServerAbortedState::SecAggServerAbortedState(
+ const std::string& error_message,
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking)
+ : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
+ number_of_clients_failed_before_sending_masked_input,
+ number_of_clients_terminated_without_unmasking,
+ SecAggServerStateKind::ABORTED, std::move(impl)),
+ error_message_(error_message) {}
+
+SecAggServerAbortedState::~SecAggServerAbortedState() {}
+
+bool SecAggServerAbortedState::IsAborted() const { return true; }
+
+StatusOr<std::string> SecAggServerAbortedState::ErrorMessage() const {
+ return error_message_;
+}
+
+bool SecAggServerAbortedState::IsNumberOfIncludedInputsCommitted() const {
+ return true;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_aborted_state.h b/fcp/secagg/server/secagg_server_aborted_state.h
new file mode 100644
index 0000000..f2f4273
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_aborted_state.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_ABORTED_STATE_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_ABORTED_STATE_H_
+
+#include <memory>
+#include <string>
+
+#include "fcp/secagg/server/secagg_server_state.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class is the State for the SecAggServer after it has aborted. The server
+// cannot transition out of this state; a new SecAggServer object will be needed
+// to start a new run of the protocol. However, an aborted SecAggServer still
+// stores some of the information about the server before it aborted.
+
+class SecAggServerAbortedState : public SecAggServerState {
+ public:
+ SecAggServerAbortedState(
+ const std::string& error_message,
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking);
+
+ ~SecAggServerAbortedState() override;
+
+ // Returns true.
+ bool IsAborted() const override;
+
+ // Returns an error message explaining why the server aborted.
+ StatusOr<std::string> ErrorMessage() const override;
+
+ bool IsNumberOfIncludedInputsCommitted() const override;
+
+ private:
+ const std::string error_message_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_ABORTED_STATE_H_
diff --git a/fcp/secagg/server/secagg_server_aborted_state_test.cc b/fcp/secagg/server/secagg_server_aborted_state_test.cc
new file mode 100644
index 0000000..47126ed
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_aborted_state_test.cc
@@ -0,0 +1,309 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_aborted_state.h"
+
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_set.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
+#include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+
+std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
+ MockSecAggServerMetricsListener* metrics_listener = nullptr) {
+ auto sender = std::unique_ptr<SendToClientsInterface>();
+ SecretSharingGraphFactory factory;
+ return std::make_unique<AesSecAggServerProtocolImpl>(
+ factory.CreateCompleteGraph(4, 3), // total number of clients is 4
+ 3, // minimum_number_of_clients_to_proceed,
+ std::vector<InputVectorSpecification>(), // input_vector_specs
+ std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
+ nullptr, // prng_factory
+ sender.get(),
+ nullptr, // prng_runner
+ std::vector<ClientStatus>(
+ 4, DEAD_AFTER_SHARE_KEYS_RECEIVED), // client_statuses
+ ServerVariant::NATIVE_V1);
+}
+
+TEST(SecaggServerAbortedStateTest, IsAbortedReturnsTrue) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.IsAborted(), Eq(true));
+}
+
+TEST(SecaggServerAbortedStateTest, IsCompletedSuccessfullyReturnsFalse) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.IsCompletedSuccessfully(), Eq(false));
+}
+
+TEST(SecaggServerAbortedStateTest, ErrorMessageReturnsSelectedMessage) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.ErrorMessage().value(), Eq(test_error_message));
+}
+
+TEST(SecaggServerAbortedStateTest, ReadyForNextRoundReturnsFalse) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.ReadyForNextRound(), Eq(false));
+}
+
+TEST(SecaggServerAbortedStateTest,
+ NumberOfMessagesReceivedInThisRoundReturnsZero) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.NumberOfMessagesReceivedInThisRound(), Eq(0));
+}
+
+TEST(SecaggServerAbortedStateTest,
+ NumberOfClientsReadyForNextRoundReturnsZero) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.NumberOfClientsReadyForNextRound(), Eq(0));
+}
+
+TEST(SecaggServerAbortedStateTest, NumberOfAliveClientsIsZero) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.NumberOfAliveClients(), Eq(0));
+}
+
+TEST(SecaggServerAbortedStateTest,
+ NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(4));
+}
+
+TEST(SecaggServerAbortedStateTest,
+ NumberOfClientsFailedAfterSendingMaskedInputReturnsZero) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+}
+
+TEST(SecaggServerAbortedStateTest,
+ NumberOfClientsTerminatedWithoutUnmaskingReturnsZero) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.NumberOfClientsTerminatedWithoutUnmasking(), Eq(0));
+}
+
+TEST(SecaggServerAbortedStateTest, NumberOfPendingClientsReturnsZero) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.NumberOfPendingClients(), Eq(0));
+}
+
+TEST(SecaggServerAbortedStateTest, NumberOfIncludedInputsReturnsZero) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.NumberOfIncludedInputs(), Eq(0));
+}
+
+TEST(SecaggServerAbortedStateTest,
+ IsNumberOfIncludedInputsCommittedReturnsTrue) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.IsNumberOfIncludedInputsCommitted(), Eq(true));
+}
+
+TEST(SecaggServerAbortedStateTest,
+ MinimumMessagesNeededForNextRoundReturnsZero) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.MinimumMessagesNeededForNextRound(), Eq(0));
+}
+
+TEST(SecaggServerAbortedStateTest,
+ minimum_number_of_clients_to_proceedIsAccurate) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.minimum_number_of_clients_to_proceed(), Eq(3));
+}
+
+TEST(SecaggServerAbortedStateTest, HandleMessageRaisesError) {
+ std::string test_error_message = "test error message";
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(metrics),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ ClientToServerWrapperMessage client_message;
+ EXPECT_CALL(*metrics, MessageReceivedSizes(
+ Eq(ClientToServerWrapperMessage::
+ MessageContentCase::MESSAGE_CONTENT_NOT_SET),
+ Eq(false), Eq(client_message.ByteSizeLong())));
+ EXPECT_THAT(aborted_state.HandleMessage(0, client_message).ok(), Eq(false));
+}
+
+TEST(SecaggServerAbortedStateTest, ProceedToNextRoundRaisesError) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.ProceedToNextRound().ok(), Eq(false));
+}
+
+TEST(SecaggServerAbortedStateTest, ResultRaisesErrorStatus) {
+ std::string test_error_message = "test error message";
+
+ SecAggServerAbortedState aborted_state(
+ test_error_message, CreateSecAggServerProtocolImpl(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 4, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(aborted_state.Result().ok(), Eq(false));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_completed_state.cc b/fcp/secagg/server/secagg_server_completed_state.cc
new file mode 100644
index 0000000..d5d03ee
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_completed_state.cc
@@ -0,0 +1,70 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_completed_state.h"
+
+#include <memory>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/tracing/tracing_span.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggServerCompletedState::SecAggServerCompletedState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking)
+ : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
+ number_of_clients_failed_before_sending_masked_input,
+ number_of_clients_terminated_without_unmasking,
+ SecAggServerStateKind::COMPLETED, std::move(impl)) {
+ // Moving to this state means the protocol succeeded!
+ if (metrics()) {
+ metrics()->ProtocolOutcomes(SecAggServerOutcome::SUCCESS);
+ }
+ Trace<SecAggProtocolOutcome>(TracingSecAggServerOutcome_Success);
+}
+
+SecAggServerCompletedState::~SecAggServerCompletedState() {}
+
+bool SecAggServerCompletedState::IsCompletedSuccessfully() const {
+ return true;
+}
+
+int SecAggServerCompletedState::NumberOfIncludedInputs() const {
+ return total_number_of_clients() -
+ number_of_clients_failed_before_sending_masked_input_;
+}
+
+bool SecAggServerCompletedState::IsNumberOfIncludedInputsCommitted() const {
+ return true;
+}
+
+StatusOr<std::unique_ptr<SecAggVectorMap> >
+SecAggServerCompletedState::Result() {
+ auto result = impl()->TakeResult();
+ if (!result) {
+ return FCP_STATUS(UNAVAILABLE)
+ << "Result is uninitialized or requested more than once";
+ }
+ return std::move(result);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_completed_state.h b/fcp/secagg/server/secagg_server_completed_state.h
new file mode 100644
index 0000000..44ac741
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_completed_state.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_COMPLETED_STATE_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_COMPLETED_STATE_H_
+
+#include <memory>
+#include <string>
+
+#include "fcp/secagg/server/secagg_server_state.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class is the State for the SecAggServer after it has successfully
+// completed the protocol. The server cannot transition out of this state; a new
+// SecAggServer object will be needed to start a new run of the protocol.
+// This state stores information about the final state of the protocol, such as
+// the number of inputs included in the output.
+
+class SecAggServerCompletedState : public SecAggServerState {
+ public:
+ SecAggServerCompletedState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking);
+
+ ~SecAggServerCompletedState() override;
+
+ // Returns true.
+ bool IsCompletedSuccessfully() const override;
+
+ int NumberOfIncludedInputs() const override;
+
+ bool IsNumberOfIncludedInputsCommitted() const override;
+
+ StatusOr<std::unique_ptr<SecAggVectorMap> > Result() override;
+};
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_COMPLETED_STATE_H_
diff --git a/fcp/secagg/server/secagg_server_completed_state_test.cc b/fcp/secagg/server/secagg_server_completed_state_test.cc
new file mode 100644
index 0000000..2dc007e
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_completed_state_test.cc
@@ -0,0 +1,261 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_completed_state.h"
+
+#include <memory>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_set.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
+#include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
+#include "fcp/secagg/testing/test_matchers.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+
+std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
+ MockSendToClientsInterface* sender,
+ MockSecAggServerMetricsListener* metrics_listener = nullptr) {
+ int total_number_of_clients = 4;
+ SecretSharingGraphFactory factory;
+ return std::make_unique<AesSecAggServerProtocolImpl>(
+ factory.CreateCompleteGraph(total_number_of_clients, 3),
+ 3, // minimum_number_of_clients_to_proceed
+ std::vector<InputVectorSpecification>(),
+ std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
+ nullptr, // prng_factory
+ sender,
+ nullptr, // prng_runner
+ std::vector<ClientStatus>(total_number_of_clients,
+ ClientStatus::UNMASKING_RESPONSE_RECEIVED),
+ ServerVariant::NATIVE_V1);
+}
+
+SecAggServerCompletedState CreateState(
+ MockSendToClientsInterface* sender,
+ int number_of_clients_failed_after_sending_masked_input = 0,
+ int number_of_clients_failed_before_sending_masked_input = 0,
+ int number_of_clients_terminated_without_unmasking = 0,
+ std::unique_ptr<SecAggVectorMap> map = std::unique_ptr<SecAggVectorMap>(),
+ MockSecAggServerMetricsListener* metrics_listener = nullptr) {
+ std::unique_ptr<AesSecAggServerProtocolImpl> impl =
+ CreateSecAggServerProtocolImpl(sender, metrics_listener);
+ impl->SetResult(std::move(map));
+ return SecAggServerCompletedState(
+ std::move(impl), number_of_clients_failed_after_sending_masked_input,
+ number_of_clients_failed_before_sending_masked_input,
+ number_of_clients_terminated_without_unmasking);
+}
+
+TEST(SecAggServerCompletedStateTest, IsAbortedReturnsFalse) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.IsAborted(), Eq(false));
+}
+
+TEST(SecAggServerCompletedStateTest, IsCompletedSuccessfullyReturnsTrue) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.IsCompletedSuccessfully(), Eq(true));
+}
+
+TEST(SecAggServerCompletedStateTest, ErrorMessageRaisesError) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecAggServerCompletedStateTest, ReadyForNextRoundReturnsFalse) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.ReadyForNextRound(), Eq(false));
+}
+
+TEST(SecAggServerCompletedStateTest,
+ NumberOfMessagesReceivedInThisRoundReturnsZero) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.NumberOfMessagesReceivedInThisRound(), Eq(0));
+}
+
+TEST(SecAggServerCompletedStateTest,
+ NumberOfClientsReadyForNextRoundReturnsZero) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.NumberOfClientsReadyForNextRound(), Eq(0));
+}
+
+TEST(SecAggServerCompletedStateTest, NumberOfAliveClientsIsAccurate) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(
+ sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 1); // number_of_clients_terminated_without_unmasking
+ EXPECT_THAT(completed_state.NumberOfAliveClients(), Eq(3));
+}
+
+TEST(SecAggServerCompletedStateTest,
+ NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(
+ sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
+ 1, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+ EXPECT_THAT(completed_state.NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+}
+
+TEST(SecAggServerCompletedStateTest,
+ NumberOfClientsFailedAfterSendingMaskedInputIsAccurate) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(
+ sender.get(), 1, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+ EXPECT_THAT(completed_state.NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(1));
+}
+
+TEST(SecAggServerCompletedStateTest,
+ NumberOfClientsTerminatedWithoutUnmaskingIsAccurate) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(
+ sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 1); // number_of_clients_terminated_without_unmasking
+ EXPECT_THAT(completed_state.NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(1));
+}
+
+TEST(SecAggServerCompletedStateTest, NumberOfPendingClientsReturnsZero) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.NumberOfPendingClients(), Eq(0));
+}
+
+TEST(SecAggServerCompletedStateTest, NumberOfIncludedInputsIsAccurate) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(
+ sender.get(), 1, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+ EXPECT_THAT(completed_state.NumberOfIncludedInputs(), Eq(4));
+
+ SecAggServerCompletedState completed_state_2 = CreateState(
+ sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
+ 1, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+ EXPECT_THAT(completed_state_2.NumberOfIncludedInputs(), Eq(3));
+}
+
+TEST(SecAggServerCompletedStateTest,
+ IsNumberOfIncludedInputsCommittedReturnsTrue) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.IsNumberOfIncludedInputsCommitted(), Eq(true));
+}
+
+TEST(SecAggServerCompletedStateTest,
+ MinimumMessagesNeededForNextRoundReturnsZero) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.MinimumMessagesNeededForNextRound(), Eq(0));
+}
+
+TEST(SecAggServerCompletedStateTest,
+ MinimumNumberOfClientsToProceedIsAccurate) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.minimum_number_of_clients_to_proceed(), Eq(3));
+}
+
+TEST(SecAggServerCompletedStateTest, HandleMessageRaisesError) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+
+ SecAggServerCompletedState completed_state = CreateState(
+ sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0, // number_of_clients_terminated_without_unmasking
+ std::unique_ptr<SecAggVectorMap>(), metrics);
+
+ ClientToServerWrapperMessage client_message;
+ EXPECT_CALL(*metrics, MessageReceivedSizes(
+ Eq(ClientToServerWrapperMessage::
+ MessageContentCase::MESSAGE_CONTENT_NOT_SET),
+ Eq(false), Eq(client_message.ByteSizeLong())));
+ EXPECT_THAT(completed_state.HandleMessage(0, client_message).ok(), Eq(false));
+}
+
+TEST(SecAggServerCompletedStateTest, ProceedToNextRoundRaisesError) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerCompletedState completed_state = CreateState(sender.get());
+ EXPECT_THAT(completed_state.ProceedToNextRound().ok(), Eq(false));
+}
+
+TEST(SecAggServerCompletedStateTest, ResultGivesStoredResult) {
+ std::vector<uint64_t> vec = {1, 3, 6, 10};
+ auto result_map = std::make_unique<SecAggVectorMap>();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ result_map->emplace("foobar", SecAggVector(vec, 32));
+ SecAggServerCompletedState completed_state =
+ CreateState(sender.get(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0, // number_of_clients_terminated_without_unmasking
+ std::move(result_map));
+
+ auto result = completed_state.Result();
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(*result.value(),
+ testing::MatchesSecAggVector("foobar", SecAggVector(vec, 32)));
+}
+
+TEST(SecAggServerCompletedStateTest, ConstructorRecordsSuccessMetric) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+
+ EXPECT_CALL(*metrics, ProtocolOutcomes(Eq(SecAggServerOutcome::SUCCESS)));
+ SecAggServerCompletedState completed_state =
+ CreateState(sender.get(),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0, // number_of_clients_terminated_without_unmasking
+ std::unique_ptr<SecAggVectorMap>(), metrics);
+
+ EXPECT_THAT(tracing_recorder.FindAllEvents<SecAggProtocolOutcome>(),
+ ElementsAre(IsEvent<SecAggProtocolOutcome>(
+ Eq(TracingSecAggServerOutcome_Success))));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_enums.proto b/fcp/secagg/server/secagg_server_enums.proto
new file mode 100644
index 0000000..f5b686f
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_enums.proto
@@ -0,0 +1,149 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package fcp.secagg;
+
+option java_package = "fcp.secagg.server";
+option java_outer_classname = "SecAggServerEnums";
+
+// Describes current state of SecAggServer.
+enum SecAggServerStateKind {
+ UNKNOWN_STATE = 0;
+ R0_ADVERTISE_KEYS = 1;
+ R1_SHARE_KEYS = 2;
+ R2_MASKED_INPUT_COLLECTION = 3;
+ R3_UNMASKING = 4;
+ PRNG_RUNNING = 5;
+ COMPLETED = 6;
+ ABORTED = 7;
+}
+
+// Describes version of SecAggServer implementation.
+enum ServerVariant {
+ UNKNOWN_VERSION = 0;
+ OBSOLETE_JAVA = 1;
+ NATIVE_V1 = 2;
+ RLWE_HOMOMORPHIC_KEYS = 3;
+ NATIVE_SUBGRAPH = 4;
+}
+
+// Describes the outcome of running SecAgg protocol on server.
+enum SecAggServerOutcome {
+ // A public abort() method of SecAggServerImpl was called.
+ EXTERNAL_REQUEST = 0;
+ // Too many clients dropped out for the protocol to continue.
+ NOT_ENOUGH_CLIENTS_REMAINING = 1;
+ // Some error occurred and was not otherwise handled.
+ UNHANDLED_ERROR = 2;
+ // The protocol ran to success and the server produced an output value.
+ SUCCESS = 3;
+}
+
+// Used by descendants of SecAggServerState to track the status of clients. This
+// is referred to as a "status" rather than a "state" because it does not
+// necessarily correspond with the client's actual state in the FSM.
+enum ClientStatus {
+ READY_TO_START = 0;
+ DEAD_BEFORE_SENDING_ANYTHING = 1;
+ ADVERTISE_KEYS_RECEIVED = 2;
+ DEAD_AFTER_ADVERTISE_KEYS_RECEIVED = 3;
+ SHARE_KEYS_RECEIVED = 4;
+ DEAD_AFTER_SHARE_KEYS_RECEIVED = 5;
+ MASKED_INPUT_RESPONSE_RECEIVED = 6;
+ DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED = 7;
+ UNMASKING_RESPONSE_RECEIVED = 8;
+ DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED = 9;
+}
+
+// Error codes summarizing the reason why a client was dropped.
+enum ClientDropReason {
+ // Received abort message from the client.
+ SENT_ABORT_MESSAGE = 0;
+ // Message type received different from expected.
+ UNEXPECTED_MESSAGE_TYPE = 1;
+ // Message type not recognized or not set.
+ UNKNOWN_MESSAGE_TYPE = 2;
+ // Not expecting an AdvertiseKeys message from this client.
+ ADVERTISE_KEYS_UNEXPECTED = 3;
+ // One of the public keys in an AdvertiseKeys message has length 0.
+ EMPTY_PUBLIC_KEY = 4;
+ // Did not send an AdvertiseKeys message before round ended.
+ NO_ADVERTISE_KEYS = 5;
+ // Not expecting a ShareKeysResponse message from this client.
+ SHARE_KEYS_UNEXPECTED = 6;
+ // ShareKeysResponse did not have the expected number of key shares.
+ WRONG_NUMBER_OF_KEY_SHARES = 7 [deprecated = true];
+ // ShareKeysResponse does not include key shares for all clients it should.
+ MISSING_KEY_SHARE = 8 [deprecated = true];
+ // ShareKeysResponse sent a key share for a client it shouldn't have.
+ EXTRA_KEY_SHARE = 9 [deprecated = true];
+ // Did not send a ShareKeysResponse message before round ended.
+ NO_SHARE_KEYS = 10;
+ // Not expecting a MaskedInputResponse message from this client.
+ MASKED_INPUT_UNEXPECTED = 11;
+ // Masked input received does not match the input specification.
+ INVALID_MASKED_INPUT = 12;
+ // Did not send a MaskedInputResponse message before round ended.
+ NO_MASKED_INPUT = 13;
+ // Not expecting an UnmaskingResponse message from this client.
+ UNMASKING_RESPONSE_UNEXPECTED = 14;
+ // UnmaskingResponse received does not contain the correct type of key shares.
+ INVALID_UNMASKING_RESPONSE = 15;
+ // Did not send an UnmaskingResponse message before round ended.
+ NO_UNMASKING_RESPONSE = 16;
+ // AdvertiseKeys message contained a public key of invalid size.
+ INVALID_PUBLIC_KEY = 17;
+ // Protocol aborted the client either due to early success or internal errors.
+ SERVER_PROTOCOL_ABORT_CLIENT = 18;
+ // Client is no longer needed but marks the protocol as success.
+ EARLY_SUCCESS = 19;
+ // Client connection closed.
+ CONNECTION_CLOSED = 20;
+ // Invalid ShareKeysResponse (e.g. one that doesn't have the expected number
+ // of key shares, doesn't include key shares for all clients it should, or
+ // has a key share for a client it shouldn't have).
+ INVALID_SHARE_KEYS_RESPONSE = 21;
+}
+
+// Error codes describing why the client was aborted by the protocol.
+enum ClientAbortReason {
+ // Client was aborted because it sent an invalid message.
+ INVALID_MESSAGE = 0;
+ // Client never checked-in with a handshake message.
+ NOT_CHECKED_IN = 1;
+ // Client connection dropped over the wire.
+ CONNECTION_DROPPED = 2;
+ // Client is running an obsolete version.
+ OBSOLETE_VERSION = 3 [deprecated = true];
+}
+enum AdversaryClass {
+ NONE = 0;
+ // A semi-honest/honest-but-curious adversary controlling the server and a
+ // fraction of the clients
+ CURIOUS_SERVER = 1;
+ // A semi-honest adversary controlling the server and a fraction of the
+ // clients that might perform the following malicious attack:
+ // Consider a client i that submits its masked input y. The server
+ // requests t (the shamir threshold) shares to recover the self-mask of i, and
+ // additionally (and this is the malicious behaviour)
+ // obtains another t shares to recover the pairwise masks of i from the
+ // number_of_neighbors - t clients from with a share of a self-mask had not
+ // been requested. Using both pairwise and self masks the value of i can be
+ // recovered by the server.
+ SEMI_MALICIOUS_SERVER = 2;
+}
diff --git a/fcp/secagg/server/secagg_server_messages.proto b/fcp/secagg/server/secagg_server_messages.proto
new file mode 100644
index 0000000..c5bffbc
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_messages.proto
@@ -0,0 +1,44 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package fcp.secagg;
+
+import "fcp/secagg/server/secagg_server_enums.proto";
+
+option java_package = "fcp.secagg.server";
+option java_multiple_files = true;
+
+// Specifies the security and robustness requirements for an instance of secure
+// aggregation
+message SecureAggregationRequirements {
+ // The adversary class against which to protect.
+ AdversaryClass adversary_class = 1;
+
+ // Fraction of clients in the population which might be compromised
+ // by an adversary, expressed as a fraction (between 0.0 and 1.0).
+ double adversarial_client_rate = 2;
+
+ // Estimated client dropout rate, expressed as a fraction (between 0.0
+ // and 1.0).
+ double estimated_dropout_rate = 3;
+
+ // The minimum number of (non-adversarial) users' values that must be
+ // aggregated together before the server can gain access to the aggregate,
+ // even transiently (e.g. in RAM).
+ int32 minimum_clients_in_server_visible_aggregate = 4;
+}
diff --git a/fcp/secagg/server/secagg_server_metrics_listener.h b/fcp/secagg/server/secagg_server_metrics_listener.h
new file mode 100644
index 0000000..af2781d
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_metrics_listener.h
@@ -0,0 +1,102 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_METRICS_LISTENER_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_METRICS_LISTENER_H_
+
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// Callback interface for reporting SecAggServer metrics.
+class SecAggServerMetricsListener {
+ public:
+ virtual ~SecAggServerMetricsListener() = default;
+
+ // Called each time SecAggServer is instantiated, starting the SecAgg
+ // protocol.
+ virtual void ProtocolStarts(ServerVariant server_variant) = 0;
+
+ // Size (in bytes) of a message sent by the SecAgg server to an individual
+ // user.
+ virtual void IndividualMessageSizes(
+ ServerToClientWrapperMessage::MessageContentCase message_type,
+ uint64_t size) = 0;
+
+ // Size (in bytes) of a message broadcast by the SecAgg server.
+ virtual void BroadcastMessageSizes(
+ ServerToClientWrapperMessage::MessageContentCase message_type,
+ uint64_t size) = 0;
+
+ // Size (in bytes) of a message received by the SecAgg server from a user.
+ virtual void MessageReceivedSizes(
+ ClientToServerWrapperMessage::MessageContentCase message_type,
+ bool message_expected, uint64_t size) = 0;
+
+ // Time (in milliseconds) taken for a client to send a response message to the
+ // server.
+ // Measured from the time the server sent the previous round's message to
+ // the time a new message was received. The first round is measured starting
+ // from the instantiation of the SecAggServer. Only messages received
+ // before the end of the round are monitored.
+ virtual void ClientResponseTimes(
+ ClientToServerWrapperMessage::MessageContentCase message_type,
+ uint64_t elapsed_millis) = 0;
+
+ // Time (in milliseconds) spent in each round.
+ // Counts end-to-end time spent in each state, starting from transitioning to
+ // that state and including waiting for the client messages necessary to
+ // transition to a next state.
+ virtual void RoundTimes(SecAggServerStateKind target_state, bool successful,
+ uint64_t elapsed_millis) = 0;
+
+ // Times (in milliseconds) taken to execute the PRF expansion step.
+ // During PRNG expansion, the server computes the map of masking vectors
+ // needed for unmasking. These are wall times measured over the execution of a
+ // multi-threaded process.
+ virtual void PrngExpansionTimes(uint64_t elapsed_millis) = 0;
+
+ // Number of clients at the end of each round.
+ virtual void RoundSurvivingClients(SecAggServerStateKind target_state,
+ uint64_t number_of_clients) = 0;
+
+ // Fraction of clients at each client state at the end of each round.
+ // Fractions are calculates off the total number of clients that the protocol
+ // starts with.
+ virtual void RoundCompletionFractions(SecAggServerStateKind target_state,
+ ClientStatus client_state,
+ double fraction) = 0;
+
+ // Records outcomes of SecAggServerImpl protocol runs.
+ // SUCCESS means the protocol ran through all phases and produced output.
+ virtual void ProtocolOutcomes(SecAggServerOutcome outcome) = 0;
+
+ // Called when a client drops during an execution of the SecAgg protocol.
+ virtual void ClientsDropped(ClientStatus abort_state,
+ ClientDropReason error_code) = 0;
+
+ // Time (in milliseconds) taken to reconstruct all users' keys from their
+ // Shamir secret shares.
+ // This includes all reconstruction operations for all shares, taken together.
+ virtual void ShamirReconstructionTimes(uint64_t elapsed_millis) = 0;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_METRICS_LISTENER_H_
diff --git a/fcp/secagg/server/secagg_server_prng_running_state.cc b/fcp/secagg/server/secagg_server_prng_running_state.cc
new file mode 100644
index 0000000..79bed14
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_prng_running_state.cc
@@ -0,0 +1,181 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_prng_running_state.h"
+
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secagg_server_completed_state.h"
+#include "fcp/tracing/tracing_span.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggServerPrngRunningState::SecAggServerPrngRunningState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking)
+ : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
+ number_of_clients_failed_before_sending_masked_input,
+ number_of_clients_terminated_without_unmasking,
+ SecAggServerStateKind::PRNG_RUNNING, std::move(impl)),
+ completion_status_(std::nullopt) {}
+
+SecAggServerPrngRunningState::~SecAggServerPrngRunningState() {}
+
+Status SecAggServerPrngRunningState::HandleMessage(
+ uint32_t client_id, const ClientToServerWrapperMessage& message) {
+ MessageReceived(message, false); // Messages are always unexpected here.
+ if (message.has_abort()) {
+ AbortClient(client_id, "", ClientDropReason::SENT_ABORT_MESSAGE,
+ /*notify=*/false);
+ } else {
+ AbortClient(client_id, "Non-abort message sent during PrngUnmasking step.",
+ ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
+ }
+ return FCP_STATUS(OK);
+}
+
+void SecAggServerPrngRunningState::HandleAbort() {
+ if (cancellation_token_) {
+ cancellation_token_->Cancel();
+ }
+}
+
+StatusOr<SecAggServerProtocolImpl::PrngWorkItems>
+SecAggServerPrngRunningState::Initialize() {
+ // Shamir reconstruction part of PRNG
+ absl::Time reconstruction_start = absl::Now();
+ FCP_ASSIGN_OR_RETURN(auto shamir_reconstruction_result,
+ impl()->HandleShamirReconstruction());
+ auto elapsed_millis =
+ absl::ToInt64Milliseconds(absl::Now() - reconstruction_start);
+ if (metrics()) {
+ metrics()->ShamirReconstructionTimes(elapsed_millis);
+ }
+ Trace<ShamirReconstruction>(elapsed_millis);
+
+ // Generating workitems for PRNG computation.
+ return impl()->InitializePrng(std::move(shamir_reconstruction_result));
+}
+
+void SecAggServerPrngRunningState::EnterState() {
+ auto initialize_result = Initialize();
+
+ if (!initialize_result.ok()) {
+ absl::MutexLock lock(&mutex_);
+ completion_status_ = initialize_result.status();
+ return;
+ }
+
+ auto work_items = std::move(initialize_result).value();
+
+ // Scheduling workitems to run.
+ prng_started_time_ = absl::Now();
+
+ cancellation_token_ = impl()->StartPrng(
+ work_items, [this](Status status) { this->PrngRunnerFinished(status); });
+}
+
+bool SecAggServerPrngRunningState::SetAsyncCallback(
+ std::function<void()> async_callback) {
+ absl::MutexLock lock(&mutex_);
+ FCP_CHECK(async_callback != nullptr) << "async_callback is expected";
+
+ if (completion_status_.has_value()) {
+ // PRNG computation has already finished.
+ impl()->scheduler()->ScheduleCallback(async_callback);
+ } else {
+ prng_done_callback_ = async_callback;
+ }
+ return true;
+}
+
+void SecAggServerPrngRunningState::PrngRunnerFinished(Status final_status) {
+ auto elapsed_millis =
+ absl::ToInt64Milliseconds(absl::Now() - prng_started_time_);
+ if (metrics()) {
+ metrics()->PrngExpansionTimes(elapsed_millis);
+ }
+ Trace<PrngExpansion>(elapsed_millis);
+
+ std::function<void()> prng_done_callback;
+ {
+ absl::MutexLock lock(&mutex_);
+ completion_status_ = final_status;
+ prng_done_callback = prng_done_callback_;
+ }
+
+ if (prng_done_callback) {
+ prng_done_callback();
+ }
+}
+
+void SecAggServerPrngRunningState::HandleAbortClient(
+ uint32_t client_id, ClientDropReason reason_code) {
+ set_client_status(client_id,
+ ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED);
+}
+
+StatusOr<std::unique_ptr<SecAggServerState>>
+SecAggServerPrngRunningState::ProceedToNextRound() {
+ // Block if StartPrng is still being called. That done to ensure that
+ // StartPrng doesn't use *this* object after it has been destroyed by
+ // the code that called ProceedToNextRound.
+ absl::MutexLock lock(&mutex_);
+
+ if (!completion_status_.has_value()) {
+ return FCP_STATUS(UNAVAILABLE);
+ }
+
+ // Don't send any messages; every client either got an "early success"
+ // notification at the end of Round 3, marked itself completed after sending
+ // its Round 3 message, or was already aborted.
+ if (completion_status_.value().ok()) {
+ return std::make_unique<SecAggServerCompletedState>(
+ ExitState(StateTransition::kSuccess),
+ number_of_clients_failed_after_sending_masked_input_,
+ number_of_clients_failed_before_sending_masked_input_,
+ number_of_clients_terminated_without_unmasking_);
+ } else {
+ return AbortState(std::string(completion_status_.value().message()),
+ SecAggServerOutcome::UNHANDLED_ERROR);
+ }
+}
+
+bool SecAggServerPrngRunningState::ReadyForNextRound() const {
+ absl::MutexLock lock(&mutex_);
+ return completion_status_.has_value();
+}
+
+int SecAggServerPrngRunningState::NumberOfIncludedInputs() const {
+ return total_number_of_clients() -
+ number_of_clients_failed_before_sending_masked_input_;
+}
+
+bool SecAggServerPrngRunningState::IsNumberOfIncludedInputsCommitted() const {
+ return true;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_prng_running_state.h b/fcp/secagg/server/secagg_server_prng_running_state.h
new file mode 100644
index 0000000..ab25f04
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_prng_running_state.h
@@ -0,0 +1,96 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_PRNG_RUNNING_STATE_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_PRNG_RUNNING_STATE_H_
+
+#include <functional>
+#include <memory>
+#include <optional>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "fcp/secagg/server/secagg_server_state.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class is the State for the SecAggServer when it has collected all secret
+// shares from the clients and is ready to compute its final output. The
+// protocol is essentially done, but this is a separate state from
+// SecAggClientCompletedState because there the server still needs to run the
+// potentially expensive step of using the PRNG to stretch client keys into
+// masking vectors.
+
+class SecAggServerPrngRunningState final : public SecAggServerState {
+ public:
+ SecAggServerPrngRunningState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking);
+
+ ~SecAggServerPrngRunningState() override;
+
+ void EnterState() override;
+
+ // Handles abort message from a client. Any other type of message is
+ // unexpected and results in the client being aborted.
+ Status HandleMessage(uint32_t client_id,
+ const ClientToServerWrapperMessage& message) override;
+
+ bool IsNumberOfIncludedInputsCommitted() const override;
+
+ int NumberOfIncludedInputs() const override;
+
+ StatusOr<std::unique_ptr<SecAggServerState> > ProceedToNextRound() override;
+
+ bool ReadyForNextRound() const override;
+
+ bool SetAsyncCallback(std::function<void()> async_callback) override;
+
+ private:
+ void HandleAbort() override;
+
+ void HandleAbortClient(uint32_t client_id,
+ ClientDropReason reason_code) override;
+
+ // Called to perform the initial synchronous part of PRNG state.
+ StatusOr<SecAggServerProtocolImpl::PrngWorkItems> Initialize();
+
+ // This is called when all computations are finished.
+ // final_status indicates whether PRNG computation has finished successfully.
+ void PrngRunnerFinished(Status final_status);
+
+ // The status is assigned when the state completes either successfully or
+ // unsuccessfully.
+ std::optional<Status> completion_status_ ABSL_GUARDED_BY(mutex_);
+
+ absl::Time prng_started_time_;
+ CancellationToken cancellation_token_;
+
+ std::function<void()> prng_done_callback_ ABSL_GUARDED_BY(mutex_);
+
+ // Protects this object from being destroyed while StartPrng call is still
+ // in progress. Also protects completion_status_ and prng_done_callback_.
+ mutable absl::Mutex mutex_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_PRNG_RUNNING_STATE_H_
diff --git a/fcp/secagg/server/secagg_server_prng_running_state_test.cc b/fcp/secagg/server/secagg_server_prng_running_state_test.cc
new file mode 100644
index 0000000..06c837c
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_prng_running_state_test.cc
@@ -0,0 +1,997 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_prng_running_state.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/node_hash_map.h"
+#include "absl/container/node_hash_set.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/secagg_scheduler.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/map_of_masks.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
+#include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
+#include "fcp/secagg/testing/server/test_async_runner.h"
+#include "fcp/secagg/testing/test_matchers.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::_;
+using ::testing::Eq;
+using ::testing::Ge;
+using ::testing::NiceMock;
+
+// For testing purposes, make an AesKey out of a string.
+AesKey MakeAesKey(const std::string& key) {
+ EXPECT_THAT(key.size(), Eq(AesKey::kSize));
+ return AesKey(reinterpret_cast<const uint8_t*>(key.c_str()));
+}
+
+class MockScheduler : public Scheduler {
+ public:
+ MOCK_METHOD(void, Schedule, (std::function<void()>), (override));
+ MOCK_METHOD(void, WaitUntilIdle, ());
+};
+
+constexpr auto call_fn = [](const std::function<void()>& f) { f(); };
+
+// Default test session_id.
+std::unique_ptr<SessionId> MakeTestSessionId() {
+ SessionId session_id = {"session id number, 32 bytes long"};
+ return std::make_unique<SessionId>(session_id);
+}
+
+std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
+ std::vector<InputVectorSpecification> input_vector_specs,
+ MockSendToClientsInterface* sender,
+ MockSecAggServerMetricsListener* metrics_listener = nullptr) {
+ SecretSharingGraphFactory factory;
+ auto parallel_scheduler = std::make_unique<NiceMock<MockScheduler>>();
+ auto sequential_scheduler = std::make_unique<NiceMock<MockScheduler>>();
+ EXPECT_CALL(*parallel_scheduler, Schedule(_)).WillRepeatedly(call_fn);
+ EXPECT_CALL(*sequential_scheduler, Schedule(_)).WillRepeatedly(call_fn);
+ auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
+ factory.CreateCompleteGraph(4, 3), // total number of clients is 4
+ 3, // minimum_number_of_clients_to_proceed
+ input_vector_specs,
+ std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
+ std::make_unique<AesCtrPrngFactory>(), sender,
+ std::make_unique<TestAsyncRunner>(std::move(parallel_scheduler),
+ std::move(sequential_scheduler)),
+ std::vector<ClientStatus>(4, ClientStatus::UNMASKING_RESPONSE_RECEIVED),
+ ServerVariant::NATIVE_V1);
+ impl->set_session_id(MakeTestSessionId());
+ EcdhPregeneratedTestKeys ecdh_keys;
+ for (int i = 0; i < 4; ++i) {
+ impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
+ }
+ impl->set_masked_input(std::make_unique<SecAggUnpackedVectorMap>());
+ return impl;
+}
+
+// Mock class containing a callback that would be called when the PRNG is done.
+class MockPrngDone {
+ public:
+ MOCK_METHOD(void, Callback, ());
+};
+
+TEST(SecaggServerPrngRunningStateTest, IsAbortedReturnsFalse) {
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->try_emplace(
+ i, sharer.Share(
+ 3, 4,
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
+ }
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_THAT(state.IsAborted(), Eq(false));
+}
+
+TEST(SecaggServerPrngRunningStateTest, IsCompletedSuccessfullyReturnsFalse) {
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->try_emplace(
+ i, sharer.Share(
+ 3, 4,
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
+ }
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_THAT(state.IsCompletedSuccessfully(), Eq(false));
+}
+
+TEST(SecaggServerPrngRunningStateTest, ErrorMessageRaisesError) {
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->try_emplace(
+ i, sharer.Share(
+ 3, 4,
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
+ }
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_THAT(state.ErrorMessage().ok(), Eq(false));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ NumberOfMessagesReceivedInThisRoundReturnsZero) {
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->try_emplace(
+ i, sharer.Share(
+ 3, 4,
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
+ }
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ NumberOfClientsReadyForNextRoundReturnsZero) {
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->try_emplace(
+ i, sharer.Share(
+ 3, 4,
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
+ }
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ HandleNonAbortMessageAbortsClientDoesNotRecordMetrics) {
+ TestTracingRecorder tracing_recorder;
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->try_emplace(
+ i, sharer.Share(
+ 3, 4,
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
+ }
+
+ auto impl =
+ CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "Non-abort message sent during PrngUnmasking step.");
+
+ ClientToServerWrapperMessage client_message;
+ EXPECT_CALL(*sender, Send(Eq(0), EqualsProto(abort_message)));
+ EXPECT_CALL(*metrics, MessageReceivedSizes(
+ Eq(ClientToServerWrapperMessage::
+ MessageContentCase::MESSAGE_CONTENT_NOT_SET),
+ Eq(false), Eq(client_message.ByteSizeLong())));
+ EXPECT_CALL(*metrics,
+ IndividualMessageSizes(
+ Eq(ServerToClientWrapperMessage::MessageContentCase::kAbort),
+ Eq(abort_message.ByteSizeLong())));
+ EXPECT_CALL(*metrics, ClientsDropped(_, _)).Times(0);
+
+ EXPECT_THAT(state.HandleMessage(0, client_message), IsOk());
+ EXPECT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(0));
+ ASSERT_THAT(state.AbortedClientIds().contains(0), Eq(true));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<IndividualMessageSent>(),
+ ElementsAre(IsEvent<IndividualMessageSent>(
+ Eq(0), Eq(ServerToClientMessageType_Abort),
+ Eq(abort_message.ByteSizeLong()))));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<ClientMessageReceived>(),
+ ElementsAre(IsEvent<ClientMessageReceived>(
+ Eq(ClientToServerMessageType_MessageContentNotSet),
+ Eq(client_message.ByteSizeLong()), Eq(false), Ge(0))));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ HandleAbortMessageAbortsClientDoesNotRecordMetrics) {
+ TestTracingRecorder tracing_recorder;
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->try_emplace(
+ i, sharer.Share(
+ 3, 4,
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
+ }
+
+ auto impl =
+ CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ ClientToServerWrapperMessage client_message;
+ client_message.mutable_abort();
+ EXPECT_CALL(*metrics,
+ MessageReceivedSizes(
+ Eq(ClientToServerWrapperMessage::MessageContentCase::kAbort),
+ Eq(false), Eq(client_message.ByteSizeLong())));
+ EXPECT_CALL(*metrics, ClientsDropped(_, _)).Times(0);
+ EXPECT_CALL(*sender, Send(Eq(0), _)).Times(0);
+
+ EXPECT_THAT(state.HandleMessage(0, client_message), IsOk());
+ EXPECT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(0));
+ ASSERT_THAT(state.AbortedClientIds().contains(0), Eq(true));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<ClientMessageReceived>(),
+ ElementsAre(IsEvent<ClientMessageReceived>(
+ Eq(ClientToServerMessageType_Abort),
+ Eq(client_message.ByteSizeLong()), Eq(false), Ge(0))));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ AbortReturnsValidStateAndNotifiesClients) {
+ TestTracingRecorder tracing_recorder;
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->try_emplace(
+ i, sharer.Share(
+ 3, 4,
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i))));
+ }
+
+ auto impl =
+ CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
+
+ EXPECT_CALL(*metrics,
+ ProtocolOutcomes(Eq(SecAggServerOutcome::UNHANDLED_ERROR)));
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
+ auto next_state =
+ state.Abort("test abort reason", SecAggServerOutcome::UNHANDLED_ERROR);
+
+ ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(next_state->ErrorMessage().ok(), Eq(true));
+ EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
+ ElementsAre(IsEvent<BroadcastMessageSent>(
+ Eq(ServerToClientMessageType_Abort),
+ Eq(abort_message.ByteSizeLong()))));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ PrngGetsRightMasksWhenAllClientsSurvive) {
+ // First, set up necessary data for the SecAggServerPrngRunningState
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->insert(std::make_pair(
+ i, sharer.Share(3, 4,
+ MakeAesKey(absl::StrCat(
+ "test 32 byte AES key for user #", i)))));
+ }
+
+ // Generate the expected (negative) sum of masking vectors using MapofMasks.
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+ for (int i = 0; i < 4; ++i) {
+ prng_keys_to_subtract.push_back(
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
+ }
+ auto session_id = MakeTestSessionId();
+ auto expected_map_of_masks =
+ MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ *session_id, AesCtrPrngFactory());
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
+ zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
+ impl->set_masked_input(std::move(zero_map));
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ MockPrngDone prng_done;
+ EXPECT_CALL(prng_done, Callback());
+
+ state.EnterState();
+ state.SetAsyncCallback([&]() { prng_done.Callback(); });
+
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state.ok(), Eq(true));
+ ASSERT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::COMPLETED));
+ auto result = next_state.value()->Result();
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(*result.value(),
+ testing::MatchesSecAggVectorMap(*expected_map_of_masks));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ PrngGetsRightMasksWithOneDeadClientAfterSendingInput) {
+ // In this test, client 1 died after sending its masked input. Its input will
+ // still be included.
+ //
+ // First, set up necessary data for the SecAggServerPrngRunningState.
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto pairwise_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+
+ auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ impl->set_client_status(
+ 1, ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED);
+
+ aborted_client_ids->insert(1);
+
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->insert(std::make_pair(
+ i, sharer.Share(3, 4,
+ MakeAesKey(absl::StrCat(
+ "test 32 byte AES key for user #", i)))));
+ // Blank out the share in position 1 because it would not have been sent.
+ (*self_shamir_share_table)[i][1] = {""};
+ }
+
+ auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
+ zero_map->insert(std::make_pair(
+ "foobar", SecAggUnpackedVector(std::vector<uint64_t>{0, 0, 0, 0}, 32)));
+ impl->set_masked_input(std::move(zero_map));
+ impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ // Generate the expected (negative) sum of masking vectors using MapofMasks.
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+ for (int i = 0; i < 4; ++i) {
+ prng_keys_to_subtract.push_back(
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
+ }
+ auto session_id = MakeTestSessionId();
+ auto expected_map_of_masks =
+ MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ *session_id, AesCtrPrngFactory());
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 1, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 1); // number_of_clients_terminated_without_unmasking
+
+ MockPrngDone prng_done;
+ EXPECT_CALL(prng_done, Callback());
+
+ state.EnterState();
+ state.SetAsyncCallback([&]() { prng_done.Callback(); });
+
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state.ok(), Eq(true));
+ ASSERT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::COMPLETED));
+ auto result = next_state.value()->Result();
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(*result.value(),
+ testing::MatchesSecAggVectorMap(*expected_map_of_masks));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ PrngGetsRightMasksWithOneDeadClientBeforeSendingInput) {
+ // In this test, client 1 died before sending its masked input but after other
+ // clients computed theirs, so its pairwise key will need to be canceled out.
+ //
+ // First, set up necessary data for the SecAggServerPrngRunningState.
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto pairwise_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ impl->set_client_status(1, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
+
+ auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
+ aborted_client_ids->insert(1);
+
+ EcdhPregeneratedTestKeys ecdh_keys;
+ for (int i = 0; i < 4; ++i) {
+ if (i == 1) {
+ // Client 1 died in the previous step, so the other clients will have sent
+ // shares of its pairwise key instead.
+ pairwise_shamir_share_table->insert(
+ std::make_pair(i, sharer.Share(3, 4, ecdh_keys.GetPrivateKey(i))));
+ // Blank out the share in position 1 because it would not have been sent.
+ (*pairwise_shamir_share_table)[i][1] = {""};
+ } else {
+ self_shamir_share_table->insert(std::make_pair(
+ i, sharer.Share(3, 4,
+ MakeAesKey(absl::StrCat(
+ "test 32 byte AES key for user #", i)))));
+ // Blank out the share in position 1 because it would not have been sent.
+ (*self_shamir_share_table)[i][1] = {""};
+ }
+ }
+
+ auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
+ zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
+ impl->set_masked_input(std::move(zero_map));
+ impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ // Generate the expected (negative) sum of masking vectors using MapofMasks.
+ // We should subtract the self masks of clients 0, 2, and 3. We should
+ // subtract the pairwise mask 2 and 3 added for 1, and add the pairwise mask
+ // that 0 subtracted for 1.
+ auto aborted_client_key_agreement =
+ EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(1));
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+ for (int i = 0; i < 4; ++i) {
+ if (i == 1) {
+ continue;
+ }
+ prng_keys_to_subtract.push_back(
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
+ AesKey pairwise_key = aborted_client_key_agreement.value()
+ ->ComputeSharedSecret(ecdh_keys.GetPublicKey(i))
+ .value();
+ if (i == 0) {
+ prng_keys_to_add.push_back(pairwise_key);
+ } else {
+ prng_keys_to_subtract.push_back(pairwise_key);
+ }
+ }
+ auto session_id = MakeTestSessionId();
+ auto expected_map_of_masks =
+ MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ *session_id, AesCtrPrngFactory());
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 1, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ MockPrngDone prng_done;
+ EXPECT_CALL(prng_done, Callback());
+
+ state.EnterState();
+ state.SetAsyncCallback([&]() { prng_done.Callback(); });
+
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state.ok(), Eq(true));
+ ASSERT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::COMPLETED));
+ auto result = next_state.value()->Result();
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(*result.value(),
+ testing::MatchesSecAggVectorMap(*expected_map_of_masks));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ PrngGetsRightMasksAndCallsCallbackIfSpecified) {
+ // In this test, there is now a callback that should be called when the PRNG
+ // is done running.
+ //
+ // First, set up necessary data for the SecAggServerPrngRunningState.
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto pairwise_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ impl->set_client_status(1, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
+
+ auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
+ aborted_client_ids->insert(1);
+
+ EcdhPregeneratedTestKeys ecdh_keys;
+ for (int i = 0; i < 4; ++i) {
+ if (i == 1) {
+ // Client 1 died in the previous step, so the other clients will have sent
+ // shares of its pairwise key instead.
+ pairwise_shamir_share_table->insert(
+ std::make_pair(i, sharer.Share(3, 4, ecdh_keys.GetPrivateKey(i))));
+ // Blank out the share in position 1 because it would not have been sent.
+ (*pairwise_shamir_share_table)[i][1] = {""};
+ } else {
+ self_shamir_share_table->insert(std::make_pair(
+ i, sharer.Share(3, 4,
+ MakeAesKey(absl::StrCat(
+ "test 32 byte AES key for user #", i)))));
+ // Blank out the share in position 1 because it would not have been sent.
+ (*self_shamir_share_table)[i][1] = {""};
+ }
+ }
+
+ auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
+ zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
+ impl->set_masked_input(std::move(zero_map));
+ impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ // Generate the expected (negative) sum of masking vectors using MapofMasks.
+ // We should subtract the self masks of clients 0, 2, and 3. We should
+ // subtract the pairwise mask 2 and 3 added for 1, and add the pairwise mask
+ // that 0 subtracted for 1.
+ auto aborted_client_key_agreement =
+ EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(1));
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+ for (int i = 0; i < 4; ++i) {
+ if (i == 1) {
+ continue;
+ }
+ prng_keys_to_subtract.push_back(
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
+ AesKey pairwise_key = aborted_client_key_agreement.value()
+ ->ComputeSharedSecret(ecdh_keys.GetPublicKey(i))
+ .value();
+ if (i == 0) {
+ prng_keys_to_add.push_back(pairwise_key);
+ } else {
+ prng_keys_to_subtract.push_back(pairwise_key);
+ }
+ }
+ auto session_id = MakeTestSessionId();
+ auto expected_map_of_masks =
+ MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ *session_id, AesCtrPrngFactory());
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 1, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ MockPrngDone prng_done;
+ EXPECT_CALL(prng_done, Callback());
+
+ state.EnterState();
+ state.SetAsyncCallback([&]() { prng_done.Callback(); });
+
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state.ok(), Eq(true));
+ ASSERT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::COMPLETED));
+ auto result = next_state.value()->Result();
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(*result.value(),
+ testing::MatchesSecAggVectorMap(*expected_map_of_masks));
+}
+
+TEST(SecaggServerPrngRunningStateTest, SetAsyncCallbackCanBeCalledTwice) {
+ // StartPrng should have the property that it can be called after it has
+ // already run successfully without any problems. It should just return OK
+ // again.
+ //
+ // First, set up necessary data for the SecAggServerPrngRunningState.
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto pairwise_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ impl->set_client_status(1, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
+
+ auto aborted_client_ids = std::make_unique<absl::flat_hash_set<uint32_t>>();
+ aborted_client_ids->insert(1);
+
+ EcdhPregeneratedTestKeys ecdh_keys;
+ for (int i = 0; i < 4; ++i) {
+ if (i == 1) {
+ // Client 1 died in the previous step, so the other clients will have sent
+ // shares of its pairwise key instead.
+ pairwise_shamir_share_table->insert(
+ std::make_pair(i, sharer.Share(3, 4, ecdh_keys.GetPrivateKey(i))));
+ // Blank out the share in position 1 because it would not have been sent.
+ (*pairwise_shamir_share_table)[i][1] = {""};
+ } else {
+ self_shamir_share_table->insert(std::make_pair(
+ i, sharer.Share(3, 4,
+ MakeAesKey(absl::StrCat(
+ "test 32 byte AES key for user #", i)))));
+ // Blank out the share in position 1 because it would not have been sent.
+ (*self_shamir_share_table)[i][1] = {""};
+ }
+ }
+
+ auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
+ zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
+ impl->set_masked_input(std::move(zero_map));
+ impl->set_pairwise_shamir_share_table(std::move(pairwise_shamir_share_table));
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ // Generate the expected (negative) sum of masking vectors using MapofMasks.
+ // We should subtract the self masks of clients 0, 2, and 3. We should
+ // subtract the pairwise mask 2 and 3 added for 1, and add the pairwise mask
+ // that 0 subtracted for 1.
+ auto aborted_client_key_agreement =
+ EcdhKeyAgreement::CreateFromPrivateKey(ecdh_keys.GetPrivateKey(1));
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+ for (int i = 0; i < 4; ++i) {
+ if (i == 1) {
+ continue;
+ }
+ prng_keys_to_subtract.push_back(
+ MakeAesKey(absl::StrCat("test 32 byte AES key for user #", i)));
+ AesKey pairwise_key = aborted_client_key_agreement.value()
+ ->ComputeSharedSecret(ecdh_keys.GetPublicKey(i))
+ .value();
+ if (i == 0) {
+ prng_keys_to_add.push_back(pairwise_key);
+ } else {
+ prng_keys_to_subtract.push_back(pairwise_key);
+ }
+ }
+ auto session_id = MakeTestSessionId();
+ auto expected_map_of_masks =
+ MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ *session_id, AesCtrPrngFactory());
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 1, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ MockPrngDone prng_done;
+ EXPECT_CALL(prng_done, Callback());
+
+ state.EnterState();
+ state.SetAsyncCallback([&]() { prng_done.Callback(); });
+
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
+
+ // Make sure we can call SetAsyncCallback again.
+ MockPrngDone prng_done_2;
+ EXPECT_CALL(prng_done_2, Callback());
+ state.SetAsyncCallback([&]() { prng_done_2.Callback(); });
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state.ok(), Eq(true));
+ ASSERT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::COMPLETED));
+ auto result = next_state.value()->Result();
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(*result.value(),
+ testing::MatchesSecAggVectorMap(*expected_map_of_masks));
+}
+
+TEST(SecaggServerPrngRunningStateTest,
+ PrngGetsRightMasksWhenClientsUse16BSelfKeys) {
+ // TODO(team): This test is only for ensuring Java compatibility.
+ // First, set up necessary data for the SecAggServerPrngRunningState
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->insert(std::make_pair(
+ i, sharer.Share(3, 4,
+ AesKey(reinterpret_cast<const uint8_t*>(
+ absl::StrCat("16B key of user", i).c_str()),
+ 16))));
+ }
+
+ // Generate the expected (negative) sum of masking vectors using MapofMasks.
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+ for (int i = 0; i < 4; ++i) {
+ prng_keys_to_subtract.push_back(
+ AesKey(reinterpret_cast<const uint8_t*>(
+ absl::StrCat("16B key of user", i).c_str()),
+ 16));
+ }
+ auto session_id = MakeTestSessionId();
+ auto expected_map_of_masks =
+ MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ *session_id, AesCtrPrngFactory());
+
+ auto impl = CreateSecAggServerProtocolImpl(input_vector_specs, sender.get());
+ auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
+ zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
+ impl->set_masked_input(std::move(zero_map));
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ MockPrngDone prng_done;
+ EXPECT_CALL(prng_done, Callback());
+
+ state.EnterState();
+ state.SetAsyncCallback([&]() { prng_done.Callback(); });
+
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state.ok(), Eq(true));
+ ASSERT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::COMPLETED));
+ auto result = next_state.value()->Result();
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(*result.value(),
+ testing::MatchesSecAggVectorMap(*expected_map_of_masks));
+}
+
+TEST(SecaggServerPrngRunningStateTest, TimingMetricsAreRecorded) {
+ // First, set up necessary data for the SecAggServerPrngRunningState
+ TestTracingRecorder tracing_recorder;
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ FakePrng prng;
+ ShamirSecretSharing sharer;
+ auto self_shamir_share_table = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ for (int i = 0; i < 4; ++i) {
+ self_shamir_share_table->insert(std::make_pair(
+ i, sharer.Share(3, 4,
+ MakeAesKey(absl::StrCat(
+ "test 32 byte AES key for user #", i)))));
+ }
+
+ auto impl =
+ CreateSecAggServerProtocolImpl(input_vector_specs, sender.get(), metrics);
+ auto zero_map = std::make_unique<SecAggUnpackedVectorMap>();
+ zero_map->emplace("foobar", SecAggUnpackedVector({0, 0, 0, 0}, 32));
+ impl->set_masked_input(std::move(zero_map));
+ impl->set_pairwise_shamir_share_table(
+ std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>());
+ impl->set_self_shamir_share_table(std::move(self_shamir_share_table));
+
+ SecAggServerPrngRunningState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ MockPrngDone prng_done;
+ EXPECT_CALL(prng_done, Callback());
+
+ EXPECT_CALL(*metrics, PrngExpansionTimes(Ge(0)));
+ EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::PRNG_RUNNING),
+ Eq(true), Ge(0)));
+ EXPECT_CALL(*metrics, ShamirReconstructionTimes(Ge(0)));
+
+ state.EnterState();
+ state.SetAsyncCallback([&]() { prng_done.Callback(); });
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(true));
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state.ok(), Eq(true));
+ ASSERT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::COMPLETED));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<ShamirReconstruction>(),
+ ElementsAre(IsEvent<ShamirReconstruction>(Ge(0))));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<PrngExpansion>(),
+ ElementsAre(IsEvent<PrngExpansion>(Ge(0))));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_protocol_impl.cc b/fcp/secagg/server/secagg_server_protocol_impl.cc
new file mode 100644
index 0000000..cceb251
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_protocol_impl.cc
@@ -0,0 +1,403 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_protocol_impl.h"
+
+#include <string>
+#include <utility>
+
+#include "absl/container/node_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/tracing_schema.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/tracing/tracing_span.h"
+
+namespace {
+
+// Defines an experiments object with no experiments enabled
+class EmptyExperiment : public fcp::secagg::ExperimentsInterface {
+ public:
+ bool IsEnabled(absl::string_view experiment_name) override { return false; }
+};
+
+} // namespace
+
+namespace fcp {
+namespace secagg {
+
+SecAggServerProtocolImpl::SecAggServerProtocolImpl(
+ std::unique_ptr<SecretSharingGraph> graph,
+ int minimum_number_of_clients_to_proceed,
+ std::unique_ptr<SecAggServerMetricsListener> metrics,
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ SendToClientsInterface* sender, std::unique_ptr<SecAggScheduler> scheduler,
+ std::vector<ClientStatus> client_statuses,
+ std::unique_ptr<ExperimentsInterface> experiments)
+ : secret_sharing_graph_(std::move(graph)),
+ minimum_number_of_clients_to_proceed_(
+ minimum_number_of_clients_to_proceed),
+ metrics_(std::move(metrics)),
+ prng_factory_(std::move(prng_factory)),
+ sender_(sender),
+ scheduler_(std::move(scheduler)),
+ total_number_of_clients_(client_statuses.size()),
+ client_statuses_(std::move(client_statuses)),
+ experiments_(experiments ? std::move(experiments)
+ : std::unique_ptr<ExperimentsInterface>(
+ new EmptyExperiment())),
+ pairwise_public_keys_(total_number_of_clients()),
+ pairs_of_public_keys_(total_number_of_clients()),
+ encrypted_shares_(total_number_of_clients(),
+ std::vector<std::string>(number_of_neighbors())) {}
+
+void SecAggServerProtocolImpl::SetResult(
+ std::unique_ptr<SecAggVectorMap> result) {
+ FCP_CHECK(!result_) << "Result can't be set twice";
+ result_ = std::move(result);
+}
+
+std::unique_ptr<SecAggVectorMap> SecAggServerProtocolImpl::TakeResult() {
+ return std::move(result_);
+}
+
+// -----------------------------------------------------------------------------
+// Round 0 methods
+// -----------------------------------------------------------------------------
+
+Status SecAggServerProtocolImpl::HandleAdvertiseKeys(
+ uint32_t client_id, const AdvertiseKeys& advertise_keys) {
+ const auto& pair_of_public_keys = advertise_keys.pair_of_public_keys();
+ if ((pair_of_public_keys.enc_pk().size() != EcdhPublicKey::kSize &&
+ (pair_of_public_keys.enc_pk().size() <
+ EcdhPublicKey::kUncompressedSize ||
+ pair_of_public_keys.noise_pk().size() <
+ EcdhPublicKey::kUncompressedSize)) ||
+ pair_of_public_keys.enc_pk().size() !=
+ pair_of_public_keys.noise_pk().size()) {
+ return ::absl::InvalidArgumentError(
+ "A public key sent by the client was not the correct size.");
+ }
+
+ if (pair_of_public_keys.noise_pk().size() == EcdhPublicKey::kSize) {
+ pairwise_public_keys_[client_id] =
+ EcdhPublicKey(reinterpret_cast<const uint8_t*>(
+ pair_of_public_keys.noise_pk().c_str()));
+ } else {
+ // Strip off the header, if any, and use the uncompressed ECDH key.
+ size_t key_size_with_header = pair_of_public_keys.noise_pk().size();
+ pairwise_public_keys_[client_id] = EcdhPublicKey(
+ reinterpret_cast<const uint8_t*>(
+ pair_of_public_keys.noise_pk()
+ .substr(key_size_with_header - EcdhPublicKey::kUncompressedSize)
+ .c_str()),
+ EcdhPublicKey::kUncompressed);
+ }
+
+ pairs_of_public_keys_[client_id] = pair_of_public_keys;
+ return ::absl::OkStatus();
+}
+
+void SecAggServerProtocolImpl::ErasePublicKeysForClient(uint32_t client_id) {
+ pairwise_public_keys_[client_id] = EcdhPublicKey();
+ pairs_of_public_keys_[client_id] = PairOfPublicKeys();
+}
+
+void SecAggServerProtocolImpl::ComputeSessionId() {
+ // This message contains all keys, and is only built for the purpose
+ // of deriving the session key from it
+ ShareKeysRequest share_keys_request;
+ for (int i = 0; i < total_number_of_clients(); ++i) {
+ *(share_keys_request.add_pairs_of_public_keys()) = pairs_of_public_keys_[i];
+ }
+ set_session_id(std::make_unique<SessionId>(
+ fcp::secagg::ComputeSessionId(share_keys_request)));
+}
+
+void SecAggServerProtocolImpl::PrepareShareKeysRequestForClient(
+ uint32_t client_id, ShareKeysRequest* request) const {
+ request->clear_pairs_of_public_keys();
+ for (int j = 0; j < secret_sharing_graph()->GetDegree(); ++j) {
+ *request->add_pairs_of_public_keys() =
+ pairs_of_public_keys_[secret_sharing_graph()->GetNeighbor(client_id,
+ j)];
+ }
+}
+
+void SecAggServerProtocolImpl::ClearPairsOfPublicKeys() {
+ pairs_of_public_keys_.clear();
+}
+
+// -----------------------------------------------------------------------------
+// Round 1 methods
+// -----------------------------------------------------------------------------
+
+Status SecAggServerProtocolImpl::HandleShareKeysResponse(
+ uint32_t client_id, const ShareKeysResponse& share_keys_response) {
+ // Verify that the message has the expected fields set before accepting it.
+ if (share_keys_response.encrypted_key_shares().size() !=
+ number_of_neighbors()) {
+ return ::absl::InvalidArgumentError(
+ "The ShareKeysResponse does not contain the expected number of "
+ "encrypted pairs of key shares.");
+ }
+
+ for (uint32_t i = 0; i < number_of_neighbors(); ++i) {
+ bool i_is_empty = share_keys_response.encrypted_key_shares(i).empty();
+ int neighbor_id = GetNeighbor(client_id, i);
+ bool i_should_be_empty = (neighbor_id == client_id) ||
+ (client_status(neighbor_id) ==
+ ClientStatus::DEAD_BEFORE_SENDING_ANYTHING);
+ if (i_is_empty && !i_should_be_empty) {
+ return ::absl::InvalidArgumentError(
+ "Client omitted a key share that was expected.");
+ }
+ if (i_should_be_empty && !i_is_empty) {
+ return ::absl::InvalidArgumentError(
+ "Client sent a key share that was not expected.");
+ }
+ }
+
+ // Client sent a valid message.
+ for (int i = 0; i < number_of_neighbors(); ++i) {
+ int neighbor_id = GetNeighbor(client_id, i);
+ // neighbor_id and client_id are neighbors, and thus index_in_neighbors is
+ // in [0, number_neighbors()-1]
+ int index_in_neighbor = GetNeighborIndexOrDie(neighbor_id, client_id);
+ encrypted_shares_[neighbor_id][index_in_neighbor] =
+ share_keys_response.encrypted_key_shares(i);
+ }
+
+ return ::absl::OkStatus();
+}
+
+void SecAggServerProtocolImpl::EraseShareKeysForClient(uint32_t client_id) {
+ for (int i = 0; i < number_of_neighbors(); ++i) {
+ int neighbor_id = GetNeighbor(client_id, i);
+ int index_in_neighbor = GetNeighborIndexOrDie(neighbor_id, client_id);
+ encrypted_shares_[neighbor_id][index_in_neighbor].clear();
+ }
+}
+
+void SecAggServerProtocolImpl::PrepareMaskedInputCollectionRequestForClient(
+ uint32_t client_id, MaskedInputCollectionRequest* request) const {
+ request->clear_encrypted_key_shares();
+ for (int j = 0; j < number_of_neighbors(); ++j) {
+ request->add_encrypted_key_shares(encrypted_shares_[client_id][j]);
+ }
+}
+
+void SecAggServerProtocolImpl::ClearShareKeys() { encrypted_shares_.clear(); }
+
+// -----------------------------------------------------------------------------
+// Round 3 methods
+// -----------------------------------------------------------------------------
+
+// This enum and the following function relates the client status to whether
+// or not its pairwise mask, its self mask, or neither will appear in the
+// summed masked input.
+enum class ClientMask { kPairwiseMask, kSelfMask, kNoMask };
+
+// Returns the type of mask the server expects to receive a share for, for a
+// give client status.
+inline ClientMask ClientMaskType(const ClientStatus& client_status) {
+ switch (client_status) {
+ case ClientStatus::SHARE_KEYS_RECEIVED:
+ case ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED:
+ return ClientMask::kPairwiseMask;
+ break;
+ case ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED:
+ case ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED:
+ case ClientStatus::UNMASKING_RESPONSE_RECEIVED:
+ case ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED:
+ return ClientMask::kSelfMask;
+ break;
+ case ClientStatus::READY_TO_START:
+ case ClientStatus::DEAD_BEFORE_SENDING_ANYTHING:
+ case ClientStatus::ADVERTISE_KEYS_RECEIVED:
+ case ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED:
+ default:
+ return ClientMask::kNoMask;
+ }
+}
+
+void SecAggServerProtocolImpl::SetUpShamirSharesTables() {
+ pairwise_shamir_share_table_ = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+ self_shamir_share_table_ = std::make_unique<
+ absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
+
+ // Prepare the share tables with rows for clients we expect to have shares for
+ for (uint32_t i = 0; i < total_number_of_clients(); ++i) {
+ auto mask_type = ClientMaskType(client_status(i));
+ if (mask_type == ClientMask::kPairwiseMask) {
+ pairwise_shamir_share_table_->emplace(i, number_of_neighbors());
+ } else if (mask_type == ClientMask::kSelfMask) {
+ self_shamir_share_table_->emplace(i, number_of_neighbors());
+ }
+ }
+}
+
+Status SecAggServerProtocolImpl::HandleUnmaskingResponse(
+ uint32_t client_id, const UnmaskingResponse& unmasking_response) {
+ FCP_CHECK(pairwise_shamir_share_table_ != nullptr &&
+ self_shamir_share_table_ != nullptr)
+ << "Shamir Shares Tables haven't been initialized";
+
+ // Verify the client sent all the right types of shares.
+ for (uint32_t i = 0; i < number_of_neighbors(); ++i) {
+ int ith_neighbor = GetNeighbor(client_id, i);
+ switch (ClientMaskType(client_status(ith_neighbor))) {
+ case ClientMask::kPairwiseMask:
+ if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() !=
+ NoiseOrPrfKeyShare::OneofSharesCase::kNoiseSkShare) {
+ return ::absl::InvalidArgumentError(
+ "Client did not include the correct type of key share.");
+ }
+ break;
+ case ClientMask::kSelfMask:
+ if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() !=
+ NoiseOrPrfKeyShare::OneofSharesCase::kPrfSkShare) {
+ return ::absl::InvalidArgumentError(
+ "Client did not include the correct type of key share.");
+ }
+ break;
+ case ClientMask::kNoMask:
+ default:
+ if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() !=
+ NoiseOrPrfKeyShare::OneofSharesCase::ONEOF_SHARES_NOT_SET) {
+ return ::absl::InvalidArgumentError(
+ "Client included a key share for which none was expected.");
+ }
+ }
+ }
+ // Prepare the received key shares for reconstruction by inserting them into
+ // the tables.
+ for (int i = 0; i < number_of_neighbors(); ++i) {
+ // Find the index of client_id in the list of neighbors of the ith
+ // neighbor of client_id
+ int ith_neighbor = GetNeighbor(client_id, i);
+ int index = GetNeighborIndexOrDie(ith_neighbor, client_id);
+ if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() ==
+ NoiseOrPrfKeyShare::OneofSharesCase::kNoiseSkShare) {
+ (*pairwise_shamir_share_table_)[ith_neighbor][index].data =
+ unmasking_response.noise_or_prf_key_shares(i).noise_sk_share();
+ } else if (unmasking_response.noise_or_prf_key_shares(i)
+ .oneof_shares_case() ==
+ NoiseOrPrfKeyShare::OneofSharesCase::kPrfSkShare) {
+ (*self_shamir_share_table_)[ith_neighbor][index].data =
+ unmasking_response.noise_or_prf_key_shares(i).prf_sk_share();
+ }
+ }
+ return ::absl::OkStatus();
+}
+
+// -----------------------------------------------------------------------------
+// PRNG computation methods
+// -----------------------------------------------------------------------------
+
+StatusOr<SecAggServerProtocolImpl::ShamirReconstructionResult>
+SecAggServerProtocolImpl::HandleShamirReconstruction() {
+ FCP_CHECK(pairwise_shamir_share_table_ != nullptr &&
+ self_shamir_share_table_ != nullptr)
+ << "Shamir Shares Tables haven't been initialized";
+
+ ShamirReconstructionResult result;
+ ShamirSecretSharing reconstructor;
+
+ for (const auto& item : *pairwise_shamir_share_table_) {
+ FCP_ASSIGN_OR_RETURN(std::string reconstructed_key,
+ reconstructor.Reconstruct(
+ minimum_surviving_neighbors_for_reconstruction(),
+ item.second, EcdhPrivateKey::kSize));
+ auto key_agreement = EcdhKeyAgreement::CreateFromPrivateKey(EcdhPrivateKey(
+ reinterpret_cast<const uint8_t*>(reconstructed_key.c_str())));
+ if (!key_agreement.ok()) {
+ // The server was unable to reconstruct the private key, probably
+ // because some client(s) sent invalid key shares. The only way out is
+ // to abort.
+ return ::absl::InvalidArgumentError(
+ "Unable to reconstruct aborted client's private key from shares");
+ }
+ result.aborted_client_key_agreements.try_emplace(
+ item.first, std::move(*(key_agreement.value())));
+ }
+
+ for (const auto& item : *self_shamir_share_table_) {
+ FCP_ASSIGN_OR_RETURN(
+ AesKey reconstructed,
+ AesKey::CreateFromShares(
+ item.second, minimum_surviving_neighbors_for_reconstruction()));
+ result.self_keys.try_emplace(item.first, reconstructed);
+ }
+
+ return std::move(result);
+}
+
+StatusOr<SecAggServerProtocolImpl::PrngWorkItems>
+SecAggServerProtocolImpl::InitializePrng(
+ const ShamirReconstructionResult& shamir_reconstruction_result) const {
+ PrngWorkItems work_items;
+
+ for (uint32_t i = 0; i < total_number_of_clients(); ++i) {
+ // Although clients who are DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED and
+ // kDeadAfterUnmaskingResponseReceived have they did so after sending
+ // their masked input. Therefore, it is possible to include their
+ // contribution to the aggregate sum. So we treat them here as if they had
+ // completed the protocol correctly.
+ auto status = client_status(i);
+ if (status != ClientStatus::UNMASKING_RESPONSE_RECEIVED &&
+ status != ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED &&
+ status != ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED) {
+ continue;
+ }
+
+ // Since client i's value will be included in the sum, the server must
+ // remove its self mask.
+ auto it = shamir_reconstruction_result.self_keys.find(i);
+ FCP_CHECK(it != shamir_reconstruction_result.self_keys.end());
+ work_items.prng_keys_to_subtract.push_back(it->second);
+
+ // For clients that aborted, client i's sum contains an un-canceled
+ // pairwise mask generated between the two clients. The server must remove
+ // this pairwise mask from the sum.
+ for (const auto& item :
+ shamir_reconstruction_result.aborted_client_key_agreements) {
+ if (!AreNeighbors(i, item.first)) {
+ continue;
+ }
+ auto shared_key =
+ item.second.ComputeSharedSecret(pairwise_public_keys(i));
+ if (!shared_key.ok()) {
+ // Should not happen; invalid public keys should already be detected.
+ // But if it does happen, abort.
+ return ::absl::InvalidArgumentError(
+ "Invalid public key from client detected");
+ }
+ if (IsOutgoingNeighbor(i, item.first)) {
+ work_items.prng_keys_to_add.push_back(shared_key.value());
+ } else {
+ work_items.prng_keys_to_subtract.push_back(shared_key.value());
+ }
+ }
+ }
+
+ return std::move(work_items);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_protocol_impl.h b/fcp/secagg/server/secagg_server_protocol_impl.h
new file mode 100644
index 0000000..f60bf5b
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_protocol_impl.h
@@ -0,0 +1,383 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "fcp/secagg/server/experiments_interface.h"
+#include "fcp/secagg/server/secagg_scheduler.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_metrics_listener.h"
+#include "fcp/secagg/server/secret_sharing_graph.h"
+#include "fcp/secagg/server/send_to_clients_interface.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+
+// Interface that describes internal implementation of SecAgg protocol.
+//
+// The general design is the following
+//
+// +--------------+ +-------------------+ +--------------------------+
+// | SecAggServer |--->| SecAggServerState |--->| SecAggServerProtocolImpl |
+// +--------------+ +-------------------+ +--------------------------+
+// ^ ^
+// /-\ /-\
+// | |
+// +-------------------+ +------------------------+
+// | Specific State | | Specific protocol impl |
+// +-------------------+ +------------------------+
+//
+// Specific states implement logic specific to each logic SecAgg state, such as
+// R0AdvertiseKeys or PrngRunning, while specific protocol implementation is
+// shared between all states and is responsible for encapsulating the data
+// for the protocol and providing methods for manipulating the data.
+//
+
+class SecAggServerProtocolImpl {
+ public:
+ explicit SecAggServerProtocolImpl(
+ std::unique_ptr<SecretSharingGraph> graph,
+ int minimum_number_of_clients_to_proceed,
+ std::unique_ptr<SecAggServerMetricsListener> metrics,
+ std::unique_ptr<AesPrngFactory> prng_factory,
+ SendToClientsInterface* sender,
+ std::unique_ptr<SecAggScheduler> scheduler,
+ std::vector<ClientStatus> client_statuses,
+ std::unique_ptr<ExperimentsInterface> experiments = nullptr);
+ virtual ~SecAggServerProtocolImpl() = default;
+
+ SecAggServerProtocolImpl(const SecAggServerProtocolImpl& other) = delete;
+ SecAggServerProtocolImpl& operator=(const SecAggServerProtocolImpl& other) =
+ delete;
+
+ // Returns server variant for this protocol implementation.
+ virtual ServerVariant server_variant() const = 0;
+
+ // Returns the graph that represents the cohort of clients.
+ inline const SecretSharingGraph* secret_sharing_graph() const {
+ return secret_sharing_graph_.get();
+ }
+
+ // Returns the minimum threshold number of clients that need to send valid
+ // responses in order for the protocol to proceed from one round to the next.
+ inline int minimum_number_of_clients_to_proceed() const {
+ return minimum_number_of_clients_to_proceed_;
+ }
+
+ // Returns the callback interface for recording metrics.
+ inline SecAggServerMetricsListener* metrics() const { return metrics_.get(); }
+
+ // Returns a reference to an instance of a subclass of AesPrngFactory.
+ inline AesPrngFactory* prng_factory() const { return prng_factory_.get(); }
+
+ // Returns the callback interface for sending protocol buffer messages to the
+ // client.
+ inline SendToClientsInterface* sender() const { return sender_; }
+
+ // Returns the scheduler for scheduling parallel computation tasks and
+ // callbacks.
+ inline SecAggScheduler* scheduler() const { return scheduler_.get(); }
+
+ // Returns the experiments
+ inline ExperimentsInterface* experiments() const {
+ return experiments_.get();
+ }
+
+ // Getting or setting the protocol result.
+ //
+ // TODO(team): SetResult should not be needed (except for testing) once
+ // PRNG computation is moved into the protocol implementation.
+ void SetResult(std::unique_ptr<SecAggVectorMap> result);
+ std::unique_ptr<SecAggVectorMap> TakeResult();
+
+ // Gets the client status.
+ inline const ClientStatus& client_status(uint32_t client_id) const {
+ return client_statuses_.at(client_id);
+ }
+
+ // Sets the client status.
+ inline void set_client_status(uint32_t client_id, ClientStatus status) {
+ client_statuses_[client_id] = status;
+ }
+
+ // Gets the number of clients that the protocol starts with.
+ inline size_t total_number_of_clients() const {
+ return total_number_of_clients_;
+ }
+
+ // Returns the number of neighbors of each client.
+ inline const int number_of_neighbors() const {
+ return secret_sharing_graph()->GetDegree();
+ }
+
+ // Returns the minimum number of neighbors of a client that must not drop-out
+ // for that client's contribution to be included in the sum. This corresponds
+ // to the threshold in the shamir secret sharing of self and pairwise masks.
+ inline const int minimum_surviving_neighbors_for_reconstruction() const {
+ return secret_sharing_graph()->GetThreshold();
+ }
+
+ // Returns client_id's ith neighbor.
+ // This function assumes that 0 <= i < number_of_neighbors() and will throw a
+ // runtime error if that's not the case
+ inline const int GetNeighbor(int client_id, int i) const {
+ return secret_sharing_graph()->GetNeighbor(client_id, i);
+ }
+
+ // Returns the index of client_id_2 in the list of neighbors of client_id_1,
+ // if present
+ inline const std::optional<int> GetNeighborIndex(int client_id_1,
+ int client_id_2) const {
+ return secret_sharing_graph()->GetNeighborIndex(client_id_1, client_id_2);
+ }
+
+ // Returns the index of client_id_2 in the list of neighbors of client_id_1
+ // This function assumes that client_id_1 and client_id_2 are neighbors, and
+ // will throw a runtime error if that's not the case
+ inline const int GetNeighborIndexOrDie(int client_id_1,
+ int client_id_2) const {
+ auto index =
+ secret_sharing_graph()->GetNeighborIndex(client_id_1, client_id_2);
+ FCP_CHECK(index.has_value());
+ return index.value();
+ }
+
+ // Returns true if clients client_id_1 and client_id_1 are neighbors, else
+ // false.
+ inline const bool AreNeighbors(int client_id_1, int client_id_2) const {
+ return secret_sharing_graph()->AreNeighbors(client_id_1, client_id_2);
+ }
+
+ // Returns true if client_id_1 is an outgoing neighbor of client_id_2, else
+ // false.
+ inline const bool IsOutgoingNeighbor(int client_id_1, int client_id_2) const {
+ return secret_sharing_graph()->IsOutgoingNeighbor(client_id_1, client_id_2);
+ }
+
+ inline void SetPairwisePublicKeys(uint32_t client_id,
+ const EcdhPublicKey& pairwise_key) {
+ pairwise_public_keys_[client_id] = pairwise_key;
+ }
+
+ inline const EcdhPublicKey& pairwise_public_keys(uint32_t client_id) const {
+ return pairwise_public_keys_[client_id];
+ }
+
+ inline const SessionId& session_id() const {
+ FCP_CHECK(session_id_ != nullptr);
+ return *session_id_;
+ }
+
+ void set_session_id(std::unique_ptr<SessionId> session_id) {
+ FCP_CHECK(session_id != nullptr);
+ session_id_ = std::move(session_id);
+ }
+
+ // TODO(team): Review whether getters and setters below are needed.
+ // Most of these fields are needed only for testing.
+
+ void set_pairwise_shamir_share_table(
+ std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>
+ pairwise_shamir_share_table) {
+ pairwise_shamir_share_table_ = std::move(pairwise_shamir_share_table);
+ }
+
+ void set_self_shamir_share_table(
+ std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>
+ self_shamir_share_table) {
+ self_shamir_share_table_ = std::move(self_shamir_share_table);
+ }
+
+ // ---------------------------------------------------------------------------
+ // Round 0 methods
+ // ---------------------------------------------------------------------------
+
+ // Sets the public key pairs for a client.
+ Status HandleAdvertiseKeys(uint32_t client_id,
+ const AdvertiseKeys& advertise_keys);
+
+ // Erases public key pairs for a client.
+ void ErasePublicKeysForClient(uint32_t client_id);
+
+ // Compute session ID based on public key pairs advertised by clients.
+ void ComputeSessionId();
+
+ // This method allows a protocol implementation to populate fields that are
+ // common to the ShareKeysRequest sent to all clients.
+ virtual Status InitializeShareKeysRequest(
+ ShareKeysRequest* request) const = 0;
+
+ // Prepares ShareKeysRequest message to send to the client.
+ // This method will update fields in the request as needed, but will not clear
+ // any fields that are not specific to the share keys request for the specific
+ // client. The caller can therefore set up a single ShareKeysRequest object,
+ // populate fields that will be common to all clients, and repeatedly call
+ // this method to set the client-specific fields before serializing the
+ // message and sending it.
+ void PrepareShareKeysRequestForClient(uint32_t client_id,
+ ShareKeysRequest* request) const;
+
+ // Clears all pairs of public keys.
+ void ClearPairsOfPublicKeys();
+
+ // ---------------------------------------------------------------------------
+ // Round 1 methods
+ // ---------------------------------------------------------------------------
+
+ // Sets the encrypted shares received from a client.
+ Status HandleShareKeysResponse(uint32_t client_id,
+ const ShareKeysResponse& share_keys_response);
+
+ // Erases the encrypted shares for a client.
+ void EraseShareKeysForClient(uint32_t client_id);
+
+ // Prepares MaskedInputCollectionRequest message to send to the client.
+ // This method will update fields in the request as needed, but will not clear
+ // any fields that are not specific to the share keys request for the specific
+ // client. The caller can therefore set up a single ShareKeysRequest object,
+ // populate fields that will be common to all clients, and repeatedly call
+ // this method to set the client-specific fields before serializing the
+ // message and sending it.
+ void PrepareMaskedInputCollectionRequestForClient(
+ uint32_t client_id, MaskedInputCollectionRequest* request) const;
+
+ // Clears all encrypted shares.
+ void ClearShareKeys();
+
+ // ---------------------------------------------------------------------------
+ // Round 2 methods
+ // ---------------------------------------------------------------------------
+
+ // Sets up the sum of encrypted vectors received by the clients in R1. This
+ // must be called before any other R2 methods are called.
+ virtual std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>>
+ SetupMaskedInputCollection() = 0;
+
+ // Finalizes the async aggregation of R2 messages before moving to R3.
+ virtual void FinalizeMaskedInputCollection() = 0;
+
+ // Check that an encrypted vector received by the user is valid, and add it to
+ // the sum of encrypted vectors.
+ virtual Status HandleMaskedInputCollectionResponse(
+ std::unique_ptr<MaskedInputCollectionResponse> masked_input_response) = 0;
+
+ // ---------------------------------------------------------------------------
+ // Round 3 methods
+ // ---------------------------------------------------------------------------
+
+ // This must be called in the beginning of round 3 to setup Shamir shares
+ // tables based on client states at the beginning of the round.
+ void SetUpShamirSharesTables();
+
+ // Populates Shamir shares tables with the data from UnmaskingResponse.
+ // Returning an error status means that the unmasking response was invalid.
+ Status HandleUnmaskingResponse(uint32_t client_id,
+ const UnmaskingResponse& unmasking_response);
+
+ // ---------------------------------------------------------------------------
+ // PRNG computation methods
+ // ---------------------------------------------------------------------------
+
+ // Result of performing Shamir secret sharing keys reconstruction.
+ struct ShamirReconstructionResult {
+ absl::flat_hash_map<uint32_t, EcdhKeyAgreement>
+ aborted_client_key_agreements;
+ absl::node_hash_map<uint32_t, AesKey> self_keys;
+ };
+
+ // Performs reconstruction secret sharing keys reconstruction step of
+ // the PRNG stage of the protocol.
+ StatusOr<ShamirReconstructionResult> HandleShamirReconstruction();
+
+ struct PrngWorkItems {
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+ };
+
+ // Initializes PRNG work items.
+ StatusOr<PrngWorkItems> InitializePrng(
+ const ShamirReconstructionResult& shamir_reconstruction_result) const;
+
+ // Tells the PRNG stage of the protocol to start running asynchronously by
+ // executing PRNG work items.
+ // The returned cancellation token can be used to abort the asynchronous
+ // execution.
+ virtual CancellationToken StartPrng(
+ const PrngWorkItems& work_items,
+ std::function<void(Status)> done_callback) = 0;
+
+ private:
+ std::unique_ptr<SecretSharingGraph> secret_sharing_graph_;
+ int minimum_number_of_clients_to_proceed_;
+
+ std::vector<InputVectorSpecification> input_vector_specs_;
+ std::unique_ptr<SecAggServerMetricsListener> metrics_;
+ std::unique_ptr<AesPrngFactory> prng_factory_;
+ SendToClientsInterface* sender_;
+ std::unique_ptr<SecAggScheduler> scheduler_;
+
+ std::unique_ptr<SecAggVectorMap> result_;
+
+ size_t total_number_of_clients_;
+ std::vector<ClientStatus> client_statuses_;
+ std::unique_ptr<ExperimentsInterface> experiments_;
+
+ // This vector collects the public keys sent by the clients that will be used
+ // for running the PRNG later on.
+ std::vector<EcdhPublicKey> pairwise_public_keys_;
+
+ // This vector collects all pairs of public keys sent by the clients, so they
+ // can be forwarded at the end of Advertise Keys round.
+ std::vector<PairOfPublicKeys> pairs_of_public_keys_;
+
+ std::unique_ptr<SessionId> session_id_;
+
+ // Track the encrypted shares received from clients in preparation for sending
+ // them. encrypted_shares_table_[i][j] is an encryption of the pair of shares
+ // to be sent to client i, received from client j.
+ std::vector<std::vector<std::string>> encrypted_shares_;
+
+ // Shamir shares tables.
+ // These store shares that have been collected from clients, and will be built
+ // up over the course of round 3. For both tables, the map key represents
+ // the client whose key these are shares of; the index in the vector
+ // represents the client who provided that key share.
+ std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>
+ pairwise_shamir_share_table_;
+ std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>
+ self_shamir_share_table_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_
diff --git a/fcp/secagg/server/secagg_server_r0_advertise_keys_state.cc b/fcp/secagg/server/secagg_server_r0_advertise_keys_state.cc
new file mode 100644
index 0000000..6dcf9c1
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r0_advertise_keys_state.cc
@@ -0,0 +1,178 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_r0_advertise_keys_state.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secagg_server_r1_share_keys_state.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggServerR0AdvertiseKeysState::SecAggServerR0AdvertiseKeysState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl)
+ : SecAggServerState(0, 0, 0, SecAggServerStateKind::R0_ADVERTISE_KEYS,
+ std::move(impl)) {
+ if (metrics()) {
+ // This is the initial state, so we count the start of the protocol from the
+ // moment it is constructed.
+ metrics()->ProtocolStarts(this->impl()->server_variant());
+ }
+}
+
+SecAggServerR0AdvertiseKeysState::~SecAggServerR0AdvertiseKeysState() {}
+
+Status SecAggServerR0AdvertiseKeysState::HandleMessage(
+ uint32_t client_id, const ClientToServerWrapperMessage& message) {
+ if (message.has_abort()) {
+ MessageReceived(message, false);
+ AbortClient(client_id, "", ClientDropReason::SENT_ABORT_MESSAGE,
+ /*notify=*/false);
+ return FCP_STATUS(OK);
+ }
+ // If the client has aborted or sent a message already, ignore its messages.
+ if (client_status(client_id) != ClientStatus::READY_TO_START) {
+ MessageReceived(message, false);
+ AbortClient(
+ client_id,
+ "Not expecting an AdvertiseKeys message from this client - either the "
+ "client already aborted or one such message was already received.",
+ ClientDropReason::ADVERTISE_KEYS_UNEXPECTED);
+ return FCP_STATUS(OK);
+ }
+ if (!message.has_advertise_keys()) {
+ MessageReceived(message, false);
+ AbortClient(client_id,
+ "Message type received is different from what was expected.",
+ ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
+ return FCP_STATUS(OK);
+ }
+ MessageReceived(message, true);
+
+ Status status =
+ impl()->HandleAdvertiseKeys(client_id, message.advertise_keys());
+ if (!status.ok()) {
+ AbortClient(client_id, std::string(status.message()),
+ ClientDropReason::INVALID_PUBLIC_KEY);
+ return FCP_STATUS(OK);
+ }
+
+ set_client_status(client_id, ClientStatus::ADVERTISE_KEYS_RECEIVED);
+ number_of_clients_ready_for_next_round_++;
+ number_of_messages_received_in_this_round_++;
+
+ return FCP_STATUS(OK);
+}
+
+bool SecAggServerR0AdvertiseKeysState::IsNumberOfIncludedInputsCommitted()
+ const {
+ return false;
+}
+
+int SecAggServerR0AdvertiseKeysState::MinimumMessagesNeededForNextRound()
+ const {
+ return std::max(0, minimum_number_of_clients_to_proceed() -
+ number_of_clients_ready_for_next_round_);
+}
+
+int SecAggServerR0AdvertiseKeysState::NumberOfPendingClients() const {
+ return NumberOfAliveClients() - number_of_clients_ready_for_next_round_;
+}
+
+void SecAggServerR0AdvertiseKeysState::HandleAbortClient(
+ uint32_t client_id, ClientDropReason reason_code) {
+ number_of_clients_failed_before_sending_masked_input_++;
+ if (client_status(client_id) == ClientStatus::ADVERTISE_KEYS_RECEIVED) {
+ number_of_clients_ready_for_next_round_--;
+ // Remove that client's public keys as if they were never sent. This will
+ // avoid wasted computation and bandwidth.
+ impl()->ErasePublicKeysForClient(client_id);
+ }
+ set_client_status(client_id, ClientStatus::DEAD_BEFORE_SENDING_ANYTHING);
+ if (NumberOfAliveClients() < minimum_number_of_clients_to_proceed()) {
+ needs_to_abort_ = true;
+ }
+}
+
+StatusOr<std::unique_ptr<SecAggServerState>>
+SecAggServerR0AdvertiseKeysState::ProceedToNextRound() {
+ if (!ReadyForNextRound()) {
+ return FCP_STATUS(UNAVAILABLE);
+ }
+ if (needs_to_abort_) {
+ std::string error_string = "Too many clients aborted.";
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_diagnostic_info(error_string);
+ abort_message.mutable_abort()->set_early_success(false);
+ SendBroadcast(abort_message);
+
+ return AbortState(error_string,
+ SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING);
+ }
+
+ // Abort all clients that haven't yet sent a message.
+ for (int i = 0; i < total_number_of_clients(); ++i) {
+ if (!IsClientDead(i) &&
+ client_status(i) != ClientStatus::ADVERTISE_KEYS_RECEIVED) {
+ AbortClient(
+ i,
+ "Client did not send AdvertiseKeys message before round transition.",
+ ClientDropReason::NO_ADVERTISE_KEYS);
+ }
+ }
+
+ impl()->ComputeSessionId();
+
+ ServerToClientWrapperMessage message_to_client;
+ message_to_client.mutable_share_keys_request()->set_session_id(
+ impl()->session_id().data);
+ FCP_RETURN_IF_ERROR(impl()->InitializeShareKeysRequest(
+ message_to_client.mutable_share_keys_request()));
+
+ for (int i = 0; i < total_number_of_clients(); ++i) {
+ // Reuse the common parts of the ShareKeysRequest message and update the
+ // client-specific parts.
+ if (!IsClientDead(i)) {
+ impl()->PrepareShareKeysRequestForClient(
+ i, message_to_client.mutable_share_keys_request());
+ Send(i, message_to_client);
+ }
+ }
+
+ // Pairs of public keys are no longer needed beyond this point as the server
+ // has already forwarded them to the clients.
+ impl()->ClearPairsOfPublicKeys();
+
+ return {std::make_unique<SecAggServerR1ShareKeysState>(
+ ExitState(StateTransition::kSuccess),
+ number_of_clients_failed_after_sending_masked_input_,
+ number_of_clients_failed_before_sending_masked_input_,
+ number_of_clients_terminated_without_unmasking_)};
+}
+
+bool SecAggServerR0AdvertiseKeysState::ReadyForNextRound() const {
+ return (number_of_clients_ready_for_next_round_ >=
+ minimum_number_of_clients_to_proceed()) ||
+ (needs_to_abort_);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_r0_advertise_keys_state.h b/fcp/secagg/server/secagg_server_r0_advertise_keys_state.h
new file mode 100644
index 0000000..0799493
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r0_advertise_keys_state.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_R0_ADVERTISE_KEYS_STATE_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_R0_ADVERTISE_KEYS_STATE_H_
+
+#include <memory>
+
+#include "fcp/secagg/server/secagg_server_state.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class is the State for the SecAggServer when it is in the
+// Round 0: Advertise state. The server begins in this state. This state
+// collects public ECDH keys from clients, stores the one of them that will be
+// used later in running the PRNG, and then sends to each client both public
+// keys from each of its neighbors. It also computes and stores the session ID
+// from the received keys, and sends the session ID to each client along with
+// its neighbor's keys. This state should transition to Round 1: Share Keys, but
+// might transition to Aborted if too many clients abort.
+
+class SecAggServerR0AdvertiseKeysState : public SecAggServerState {
+ public:
+ explicit SecAggServerR0AdvertiseKeysState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl);
+
+ ~SecAggServerR0AdvertiseKeysState() override;
+
+ // Handles an advertise keys response or abort message from a client.
+ Status HandleMessage(uint32_t client_id,
+ const ClientToServerWrapperMessage& message) override;
+
+ bool IsNumberOfIncludedInputsCommitted() const override;
+
+ int MinimumMessagesNeededForNextRound() const override;
+
+ int NumberOfPendingClients() const override;
+
+ StatusOr<std::unique_ptr<SecAggServerState> > ProceedToNextRound() override;
+
+ // This will return true only after minimum_number_of_clients_to_proceed
+ // clients have sent messages (and not subsequently aborted).
+ bool ReadyForNextRound() const override;
+
+ private:
+ void HandleAbortClient(uint32_t client_id,
+ ClientDropReason reason_code) override;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_R0_ADVERTISE_KEYS_STATE_H_
diff --git a/fcp/secagg/server/secagg_server_r0_advertise_keys_state_test.cc b/fcp/secagg/server/secagg_server_r0_advertise_keys_state_test.cc
new file mode 100644
index 0000000..520d226
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r0_advertise_keys_state_test.cc
@@ -0,0 +1,795 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_r0_advertise_keys_state.h"
+
+#include <memory>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_state.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
+#include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::_;
+using ::testing::Eq;
+using ::testing::Ge;
+using ::testing::IsFalse;
+using ::testing::IsTrue;
+
+std::unique_ptr<AesSecAggServerProtocolImpl> CreateAesSecAggServerProtocolImpl(
+ MockSendToClientsInterface* sender,
+ MockSecAggServerMetricsListener* metrics_listener = nullptr) {
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ SecretSharingGraphFactory factory;
+
+ return std::make_unique<AesSecAggServerProtocolImpl>(
+ factory.CreateCompleteGraph(4, 3), // total number of clients is 4
+ 3, // minimum_number_of_clients_to_proceed
+ input_vector_specs,
+ std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
+ std::make_unique<AesCtrPrngFactory>(), sender,
+ nullptr, // prng_runner
+ std::vector<ClientStatus>(4, ClientStatus::READY_TO_START),
+ ServerVariant::NATIVE_V1);
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest, IsAbortedReturnsFalse) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get()));
+
+ EXPECT_THAT(state.IsAborted(), IsFalse());
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest,
+ IsCompletedSuccessfullyReturnsFalse) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get()));
+
+ EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest, ErrorMessageRaisesErrorStatus) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get()));
+
+ EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest, ResultRaisesErrorStatus) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get()));
+
+ EXPECT_THAT(state.Result().ok(), IsFalse());
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest,
+ AbortReturnsValidStateAndNotifiesClients) {
+ TestTracingRecorder tracing_recorder;
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
+
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
+
+ EXPECT_CALL(*metrics,
+ ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
+ auto next_state =
+ state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
+
+ ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(next_state->ErrorMessage(), IsOk());
+ EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
+ ElementsAre(IsEvent<BroadcastMessageSent>(
+ Eq(ServerToClientMessageType_Abort),
+ Eq(abort_message.ByteSizeLong()))));
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest,
+ StateProceedsCorrectlyWithAllClientsValid) {
+ // In this test, all clients send two valid ECDH public keys apiece, and then
+ // the server proceeds to the next state.
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get()));
+
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
+ std::vector<ClientToServerWrapperMessage> client_messages(4);
+ ServerToClientWrapperMessage expected_server_message;
+ for (int i = 0; i < 4; ++i) {
+ PairOfPublicKeys* public_keys =
+ expected_server_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys();
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ }
+ expected_server_message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(expected_server_message.share_keys_request()).data);
+
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
+ .Times(1);
+ }
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R1_SHARE_KEYS));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<IndividualMessageSent>(),
+ ElementsAre(IsEvent<IndividualMessageSent>(
+ 0, Eq(ServerToClientMessageType_ShareKeysRequest),
+ Eq(expected_server_message.ByteSizeLong())),
+ IsEvent<IndividualMessageSent>(
+ 1, Eq(ServerToClientMessageType_ShareKeysRequest),
+ Eq(expected_server_message.ByteSizeLong())),
+ IsEvent<IndividualMessageSent>(
+ 2, Eq(ServerToClientMessageType_ShareKeysRequest),
+ Eq(expected_server_message.ByteSizeLong())),
+ IsEvent<IndividualMessageSent>(
+ 3, Eq(ServerToClientMessageType_ShareKeysRequest),
+ Eq(expected_server_message.ByteSizeLong()))));
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest,
+ StateProceedsCorrectlyWithInvalidKeysFromOneClient) {
+ // In this test, client 3 sends invalid public keys, so it should be forced to
+ // abort. But this should not stop the rest of the state proceeding normally.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get()));
+
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
+ std::vector<ClientToServerWrapperMessage> client_messages(4);
+ ServerToClientWrapperMessage expected_server_message;
+ for (int i = 0; i < 3; ++i) {
+ PairOfPublicKeys* public_keys =
+ expected_server_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys();
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ }
+ client_messages[3]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetPublicKeyString(3));
+ client_messages[3]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk("This is too long to be a valid key.");
+ expected_server_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys(); // this one will be empty
+
+ expected_server_message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(expected_server_message.share_keys_request()).data);
+
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
+ .Times(1);
+ }
+
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "A public key sent by the client was not the correct size.");
+
+ EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message)));
+
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(3));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R1_SHARE_KEYS));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest,
+ StateProceedsCorrectlyWithNoMessageFromOneClient) {
+ // In this test, we proceed to the next state before client 3 sends any
+ // message, so it should be forced to abort. But this should not stop the rest
+ // of the state proceeding normally.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get()));
+
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
+ std::vector<ClientToServerWrapperMessage> client_messages(3);
+ ServerToClientWrapperMessage expected_server_message;
+ for (int i = 0; i < 3; ++i) {
+ PairOfPublicKeys* public_keys =
+ expected_server_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys();
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ }
+ expected_server_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys(); // this one will be empty
+
+ expected_server_message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(expected_server_message.share_keys_request()).data);
+
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
+ .Times(1);
+ }
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "Client did not send AdvertiseKeys message before round transition.");
+
+ EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message)));
+
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 3) {
+ ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R1_SHARE_KEYS));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest,
+ StateNeedsToAbortIfTooManyClientsAbort) {
+ // In this test, the first two clients send abort messages, so the server
+ // should register that it needs to abort.
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get()));
+
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), Eq(i >= 2));
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4 - i));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3));
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ if (i < 2) {
+ // Have client abort
+ ClientToServerWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
+ ASSERT_THAT(state.HandleMessage(i, abort_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 1));
+ }
+ }
+
+ ServerToClientWrapperMessage server_message;
+ server_message.mutable_abort()->set_early_success(false);
+ server_message.mutable_abort()->set_diagnostic_info(
+ "Too many clients aborted.");
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(1);
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(next_state.value()->ErrorMessage(), IsOk());
+ EXPECT_THAT(next_state.value()->ErrorMessage().value(),
+ Eq("Too many clients aborted."));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
+ ElementsAre(IsEvent<BroadcastMessageSent>(
+ Eq(ServerToClientMessageType_Abort),
+ Eq(server_message.ByteSizeLong()))));
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest,
+ StateProceedsCorrectlyWithAllUncompressedClientMessages) {
+ // In this test, all clients send two valid ECDH public keys apiece, and then
+ // the server proceeds to the next state.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get()));
+
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
+ std::vector<ClientToServerWrapperMessage> client_messages(4);
+ ServerToClientWrapperMessage expected_server_message;
+ for (int i = 0; i < 4; ++i) {
+ PairOfPublicKeys* public_keys =
+ expected_server_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys();
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetUncompressedPublicKeyString(i));
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk(ecdh_keys.GetUncompressedPublicKeyString(i + 4));
+ public_keys->set_enc_pk(ecdh_keys.GetUncompressedPublicKeyString(i));
+ public_keys->set_noise_pk(ecdh_keys.GetUncompressedPublicKeyString(i + 4));
+ }
+
+ expected_server_message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(expected_server_message.share_keys_request()).data);
+
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
+ .Times(1);
+ }
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R1_SHARE_KEYS));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest, MetricsRecordsStart) {
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ EXPECT_CALL(*metrics, ProtocolStarts(_));
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
+
+ EXPECT_THAT(state.Result().ok(), IsFalse());
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest, MetricsRecordsMessageSizes) {
+ // In this test, client 3 sends invalid public keys, so it should be forced to
+ // abort. But this should not stop the rest of the state proceeding normally.
+ TestTracingRecorder tracing_recorder;
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ EXPECT_CALL(*metrics, ProtocolStarts(_));
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
+
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
+ std::vector<ClientToServerWrapperMessage> client_messages(4);
+ ServerToClientWrapperMessage expected_server_message;
+ for (int i = 0; i < 3; ++i) {
+ PairOfPublicKeys* public_keys =
+ expected_server_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys();
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ }
+ client_messages[3]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetPublicKeyString(3));
+ client_messages[3]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk("This is too long to be a valid key.");
+ expected_server_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys(); // this one will be empty
+
+ expected_server_message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(expected_server_message.share_keys_request()).data);
+
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
+ .Times(1);
+ }
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "A public key sent by the client was not the correct size.");
+ EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message)));
+
+ EXPECT_CALL(*metrics, IndividualMessageSizes(
+ Eq(ServerToClientWrapperMessage::
+ MessageContentCase::kShareKeysRequest),
+ Eq(expected_server_message.ByteSizeLong())))
+ .Times(3);
+ EXPECT_CALL(*metrics,
+ IndividualMessageSizes(
+ Eq(ServerToClientWrapperMessage::MessageContentCase::kAbort),
+ Eq(abort_message.ByteSizeLong())));
+ EXPECT_CALL(
+ *metrics,
+ MessageReceivedSizes(
+ Eq(ClientToServerWrapperMessage::MessageContentCase::kAdvertiseKeys),
+ Eq(true), Eq(client_messages[0].ByteSizeLong())))
+ .Times(3);
+ EXPECT_CALL(
+ *metrics,
+ MessageReceivedSizes(
+ Eq(ClientToServerWrapperMessage::MessageContentCase::kAdvertiseKeys),
+ Eq(true), Eq(client_messages[3].ByteSizeLong())));
+
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ EXPECT_THAT(tracing_recorder.root()[i],
+ IsEvent<ClientMessageReceived>(
+ Eq(ClientToServerMessageType_AdvertiseKeys),
+ Eq(client_messages[i].ByteSizeLong()), Eq(true), Ge(0)));
+ }
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(3));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R1_SHARE_KEYS));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest,
+ ServerAndClientAbortsAreRecordedCorrectly) {
+ TestTracingRecorder tracing_recorder;
+ // In this test clients abort for a variety of reasons, and then ultimately
+ // the server aborts. Metrics should record all of these events.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ EcdhPregeneratedTestKeys ecdh_keys;
+
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
+
+ EXPECT_CALL(*metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
+ Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
+ EXPECT_CALL(*metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
+ Eq(ClientDropReason::ADVERTISE_KEYS_UNEXPECTED)));
+ EXPECT_CALL(*metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
+ Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
+ EXPECT_CALL(*metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
+ Eq(ClientDropReason::INVALID_PUBLIC_KEY)));
+ EXPECT_CALL(
+ *metrics,
+ ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
+
+ ClientToServerWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
+ ClientToServerWrapperMessage valid_message;
+ valid_message.mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetPublicKeyString(0));
+ valid_message.mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk(ecdh_keys.GetPublicKeyString(4));
+ ClientToServerWrapperMessage invalid_message;
+ invalid_message.mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetPublicKeyString(3));
+ invalid_message.mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk("This is too long to be a valid key.");
+ ClientToServerWrapperMessage wrong_message;
+ wrong_message.mutable_share_keys_response(); // wrong type of message
+
+ state.HandleMessage(0, abort_message).IgnoreError();
+ state.HandleMessage(1, valid_message).IgnoreError();
+ state.HandleMessage(1, valid_message).IgnoreError();
+ state.HandleMessage(2, invalid_message).IgnoreError();
+ state.HandleMessage(3, wrong_message).IgnoreError();
+ state.ProceedToNextRound().IgnoreError(); // causes server abort
+
+ EXPECT_THAT(tracing_recorder.FindAllEvents<SecAggProtocolOutcome>(),
+ ElementsAre(IsEvent<SecAggProtocolOutcome>(
+ Eq(TracingSecAggServerOutcome_NotEnoughClientsRemaining))));
+ EXPECT_THAT(
+ tracing_recorder.FindAllEvents<ClientsDropped>(),
+ ElementsAre(IsEvent<ClientsDropped>(
+ Eq(TracingClientStatus_DeadBeforeSendingAnything),
+ Eq(TracingClientDropReason_SentAbortMessage)),
+ IsEvent<ClientsDropped>(
+ Eq(TracingClientStatus_DeadBeforeSendingAnything),
+ Eq(TracingClientDropReason_AdvertiseKeysUnexpected)),
+ IsEvent<ClientsDropped>(
+ Eq(TracingClientStatus_DeadBeforeSendingAnything),
+ Eq(TracingClientDropReason_InvalidPublicKey)),
+ IsEvent<ClientsDropped>(
+ Eq(TracingClientStatus_DeadBeforeSendingAnything),
+ Eq(TracingClientDropReason_UnexpectedMessageType))));
+}
+
+TEST(SecaggServerR0AdvertiseKeysStateTest, MetricsAreRecorded) {
+ // In this test, all clients send two valid ECDH public keys apiece, and then
+ // the server proceeds to the next state.
+ TestTracingRecorder tracing_recorder;
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ SecAggServerR0AdvertiseKeysState state(
+ CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
+
+ EcdhPregeneratedTestKeys ecdh_keys;
+ auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
+ std::vector<ClientToServerWrapperMessage> client_messages(4);
+ ServerToClientWrapperMessage expected_server_message;
+ for (int i = 0; i < 4; ++i) {
+ PairOfPublicKeys* public_keys =
+ expected_server_message.mutable_share_keys_request()
+ ->add_pairs_of_public_keys();
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ client_messages[i]
+ .mutable_advertise_keys()
+ ->mutable_pair_of_public_keys()
+ ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
+ public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
+ }
+
+ expected_server_message.mutable_share_keys_request()->set_session_id(
+ ComputeSessionId(expected_server_message.share_keys_request()).data);
+
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
+ .Times(1);
+ }
+ EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS),
+ Eq(true), Ge(0)));
+ EXPECT_CALL(*metrics,
+ RoundSurvivingClients(
+ Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS), Eq(4)));
+ EXPECT_CALL(
+ *metrics,
+ ClientResponseTimes(
+ Eq(ClientToServerWrapperMessage::MessageContentCase::kAdvertiseKeys),
+ Ge(0)))
+ .Times(4);
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R1_SHARE_KEYS));
+ EXPECT_THAT(
+ tracing_recorder.FindAllEvents<StateCompletion>(),
+ ElementsAre(IsEvent<StateCompletion>(
+ Eq(SecAggServerTraceState_R0AdvertiseKeys), Eq(true), Ge(0), Eq(4))));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_r1_share_keys_state.cc b/fcp/secagg/server/secagg_server_r1_share_keys_state.cc
new file mode 100644
index 0000000..eb05bb8
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r1_share_keys_state.cc
@@ -0,0 +1,163 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_r1_share_keys_state.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggServerR1ShareKeysState::SecAggServerR1ShareKeysState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking)
+ : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
+ number_of_clients_failed_before_sending_masked_input,
+ number_of_clients_terminated_without_unmasking,
+ SecAggServerStateKind::R1_SHARE_KEYS, std::move(impl)) {
+}
+
+SecAggServerR1ShareKeysState::~SecAggServerR1ShareKeysState() {}
+
+Status SecAggServerR1ShareKeysState::HandleMessage(
+ uint32_t client_id, const ClientToServerWrapperMessage& message) {
+ if (message.has_abort()) {
+ MessageReceived(message, false);
+ AbortClient(client_id, "", ClientDropReason::SENT_ABORT_MESSAGE,
+ /*notify=*/false);
+ return FCP_STATUS(OK);
+ }
+ // If the client has aborted or sent a message already, ignore its messages.
+ if (client_status(client_id) != ClientStatus::ADVERTISE_KEYS_RECEIVED) {
+ MessageReceived(message, false);
+ AbortClient(client_id,
+ "Not expecting an ShareKeysResponse from this "
+ "client - either the client already aborted or one such "
+ "message was already received.",
+ ClientDropReason::SHARE_KEYS_UNEXPECTED);
+ return FCP_STATUS(OK);
+ }
+ if (!message.has_share_keys_response()) {
+ MessageReceived(message, false);
+ AbortClient(client_id,
+ "Message type received is different from what was expected.",
+ ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
+ return FCP_STATUS(OK);
+ }
+ MessageReceived(message, true);
+
+ Status status =
+ impl()->HandleShareKeysResponse(client_id, message.share_keys_response());
+ if (!status.ok()) {
+ AbortClient(client_id, std::string(status.message()),
+ ClientDropReason::INVALID_SHARE_KEYS_RESPONSE);
+ return FCP_STATUS(OK);
+ }
+
+ set_client_status(client_id, ClientStatus::SHARE_KEYS_RECEIVED);
+ number_of_messages_received_in_this_round_++;
+ number_of_clients_ready_for_next_round_++;
+ return FCP_STATUS(OK);
+}
+
+bool SecAggServerR1ShareKeysState::IsNumberOfIncludedInputsCommitted() const {
+ return false;
+}
+
+int SecAggServerR1ShareKeysState::MinimumMessagesNeededForNextRound() const {
+ return std::max(0, minimum_number_of_clients_to_proceed() -
+ number_of_clients_ready_for_next_round_);
+}
+
+int SecAggServerR1ShareKeysState::NumberOfPendingClients() const {
+ return NumberOfAliveClients() - number_of_clients_ready_for_next_round_;
+}
+
+void SecAggServerR1ShareKeysState::HandleAbortClient(
+ uint32_t client_id, ClientDropReason reason_code) {
+ number_of_clients_failed_before_sending_masked_input_++;
+ if (client_status(client_id) == ClientStatus::SHARE_KEYS_RECEIVED) {
+ number_of_clients_ready_for_next_round_--;
+ // Remove that client's shared keys as if they were never sent. This will
+ // avoid wasted computation on both client and server ends.
+ impl()->EraseShareKeysForClient(client_id);
+ }
+ set_client_status(client_id,
+ ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED);
+ if (NumberOfAliveClients() < minimum_number_of_clients_to_proceed()) {
+ needs_to_abort_ = true;
+ }
+}
+
+StatusOr<std::unique_ptr<SecAggServerState>>
+SecAggServerR1ShareKeysState::ProceedToNextRound() {
+ if (!ReadyForNextRound()) {
+ return FCP_STATUS(UNAVAILABLE);
+ }
+ if (needs_to_abort_) {
+ std::string error_string = "Too many clients aborted.";
+ ServerToClientWrapperMessage message;
+ message.mutable_abort()->set_diagnostic_info(error_string);
+ message.mutable_abort()->set_early_success(false);
+ SendBroadcast(message);
+
+ return AbortState(error_string,
+ SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING);
+ }
+
+ // Abort all clients that haven't yet sent a message, and send a message to
+ // all clients that are still alive.
+ for (int i = 0; i < total_number_of_clients(); ++i) {
+ if (!IsClientDead(i) &&
+ client_status(i) != ClientStatus::SHARE_KEYS_RECEIVED) {
+ AbortClient(
+ i, "Client did not send ShareKeysResponse before round transition.",
+ ClientDropReason::NO_SHARE_KEYS);
+ } else if (client_status(i) == ClientStatus::SHARE_KEYS_RECEIVED) {
+ ServerToClientWrapperMessage message;
+ impl()->PrepareMaskedInputCollectionRequestForClient(
+ i, message.mutable_masked_input_request());
+ Send(i, message);
+ }
+ }
+
+ // Encrypted shares are no longer needed beyond this point as the server has
+ // already forwarded them to the clients.
+ impl()->ClearShareKeys();
+
+ return std::make_unique<SecAggServerR2MaskedInputCollState>(
+ ExitState(StateTransition::kSuccess),
+ number_of_clients_failed_after_sending_masked_input_,
+ number_of_clients_failed_before_sending_masked_input_,
+ number_of_clients_terminated_without_unmasking_);
+}
+
+bool SecAggServerR1ShareKeysState::ReadyForNextRound() const {
+ return (number_of_clients_ready_for_next_round_ >=
+ minimum_number_of_clients_to_proceed()) ||
+ (needs_to_abort_);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_r1_share_keys_state.h b/fcp/secagg/server/secagg_server_r1_share_keys_state.h
new file mode 100644
index 0000000..53d54a9
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r1_share_keys_state.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_R1_SHARE_KEYS_STATE_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_R1_SHARE_KEYS_STATE_H_
+
+#include <memory>
+#include <string>
+
+#include "fcp/secagg/server/secagg_server_state.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class is the State for the SecAggServer when it is in the
+// Round 1: Share Keys state. This state collects encrypted key shares from
+// clients, and at the end of the round, delivers the appropriate shares to the
+// right clients. It should transition to Round 2: Masked Input Collection, but
+// might transition to Aborted if too many clients abort.
+
+class SecAggServerR1ShareKeysState : public SecAggServerState {
+ public:
+ SecAggServerR1ShareKeysState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking);
+
+ ~SecAggServerR1ShareKeysState() override;
+
+ // Handles a share keys response or abort message from a client.
+ Status HandleMessage(uint32_t client_id,
+ const ClientToServerWrapperMessage& message) override;
+
+ bool IsNumberOfIncludedInputsCommitted() const override;
+
+ int MinimumMessagesNeededForNextRound() const override;
+
+ int NumberOfPendingClients() const override;
+
+ StatusOr<std::unique_ptr<SecAggServerState>> ProceedToNextRound() override;
+
+ // This will return true only after minimum_number_of_clients_to_proceed
+ // clients have sent messages (and not subsequently aborted).
+ bool ReadyForNextRound() const override;
+
+ private:
+ void HandleAbortClient(uint32_t client_id,
+ ClientDropReason reason_code) override;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_R1_SHARE_KEYS_STATE_H_
diff --git a/fcp/secagg/server/secagg_server_r1_share_keys_state_test.cc b/fcp/secagg/server/secagg_server_r1_share_keys_state_test.cc
new file mode 100644
index 0000000..0d0f308
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r1_share_keys_state_test.cc
@@ -0,0 +1,829 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_r1_share_keys_state.h"
+
+#include <memory>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/secagg_server_state.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
+#include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::_;
+using ::testing::Eq;
+using ::testing::Ge;
+using ::testing::IsFalse;
+using ::testing::IsTrue;
+
+// Default test session_id.
+SessionId session_id = {"session id number, 32 bytes long"};
+
+std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
+ int minimum_number_of_clients_to_proceed, int total_number_of_clients,
+ MockSendToClientsInterface* sender,
+ MockSecAggServerMetricsListener* metrics_listener = nullptr) {
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ SecretSharingGraphFactory factory;
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
+ factory.CreateCompleteGraph(total_number_of_clients,
+ minimum_number_of_clients_to_proceed),
+ minimum_number_of_clients_to_proceed, input_vector_specs,
+ std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
+ std::make_unique<AesCtrPrngFactory>(), sender,
+ std::make_unique<SecAggScheduler>(
+ /*sequential_scheduler=*/nullptr,
+ /*parallel_scheduler=*/nullptr),
+ std::vector<ClientStatus>(total_number_of_clients,
+ ClientStatus::ADVERTISE_KEYS_RECEIVED),
+ ServerVariant::NATIVE_V1);
+ impl->set_session_id(std::make_unique<SessionId>(session_id));
+ EcdhPregeneratedTestKeys ecdh_keys;
+ for (int i = 0; i < total_number_of_clients; i++) {
+ impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
+ }
+ return impl;
+}
+
+TEST(SecaggServerR1ShareKeysStateTest, IsAbortedReturnsFalse) {
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(state.IsAborted(), IsFalse());
+}
+
+TEST(SecaggServerR1ShareKeysStateTest, IsCompletedSuccessfullyReturnsFalse) {
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
+}
+
+TEST(SecaggServerR1ShareKeysStateTest, ErrorMessageRaisesErrorStatus) {
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
+}
+
+TEST(SecaggServerR1ShareKeysStateTest, ResultRaisesErrorStatus) {
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(state.Result().ok(), IsFalse());
+}
+
+TEST(SecaggServerR1ShareKeysStateTest,
+ AbortReturnsValidStateAndNotifiesClients) {
+ TestTracingRecorder tracing_recorder;
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
+
+ EXPECT_CALL(*metrics,
+ ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
+ auto next_state =
+ state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
+
+ ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(next_state->ErrorMessage(), IsOk());
+ EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
+ ElementsAre(IsEvent<BroadcastMessageSent>(
+ Eq(ServerToClientMessageType_Abort),
+ Eq(abort_message.ByteSizeLong()))));
+}
+
+TEST(SecaggServerR1ShareKeysStateTest,
+ StateProceedsCorrectlyWithAllClientsValid) {
+ // In this test, all clients send inputs for the correct clients, and then the
+ // server proceeds to the next state. (The inputs aren't actually encrypted
+ // shared keys, but that doesn't matter for this test.)
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ // Have one client send the right vector of "encrypted keys" to the
+ // server.
+ ClientToServerWrapperMessage client_message;
+ for (int j = 0; j < 4; ++j) {
+ if (i == j) {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares("");
+ } else {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", i, " to ", j));
+ }
+ }
+ ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+ std::vector<ServerToClientWrapperMessage> server_messages(4);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ if (i == j) {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares("");
+ } else {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", j, " to ", i));
+ }
+ }
+ EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
+ }
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR1ShareKeysStateTest,
+ StateProceedsCorrectlyWithOnePreviousDropout) {
+ // In this test, client 3 dropped out in round 0, so clients should not send
+ // key shares for it. All other clients proceed normally.
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+ auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get());
+ impl->set_client_status(3, ClientStatus::DEAD_BEFORE_SENDING_ANYTHING);
+
+ SecAggServerR1ShareKeysState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 1, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(3 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 3) {
+ // Have one client send the right vector of "encrypted keys" to the
+ // server.
+ ClientToServerWrapperMessage client_message;
+ for (int j = 0; j < 4; ++j) {
+ if (i == j || j == 3) {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares("");
+ } else {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", i, " to ", j));
+ }
+ }
+ ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+ std::vector<ServerToClientWrapperMessage> server_messages(3);
+ for (int i = 0; i < 3; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ if (i == j || j == 3) {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares("");
+ } else {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", j, " to ", i));
+ }
+ }
+ EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
+ }
+ EXPECT_CALL(*sender, Send(Eq(3), _)).Times(0);
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR1ShareKeysStateTest,
+ StateProceedsCorrectlyWithAnAbortAfterSendingShares) {
+ // In this test, all clients send inputs for the correct clients, but then
+ // client 2 aborts. This should cause that client's message shared keys not to
+ // appear in the messages sent later.
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ // Have one client send the right vector of "encrypted key shares" to
+ // the server.
+ ClientToServerWrapperMessage client_message;
+ for (int j = 0; j < 4; ++j) {
+ if (i == j) {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares("");
+ } else {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", i, " to ", j));
+ }
+ }
+ ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ ClientToServerWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_diagnostic_info("aborting for test");
+ ASSERT_THAT(state.HandleMessage(2, abort_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+
+ std::vector<ServerToClientWrapperMessage> server_messages(4);
+ for (int i = 0; i < 4; ++i) {
+ if (i == 2) {
+ EXPECT_CALL(*sender, Send(Eq(2), _)).Times(0);
+ continue;
+ }
+ for (int j = 0; j < 4; ++j) {
+ if (i == j || j == 2) {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares("");
+ } else {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", j, " to ", i));
+ }
+ }
+ EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
+ }
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR1ShareKeysStateTest,
+ StateProceedsCorrectlyWithOneClientSendingInvalidShares) {
+ // In this test, all clients send encrypted shares, but client 0 omits an
+ // encrypted share for client 1. This should force client 0 to abort.
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ std::vector<ServerToClientWrapperMessage> server_messages(4);
+ server_messages[0].mutable_abort()->set_early_success(false);
+ server_messages[0].mutable_abort()->set_diagnostic_info(
+ "Client omitted a key share that was expected.");
+ EXPECT_CALL(*sender, Send(Eq(0), EqualsProto(server_messages[0]))).Times(1);
+ for (int i = 1; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ if (i == j || j == 0) {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares("");
+ } else {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", j, " to ", i));
+ }
+ }
+ EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
+ }
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+
+ ClientToServerWrapperMessage bad_message;
+ bad_message.mutable_share_keys_response()->add_encrypted_key_shares("");
+ bad_message.mutable_share_keys_response()->add_encrypted_key_shares("");
+ bad_message.mutable_share_keys_response()->add_encrypted_key_shares(
+ "encrypted key shares from 0 to 2");
+ bad_message.mutable_share_keys_response()->add_encrypted_key_shares(
+ "encrypted key shares from 0 to 3");
+ ASSERT_THAT(state.HandleMessage(0, bad_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+
+ for (int i = 1; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i - 1));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i - 1));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 4) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(4 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ // Have one client send the right vector of "encrypted key shares" to
+ // the server.
+ ClientToServerWrapperMessage client_message;
+ for (int j = 0; j < 4; ++j) {
+ if (i == j) {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares("");
+ } else {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", i, " to ", j));
+ }
+ }
+ ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR1ShareKeysStateTest, StateAbortsIfTooManyClientsAbort) {
+ // In this test, clients 0 and 1 send abort messages. This should cause the
+ // server state to register that it needs to abort immediately.
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), Eq(i >= 2));
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4 - i));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3));
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ if (i < 2) {
+ // Have client abort
+ ClientToServerWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
+ ASSERT_THAT(state.HandleMessage(i, abort_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 1));
+ }
+ }
+
+ ServerToClientWrapperMessage server_message;
+ server_message.mutable_abort()->set_early_success(false);
+ server_message.mutable_abort()->set_diagnostic_info(
+ "Too many clients aborted.");
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(1);
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(next_state.value()->ErrorMessage(), IsOk());
+ EXPECT_THAT(next_state.value()->ErrorMessage().value(),
+ Eq("Too many clients aborted."));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
+ ElementsAre(IsEvent<BroadcastMessageSent>(
+ Eq(ServerToClientMessageType_Abort),
+ Eq(server_message.ByteSizeLong()))));
+}
+
+TEST(SecaggServerR1ShareKeysStateTest, MetricsRecordsMessageSizes) {
+ // In this test, all clients send inputs for the correct clients, and then the
+ // server proceeds to the next state. (The inputs aren't actually encrypted
+ // shared keys, but that doesn't matter for this test.)
+ TestTracingRecorder tracing_recorder;
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ // Have one client send the right vector of "encrypted keys" to the
+ // server.
+ ClientToServerWrapperMessage client_message;
+ for (int j = 0; j < 4; ++j) {
+ if (i == j) {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares("");
+ } else {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", i, " to ", j));
+ }
+ }
+ EXPECT_CALL(*metrics, MessageReceivedSizes(
+ Eq(ClientToServerWrapperMessage::
+ MessageContentCase::kShareKeysResponse),
+ Eq(true), Eq(client_message.ByteSizeLong())));
+ ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ EXPECT_THAT(tracing_recorder.root()[i],
+ IsEvent<ClientMessageReceived>(
+ Eq(ClientToServerMessageType_ShareKeysResponse),
+ Eq(client_message.ByteSizeLong()), Eq(true), Ge(0)));
+ }
+ }
+ std::vector<ServerToClientWrapperMessage> server_messages(4);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ if (i == j) {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares("");
+ } else {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", j, " to ", i));
+ }
+ }
+ EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i])));
+ }
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ EXPECT_CALL(*metrics, BroadcastMessageSizes(_, _)).Times(0);
+ EXPECT_CALL(*metrics, IndividualMessageSizes(
+ Eq(ServerToClientWrapperMessage::
+ MessageContentCase::kMaskedInputRequest),
+ Eq(server_messages[0].ByteSizeLong())))
+ .Times(4);
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+ EXPECT_THAT(
+ tracing_recorder.FindAllEvents<IndividualMessageSent>(),
+ ElementsAre(IsEvent<IndividualMessageSent>(
+ 0, Eq(ServerToClientMessageType_MaskedInputRequest),
+ Eq(server_messages[0].ByteSizeLong())),
+ IsEvent<IndividualMessageSent>(
+ 1, Eq(ServerToClientMessageType_MaskedInputRequest),
+ Eq(server_messages[1].ByteSizeLong())),
+ IsEvent<IndividualMessageSent>(
+ 2, Eq(ServerToClientMessageType_MaskedInputRequest),
+ Eq(server_messages[2].ByteSizeLong())),
+ IsEvent<IndividualMessageSent>(
+ 3, Eq(ServerToClientMessageType_MaskedInputRequest),
+ Eq(server_messages[3].ByteSizeLong()))));
+}
+
+TEST(SecaggServerR1ShareKeysStateTest,
+ ServerAndClientAbortsAreRecordedCorrectly) {
+ // In this test clients abort for a variety of reasons, and then ultimately
+ // the server aborts. Metrics should record all of these events.
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(2, 7, sender.get(), metrics),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_CALL(
+ *metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
+ Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
+ EXPECT_CALL(
+ *metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
+ Eq(ClientDropReason::SHARE_KEYS_UNEXPECTED)));
+ EXPECT_CALL(
+ *metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
+ Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
+ EXPECT_CALL(
+ *metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED),
+ Eq(ClientDropReason::INVALID_SHARE_KEYS_RESPONSE)))
+ .Times(3);
+ EXPECT_CALL(
+ *metrics,
+ ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
+
+ ClientToServerWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
+ ClientToServerWrapperMessage valid_message; // from client 1
+ for (int j = 0; j < 7; ++j) {
+ if (1 == j) {
+ valid_message.mutable_share_keys_response()->add_encrypted_key_shares("");
+ } else {
+ valid_message.mutable_share_keys_response()->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", 1, " to ", j));
+ }
+ }
+
+ ClientToServerWrapperMessage invalid_message_wrong_number; // from client 2
+ for (int j = 0; j <= 7; ++j) { // goes one past the end
+ if (2 == j) {
+ invalid_message_wrong_number.mutable_share_keys_response()
+ ->add_encrypted_key_shares("");
+ } else {
+ invalid_message_wrong_number.mutable_share_keys_response()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", 2, " to ", j));
+ }
+ }
+
+ ClientToServerWrapperMessage invalid_message_missing_share; // from client 3
+ for (int j = 0; j < 7; ++j) {
+ if (3 == j || 0 == j) { // missing share for 0
+ invalid_message_missing_share.mutable_share_keys_response()
+ ->add_encrypted_key_shares("");
+ } else {
+ invalid_message_missing_share.mutable_share_keys_response()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", 3, " to ", j));
+ }
+ }
+
+ ClientToServerWrapperMessage invalid_message_extra_share; // from client 4
+ for (int j = 0; j < 7; ++j) {
+ // including share for self, which is wrong
+ invalid_message_extra_share.mutable_share_keys_response()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", 4, " to ", j));
+ }
+
+ ClientToServerWrapperMessage wrong_message;
+ wrong_message.mutable_advertise_keys(); // wrong type of message
+
+ state.HandleMessage(0, abort_message).IgnoreError();
+ state.HandleMessage(1, valid_message).IgnoreError();
+ state.HandleMessage(1, valid_message).IgnoreError();
+ state.HandleMessage(2, invalid_message_wrong_number).IgnoreError();
+ state.HandleMessage(3, invalid_message_missing_share).IgnoreError();
+ state.HandleMessage(4, invalid_message_extra_share).IgnoreError();
+ state.HandleMessage(5, wrong_message).IgnoreError();
+ state.ProceedToNextRound().IgnoreError(); // causes server abort
+}
+
+TEST(SecaggServerR1ShareKeysStateTest, MetricsAreRecorded) {
+ // In this test, all clients send inputs for the correct clients, and then the
+ // server proceeds to the next state. (The inputs aren't actually encrypted
+ // shared keys, but that doesn't matter for this test.)
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_shared<MockSendToClientsInterface>();
+
+ SecAggServerR1ShareKeysState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_CALL(*metrics, ClientResponseTimes(
+ Eq(ClientToServerWrapperMessage::
+ MessageContentCase::kShareKeysResponse),
+ Ge(0)))
+ .Times(4);
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ // Have one client send the right vector of "encrypted keys" to the
+ // server.
+ ClientToServerWrapperMessage client_message;
+ for (int j = 0; j < 4; ++j) {
+ if (i == j) {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares("");
+ } else {
+ client_message.mutable_share_keys_response()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", i, " to ", j));
+ }
+ }
+ ASSERT_THAT(state.HandleMessage(i, client_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+ std::vector<ServerToClientWrapperMessage> server_messages(4);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ if (i == j) {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares("");
+ } else {
+ server_messages[i]
+ .mutable_masked_input_request()
+ ->add_encrypted_key_shares(
+ absl::StrCat("encrypted key shares from ", j, " to ", i));
+ }
+ }
+ EXPECT_CALL(*sender, Send(Eq(i), EqualsProto(server_messages[i]))).Times(1);
+ }
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::R1_SHARE_KEYS),
+ Eq(true), Ge(0)));
+ EXPECT_CALL(*metrics, RoundSurvivingClients(
+ Eq(SecAggServerStateKind::R1_SHARE_KEYS), Eq(4)));
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION));
+}
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_r2_masked_input_coll_state.cc b/fcp/secagg/server/secagg_server_r2_masked_input_coll_state.cc
new file mode 100644
index 0000000..edf4617
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r2_masked_input_coll_state.cc
@@ -0,0 +1,211 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h"
+
+#include <algorithm>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secagg_server_r3_unmasking_state.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggServerR2MaskedInputCollState::SecAggServerR2MaskedInputCollState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking)
+ : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
+ number_of_clients_failed_before_sending_masked_input,
+ number_of_clients_terminated_without_unmasking,
+ SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION,
+ std::move(impl)) {
+ accumulator_ = this->impl()->SetupMaskedInputCollection();
+}
+
+SecAggServerR2MaskedInputCollState::~SecAggServerR2MaskedInputCollState() {}
+
+Status SecAggServerR2MaskedInputCollState::HandleMessage(
+ uint32_t client_id, const ClientToServerWrapperMessage& message) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Call to deprecated HandleMessage method.";
+}
+
+Status SecAggServerR2MaskedInputCollState::HandleMessage(
+ uint32_t client_id, std::unique_ptr<ClientToServerWrapperMessage> message) {
+ if (message->has_abort()) {
+ MessageReceived(*message, false);
+ AbortClient(client_id, "", ClientDropReason::SENT_ABORT_MESSAGE,
+ /*notify=*/false);
+ return FCP_STATUS(OK);
+ }
+ // If the client has aborted already, ignore its messages.
+ if (client_status(client_id) != ClientStatus::SHARE_KEYS_RECEIVED) {
+ MessageReceived(*message, false);
+ AbortClient(client_id,
+ "Not expecting an MaskedInputCollectionResponse from this "
+ "client - either the client already aborted or one such "
+ "message was already received.",
+ ClientDropReason::MASKED_INPUT_UNEXPECTED);
+ return FCP_STATUS(OK);
+ }
+ if (!message->has_masked_input_response()) {
+ MessageReceived(*message, false);
+ AbortClient(client_id,
+ "Message type received is different from what was expected.",
+ ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
+ return FCP_STATUS(OK);
+ }
+ MessageReceived(*message, true);
+
+ Status check_and_accumulate_status =
+ impl()->HandleMaskedInputCollectionResponse(
+ std::make_unique<MaskedInputCollectionResponse>(
+ std::move(*message->mutable_masked_input_response())));
+ if (!check_and_accumulate_status.ok()) {
+ AbortClient(client_id, std::string(check_and_accumulate_status.message()),
+ ClientDropReason::INVALID_MASKED_INPUT);
+ return FCP_STATUS(OK);
+ }
+ set_client_status(client_id, ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED);
+ number_of_messages_received_in_this_round_++;
+ number_of_clients_ready_for_next_round_++;
+ return FCP_STATUS(OK);
+}
+
+bool SecAggServerR2MaskedInputCollState::IsNumberOfIncludedInputsCommitted()
+ const {
+ return false;
+}
+
+int SecAggServerR2MaskedInputCollState::MinimumMessagesNeededForNextRound()
+ const {
+ return std::max(0, minimum_number_of_clients_to_proceed() -
+ number_of_clients_ready_for_next_round_);
+}
+
+int SecAggServerR2MaskedInputCollState::NumberOfIncludedInputs() const {
+ return number_of_messages_received_in_this_round_;
+}
+
+int SecAggServerR2MaskedInputCollState::NumberOfPendingClients() const {
+ return NumberOfAliveClients() - number_of_clients_ready_for_next_round_;
+}
+
+void SecAggServerR2MaskedInputCollState::HandleAbortClient(
+ uint32_t client_id, ClientDropReason reason_code) {
+ if (client_status(client_id) ==
+ ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED) {
+ number_of_clients_ready_for_next_round_--;
+ number_of_clients_failed_after_sending_masked_input_++;
+ set_client_status(client_id,
+ ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED);
+ } else {
+ number_of_clients_failed_before_sending_masked_input_++;
+ clients_aborted_at_round_2_.push_back(client_id);
+ set_client_status(client_id, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
+ }
+ if (NumberOfAliveClients() < minimum_number_of_clients_to_proceed()) {
+ needs_to_abort_ = true;
+ }
+}
+
+void SecAggServerR2MaskedInputCollState::HandleAbort() {
+ if (accumulator_) {
+ accumulator_->Cancel();
+ }
+}
+
+StatusOr<std::unique_ptr<SecAggServerState>>
+SecAggServerR2MaskedInputCollState::ProceedToNextRound() {
+ if (!ReadyForNextRound()) {
+ return FCP_STATUS(UNAVAILABLE);
+ }
+ if (needs_to_abort_) {
+ std::string error_string = "Too many clients aborted.";
+ ServerToClientWrapperMessage message;
+ message.mutable_abort()->set_diagnostic_info(error_string);
+ message.mutable_abort()->set_early_success(false);
+ SendBroadcast(message);
+ HandleAbort();
+
+ return AbortState(error_string,
+ SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING);
+ }
+
+ // Close all clients that haven't yet sent a message.
+ for (int i = 0; i < total_number_of_clients(); ++i) {
+ if (!IsClientDead(i) &&
+ client_status(i) != ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED) {
+ AbortClient(i,
+ "Client did not send MaskedInputCollectionResponse before "
+ "round transition.",
+ ClientDropReason::NO_MASKED_INPUT);
+ }
+ }
+ // Send to each alive client the list of their aborted neighbors
+ for (int i = 0; i < total_number_of_clients(); ++i) {
+ if (IsClientDead(i)) {
+ continue;
+ }
+ ServerToClientWrapperMessage message_to_i;
+ // Set message to proper type
+ auto request = message_to_i.mutable_unmasking_request();
+ for (uint32_t aborted_client : clients_aborted_at_round_2_) {
+ // neighbor_index has a value iff i and aborted_client are neighbors
+ auto neighbor_index = GetNeighborIndex(i, aborted_client);
+ if (neighbor_index.has_value()) {
+ // TODO(team): Stop adding + 1 here once we don't need
+ // compatibility.
+ request->add_dead_3_client_ids(neighbor_index.value() + 1);
+ }
+ }
+ Send(i, message_to_i);
+ }
+
+ impl()->FinalizeMaskedInputCollection();
+
+ return {std::make_unique<SecAggServerR3UnmaskingState>(
+ ExitState(StateTransition::kSuccess),
+ number_of_clients_failed_after_sending_masked_input_,
+ number_of_clients_failed_before_sending_masked_input_,
+ number_of_clients_terminated_without_unmasking_)};
+}
+
+bool SecAggServerR2MaskedInputCollState::SetAsyncCallback(
+ std::function<void()> async_callback) {
+ if (accumulator_) {
+ return accumulator_->SetAsyncObserver(async_callback);
+ }
+ return false;
+}
+
+bool SecAggServerR2MaskedInputCollState::ReadyForNextRound() const {
+ // Accumulator is not set (this is a synchronous session) or it does not have
+ // unobserved work.
+ bool accumulator_is_idle = (!accumulator_ || accumulator_->IsIdle());
+ return accumulator_is_idle && ((number_of_clients_ready_for_next_round_ >=
+ minimum_number_of_clients_to_proceed()) ||
+ (needs_to_abort_));
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h b/fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h
new file mode 100644
index 0000000..4eab38d
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h
@@ -0,0 +1,85 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_R2_MASKED_INPUT_COLL_STATE_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_R2_MASKED_INPUT_COLL_STATE_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+
+#include "fcp/secagg/server/secagg_scheduler.h"
+#include "fcp/secagg/server/secagg_server_state.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class is the State for the SecAggServer when it is in the
+// Round 2: Masked Input Collection state. This state receives masked inputs
+// from clients and adds them together in preparation for the unmasking step. At
+// the conclusion of masked input collection, if the server has collected enough
+// masked inputs, it sends the clients a message with the set of clients that
+// have not sent masked inputs and moved into Round 3: Unmasking. If too many
+// clients abort, it can abort instead.
+class SecAggServerR2MaskedInputCollState : public SecAggServerState {
+ public:
+ SecAggServerR2MaskedInputCollState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking);
+
+ ~SecAggServerR2MaskedInputCollState() override;
+
+ bool IsNumberOfIncludedInputsCommitted() const override;
+
+ int MinimumMessagesNeededForNextRound() const override;
+
+ int NumberOfIncludedInputs() const override;
+
+ int NumberOfPendingClients() const override;
+
+ // This will return true only after minimum_number_of_clients_to_proceed
+ // clients have sent messages (and not subsequently aborted).
+ bool ReadyForNextRound() const override;
+
+ // Handles a masked input response or abort message from a client.
+ Status HandleMessage(uint32_t client_id,
+ const ClientToServerWrapperMessage& message) override;
+ Status HandleMessage(
+ uint32_t client_id,
+ std::unique_ptr<ClientToServerWrapperMessage> message) override;
+
+ StatusOr<std::unique_ptr<SecAggServerState> > ProceedToNextRound() override;
+
+ bool SetAsyncCallback(std::function<void()> async_callback) override;
+
+ protected:
+ // Track the clients who abort this round and send this list to the clients.
+ std::vector<uint32_t> clients_aborted_at_round_2_;
+
+ private:
+ std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>> accumulator_;
+ void HandleAbort() override;
+
+ void HandleAbortClient(uint32_t client_id,
+ ClientDropReason reason_code) override;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_R2_MASKED_INPUT_COLL_STATE_H_
diff --git a/fcp/secagg/server/secagg_server_r2_masked_input_coll_state_test.cc b/fcp/secagg/server/secagg_server_r2_masked_input_coll_state_test.cc
new file mode 100644
index 0000000..82f9e7a
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r2_masked_input_coll_state_test.cc
@@ -0,0 +1,931 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/node_hash_set.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/experiments_interface.h"
+#include "fcp/secagg/server/experiments_names.h"
+#include "fcp/secagg/server/secagg_server_state.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/secagg/server/send_to_clients_interface.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
+#include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
+#include "fcp/secagg/testing/server/test_secagg_experiments.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::_;
+using ::testing::Eq;
+using ::testing::Ge;
+using ::testing::IsFalse;
+using ::testing::IsTrue;
+
+class FakeScheduler : public Scheduler {
+ public:
+ void Schedule(std::function<void()> job) override { jobs_.push_back(job); }
+
+ void WaitUntilIdle() override {}
+
+ void Run() {
+ for (auto& job : jobs_) {
+ job();
+ }
+ jobs_.clear();
+ }
+
+ private:
+ std::vector<std::function<void()>> jobs_;
+};
+
+// Default test session_id.
+SessionId session_id = {"session id number, 32 bytes long"};
+
+struct SecAggR2StateTestParams {
+ const std::string test_name;
+ // Enables asymchronous processing of round 2 messages by the server.
+ bool enable_async_r2;
+};
+
+class SecaggServerR2MaskedInputCollStateTest
+ : public ::testing::TestWithParam<SecAggR2StateTestParams> {
+ protected:
+ std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
+ int minimum_number_of_clients_to_proceed, int total_number_of_clients,
+ MockSendToClientsInterface* sender,
+ MockSecAggServerMetricsListener* metrics_listener = nullptr,
+ bool enable_async_r2 = true) {
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ SecretSharingGraphFactory factory;
+ auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
+ factory.CreateCompleteGraph(total_number_of_clients,
+ minimum_number_of_clients_to_proceed),
+ minimum_number_of_clients_to_proceed, input_vector_specs,
+ std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
+ std::make_unique<AesCtrPrngFactory>(), sender,
+ std::make_unique<SecAggScheduler>(&parallel_scheduler_,
+ &sequential_scheduler_),
+ std::vector<ClientStatus>(total_number_of_clients,
+ ClientStatus::SHARE_KEYS_RECEIVED),
+ ServerVariant::NATIVE_V1,
+ enable_async_r2
+ ? std::make_unique<TestSecAggExperiment>(
+ TestSecAggExperiment(kSecAggAsyncRound2Experiment))
+ : std::make_unique<TestSecAggExperiment>(TestSecAggExperiment()));
+ impl->set_session_id(std::make_unique<SessionId>(session_id));
+ EcdhPregeneratedTestKeys ecdh_keys;
+ for (int i = 0; i < total_number_of_clients; ++i) {
+ impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
+ }
+
+ return impl;
+ }
+
+ void RunSchedulers() {
+ parallel_scheduler_.Run();
+ sequential_scheduler_.Run();
+ }
+
+ private:
+ FakeScheduler parallel_scheduler_;
+ FakeScheduler sequential_scheduler_;
+};
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest, IsAbortedReturnsFalse) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(),
+ nullptr /* metrics_listener */,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(state.IsAborted(), IsFalse());
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest,
+ IsCompletedSuccessfullyReturnsFalse) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(),
+ nullptr /* metrics_listener */,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest, ErrorMessageRaisesErrorStatus) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(),
+ nullptr /* metrics_listener */,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest, ResultRaisesErrorStatus) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(),
+ nullptr /* metrics_listener */,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_THAT(state.Result().ok(), IsFalse());
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest,
+ AbortReturnsValidStateAndNotifiesClients) {
+ TestTracingRecorder tracing_recorder;
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
+
+ EXPECT_CALL(*metrics,
+ ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
+ auto next_state =
+ state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
+
+ ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(next_state->ErrorMessage(), IsOk());
+ EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
+ ElementsAre(IsEvent<BroadcastMessageSent>(
+ Eq(ServerToClientMessageType_Abort),
+ Eq(abort_message.ByteSizeLong()))));
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest,
+ StateProceedsCorrectlyWithAllClientsValid) {
+ // In this test, all clients send in their valid masked inputs, and then the
+ // server proceeds to the next state.
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(),
+ nullptr /* metrics_listener */,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ }
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
+ }
+
+ if (i < 4) {
+ // Have client send a vector of the correct size to the server
+ auto client_message = std::make_unique<ClientToServerWrapperMessage>();
+ MaskedInputVector encoded_vector;
+ SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
+ encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
+ (*client_message->mutable_masked_input_response()
+ ->mutable_vectors())["foobar"] = encoded_vector;
+ ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+ }
+
+ if (GetParam().enable_async_r2) {
+ RunSchedulers();
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+
+ ServerToClientWrapperMessage server_message;
+ server_message.mutable_unmasking_request()
+ ->mutable_dead_3_client_ids()
+ ->Clear(); // Just to set it to an empty vector
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R3_UNMASKING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<IndividualMessageSent>(),
+ ElementsAre(IsEvent<IndividualMessageSent>(
+ 0, Eq(ServerToClientMessageType_UnmaskingRequest),
+ Eq(server_message.ByteSizeLong())),
+ IsEvent<IndividualMessageSent>(
+ 1, Eq(ServerToClientMessageType_UnmaskingRequest),
+ Eq(server_message.ByteSizeLong())),
+ IsEvent<IndividualMessageSent>(
+ 2, Eq(ServerToClientMessageType_UnmaskingRequest),
+ Eq(server_message.ByteSizeLong())),
+ IsEvent<IndividualMessageSent>(
+ 3, Eq(ServerToClientMessageType_UnmaskingRequest),
+ Eq(server_message.ByteSizeLong()))));
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest,
+ StateProceedsCorrectlyWithoutAllClients) {
+ // In this test, clients 0 through 2 send in valid masked inputs, and then we
+ // proceed to the next step even without client 3.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(),
+ nullptr /* metrics_listener */,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ }
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
+ }
+
+ if (i < 3) {
+ // Have client send a vector of the correct size to the server
+ auto client_message = std::make_unique<ClientToServerWrapperMessage>();
+ MaskedInputVector encoded_vector;
+ SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
+ encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
+ (*client_message->mutable_masked_input_response()
+ ->mutable_vectors())["foobar"] = encoded_vector;
+ ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+ }
+
+ if (GetParam().enable_async_r2) {
+ RunSchedulers();
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+
+ ServerToClientWrapperMessage server_message;
+ // TODO(team): 4 -> 3 below, once backwards compatibility not needed.
+ server_message.mutable_unmasking_request()->add_dead_3_client_ids(4);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "Client did not send MaskedInputCollectionResponse before round "
+ "transition.");
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
+ }
+ EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R3_UNMASKING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest,
+ StateProceedsCorrectlyWithOneClientSendingInvalidInput) {
+ // In this test, client 0 sends an invalid masked input, so it is aborted. The
+ // rest of the round goes normally.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(),
+ nullptr /* metrics_listener */,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ ServerToClientWrapperMessage server_message;
+ // TODO(team): 1 -> 0 below, once backwards compatibility not needed.
+ server_message.mutable_unmasking_request()->add_dead_3_client_ids(1);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "Masked input does not match input vector specification - vector is "
+ "wrong size.");
+
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 1; i < 4; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
+ }
+ EXPECT_CALL(*sender, Send(0, EqualsProto(abort_message))).Times(1);
+
+ // Have client 0 send an invalid message.
+ auto invalid_message = std::make_unique<ClientToServerWrapperMessage>();
+ MaskedInputVector encoded_vector;
+ encoded_vector.set_encoded_vector("not a real masked input vector - invalid");
+ (*invalid_message->mutable_masked_input_response()
+ ->mutable_vectors())["foobar"] = encoded_vector;
+ ASSERT_THAT(state.HandleMessage(0, std::move(invalid_message)), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ for (int i = 1; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i - 1));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i - 1));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 4) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(4 - i));
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ }
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 4));
+ }
+
+ if (i < 4) {
+ // Have client send a vector of the correct size to the server
+ auto client_message = std::make_unique<ClientToServerWrapperMessage>();
+ MaskedInputVector encoded_vector;
+ SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
+ encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
+ (*client_message->mutable_masked_input_response()
+ ->mutable_vectors())["foobar"] = encoded_vector;
+ ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
+ }
+ }
+ }
+
+ if (GetParam().enable_async_r2) {
+ RunSchedulers();
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R3_UNMASKING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest,
+ StateProceedsCorrectlyWithOneClientAbortingAfterSendingInput) {
+ // In this test, all clients send in their valid masked inputs, but then
+ // client 2 aborts before the server proceeds to the next state.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(),
+ nullptr /* metrics_listener */,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ }
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
+ }
+ if (i < 4) {
+ // Have client send a vector of the correct size to the server
+ auto client_message = std::make_unique<ClientToServerWrapperMessage>();
+ MaskedInputVector encoded_vector;
+ SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
+ encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
+ (*client_message->mutable_masked_input_response()
+ ->mutable_vectors())["foobar"] = encoded_vector;
+ ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+ }
+
+ if (GetParam().enable_async_r2) {
+ RunSchedulers();
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+
+ auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
+ abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
+ ASSERT_THAT(state.HandleMessage(2, std::move(abort_message)), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(4));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+
+ ServerToClientWrapperMessage server_message;
+ server_message.mutable_unmasking_request()
+ ->mutable_dead_3_client_ids()
+ ->Clear(); // Just to set it to an empty vector
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ for (int i = 0; i < 4; ++i) {
+ if (i != 2) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R3_UNMASKING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest,
+ StateForcesAbortIfTooManyClientsAbort) {
+ // In this test, clients 0 and 1 abort, so the state aborts.
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(),
+ nullptr /* metrics_listener */,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), Eq(i >= 2));
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4 - i));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3));
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ if (i < 2) {
+ // Have client abort
+ auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
+ abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
+ ASSERT_THAT(state.HandleMessage(i, std::move(abort_message)), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 1));
+ }
+ }
+
+ ServerToClientWrapperMessage server_message;
+ server_message.mutable_abort()->set_early_success(false);
+ server_message.mutable_abort()->set_diagnostic_info(
+ "Too many clients aborted.");
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(1);
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(next_state.value()->ErrorMessage(), IsOk());
+ EXPECT_THAT(next_state.value()->ErrorMessage().value(),
+ Eq("Too many clients aborted."));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
+ ElementsAre(IsEvent<BroadcastMessageSent>(
+ Eq(ServerToClientMessageType_Abort),
+ Eq(server_message.ByteSizeLong()))));
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest, MetricsRecordsMessageSizes) {
+ // In this test, all clients send in their valid masked inputs, but then
+ // client 2 aborts before the server proceeds to the next state.
+ TestTracingRecorder tracing_recorder;
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ }
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
+ }
+ if (i < 4) {
+ // Have client send a vector of the correct size to the server
+ auto client_message = std::make_unique<ClientToServerWrapperMessage>();
+ MaskedInputVector encoded_vector;
+ SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
+ encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
+ (*client_message->mutable_masked_input_response()
+ ->mutable_vectors())["foobar"] = encoded_vector;
+ EXPECT_CALL(
+ *metrics,
+ MessageReceivedSizes(Eq(ClientToServerWrapperMessage::
+ MessageContentCase::kMaskedInputResponse),
+ Eq(true), Eq(client_message->ByteSizeLong())));
+ ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+ }
+
+ if (GetParam().enable_async_r2) {
+ RunSchedulers();
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+
+ auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
+ abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
+ EXPECT_CALL(*metrics,
+ MessageReceivedSizes(
+ Eq(ClientToServerWrapperMessage::MessageContentCase::kAbort),
+ Eq(false), Eq(abort_message->ByteSizeLong())));
+
+ size_t abort_message_size = abort_message->ByteSizeLong();
+ ASSERT_THAT(state.HandleMessage(2, std::move(abort_message)), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(4));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ EXPECT_THAT(tracing_recorder.root(),
+ Contains(IsEvent<ClientMessageReceived>(
+ Eq(ClientToServerMessageType_Abort), Eq(abort_message_size),
+ Eq(false), Ge(0))));
+
+ ServerToClientWrapperMessage server_message;
+ server_message.mutable_unmasking_request()
+ ->mutable_dead_3_client_ids()
+ ->Clear(); // Just to set it to an empty vector
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(0);
+ for (int i = 0; i < 4; ++i) {
+ if (i != 2) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
+ }
+ }
+ EXPECT_CALL(*metrics, BroadcastMessageSizes(_, _)).Times(0);
+ EXPECT_CALL(*metrics, IndividualMessageSizes(
+ Eq(ServerToClientWrapperMessage::
+ MessageContentCase::kUnmaskingRequest),
+ Eq(server_message.ByteSizeLong())))
+ .Times(3);
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R3_UNMASKING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest,
+ ServerAndClientAbortsAreRecordedCorrectly) {
+ // In this test clients abort for a variety of reasons, and then ultimately
+ // the server aborts. Metrics should record all of these events.
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(2, 7, sender.get(), metrics,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_CALL(*metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED),
+ Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
+ EXPECT_CALL(*metrics,
+ ClientsDropped(
+ Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
+ Eq(ClientDropReason::MASKED_INPUT_UNEXPECTED)));
+ EXPECT_CALL(*metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED),
+ Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
+ EXPECT_CALL(*metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED),
+ Eq(ClientDropReason::INVALID_MASKED_INPUT)))
+ .Times(3);
+ EXPECT_CALL(
+ *metrics,
+ ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
+
+ auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
+ abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
+
+ ClientToServerWrapperMessage valid_message;
+ MaskedInputVector encoded_vector;
+ SecAggVector masked_vector(std::vector<uint64_t>(4, 9), 32);
+ encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
+ (*valid_message.mutable_masked_input_response()
+ ->mutable_vectors())["foobar"] = encoded_vector;
+
+ auto invalid_message_too_many_vectors =
+ std::make_unique<ClientToServerWrapperMessage>();
+ (*invalid_message_too_many_vectors->mutable_masked_input_response()
+ ->mutable_vectors())["extra"] = encoded_vector;
+
+ auto invalid_message_wrong_name =
+ std::make_unique<ClientToServerWrapperMessage>();
+ (*invalid_message_wrong_name->mutable_masked_input_response()
+ ->mutable_vectors())["wrong"] = encoded_vector;
+
+ auto invalid_message_wrong_size =
+ std::make_unique<ClientToServerWrapperMessage>();
+ MaskedInputVector large_encoded_vector;
+ SecAggVector large_masked_vector(std::vector<uint64_t>(7, 9), 32);
+ large_encoded_vector.set_encoded_vector(
+ large_masked_vector.GetAsPackedBytes());
+ (*invalid_message_wrong_size->mutable_masked_input_response()
+ ->mutable_vectors())["foobar"] = large_encoded_vector;
+
+ auto wrong_message = std::make_unique<ClientToServerWrapperMessage>();
+ wrong_message->mutable_advertise_keys(); // wrong type of message
+
+ state.HandleMessage(0, std::move(abort_message)).IgnoreError();
+ state
+ .HandleMessage(
+ 1, std::make_unique<ClientToServerWrapperMessage>(valid_message))
+ .IgnoreError();
+ state
+ .HandleMessage(
+ 1, std::make_unique<ClientToServerWrapperMessage>(valid_message))
+ .IgnoreError();
+ state.HandleMessage(2, std::move(invalid_message_too_many_vectors))
+ .IgnoreError();
+ state.HandleMessage(3, std::move(invalid_message_wrong_name)).IgnoreError();
+ state.HandleMessage(4, std::move(invalid_message_wrong_size)).IgnoreError();
+ state.HandleMessage(5, std::move(wrong_message)).IgnoreError();
+
+ if (GetParam().enable_async_r2) {
+ RunSchedulers();
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+
+ state.ProceedToNextRound().IgnoreError(); // causes server abort
+}
+
+TEST_P(SecaggServerR2MaskedInputCollStateTest, MetricsAreRecorded) {
+ // In this test, clients 0 through 2 send in valid masked inputs, and then we
+ // proceed to the next step even without client 3.
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR2MaskedInputCollState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics,
+ GetParam().enable_async_r2),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0 // number_of_clients_terminated_without_unmasking
+ );
+
+ EXPECT_CALL(*metrics, ClientResponseTimes(
+ Eq(ClientToServerWrapperMessage::
+ MessageContentCase::kMaskedInputResponse),
+ Ge(0)))
+ .Times(3);
+
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ }
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
+ }
+ if (i < 3) {
+ // Have client send a vector of the correct size to the server
+ auto client_message = std::make_unique<ClientToServerWrapperMessage>();
+ MaskedInputVector encoded_vector;
+ SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
+ encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
+ (*client_message->mutable_masked_input_response()
+ ->mutable_vectors())["foobar"] = encoded_vector;
+ ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
+ if (GetParam().enable_async_r2) {
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+ }
+
+ if (GetParam().enable_async_r2) {
+ RunSchedulers();
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+
+ ServerToClientWrapperMessage server_message;
+ // TODO(team): 4 -> 3 below, once backwards compatibility not needed.
+ server_message.mutable_unmasking_request()->add_dead_3_client_ids(4);
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "Client did not send MaskedInputCollectionResponse before round "
+ "transition.");
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(0);
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
+ }
+ EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
+ EXPECT_CALL(*metrics,
+ RoundTimes(Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION),
+ Eq(true), Ge(0)));
+ EXPECT_CALL(
+ *metrics,
+ RoundSurvivingClients(
+ Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION), Eq(3)));
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::R3_UNMASKING));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ SecaggServerR2MaskedInputCollStateTests,
+ SecaggServerR2MaskedInputCollStateTest,
+ ::testing::ValuesIn<SecAggR2StateTestParams>(
+ {{"r2_async_processing_enabled", true},
+ {"r2_async_processing_disabled", false}}),
+ [](const ::testing::TestParamInfo<
+ SecaggServerR2MaskedInputCollStateTest::ParamType>& info) {
+ return info.param.test_name;
+ });
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_r3_unmasking_state.cc b/fcp/secagg/server/secagg_server_r3_unmasking_state.cc
new file mode 100644
index 0000000..018c978
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r3_unmasking_state.cc
@@ -0,0 +1,167 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_r3_unmasking_state.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secagg_server_prng_running_state.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggServerR3UnmaskingState::SecAggServerR3UnmaskingState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking)
+ : SecAggServerState(number_of_clients_failed_after_sending_masked_input,
+ number_of_clients_failed_before_sending_masked_input,
+ number_of_clients_terminated_without_unmasking,
+ SecAggServerStateKind::R3_UNMASKING, std::move(impl)) {
+ this->impl()->SetUpShamirSharesTables();
+}
+
+SecAggServerR3UnmaskingState::~SecAggServerR3UnmaskingState() {}
+
+Status SecAggServerR3UnmaskingState::HandleMessage(
+ uint32_t client_id, const ClientToServerWrapperMessage& message) {
+ if (message.has_abort()) {
+ MessageReceived(message, false);
+ AbortClient(client_id, "Client sent abort message.",
+ ClientDropReason::SENT_ABORT_MESSAGE,
+ /*notify=*/false);
+ return FCP_STATUS(OK);
+ }
+ // If the client has aborted already, ignore its messages.
+ if (client_status(client_id) !=
+ ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED) {
+ MessageReceived(message, false);
+ AbortClient(
+ client_id,
+ "Not expecting an UnmaskingResponse from this client - either the "
+ "client already aborted or one such message was already received.",
+ ClientDropReason::UNMASKING_RESPONSE_UNEXPECTED);
+ return FCP_STATUS(OK);
+ }
+ if (!message.has_unmasking_response()) {
+ MessageReceived(message, false);
+ AbortClient(client_id,
+ "Message type received is different from what was expected.",
+ ClientDropReason::UNEXPECTED_MESSAGE_TYPE);
+ return FCP_STATUS(OK);
+ }
+ MessageReceived(message, true);
+
+ Status status =
+ impl()->HandleUnmaskingResponse(client_id, message.unmasking_response());
+ if (!status.ok()) {
+ AbortClient(client_id, std::string(status.message()),
+ ClientDropReason::INVALID_UNMASKING_RESPONSE);
+ return FCP_STATUS(OK);
+ }
+
+ set_client_status(client_id, ClientStatus::UNMASKING_RESPONSE_RECEIVED);
+ number_of_messages_received_in_this_round_++;
+ number_of_clients_ready_for_next_round_++;
+ return FCP_STATUS(OK);
+}
+
+bool SecAggServerR3UnmaskingState::IsNumberOfIncludedInputsCommitted() const {
+ return true;
+}
+
+int SecAggServerR3UnmaskingState::MinimumMessagesNeededForNextRound() const {
+ return std::max(0, minimum_number_of_clients_to_proceed() -
+ number_of_messages_received_in_this_round_);
+}
+
+int SecAggServerR3UnmaskingState::NumberOfIncludedInputs() const {
+ return total_number_of_clients() -
+ number_of_clients_failed_before_sending_masked_input_;
+}
+
+int SecAggServerR3UnmaskingState::NumberOfPendingClients() const {
+ return NumberOfAliveClients() - number_of_clients_ready_for_next_round_;
+}
+
+void SecAggServerR3UnmaskingState::HandleAbortClient(
+ uint32_t client_id, ClientDropReason reason_code) {
+ if (client_status(client_id) == ClientStatus::UNMASKING_RESPONSE_RECEIVED) {
+ set_client_status(client_id,
+ ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED);
+ return;
+ }
+ if (reason_code == ClientDropReason::EARLY_SUCCESS) {
+ number_of_clients_terminated_without_unmasking_++;
+ } else {
+ number_of_clients_failed_after_sending_masked_input_++;
+ }
+ set_client_status(client_id,
+ ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED);
+ if (NumberOfPendingClients() + number_of_messages_received_in_this_round_ <
+ minimum_number_of_clients_to_proceed()) {
+ needs_to_abort_ = true;
+ }
+}
+
+bool SecAggServerR3UnmaskingState::ReadyForNextRound() const {
+ return (number_of_messages_received_in_this_round_ >=
+ minimum_number_of_clients_to_proceed()) ||
+ (needs_to_abort_);
+}
+
+StatusOr<std::unique_ptr<SecAggServerState> >
+SecAggServerR3UnmaskingState::ProceedToNextRound() {
+ if (!ReadyForNextRound()) {
+ return FCP_STATUS(UNAVAILABLE);
+ }
+ if (needs_to_abort_) {
+ std::string error_string = "Too many clients aborted.";
+ ServerToClientWrapperMessage message;
+ message.mutable_abort()->set_diagnostic_info(error_string);
+ message.mutable_abort()->set_early_success(false);
+ SendBroadcast(message);
+
+ return AbortState(error_string,
+ SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING);
+ }
+
+ // Abort all clients that haven't yet sent a message, but let them count it as
+ // a success.
+ for (int i = 0; i < total_number_of_clients(); ++i) {
+ if (client_status(i) != ClientStatus::UNMASKING_RESPONSE_RECEIVED) {
+ AbortClient(
+ i,
+ "Client did not send unmasking response but protocol completed "
+ "successfully.",
+ ClientDropReason::EARLY_SUCCESS);
+ }
+ }
+
+ return {std::make_unique<SecAggServerPrngRunningState>(
+ ExitState(StateTransition::kSuccess),
+ number_of_clients_failed_after_sending_masked_input_,
+ number_of_clients_failed_before_sending_masked_input_,
+ number_of_clients_terminated_without_unmasking_)};
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_r3_unmasking_state.h b/fcp/secagg/server/secagg_server_r3_unmasking_state.h
new file mode 100644
index 0000000..984e99d
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r3_unmasking_state.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_R3_UNMASKING_STATE_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_R3_UNMASKING_STATE_H_
+
+#include <memory>
+
+#include "fcp/secagg/server/secagg_server_state.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class is the State for the SecAggServer when it is in the
+// Round 3: Unmasking state. This state covers the process of collecting secret
+// shares from clients, based on which clients submitted masked input in the
+// previous round. Unless the server aborts a client (or itself), it should not
+// need to send messages this state. This state should transition to
+// SecAggServerPrngRunningState once enough secret shares have been collected.
+// Unlike previous steps, there is no particular reason to wait for more than
+// the bare minimum number of clients to proceed.
+
+class SecAggServerR3UnmaskingState : public SecAggServerState {
+ public:
+ SecAggServerR3UnmaskingState(
+ std::unique_ptr<SecAggServerProtocolImpl> impl,
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking);
+
+ ~SecAggServerR3UnmaskingState() override;
+
+ // Handles an unmasking response or abort message from a client.
+ Status HandleMessage(uint32_t client_id,
+ const ClientToServerWrapperMessage& message) override;
+
+ bool IsNumberOfIncludedInputsCommitted() const override;
+
+ int MinimumMessagesNeededForNextRound() const override;
+
+ int NumberOfIncludedInputs() const override;
+
+ int NumberOfPendingClients() const override;
+
+ StatusOr<std::unique_ptr<SecAggServerState> > ProceedToNextRound() override;
+
+ // This will return true only after minimum_number_of_clients_to_proceed
+ // messages have been received.
+ bool ReadyForNextRound() const override;
+
+ private:
+ void HandleAbortClient(uint32_t client_id,
+ ClientDropReason reason_code) override;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_R3_UNMASKING_STATE_H_
diff --git a/fcp/secagg/server/secagg_server_r3_unmasking_state_test.cc b/fcp/secagg/server/secagg_server_r3_unmasking_state_test.cc
new file mode 100644
index 0000000..4bdced0
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_r3_unmasking_state_test.cc
@@ -0,0 +1,924 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_r3_unmasking_state.h"
+
+#include <memory>
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/secagg_server_state.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/fake_prng.h"
+#include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
+#include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::_;
+using ::testing::Eq;
+using ::testing::Ge;
+using ::testing::IsFalse;
+using ::testing::IsTrue;
+using ::testing::Ne;
+
+// Default test session_id.
+SessionId session_id = {"session id number, 32 bytes long"};
+
+std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
+ int minimum_number_of_clients_to_proceed, int total_number_of_clients,
+ MockSendToClientsInterface* sender,
+ MockSecAggServerMetricsListener* metrics_listener = nullptr) {
+ auto input_vector_specs = std::vector<InputVectorSpecification>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ SecretSharingGraphFactory factory;
+ auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
+ factory.CreateCompleteGraph(total_number_of_clients,
+ minimum_number_of_clients_to_proceed),
+ minimum_number_of_clients_to_proceed, input_vector_specs,
+ std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
+ std::make_unique<AesCtrPrngFactory>(), sender,
+ nullptr, // prng_runner
+ std::vector<ClientStatus>(total_number_of_clients,
+ ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED),
+ ServerVariant::NATIVE_V1);
+ impl->set_session_id(std::make_unique<SessionId>(session_id));
+ EcdhPregeneratedTestKeys ecdh_keys;
+
+ for (int i = 0; i < total_number_of_clients; ++i) {
+ impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
+ }
+
+ return impl;
+}
+
+TEST(SecaggServerR3UnmaskingStateTest, IsAbortedReturnsFalse) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_THAT(state.IsAborted(), IsFalse());
+}
+
+TEST(SecaggServerR3UnmaskingStateTest, IsCompletedSuccessfullyReturnsFalse) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
+}
+
+TEST(SecaggServerR3UnmaskingStateTest, ErrorMessageRaisesErrorStatus) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
+}
+
+TEST(SecaggServerR3UnmaskingStateTest, ResultRaisesErrorStatus) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_THAT(state.Result().ok(), IsFalse());
+}
+
+TEST(SecaggServerR3UnmaskingStateTest,
+ AbortClientAfterUnmaskingResponseReceived) {
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics);
+ impl->set_client_status(2, ClientStatus::UNMASKING_RESPONSE_RECEIVED);
+ SecAggServerR3UnmaskingState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ state.AbortClient(2, "close client message.",
+ ClientDropReason::SENT_ABORT_MESSAGE, false);
+ ASSERT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(0));
+ // Metrics are not logged
+ EXPECT_CALL(*metrics, ClientsDropped(_, _)).Times(0);
+ // Client is not notified
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+ ASSERT_THAT(state.AbortedClientIds().contains(2), Eq(true));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest,
+ AbortReturnsValidStateAndNotifiesClients) {
+ TestTracingRecorder tracing_recorder;
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(false);
+ abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
+
+ EXPECT_CALL(*metrics,
+ ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
+ auto next_state =
+ state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
+
+ ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(next_state->ErrorMessage(), IsOk());
+ EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
+ ElementsAre(IsEvent<BroadcastMessageSent>(
+ Eq(ServerToClientMessageType_Abort),
+ Eq(abort_message.ByteSizeLong()))));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest,
+ StateProceedsCorrectlyWithNoAbortsAndAllCorrectMessagesReceived) {
+ // In this test, no clients abort or aborted at any point, and all four
+ // clients send unmasking responses to the server before ProceedToNextRound is
+ // called.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ // Set up correct responses
+ std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client ", i));
+ }
+ }
+
+ // No clients should actually get a message in this round.
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+
+ // i is the number of messages received so far
+ for (int i = 0; i <= 4; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::PRNG_RUNNING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest,
+ StateProceedsCorrectlyWithMinimumCorrectMessagesReceived) {
+ // In this test, no clients abort or aborted at any point, but
+ // ProceedToNextRound is called after only 3 clients have submitted masked
+ // input responses. This is perfectly valid because the threshold is 3.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ // Set up correct responses
+ std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client ", i));
+ }
+ }
+
+ // Only client 3 should get a message this round.
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_early_success(true);
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "Client did not send unmasking response but protocol completed "
+ "successfully.");
+ EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
+ EXPECT_CALL(*sender, Send(Ne(3), _)).Times(0);
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ EXPECT_CALL(*metrics,
+ ClientsDropped(
+ Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
+ Eq(ClientDropReason::EARLY_SUCCESS)));
+
+ // i is the number of messages received so far. Stop after 3
+ for (int i = 0; i <= 3; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 3) {
+ ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::PRNG_RUNNING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(1));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest, StateProceedsCorrectlyWithOneFailure) {
+ // In this test, no clients abort or aborted at any point, but client 0 sends
+ // an invalid message. It should be aborted, but the other 3 clients should be
+ // enough to proceed.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ // Set up correct responses
+ std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client ", i));
+ }
+ }
+ // Add an incorrect response.
+ unmasking_responses[0]
+ .mutable_unmasking_response()
+ ->mutable_noise_or_prf_key_shares(2)
+ ->set_noise_sk_share("This is the wrong type of share!");
+
+ // Only client 0 should get a message this round.
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "Client did not include the correct type of key share.");
+ abort_message.mutable_abort()->set_early_success(false);
+ EXPECT_CALL(*sender, Send(0, EqualsProto(abort_message))).Times(1);
+ EXPECT_CALL(*sender, Send(Ne(0), _)).Times(0);
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+
+ EXPECT_THAT(state.HandleMessage(0, unmasking_responses[0]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ EXPECT_THAT(state.NumberOfClientsFailedAfterSendingMaskedInput(), Eq(1));
+ EXPECT_THAT(state.AbortedClientIds().contains(0), IsTrue());
+
+ // i is the number of messages received so far.
+ for (int i = 1; i <= 4; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i - 1));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i - 1));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 4) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(4 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::PRNG_RUNNING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest,
+ StateProceedsCorrectlyWithAnAbortInRound2) {
+ // In this test, client 3 never sent a masked input, so clients should send
+ // the pairwise key share for client 3.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get());
+ impl->set_client_status(3, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
+
+ SecAggServerR3UnmaskingState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 1, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ // Set up correct responses
+ std::vector<ClientToServerWrapperMessage> unmasking_responses(3);
+ for (int i = 0; i < 3; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client ", i));
+ }
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_noise_sk_share(
+ absl::StrCat("Test key share for client ", 3, " from client ", i));
+ }
+
+ // No clients should actually get a message in this round.
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+
+ // i is the number of messages received so far
+ for (int i = 0; i <= 3; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(3 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 3) {
+ ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::PRNG_RUNNING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest,
+ StateProceedsCorrectlyWithAnAbortInRound1) {
+ // In this test, client 3 never even finished the key share round, so the
+ // other clients should send no key share for client 3.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get());
+ impl->set_client_status(3, ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED);
+
+ SecAggServerR3UnmaskingState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 1, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ // Set up correct responses
+ std::vector<ClientToServerWrapperMessage> unmasking_responses(3);
+ for (int i = 0; i < 3; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client ", i));
+ }
+ unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ }
+
+ // No clients should actually get a message in this round.
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+
+ // i is the number of messages received so far
+ for (int i = 0; i <= 3; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(3 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 3) {
+ ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::PRNG_RUNNING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest,
+ StateProceedsCorrectlyEvenIfClientsAbortAfterSendingMessage) {
+ // In this test, clients 0 and 1 send valid messages but then abort. But since
+ // they sent valid messages, the server should proceed regardless.
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ // Set up correct responses
+ std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client ", i));
+ }
+ }
+
+ // No clients should actually get a message in this round.
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+
+ ClientToServerWrapperMessage abort_message;
+ abort_message.mutable_abort();
+
+ // i is the number of messages received so far
+ for (int i = 0; i <= 4; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 4) {
+ ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+ // These should not change anything.
+ EXPECT_THAT(state.HandleMessage(0, abort_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ EXPECT_THAT(state.HandleMessage(1, abort_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(4));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(4));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::PRNG_RUNNING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest, StateAbortsIfTooManyClientsAbort) {
+ // In this test, clients 0 and 1 send abort messages rather than valid
+ // unmasking responses, so the server must abort
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get()),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ // Set up correct responses
+ std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client ", i));
+ }
+ }
+
+ // No individual clients should get a message, but the server should broadcast
+ // an abort message
+ ServerToClientWrapperMessage server_abort_message;
+ server_abort_message.mutable_abort()->set_diagnostic_info(
+ "Too many clients aborted.");
+ server_abort_message.mutable_abort()->set_early_success(false);
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_abort_message)))
+ .Times(1);
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+
+ ClientToServerWrapperMessage client_abort_message;
+ client_abort_message.mutable_abort();
+
+ ASSERT_THAT(state.HandleMessage(0, client_abort_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ ASSERT_THAT(state.HandleMessage(1, client_abort_message), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ EXPECT_THAT(state.NeedsToAbort(), IsTrue());
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
+ EXPECT_THAT(next_state.value()->ErrorMessage(), IsOk());
+ EXPECT_THAT(next_state.value()->ErrorMessage().value(),
+ Eq("Too many clients aborted."));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
+ ElementsAre(IsEvent<BroadcastMessageSent>(
+ Eq(ServerToClientMessageType_Abort),
+ Eq(server_abort_message.ByteSizeLong()))));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest, MetricsRecordsMessageSizes) {
+ // In this test, client 3 never sent a masked input, so clients should send
+ // the pairwise key share for client 3.
+ TestTracingRecorder tracing_recorder;
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto impl = CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics);
+ impl->set_client_status(3, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
+
+ SecAggServerR3UnmaskingState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 1, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ // Set up correct responses
+ std::vector<ClientToServerWrapperMessage> unmasking_responses(3);
+ for (int i = 0; i < 3; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client ", i));
+ }
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_noise_sk_share(
+ absl::StrCat("Test key share for client ", 3, " from client ", i));
+ }
+
+ // No clients should actually get a message in this round.
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+ EXPECT_CALL(
+ *metrics,
+ MessageReceivedSizes(Eq(ClientToServerWrapperMessage::MessageContentCase::
+ kUnmaskingResponse),
+ Eq(true), Eq(unmasking_responses[0].ByteSizeLong())))
+ .Times(3);
+
+ // i is the number of messages received so far
+ for (int i = 0; i <= 3; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(3 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 3) {
+ ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ EXPECT_THAT(
+ tracing_recorder.root()[i],
+ IsEvent<ClientMessageReceived>(
+ Eq(ClientToServerMessageType_UnmaskingResponse),
+ Eq(unmasking_responses[i].ByteSizeLong()), Eq(true), Ge(0)));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::PRNG_RUNNING));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
+ Eq(0));
+ EXPECT_THAT(
+ next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
+ Eq(1));
+ EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
+ Eq(0));
+}
+
+TEST(SecaggServerR3UnmaskingStateTest,
+ ServerAndClientAbortsAreRecordedCorrectly) {
+ // In this test clients abort for a variety of reasons, and then ultimately
+ // the server aborts. Metrics should record all of these events.
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto impl = CreateSecAggServerProtocolImpl(2, 8, sender.get(), metrics);
+ impl->ErasePublicKeysForClient(7);
+ impl->set_client_status(6, ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED);
+ impl->set_client_status(7, ClientStatus::DEAD_BEFORE_SENDING_ANYTHING);
+
+ SecAggServerR3UnmaskingState state(
+ std::move(impl),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 2, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ EXPECT_CALL(*metrics,
+ ClientsDropped(
+ Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
+ Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
+ EXPECT_CALL(*metrics,
+ ClientsDropped(
+ Eq(ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED), _))
+ .Times(0);
+ EXPECT_CALL(*metrics,
+ ClientsDropped(
+ Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
+ Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
+ EXPECT_CALL(*metrics,
+ ClientsDropped(
+ Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
+ Eq(ClientDropReason::INVALID_UNMASKING_RESPONSE)))
+ .Times(3);
+ EXPECT_CALL(
+ *metrics,
+ ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
+
+ ClientToServerWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
+
+ ClientToServerWrapperMessage valid_message; // from client 1
+ for (int j = 0; j < 6; ++j) {
+ NoiseOrPrfKeyShare* share = valid_message.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client 1"));
+ }
+ NoiseOrPrfKeyShare* share =
+ valid_message.mutable_unmasking_response()->add_noise_or_prf_key_shares();
+ share->set_noise_sk_share(
+ absl::StrCat("Test key share for client ", 6, " from client 1"));
+ share =
+ valid_message.mutable_unmasking_response()->add_noise_or_prf_key_shares();
+
+ ClientToServerWrapperMessage invalid_noise_instead_of_prf; // from client 2
+ for (int j = 0; j < 5; ++j) {
+ share = invalid_noise_instead_of_prf.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client 2"));
+ }
+ for (int j = 5; j < 7; ++j) { // client 5 should not be included here
+ share = invalid_noise_instead_of_prf.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_noise_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client 2"));
+ }
+ share = invalid_noise_instead_of_prf.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+
+ ClientToServerWrapperMessage invalid_prf_instead_of_noise; // from client 3
+ for (int j = 0; j < 7; ++j) { // client 6 should not be included here
+ share = invalid_prf_instead_of_noise.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client 3"));
+ }
+ share = invalid_prf_instead_of_noise.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+
+ ClientToServerWrapperMessage invalid_noise_instead_of_blank; // from client 4
+ for (int j = 0; j < 6; ++j) {
+ share = invalid_noise_instead_of_blank.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client 4"));
+ }
+ for (int j = 6; j < 8; ++j) { // client 7 should not be included here
+ share = invalid_noise_instead_of_blank.mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_noise_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client 4"));
+ }
+
+ ClientToServerWrapperMessage wrong_message;
+ wrong_message.mutable_advertise_keys(); // wrong type of message
+
+ state.HandleMessage(0, abort_message).IgnoreError();
+ state.HandleMessage(1, valid_message).IgnoreError();
+ state.HandleMessage(1, valid_message).IgnoreError();
+ state.HandleMessage(2, invalid_noise_instead_of_prf).IgnoreError();
+ state.HandleMessage(3, invalid_prf_instead_of_noise).IgnoreError();
+ state.HandleMessage(4, invalid_noise_instead_of_blank).IgnoreError();
+ state.HandleMessage(5, wrong_message).IgnoreError();
+ state.ProceedToNextRound().IgnoreError(); // causes server abort
+}
+
+TEST(SecaggServerR3UnmaskingStateTest, MetricsAreRecorded) {
+ // In this test, no clients abort or aborted at any point, but
+ // ProceedToNextRound is called after only 3 clients have submitted masked
+ // input responses. This is perfectly valid because the threshold is 3.
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+
+ SecAggServerR3UnmaskingState state(
+ CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics),
+ 0, // number_of_clients_failed_after_sending_masked_input
+ 0, // number_of_clients_failed_before_sending_masked_input
+ 0); // number_of_clients_terminated_without_unmasking
+
+ // Set up correct responses
+ std::vector<ClientToServerWrapperMessage> unmasking_responses(4);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ NoiseOrPrfKeyShare* share = unmasking_responses[i]
+ .mutable_unmasking_response()
+ ->add_noise_or_prf_key_shares();
+ share->set_prf_sk_share(
+ absl::StrCat("Test key share for client ", j, " from client ", i));
+ }
+ }
+
+ // Only client 3 should get a message this round.
+ ServerToClientWrapperMessage abort_message;
+ abort_message.mutable_abort()->set_diagnostic_info(
+ "Client did not send unmasking response but protocol completed "
+ "successfully.");
+ abort_message.mutable_abort()->set_early_success(true);
+ EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
+ EXPECT_CALL(*sender, Send(Ne(3), _)).Times(0);
+ EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
+
+ EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::R3_UNMASKING),
+ Eq(true), Ge(0)));
+ EXPECT_CALL(*metrics, RoundSurvivingClients(
+ Eq(SecAggServerStateKind::R3_UNMASKING), Eq(3)));
+ EXPECT_CALL(*metrics, ClientResponseTimes(
+ Eq(ClientToServerWrapperMessage::
+ MessageContentCase::kUnmaskingResponse),
+ Ge(0)))
+ .Times(3);
+ EXPECT_CALL(*metrics,
+ ClientsDropped(
+ Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
+ Eq(ClientDropReason::EARLY_SUCCESS)));
+
+ // i is the number of messages received so far. Stop after 3
+ for (int i = 0; i <= 3; ++i) {
+ EXPECT_THAT(state.NeedsToAbort(), IsFalse());
+ EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
+ EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
+ EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
+ if (i < 3) {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
+ EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
+ } else {
+ EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
+ EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
+ }
+ if (i < 3) {
+ ASSERT_THAT(state.HandleMessage(i, unmasking_responses[i]), IsOk());
+ EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
+ }
+ }
+
+ auto next_state = state.ProceedToNextRound();
+ ASSERT_THAT(next_state, IsOk());
+ EXPECT_THAT(next_state.value()->State(),
+ Eq(SecAggServerStateKind::PRNG_RUNNING));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_state.cc b/fcp/secagg/server/secagg_server_state.cc
new file mode 100644
index 0000000..8a6445e
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_state.cc
@@ -0,0 +1,320 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server_state.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/node_hash_set.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secagg_server_aborted_state.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_trace_utility.h"
+#include "fcp/secagg/server/tracing_schema.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/tracing/tracing_span.h"
+
+namespace fcp {
+namespace secagg {
+
+SecAggServerState::SecAggServerState(
+ int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking,
+ SecAggServerStateKind state_kind,
+ std::unique_ptr<SecAggServerProtocolImpl> impl)
+ : needs_to_abort_(false),
+ number_of_clients_failed_after_sending_masked_input_(
+ number_of_clients_failed_after_sending_masked_input),
+ number_of_clients_failed_before_sending_masked_input_(
+ number_of_clients_failed_before_sending_masked_input),
+ number_of_clients_ready_for_next_round_(0),
+ number_of_clients_terminated_without_unmasking_(
+ number_of_clients_terminated_without_unmasking),
+ number_of_messages_received_in_this_round_(0),
+ round_start_(absl::Now()),
+ state_kind_(state_kind),
+ impl_(std::move(impl)) {}
+
+SecAggServerState::~SecAggServerState() {}
+
+std::unique_ptr<SecAggServerProtocolImpl>&& SecAggServerState::ExitState(
+ StateTransition state_transition_status) {
+ bool record_success = state_transition_status == StateTransition::kSuccess;
+ auto elapsed_time = absl::ToInt64Milliseconds(absl::Now() - round_start_);
+ if (metrics()) {
+ metrics()->RoundTimes(state_kind_, record_success, elapsed_time);
+ metrics()->RoundSurvivingClients(state_kind_, NumberOfAliveClients());
+
+ // Fractions of clients by state
+ absl::flat_hash_map<ClientStatus, int> counts_by_state;
+ for (uint32_t i = 0; i < total_number_of_clients(); i++) {
+ counts_by_state[client_status(i)]++;
+ }
+ for (const auto& count_by_state : counts_by_state) {
+ double fraction = static_cast<double>(count_by_state.second) /
+ total_number_of_clients();
+ Trace<ClientCountsPerState>(TracingState(state_kind_),
+ ClientStatusType(count_by_state.first),
+ count_by_state.second, fraction);
+ metrics()->RoundCompletionFractions(state_kind_, count_by_state.first,
+ fraction);
+ }
+ }
+ Trace<StateCompletion>(TracingState(state_kind_), record_success,
+ elapsed_time, NumberOfAliveClients());
+ return std::move(impl_);
+}
+
+// These methods return default values unless overridden.
+bool SecAggServerState::IsAborted() const { return false; }
+bool SecAggServerState::IsCompletedSuccessfully() const { return false; }
+int SecAggServerState::NumberOfPendingClients() const { return 0; }
+int SecAggServerState::NumberOfIncludedInputs() const { return 0; }
+int SecAggServerState::MinimumMessagesNeededForNextRound() const { return 0; }
+bool SecAggServerState::ReadyForNextRound() const { return false; }
+
+Status SecAggServerState::HandleMessage(
+ uint32_t client_id, const ClientToServerWrapperMessage& message) {
+ MessageReceived(message, false);
+ if (message.message_content_case() ==
+ ClientToServerWrapperMessage::MESSAGE_CONTENT_NOT_SET) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Server received a message of unknown type from client "
+ << client_id << " but was in state " << StateName();
+ } else {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Server received a message of type "
+ << message.message_content_case() << " from client " << client_id
+ << " but was in state " << StateName();
+ }
+}
+
+Status SecAggServerState::HandleMessage(
+ uint32_t client_id, std::unique_ptr<ClientToServerWrapperMessage> message) {
+ return HandleMessage(client_id, *message);
+}
+
+StatusOr<std::unique_ptr<SecAggServerState>>
+SecAggServerState::ProceedToNextRound() {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "The server cannot proceed to next round from state "
+ << StateName();
+}
+
+bool SecAggServerState::IsClientDead(uint32_t client_id) const {
+ switch (client_status(client_id)) {
+ case ClientStatus::DEAD_BEFORE_SENDING_ANYTHING:
+ case ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED:
+ case ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED:
+ case ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED:
+ case ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED:
+ return true;
+ break;
+ default:
+ return false;
+ }
+}
+
+void SecAggServerState::AbortClient(uint32_t client_id,
+ const std::string& reason,
+ ClientDropReason reason_code, bool notify,
+ bool log_metrics) {
+ FCP_CHECK(!(IsAborted() || IsCompletedSuccessfully()));
+
+ if (IsClientDead(client_id)) {
+ return; // without sending a message
+ }
+
+ HandleAbortClient(client_id, reason_code);
+ if (notify) {
+ ServerToClientWrapperMessage message;
+ message.mutable_abort()->set_diagnostic_info(reason);
+ message.mutable_abort()->set_early_success(reason_code ==
+ ClientDropReason::EARLY_SUCCESS);
+ Send(client_id, message);
+ }
+ // Clients that have successfully completed the protocol should not be logging
+ // metrics.
+ if (metrics() && log_metrics &&
+ client_status(client_id) !=
+ ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED) {
+ metrics()->ClientsDropped(client_status(client_id), reason_code);
+ }
+ auto elapsed_millis = absl::ToInt64Milliseconds(absl::Now() - round_start_);
+ Trace<ClientsDropped>(ClientStatusType(client_status(client_id)),
+ ClientDropReasonType(reason_code), elapsed_millis,
+ reason);
+}
+
+std::unique_ptr<SecAggServerState> SecAggServerState::AbortState(
+ const std::string& reason, SecAggServerOutcome outcome) {
+ if (metrics()) {
+ metrics()->ProtocolOutcomes(outcome);
+ }
+ Trace<SecAggProtocolOutcome>(ConvertSecAccServerOutcomeToTrace(outcome));
+ return std::make_unique<SecAggServerAbortedState>(
+ reason, ExitState(StateTransition::kAbort),
+ number_of_clients_failed_after_sending_masked_input_,
+ number_of_clients_failed_before_sending_masked_input_,
+ number_of_clients_terminated_without_unmasking_);
+}
+
+std::unique_ptr<SecAggServerState> SecAggServerState::Abort(
+ const std::string& reason, SecAggServerOutcome outcome) {
+ FCP_CHECK(!(IsAborted() || IsCompletedSuccessfully()));
+
+ HandleAbort();
+
+ ServerToClientWrapperMessage message;
+ message.mutable_abort()->set_early_success(false);
+ message.mutable_abort()->set_diagnostic_info(reason);
+ SendBroadcast(message);
+
+ return AbortState(reason, outcome);
+}
+
+StatusOr<std::string> SecAggServerState::ErrorMessage() const {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Error message requested, but server is in state " << StateName();
+}
+
+int SecAggServerState::NumberOfAliveClients() const {
+ return total_number_of_clients() -
+ number_of_clients_failed_before_sending_masked_input_ -
+ number_of_clients_failed_after_sending_masked_input_ -
+ number_of_clients_terminated_without_unmasking_;
+}
+
+int SecAggServerState::NumberOfMessagesReceivedInThisRound() const {
+ return number_of_messages_received_in_this_round_;
+}
+
+int SecAggServerState::NumberOfClientsReadyForNextRound() const {
+ return number_of_clients_ready_for_next_round_;
+}
+
+int SecAggServerState::NumberOfClientsFailedAfterSendingMaskedInput() const {
+ return number_of_clients_failed_after_sending_masked_input_;
+}
+
+int SecAggServerState::NumberOfClientsFailedBeforeSendingMaskedInput() const {
+ return number_of_clients_failed_before_sending_masked_input_;
+}
+
+int SecAggServerState::NumberOfClientsTerminatedWithoutUnmasking() const {
+ return number_of_clients_terminated_without_unmasking_;
+}
+
+bool SecAggServerState::NeedsToAbort() const { return needs_to_abort_; }
+
+absl::flat_hash_set<uint32_t> SecAggServerState::AbortedClientIds() const {
+ auto aborted_client_ids_ = absl::flat_hash_set<uint32_t>();
+ for (int i = 0; i < total_number_of_clients(); ++i) {
+ // Clients that have successfully completed the protocol are not reported
+ // as aborted.
+ if (IsClientDead(i)) {
+ aborted_client_ids_.insert(i);
+ }
+ }
+ return aborted_client_ids_;
+}
+
+bool SecAggServerState::SetAsyncCallback(std::function<void()> async_callback) {
+ return false;
+}
+
+StatusOr<std::unique_ptr<SecAggVectorMap>> SecAggServerState::Result() {
+ return FCP_STATUS(UNAVAILABLE)
+ << "Result requested, but server is in state " << StateName();
+}
+
+SecAggServerStateKind SecAggServerState::State() const { return state_kind_; }
+
+std::string SecAggServerState::StateName() const {
+ switch (state_kind_) {
+ case SecAggServerStateKind::ABORTED:
+ return "Aborted";
+ case SecAggServerStateKind::COMPLETED:
+ return "Completed";
+ case SecAggServerStateKind::PRNG_RUNNING:
+ return "PrngRunning";
+ case SecAggServerStateKind::R0_ADVERTISE_KEYS:
+ return "R0AdvertiseKeys";
+ case SecAggServerStateKind::R1_SHARE_KEYS:
+ return "R1ShareKeys";
+ case SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION:
+ return "R2MaskedInputCollection";
+ case SecAggServerStateKind::R3_UNMASKING:
+ return "R3Unmasking";
+ default:
+ return "Unknown";
+ }
+}
+
+void SecAggServerState::MessageReceived(
+ const ClientToServerWrapperMessage& message, bool expected) {
+ auto elapsed_millis = absl::ToInt64Milliseconds(absl::Now() - round_start_);
+ if (metrics()) {
+ if (expected) {
+ metrics()->ClientResponseTimes(message.message_content_case(),
+ elapsed_millis);
+ }
+ metrics()->MessageReceivedSizes(message.message_content_case(), expected,
+ message.ByteSizeLong());
+ }
+ Trace<ClientMessageReceived>(GetClientToServerMessageType(message),
+ message.ByteSizeLong(), expected,
+ elapsed_millis);
+}
+
+void SecAggServerState::SendBroadcast(
+ const ServerToClientWrapperMessage& message) {
+ FCP_CHECK(message.message_content_case() !=
+ ServerToClientWrapperMessage::MESSAGE_CONTENT_NOT_SET);
+ if (metrics()) {
+ metrics()->BroadcastMessageSizes(message.message_content_case(),
+ message.ByteSizeLong());
+ }
+ sender()->SendBroadcast(message);
+ Trace<BroadcastMessageSent>(GetServerToClientMessageType(message),
+ message.ByteSizeLong());
+}
+
+void SecAggServerState::Send(uint32_t recipient_id,
+ const ServerToClientWrapperMessage& message) {
+ FCP_CHECK(message.message_content_case() !=
+ ServerToClientWrapperMessage::MESSAGE_CONTENT_NOT_SET);
+ if (metrics()) {
+ metrics()->IndividualMessageSizes(message.message_content_case(),
+ message.ByteSizeLong());
+ }
+ sender()->Send(recipient_id, message);
+
+ Trace<IndividualMessageSent>(recipient_id,
+ GetServerToClientMessageType(message),
+ message.ByteSizeLong());
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_server_state.h b/fcp/secagg/server/secagg_server_state.h
new file mode 100644
index 0000000..5a00cee
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_state.h
@@ -0,0 +1,314 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_STATE_H_
+#define FCP_SECAGG_SERVER_SECAGG_SERVER_STATE_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/time/time.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_protocol_impl.h"
+#include "fcp/secagg/server/tracing_schema.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// This is an abstract class which is the parent of the other SecAggServer*State
+// classes. It should not be instantiated directly. Default versions of all the
+// methods declared here are provided for use by states which do not expect, and
+// therefore do not implement, those methods.
+
+class SecAggServerState {
+ public:
+ // Returns the number of clients selected to be in the cohort for this
+ // instance of Secure Aggregation.
+ inline const size_t total_number_of_clients() const {
+ return impl_->total_number_of_clients();
+ }
+
+ // Returns the number of neighbors of each client.
+ inline const int number_of_neighbors() const {
+ return impl_->number_of_neighbors();
+ }
+
+ // Returns the minimum number of neighbors of a client that must not drop-out
+ // for that client's contribution to be included in the sum. This corresponds
+ // to the threshold in the shamir secret sharing of self and pairwise masks.
+ inline const int minimum_surviving_neighbors_for_reconstruction() const {
+ return impl_->minimum_surviving_neighbors_for_reconstruction();
+ }
+
+ // Returns the index of client_id_2 in the list of neighbors of client_id_1,
+ // if present
+ inline const std::optional<int> GetNeighborIndex(int client_id_1,
+ int client_id_2) const {
+ return impl_->GetNeighborIndex(client_id_1, client_id_2);
+ }
+
+ // EnterState must be called just after transitioning to a state.
+ // States may use this to initialize their state or trigger work.
+ virtual void EnterState() {}
+
+ // Processes the received message in a way consistent with the current state.
+ //
+ // Returns OK status to indicate that the message has been handled
+ // successfully.
+ //
+ // Returns a FAILED_PRECONDITION status if the server is in a state from which
+ // it does not expect to receive any messages. In that case no reply will be
+ // sent.
+ virtual Status HandleMessage(uint32_t client_id,
+ const ClientToServerWrapperMessage& message);
+ // Analog of the above method, bu giving ownership of the message.
+ virtual Status HandleMessage(
+ uint32_t client_id,
+ std::unique_ptr<ClientToServerWrapperMessage> message);
+
+ // Proceeds to the next round, doing all necessary computation and sending
+ // messages to clients as appropriate. If the server is not yet ready to
+ // proceed, returns an UNAVAILABLE status.
+ //
+ // If the server is in a terminal state, returns a FAILED_PRECONDITION status.
+ //
+ // Otherwise, returns the new state. This may be an abort state if the server
+ // has aborted.
+ //
+ // If this method returns a new state (i.e. if the status is OK), then the old
+ // state is no longer valid and the new state must be considered the current
+ // state. If it returns a non-OK status, this method does not change the
+ // underlying state.
+ virtual StatusOr<std::unique_ptr<SecAggServerState>> ProceedToNextRound();
+
+ // Returns true if the client state is considered to be "dead" e.g. aborted or
+ // disconnected; otherwise returns false.
+ bool IsClientDead(uint32_t client_id) const;
+
+ // Abort the specified client for the given reason. If notify is true, send a
+ // notification message to the client. (If the client was already closed, no
+ // message will be sent).
+ //
+ // The reason code will be used for recording metrics if log_metrics is true,
+ // else no metrics are recorded. By default, metrics will always be logged.
+ void AbortClient(uint32_t client_id, const std::string& reason,
+ ClientDropReason reason_code, bool notify = true,
+ bool log_metrics = true);
+
+ // Aborts the protocol for the specified reason. Notifies all clients of
+ // the abort. Returns the new state.
+ // Calling this method on a terminal state isn't valid.
+ std::unique_ptr<SecAggServerState> Abort(const std::string& reason,
+ SecAggServerOutcome outcome);
+
+ // Returns true if the current state is Abort, false else.
+ virtual bool IsAborted() const;
+
+ // Returns true if the current state is ProtocolCompleted, false else.
+ virtual bool IsCompletedSuccessfully() const;
+
+ // Returns an error message explaining why the server aborted, if the current
+ // state is an abort state. If not returns an error Status with code
+ // FAILED_PRECONDITION.
+ virtual StatusOr<std::string> ErrorMessage() const;
+
+ // Returns an enum specifying the current state.
+ SecAggServerStateKind State() const;
+
+ // Returns the name of the current state in the form of a short string.
+ std::string StateName() const;
+
+ // Returns whether or not the server has received enough messages to be ready
+ // for the next phase of the protocol.
+ // In the PRNG Running state, it returns whether or not the PRNG has stopped
+ // running.
+ // Always false in a terminal state.
+ virtual bool ReadyForNextRound() const;
+
+ // Returns the number of valid messages received by clients this round.
+ int NumberOfMessagesReceivedInThisRound() const;
+
+ // Returns the number of clients that would still be alive if
+ // ProceedToNextRound were called immediately after. This value may be less
+ // than NumberOfMessagesReceivedInThisRound if a client fails after sending a
+ // message in this round.
+ // Note that this value is not guaranteed to be monotonically increasing, even
+ // within a round. Client failures can cause this value to decrease.
+ virtual int NumberOfClientsReadyForNextRound() const;
+
+ // Indicates the total number of clients that the server expects to receive a
+ // response from in this round (i.e. the ones that have not aborted).
+ // In the COMPLETED state, this returns the number of clients that survived to
+ // the final protocol message.
+ virtual int NumberOfAliveClients() const;
+
+ // Number of clients that failed before submitting their masked input. These
+ // clients' inputs won't be included in the aggregate value, even if the
+ // protocol succeeds.
+ int NumberOfClientsFailedBeforeSendingMaskedInput() const;
+
+ // Number of clients that failed after submitting their masked input. These
+ // clients' inputs will be included in the aggregate value, even though these
+ // clients did not complete the protocol.
+ int NumberOfClientsFailedAfterSendingMaskedInput() const;
+
+ // Number of clients that submitted a masked value, but didn't report their
+ // unmasking values fast enough to have them used in the final unmasking
+ // process. These clients' inputs will be included in the aggregate value.
+ int NumberOfClientsTerminatedWithoutUnmasking() const;
+
+ // Returns the number of live clients that have not yet submitted the expected
+ // response for the current round. In terminal states, this will be 0.
+ virtual int NumberOfPendingClients() const;
+
+ // Returns the number of inputs that will appear in the final sum, if the
+ // protocol completes.
+ // Once IsNumberOfIncludedInputsCommitted is true, this value will be fixed
+ // for the remainder of the protocol.
+ // This will be 0 if the server is aborted. This will also be 0 if the server
+ // is in an early state, prior to receiving masked inputs. It is incremented
+ // only when the server receives a masked input from a client.
+ virtual int NumberOfIncludedInputs() const;
+
+ // Whether the set of inputs that will be included in the final aggregation is
+ // fixed.
+ // If true, the value of NumberOfIncludedInputs will be fixed for the
+ // remainder of the protocol.
+ virtual bool IsNumberOfIncludedInputsCommitted() const = 0;
+
+ // Indicates the minimum number of valid messages needed to be able to
+ // successfully move to the next round.
+ // Note that this value is not guaranteed to be monotonically decreasing.
+ // Client failures can cause this value to increase.
+ // In terminal states, this returns 0.
+ virtual int MinimumMessagesNeededForNextRound() const;
+
+ // Returns the minimum threshold number of clients that need to send valid
+ // responses in order for the protocol to proceed from one round to the next.
+ inline const int minimum_number_of_clients_to_proceed() const {
+ return impl_->minimum_number_of_clients_to_proceed();
+ }
+
+ // Returns the set of clients that aborted the protocol. Can be used by the
+ // caller to close the relevant RPC connections or just start ignoring
+ // incoming messages from those clients for performance reasons.
+ absl::flat_hash_set<uint32_t> AbortedClientIds() const;
+
+ // Returns true if the server has determined that it needs to abort itself,
+ // If the server is in a terminal state, returns false.
+ bool NeedsToAbort() const;
+
+ // Sets up a callback to be triggered when any background asynchronous work
+ // has been done. The callback is guaranteed to invoked via the server's
+ // callback scheduler.
+ //
+ // Returns true if the state supports asynchronous processing and the callback
+ // has been setup successfully.
+ // Returns false if the state doesn't support asynchronous processing or if
+ // no further asynchronous processing is possible. The callback argument is
+ // ignored in this case.
+ virtual bool SetAsyncCallback(std::function<void()> async_callback);
+
+ // Transfers ownership of the result of the protocol to the caller. Requires
+ // the server to be in a completed state; returns UNAVAILABLE otherwise.
+ // Can be called only once; any consecutive calls result in an error.
+ virtual StatusOr<std::unique_ptr<SecAggVectorMap>> Result();
+
+ virtual ~SecAggServerState();
+
+ protected:
+ // SecAggServerState should never be instantiated directly.
+ SecAggServerState(int number_of_clients_failed_after_sending_masked_input,
+ int number_of_clients_failed_before_sending_masked_input,
+ int number_of_clients_terminated_without_unmasking,
+ SecAggServerStateKind state_kind,
+ std::unique_ptr<SecAggServerProtocolImpl> impl);
+
+ SecAggServerProtocolImpl* impl() { return impl_.get(); }
+
+ // Returns the callback interface for recording metrics.
+ inline SecAggServerMetricsListener* metrics() const {
+ return impl_->metrics();
+ }
+
+ // Returns the callback interface for sending protocol buffer messages to the
+ // client.
+ inline SendToClientsInterface* sender() const { return impl_->sender(); }
+
+ inline const ClientStatus& client_status(uint32_t client_id) const {
+ return impl_->client_status(client_id);
+ }
+
+ inline void set_client_status(uint32_t client_id, ClientStatus status) {
+ impl_->set_client_status(client_id, status);
+ }
+
+ // Records information about a message that was received from a client.
+ void MessageReceived(const ClientToServerWrapperMessage& message,
+ bool expected);
+
+ // Broadcasts the message and records metrics.
+ void SendBroadcast(const ServerToClientWrapperMessage& message);
+
+ // Sends the message to the given client and records metrics.
+ void Send(uint32_t recipient_id, const ServerToClientWrapperMessage& message);
+
+ // Returns an aborted version of the current state, storing the specified
+ // reason. Calling this method makes the current state unusable. The caller is
+ // responsible for sending any failure messages that need to be sent, and for
+ // doing so BEFORE calling this method.
+ // The SecAggServerOutcome outcome is used for recording metrics.
+ std::unique_ptr<SecAggServerState> AbortState(const std::string& reason,
+ SecAggServerOutcome outcome);
+
+ // ExitState must be called on the current state just before transitioning to
+ // a new state to record metrics and transfer out the shared state.
+ enum class StateTransition {
+ // Indicates a successful state transition to any state other than Aborted.
+ kSuccess = 0,
+ // Indicates transition to Aborted state.
+ kAbort = 1
+ };
+ std::unique_ptr<SecAggServerProtocolImpl>&& ExitState(
+ StateTransition state_transition_status);
+
+ bool needs_to_abort_;
+ int number_of_clients_failed_after_sending_masked_input_;
+ int number_of_clients_failed_before_sending_masked_input_;
+ int number_of_clients_ready_for_next_round_;
+ int number_of_clients_terminated_without_unmasking_;
+ int number_of_messages_received_in_this_round_;
+ absl::Time round_start_;
+ SecAggServerStateKind state_kind_;
+
+ private:
+ // Performs state specific action when a client is aborted.
+ virtual void HandleAbortClient(uint32_t client_id,
+ ClientDropReason reason_code) {}
+
+ // Performs state specific action when the server is aborted.
+ virtual void HandleAbort() {}
+
+ std::unique_ptr<SecAggServerProtocolImpl> impl_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_SERVER_STATE_H_
diff --git a/fcp/secagg/server/secagg_server_test.cc b/fcp/secagg/server/secagg_server_test.cc
new file mode 100644
index 0000000..834bf7e
--- /dev/null
+++ b/fcp/secagg/server/secagg_server_test.cc
@@ -0,0 +1,404 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secagg_server.h"
+
+#include <cstddef>
+#include <memory>
+#include <set>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_state.h"
+#include "fcp/secagg/server/tracing_schema.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
+#include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
+#include "fcp/secagg/testing/server/test_secagg_experiments.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::_;
+using ::testing::Eq;
+
+std::unique_ptr<SecAggServer> CreateServer(
+ SendToClientsInterface* sender,
+ SecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener(),
+ std::unique_ptr<TestSecAggExperiment> experiments =
+ std::make_unique<TestSecAggExperiment>()) {
+ SecureAggregationRequirements threat_model;
+ threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
+ threat_model.set_adversarial_client_rate(.3);
+ threat_model.set_estimated_dropout_rate(.3);
+ std::unique_ptr<AesPrngFactory> prng_factory;
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto status_or_server = SecAggServer::Create(
+ 100, // minimum_number_of_clients_to_proceed
+ 1000, // total_number_of_clients
+ input_vector_specs, sender,
+ std::unique_ptr<SecAggServerMetricsListener>(metrics),
+ /*prng_runner=*/nullptr, std::move(experiments), threat_model);
+ EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status();
+ return std::move(status_or_server.value());
+}
+
+template <typename... M>
+auto TraceRecorderHas(const M&... matchers) {
+ return ElementsAre(AllOf(
+ IsSpan<CreateSecAggServer>(),
+ ElementsAre(
+ IsEvent<SubGraphServerParameters>(
+ 1000, // number_of_clients
+ 219, // degree
+ 116, // threshold
+ 700, // minimum_number_of_clients_to_proceed
+ false), // is_r2_async_aggregation_enabled
+ AllOf(IsSpan<SecureAggServerSession>(), ElementsAre(matchers...)))));
+}
+
+TEST(SecaggServerTest, ConstructedWithCorrectState) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto server = CreateServer(sender.get());
+
+ EXPECT_THAT(server->IsAborted(), Eq(false));
+ EXPECT_THAT(server->NumberOfNeighbors(), Eq(219));
+ EXPECT_THAT(server->IsCompletedSuccessfully(), Eq(false));
+ EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS));
+ EXPECT_THAT(tracing_recorder.root(),
+ TraceRecorderHas(IsSpan<SecureAggServerState>(
+ SecAggServerTraceState_R0AdvertiseKeys)));
+}
+
+TEST(SecaggServerTest, FullgraphSecAggExperimentTakesEffect) {
+ // Tests FullgraphSecAggExperiment by instatiating
+ // a server under that experiment , and
+ // checking that it results in the expected number of neighbors for the given
+ // setting (1000 clients) and threat model (.3 dropout rate and .3 adversarial
+ // client rate).
+ SecureAggregationRequirements threat_model;
+ threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
+ threat_model.set_adversarial_client_rate(.3);
+ threat_model.set_estimated_dropout_rate(.3);
+ std::unique_ptr<AesPrngFactory> prng_factory;
+ std::vector<InputVectorSpecification> input_vector_specs;
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ std::set<std::string> experiment_names = {kFullgraphSecAggExperiment};
+ auto status_or_server = SecAggServer::Create(
+ 100, // minimum_number_of_clients_to_proceed
+ 1000, // total_number_of_clients
+ input_vector_specs, sender.get(),
+ std::unique_ptr<SecAggServerMetricsListener>(
+ new MockSecAggServerMetricsListener()),
+ /*prng_runner=*/nullptr,
+ std::make_unique<TestSecAggExperiment>(experiment_names), threat_model);
+ EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status();
+ EXPECT_THAT(status_or_server.value()->NumberOfNeighbors(), Eq(1000));
+ EXPECT_THAT(status_or_server.value()->IsAborted(), Eq(false));
+ EXPECT_THAT(status_or_server.value()->IsCompletedSuccessfully(), Eq(false));
+ EXPECT_THAT(status_or_server.value()->State(),
+ Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS));
+}
+
+TEST(SecaggServerTest, SubgraphSecAggResortsToFullGraphOnSmallCohorts) {
+ // Tests that a small number of clients for which subgraph-secagg does not
+ // have favorable parameters results in executiong the full-graph varian
+ SecureAggregationRequirements threat_model;
+ threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER);
+ threat_model.set_adversarial_client_rate(.45);
+ threat_model.set_estimated_dropout_rate(.45);
+ std::unique_ptr<AesPrngFactory> prng_factory;
+ std::vector<InputVectorSpecification> input_vector_specs;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
+ std::set<std::string> experiment_names = {};
+ auto status_or_server = SecAggServer::Create(
+ 5, // minimum_number_of_clients_to_proceed
+ 25, // total_number_of_clients
+ input_vector_specs, sender.get(),
+ std::unique_ptr<SecAggServerMetricsListener>(
+ new MockSecAggServerMetricsListener()),
+ /*prng_runner=*/nullptr,
+ std::make_unique<TestSecAggExperiment>(experiment_names), threat_model);
+ EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status();
+ EXPECT_THAT(status_or_server.value()->NumberOfNeighbors(), Eq(25));
+ EXPECT_THAT(
+ status_or_server.value()->MinimumSurvivingNeighborsForReconstruction(),
+ Eq(14));
+ EXPECT_THAT(status_or_server.value()->IsAborted(), Eq(false));
+ EXPECT_THAT(status_or_server.value()->IsCompletedSuccessfully(), Eq(false));
+ EXPECT_THAT(status_or_server.value()->State(),
+ Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS));
+}
+
+TEST(SecaggServerTest, AbortClientWithInvalidIdThrowsError) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto server = CreateServer(sender.get());
+
+ EXPECT_THAT(
+ server->AbortClient(1001, ClientAbortReason::CONNECTION_DROPPED).code(),
+ Eq(FAILED_PRECONDITION));
+}
+
+TEST(SecaggServerTest, ReceiveMessageWithInvalidIdThrowsError) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto server = CreateServer(sender.get());
+
+ ClientToServerWrapperMessage client_abort_message;
+ client_abort_message.mutable_abort()->set_diagnostic_info("Abort for test.");
+ EXPECT_THAT(
+ server
+ ->ReceiveMessage(1001, std::make_unique<ClientToServerWrapperMessage>(
+ client_abort_message))
+ .status()
+ .code(),
+ Eq(FAILED_PRECONDITION));
+}
+
+TEST(SecaggServerTest, AbortCausesStateTransitionAndMessageToBeSent) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto server = CreateServer(sender.get());
+
+ const ServerToClientWrapperMessage abort_message = PARSE_TEXT_PROTO(R"pb(
+ abort: {
+ early_success: false
+ diagnostic_info: "Abort upon external request."
+ })pb");
+
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
+ Status result = server->Abort();
+
+ EXPECT_THAT(result.code(), Eq(OK));
+ EXPECT_THAT(server->IsAborted(), Eq(true));
+ EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(server->ErrorMessage().ok(), Eq(true));
+ EXPECT_THAT(server->ErrorMessage().value(),
+ Eq("Abort upon external request."));
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ TraceRecorderHas(
+ AllOf(IsSpan<SecureAggServerState>(
+ SecAggServerTraceState_R0AdvertiseKeys),
+ ElementsAre(
+ IsSpan<AbortSecAggServer>("Abort upon external request."))),
+ IsSpan<SecureAggServerState>(SecAggServerTraceState_Aborted)));
+}
+
+TEST(SecaggServerTest, AbortWithReasonCausesStateTransitionAndMessageToBeSent) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto server = CreateServer(sender.get());
+
+ const ServerToClientWrapperMessage abort_message = PARSE_TEXT_PROTO(R"pb(
+ abort: {
+ early_success: false
+ diagnostic_info: "Abort upon external request for reason <Test reason.>."
+ })pb");
+
+ EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
+ Status result =
+ server->Abort("Test reason.", SecAggServerOutcome::EXTERNAL_REQUEST);
+
+ EXPECT_THAT(result.code(), Eq(OK));
+ EXPECT_THAT(server->IsAborted(), Eq(true));
+ EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED));
+ ASSERT_THAT(server->ErrorMessage().ok(), Eq(true));
+ EXPECT_THAT(server->ErrorMessage().value(),
+ Eq("Abort upon external request for reason <Test reason.>."));
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ TraceRecorderHas(
+ AllOf(IsSpan<SecureAggServerState>(
+ SecAggServerTraceState_R0AdvertiseKeys),
+ ElementsAre(IsSpan<AbortSecAggServer>(
+ "Abort upon external request for reason <Test "
+ "reason.>."))),
+ IsSpan<SecureAggServerState>(SecAggServerTraceState_Aborted)));
+}
+
+TEST(SecaggServerTest, AbortClientNotCheckedIn) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto server = CreateServer(sender.get(), metrics);
+
+ EXPECT_CALL(*metrics, ClientsDropped(
+ Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
+ Eq(ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT)))
+ .Times(0);
+ // Client is not notified
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+ Status result = server->AbortClient(2, ClientAbortReason::NOT_CHECKED_IN);
+
+ EXPECT_THAT(result.code(), Eq(OK));
+ EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true));
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ TraceRecorderHas(AllOf(
+ IsSpan<SecureAggServerState>(SecAggServerTraceState_R0AdvertiseKeys),
+ ElementsAre(IsSpan<AbortSecAggClient>(2, "NOT_CHECKED_IN")))));
+}
+
+TEST(SecaggServerTest, AbortClientWhenConnectionDropped) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto server = CreateServer(sender.get(), metrics);
+
+ EXPECT_CALL(*metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
+ Eq(ClientDropReason::CONNECTION_CLOSED)));
+ // Client is not notified
+ EXPECT_CALL(*sender, Send(_, _)).Times(0);
+ Status result = server->AbortClient(2, ClientAbortReason::CONNECTION_DROPPED);
+
+ EXPECT_THAT(result.code(), Eq(OK));
+ EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true));
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ TraceRecorderHas(AllOf(
+ IsSpan<SecureAggServerState>(SecAggServerTraceState_R0AdvertiseKeys),
+ ElementsAre(IsSpan<AbortSecAggClient>(2, "CONNECTION_DROPPED")))));
+}
+
+TEST(SecaggServerTest, AbortClientWhenInvalidMessageSent) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ MockSecAggServerMetricsListener* metrics =
+ new MockSecAggServerMetricsListener();
+ auto server = CreateServer(sender.get(), metrics);
+
+ const ServerToClientWrapperMessage message = PARSE_TEXT_PROTO(R"pb(
+ abort: {
+ early_success: false
+ diagnostic_info: "The protocol is closing client with ClientAbortReason <INVALID_MESSAGE>."
+ })pb");
+ EXPECT_CALL(*sender, Send(2, EqualsProto(message)));
+
+ EXPECT_CALL(
+ *metrics,
+ ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
+ Eq(ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT)));
+ Status result = server->AbortClient(2, ClientAbortReason::INVALID_MESSAGE);
+
+ EXPECT_THAT(result.code(), Eq(OK));
+ EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true));
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ TraceRecorderHas(AllOf(
+ IsSpan<SecureAggServerState>(SecAggServerTraceState_R0AdvertiseKeys),
+ ElementsAre(IsSpan<AbortSecAggClient>(2, "INVALID_MESSAGE")))));
+}
+
+TEST(SecaggServerTest, ReceiveMessageCausesServerToAbortIfTooManyClientsAbort) {
+ // The actual behavior of the server upon receipt of messages is tested in the
+ // state class test files, but this tests the special behavior that the server
+ // should automatically transition to an abort state if it cannot continue.
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto server = CreateServer(sender.get());
+ StatusOr<int> clients_needed = server->MinimumMessagesNeededForNextRound();
+ ASSERT_THAT(clients_needed.ok(), Eq(true));
+ int maximum_number_of_aborts =
+ server->NumberOfAliveClients() - clients_needed.value();
+ EcdhPregeneratedTestKeys ecdh_keys;
+ ClientToServerWrapperMessage client_abort_message;
+ client_abort_message.mutable_abort()->set_diagnostic_info("Abort for test.");
+
+ // Receiving `maximum_number_of_aborts - 1` aborts should not cause the entire
+ // protocol to abort.
+ std::vector<Matcher<const TestTracingRecorder::SpanOrEvent&>> matchers;
+ for (int i = 0; i < maximum_number_of_aborts; ++i) {
+ StatusOr<bool> result = server->ReceiveMessage(
+ i,
+ std::make_unique<ClientToServerWrapperMessage>(client_abort_message));
+ matchers.push_back(IsSpan<ReceiveSecAggMessage>(i));
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(result.value(), Eq(false));
+ EXPECT_THAT(server->IsAborted(), Eq(false));
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ TraceRecorderHas(AllOf(IsSpan<SecureAggServerState>(
+ SecAggServerTraceState_R0AdvertiseKeys),
+ ElementsAreArray(matchers))));
+ }
+ // Receiving `maximum_number_of_aborts` aborts means the protocol is ready to
+ // proceed to the aborted state, which is indicated by ReceiveMessage
+ // returning true.
+ StatusOr<bool> result = server->ReceiveMessage(
+ maximum_number_of_aborts,
+ std::make_unique<ClientToServerWrapperMessage>(client_abort_message));
+ matchers.push_back(IsSpan<ReceiveSecAggMessage>(maximum_number_of_aborts));
+ ASSERT_THAT(result.ok(), Eq(true));
+ EXPECT_THAT(result.value(), Eq(true));
+ // However the server is not aborted until ProceedToNextRound is called.
+ EXPECT_THAT(server->IsAborted(), Eq(false));
+
+ EXPECT_THAT(server->ProceedToNextRound(), IsOk());
+ matchers.push_back(IsSpan<ProceedToNextSecAggRound>());
+ EXPECT_THAT(server->IsAborted(), Eq(true));
+ EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED));
+
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ TraceRecorderHas(
+ AllOf(IsSpan<SecureAggServerState>(
+ SecAggServerTraceState_R0AdvertiseKeys),
+ ElementsAreArray(matchers)),
+ IsSpan<SecureAggServerState>(SecAggServerTraceState_Aborted)));
+}
+
+TEST(SecaggServerTest, VerifyErrorsInAbortedState) {
+ TestTracingRecorder tracing_recorder;
+ auto sender = std::make_unique<MockSendToClientsInterface>();
+ auto server = CreateServer(sender.get());
+ EXPECT_THAT(server->Abort(), IsOk());
+
+ EXPECT_THAT(
+ server->ReceiveMessage(1, std::make_unique<ClientToServerWrapperMessage>(
+ ClientToServerWrapperMessage{})),
+ IsCode(FAILED_PRECONDITION));
+ EXPECT_THAT(server->ProceedToNextRound(), IsCode(FAILED_PRECONDITION));
+ EXPECT_THAT(server->MinimumMessagesNeededForNextRound(),
+ IsCode(FAILED_PRECONDITION));
+ EXPECT_THAT(server->NumberOfMessagesReceivedInThisRound(),
+ IsCode(FAILED_PRECONDITION));
+ EXPECT_THAT(server->ReadyForNextRound(), IsCode(FAILED_PRECONDITION));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_trace_utility.cc b/fcp/secagg/server/secagg_trace_utility.cc
new file mode 100644
index 0000000..62e28d0
--- /dev/null
+++ b/fcp/secagg/server/secagg_trace_utility.cc
@@ -0,0 +1,173 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/secagg/server/secagg_trace_utility.h"
+
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/tracing_schema.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+TracingClientStatus ClientStatusType(ClientStatus client_status) {
+ switch (client_status) {
+ case (ClientStatus::READY_TO_START):
+ return TracingClientStatus_ReadyToStart;
+ case (ClientStatus::DEAD_BEFORE_SENDING_ANYTHING):
+ return TracingClientStatus_DeadBeforeSendingAnything;
+ case (ClientStatus::ADVERTISE_KEYS_RECEIVED):
+ return TracingClientStatus_AdvertiseKeysReceived;
+ case (ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED):
+ return TracingClientStatus_DeadAfterAdvertiseKeysReceived;
+ case (ClientStatus::SHARE_KEYS_RECEIVED):
+ return TracingClientStatus_ShareKeysReceived;
+ case (ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED):
+ return TracingClientStatus_DeadAfterShareKeysReceived;
+ case (ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED):
+ return TracingClientStatus_MaskedInputResponseReceived;
+ case (ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED):
+ return TracingClientStatus_DeadAfterMaskedInputResponseReceived;
+ case (ClientStatus::UNMASKING_RESPONSE_RECEIVED):
+ return TracingClientStatus_UnmaskingResponseReceived;
+ case (ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED):
+ return TracingClientStatus_DeadAfterUnmaskingResponseReceived;
+ default:
+ return TracingClientStatus_Unknown;
+ }
+}
+
+TracingClientDropReason ClientDropReasonType(ClientDropReason reason_code) {
+ switch (reason_code) {
+ case (ClientDropReason::SENT_ABORT_MESSAGE):
+ return TracingClientDropReason_SentAbortMessage;
+ case (ClientDropReason::UNEXPECTED_MESSAGE_TYPE):
+ return TracingClientDropReason_UnexpectedMessageType;
+ case (ClientDropReason::UNKNOWN_MESSAGE_TYPE):
+ return TracingClientDropReason_UnknownMessageType;
+ case (ClientDropReason::ADVERTISE_KEYS_UNEXPECTED):
+ return TracingClientDropReason_AdvertiseKeysUnexpected;
+ case (ClientDropReason::EMPTY_PUBLIC_KEY):
+ return TracingClientDropReason_EmptyPublicKey;
+ case (ClientDropReason::NO_ADVERTISE_KEYS):
+ return TracingClientDropReason_NoAdvertiseKeys;
+ case (ClientDropReason::SHARE_KEYS_UNEXPECTED):
+ return TracingClientDropReason_ShareKeysUnexpected;
+ case (ClientDropReason::WRONG_NUMBER_OF_KEY_SHARES):
+ return TracingClientDropReason_WrongNumberOfKeyShares;
+ case (ClientDropReason::MISSING_KEY_SHARE):
+ return TracingClientDropReason_MissingKeyShare;
+ case (ClientDropReason::EXTRA_KEY_SHARE):
+ return TracingClientDropReason_ExtraKeyShare;
+ case (ClientDropReason::NO_SHARE_KEYS):
+ return TracingClientDropReason_NoShareKeys;
+ case (ClientDropReason::MASKED_INPUT_UNEXPECTED):
+ return TracingClientDropReason_MaskedInputUnexpected;
+ case (ClientDropReason::INVALID_MASKED_INPUT):
+ return TracingClientDropReason_InvalidMaskedInput;
+ case (ClientDropReason::NO_MASKED_INPUT):
+ return TracingClientDropReason_NoMaskedInput;
+ case (ClientDropReason::UNMASKING_RESPONSE_UNEXPECTED):
+ return TracingClientDropReason_UnmaskingResponseUnexpected;
+ case (ClientDropReason::INVALID_UNMASKING_RESPONSE):
+ return TracingClientDropReason_InvalidUnmaskingResponse;
+ case (ClientDropReason::NO_UNMASKING_RESPONSE):
+ return TracingClientDropReason_NoUnmaskingResponse;
+ case (ClientDropReason::INVALID_PUBLIC_KEY):
+ return TracingClientDropReason_InvalidPublicKey;
+ case (ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT):
+ return TracingClientDropReason_ServerProtocolAbortClient;
+ case (ClientDropReason::EARLY_SUCCESS):
+ return TracingClientDropReason_EarlySuccess;
+ case (ClientDropReason::CONNECTION_CLOSED):
+ return TracingClientDropReason_ConnectionClosed;
+ default:
+ return TracingClientDropReason_Unknown;
+ }
+}
+
+ClientToServerMessageType GetClientToServerMessageType(
+ const ClientToServerWrapperMessage& message) {
+ switch (message.message_content_case()) {
+ case ClientToServerWrapperMessage::MESSAGE_CONTENT_NOT_SET:
+ return ClientToServerMessageType_MessageContentNotSet;
+ case ClientToServerWrapperMessage::kAbort:
+ return ClientToServerMessageType_Abort;
+ case ClientToServerWrapperMessage::kAdvertiseKeys:
+ return ClientToServerMessageType_AdvertiseKeys;
+ case ClientToServerWrapperMessage::kShareKeysResponse:
+ return ClientToServerMessageType_ShareKeysResponse;
+ case ClientToServerWrapperMessage::kMaskedInputResponse:
+ return ClientToServerMessageType_MaskedInputResponse;
+ case ClientToServerWrapperMessage::kUnmaskingResponse:
+ return ClientToServerMessageType_UnmaskingResponse;
+ }
+}
+
+ServerToClientMessageType GetServerToClientMessageType(
+ const ServerToClientWrapperMessage& message) {
+ switch (message.message_content_case()) {
+ case ServerToClientWrapperMessage::kAbort:
+ return ServerToClientMessageType_Abort;
+ case ServerToClientWrapperMessage::kShareKeysRequest:
+ return ServerToClientMessageType_ShareKeysRequest;
+ case ServerToClientWrapperMessage::kMaskedInputRequest:
+ return ServerToClientMessageType_MaskedInputRequest;
+ case ServerToClientWrapperMessage::kUnmaskingRequest:
+ return ServerToClientMessageType_UnmaskingRequest;
+ default:
+ return ServerToClientMessageType_MessageContentNotSet;
+ }
+}
+
+TracingSecAggServerOutcome ConvertSecAccServerOutcomeToTrace(
+ SecAggServerOutcome outcome) {
+ switch (outcome) {
+ case (SecAggServerOutcome::EXTERNAL_REQUEST):
+ return TracingSecAggServerOutcome_ExternalRequest;
+ case (SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING):
+ return TracingSecAggServerOutcome_NotEnoughClientsRemaining;
+ case (SecAggServerOutcome::UNHANDLED_ERROR):
+ return TracingSecAggServerOutcome_UnhandledError;
+ case (SecAggServerOutcome::SUCCESS):
+ return TracingSecAggServerOutcome_Success;
+ default:
+ return TracingSecAggServerOutcome_Unknown;
+ }
+}
+
+SecAggServerTraceState TracingState(SecAggServerStateKind state_kind) {
+ switch (state_kind) {
+ case SecAggServerStateKind::ABORTED:
+ return SecAggServerTraceState_Aborted;
+ case SecAggServerStateKind::COMPLETED:
+ return SecAggServerTraceState_Completed;
+ case SecAggServerStateKind::PRNG_RUNNING:
+ return SecAggServerTraceState_PrngRunning;
+ case SecAggServerStateKind::R0_ADVERTISE_KEYS:
+ return SecAggServerTraceState_R0AdvertiseKeys;
+ case SecAggServerStateKind::R1_SHARE_KEYS:
+ return SecAggServerTraceState_R1ShareKeys;
+ case SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION:
+ return SecAggServerTraceState_R2MaskedInputCollection;
+ case SecAggServerStateKind::R3_UNMASKING:
+ return SecAggServerTraceState_R3Unmasking;
+ default:
+ return SecAggServerTraceState_UnknownState;
+ }
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secagg_trace_utility.h b/fcp/secagg/server/secagg_trace_utility.h
new file mode 100644
index 0000000..3a2638d
--- /dev/null
+++ b/fcp/secagg/server/secagg_trace_utility.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_SECAGG_SERVER_SECAGG_TRACE_UTILITY_H_
+#define FCP_SECAGG_SERVER_SECAGG_TRACE_UTILITY_H_
+
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/tracing_schema.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+// Returns the ClientStatus state to be used for the context of tracing
+TracingClientStatus ClientStatusType(ClientStatus client_status);
+
+// Returns the ClientDropReason state to be used for the context of tracing
+TracingClientDropReason ClientDropReasonType(ClientDropReason reason_code);
+
+// Returns the ClientToServerWrapperMessage state
+// to be used for the context of tracing
+ClientToServerMessageType GetClientToServerMessageType(
+ const ClientToServerWrapperMessage& message);
+
+// Returns the ClientToServerWrapperMessage state
+// to be used for the context of tracing
+ServerToClientMessageType GetServerToClientMessageType(
+ const ServerToClientWrapperMessage& message);
+
+// Returns the SecAggServerOutcome state
+// to be used for the context of tracing
+TracingSecAggServerOutcome ConvertSecAccServerOutcomeToTrace(
+ SecAggServerOutcome outcome);
+
+// Returns the SecAggServerStateKind state
+// to be used for the context of tracing
+SecAggServerTraceState TracingState(SecAggServerStateKind state_kind);
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECAGG_TRACE_UTILITY_H_
diff --git a/fcp/secagg/server/secret_sharing_complete_graph.h b/fcp/secagg/server/secret_sharing_complete_graph.h
new file mode 100644
index 0000000..a633d4d
--- /dev/null
+++ b/fcp/secagg/server/secret_sharing_complete_graph.h
@@ -0,0 +1,97 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECRET_SHARING_COMPLETE_GRAPH_H_
+#define FCP_SECAGG_SERVER_SECRET_SHARING_COMPLETE_GRAPH_H_
+
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secret_sharing_graph.h"
+
+namespace fcp {
+namespace secagg {
+
+// SecretSharingGraph built from a complete (directed) graph.
+// For example, in a SecretSharingCompleteGraph with 4
+// nodes the list of neighbors of each node are:
+// 0 -> 0, 1, 2, 3
+// 1 -> 0, 1, 2, 3
+// 2 -> 0, 1, 2, 3
+// 3 -> 0, 1, 2, 3
+// Thus, the (single) outgoing neighbor of 0 is 0.
+// The outgoing neighbors of 1 are 0, 1.
+// The outgoing neighbors of 2 are 0, 1, 2.
+// The outgoing neighbors of 3 are 0, 1, 2, 3.
+
+// SecretSharingCompleteGraph must be instantiated via
+// SecretSharingGraphFactory.
+class SecretSharingCompleteGraph : public SecretSharingGraph {
+ public:
+ SecretSharingCompleteGraph(const SecretSharingCompleteGraph&) = delete;
+ SecretSharingCompleteGraph& operator=(const SecretSharingCompleteGraph&) =
+ delete;
+ ~SecretSharingCompleteGraph() override = default;
+
+ int GetNumNodes() const override { return num_nodes_; }
+
+ int GetDegree() const override {
+ // All nodes have degree num_nodes.
+ return num_nodes_;
+ }
+
+ int GetThreshold() const override { return threshold_; }
+
+ int GetNeighbor(int curr_node, int i) const override {
+ FCP_CHECK(IsValidNode(curr_node));
+ FCP_CHECK(IsValidNode(i)); // i must be in [0, num_nodes)
+ // Each node has all other nodes as a neighbor, including itself.
+ return i;
+ }
+
+ std::optional<int> GetNeighborIndex(int node_1, int node_2) const override {
+ // Lists of neighbors are sorted by client id
+ FCP_CHECK(IsValidNode(node_1));
+ FCP_CHECK(IsValidNode(node_2));
+ return node_2;
+ }
+
+ bool AreNeighbors(int node_1, int node_2) const override {
+ FCP_CHECK(IsValidNode(node_1));
+ FCP_CHECK(IsValidNode(node_2));
+ return true;
+ }
+
+ bool IsOutgoingNeighbor(int node_1, int node_2) const override {
+ FCP_CHECK(IsValidNode(node_1));
+ FCP_CHECK(IsValidNode(node_2));
+ return node_2 >= node_1;
+ }
+
+ private:
+ // Number of nodes in the graph, with indices [0, num_nodes).
+ int num_nodes_;
+ int threshold_;
+ explicit SecretSharingCompleteGraph(int num_nodes, int threshold)
+ : num_nodes_(num_nodes), threshold_(threshold) {}
+ friend class SecretSharingGraphFactory;
+
+ bool IsValidNode(int node) const { return 0 <= node && node < num_nodes_; }
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECRET_SHARING_COMPLETE_GRAPH_H_
diff --git a/fcp/secagg/server/secret_sharing_complete_graph_test.cc b/fcp/secagg/server/secret_sharing_complete_graph_test.cc
new file mode 100644
index 0000000..2498656
--- /dev/null
+++ b/fcp/secagg/server/secret_sharing_complete_graph_test.cc
@@ -0,0 +1,106 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+
+#include "absl/status/status.h"
+#include "fcp/secagg/server/secret_sharing_graph.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+static constexpr int kNumNodes = 10;
+static constexpr int kThreshold = 5;
+
+TEST(SecretSharingCompleteGraphTest, GetNumNodes) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateCompleteGraph(kNumNodes, kThreshold);
+ EXPECT_EQ(graph->GetNumNodes(), kNumNodes);
+}
+
+TEST(SecretSharingCompleteGraphTest, GetDegree) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateCompleteGraph(kNumNodes, kThreshold);
+ EXPECT_EQ(graph->GetDegree(), kNumNodes);
+}
+
+TEST(SecretSharingCompleteGraphTest, GetThreshold_Valid) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateCompleteGraph(kNumNodes, kThreshold);
+ EXPECT_EQ(graph->GetThreshold(), kThreshold);
+}
+
+TEST(SecretSharingCompleteGraphTest, Threshold_OutOfRange) {
+ SecretSharingGraphFactory factory;
+ EXPECT_DEATH(factory.CreateCompleteGraph(kNumNodes, kNumNodes + 1), "");
+}
+
+TEST(SecretSharingCompleteGraphTest, GetNeighbor_Valid) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateCompleteGraph(kNumNodes, kThreshold);
+ for (int i = 0; i < graph->GetDegree(); i++) {
+ EXPECT_EQ(graph->GetNeighbor(0, i), i);
+ }
+}
+
+TEST(SecretSharingCompleteGraphTest, GetNeighbor_OutOfRange) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateCompleteGraph(kNumNodes, kThreshold);
+ EXPECT_DEATH(graph->GetNeighbor(0, -1), "");
+ EXPECT_DEATH(graph->GetNeighbor(0, kNumNodes), "");
+}
+
+TEST(SecretSharingCompleteGraphTest, AreNeighbors_Valid) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateCompleteGraph(kNumNodes, kThreshold);
+ for (int i = 0; i < graph->GetDegree(); i++) {
+ EXPECT_TRUE(graph->AreNeighbors(0, i));
+ }
+}
+
+TEST(SecretSharingCompleteGraphTest, AreNeighbors_OutOfRange) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateCompleteGraph(kNumNodes, kThreshold);
+ EXPECT_DEATH(graph->AreNeighbors(0, -1), "");
+ EXPECT_DEATH(graph->AreNeighbors(0, kNumNodes), "");
+}
+
+TEST(SecretSharingCompleteGraphTest, GetNeighborIndex) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateCompleteGraph(kNumNodes, kThreshold);
+ for (int i = 0; i < graph->GetNumNodes(); i++) {
+ for (int j = 0; j < graph->GetDegree(); j++) {
+ EXPECT_EQ(graph->GetNeighborIndex(i, j), j);
+ }
+ }
+}
+
+TEST(SecretSharingCompleteGraphTest, IsOutgoingNeighbor) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateCompleteGraph(kNumNodes, kThreshold);
+ for (int i = 0; i < graph->GetNumNodes(); i++) {
+ for (int j = 0; j < graph->GetDegree(); j++) {
+ EXPECT_EQ(graph->IsOutgoingNeighbor(i, j), i <= j);
+ }
+ }
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secret_sharing_graph.h b/fcp/secagg/server/secret_sharing_graph.h
new file mode 100644
index 0000000..0b7d4ab
--- /dev/null
+++ b/fcp/secagg/server/secret_sharing_graph.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECRET_SHARING_GRAPH_H_
+#define FCP_SECAGG_SERVER_SECRET_SHARING_GRAPH_H_
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace secagg {
+
+// Abstract class representing a regular directed graph.
+// Nodes are integers in [0..NumNodes - 1]. For each node i (representing client
+// with id i), the graph specifies the set of other nodes with which i shares
+// its keys and computes pairwise masks. The graph is directed, and the
+// neighbors of each node are ordered (i.e. we have a notion of the 1st neighbor
+// of client i, the second neighbor, etc).
+
+// The intuitive way to visualize the graph is by means of a complete mapping
+// between nodes and the *ordered* lists of their neighbors (of length the
+// degree k):
+
+// 0 - > 1st_neighbor_of_0, ..., kth_neighbor_of_0
+// 1 - > 1st_neighbor_of_1, ..., kth_neighbor_of_1
+// ...
+// n-1 - > 1st_neighbor_of_n-1, ..., kth_neighbor_of_n-1
+
+// For every node i, its list of neighbors includes i, because as mentioned
+// above *each node has a self loop*.
+
+// The direction of each edge adjacent to node i is given implicitly by the
+// order of the neighbors of i: nodes occurring in the list of neighbors of i
+// *strictly* after i itself are called *outgoing* neighbors, and nodes ocurring
+// before i (including i itself) are called *incoming* neighbors.
+
+// The SecretSharingGraph class includes functions to (a) retrieve the index of
+// a neighbor of a node in the list of neighbors , (b) retrieve the neighbor of
+// a node at a given index, and (c) check if a nodes are neighbors, and of which
+// kind (i.e. incoming vs outgoing).
+
+// There are multiple subclasses of SecretSharingGraph. The complete graph
+// variant implemented as the SecretSharingCompleteGraph subclass, and the
+// (random) Harary graph variant implemented as the SecretSharingCompleteGraph
+// subclass.
+class SecretSharingGraph {
+ public:
+ virtual ~SecretSharingGraph() = default;
+
+ // Returns the number of nodes in the graph.
+ virtual int GetNumNodes() const = 0;
+
+ // Returns the degree of the graph.
+ virtual int GetDegree() const = 0;
+
+ // Returns the threshold of the secret sharing
+ virtual int GetThreshold() const = 0;
+
+ // Returns curr_node's ith neighbor.
+ // This function assumes that 0 <= i < GetDegree() and will throw a runtime
+ // error if that's not the case
+ virtual int GetNeighbor(int curr_node, int i) const = 0;
+
+ // Returns the index of node_2 in the list of neighbors of node_1, if present
+ virtual std::optional<int> GetNeighborIndex(int node_1, int node_2) const = 0;
+
+ // Returns true if node_1 and node_2 are neighbors, else false.
+ virtual bool AreNeighbors(int node_1, int node_2) const = 0;
+
+ // Returns true if node_1 is an outgoing neighbor of node_2, else false.
+ virtual bool IsOutgoingNeighbor(int node_1, int node_2) const = 0;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECRET_SHARING_GRAPH_H_
diff --git a/fcp/secagg/server/secret_sharing_graph_factory.h b/fcp/secagg/server/secret_sharing_graph_factory.h
new file mode 100644
index 0000000..64963c6
--- /dev/null
+++ b/fcp/secagg/server/secret_sharing_graph_factory.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECRET_SHARING_GRAPH_FACTORY_H_
+#define FCP_SECAGG_SERVER_SECRET_SHARING_GRAPH_FACTORY_H_
+
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secret_sharing_complete_graph.h"
+#include "fcp/secagg/server/secret_sharing_graph.h"
+#include "fcp/secagg/server/secret_sharing_harary_graph.h"
+#include "fcp/secagg/server/ssl_bit_gen.h"
+
+namespace fcp {
+namespace secagg {
+
+// Factory class that constructs non-copyable instances of children classes of
+// SecretSharingGraph.
+class SecretSharingGraphFactory {
+ public:
+ // Creates a SecretSharingCompleteGraph.
+ static std::unique_ptr<SecretSharingCompleteGraph> CreateCompleteGraph(
+ int num_nodes, int threshold) {
+ FCP_CHECK(num_nodes >= 1)
+ << "num_nodes must be >= 1, given value was " << num_nodes;
+ FCP_CHECK(threshold >= 1)
+ << "threshold must be >= 1, given value was " << threshold;
+ FCP_CHECK(threshold <= num_nodes)
+ << "threshold must be <= num_nodes, given values were " << threshold
+ << ", " << num_nodes;
+ return absl::WrapUnique(
+ new SecretSharingCompleteGraph(num_nodes, threshold));
+ }
+
+ // Creates a SecretSharingHararyGraph.
+ static std::unique_ptr<SecretSharingHararyGraph> CreateHararyGraph(
+ int num_nodes, int degree, int threshold, bool is_random = true) {
+ FCP_CHECK(num_nodes >= 1)
+ << "num_nodes must be >= 1, given value was " << num_nodes;
+ FCP_CHECK(degree <= num_nodes)
+ << "degree must be <= num_nodes, given values were " << num_nodes
+ << ", " << degree;
+ FCP_CHECK(degree % 2 == 1)
+ << "degree must be odd, given value was " << degree;
+ FCP_CHECK(threshold >= 1)
+ << "threshold must be >= 1, given value was " << threshold;
+ FCP_CHECK(threshold <= degree)
+ << "threshold must be <= degree, given values were " << threshold
+ << ", " << degree;
+ auto permutation = std::vector<int>(num_nodes);
+ for (int i = 0; i < num_nodes; ++i) {
+ permutation[i] = i;
+ }
+ if (is_random) {
+ std::shuffle(permutation.begin(), permutation.end(), SslBitGen());
+ }
+ return absl::WrapUnique(new SecretSharingHararyGraph(
+ degree, threshold, std::move(permutation)));
+ }
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECRET_SHARING_GRAPH_FACTORY_H_
diff --git a/fcp/secagg/server/secret_sharing_harary_graph.cc b/fcp/secagg/server/secret_sharing_harary_graph.cc
new file mode 100644
index 0000000..d43849a
--- /dev/null
+++ b/fcp/secagg/server/secret_sharing_harary_graph.cc
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/secret_sharing_harary_graph.h"
+
+#include <algorithm>
+#include <utility>
+
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/server/secret_sharing_graph.h"
+
+namespace fcp {
+namespace secagg {
+
+SecretSharingHararyGraph::SecretSharingHararyGraph(int degree, int threshold,
+ std::vector<int> permutation)
+ : number_of_nodes_(static_cast<int>(permutation.size())),
+ degree_(degree),
+ threshold_(threshold),
+ permutation_(std::move(permutation)) {
+ inverse_permutation_ = std::vector<int>(number_of_nodes_);
+ for (int i = 0; i < number_of_nodes_; ++i) {
+ inverse_permutation_[permutation_[i]] = i;
+ }
+}
+
+int SecretSharingHararyGraph::GetNeighbor(int curr_node,
+ int neighbor_index) const {
+ FCP_CHECK(IsValidNode(curr_node));
+ FCP_CHECK(IsValidNeighborIndex(neighbor_index));
+
+ int curr_node_before_renaming = inverse_permutation_[curr_node];
+ int zeroth_neighbor_before_renaming = curr_node_before_renaming - degree_ / 2;
+ // We add number_of_nodes_ as a way to handle negative numbers
+ return permutation_[(zeroth_neighbor_before_renaming + neighbor_index +
+ number_of_nodes_) %
+ number_of_nodes_];
+}
+
+std::optional<int> SecretSharingHararyGraph::GetNeighborIndex(
+ int node_1, int node_2) const {
+ FCP_CHECK(IsValidNode(node_1));
+ FCP_CHECK(IsValidNode(node_2));
+ int sub_before_renaming =
+ std::abs(inverse_permutation_[node_1] - inverse_permutation_[node_2]);
+ // Compute distance between nodes before applying the permutation.
+ // node_1 and node_2 are connected iff, before renaming, they were at modular
+ // distance <= degree_ / 2
+ int mod_dist_before_renaming =
+ std::min(sub_before_renaming, number_of_nodes_ - sub_before_renaming);
+ if (mod_dist_before_renaming > degree_ / 2) {
+ return {};
+ }
+ // Check that node_2 occurs before node_1 in the list of neighbors of node_1,
+ // i.e. node_2 is an incoming neighbor of node_1
+ // We add number_of_nodes_ as a way to handle negative numbers
+ if ((inverse_permutation_[node_1] - mod_dist_before_renaming +
+ number_of_nodes_) %
+ number_of_nodes_ ==
+ inverse_permutation_[node_2]) {
+ return degree_ / 2 - mod_dist_before_renaming;
+ }
+ return degree_ / 2 + mod_dist_before_renaming;
+}
+
+bool SecretSharingHararyGraph::AreNeighbors(int node_1, int node_2) const {
+ FCP_CHECK(IsValidNode(node_1));
+ FCP_CHECK(IsValidNode(node_2));
+ return GetNeighborIndex(node_1, node_2).value_or(-1) >= 0;
+}
+
+bool SecretSharingHararyGraph::IsOutgoingNeighbor(int node_1,
+ int node_2) const {
+ FCP_CHECK(IsValidNode(node_1));
+ FCP_CHECK(IsValidNode(node_2));
+ return GetNeighborIndex(node_1, node_2).value_or(-1) >= degree_ / 2;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/secret_sharing_harary_graph.h b/fcp/secagg/server/secret_sharing_harary_graph.h
new file mode 100644
index 0000000..9e35495
--- /dev/null
+++ b/fcp/secagg/server/secret_sharing_harary_graph.h
@@ -0,0 +1,135 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SECRET_SHARING_HARARY_GRAPH_H_
+#define FCP_SECAGG_SERVER_SECRET_SHARING_HARARY_GRAPH_H_
+
+#include <vector>
+
+#include "fcp/secagg/server/secret_sharing_graph.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents a regular undirected graph specifying, for each client,
+// among which other clients it shares its keys and computes pairwise masks.
+// The construction of the graph is randomized, by permuting node ids, which is
+// crucial for security. More concretely, the graph is a (degree-1,
+// num_nodes)-Harary graph with randomly permuted node labels and an additional
+// self-edge in each node. For simplicity of the construction we require that
+// the degree is odd.
+
+// A SecretSharingHararyGraph(num_nodes, degree) graph can be constructed by
+// putting all num_nodes nodes in a circle and, for each node: (i) adding
+// degree/2 edges to the immediately preceding nodes, (ii) adding degree/2 edges
+// to the immediately successive nodes, and (iii) adding a self edge. Finally,
+// nodes are given ids in [0..num_nodes - 1] uniformaly at random (or
+// equivalently one applies a permutation to the original ids 0...num_nodes -
+// 1).
+
+// For example, if degree = 5 and num_nodes = 10, the adjacency list of a
+// (num_nodes, degree)-SecretSharingHararyGraph (before permuting ids) is:
+//
+// 0 -> 8, 9, 0, 1, 2
+// 1 -> 9, 0, 1, 2, 3
+// 2 -> 0, 1, 2, 3, 4
+// 3 -> 1, 2, 3, 4, 5
+// 4 -> 2, 3, 4, 5, 6
+// 5 -> 3, 4, 5, 6, 7
+// 6 -> 4, 5, 6, 7, 8
+// 7 -> 5, 6, 7, 8, 9
+// 8 -> 6, 7, 8, 9, 0
+// 9 -> 7, 8, 9, 0, 1
+//
+//
+// SecretSharingHararyGraph additionally have permuted node ids (iff is_random
+// == true) according to a uniformly random permutation. For example, if that
+// permutation was (3, 2, 5, 4, 1, 8, 9, 0, 6, 7) then the resulting
+// SecretSharingHararyGraph is the result of applying the permutation (aka node
+// renaming) to the above adjacency list:
+
+// 3 -> 6, 7, 3, 2, 5
+// 2 -> 7, 3, 2, 5, 4
+// 5 -> 3, 2, 5, 4, 1
+// 4 -> 2, 5, 4, 1, 8
+// 1 -> 5, 4, 1, 8, 9
+// 8 -> 4, 1, 8, 9, 0
+// 9 -> 1, 8, 9, 0, 6
+// 0 -> 8, 9, 0, 6, 7
+// 6 -> 9, 0, 6, 7, 3
+// 7 -> 0, 6, 7, 3, 1
+
+// Thus, the outgoing neighbors of 3 are 6, 7, 3, The outgoing neighbors of 2
+// are 7, 3, 2, and so on.
+
+// Although the above example aludes to an adjacency list based representation
+// of the graph, this is only for clarity, as this is not stored explicitly.
+// Instead, storing the random permutation (that is (3, 2, 5, 4, 1, 8, 9, 0, 6,
+// 7) in the above example) and its inverse (which is (7, 4, 1, 0, 3, 2,
+// 8, 9, 5, 6)) leads to a more space efficient implementation with constant
+// time cost for all class functions.
+
+// This class must be instantiated through SecretSharingGraphFactory.
+class SecretSharingHararyGraph : public SecretSharingGraph {
+ public:
+ SecretSharingHararyGraph(const SecretSharingHararyGraph&) = delete;
+ SecretSharingHararyGraph& operator=(const SecretSharingHararyGraph&) = delete;
+ ~SecretSharingHararyGraph() override = default;
+
+ int GetNumNodes() const override { return number_of_nodes_; }
+
+ int GetDegree() const override { return degree_; }
+
+ int GetThreshold() const override { return threshold_; }
+
+ int GetNeighbor(int curr_node, int neighbor_index) const override;
+
+ std::optional<int> GetNeighborIndex(int node_1, int node_2) const override;
+
+ bool AreNeighbors(int node_1, int node_2) const override;
+
+ bool IsOutgoingNeighbor(int node_1, int node_2) const override;
+
+ // Returns the permutation that was applied to the nodes in the construction.
+ // This function is only used for testing purposes.
+ std::vector<int> GetPermutationForTesting() const { return permutation_; }
+
+ private:
+ int number_of_nodes_;
+ int degree_;
+ int threshold_;
+ // random permutation applied to the node ids in the SecretSharingHararyGraph
+ // construction.
+ const std::vector<int> permutation_;
+ // Inverse of the above permutation.
+ std::vector<int> inverse_permutation_;
+ SecretSharingHararyGraph(int degree, int threshold,
+ std::vector<int> permutation);
+ friend class SecretSharingGraphFactory;
+
+ bool IsValidNode(int node) const {
+ return 0 <= node && node < number_of_nodes_;
+ }
+
+ bool IsValidNeighborIndex(int index) const {
+ return 0 <= index && index < degree_;
+ }
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SECRET_SHARING_HARARY_GRAPH_H_
diff --git a/fcp/secagg/server/secret_sharing_harary_graph_test.cc b/fcp/secagg/server/secret_sharing_harary_graph_test.cc
new file mode 100644
index 0000000..85fd32e
--- /dev/null
+++ b/fcp/secagg/server/secret_sharing_harary_graph_test.cc
@@ -0,0 +1,315 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <memory>
+#include <string>
+
+#include "absl/status/status.h"
+#include "fcp/secagg/server/secret_sharing_graph.h"
+#include "fcp/secagg/server/secret_sharing_graph_factory.h"
+#include "fcp/testing/testing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+namespace secret_sharing_harary_graph_test_internal {
+// Auxiliary function returning the index of j in the list of neighbors of i,
+// in a graph g represented as an adjacency list
+std::optional<int> GetNeighborIndexFromAdjacencyList(
+ const std::vector<std::vector<int>>& g, int i, int j) {
+ auto index = std::find(std::begin(g[i]), std::end(g[i]), j);
+ if (index != std::end(g[i])) {
+ return *index;
+ }
+ return {};
+}
+} // namespace secret_sharing_harary_graph_test_internal
+
+static constexpr int kNumNodes = 10;
+static constexpr int kDegree = 5;
+static constexpr int kThreshold = 2;
+
+TEST(SecretSharingHararyGraphTest, GetPermutationReturnsPermutation) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateHararyGraph(kNumNodes, kDegree, kThreshold);
+ std::vector<int> permutation = graph->GetPermutationForTesting();
+ EXPECT_EQ(permutation.size(), kNumNodes);
+ std::vector<int> counters(kNumNodes, 0);
+ for (int i = 0; i < permutation.size(); ++i) {
+ counters[permutation[i]]++;
+ }
+ for (auto x : counters) {
+ EXPECT_EQ(x, 1);
+ }
+}
+
+TEST(SecretSharingHararyGraphTest,
+ GetPermutationInDeterministicHararyGraphReturnsIdentityPermutation) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateHararyGraph(kNumNodes, kDegree, kThreshold, false);
+ std::vector<int> permutation = graph->GetPermutationForTesting();
+ for (int i = 0; i < permutation.size(); ++i) {
+ EXPECT_EQ(permutation[i], i);
+ }
+}
+
+TEST(SecretSharingHararyGraphTest,
+ GetPermutationInRandomHararyGraphDoesNotReturnIdentityPermutation) {
+ SecretSharingGraphFactory factory;
+ // We use a larger number of nodes so that the probability of getting the
+ // identity permutation by change is negligible
+ int larger_num_nodes = 100;
+ auto graph = factory.CreateHararyGraph(larger_num_nodes, kDegree, kThreshold);
+ std::vector<int> permutation = graph->GetPermutationForTesting();
+ // Find j so that permutation[j] != j. This will be the case for most i's (all
+ // but one in expectation), but one is sufficient in this test.
+ bool found = false;
+ for (int i = 0; i < permutation.size(); ++i) {
+ found = found || (i != permutation[i]);
+ }
+ EXPECT_EQ(found, true);
+}
+
+TEST(SecretSharingHararyGraphTest, AreNeighborsIsCorrectInRandomHararyGraph) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateHararyGraph(kNumNodes, kDegree, kThreshold);
+
+ EXPECT_EQ(graph->GetNumNodes(), kNumNodes);
+ EXPECT_EQ(graph->GetDegree(), kDegree);
+ EXPECT_EQ(graph->GetThreshold(), kThreshold);
+
+ std::vector<int> p = graph->GetPermutationForTesting();
+ std::vector<std::vector<int>> adjacency_list(kNumNodes,
+ std::vector<int>(kDegree));
+ adjacency_list[p[0]] = {p[8], p[9], p[0], p[1], p[2]};
+ adjacency_list[p[1]] = {p[9], p[0], p[1], p[2], p[3]};
+ adjacency_list[p[2]] = {p[0], p[1], p[2], p[3], p[4]};
+ adjacency_list[p[3]] = {p[1], p[2], p[3], p[4], p[5]};
+ adjacency_list[p[4]] = {p[2], p[3], p[4], p[5], p[6]};
+ adjacency_list[p[5]] = {p[3], p[4], p[5], p[6], p[7]};
+ adjacency_list[p[6]] = {p[4], p[5], p[6], p[7], p[8]};
+ adjacency_list[p[7]] = {p[5], p[6], p[7], p[8], p[9]};
+ adjacency_list[p[8]] = {p[6], p[7], p[8], p[9], p[0]};
+ adjacency_list[p[9]] = {p[7], p[8], p[9], p[0], p[1]};
+
+ for (int i = 0; i < kNumNodes; ++i) {
+ for (int j = 0; j < kNumNodes; ++j) {
+ bool are_neighbors =
+ secret_sharing_harary_graph_test_internal::
+ GetNeighborIndexFromAdjacencyList(adjacency_list, i, j)
+ .value_or(-1) >= 0;
+ EXPECT_EQ(graph->AreNeighbors(i, j), are_neighbors) << i << "," << j;
+ }
+ }
+}
+
+TEST(SecretSharingHararyGraphTest,
+ AreNeighborsIsCorrectInDeterministicHararyGraph) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateHararyGraph(kNumNodes, kDegree, kThreshold, false);
+
+ EXPECT_EQ(graph->GetNumNodes(), kNumNodes);
+ EXPECT_EQ(graph->GetDegree(), kDegree);
+ EXPECT_EQ(graph->GetThreshold(), kThreshold);
+
+ std::vector<int> p = graph->GetPermutationForTesting();
+ std::vector<std::vector<int>> adjacency_list(kNumNodes,
+ std::vector<int>(kDegree));
+ adjacency_list[0] = {8, 9, 0, 1, 2};
+ adjacency_list[1] = {9, 0, 1, 2, 3};
+ adjacency_list[2] = {0, 1, 2, 3, 4};
+ adjacency_list[3] = {1, 2, 3, 4, 5};
+ adjacency_list[4] = {2, 3, 4, 5, 6};
+ adjacency_list[5] = {3, 4, 5, 6, 7};
+ adjacency_list[6] = {4, 5, 6, 7, 8};
+ adjacency_list[7] = {5, 6, 7, 8, 9};
+ adjacency_list[8] = {6, 7, 8, 9, 0};
+ adjacency_list[9] = {7, 8, 9, 0, 1};
+
+ for (int i = 0; i < kNumNodes; ++i) {
+ for (int j = 0; j < kNumNodes; ++j) {
+ bool are_neighbors =
+ secret_sharing_harary_graph_test_internal::
+ GetNeighborIndexFromAdjacencyList(adjacency_list, i, j)
+ .value_or(-1) >= 0;
+ EXPECT_EQ(graph->AreNeighbors(i, j), are_neighbors) << i << "," << j;
+ }
+ }
+}
+
+TEST(SecretSharingHararyGraphTest, GetNeighborsIsCorrectInRandomHararyGraph) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateHararyGraph(kNumNodes, kDegree, kThreshold);
+
+ EXPECT_EQ(graph->GetNumNodes(), kNumNodes);
+ EXPECT_EQ(graph->GetDegree(), kDegree);
+ EXPECT_EQ(graph->GetThreshold(), kThreshold);
+
+ std::vector<int> p = graph->GetPermutationForTesting();
+ std::vector<std::vector<int>> adjacency_list(kNumNodes,
+ std::vector<int>(kDegree));
+ adjacency_list[p[0]] = {p[8], p[9], p[0], p[1], p[2]};
+ adjacency_list[p[1]] = {p[9], p[0], p[1], p[2], p[3]};
+ adjacency_list[p[2]] = {p[0], p[1], p[2], p[3], p[4]};
+ adjacency_list[p[3]] = {p[1], p[2], p[3], p[4], p[5]};
+ adjacency_list[p[4]] = {p[2], p[3], p[4], p[5], p[6]};
+ adjacency_list[p[5]] = {p[3], p[4], p[5], p[6], p[7]};
+ adjacency_list[p[6]] = {p[4], p[5], p[6], p[7], p[8]};
+ adjacency_list[p[7]] = {p[5], p[6], p[7], p[8], p[9]};
+ adjacency_list[p[8]] = {p[6], p[7], p[8], p[9], p[0]};
+ adjacency_list[p[9]] = {p[7], p[8], p[9], p[0], p[1]};
+
+ for (int i = 0; i < kNumNodes; ++i) {
+ for (int j = 0; j < kDegree; ++j) {
+ auto x = graph->GetNeighbor(i, j);
+ EXPECT_EQ(adjacency_list[i][j], x);
+ }
+ }
+}
+
+TEST(SecretSharingHararyGraphTest,
+ GetNeighborsIsCorrectInDeterministicHararyGraph) {
+ SecretSharingGraphFactory factory;
+ auto graph = factory.CreateHararyGraph(kNumNodes, kDegree, kThreshold, false);
+
+ EXPECT_EQ(graph->GetNumNodes(), kNumNodes);
+ EXPECT_EQ(graph->GetDegree(), kDegree);
+ EXPECT_EQ(graph->GetThreshold(), kThreshold);
+
+ std::vector<int> p = graph->GetPermutationForTesting();
+ std::vector<std::vector<int>> adjacency_list(kNumNodes,
+ std::vector<int>(kDegree));
+ adjacency_list[0] = {8, 9, 0, 1, 2};
+ adjacency_list[1] = {9, 0, 1, 2, 3};
+ adjacency_list[2] = {0, 1, 2, 3, 4};
+ adjacency_list[3] = {1, 2, 3, 4, 5};
+ adjacency_list[4] = {2, 3, 4, 5, 6};
+ adjacency_list[5] = {3, 4, 5, 6, 7};
+ adjacency_list[6] = {4, 5, 6, 7, 8};
+ adjacency_list[7] = {5, 6, 7, 8, 9};
+ adjacency_list[8] = {6, 7, 8, 9, 0};
+ adjacency_list[9] = {7, 8, 9, 0, 1};
+
+ for (int i = 0; i < kNumNodes; ++i) {
+ for (int j = 0; j < kDegree; ++j) {
+ auto x = graph->GetNeighbor(i, j);
+ EXPECT_EQ(adjacency_list[i][j], x);
+ }
+ }
+}
+
+struct HararyGraphParams {
+ const std::string test_name;
+ const int kNumNodes;
+ const int kDegree;
+ const int kThreshold;
+};
+
+class SecretSharingHararyGraphParamTest_Valid
+ : public ::testing::TestWithParam<HararyGraphParams> {};
+
+TEST_P(SecretSharingHararyGraphParamTest_Valid,
+ GetNeighborIndexIsCorrectInHararyGraph) {
+ const HararyGraphParams& graph_params = GetParam();
+ SecretSharingGraphFactory factory;
+ std::unique_ptr<SecretSharingGraph> graph = factory.CreateHararyGraph(
+ graph_params.kNumNodes, graph_params.kDegree, graph_params.kThreshold);
+
+ EXPECT_EQ(graph->GetNumNodes(), graph_params.kNumNodes);
+ EXPECT_EQ(graph->GetDegree(), graph_params.kDegree);
+ EXPECT_EQ(graph->GetThreshold(), graph_params.kThreshold);
+
+ for (int i = 0; i < graph_params.kNumNodes; ++i) {
+ for (int j = 0; j < graph_params.kDegree; ++j) {
+ auto x = graph->GetNeighbor(i, j);
+ EXPECT_EQ(graph->GetNeighborIndex(i, x), j);
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ SecretSharingHararyGraphParamTests, SecretSharingHararyGraphParamTest_Valid,
+ ::testing::ValuesIn<HararyGraphParams>({
+ {"10_nodes__degree_1", 10, 1, 1},
+ {"10_nodes__degree_3", 10, 3, 2},
+ {"10_nodes__degree_5", 10, 5, 3},
+ {"100_nodes__degree_23", 100, 23, 10},
+ {"1000_nodes__degree_43", 1000, 43, 20},
+ {"10000_nodes__degree_300", 10000, 301, 100},
+ }),
+ [](const ::testing::TestParamInfo<
+ SecretSharingHararyGraphParamTest_Valid::ParamType>& info) {
+ return info.param.test_name;
+ });
+
+class SecretSharingHararyGraphParamTest_InvalidDegree
+ : public ::testing::TestWithParam<HararyGraphParams> {};
+
+TEST_P(SecretSharingHararyGraphParamTest_InvalidDegree,
+ ConstructionFailsOnEvenDegree) {
+ const HararyGraphParams& graph_params = GetParam();
+ SecretSharingGraphFactory factory;
+ EXPECT_DEATH(
+ factory.CreateHararyGraph(graph_params.kNumNodes, graph_params.kDegree,
+ graph_params.kThreshold),
+ absl::StrCat("degree must be odd, given value was ",
+ graph_params.kDegree));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ SecretSharingHararyGraphParamTests,
+ SecretSharingHararyGraphParamTest_InvalidDegree,
+ ::testing::ValuesIn<HararyGraphParams>({
+ {"10_nodes__degree_4", 10, 4, 2},
+ {"50_nodes__degree_20", 50, 20, 10},
+ }),
+ [](const ::testing::TestParamInfo<
+ SecretSharingHararyGraphParamTest_InvalidDegree::ParamType>& info) {
+ return info.param.test_name;
+ });
+
+class SecretSharingHararyGraphParamTest_InvalidThreshold
+ : public ::testing::TestWithParam<HararyGraphParams> {};
+
+TEST_P(SecretSharingHararyGraphParamTest_InvalidThreshold,
+ ConstructionFailsOnThresholdOutOfObounds) {
+ const HararyGraphParams& graph_params = GetParam();
+ SecretSharingGraphFactory factory;
+ EXPECT_DEATH(
+ factory.CreateHararyGraph(graph_params.kNumNodes, graph_params.kDegree,
+ graph_params.kThreshold),
+ "");
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ SecretSharingHararyGraphParamTests,
+ SecretSharingHararyGraphParamTest_InvalidThreshold,
+ ::testing::ValuesIn<HararyGraphParams>({
+ {"10_nodes__degree_4_under", 10, 4, -1},
+ {"10_nodes__degree_4_over", 10, 4, 6},
+ {"50_nodes__degree_20_under", 50, 20, -1},
+ {"50_nodes__degree_20_over", 50, 20, 21},
+ }),
+ [](const ::testing::TestParamInfo<
+ SecretSharingHararyGraphParamTest_InvalidThreshold::ParamType>& info) {
+ return info.param.test_name;
+ });
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/send_to_clients_interface.h b/fcp/secagg/server/send_to_clients_interface.h
new file mode 100644
index 0000000..428a8a1
--- /dev/null
+++ b/fcp/secagg/server/send_to_clients_interface.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SEND_TO_CLIENTS_INTERFACE_H_
+#define FCP_SECAGG_SERVER_SEND_TO_CLIENTS_INTERFACE_H_
+
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// Used to provide a SecAggServer with a private and authenticated channel with
+// the clients, which can be used to send protocol buffer messages.
+
+class SendToClientsInterface {
+ public:
+ // Sends the message to all of the active clients involved in the current
+ // protocol session.
+ virtual void SendBroadcast(const ServerToClientWrapperMessage& message) = 0;
+
+ // Sends a message to a single client.
+ virtual void Send(uint32_t recipient_id,
+ const ServerToClientWrapperMessage& message) = 0;
+
+ virtual ~SendToClientsInterface() = default;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SEND_TO_CLIENTS_INTERFACE_H_
diff --git a/fcp/secagg/server/ssl_bit_gen.cc b/fcp/secagg/server/ssl_bit_gen.cc
new file mode 100644
index 0000000..f8d3459
--- /dev/null
+++ b/fcp/secagg/server/ssl_bit_gen.cc
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/server/ssl_bit_gen.h"
+
+#include <type_traits>
+
+// Thread-safety guarantees only apply in BoringSSL.
+#include "fcp/base/monitoring.h"
+#include "openssl/is_boringssl.h"
+#include "openssl/rand.h"
+
+namespace fcp {
+namespace secagg {
+
+SslBitGen::result_type SslBitGen::operator()() {
+ static_assert(std::is_same<uint8_t, unsigned char>::value,
+ "uint8_t being other than unsigned char isn't supported by "
+ "BoringSSL");
+ SslBitGen::result_type random_integer;
+ int success = RAND_bytes(reinterpret_cast<unsigned char*>(&random_integer),
+ sizeof(random_integer));
+ // RAND_bytes always returns 1 in BoringSSL
+ FCP_CHECK(success == 1);
+
+ return random_integer;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/server/ssl_bit_gen.h b/fcp/secagg/server/ssl_bit_gen.h
new file mode 100644
index 0000000..c0adf44
--- /dev/null
+++ b/fcp/secagg/server/ssl_bit_gen.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SERVER_SSL_BIT_GEN_H_
+#define FCP_SECAGG_SERVER_SSL_BIT_GEN_H_
+
+#include <cstdint>
+#include <limits>
+
+namespace fcp {
+namespace secagg {
+
+// A secure BitGen class (analogous to absl::BitGen) for use with absl random
+// APIs, which uses RAND_bytes as a source of randomness. This type satisfies
+// the UniformRandomBitGenerator (URBG) concept:
+// https://en.cppreference.com/w/cpp/named_req/UniformRandomBitGenerator
+//
+// For generating a large quantity of random bytes (e.g. a cryptographic key),
+// it is more appropriate to use RAND_bytes directly.
+//
+// Thread safety: SslBitGen is thread safe.
+//
+// SslBitGen construction is free, and instances don't need to be
+// reused. In addition, it's probably better to make it clear at the call site
+// when a SslBitGen is being used, as opposed to a different URBG. So
+// rather than storing the SslBitGen, if possible, prefer to create one
+// at the time it is needed:
+//
+// int x = absl::Uniform(SslBitGen(), 0, 100);
+//
+class SslBitGen {
+ public:
+ using result_type = uint64_t;
+
+ SslBitGen() = default;
+
+ // SslBitGen cannot be copied or moved. This allows uses of it to easily be
+ // replaced with a stateful UniformRandomBitGenerator.
+ SslBitGen(const SslBitGen&) = delete;
+ SslBitGen& operator=(const SslBitGen&) = delete;
+
+ bool operator==(const SslBitGen&) const = delete;
+ bool operator!=(const SslBitGen&) const = delete;
+
+ // Returns a random number from a CSPRNG.
+ result_type operator()();
+
+ static constexpr result_type min() {
+ return std::numeric_limits<result_type>::min();
+ }
+ static constexpr result_type max() {
+ return std::numeric_limits<result_type>::max();
+ }
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SERVER_SSL_BIT_GEN_H_
diff --git a/fcp/secagg/server/tracing_schema.fbs b/fcp/secagg/server/tracing_schema.fbs
new file mode 100644
index 0000000..c57ef62
--- /dev/null
+++ b/fcp/secagg/server/tracing_schema.fbs
@@ -0,0 +1,248 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+include "fcp/tracing/tracing_schema_common.fbs";
+
+enum SecAggServerTraceState : short {
+ UnknownState,
+ R0AdvertiseKeys,
+ R1ShareKeys,
+ R2MaskedInputCollection,
+ R3Unmasking,
+ PrngRunning,
+ Completed,
+ Aborted,
+}
+
+enum ServerToClientMessageType : short {
+ Abort,
+ ShareKeysRequest,
+ MaskedInputRequest,
+ UnmaskingRequest,
+ MessageContentNotSet,
+}
+
+enum ClientToServerMessageType: short {
+ Abort,
+ AdvertiseKeys,
+ ShareKeysResponse,
+ MaskedInputResponse,
+ UnmaskingResponse,
+ MessageContentNotSet,
+ }
+
+enum TracingClientStatus : short {
+ ReadyToStart,
+ DeadBeforeSendingAnything,
+ AdvertiseKeysReceived,
+ DeadAfterAdvertiseKeysReceived,
+ ShareKeysReceived,
+ DeadAfterShareKeysReceived,
+ MaskedInputResponseReceived,
+ DeadAfterMaskedInputResponseReceived,
+ UnmaskingResponseReceived,
+ DeadAfterUnmaskingResponseReceived,
+ Unknown,
+}
+
+enum TracingClientDropReason : short {
+ SentAbortMessage,
+ UnexpectedMessageType,
+ UnknownMessageType,
+ AdvertiseKeysUnexpected,
+ EmptyPublicKey,
+ NoAdvertiseKeys,
+ ShareKeysUnexpected,
+ WrongNumberOfKeyShares,
+ MissingKeyShare,
+ ExtraKeyShare,
+ NoShareKeys,
+ MaskedInputUnexpected,
+ InvalidMaskedInput,
+ NoMaskedInput,
+ UnmaskingResponseUnexpected,
+ InvalidUnmaskingResponse,
+ NoUnmaskingResponse,
+ InvalidPublicKey,
+ ServerProtocolAbortClient,
+ EarlySuccess,
+ ConnectionClosed,
+ InvalidShareKeys,
+ Unknown,
+}
+
+enum TracingSecAggServerOutcome : short {
+ ExternalRequest,
+ NotEnoughClientsRemaining,
+ UnhandledError,
+ Success,
+ Unknown,
+}
+
+// Spans
+// Span that records the lifetime of SecAggServer i.e from starting the SecAgg
+// protocol to its termination.
+table SecureAggServerSession(tag: "SASS", span) {}
+
+// Span that records the lifetime of each state within SecAggServer.
+table SecureAggServerState(tag: "SAST", span) {
+ // Name of the current SecAggServerState.
+ name: SecAggServerTraceState;
+}
+
+// Span that records the duration of SecAggServer::Create method.
+table CreateSecAggServer(tag: "CSAS", span) {}
+
+// Metric that records the parameters of the complete graph SecAggServer
+// instance.
+table FullGraphServerParameters(tag: "FGSP") {
+ number_of_clients: uint64;
+ minimum_number_of_clients_to_proceed: uint64;
+ is_r2_async_aggregation_enabled: bool;
+}
+
+// Metric that records the parameters of the SubGraph SecAggServer instance.
+table SubGraphServerParameters(tag: "SGSP") {
+ number_of_clients: uint64;
+ degree: uint64;
+ threshold: uint64;
+ minimum_number_of_clients_to_proceed: uint64;
+ is_r2_async_aggregation_enabled: bool;
+}
+
+// Span that records the duration of an external abort call to the SecAggServer.
+table AbortSecAggServer(tag: "ABSR", span) {
+ // Reason why the server is being aborted.
+ reason: string;
+}
+
+// Span that records the duration of an external abort client call.
+table AbortSecAggClient(tag: "ABCL", span) {
+ // Client id that needs to be aborted.
+ client_id: uint32;
+ // Reason why the client is being aborted.
+ reason: string;
+}
+
+// Span that records the duration of an external ProceedToNextRound call.
+table ProceedToNextSecAggRound(tag: "PTNR", span) {}
+
+// Span that records the duration of an external ReceiveMessage call.
+table ReceiveSecAggMessage(tag: "RCMS", span) {
+ // Client id that sent the message.
+ client_id: uint32;
+}
+
+// Span that records the duration of an external StartPrng call.
+table StartPrngForSecAgg(tag: "STPR", span) {}
+
+// Metrics
+// Metric that records the message sent by the SecAgg server to an individual
+// user.
+table IndividualMessageSent(tag: "IMSG") {
+ // Client id of the client receiving this message.
+ client_id: uint32;
+ // Type of message such as abort etc.
+ message_type: ServerToClientMessageType;
+ // Size of the message in bytes.
+ size: uint64;
+}
+
+// Metric that records the message broadcasted by the SecAgg server.
+table BroadcastMessageSent(tag: "BMSG") {
+ // Type of message such as abort etc.
+ message_type: ServerToClientMessageType;
+ // Size of the message in bytes.
+ size: uint64;
+}
+
+// Metric that records the message received by the SecAgg server from a user.
+table ClientMessageReceived(tag: "CMSG") {
+ // Type of message such as abort etc.
+ message_type: ClientToServerMessageType;
+ // Size of the message in bytes.
+ size: uint64;
+ // True, if message was expected from the client, false otherwise.
+ expected: bool;
+ // Elapsed time since the round started. 0, if the message was not expected.
+ elapsed_millis: uint64;
+}
+
+table Round2AsyncWorkScheduled(tag: "R2WS") {}
+
+// Metric that records the event of a queue of round 2 client messages being
+// taken by an asynchronous task.
+table Round2MessageQueueTaken(tag: "R2MT") {
+ // Queue length
+ queue_length: uint64;
+}
+
+// Metric that records the time taken to execute the PRF expansion step.
+table PrngExpansion(tag: "PRNG") {
+ // Time taken to complete the step (in milliseconds).
+ elapsed_millis: uint64;
+}
+
+// Metric that records the time taken to reconstruct all users' keys from their
+// Shamir secret shares.
+table ShamirReconstruction(tag: "SHRC") {
+ // Time taken to complete the step (in milliseconds).
+ elapsed_millis: uint64;
+}
+
+// Metric that records details about client drops during an execution of the
+// SecAgg protocol.
+table ClientsDropped(tag: "CLDR") {
+ // Status of the client when it was aborted.
+ client_status: TracingClientStatus;
+ // Reason for abort.
+ reason: TracingClientDropReason;
+ // Elapsed time since the round started.
+ elapsed_millis: uint64;
+ // Optional error message for the client to be aborted.
+ message: string;
+}
+
+// Metric that records the outcome of the SecAgg protocol
+table SecAggProtocolOutcome(tag: "SAPO") {
+ // Outcome of the protocol e.g SUCCESS means the protocol ran through all
+ // phases and completed.
+ outcome: TracingSecAggServerOutcome;
+}
+
+// Metric that records details about each state of the SecAgg protocol.
+table StateCompletion(tag: "ROCP") {
+ // Current SecAggServerState that the protocol was running within.
+ state: SecAggServerTraceState;
+ // True if current state successfully transitioned to the next state, false
+ // otherwise.
+ is_success: bool;
+ // E2E time (in milliseconds) spent in current state, starting from
+ // transitioning to that state and including waiting for the client messages
+ // necessary to transition to a next state.
+ elapsed_millis: uint64;
+ // Number of clients at the end of current state.
+ number_of_surviving_clients: uint64;
+}
+
+table ClientCountsPerState(tag: "CLPS") {
+ // Current SecAggServerState that the protocol was running within.
+ state: SecAggServerTraceState;
+ // Client status.
+ client_status: TracingClientStatus;
+ // Number of clients corresponding to the status above.
+ count: uint64;
+ // Fraction of clients corresponding to client_status above.
+ fraction: double;
+}
diff --git a/fcp/secagg/shared/BUILD b/fcp/secagg/shared/BUILD
new file mode 100644
index 0000000..213ef41
--- /dev/null
+++ b/fcp/secagg/shared/BUILD
@@ -0,0 +1,255 @@
+# Description:
+# SecAgg components shared between client and server.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp:internal"],
+)
+
+proto_library(
+ name = "proto",
+ srcs = ["secagg_messages.proto"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_google_protobuf//:any_proto",
+ ],
+)
+
+cc_proto_library(
+ name = "cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":proto",
+ ],
+)
+
+cc_library(
+ name = "shared",
+ srcs = [
+ "aes_ctr_prng.cc",
+ "aes_ctr_prng_factory.cc",
+ "aes_gcm_encryption.cc",
+ "aes_key.cc",
+ "compute_session_id.cc",
+ "crypto_rand_prng.cc",
+ "ecdh_key_agreement.cc",
+ "input_vector_specification.cc",
+ "map_of_masks.cc",
+ "secagg_vector.cc",
+ "shamir_secret_sharing.cc",
+ ],
+ hdrs = [
+ "aes_ctr_prng.h",
+ "aes_ctr_prng_factory.h",
+ "aes_gcm_encryption.h",
+ "aes_key.h",
+ "aes_prng_factory.h",
+ "async_abort.h",
+ "compute_session_id.h",
+ "crypto_rand_prng.h",
+ "ecdh_key_agreement.h",
+ "ecdh_keys.h",
+ "input_vector_specification.h",
+ "key.h",
+ "map_of_masks.h",
+ "math.h",
+ "prng.h",
+ "secagg_vector.h",
+ "shamir_secret_sharing.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":cc_proto",
+ "//fcp/base",
+ "@boringssl//:crypto",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/base:endian",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/numeric:bits",
+ "@com_google_absl//absl/numeric:int128",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_test(
+ name = "aes_gcm_encryption_test",
+ size = "small",
+ srcs = [
+ "aes_gcm_encryption_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "aes_key_test",
+ size = "small",
+ srcs = [
+ "aes_key_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "aes_prng_tests",
+ size = "small",
+ srcs = [
+ "aes_ctr_prng_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "compute_session_id_test",
+ size = "small",
+ srcs = [
+ "compute_session_id_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "ecdh_test",
+ size = "small",
+ srcs = ["ecdh_key_agreement_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "//fcp/secagg/testing:common_mocks",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "input_vector_specification_test",
+ size = "small",
+ srcs = [
+ "input_vector_specification_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "map_of_masks_test",
+ size = "small",
+ srcs = [
+ "map_of_masks_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/numeric:bits",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "map_of_masks_bench",
+ size = "large",
+ srcs = [
+ "map_of_masks_bench.cc",
+ ],
+ copts = FCP_COPTS,
+ linkstatic = 1,
+ deps = [
+ ":shared",
+ "@com_google_absl//absl/numeric:bits",
+ "@com_google_absl//absl/strings",
+ "@com_google_benchmark//:benchmark_main",
+ ],
+)
+
+cc_test(
+ name = "math_test",
+ size = "small",
+ srcs = [
+ "math_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "secagg_vector_test",
+ size = "large",
+ srcs = [
+ "secagg_vector_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "secagg_vector_bench",
+ size = "large",
+ srcs = [
+ "secagg_vector_bench.cc",
+ ],
+ copts = FCP_COPTS,
+ linkstatic = 1,
+ deps = [
+ ":shared",
+ "@com_google_benchmark//:benchmark_main",
+ ],
+)
+
+cc_test(
+ name = "shamir_secret_sharing_test",
+ size = "small",
+ srcs = [
+ "shamir_secret_sharing_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "//fcp/secagg/testing:common_mocks",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_binary(
+ name = "add_maps_bench",
+ testonly = 1,
+ srcs = [
+ "add_maps_bench.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":shared",
+ "@com_google_benchmark//:benchmark_main",
+ ],
+)
diff --git a/fcp/secagg/shared/add_maps_bench.cc b/fcp/secagg/shared/add_maps_bench.cc
new file mode 100644
index 0000000..92ef209
--- /dev/null
+++ b/fcp/secagg/shared/add_maps_bench.cc
@@ -0,0 +1,75 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "fcp/secagg/shared/map_of_masks.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+// Open-source version of benchmarking library
+#include "benchmark//benchmark.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+// Open-source version of benchmarking library
+using benchmark::internal::Benchmark;
+
+// This function produces varied pairs of {bit_width, size} for the benchmark.
+static void CustomArguments(Benchmark* b) {
+ constexpr int bit_widths[] = {8, 25, 38, 53,
+ absl::bit_width(SecAggVector::kMaxModulus - 1)};
+ for (int bit_width : bit_widths) {
+ for (int size = 32; size <= 32 * 1024 * 1024; size *= 32) {
+ b->ArgPair(bit_width, size);
+ }
+ }
+}
+
+std::unique_ptr<SecAggVectorMap> MakeMap(int64_t bit_width, int64_t size,
+ uint64_t start, uint64_t step) {
+ std::vector<uint64_t> vec;
+ vec.resize(size);
+
+ uint64_t modulus = 1ULL << bit_width;
+ uint64_t v = start;
+ for (int64_t i = 0; i < size; i++) {
+ vec[i] = v;
+ v = (v + step) % modulus;
+ }
+
+ auto map = std::make_unique<SecAggVectorMap>();
+ map->emplace("test", SecAggVector(vec, modulus));
+ return map;
+}
+
+void BM_AddMaps(benchmark::State& state) {
+ auto map_a = MakeMap(state.range(0), state.range(1), 1, 1);
+ auto map_b = MakeMap(state.range(0), state.range(1), 2, 3);
+ for (auto _ : state) {
+ auto map_sum = AddMaps(*map_a, *map_b);
+ benchmark::DoNotOptimize(map_sum);
+ }
+}
+
+BENCHMARK(BM_AddMaps)->Apply(CustomArguments);
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/aes_ctr_prng.cc b/fcp/secagg/shared/aes_ctr_prng.cc
new file mode 100644
index 0000000..e24c701
--- /dev/null
+++ b/fcp/secagg/shared/aes_ctr_prng.cc
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/aes_ctr_prng.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/prng.h"
+#include "openssl/cipher.h"
+#include "openssl/evp.h"
+
+namespace fcp {
+namespace secagg {
+
+AesCtrPrng::AesCtrPrng(const AesKey& seed) {
+ uint8_t iv[kIvSize];
+ memset(iv, 0, kIvSize);
+ FCP_CHECK(ctx_ = EVP_CIPHER_CTX_new());
+
+ FCP_CHECK(1 == EVP_EncryptInit_ex(ctx_, EVP_aes_256_ctr(), nullptr,
+ seed.data(), iv));
+
+ // Initializing these to one past the end, in order to force a call to
+ // GenerateBytes on the first attempt to use each cache.
+ next_byte_pos_ = kCacheSize;
+ blocks_generated_ = 0;
+}
+
+AesCtrPrng::~AesCtrPrng() { EVP_CIPHER_CTX_free(ctx_); }
+
+void AesCtrPrng::GenerateBytes(uint8_t* cache, int cache_size) {
+ FCP_CHECK((cache_size % kBlockSize) == 0)
+ << "Number of bytes generated by AesCtrPrng must be a multiple of "
+ << kBlockSize;
+ FCP_CHECK(cache_size <= kCacheSize)
+ << "Requested number of bytes " << cache_size
+ << " exceeds maximum cache size " << kCacheSize;
+ FCP_CHECK(blocks_generated_ <= kMaxBlocks)
+ << "AesCtrPrng generated " << kMaxBlocks
+ << " blocks and needs a new seed.";
+ int bytes_written;
+ FCP_CHECK(
+ EVP_EncryptUpdate(ctx_, cache, &bytes_written, kAllZeroes, cache_size));
+ FCP_CHECK(bytes_written == cache_size);
+ blocks_generated_ += static_cast<size_t>(cache_size) / kBlockSize;
+}
+
+uint8_t AesCtrPrng::Rand8() {
+ if (next_byte_pos_ >= kCacheSize) {
+ GenerateBytes(cache_, kCacheSize);
+ next_byte_pos_ = 0;
+ }
+ // Return the next byte and then increment the position.
+ return cache_[next_byte_pos_++];
+}
+
+uint64_t AesCtrPrng::Rand64() {
+ uint64_t output = 0;
+ for (size_t i = 0; i < sizeof(uint64_t); ++i) {
+ output |= static_cast<uint64_t>(Rand8()) << 8 * i;
+ }
+ return output;
+}
+
+int AesCtrPrng::RandBuffer(uint8_t* buffer, int buffer_size) {
+ buffer_size = std::min(buffer_size, kCacheSize);
+ GenerateBytes(buffer, buffer_size);
+ return buffer_size;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/aes_ctr_prng.h b/fcp/secagg/shared/aes_ctr_prng.h
new file mode 100644
index 0000000..6049f3c
--- /dev/null
+++ b/fcp/secagg/shared/aes_ctr_prng.h
@@ -0,0 +1,106 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_AES_CTR_PRNG_H_
+#define FCP_SECAGG_SHARED_AES_CTR_PRNG_H_
+
+#include <cstdint>
+#include <string>
+
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/prng.h"
+#include "openssl/evp.h"
+
+namespace fcp {
+namespace secagg {
+
+// A cryptographically strong Deterministic Pseudorandom Number Generator based
+// on AES-CTR in OpenSSL. The seed must be supplied by the user.
+//
+// This code is used for the very specific purpose of generating *reproduceable*
+// numbers which appear pseudorandom to any party not in possession of the input
+// seed. DO NOT use AesCtrPrng in any situation where real randomness would be
+// useful, because it uses no entropy at all and will always produce the same
+// numbers if given the same seed.
+//
+// A single instance of AesCtrPrng can generate up to 2^36 bytes of
+// pseudorandom output. If more than 2^36 bytes of output are needed, multiple
+// instances of AesCtrPrng with different seeds should be used.
+//
+// This class is not thread-safe.
+
+class AesCtrPrng : public SecureBatchPrng {
+ public:
+ // This class should only be instantiated via AesCtrPrngFactory.
+ friend class AesCtrPrngFactory;
+
+ ~AesCtrPrng() override;
+
+ // Returns a new pseudorandom number of the specified size, generating new
+ // pseudorandom bytes as needed.
+ uint8_t Rand8() override;
+ uint64_t Rand64() override;
+
+ // Fills the provided buffer with pseudorandom bytes. Returns the number of
+ // bytes that has been generated, which can be smaller than the requested
+ // buffer_size if it exceeds the maximum buffer size returned by
+ // GetMaxBufferSize().
+ int RandBuffer(uint8_t* buffer, int buffer_size) override;
+
+ // Get the maximum size of a buffer that can be filled by RandBuffer() in a
+ // single call.
+ size_t GetMaxBufferSize() const override { return kCacheSize; }
+
+ private:
+ static constexpr size_t kIvSize = 16; // IV size, in bytes
+
+ // Constructs the PRNG with the given seed, and an IV of all zeroes.
+ // This is ONLY secure if the seed is never used more than once.
+ explicit AesCtrPrng(const AesKey& seed);
+
+ // Number of AES blocks in the cache.
+ // The number of blocks is optimized to make kCacheSize to be a multiple
+ // of any possible number of bytes in a SecAgg output (i.e. 1 to 8).
+ static constexpr size_t kBatchSize = 3 * 5 * 7;
+
+ // Block size, in bytes
+ static constexpr size_t kBlockSize = 16;
+
+ // Size of our cache, in bytes. We cache blocks to save leftover bytes.
+ static constexpr int kCacheSize = kBatchSize * kBlockSize;
+
+ // For security, we don't want to generate more than 2^32-1 blocks.
+ static constexpr size_t kMaxBlocks = 0xFFFFFFFF;
+
+ // Fills the selected cache with deterministic pseudorandomly generated bytes.
+ // After this, the associated next_byte_pos counter must be set to 0.
+ void GenerateBytes(uint8_t* cache, int cache_size);
+
+ // Cache used by both Rand8() and Rand64()
+ uint8_t cache_[kCacheSize];
+ size_t next_byte_pos_;
+
+ // This is used to generate bytes.
+ static constexpr uint8_t kAllZeroes[kCacheSize] = {0};
+
+ EVP_CIPHER_CTX* ctx_;
+ size_t blocks_generated_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_AES_CTR_PRNG_H_
diff --git a/fcp/secagg/shared/aes_ctr_prng_factory.cc b/fcp/secagg/shared/aes_ctr_prng_factory.cc
new file mode 100644
index 0000000..5d64280
--- /dev/null
+++ b/fcp/secagg/shared/aes_ctr_prng_factory.cc
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+
+#include <memory>
+#include "fcp/secagg/shared/aes_ctr_prng.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/prng.h"
+
+namespace fcp {
+namespace secagg {
+
+std::unique_ptr<SecurePrng> AesCtrPrngFactory::MakePrng(
+ const AesKey& key) const {
+ return std::unique_ptr<SecurePrng>(new AesCtrPrng(key));
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/aes_ctr_prng_factory.h b/fcp/secagg/shared/aes_ctr_prng_factory.h
new file mode 100644
index 0000000..faf4ed9
--- /dev/null
+++ b/fcp/secagg/shared/aes_ctr_prng_factory.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_AES_CTR_PRNG_FACTORY_H_
+#define FCP_SECAGG_SHARED_AES_CTR_PRNG_FACTORY_H_
+
+#include <memory>
+
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/prng.h"
+
+namespace fcp {
+namespace secagg {
+
+// Factory for the OpenSSL-based AesCtrPrng.
+class AesCtrPrngFactory : public AesPrngFactory {
+ public:
+ AesCtrPrngFactory() = default;
+
+ // Creates and returns an instance of AesCtrPrng, given an AES key.
+ // For security reasons, the key MUST be suitable for immediate use in AES,
+ // i.e. it must not be a shared ECDH secret that has not yet been hashed.
+ std::unique_ptr<SecurePrng> MakePrng(const AesKey& key) const override;
+
+ // TODO(team): Remove this when transition to the batch mode of
+ // SecurePrng is fully done.
+ bool SupportsBatchMode() const override { return true; }
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_AES_CTR_PRNG_FACTORY_H_
diff --git a/fcp/secagg/shared/aes_ctr_prng_test.cc b/fcp/secagg/shared/aes_ctr_prng_test.cc
new file mode 100644
index 0000000..5688d68
--- /dev/null
+++ b/fcp/secagg/shared/aes_ctr_prng_test.cc
@@ -0,0 +1,222 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/aes_ctr_prng.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Contains;
+using ::testing::Eq;
+using ::testing::Ne;
+using ::testing::Not;
+using ::testing::Pointwise;
+
+TEST(AesCtrPrngTest, Rand8ReturnsSameValuesGivenSameInputs) {
+ uint8_t seed_data[32];
+ memset(seed_data, '1', 32);
+ AesKey seed(seed_data);
+
+ AesCtrPrngFactory factory;
+ std::unique_ptr<SecurePrng> prng0 = factory.MakePrng(seed);
+ std::unique_ptr<SecurePrng> prng1 = factory.MakePrng(seed);
+ std::vector<uint8_t> output0;
+ std::vector<uint8_t> output1;
+ for (int i = 0; i < 16; ++i) {
+ output0.push_back(prng0->Rand8());
+ output1.push_back(prng1->Rand8());
+ }
+ EXPECT_THAT(output0, Eq(output1));
+}
+
+TEST(AesCtrPrngTest, Rand64ReturnsSameValuesGivenSameInputs) {
+ uint8_t seed_data[32];
+ memset(seed_data, '1', 32);
+ AesKey seed(seed_data);
+
+ AesCtrPrngFactory factory;
+ std::unique_ptr<SecurePrng> prng0 = factory.MakePrng(seed);
+ std::unique_ptr<SecurePrng> prng1 = factory.MakePrng(seed);
+ std::vector<uint64_t> output0;
+ std::vector<uint64_t> output1;
+ for (int i = 0; i < 16; ++i) {
+ output0.push_back(prng0->Rand64());
+ output1.push_back(prng1->Rand64());
+ }
+ EXPECT_THAT(output0, Eq(output1));
+}
+
+TEST(AesCtrPrngTest, MixedRandCallsReturnSameValuesGivenSameInputs) {
+ uint8_t seed_data[32];
+ memset(seed_data, '1', 32);
+ AesKey seed(seed_data);
+
+ AesCtrPrngFactory factory;
+ std::unique_ptr<SecurePrng> prng0 = factory.MakePrng(seed);
+ std::unique_ptr<SecurePrng> prng1 = factory.MakePrng(seed);
+ std::vector<uint64_t> output0;
+ std::vector<uint64_t> output1;
+ for (int i = 0; i < 5; ++i) {
+ output0.push_back(prng0->Rand8());
+ output1.push_back(prng1->Rand8());
+ }
+ for (int i = 0; i < 5; ++i) {
+ output0.push_back(prng0->Rand64());
+ output1.push_back(prng1->Rand64());
+ }
+ for (int i = 0; i < 10; ++i) {
+ output0.push_back(prng0->Rand8());
+ output1.push_back(prng1->Rand8());
+ }
+ EXPECT_THAT(output0, Eq(output1));
+}
+
+// While for random seeds or IVs there would be a very small chance of
+// duplication, these tests are not flaky because this PRNG is deterministic.
+TEST(AesCtrPrngTest, DifferentSeedsGenerateDifferentValues) {
+ uint8_t seed_data[32];
+ memset(seed_data, '1', 32);
+ AesKey seed1(seed_data);
+ memset(seed_data, '3', 32);
+ AesKey seed2(seed_data);
+
+ AesCtrPrngFactory factory;
+ std::unique_ptr<SecurePrng> prng0 = factory.MakePrng(seed1);
+ std::unique_ptr<SecurePrng> prng1 = factory.MakePrng(seed2);
+ std::vector<uint64_t> output0;
+ std::vector<uint64_t> output1;
+ for (int i = 0; i < 16; ++i) {
+ output0.push_back(prng0->Rand64());
+ output1.push_back(prng1->Rand64());
+ }
+ // output0 differs from output1 at every point
+ EXPECT_THAT(output0, Pointwise(Ne(), output1));
+}
+
+TEST(AesCtrPrngTest, DoesntGenerateRepeatedValues) {
+ uint8_t seed_data[32];
+ memset(seed_data, '1', 32);
+ AesKey seed(seed_data);
+
+ AesCtrPrngFactory factory;
+ std::unique_ptr<SecurePrng> prng = factory.MakePrng(seed);
+ std::vector<uint64_t> output;
+ uint64_t val;
+ for (int i = 0; i < 16; ++i) {
+ val = prng->Rand64();
+ EXPECT_THAT(output, Not(Contains(val)));
+ output.push_back(val);
+ }
+}
+
+TEST(AesCtrPrngTest, GeneratesExpectedValues) {
+ uint8_t iv[16];
+ memset(iv, 0, sizeof(iv));
+
+ uint8_t seed_data[32];
+ memset(seed_data, '1', sizeof(seed_data));
+ AesKey seed(seed_data);
+
+ EVP_CIPHER_CTX* ctx;
+ ctx = EVP_CIPHER_CTX_new();
+ ASSERT_THAT(ctx, Ne(nullptr));
+
+ ASSERT_THAT(
+ EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), nullptr, seed_data, iv),
+ Eq(1));
+
+ const int kBlockSize = 16 * 32;
+
+ static constexpr uint8_t zeroes[kBlockSize] = {0};
+
+ // These are processed separately in the class
+ uint8_t expected_uint8_t[kBlockSize];
+ uint8_t expected_uint64_t[kBlockSize];
+
+ // Obtain the ciphertext incrementally to verify identical output of versions
+ // using a different block size.
+ int len;
+ for (auto i = 0; i < kBlockSize; i += 16) {
+ ASSERT_THAT(EVP_EncryptUpdate(ctx, &expected_uint8_t[i], &len, zeroes, 16),
+ Ne(0));
+ ASSERT_THAT(len, 16);
+ }
+ for (auto i = 0; i < kBlockSize; i += 16) {
+ ASSERT_THAT(EVP_EncryptUpdate(ctx, &expected_uint64_t[i], &len, zeroes, 16),
+ Ne(0));
+ ASSERT_THAT(len, 16);
+ }
+
+ AesCtrPrngFactory factory;
+ std::unique_ptr<SecurePrng> prng = factory.MakePrng(seed);
+
+ for (int i = 0; i < sizeof(expected_uint8_t); i++) {
+ EXPECT_THAT(prng->Rand8(), Eq(expected_uint8_t[i]));
+ }
+ for (int i = 0; i < sizeof(expected_uint64_t) / sizeof(uint64_t); i++) {
+ uint64_t value = 0;
+ for (int j = 0; j < sizeof(uint64_t); j++) {
+ value |=
+ static_cast<uint64_t>(expected_uint64_t[i * sizeof(uint64_t) + j])
+ << (8 * j);
+ }
+ EXPECT_THAT(prng->Rand64(), Eq(value));
+ }
+ EVP_CIPHER_CTX_free(ctx);
+}
+
+TEST(AesCtrPrngTest, RandBufferIsConsistentWithRand8) {
+ uint8_t seed_data[32];
+ memset(seed_data, '1', 32);
+ AesKey seed(seed_data);
+
+ AesCtrPrngFactory factory1;
+ AesCtrPrngFactory factory2;
+ std::unique_ptr<SecurePrng> prng1 = factory1.MakePrng(seed);
+ std::unique_ptr<SecurePrng> prng2 = factory2.MakePrng(seed);
+ auto batch_prng = static_cast<SecureBatchPrng*>(prng2.get());
+
+ constexpr int kSize = 16000;
+ std::vector<uint8_t> output1(kSize);
+ std::vector<uint8_t> output2(kSize);
+
+ // Fill output1 using Rand8
+ for (int i = 0; i < kSize; ++i) {
+ output1[i] = prng1->Rand8();
+ }
+
+ // Fill output2 using RandBuffer
+ int bytes_received = 0;
+ while (bytes_received < kSize) {
+ bytes_received += batch_prng->RandBuffer(output2.data() + bytes_received,
+ kSize - bytes_received);
+ }
+
+ // output1 and output2 should be the same.
+ EXPECT_THAT(output1, Eq(output2));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/aes_gcm_encryption.cc b/fcp/secagg/shared/aes_gcm_encryption.cc
new file mode 100644
index 0000000..5c9f2af
--- /dev/null
+++ b/fcp/secagg/shared/aes_gcm_encryption.cc
@@ -0,0 +1,92 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/aes_gcm_encryption.h"
+
+#include <cstdint>
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/prng.h"
+#include "openssl/cipher.h"
+#include "openssl/evp.h"
+#include "openssl/rand.h"
+
+namespace fcp {
+namespace secagg {
+
+constexpr int kIvSize = 12;
+constexpr int kTagSize = 16;
+
+AesGcmEncryption::AesGcmEncryption() {}
+
+std::string AesGcmEncryption::Encrypt(const AesKey& key,
+ const std::string& plaintext) {
+ FCP_CHECK(key.size() != 0) << "Encrypt called with blank key.";
+ FCP_CHECK(key.size() == AesKey::kSize)
+ << "Encrypt called with key of " << key.size()
+ << " bytes, but 32 bytes are required.";
+ std::vector<uint8_t> ciphertext_buffer;
+ ciphertext_buffer.resize(kIvSize + plaintext.length() + kTagSize);
+ FCP_CHECK(RAND_bytes(ciphertext_buffer.data(), kIvSize));
+
+ // ScopedEVP_AEAD_CTX will automatically call EVP_AEAD_CTX_cleanup when going
+ // out of scope.
+ bssl::ScopedEVP_AEAD_CTX ctx;
+ FCP_CHECK(EVP_AEAD_CTX_init(ctx.get(), EVP_aead_aes_256_gcm(),
+ const_cast<uint8_t*>(key.data()), key.size(),
+ EVP_AEAD_DEFAULT_TAG_LENGTH, nullptr) == 1);
+ size_t len;
+ FCP_CHECK(EVP_AEAD_CTX_seal(
+ ctx.get(), ciphertext_buffer.data() + kIvSize, &len,
+ plaintext.size() + kTagSize, ciphertext_buffer.data(), kIvSize,
+ reinterpret_cast<const uint8_t*>(plaintext.c_str()),
+ plaintext.size(), nullptr, 0) == 1);
+ return std::string(ciphertext_buffer.begin(), ciphertext_buffer.end());
+}
+
+StatusOr<std::string> AesGcmEncryption::Decrypt(const AesKey& key,
+ const std::string& ciphertext) {
+ FCP_CHECK(key.size() != 0) << "Decrypt called with blank key.";
+ FCP_CHECK(key.size() == AesKey::kSize)
+ << "Decrypt called with key of " << key.size()
+ << " bytes, but 32 bytes are required.";
+ if (ciphertext.size() < kIvSize + kTagSize) {
+ return FCP_STATUS(DATA_LOSS) << "Ciphertext is too short.";
+ }
+ size_t len;
+ std::vector<uint8_t> plaintext_buffer;
+ plaintext_buffer.resize(ciphertext.size() - kIvSize - kTagSize);
+
+ // ScopedEVP_AEAD_CTX will automatically call EVP_AEAD_CTX_cleanup when going
+ // out of scope.
+ bssl::ScopedEVP_AEAD_CTX ctx;
+ FCP_CHECK(EVP_AEAD_CTX_init(ctx.get(), EVP_aead_aes_256_gcm(),
+ const_cast<uint8_t*>(key.data()), key.size(),
+ EVP_AEAD_DEFAULT_TAG_LENGTH, nullptr) == 1);
+ if (EVP_AEAD_CTX_open(
+ ctx.get(), plaintext_buffer.data(), &len, plaintext_buffer.size(),
+ reinterpret_cast<const uint8_t*>(ciphertext.data()), kIvSize,
+ reinterpret_cast<const uint8_t*>(ciphertext.data() + kIvSize),
+ ciphertext.size() - kIvSize, nullptr, 0) != 1) {
+ return FCP_STATUS(DATA_LOSS) << "Verification of ciphertext failed.";
+ }
+ return std::string(plaintext_buffer.begin(), plaintext_buffer.end());
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/aes_gcm_encryption.h b/fcp/secagg/shared/aes_gcm_encryption.h
new file mode 100644
index 0000000..f5dc97d
--- /dev/null
+++ b/fcp/secagg/shared/aes_gcm_encryption.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_AES_GCM_ENCRYPTION_H_
+#define FCP_SECAGG_SHARED_AES_GCM_ENCRYPTION_H_
+
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "openssl/evp.h"
+
+namespace fcp {
+namespace secagg {
+
+// A class to handle encryption and decryption using AES-256-GCM.
+// This class is NOT thread-safe.
+class AesGcmEncryption {
+ public:
+ AesGcmEncryption();
+
+ // Encrypts the plaintext with the given key, using AES-256-GCM. Prepends an
+ // IV randomly generated with the given prng to the ciphertext, and appends
+ // the AES-GCM tag.
+ std::string Encrypt(const AesKey& key, const std::string& plaintext);
+
+ // Decrypts the plaintext with the given key, using AES-256-GCM. Expects the
+ // IV to be prepended to the ciphertext, and the tag to be appended. If the
+ // tag does not authenticate, returns a DATA_LOSS error status.
+ StatusOr<std::string> Decrypt(const AesKey& key,
+ const std::string& ciphertext);
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_AES_GCM_ENCRYPTION_H_
diff --git a/fcp/secagg/shared/aes_gcm_encryption_test.cc b/fcp/secagg/shared/aes_gcm_encryption_test.cc
new file mode 100644
index 0000000..d504bde
--- /dev/null
+++ b/fcp/secagg/shared/aes_gcm_encryption_test.cc
@@ -0,0 +1,171 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/aes_gcm_encryption.h"
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/secagg/shared/aes_key.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::Ne;
+
+// For testing purposes, make an AesKey out of a string.
+AesKey MakeAesKey(const std::string& key) {
+ EXPECT_THAT(key.size(), Eq(AesKey::kSize));
+ return AesKey(reinterpret_cast<const uint8_t*>(key.c_str()));
+}
+
+TEST(AesGcmEncryptionTest, EncryptionThenDecryptionWorks) {
+ AesGcmEncryption aes;
+ AesKey key = MakeAesKey("Just some random 32 byte AES key");
+ std::string test_str = "This is a test. Should work on arbitrary strings.";
+
+ std::string ciphertext = aes.Encrypt(key, test_str);
+ StatusOr<std::string> plaintext = aes.Decrypt(key, ciphertext);
+ ASSERT_TRUE(plaintext.ok());
+ EXPECT_THAT(plaintext.value(), Eq(test_str));
+}
+
+TEST(AesGcmEncryptionTest, MultipleOperationsWithSameObjectWork) {
+ AesGcmEncryption aes;
+ AesKey key1 = MakeAesKey("Just some random 32 byte AES key");
+ AesKey key2 = MakeAesKey("A different 32-byte-long AES key");
+ std::string test_str1 = "This is a test. Should work on arbitrary strings.";
+ std::string test_str2 = "Another test string.";
+
+ std::string ciphertext1 = aes.Encrypt(key1, test_str1);
+ std::string ciphertext2 = aes.Encrypt(key2, test_str2);
+ StatusOr<std::string> plaintext1 = aes.Decrypt(key1, ciphertext1);
+ StatusOr<std::string> plaintext2 = aes.Decrypt(key2, ciphertext2);
+ ASSERT_TRUE(plaintext1.ok());
+ EXPECT_THAT(plaintext1.value(), Eq(test_str1));
+ ASSERT_TRUE(plaintext2.ok());
+ EXPECT_THAT(plaintext2.value(), Eq(test_str2));
+}
+
+TEST(AesGcmEncryptionTest, EncryptionsWithDifferentKeysAreDifferent) {
+ AesGcmEncryption aes;
+ AesKey key1 = MakeAesKey("Just some random 32 byte AES key");
+ AesKey key2 = MakeAesKey("A different 32-byte-long AES key");
+ std::string test_str = "This is a test. Should work on arbitrary strings.";
+
+ std::string ciphertext1 = aes.Encrypt(key1, test_str);
+ std::string ciphertext2 = aes.Encrypt(key2, test_str);
+ EXPECT_THAT(ciphertext1, Ne(ciphertext2));
+ StatusOr<std::string> plaintext1 = aes.Decrypt(key1, ciphertext1);
+ StatusOr<std::string> plaintext2 = aes.Decrypt(key2, ciphertext2);
+ ASSERT_TRUE(plaintext1.ok());
+ EXPECT_THAT(plaintext1.value(), Eq(test_str));
+ ASSERT_TRUE(plaintext2.ok());
+ EXPECT_THAT(plaintext2.value(), Eq(test_str));
+}
+
+TEST(AesGcmEncryptionTest, VerificationFailsOnBadTag) {
+ AesGcmEncryption aes;
+ AesKey key = MakeAesKey("Just some random 32 byte AES key");
+ std::string test_str = "This is a test. Should work on arbitrary strings.";
+
+ std::string ciphertext = aes.Encrypt(key, test_str);
+ ciphertext[ciphertext.size() - 1] = 'X';
+ StatusOr<std::string> plaintext = aes.Decrypt(key, ciphertext);
+ EXPECT_THAT(plaintext.ok(), Eq(false));
+}
+
+TEST(AesGcmEncryptionTest, VerificationFailsOnBadCiphertext) {
+ AesGcmEncryption aes;
+ AesKey key = MakeAesKey("Just some random 32 byte AES key");
+ std::string test_str = "This is a test. Should work on arbitrary strings.";
+
+ std::string ciphertext = aes.Encrypt(key, test_str);
+ for (int i = 0; i < ciphertext.size(); i++) {
+ // modify every bit of the ciphertext
+ for (int j = 0; j < 8; j++) {
+ ciphertext[i] ^= (1 << j);
+
+ StatusOr<std::string> plaintext = aes.Decrypt(key, ciphertext);
+ EXPECT_THAT(plaintext.ok(), Eq(false));
+
+ // reset the ciphertext
+ ciphertext[i] ^= (1 << j);
+ }
+ }
+}
+
+TEST(AesGcmEncryptionTest, VerificationFailsOnWrongKey) {
+ AesGcmEncryption aes;
+ AesKey key = MakeAesKey("Just some random 32 byte AES key");
+ AesKey key2 = MakeAesKey("A different 32-byte-long AES key");
+ std::string test_str = "This is a test. Should work on arbitrary strings.";
+
+ std::string ciphertext = aes.Encrypt(key, test_str);
+ StatusOr<std::string> plaintext = aes.Decrypt(key2, ciphertext);
+ EXPECT_THAT(plaintext.ok(), Eq(false));
+}
+
+TEST(AesGcmEncryptionTest, EncryptionDiesOnEmptyKey) {
+ AesGcmEncryption aes;
+ std::string test_str = "This is a test. Should work on arbitrary strings.";
+
+ EXPECT_DEATH(aes.Encrypt(AesKey(), test_str),
+ "Encrypt called with blank key.");
+}
+
+TEST(AesGcmEncryptionTest, DecryptionDiesOnEmptyKey) {
+ AesGcmEncryption aes;
+ AesKey key = MakeAesKey("Just some random 32 byte AES key");
+ std::string test_str = "This is a test. Should work on arbitrary strings.";
+
+ std::string ciphertext = aes.Encrypt(key, test_str);
+ EXPECT_DEATH(aes.Decrypt(AesKey(), ciphertext).IgnoreError(),
+ "Decrypt called with blank key.");
+}
+TEST(AesGcmEncryptionTest, EncryptionDiesOnShortKey) {
+ AesGcmEncryption aes;
+ std::string test_str = "This is a test. Should work on arbitrary strings.";
+
+ std::string bad_key_input = "only 16 byte key";
+ EXPECT_DEATH(
+ aes.Encrypt(
+ AesKey(reinterpret_cast<const uint8_t*>(bad_key_input.c_str()), 16),
+ test_str),
+ "Encrypt called with key of 16 bytes, but 32 bytes are required.");
+}
+
+TEST(AesGcmEncryptionTest, DecryptionDiesOnShortKey) {
+ AesGcmEncryption aes;
+ AesKey key = MakeAesKey("Just some random 32 byte AES key");
+ std::string test_str = "This is a test. Should work on arbitrary strings.";
+
+ std::string ciphertext = aes.Encrypt(key, test_str);
+ std::string bad_key_input = "short 17 byte key";
+ EXPECT_DEATH(
+ aes.Decrypt(
+ AesKey(reinterpret_cast<const uint8_t*>(bad_key_input.c_str()),
+ 17),
+ ciphertext)
+ .IgnoreError(),
+ "Decrypt called with key of 17 bytes, but 32 bytes are required.");
+}
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/aes_key.cc b/fcp/secagg/shared/aes_key.cc
new file mode 100644
index 0000000..721399f
--- /dev/null
+++ b/fcp/secagg/shared/aes_key.cc
@@ -0,0 +1,80 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/secagg/shared/aes_key.h"
+
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+static constexpr int kLegacyKeySize = 17;
+
+namespace fcp {
+namespace secagg {
+
+AesKey::AesKey(const uint8_t* data, int key_size) : Key(data, key_size) {
+ FCP_CHECK((key_size > 0 && key_size <= 17) || (key_size == 32));
+}
+
+StatusOr<AesKey> AesKey::CreateFromShares(
+ const std::vector<ShamirShare>& shares, int threshold) {
+ ShamirSecretSharing reconstructor;
+ // TODO(team): Once Java support is removed, assume 32 byte keys.
+ int key_length = 0;
+ // For compatibility, we need to know if the key that was shared was 128 or
+ // 256 bits long. It can only have been one of those two lengths, so the
+ // shares should be either 20 or 36 bytes long respectively.
+ for (int i = 0; i < shares.size() && key_length == 0; ++i) {
+ if (shares[i].data.size() == 36) {
+ key_length = kSize;
+ } else if (shares[i].data.size() == 20) {
+ key_length = kLegacyKeySize; // May be 17 bytes or shorter, see below
+ } else {
+ // Key share must be missing if it's not one of those lengths.
+ FCP_CHECK(shares[i].data.empty());
+ }
+ }
+ FCP_CHECK(key_length != 0);
+ std::string reconstructed;
+ FCP_ASSIGN_OR_RETURN(
+ reconstructed, reconstructor.Reconstruct(threshold, shares, key_length));
+
+ if (key_length == kLegacyKeySize) {
+ // The key produced on Java side normally has 16 bytes, however when
+ // exporting the key from BigInteger to byte array an extra zero byte is
+ // added at the front if the high-order bit was '1' to indicate that the
+ // BigInteger value was positive (to avoid treating the high order bit
+ // as the sign bit). However the byte array may also be shorter than
+ // 16 bytes if the BigInteger value was smaller.
+ // For compatibility with Java behavior any leading zero byte that isn't
+ // followed by a byte with '1' in the high-order bit need to be removed.
+ int index = 0;
+ while (index < kLegacyKeySize - 1 &&
+ static_cast<uint8_t>(reconstructed[index]) == 0 &&
+ static_cast<uint8_t>(reconstructed[index + 1]) <= 127) {
+ index++;
+ }
+
+ if (index > 0) {
+ reconstructed.erase(0, index);
+ key_length -= index;
+ }
+ }
+ return AesKey(reinterpret_cast<const uint8_t*>(reconstructed.c_str()),
+ key_length);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/aes_key.h b/fcp/secagg/shared/aes_key.h
new file mode 100644
index 0000000..d43f183
--- /dev/null
+++ b/fcp/secagg/shared/aes_key.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_AES_KEY_H_
+#define FCP_SECAGG_SHARED_AES_KEY_H_
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/key.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+// A Key specifically intended for use with AES symmetric encryption.
+// Keys originating on Java clients are 17 bytes or shorter (typically
+// 16 or 17 bytes, but sometimes shorter).
+// Keys originating on C++ clients must have 32 bytes.
+// A 0-byte key should not be used for anything, and represents the absence of
+// a key in a collection of keys.
+//
+class AesKey : public Key {
+ public:
+ static constexpr int kSize = 32; // Expected key size for AES-256
+
+ // The key is blank.
+ AesKey() : Key() {}
+
+ // The key is a standard-size 32 byte key.
+ explicit AesKey(const uint8_t* data, int key_size = kSize);
+
+ // Create a key by reconstructing it from key shares. Length depends on the
+ // key shares, and may not be 32 bytes. Threshold is the threshold used when
+ // the secret was shared, i.e. the minimum number of clients to reconstruct.
+ static StatusOr<AesKey> CreateFromShares(
+ const std::vector<ShamirShare>& shares, int threshold);
+};
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_AES_KEY_H_
diff --git a/fcp/secagg/shared/aes_key_test.cc b/fcp/secagg/shared/aes_key_test.cc
new file mode 100644
index 0000000..e5ed0ee
--- /dev/null
+++ b/fcp/secagg/shared/aes_key_test.cc
@@ -0,0 +1,95 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/secagg/shared/aes_key.h"
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/secagg/shared/math.h"
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+
+// For testing purposes, make an AesKey out of a string.
+AesKey AesKeyFromString(const std::string& key) {
+ return AesKey(reinterpret_cast<const uint8_t*>(key.c_str()),
+ static_cast<int>(key.size()));
+}
+
+// Suppose the randomly chosen key can be expressed in bit_length <= 128 bits.
+// Java did not express the key in 128 bits, but rather will have used
+// (bit_length + 1) bits. The extra bit is the highest-order bit, and is a sign
+// bit, guaranteed to be 0. The next highest-order bit is guaranteed to be 1.
+std::string JavaStyleKey(int bit_length) {
+ EXPECT_TRUE(0 < bit_length && bit_length <= 128);
+ std::string key = "16 byte test key";
+ int byte_with_sign_bit = (127 - bit_length) / 8;
+ int pos_of_sign_bit = (127 - bit_length) % 8;
+ if (bit_length == 128) {
+ pos_of_sign_bit = 7;
+ key = absl::StrCat("\0", key);
+ } else {
+ key.erase(0, byte_with_sign_bit);
+ }
+ // Make sure the high-order bit is the sign bit 0, and the next highest-order
+ // bit is 1.
+ key[0] = static_cast<char>(127 >> pos_of_sign_bit);
+ if (pos_of_sign_bit == 7) {
+ key[1] = static_cast<char>(255);
+ }
+ return key;
+}
+
+TEST(AesKeyTest, CreateFromSharesHandles32BKeys) {
+ AesKey original_key = AesKeyFromString("32 byte AES key for testing only");
+ ShamirSecretSharing shamir;
+ auto shares = shamir.Share(5, 7, original_key);
+ auto key_or_error = AesKey::CreateFromShares(shares, 5);
+ EXPECT_THAT(key_or_error.ok(), Eq(true));
+ EXPECT_THAT(key_or_error.value(), Eq(original_key));
+}
+
+TEST(AesKeyTest, CreateFromSharesHandlesShortKeys) {
+ ShamirSecretSharing shamir;
+ for (int i = 1; i <= 128; ++i) {
+ std::string original_key_string = JavaStyleKey(i);
+ AesKey original_key = AesKeyFromString(original_key_string);
+ std::string key_string_for_sharing;
+ if (original_key_string.size() < 16) {
+ key_string_for_sharing =
+ absl::StrCat(std::string(16 - original_key_string.size(), '\0'),
+ original_key_string);
+ } else if (original_key_string.size() == 17) {
+ key_string_for_sharing = original_key_string.substr(1);
+ } else {
+ key_string_for_sharing = original_key_string;
+ }
+ auto shares = shamir.Share(5, 7, AesKeyFromString(key_string_for_sharing));
+ auto key_or_error = AesKey::CreateFromShares(shares, 5);
+ EXPECT_THAT(key_or_error.ok(), Eq(true));
+ EXPECT_THAT(key_or_error.value(), Eq(original_key))
+ << i << " bit key fails";
+ }
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/aes_prng_factory.h b/fcp/secagg/shared/aes_prng_factory.h
new file mode 100644
index 0000000..0550b82
--- /dev/null
+++ b/fcp/secagg/shared/aes_prng_factory.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_AES_PRNG_FACTORY_H_
+#define FCP_SECAGG_SHARED_AES_PRNG_FACTORY_H_
+
+#include <memory>
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/prng.h"
+
+namespace fcp {
+namespace secagg {
+
+// Factory interface for AES-based deterministic pseudorandom number generators.
+class AesPrngFactory {
+ public:
+ virtual ~AesPrngFactory() = default;
+ virtual std::unique_ptr<SecurePrng> MakePrng(const AesKey& key) const = 0;
+ // TODO(team): Remove this when transition to the batch mode of
+ // SecurePrng is fully done.
+ // The batch mode allows to retrive a large batch of preuso-random numbers
+ // in a single call.
+ virtual bool SupportsBatchMode() const { return false; }
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_AES_PRNG_FACTORY_H_
diff --git a/fcp/secagg/shared/async_abort.h b/fcp/secagg/shared/async_abort.h
new file mode 100644
index 0000000..5934806
--- /dev/null
+++ b/fcp/secagg/shared/async_abort.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_ASYNC_ABORT_H_
+#define FCP_SECAGG_SHARED_ASYNC_ABORT_H_
+
+#include <atomic>
+#include <string>
+
+#include "absl/base/attributes.h"
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace secagg {
+
+// A helper to allow polling for asynchronous aborts. For ease of testing, this
+// class does not manage its own atomic, which allows the atomic to be easily
+// allocated on its own page.
+//
+// This class is thread-safe.
+class AsyncAbort {
+ public:
+ explicit AsyncAbort(std::atomic<std::string*>* signal)
+ : signal_(signal), mu_() {
+ FCP_CHECK(signal_);
+ }
+ virtual ~AsyncAbort() = default;
+
+ // AsyncAbort is neither copyable nor movable.
+ AsyncAbort(const AsyncAbort&) = delete;
+ AsyncAbort& operator=(const AsyncAbort&) = delete;
+
+ // Signal an async. abort. The abort message may not be reflected in
+ // SecAggClient if it has already transitioned to a terminal state (aborted
+ // or completed).
+ void Abort(std::string message) {
+ absl::WriterMutexLock _(&mu_);
+ message_ = message;
+ *signal_ = &message_;
+ }
+
+ // Returns whether the abort signal is raised.
+ ABSL_MUST_USE_RESULT bool Signalled() const {
+ return signal_->load(std::memory_order_relaxed);
+ }
+
+ // Returns the abort message specified by the abort signal.
+ // If Signalled() returns false, the value is undefined.
+ ABSL_MUST_USE_RESULT std::string Message() const {
+ absl::ReaderMutexLock _(&mu_);
+ return **signal_;
+ }
+
+ std::atomic<std::string*>* signal_;
+ mutable absl::Mutex mu_;
+ std::string message_ ABSL_GUARDED_BY(mu_);
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_ASYNC_ABORT_H_
diff --git a/fcp/secagg/shared/compute_session_id.cc b/fcp/secagg/shared/compute_session_id.cc
new file mode 100644
index 0000000..adea423
--- /dev/null
+++ b/fcp/secagg/shared/compute_session_id.cc
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/secagg/shared/compute_session_id.h"
+
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/math.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+#include "openssl/evp.h"
+
+namespace fcp {
+namespace secagg {
+
+SessionId ComputeSessionId(const ShareKeysRequest& request) {
+ EVP_MD_CTX* ctx;
+ FCP_CHECK(ctx = EVP_MD_CTX_create());
+ FCP_CHECK(EVP_DigestInit_ex(ctx, EVP_sha256(), nullptr));
+ for (const PairOfPublicKeys& keys : request.pairs_of_public_keys()) {
+ int noise_pk_size = keys.noise_pk().size();
+ std::string noise_pk_size_data = IntToByteString(noise_pk_size);
+ int enc_pk_size = keys.enc_pk().size();
+ std::string enc_pk_size_data = IntToByteString(noise_pk_size);
+ FCP_CHECK(EVP_DigestUpdate(ctx, noise_pk_size_data.c_str(), sizeof(int)));
+ FCP_CHECK(EVP_DigestUpdate(ctx, keys.noise_pk().c_str(), noise_pk_size));
+ FCP_CHECK(EVP_DigestUpdate(ctx, enc_pk_size_data.c_str(), sizeof(int)));
+ FCP_CHECK(EVP_DigestUpdate(ctx, keys.enc_pk().c_str(), enc_pk_size));
+ }
+
+ char digest[kSha256Length];
+ uint32_t digest_length = 0;
+ FCP_CHECK(EVP_DigestFinal_ex(ctx, reinterpret_cast<uint8_t*>(digest),
+ &digest_length));
+ FCP_CHECK(digest_length == kSha256Length);
+ EVP_MD_CTX_destroy(ctx);
+ return {std::string(digest, kSha256Length)};
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/compute_session_id.h b/fcp/secagg/shared/compute_session_id.h
new file mode 100644
index 0000000..5ea2820
--- /dev/null
+++ b/fcp/secagg/shared/compute_session_id.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_COMPUTE_SESSION_ID_H_
+#define FCP_SECAGG_SHARED_COMPUTE_SESSION_ID_H_
+
+#include <string>
+
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+inline constexpr int kSha256Length = 32;
+
+// A SessionId is the id of a given SecAgg session. Every session's SessionId
+// should be unique.
+typedef struct SessionId {
+ std::string data;
+} SessionId;
+
+// Computes the session ID for a specific protocol session given the first
+// message (ShareKeysRequest) sent by the server.
+//
+// The session id is computed as a SHA-256 hash of the concatenation of all the
+// PairOfPublicKeys inside the request message (in the same order in which they
+// appear in the message). More specifically, for each PairOfPublicKeys inside
+// the message, the following are concatenated to the input of the hash
+// function:
+//
+// - The length of the prng ECDH public key
+// - The prng ECDH public key
+// - The length of the encryption ECDH public key
+// - The encryption ECDH public key
+//
+// Lengths are prepended to the keys so that the encoding is not ambiguous and
+// there are no unexpected collisions.
+//
+// The output of this method is 32 bytes long.
+SessionId ComputeSessionId(const ShareKeysRequest& request);
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_COMPUTE_SESSION_ID_H_
diff --git a/fcp/secagg/shared/compute_session_id_test.cc b/fcp/secagg/shared/compute_session_id_test.cc
new file mode 100644
index 0000000..9c032f3
--- /dev/null
+++ b/fcp/secagg/shared/compute_session_id_test.cc
@@ -0,0 +1,99 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/secagg/shared/compute_session_id.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::Ne;
+
+TEST(ComputeSessionIdTest, OutputIsCorrectLength) {
+ ShareKeysRequest request;
+ PairOfPublicKeys* keys = request.add_pairs_of_public_keys();
+ keys->set_noise_pk("abcdefghijklmnopqrstuvwxyz7890123");
+ keys->set_enc_pk("1234567abcdefghijklmnopqrstuvwxyz");
+
+ SessionId session_id = ComputeSessionId(request);
+ EXPECT_THAT(session_id.data.size(), Eq(32));
+}
+
+TEST(ComputeSessionIdTest, OutputIsDeterministic) {
+ ShareKeysRequest request1;
+ PairOfPublicKeys* keys1 = request1.add_pairs_of_public_keys();
+ keys1->set_noise_pk("abcdefghijklmnopqrstuvwxyz7890123");
+ keys1->set_enc_pk("1234567abcdefghijklmnopqrstuvwxyz");
+ ShareKeysRequest request2;
+ PairOfPublicKeys* keys2 = request2.add_pairs_of_public_keys();
+ keys2->set_noise_pk("abcdefghijklmnopqrstuvwxyz7890123");
+ keys2->set_enc_pk("1234567abcdefghijklmnopqrstuvwxyz");
+
+ SessionId session_id_1 = ComputeSessionId(request1);
+ SessionId session_id_2 = ComputeSessionId(request2);
+ EXPECT_THAT(session_id_2.data, Eq(session_id_1.data));
+}
+
+TEST(ComputeSessionIdTest, OutputChangesOnDifferentInputs) {
+ ShareKeysRequest request1;
+ PairOfPublicKeys* keys1 = request1.add_pairs_of_public_keys();
+ keys1->set_noise_pk("abcdefghijklmnopqrstuvwxyz7890123");
+ keys1->set_enc_pk("1234567abcdefghijklmnopqrstuvwxyz");
+ ShareKeysRequest request2;
+ PairOfPublicKeys* keys2 = request2.add_pairs_of_public_keys();
+ keys2->set_noise_pk("123456789012345678901234567890123");
+ keys2->set_enc_pk("ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFG");
+
+ SessionId session_id_1 = ComputeSessionId(request1);
+ SessionId session_id_2 = ComputeSessionId(request2);
+ EXPECT_THAT(session_id_2.data, Ne(session_id_1.data));
+}
+
+// Hard coded keys and output generated by Java
+TEST(ComputeSessionIdTest, OutputMatchesHardCodedValues) {
+ ShareKeysRequest request1;
+ PairOfPublicKeys* keys1 = request1.add_pairs_of_public_keys();
+ keys1->set_noise_pk(
+ "\002Y\256\332c\202\214\367\234F\f\370M;\301P}\b\220)\267\206C*"
+ "\253f\363\375Z\262\300\214(");
+ keys1->set_enc_pk(
+ "\002m\003C\234\217\"\037\025{\354~\345G\233\277~"
+ "\222\220\036Tkl\334C\241Ln\256\023\315k]");
+ PairOfPublicKeys* keys2 = request1.add_pairs_of_public_keys();
+ keys2->set_noise_pk(
+ "\002\023\313\267\331\211\031\332fn8\035Qx\241\217\002K\345\"\260\377:"
+ "\231~\222\246,\232?\030m\032");
+ keys2->set_enc_pk(
+ "\003\204\243\326["
+ "I\273\326\301\336\254X\300\332\201\334\371\023\351\021\022\323\371\234`"
+ "\301\352p\251\vR\217I");
+
+ SessionId expected;
+ uint8_t precomputed[32] = {120, 175, 110, 210, 30, 111, 197, 231,
+ 253, 35, 163, 25, 159, 204, 80, 79,
+ 173, 180, 27, 166, 83, 53, 85, 161,
+ 228, 232, 97, 20, 242, 62, 142, 114};
+ expected.data = std::string(reinterpret_cast<const char*>(precomputed), 32);
+ SessionId session_id = ComputeSessionId(request1);
+ EXPECT_THAT(session_id.data, Eq(expected.data));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/crypto_rand_prng.cc b/fcp/secagg/shared/crypto_rand_prng.cc
new file mode 100644
index 0000000..4704866
--- /dev/null
+++ b/fcp/secagg/shared/crypto_rand_prng.cc
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/crypto_rand_prng.h"
+
+#include <cstdint>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/prng.h"
+#include "openssl/rand.h"
+
+namespace fcp {
+namespace secagg {
+
+template <typename Output>
+static Output Rand() {
+ Output output;
+ uint8_t bytes[sizeof(output)];
+ FCP_CHECK(RAND_bytes(bytes, sizeof(output)));
+ memcpy(&output, bytes, sizeof(output));
+ return output;
+}
+
+uint8_t CryptoRandPrng::Rand8() { return Rand<uint8_t>(); }
+uint64_t CryptoRandPrng::Rand64() { return Rand<uint64_t>(); }
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/crypto_rand_prng.h b/fcp/secagg/shared/crypto_rand_prng.h
new file mode 100644
index 0000000..e229dcc
--- /dev/null
+++ b/fcp/secagg/shared/crypto_rand_prng.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_CRYPTO_RAND_PRNG_H_
+#define FCP_SECAGG_SHARED_CRYPTO_RAND_PRNG_H_
+
+#include <cstdint>
+
+#include "fcp/secagg/shared/prng.h"
+
+namespace fcp {
+namespace secagg {
+
+// A cryptographically strong Pseudorandom Number Generator based on OpenSSL,
+// which seeds using /dev/urandom on UNIX-like operating systems, and other
+// sources of randomness on Windows.
+//
+// This class is thread-safe.
+
+class CryptoRandPrng : public SecurePrng {
+ public:
+ CryptoRandPrng() = default;
+
+ uint8_t Rand8() override;
+ uint64_t Rand64() override;
+};
+
+} // namespace secagg
+} // namespace fcp
+#endif // FCP_SECAGG_SHARED_CRYPTO_RAND_PRNG_H_
diff --git a/fcp/secagg/shared/ecdh_key_agreement.cc b/fcp/secagg/shared/ecdh_key_agreement.cc
new file mode 100644
index 0000000..fdc86f8
--- /dev/null
+++ b/fcp/secagg/shared/ecdh_key_agreement.cc
@@ -0,0 +1,153 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+
+#include <memory>
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "openssl/bn.h"
+#include "openssl/ec.h"
+#include "openssl/ecdh.h"
+#include "openssl/evp.h"
+#include "openssl/mem.h"
+#include "openssl/sha.h"
+
+namespace fcp {
+namespace secagg {
+
+EcdhKeyAgreement::EcdhKeyAgreement() : key_(nullptr, EC_KEY_free) {}
+
+EcdhKeyAgreement::EcdhKeyAgreement(EC_KEY* key) : key_(key, EC_KEY_free) {}
+
+StatusOr<std::unique_ptr<EcdhKeyAgreement>>
+EcdhKeyAgreement::CreateFromRandomKeys() {
+ EC_KEY* key = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
+ EC_KEY_generate_key(key);
+ if (EC_KEY_check_key(key)) {
+ return std::make_unique<EcdhKeyAgreement>(key);
+ } else {
+ return FCP_STATUS(INTERNAL);
+ }
+}
+
+StatusOr<std::unique_ptr<EcdhKeyAgreement>>
+EcdhKeyAgreement::CreateFromPrivateKey(const EcdhPrivateKey& private_key) {
+ if (private_key.size() != EcdhPrivateKey::kSize) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "Private key must be of length " << EcdhPrivateKey::kSize;
+ }
+ // Wrap a raw pointer in a unique_ptr with a deleter to guarantee deletion.
+ BIGNUM* private_key_bn_raw =
+ BN_bin2bn(private_key.data(), private_key.size(), nullptr);
+ std::unique_ptr<BIGNUM, void (*)(BIGNUM*)> private_key_bn(private_key_bn_raw,
+ BN_free);
+ EC_KEY* key = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
+ if (!EC_KEY_set_private_key(key, private_key_bn.get())) {
+ return FCP_STATUS(INVALID_ARGUMENT) << "Invalid private key.";
+ }
+ return std::make_unique<EcdhKeyAgreement>(key);
+}
+
+StatusOr<std::unique_ptr<EcdhKeyAgreement>> EcdhKeyAgreement::CreateFromKeypair(
+ const EcdhPrivateKey& private_key, const EcdhPublicKey& public_key) {
+ if (public_key.size() != EcdhPublicKey::kSize &&
+ public_key.size() != EcdhPublicKey::kUncompressedSize) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "Public key must be of length " << EcdhPublicKey::kSize << " or "
+ << EcdhPublicKey::kUncompressedSize;
+ }
+
+ // Create using a private key only first, then add the public key if it
+ // is valid.
+ auto ecdh = CreateFromPrivateKey(private_key);
+ if (!ecdh.ok()) {
+ return ecdh;
+ }
+
+ // Copy this raw pointer for simplicity.
+ EC_KEY* key_ptr = ecdh.value()->key_.get();
+
+ // Wrap a raw pointer in a unique_ptr with a deleter to guarantee deletion.
+ EC_POINT* public_key_point_raw = EC_POINT_new(EC_KEY_get0_group(key_ptr));
+ std::unique_ptr<EC_POINT, void (*)(EC_POINT*)> public_key_point(
+ public_key_point_raw, EC_POINT_free);
+
+ if (!EC_POINT_oct2point(EC_KEY_get0_group(key_ptr), public_key_point.get(),
+ public_key.data(), public_key.size(), nullptr)) {
+ return FCP_STATUS(INVALID_ARGUMENT) << "Invalid public key.";
+ }
+
+ // This makes a copy of the public key, so deletion is safe.
+ if (!EC_KEY_set_public_key(key_ptr, public_key_point.get())) {
+ return FCP_STATUS(INVALID_ARGUMENT) << "Invalid public key.";
+ }
+
+ if (EC_KEY_check_key(key_ptr)) {
+ return ecdh;
+ } else {
+ return FCP_STATUS(INVALID_ARGUMENT) << "Invalid keypair.";
+ }
+}
+
+EcdhPrivateKey EcdhKeyAgreement::PrivateKey() const {
+ const BIGNUM* private_key_bn = EC_KEY_get0_private_key(key_.get());
+ uint8_t private_key[EcdhPrivateKey::kSize];
+ FCP_CHECK(
+ BN_bn2bin_padded(private_key, EcdhPrivateKey::kSize, private_key_bn));
+ return EcdhPrivateKey(private_key);
+}
+
+EcdhPublicKey EcdhKeyAgreement::PublicKey() const {
+ const EC_POINT* public_key_point = EC_KEY_get0_public_key(key_.get());
+ if (public_key_point == nullptr) {
+ return EcdhPublicKey();
+ }
+ uint8_t public_key[EcdhPublicKey::kSize];
+ int public_key_size = EC_POINT_point2oct(
+ EC_KEY_get0_group(key_.get()), public_key_point,
+ POINT_CONVERSION_COMPRESSED, public_key, EcdhPublicKey::kSize, nullptr);
+ FCP_CHECK(public_key_size == EcdhPublicKey::kSize);
+ return EcdhPublicKey(public_key);
+}
+
+StatusOr<AesKey> EcdhKeyAgreement::ComputeSharedSecret(
+ const EcdhPublicKey& other_key) const {
+ if (other_key.size() != EcdhPublicKey::kSize &&
+ other_key.size() != EcdhPublicKey::kUncompressedSize) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "Public key must be of length " << EcdhPublicKey::kSize << " or "
+ << EcdhPublicKey::kUncompressedSize;
+ }
+ // Wrap a raw pointer in a unique_ptr with a deleter to guarantee deletion.
+ EC_POINT* other_point_raw = EC_POINT_new(EC_KEY_get0_group(key_.get()));
+ std::unique_ptr<EC_POINT, void (*)(EC_POINT*)> other_point(other_point_raw,
+ EC_POINT_free);
+ if (!EC_POINT_oct2point(EC_KEY_get0_group(key_.get()), other_point.get(),
+ other_key.data(), other_key.size(), nullptr)) {
+ return FCP_STATUS(INVALID_ARGUMENT) << "Invalid ECDH public key.";
+ }
+ uint8_t secret[AesKey::kSize];
+ ECDH_compute_key(secret, AesKey::kSize, other_point.get(), key_.get(),
+ nullptr);
+ return AesKey(secret);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/ecdh_key_agreement.h b/fcp/secagg/shared/ecdh_key_agreement.h
new file mode 100644
index 0000000..43b04a6
--- /dev/null
+++ b/fcp/secagg/shared/ecdh_key_agreement.h
@@ -0,0 +1,103 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_ECDH_KEY_AGREEMENT_H_
+#define FCP_SECAGG_SHARED_ECDH_KEY_AGREEMENT_H_
+
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "openssl/base.h"
+#include "openssl/ec.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class represents a participant in the ECDH Key Agreement protocol. It
+// serves to hold one private/public ECDH key pair, or just one private key
+// (with no public key associated).
+//
+// The curve used for this is NIST P-256, also known as prime256v1.
+//
+// The private and public keys can be retrieved from the object in the form of
+// strings. The compressed representation of the public key is used.
+//
+// The shared secret is always hashed with SHA-256 to produce a valid AES-256
+// key before returning.
+//
+// Because only certain strings are valid ECDH keys, even if they have the
+// correct length, this class uses factory methods to instantiate new objects.
+// The constructors should never be called directly by a user of this class.
+class EcdhKeyAgreement {
+ public:
+ // FACTORY METHODS:
+ // Use one of the CreateFrom* factory methods below to instantiate a new
+ // EcdhKeyAgreement object, and note that these methods may fail. Do not use a
+ // constructor directly to instantiate an EcdhKeyAgreement object.
+
+ // Returns a new EcdhKeyAgreement object containing a keypair randomly
+ // generated using OpenSSL's randomness. Only fails if OpenSSL has an internal
+ // error.
+ static StatusOr<std::unique_ptr<EcdhKeyAgreement>> CreateFromRandomKeys();
+
+ // Returns a new EcdhKeyAgreement object containing the supplied private key,
+ // and no public key. This object can still be used to do ECDH with other
+ // public keys.
+ // Fails if the supplied private key is invalid.
+ static StatusOr<std::unique_ptr<EcdhKeyAgreement>> CreateFromPrivateKey(
+ const EcdhPrivateKey& private_key);
+
+ // Returns a new EcdhKeyAgreement object containing the supplied
+ // private/public keypair.
+ // Fails if the supplied keypair is invalid.
+ static StatusOr<std::unique_ptr<EcdhKeyAgreement>> CreateFromKeypair(
+ const EcdhPrivateKey& private_key, const EcdhPublicKey& public_key);
+
+ // Returns a representation of the private key stored in this object.
+ EcdhPrivateKey PrivateKey() const;
+
+ // Returns a representation of the public key stored in this object. This uses
+ // the compressed ECDH public key representation.
+ //
+ // If object has no public key (i.e. it was constructed with
+ // CreateFromPrivateKey), this method will return an empty string.
+ EcdhPublicKey PublicKey() const;
+
+ // Returns the shared secret AES key generated by ECDH, using with the stored
+ // ECDH private key and the supplied ECDH public key, and then hashed with
+ // SHA-256. The output will be an AES-256 key.
+ //
+ // If the other_key is not a valid public key, instead returns an error status
+ // with code INVALID_ARGUMENT.
+ StatusOr<AesKey> ComputeSharedSecret(const EcdhPublicKey& other_key) const;
+
+ // DO NOT USE THESE CONSTRUCTORS.
+ // Instead, one of the CreateFrom* factory methods below.
+ // These constructors are made public only as an implementation detail.
+ // See https://abseil.io/tips/134 for details.
+ EcdhKeyAgreement();
+ explicit EcdhKeyAgreement(EC_KEY* key);
+
+ private:
+ std::unique_ptr<EC_KEY, void (*)(EC_KEY*)> key_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_ECDH_KEY_AGREEMENT_H_
diff --git a/fcp/secagg/shared/ecdh_key_agreement_test.cc b/fcp/secagg/shared/ecdh_key_agreement_test.cc
new file mode 100644
index 0000000..5359eea
--- /dev/null
+++ b/fcp/secagg/shared/ecdh_key_agreement_test.cc
@@ -0,0 +1,172 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Contains;
+using ::testing::Eq;
+using ::testing::Not;
+
+TEST(EcdhKeyAgreementTest, CanRecoverFromPrivateKey) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh1 = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(0),
+ keys.GetPublicKey(0))
+ .value();
+ auto ecdh2 =
+ EcdhKeyAgreement::CreateFromPrivateKey(ecdh1->PrivateKey()).value();
+ EXPECT_THAT(ecdh2->PrivateKey(), Eq(ecdh1->PrivateKey()));
+ EXPECT_THAT(ecdh2->PublicKey().size(), Eq(0));
+}
+
+TEST(EcdhKeyAgreementTest, CanRecoverKeypairFromPrivateAndPublicKeys) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh1 = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(0),
+ keys.GetPublicKey(0))
+ .value();
+ auto ecdh2 = EcdhKeyAgreement::CreateFromKeypair(ecdh1->PrivateKey(),
+ ecdh1->PublicKey())
+ .value();
+ EXPECT_THAT(ecdh2->PrivateKey(), Eq(ecdh1->PrivateKey()));
+ EXPECT_THAT(ecdh2->PublicKey(), Eq(ecdh1->PublicKey()));
+}
+
+TEST(EcdhKeyAgreementTest, PrivateKeyIsExpectedLength) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(0),
+ keys.GetPublicKey(0))
+ .value();
+ EXPECT_THAT(ecdh->PrivateKey().size(), Eq(EcdhPrivateKey::kSize));
+}
+
+TEST(EcdhKeyAgreementTest, PublicKeyIsExpectedLength) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(0),
+ keys.GetPublicKey(0))
+ .value();
+ EXPECT_THAT(ecdh->PublicKey().size(), Eq(EcdhPublicKey::kSize));
+}
+
+TEST(EcdhKeyAgreementTest, RandomKeypairIsntTheSameEveryTime) {
+ std::vector<Key> private_keys;
+ std::vector<Key> public_keys;
+ private_keys.reserve(16);
+ public_keys.reserve(16);
+ for (int i = 0; i < 16; ++i) {
+ auto ecdh = EcdhKeyAgreement::CreateFromRandomKeys().value();
+ EXPECT_THAT(private_keys, Not(Contains(ecdh->PrivateKey())));
+ EXPECT_THAT(public_keys, Not(Contains(ecdh->PublicKey())));
+ private_keys.push_back(ecdh->PrivateKey());
+ public_keys.push_back(ecdh->PublicKey());
+ }
+}
+
+TEST(EcdhKeyAgreementTest, SharedSecretsHaveCorrectLength) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(0),
+ keys.GetPublicKey(0))
+ .value();
+ auto secret = ecdh->ComputeSharedSecret(keys.GetPublicKey(1));
+ ASSERT_TRUE(secret.ok());
+ EXPECT_THAT(secret.value().size(), Eq(AesKey::kSize));
+}
+
+TEST(EcdhKeyAgreementTest, SharedSecretsAreDeterministic) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(0),
+ keys.GetPublicKey(0))
+ .value();
+ auto secret1 = ecdh->ComputeSharedSecret(keys.GetPublicKey(1));
+ auto secret2 = ecdh->ComputeSharedSecret(keys.GetPublicKey(1));
+ ASSERT_TRUE(secret1.ok());
+ ASSERT_TRUE(secret2.ok());
+ EXPECT_THAT(secret1.value(), Eq(secret2.value()));
+}
+
+TEST(EcdhKeyAgreementTest, SharedSecretsAreConsistent) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh1 = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(0),
+ keys.GetPublicKey(0))
+ .value();
+ auto ecdh2 = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(1),
+ keys.GetPublicKey(1))
+ .value();
+ auto secret1 = ecdh1->ComputeSharedSecret(ecdh2->PublicKey());
+ auto secret2 = ecdh2->ComputeSharedSecret(ecdh1->PublicKey());
+ ASSERT_TRUE(secret1.ok());
+ ASSERT_TRUE(secret2.ok());
+ EXPECT_THAT(secret1.value(), Eq(secret2.value()));
+}
+
+TEST(EcdhKeyAgreementTest, SharedSecretsAreConsistentWithoutPublicKey) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh1 =
+ EcdhKeyAgreement::CreateFromPrivateKey(keys.GetPrivateKey(0)).value();
+ auto ecdh2 =
+ EcdhKeyAgreement::CreateFromPrivateKey(keys.GetPrivateKey(1)).value();
+ auto secret1 = ecdh1->ComputeSharedSecret(keys.GetPublicKey(1));
+ auto secret2 = ecdh2->ComputeSharedSecret(keys.GetPublicKey(0));
+ ASSERT_TRUE(secret1.ok());
+ ASSERT_TRUE(secret2.ok());
+ EXPECT_THAT(secret1.value(), Eq(secret2.value()));
+}
+
+TEST(EcdhKeyAgreementTest, CreateFromKeypairErrorsOnInconsistentKeys) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(0),
+ keys.GetPublicKey(1));
+ EXPECT_THAT(ecdh.ok(), Eq(false));
+}
+
+TEST(EcdhKeyAgreementTest, ComputeSharedSecretErrorsOnGarbagePublicKey) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh = EcdhKeyAgreement::CreateFromPrivateKey(keys.GetPrivateKey(0));
+ ASSERT_TRUE(ecdh.ok());
+
+ // first byte valid at least
+ const char bad_key[] =
+ "\x2"
+ "23456789012345678901234567890123";
+
+ auto secret = ecdh.value()->ComputeSharedSecret(
+ EcdhPublicKey(reinterpret_cast<const uint8_t*>(bad_key)));
+ EXPECT_THAT(secret.ok(), Eq(false));
+}
+
+TEST(EcdhKeyAgreementTest, SharedSecretsWorkWithUncompressedPublicKeys) {
+ EcdhPregeneratedTestKeys keys;
+ auto ecdh = EcdhKeyAgreement::CreateFromKeypair(keys.GetPrivateKey(0),
+ keys.GetPublicKey(0))
+ .value();
+ auto secret = ecdh->ComputeSharedSecret(keys.GetUncompressedPublicKey(0));
+ ASSERT_TRUE(secret.ok());
+ EXPECT_THAT(secret.value().size(), Eq(AesKey::kSize));
+}
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/ecdh_keys.h b/fcp/secagg/shared/ecdh_keys.h
new file mode 100644
index 0000000..f49079b
--- /dev/null
+++ b/fcp/secagg/shared/ecdh_keys.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_ECDH_KEYS_H_
+#define FCP_SECAGG_SHARED_ECDH_KEYS_H_
+
+#include "fcp/secagg/shared/key.h"
+
+// This file contains definitions for ECDH public key and private key types.
+
+namespace fcp {
+namespace secagg {
+// A Key that serves as a private key for use with ECDH, with the NIST P-256
+// curve. Works the same as Key, but is guaranteed to have either 0 or 32 bytes.
+// A 0-byte key should not be used for anything, and represents the absence of
+// a key in a collection of keys.
+class EcdhPrivateKey : public Key {
+ public:
+ static constexpr int kSize = 32;
+
+ // The key is blank.
+ EcdhPrivateKey() : Key() {}
+
+ // The data MUST have 32 bytes.
+ explicit EcdhPrivateKey(const uint8_t* data) : Key(data, kSize) {}
+};
+
+// A Key that serves as a public key for use with ECDH, with the NIST P-256
+// curve. Works the same as Key, but is guaranteed to have either 0, 33, or 65
+// bytes (depending on whether the key is compressed or not). Clients and the
+// server should both produce compressed keys, but legacy Java clients send
+// keys in uncompressed format.
+// A 0-byte key should not be used for anything, and represents the absence of
+// a key in a collection of keys.
+class EcdhPublicKey : public Key {
+ public:
+ static constexpr int kSize = 33;
+ // TODO(team): Remove uncompressed support when Java SecAgg deprecated.
+ static constexpr int kUncompressedSize = 65;
+ enum Format { kCompressed, kUncompressed };
+
+ // The key is blank.
+ EcdhPublicKey() : Key() {}
+
+ // If the key is compressed, data must have 33 bytes.
+ // If the key is uncompressed, data must have 65 bytes and the uncompressed
+ // format must be specified.
+ explicit EcdhPublicKey(const uint8_t* data, Format format = kCompressed)
+ : Key(data, format == kCompressed ? kSize : kUncompressedSize) {}
+};
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_ECDH_KEYS_H_
diff --git a/fcp/secagg/shared/input_vector_specification.cc b/fcp/secagg/shared/input_vector_specification.cc
new file mode 100644
index 0000000..6deed54
--- /dev/null
+++ b/fcp/secagg/shared/input_vector_specification.cc
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/input_vector_specification.h"
+
+#include <string>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+
+InputVectorSpecification::InputVectorSpecification(const std::string& name,
+ int length, uint64_t modulus)
+ : name_(name), length_(length), modulus_(modulus) {
+ FCP_CHECK(length >= 0) << "Length must be >= 0, given value was" << length;
+ FCP_CHECK(modulus > 1 && modulus <= SecAggVector::kMaxModulus)
+ << "The specified modulus is not valid: must be > 1 and <= "
+ << SecAggVector::kMaxModulus << ", supplied value : " << modulus;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/input_vector_specification.h b/fcp/secagg/shared/input_vector_specification.h
new file mode 100644
index 0000000..fd8c63b
--- /dev/null
+++ b/fcp/secagg/shared/input_vector_specification.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_INPUT_VECTOR_SPECIFICATION_H_
+#define FCP_SECAGG_SHARED_INPUT_VECTOR_SPECIFICATION_H_
+
+#include <cstdint>
+#include <string>
+
+#include "absl/base/attributes.h"
+
+namespace fcp {
+namespace secagg {
+
+// Used to specify the name and either:
+//
+// 1. For the original protocol, the length and bit width of each input vector
+// which the protocol will aggregate.
+// 2. For the RLWE version, the length, the polynomial degree, and the modulus
+// for each input vector. In this case the length must be a multiple of the
+// degree.
+class InputVectorSpecification {
+ public:
+ InputVectorSpecification(const std::string& name, int length,
+ uint64_t modulus);
+
+ virtual ~InputVectorSpecification() = default;
+
+ ABSL_MUST_USE_RESULT inline const std::string& name() const { return name_; }
+
+ ABSL_MUST_USE_RESULT inline int length() const { return length_; }
+
+ ABSL_MUST_USE_RESULT inline uint64_t modulus() const { return modulus_; }
+
+ private:
+ const std::string name_;
+ const int length_;
+ const uint64_t modulus_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_INPUT_VECTOR_SPECIFICATION_H_
diff --git a/fcp/secagg/shared/input_vector_specification_test.cc b/fcp/secagg/shared/input_vector_specification_test.cc
new file mode 100644
index 0000000..bb89c46
--- /dev/null
+++ b/fcp/secagg/shared/input_vector_specification_test.cc
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/input_vector_specification.h"
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+
+TEST(InputVectorSpecificationTest, GettersReturnAppropriateValues) {
+ std::string name = "test";
+ int length = 2;
+ uint64_t modulus = 256;
+
+ InputVectorSpecification vec_spec(name, length, modulus);
+ EXPECT_THAT(vec_spec.modulus(), Eq(modulus));
+ EXPECT_THAT(vec_spec.length(), Eq(length));
+ EXPECT_THAT(vec_spec.name(), Eq(name));
+}
+
+TEST(InputVectorSpecificationTest, ConstructorDiesOnSmallModulus) {
+ EXPECT_DEATH(InputVectorSpecification vec_spec("test", 5, 1),
+ "The specified modulus is not valid");
+}
+
+TEST(InputVectorSpecificationTest, ConstructorDiesOnModulusGreatorThanMax) {
+ EXPECT_DEATH(InputVectorSpecification vec_spec("test", 5,
+ SecAggVector::kMaxModulus + 1),
+ "The specified modulus is not valid");
+}
+
+TEST(InputVectorSpecificationTest, ConstructorDiesOnNegativeLength) {
+ EXPECT_DEATH(InputVectorSpecification vec_spec("test", -1, 256),
+ "Length must be >= 0");
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/key.h b/fcp/secagg/shared/key.h
new file mode 100644
index 0000000..53def3d
--- /dev/null
+++ b/fcp/secagg/shared/key.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_KEY_H_
+#define FCP_SECAGG_SHARED_KEY_H_
+
+#include <cstdint>
+#include <string>
+
+namespace fcp {
+namespace secagg {
+// An immutable type that encapsulates a key to be used with OpenSSL. Stores the
+// key as std::string, but for better interaction with the OpenSSL API, the Key
+// API treats the key as either a string or a const uint8_t*.
+//
+// Note that this doesn't replace any OpenSSL structure, it simply allows for
+// storage of keys at rest without needing to store associated OpenSSL data.
+class Key {
+ public:
+ Key() : data_("") {}
+
+ Key(const uint8_t* data, int size)
+ : data_(reinterpret_cast<const char*>(data), size) {}
+
+ inline const uint8_t* data() const {
+ return reinterpret_cast<const uint8_t*>(data_.c_str());
+ }
+
+ inline const int size() const { return data_.size(); }
+
+ inline const std::string AsString() const { return data_; }
+
+ friend inline bool operator==(const Key& lhs, const Key& rhs) {
+ return lhs.data_ == rhs.data_;
+ }
+
+ private:
+ std::string data_; // The binary key data.
+};
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_KEY_H_
diff --git a/fcp/secagg/shared/map_of_masks.cc b/fcp/secagg/shared/map_of_masks.cc
new file mode 100644
index 0000000..d6dba00
--- /dev/null
+++ b/fcp/secagg/shared/map_of_masks.cc
@@ -0,0 +1,372 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/map_of_masks.h"
+
+#include <algorithm>
+#include <atomic>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/numeric/bits.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/math.h"
+#include "fcp/secagg/shared/prng.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+#include "openssl/evp.h"
+
+namespace fcp {
+namespace secagg {
+
+// Constant for backwards compatibility with legacy clients. Even though it is
+// no longer needed, removing it would be disruptive due to making a large
+// number of clients incompatible while not providing any benefits.
+uint8_t kPrngSeedConstant = 0x02;
+
+// We specifically avoid sample_bits == 64 to sidestep numerical precision
+// issues, e.g. a uint64_t cannot represent the associated modulus.
+constexpr int kMaxSampleBits = 63;
+
+// We consider using at most 16 additional random bits from the underlying
+// PRNG per sample.
+//
+constexpr int kMaxSampleBitsExpansion = 16;
+
+static AesKey DigestKey(EVP_MD_CTX* mdctx, const std::string& prng_input,
+ int bit_width, const AesKey& prng_key) {
+ int input_size = prng_input.size();
+ std::string input_size_data = IntToByteString(input_size);
+ std::string bit_width_data = IntToByteString(bit_width);
+ FCP_CHECK(EVP_DigestInit_ex(mdctx, EVP_sha256(), nullptr));
+ FCP_CHECK(EVP_DigestUpdate(mdctx, bit_width_data.c_str(), sizeof(int)));
+ FCP_CHECK(EVP_DigestUpdate(mdctx, prng_key.data(), prng_key.size()));
+ FCP_CHECK(EVP_DigestUpdate(mdctx, &kPrngSeedConstant, 1));
+ FCP_CHECK(EVP_DigestUpdate(mdctx, input_size_data.c_str(), sizeof(int)));
+ FCP_CHECK(EVP_DigestUpdate(mdctx, prng_input.c_str(), input_size));
+
+ uint8_t digest[AesKey::kSize];
+ uint32_t digest_length = 0;
+ FCP_CHECK(EVP_DigestFinal_ex(mdctx, digest, &digest_length));
+ FCP_CHECK(digest_length == AesKey::kSize);
+ return AesKey(digest);
+}
+
+// Determines whether sample_bits_1 or sample_bits_2 will be more efficient
+// for sampling uniformly from [0, modulus).
+//
+int choose_better_sample_bits(uint64_t modulus, int sample_bits_1,
+ int sample_bits_2) {
+ FCP_CHECK(sample_bits_1 <= sample_bits_2);
+ FCP_CHECK(sample_bits_2 <= kMaxSampleBits);
+ FCP_CHECK(sample_bits_2 - sample_bits_1 <= kMaxSampleBitsExpansion);
+
+ uint64_t sample_modulus_1 = 1ULL << sample_bits_1;
+ FCP_CHECK(modulus <= sample_modulus_1);
+
+ if (sample_bits_1 == sample_bits_2) {
+ return sample_bits_1;
+ }
+
+ uint64_t sample_modulus_2 = 1ULL << sample_bits_2;
+ uint64_t sample_modulus_2_over_1 = 1ULL << (sample_bits_2 - sample_bits_1);
+ uint32_t cost_per_sample_1 = DivideRoundUp(sample_bits_1, 8);
+ uint32_t cost_per_sample_2 = DivideRoundUp(sample_bits_2, 8);
+ uint64_t modulus_reps_1 = sample_modulus_1 / modulus;
+ uint64_t modulus_reps_2 = sample_modulus_2 / modulus;
+ uint64_t cost_product_1 = cost_per_sample_1 * modulus_reps_1;
+ uint64_t cost_product_2 =
+ cost_per_sample_2 * modulus_reps_2 * sample_modulus_2_over_1;
+ return cost_product_1 > cost_product_2 ? sample_bits_2 : sample_bits_1;
+}
+
+// Computes the sample_bits that minimizes the expected number of bytes of
+// randomness that will be consumed when drawing a uniform sample from
+// [0, modulus) using our rejection sampling algorithm.
+//
+int compute_best_sample_bits(uint64_t modulus) {
+ int min_sample_bits = static_cast<int>(absl::bit_width(modulus - 1ULL));
+ int max_sample_bits = std::min(kMaxSampleBitsExpansion,
+ min_sample_bits + kMaxSampleBitsExpansion);
+ int best_sample_bits = min_sample_bits;
+ for (int sample_bits = min_sample_bits + 1; sample_bits <= max_sample_bits;
+ sample_bits++) {
+ best_sample_bits =
+ choose_better_sample_bits(modulus, best_sample_bits, sample_bits);
+ }
+ return best_sample_bits;
+}
+
+// PrngBuffer implements the logic for generating pseudo-random masks while
+// fetching and caching buffers of psedo-random uint8_t numbers.
+// Two important factors of this implementation compared to using SecurePrng
+// directly are:
+// 1) The implementation is fully inlineable allowing the the compiler to
+// greatly optimize the resulting code.
+// 2) Checking whether a new buffer of pseudo-random bytes needs to be filled is
+// done only once per mask as opposed to doing that for every byte, which
+// optimizes the most nested loop.
+class PrngBuffer {
+ public:
+ PrngBuffer(std::unique_ptr<SecurePrng> prng, uint8_t msb_mask,
+ size_t bytes_per_output)
+ : prng_(static_cast<SecureBatchPrng*>(prng.release())),
+ msb_mask_(msb_mask),
+ bytes_per_output_(bytes_per_output),
+ buffer_(prng_->GetMaxBufferSize()),
+ buffer_end_(buffer_.data() + buffer_.size()) {
+ FCP_CHECK((prng_->GetMaxBufferSize() % bytes_per_output) == 0)
+ << "PRNG buffer size must be a multiple bytes_per_output.";
+ FillBuffer();
+ }
+
+ inline uint64_t NextMask() {
+ if (buffer_ptr_ == buffer_end_) {
+ FillBuffer();
+ }
+
+ auto output = static_cast<uint64_t>((*buffer_ptr_++) & msb_mask_);
+ for (size_t i = 1; i < bytes_per_output_; ++i) {
+ output <<= 8UL;
+ output |= static_cast<uint64_t>(*buffer_ptr_++);
+ }
+ return output;
+ }
+
+ private:
+ inline int buffer_size() { return static_cast<int>(buffer_.size()); }
+
+ inline void FillBuffer() {
+ buffer_ptr_ = buffer_.data();
+ FCP_CHECK(prng_->RandBuffer(buffer_.data(), buffer_size()) ==
+ buffer_size());
+ }
+
+ std::unique_ptr<SecureBatchPrng> prng_;
+ const uint8_t msb_mask_;
+ const size_t bytes_per_output_;
+ std::vector<uint8_t> buffer_;
+ const uint8_t* buffer_ptr_ = nullptr;
+ const uint8_t* const buffer_end_;
+};
+
+struct AddModAdapter {
+ inline static uint64_t AddModImpl(uint64_t a, uint64_t b, uint64_t z) {
+ return AddMod(a, b, z);
+ }
+ inline static uint64_t SubtractModImpl(uint64_t a, uint64_t b, uint64_t z) {
+ return SubtractMod(a, b, z);
+ }
+};
+
+struct AddModOptAdapter {
+ inline static uint64_t AddModImpl(uint64_t a, uint64_t b, uint64_t z) {
+ return AddModOpt(a, b, z);
+ }
+ inline static uint64_t SubtractModImpl(uint64_t a, uint64_t b, uint64_t z) {
+ return SubtractModOpt(a, b, z);
+ }
+};
+
+// Templated implementation of MapOfMasks that allows substituting
+// AddMod and SubtractMod implementations.
+template <typename TAdapter, typename TVector, typename TVectorMap>
+inline std::unique_ptr<TVectorMap> MapOfMasksImpl(
+ const std::vector<AesKey>& prng_keys_to_add,
+ const std::vector<AesKey>& prng_keys_to_subtract,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ const SessionId& session_id, const AesPrngFactory& prng_factory,
+ AsyncAbort* async_abort) {
+ FCP_CHECK(prng_factory.SupportsBatchMode());
+
+ auto map_of_masks = std::make_unique<TVectorMap>();
+ std::unique_ptr<EVP_MD_CTX, void (*)(EVP_MD_CTX*)> mdctx(EVP_MD_CTX_create(),
+ EVP_MD_CTX_destroy);
+ FCP_CHECK(mdctx.get());
+ for (const InputVectorSpecification& vector_spec : input_vector_specs) {
+ if (async_abort && async_abort->Signalled()) return nullptr;
+ int bit_width =
+ static_cast<int>(absl::bit_width(vector_spec.modulus() - 1ULL));
+ std::string prng_input =
+ absl::StrCat(session_id.data, IntToByteString(bit_width),
+ IntToByteString(vector_spec.length()), vector_spec.name());
+ std::vector<uint64_t> mask_vector_buffer(vector_spec.length(), 0);
+
+ bool modulus_is_power_of_two = (1ULL << bit_width == vector_spec.modulus());
+ if (modulus_is_power_of_two) {
+ // Because the modulus is a power of two, we can sample uniformly
+ // simply by drawing the correct number of random bits.
+ int bytes_per_output = DivideRoundUp(bit_width, 8);
+ // msb = "most significant byte"
+ size_t bits_in_msb = bit_width - ((bytes_per_output - 1) * 8);
+ uint8_t msb_mask = (1UL << bits_in_msb) - 1;
+
+ for (const auto& prng_key : prng_keys_to_add) {
+ if (async_abort && async_abort->Signalled()) return nullptr;
+ AesKey digest_key =
+ DigestKey(mdctx.get(), prng_input, bit_width, prng_key);
+ PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
+ bytes_per_output);
+ for (auto& v : mask_vector_buffer) {
+ v = TAdapter::AddModImpl(v, prng.NextMask(), vector_spec.modulus());
+ }
+ }
+
+ for (const auto& prng_key : prng_keys_to_subtract) {
+ if (async_abort && async_abort->Signalled()) return nullptr;
+ AesKey digest_key =
+ DigestKey(mdctx.get(), prng_input, bit_width, prng_key);
+ PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
+ bytes_per_output);
+ for (auto& v : mask_vector_buffer) {
+ v = TAdapter::SubtractModImpl(v, prng.NextMask(),
+ vector_spec.modulus());
+ }
+ }
+ } else {
+ // Rejection Sampling algorithm for arbitrary moduli.
+ // Follows Algorithm 3 from:
+ // "Fast Random Integer Generation in an Interval," Daniel Lemire, 2018.
+ // https://arxiv.org/pdf/1805.10941.pdf.
+ //
+ // The inner loops are structured to avoid conditional branches
+ // and the associated branch misprediction errors they would entail.
+ //
+ // We choose sample_bits to minimize the expected number of bytes
+ // drawn from the PRNG.
+
+ int sample_bits = compute_best_sample_bits(vector_spec.modulus());
+ int bytes_per_output = DivideRoundUp(sample_bits, 8);
+ // msb = "most significant byte"
+ size_t bits_in_msb = sample_bits - ((bytes_per_output - 1) * 8);
+ uint8_t msb_mask = (1UL << bits_in_msb) - 1;
+
+ uint64_t sample_modulus = 1ULL << sample_bits;
+ uint64_t rejection_threshold =
+ (sample_modulus - vector_spec.modulus()) % vector_spec.modulus();
+
+ for (const auto& prng_key : prng_keys_to_add) {
+ if (async_abort && async_abort->Signalled()) return nullptr;
+ AesKey digest_key =
+ DigestKey(mdctx.get(), prng_input, sample_bits, prng_key);
+ PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
+ bytes_per_output);
+ int i = 0;
+ while (i < vector_spec.length()) {
+ auto& v = mask_vector_buffer[i];
+ auto mask = prng.NextMask();
+ auto reject = mask < rejection_threshold;
+ auto inc = reject ? 0 : 1;
+ mask = reject ? 0 : mask;
+ v = TAdapter::AddModImpl(v, mask % vector_spec.modulus(),
+ vector_spec.modulus());
+ i += inc;
+ }
+ }
+
+ for (const auto& prng_key : prng_keys_to_subtract) {
+ if (async_abort && async_abort->Signalled()) return nullptr;
+ AesKey digest_key =
+ DigestKey(mdctx.get(), prng_input, sample_bits, prng_key);
+ PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
+ bytes_per_output);
+ int i = 0;
+ while (i < vector_spec.length()) {
+ auto& v = mask_vector_buffer[i];
+ auto mask = prng.NextMask();
+ auto reject = mask < rejection_threshold;
+ auto inc = reject ? 0 : 1;
+ mask = reject ? 0 : mask;
+ v = TAdapter::SubtractModImpl(v, mask % vector_spec.modulus(),
+ vector_spec.modulus());
+ i += inc;
+ }
+ }
+ }
+
+ if (async_abort && async_abort->Signalled()) return nullptr;
+ map_of_masks->emplace(vector_spec.name(),
+ TVector(mask_vector_buffer, vector_spec.modulus()));
+ }
+ return map_of_masks;
+}
+
+std::unique_ptr<SecAggVectorMap> MapOfMasks(
+ const std::vector<AesKey>& prng_keys_to_add,
+ const std::vector<AesKey>& prng_keys_to_subtract,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ const SessionId& session_id, const AesPrngFactory& prng_factory,
+ AsyncAbort* async_abort) {
+ return MapOfMasksImpl<AddModAdapter, SecAggVector, SecAggVectorMap>(
+ prng_keys_to_add, prng_keys_to_subtract, input_vector_specs, session_id,
+ prng_factory, async_abort);
+}
+
+std::unique_ptr<SecAggVectorMap> MapOfMasksV3(
+ const std::vector<AesKey>& prng_keys_to_add,
+ const std::vector<AesKey>& prng_keys_to_subtract,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ const SessionId& session_id, const AesPrngFactory& prng_factory,
+ AsyncAbort* async_abort) {
+ return MapOfMasksImpl<AddModOptAdapter, SecAggVector, SecAggVectorMap>(
+ prng_keys_to_add, prng_keys_to_subtract, input_vector_specs, session_id,
+ prng_factory, async_abort);
+}
+
+SecAggVector AddVectors(const SecAggVector& a, const SecAggVector& b) {
+ FCP_CHECK(a.modulus() == b.modulus() && a.num_elements() == b.num_elements());
+ uint64_t modulus = a.modulus();
+ SecAggVector::Decoder decoder_a(a);
+ SecAggVector::Decoder decoder_b(b);
+ SecAggVector::Coder sum_coder(modulus, static_cast<int>(a.bit_width()),
+ a.num_elements());
+ for (int remaining_elements = static_cast<int>(a.num_elements());
+ remaining_elements > 0; --remaining_elements) {
+ sum_coder.WriteValue((decoder_a.ReadValue() + decoder_b.ReadValue()) %
+ modulus);
+ }
+ return std::move(sum_coder).Create();
+}
+
+std::unique_ptr<SecAggVectorMap> AddMaps(const SecAggVectorMap& a,
+ const SecAggVectorMap& b) {
+ auto result = std::make_unique<SecAggVectorMap>();
+ for (const auto& item : a) {
+ result->emplace(item.first, AddVectors(item.second, b.at(item.first)));
+ }
+ return result;
+}
+
+std::unique_ptr<SecAggUnpackedVectorMap> UnpackedMapOfMasks(
+ const std::vector<AesKey>& prng_keys_to_add,
+ const std::vector<AesKey>& prng_keys_to_subtract,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ const SessionId& session_id, const AesPrngFactory& prng_factory,
+ AsyncAbort* async_abort) {
+ return MapOfMasksImpl<AddModOptAdapter, SecAggUnpackedVector,
+ SecAggUnpackedVectorMap>(
+ prng_keys_to_add, prng_keys_to_subtract, input_vector_specs, session_id,
+ prng_factory, async_abort);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/map_of_masks.h b/fcp/secagg/shared/map_of_masks.h
new file mode 100644
index 0000000..1b46643
--- /dev/null
+++ b/fcp/secagg/shared/map_of_masks.h
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_SECAGG_SHARED_MAP_OF_MASKS_H_
+#define FCP_SECAGG_SHARED_MAP_OF_MASKS_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "fcp/secagg/shared/aes_key.h"
+#include "fcp/secagg/shared/aes_prng_factory.h"
+#include "fcp/secagg/shared/async_abort.h"
+#include "fcp/secagg/shared/compute_session_id.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+// This file contains two unbound functions for generating and adding maps of
+// mask vectors.
+
+namespace fcp {
+namespace secagg {
+
+// Generates and returns a map of masks for all the vectors that need to be
+// masked, given all the keys that need to be used to mask (or unmask) those
+// vectors.
+//
+// prng_factory is an instance of a subclass of AesPrngFactory.
+// For clients communicating with the (C++) version of SecAggServer in this
+// package, or the SecAggServer itself, this must be an instance of
+// AesCtrPrngFactory.
+//
+// Returns a nullptr value if the operation was aborted, as detected via the
+// optional async_abort parameter.
+std::unique_ptr<SecAggVectorMap> MapOfMasks(
+ const std::vector<AesKey>& prng_keys_to_add,
+ const std::vector<AesKey>& prng_keys_to_subtract,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ const SessionId& session_id, const AesPrngFactory& prng_factory,
+ AsyncAbort* async_abort = nullptr);
+
+// Optimized version of MapOfMasks that uses optimized AddModOpt and
+// SubtractModOpt modulus operations.
+std::unique_ptr<SecAggVectorMap> MapOfMasksV3(
+ const std::vector<AesKey>& prng_keys_to_add,
+ const std::vector<AesKey>& prng_keys_to_subtract,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ const SessionId& session_id, const AesPrngFactory& prng_factory,
+ AsyncAbort* async_abort = nullptr);
+
+// Optimized version of MapOfMasks that uses optimized AddMapOpt and
+// SubtractMapOpt modululs operations and produces map of unpacked vectors.
+std::unique_ptr<SecAggUnpackedVectorMap> UnpackedMapOfMasks(
+ const std::vector<AesKey>& prng_keys_to_add,
+ const std::vector<AesKey>& prng_keys_to_subtract,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ const SessionId& session_id, const AesPrngFactory& prng_factory,
+ AsyncAbort* async_abort = nullptr);
+
+// Adds two vectors together and returns a new sum vector.
+SecAggVector AddVectors(const SecAggVector& a, const SecAggVector& b);
+
+// Takes two maps of masks/masked vectors, and adds them together, returning the
+// sum.
+std::unique_ptr<SecAggVectorMap> AddMaps(const SecAggVectorMap& a,
+ const SecAggVectorMap& b);
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_MAP_OF_MASKS_H_
diff --git a/fcp/secagg/shared/map_of_masks_bench.cc b/fcp/secagg/shared/map_of_masks_bench.cc
new file mode 100644
index 0000000..3c3ebee
--- /dev/null
+++ b/fcp/secagg/shared/map_of_masks_bench.cc
@@ -0,0 +1,169 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/numeric/bits.h"
+#include "absl/strings/str_cat.h"
+#include "benchmark//benchmark.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/map_of_masks.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+constexpr auto kVectorSize = 1024 * 1024;
+constexpr auto kNumKeys = 128;
+
+inline void BM_MapOfMasks_Impl(benchmark::State& state, uint64_t modulus) {
+ state.PauseTiming();
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t key[AesKey::kSize];
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.reserve(kNumKeys);
+ for (int i = 0; i < kNumKeys; i++) {
+ prng_keys_to_add.emplace_back(key);
+ }
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.emplace_back("unused", kVectorSize, modulus);
+
+ state.ResumeTiming();
+ benchmark::DoNotOptimize(MapOfMasks(
+ prng_keys_to_add, prng_keys_to_subtract, vector_specs, session_id,
+ static_cast<const AesPrngFactory&>(AesCtrPrngFactory())));
+
+ state.SetItemsProcessed(kVectorSize);
+}
+
+inline void BM_MapOfMasksV3_Impl(benchmark::State& state, uint64_t modulus) {
+ state.PauseTiming();
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t key[AesKey::kSize];
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.reserve(kNumKeys);
+ for (int i = 0; i < kNumKeys; i++) {
+ prng_keys_to_add.emplace_back(key);
+ }
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.emplace_back("unused", kVectorSize, modulus);
+
+ state.ResumeTiming();
+ benchmark::DoNotOptimize(MapOfMasksV3(
+ prng_keys_to_add, prng_keys_to_subtract, vector_specs, session_id,
+ static_cast<const AesPrngFactory&>(AesCtrPrngFactory())));
+
+ state.SetItemsProcessed(kVectorSize);
+}
+
+void BM_MapOfMasks_PowerOfTwo(benchmark::State& state) {
+ for (auto s : state) {
+ int bitwidth = static_cast<int>(state.range(0));
+ BM_MapOfMasks_Impl(state, 1ULL << bitwidth);
+ }
+}
+
+void BM_MapOfMasks_Arbitrary(benchmark::State& state) {
+ for (auto s : state) {
+ uint64_t modulus = static_cast<uint64_t>(state.range(0));
+ BM_MapOfMasks_Impl(state, modulus);
+ }
+}
+
+void BM_MapOfMasksV3_PowerOfTwo(benchmark::State& state) {
+ for (auto s : state) {
+ int bitwidth = static_cast<int>(state.range(0));
+ BM_MapOfMasksV3_Impl(state, 1ULL << bitwidth);
+ }
+}
+
+void BM_MapOfMasksV3_Arbitrary(benchmark::State& state) {
+ for (auto s : state) {
+ uint64_t modulus = static_cast<uint64_t>(state.range(0));
+ BM_MapOfMasksV3_Impl(state, modulus);
+ }
+}
+
+BENCHMARK(BM_MapOfMasks_PowerOfTwo)
+ ->Arg(9)
+ ->Arg(25)
+ ->Arg(41)
+ ->Arg(53)
+ ->Arg(absl::bit_width(SecAggVector::kMaxModulus - 1));
+
+BENCHMARK(BM_MapOfMasks_Arbitrary)
+ ->Arg(5)
+ ->Arg(39)
+ ->Arg(485)
+ ->Arg(2400)
+ ->Arg(14901)
+ ->Arg(51813)
+ ->Arg(532021)
+ ->Arg(13916946)
+ ->Arg(39549497)
+ ->Arg(548811945)
+ ->Arg(590549014)
+ ->Arg(48296031686)
+ ->Arg(156712951284)
+ ->Arg(2636861836189)
+ ->Arg(14673852658160)
+ ->Arg(92971495438615)
+ ->Arg(304436005557271)
+ ->Arg(14046234330484262)
+ ->Arg(38067457113486645)
+ ->Arg(175631339105057682);
+
+BENCHMARK(BM_MapOfMasksV3_PowerOfTwo)
+ ->Arg(9)
+ ->Arg(25)
+ ->Arg(41)
+ ->Arg(53)
+ ->Arg(absl::bit_width(SecAggVector::kMaxModulus - 1));
+
+BENCHMARK(BM_MapOfMasksV3_Arbitrary)
+ ->Arg(5)
+ ->Arg(39)
+ ->Arg(485)
+ ->Arg(2400)
+ ->Arg(14901)
+ ->Arg(51813)
+ ->Arg(532021)
+ ->Arg(13916946)
+ ->Arg(39549497)
+ ->Arg(548811945)
+ ->Arg(590549014)
+ ->Arg(48296031686)
+ ->Arg(156712951284)
+ ->Arg(2636861836189)
+ ->Arg(14673852658160)
+ ->Arg(92971495438615)
+ ->Arg(304436005557271)
+ ->Arg(14046234330484262)
+ ->Arg(38067457113486645)
+ ->Arg(175631339105057682);
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/map_of_masks_test.cc b/fcp/secagg/shared/map_of_masks_test.cc
new file mode 100644
index 0000000..e62eec3
--- /dev/null
+++ b/fcp/secagg/shared/map_of_masks_test.cc
@@ -0,0 +1,553 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/map_of_masks.h"
+
+#include <array>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/node_hash_map.h"
+#include "absl/numeric/bits.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/secagg/shared/aes_ctr_prng_factory.h"
+#include "fcp/secagg/shared/input_vector_specification.h"
+#include "fcp/secagg/shared/math.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+using ::testing::Lt;
+using ::testing::Ne;
+
+const std::array<uint64_t, 20> kArbitraryModuli{5,
+ 39,
+ 485,
+ 2400,
+ 14901,
+ 51813,
+ 532021,
+ 13916946,
+ 39549497,
+ 548811945,
+ 590549014,
+ 48296031686,
+ 156712951284,
+ 2636861836189,
+ 14673852658160,
+ 92971495438615,
+ 304436005557271,
+ 14046234330484262,
+ 38067457113486645,
+ 175631339105057682};
+
+TEST(AddMapsTest, AddMapsGetsRightSum_PowerOfTwo) {
+ std::vector<uint64_t> vec_a{25, 50, 75, 100, 150};
+ std::vector<uint64_t> vec_b{50, 100, 150, 200, 250};
+ SecAggVectorMap map_a;
+ map_a.emplace("test", SecAggVector(vec_a, 256));
+ SecAggVectorMap map_b;
+ map_b.emplace("test", SecAggVector(vec_b, 256));
+
+ auto map_sum = AddMaps(map_a, map_b);
+ std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
+ for (int i = 0; i < vec_a.size(); ++i) {
+ EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % 256));
+ }
+}
+
+TEST(AddMapsTest, AddMapsGetsRightSum_AribraryModuli) {
+ std::vector<uint64_t> vec_a{25, 50, 75, 100, 150};
+ std::vector<uint64_t> vec_b{50, 100, 150, 200, 250};
+ SecAggVectorMap map_a;
+ map_a.emplace("test", SecAggVector(vec_a, 255));
+ SecAggVectorMap map_b;
+ map_b.emplace("test", SecAggVector(vec_b, 255));
+
+ auto map_sum = AddMaps(map_a, map_b);
+ std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
+ for (int i = 0; i < vec_a.size(); ++i) {
+ EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % 255));
+ }
+}
+
+TEST(AddMapsTest, AddMapsExhaustiveTest_PowerOfTwo) {
+ // Make SecurePrng instance to be used as a consistent pseudo-random number
+ // generator.
+ uint8_t seed_data[32];
+ memset(seed_data, '1', 32);
+ AesKey seed(seed_data);
+ AesCtrPrngFactory prng_factory;
+ std::unique_ptr<SecurePrng> prng = prng_factory.MakePrng(seed);
+
+ // Iterate through all possible bitwidths, add two random vectors, and
+ // verify the results.
+ for (int number_of_bits = 1;
+ number_of_bits <= absl::bit_width(SecAggVector::kMaxModulus - 1);
+ ++number_of_bits) {
+ uint64_t modulus = 1ULL << number_of_bits;
+ constexpr size_t kSize = 1000;
+ std::vector<uint64_t> vec_a(kSize);
+ std::vector<uint64_t> vec_b(kSize);
+ for (size_t i = 0; i < kSize; i++) {
+ vec_a[i] = prng->Rand64() % modulus;
+ vec_b[i] = prng->Rand64() % modulus;
+ }
+
+ SecAggVectorMap map_a;
+ map_a.emplace("test", SecAggVector(vec_a, modulus));
+ SecAggVectorMap map_b;
+ map_b.emplace("test", SecAggVector(vec_b, modulus));
+
+ auto map_sum = AddMaps(map_a, map_b);
+ std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
+ for (size_t i = 0; i < kSize; i++) {
+ EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % modulus));
+ }
+ }
+}
+
+TEST(AddMapsTest, AddMapsExhaustiveTest_ArbitraryModuli) {
+ // Make SecurePrng instance to be used as a consistent pseudo-random number
+ // generator.
+ uint8_t seed_data[32];
+ memset(seed_data, '1', 32);
+ AesKey seed(seed_data);
+ AesCtrPrngFactory prng_factory;
+ std::unique_ptr<SecurePrng> prng = prng_factory.MakePrng(seed);
+
+ // Iterate through all possible bitwidths, add two random vectors, and
+ // verify the results.
+ for (uint64_t modulus : kArbitraryModuli) {
+ constexpr size_t kSize = 1000;
+ std::vector<uint64_t> vec_a(kSize);
+ std::vector<uint64_t> vec_b(kSize);
+ for (size_t i = 0; i < kSize; i++) {
+ vec_a[i] = prng->Rand64() % modulus;
+ vec_b[i] = prng->Rand64() % modulus;
+ }
+
+ SecAggVectorMap map_a;
+ map_a.emplace("test", SecAggVector(vec_a, modulus));
+ SecAggVectorMap map_b;
+ map_b.emplace("test", SecAggVector(vec_b, modulus));
+
+ auto map_sum = AddMaps(map_a, map_b);
+ std::vector<uint64_t> vec_sum = map_sum->at("test").GetAsUint64Vector();
+ for (size_t i = 0; i < kSize; i++) {
+ EXPECT_THAT(vec_sum[i], Eq((vec_a[i] + vec_b[i]) % modulus));
+ }
+ }
+}
+
+enum MapOfMasksVersion { CURRENT, V3, UNPACKED };
+
+class MapOfMasksTest : public ::testing::TestWithParam<MapOfMasksVersion> {
+ public:
+ using Uint64VectorMap =
+ absl::node_hash_map<std::string, std::vector<uint64_t>>;
+
+ std::unique_ptr<Uint64VectorMap> MapOfMasks(
+ const std::vector<AesKey>& prng_keys_to_add,
+ const std::vector<AesKey>& prng_keys_to_subtract,
+ const std::vector<InputVectorSpecification>& input_vector_specs,
+ const SessionId& session_id, const AesPrngFactory& prng_factory) {
+ if (GetParam() == MapOfMasksVersion::UNPACKED) {
+ return ToUint64VectorMap(fcp::secagg::UnpackedMapOfMasks(
+ prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ session_id, prng_factory));
+ } else if (GetParam() == MapOfMasksVersion::V3) {
+ return ToUint64VectorMap(fcp::secagg::MapOfMasksV3(
+ prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ session_id, prng_factory));
+ } else {
+ return ToUint64VectorMap(fcp::secagg::MapOfMasks(
+ prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
+ session_id, prng_factory));
+ }
+ }
+
+ private:
+ std::unique_ptr<Uint64VectorMap> ToUint64VectorMap(
+ std::unique_ptr<SecAggVectorMap> map) {
+ auto result = std::make_unique<Uint64VectorMap>();
+ for (auto& [name, vec] : *map) {
+ result->emplace(name, vec.GetAsUint64Vector());
+ }
+ return result;
+ }
+
+ std::unique_ptr<Uint64VectorMap> ToUint64VectorMap(
+ std::unique_ptr<SecAggUnpackedVectorMap> map) {
+ auto result = std::make_unique<Uint64VectorMap>();
+ for (auto& [name, vec] : *map) {
+ result->emplace(name, std::move(vec));
+ }
+ return result;
+ }
+};
+
+// AES MapOfMasks: Power-of-two Moduli
+
+TEST_P(MapOfMasksTest, ReturnsZeroIfNoKeysSpecified_PowerOfTwo) {
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
+
+ auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
+ session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks->size(), Eq(1));
+ std::vector<uint64_t> zeroes(10, 0);
+ EXPECT_THAT(masks->at("test"), Eq(std::vector<uint64_t>(10, 0)));
+}
+
+TEST_P(MapOfMasksTest, ReturnsNonZeroIfOneKeySpecified_PowerOfTwo) {
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t key[AesKey::kSize];
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
+
+ auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
+ session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks->size(), Eq(1));
+ EXPECT_THAT(masks->at("test"), Ne(std::vector<uint64_t>(10, 0)));
+}
+
+TEST_P(MapOfMasksTest, MapWithOneKeyDiffersFromMapWithTwoKeys_PowerOfTwo) {
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t
+ key[AesKey::kSize]; // This key is reusable because AesKey makes a copy
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
+
+ auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ memset(key, 'B', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ auto masks2 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks1->size(), Eq(1));
+ EXPECT_THAT(masks2->size(), Eq(1));
+ EXPECT_THAT(masks2->at("test"), Ne(masks1->at("test")));
+}
+
+TEST_P(MapOfMasksTest, MapsWithOppositeMasksCancel_PowerOfTwo) {
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t key[AesKey::kSize];
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ memset(key, 'B', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
+
+ auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks1->size(), Eq(1));
+ EXPECT_THAT(masks2->size(), Eq(1));
+ std::vector<uint64_t> mask_vector1 = masks1->at("test");
+ std::vector<uint64_t> mask_vector2 = masks2->at("test");
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], 1ULL << 20), Eq(0));
+ }
+}
+
+TEST_P(MapOfMasksTest, MapsWithMixedOppositeMasksCancel_PowerOfTwo) {
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t key[AesKey::kSize];
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ memset(key, 'B', AesKey::kSize);
+ std::vector<AesKey> prng_keys_to_subtract;
+ prng_keys_to_subtract.push_back(AesKey(key));
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, 1ULL << 20));
+
+ auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks1->size(), Eq(1));
+ EXPECT_THAT(masks2->size(), Eq(1));
+ std::vector<uint64_t> mask_vector1 = masks1->at("test");
+ std::vector<uint64_t> mask_vector2 = masks2->at("test");
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], 1ULL << 20), Eq(0));
+ }
+}
+
+TEST_P(MapOfMasksTest, PrngMaskGeneratesCorrectBitwidthMasks_PowerOfTwo) {
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t key[AesKey::kSize];
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+
+ // Check a variety of bit_widths
+ std::vector<uint64_t> moduli{1ULL << 1, 1ULL << 4, 1ULL << 20, 1ULL << 24,
+ SecAggVector::kMaxModulus};
+ for (uint64_t i : moduli) {
+ vector_specs.push_back(
+ InputVectorSpecification(absl::StrCat("test", i), 50, i));
+ }
+
+ auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
+ session_id, AesCtrPrngFactory());
+
+ // Make sure all elements are less than the bound, and also at least one of
+ // them has the highest-allowed bit set.
+ for (uint64_t modulus : moduli) {
+ auto vec = masks->at(absl::StrCat("test", modulus));
+ bool high_order_bit_set = false;
+ for (uint64_t mask : vec) {
+ EXPECT_THAT(mask, Lt(modulus));
+ if (mask >= (modulus >> 1)) {
+ high_order_bit_set = true;
+ }
+ }
+ EXPECT_THAT(high_order_bit_set, Eq(true));
+ }
+}
+
+// AES MapOfMasks: Arbitrary Moduli
+
+TEST_P(MapOfMasksTest, ReturnsZeroIfNoKeysSpecified_ArbitraryModuli) {
+ uint64_t modulus = 2636861836189;
+ std::vector<AesKey> prng_keys_to_add;
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
+
+ auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
+ session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks->size(), Eq(1));
+ std::vector<uint64_t> zeroes(10, 0);
+ EXPECT_THAT(masks->at("test"), Eq(std::vector<uint64_t>(10, 0)));
+}
+
+TEST_P(MapOfMasksTest, ReturnsNonZeroIfOneKeySpecified_ArbitraryModuli) {
+ uint64_t modulus = 2636861836189;
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t key[AesKey::kSize];
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
+
+ auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
+ session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks->size(), Eq(1));
+ EXPECT_THAT(masks->at("test"), Ne(std::vector<uint64_t>(10, 0)));
+}
+
+TEST_P(MapOfMasksTest, MapWithOneKeyDiffersFromMapWithTwoKeys_ArbitraryModuli) {
+ uint64_t modulus = 2636861836189;
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t
+ key[AesKey::kSize]; // This key is reusable because AesKey makes a copy
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
+
+ auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ memset(key, 'B', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ auto masks2 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks1->size(), Eq(1));
+ EXPECT_THAT(masks2->size(), Eq(1));
+ EXPECT_THAT(masks2->at("test"), Ne(masks1->at("test")));
+}
+
+TEST_P(MapOfMasksTest, MapsAreDeterministic_KeysToAdd_ArbitraryModuli) {
+ uint64_t modulus = 2636861836189;
+ uint8_t key[AesKey::kSize];
+ // prng_keys_to_add includes A
+ std::vector<AesKey> prng_keys_to_add;
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+
+ // prng_keys_to_subtract includes B
+ std::vector<AesKey> prng_keys_to_subtract;
+ memset(key, 'B', AesKey::kSize);
+ prng_keys_to_subtract.push_back(AesKey(key));
+
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
+
+ auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ auto masks2 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks1->size(), Eq(1));
+ EXPECT_THAT(masks2->size(), Eq(1));
+ std::vector<uint64_t> mask_vector1 = masks1->at("test");
+ std::vector<uint64_t> mask_vector2 = masks2->at("test");
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_THAT(mask_vector1[i], Eq(mask_vector2[i]));
+ }
+}
+
+TEST_P(MapOfMasksTest, MapsWithOppositeMasksCancel_ArbitraryModuli) {
+ uint64_t modulus = 2636861836189;
+ uint8_t key[AesKey::kSize];
+ // prng_keys_to_add includes A & B
+ std::vector<AesKey> prng_keys_to_add;
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ memset(key, 'B', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ // prng_keys_to_subtract is empty
+ std::vector<AesKey> prng_keys_to_subtract;
+
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
+
+ auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks1->size(), Eq(1));
+ EXPECT_THAT(masks2->size(), Eq(1));
+ std::vector<uint64_t> mask_vector1 = masks1->at("test");
+ std::vector<uint64_t> mask_vector2 = masks2->at("test");
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], modulus), Eq(0));
+ }
+}
+
+TEST_P(MapOfMasksTest, MapsWithMixedOppositeMasksCancel_ArbitraryModuli) {
+ uint64_t modulus = 2636861836189;
+ uint8_t key[AesKey::kSize];
+ // prng_keys_to_add includes A
+ std::vector<AesKey> prng_keys_to_add;
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ // prng_keys_to_subtract includes B
+ std::vector<AesKey> prng_keys_to_subtract;
+ memset(key, 'B', AesKey::kSize);
+ prng_keys_to_subtract.push_back(AesKey(key));
+
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+ vector_specs.push_back(InputVectorSpecification("test", 10, modulus));
+
+ auto masks1 = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ auto masks2 = MapOfMasks(prng_keys_to_subtract, prng_keys_to_add,
+ vector_specs, session_id, AesCtrPrngFactory());
+
+ EXPECT_THAT(masks1->size(), Eq(1));
+ EXPECT_THAT(masks2->size(), Eq(1));
+ std::vector<uint64_t> mask_vector1 = masks1->at("test");
+ std::vector<uint64_t> mask_vector2 = masks2->at("test");
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_THAT(AddMod(mask_vector1[i], mask_vector2[i], modulus), Eq(0));
+ }
+}
+
+TEST_P(MapOfMasksTest, PrngMaskGeneratesCorrectBitwidthMasks_ArbitraryModuli) {
+ std::vector<AesKey> prng_keys_to_add;
+ uint8_t key[AesKey::kSize];
+ memset(key, 'A', AesKey::kSize);
+ prng_keys_to_add.push_back(AesKey(key));
+ std::vector<AesKey> prng_keys_to_subtract;
+ SessionId session_id = {std::string(32, 'Z')};
+ std::vector<InputVectorSpecification> vector_specs;
+
+ // Check a variety of bit_widths
+ for (uint64_t i : kArbitraryModuli) {
+ vector_specs.push_back(
+ InputVectorSpecification(absl::StrCat("test", i), 50, i));
+ }
+
+ auto masks = MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, vector_specs,
+ session_id, AesCtrPrngFactory());
+
+ // Make sure all elements are less than the bound, and also at least one of
+ // them has the highest-allowed bit set.
+ for (uint64_t modulus : kArbitraryModuli) {
+ auto vec = masks->at(absl::StrCat("test", modulus));
+ bool high_order_bit_set = false;
+ for (uint64_t mask : vec) {
+ EXPECT_THAT(mask, Lt(modulus));
+ if (mask >= (modulus >> 1)) {
+ high_order_bit_set = true;
+ }
+ }
+ EXPECT_THAT(high_order_bit_set, Eq(true));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(MapOfMasksTest, MapOfMasksTest,
+ ::testing::Values<MapOfMasksVersion>(CURRENT, V3,
+ UNPACKED));
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/math.h b/fcp/secagg/shared/math.h
new file mode 100644
index 0000000..02fe14c
--- /dev/null
+++ b/fcp/secagg/shared/math.h
@@ -0,0 +1,122 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This class contains some simple inline math methods commonly used elsewhere
+// within SecAgg. No error checking or bounds checking is performed. The calling
+// code is responsible for making sure the operations do not overflow, except as
+// noted.
+
+#ifndef FCP_SECAGG_SHARED_MATH_H_
+#define FCP_SECAGG_SHARED_MATH_H_
+
+#include <cstdint>
+#include <string>
+
+#include "absl/base/internal/endian.h"
+#include "absl/numeric/int128.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+namespace secagg {
+
+// Integer division rounded up.
+static inline uint32_t DivideRoundUp(uint32_t a, uint32_t b) {
+ return (a + b - 1) / b;
+}
+
+// Addition modulo non-zero integer z.
+static inline uint64_t AddMod(uint64_t a, uint64_t b, uint64_t z) {
+ return (a + b) % z;
+}
+
+// Optimized version of AddMod that assumes that a and b are smaller than mod.
+// This version produces a code with branchless CMOVB instruction and is at
+// least 2x faster than AddMod on x64.
+// TODO(team): Eventually this should replace AddMod.
+inline uint64_t AddModOpt(uint64_t a, uint64_t b, uint64_t mod) {
+#ifndef NDEBUG
+ // Verify assumption that a and b are smaller than mod to start with.
+ FCP_CHECK(a < mod && b < mod);
+ // Make sure there is no overflow when adding a and b.
+ FCP_CHECK(a <= (a + b) && b <= (a + b));
+#endif
+ uint64_t sum = a + b;
+ return sum < mod ? sum : sum - mod;
+}
+
+// Subtraction modulo non-zero integer z. Handles underflow correctly if b > a.
+static inline uint64_t SubtractMod(uint64_t a, uint64_t b, uint64_t z) {
+ return ((a - b) + z) % z;
+}
+
+// Optimized version of SubtractMod that assumes that a and b are smaller than
+// mod. This version produces a code with branchless CMOVB instruction and is
+// at least 2x faster than SubtractMod on x64.
+// TODO(team): Eventually this should replace SubtractMod.
+inline uint64_t SubtractModOpt(uint64_t a, uint64_t b, uint64_t mod) {
+#ifndef NDEBUG
+ // Verify assumption that a and b are smaller than mod to start with.
+ FCP_CHECK(a < mod && b < mod);
+#endif
+ return a >= b ? a - b : mod - b + a;
+}
+
+// Multiplication of 32-bit integers modulo a non-zero integer z.
+// Guarantees the output is a 32-bit integer and avoids overflow by casting both
+// factors to uint64_t first.
+static inline uint32_t MultiplyMod(uint32_t a, uint32_t b, uint64_t z) {
+ return static_cast<uint32_t>((uint64_t{a} * uint64_t{b}) % z);
+}
+
+// Multiplication of 64-bit integers modulo a non-zero integer z.
+// Guarantees the output is a 64-bit integer and avoids overflow by casting both
+// factors to uint128 first.
+static inline uint64_t MultiplyMod64(uint64_t a, uint64_t b, uint64_t z) {
+ return absl::Uint128Low64((absl::uint128(a) * absl::uint128(b)) %
+ absl::uint128(z));
+}
+
+// Modular inverse of a 64-bit integer modulo a prime z via Fermat's little
+// theorem. Assumes that z is prime.
+static inline uint64_t InverseModPrime(uint64_t a, uint64_t z) {
+ uint64_t inverse = 1;
+ uint64_t exponent = z - 2;
+
+ while (exponent > 0) {
+ if (exponent & 1) {
+ inverse = MultiplyMod64(inverse, a, z);
+ }
+
+ exponent >>= 1;
+ a = MultiplyMod64(a, a, z);
+ }
+
+ return inverse;
+}
+
+// Converts ints to big-endian byte string representation. Provides platform-
+// independence only in converting known integer values to byte strings for use
+// in cryptographic methods, not for general processing of binary data.
+static inline std::string IntToByteString(uint32_t input) {
+ char bytes[4];
+ absl::big_endian::Store32(bytes, input);
+ return std::string(bytes, 4);
+}
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_MATH_H_
diff --git a/fcp/secagg/shared/math_test.cc b/fcp/secagg/shared/math_test.cc
new file mode 100644
index 0000000..82d008d
--- /dev/null
+++ b/fcp/secagg/shared/math_test.cc
@@ -0,0 +1,205 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/math.h"
+
+#include <cstdint>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::Eq;
+
+TEST(MathTest, DivideRoundUpIsAccurate) {
+ EXPECT_THAT(DivideRoundUp(0, 8), Eq(0));
+ EXPECT_THAT(DivideRoundUp(1, 8), Eq(1));
+ EXPECT_THAT(DivideRoundUp(8, 8), Eq(1));
+ EXPECT_THAT(DivideRoundUp(12, 8), Eq(2));
+ EXPECT_THAT(DivideRoundUp(31, 8), Eq(4));
+ EXPECT_THAT(DivideRoundUp(32, 8), Eq(4));
+ EXPECT_THAT(DivideRoundUp(33, 8), Eq(5));
+}
+
+TEST(MathTest, AddModIsAccurate) {
+ // power-of-2 moduli
+ EXPECT_THAT(AddMod(2, 5, 8), Eq(7));
+ EXPECT_THAT(AddMod(4, 5, 8), Eq(1));
+ EXPECT_THAT(AddMod(0, 5, 8), Eq(5));
+ EXPECT_THAT(AddMod(5, 0, 8), Eq(5));
+ EXPECT_THAT(AddMod(7, 7, 8), Eq(6));
+ EXPECT_THAT(AddMod(9223372036854775806ULL, 9223372036854775807ULL,
+ 9223372036854775808ULL),
+ Eq(9223372036854775805ULL));
+
+ // non-power-of-2 moduli
+ EXPECT_THAT(AddMod(2, 5, 7), Eq(0));
+ EXPECT_THAT(AddMod(4, 5, 7), Eq(2));
+ EXPECT_THAT(AddMod(0, 5, 7), Eq(5));
+ EXPECT_THAT(AddMod(5, 0, 7), Eq(5));
+ EXPECT_THAT(AddMod(7, 7, 7), Eq(0));
+ EXPECT_THAT(AddMod(9223372036854775805ULL, 9223372036854775806ULL,
+ 9223372036854775807ULL),
+ Eq(9223372036854775804ULL));
+}
+
+TEST(MathTest, AddModOptIsAccurate) {
+ // power-of-2 moduli
+ EXPECT_THAT(AddModOpt(2, 5, 8), Eq(7));
+ EXPECT_THAT(AddModOpt(4, 5, 8), Eq(1));
+ EXPECT_THAT(AddModOpt(0, 5, 8), Eq(5));
+ EXPECT_THAT(AddModOpt(5, 0, 8), Eq(5));
+ EXPECT_THAT(AddModOpt(7, 7, 8), Eq(6));
+ EXPECT_THAT(AddModOpt(9223372036854775806ULL, 9223372036854775807ULL,
+ 9223372036854775808ULL),
+ Eq(9223372036854775805ULL));
+
+ // non-power-of-2 moduli
+ EXPECT_THAT(AddModOpt(2, 5, 7), Eq(0));
+ EXPECT_THAT(AddModOpt(4, 5, 7), Eq(2));
+ EXPECT_THAT(AddModOpt(0, 5, 7), Eq(5));
+ EXPECT_THAT(AddModOpt(5, 0, 7), Eq(5));
+ EXPECT_THAT(AddModOpt(6, 6, 7), Eq(5));
+ EXPECT_THAT(AddModOpt(9223372036854775805ULL, 9223372036854775806ULL,
+ 9223372036854775807ULL),
+ Eq(9223372036854775804ULL));
+}
+
+TEST(MathTest, SubtractModWorksAndHandlesUnderflow) {
+ EXPECT_THAT(SubtractMod(3, 4, 10), Eq(9));
+ EXPECT_THAT(SubtractMod(2, 9, 10), Eq(3));
+ EXPECT_THAT(SubtractMod(0, 6, 10), Eq(4));
+ EXPECT_THAT(SubtractMod(0, 5, 10), Eq(5));
+ EXPECT_THAT(SubtractMod(7, 3, 10), Eq(4));
+ EXPECT_THAT(SubtractMod(9, 0, 10), Eq(9));
+ EXPECT_THAT(SubtractMod(0, 0, 10), Eq(0));
+ EXPECT_THAT(SubtractMod(7, 7, 10), Eq(0));
+ EXPECT_THAT(SubtractMod(9223372036854775807ULL, 0, 9223372036854775808ULL),
+ Eq(9223372036854775807ULL));
+ EXPECT_THAT(SubtractMod(0, 9223372036854775807ULL, 9223372036854775808ULL),
+ Eq(1));
+ EXPECT_THAT(SubtractMod(9223372036854775805ULL, 9223372036854775807ULL,
+ 9223372036854775808ULL),
+ Eq(9223372036854775806ULL));
+
+ EXPECT_THAT(SubtractMod(9223372036854775806ULL, 0, 9223372036854775807ULL),
+ Eq(9223372036854775806ULL));
+ EXPECT_THAT(SubtractMod(0, 9223372036854775806ULL, 9223372036854775807ULL),
+ Eq(1));
+ EXPECT_THAT(SubtractMod(9223372036854775805ULL, 9223372036854775806ULL,
+ 9223372036854775807ULL),
+ Eq(9223372036854775806ULL));
+}
+
+TEST(MathTest, SubtractModOptWorksAndHandlesUnderflow) {
+ EXPECT_THAT(SubtractModOpt(3, 4, 10), Eq(9));
+ EXPECT_THAT(SubtractModOpt(2, 9, 10), Eq(3));
+ EXPECT_THAT(SubtractModOpt(0, 6, 10), Eq(4));
+ EXPECT_THAT(SubtractModOpt(0, 5, 10), Eq(5));
+ EXPECT_THAT(SubtractModOpt(7, 3, 10), Eq(4));
+ EXPECT_THAT(SubtractModOpt(9, 0, 10), Eq(9));
+ EXPECT_THAT(SubtractModOpt(0, 0, 10), Eq(0));
+ EXPECT_THAT(SubtractModOpt(7, 7, 10), Eq(0));
+ EXPECT_THAT(SubtractModOpt(9223372036854775807ULL, 0, 9223372036854775808ULL),
+ Eq(9223372036854775807ULL));
+ EXPECT_THAT(SubtractModOpt(0, 9223372036854775807ULL, 9223372036854775808ULL),
+ Eq(1));
+ EXPECT_THAT(SubtractModOpt(9223372036854775805ULL, 9223372036854775807ULL,
+ 9223372036854775808ULL),
+ Eq(9223372036854775806ULL));
+
+ EXPECT_THAT(SubtractModOpt(9223372036854775806ULL, 0, 9223372036854775807ULL),
+ Eq(9223372036854775806ULL));
+ EXPECT_THAT(SubtractModOpt(0, 9223372036854775806ULL, 9223372036854775807ULL),
+ Eq(1));
+ EXPECT_THAT(SubtractModOpt(9223372036854775805ULL, 9223372036854775806ULL,
+ 9223372036854775807ULL),
+ Eq(9223372036854775806ULL));
+}
+
+TEST(MathTest, MultiplyModAvoidsOverflow) {
+ uint64_t p = 2147483659ULL; // 2 ^ 31 + 11; a prime number
+ uint32_t a = 2147483646; // 2 ^ 31 - 2; -13 mod p
+ uint32_t b = 2147483640; // 2 ^ 31 - 8; -19 mod p
+ uint32_t res1 = 169; // -13 * -13
+ uint32_t res2 = 247; // -13 * -19
+ uint32_t res3 = 361; // -19 * -19
+ EXPECT_THAT(MultiplyMod(a, a, p), Eq(res1));
+ EXPECT_THAT(MultiplyMod(a, b, p), Eq(res2));
+ EXPECT_THAT(MultiplyMod(b, a, p), Eq(res2));
+ EXPECT_THAT(MultiplyMod(b, b, p), Eq(res3));
+}
+
+TEST(MathTest, MultiplyMod64AvoidsOverflow) {
+ {
+ uint64_t p = 2147483659ULL; // 2 ^ 31 + 11; a prime number
+ uint32_t a = 2147483646; // 2 ^ 31 - 2; -13 mod p
+ uint32_t b = 2147483640; // 2 ^ 31 - 8; -19 mod p
+ uint32_t res1 = 169; // -13 * -13
+ uint32_t res2 = 247; // -13 * -19
+ uint32_t res3 = 361; // -19 * -19
+ EXPECT_THAT(MultiplyMod64(a, a, p), Eq(res1));
+ EXPECT_THAT(MultiplyMod64(a, b, p), Eq(res2));
+ EXPECT_THAT(MultiplyMod64(b, a, p), Eq(res2));
+ EXPECT_THAT(MultiplyMod64(b, b, p), Eq(res3));
+ }
+
+ {
+ uint64_t p = 4503599627371499ULL; // 2 ^ 52 + 1003; a prime number
+ uint64_t a = 1099511627776ULL; // 2 ^ 40
+ uint64_t b = 36028797018963971ULL; // 2 ^ 55 + 3
+ uint64_t res1 = 4503330386609131ULL; // a * a
+ uint64_t res2 = 188016488351702ULL; // a * b
+ uint64_t res3 = 64336441ULL; // b * b
+ EXPECT_THAT(MultiplyMod64(a, a, p), Eq(res1));
+ EXPECT_THAT(MultiplyMod64(a, b, p), Eq(res2));
+ EXPECT_THAT(MultiplyMod64(b, a, p), Eq(res2));
+ EXPECT_THAT(MultiplyMod64(b, b, p), Eq(res3));
+ }
+}
+TEST(MathTest, InverseModPrimeIsAccurate) {
+ // All mods assumed to be prime
+ EXPECT_THAT(InverseModPrime(12, 31), Eq(13));
+ EXPECT_THAT(InverseModPrime(13, 31), Eq(12));
+ EXPECT_THAT(InverseModPrime(13, 2147483659ULL), Eq(1651910507));
+ EXPECT_THAT(InverseModPrime(2147483646, 2147483659ULL), Eq(495573152));
+}
+
+TEST(MathTest, IntToByteStringProvidesBigEndianString) {
+ uint32_t big_low_bits = 0x01234567;
+ uint32_t big_high_bits = 0xFEDCBA98;
+ uint32_t max_val = 0xFFFFFFFF;
+ uint32_t min_val = 0x00000000;
+ uint8_t expected0[4] = {0x1, 0x23, 0x45, 0x67};
+ EXPECT_THAT(IntToByteString(big_low_bits),
+ Eq(std::string(reinterpret_cast<char*>(expected0), 4)));
+ uint8_t expected1[4] = {0xFE, 0xDC, 0xBA, 0x98};
+ EXPECT_THAT(IntToByteString(big_high_bits),
+ Eq(std::string(reinterpret_cast<char*>(expected1), 4)));
+ uint8_t expected2[4] = {0xFF, 0xFF, 0xFF, 0xFF};
+ EXPECT_THAT(IntToByteString(max_val),
+ Eq(std::string(reinterpret_cast<char*>(expected2), 4)));
+ uint8_t expected3[4] = {0x0, 0x0, 0x0, 0x0};
+ EXPECT_THAT(IntToByteString(min_val),
+ Eq(std::string(reinterpret_cast<char*>(expected3), 4)));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/prng.h b/fcp/secagg/shared/prng.h
new file mode 100644
index 0000000..7365218
--- /dev/null
+++ b/fcp/secagg/shared/prng.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_PRNG_H_
+#define FCP_SECAGG_SHARED_PRNG_H_
+
+#include <cstdint>
+#include <vector>
+
+namespace fcp {
+namespace secagg {
+
+// An interface for a secure pseudo-random number generator.
+class SecurePrng {
+ public:
+ virtual uint8_t Rand8() = 0;
+ virtual uint64_t Rand64() = 0;
+ virtual ~SecurePrng() = default;
+};
+
+// Extension of SecurePrng interface that supports batch mode - getting multiple
+// pseudo-random numbers in a single call.
+class SecureBatchPrng : public SecurePrng {
+ public:
+ // Get the maximum size of a buffer that can be filled by RandBuffer() in a
+ // single call.
+ virtual size_t GetMaxBufferSize() const = 0;
+
+ // Fills the provided buffer with pseudorandom bytes. Returns the number of
+ // bytes that has been generated, which can be smaller than the requested
+ // buffer_size if it exceeds the maximum buffer size.
+ virtual int RandBuffer(uint8_t* buffer, int buffer_size) = 0;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_PRNG_H_
diff --git a/fcp/secagg/shared/secagg_messages.proto b/fcp/secagg/shared/secagg_messages.proto
new file mode 100644
index 0000000..aa2e0dc
--- /dev/null
+++ b/fcp/secagg/shared/secagg_messages.proto
@@ -0,0 +1,272 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Definition of the protocol buffers for the SecAgg protocol
+
+syntax = "proto3";
+
+package fcp.secagg;
+
+import "google/protobuf/any.proto";
+
+option java_package = "fcp.secagg.shared";
+option java_outer_classname = "SecAggMessages";
+
+// TODO(team): remove container class, using
+// option java_multiple_files = true;
+
+// # MESSAGE WRAPPERS EXPOSED TO THE OUTSIDE
+
+// This message is a wrapper (simulating polymorphism) for all messages sent
+// from a client to the server during the protocol.
+message ClientToServerWrapperMessage {
+ // Each message from client to server contains exactly one of these
+ // semantic types, depending on the phase of the protocol.
+ oneof message_content {
+ // Abort the protocol.
+ AbortMessage abort = 1;
+
+ // Round 0 message; see details at message definition.
+ AdvertiseKeys advertise_keys = 2;
+
+ // Round 1 message; see details at message definition.
+ ShareKeysResponse share_keys_response = 3;
+
+ // Round 2 message; see details at message definition.
+ MaskedInputCollectionResponse masked_input_response = 4;
+
+ // Round 3 message; see details at message definition.
+ UnmaskingResponse unmasking_response = 5;
+ }
+}
+
+// This message is a wrapper (simulating polymorphism) for all messages sent
+// from the server to a client during the protocol.
+message ServerToClientWrapperMessage {
+ // Each message from server to client contains exactly one of these
+ // semantic types, depending on the phase of the protocol.
+ oneof message_content {
+ // Abort the protocol.
+ AbortMessage abort = 1;
+
+ // Round 1 message; see details at message definition.
+ ShareKeysRequest share_keys_request = 2;
+
+ // Round 2 message; see details at message definition.
+ MaskedInputCollectionRequest masked_input_request = 3;
+
+ // Round 3 message; see details at message definition.
+ UnmaskingRequest unmasking_request = 4;
+ }
+}
+
+// # MESSAGES INTERNAL TO THE LIBRARY
+
+// ## ABORT MESSAGE
+// Sent by the server to the client to cause the client to abort and erase all
+// the state relate to the current session. Can signify either an abort due to
+// some error or an abort because the server needs no more messages.
+// Sent by the client to the server to notify it that the sending client
+// aborted.
+message AbortMessage {
+ // Can contain optional logging/diagnostic info.
+ string diagnostic_info = 1;
+ // If true, the client will halt early but mark the protocol as a success,
+ // rather than as aborted.
+ bool early_success = 2;
+}
+
+// ## ROUND 0 MESSAGES: ADVERTISE KEYS
+// AdvertiseKeys is sent by clients who wish to participate in the protocol.
+
+// Part of a ClientToServerWrapperMessage. Contains a pair of public keys.
+message AdvertiseKeys {
+ // Pair of public keys for this client.
+ PairOfPublicKeys pair_of_public_keys = 1;
+}
+
+// A pair of public keys. Used as part of AdvertiseKeys and ShareKeysRequest.
+// TODO(team): move away from comprossed encoding of ECPoint into
+// encoding the complete PublicKey, including the other parameters.
+message PairOfPublicKeys {
+ // An encoding of the Diffie Hellman public key used to generate correlated
+ // noise to be added to ciphertexts.
+ // The encoding used is the one returned by the getEncoded(compressed=true)
+ // method of the class org.bouncycastle.math.ec.ECPoint
+ // TODO(team): figure out the actual encoding used.
+ bytes noise_pk = 1;
+
+ // An encoding of the Diffie Hellman public key used to establish private
+ // channels between the protocol parties.
+ // The encoding used is the one returned by the getEncoded(compressed=true)
+ // method of the class org.bouncycastle.math.ec.ECPoint
+ // TODO(team): figure out the actual encoding used.
+ bytes enc_pk = 2;
+}
+
+// ## ROUND 1 MESSAGES: SHARE KEYS
+// Clients who are selected to participate in the protocol receive a list of the
+// (pairs of) public keys of all other clients. Each client secret-shares its
+// own noise_sk and prf_sk with all the other clients (encrypting shares for
+// client j with their own enc_pk) and sends all these encrypted shares to the
+// server for distribution.
+
+// Part of a ServerToClientWrapperMessage. Contains a list of pairs of public
+// keys, as well as the logging ID for the SecAgg execution.
+message ShareKeysRequest {
+ // List of public keys for all clients, ordered by the clients'
+ // logical ids. Each client infers its logical id "i" from the
+ // position "i" of its pair of keys in the list.
+ // Note that the logical ids are assumed to be 1-indexed (i.e. the first
+ // public key in the repeated field corresponds to client with logical id 1)
+ repeated PairOfPublicKeys pairs_of_public_keys = 1;
+ // The logging ID for the SecAgg execution. All participants in the protocol
+ // will use this ID while logging, to allow metrics for the entire execution
+ // to be collected.
+ int64 sec_agg_execution_logging_id = 2;
+
+ // May be populated with implementation-specific data.
+ repeated google.protobuf.Any extra_data = 3;
+
+ // The session ID for the Secagg execution. All clients will use this session
+ // ID as part of the seed of PRNGs used during the execution of the protocol,
+ bytes session_id = 4 ;
+}
+
+// Part of a ClientToServerWrapperMessage. Contains a list of encrypted pairs of
+// key shares (one for each other client).
+message ShareKeysResponse {
+ // The (1-indexed) j'th element of the repeated field (of a ShareKeysResponse
+ // message sent from client i) contains the j'th shares of each of client i's
+ // noise_sk and prf_sk, encrypted under client j's enc_pk and intended to be
+ // sent to him. The i'th share (i.e. the share the a client would send to
+ // himself) will be empty.
+ // The client indexes above refer to their logical ids.
+ // This field is opaque, as it will be the output of an AES/GCM encryption.
+ // However, once decrypted, each entry should be interpreted as the
+ // serialization of a PairOfKeyShares message.
+ repeated bytes encrypted_key_shares = 1;
+}
+
+// Each of the encrypted_key_shares bytes in ShareKeysResponse and
+// MaskedInputCollectionRequest is an encryption of the serialization of this
+// message.
+message PairOfKeyShares {
+ // The two shares are encodings of BigIntegers which represents elements of
+ // the group Z_p, where p is the prime denoting the field over which the
+ // elliptic curve prime256v1 is defined. See the ShamirSecretSharing
+ // class for more details on the sharing.
+ //
+ // The following comes from the BigInteger documentation (the encoding is
+ // computed using the toByteArray() method):
+ // [The encoding consists of] a byte array containing the two's-complement
+ // representation of the BigInteger and is in big-endian byte-order: the most
+ // significant byte is in the zeroth element.
+ // The array will contain the minimum number of bytes required to represent
+ // this BigInteger, including at least one sign bit, which is
+ // (ceil((this.bitLength() + 1)/8)).
+ bytes noise_sk_share = 1;
+ bytes prf_sk_share = 2;
+}
+
+// ## ROUND 2 MESSAGES: MASKED INPUT COLLECTION
+// The server gives each client his shares of the keys of each other client who
+// responded in the previous round. Each client computes a masked version of its
+// own input and sends it to the server.
+
+// Part of a ServerToClientWrapperMessage. Contains a list of shares of other
+// clients' keys encrypted and intended for the client who receives this
+// message.
+message MaskedInputCollectionRequest {
+ // Each bytes field of this message (intended for party with logical id i)
+ // contains an encryption of a PairOfKeyShares message.
+ // The bytes are ordered according to the clients' logical ids. If a client
+ // j did not send the ShareKeysResponse in the previous round, the
+ // corresponding encrypted_key_shares bytes field will be empty (the receiving
+ // client i will therefore consider such a client j dead from now on).
+ repeated bytes encrypted_key_shares = 1;
+}
+
+// Part of a ClientToServerWrapperMessage. Contains a map of masked input
+// vectors (each identified by a name string).
+message MaskedInputCollectionResponse {
+ // The string key of the map represents the name of the input vector.
+ map<string, MaskedInputVector> vectors = 1;
+}
+
+// Part of a MaskedInputCollectionResponse message. Contains an encoding of an
+// input vector masked by the addition of pseudo-random values (which will be
+// removed later).
+message MaskedInputVector {
+ // The vector contains a packed representation of a masked SecAggVector,
+ // where each element is packed in little-endian order.
+ bytes encoded_vector = 1;
+
+ // May be populated with implementation-specific data.
+ repeated google.protobuf.Any extra_data = 2;
+}
+
+// ## ROUND 3 MESSAGES: UNMASKING
+// The server communicates to the clients the list of clients which did not
+// complete the previous round (and therefore are considered dead). Each client
+// i, for each other client j, responds with either the share of noise_sk (if
+// client j is dead) OR share of prf_sk (if client j is still alive).
+
+// Part of a ServerToClientWrapperMessage.
+message UnmaskingRequest {
+ // Clients which were alive at round 2, but did not send a ROUND 2 response,
+ // so they are considered dead from round 3 onwards.
+ // NOTE: to minimize bandwidth, this DOES NOT include clients which dropped
+ // before round 2 (i.e. those who did not send a round 1 response).
+ repeated uint32 dead_3_client_ids = 1;
+}
+
+// Part of a ClientToServerWrapperMessage.
+message UnmaskingResponse {
+ // These shares are NOT encrypted. For each other client j, client i sends to
+ // the server his share of noise_sk (if client j is dead) OR share of prf_sk
+ // (if client j is still alive).
+ // The entries corresponding to clients who died before round 2 (i.e. the
+ // ones for which the client sending the message never got any key shares) are
+ // empty. The entry corresponding to the client sending the message must be
+ // a prf_sk_share (otherwise the client would be dead and would not be sending
+ // this response.
+ repeated NoiseOrPrfKeyShare noise_or_prf_key_shares = 1;
+}
+
+// Part of an UnmaskingResponse. Contains either a share of a noise secret key
+// or a share of a prf secret key.
+message NoiseOrPrfKeyShare {
+ // Either a share of a noise secrete key or a share of a prf secret key.
+ // An honest client never reveals both, therefore oneof is appropriate.
+ oneof oneof_shares {
+ // A share of a noise secret key.
+ bytes noise_sk_share = 1;
+
+ // A share of a prf secret key.
+ bytes prf_sk_share = 2;
+ }
+}
+
+// Version(s) supported on each client.
+enum ClientVariant {
+ // No SecAgg versions are supported on this client.
+ SECAGG_CLIENT_VARIANT_NONE = 0;
+ // The Java implementation of the SegAgg client.
+ SECAGG_CLIENT_VARIANT_JAVA = 1 [deprecated = true];
+ // The native (C++) implementation of the SegAgg client.
+ SECAGG_CLIENT_VARIANT_NATIVE_V1 = 2;
+}
diff --git a/fcp/secagg/shared/secagg_vector.cc b/fcp/secagg/shared/secagg_vector.cc
new file mode 100644
index 0000000..3733a01
--- /dev/null
+++ b/fcp/secagg/shared/secagg_vector.cc
@@ -0,0 +1,386 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/secagg_vector.h"
+
+#include <inttypes.h>
+
+#include <algorithm>
+#include <array>
+#include <climits>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/secagg/shared/math.h"
+
+namespace fcp {
+namespace secagg {
+
+const uint64_t SecAggVector::kMaxModulus;
+
+SecAggVector::SecAggVector(absl::Span<const uint64_t> span, uint64_t modulus,
+ bool branchless_codec)
+ : modulus_(modulus),
+ bit_width_(SecAggVector::GetBitWidth(modulus)),
+ num_elements_(span.size()),
+ branchless_codec_(branchless_codec) {
+ FCP_CHECK(modulus_ > 1 && modulus_ <= kMaxModulus)
+ << "The specified modulus is not valid: must be > 1 and <= "
+ << kMaxModulus << "; supplied value : " << modulus_;
+ // Ensuring the supplied vector has the appropriate modulus.
+ for (uint64_t element : span) {
+ FCP_CHECK(element >= 0)
+ << "Only non negative elements are allowed in the vector.";
+ FCP_CHECK(element < modulus_)
+ << "The span does not have the appropriate modulus: element "
+ "with value "
+ << element << " found, max value allowed " << (modulus_ - 1ULL);
+ }
+
+ // Packs the long vector into a string, initialized to all null.
+ if (branchless_codec_) {
+ PackUint64IntoByteStringBranchless(span);
+ } else {
+ int num_bytes_needed =
+ DivideRoundUp(static_cast<uint32_t>(num_elements_ * bit_width_), 8);
+ packed_bytes_ = std::string(num_bytes_needed, '\0');
+ for (int i = 0; static_cast<size_t>(i) < span.size(); ++i) {
+ PackUint64IntoByteStringAt(i, span[i]);
+ }
+ }
+}
+
+SecAggVector::SecAggVector(std::string packed_bytes, uint64_t modulus,
+ size_t num_elements, bool branchless_codec)
+ : packed_bytes_(std::move(packed_bytes)),
+ modulus_(modulus),
+ bit_width_(SecAggVector::GetBitWidth(modulus)),
+ num_elements_(num_elements),
+ branchless_codec_(branchless_codec) {
+ FCP_CHECK(modulus_ > 1 && modulus_ <= kMaxModulus)
+ << "The specified modulus is not valid: must be > 1 and <= "
+ << kMaxModulus << "; supplied value : " << modulus_;
+ int expected_num_bytes = DivideRoundUp(num_elements_ * bit_width_, 8);
+ FCP_CHECK(packed_bytes_.size() == static_cast<size_t>(expected_num_bytes))
+ << "The supplied string is not the right size for " << num_elements_
+ << " packed elements: given string has a limit of "
+ << packed_bytes_.size() << " bytes, " << expected_num_bytes
+ << " bytes would have been needed.";
+}
+
+std::vector<uint64_t> SecAggVector::GetAsUint64Vector() const {
+ CheckHasValue();
+ std::vector<uint64_t> long_vector;
+ if (branchless_codec_) {
+ UnpackByteStringToUint64VectorBranchless(&long_vector);
+ } else {
+ long_vector.reserve(num_elements_);
+ for (int i = 0; i < num_elements_; ++i) {
+ long_vector.push_back(
+ UnpackUint64FromByteStringAt(i, bit_width_, packed_bytes_));
+ }
+ }
+ return long_vector;
+}
+
+void SecAggVector::PackUint64IntoByteStringAt(int index, uint64_t element) {
+ // The element will be packed starting with the least significant (leftmost)
+ // bits.
+ //
+ // TODO(team): Optimize out this extra per element computation.
+ int leftmost_bit_position = index * bit_width_;
+ int current_byte_index = leftmost_bit_position / 8;
+ int bits_left_to_pack = bit_width_;
+
+ // If leftmost_bit_position is in the middle of a byte, first fill that byte.
+ if (leftmost_bit_position % 8 != 0) {
+ int starting_bit_position = leftmost_bit_position % 8;
+ int empty_bits_left = 8 - starting_bit_position;
+ // Extract enough bits from "element" to fill the current byte, and shift
+ // them to the correct position.
+ uint64_t mask = (1ULL << std::min(empty_bits_left, bits_left_to_pack)) - 1L;
+ uint64_t value_to_add = (element & mask) << starting_bit_position;
+ packed_bytes_[current_byte_index] |= static_cast<char>(value_to_add);
+
+ bits_left_to_pack -= empty_bits_left;
+ element >>= empty_bits_left;
+ current_byte_index++;
+ }
+
+ // Current bit position is now aligned with the start of the current byte.
+ // Pack as many whole bytes as possible.
+ uint64_t lower_eight_bit_mask = 255L;
+ while (bits_left_to_pack >= 8) {
+ packed_bytes_[current_byte_index] =
+ static_cast<char>(element & lower_eight_bit_mask);
+
+ bits_left_to_pack -= 8;
+ element >>= 8;
+ current_byte_index++;
+ }
+
+ // Pack the remaining partial byte, if necessary.
+ if (bits_left_to_pack > 0) {
+ // there should be < 8 bits left, so pack all remaining bits at once.
+ packed_bytes_[current_byte_index] |= static_cast<char>(element);
+ }
+}
+
+uint64_t SecAggVector::UnpackUint64FromByteStringAt(
+ int index, int bit_width, const std::string& byte_string) {
+ // all the bits starting from, and including, this bit are copied.
+ int leftmost_bit_position = index * bit_width;
+ // byte containing the lowest order bit to be copied
+ int leftmost_byte_index = leftmost_bit_position / 8;
+ // all bits up to, but not including this bit, are copied.
+ int right_boundary_bit_position = ((index + 1) * bit_width);
+ // byte containing the highest order bit to copy
+ int rightmost_byte_index = (right_boundary_bit_position - 1) / 8;
+
+ // Special case: when the entire long value to unpack is contained in a single
+ // byte, then extract that long value in a single step.
+ if (leftmost_byte_index == rightmost_byte_index) {
+ int num_bits_to_skip = (leftmost_bit_position % 8);
+ int mask = ((1 << bit_width) - 1) << num_bits_to_skip;
+ // drop the extraneous bits below and above the value to unpack.
+ uint64_t unpacked_element =
+ (byte_string[leftmost_byte_index] & mask) >> num_bits_to_skip;
+ return unpacked_element;
+ }
+
+ // Normal case: the value to unpack spans one or more byte boundaries.
+ // The element will be unpacked in reverse order, starting from the most
+ // significant (rightmost) bits.
+ int current_byte_index = rightmost_byte_index;
+ uint64_t unpacked_element = 0;
+ int bits_left_to_unpack = bit_width;
+
+ // If right_boundary_bit_position is in the middle of a byte, unpack the bits
+ // up to right_boundary_bit_position within that byte.
+ if (right_boundary_bit_position % 8 != 0) {
+ int bits_to_copy_from_current_byte = (right_boundary_bit_position % 8);
+ int lower_bits_mask = (1 << bits_to_copy_from_current_byte) - 1;
+ unpacked_element |= (byte_string[current_byte_index] & lower_bits_mask);
+
+ bits_left_to_unpack -= bits_to_copy_from_current_byte;
+ current_byte_index--;
+ }
+
+ // Current bit position is now aligned with a byte boundary. Unpack as many
+ // whole bytes as possible.
+ while (bits_left_to_unpack >= 8) {
+ unpacked_element <<= 8;
+ unpacked_element |= byte_string[current_byte_index] & 0xff;
+
+ bits_left_to_unpack -= 8;
+ current_byte_index--;
+ }
+
+ // Unpack the remaining partial byte, if necessary.
+ if (bits_left_to_unpack > 0) {
+ unpacked_element <<= bits_left_to_unpack;
+ int bits_to_skip_in_current_byte = 8 - bits_left_to_unpack;
+ unpacked_element |= (byte_string[current_byte_index] & 0xff) >>
+ bits_to_skip_in_current_byte;
+ }
+
+ return unpacked_element;
+}
+
+void SecAggVector::PackUint64IntoByteStringBranchless(
+ const absl::Span<const uint64_t> span) {
+ SecAggVector::Coder coder(modulus_, bit_width_, num_elements_);
+ for (uint64_t element : span) {
+ coder.WriteValue(element);
+ }
+ packed_bytes_ = std::move(coder).Create().TakePackedBytes();
+}
+
+void SecAggVector::UnpackByteStringToUint64VectorBranchless(
+ std::vector<uint64_t>* long_vector) const {
+ long_vector->resize(num_elements_);
+ Decoder decoder(*this);
+ for (uint64_t& element : *long_vector) {
+ element = decoder.ReadValue();
+ }
+}
+
+SecAggVector::Decoder::Decoder(absl::string_view packed_bytes, uint64_t modulus)
+ : read_cursor_(packed_bytes.data()),
+ cursor_sentinel_(packed_bytes.data() + packed_bytes.size()),
+ cursor_read_value_(0),
+ scratch_(0),
+ read_cursor_bit_(0),
+ bit_width_(SecAggVector::GetBitWidth(modulus)),
+ mask_((1ULL << bit_width_) - 1),
+ modulus_(modulus) {
+ ReadData();
+}
+
+inline void SecAggVector::Decoder::ReadData() {
+ static constexpr ssize_t kBlockSizeBytes = sizeof(cursor_read_value_);
+ const ptrdiff_t bytes_remaining = cursor_sentinel_ - read_cursor_;
+ // Here, we use memcpy() to avoid the undefined behavior of
+ // reinterpret_cast<> on unaligned reads, opportunistically reading up to
+ // eight bytes at a time.
+ if (bytes_remaining >= kBlockSizeBytes) {
+ memcpy(&cursor_read_value_, read_cursor_, kBlockSizeBytes);
+ } else {
+ memcpy(&cursor_read_value_, read_cursor_,
+ bytes_remaining > 0 ? bytes_remaining : 0);
+ }
+ scratch_ |= cursor_read_value_ << static_cast<unsigned>(read_cursor_bit_);
+}
+
+uint64_t SecAggVector::Decoder::ReadValue() {
+ static constexpr int kBlockSizeBits = sizeof(cursor_read_value_) * 8;
+ // Get the current value.
+ const uint64_t current_value = scratch_ & mask_;
+ // Advance to the next value.
+ scratch_ >>= bit_width_;
+ int unwritten_bits = read_cursor_bit_;
+ read_cursor_bit_ -= bit_width_;
+ // Because we read in eight byte chunks on byte boundaries, and only keep
+ // eight bytes of scratch, a portion of the read could not fit, and now
+ // belongs at the back of scratch. The following assignments are compiled
+ // to a branchless conditional move on Clang X86_64 and Clang ARMv{7, 8}.
+ int read_bit_shift = bit_width_ - unwritten_bits;
+ unsigned int right_shift_value = read_bit_shift > 0 ? read_bit_shift : 0;
+ unsigned int left_shift_value = read_bit_shift < 0 ? -read_bit_shift : 0;
+ cursor_read_value_ >>= right_shift_value;
+ cursor_read_value_ <<= left_shift_value;
+ scratch_ |= cursor_read_value_;
+ int valid_scratch_bits = kBlockSizeBits - bit_width_ + unwritten_bits;
+ valid_scratch_bits = (valid_scratch_bits > kBlockSizeBits)
+ ? kBlockSizeBits
+ : valid_scratch_bits;
+ int new_read_cursor_bit =
+ read_cursor_bit_ +
+ static_cast<signed>(
+ (static_cast<unsigned>(valid_scratch_bits - read_cursor_bit_) & ~7U));
+ new_read_cursor_bit = new_read_cursor_bit == kBlockSizeBits
+ ? static_cast<int>(kBlockSizeBits - 8)
+ : new_read_cursor_bit;
+ read_cursor_ +=
+ static_cast<unsigned>((new_read_cursor_bit - read_cursor_bit_)) / 8;
+ read_cursor_bit_ = new_read_cursor_bit;
+ ReadData();
+ // The current_value is guaranteed to be in [0, 2 * modulus_) range due to the
+ // relationship between modulus_ and bit_width_, and therefore the below
+ // statement guarantees the return value to be in [0, modulus_) range.
+ return current_value < modulus_ ? current_value : current_value - modulus_;
+}
+
+SecAggVector::Coder::Coder(uint64_t modulus, int bit_width, size_t num_elements)
+ : modulus_(modulus),
+ bit_width_(bit_width),
+ num_elements_(num_elements),
+ target_cursor_value_(0),
+ starting_bit_position_(0) {
+ num_bytes_needed_ =
+ DivideRoundUp(static_cast<uint32_t>(num_elements_ * bit_width_), 8);
+ // The branchless variant assumes eight bytes of scratch space.
+ // The string is resized to the correct size at the end.
+ packed_bytes_ = std::string(num_bytes_needed_ + 8, '\0');
+ write_cursor_ = &packed_bytes_[0];
+}
+
+void SecAggVector::Coder::WriteValue(uint64_t value) {
+ static constexpr size_t kBlockSize = sizeof(target_cursor_value_);
+ // Here, we use memcpy() to avoid the undefined behavior of
+ // reinterpret_cast<> on unaligned stores, opportunistically writing eight
+ // bytes at a time.
+ target_cursor_value_ &= (1ULL << starting_bit_position_) - 1;
+ target_cursor_value_ |= value << starting_bit_position_;
+ std::memcpy(write_cursor_, &target_cursor_value_, kBlockSize);
+ const auto new_write_cursor =
+ write_cursor_ + (starting_bit_position_ + bit_width_) / 8;
+ const auto new_starting_bit_position =
+ (starting_bit_position_ + bit_width_) % 8;
+ // Because we write in eight byte chunks, a portion of element may have
+ // been missed, and now belongs at the front of target_cursor_value. The
+ // following assignments are compiled to a branchless conditional move on
+ // Clang X86_64 and Clang ARMv{7, 8}.
+ auto runt_cursor_value =
+ new_starting_bit_position
+ ? value >> (static_cast<unsigned>(kBlockSize * 8 -
+ starting_bit_position_) &
+ (kBlockSize * 8 - 1)) // Prevent unused UB warning.
+ : 0;
+ // Otherwise, remove fully written values from our scratch space.
+ target_cursor_value_ >>=
+ (static_cast<unsigned>(new_write_cursor - write_cursor_) * 8) &
+ (kBlockSize * 8 - 1); // Prevent unused UB warning.
+ target_cursor_value_ = (new_write_cursor - write_cursor_ == kBlockSize)
+ ? runt_cursor_value
+ : target_cursor_value_;
+ write_cursor_ = new_write_cursor;
+ starting_bit_position_ = new_starting_bit_position;
+}
+
+SecAggVector SecAggVector::Coder::Create() && {
+ static constexpr size_t kBlockSize = sizeof(target_cursor_value_);
+ std::memcpy(write_cursor_, &target_cursor_value_, kBlockSize);
+ packed_bytes_.resize(num_bytes_needed_);
+ return SecAggVector(std::move(packed_bytes_), modulus_, num_elements_,
+ /* branchless_codec=*/true);
+}
+
+void SecAggUnpackedVector::Add(const SecAggVector& other) {
+ FCP_CHECK(num_elements() == other.num_elements());
+ FCP_CHECK(modulus() == other.modulus());
+ SecAggVector::Decoder decoder(other);
+ for (auto& v : *this) {
+ v = AddModOpt(v, decoder.ReadValue(), modulus());
+ }
+}
+
+void SecAggUnpackedVectorMap::Add(const SecAggVectorMap& other) {
+ FCP_CHECK(size() == other.size());
+ for (auto& [name, vector] : *this) {
+ auto it = other.find(name);
+ FCP_CHECK(it != other.end());
+ vector.Add(it->second);
+ }
+}
+
+std::unique_ptr<SecAggUnpackedVectorMap> SecAggUnpackedVectorMap::AddMaps(
+ const SecAggUnpackedVectorMap& a, const SecAggUnpackedVectorMap& b) {
+ auto result = std::make_unique<SecAggUnpackedVectorMap>();
+ for (const auto& entry : a) {
+ auto name = entry.first;
+ auto length = entry.second.num_elements();
+ auto modulus = entry.second.modulus();
+ const auto& a_at_name = entry.second;
+ const auto& b_at_name = b.at(name);
+ SecAggUnpackedVector result_vector(length, modulus);
+ for (int j = 0; j < length; ++j) {
+ result_vector[j] = AddModOpt(a_at_name[j], b_at_name[j], modulus);
+ }
+ result->emplace(name, std::move(result_vector));
+ }
+ return result;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/secagg_vector.h b/fcp/secagg/shared/secagg_vector.h
new file mode 100644
index 0000000..7cfbf17
--- /dev/null
+++ b/fcp/secagg/shared/secagg_vector.h
@@ -0,0 +1,311 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_SECAGG_VECTOR_H_
+#define FCP_SECAGG_SHARED_SECAGG_VECTOR_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/base/attributes.h"
+#include "absl/container/node_hash_map.h"
+#include "absl/numeric/bits.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "fcp/base/monitoring.h"
+
+// Represents an immutable vector of nonnegative integers, where each entry has
+// the same specified bit width. This is used in the SecAgg package both to
+// provide input to SecAggClient and by SecAggServer to provide its output (more
+// specifically, inputs and outputs are of type
+// unordered_map<std::string, SecAggVector>, where the key denotes a name
+// associated with the vector).
+//
+// This class is backed by a packed byte representation of a uint64_t vector, in
+// little endian order, where each consecutive bit_width sequence of bits of
+// the packed vector corresponds to an integer value between
+// 0 and modulus.
+
+namespace fcp {
+namespace secagg {
+
+class SecAggVector {
+ public:
+ static constexpr uint64_t kMaxModulus = 1ULL << 62; // max 62 bitwidth
+
+ // Creates a SecAggVector of the specified modulus, using the specified
+ // span of uint64s. The integers are converted into a packed byte
+ // representation and stored in that format.
+ //
+ // Each element of span must be in [0, modulus-1].
+ //
+ // modulus itself must be > 1 and <= kMaxModulus.
+ SecAggVector(absl::Span<const uint64_t> span, uint64_t modulus,
+ bool branchless_codec = false);
+
+ // Creates a SecAggVector from the given little-endian packed byte
+ // representation. The packed representation should have num_elements longs,
+ // each of bit_width length (in bits).
+ //
+ // packed_bytes must be in the same format as the output of GetAsPackedBytes.
+ //
+ // modulus must be > 1 and <= kMaxModulus.
+ //
+ // For large strings, copying may be avoided by specifying an rvalue for
+ // packed bytes, e.g. std::move(large_caller_string), which should move the
+ // contents.
+ SecAggVector(std::string packed_bytes, uint64_t modulus, size_t num_elements,
+ bool branchless_codec = false);
+
+ // Disallow memory expensive copying of SecAggVector.
+ SecAggVector(const SecAggVector&) = delete;
+ SecAggVector& operator=(const SecAggVector&) = delete;
+
+ // Enable move semantics.
+ SecAggVector(SecAggVector&& other) { other.MoveTo(this); }
+
+ SecAggVector& operator=(SecAggVector&& other) {
+ other.MoveTo(this);
+ return *this;
+ }
+
+ // Calculates bitwith for the specified modulus.
+ inline static int GetBitWidth(uint64_t modulus) {
+ return static_cast<int>(absl::bit_width(modulus - 1ULL));
+ }
+
+ ABSL_MUST_USE_RESULT inline uint64_t modulus() const { return modulus_; }
+ ABSL_MUST_USE_RESULT inline size_t bit_width() const { return bit_width_; }
+ ABSL_MUST_USE_RESULT inline size_t num_elements() const {
+ return num_elements_;
+ }
+
+ // Produces and returns a representation of this SecAggVector as a vector of
+ // uint64_t. The returned vector is obtained by unpacking the stored packed
+ // representation of the vector.
+ ABSL_MUST_USE_RESULT std::vector<uint64_t> GetAsUint64Vector() const;
+
+ // Returns the stored, compressed representation of the SecAggVector.
+ // The bytes are stored in little-endian order, using only bit_width bits to
+ // represent each element of the vector.
+ ABSL_MUST_USE_RESULT inline const std::string& GetAsPackedBytes() const {
+ CheckHasValue();
+ return packed_bytes_;
+ }
+
+ // Takes out the stored, compressed representation of the SecAggVector.
+ // This call "consumes" the SecAggVector instance, and after that it becomes
+ // invalid.
+ // The bytes are stored in little-endian order, using only bit_width bits to
+ // represent each element of the vector.
+ ABSL_MUST_USE_RESULT inline std::string TakePackedBytes() && {
+ CheckHasValue();
+ modulus_ = 0;
+ bit_width_ = 0;
+ num_elements_ = 0;
+ return std::move(packed_bytes_);
+ }
+
+ inline friend bool operator==(const SecAggVector& lhs,
+ const SecAggVector& rhs) {
+ return lhs.packed_bytes_ == rhs.packed_bytes_;
+ }
+
+ // Decoder for unpacking SecAggVector values one by one.
+ class Decoder {
+ public:
+ explicit Decoder(const SecAggVector& v)
+ : Decoder(v.packed_bytes_, v.modulus_) {}
+
+ explicit Decoder(absl::string_view packed_bytes, uint64_t modulus);
+
+ // Unpacks and returns the next value.
+ // Result of this operation is undetermined when the decoder has already
+ // decoded all values. For performance reasons ReadValue doesn't validate
+ // the state.
+ uint64_t ReadValue();
+
+ private:
+ inline void ReadData();
+
+ const char* read_cursor_;
+ const char* const cursor_sentinel_;
+ uint64_t cursor_read_value_;
+ uint64_t scratch_;
+ int read_cursor_bit_;
+ uint8_t bit_width_;
+ const uint64_t mask_;
+ uint64_t modulus_;
+ };
+
+ // Coder for packing SecAggVector values one by one.
+ class Coder {
+ public:
+ explicit Coder(uint64_t modulus, int bit_width, size_t num_elements);
+
+ // Pack and write value to packed buffer.
+ void WriteValue(uint64_t value);
+
+ // Consumes the coder and creates SecAggVector with the packed buffer.
+ SecAggVector Create() &&;
+
+ private:
+ std::string packed_bytes_;
+ int num_bytes_needed_;
+ uint64_t modulus_;
+ int bit_width_;
+ size_t num_elements_;
+ char* write_cursor_;
+ uint64_t target_cursor_value_;
+ uint8_t starting_bit_position_;
+ };
+
+ private:
+ std::string packed_bytes_;
+ uint64_t modulus_;
+ int bit_width_;
+ size_t num_elements_;
+ bool branchless_codec_;
+
+ // Moves this object's value to the target one and resets this object's state.
+ inline void MoveTo(SecAggVector* target) {
+ target->modulus_ = modulus_;
+ target->bit_width_ = bit_width_;
+ target->num_elements_ = num_elements_;
+ target->branchless_codec_ = branchless_codec_;
+ target->packed_bytes_ = std::move(packed_bytes_);
+ modulus_ = 0;
+ bit_width_ = 0;
+ num_elements_ = 0;
+ branchless_codec_ = false;
+ }
+
+ // Verifies that this SecAggVector value can't be accessed after swapping it
+ // with another SecAggVector via std::move().
+ void CheckHasValue() const {
+ FCP_CHECK(modulus_ > 0) << "SecAggVector has no value";
+ }
+
+ void PackUint64IntoByteStringAt(int index, uint64_t element);
+ // A version without expensive branches or multiplies.
+ void PackUint64IntoByteStringBranchless(absl::Span<const uint64_t> span);
+
+ static ABSL_MUST_USE_RESULT uint64_t UnpackUint64FromByteStringAt(
+ int index, int bit_width, const std::string& byte_string);
+ // A version without expensive branches or multiplies.
+ void UnpackByteStringToUint64VectorBranchless(
+ std::vector<uint64_t>* long_vector) const;
+}; // class SecAggVector
+
+// This is equivalent to
+// using SecAggVectorMap = absl::node_hash_map<std::string, SecAggVector>;
+// except copy construction and assignment are explicitly prohibited.
+class SecAggVectorMap : public absl::node_hash_map<std::string, SecAggVector> {
+ public:
+ using Base = absl::node_hash_map<std::string, SecAggVector>;
+ using Base::Base;
+ using Base::operator=;
+ SecAggVectorMap(const SecAggVectorMap&) = delete;
+ SecAggVectorMap& operator=(const SecAggVectorMap&) = delete;
+};
+
+// Unpacked vector is simply a pair vector<uint64_t> and the modulus used with
+// each element.
+class SecAggUnpackedVector : public std::vector<uint64_t> {
+ public:
+ explicit SecAggUnpackedVector(size_t size, uint64_t modulus)
+ : vector(size), modulus_(modulus) {}
+
+ explicit SecAggUnpackedVector(std::vector<uint64_t> elements,
+ uint64_t modulus)
+ : vector(std::move(elements)), modulus_(modulus) {}
+
+ ABSL_MUST_USE_RESULT inline uint64_t modulus() const { return modulus_; }
+ ABSL_MUST_USE_RESULT inline size_t num_elements() const { return size(); }
+
+ // Disallow memory expensive copying of SecAggVector.
+ SecAggUnpackedVector(const SecAggUnpackedVector&) = delete;
+ SecAggUnpackedVector& operator=(const SecAggUnpackedVector&) = delete;
+
+ explicit SecAggUnpackedVector(const SecAggVector& other)
+ : vector(other.num_elements()), modulus_(other.modulus()) {
+ SecAggVector::Decoder decoder(other);
+ for (auto& v : *this) {
+ v = decoder.ReadValue();
+ }
+ }
+
+ // Enable move semantics.
+ SecAggUnpackedVector(SecAggUnpackedVector&& other)
+ : vector(std::move(other)), modulus_(other.modulus_) {
+ other.modulus_ = 0;
+ }
+
+ SecAggUnpackedVector& operator=(SecAggUnpackedVector&& other) {
+ modulus_ = other.modulus_;
+ other.modulus_ = 0;
+ vector::operator=(std::move(other));
+ return *this;
+ }
+
+ // Combines this vector with another (packed) vector by adding elements of
+ // this vector to corresponding elements of the other vector.
+ // It is assumed that both vectors have the same modulus. The modulus is
+ // applied to each sum.
+ void Add(const SecAggVector& other);
+
+ private:
+ uint64_t modulus_;
+};
+
+// This is mostly equivalent to
+// using SecAggUnpackedVectorMap =
+// absl::node_hash_map<std::string, SecAggUnpackedVector>;
+// except copy construction and assignment are explicitly prohibited and
+// Add method is added.
+class SecAggUnpackedVectorMap
+ : public absl::node_hash_map<std::string, SecAggUnpackedVector> {
+ public:
+ using Base = absl::node_hash_map<std::string, SecAggUnpackedVector>;
+ using Base::Base;
+ using Base::operator=;
+ SecAggUnpackedVectorMap(const SecAggUnpackedVectorMap&) = delete;
+ SecAggUnpackedVectorMap& operator=(const SecAggUnpackedVectorMap&) = delete;
+
+ explicit SecAggUnpackedVectorMap(const SecAggVectorMap& other) {
+ for (auto& [name, vector] : other) {
+ this->emplace(name, SecAggUnpackedVector(vector));
+ }
+ }
+
+ // Combines this map with another (packed) map by adding all vectors in this
+ // map to corresponding vectors in the other map.
+ // It is assumed that names of vectors match in both maps.
+ void Add(const SecAggVectorMap& other);
+
+ // Analogous to the above, as a static method. Also assumes that names of
+ // vectors match in both maps.
+ static std::unique_ptr<SecAggUnpackedVectorMap> AddMaps(
+ const SecAggUnpackedVectorMap& a, const SecAggUnpackedVectorMap& b);
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_SECAGG_VECTOR_H_
diff --git a/fcp/secagg/shared/secagg_vector_bench.cc b/fcp/secagg/shared/secagg_vector_bench.cc
new file mode 100644
index 0000000..49b13e0
--- /dev/null
+++ b/fcp/secagg/shared/secagg_vector_bench.cc
@@ -0,0 +1,106 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cstdint>
+#include <vector>
+
+#include "benchmark//benchmark.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+constexpr auto kVectorSize = 32 * 1024 * 1024;
+
+static void BM_CreatePowerOfTwo(benchmark::State& state) {
+ auto items_processed = 0;
+ std::vector<uint64_t> input;
+ input.resize(kVectorSize);
+ for (auto s : state) {
+ uint64_t modulus = 1ULL << static_cast<int>(state.range(1));
+ SecAggVector vec(input, modulus, state.range(0));
+ benchmark::DoNotOptimize(vec.GetAsUint64Vector());
+ items_processed += vec.num_elements();
+ }
+ state.SetItemsProcessed(items_processed);
+}
+
+static void BM_CreateArbitrary(benchmark::State& state) {
+ auto items_processed = 0;
+ std::vector<uint64_t> input;
+ input.resize(kVectorSize);
+ for (auto s : state) {
+ uint64_t modulus = static_cast<uint64_t>(state.range(1));
+ SecAggVector vec(input, modulus, state.range(0));
+ benchmark::DoNotOptimize(vec.GetAsUint64Vector());
+ items_processed += vec.num_elements();
+ }
+ state.SetItemsProcessed(items_processed);
+}
+
+BENCHMARK(BM_CreatePowerOfTwo)
+ ->RangeMultiplier(2)
+ ->Ranges({{false, true},
+ {1, absl::bit_width(SecAggVector::kMaxModulus - 1ULL)}});
+
+BENCHMARK(BM_CreatePowerOfTwo)->Args({false, 41})->Args({true, 41});
+
+BENCHMARK(BM_CreateArbitrary)
+ ->Args({false, 5})
+ ->Args({false, 39})
+ ->Args({false, 485})
+ ->Args({false, 2400})
+ ->Args({false, 14901})
+ ->Args({false, 51813})
+ ->Args({false, 532021})
+ ->Args({false, 13916946})
+ ->Args({false, 39549497})
+ ->Args({false, 548811945})
+ ->Args({false, 590549014})
+ ->Args({false, 48296031686})
+ ->Args({false, 156712951284})
+ ->Args({false, 2636861836189})
+ ->Args({false, 14673852658160})
+ ->Args({false, 92971495438615})
+ ->Args({false, 304436005557271})
+ ->Args({false, 14046234330484262})
+ ->Args({false, 38067457113486645})
+ ->Args({false, 175631339105057682})
+ ->Args({true, 5})
+ ->Args({true, 39})
+ ->Args({true, 485})
+ ->Args({true, 2400})
+ ->Args({true, 14901})
+ ->Args({true, 51813})
+ ->Args({true, 532021})
+ ->Args({true, 13916946})
+ ->Args({true, 39549497})
+ ->Args({true, 548811945})
+ ->Args({true, 590549014})
+ ->Args({true, 48296031686})
+ ->Args({true, 156712951284})
+ ->Args({true, 2636861836189})
+ ->Args({true, 14673852658160})
+ ->Args({true, 92971495438615})
+ ->Args({true, 304436005557271})
+ ->Args({true, 14046234330484262})
+ ->Args({true, 38067457113486645})
+ ->Args({true, 175631339105057682});
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/secagg_vector_test.cc b/fcp/secagg/shared/secagg_vector_test.cc
new file mode 100644
index 0000000..ac12755
--- /dev/null
+++ b/fcp/secagg/shared/secagg_vector_test.cc
@@ -0,0 +1,477 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/secagg_vector.h"
+
+#include <cstdint>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/secagg/shared/math.h"
+
+namespace fcp {
+namespace secagg {
+namespace {
+
+using ::testing::ElementsAreArray;
+using ::testing::Eq;
+using SecAggVectorTest = ::testing::TestWithParam<bool>;
+
+static std::array<uint64_t, 20> kArbitraryModuli{5,
+ 39,
+ 485,
+ 2400,
+ 14901,
+ 51813,
+ 532021,
+ 13916946,
+ 39549497,
+ 548811945,
+ 590549014,
+ 48296031686,
+ 156712951284,
+ 2636861836189,
+ 14673852658160,
+ 92971495438615,
+ 304436005557271,
+ 14046234330484262,
+ 38067457113486645,
+ 175631339105057682};
+
+TEST_P(SecAggVectorTest, GettersReturnAppropriateValuesOnConstructedVector) {
+ std::vector<uint64_t> raw_vector = {4, 5};
+ uint64_t modulus = 256;
+ SecAggVector vector(raw_vector, modulus, GetParam());
+ EXPECT_THAT(modulus, Eq(vector.modulus()));
+ EXPECT_THAT(8, Eq(vector.bit_width()));
+ EXPECT_THAT(raw_vector.size(), Eq(vector.num_elements()));
+ EXPECT_THAT(raw_vector, Eq(vector.GetAsUint64Vector()));
+}
+
+TEST_P(SecAggVectorTest, ConstructorDoesNotDieOnInputsCloseToModulusBound) {
+ std::vector<uint64_t> raw_vector = {0, 3};
+ SecAggVector vector(raw_vector, 4, GetParam());
+}
+
+TEST_P(SecAggVectorTest, ConstructorDiesOnInputEqualsModulus) {
+ std::vector<uint64_t> raw_vector = {4};
+ EXPECT_DEATH(SecAggVector vector(raw_vector, 4, GetParam()),
+ "The span does not have the appropriate modulus");
+}
+
+TEST_P(SecAggVectorTest, ConstructorDiesOnInputBiggerThanMaxModulus) {
+ std::vector<uint64_t> raw_vector = {SecAggVector::kMaxModulus};
+ EXPECT_DEATH(
+ SecAggVector vector(raw_vector, SecAggVector::kMaxModulus, GetParam()),
+ "The span does not have the appropriate modulus");
+}
+
+TEST_P(SecAggVectorTest, ConstructorDiesOnNegativeModulus) {
+ std::vector<uint64_t> raw_vector = {4};
+ EXPECT_DEATH(SecAggVector vector(raw_vector, -2, GetParam()),
+ "The specified modulus is not valid");
+}
+
+TEST_P(SecAggVectorTest, ConstructorDiesOnModulusZero) {
+ std::vector<uint64_t> raw_vector = {4};
+ EXPECT_DEATH(SecAggVector vector(raw_vector, 0, GetParam()),
+ "The specified modulus is not valid");
+}
+
+TEST_P(SecAggVectorTest, ConstructorDiesOnModulusOne) {
+ std::vector<uint64_t> raw_vector = {4};
+ EXPECT_DEATH(SecAggVector vector(raw_vector, 1, GetParam()),
+ "The specified modulus is not valid");
+}
+
+TEST_P(SecAggVectorTest, ConstructorDiesOnModulusTooLarge) {
+ std::vector<uint64_t> raw_vector = {4};
+ EXPECT_DEATH(SecAggVector vector(raw_vector, SecAggVector::kMaxModulus + 1,
+ GetParam()),
+ "The specified modulus is not valid");
+}
+
+TEST_P(SecAggVectorTest, StringConstructorSucceedsOnValidInputs) {
+ std::string packed_bytes(3, '\0');
+ SecAggVector vector(packed_bytes, 4, 12, GetParam());
+
+ // empty vector
+ std::string packed_bytes2 = "";
+ SecAggVector vector2(packed_bytes2, 32, 0, GetParam());
+
+ // lines up with byte boundary
+ std::string packed_bytes3(4, '\0');
+ SecAggVector vector3(packed_bytes3, 1ULL << 16, 2, GetParam());
+}
+
+TEST_P(SecAggVectorTest, StringConstructorDiesOnNegativeModulus) {
+ std::string packed_bytes(3, '\0');
+ EXPECT_DEATH(SecAggVector vector(packed_bytes, -2, 4, GetParam()),
+ "The specified modulus is not valid");
+}
+
+TEST_P(SecAggVectorTest, StringConstructorDiesOnModulusZero) {
+ std::string packed_bytes(3, '\0');
+ EXPECT_DEATH(SecAggVector vector(packed_bytes, 0, 4, GetParam()),
+ "The specified modulus is not valid");
+}
+
+TEST_P(SecAggVectorTest, StringConstructorDiesOnModulusOne) {
+ std::string packed_bytes(3, '\0');
+ EXPECT_DEATH(SecAggVector vector(packed_bytes, 1, 4, GetParam()),
+ "The specified modulus is not valid");
+}
+
+TEST_P(SecAggVectorTest, StringConstructorDiesOnModulusTooLarge) {
+ std::string packed_bytes(3, '\0');
+ EXPECT_DEATH(SecAggVector vector(packed_bytes, SecAggVector::kMaxModulus + 1,
+ 4, GetParam()),
+ "The specified modulus is not valid");
+}
+
+TEST_P(SecAggVectorTest, StringConstructorDiesOnTooShortString) {
+ int num_elements = 4;
+ uint64_t modulus = 16;
+ int bit_width = 4;
+ int expected_length = DivideRoundUp(num_elements * bit_width, 8);
+
+ std::string packed_bytes(expected_length - 1, '\0');
+ EXPECT_DEATH(SecAggVector vector(packed_bytes, modulus, 4, GetParam()),
+ "The supplied string is not the right size");
+}
+
+TEST_P(SecAggVectorTest, StringConstructorDiesOnTooLongString) {
+ int num_elements = 4;
+ uint64_t modulus = 16;
+ int bit_width = 4;
+ int expected_length = DivideRoundUp(num_elements * bit_width, 8);
+
+ std::string packed_bytes(expected_length + 1, '\0');
+ EXPECT_DEATH(SecAggVector vector(packed_bytes, modulus, 4, GetParam()),
+ "The supplied string is not the right size");
+}
+
+TEST_P(SecAggVectorTest, PackedVectorHasCorrectSize) {
+ std::vector<uint64_t> raw_vector = {0, 1, 2, 3, 4};
+ uint64_t modulus = 32;
+ int bit_width = 5;
+ SecAggVector vector(raw_vector, modulus, GetParam());
+ std::string packed_bytes = vector.GetAsPackedBytes();
+ int expected_length = DivideRoundUp(raw_vector.size() * bit_width, 8);
+ EXPECT_THAT(expected_length, Eq(packed_bytes.size()));
+
+ // empty vector
+ std::vector<uint64_t> empty_raw_vector = {};
+ modulus = 32;
+ bit_width = 5;
+ SecAggVector vector2(empty_raw_vector, modulus, GetParam());
+ packed_bytes = vector2.GetAsPackedBytes();
+ expected_length = 0;
+ EXPECT_THAT(expected_length, Eq(packed_bytes.size()));
+
+ // packed_bytes lines up with byte boundary
+ modulus = 1ULL << 16;
+ bit_width = 16;
+ SecAggVector vector3(raw_vector, modulus, GetParam());
+ packed_bytes = vector3.GetAsPackedBytes();
+ expected_length = DivideRoundUp(raw_vector.size() * bit_width, 8);
+ EXPECT_THAT(expected_length, Eq(packed_bytes.size()));
+
+ // max bit_width
+ modulus = 1ULL << 62;
+ bit_width = 62;
+ SecAggVector vector4(raw_vector, modulus, GetParam());
+ packed_bytes = vector4.GetAsPackedBytes();
+ expected_length = DivideRoundUp(raw_vector.size() * bit_width, 8);
+ EXPECT_THAT(expected_length, Eq(packed_bytes.size()));
+}
+
+TEST_P(SecAggVectorTest, PackedVectorUnpacksToSameValues) {
+ std::vector<uint64_t> raw_vector = {};
+ uint64_t modulus = 32;
+ SecAggVector vector(raw_vector, modulus, GetParam());
+ std::string packed_bytes = vector.GetAsPackedBytes();
+ SecAggVector unpacked_vector(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector.GetAsUint64Vector()));
+
+ // bit_width 1
+ raw_vector = {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0};
+ modulus = 2;
+ SecAggVector vector2(raw_vector, modulus, GetParam());
+ packed_bytes = vector2.GetAsPackedBytes();
+ SecAggVector unpacked_vector2(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector2.GetAsUint64Vector()));
+
+ // bit_width lines up with byte boundary
+ raw_vector = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ modulus = 1ULL << 16;
+ SecAggVector vector3(raw_vector, modulus, GetParam());
+ packed_bytes = vector3.GetAsPackedBytes();
+ SecAggVector unpacked_vector3(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector3.GetAsUint64Vector()));
+
+ // bit_width one less than with byte boundary
+ raw_vector = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ modulus = 1ULL << 15;
+ SecAggVector vector4(raw_vector, modulus, GetParam());
+ packed_bytes = vector4.GetAsPackedBytes();
+ SecAggVector unpacked_vector4(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector4.GetAsUint64Vector()));
+
+ // bit_width one greater than byte boundary
+ raw_vector = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ modulus = 1ULL << 17;
+ SecAggVector vector5(raw_vector, modulus, GetParam());
+ packed_bytes = vector5.GetAsPackedBytes();
+ SecAggVector unpacked_vector5(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector5.GetAsUint64Vector()));
+
+ // bit_width relatively prime to byte boundary
+ raw_vector.clear();
+ raw_vector.resize(100, 1L);
+ modulus = 1ULL << 19;
+ SecAggVector vector6(raw_vector, modulus, GetParam());
+ packed_bytes = vector6.GetAsPackedBytes();
+ SecAggVector unpacked_vector6(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector6.GetAsUint64Vector()));
+
+ // max bit_width, where each array entry has its lowest bit set
+ modulus = 1ULL << 62;
+ SecAggVector vector7(raw_vector, modulus, GetParam());
+ packed_bytes = vector7.GetAsPackedBytes();
+ SecAggVector unpacked_vector7(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector7.GetAsUint64Vector()));
+
+ // max bit_width, where each array entry has its highest bit set
+ uint64_t val = SecAggVector::kMaxModulus - 1;
+ raw_vector.clear();
+ raw_vector.resize(100, val);
+ modulus = 1ULL << 62;
+ SecAggVector vector8(raw_vector, modulus, GetParam());
+ packed_bytes = vector8.GetAsPackedBytes();
+ SecAggVector unpacked_vector8(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector8.GetAsUint64Vector()));
+
+ // small non power-of-2 modulus
+ raw_vector = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+ modulus = 11;
+ SecAggVector vector9(raw_vector, modulus, GetParam());
+ packed_bytes = vector9.GetAsPackedBytes();
+ SecAggVector unpacked_vector9(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector9.GetAsUint64Vector()));
+
+ // large non power-of-2 modulus
+ raw_vector = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2636861836188};
+ modulus = 2636861836189;
+ SecAggVector vector10(raw_vector, modulus, GetParam());
+ packed_bytes = vector10.GetAsPackedBytes();
+ SecAggVector unpacked_vector10(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector10.GetAsUint64Vector()));
+}
+
+TEST_P(SecAggVectorTest, PackedVectorUnpacksToSameValuesExhaustive_PowerOf2) {
+ for (auto i = 1; i < absl::bit_width(SecAggVector::kMaxModulus - 1); ++i) {
+ for (auto j = 0; j < 1024; ++j) {
+ for (auto val : {1ULL, 1ULL << (i - 1), (1ULL << (i - 1)) - 1,
+ i & ~((1ULL << i) - 1)}) {
+ auto bit_width = i;
+ uint64_t modulus = 1ULL << bit_width;
+ std::vector<uint64_t> raw_vector(j, val);
+ SecAggVector vector(raw_vector, modulus, GetParam());
+ const auto& packed_bytes = vector.GetAsPackedBytes();
+ SecAggVector unpacked_vector(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector.GetAsUint64Vector()));
+ }
+ }
+ }
+}
+
+TEST_P(SecAggVectorTest, PackedVectorUnpacksToSameValuesExhaustive_Arbitrary) {
+ for (auto modulus : kArbitraryModuli) {
+ for (auto j = 0; j < 1024; ++j) {
+ for (uint64_t val :
+ {static_cast<uint64_t>(0UL), static_cast<uint64_t>(1UL),
+ static_cast<uint64_t>((modulus >> 1) - 1),
+ static_cast<uint64_t>(modulus >> 1),
+ static_cast<uint64_t>((modulus >> 1) + 1),
+ static_cast<uint64_t>(modulus - 1)}) {
+ std::vector<uint64_t> raw_vector(j, val);
+ SecAggVector vector(raw_vector, modulus, GetParam());
+ const auto& packed_bytes = vector.GetAsPackedBytes();
+ SecAggVector unpacked_vector(packed_bytes, modulus, raw_vector.size(),
+ GetParam());
+ EXPECT_THAT(raw_vector, Eq(unpacked_vector.GetAsUint64Vector()));
+ }
+ }
+ }
+}
+
+TEST_P(SecAggVectorTest, VerifyPackingExample1) {
+ std::vector<uint64_t> correct_unpacked = {1, 3, 7, 15};
+ char correct_packed_array[] = {static_cast<char>(0b01100001),
+ static_cast<char>(0b10011100),
+ static_cast<char>(0b00000111)};
+ std::string correct_packed(correct_packed_array, 3);
+ uint64_t modulus = 32;
+
+ SecAggVector from_unpacked_vector(correct_unpacked, modulus, GetParam());
+ const std::string& packed_bytes = from_unpacked_vector.GetAsPackedBytes();
+ EXPECT_THAT(correct_packed, Eq(packed_bytes));
+
+ SecAggVector from_packed_vector(correct_packed, modulus,
+ correct_unpacked.size(), GetParam());
+ EXPECT_THAT(correct_unpacked, Eq(from_packed_vector.GetAsUint64Vector()));
+}
+
+TEST_P(SecAggVectorTest, VerifyPackingExample2) {
+ std::vector<uint64_t> correct_unpacked = {13, 17, 19};
+ char correct_packed_array[] = {
+ static_cast<char>(0b00001101), static_cast<char>(0b00100010),
+ static_cast<char>(0b01001100), static_cast<char>(0b00000000)};
+ std::string correct_packed(correct_packed_array, 4);
+ uint64_t modulus = 512;
+
+ SecAggVector from_unpacked_vector(correct_unpacked, modulus, GetParam());
+ const std::string& packed_bytes = from_unpacked_vector.GetAsPackedBytes();
+ EXPECT_THAT(correct_packed, Eq(packed_bytes));
+
+ SecAggVector from_packed_vector(correct_packed, modulus,
+ correct_unpacked.size(), GetParam());
+ EXPECT_THAT(correct_unpacked, Eq(from_packed_vector.GetAsUint64Vector()));
+}
+
+TEST_P(SecAggVectorTest, MoveConstructor) {
+ std::vector<uint64_t> raw_vector = {0, 3};
+ SecAggVector vector(raw_vector, 4, GetParam());
+ SecAggVector other(std::move(vector));
+ EXPECT_THAT(other.GetAsUint64Vector(), Eq(raw_vector));
+}
+
+TEST_P(SecAggVectorTest, MoveAssignment) {
+ std::vector<uint64_t> raw_vector = {0, 3};
+ SecAggVector vector(raw_vector, 4, GetParam());
+ SecAggVector other = std::move(vector);
+ EXPECT_THAT(other.GetAsUint64Vector(), Eq(raw_vector));
+}
+
+TEST_P(SecAggVectorTest, VerifyGetAsPackedBytesDiesAfterMoving) {
+ std::vector<uint64_t> raw_vector = {0, 3};
+ SecAggVector vector(raw_vector, 4, GetParam());
+ SecAggVector other = std::move(vector);
+
+ ASSERT_DEATH(auto i = vector.GetAsPackedBytes(), // NOLINT
+ "SecAggVector has no value");
+}
+
+TEST_P(SecAggVectorTest, VerifyGetAsUint64VectorDiesAfterMoving) {
+ std::vector<uint64_t> raw_vector = {0, 3};
+ SecAggVector vector(raw_vector, 4, GetParam());
+ SecAggVector other = std::move(vector);
+
+ ASSERT_DEATH(auto vec = vector.GetAsUint64Vector(), // NOLINT
+ "SecAggVector has no value");
+}
+
+TEST(SecAggVectorTest, VerifyTakePackedBytesDiesAfterMoving) {
+ std::vector<uint64_t> raw_vector = {0, 3};
+ SecAggVector vector(raw_vector, 4);
+ SecAggVector other = std::move(vector);
+
+ ASSERT_DEATH(auto i = std::move(vector).TakePackedBytes(), // NOLINT
+ "SecAggVector has no value");
+}
+
+INSTANTIATE_TEST_SUITE_P(Branchless, SecAggVectorTest, ::testing::Bool(),
+ ::testing::PrintToStringParamName());
+
+TEST(SecAggUnpackedVectorTest, VerifyBasicOperations) {
+ SecAggUnpackedVector vector(100, 32);
+ EXPECT_THAT(vector.num_elements(), Eq(100));
+ EXPECT_THAT(vector.modulus(), Eq(32));
+
+ SecAggUnpackedVector vector2({1, 2, 3}, 32);
+ EXPECT_THAT(vector2.num_elements(), Eq(3));
+ EXPECT_THAT(vector2.modulus(), Eq(32));
+ EXPECT_THAT(vector2.size(), Eq(3));
+ EXPECT_THAT(vector2[1], Eq(2));
+}
+
+TEST(SecAggUnpackedVectorTest, VerifyMoveConstructor) {
+ SecAggUnpackedVector vector({1, 2, 3}, 32);
+ SecAggUnpackedVector vector2(std::move(vector));
+ EXPECT_THAT(vector.modulus(), Eq(0)); // NOLINT
+ EXPECT_THAT(vector2.num_elements(), Eq(3));
+ EXPECT_THAT(vector2.modulus(), Eq(32));
+ EXPECT_THAT(vector2[2], Eq(3));
+}
+
+TEST(SecAggUnpackedVectorTest, VerifyConstructorFromSecAggVector) {
+ std::vector<uint64_t> raw_vector = {1, 2, 3};
+ SecAggVector vector(raw_vector, 32);
+ SecAggUnpackedVector vector2(vector);
+ EXPECT_THAT(vector2.num_elements(), Eq(3));
+ EXPECT_THAT(vector2.modulus(), Eq(32));
+ EXPECT_THAT(vector2[2], Eq(3));
+}
+
+TEST(SecAggUnpackedVectorTest, VerifyMoveAssignment) {
+ SecAggUnpackedVector vector({1, 2, 3}, 32);
+ SecAggUnpackedVector vector2 = std::move(vector);
+ EXPECT_THAT(vector.modulus(), Eq(0)); // NOLINT
+ EXPECT_THAT(vector2.num_elements(), Eq(3));
+ EXPECT_THAT(vector2.modulus(), Eq(32));
+ EXPECT_THAT(vector2[0], Eq(1));
+}
+
+TEST(SecAggUnpackedVectorTest, AddSecAggVectorMap) {
+ auto unpacked_map = std::make_unique<SecAggUnpackedVectorMap>();
+ unpacked_map->emplace("foobar", SecAggUnpackedVector({0, 10, 20, 30}, 32));
+
+ auto packed_map = std::make_unique<SecAggVectorMap>();
+ packed_map->emplace("foobar", SecAggVector({5, 5, 5, 5}, 32));
+
+ unpacked_map->Add(*packed_map);
+ EXPECT_THAT(unpacked_map->size(), Eq(1));
+ EXPECT_THAT(unpacked_map->at("foobar"), ElementsAreArray({5, 15, 25, 3}));
+}
+
+TEST(SecAggUnpackedVectorTest, AddUnpackedSecAggVectorMaps) {
+ SecAggUnpackedVectorMap unpacked_map_1, unpacked_map_2;
+ unpacked_map_1.emplace("foobar", SecAggUnpackedVector({0, 10, 20, 30}, 32));
+ unpacked_map_2.emplace("foobar", SecAggUnpackedVector({5, 5, 5, 5}, 32));
+
+ auto result =
+ SecAggUnpackedVectorMap::AddMaps(unpacked_map_1, unpacked_map_2);
+ EXPECT_THAT(result->size(), Eq(1));
+ EXPECT_THAT(result->at("foobar"), ElementsAreArray({5, 15, 25, 3}));
+}
+
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/shamir_secret_sharing.cc b/fcp/secagg/shared/shamir_secret_sharing.cc
new file mode 100644
index 0000000..2d96a79
--- /dev/null
+++ b/fcp/secagg/shared/shamir_secret_sharing.cc
@@ -0,0 +1,295 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/numbers.h"
+#include "fcp/secagg/shared/math.h"
+#include "openssl/rand.h"
+
+namespace fcp {
+namespace secagg {
+
+const uint64_t ShamirSecretSharing::kPrime;
+constexpr size_t kSubsecretSize = sizeof(uint32_t);
+
+ShamirSecretSharing::ShamirSecretSharing() {}
+
+std::vector<ShamirShare> ShamirSecretSharing::Share(
+ int threshold, int num_shares, const std::string& to_share) {
+ FCP_CHECK(!to_share.empty()) << "to_share must not be empty";
+ FCP_CHECK(num_shares > 1) << "num_shares must be greater than 1";
+ FCP_CHECK(2 <= threshold && threshold <= num_shares)
+ << "threshold must be at least 2 and at most num_shares";
+
+ std::vector<uint32_t> subsecrets = DivideIntoSubsecrets(to_share);
+
+ // Each ShamirShare is specified as a string of length 4 * subsecrets.size().
+ // The first four characters of the ShamirShare are the share of the first
+ // subsecret stored in big-endian order, and so on.
+ std::vector<ShamirShare> shares(num_shares);
+ for (auto& share : shares) {
+ share.data.reserve(kSubsecretSize * subsecrets.size());
+ }
+
+ for (uint32_t subsecret : subsecrets) {
+ std::vector<uint32_t> coefficients;
+ coefficients.reserve(threshold);
+ coefficients.push_back(subsecret);
+
+ for (int i = 1; i < threshold; ++i) {
+ coefficients.push_back(RandomFieldElement());
+ }
+
+ for (int i = 0; i < num_shares; ++i) {
+ // The client with id x gets the share of the polynomial evaluated at x+1.
+ uint32_t subshare = EvaluatePolynomial(coefficients, i + 1);
+ // Big-endian encoding
+ shares[i].data += IntToByteString(subshare);
+ }
+ }
+ return shares;
+}
+
+StatusOr<std::string> ShamirSecretSharing::Reconstruct(
+ int threshold, const std::vector<ShamirShare>& shares, int secret_length) {
+ FCP_CHECK(threshold > 1) << "threshold must be at least 2";
+ FCP_CHECK(secret_length > 0) << "secret_length must be positive";
+ FCP_CHECK(static_cast<int>(shares.size()) >= threshold)
+ << "A vector of size " << shares.size()
+ << " was provided, but threshold was specified as " << threshold;
+
+ // The max possible number of subsecrets is based on the secret_length.
+ int max_num_subsecrets =
+ ((8 * secret_length) + kBitsPerSubsecret - 1) / kBitsPerSubsecret;
+ // The number of subsecrets may be different due to compatibility with the
+ // legacy Java implementation and may be smaller than max_num_subsecrets.
+ // The actual number is determined below.
+ int num_subsecrets = 0;
+
+ // The X values of the participating clients' shares. The i-th share will be
+ // given an X value of i+1, to account for the fact that shares are 0-indexed.
+ // We want exactly threshold participating clients.
+ std::vector<int> x_values;
+
+ for (int i = 0; i < static_cast<int>(shares.size()) &&
+ static_cast<int>(x_values.size()) < threshold;
+ ++i) {
+ if (shares[i].data.empty()) {
+ continue;
+ }
+
+ FCP_CHECK(shares[i].data.size() % kSubsecretSize == 0)
+ << "Share with index " << i << " is invalid: a share of size "
+ << shares[i].data.size() << " was provided but a multiple of "
+ << kSubsecretSize << " is expected";
+ if (num_subsecrets == 0) {
+ num_subsecrets = static_cast<int>(shares[i].data.size() / kSubsecretSize);
+ FCP_CHECK(num_subsecrets > 0 && num_subsecrets <= max_num_subsecrets)
+ << "Share with index " << i << " is invalid: "
+ << "the number of subsecrets is " << num_subsecrets
+ << " but between 1 and " << max_num_subsecrets << " is expected";
+ } else {
+ FCP_CHECK(shares[i].data.size() == num_subsecrets * kSubsecretSize)
+ << "Share with index " << i << " is invalid: "
+ << "all shares must match sizes: "
+ << "shares[i].data.size() = " << shares[i].data.size()
+ << ", num_subsecrets = " << num_subsecrets;
+ }
+ x_values.push_back(i + 1);
+ }
+ if (static_cast<int>(x_values.size()) < threshold) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "Only " << x_values.size()
+ << " valid shares were provided, but threshold was specified as "
+ << threshold;
+ }
+
+ // Recover the sharing polynomials using Lagrange polynomial interpolation.
+ std::vector<uint32_t> coefficients = LagrangeCoefficients(x_values);
+ std::vector<uint32_t> subsecrets;
+ for (int i = 0; i < num_subsecrets; ++i) {
+ subsecrets.push_back(0);
+ for (int j = 0; j < static_cast<int>(x_values.size()); ++j) {
+ int share_index = x_values[j] - 1;
+ uint32_t subshare = 0;
+ // Big-endian decoding
+ for (int k = 0; k < kSubsecretSize; ++k) {
+ subshare <<= 8;
+ subshare += static_cast<uint8_t>(
+ shares[share_index].data[kSubsecretSize * i + k]);
+ }
+ subsecrets[i] += MultiplyMod(subshare, coefficients[j], kPrime);
+ subsecrets[i] %= kPrime;
+ }
+ }
+
+ return RebuildFromSubsecrets(subsecrets, secret_length);
+}
+
+// Helper function for ModInverse.
+static uint32_t ModPow(uint32_t x, uint32_t y) {
+ if (y == 0) {
+ return 1;
+ }
+ uint32_t p = ModPow(x, y / 2) % ShamirSecretSharing::kPrime;
+ uint32_t q = MultiplyMod(p, p, ShamirSecretSharing::kPrime);
+ return ((y & 0x01) == 0) ? q : MultiplyMod(x, q, ShamirSecretSharing::kPrime);
+}
+
+uint32_t ShamirSecretSharing::ModInverse(uint32_t n) {
+ FCP_CHECK(n > 0 && n < kPrime) << "Invalid value " << n << " for ModInverse";
+ while (inverses_.size() < n) {
+ // Fermat's Little Theorem guarantees n^-1 = n^(P-2) mod P.
+ inverses_.push_back(ModPow(inverses_.size() + 1, kPrime - 2));
+ }
+ return inverses_[n - 1];
+}
+
+std::vector<uint32_t> ShamirSecretSharing::LagrangeCoefficients(
+ const std::vector<int>& x_values) {
+ FCP_CHECK(x_values.size() > 1) << "Must have at least 2 x_values";
+ for (int x : x_values) {
+ FCP_CHECK(x > 0) << "x_values must all be positive, but got a value of "
+ << x;
+ }
+
+ if (x_values == last_lc_input_) {
+ return last_lc_output_;
+ }
+ last_lc_input_ = x_values;
+ last_lc_output_.clear();
+
+ for (int i = 0; i < static_cast<int>(x_values.size()); ++i) {
+ last_lc_output_.push_back(1);
+ for (int j = 0; j < static_cast<int>(x_values.size()); ++j) {
+ if (i == j) {
+ continue;
+ }
+ last_lc_output_[i] = MultiplyMod(last_lc_output_[i], x_values[j], kPrime);
+ if (x_values[j] > x_values[i]) {
+ last_lc_output_[i] = MultiplyMod(
+ last_lc_output_[i], ModInverse(x_values[j] - x_values[i]), kPrime);
+ } else {
+ // Factor out -1 (mod kPrime)
+ last_lc_output_[i] =
+ MultiplyMod(last_lc_output_[i], kPrime - 1, kPrime);
+ last_lc_output_[i] = MultiplyMod(
+ last_lc_output_[i], ModInverse(x_values[i] - x_values[j]), kPrime);
+ }
+ }
+ }
+
+ return last_lc_output_;
+}
+
+std::vector<uint32_t> ShamirSecretSharing::DivideIntoSubsecrets(
+ const std::string& to_share) {
+ std::vector<uint32_t> secret_parts(DivideRoundUp(
+ static_cast<uint32_t>(to_share.size()) * 8, kBitsPerSubsecret));
+
+ int bits_done = 0;
+ auto current_subsecret = secret_parts.rbegin();
+
+ // This is a packing of the bits in to_share into the bits in secret_parts.
+ // The last 31 bits in to_share are kept in the same order and placed into
+ // the last element of secret_parts, the second-to-last 31 bits are placed in
+ // the second-to-last element, and so on. The high-order bit of every element
+ // of secret_parts is 0. And the first element of secret_parts will contain
+ // the remaining bits at the front of to_share.
+ for (int i = to_share.size() - 1; i >= 0; --i) {
+ // Ensure high-order characters are treated consistently
+ uint8_t current_byte = static_cast<uint8_t>(to_share[i]);
+ if (kBitsPerSubsecret - bits_done > 8) {
+ *current_subsecret |= static_cast<uint32_t>(current_byte) << bits_done;
+ bits_done += 8;
+ } else {
+ uint8_t current_byte_right =
+ current_byte & (0xFF >> (8 - (kBitsPerSubsecret - bits_done)));
+ *current_subsecret |= static_cast<uint32_t>(current_byte_right)
+ << bits_done;
+ // Make sure we're not in the edge case where we're exactly done.
+ if (!(i == 0 && bits_done + 8 == kBitsPerSubsecret)) {
+ bits_done = (bits_done + 8) % kBitsPerSubsecret;
+ ++current_subsecret;
+ *current_subsecret |= current_byte >> (8 - bits_done);
+ }
+ }
+ }
+ // We should have been working on the 0th element of the vector.
+ FCP_CHECK(current_subsecret + 1 == secret_parts.rend());
+ return secret_parts;
+}
+
+std::string ShamirSecretSharing::RebuildFromSubsecrets(
+ const std::vector<uint32_t>& secret_parts, int secret_length) {
+ std::string secret(secret_length, 0);
+ int bits_done = 0;
+ auto subsecret = secret_parts.crbegin();
+ // Exactly reverse the process in DivideIntoSubsecrets.
+ for (int i = static_cast<int>(secret.size()) - 1;
+ i >= 0 && subsecret != secret_parts.crend(); --i) {
+ if (kBitsPerSubsecret - bits_done > 8) {
+ secret[i] = static_cast<uint8_t>((*subsecret >> bits_done) & 0xFF);
+ bits_done += 8;
+ } else {
+ uint8_t next_low_bits = static_cast<uint8_t>(*subsecret >> bits_done);
+ ++subsecret;
+ if (subsecret != secret_parts.crend()) {
+ secret[i] = static_cast<uint8_t>(
+ *subsecret & (0xFF >> (kBitsPerSubsecret - bits_done)));
+ }
+ bits_done = (bits_done + 8) % kBitsPerSubsecret;
+ secret[i] <<= 8 - bits_done;
+ secret[i] |= next_low_bits;
+ }
+ }
+
+ return secret;
+}
+
+uint32_t ShamirSecretSharing::EvaluatePolynomial(
+ const std::vector<uint32_t>& polynomial, uint32_t x) const {
+ uint64_t sum = 0;
+
+ for (int i = polynomial.size() - 1; i > 0; --i) {
+ sum += polynomial[i];
+ sum *= x;
+ sum %= kPrime;
+ }
+
+ sum += polynomial[0];
+ sum %= kPrime;
+
+ return static_cast<uint32_t>(sum);
+}
+
+uint32_t ShamirSecretSharing::RandomFieldElement() {
+ uint32_t rand = 0;
+ do {
+ rand = 0;
+ RAND_bytes(reinterpret_cast<uint8_t*>(&rand), sizeof(uint32_t));
+ } while (rand >= kPrime);
+ return rand;
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/shared/shamir_secret_sharing.h b/fcp/secagg/shared/shamir_secret_sharing.h
new file mode 100644
index 0000000..735c315
--- /dev/null
+++ b/fcp/secagg/shared/shamir_secret_sharing.h
@@ -0,0 +1,139 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_SHARED_SHAMIR_SECRET_SHARING_H_
+#define FCP_SECAGG_SHARED_SHAMIR_SECRET_SHARING_H_
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "absl/status/statusor.h"
+#include "fcp/secagg/shared/key.h"
+
+namespace fcp {
+namespace secagg {
+
+// A ShamirShare represents one share of a shared secret, stored as binary data.
+typedef struct ShamirShare {
+ std::string data;
+} ShamirShare;
+
+// This class encapsulates all of the logic needed to perform t-of-n Shamir
+// Secret Sharing on arbitrary-size secrets. For efficiency, the secrets are
+// subdivided into 31-bit chunks called "subsecrets" - this allows us to use
+// native unsigned 64-bit integer multiplication without worrying about
+// overflow. This should be invisible to users of this class, as one ShamirShare
+// still holds one user's share of the secret, represented as one share of each
+// subsecret.
+//
+// This class is not thread-safe.
+
+class ShamirSecretSharing {
+ public:
+ // This is the smallest 32-bit prime, 2^31+11. Everything this class does is
+ // modulo kPrime. We need all values to be no more than 32 bits, so that we
+ // can multiply using native types without overflow.
+ static constexpr uint64_t kPrime = 2147483659L;
+
+ // Constructs the ShamirSecretSharing object.
+ ShamirSecretSharing();
+
+ // Splits the arbitrary-length value stored in to_share into shares, following
+ // threshold-out-of-num_shares Shamir Secret Sharing.
+ //
+ // The output is a vector such that the i-th element of the vector is the i-th
+ // share of the secret.
+ std::vector<ShamirShare> Share(int threshold, int num_shares,
+ const std::string& to_share);
+
+ // Convenience method to share a key instead of an arbitrary string.
+ inline std::vector<ShamirShare> Share(int threshold, int num_shares,
+ const Key& to_share) {
+ return Share(threshold, num_shares, to_share.AsString());
+ }
+
+ // Reconstructs a secret, based on a vector of shares. The vector is
+ // interpreted such that the i-th element of the vector is the i-th share. If
+ // the i-th element of the vector is set to the default ShamirShare (an empty
+ // string), that share is considered not to be present.
+ //
+ // secret_length should be set to the expected length of the reconstructed
+ // secret, in bytes.
+ //
+ // At least threshold of the shares must be set to non-empty strings, or this
+ // operation will fail.
+ //
+ // Reconstruct is most efficient when consecutive calls to this method use
+ // shares with the same indices, because this allows for caching of
+ // intermediate values that depend only on the x-value of the Shamir shares.
+ absl::StatusOr<std::string> Reconstruct(
+ int threshold, const std::vector<ShamirShare>& shares, int secret_length);
+
+ private:
+ // Returns the modular inverse of n mod kPrime, getting the value from a cache
+ // if possible. If not, extends the cache to contain modular inverses from
+ // integers from 1 to n.
+ //
+ // Fails if n is not between 1 and kPrime-1, inclusive.
+ //
+ // For most efficiency, call this method using the largest value of n that
+ // will be needed before calling Reconstruct.
+ uint32_t ModInverse(uint32_t n);
+
+ // Returns the Lagrange coefficients needed to reconstruct secrets for this
+ // exact set of shares. The Lagrange coefficient for the i-th value is
+ // the product, for all j != i, of x_values[j] / (x_values[j] - x_values[i]).
+ //
+ // If this method is called twice in a row on the same input, the output is
+ // returned from cache instead of being recomputed.
+ std::vector<uint32_t> LagrangeCoefficients(const std::vector<int>& x_values);
+
+ // Divides a secret into subsecrets. This takes place when Share is called,
+ // before any further secret sharing work.
+ std::vector<uint32_t> DivideIntoSubsecrets(const std::string& to_share);
+
+ // Rebuilds a secret from subsecrets. This takes place at the end of
+ // the Reconstruct operation, after all the secret reconstruction is already
+ // finished.
+ std::string RebuildFromSubsecrets(const std::vector<uint32_t>& secret_parts,
+ int secret_length);
+
+ // We will split our secret into sub-secrets of no more than 31 bits each.
+ // This allows us to multiply two field elements using only native types.
+ static constexpr int kBitsPerSubsecret = 31;
+
+ // Returns a pseudorandom number uniformly between 0 and kPrime-1.
+ uint32_t RandomFieldElement();
+
+ // Returns the evaluation of x on the specified polynomial.
+ // polynomial[i] is the i-degree coefficient of the polynomial.
+ uint32_t EvaluatePolynomial(const std::vector<uint32_t>& polynomial,
+ uint32_t x) const;
+
+ // Caches previously computed modular inverses.
+ // inverses_[i] = (i+1)^-1 mod kPrime
+ std::vector<uint32_t> inverses_;
+
+ // Store a copy of the last input/output from LagrangeCoefficients.
+ std::vector<int> last_lc_input_;
+ std::vector<uint32_t> last_lc_output_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_SHARED_SHAMIR_SECRET_SHARING_H_
diff --git a/fcp/secagg/shared/shamir_secret_sharing_test.cc b/fcp/secagg/shared/shamir_secret_sharing_test.cc
new file mode 100644
index 0000000..47d8b73
--- /dev/null
+++ b/fcp/secagg/shared/shamir_secret_sharing_test.cc
@@ -0,0 +1,190 @@
+
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/secagg/shared/shamir_secret_sharing.h"
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+#include "fcp/secagg/testing/fake_prng.h"
+namespace fcp {
+namespace secagg {
+namespace {
+using ::testing::Eq;
+TEST(ShamirSecretSharingTest, ShareReturnsTheAppropriateNumberOfShares) {
+ ShamirSecretSharing shamir;
+ std::string secret = "abcdefghijklmnopqrstuvwxyz123456";
+ std::vector<ShamirShare> shares;
+ for (int num_shares = 2; num_shares < 5; ++num_shares) {
+ for (int threshold = 2; threshold <= num_shares; ++threshold) {
+ shares = shamir.Share(threshold, num_shares, secret);
+ EXPECT_THAT(shares.size(), Eq(num_shares));
+ }
+ }
+}
+TEST(ShamirSecretSharingTest, ShareFailsWhenTheSecretIsEmpty) {
+ ShamirSecretSharing shamir;
+ std::string secret = "";
+ EXPECT_DEATH(shamir.Share(2, 5, secret), "to_share must not be empty");
+}
+TEST(ShamirSecretSharingTest, ShareFailsWhenNumberOfSharesIsSmall) {
+ ShamirSecretSharing shamir;
+ std::string secret = "abcdefghijklmnopqrstuvwxyz123456";
+ EXPECT_DEATH(shamir.Share(1, 1, secret), "num_shares must be greater than 1");
+}
+TEST(ShamirSecretSharingTest, ShareFailsWhenTheThresholdIsOutOfBounds) {
+ ShamirSecretSharing shamir;
+ std::string secret = "abcdefghijklmnopqrstuvwxyz123456";
+ EXPECT_DEATH(shamir.Share(6, 5, secret),
+ "threshold must be at least 2 and at most num_shares");
+ EXPECT_DEATH(shamir.Share(1, 5, secret),
+ "threshold must be at least 2 and at most num_shares");
+}
+TEST(ShamirSecretSharingTest, ShareAndReconstructIntegrate) {
+ ShamirSecretSharing shamir;
+ std::string secret = "abcdefghijklmnopqrstuvwxyz123456";
+ std::vector<ShamirShare> shares;
+ int num_shares = 6;
+ int threshold = 4;
+ shares = shamir.Share(threshold, num_shares, secret);
+ auto reconstructed_or_error =
+ shamir.Reconstruct(threshold, shares, secret.size());
+ EXPECT_THAT(reconstructed_or_error.ok(), Eq(true));
+ EXPECT_THAT(reconstructed_or_error.value(), Eq(secret));
+}
+TEST(ShamirSecretSharingTest, ShareAndReconstructIntegrateWithMissingShares) {
+ ShamirSecretSharing shamir;
+ std::string secret = "abcdefghijklmnopqrstuvwxyz123456";
+ std::vector<ShamirShare> shares;
+ int num_shares = 6;
+ int threshold = 4;
+ shares = shamir.Share(threshold, num_shares, secret);
+ shares[0].data = "";
+ shares[2].data = "";
+ auto reconstructed_or_error =
+ shamir.Reconstruct(threshold, shares, secret.size());
+ EXPECT_THAT(reconstructed_or_error.ok(), Eq(true));
+ EXPECT_THAT(reconstructed_or_error.value(), Eq(secret));
+}
+TEST(ShamirSecretSharingTest, ShareAndReconstructIntegrateWithZeroInSecret) {
+ ShamirSecretSharing shamir;
+ std::string secret = "abcdefghijklmnopqrstuvwxyz123456";
+ secret[26] = '\0';
+ std::vector<ShamirShare> shares;
+ int num_shares = 6;
+ int threshold = 4;
+ shares = shamir.Share(threshold, num_shares, secret);
+ auto reconstructed_or_error =
+ shamir.Reconstruct(threshold, shares, secret.size());
+ EXPECT_THAT(reconstructed_or_error.ok(), Eq(true));
+ EXPECT_THAT(reconstructed_or_error.value(), Eq(secret));
+}
+TEST(ShamirSecretSharingTest,
+ ShareAndReconstructIntegrateWithHighOrderCharactersInSecret) {
+ ShamirSecretSharing shamir;
+ std::string secret = "abcdefghijklmnopqrstuvwxyz123456";
+ secret[10] = static_cast<char>(128);
+ secret[20] = static_cast<char>(197);
+ secret[30] = static_cast<char>(255);
+ std::vector<ShamirShare> shares;
+ int num_shares = 6;
+ int threshold = 4;
+ shares = shamir.Share(threshold, num_shares, secret);
+ auto reconstructed_or_error =
+ shamir.Reconstruct(threshold, shares, secret.size());
+ EXPECT_THAT(reconstructed_or_error.ok(), Eq(true));
+ EXPECT_THAT(reconstructed_or_error.value(), Eq(secret));
+}
+TEST(ShamirSecretSharingTest, ShareAndReconstructIntegrateWithKeys) {
+ ShamirSecretSharing shamir;
+ EcdhPregeneratedTestKeys keys;
+ std::vector<ShamirShare> shares;
+ int num_shares = 6;
+ int threshold = 4;
+ shares = shamir.Share(threshold, num_shares, keys.GetPrivateKeyString(3));
+ auto reconstructed_string_or_error =
+ shamir.Reconstruct(threshold, shares, EcdhPrivateKey::kSize);
+ EXPECT_THAT(reconstructed_string_or_error.ok(), Eq(true));
+ EcdhPrivateKey reconstructed(reinterpret_cast<const uint8_t*>(
+ reconstructed_string_or_error.value().c_str()));
+ EXPECT_THAT(reconstructed, Eq(keys.GetPrivateKey(3)));
+ EXPECT_THAT(reconstructed_string_or_error.value(),
+ Eq(keys.GetPrivateKeyString(3)));
+}
+TEST(ShamirSecretSharingTest, ReconstructFailsIfThresholdIsInvalid) {
+ ShamirSecretSharing shamir;
+ std::vector<ShamirShare> shares(5, {"fake"});
+ EXPECT_DEATH(auto secret_or_error = shamir.Reconstruct(1, shares, 16),
+ "threshold must be at least 2");
+ EXPECT_DEATH(
+ auto secret_or_error = shamir.Reconstruct(6, shares, 16),
+ "A vector of size 5 was provided, but threshold was specified as 6");
+}
+TEST(ShamirSecretSharingTest, ReconstructFailsIfSecretLengthSmall) {
+ ShamirSecretSharing shamir;
+ std::vector<ShamirShare> shares(5, {"fake"});
+ EXPECT_DEATH(auto secret_or_error = shamir.Reconstruct(2, shares, 0),
+ "secret_length must be positive");
+}
+TEST(ShamirSecretSharingTest, ReconstructFailsIfSharesAreInvalid) {
+ ShamirSecretSharing shamir;
+ std::vector<ShamirShare> shares(5, {"fakefakefakefakefake"});
+ shares[0].data = "bad";
+ EXPECT_DEATH(auto secret_or_error = shamir.Reconstruct(5, shares, 16),
+ "Share with index 0 is invalid: a share of size 3 was provided "
+ "but a multiple of 4 is expected");
+ shares[0].data = "baad";
+ EXPECT_DEATH(auto secret_or_error = shamir.Reconstruct(5, shares, 16),
+ "Share with index 1 is invalid: all shares must match sizes");
+ shares[0].data = "baadbaadbaadbaadbaadbaad";
+ EXPECT_DEATH(auto secret_or_error = shamir.Reconstruct(5, shares, 16),
+ "Share with index 0 is invalid: the number of subsecrets is 6 "
+ "but between 1 and 5 is expected");
+ shares[0].data = "";
+ auto secret_or_error = shamir.Reconstruct(5, shares, 16);
+ EXPECT_THAT(secret_or_error.ok(), Eq(false));
+ EXPECT_THAT(secret_or_error.status().message(),
+ testing::HasSubstr("Only 4 valid shares were provided, but "
+ "threshold was specified as 5"));
+}
+TEST(ShamirSecretSharingTest, ReconstructWorksWithPrecomputedShares) {
+ ShamirSecretSharing shamir;
+ std::vector<ShamirShare> shares(5);
+ int threshold = 3;
+ // These shares were generated by the legacy Java code.
+ uint8_t shares0[] = {112, 207, 118, 46, 110, 212, 170, 28};
+ shares[0].data = std::string(reinterpret_cast<char*>(shares0), 8);
+ uint8_t shares1[] = {48, 160, 197, 172, 38, 235, 145, 204};
+ shares[1].data = std::string(reinterpret_cast<char*>(shares1), 8);
+ uint8_t shares2[] = {63, 115, 238, 144, 40, 68, 183, 71};
+ shares[2].data = std::string(reinterpret_cast<char*>(shares2), 8);
+ uint8_t shares3[] = {29, 72, 240, 207, 114, 224, 26, 141};
+ shares[3].data = std::string(reinterpret_cast<char*>(shares3), 8);
+ uint8_t shares4[] = {74, 31, 204, 116, 6, 189, 187, 136};
+ shares[4].data = std::string(reinterpret_cast<char*>(shares4), 8);
+ ASSERT_THAT(shares[0].data.size(), Eq(8));
+ auto reconstructed_or_error = shamir.Reconstruct(threshold, shares, 4);
+ EXPECT_THAT(reconstructed_or_error.ok(), Eq(true));
+ EXPECT_THAT(reconstructed_or_error.value(), Eq(std::string({0, 0, 0, 33})));
+}
+} // namespace
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/testing/BUILD b/fcp/secagg/testing/BUILD
new file mode 100644
index 0000000..3332538
--- /dev/null
+++ b/fcp/secagg/testing/BUILD
@@ -0,0 +1,57 @@
+# Description:
+# Mocks for SecAgg.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp:internal"],
+)
+
+cc_library(
+ name = "common_mocks",
+ testonly = 1,
+ srcs = [
+ "ecdh_pregenerated_test_keys.cc",
+ ],
+ hdrs = [
+ "ecdh_pregenerated_test_keys.h",
+ "fake_prng.h",
+ ],
+ deps = [
+ "//fcp/secagg/shared",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "client_mocks",
+ testonly = 1,
+ hdrs = [
+ "mock_send_to_server_interface.h",
+ "mock_state_transition_listener.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/secagg/client",
+ "//fcp/secagg/client:state_transition_listener",
+ "//fcp/secagg/shared:cc_proto",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "testing",
+ testonly = 1,
+ srcs = [
+ "test_matchers.cc",
+ ],
+ hdrs = [
+ "test_matchers.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/secagg/shared",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/secagg/testing/ecdh_pregenerated_test_keys.cc b/fcp/secagg/testing/ecdh_pregenerated_test_keys.cc
new file mode 100644
index 0000000..969bcd4
--- /dev/null
+++ b/fcp/secagg/testing/ecdh_pregenerated_test_keys.cc
@@ -0,0 +1,168 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/secagg/shared/ecdh_key_agreement.h"
+#include "fcp/secagg/shared/ecdh_keys.h"
+
+namespace fcp {
+namespace secagg {
+
+constexpr int EcdhPregeneratedTestKeys::kNumTestEcdhKeys;
+
+EcdhPregeneratedTestKeys::EcdhPregeneratedTestKeys() {
+ // These were generated by EcdhKeyAgreement::CreateFromRandomKeys.
+ // Despite the odd line wrapping, there are 8 individual strings in each
+ // vector.
+ private_key_strings_ = {
+ "\x6\x8D\xCC\xDD\x96Sv\x1Ams\xED\x83\x86\xB0vW\xB2"
+ "9,\xEA\xBE\x1C\xAD\xEF\x8E\xC6\xC1\xEE\xD5@\xC1\x98",
+ "\x99\xB7\x18\x91\x8E\xC1\xD5=\x86\xDB\xF0\x1D"
+ "4\x9D\xFD\xDC\xE6\xCC\xBB`=\x94iO\xBBU<x'%\xF9\x99",
+ "V\xA2"
+ "fLD\xA8"
+ "a\xA8\xCD"
+ "B\x93N\xF1|\xE7:\xDB U&g\\\xAA\xD8'P\xEF\x8F\xA5\xF8"
+ "do",
+ "R\xAB\xCF\x11\xC6\xA1P@\xD5VS}C\x3\xD2\x1\xC8\x8A\x96\x10Z\xF8"
+ "C\x1\xA0\xD2\x9C\xB6\xF9\x6\n\xEF",
+ "\x8A\xFCxv\x95\xB5n\xA3v{\xD4\x9F"
+ "6\x89^@\xB0Oa\xB3\x8B"
+ "C\xB0\x2\xA5\xD7\xE9\x14\x8DXf1",
+ "\xF4\xDD\xAF\x19"
+ "D\xCBi\x94\x8F\x9C\xEAq\xFWXE\xE0\xFA\xF2\x8B[\xCD \xE4"
+ "Aa;\x9A\x8E}\xA0\x85",
+ "pC\x18\xDD\xFW\xDF\xA0\n\v9\xBB\xE8\xCAU\x1D\xCA\x93"
+ "B\xA2\x12-\x9DoJju\xDB\x9FVD\xD6",
+ "\xF0\xD4\x99\xB7\xBE\x1B\xC3QMU\x99\xE1\xA1"
+ "f\xF2X\x13\xAC\xB4\x96"
+ "2)=U\xE6\xEC#\xCE\xA5y\xCE\x9B"};
+ public_key_strings_ = {
+ "\x2^\x8D\x95\xEmW\xA4\x99y\x19ZT\x87;\xC1.R\xC3_r\xAE"
+ "1=nh\xD2\x82xZ\xC9L\x83",
+ "\x3S\xD3zo\xB5.\xB3"
+ "3f\x94\x14\xAE\xC3,\xC3\xB9r,\xB3\\\xE8\xBD\x8E\xF2\t\xEB\x1\f\xE1\x83"
+ "9\xED",
+ "\x2 \xB3-\xC4\x1\x8D\x9A"
+ "BM\x98W\xB3=Y\x1\xEA\x80WGS&r\x1A\x12\xE0z\x9EzC:\xCA\x9D",
+ "\x2"
+ "3\xCBn\x5\x5\xBC\xB5\x6y\x13h\xD5\x14\xA8\x19L\xCBG/"
+ "N\x12\xEA\x98\xA7\xA2/\xE4\x5\xCD\xC2\xE1\xC1",
+ "\x2\x1C"
+ "C\x85\xE3\xFF\x1B\xA8\xF1\xE7\xF2\xF9\x8C(\x85\x1CZ\xC9\xFB\xE3.\x8D "
+ "'\xB5\xCC"
+ "A\x12\x19\x80\xED)\xD3",
+ "\x3\xF3\xB6h\t\xB3\x80\xD0%\xCD> "
+ "'\xA2\xFBz\xA5\xB5K\x94\xE5==\t\aIt\xCD\x98'\x1"
+ "5)",
+ "\x2"
+ "4\xFF\xA5\xA2\x1B\xF1!q\xBEH\xEB\xE7*\xD8\xBF\x87(!\xBDM "
+ "(\xC1mK\xC2\xB8\xDA\xAB"
+ "al\x8E",
+ "\x3Jr+~\x19kht\xE1\xD7\x10\x9B\xB9\xAD}U\xB2\xE4\xF1\x89%\xCF"
+ "b\xAA\n\x83"
+ "5\x8B"
+ "7G\xA3\xF2"};
+
+ // These were generated by a Java ECDH implementation. They do not correspond
+ // to the private keys, and have X.509 headers.
+ uncompressed_public_key_strings_ = {
+ "0Y0\023\006\a*\206H\316=\002\001\006\b*\206H\316="
+ "\003\001\a\003B\000\004Pv\301\226\312Y\326\323I\271\265\310Vj\353\236"
+ "\336}\230a\265\312:kh\315\365\270nW\332\271$\v\b\277%"
+ "\246\375d\361H\274\343q\362\254\235\252\220\204a\303\222\020\234\000\372"
+ "\332]%p3\006",
+ "0Y0\023\006\a*\206H\316=\002\001\006\b*\206H\316="
+ "\003\001\a\003B\000\004\221f\267\002lX?\325\261\353\\\3430U\365["
+ "\005\222\214ey\300s`"
+ "F\\\251\370\t\310\310\n4\375\005\324\360\316\377\025\320\341\363e]4["
+ "\317??\276\227\"p\353\314\377\260\213\024\300\003\250\253",
+ "0Y0\023\006\a*\206H\316=\002\001\006\b*\206H\316="
+ "\003\001\a\003B\000\004\256\000g\0247\365!<$"
+ "\234eYA\323u\341\346x\350\271=F\220n\257\201\210\034z0\315s\221="
+ "\357\226\'%O\342\270\346\035\277\2465\223\036^H+"
+ "Q\246mAZ\021cf\000\016\221\350c",
+ "0Y0\023\006\a*\206H\316=\002\001\006\b*\206H\316="
+ "\003\001\a\003B\000\004\034\230\204\000y\022e\226\330\3547\346\212~5Z-"
+ "0\2342)\315\'B\271\235\226R~\252]\256\305\353$%"
+ "Sq\376\337\254Z\022\213eH%!\225\271-,/\227s\263!J3Ln\3503,",
+ "0Y0\023\006\a*\206H\316=\002\001\006\b*\206H\316="
+ "\003\001\a\003B\000\004u\273R8\212\335\302c\341\237\373\"\270\350UL\342R"
+ "\330]l\272C\346[\312\240\fl\207[\201\241\022\a\035? !\210 "
+ "l\264\241H}2*\004<E\034n\367!8\226Z\263p\214\tY@",
+ "0Y0\023\006\a*\206H\316=\002\001\006\b*\206H\316="
+ "\003\001\a\003B\000\004\375\211:E5j\373\357\257\276|"
+ "\023\003\a\214\2751\200n\374\306\300\313\266EL\306\364Vq\205\000\031R6"
+ "\357I2\253r\024j\235\264^\340_\363\225*\306?"
+ "\352\270\017\371\351EN\362K8\020\244",
+ "0Y0\023\006\a*\206H\316=\002\001\006\b*\206H\316="
+ "\003\001\a\003B\000\004\322\333\350*"
+ "3\016\246D\t\177\226GL\332\356\336\365k\303\331?H\020\235\235\206\373*"
+ "\322c\006\206\030\200k\363\233]\231\260=\363\003\2221\337:\021E?"
+ "ra\370\026\340\267\"5\n\317\vN\375\235",
+ "0Y0\023\006\a*\206H\316=\002\001\006\b*\206H\316="
+ "\003\001\a\003B\000\004\224\326\324\2577{z\364sTf\321YD\317\177QR\034}"
+ "\030\241\211\026\366\255\256t\352\322\363\027\201\031F\026\260\3706h\330"
+ "\026\022\\\022\'\255Y!\321a\324c\324\337\276\256?\002\235 \2636\034"};
+
+ for (int i = 0; i < kNumTestEcdhKeys; ++i) {
+ private_keys_.push_back(EcdhPrivateKey(
+ reinterpret_cast<const uint8_t*>(private_key_strings_[i])));
+ public_keys_.push_back(EcdhPublicKey(
+ reinterpret_cast<const uint8_t*>(public_key_strings_[i])));
+ // Move pointer ahead 26 bytes to skip header
+ uncompressed_public_keys_.push_back(EcdhPublicKey(
+ reinterpret_cast<const uint8_t*>(uncompressed_public_key_strings_[i] +
+ 26),
+ EcdhPublicKey::kUncompressed));
+ }
+}
+
+EcdhPrivateKey EcdhPregeneratedTestKeys::GetPrivateKey(size_t index) {
+ EXPECT_THAT(index, testing::Lt(kNumTestEcdhKeys));
+ return private_keys_[index];
+}
+EcdhPublicKey EcdhPregeneratedTestKeys::GetPublicKey(size_t index) {
+ EXPECT_THAT(index, testing::Lt(kNumTestEcdhKeys));
+ return public_keys_[index];
+}
+EcdhPublicKey EcdhPregeneratedTestKeys::GetUncompressedPublicKey(size_t index) {
+ EXPECT_THAT(index, testing::Lt(kNumTestEcdhKeys));
+ return uncompressed_public_keys_[index];
+}
+
+std::string EcdhPregeneratedTestKeys::GetPrivateKeyString(size_t index) {
+ EXPECT_THAT(index, testing::Lt(kNumTestEcdhKeys));
+ return std::string(private_key_strings_[index], EcdhPrivateKey::kSize);
+}
+std::string EcdhPregeneratedTestKeys::GetPublicKeyString(size_t index) {
+ EXPECT_THAT(index, testing::Lt(kNumTestEcdhKeys));
+ return std::string(public_key_strings_[index], EcdhPublicKey::kSize);
+}
+std::string EcdhPregeneratedTestKeys::GetUncompressedPublicKeyString(
+ size_t index) {
+ EXPECT_THAT(index, testing::Lt(kNumTestEcdhKeys));
+ return std::string(public_key_strings_[index], EcdhPublicKey::kSize);
+}
+
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/testing/ecdh_pregenerated_test_keys.h b/fcp/secagg/testing/ecdh_pregenerated_test_keys.h
new file mode 100644
index 0000000..16607cc
--- /dev/null
+++ b/fcp/secagg/testing/ecdh_pregenerated_test_keys.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_TESTING_ECDH_PREGENERATED_TEST_KEYS_H_
+#define FCP_SECAGG_TESTING_ECDH_PREGENERATED_TEST_KEYS_H_
+
+#include <string>
+#include <vector>
+
+#include "fcp/secagg/shared/ecdh_keys.h"
+
+namespace fcp {
+namespace secagg {
+
+// This class contains some pregenerated ECDH public/private keypairs. In no
+// actual implementation should pregenerated keys such as these be used.
+class EcdhPregeneratedTestKeys {
+ public:
+ // Valid inputs for all functions are integers from 0 to kNumTestEcdhKeys.
+ static constexpr int kNumTestEcdhKeys = 8;
+
+ EcdhPregeneratedTestKeys();
+
+ // Returns a public or private key.
+ EcdhPrivateKey GetPrivateKey(size_t index);
+ EcdhPublicKey GetPublicKey(size_t index);
+
+ // Returns a public or private key in the form of a string.
+ std::string GetPrivateKeyString(size_t index);
+ std::string GetPublicKeyString(size_t index);
+
+ // Returns an uncompressed public key.
+ EcdhPublicKey GetUncompressedPublicKey(size_t index);
+ // Returns an uncompressed public key in the form of a string, with X.509
+ // header.
+ std::string GetUncompressedPublicKeyString(size_t index);
+
+ private:
+ std::vector<const char*> private_key_strings_;
+ std::vector<const char*> public_key_strings_;
+ std::vector<EcdhPrivateKey> private_keys_;
+ std::vector<EcdhPublicKey> public_keys_;
+ std::vector<EcdhPublicKey> uncompressed_public_keys_;
+ std::vector<const char*> uncompressed_public_key_strings_;
+};
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_TESTING_ECDH_PREGENERATED_TEST_KEYS_H_
diff --git a/fcp/secagg/testing/fake_prng.h b/fcp/secagg/testing/fake_prng.h
new file mode 100644
index 0000000..2fafd51
--- /dev/null
+++ b/fcp/secagg/testing/fake_prng.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_TESTING_FAKE_PRNG_H_
+#define FCP_SECAGG_TESTING_FAKE_PRNG_H_
+
+#include <cstdint>
+
+#include "fcp/secagg/shared/prng.h"
+
+namespace fcp {
+namespace secagg {
+
+// Fake Implementation of SecurePrng that just returns constantly incrementing
+// values.
+
+class FakePrng : public SecurePrng {
+ public:
+ // Returns 1, 2, 3, etc.
+ FakePrng() = default;
+
+ // Returns the selected value first, and increments by 1 each time from there.
+ explicit FakePrng(uint64_t value) : value_(value - 1) {}
+
+ uint8_t Rand8() override { return static_cast<uint8_t>(++value_); }
+ uint64_t Rand64() override { return ++value_; }
+
+ private:
+ uint64_t value_ = 0;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_TESTING_FAKE_PRNG_H_
diff --git a/fcp/secagg/testing/mock_send_to_server_interface.h b/fcp/secagg/testing/mock_send_to_server_interface.h
new file mode 100644
index 0000000..b855e5a
--- /dev/null
+++ b/fcp/secagg/testing/mock_send_to_server_interface.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_TESTING_MOCK_SEND_TO_SERVER_INTERFACE_H_
+#define FCP_SECAGG_TESTING_MOCK_SEND_TO_SERVER_INTERFACE_H_
+
+#include "gmock/gmock.h"
+#include "fcp/secagg/client/send_to_server_interface.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// GMock Implementation of SendToServerInterface.
+
+class MockSendToServerInterface : public SendToServerInterface {
+ public:
+ MOCK_METHOD(void, Send, (ClientToServerWrapperMessage * message));
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_TESTING_MOCK_SEND_TO_SERVER_INTERFACE_H_
diff --git a/fcp/secagg/testing/mock_state_transition_listener.h b/fcp/secagg/testing/mock_state_transition_listener.h
new file mode 100644
index 0000000..c01e1cf
--- /dev/null
+++ b/fcp/secagg/testing/mock_state_transition_listener.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_TESTING_MOCK_STATE_TRANSITION_LISTENER_H_
+#define FCP_SECAGG_TESTING_MOCK_STATE_TRANSITION_LISTENER_H_
+
+#include "gmock/gmock.h"
+#include "fcp/secagg/client/state_transition_listener_interface.h"
+
+namespace fcp {
+namespace secagg {
+
+// GMock Implementation of SendToServerInterface.
+
+class MockStateTransitionListener : public StateTransitionListenerInterface {
+ public:
+ MOCK_METHOD(void, Transition, (ClientState state));
+ MOCK_METHOD(void, Started, (ClientState state));
+ MOCK_METHOD(void, Stopped, (ClientState state));
+ MOCK_METHOD(void, set_execution_session_id, (int64_t execution_session_id));
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_TESTING_MOCK_STATE_TRANSITION_LISTENER_H_
diff --git a/fcp/secagg/testing/server/BUILD b/fcp/secagg/testing/server/BUILD
new file mode 100644
index 0000000..e090c97
--- /dev/null
+++ b/fcp/secagg/testing/server/BUILD
@@ -0,0 +1,45 @@
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp:internal"],
+)
+
+cc_library(
+ name = "server_mocks",
+ testonly = 1,
+ hdrs = [
+ "mock_secagg_server_metrics_listener.h",
+ "mock_send_to_clients_interface.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/secagg/server:secagg_server_metrics_listener",
+ "//fcp/secagg/server:send_to_clients_interface",
+ "//fcp/secagg/server:server_cc_proto",
+ "//fcp/secagg/shared:cc_proto",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "experiments",
+ testonly = 1,
+ hdrs = [
+ "test_secagg_experiments.h",
+ ],
+ deps = [
+ "//fcp/secagg/server:experiments_interface",
+ ],
+)
+
+cc_library(
+ name = "async_runner",
+ testonly = 1,
+ hdrs = [
+ "test_async_runner.h",
+ ],
+ deps = [
+ "//fcp/base:scheduler",
+ "//fcp/secagg/server:secagg_scheduler",
+ ],
+)
diff --git a/fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h b/fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h
new file mode 100644
index 0000000..7fe7807
--- /dev/null
+++ b/fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_TESTING_SERVER_MOCK_SECAGG_SERVER_METRICS_LISTENER_H_
+#define FCP_SECAGG_TESTING_SERVER_MOCK_SECAGG_SERVER_METRICS_LISTENER_H_
+
+#include "gmock/gmock.h"
+#include "fcp/secagg/server/secagg_server_enums.pb.h"
+#include "fcp/secagg/server/secagg_server_metrics_listener.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// GMock Implementation of SecAggServerMetricsListener.
+class MockSecAggServerMetricsListener : public SecAggServerMetricsListener {
+ public:
+ MOCK_METHOD(void, ProtocolStarts, (ServerVariant server_variant), (override));
+ MOCK_METHOD(void, IndividualMessageSizes,
+ (ServerToClientWrapperMessage::MessageContentCase message_type,
+ uint64_t size),
+ (override));
+ MOCK_METHOD(void, BroadcastMessageSizes,
+ (ServerToClientWrapperMessage::MessageContentCase message_type,
+ uint64_t size),
+ (override));
+ MOCK_METHOD(void, MessageReceivedSizes,
+ (ClientToServerWrapperMessage::MessageContentCase message_type,
+ bool message_expected, uint64_t size),
+ (override));
+ MOCK_METHOD(void, ClientResponseTimes,
+ (ClientToServerWrapperMessage::MessageContentCase message_type,
+ uint64_t elapsed_millis),
+ (override));
+ MOCK_METHOD(void, RoundTimes,
+ (SecAggServerStateKind target_state, bool successful,
+ uint64_t elapsed_millis),
+ (override));
+ MOCK_METHOD(void, PrngExpansionTimes, (uint64_t elapsed_millis), (override));
+ MOCK_METHOD(void, RoundSurvivingClients,
+ (SecAggServerStateKind target_state, uint64_t number_of_clients),
+ (override));
+ MOCK_METHOD(void, RoundCompletionFractions,
+ (SecAggServerStateKind target_state, ClientStatus client_state,
+ double fraction),
+ (override));
+ MOCK_METHOD(void, ProtocolOutcomes, (SecAggServerOutcome outcome),
+ (override));
+ MOCK_METHOD(void, ClientsDropped,
+ (ClientStatus abort_state, ClientDropReason error_code),
+ (override));
+ MOCK_METHOD(void, ShamirReconstructionTimes, (uint64_t elapsed_millis),
+ (override));
+};
+
+} // namespace secagg
+} // namespace fcp
+#endif // FCP_SECAGG_TESTING_SERVER_MOCK_SECAGG_SERVER_METRICS_LISTENER_H_
diff --git a/fcp/secagg/testing/server/mock_send_to_clients_interface.h b/fcp/secagg/testing/server/mock_send_to_clients_interface.h
new file mode 100644
index 0000000..1a2ddd7
--- /dev/null
+++ b/fcp/secagg/testing/server/mock_send_to_clients_interface.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright 2018 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_TESTING_SERVER_MOCK_SEND_TO_CLIENTS_INTERFACE_H_
+#define FCP_SECAGG_TESTING_SERVER_MOCK_SEND_TO_CLIENTS_INTERFACE_H_
+
+#include "gmock/gmock.h"
+#include "fcp/secagg/server/send_to_clients_interface.h"
+#include "fcp/secagg/shared/secagg_messages.pb.h"
+
+namespace fcp {
+namespace secagg {
+
+// GMock Implementation of SendToClientsInterface.
+
+class MockSendToClientsInterface : public SendToClientsInterface {
+ public:
+ MOCK_METHOD(void, SendBroadcast,
+ (const ServerToClientWrapperMessage& message), (override));
+ MOCK_METHOD(void, Send,
+ (uint32_t recipient_id,
+ const ServerToClientWrapperMessage& message),
+ (override));
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_TESTING_SERVER_MOCK_SEND_TO_CLIENTS_INTERFACE_H_
diff --git a/fcp/secagg/testing/server/test_async_runner.h b/fcp/secagg/testing/server/test_async_runner.h
new file mode 100644
index 0000000..7bd9cb6
--- /dev/null
+++ b/fcp/secagg/testing/server/test_async_runner.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_TESTING_SERVER_TEST_ASYNC_RUNNER_H_
+#define FCP_SECAGG_TESTING_SERVER_TEST_ASYNC_RUNNER_H_
+
+#include <memory>
+#include <utility>
+
+#include "fcp/base/scheduler.h"
+#include "fcp/secagg/server/secagg_scheduler.h"
+
+namespace fcp {
+namespace secagg {
+
+// Defines a scheduler used for testing, owning pointers to underlying
+// schedulers in SecAggScheduler
+class TestAsyncRunner : public SecAggScheduler {
+ public:
+ TestAsyncRunner(std::unique_ptr<Scheduler> worker_scheduler,
+ std::unique_ptr<Scheduler> callback_scheduler)
+ : SecAggScheduler(worker_scheduler.get(), callback_scheduler.get()),
+ worker_scheduler_(std::move(worker_scheduler)),
+ callback_scheduler_(std::move(callback_scheduler)) {}
+
+ ~TestAsyncRunner() override { WaitUntilIdle(); }
+
+ private:
+ std::unique_ptr<Scheduler> worker_scheduler_;
+ std::unique_ptr<Scheduler> callback_scheduler_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_TESTING_SERVER_TEST_ASYNC_RUNNER_H_
diff --git a/fcp/secagg/testing/server/test_secagg_experiments.h b/fcp/secagg/testing/server/test_secagg_experiments.h
new file mode 100644
index 0000000..8ac9c1f
--- /dev/null
+++ b/fcp/secagg/testing/server/test_secagg_experiments.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_TESTING_SERVER_TEST_SECAGG_EXPERIMENTS_H_
+#define FCP_SECAGG_TESTING_SERVER_TEST_SECAGG_EXPERIMENTS_H_
+
+#include <set>
+#include <string>
+
+#include "fcp/secagg/server/experiments_interface.h"
+
+namespace fcp {
+namespace secagg {
+
+// Defines an experiment class to set secagg experiments.
+class TestSecAggExperiment : public ExperimentsInterface {
+ public:
+ explicit TestSecAggExperiment(
+ const std::set<std::string>& enabled_experiment_names)
+ : enabled_experiment_names_(enabled_experiment_names) {}
+ explicit TestSecAggExperiment(std::string enabled_experiment_name) {
+ enabled_experiment_names_ =
+ std::set<std::string>({enabled_experiment_name});
+ }
+ explicit TestSecAggExperiment() {}
+ bool IsEnabled(absl::string_view experiment_name) override {
+ return enabled_experiment_names_.find(static_cast<std::string>(
+ experiment_name)) != enabled_experiment_names_.end();
+ }
+
+ private:
+ std::set<std::string> enabled_experiment_names_;
+};
+
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_TESTING_SERVER_TEST_SECAGG_EXPERIMENTS_H_
diff --git a/fcp/secagg/testing/test_matchers.cc b/fcp/secagg/testing/test_matchers.cc
new file mode 100644
index 0000000..548acb6
--- /dev/null
+++ b/fcp/secagg/testing/test_matchers.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/secagg/testing/test_matchers.h"
+
+#include <string>
+
+namespace fcp {
+namespace secagg {
+namespace testing {
+
+SecAggVectorData ToSecAggVectorData(const SecAggVector& vector){
+ return std::make_pair(vector.modulus(), vector.GetAsUint64Vector());
+}
+
+SecAggVectorDataMap ToDataMap(const SecAggVectorMap& secagg_vector_map) {
+ SecAggVectorDataMap result;
+ for (const auto& item : secagg_vector_map) {
+ result.emplace(item.first, ToSecAggVectorData(item.second));
+ }
+ return result;
+}
+
+class SecAggVectorMapMatcherImpl
+ : public ::testing::MatcherInterface<const SecAggVectorMap&> {
+ public:
+ explicit SecAggVectorMapMatcherImpl(SecAggVectorDataMap expected)
+ : expected_(expected) {}
+ void DescribeTo(::std::ostream* os) const override {
+ for (const auto& item : expected_) {
+ *os << "{name: \"" << item.first << "\", modulus: " << item.second.first
+ << ", vector:";
+ for (uint64_t val : item.second.second) {
+ *os << " " << val;
+ }
+ *os << "} ";
+ }
+ }
+
+ bool MatchAndExplain(
+ const SecAggVectorMap& arg,
+ ::testing::MatchResultListener* listener) const override {
+ return ::testing::ExplainMatchResult(
+ ::testing::UnorderedElementsAreArray(expected_), ToDataMap(arg),
+ listener);
+ }
+
+ private:
+ SecAggVectorDataMap expected_;
+};
+
+SecAggVectorMapMatcher::operator ::testing::Matcher<const SecAggVectorMap&>()
+ const {
+ return ::testing::MakeMatcher(new SecAggVectorMapMatcherImpl(expected_));
+}
+
+SecAggVectorMapMatcher MatchesSecAggVectorMap(const SecAggVectorMap& expected) {
+ return SecAggVectorMapMatcher(ToDataMap(expected));
+}
+
+SecAggVectorMapMatcher MatchesSecAggVector(const std::string& name,
+ const SecAggVector& vector) {
+ SecAggVectorDataMap expected;
+ expected.emplace(name, ToSecAggVectorData(vector));
+ return SecAggVectorMapMatcher(expected);
+}
+
+} // namespace testing
+} // namespace secagg
+} // namespace fcp
diff --git a/fcp/secagg/testing/test_matchers.h b/fcp/secagg/testing/test_matchers.h
new file mode 100644
index 0000000..217416f
--- /dev/null
+++ b/fcp/secagg/testing/test_matchers.h
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_SECAGG_TESTING_TEST_MATCHERS_H_
+#define FCP_SECAGG_TESTING_TEST_MATCHERS_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "absl/container/node_hash_map.h"
+#include "fcp/secagg/shared/secagg_vector.h"
+
+namespace fcp {
+namespace secagg {
+namespace testing {
+
+using SecAggVectorData = std::pair<int, std::vector<uint64_t> >;
+using SecAggVectorDataMap = absl::node_hash_map<std::string, SecAggVectorData>;
+
+class SecAggVectorMapMatcher {
+ public:
+ explicit SecAggVectorMapMatcher(SecAggVectorDataMap expected)
+ : expected_(expected) {}
+ // Intentionally allowed to be implicit.
+ operator ::testing::Matcher<const SecAggVectorMap&>() const; // NOLINT
+
+ private:
+ SecAggVectorDataMap expected_;
+};
+
+SecAggVectorMapMatcher MatchesSecAggVectorMap(const SecAggVectorMap& expected);
+SecAggVectorMapMatcher MatchesSecAggVector(const std::string& name,
+ const SecAggVector& vector);
+
+} // namespace testing
+} // namespace secagg
+} // namespace fcp
+
+#endif // FCP_SECAGG_TESTING_TEST_MATCHERS_H_
diff --git a/fcp/tensorflow/BUILD b/fcp/tensorflow/BUILD
new file mode 100644
index 0000000..4b2d936
--- /dev/null
+++ b/fcp/tensorflow/BUILD
@@ -0,0 +1,974 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@system_provided_tf//:system_provided_tf.bzl", "tf_custom_op_library")
+load("@rules_python//python:defs.bzl", "py_binary", "py_test")
+load("//fcp:config.bzl", "FCP_COPTS")
+load("//fcp/tracing:build_defs.bzl", "tracing_schema_cc_library")
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py")
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library")
+
+default_visibility = ["//fcp:internal"]
+
+package(
+ default_visibility = default_visibility,
+ licenses = ["notice"], # Apache 2.0
+)
+
+tf_cc_test(
+ name = "tf_smoke_test",
+ srcs = ["tf_smoke_test.cc"],
+ extra_copts = FCP_COPTS,
+ deps = [
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/cc:cc_ops",
+ "@org_tensorflow//tensorflow/cc:client_session",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:tensorflow_opensource",
+ "@org_tensorflow//tensorflow/core:testlib",
+ ],
+)
+
+py_test(
+ name = "tf_py_smoke_test",
+ srcs = ["tf_py_smoke_test.py"],
+ python_version = "PY3",
+)
+
+# Library for converting between the FCP and TensorFlow versions of a Status.
+# Note that this library is intended to be usable in an op .so, thus it depends
+# on TF headers but *not* an implementation (the final binary needs to link it
+# in). We must also use the right copy of headers, depending on whether the
+# build is targeting a system-provided TF library or a bazel-built one.
+cc_library(
+ name = "status",
+ srcs = [
+ "status.cc",
+ ],
+ hdrs = [
+ "status.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/base",
+ ] + select({
+ "@system_provided_tf//:system_provided_tf_build": ["@system_provided_tf//:tf_headers"],
+ "//conditions:default": ["@org_tensorflow//tensorflow/core:framework_headers_lib"],
+ }),
+)
+
+cc_test(
+ name = "status_test",
+ srcs = [
+ "status_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":status",
+ "@com_google_googletest//:gtest_main",
+ # See remarks on :status about the TF framework dependency
+ "@org_tensorflow//tensorflow/core:framework",
+ ],
+)
+
+cc_library(
+ name = "host_object",
+ srcs = [
+ "host_object.cc",
+ ],
+ hdrs = [
+ "host_object.h",
+ ],
+ copts = FCP_COPTS,
+ visibility = default_visibility + [
+ ],
+ deps = [
+ "//fcp/base",
+ "//fcp/base:random_token",
+ "//fcp/base:unique_value",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_test(
+ name = "host_object_test",
+ srcs = [
+ "host_object_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":host_object",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tracing_schema_cc_library(
+ name = "tracing_schema",
+ srcs = ["tracing_schema.fbs"],
+)
+
+cc_library(
+ name = "tf_session",
+ srcs = ["tf_session.cc"],
+ hdrs = ["tf_session.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":status",
+ ":tracing_schema",
+ "//fcp/base",
+ "//fcp/base:process_unique_id",
+ "//fcp/base:result",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/tracing",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ "@org_tensorflow//tensorflow/core:core_cpu",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ "@org_tensorflow//tensorflow/core:tensorflow",
+ ],
+)
+
+cc_test(
+ name = "tf_session_test",
+ srcs = ["tf_session_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":tf_session",
+ ":tracing_schema",
+ "//fcp/base:tracing_schema",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/tensorflow/testing:tf_helper",
+ "//fcp/testing:result_matchers",
+ "//fcp/tracing:test_tracing_recorder",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/cc:cc_ops",
+ "@org_tensorflow//tensorflow/cc:scope",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ "@org_tensorflow//tensorflow/core:tensorflow",
+ "@org_tensorflow//tensorflow/core:testlib",
+ ],
+)
+
+# C++ interfaces for implementing an 'external dataset' (a kind of host object).
+# Note this does *not* depend on TensorFlow.
+cc_library(
+ name = "external_dataset",
+ srcs = [
+ ],
+ hdrs = [
+ "external_dataset.h",
+ ],
+ copts = FCP_COPTS,
+ visibility = default_visibility + [
+ ],
+ deps = [
+ ":host_object",
+ "//fcp/base:bounds",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+# The targets below produce a custom op, which involves a native library as
+# well as Python wrappers. There is some significant complexity arising from
+# the various ways that an op / kernel might be linked in, which we'll try to
+# explain here:
+#
+# - Ops / kernels are meant to be buildable as DSOs (dynamic shared objects,
+# i.e. .so files). Yet, all ops and kernels must agree on the same
+# 'framework' for registration etc. When building a DSO, ops and kernels
+# can include TensorFlow framework *headers*, with implementations provided
+# by libtensorflow_framework.so at runtime.
+#
+# - When using ops / kernels (and TensorFlow) from a standard Python
+# interpreter, they *must* be loaded as DSOs.
+#
+# - When using ops / kernels (and TensorFlow) from C++, we have the option of
+# linking a monolithic binary, with Bazel's usual handling of deps. This is
+# in fact necessary to generate Python wrapper code (the generator links in
+# cc_library deps).
+#
+# Below, we generate *both* a DSO and cc_library variant of the ExternalDataset
+# op and kernel:
+# cc_library: :external_dataset_op_lib
+# DSO: _external_dataset_op.so
+#
+# The ExternalDataset op is a peculiar case, since it is specifically intended
+# to use objects provided by the program hosting TensorFlow (beyond the usual
+# TensorFlow APIs). This is problematic, since separate host and DSO binaries
+# each end up with their own definitions of symbols from common libraries (and
+# likely export them!). Though this might appear to work sometimes, it must be
+# avoided.
+# See e.g. https://github.com/abseil/abseil-cpp/issues/125
+#
+# ---------------------------
+# | _external_dataset_op.so |
+# ------------- -> | absl |
+# | Host | / | fcp/base |
+# | absl | ---------------------------
+# | fcp/base | \ |
+# ------------- \ v
+# \ ------------------------------
+# -> | libtensorflow_framework.so |
+# ------------------------------
+#
+# When using the cc_library version and Bazel's usual handling of the deps
+# graph, this is of course not a problem.
+#
+# As such, the DSO version is specifically useful for *building graphs in
+# Python* and therefore targets the system-provided Python TensorFlow package.
+# (C++) host programs must use the cc_library version.
+
+EXTERNAL_DATASET_OP_SRCS = ["external_dataset_op.cc"]
+
+EXTERNAL_DATASET_OP_DEPS = [
+ ":external_dataset",
+ ":status",
+ "@com_google_absl//absl/strings:str_format",
+ "//fcp/base:random_token",
+]
+
+# Public: TensorFlow op and op-kernel, that delegates to an ExternalDatasetStub
+# host object. This is the cc_library version. See explanation above.
+cc_library(
+ name = "external_dataset_op_lib",
+ srcs = EXTERNAL_DATASET_OP_SRCS,
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = EXTERNAL_DATASET_OP_DEPS + [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+# DSO version of :external_dataset_op_lib, intended to be loaded by Python
+# wrappers. See explanation above.
+tf_custom_op_library(
+ name = "_external_dataset_op.so",
+ srcs = EXTERNAL_DATASET_OP_SRCS,
+ copts = FCP_COPTS,
+ deps = EXTERNAL_DATASET_OP_DEPS,
+)
+
+# Generates the basic op wrapper for use in Python. As this is a dataset op,
+# it's not useful directly; see :external_dataset_py.
+tf_gen_op_wrapper_py(
+ name = "gen_external_dataset_py",
+ out = "gen_external_dataset_py.py",
+ deps = [
+ ":external_dataset_op_lib",
+ ],
+)
+
+# Public: Python library for ExternalDataset.
+tf_custom_op_py_library(
+ name = "external_dataset_py",
+ srcs = ["external_dataset.py"],
+ dso = [":_external_dataset_op.so"],
+ kernels = [
+ ":external_dataset_op_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [":gen_external_dataset_py"],
+)
+
+# The dataset API isn't really usable from C++, so we generate a GraphDef for
+# testing using Python.
+py_binary(
+ name = "make_external_dataset_test_graph",
+ testonly = True,
+ srcs = ["make_external_dataset_test_graph.py"],
+ python_version = "PY3",
+ deps = [":external_dataset_py"],
+)
+
+genrule(
+ name = "external_dataset_test_graph",
+ testonly = True,
+ srcs = [],
+ outs = ["external_dataset_test.pbtxt"],
+ cmd = "$(location :make_external_dataset_test_graph) --output \"$@\"",
+ tools = [":make_external_dataset_test_graph"],
+)
+
+# Selector proto used in test dataset stubs and example selector fuser op.
+proto_library(
+ name = "test_selector_proto",
+ testonly = True,
+ srcs = [
+ "test_selector.proto",
+ ],
+)
+
+cc_proto_library(
+ name = "test_selector_cc_proto",
+ testonly = True,
+ deps = [":test_selector_proto"],
+)
+
+tf_cc_test(
+ name = "external_dataset_op_test",
+ srcs = ["external_dataset_op_test.cc"],
+ data = [
+ "external_dataset_test.pbtxt",
+ ],
+ extra_copts = FCP_COPTS,
+ deps = [
+ ":external_dataset",
+ ":external_dataset_op_lib",
+ ":test_selector_cc_proto",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ "@org_tensorflow//tensorflow/core:core_cpu",
+ "@org_tensorflow//tensorflow/core:direct_session",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ "@org_tensorflow//tensorflow/core:tensorflow_opensource",
+ "@org_tensorflow//tensorflow/core:testlib",
+ ],
+)
+
+CRC32_OP_SRCS = [
+ "crc32_op.cc",
+ "tensor_crc32.cc",
+]
+
+# Custom op to compute the CRC32 checksum of a tensor.
+cc_library(
+ name = "crc32_op_lib",
+ srcs = [
+ "crc32_op.cc",
+ "tensor_crc32.cc",
+ ],
+ hdrs = ["tensor_crc32.h"],
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "_crc32_op.so",
+ srcs = CRC32_OP_SRCS + ["tensor_crc32.h"],
+ copts = FCP_COPTS,
+)
+
+# Generates the basic op wrapper for use in Python.
+tf_gen_op_wrapper_py(
+ name = "gen_crc32_py",
+ out = "gen_crc32_py.py",
+ deps = [
+ ":crc32_op_lib",
+ ],
+)
+
+tf_custom_op_py_library(
+ name = "crc32_py",
+ srcs = ["crc32.py"],
+ dso = [":_crc32_op.so"],
+ kernels = [
+ ":crc32_op_lib",
+ ],
+ deps = [":gen_crc32_py"],
+)
+
+py_test(
+ name = "crc32_test",
+ srcs = ["crc32_test.py"],
+ python_version = "PY3",
+ deps = [":crc32_py"],
+)
+
+EXAMPLE_SELECTOR_FUSER_OP_SRCS = ["example_selector_fuser_op.cc"]
+
+EXAMPLE_SELECTOR_FUSER_OP_DEPS = [
+ "@com_google_protobuf//:protobuf",
+ "//fcp/protos:plan_cc_proto",
+]
+
+# Custom op to add resumption token to example selector.
+cc_library(
+ name = "example_selector_fuser_op_lib",
+ srcs = EXAMPLE_SELECTOR_FUSER_OP_SRCS,
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = EXAMPLE_SELECTOR_FUSER_OP_DEPS + [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "_example_selector_fuser_op.so",
+ srcs = EXAMPLE_SELECTOR_FUSER_OP_SRCS,
+ copts = FCP_COPTS,
+ deps = EXAMPLE_SELECTOR_FUSER_OP_DEPS,
+)
+
+# Generates the basic op wrapper for use in Python.
+tf_gen_op_wrapper_py(
+ name = "gen_example_selector_fuser_op",
+ out = "gen_example_selector_fuser_op.py",
+ deps = [
+ ":example_selector_fuser_op_lib",
+ ],
+)
+
+tf_custom_op_py_library(
+ name = "example_selector_fuser_py",
+ srcs = ["example_selector_fuser.py"],
+ dso = [":_example_selector_fuser_op.so"],
+ kernels = [
+ ":example_selector_fuser_op_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [":gen_example_selector_fuser_op"],
+)
+
+py_proto_library(
+ name = "test_selector_py_pb2",
+ testonly = True,
+ deps = [
+ ":test_selector_proto",
+ ],
+)
+
+py_test(
+ name = "example_selector_fuser_test",
+ srcs = ["example_selector_fuser_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":example_selector_fuser_py",
+ ":test_selector_py_pb2",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+# C++ library to set and access callbacks for slice serving requests.
+# Used by the `ServeSlices` custom op below.
+cc_library(
+ name = "serve_slices_registry",
+ hdrs = [
+ "serve_slices_registry.h",
+ ],
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":host_object",
+ ],
+)
+
+cc_test(
+ name = "serve_slices_registry_test",
+ srcs = ["serve_slices_registry_test.cc"],
+ deps = [
+ ":host_object",
+ ":serve_slices_registry",
+ "//fcp/base:random_token",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/core:framework",
+ ],
+)
+
+SERVE_SLICES_OP_SRCS = ["serve_slices_op.cc"]
+
+SERVE_SLICES_OP_DEPS = [
+ ":serve_slices_registry",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+]
+
+# Custom op to register slices to serve for a `federated_select`.
+cc_library(
+ name = "serve_slices_op_lib",
+ srcs = SERVE_SLICES_OP_SRCS,
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = SERVE_SLICES_OP_DEPS + [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+# DSO version of `:serve_slices_op_lib`, intended to be loaded by Python
+# wrappers. See explanation above starting with "The targets below...".
+tf_custom_op_library(
+ name = "_serve_slices_op.so",
+ srcs = SERVE_SLICES_OP_SRCS + [
+ # Bundling the registry and op ensures that the same HostObjectRegistry is used by both.
+ "//fcp/tensorflow/python:serve_slices_registry.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = SERVE_SLICES_OP_DEPS + [
+ "@pybind11",
+ "@pybind11_abseil//pybind11_abseil:absl_casters",
+ ],
+)
+
+# Generates the basic op wrapper for use in Python.
+# Don't use this directly: use `:serve_slices_py` to ensure that the
+# appropriate shared libraries are loaded.
+tf_gen_op_wrapper_py(
+ name = "gen_serve_slices_py",
+ out = "gen_serve_slices_py.py",
+ deps = [
+ ":serve_slices_op_lib",
+ ],
+)
+
+# Public: Python library for ServeSlices.
+tf_custom_op_py_library(
+ name = "serve_slices_py",
+ srcs = ["serve_slices.py"],
+ dso = [":_serve_slices_op.so"],
+ kernels = [
+ ":serve_slices_op_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [":gen_serve_slices_py"],
+)
+
+# Generate a GraphDef for testing `ServeSlices` using Python.
+py_binary(
+ name = "make_serve_slices_test_graph",
+ testonly = True,
+ srcs = ["make_serve_slices_test_graph.py"],
+ python_version = "PY3",
+ deps = [":serve_slices_py"],
+)
+
+genrule(
+ name = "serve_slices_test_graph",
+ testonly = True,
+ srcs = [],
+ outs = ["serve_slices_test.pbtxt"],
+ cmd = "$(location :make_serve_slices_test_graph) --output \"$@\"",
+ tools = [":make_serve_slices_test_graph"],
+)
+
+tf_cc_test(
+ name = "serve_slices_op_test",
+ srcs = ["serve_slices_op_test.cc"],
+ data = [
+ "serve_slices_test.pbtxt",
+ ],
+ extra_copts = FCP_COPTS,
+ deps = [
+ ":serve_slices_op_lib",
+ ":serve_slices_registry",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_protobuf//:protobuf",
+ "@org_tensorflow//tensorflow/core:core_cpu",
+ "@org_tensorflow//tensorflow/core:direct_session",
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ "@org_tensorflow//tensorflow/core:tensorflow_opensource",
+ "@org_tensorflow//tensorflow/core:testlib",
+ "@org_tensorflow//tensorflow/core/platform:status_matchers",
+ ],
+)
+
+MAKE_SLICES_SELECTOR_EXAMPLE_SELECTOR_OP_SRCS = ["make_slices_selector_example_selector_op.cc"]
+
+MAKE_SLICES_SELECTOR_EXAMPLE_SELECTOR_OP_DEPS = [
+ "@com_google_absl//absl/strings:str_format",
+ "//fcp/protos:plan_cc_proto",
+ "//fcp/client:federated_select",
+]
+
+# Custom op to serialize an ExampleSelector containing a SlicesSelector proto.
+cc_library(
+ name = "make_slices_selector_example_selector_op_lib",
+ srcs = MAKE_SLICES_SELECTOR_EXAMPLE_SELECTOR_OP_SRCS,
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = ["@com_google_protobuf//:protobuf"] + MAKE_SLICES_SELECTOR_EXAMPLE_SELECTOR_OP_DEPS + [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "_make_slices_selector_example_selector_op.so",
+ srcs = MAKE_SLICES_SELECTOR_EXAMPLE_SELECTOR_OP_SRCS,
+ copts = FCP_COPTS,
+ deps = ["@com_google_protobuf//:protobuf"] + MAKE_SLICES_SELECTOR_EXAMPLE_SELECTOR_OP_DEPS,
+)
+
+# Generates the basic op wrapper for use in Python.
+# Don't use this directly: use `:make_slices_selector_py` to ensure that the
+# appropriate shared libraries are loaded.
+tf_gen_op_wrapper_py(
+ name = "gen_make_slices_selector_example_selector_py",
+ out = "gen_make_slices_selector_example_selector_py.py",
+ deps = [
+ ":make_slices_selector_example_selector_op_lib",
+ ],
+)
+
+# Public: Python library for the `MakeSlicesSelectorExampleSelector` op.
+tf_custom_op_py_library(
+ name = "make_slices_selector_example_selector_py",
+ srcs = ["make_slices_selector_example_selector.py"],
+ dso = [":_make_slices_selector_example_selector_op.so"],
+ kernels = [
+ ":make_slices_selector_example_selector_op_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [":gen_make_slices_selector_example_selector_py"],
+)
+
+# Test `MakeSlicesSelectorExampleSelector` using Python.
+py_test(
+ name = "make_slices_selector_example_selector_test",
+ testonly = True,
+ srcs = ["make_slices_selector_example_selector_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":make_slices_selector_example_selector_py",
+ "//fcp/protos:plan_py_pb2",
+ ],
+)
+
+APPEND_SLICES_OP_SRCS = ["append_slices_op.cc"]
+
+APPEND_SLICES_OP_DEPS = [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/synchronization",
+ "@org_tensorflow//tensorflow/core/util:saved_tensor_slice_proto_cc",
+]
+
+# Custom op to serialize an ExampleSelector containing a SlicesSelector proto.
+cc_library(
+ name = "append_slices_op_lib",
+ srcs = APPEND_SLICES_OP_SRCS,
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = APPEND_SLICES_OP_DEPS + [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ "@org_tensorflow//tensorflow/core/kernels:save_restore_tensor",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "_append_slices_op.so",
+ srcs = APPEND_SLICES_OP_SRCS,
+ copts = FCP_COPTS,
+ deps = APPEND_SLICES_OP_DEPS,
+)
+
+# Generates the basic op wrapper for use in Python.
+# Don't use this directly: use `:append_slices_py` to ensure that the
+# appropriate shared libraries are loaded.
+tf_gen_op_wrapper_py(
+ name = "gen_append_slices_py",
+ out = "gen_append_slices_py.py",
+ deps = [
+ ":append_slices_op_lib",
+ ],
+)
+
+# Public: Python library for the `AppendSlices` and `MergeAppendedSlices` ops.
+tf_custom_op_py_library(
+ name = "append_slices_py",
+ srcs = ["append_slices.py"],
+ dso = [":_append_slices_op.so"],
+ kernels = [
+ ":append_slices_op_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [":gen_append_slices_py"],
+)
+
+# Test `AppendSlices` and `MergeAppendedSlices` using Python.
+py_test(
+ name = "append_slices_test",
+ testonly = True,
+ srcs = ["append_slices_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":append_slices_py",
+ ":delete_file_py",
+ ],
+)
+
+DELETE_FILE_OP_SRCS = ["delete_file_op.cc"]
+
+DELETE_FILE_OP_DEPS = [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/synchronization",
+]
+
+# Custom op to serialize an ExampleSelector containing a SlicesSelector proto.
+cc_library(
+ name = "delete_file_op_lib",
+ srcs = DELETE_FILE_OP_SRCS,
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = DELETE_FILE_OP_DEPS + [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "_delete_file_op.so",
+ srcs = DELETE_FILE_OP_SRCS,
+ copts = FCP_COPTS,
+ deps = DELETE_FILE_OP_DEPS,
+)
+
+# Generates the basic op wrapper for use in Python.
+# Don't use this directly: use `:delete_file_py` to ensure that the
+# appropriate shared libraries are loaded.
+tf_gen_op_wrapper_py(
+ name = "gen_delete_file_py",
+ out = "gen_delete_file_py.py",
+ deps = [
+ ":delete_file_op_lib",
+ ],
+)
+
+# Public: Python library for the `DeleteFile` ops.
+tf_custom_op_py_library(
+ name = "delete_file_py",
+ srcs = ["delete_file.py"],
+ dso = [":_delete_file_op.so"],
+ kernels = [
+ ":delete_file_op_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [":gen_delete_file_py"],
+)
+
+# Test `DeleteFile` using Python.
+py_test(
+ name = "delete_file_test",
+ testonly = True,
+ srcs = ["delete_file_test.py"],
+ python_version = "PY3",
+ deps = [":delete_file_py"],
+)
+
+TENSOR_NAME_OP_SRCS = ["tensor_name_op.cc"]
+
+TENSOR_NAME_OP_DEPS = [
+ "@com_google_absl//absl/strings:str_format",
+]
+
+# Custom op to get the name of a tensor in the final graph at runtime.
+cc_library(
+ name = "tensor_name_op_lib",
+ srcs = TENSOR_NAME_OP_SRCS,
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = TENSOR_NAME_OP_DEPS + [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+# DSO version of `:tensor_name_op_lib`, intended to be loaded by Python
+# wrappers. See explanation above starting with "The targets below...".
+tf_custom_op_library(
+ name = "_tensor_name_op.so",
+ srcs = TENSOR_NAME_OP_SRCS,
+ copts = FCP_COPTS,
+ deps = TENSOR_NAME_OP_DEPS,
+)
+
+# Generates the basic op wrapper for use in Python.
+# Don't use this directly: use `:tensor_name_py` to ensure that the
+# appropriate shared libraries are loaded.
+tf_gen_op_wrapper_py(
+ name = "gen_tensor_name_py",
+ out = "gen_tensor_name_py.py",
+ deps = [
+ ":tensor_name_op_lib",
+ ],
+)
+
+# Public: Python library for the `TensorName` op.
+tf_custom_op_py_library(
+ name = "tensor_name_py",
+ srcs = ["tensor_name.py"],
+ dso = [":_tensor_name_op.so"],
+ kernels = [
+ ":tensor_name_op_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [":gen_tensor_name_py"],
+)
+
+# Test `TensorName` using Python.
+py_test(
+ name = "tensor_name_test",
+ testonly = True,
+ srcs = ["tensor_name_test.py"],
+ python_version = "PY3",
+ deps = [":tensor_name_py"],
+)
+
+TASK_ELIGIBILITY_INFO_OPS_SRCS = ["task_eligibility_info_ops.cc"]
+
+TASK_ELIGIBILITY_INFO_OPS_DEPS = [
+ "//fcp/protos:federated_api_cc_proto",
+]
+
+cc_library(
+ name = "task_eligibility_info_ops_lib",
+ srcs = TASK_ELIGIBILITY_INFO_OPS_SRCS,
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = TASK_ELIGIBILITY_INFO_OPS_DEPS + [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "_task_eligibility_info_ops.so",
+ srcs = TASK_ELIGIBILITY_INFO_OPS_SRCS,
+ copts = FCP_COPTS,
+ deps = TASK_ELIGIBILITY_INFO_OPS_DEPS,
+)
+
+# Generates the basic op wrapper for use in Python. We don't expose this wrapper
+# directly, and rather we create a more user-friendly wrapper below, which uses
+# this auto-generated one.
+tf_gen_op_wrapper_py(
+ name = "gen_task_eligibility_info_ops_py",
+ out = "gen_task_eligibility_info_ops.py",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":task_eligibility_info_ops_lib",
+ ],
+)
+
+# Python library exposing the user-facing task eligibility info ops.
+tf_custom_op_py_library(
+ name = "task_eligibility_info_ops_py",
+ srcs = ["task_eligibility_info_ops.py"],
+ dso = [":_task_eligibility_info_ops.so"],
+ kernels = [
+ ":task_eligibility_info_ops_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [":gen_task_eligibility_info_ops_py"],
+)
+
+py_test(
+ name = "task_eligibility_info_ops_test",
+ srcs = ["task_eligibility_info_ops_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":task_eligibility_info_ops_py",
+ "//fcp/protos:federated_api_py_pb2",
+ ],
+)
+
+DICTIONARY_OPS_SRCS = ["dictionary_ops.cc"]
+
+DICTIONARY_OPS_DEPS = [
+ "//fcp/base",
+ "//fcp/dictionary:dictionary_lib",
+ "//fcp/dictionary:dictionary_cc_proto",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+]
+
+cc_library(
+ name = "dictionary_ops_lib",
+ srcs = DICTIONARY_OPS_SRCS,
+ copts = FCP_COPTS,
+ visibility = ["//visibility:public"],
+ deps = DICTIONARY_OPS_DEPS + [
+ "@org_tensorflow//tensorflow/core:framework",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+ # Uses TensorFlow's registration macros
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "_dictionary_ops.so",
+ srcs = DICTIONARY_OPS_SRCS,
+ copts = FCP_COPTS,
+ deps = DICTIONARY_OPS_DEPS,
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_dictionary_ops_py",
+ out = "gen_dictionary_ops.py",
+ op_allowlist = [
+ "DictionarySize",
+ "DictionaryLookup",
+ "DictionaryReverseLookup",
+ ],
+ visibility = ["//visibility:private"],
+ deps = [":dictionary_ops_lib"],
+)
+
+tf_custom_op_py_library(
+ name = "dictionary_ops_py",
+ srcs = ["dictionary_ops.py"],
+ dso = [":_dictionary_ops.so"],
+ kernels = [
+ ":dictionary_ops_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":gen_dictionary_ops_py",
+ "//fcp/dictionary:dictionary_py_pb2",
+ ],
+)
+
+py_test(
+ name = "dictionary_ops_test",
+ srcs = ["dictionary_ops_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":dictionary_ops_py",
+ "//fcp/dictionary:dictionary_py_pb2",
+ ],
+)
diff --git a/fcp/tensorflow/append_slices.py b/fcp/tensorflow/append_slices.py
new file mode 100644
index 0000000..d5324f4
--- /dev/null
+++ b/fcp/tensorflow/append_slices.py
@@ -0,0 +1,74 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Provides the `append_slices` and `merge_appended_slices operations.
+
+This wraps the generated ops and ensures that necessary shared libraries
+are loaded.
+"""
+
+import tensorflow as tf
+
+from fcp.tensorflow import gen_append_slices_py
+
+_append_slices_so = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile('./_append_slices_op.so'))
+
+
+def append_slices(filename, tensor_names, shapes_and_slices, data, name=None):
+ """Append slices to `filename`.
+
+ Must be paired with `merge_appended_slices`.
+
+ This op is identical to `tf.raw_ops.SaveSlices`, except that it appends the
+ resulting checkpoint to `filename` rather than erasing the contents of
+ `filename`.
+
+ Note: the resulting file at `filename` will not be in checkpoint format until
+ `merge_appended_slices` has been called.
+
+ Args:
+ filename: A `Tensor` fo type `string`. Must have a single element. The name
+ of the file to which the tensor should be appended.
+ tensor_names: A `Tensor` of type `string`. Shape `[N]`. The names of the
+ tensors to be saved.
+ shapes_and_slices: A `Tensor` of type `string`. Shape `[N]`. The shapes and
+ slice specifications to use when saving the tensors.
+ data: A list of `Tensor` objects. `N` tensors to save.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created `Operation`.
+ """
+ return gen_append_slices_py.append_slices(
+ filename, tensor_names, shapes_and_slices, data, name=name)
+
+
+def merge_appended_slices(filename, name=None):
+ """Merges the appended file created by `append_slices` to a single checkpoint.
+
+ The immediate file output of `append_slices` is not in checkpoint format. It
+ must be converted to a checkpoint using this function `merge_appended_slices`.
+
+ Note: Users must call `control_dependencies` or other mechanisms to ensure
+ that the `append_slices` calls have executed prior to the execution of
+ `merge_appended_slices`.
+
+ Args:
+ filename: The name of a file appended to by calls to `append_slices`.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created `Operation`.
+ """
+ return gen_append_slices_py.merge_appended_slices(filename, name)
diff --git a/fcp/tensorflow/append_slices_op.cc b/fcp/tensorflow/append_slices_op.cc
new file mode 100644
index 0000000..eb48e48
--- /dev/null
+++ b/fcp/tensorflow/append_slices_op.cc
@@ -0,0 +1,591 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <functional>
+#include <memory>
+#include <numeric>
+#include <queue>
+#include <string>
+#include <utility>
+
+#include "absl/base/attributes.h"
+#include "absl/base/const_init.h"
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/core/framework/bounds_check.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_requires.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/lib/io/table.h"
+#include "tensorflow/core/lib/io/table_builder.h"
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/core/protobuf/error_codes.pb.h"
+#include "tensorflow/core/public/version.h"
+#include "tensorflow/core/util/saved_tensor_slice.pb.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+namespace fcp {
+namespace {
+
+using ::tensorflow::OpKernel;
+using ::tensorflow::OpKernelConstruction;
+using ::tensorflow::OpKernelContext;
+
+constexpr absl::string_view kSavedTensorSlicesKey = "";
+
+// Returns the host-endian byte representation of `value`.
+//
+// The `value` must be non-null and must continue to be valid as long as the
+// return value is used.
+absl::string_view Int64ToHostEndianBytes(int64_t* value) {
+ return absl::string_view(reinterpret_cast<const char*>(value),
+ sizeof(int64_t));
+}
+
+// Returns `value` intepreted as the host-endian bytes of an `int64_t`.
+int64_t Int64FromHostEndianBytes(const char value[sizeof(int64_t)]) {
+ return *reinterpret_cast<const int64_t*>(value);
+}
+
+// Implementation of the save ops.
+//
+// This is copied without change from save_restore_tensor.cc because that target
+// cannot be included in `tf_custom_op_library` targets due to its dependency
+// on `//third_party/tensorflow/core:framework`.
+void SaveTensors(
+ OpKernelContext* context,
+ tensorflow::checkpoint::TensorSliceWriter::CreateBuilderFunction
+ builder_func,
+ bool save_slices) {
+ const tensorflow::Tensor& filename_t = context->input(0);
+ {
+ const int64_t size = filename_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ tensorflow::errors::InvalidArgument(
+ "Input 0 (filename) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+ const std::string& filename = filename_t.scalar<tensorflow::tstring>()();
+
+ // Path, names, and slices if save_slices is true.
+ const int kFixedInputs = save_slices ? 3 : 2;
+ const tensorflow::Tensor& tensor_names_t = context->input(1);
+ OP_REQUIRES(
+ context,
+ tensorflow::FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
+ std::numeric_limits<int>::max()),
+ tensorflow::errors::InvalidArgument("Too many inputs to SaveTensors"));
+ const int N = static_cast<int>(tensor_names_t.NumElements());
+ const tensorflow::tstring* tensor_shapes_and_slices_ptr = nullptr;
+ if (save_slices) {
+ const tensorflow::Tensor& tensor_shapes_and_slices_t = context->input(2);
+ OP_REQUIRES(
+ context,
+ tensor_shapes_and_slices_t.NumElements() == static_cast<int64_t>(N),
+ tensorflow::errors::InvalidArgument(
+ "Expected ", N,
+ " elements for the tensor "
+ "shapes and slices but got ",
+ tensor_shapes_and_slices_t.NumElements()));
+ tensor_shapes_and_slices_ptr =
+ tensor_shapes_and_slices_t.flat<tensorflow::tstring>().data();
+ }
+ OP_REQUIRES(
+ context, context->num_inputs() == N + kFixedInputs,
+ tensorflow::errors::InvalidArgument(
+ "Expected totally ", N + kFixedInputs,
+ " inputs as input #1 (which is a string "
+ "tensor of saved names) contains ",
+ N, " names, but received ", context->num_inputs(), " inputs"));
+
+ VLOG(1) << "About to save tensors to file " << filename << "...";
+ tensorflow::checkpoint::TensorSliceWriter writer(filename,
+ std::move(builder_func));
+
+ tensorflow::Status s;
+ auto tensor_names_flat = tensor_names_t.flat<tensorflow::tstring>();
+
+ // Process tensors in sorted name order. This allows us to avoid seeking
+ // during restoration in the common case where we are restoring a full
+ // checkpoint.
+ // RestoreTensorsV2 was changed to sort by file offset, so this sorting isn't
+ // strictly necessary anymore. However, restores with TF version <= 2.7 will
+ // still benefit.
+ std::vector<int> sorted_name_idx(tensor_names_flat.size());
+ std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
+ std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
+ [&tensor_names_flat](size_t a, size_t b) {
+ return tensor_names_flat(a) < tensor_names_flat(b);
+ });
+
+ for (const int i : sorted_name_idx) {
+ const std::string& name = tensor_names_flat(i);
+ const tensorflow::Tensor& input = context->input(i + kFixedInputs);
+ tensorflow::TensorShape shape(input.shape());
+ tensorflow::TensorSlice slice(input.dims());
+ if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
+ const tensorflow::tstring& shape_spec = tensor_shapes_and_slices_ptr[i];
+ tensorflow::TensorShape slice_shape;
+ OP_REQUIRES_OK(context, tensorflow::checkpoint::ParseShapeAndSlice(
+ shape_spec, &shape, &slice, &slice_shape));
+ OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
+ tensorflow::errors::InvalidArgument(
+ "Slice in shape_and_slice "
+ "specification does not match the "
+ "shape of the tensor to save: ",
+ shape_spec, ", tensor: ", input.shape().DebugString()));
+ }
+
+#define WRITER_ADD(T) \
+ case tensorflow::DataTypeToEnum<T>::value: \
+ s = writer.Add(name, shape, slice, input.flat<T>().data()); \
+ break;
+
+ switch (input.dtype()) {
+ TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
+ default:
+ context->SetStatus(tensorflow::errors::Unimplemented(
+ "Saving data type ", DataTypeString(input.dtype()),
+ " not yet supported"));
+ return;
+ }
+#undef WRITER_ADD
+ if (!s.ok()) {
+ context->SetStatus(s);
+ return;
+ }
+ }
+
+ s = writer.Finish();
+ if (!s.ok()) {
+ context->SetStatus(s);
+ }
+}
+
+// A `WritableFile` that wraps an existing file, appending a chunk with a length
+// footer to the end of it.
+//
+// File start position is stored as a footer since `WritableFile` does not allow
+// `Seek`ing to modify an earlier position in the file.
+class AppendedFileWithStartPosFooter : public tensorflow::WritableFile {
+ public:
+ static tensorflow::Status FromFile(
+ std::unique_ptr<tensorflow::WritableFile> file,
+ std::unique_ptr<tensorflow::WritableFile>& wrapped_file_out) {
+ int64_t body_start;
+ TF_RETURN_IF_ERROR(file->Tell(&body_start));
+ VLOG(1) << "Appending to checkpoint with starting position " << body_start;
+ // Note: cannot use `make_unique` due to private constructor.
+ wrapped_file_out = std::unique_ptr<tensorflow::WritableFile>(
+ new AppendedFileWithStartPosFooter(std::move(file), body_start));
+ return tensorflow::OkStatus();
+ }
+ tensorflow::Status Append(tensorflow::StringPiece data) override {
+ return file_->Append(data);
+ }
+ tensorflow::Status Close() override {
+ TF_RETURN_IF_ERROR(file_->Append(Int64ToHostEndianBytes(&body_start_)));
+ return file_->Close();
+ }
+ tensorflow::Status Flush() override { return file_->Flush(); }
+ tensorflow::Status Sync() override { return file_->Sync(); }
+ tensorflow::Status Tell(int64_t* position) override {
+ int64_t internal_position;
+ TF_RETURN_IF_ERROR(file_->Tell(&internal_position));
+ *position = internal_position - body_start_;
+ return tensorflow::OkStatus();
+ }
+
+ private:
+ AppendedFileWithStartPosFooter(std::unique_ptr<tensorflow::WritableFile> file,
+ int64_t body_start)
+ : file_(std::move(file)), body_start_(body_start) {}
+
+ std::unique_ptr<tensorflow::WritableFile> file_;
+ int64_t body_start_;
+};
+
+// An implementation of the `TensorSliceWriter::Builder` interface which
+// delegates to `tensorflow::table::TableBuilder`.
+class TableBuilder : public tensorflow::checkpoint::TensorSliceWriter::Builder {
+ public:
+ TableBuilder(std::string name, std::unique_ptr<tensorflow::WritableFile> file)
+ : name_(std::move(name)), file_(std::move(file)) {
+ tensorflow::table::Options option;
+ option.compression = tensorflow::table::kNoCompression;
+ builder_ =
+ std::make_unique<tensorflow::table::TableBuilder>(option, file_.get());
+ }
+ void Add(tensorflow::StringPiece key, tensorflow::StringPiece val) override {
+ builder_->Add(key, val);
+ }
+ tensorflow::Status Finish(int64_t* file_size) override {
+ *file_size = -1;
+ tensorflow::Status s = builder_->Finish();
+ if (s.ok()) {
+ s = file_->Close();
+ if (s.ok()) {
+ *file_size = builder_->FileSize();
+ }
+ }
+ if (!s.ok()) {
+ s = tensorflow::errors::Internal(
+#if TF_GRAPH_DEF_VERSION < 1467
+ "Error writing (tmp) checkpoint file: ", name_, ": ",
+ s.error_message());
+#else
+ "Error writing (tmp) checkpoint file: ", name_, ": ", s.message());
+#endif
+ }
+ return s;
+ }
+
+ private:
+ std::string name_;
+ std::unique_ptr<tensorflow::WritableFile> file_;
+ std::unique_ptr<tensorflow::table::TableBuilder> builder_;
+};
+
+// Creates a new `TensorSliceWriter::Builder` which will append the tensor
+// slices to `filename` along with a footer indicating the start position of
+// this particular chunk of slices.
+//
+// If this method returns `OK`, `builder` will contain a new owned pointer to
+// a `TensorSliceWriter::Builder`.
+tensorflow::Status CreateAppendingTensorSliceBuilder(
+ const std::string& filename,
+ tensorflow::checkpoint::TensorSliceWriter::Builder** builder) {
+ *builder = nullptr;
+ if (VLOG_IS_ON(1)) {
+ uint64_t file_size = 0;
+ if (tensorflow::Env::Default()->GetFileSize(filename, &file_size).ok()) {
+ VLOG(1) << "Appending checkpoint to file " << filename << " with size "
+ << file_size;
+ } else {
+ VLOG(1) << "Appending checkpoint to new file " << filename;
+ }
+ }
+ std::unique_ptr<tensorflow::WritableFile> file;
+ TF_RETURN_IF_ERROR(
+ tensorflow::Env::Default()->NewAppendableFile(filename, &file));
+ std::unique_ptr<tensorflow::WritableFile> wrapped_file;
+ TF_RETURN_IF_ERROR(
+ AppendedFileWithStartPosFooter::FromFile(std::move(file), wrapped_file));
+ *builder = new TableBuilder(filename, std::move(wrapped_file));
+ return tensorflow::OkStatus();
+}
+
+// A `RandomAccessFile` which wraps another `RandomAccessFile`, providing access
+// to only a portion of the file.
+class PartialRandomAccessFile : public tensorflow::RandomAccessFile {
+ public:
+ // Constructs a `PartialRandomAccessFile` pointing to a segment of `file`.
+ //
+ // `file` must be non-null and must continue to be valid as long as the
+ // return value is used.
+ PartialRandomAccessFile(tensorflow::RandomAccessFile* file, int64_t start,
+ int64_t end)
+ : file_(file), start_(start), end_(end) {}
+ ~PartialRandomAccessFile() override = default;
+ tensorflow::Status Read(uint64_t offset, size_t n,
+ tensorflow::StringPiece* result,
+ char* scratch) const override {
+ const size_t max_allowable_n = end_ - (start_ + offset);
+ bool read_too_long = n > max_allowable_n;
+ if (read_too_long) {
+ n = max_allowable_n;
+ }
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ file_->Read(offset + start_, n, result, scratch),
+ absl::StrCat("Reading from PartialRandomAccessFile at offset ", offset,
+ " from start position ", start_));
+ if (read_too_long) {
+ return tensorflow::Status(
+ static_cast<tensorflow::errors::Code>(absl::StatusCode::kOutOfRange),
+ "Attempted to read past end of file chunk.");
+ }
+ return tensorflow::OkStatus();
+ }
+
+ private:
+ tensorflow::RandomAccessFile* file_;
+ int64_t start_;
+ int64_t end_;
+};
+
+struct TableIteratorComparator {
+ // Returns whether `i1` should come after `i2` in the priority queue.
+ // That is, whether `i1` has *lower* priority than `i2`.
+ bool operator()(const std::unique_ptr<tensorflow::table::Iterator>& i1,
+ const std::unique_ptr<tensorflow::table::Iterator>& i2) {
+ // Ensure that iterators which have no remaining elements go last in the
+ // list.
+ if (!i2->Valid()) {
+ return false;
+ }
+ if (!i1->Valid()) {
+ return true;
+ }
+ if ((i2->key() == kSavedTensorSlicesKey) &&
+ (i1->key() != kSavedTensorSlicesKey)) {
+ return true;
+ }
+ return i1->key() > i2->key();
+ }
+};
+
+// Pops and returns the top element of a `std::priority_queue`.
+template <class Element, class Container, class Comparator>
+Element PopWithElement(
+ std::priority_queue<Element, Container, Comparator>& queue) {
+ Element e = std::move(const_cast<Element&>(queue.top()));
+ queue.pop();
+ return e;
+}
+
+// Parses a `serialized` into a `SavedTensorSlices` stored in `meta_out`.
+tensorflow::Status MetadataFromString(absl::string_view serialized,
+ tensorflow::SavedTensorSlices& meta_out) {
+ // NOTE: The conversion to `std::string` is unfortunately necessary here
+ // because the OSS version of `ParseFromString` takes a `const std::string&`
+ // rather than a `absl::string_view`.
+ if (!meta_out.ParseFromString(std::string(serialized))) {
+ return tensorflow::Status(
+ static_cast<tensorflow::errors::Code>(absl::StatusCode::kInternal),
+ absl::StrCat("Failed to parse table entry as `SavedTensorSlices`: ",
+ serialized));
+ }
+ return tensorflow::OkStatus();
+}
+
+// Merges appended checkpoints in `filename` into a single checkpoint.
+//
+// Note: this function accepts `filename` as a `const std::string&` rather than
+// `string_view` because that is the type accepted by the functions it calls
+// (`GetFileSize` and `NewRandomAccessFile`). This avoids unnecessary
+// allocation.
+tensorflow::Status LoadAndMergeAppendedSlices(const std::string& filename) {
+ tensorflow::Env* env = tensorflow::Env::Default();
+ uint64_t file_size;
+ TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
+ // Short-circuit on empty files so that we can assume at least a single entry
+ // below.
+ if (file_size == 0) {
+ return tensorflow::OkStatus();
+ }
+ std::unique_ptr<tensorflow::RandomAccessFile> file;
+ TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
+
+ // Overwrite the underlying file, relying on `file` above to provide a handle
+ // into the old file contents even after it is overwritten.
+ TF_RETURN_IF_ERROR(tensorflow::Env::Default()->DeleteFile(filename));
+
+ // `chunk_files` and `chunk_tables` must be kept around since they are
+ // referenced internally by `chunk_iterators`.
+ std::vector<std::unique_ptr<tensorflow::RandomAccessFile>> chunk_files;
+ std::vector<std::unique_ptr<tensorflow::table::Table>> chunk_tables;
+ std::priority_queue<std::unique_ptr<tensorflow::table::Iterator>,
+ std::vector<std::unique_ptr<tensorflow::table::Iterator>>,
+ TableIteratorComparator>
+ chunk_iterators;
+
+ tensorflow::SavedTensorSlices merged_sts;
+ tensorflow::SavedTensorSliceMeta* merged_meta = merged_sts.mutable_meta();
+ std::set<std::string> slices_added;
+
+ // Read all of the chunks into tables.
+ int64_t chunk_footer_end = file_size;
+ bool version_was_set = false;
+ while (chunk_footer_end > 0) {
+ // Read in the footer telling us where the chunk started.
+ char footer_scratch[sizeof(int64_t)];
+ tensorflow::StringPiece chunk_footer;
+ TF_RETURN_IF_ERROR(file->Read(chunk_footer_end - sizeof(int64_t),
+ sizeof(int64_t), &chunk_footer,
+ footer_scratch));
+ int64_t chunk_start = Int64FromHostEndianBytes(chunk_footer.data());
+ int64_t chunk_end = chunk_footer_end - sizeof(int64_t);
+ int64_t chunk_len = chunk_end - chunk_start;
+ std::unique_ptr<tensorflow::RandomAccessFile> chunk_file =
+ std::make_unique<PartialRandomAccessFile>(file.get(), chunk_start,
+ chunk_end);
+ tensorflow::table::Options options;
+ tensorflow::table::Table* raw_table;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ tensorflow::table::Table::Open(options, chunk_file.get(), chunk_len,
+ &raw_table),
+ absl::StrCat("Error opening sub-table of file ", filename,
+ " starting at ", chunk_start, " and ending at ", chunk_end,
+ ". Total file size: ", file_size));
+ std::unique_ptr<tensorflow::table::Table> table(raw_table);
+ tensorflow::table::Iterator* raw_iterator = table->NewIterator();
+ std::unique_ptr<tensorflow::table::Iterator> iterator(raw_iterator);
+ iterator->SeekToFirst();
+ if (!iterator->Valid()) {
+ return tensorflow::Status(
+ static_cast<tensorflow::errors::Code>(absl::StatusCode::kInternal),
+ "Unexpected immediately-invalid iterator. "
+ "Expected table to iterator to have at least a "
+ "single entry (metadata)");
+ }
+ if (iterator->key() != kSavedTensorSlicesKey) {
+ return tensorflow::Status(
+ static_cast<tensorflow::errors::Code>(absl::StatusCode::kInternal),
+ absl::StrCat("Expected table iterator to have an initial metadata "
+ "entry with key `",
+ kSavedTensorSlicesKey, "`, found key `", iterator->key(),
+ "`"));
+ }
+ tensorflow::SavedTensorSlices sts;
+ TF_RETURN_IF_ERROR(MetadataFromString(iterator->value(), sts));
+ iterator->Next();
+ if (!version_was_set) {
+ version_was_set = true;
+ *merged_meta->mutable_versions() = sts.meta().versions();
+ }
+ for (const tensorflow::SavedSliceMeta& slice_meta : sts.meta().tensor()) {
+ if (slices_added.find(slice_meta.name()) != slices_added.end()) {
+ return tensorflow::Status(
+ // Remove the cast after TF 2.12 is released and used in FCP.
+ static_cast<tensorflow::errors::Code>(
+ absl::StatusCode::kInvalidArgument),
+ absl::StrCat(
+ "Attempted to merge two checkpoint entries for slice name: `",
+ slice_meta.name(), "`. Only one entry per name is permitted."));
+ }
+ slices_added.insert(slice_meta.name());
+ }
+ merged_meta->mutable_tensor()->MergeFrom(sts.meta().tensor());
+ chunk_iterators.push(std::move(iterator));
+ chunk_files.push_back(std::move(chunk_file));
+ chunk_tables.push_back(std::move(table));
+ chunk_footer_end = chunk_start;
+ }
+ VLOG(1) << "Merging " << chunk_files.size() << " checkpoint chunks from file "
+ << filename;
+
+ tensorflow::checkpoint::TensorSliceWriter::Builder* raw_builder;
+ TF_RETURN_IF_ERROR(tensorflow::checkpoint::CreateTableTensorSliceBuilder(
+ filename, &raw_builder));
+ std::unique_ptr<tensorflow::checkpoint::TensorSliceWriter::Builder> builder(
+ raw_builder);
+
+ // First, we add the merged entry which holds a `SavedTensorSlices` proto.
+ builder->Add(kSavedTensorSlicesKey, merged_sts.SerializeAsString());
+
+ // Then the remaining entries are concatenated alphabetically.
+ while (chunk_iterators.top()->Valid()) {
+ std::unique_ptr<tensorflow::table::Iterator> iter =
+ PopWithElement(chunk_iterators);
+ VLOG(2) << "Merging table entry for key " << iter->key();
+ builder->Add(iter->key(), iter->value());
+ iter->Next();
+ chunk_iterators.push(std::move(iter));
+ }
+ int64_t resulting_file_size;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(builder->Finish(&resulting_file_size),
+ "Finishing TensorSliceWriter::Builder");
+ return tensorflow::OkStatus();
+}
+
+ABSL_CONST_INIT absl::Mutex append_mutex(absl::kConstInit);
+
+} // namespace
+
+class AppendSlicesOp : public OpKernel {
+ public:
+ explicit AppendSlicesOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ absl::MutexLock lock(&append_mutex);
+ const tensorflow::Tensor& filename_t = context->input(0);
+ tensorflow::tstring filename = filename_t.flat<tensorflow::tstring>()(0);
+ SaveTensors(
+ context,
+ [context](
+ const std::string& target_filename,
+ tensorflow::checkpoint::TensorSliceWriter::Builder** builder) {
+ // `TensorSliceWriter` targets writing to a new temporary file which
+ // it then moves into the location of the final file once complete.
+ // In order to comply with this behavior while still retaining
+ // "append" semantics, the original file (if it exists) is first moved
+ // into the temporary target location.
+ tensorflow::tstring original_filename =
+ context->input(0).scalar<tensorflow::tstring>()();
+ tensorflow::Status status = tensorflow::Env::Default()->RenameFile(
+ original_filename, target_filename);
+ if (status.ok()) {
+ VLOG(1) << "Appending to existing file " << original_filename
+ << " via move to temporary location " << target_filename;
+ } else if (status.code() == tensorflow::error::NOT_FOUND) {
+ VLOG(1) << "Appending to new file " << original_filename
+ << " in temporary location " << target_filename;
+ } else {
+ return status;
+ }
+ return CreateAppendingTensorSliceBuilder(target_filename, builder);
+ },
+ /*save_slices=*/true);
+ }
+};
+
+class MergeAppendedSlicesOp : public OpKernel {
+ public:
+ explicit MergeAppendedSlicesOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ absl::MutexLock lock(&append_mutex);
+ const tensorflow::Tensor* filename_tensor;
+ OP_REQUIRES_OK(context, context->input("filename", &filename_tensor));
+ const tensorflow::tstring filename =
+ filename_tensor->scalar<tensorflow::tstring>()();
+ OP_REQUIRES_OK(context, LoadAndMergeAppendedSlices(filename));
+ }
+};
+
+// Note: `key` *must* come last so that the indices of the other arguments are
+// as expected by `SaveTensors`.
+REGISTER_OP("AppendSlices")
+ .Input("filename: string")
+ .Input("tensor_names: string")
+ .Input("shapes_and_slices: string")
+ .Input("data: T")
+ .Attr("T: list(type)")
+ .SetIsStateful();
+
+REGISTER_KERNEL_BUILDER(Name("AppendSlices").Device(tensorflow::DEVICE_CPU),
+ AppendSlicesOp);
+
+REGISTER_OP("MergeAppendedSlices").Input("filename: string").SetIsStateful();
+
+REGISTER_KERNEL_BUILDER(
+ Name("MergeAppendedSlices").Device(tensorflow::DEVICE_CPU),
+ MergeAppendedSlicesOp);
+
+} // namespace fcp
diff --git a/fcp/tensorflow/append_slices_test.py b/fcp/tensorflow/append_slices_test.py
new file mode 100644
index 0000000..4b39e66
--- /dev/null
+++ b/fcp/tensorflow/append_slices_test.py
@@ -0,0 +1,183 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for the `append_slices` and `merge_appended_slices` custom ops."""
+
+import os
+import tensorflow as tf
+
+from fcp.tensorflow import append_slices
+from fcp.tensorflow import delete_file
+
+
+class AppendSlicesTest(tf.test.TestCase):
+
+ def new_tempfile_path(self):
+ """Returns a path that can be used to store a new tempfile."""
+ return os.path.join(self.create_tempdir(), 'checkpoint.ckp')
+
+ def test_converts_single_element_once_appended_file_to_checkpoint(self):
+ checkpoint_path = self.new_tempfile_path()
+ tensor_name = 'a'
+ tensor = tf.constant(42, dtype=tf.int32)
+ append_slices.append_slices(
+ filename=checkpoint_path,
+ tensor_names=[tensor_name],
+ data=[tensor],
+ shapes_and_slices=[''])
+ append_slices.merge_appended_slices(checkpoint_path)
+ restored = tf.raw_ops.RestoreV2(
+ prefix=checkpoint_path,
+ tensor_names=[tensor_name],
+ shape_and_slices=[''],
+ dtypes=[tf.int32])
+ self.assertEqual(restored[0], 42)
+
+ def test_converts_single_element_twice_appended_file_to_checkpoint(self):
+ checkpoint_path = self.new_tempfile_path()
+ tensor_names = ['a', 'b']
+ tensor_values = [tf.constant(x, dtype=tf.int32) for x in (7, 11)]
+ for (tensor_name, tensor_value) in zip(tensor_names, tensor_values):
+ append_slices.append_slices(
+ filename=checkpoint_path,
+ tensor_names=[tensor_name],
+ data=[tensor_value],
+ shapes_and_slices=[''])
+ append_slices.merge_appended_slices(checkpoint_path)
+ restored = tf.raw_ops.RestoreV2(
+ prefix=checkpoint_path,
+ tensor_names=tensor_names,
+ shape_and_slices=[''] * 2,
+ dtypes=[tf.int32] * 2)
+ self.assertEqual(restored[0], 7)
+ self.assertEqual(restored[1], 11)
+
+ def test_converts_two_element_once_appended_file_to_checkpoint(self):
+ checkpoint_path = self.new_tempfile_path()
+ tensors = [('a', 16), ('b', 17)]
+ append_slices.append_slices(
+ filename=checkpoint_path,
+ tensor_names=[name for (name, value) in tensors],
+ data=[tf.constant(value, tf.int32) for (name, value) in tensors],
+ shapes_and_slices=['' for _ in tensors])
+ append_slices.merge_appended_slices(checkpoint_path)
+ restored = tf.raw_ops.RestoreV2(
+ prefix=checkpoint_path,
+ tensor_names=['a', 'b'],
+ shape_and_slices=[''] * 2,
+ dtypes=[tf.int32] * 2)
+ self.assertEqual(restored[0], 16)
+ self.assertEqual(restored[1], 17)
+
+ def test_converts_two_element_multi_twice_appended_file_to_checkpoint(self):
+ # Note: the interleaved ordering ensures that the resulting merged
+ # checkpoint is able to mix together the two input checkpoints properly.
+ checkpoint_path = self.new_tempfile_path()
+ tensors = [
+ [('a', 12), ('c', 55)],
+ [('b', 40), ('d', 88)],
+ ]
+ for tensors_for_checkpoint in tensors:
+ append_slices.append_slices(
+ filename=checkpoint_path,
+ tensor_names=[name for (name, value) in tensors_for_checkpoint],
+ data=[
+ tf.constant(value, tf.int32)
+ for (name, value) in tensors_for_checkpoint
+ ],
+ shapes_and_slices=['' for _ in tensors_for_checkpoint])
+ append_slices.merge_appended_slices(checkpoint_path)
+ restored = tf.raw_ops.RestoreV2(
+ prefix=checkpoint_path,
+ tensor_names=['a', 'b', 'c', 'd'],
+ shape_and_slices=[''] * 4,
+ dtypes=[tf.int32] * 4)
+ self.assertEqual(restored[0], 12)
+ self.assertEqual(restored[1], 40)
+ self.assertEqual(restored[2], 55)
+ self.assertEqual(restored[3], 88)
+
+ def test_converts_nonalphabetical_two_element_multi_twice_appended_file_to_checkpoint(
+ self):
+ # Note: the interleaved ordering ensures that the resulting merged
+ # checkpoint is able to mix together the two input checkpoints properly.
+ checkpoint_path = self.new_tempfile_path()
+ tensors = [
+ [('b', 12), ('a', 55)],
+ [('d', 40), ('c', 88)],
+ ]
+ for tensors_for_checkpoint in tensors:
+ append_slices.append_slices(
+ filename=checkpoint_path,
+ tensor_names=[name for (name, value) in tensors_for_checkpoint],
+ data=[
+ tf.constant(value, tf.int32)
+ for (name, value) in tensors_for_checkpoint
+ ],
+ shapes_and_slices=['' for _ in tensors_for_checkpoint])
+ append_slices.merge_appended_slices(checkpoint_path)
+ restored = tf.raw_ops.RestoreV2(
+ prefix=checkpoint_path,
+ tensor_names=['d', 'c', 'b', 'a'],
+ shape_and_slices=[''] * 4,
+ dtypes=[tf.int32] * 4)
+ self.assertEqual(restored[0], 40)
+ self.assertEqual(restored[1], 88)
+ self.assertEqual(restored[2], 12)
+ self.assertEqual(restored[3], 55)
+
+ def test_merge_missing_checkpoint_file_raises(self):
+ checkpoint_path = self.new_tempfile_path()
+ with self.assertRaises(tf.errors.NotFoundError):
+ append_slices.merge_appended_slices(checkpoint_path)
+
+ def test_duplicate_named_tensor_raises(self):
+ checkpoint_path = self.new_tempfile_path()
+ tensor_values = [tf.constant(x, dtype=tf.int32) for x in (7, 11)]
+ for tensor_value in tensor_values:
+ append_slices.append_slices(
+ filename=checkpoint_path,
+ tensor_names=['a'],
+ data=[tensor_value],
+ shapes_and_slices=[''])
+ with self.assertRaisesRegex(
+ tf.errors.InvalidArgumentError,
+ 'Attempted to merge two checkpoint entries for slice name: `a`'):
+ append_slices.merge_appended_slices(checkpoint_path)
+
+ def test_append_and_merge_using_same_filename(self):
+ checkpoint_path = self.new_tempfile_path()
+ for _ in range(2):
+ # Without calling this we might append to a previously used file.
+ delete_file.delete_file(checkpoint_path)
+
+ tensor_names = ['a', 'b']
+ tensor_values = [tf.constant(x, dtype=tf.int32) for x in (7, 11)]
+ for (tensor_name, tensor_value) in zip(tensor_names, tensor_values):
+ append_slices.append_slices(
+ filename=checkpoint_path,
+ tensor_names=[tensor_name],
+ data=[tensor_value],
+ shapes_and_slices=[''])
+ append_slices.merge_appended_slices(checkpoint_path)
+ restored = tf.raw_ops.RestoreV2(
+ prefix=checkpoint_path,
+ tensor_names=tensor_names,
+ shape_and_slices=[''] * 2,
+ dtypes=[tf.int32] * 2)
+ self.assertEqual(restored[0], 7)
+ self.assertEqual(restored[1], 11)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/fcp/tensorflow/crc32.py b/fcp/tensorflow/crc32.py
new file mode 100644
index 0000000..dcff3b5
--- /dev/null
+++ b/fcp/tensorflow/crc32.py
@@ -0,0 +1,40 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Provides the `crc32` operation.
+
+This wraps the generated op and ensures that necessary shared libraries
+are loaded.
+"""
+
+from typing import Optional
+
+import tensorflow as tf
+
+from fcp.tensorflow import gen_crc32_py
+
+_crc32_so = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile('./_crc32_op.so'))
+
+
+def crc32(tensor: tf.Tensor, name: Optional[str] = None) -> tf.Operation:
+ """Computes the CRC32 checksum of a Tensor.
+
+ Args:
+ tensor: The input `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created `Operation`.
+ """
+ return gen_crc32_py.crc32(tensor, name=name)
diff --git a/fcp/tensorflow/crc32_op.cc b/fcp/tensorflow/crc32_op.cc
new file mode 100644
index 0000000..77fa31d
--- /dev/null
+++ b/fcp/tensorflow/crc32_op.cc
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/tensorflow/tensor_crc32.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+
+namespace fcp {
+namespace tensorflow {
+namespace checksums {
+
+using ::tensorflow::DEVICE_CPU;
+using ::tensorflow::OpKernel;
+using ::tensorflow::OpKernelConstruction;
+using ::tensorflow::OpKernelContext;
+using ::tensorflow::StringPiece;
+using ::tensorflow::Tensor;
+using ::tensorflow::shape_inference::InferenceContext;
+
+REGISTER_OP("CRC32")
+ .Input("input: T")
+ .Attr("T: type")
+ .Output("checksum: uint32")
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return ::tensorflow::OkStatus();
+ });
+
+class CRC32Op : public OpKernel {
+ public:
+ explicit CRC32Op(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Create an output tensor
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor));
+
+ // Store CRC32 of input tensor in output.
+ output_tensor->scalar<uint32_t>()() = TensorToCRC32(context->input(0));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("CRC32").Device(DEVICE_CPU), CRC32Op);
+} // namespace checksums
+} // namespace tensorflow
+} // namespace fcp
diff --git a/fcp/tensorflow/crc32_test.py b/fcp/tensorflow/crc32_test.py
new file mode 100644
index 0000000..a6c6fcf
--- /dev/null
+++ b/fcp/tensorflow/crc32_test.py
@@ -0,0 +1,53 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import tensorflow as tf
+from fcp.tensorflow import crc32
+
+
+class CRC32Test(tf.test.TestCase):
+
+ def test_crc32(self):
+ assert 0 == crc32.crc32([])
+
+ # Constants taken from rfc3720 B.4
+ # 1. Tests on byte arrays
+ assert 0x8a9136aa == crc32.crc32(np.zeros(32, dtype=np.uint8))
+ assert 0x62a8ab43 == crc32.crc32(255 * np.ones(32, dtype=np.uint8))
+ x = np.arange(0, 0x20, dtype=np.uint8)
+ assert 0x46dd794e == crc32.crc32(x)
+ # 2. Tests for higher dimensional tensor shapes
+ assert 0x46dd794e == crc32.crc32(x.reshape(2, -1))
+ assert 0x46dd794e == crc32.crc32(x.reshape(4, 4, 2))
+ # Transpose will change memory order so checksum should change.
+ assert 0x46dd794e != crc32.crc32(x.reshape(2, -1).transpose())
+
+ # 3. Tests on int32, int64
+ assert crc32.crc32(tf.constant(0x123456789abcdef0,
+ dtype=tf.int64)) == crc32.crc32(
+ tf.constant([0x9abcdef0, 0x12345678],
+ dtype=tf.int32))
+
+ # 4. IEEE float test. Not much to test here other than that the checksum
+ # produces the same result as it would on an integer tensor with the same
+ # memory representation.
+ assert crc32.crc32(tf.constant([-0.0, 0.0],
+ dtype=tf.float32)) == crc32.crc32(
+ tf.constant([0x80000000, 0x00000000],
+ dtype=tf.int32))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/fcp/tensorflow/delete_file.py b/fcp/tensorflow/delete_file.py
new file mode 100644
index 0000000..802fba0
--- /dev/null
+++ b/fcp/tensorflow/delete_file.py
@@ -0,0 +1,50 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Provides the `delete_file` operation.
+
+This wraps the generated ops and ensures that necessary shared libraries
+are loaded.
+"""
+
+import tensorflow as tf
+
+from fcp.tensorflow import gen_delete_file_py
+
+_delete_file_so = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile('./_delete_file_op.so'))
+
+
+def delete_file(filename: tf.Tensor) -> tf.Operation:
+ """Delete file if the filename exists.
+
+ Args:
+ filename: The filename to delete.
+
+ Returns:
+ The created `Operation`.
+ """
+ return gen_delete_file_py.delete_file(filename)
+
+
+def delete_dir(dirname: tf.Tensor, recursively: bool = False) -> tf.Operation:
+ """Delete directory if the dirname exists.
+
+ Args:
+ dirname: The directory to delete.
+ recursively: If true the op attempts to delete also the content.
+
+ Returns:
+ The created `Operation`.
+ """
+ return gen_delete_file_py.delete_dir(dirname, recursively)
diff --git a/fcp/tensorflow/delete_file_op.cc b/fcp/tensorflow/delete_file_op.cc
new file mode 100644
index 0000000..f8828b7
--- /dev/null
+++ b/fcp/tensorflow/delete_file_op.cc
@@ -0,0 +1,127 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_requires.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/core/util/saved_tensor_slice.pb.h"
+
+namespace fcp {
+namespace {
+
+using ::tensorflow::OpKernel;
+using ::tensorflow::OpKernelConstruction;
+using ::tensorflow::OpKernelContext;
+
+} // namespace
+
+class DeleteDirOp : public OpKernel {
+ public:
+ explicit DeleteDirOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const tensorflow::Tensor& dirname_t = context->input(0);
+ const tensorflow::Tensor& recursively_t = context->input(1);
+ {
+ const int64_t size = dirname_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ tensorflow::errors::InvalidArgument(
+ "Input 0 (dirname) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+ {
+ const int64_t size = recursively_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ tensorflow::errors::InvalidArgument(
+ "Input 1 (recursively) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+ const tensorflow::tstring& dirname =
+ dirname_t.scalar<tensorflow::tstring>()();
+ const bool recursively = recursively_t.scalar<bool>()();
+ if (context->env()->IsDirectory(dirname).ok()) {
+ if (recursively) {
+ int64_t undeleted_files = 0;
+ int64_t undeleted_dirs = 0;
+ tensorflow::Status delete_status = context->env()->DeleteRecursively(
+ dirname, &undeleted_files, &undeleted_dirs);
+ if (!delete_status.ok()) {
+ // The directory could be already deleted by another op. Let's not
+ // propagate this error.
+ LOG(WARNING) << "Failed to recursively delete the directory '"
+ << dirname << "' (remaining files: " << undeleted_files
+ << ", remaining dirs: " << undeleted_dirs << "). "
+ << delete_status;
+ }
+ } else {
+ tensorflow::Status delete_status = context->env()->DeleteDir(dirname);
+ if (!delete_status.ok()) {
+ // The directory could be already deleted by another op. Let's not
+ // propagate this error.
+ LOG(WARNING) << "Failed to delete the directory '" << dirname << "'. "
+ << delete_status;
+ }
+ }
+ }
+ }
+};
+
+REGISTER_OP("DeleteDir")
+ .Input("dirname: string")
+ .Input("recursively: bool")
+ .SetIsStateful();
+REGISTER_KERNEL_BUILDER(Name("DeleteDir").Device(tensorflow::DEVICE_CPU),
+ DeleteDirOp);
+
+class DeleteFileOp : public OpKernel {
+ public:
+ explicit DeleteFileOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const tensorflow::Tensor& filename_t = context->input(0);
+ {
+ const int64_t size = filename_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ tensorflow::errors::InvalidArgument(
+ "Input 0 (filename) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+ const tensorflow::tstring& filename =
+ filename_t.scalar<tensorflow::tstring>()();
+ if (context->env()->FileExists(filename).ok()) {
+ tensorflow::Status delete_status = context->env()->DeleteFile(filename);
+ if (!delete_status.ok()) {
+ // The file could be already deleted by another op. Let's not propagate
+ // this error.
+ LOG(WARNING) << "Failed to delete the file '" << filename << "'. "
+ << delete_status;
+ }
+ }
+ }
+};
+
+REGISTER_OP("DeleteFile").Input("filename: string").SetIsStateful();
+REGISTER_KERNEL_BUILDER(
+ Name("DeleteFile").Device(tensorflow::DEVICE_CPU),
+ DeleteFileOp);
+
+} // namespace fcp
diff --git a/fcp/tensorflow/delete_file_test.py b/fcp/tensorflow/delete_file_test.py
new file mode 100644
index 0000000..1ba615f
--- /dev/null
+++ b/fcp/tensorflow/delete_file_test.py
@@ -0,0 +1,100 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for the `delete_file` custom op."""
+
+import os
+
+import tensorflow as tf
+
+from fcp.tensorflow import delete_file
+
+
+class DeleteOpTest(tf.test.TestCase):
+
+ def setup_temp_dir(self) -> tuple[str, str]:
+ """Sets up a temporary directory suitable for testing.
+
+ The filesystem consist of directory with one file inside.
+
+ Returns:
+ Tuple of directory and checkpoint paths.
+ """
+ temp_dir = self.create_tempdir().full_path
+ temp_file = os.path.join(temp_dir, 'checkpoint.ckp')
+
+ expected_content = 'content'
+ tf.io.write_file(temp_file, expected_content)
+ read_content = tf.io.read_file(temp_file)
+ self.assertEqual(expected_content, read_content)
+
+ self.assertTrue(os.path.isdir(temp_dir))
+ self.assertTrue(os.path.exists(temp_file))
+ return temp_dir, temp_file
+
+ def test_delete_file_op(self):
+ _, temp_file = self.setup_temp_dir()
+
+ delete_file.delete_file(temp_file)
+ # Delete one more time to make sure no error when the file doesn't exist.
+ delete_file.delete_file(temp_file)
+ self.assertFalse(os.path.exists(temp_file))
+
+ def test_delete_file_op_exceptions(self):
+ with self.subTest(name='non_string_dtype'):
+ with self.assertRaises(TypeError):
+ delete_file.delete_file(1.0)
+ with self.subTest(name='non_scalar'):
+ with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
+ '.*must be a string scalar.*'):
+ _, checkpoint_path = self.setup_temp_dir()
+ delete_file.delete_file([checkpoint_path, checkpoint_path])
+
+ def test_delete_file_and_dir_succeeds(self):
+ temp_dir, temp_file = self.setup_temp_dir()
+ delete_file.delete_file(temp_file)
+ self.assertFalse(os.path.exists(temp_file))
+
+ delete_file.delete_dir(temp_dir)
+ # Delete dir more time to make sure no error when the dir doesn't exist.
+ delete_file.delete_dir(temp_dir)
+ self.assertFalse(os.path.isdir(temp_dir))
+
+ def test_delete_non_empty_dir_fails(self):
+ temp_dir, temp_file = self.setup_temp_dir()
+
+ delete_file.delete_dir(temp_dir)
+ self.assertTrue(os.path.isdir(temp_dir))
+ self.assertTrue(os.path.exists(temp_file))
+
+ def test_recursive_delete_non_empty_dir_succeeds(self):
+ temp_dir, temp_file = self.setup_temp_dir()
+
+ delete_file.delete_dir(temp_dir, recursively=True)
+ self.assertFalse(os.path.isdir(temp_dir))
+ self.assertFalse(os.path.exists(temp_file))
+
+ def test_delete_dir_op_exceptions(self):
+ with self.subTest(name='non_string_dtype'):
+ with self.assertRaises(TypeError):
+ delete_file.delete_dir(1.0)
+ with self.subTest(name='non_scalar'):
+ with self.assertRaisesRegex(
+ tf.errors.InvalidArgumentError, '.*must be a string scalar.*'
+ ):
+ temp_dir, _ = self.setup_temp_dir()
+ delete_file.delete_dir([temp_dir, temp_dir])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/fcp/tensorflow/dictionary_ops.cc b/fcp/tensorflow/dictionary_ops.cc
new file mode 100644
index 0000000..0868048
--- /dev/null
+++ b/fcp/tensorflow/dictionary_ops.cc
@@ -0,0 +1,252 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "fcp/base/monitoring.h"
+#include "fcp/dictionary/dictionary.h"
+#include "fcp/dictionary/dictionary.pb.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tf = tensorflow;
+
+namespace fcp {
+namespace tensorflow {
+
+using fcp::dictionary::Dictionary;
+using fcp::dictionary::DictionaryDescription;
+
+namespace {
+
+// Base class for ops that work with a Dictionary.
+//
+// Subclasses need to provide Compute and register appropriately using
+// REGISTER_OP.
+class AbstractDictionaryOp : public tf::OpKernel {
+ public:
+ explicit AbstractDictionaryOp(tf::OpKernelConstruction* context,
+ int32_t num_expected_inputs)
+ : tf::OpKernel(context), num_expected_inputs_(num_expected_inputs) {
+ std::string dictionary_description_string;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("dictionary_description_proto",
+ &dictionary_description_string));
+
+ DictionaryDescription parsed_dictionary_description;
+ OP_REQUIRES(context,
+ parsed_dictionary_description.ParseFromString(
+ dictionary_description_string),
+ tf::errors::InvalidArgument(
+ "Cannot parse provided DictionaryDescription."));
+
+ if (parsed_dictionary_description.has_vocabulary()) {
+ // Fully specified dictionary.
+ absl::StatusOr<std::unique_ptr<Dictionary>> dictionary(
+ Dictionary::Create(parsed_dictionary_description));
+ OP_REQUIRES(context, dictionary.ok(),
+ tf::errors::InvalidArgument(dictionary.status().ToString()));
+ dictionary_ = *std::move(dictionary);
+ parsed_dictionary_description.clear_vocabulary(); // Save space.
+ }
+ dictionary_description_ = parsed_dictionary_description;
+ }
+
+ void Compute(tf::OpKernelContext* context) override {
+ FCP_CHECK(num_expected_inputs_ == context->num_inputs());
+
+ // Use the dictionary_ constructed at setup.
+ OP_REQUIRES(context, dictionary_ != nullptr,
+ tf::errors::InvalidArgument(
+ "DictionaryDescription does not contain a vocabulary. "));
+ absl::Status status = DoCompute(context, *dictionary_);
+ OP_REQUIRES(context, status.ok(),
+ tf::errors::InvalidArgument(std::string(status.message())));
+ }
+
+ protected:
+ // Computes using the given dictionary.
+ virtual absl::Status DoCompute(tf::OpKernelContext* context,
+ const Dictionary& dictionary) = 0;
+
+ private:
+ DictionaryDescription dictionary_description_;
+ std::unique_ptr<Dictionary> dictionary_;
+ const int32_t num_expected_inputs_;
+};
+
+} // namespace
+
+class DictionarySize : public AbstractDictionaryOp {
+ public:
+ explicit DictionarySize(tf::OpKernelConstruction* context)
+ : AbstractDictionaryOp(context, 0 /* num_expected_inputs */) {}
+
+ protected:
+ absl::Status DoCompute(tf::OpKernelContext* context,
+ const Dictionary& dictionary) override {
+ tf::Tensor* size_tensor;
+ auto status =
+ context->allocate_output(0, tf::TensorShape({}), &size_tensor);
+ if (!status.ok()) {
+#if TF_GRAPH_DEF_VERSION < 1467
+ return absl::InternalError(status.error_message());
+#else
+ return absl::InternalError(status.message());
+#endif
+ }
+ size_tensor->flat<int64_t>()(0) = dictionary.Size();
+ return absl::OkStatus();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("DictionarySize").Device(tf::DEVICE_CPU),
+ DictionarySize);
+REGISTER_OP("DictionarySize")
+ .Output("size: int64")
+ .Attr("dictionary_description_proto: string = ''")
+ .SetShapeFn(::tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Returns the number of ids in the given dictionary.
+
+The dictionary should be fully specified at construction time via the
+dictionary_description_proto.
+
+dictionary_description_proto: A `DictionaryDescription` as a string.
+)doc");
+
+class DictionaryLookup : public AbstractDictionaryOp {
+ public:
+ explicit DictionaryLookup(tf::OpKernelConstruction* context)
+ : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {}
+
+ protected:
+ absl::Status DoCompute(tf::OpKernelContext* context,
+ const Dictionary& dictionary) override {
+ const tf::Tensor& token_tensor = context->input(0);
+ tf::Tensor* ids_tensor;
+ auto status =
+ context->allocate_output(0, token_tensor.shape(), &ids_tensor);
+ if (!status.ok()) {
+#if TF_GRAPH_DEF_VERSION < 1467
+ return absl::InternalError(status.error_message());
+#else
+ return absl::InternalError(status.message());
+#endif
+ }
+
+ if (token_tensor.dtype() != tf::DataType::DT_STRING) {
+ return absl::InvalidArgumentError("Expected input of 'tokens'.");
+ }
+ if (ids_tensor->dtype() != tf::DataType::DT_INT64) {
+ return absl::InvalidArgumentError("Expected output of 'ids'.");
+ }
+ if (token_tensor.shape() != ids_tensor->shape()) {
+ return absl::InvalidArgumentError("Wrong shape for ids_tensor");
+ }
+ const auto tokens_flat = token_tensor.flat<tf::tstring>();
+ auto ids_flat = ids_tensor->flat<int64_t>();
+ const int64_t num_tokens = tokens_flat.size();
+ for (int i = 0; i < num_tokens; ++i) {
+ ids_flat(i) = dictionary.TokenToId(tokens_flat(i));
+ }
+ return absl::OkStatus();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("DictionaryLookup").Device(tf::DEVICE_CPU),
+ DictionaryLookup);
+REGISTER_OP("DictionaryLookup")
+ .Input("tokens: string")
+ .Output("token_ids: int64")
+ .Attr("dictionary_description_proto: string = ''")
+ .SetShapeFn(::tensorflow::shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Maps each string to an id by lookup in the dictionary.
+
+The dictionary should be fully specified at construction time via the
+dictionary_description_proto. Output has the same shape as input.
+
+tokens: A `Tensor` of strings to lookup in the dictionary.
+dictionary_description_proto: A `DictionaryDescription` as a string.
+)doc");
+
+class DictionaryReverseLookup : public AbstractDictionaryOp {
+ public:
+ explicit DictionaryReverseLookup(tf::OpKernelConstruction* context)
+ : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {}
+
+ protected:
+ absl::Status DoCompute(tf::OpKernelContext* context,
+ const Dictionary& dictionary) override {
+ const tf::Tensor& ids_tensor = context->input(0);
+ tf::Tensor* token_tensor;
+ auto status =
+ context->allocate_output(0, ids_tensor.shape(), &token_tensor);
+ if (!status.ok()) {
+#if TF_GRAPH_DEF_VERSION < 1467
+ return absl::InternalError(status.error_message());
+#else
+ return absl::InternalError(status.message());
+#endif
+ }
+
+ if (token_tensor->dtype() != tf::DataType::DT_STRING) {
+ return absl::InvalidArgumentError("Expected input of 'tokens'.");
+ }
+ if (ids_tensor.dtype() != tf::DataType::DT_INT64) {
+ return absl::InvalidArgumentError("Expected output of 'ids'.");
+ }
+ if (ids_tensor.shape() != token_tensor->shape()) {
+ return absl::InvalidArgumentError("Wrong shape for token_tensor");
+ }
+
+ const auto ids_flat = ids_tensor.flat<int64_t>();
+ auto tokens_flat = token_tensor->flat<tf::tstring>();
+ const int64_t num_tokens = ids_flat.size();
+ for (int i = 0; i < num_tokens; ++i) {
+ tokens_flat(i) = dictionary.IdToToken(ids_flat(i));
+ }
+ return absl::OkStatus();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("DictionaryReverseLookup").Device(tf::DEVICE_CPU),
+ DictionaryReverseLookup);
+REGISTER_OP("DictionaryReverseLookup")
+ .Input("token_ids: int64")
+ .Output("tokens: string")
+ .Attr("dictionary_description_proto: string = ''")
+ .SetShapeFn(::tensorflow::shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Maps each id to its string by performing a reverse lookup in the dictionary.
+
+The dictionary should be fully specified at construction time via the
+dictionary_description_proto. Output has the same shape as input.
+
+token_ids: A `Tensor` of int64 ids to lookup in the dictionary.
+dictionary_description_proto: A `DictionaryDescription` as a string.
+)doc");
+} // namespace tensorflow
+} // namespace fcp
diff --git a/fcp/tensorflow/dictionary_ops.py b/fcp/tensorflow/dictionary_ops.py
new file mode 100644
index 0000000..b168087
--- /dev/null
+++ b/fcp/tensorflow/dictionary_ops.py
@@ -0,0 +1,372 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Python and TensorFlow functions to work with dictionaries.
+
+Please see fcp/dictionary/dictionary.h for more on this type of
+dictionary.
+
+Python Classes:
+
+* `Dictionary`: A Python analogue to fcp/dictionary/dictionary.h
+ that includes additional helpers for dictionary construction.
+
+TensorFlow ops:
+
+* dictionary_size
+ Queries the size of a dictionary.
+
+* dictionary_lookup
+ Looks up ids for string tokens in the dictionary.
+
+* dictionary_reverse_lookup
+ Looks up string tokens from ids in the dictionary.
+
+Canonical use (note that the dictionary is known at graph construction time):
+ dictionary = Dictionary.from_tokens(
+ tokens=['some', 'token', 'list'], unk_id=0,
+ vocabulary_type=VocabularyType.TOKEN_INDEX)
+
+ with tf.Graph().as_default():
+ tokens = tf.compat.v1.placeholder(tf.String, ...) # Tokens to look up.
+ ids = dictionary_lookup(
+ tokens, dictionary.dictionary_description_proto)
+"""
+
+import collections
+import enum
+
+import tensorflow as tf
+
+from fcp.dictionary.dictionary_pb2 import DictionaryDescription # pylint: disable=g-importing-member
+from fcp.tensorflow.gen_dictionary_ops import dictionary_lookup
+from fcp.tensorflow.gen_dictionary_ops import dictionary_reverse_lookup
+from fcp.tensorflow.gen_dictionary_ops import dictionary_size
+
+_dictionary_ops = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile('./_dictionary_ops.so'))
+
+
+def ignore_ids_mask(token_ids, ignore_ids, name=None):
+ """Creates a bool mask with True everywhere token_ids is not in ignore_ids."""
+ with tf.op_scope([token_ids, ignore_ids], name, 'ignore_ids_mask'):
+ # Yay broadcasting
+ all_check = tf.not_equal(tf.expand_dims(token_ids, -1), ignore_ids)
+ check = tf.reduce_all(all_check, reduction_indices=tf.rank(all_check) - 1)
+ check.set_shape(token_ids.get_shape())
+ return check
+
+
+def mask_and_replace_padding(token_ids,
+ lengths,
+ eos_id=None,
+ special_tokens=(),
+ name=None):
+ """Creates a mask of valid tokens and sets padded values in id space.
+
+ This creates a mask the same shape as token_ids with a boolean indicating
+ if the id was a valid token (i.e not padding or a special token). If
+ provided, this also remaps tokens after lengths to the eos_id. Since the
+ dictionary doesn't map tokens to eos or bos ids, it would generally be the
+ unknown token id which is not correct if you need to predict the eos.
+
+ Args:
+ token_ids: A matrix `Tensor` of integer ids.
+ lengths: A vector `Tensor` of lengths for each row in token_ids.
+ eos_id: The end of sequence id, if provided then all token ids after length
+ in a row will be replaced with `eos_id`.
+ special_tokens: An iterable of special tokens for ids that are not
+ considered valid.
+ name: Name scope for these ops.
+
+ Returns:
+ token_ids: `token_ids` with all tokens after a row's length replaced with
+ eos if provided.
+ mask: A bool `Tensor` the same shape as `token_ids` indicating which tokens
+ are valid.
+ """
+ with tf.op_scope([token_ids, lengths, eos_id, special_tokens], name,
+ 'mask_and_replace_padding'):
+ ranges = tf.range(0, tf.gather(tf.shape(token_ids), 1))
+
+ # Yay! Broadcasting.
+ selected = tf.less(ranges, tf.expand_dims(lengths, -1))
+
+ if eos_id is not None:
+ token_ids = tf.where(
+ selected, token_ids,
+ tf.fill(
+ tf.shape(token_ids), tf.constant(eos_id, dtype=token_ids.dtype)))
+ if special_tokens:
+ mask = tf.logical_and(
+ ignore_ids_mask(token_ids, special_tokens), selected)
+ else:
+ mask = selected
+ return token_ids, mask
+
+tf.no_gradient('DictionarySize')
+tf.no_gradient('DictionaryLookup')
+tf.no_gradient('DictionaryReverseLookup')
+
+
+class VocabularyType(enum.Enum):
+ """Valid vocabulary types for Dictionary construction.
+
+ TOKEN_INDEX: dictionary.dictionary_description contains an embedded map of
+ string names stored in order with ids assigned starting from the lowest
+ non-special id. Preserves order but is not compact.
+ """
+ TOKEN_INDEX = 3
+
+
+class Dictionary(object):
+ """Utility for working with fcp/dictionary/ via TensorFlow."""
+
+ def __init__(
+ self,
+ dictionary_description
+ ):
+ """Creates a dictionary from a dictionary_description.
+
+ Use static from_* constructor methods for building dictionaries from
+ common data types.
+
+ Args:
+ dictionary_description: A `dictionary_pb2.DictionaryDescription`
+ describing the dictionary.
+
+ Raises:
+ ValueError: An invalid dictionary description.
+ """
+ if not isinstance(dictionary_description, DictionaryDescription):
+ raise ValueError('Expected a DictionaryDescription')
+ if not dictionary_description.HasField('vocabulary'):
+ raise ValueError('dictionary_description has no vocabulary')
+
+ self._dictionary_description = dictionary_description
+
+ # Lazily constructed fields for lookup.
+ self._lookup_graph = None
+ self._lookup_placeholder = None
+ self._lookup_result = None
+ self._reverse_lookup_placeholder = None
+ self._reverse_lookup_result = None
+
+ @classmethod
+ def from_tokens(
+ cls,
+ tokens,
+ bos_id=None,
+ eos_id=None,
+ unk_id=None,
+ output_blocklist_tokens=None,
+ output_size=None,
+ vocabulary_type=VocabularyType.TOKEN_INDEX
+ ):
+ """Creates a dictionary from a provided list of tokens.
+
+ The id mappings to token ids depend on the vocabulary_type requested.
+
+ NB: the special tokens must be the first ids [0, num-specials)
+
+ Args:
+ tokens: An unordered iterable of tokens for the dictionary.
+ bos_id: Token id for start of sequence.
+ eos_id: Token id for end of sequence.
+ unk_id: Token id for unknown words.
+ output_blocklist_tokens: A list of vocabulary tokens that should be
+ filtered from predictions (e.g., punctuation, bad words etc.).
+ output_size: If a positive integer, tokens with ids greater than this are
+ automatically added to the output blocklist.
+ vocabulary_type: `VocabularyType` to use, defaults to TOKEN_INDEX.
+
+ Returns:
+ A `Dictionary` instance.
+
+ Raises:
+ ValueError: If the special tokens don't have the lowest ids.
+ ValueError: If there are duplicates in tokens.
+ """
+ dictionary_description = DictionaryDescription()
+
+ # Special ids.
+ special_ids = []
+ if unk_id is not None:
+ dictionary_description.special_ids.unk = unk_id
+ special_ids.append(unk_id)
+ if bos_id is not None:
+ dictionary_description.special_ids.bos = bos_id
+ special_ids.append(bos_id)
+ if eos_id is not None:
+ dictionary_description.special_ids.eos = eos_id
+ special_ids.append(eos_id)
+ if sorted(special_ids) != list(range(len(special_ids))):
+ raise ValueError(
+ 'Special ids must be the first items of the dictionary starting at 0'
+ 'or None. eos: %s; bos %s; unk: %s' % (eos_id, bos_id, unk_id))
+
+ # Vocabulary.
+ if len(tokens) != len(set(tokens)):
+ raise ValueError('Duplicate tokens provided')
+ for token in tokens:
+ if not isinstance(token, (str, bytes)):
+ raise ValueError('Bad type in tokens %s' % token)
+ if vocabulary_type == VocabularyType.TOKEN_INDEX:
+ for token in tokens:
+ dictionary_description.vocabulary.index.token.append(token)
+ else:
+ raise AssertionError('Unsupported vocabulary_type: %s' % vocabulary_type)
+
+ # Output blocklist.
+ output_blocklist_tokens = list(output_blocklist_tokens or [])
+ if output_size:
+ assert output_size >= len(special_ids), (
+ 'Cannot blocklist special tokens via output_size.')
+ assert isinstance(tokens, list) # Make sure order preserving pre-slice.
+ output_blocklist_tokens.extend(tokens[output_size - len(special_ids):])
+ for token in output_blocklist_tokens:
+ assert token in tokens, "Unexpected blocklist token: '%s'" % token
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess:
+ output_blocklist_ids = sess.run(
+ dictionary_lookup(output_blocklist_tokens,
+ dictionary_description.SerializeToString()))
+ dictionary_description.output_blocklist_ids.id.extend(
+ sorted(output_blocklist_ids))
+ assert (len(set(dictionary_description.output_blocklist_ids.id)) == len(
+ output_blocklist_tokens)), 'blocklist contains dups or unks?'
+
+ # Return completed dictionary.
+ return cls(
+ dictionary_description=dictionary_description)
+
+ @classmethod
+ def from_dictionary_description(cls,
+ dictionary_description):
+ """Returns a Dictionary from a DictionaryDescription."""
+ return cls(
+ dictionary_description=dictionary_description)
+
+ def _get_lookup_graph(self):
+ """Returns a graph to use for lookup, reverse lookup, and size queries."""
+ if self._lookup_graph is None:
+ self._lookup_graph = tf.Graph()
+ serialized_description_proto = (
+ self._dictionary_description.SerializeToString())
+ with self._lookup_graph.as_default():
+ self._lookup_placeholder = tf.compat.v1.placeholder(
+ tf.string, shape=None)
+ self._reverse_lookup_placeholder = tf.compat.v1.placeholder(
+ tf.int64, shape=None)
+
+ # Use Dictionary(Op) (without blob) variants.
+ self._lookup_result = dictionary_lookup(
+ self._lookup_placeholder,
+ dictionary_description_proto=serialized_description_proto)
+ self._reverse_lookup_result = dictionary_reverse_lookup(
+ self._reverse_lookup_placeholder,
+ dictionary_description_proto=serialized_description_proto)
+ self._size_result = dictionary_size(
+ dictionary_description_proto=serialized_description_proto)
+
+ return self._lookup_graph
+
+ def lookup(self, tokens):
+ """Maps a list of tokens to a list of ids.
+
+ Args:
+ tokens: A list of tokens to lookup.
+
+ Returns:
+ A list of token ids of the same size.
+
+ Raises:
+ ValueError: If tokens is not a list.
+ """
+ if not isinstance(tokens, list):
+ raise ValueError('lookup expected a list of tokens.')
+
+ with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
+ return sess.run(self._lookup_result, {
+ self._lookup_placeholder: tokens
+ }).tolist()
+
+ def reverse_lookup(self, ids):
+ """Maps a list of ids to tokens.
+
+ Args:
+ ids: A list of ids to map back to tokens.
+
+ Returns:
+ A list of tokens corresponding to those ids.
+
+ Raises:
+ ValueError: If ids is not a list.
+ """
+ if not isinstance(ids, list):
+ raise ValueError('reverse_lookup expected a list of ids.')
+ with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
+ return list(
+ sess.run(self._reverse_lookup_result,
+ {self._reverse_lookup_placeholder: ids}))
+
+ @property
+ def special_ids(self):
+ """Returns a list of special token ids."""
+ return [t for t in [self.unk_id, self.bos_id, self.eos_id] if t is not None]
+
+ @property
+ def eos_id(self):
+ eos_id = self._dictionary_description.special_ids.eos
+ return eos_id if eos_id >= 0 else None
+
+ @property
+ def bos_id(self):
+ bos_id = self._dictionary_description.special_ids.bos
+ return bos_id if bos_id >= 0 else None
+
+ @property
+ def unk_id(self):
+ unk_id = self._dictionary_description.special_ids.unk
+ return unk_id if unk_id >= 0 else None
+
+ @property
+ def size(self):
+ with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
+ return sess.run(self._size_result)
+
+ @property
+ def output_blocklist_ids(self):
+ return list(self._dictionary_description.output_blocklist_ids.id)
+
+ @property
+ def output_blocklist_tokens(self):
+ return self.reverse_lookup(self.output_blocklist_ids)
+
+ @property
+ def tokens(self):
+ return self.reverse_lookup(list(range(len(self.special_ids), self.size)))
+
+ @property
+ def dictionary_description_proto(self):
+ """Serialized proto containing self.dictionary_description."""
+ return self.dictionary_description.SerializeToString()
+
+ @property
+ def dictionary_description(self):
+ """Returns the `DictionaryDescription` proto describing this dictionary.
+ """
+ desc = self._dictionary_description
+ return desc
+
+ def __len__(self):
+ return self.size
diff --git a/fcp/tensorflow/dictionary_ops_test.py b/fcp/tensorflow/dictionary_ops_test.py
new file mode 100644
index 0000000..9fee48a
--- /dev/null
+++ b/fcp/tensorflow/dictionary_ops_test.py
@@ -0,0 +1,110 @@
+from absl.testing import parameterized
+import tensorflow as tf
+from google.protobuf import text_format
+from fcp.dictionary import dictionary_pb2
+from fcp.tensorflow import dictionary_ops
+
+
+class DictionaryOpsTest(tf.test.TestCase, parameterized.TestCase):
+
+ def test_direct_tf_use_literal_dictionary(self):
+ dictionary = dictionary_pb2.DictionaryDescription()
+ text_format.Merge(
+ 'special_ids: < unk: 0 > '
+ 'vocabulary: < '
+ ' index: < token: "a" token: "b" token: "c" token: "d" >'
+ '>',
+ dictionary)
+
+ lookup = dictionary_ops.dictionary_lookup(
+ tf.constant(['a', 'b', 'a', 'a', 'd', 'X']),
+ dictionary_description_proto=dictionary.SerializeToString())
+ with tf.compat.v1.Session() as sess:
+ tokenized = sess.run(lookup)
+ self.assertEqual([1, 2, 1, 1, 4, 0], tokenized.tolist())
+
+ @parameterized.named_parameters(
+ ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
+ def test_build_dictionary_with_output_blocklist(self, vocabulary_type):
+ # Build a dictionary, explicitly blocklisting the first token and
+ # implicitly blocklisting the last token via output_size.
+ dictionary = dictionary_ops.Dictionary.from_tokens(
+ ['01', '02', '10', '11'],
+ unk_id=0,
+ output_blocklist_tokens=['01'],
+ output_size=4,
+ vocabulary_type=vocabulary_type)
+
+ if vocabulary_type in (
+ dictionary_ops.VocabularyType.TOKEN_INDEX,
+ ):
+ result = dictionary_ops.dictionary_lookup(
+ [['01', '02', '10', '11', '12']],
+ dictionary_description_proto=dictionary.dictionary_description_proto)
+
+ with tf.compat.v1.Session() as sess:
+ tokenized = sess.run(result)
+ self.assertEqual([[1, 2, 3, 4, 0]], tokenized.tolist())
+ self.assertEqual(
+ [1, 4], list(dictionary.dictionary_description.output_blocklist_ids.id))
+
+ @parameterized.named_parameters(
+ ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
+ def test_build_dictionary(self, vocabulary_type):
+ dictionary = dictionary_ops.Dictionary.from_tokens(
+ ['A', 'a', 'B', 'c'],
+ unk_id=0,
+ vocabulary_type=vocabulary_type)
+
+ result = dictionary_ops.dictionary_lookup(
+ [['A', 'a', 'B', 'b', 'C', 'c', 'D', 'd']],
+ dictionary_description_proto=dictionary.dictionary_description_proto)
+ expected = [[1, 2, 3, 0, 0, 4, 0, 0]]
+ with tf.compat.v1.Session() as sess:
+ tokenized = sess.run(result)
+ self.assertEqual(expected, tokenized.tolist())
+
+ @parameterized.named_parameters(
+ ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
+ def test_dictionary_should_raise_with_duplicate_tokens(self, vocabulary_type):
+ with self.assertRaisesRegex(ValueError, 'Duplicate tokens'):
+ dictionary_ops.Dictionary.from_tokens(['01', '02', '11', '10', '11'],
+ vocabulary_type=vocabulary_type)
+
+ @parameterized.named_parameters(
+ ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
+ def test_lookup_in_python(self, vocabulary_type):
+ dictionary = dictionary_ops.Dictionary.from_tokens(
+ ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type)
+ self.assertLen(dictionary, 5)
+ self.assertListEqual([1, 2, 3, 4, 0],
+ dictionary.lookup(['01', '02', '10', '11', '12']))
+
+ @parameterized.named_parameters(
+ ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
+ def test_reverse_lookup_in_python(self, vocabulary_type):
+ dictionary = dictionary_ops.Dictionary.from_tokens(
+ ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type)
+ self.assertLen(dictionary, 5)
+ rlookup = [
+ t.decode('utf-8') for t in dictionary.reverse_lookup([3, 2, 1, 4, 0])
+ ]
+ self.assertListEqual(['10', '02', '01', '11', ''], rlookup)
+
+ def test_literal_dictionary_in_python(self):
+ dictionary_description = dictionary_pb2.DictionaryDescription()
+ text_format.Merge(
+ 'special_ids: < unk: 0 > '
+ 'vocabulary: < '
+ ' index: < token: "a" token: "b" token: "c" token: "d" >'
+ '>',
+ dictionary_description)
+ dictionary = dictionary_ops.Dictionary.from_dictionary_description(
+ dictionary_description)
+ self.assertListEqual([b'a', b'b', b'c', b'd'], dictionary.tokens)
+
+
+if __name__ == '__main__':
+ # Required since the test still relies on v1 Session.run behavior.
+ tf.compat.v1.disable_v2_behavior()
+ tf.test.main()
diff --git a/fcp/tensorflow/example_selector_fuser.py b/fcp/tensorflow/example_selector_fuser.py
new file mode 100644
index 0000000..86a6ed5
--- /dev/null
+++ b/fcp/tensorflow/example_selector_fuser.py
@@ -0,0 +1,50 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Provides the `example_selector_fuser` operation.
+
+This wraps the generated op and ensures that necessary shared libraries
+are loaded.
+"""
+
+from typing import Optional
+
+import tensorflow as tf
+
+from fcp.tensorflow import gen_example_selector_fuser_op
+
+_example_selector_fuser_op_so = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile(
+ './_example_selector_fuser_op.so'))
+
+
+def example_selector_fuser(example_selector: tf.Tensor,
+ resumption_token_type_url: tf.Tensor,
+ resumption_token_content: tf.Tensor,
+ name: Optional[str] = None) -> tf.Operation:
+ """Fills the resumption token of an existing ExampleSelector message.
+
+ Args:
+ example_selector: The serialized ExampleSelector message.
+ resumption_token_type_url: The type URL of the resumption token.
+ resumption_token_content: The serialized content of the resumption token.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created `Operation`.
+ """
+ return gen_example_selector_fuser_op.example_selector_fuser(
+ example_selector,
+ resumption_token_type_url,
+ resumption_token_content,
+ name=name)
diff --git a/fcp/tensorflow/example_selector_fuser_op.cc b/fcp/tensorflow/example_selector_fuser_op.cc
new file mode 100644
index 0000000..684336c
--- /dev/null
+++ b/fcp/tensorflow/example_selector_fuser_op.cc
@@ -0,0 +1,104 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <string>
+
+#include "fcp/protos/plan.pb.h"
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace fcp {
+
+using ::google::internal::federated::plan::ExampleSelector;
+using ::tensorflow::DEVICE_CPU;
+using ::tensorflow::OpKernel;
+using ::tensorflow::OpKernelConstruction;
+using ::tensorflow::OpKernelContext;
+using ::tensorflow::Tensor;
+using ::tensorflow::data::ParseScalarArgument;
+
+/**
+ * ExampleSelectorFuserOp op-kernel.
+ *
+ * ExampleSelectorFuser fills the resumption token field for an existing
+ * ExampleSelector protobuf message. The resumption token field is an Any proto
+ * which can be any user defined protobuf message. The user needs to provide the
+ * type url and content for the resumption token.
+ *
+ * Inputs:
+ * example_selector: A string scalar encodes an ExampleSelector protobuf
+ * message.
+ * resumption_token_type_url: String scalar. The type_url for the resumption
+ * token.
+ * resumption_token_content: String scalar. The bytes for the resumption
+ * token message.
+ *
+ * Output:
+ * A string tensor contains the fused ExampleSelector message serialized to
+ * string.
+ */
+class ExampleSelectorFuserOp : public OpKernel {
+ public:
+ explicit ExampleSelectorFuserOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ tensorflow::tstring example_selector_str;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<tensorflow::tstring>(
+ ctx, "example_selector", &example_selector_str));
+ tensorflow::tstring resumption_token_type_url_str;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<tensorflow::tstring>(
+ ctx, "resumption_token_type_url",
+ &resumption_token_type_url_str));
+ tensorflow::tstring resumption_token_content_str;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<tensorflow::tstring>(
+ ctx, "resumption_token_content",
+ &resumption_token_content_str));
+ ExampleSelector example_selector;
+ if (!example_selector.ParseFromString(
+ std::string(example_selector_str.data()))) {
+ ctx->SetStatus(tensorflow::Status(
+ // Remove the cast after TF 2.12 is released and used in FCP.
+ static_cast<tensorflow::errors::Code>(
+ absl::StatusCode::kInvalidArgument),
+ tensorflow::StringPiece("Cannot parse ExampleSelector")));
+ return;
+ }
+ example_selector.mutable_resumption_token()->set_type_url(
+ std::string(resumption_token_type_url_str.data()));
+ example_selector.mutable_resumption_token()->set_value(
+ std::string(resumption_token_content_str.data()));
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output_tensor));
+ output_tensor->flat<tensorflow::tstring>()(0) =
+ example_selector.SerializeAsString();
+ }
+};
+
+REGISTER_OP("ExampleSelectorFuser")
+ .Input("example_selector: string")
+ .Input("resumption_token_type_url: string")
+ .Input("resumption_token_content: string")
+ .Output("fused_example_selector: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape);
+REGISTER_KERNEL_BUILDER(Name("ExampleSelectorFuser").Device(DEVICE_CPU),
+ ExampleSelectorFuserOp);
+} // namespace fcp
diff --git a/fcp/tensorflow/example_selector_fuser_test.py b/fcp/tensorflow/example_selector_fuser_test.py
new file mode 100644
index 0000000..634e78d
--- /dev/null
+++ b/fcp/tensorflow/example_selector_fuser_test.py
@@ -0,0 +1,60 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tensorflow as tf
+
+from fcp.protos import plan_pb2
+from fcp.tensorflow import test_selector_pb2
+
+import fcp.tensorflow.example_selector_fuser as fuser
+
+
+class ExampleSelectorFuserTest(tf.test.TestCase):
+
+ def test_example_selector_fuser(self):
+ selector = plan_pb2.ExampleSelector(collection_uri='app:/test_collection')
+ criteria = test_selector_pb2.TestCriteria(max_examples=10)
+ selector.criteria.Pack(criteria)
+ resumption_token = test_selector_pb2.ResumptionToken(last_index=25)
+ fused_selector_tensor = fuser.example_selector_fuser(
+ tf.convert_to_tensor(selector.SerializeToString(), dtype=tf.string),
+ tf.convert_to_tensor(
+ 'type.googleapis.com/fcp.ResumptionToken', dtype=tf.string),
+ tf.convert_to_tensor(
+ resumption_token.SerializeToString(), dtype=tf.string))
+
+ fused_selector = plan_pb2.ExampleSelector()
+ fused_selector.ParseFromString(fused_selector_tensor.numpy())
+ assert fused_selector.collection_uri == 'app:/test_collection'
+ unpacked_criteria = test_selector_pb2.TestCriteria()
+ assert fused_selector.criteria.Unpack(unpacked_criteria)
+ assert unpacked_criteria.max_examples == 10
+
+ unpacked_token = test_selector_pb2.ResumptionToken()
+ assert fused_selector.resumption_token.Unpack(unpacked_token)
+ assert unpacked_token.last_index == 25
+
+ def test_example_selector_fuser_error(self):
+ resumption_token = test_selector_pb2.ResumptionToken(last_index=25)
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ fuser.example_selector_fuser(
+ tf.convert_to_tensor(b'1234', dtype=tf.string),
+ tf.convert_to_tensor(
+ 'type.googleapis.com/fcp.ResumptionToken', dtype=tf.string),
+ tf.convert_to_tensor(
+ resumption_token.SerializeToString(), dtype=tf.string))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/fcp/tensorflow/external_dataset.h b/fcp/tensorflow/external_dataset.h
new file mode 100644
index 0000000..0b2559f
--- /dev/null
+++ b/fcp/tensorflow/external_dataset.h
@@ -0,0 +1,179 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TENSORFLOW_EXTERNAL_DATASET_H_
+#define FCP_TENSORFLOW_EXTERNAL_DATASET_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "fcp/base/bounds.h"
+#include "fcp/tensorflow/host_object.h"
+
+namespace fcp {
+
+/**
+ * Interface for an iterator, created from a particular dataset. A single
+ * dataset may be used to create multiple iterators.
+ */
+class ExternalDatasetIterator {
+ public:
+ virtual ~ExternalDatasetIterator() = default;
+
+ /**
+ * Returns the next element, if possible. Indicates end-of-stream with
+ * OUT_OF_RANGE, even when repeatedly called. Corresponds to
+ * tensorflow::data::IteratorBase::GetNext.
+ *
+ * Implementations must be thread-safe.
+ */
+ virtual absl::StatusOr<std::string> GetNext() = 0;
+};
+
+namespace external_dataset_internal {
+
+template <typename FuncType>
+class DatasetFromFunction;
+
+} // namespace external_dataset_internal
+
+/**
+ * Interface for a particular dataset - created from an ExternalDatasetProvider
+ * (during dataset op execution), for a particular selector. A dataset may be
+ * used zero or more times to create an ExternalDatasetIterator.
+ *
+ * Dataset implementations are often trivial, just needing to capture some
+ * values (like the selector) for the iterator constructor. Consider using
+ * ExternalDataset::FromFunction.
+ */
+class ExternalDataset {
+ public:
+ virtual ~ExternalDataset() = default;
+
+ /**
+ * Creates a new iterator. Corresponds to
+ * tensorflow::data::DatasetBase::MakeIterator.
+ */
+ virtual std::unique_ptr<ExternalDatasetIterator> MakeIterator() = 0;
+
+ /**
+ * Creates an ExternalDataset that wraps a callable object 'f', implementing
+ * MakeIterator(). The lifetime of 'f' is that of the dataset (so,
+ * by-reference lambda captures are almost always unsafe here).
+ */
+ template <typename F>
+ static std::unique_ptr<ExternalDataset> FromFunction(F f) {
+ return std::make_unique<external_dataset_internal::DatasetFromFunction<F>>(
+ std::move(f));
+ }
+};
+
+/**
+ * Interface for an ExternalDataset op's host object.
+ *
+ * An ExternalDatasetProvider is a function from Selector -> ExternalDataset.
+ * Here, 'Selector' is a string provided to the dataset op (typically, an
+ * encoded proto). The returned ExternalDataset may be used (perhaps multiple
+ * times) to create an iterator.
+ *
+ * When implementing a dataset provider and the selector is a proto message,
+ * consider inheritng from ExternalDatasetProvider::UsingProtoSelector<T> (for
+ * some message type T).
+ */
+class ExternalDatasetProvider {
+ public:
+ virtual ~ExternalDatasetProvider() = default;
+
+ /**
+ * Creates a dataset for a given selector.
+ *
+ * This function can usually be implemented succinctly, using
+ * ExternalDataset::FromFunction.
+ *
+ * Corresponds to tensorflow::data::DatasetOpKernel::MakeDataset.
+ */
+ virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
+ absl::string_view selector) = 0;
+
+ /**
+ * Base class for dataset providers that expect a selector of a particular
+ * proto message type. If inheriting from UsingProtoSelector<T>, then one
+ * implements MakeDataset(T) instead of MakeDataset(absl::string_view).
+ */
+ template <typename T>
+ class UsingProtoSelector;
+};
+
+/**
+ * HostObjectRegistry for the ExternalDataset interface.
+ */
+using ExternalDatasetProviderRegistry =
+ HostObjectRegistry<ExternalDatasetProvider>;
+
+namespace external_dataset_internal {
+
+template <typename T>
+absl::StatusOr<T> TryParseProtoSelector(absl::string_view selector) {
+ T msg;
+ if (!msg.ParseFromArray(selector.data(),
+ CastIntegerChecked<int>(selector.size()))) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "Failed to parse selector proto of type ", msg.GetTypeName()));
+ }
+
+ return msg;
+}
+
+template <typename FuncType>
+class DatasetFromFunction : public ExternalDataset {
+ public:
+ explicit DatasetFromFunction(FuncType func) : func_(std::move(func)) {}
+
+ std::unique_ptr<ExternalDatasetIterator> MakeIterator() final {
+ return func_();
+ }
+
+ private:
+ FuncType func_;
+};
+
+} // namespace external_dataset_internal
+
+template <typename T>
+class ExternalDatasetProvider::UsingProtoSelector
+ : public ExternalDatasetProvider {
+ public:
+ absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
+ absl::string_view selector) final {
+ auto maybe_msg =
+ external_dataset_internal::TryParseProtoSelector<T>(selector);
+ if (!maybe_msg.ok()) {
+ return maybe_msg.status();
+ }
+
+ return MakeDataset(std::move(maybe_msg).value());
+ }
+
+ virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
+ T selector) = 0;
+};
+
+} // namespace fcp
+
+#endif // FCP_TENSORFLOW_EXTERNAL_DATASET_H_
diff --git a/fcp/tensorflow/external_dataset.py b/fcp/tensorflow/external_dataset.py
new file mode 100644
index 0000000..8c0a9ab
--- /dev/null
+++ b/fcp/tensorflow/external_dataset.py
@@ -0,0 +1,54 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Provides the 'ExternalDataset' implementation of tf.Data.Dataset.
+
+This wraps the generated op (in external_dataset_py_wrapper).
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from fcp.tensorflow import gen_external_dataset_py
+
+_external_dataset_so = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile(
+ "./_external_dataset_op.so"))
+
+
+class ExternalDataset(tf.data.Dataset):
+ """An ExternalDataset is defined by whomever is running the graph.
+
+ To use an ExternalDataset, the graph must be fed a 'token' indicating what
+ external dataset to use. It also takes a 'selector' input - an opaque string,
+ to be interpreted by that external implementation.
+ """
+
+ def __init__(self, token, selector):
+ token = tf.convert_to_tensor(token, dtype=tf.string, name="token")
+ selector = tf.convert_to_tensor(selector, dtype=tf.string, name="selector")
+ variant_tensor = gen_external_dataset_py.ExternalDataset(
+ token=token, selector=selector)
+ super(ExternalDataset, self).__init__(variant_tensor)
+
+ @property
+ def element_spec(self):
+ return tf.TensorSpec([], tf.string)
+
+ def _inputs(self):
+ return []
diff --git a/fcp/tensorflow/external_dataset_op.cc b/fcp/tensorflow/external_dataset_op.cc
new file mode 100644
index 0000000..16a373b
--- /dev/null
+++ b/fcp/tensorflow/external_dataset_op.cc
@@ -0,0 +1,224 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/str_format.h"
+#include "fcp/base/random_token.h"
+#include "fcp/tensorflow/external_dataset.h"
+#include "fcp/tensorflow/status.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/version.h"
+
+namespace fcp {
+
+/**
+ * ExternalDataset op-kernel. Delegates to an ExternalDatasetProvider, found
+ * from the ExternalDatasetProviderRegistry (a HostObjectRegistry).
+ *
+ * Inputs:
+ * selector: An opaque string scalar. Forwarded to the stub.
+ * token: String scalar. It should encode a token obtained from
+ * ExternalDatasetProviderRegistry::Register.
+ *
+ * See TensorFlow's guide to making custom dataset ops:
+ * https://www.tensorflow.org/guide/extend/formats
+ */
+class ExternalDatasetOp : public tensorflow::data::DatasetOpKernel {
+ public:
+ using tensorflow::data::DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(tensorflow::OpKernelContext* ctx,
+ tensorflow::data::DatasetBase** output) override {
+ tensorflow::tstring token_str;
+ OP_REQUIRES_OK(ctx,
+ tensorflow::data::ParseScalarArgument<tensorflow::tstring>(
+ ctx, "token", &token_str));
+ absl::Span<char const> token_bytes = token_str;
+ OP_REQUIRES(ctx, token_bytes.size() == kRandomTokenSizeInBytes,
+ tensorflow::errors::InvalidArgument(absl::StrFormat(
+ "Tokens have a fixed size. Expected: %d; Actual %d",
+ kRandomTokenSizeInBytes, token_bytes.size())));
+ RandomToken token = RandomToken::FromBytes(token_bytes);
+
+ tensorflow::tstring selector_str;
+ OP_REQUIRES_OK(ctx,
+ tensorflow::data::ParseScalarArgument<tensorflow::tstring>(
+ ctx, "selector", &selector_str));
+
+ std::optional<std::shared_ptr<ExternalDatasetProvider>> maybe_provider =
+ ExternalDatasetProviderRegistry::TryLookup(token);
+ OP_REQUIRES(ctx, maybe_provider.has_value(),
+ tensorflow::errors::InvalidArgument(
+ "A dataset provider is not currently registered for the "
+ "provided token: ",
+ token.ToPrintableString()));
+
+ std::shared_ptr<ExternalDatasetProvider> provider =
+ *std::move(maybe_provider);
+ StatusOr<std::unique_ptr<ExternalDataset>> maybe_dataset =
+ provider->MakeDataset(selector_str);
+ // The provider might not like the given selector.
+ if (!maybe_dataset.ok()) {
+ ctx->SetStatus(ConvertToTensorFlowStatus(maybe_dataset.status()));
+ return;
+ }
+
+ *output = new Dataset(ctx, std::move(maybe_dataset).value());
+ }
+
+ private:
+ class Dataset : public tensorflow::data::DatasetBase {
+ public:
+ Dataset(tensorflow::OpKernelContext* ctx,
+ std::unique_ptr<ExternalDataset> stub)
+ : DatasetBase(tensorflow::data::DatasetContext(ctx)),
+ stub_(std::move(stub)) {}
+
+ std::unique_ptr<tensorflow::data::IteratorBase> MakeIteratorInternal(
+ const std::string& prefix) const override {
+ std::unique_ptr<ExternalDatasetIterator> iter = stub_->MakeIterator();
+ Iterator::Params params{
+ this, tensorflow::strings::StrCat(prefix, "::ExternalDataset")};
+ return std::unique_ptr<tensorflow::data::IteratorBase>(
+ new Iterator(params, std::move(iter)));
+ }
+
+ // Each iterator element is just a scalar string.
+
+ const tensorflow::DataTypeVector& output_dtypes() const override {
+ static auto* const dtypes =
+ new tensorflow::DataTypeVector({tensorflow::DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<tensorflow::PartialTensorShape>& output_shapes()
+ const override {
+ static std::vector<tensorflow::PartialTensorShape>* shapes =
+ new std::vector<tensorflow::PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ std::string DebugString() const override {
+ return "ExternalDatasetOp::Dataset";
+ }
+
+ tensorflow::Status InputDatasets(
+ std::vector<const DatasetBase*>* inputs) const override {
+ // ExternalDatast has no input datasets, so just return OK.
+ return tensorflow::OkStatus();
+ }
+
+// The `DatasetBase::CheckExternalState()` method was introduced on 8/7/2019. We
+// use the `TF_GRAPH_DEF_VERSION` value (which is updated daily) to determine if
+// we should add its override.
+#if TF_GRAPH_DEF_VERSION > 125
+ tensorflow::Status CheckExternalState() const override {
+ return tensorflow::OkStatus();
+ }
+#endif
+
+ protected:
+ tensorflow::Status AsGraphDefInternal(
+ tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b,
+ tensorflow::Node** output) const override {
+ return ::tensorflow::errors::Unimplemented(
+ DebugString(), " does not support serialization.");
+ }
+
+ private:
+ class Iterator : public tensorflow::data::DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params,
+ std::unique_ptr<ExternalDatasetIterator> stub)
+ : DatasetIterator<Dataset>(params), stub_(std::move(stub)) {}
+
+ tensorflow::Status GetNextInternal(
+ tensorflow::data::IteratorContext* ctx,
+ std::vector<tensorflow::Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ StatusOr<std::string> maybe_element;
+ {
+ absl::MutexLock _(&mu_);
+ maybe_element = stub_->GetNext();
+ }
+
+ if (maybe_element.ok()) {
+ std::string element = std::move(maybe_element).value();
+
+ // The {} at the end specifies a scalar tensor.
+ tensorflow::Tensor element_tensor(ctx->allocator({}),
+ tensorflow::DT_STRING, {});
+ element_tensor.scalar<tensorflow::tstring>()() = element;
+
+ *end_of_sequence = false;
+ out_tensors->push_back(std::move(element_tensor));
+ return tensorflow::OkStatus();
+ } else {
+ *end_of_sequence = true;
+ if (maybe_element.status().code() == StatusCode::kOutOfRange) {
+ return tensorflow::OkStatus();
+ } else {
+ return ConvertToTensorFlowStatus(maybe_element.status());
+ }
+ }
+ }
+
+ protected:
+ tensorflow::Status SaveInternal(
+// `::tensorflow::data::SerializationContext` argument was added on
+// 2020-03-17 when `TF_GRAPH_DEF_VERSION` was defined to 343.
+#if TF_GRAPH_DEF_VERSION > 343
+ tensorflow::data::SerializationContext* ctx,
+#endif
+ tensorflow::data::IteratorStateWriter* writer) override {
+ return ::tensorflow::errors::Unimplemented(
+ "Save / Restore of an ExternalDataset iterator is not supported");
+ }
+ tensorflow::Status RestoreInternal(
+ tensorflow::data::IteratorContext* ctx,
+ tensorflow::data::IteratorStateReader* reader) override {
+ return ::tensorflow::errors::Unimplemented(
+ "Save / Restore of an ExternalDataset iterator is not supported");
+ }
+
+ private:
+ std::unique_ptr<ExternalDatasetIterator> stub_;
+ absl::Mutex mu_;
+ };
+
+ // Private members of Dataset
+
+ std::unique_ptr<ExternalDataset> stub_;
+ };
+};
+
+REGISTER_OP("ExternalDataset")
+ .Input("token: string")
+ .Input("selector: string")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape);
+
+REGISTER_KERNEL_BUILDER(Name("ExternalDataset").Device(tensorflow::DEVICE_CPU),
+ ExternalDatasetOp);
+
+} // namespace fcp
diff --git a/fcp/tensorflow/external_dataset_op_test.cc b/fcp/tensorflow/external_dataset_op_test.cc
new file mode 100644
index 0000000..b326b0a
--- /dev/null
+++ b/fcp/tensorflow/external_dataset_op_test.cc
@@ -0,0 +1,320 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Tests for the ExternalDataset op.
+ *
+ * In make_external_dataset_test_graph.py, we generate a GraphDef that connects
+ * an ExternalDataset (producing serialized tf.Example protos), through
+ * tf.parse_example (parsing each to an int64_t scalar), to a Reduce (sum).
+ *
+ * Here, we load that graph and try to run it, with some ExternalDataset
+ * provided as we call Session::Run. We just need to Run once since Reduce
+ * should consume the entire dataset iterator.
+ */
+
+#include <fcntl.h>
+#include <stdint.h>
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/tensorflow/external_dataset.h"
+#include "fcp/tensorflow/test_selector.pb.h"
+#include "google/protobuf/io/zero_copy_stream.h"
+#include "google/protobuf/io/zero_copy_stream_impl.h"
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/protobuf/error_codes.pb.h"
+#include "tensorflow/core/public/session.h"
+
+namespace fcp {
+
+using ::testing::Eq;
+
+//
+// Constants related to the GraphDef we test with
+// See make_external_dataset_test_graph.py
+//
+
+char const* const kExampleGraphPath =
+ "fcp/tensorflow/external_dataset_test.pbtxt";
+char const* const kFeatureName = "val";
+char const* const kTokenPlaceholderName = "token";
+char const* const kSelectorPlaceholderName = "selector";
+char const* const kOutputName = "total:0";
+
+//
+// TensorFlow boilerplate
+//
+
+tensorflow::GraphDef LoadExampleGraph() {
+ int fd = open(kExampleGraphPath, O_RDONLY);
+ FCP_CHECK(fd != -1) << "Failed to open the example graph, using path "
+ << kExampleGraphPath;
+
+ google::protobuf::io::FileInputStream fs(fd);
+ fs.SetCloseOnDelete(true);
+
+ tensorflow::GraphDef graph;
+ bool parsed = google::protobuf::TextFormat::Parse(&fs, &graph);
+ FCP_CHECK(parsed) << "Invalid text-format GraphDef";
+
+ return graph;
+}
+
+std::unique_ptr<tensorflow::Session> PrepareExampleGraphSession() {
+ tensorflow::GraphDef graph = LoadExampleGraph();
+
+ std::unique_ptr<tensorflow::Session> session;
+ {
+ tensorflow::SessionOptions options;
+ tensorflow::Session* raw_session = nullptr;
+ tensorflow::Status session_new_status =
+ tensorflow::NewSession(options, &raw_session);
+ TF_CHECK_OK(session_new_status);
+ session = std::unique_ptr<tensorflow::Session>(raw_session);
+ }
+
+ tensorflow::Status graph_build_status = session->Create(graph);
+ TF_CHECK_OK(graph_build_status);
+ return session;
+}
+
+tensorflow::Example MakeExample(int64_t value) {
+ tensorflow::Example example;
+ tensorflow::AppendFeatureValues({value}, kFeatureName, &example);
+ return example;
+}
+
+std::string SerializeExample(int64_t value) {
+ std::string serialized;
+ FCP_CHECK(MakeExample(value).SerializeToString(&serialized));
+ return serialized;
+}
+
+tensorflow::Status RunSession(tensorflow::Session* session,
+ RandomToken dataset_token,
+ tensorflow::Tensor selector,
+ tensorflow::Tensor* output) {
+ auto token_tensor =
+ tensorflow::test::AsScalar<tensorflow::tstring>(dataset_token.ToString());
+
+ std::vector<tensorflow::Tensor> outputs;
+ tensorflow::Status run_status =
+ session->Run({{kTokenPlaceholderName, token_tensor},
+ {kSelectorPlaceholderName, selector}},
+ {kOutputName}, {}, &outputs);
+
+ if (run_status.ok() && output) {
+ FCP_CHECK(outputs.size() == 1);
+ *output = outputs[0];
+ }
+
+ return run_status;
+}
+
+tensorflow::Status RunSession(tensorflow::Session* session,
+ RandomToken dataset_token,
+ TestSelector const& selector,
+ tensorflow::Tensor* output) {
+ std::string selector_str;
+ FCP_CHECK(selector.SerializeToString(&selector_str));
+ auto selector_tensor =
+ tensorflow::test::AsScalar<tensorflow::tstring>(selector_str);
+ return RunSession(session, dataset_token, selector_tensor, output);
+}
+
+tensorflow::Tensor RunSessionAndGetOutput(tensorflow::Session* session,
+ RandomToken dataset_token,
+ TestSelector const& selector) {
+ tensorflow::Tensor output;
+ tensorflow::Status run_status =
+ RunSession(session, dataset_token, selector, &output);
+ TF_CHECK_OK(run_status);
+ return output;
+}
+
+//
+// ExternalDataset host object implementations for testing
+//
+
+class TestDatasetIterator : public ExternalDatasetIterator {
+ public:
+ explicit TestDatasetIterator(
+ std::shared_ptr<std::vector<int64_t> const> examples,
+ int64_t lower_inclusive, int64_t upper_inclusive)
+ : examples_(std::move(examples)),
+ lower_inclusive_(lower_inclusive),
+ upper_inclusive_(upper_inclusive) {}
+
+ absl::StatusOr<std::string> GetNext() final {
+ while (index_ < examples_->size()) {
+ int64_t ex = examples_->at(index_);
+ index_++;
+
+ if (ex >= lower_inclusive_ && ex < upper_inclusive_) {
+ return SerializeExample(ex);
+ }
+ }
+
+ return absl::OutOfRangeError("");
+ }
+
+ private:
+ std::shared_ptr<std::vector<int64_t> const> examples_;
+ int index_ = 0;
+ int64_t lower_inclusive_;
+ int64_t upper_inclusive_;
+};
+
+class TestDatasetProvider
+ : public ExternalDatasetProvider::UsingProtoSelector<TestSelector> {
+ public:
+ explicit TestDatasetProvider(std::vector<int64_t> examples) {
+ auto ex = std::make_shared<std::vector<int64_t>>(std::move(examples));
+ examples_ = std::move(ex);
+ }
+
+ absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
+ TestSelector selector) final {
+ int64_t lower = selector.has_lower_inclusive()
+ ? selector.lower_inclusive().value()
+ : std::numeric_limits<int64_t>::min();
+ int64_t upper = selector.has_upper_inclusive()
+ ? selector.upper_inclusive().value()
+ : std::numeric_limits<int64_t>::max();
+ auto examples = examples_;
+ return ExternalDataset::FromFunction([examples, lower, upper]() {
+ return std::make_unique<TestDatasetIterator>(examples, lower, upper);
+ });
+ }
+
+ private:
+ std::shared_ptr<std::vector<int64_t> const> examples_;
+};
+
+class FailingIterator : public ExternalDatasetIterator {
+ public:
+ absl::StatusOr<std::string> GetNext() final {
+ return absl::NotFoundError("");
+ }
+};
+
+class FailingIteratorDatasetProvider
+ : public ExternalDatasetProvider::UsingProtoSelector<TestSelector> {
+ public:
+ absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
+ TestSelector selector) final {
+ return ExternalDataset::FromFunction(
+ []() { return std::make_unique<FailingIterator>(); });
+ }
+};
+
+//
+// Actual tests
+//
+
+TEST(ExternalDatasetOpTest, RunExampleGraph) {
+ std::vector<int64_t> examples{123, 456, 789};
+
+ // Default selector (no filtering)
+ TestSelector selector;
+
+ tensorflow::Tensor expected = tensorflow::test::AsTensor<tensorflow::int64>(
+ {123 + 456 + 789}, tensorflow::TensorShape({1}));
+
+ auto stub = std::make_shared<TestDatasetProvider>(std::move(examples));
+ auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
+
+ auto session = PrepareExampleGraphSession();
+ tensorflow::Tensor output =
+ RunSessionAndGetOutput(session.get(), stub_reg.token(), selector);
+
+ tensorflow::test::ExpectTensorEqual<tensorflow::int64>(output, expected);
+}
+
+TEST(ExternalDatasetOpTest, RunExampleGraph_SelectorFilter) {
+ std::vector<int64_t> examples{123, 456, 789, 1024};
+
+ TestSelector selector;
+ selector.mutable_lower_inclusive()->set_value(124);
+ selector.mutable_upper_inclusive()->set_value(1023);
+
+ // Expecting some of the examples to be skipped, due to the filter.
+ tensorflow::Tensor expected = tensorflow::test::AsTensor<tensorflow::int64>(
+ {456 + 789}, tensorflow::TensorShape({1}));
+
+ auto stub = std::make_shared<TestDatasetProvider>(std::move(examples));
+ auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
+
+ auto session = PrepareExampleGraphSession();
+ tensorflow::Tensor output =
+ RunSessionAndGetOutput(session.get(), stub_reg.token(), selector);
+
+ tensorflow::test::ExpectTensorEqual<tensorflow::int64>(output, expected);
+}
+
+TEST(ExternalDatasetOpTest, TokenNotFound) {
+ TestSelector selector;
+ auto session = PrepareExampleGraphSession();
+ tensorflow::Status status =
+ RunSession(session.get(), RandomToken::Generate(), selector, nullptr);
+ // Remove the cast after TF 2.12 is released and used in FCP.
+ EXPECT_THAT(
+ status.code(),
+ Eq(static_cast<tsl::errors::Code>(absl::StatusCode::kInvalidArgument)));
+}
+
+TEST(ExternalDatasetOpTest, FailingIterator) {
+ auto stub = std::make_shared<FailingIteratorDatasetProvider>();
+ auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
+
+ TestSelector selector;
+
+ auto session = PrepareExampleGraphSession();
+ tensorflow::Status status =
+ RunSession(session.get(), stub_reg.token(), selector, nullptr);
+ EXPECT_THAT(status.code(), Eq(tensorflow::error::NOT_FOUND));
+}
+
+TEST(ExternalDatasetOpTest, RunExampleGraph_InvalidSelector) {
+ std::vector<int64_t> examples{123};
+
+ // This is interpreted as a varint. The MSB is set, so it asks for another
+ // byte (but there aren't any).
+ std::string bad_selector = "\xFF";
+ tensorflow::Tensor bad_selector_tensor =
+ tensorflow::test::AsScalar<tensorflow::tstring>(bad_selector);
+ auto stub = std::make_shared<TestDatasetProvider>(std::move(examples));
+ auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
+
+ auto session = PrepareExampleGraphSession();
+ tensorflow::Status status =
+ RunSession(session.get(), stub_reg.token(), bad_selector_tensor, nullptr);
+ // Remove the cast after TF 2.12 is released and used in FCP.
+ EXPECT_THAT(
+ status.code(),
+ Eq(static_cast<tsl::errors::Code>(absl::StatusCode::kInvalidArgument)));
+}
+
+} // namespace fcp
diff --git a/fcp/tensorflow/host_object.cc b/fcp/tensorflow/host_object.cc
new file mode 100644
index 0000000..5db753b
--- /dev/null
+++ b/fcp/tensorflow/host_object.cc
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/tensorflow/host_object.h"
+
+#include <utility>
+
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+
+namespace host_object_internal {
+
+std::optional<std::shared_ptr<void>> HostObjectRegistryImpl::TryLookup(
+ RandomToken token) {
+ std::shared_ptr<void> p = nullptr;
+
+ {
+ absl::ReaderMutexLock lock{&mutex_};
+ auto it = objects_.find(token);
+ if (it != objects_.end()) {
+ p = it->second;
+ }
+ }
+
+ if (p == nullptr) {
+ return std::nullopt;
+ } else {
+ return p;
+ }
+}
+
+void HostObjectRegistryImpl::Register(RandomToken token,
+ std::shared_ptr<void> p) {
+ absl::WriterMutexLock lock{&mutex_};
+ auto r = objects_.insert({token, std::move(p)});
+ FCP_CHECK(r.second)
+ << "An object has already been registered with the provided token";
+}
+
+void HostObjectRegistryImpl::Unregister(RandomToken token) {
+ absl::WriterMutexLock lock{&mutex_};
+ size_t erased = objects_.erase(token);
+ FCP_CHECK(erased == 1)
+ << "An object is not currently registered for the provided token";
+}
+
+} // namespace host_object_internal
+
+} // namespace fcp
diff --git a/fcp/tensorflow/host_object.h b/fcp/tensorflow/host_object.h
new file mode 100644
index 0000000..e7803ed
--- /dev/null
+++ b/fcp/tensorflow/host_object.h
@@ -0,0 +1,162 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TENSORFLOW_HOST_OBJECT_H_
+#define FCP_TENSORFLOW_HOST_OBJECT_H_
+
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/random_token.h"
+#include "fcp/base/unique_value.h"
+
+namespace fcp {
+
+/**
+ * Op-kernels are instantiated by TensorFlow, and can only be parameterized by
+ * graph 'attrs' and tensor inputs. So, op-kernels which access the 'outside
+ * world' tend to use ambient, process-global resources - for example, consider
+ * op-kernels which interpret a string tensor as a filesystem path.
+ *
+ * In some uses, we'd like to parameterize an op-kernel on some 'host'-side,
+ * non-Tensor objects (for example, a virtual filesystem) at the site of
+ * Session::Run (i.e. maintaining functional composition).
+ *
+ * This file defines a mechanism to register 'host objects' (in a
+ * HostObjectRegistry) outside of a session, pass them to Session::Run, and
+ * refer to them inside of the graph (and op-kernel implementations) using
+ * DT_STRING scalars ('tokens'). We could instead use DT_VARIANT tensors (which
+ * can wrap C++ objects directly), but DT_STRING is much more convenient to
+ * marshal (for example, Python's Session::Run wrapper accepts Python strings
+ * for placeholder bindings, but not existing Tensor objects).
+ *
+ * To register a host object:
+ * Use HostObjectRegistry<I> for some interface type 'I'. This returns a
+ * HostObjectRegistration object, which de-registers on destruction.
+ * To pass in a host object:
+ * Bind the token() (from the HostObjectRegistration) to some placeholder,
+ * when calling Session::Run.
+ * To access a host object in an op-kernel:
+ * Use HostObjectRegistry<I>::TryLookup (the op should take a DT_STRING scalar
+ * for the token to use).
+ */
+
+namespace host_object_internal {
+
+/**
+ * HostObjectRegistry implementation for a particular interface type.
+ *
+ * For each I, HostObjectRegistry<I> defines a HostObjectRegistryImpl with
+ * static storage duration.
+ */
+class HostObjectRegistryImpl {
+ public:
+ std::optional<std::shared_ptr<void>> TryLookup(RandomToken token);
+ void Register(RandomToken token, std::shared_ptr<void> p);
+ void Unregister(RandomToken token);
+ private:
+ absl::Mutex mutex_;
+ absl::flat_hash_map<RandomToken, std::shared_ptr<void>> objects_
+ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace host_object_internal
+
+/**
+ * Active registration of a host object, under token(). To reference this object
+ * in a TensorFlow graph, pass in token() as a DT_STRING tensor.
+ *
+ * De-registers when destructed. Note that the registered object *may* stay
+ * alive; an op-kernel can retain an std::shared_ptr ref from TryLookup.
+ */
+class HostObjectRegistration final {
+ public:
+ HostObjectRegistration(HostObjectRegistration&&) = default;
+ HostObjectRegistration& operator=(HostObjectRegistration&&) = default;
+
+ ~HostObjectRegistration() {
+ if (token_.has_value()) {
+ registry_->Unregister(*token_);
+ }
+ }
+
+ /**
+ * Token under which the object is registered. It can be passed into a graph
+ * (as a string tensor) and used to look up the object.
+ */
+ RandomToken token() const { return *token_; }
+
+ private:
+ template<typename T>
+ friend class HostObjectRegistry;
+
+ HostObjectRegistration(host_object_internal::HostObjectRegistryImpl* registry,
+ RandomToken token)
+ : registry_(registry), token_(token) {}
+
+ host_object_internal::HostObjectRegistryImpl* registry_;
+ UniqueValue<RandomToken> token_;
+};
+
+/**
+ * Registry of host objects, for a particular interface type.
+ * See file remarks.
+ */
+template<typename T>
+class HostObjectRegistry {
+ public:
+ /**
+ * Registers the provided host object, yielding a new HostObjectRegistration
+ * with a unique token(). The object is de-registered when the
+ * HostObjectRegistration is destructed.
+ */
+ static HostObjectRegistration Register(std::shared_ptr<T> p) {
+ RandomToken token = RandomToken::Generate();
+ GetImpl()->Register(token, std::move(p));
+ return HostObjectRegistration(GetImpl(), token);
+ }
+
+ /**
+ * Looks up a host object. Returns std::nullopt if nothing is currently
+ * registered for the provided token (and interface T).
+ */
+ static std::optional<std::shared_ptr<T>> TryLookup(RandomToken token) {
+ std::optional<std::shared_ptr<void>> maybe_p = GetImpl()->TryLookup(token);
+ if (maybe_p.has_value()) {
+ std::shared_ptr<void> p = *std::move(maybe_p);
+ return std::static_pointer_cast<T>(std::move(p));
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ private:
+ HostObjectRegistry();
+
+ static host_object_internal::HostObjectRegistryImpl* GetImpl() {
+ static auto* global_registry =
+ new host_object_internal::HostObjectRegistryImpl();
+ return global_registry;
+ }
+};
+
+} // namespace fcp
+
+#endif // FCP_TENSORFLOW_HOST_OBJECT_H_
diff --git a/fcp/tensorflow/host_object_test.cc b/fcp/tensorflow/host_object_test.cc
new file mode 100644
index 0000000..9359994
--- /dev/null
+++ b/fcp/tensorflow/host_object_test.cc
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/tensorflow/host_object.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace fcp {
+
+using ::testing::Eq;
+
+class WidgetInterface {
+ public:
+ virtual ~WidgetInterface() = default;
+ virtual void Poke(int value) = 0;
+};
+
+class WidgetImpl : public WidgetInterface {
+ public:
+ void Poke(int value) final {
+ counter_ += value;
+ }
+
+ int counter() const {
+ return counter_;
+ }
+ private:
+ int counter_ = 0;
+};
+
+TEST(HostObjectTest, LookupFailure) {
+ std::optional<std::shared_ptr<WidgetInterface>> p =
+ HostObjectRegistry<WidgetInterface>::TryLookup(RandomToken::Generate());
+ EXPECT_THAT(p, Eq(std::nullopt));
+}
+
+TEST(HostObjectTest, LookupSuccess) {
+ std::shared_ptr<WidgetImpl> obj = std::make_shared<WidgetImpl>();
+ HostObjectRegistration reg =
+ HostObjectRegistry<WidgetInterface>::Register(obj);
+
+ std::optional<std::shared_ptr<WidgetInterface>> p =
+ HostObjectRegistry<WidgetInterface>::TryLookup(reg.token());
+ EXPECT_TRUE(p.has_value());
+
+ (*p)->Poke(123);
+ EXPECT_THAT(obj->counter(), Eq(123));
+ EXPECT_THAT(p->get(), Eq(obj.get()));
+}
+
+TEST(HostObjectTest, Unregister) {
+ std::shared_ptr<WidgetImpl> obj = std::make_shared<WidgetImpl>();
+
+ std::optional<RandomToken> token;
+ {
+ HostObjectRegistration reg =
+ HostObjectRegistry<WidgetInterface>::Register(obj);
+ token = reg.token();
+ }
+
+ std::optional<std::shared_ptr<WidgetInterface>> p =
+ HostObjectRegistry<WidgetInterface>::TryLookup(*token);
+ EXPECT_THAT(p, Eq(std::nullopt));
+}
+
+} // namespace fcp
diff --git a/fcp/tensorflow/make_external_dataset_test_graph.py b/fcp/tensorflow/make_external_dataset_test_graph.py
new file mode 100644
index 0000000..0bb405b
--- /dev/null
+++ b/fcp/tensorflow/make_external_dataset_test_graph.py
@@ -0,0 +1,58 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Prints a GraphDef to stdout (for testing ExternalDataset)."""
+
+import argparse
+import numpy as np
+import tensorflow.compat.v1 as tf
+
+from fcp.tensorflow import external_dataset
+
+
+def _ParseSingleExample(p):
+ # parse_example doesn't like scalars, so we reshape with [-1].
+ features = tf.parse_example(
+ tf.reshape(p, [-1]), {"val": tf.FixedLenFeature([], dtype=tf.int64)})
+ return features["val"]
+
+
+def MakeGraph():
+ """Makes a GraphDef."""
+
+ graph = tf.Graph()
+
+ with graph.as_default():
+ serialized_examples = external_dataset.ExternalDataset(
+ token=tf.placeholder(name="token", dtype=tf.string),
+ selector=tf.placeholder(name="selector", dtype=tf.string))
+
+ examples = serialized_examples.map(_ParseSingleExample)
+
+ total = examples.reduce(np.int64(0), lambda x, y: x + y)
+ total = tf.identity(total, name="total")
+
+ return graph
+
+
+def _ParseArgs():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--output", required=True, type=argparse.FileType("w"))
+ return parser.parse_args()
+
+if __name__ == "__main__":
+ args = _ParseArgs()
+ with args.output:
+ graph_def = MakeGraph().as_graph_def()
+ args.output.write(str(graph_def))
diff --git a/fcp/tensorflow/make_serve_slices_test_graph.py b/fcp/tensorflow/make_serve_slices_test_graph.py
new file mode 100644
index 0000000..e457672
--- /dev/null
+++ b/fcp/tensorflow/make_serve_slices_test_graph.py
@@ -0,0 +1,68 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Writes a GraphDef to a file for testing `ServeSlices`."""
+
+from absl import app
+from absl import flags
+import tensorflow as tf
+
+from fcp.tensorflow import serve_slices
+
+CALLBACK_TOKEN_PLACEHOLDER_TENSOR = 'callback_token'
+SERVED_AT_TENSOR = 'served_at_id'
+SERVER_VAL = (1, 2.0, 'foo')
+MAX_KEY = 44
+SELECT_FN_INITIALIZE_OP = 'init_the_things'
+SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES = ['a', 'b', 'c']
+SELECT_FN_KEY_INPUT_TENSOR_NAME = 'bar'
+SELECT_FN_FILENAME_TENSOR_NAME = 'goofy'
+SELECT_FN_TARGET_TENSOR_NAME = 'goobler'
+
+flags.DEFINE_string('output', None, 'The path to the output file.')
+FLAGS = flags.FLAGS
+
+
+def make_graph():
+ """Builds and returns a `tf.Graph` which calls `ServeSlices`."""
+ graph = tf.Graph()
+ with graph.as_default():
+ # Create a placeholder with a fixed name to allow the code running the graph
+ # to provide input.
+ callback_token = tf.compat.v1.placeholder(
+ name=CALLBACK_TOKEN_PLACEHOLDER_TENSOR, dtype=tf.string)
+ served_at_id = serve_slices.serve_slices(
+ callback_token=callback_token,
+ server_val=SERVER_VAL,
+ max_key=MAX_KEY,
+ select_fn_initialize_op=SELECT_FN_INITIALIZE_OP,
+ select_fn_server_val_input_tensor_names=SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES,
+ select_fn_key_input_tensor_name=SELECT_FN_KEY_INPUT_TENSOR_NAME,
+ select_fn_filename_input_tensor_name=SELECT_FN_FILENAME_TENSOR_NAME,
+ select_fn_target_tensor_name=SELECT_FN_TARGET_TENSOR_NAME)
+ # Create a tensor with a fixed name to allow the code running the graph to
+ # receive output.
+ tf.identity(served_at_id, name=SERVED_AT_TENSOR)
+ return graph
+
+
+def main(argv):
+ del argv
+ graph_def_str = str(make_graph().as_graph_def())
+ with open(FLAGS.output, 'w') as output_file:
+ output_file.write(graph_def_str)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('output')
+ app.run(main)
diff --git a/fcp/tensorflow/make_slices_selector_example_selector.py b/fcp/tensorflow/make_slices_selector_example_selector.py
new file mode 100644
index 0000000..66927d8
--- /dev/null
+++ b/fcp/tensorflow/make_slices_selector_example_selector.py
@@ -0,0 +1,34 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Provides the `make_slices_selector_example_selector` operation.
+
+This wraps the generated op and ensures that necessary shared libraries
+are loaded.
+"""
+
+import tensorflow as tf
+
+from fcp.tensorflow import gen_make_slices_selector_example_selector_py
+
+_make_slices_selector_example_selector_so = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile(
+ './_make_slices_selector_example_selector_op.so'))
+
+
+def make_slices_selector_example_selector(served_at_id, keys):
+ """Serializes a proto `ExampleSelector` containing a `SlicesSelector`."""
+ return gen_make_slices_selector_example_selector_py.make_slices_selector_example_selector(
+ served_at_id=tf.convert_to_tensor(served_at_id, tf.string),
+ keys=tf.convert_to_tensor(keys, tf.int32),
+ )
diff --git a/fcp/tensorflow/make_slices_selector_example_selector_op.cc b/fcp/tensorflow/make_slices_selector_example_selector_op.cc
new file mode 100644
index 0000000..14cf627
--- /dev/null
+++ b/fcp/tensorflow/make_slices_selector_example_selector_op.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string>
+#include <utility>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/strings/str_format.h"
+#include "fcp/client/federated_select.h"
+#include "fcp/protos/plan.pb.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_requires.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/stringpiece.h"
+
+namespace fcp {
+
+namespace {
+
+REGISTER_OP("MakeSlicesSelectorExampleSelector")
+ .Input("served_at_id: string")
+ .Input("keys: int32")
+ .Output("serialized_proto: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape);
+
+class MakeSlicesSelectorExampleSelectorOp : public tensorflow::OpKernel {
+ public:
+ explicit MakeSlicesSelectorExampleSelectorOp(
+ tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+ void Compute(tensorflow::OpKernelContext* context) override {
+ const tensorflow::Tensor* served_at_id_tensor;
+ OP_REQUIRES_OK(context,
+ context->input("served_at_id", &served_at_id_tensor));
+ std::string served_at_id =
+ served_at_id_tensor->scalar<tensorflow::tstring>()();
+
+ const tensorflow::Tensor* keys_tensor;
+ OP_REQUIRES_OK(context, context->input("keys", &keys_tensor));
+ tensorflow::TTypes<int32_t>::ConstFlat keys = keys_tensor->flat<int32_t>();
+
+ google::internal::federated::plan::SlicesSelector slices_selector;
+ slices_selector.set_served_at_id(std::move(served_at_id));
+ slices_selector.mutable_keys()->Reserve(keys.size());
+ for (size_t i = 0; i < keys.size(); i++) {
+ slices_selector.add_keys(keys(i));
+ }
+
+ google::internal::federated::plan::ExampleSelector example_selector;
+ example_selector.mutable_criteria()->PackFrom(slices_selector);
+ example_selector.set_collection_uri(
+ fcp::client::kFederatedSelectCollectionUri);
+ // `resumption_token` not set.
+
+ tensorflow::Tensor* output_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor));
+ output_tensor->scalar<tensorflow::tstring>()() =
+ example_selector.SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MakeSlicesSelectorExampleSelector").Device(tensorflow::DEVICE_CPU),
+ MakeSlicesSelectorExampleSelectorOp);
+
+} // namespace
+
+} // namespace fcp
diff --git a/fcp/tensorflow/make_slices_selector_example_selector_test.py b/fcp/tensorflow/make_slices_selector_example_selector_test.py
new file mode 100644
index 0000000..d131cdf
--- /dev/null
+++ b/fcp/tensorflow/make_slices_selector_example_selector_test.py
@@ -0,0 +1,42 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for `make_slices_selector_example_selector` custom op."""
+
+import tensorflow as tf
+
+from fcp.protos import plan_pb2
+from fcp.tensorflow import make_slices_selector_example_selector
+
+
+class MakeSlicesSelectorExampleSelectorTest(tf.test.TestCase):
+
+ def test_returns_serialized_proto(self):
+ served_at_id = 'test_served_at_id'
+ keys = [1, 3, 5, 20]
+ serialized_proto_tensor = make_slices_selector_example_selector.make_slices_selector_example_selector(
+ served_at_id, keys)
+ self.assertIsInstance(serialized_proto_tensor, tf.Tensor)
+ self.assertEqual(serialized_proto_tensor.dtype, tf.string)
+ serialized_proto = serialized_proto_tensor.numpy()
+ example_selector = plan_pb2.ExampleSelector.FromString(serialized_proto)
+ self.assertEqual(example_selector.collection_uri,
+ 'internal:/federated_select')
+ slices_selector = plan_pb2.SlicesSelector()
+ self.assertTrue(example_selector.criteria.Unpack(slices_selector))
+ self.assertEqual(slices_selector.served_at_id, served_at_id)
+ self.assertEqual(slices_selector.keys, keys)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/fcp/tensorflow/python/BUILD b/fcp/tensorflow/python/BUILD
new file mode 100644
index 0000000..f23574f
--- /dev/null
+++ b/fcp/tensorflow/python/BUILD
@@ -0,0 +1,15 @@
+# pybind11 bindings for //fcp/tensorflow.
+
+load("@rules_python//python:defs.bzl", "py_test")
+
+package(
+ default_visibility = ["//fcp:internal"],
+)
+
+exports_files(["serve_slices_registry.cc"])
+
+py_test(
+ name = "serve_slices_registry_test",
+ srcs = ["serve_slices_registry_test.py"],
+ deps = ["//fcp/tensorflow:serve_slices_py"],
+)
diff --git a/fcp/tensorflow/python/serve_slices_registry.cc b/fcp/tensorflow/python/serve_slices_registry.cc
new file mode 100644
index 0000000..6a860e0
--- /dev/null
+++ b/fcp/tensorflow/python/serve_slices_registry.cc
@@ -0,0 +1,126 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/tensorflow/serve_slices_registry.h"
+
+#include <pybind11/functional.h>
+#include <pybind11/pybind11.h>
+
+#include <functional>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "fcp/base/random_token.h"
+#include "fcp/tensorflow/host_object.h"
+#include "pybind11_abseil/absl_casters.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+
+namespace pybind11::detail {
+
+// Type caster converting a Tensor from C++ to Python.
+template <>
+struct type_caster<tensorflow::Tensor> {
+ PYBIND11_TYPE_CASTER(tensorflow::Tensor, const_name("Tensor"));
+
+ static handle cast(const tensorflow::Tensor& tensor, return_value_policy,
+ handle) {
+ // We'd ideally use tensorflow::TensorToNdarray, but that function isn't
+ // available to code running in custom ops. Instead, we pass the Tensor
+ // as a serialized proto and convert to an ndarray in Python.
+ tensorflow::TensorProto proto;
+ if (tensor.dtype() == tensorflow::DT_STRING) {
+ // Strings encoded using AsProtoTensorContent are incompatible with
+ // tf.make_ndarray.
+ tensor.AsProtoField(&proto);
+ } else {
+ tensor.AsProtoTensorContent(&proto);
+ }
+ std::string serialized = proto.SerializeAsString();
+ return PyBytes_FromStringAndSize(serialized.data(), serialized.size());
+ }
+};
+
+} // namespace pybind11::detail
+
+namespace {
+
+namespace py = ::pybind11;
+
+// A variant of fcp::ServeSlicesCallback with Python-friendly types.
+using ServeSlicesCallback = std::function<std::string(
+ /*callback_token=*/py::bytes,
+ /*server_val=*/std::vector<tensorflow::Tensor>,
+ /*max_key=*/int32_t,
+ /*select_fn_initialize_op=*/std::string,
+ /*select_fn_server_val_input_tensor_names=*/std::vector<std::string>,
+ /*select_fn_key_input_tensor_name=*/absl::string_view,
+ /*select_fn_filename_input_tensor_name=*/absl::string_view,
+ /*select_fn_target_tensor_name=*/absl::string_view)>;
+
+// A fcp::HostObjectRegistration wrapper allowing use as a context manager.
+class ServeSlicesCallbackRegistration {
+ public:
+ explicit ServeSlicesCallbackRegistration(ServeSlicesCallback callback)
+ : callback_(std::move(callback)) {}
+
+ py::bytes enter() {
+ registration_ = fcp::register_serve_slices_callback(
+ [this](fcp::RandomToken callback_token,
+ std::vector<tensorflow::Tensor> server_val, int32_t max_key,
+ std::string select_fn_initialize_op,
+ std::vector<std::string> select_fn_server_val_input_tensor_names,
+ absl::string_view select_fn_key_input_tensor_name,
+ absl::string_view select_fn_filename_input_tensor_name,
+ absl::string_view select_fn_target_tensor_name) {
+ // The GIL isn't normally held in the context of ServeSlicesCallbacks,
+ // which are typically invoked from the ServeSlices TensorFlow op.
+ py::gil_scoped_acquire acquire;
+ return callback_(callback_token.ToString(), std::move(server_val),
+ max_key, std::move(select_fn_initialize_op),
+ std::move(select_fn_server_val_input_tensor_names),
+ select_fn_key_input_tensor_name,
+ select_fn_filename_input_tensor_name,
+ select_fn_target_tensor_name);
+ });
+ return registration_->token().ToString();
+ }
+
+ void exit(py::object, py::object, py::object) { registration_.reset(); }
+
+ private:
+ ServeSlicesCallback callback_;
+ std::optional<fcp::HostObjectRegistration> registration_;
+};
+
+PYBIND11_MODULE(_serve_slices_op, m) {
+ py::class_<ServeSlicesCallbackRegistration>(m,
+ "ServeSlicesCallbackRegistration")
+ .def("__enter__", &ServeSlicesCallbackRegistration::enter)
+ .def("__exit__", &ServeSlicesCallbackRegistration::exit);
+
+ m.def(
+ "register_serve_slices_callback",
+ [](ServeSlicesCallback callback) {
+ return ServeSlicesCallbackRegistration(std::move(callback));
+ },
+ py::return_value_policy::move);
+}
+
+} // namespace
diff --git a/fcp/tensorflow/python/serve_slices_registry_test.py b/fcp/tensorflow/python/serve_slices_registry_test.py
new file mode 100644
index 0000000..53e824b
--- /dev/null
+++ b/fcp/tensorflow/python/serve_slices_registry_test.py
@@ -0,0 +1,76 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for serve_slices_registry."""
+
+from unittest import mock
+
+from absl.testing import absltest
+import numpy as np
+import tensorflow as tf
+
+from fcp.tensorflow import serve_slices
+from fcp.tensorflow import serve_slices as serve_slices_registry
+
+SERVER_VAL = (1, 2.0, b'foo')
+SERVER_VAL_NP_DTYPE = (np.int32, np.float32, object)
+MAX_KEY = 44
+SELECT_FN_INITIALIZE_OP = 'init_the_things'
+SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES = ['a', 'b', 'c']
+SELECT_FN_KEY_INPUT_TENSOR_NAME = 'bar'
+SELECT_FN_FILENAME_TENSOR_NAME = 'goofy'
+SELECT_FN_TARGET_TENSOR_NAME = 'goobler'
+
+
+class ServeSlicesRegistryTest(absltest.TestCase):
+
+ def test_register_serve_slices_callback(self):
+ with tf.Graph().as_default() as graph:
+ # Create a placeholder with a fixed name to allow the code running the
+ # graph to provide input.
+ callback_token = tf.compat.v1.placeholder(dtype=tf.string)
+ served_at_id = serve_slices.serve_slices(
+ callback_token=callback_token,
+ server_val=SERVER_VAL,
+ max_key=MAX_KEY,
+ select_fn_initialize_op=SELECT_FN_INITIALIZE_OP,
+ select_fn_server_val_input_tensor_names=SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES,
+ select_fn_key_input_tensor_name=SELECT_FN_KEY_INPUT_TENSOR_NAME,
+ select_fn_filename_input_tensor_name=SELECT_FN_FILENAME_TENSOR_NAME,
+ select_fn_target_tensor_name=SELECT_FN_TARGET_TENSOR_NAME)
+
+ served_at_value = 'address.at.which.data.is.served'
+ mock_callback = mock.Mock(return_value=served_at_value)
+ with serve_slices_registry.register_serve_slices_callback(
+ mock_callback) as token:
+ with tf.compat.v1.Session(graph=graph) as session:
+ served_at_out = session.run(
+ served_at_id, feed_dict={callback_token: token})
+ self.assertEqual(served_at_out, served_at_value.encode())
+ mock_callback.assert_called_once_with(
+ token,
+ [
+ np.array(v, dtype=dtype)
+ for v, dtype in zip(SERVER_VAL, SERVER_VAL_NP_DTYPE)
+ ],
+ MAX_KEY,
+ SELECT_FN_INITIALIZE_OP,
+ SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES,
+ SELECT_FN_KEY_INPUT_TENSOR_NAME,
+ SELECT_FN_FILENAME_TENSOR_NAME,
+ SELECT_FN_TARGET_TENSOR_NAME,
+ )
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/fcp/tensorflow/serve_slices.py b/fcp/tensorflow/serve_slices.py
new file mode 100644
index 0000000..0ab7872
--- /dev/null
+++ b/fcp/tensorflow/serve_slices.py
@@ -0,0 +1,107 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Provides the `serve_slices` operation.
+
+This wraps the generated op and ensures that necessary shared libraries
+are loaded.
+"""
+
+import tensorflow as tf
+
+from fcp.tensorflow import _serve_slices_op
+from fcp.tensorflow import gen_serve_slices_py
+
+_serve_slices_so = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile('./_serve_slices_op.so'))
+
+
+def _to_tensor_list(list_of_python_values, dtype=None):
+ return [
+ tf.convert_to_tensor(subvalue, dtype=dtype)
+ for subvalue in list_of_python_values
+ ]
+
+
+def serve_slices(callback_token, server_val, max_key, select_fn_initialize_op,
+ select_fn_server_val_input_tensor_names,
+ select_fn_key_input_tensor_name,
+ select_fn_filename_input_tensor_name,
+ select_fn_target_tensor_name):
+ """Calls into a preregistered `callback_token` to serve slices of a value.
+
+ In addition to the arguments to this function, `serve_slices` requires that
+ a TensorFlow graph containing a selection function (`select_fn`) be provided
+ to the server running `serve_slices`. `serve_slices` is responsible for
+ providing the server with the names of the placeholder tensor inputs to the
+ selection function (`select_fn_X_input_tensor_names`,
+ `select_fn_key_input_tensor_name`, and `select_fn_filename_input_tensor_name`)
+ and the target tensor to evalate to ensure that the slice is written to the
+ provided filename (`select_fn_target_tensor_name`).
+
+ Args:
+ callback_token: An string ID corresponding to a callback registered with the
+ `register_serve_slices_callback` function. This function will be invoked
+ when `serve_slices` is called.
+ server_val: A list of arbitrary-typed tensors from which slices may be
+ generated using `select_fn`. These tensors must be passed into the
+ `select_fn` by writing them to the placeholder tensors named by
+ `select_fn_server_val_input_names`, which must contain exactly one tensor
+ name for each tensor in `server_val`.
+ max_key: An integer indicating the maxiumum slice index which may be
+ requested. Slice indices start at zero and may go up to `max_key`
+ (inclusive).
+ select_fn_initialize_op: An op to run before each call to `select_fn` in
+ order to reinitialize any state `select_fn` may contain.
+ select_fn_server_val_input_tensor_names: A list of names of the tensors that
+ make up the `server_val` portion of the inputs to `select_fn`. Must be the
+ same length as the number of tensors in `server_val`.
+ select_fn_key_input_tensor_name: The name of the tensor that is the `key`
+ input to `select_fn`.
+ select_fn_filename_input_tensor_name: The name of the placeholder tensor
+ that is the `filename` input to `select_fn`. The `filename` is used to
+ specify where the resulting slice should be written.
+ select_fn_target_tensor_name: The name of the `target` tensor to run which
+ will result in `select_fn`'s output being written to `filename`.
+
+ Returns:
+ A string identifier given by the underlying callback which can be used by
+ clients to access the generated slices.
+ """
+ return gen_serve_slices_py.serve_slices(
+ callback_token=tf.convert_to_tensor(callback_token, dtype=tf.string),
+ server_val=_to_tensor_list(server_val),
+ max_key=tf.convert_to_tensor(max_key, dtype=tf.int32),
+ select_fn_initialize_op=tf.convert_to_tensor(
+ select_fn_initialize_op, dtype=tf.string),
+ select_fn_server_val_input_tensor_names=_to_tensor_list(
+ select_fn_server_val_input_tensor_names, dtype=tf.string),
+ select_fn_key_input_tensor_name=tf.convert_to_tensor(
+ select_fn_key_input_tensor_name, dtype=tf.string),
+ select_fn_filename_input_tensor_name=tf.convert_to_tensor(
+ select_fn_filename_input_tensor_name, dtype=tf.string),
+ select_fn_target_tensor_name=tf.convert_to_tensor(
+ select_fn_target_tensor_name, dtype=tf.string))
+
+
+def register_serve_slices_callback(callback):
+ """Registers a callback to be invoked by the `ServeSlices` op."""
+ def callback_adapter(callback_token, server_val, *args):
+ # Convert the serialized TensorProtos to ndarrays.
+ tensor_proto = tf.make_tensor_proto(0)
+ converted_server_val = [
+ tf.make_ndarray(tensor_proto.FromString(val)) for val in server_val
+ ]
+ return callback(callback_token, converted_server_val, *args)
+
+ return _serve_slices_op.register_serve_slices_callback(callback_adapter)
diff --git a/fcp/tensorflow/serve_slices_op.cc b/fcp/tensorflow/serve_slices_op.cc
new file mode 100644
index 0000000..bd41f73
--- /dev/null
+++ b/fcp/tensorflow/serve_slices_op.cc
@@ -0,0 +1,192 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "fcp/tensorflow/serve_slices_registry.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/stringpiece.h"
+
+namespace fcp {
+
+namespace {
+
+REGISTER_OP("ServeSlices")
+ .Attr("NumTensorsInServerVal: int")
+ .Attr("ServerValType: list(type)")
+ .Input("callback_token: string")
+ .Input("server_val: ServerValType")
+ .Input("max_key: int32")
+ .Input("select_fn_initialize_op: string")
+ .Input(
+ "select_fn_server_val_input_tensor_names: NumTensorsInServerVal * "
+ "string")
+ .Input("select_fn_key_input_tensor_name: string")
+ .Input("select_fn_filename_input_tensor_name: string")
+ .Input("select_fn_target_tensor_name: string")
+ .Output("served_at_id: string")
+ .SetIsStateful()
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape);
+
+template <class T>
+tensorflow::Status get_scalar_input(tensorflow::OpKernelContext* context,
+ tensorflow::StringPiece name,
+ T* scalar_out) {
+ const tensorflow::Tensor* tensor;
+ TF_RETURN_IF_ERROR(context->input(name, &tensor));
+ *scalar_out = tensor->scalar<T>()();
+ return tensorflow::OkStatus();
+}
+
+tensorflow::Status get_arbitrary_input_list_as_tensor_vector(
+ tensorflow::OpKernelContext* context, tensorflow::StringPiece name,
+ std::vector<tensorflow::Tensor>* out) {
+ tensorflow::OpInputList input_list;
+ TF_RETURN_IF_ERROR(context->input_list(name, &input_list));
+ out->reserve(input_list.size());
+ for (const tensorflow::Tensor& tensor : input_list) {
+ out->push_back(tensor);
+ }
+ return tensorflow::OkStatus();
+}
+
+tensorflow::Status get_string_list_input(tensorflow::OpKernelContext* context,
+ tensorflow::StringPiece name,
+ std::vector<std::string>* out) {
+ tensorflow::OpInputList input_list;
+ TF_RETURN_IF_ERROR(context->input_list(name, &input_list));
+ out->reserve(input_list.size());
+ for (const tensorflow::Tensor& tensor : input_list) {
+ out->emplace_back(tensor.scalar<tensorflow::tstring>()());
+ }
+ return tensorflow::OkStatus();
+}
+
+// ServeSlices op-kernel.
+//
+// The ServeSlicesOp registers values present on a federated computation server
+// to be sliced and served to clients for a `federated_select`
+//
+// Inputs:
+// callback_token: The ID of the C++ callback to invoke in order to register
+// the
+// given value. Callbacks must first be registered using
+// `register_serve_slices_callback`.
+// server_val: A series of arbitrary-typed tensors from which slices may be
+// generated using a selection function (referred to as `select_fn`).
+// These tensors must be passed into the `select_fn` by writing them to the
+// placeholder tensors named by `select_fn_server_val_input_names`, which
+// must contain exactly one tensor name for each tensor in `server_val`.
+// max_key: An integer indicating the maximum slice index which may be
+// requested. Slice indices start at zero and may go up to `max_key`
+// (inclusive).
+// select_fn_initialize_op: An op to run before each call to `select_fn` in
+// order to reinitialize any state `select_fn` may contain.
+// select_fn_server_val_input_tensor_names: A list of names of the tensors
+// that make up the `server_val` portion of the inputs to `select_fn`. Must
+// be the same length as the number of tensors in `server_val`.
+// select_fn_key_input_tensor_name: The name of the tensor that is the `key`
+// input to `select_fn`.
+// select_fn_filename_input_tensor_name: The name of the placeholder tensor
+// that is the `filename` input to `select_fn`. The `filename` is used to
+// specify where the resulting slice should be written.
+// select_fn_target_tensor_name: The name of the `target` tensor to run which
+// will result in `select_fn`'s output being written to `filename`.
+//
+// Outputs:
+// served_at_id: A string ID under which the resulting slices will be served.
+// This can then be provided to the `FetchSlicesOp` running on clients.
+class ServeSlicesOp : public tensorflow::OpKernel {
+ public:
+ explicit ServeSlicesOp(tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(tensorflow::OpKernelContext* context) override {
+ tensorflow::tstring callback_token_tensor;
+ OP_REQUIRES_OK(context, get_scalar_input(context, "callback_token",
+ &callback_token_tensor));
+ absl::Span<char const> callback_token_bytes = callback_token_tensor;
+ OP_REQUIRES(context, callback_token_bytes.size() == kRandomTokenSizeInBytes,
+ tensorflow::errors::InvalidArgument(absl::StrFormat(
+ "Tokens have a fixed size. Expected: %d; Actual %d",
+ kRandomTokenSizeInBytes, callback_token_bytes.size())));
+ RandomToken callback_token = RandomToken::FromBytes(callback_token_bytes);
+
+ std::vector<tensorflow::Tensor> server_val;
+ OP_REQUIRES_OK(context, get_arbitrary_input_list_as_tensor_vector(
+ context, "server_val", &server_val));
+
+ int32_t max_key;
+ OP_REQUIRES_OK(context, get_scalar_input(context, "max_key", &max_key));
+
+ tensorflow::tstring select_fn_initialize_op;
+ OP_REQUIRES_OK(context, get_scalar_input(context, "select_fn_initialize_op",
+ &select_fn_initialize_op));
+
+ std::vector<std::string> select_fn_server_val_input_tensor_names;
+ OP_REQUIRES_OK(context,
+ get_string_list_input(
+ context, "select_fn_server_val_input_tensor_names",
+ &select_fn_server_val_input_tensor_names));
+
+ tensorflow::tstring select_fn_key_input_tensor_name;
+ OP_REQUIRES_OK(context,
+ get_scalar_input(context, "select_fn_key_input_tensor_name",
+ &select_fn_key_input_tensor_name));
+
+ tensorflow::tstring select_fn_filename_input_tensor_name;
+ OP_REQUIRES_OK(context, get_scalar_input(
+ context, "select_fn_filename_input_tensor_name",
+ &select_fn_filename_input_tensor_name));
+
+ tensorflow::tstring select_fn_target_tensor_name;
+ OP_REQUIRES_OK(context,
+ get_scalar_input(context, "select_fn_target_tensor_name",
+ &select_fn_target_tensor_name));
+
+ std::optional<std::shared_ptr<ServeSlicesCallback>> callback =
+ get_serve_slices_callback(callback_token);
+ OP_REQUIRES(context, callback.has_value(),
+ tensorflow::errors::InvalidArgument(
+ absl::StrCat("No `ServeSlices` callback found for token ",
+ callback_token.ToPrintableString())));
+ std::string served_at_id =
+ (**callback)(callback_token, std::move(server_val), max_key,
+ std::move(select_fn_initialize_op),
+ std::move(select_fn_server_val_input_tensor_names),
+ std::move(select_fn_key_input_tensor_name),
+ std::move(select_fn_filename_input_tensor_name),
+ std::move(select_fn_target_tensor_name));
+
+ tensorflow::Tensor* output_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor));
+ output_tensor->scalar<tensorflow::tstring>()() = std::move(served_at_id);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ServeSlices").Device(tensorflow::DEVICE_CPU),
+ ServeSlicesOp);
+
+} // namespace
+
+} // namespace fcp
diff --git a/fcp/tensorflow/serve_slices_op_test.cc b/fcp/tensorflow/serve_slices_op_test.cc
new file mode 100644
index 0000000..49fe647
--- /dev/null
+++ b/fcp/tensorflow/serve_slices_op_test.cc
@@ -0,0 +1,178 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fcntl.h>
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "fcp/tensorflow/serve_slices_registry.h"
+#include "google/protobuf/io/zero_copy_stream_impl.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/platform/status_matchers.h"
+#include "tensorflow/core/protobuf/error_codes.pb.h"
+#include "tensorflow/core/public/session.h"
+
+namespace fcp {
+namespace {
+
+using ::testing::_;
+using ::testing::HasSubstr;
+using ::testing::Return;
+
+using MockServeSlicesCallback = ::testing::MockFunction<std::string(
+ RandomToken, std::vector<tensorflow::Tensor>, int32_t, absl::string_view,
+ std::vector<std::string>, absl::string_view, absl::string_view,
+ absl::string_view)>;
+
+// Constants related to the GraphDef we test with
+// See make_serve_slices_test_graph.py
+
+char const* const kExampleGraphPath =
+ "fcp/tensorflow/serve_slices_test.pbtxt";
+char const* const kCallbackTokenPlaceholderName = "callback_token";
+char const* const kServedAtTensorName = "served_at_id:0";
+
+// Loads the example graph created by `make_serve_slices_test_graph.py`.
+tensorflow::GraphDef LoadExampleGraph() {
+ int fd = open(kExampleGraphPath, O_RDONLY);
+ CHECK(fd != -1) << "Failed to open example graph at path "
+ << kExampleGraphPath;
+
+ google::protobuf::io::FileInputStream fs(fd);
+ fs.SetCloseOnDelete(true);
+
+ tensorflow::GraphDef graph;
+ bool parsed = google::protobuf::TextFormat::Parse(&fs, &graph);
+ CHECK(parsed) << "Invalid text-format GraphDef";
+
+ return graph;
+}
+
+// Loads the example graph created by `make_serve_slices_test_graph.py` into a
+// `tensorflow::Session`.
+std::unique_ptr<tensorflow::Session> PrepareExampleGraphSession() {
+ tensorflow::GraphDef graph = LoadExampleGraph();
+
+ std::unique_ptr<tensorflow::Session> session;
+ {
+ tensorflow::SessionOptions options;
+ tensorflow::Session* raw_session = nullptr;
+ tensorflow::Status session_new_status =
+ tensorflow::NewSession(options, &raw_session);
+ TF_CHECK_OK(session_new_status);
+ session = std::unique_ptr<tensorflow::Session>(raw_session);
+ }
+
+ tensorflow::Status graph_build_status = session->Create(graph);
+ TF_CHECK_OK(graph_build_status);
+ return session;
+}
+
+class ServeSlicesOpTest : public ::testing::Test {
+ protected:
+ void SetUp() override { session_ = PrepareExampleGraphSession(); }
+
+ // Runs a `ServeSlices` session and returns the result.
+ //
+ // Inputs:
+ // callback_token: A `tensorflow::Tensor` to use as the `callback_token`
+ // argument to `ServeSlices`. For successful calls, this must be a
+ // `RandomToken` corresponding to the `HostObjectRegistration returned by
+ // `register_serve_slices_callback`.
+ // served_at_id_out: An output parameter into which the `served_at_id`
+ // returned from `ServeSlices` is stored.
+ //
+ // Outputs:
+ // The status of the `session.Run` invocation.
+ tensorflow::Status RunSession(tensorflow::Tensor callback_token,
+ tensorflow::Tensor& served_at_id_out) {
+ std::vector<tensorflow::Tensor> outputs;
+ tensorflow::Status run_status =
+ session_->Run({{kCallbackTokenPlaceholderName, callback_token}},
+ {kServedAtTensorName}, {}, &outputs);
+
+ if (run_status.ok()) {
+ CHECK(outputs.size() == 1)
+ << "Expected one output, found " << outputs.size();
+ served_at_id_out = outputs[0];
+ }
+
+ return run_status;
+ }
+
+ // Runs a `ServeSlices` session and returns the result.
+ //
+ // This method is similar to `RunSession`, but it expects that the run is
+ // successful and enforces that the inputs and outputs are correctly-typed.
+ //
+ // Inputs:
+ // callback_token: The `CallbackToken` of the callback to invoke from
+ // `ServeSlices`.
+ //
+ // Outputs:
+ // The `served_at_id` returned from `ServeSlices`.
+ std::string RunSessionExpectingSuccess(RandomToken callback_token) {
+ tensorflow::Tensor served_at_id_out;
+ TF_CHECK_OK(RunSession(tensorflow::Tensor(callback_token.ToString()),
+ served_at_id_out));
+ return served_at_id_out.scalar<tensorflow::tstring>()();
+ }
+
+ private:
+ std::unique_ptr<tensorflow::Session> session_;
+};
+
+TEST_F(ServeSlicesOpTest, SessionRunCallsBackIntoCPP) {
+ std::string mock_served_at_id = "mock_served_at_id";
+ MockServeSlicesCallback mock_callback;
+ HostObjectRegistration callback_registration =
+ register_serve_slices_callback(mock_callback.AsStdFunction());
+ RandomToken callback_token = callback_registration.token();
+ EXPECT_CALL(mock_callback, Call(callback_token, _, _, _, _, _, _, _))
+ .WillOnce(Return(mock_served_at_id));
+ std::string served_at_id = RunSessionExpectingSuccess(callback_token);
+ EXPECT_EQ(served_at_id, "mock_served_at_id");
+}
+
+TEST_F(ServeSlicesOpTest, SessionRunFailsOnMissingCallback) {
+ std::optional<RandomToken> callback_token;
+ {
+ MockServeSlicesCallback mock_callback;
+ HostObjectRegistration callback_registration =
+ register_serve_slices_callback(mock_callback.AsStdFunction());
+ callback_token = callback_registration.token();
+ // The registration gets destructed here.
+ }
+ tensorflow::Tensor callback_token_tensor(callback_token->ToString());
+ tensorflow::Tensor served_at_id_out;
+ tensorflow::Status status =
+ RunSession(callback_token_tensor, served_at_id_out);
+ // Remove the cast after TF 2.12 is released and used in FCP.
+ EXPECT_THAT(
+ status,
+ tensorflow::testing::StatusIs(
+ static_cast<tsl::errors::Code>(absl::StatusCode::kInvalidArgument),
+ HasSubstr("No `ServeSlices` callback found")));
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/tensorflow/serve_slices_registry.h b/fcp/tensorflow/serve_slices_registry.h
new file mode 100644
index 0000000..249b950
--- /dev/null
+++ b/fcp/tensorflow/serve_slices_registry.h
@@ -0,0 +1,107 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TENSORFLOW_SERVE_SLICES_REGISTRY_H_
+#define FCP_TENSORFLOW_SERVE_SLICES_REGISTRY_H_
+
+#include <functional>
+#include <string>
+#include <utility>
+
+#include "fcp/tensorflow/host_object.h"
+
+// Forward declare Tensor to avoid an explicit dependency on the TensorFlow
+// framework. Dependencies of custom ops (which this target is) are not able to
+// depend on the full TensorFlow framework.
+namespace tensorflow {
+
+class Tensor;
+
+} // namespace tensorflow
+
+namespace fcp {
+
+// A callback to invoke when the `ServeSlices` custom op is called.
+//
+// Callbacks are responsible for ensuring that the provided `server_val` is
+// sliced up using the provided selection function (`select_fn`) and that the
+// resulting slices are made available to clients.
+//
+// May be invoked from other threads by the TensorFlow runtime.
+//
+// Inputs:
+// callback_token: The random token associated with this callback by the
+// `HostObjectRegistration` returned by
+// `register_serve_slices_callback(...)`.
+// server_val: A series of arbitrary-typed tensors from which slices may be
+// generated using a selection function (referred to as `select_fn`).
+// These tensors must be passed into the `select_fn` by writing them to the
+// placeholder tensors named by `select_fn_server_val_input_names`, which
+// must contain exactly one tensor name for each tensor in `server_val`.
+// max_key: An integer indicating the maximum slice index which may be
+// requested. Slice indices start at zero and may go up to `max_key`
+// (inclusive).
+// select_fn_initialize_op: An op to run before each call to `select_fn` in
+// order to reinitialize any state `select_fn` may contain.
+// select_fn_server_val_input_tensor_names: A list of names of the tensors
+// that make up the `server_val` portion of the inputs to `select_fn`. Must
+// be the same length as the number of tensors in `server_val`.
+// select_fn_key_input_tensor_name: The name of the tensor that is the `key`
+// input to `select_fn`.
+// select_fn_filename_input_tensor_name: The name of the placeholder tensor
+// that is the `filename` input to `select_fn`. The `filename` is used to
+// specify where the resulting slice should be written.
+// select_fn_target_tensor_name: The name of the `target` tensor to run which
+// will result in `select_fn`'s output being written to `filename`.
+//
+// Outputs:
+// served_at_id: A string ID under which the resulting slices will be served.
+// This can then be provided to the `FetchSlicesOp` running on clients.
+using ServeSlicesCallback = std::function<std::string(
+ /*callback_token=*/RandomToken,
+ /*server_val=*/std::vector<tensorflow::Tensor>,
+ /*max_key=*/int32_t,
+ /*select_fn_initialize_op=*/std::string,
+ /*select_fn_server_val_input_tensor_names=*/std::vector<std::string>,
+ /*select_fn_key_input_tensor_name=*/absl::string_view,
+ /*select_fn_filename_input_tensor_name=*/absl::string_view,
+ /*select_fn_target_tensor_name=*/absl::string_view)>;
+
+// Registers a callback to be invoked by the `ServeSlices` op.
+//
+// Inputs:
+// callback: The callback to register.
+//
+// Outputs:
+// A `HostObjectRegistration` value which owns the association of the callback
+// with the global callback registry. When this object is destroyed, the
+// callback will be unregistered. To refer to this callback in other methods,
+// use the `token()` method on this object.
+inline HostObjectRegistration register_serve_slices_callback(
+ ServeSlicesCallback callback) {
+ return HostObjectRegistry<ServeSlicesCallback>::Register(
+ std::make_shared<ServeSlicesCallback>(std::move(callback)));
+}
+
+// Returns the callback registered with the given `token` if one exists.
+inline std::optional<std::shared_ptr<ServeSlicesCallback>>
+get_serve_slices_callback(RandomToken token) {
+ return HostObjectRegistry<ServeSlicesCallback>::TryLookup(token);
+}
+
+} // namespace fcp
+
+#endif // FCP_TENSORFLOW_SERVE_SLICES_REGISTRY_H_
diff --git a/fcp/tensorflow/serve_slices_registry_test.cc b/fcp/tensorflow/serve_slices_registry_test.cc
new file mode 100644
index 0000000..80ef1df
--- /dev/null
+++ b/fcp/tensorflow/serve_slices_registry_test.cc
@@ -0,0 +1,82 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tensorflow/serve_slices_registry.h"
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/random_token.h"
+#include "fcp/tensorflow/host_object.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace fcp {
+namespace {
+
+using ::testing::_;
+using ::testing::Return;
+
+using MockServeSlicesCallback = ::testing::MockFunction<std::string(
+ RandomToken, std::vector<tensorflow::Tensor>, int32_t, absl::string_view,
+ std::vector<std::string>, absl::string_view, absl::string_view,
+ absl::string_view)>;
+
+TEST(ServeSlicesRegistryTest, CanRegisterGetAndUnregisterCallback) {
+ MockServeSlicesCallback mock_callback;
+ std::optional<RandomToken> id = std::nullopt;
+ {
+ HostObjectRegistration registration =
+ register_serve_slices_callback(mock_callback.AsStdFunction());
+ id = registration.token();
+ std::optional<std::shared_ptr<ServeSlicesCallback>> returned_callback =
+ get_serve_slices_callback(*id);
+ ASSERT_TRUE(returned_callback.has_value());
+
+ std::string mock_served_at_id = "served_at_id";
+ EXPECT_CALL(mock_callback, Call(*id, _, _, _, _, _, _, _))
+ .WillOnce(Return(mock_served_at_id));
+ EXPECT_EQ(mock_served_at_id,
+ (**returned_callback)(*id, {}, 0, "", {}, "", "", ""));
+ }
+ // Check that it is gone after `registration` has been destroyed.
+ EXPECT_EQ(std::nullopt, get_serve_slices_callback(*id));
+}
+
+TEST(ServeSlicesRegistryTest, CanRegisterMultipleDifferentCallbacks) {
+ constexpr int8_t num_callbacks = 5;
+ MockServeSlicesCallback mock_callbacks[num_callbacks];
+ std::vector<HostObjectRegistration> callback_tokens;
+ // Register all callbacks.
+ for (int8_t i = 0; i < num_callbacks; i++) {
+ callback_tokens.push_back(
+ register_serve_slices_callback(mock_callbacks[i].AsStdFunction()));
+ }
+ // Get and invoke all callbacks.
+ for (int8_t i = 0; i < num_callbacks; i++) {
+ RandomToken id = callback_tokens[i].token();
+ std::optional<std::shared_ptr<ServeSlicesCallback>> returned_callback =
+ get_serve_slices_callback(id);
+ ASSERT_TRUE(returned_callback.has_value());
+
+ std::string mock_served_at_id = absl::StrCat("served_at_id_", i);
+ EXPECT_CALL(mock_callbacks[i], Call(id, _, _, _, _, _, _, _))
+ .WillOnce(Return(mock_served_at_id));
+ EXPECT_EQ(mock_served_at_id,
+ (**returned_callback)(id, {}, 0, "", {}, "", "", ""));
+ }
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/tensorflow/status.cc b/fcp/tensorflow/status.cc
new file mode 100644
index 0000000..a4fe54f
--- /dev/null
+++ b/fcp/tensorflow/status.cc
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/tensorflow/status.h"
+
+#include "tensorflow/core/public/version.h"
+
+namespace fcp {
+
+tensorflow::Status ConvertToTensorFlowStatus(Status const& status) {
+ absl::StatusCode code = status.code();
+ if (code == absl::StatusCode::kOk) {
+ return tensorflow::Status();
+ } else {
+ // tensorflow::Status constructor asserts that code != OK if a message is
+ // provided.
+ // Remove the cast after TF 2.12 is released and used in FCP.
+ return tensorflow::Status(static_cast<tensorflow::errors::Code>(code),
+ status.message());
+ }
+}
+
+Status ConvertFromTensorFlowStatus(tensorflow::Status const& tf_status) {
+ return Status(static_cast<absl::StatusCode>(tf_status.code()),
+#if TF_GRAPH_DEF_VERSION < 1467
+ tf_status.error_message());
+#else
+ tf_status.message());
+#endif
+}
+
+} // namespace fcp
diff --git a/fcp/tensorflow/status.h b/fcp/tensorflow/status.h
new file mode 100644
index 0000000..b2efb10
--- /dev/null
+++ b/fcp/tensorflow/status.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TENSORFLOW_STATUS_H_
+#define FCP_TENSORFLOW_STATUS_H_
+
+#include "fcp/base/monitoring.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace fcp {
+
+/**
+ * Converts an FCP Status to a tensorflow::Status.
+ */
+tensorflow::Status ConvertToTensorFlowStatus(Status const&);
+
+/**
+ * Converts a tensorflow::Status to an FCP Status.
+ */
+Status ConvertFromTensorFlowStatus(tensorflow::Status const&);
+
+} // namespace fcp
+
+#endif // FCP_TENSORFLOW_STATUS_H_
diff --git a/fcp/tensorflow/status_test.cc b/fcp/tensorflow/status_test.cc
new file mode 100644
index 0000000..ad161cd
--- /dev/null
+++ b/fcp/tensorflow/status_test.cc
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/tensorflow/status.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace fcp {
+
+using ::testing::Eq;
+
+TEST(StatusTest, ToTensorFlow_Ok) {
+ EXPECT_THAT(ConvertToTensorFlowStatus(FCP_STATUS(OK)),
+ Eq(tensorflow::OkStatus()));
+}
+
+TEST(StatusTest, ToTensorFlow_Error) {
+ Status error = FCP_STATUS(NOT_FOUND) << "Where is my mind?";
+ EXPECT_THAT(ConvertToTensorFlowStatus(error),
+ // Remove the cast after TF 2.12 is released and used in FCP.
+ Eq(tensorflow::Status(
+ static_cast<tsl::errors::Code>(absl::StatusCode::kNotFound),
+ error.message())));
+}
+
+} // namespace fcp
diff --git a/fcp/tensorflow/system_provided_tf/BUILD b/fcp/tensorflow/system_provided_tf/BUILD
new file mode 100644
index 0000000..82bab3f
--- /dev/null
+++ b/fcp/tensorflow/system_provided_tf/BUILD
@@ -0,0 +1 @@
+# This empty BUILD file is required to make Bazel treat this directory as a package.
diff --git a/fcp/tensorflow/system_provided_tf/README.md b/fcp/tensorflow/system_provided_tf/README.md
new file mode 100644
index 0000000..939efbe
--- /dev/null
+++ b/fcp/tensorflow/system_provided_tf/README.md
@@ -0,0 +1,20 @@
+
+When building libraries (such as custom op libraries) agains the TensorFlow pip
+package, care must be taken to ensure those libraries build against that
+package's headers and with the same compiler and linker flags as that package
+was compiled with. These utilities help ensure that's the case.
+
+First, add the following to your `WORKSPACE` file to configure a repository that
+provides the C++ headers and libraries provided by the TensorFlow pip package.
+
+```
+load("//fcp/tensorflow/system_provided_tf:system_provided_tf.bzl", "system_provided_tf")
+system_provided_tf(name = "system_provided_tf")
+```
+
+Then simply load `tf_custom_op_library` from
+`@system_provided_tf//:system_provided_tf.bzl` instead of
+`@org_tensorflow//tensorflow:tensorflow.bzl`.
+
+NOTE: The `gpu_srcs` and `gpu_deps` parameters supported by TensorFlow's version
+of `tf_custom_op_library` are not supported by this version.
diff --git a/fcp/tensorflow/system_provided_tf/system_provided_tf.bzl b/fcp/tensorflow/system_provided_tf/system_provided_tf.bzl
new file mode 100644
index 0000000..b3f77d7
--- /dev/null
+++ b/fcp/tensorflow/system_provided_tf/system_provided_tf.bzl
@@ -0,0 +1,161 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Sets up a repository for using the system-provided TensorFlow."""
+
+def _process_compile_flags(repository_ctx, python3, headers_dir):
+ """Processes compilation flags required by the system-provided TF package.
+
+ The tf.sysconfig module provides the compilation flags that should be used
+ for custom operators. These will include the directory containing the
+ TensorFlow C++ headers ("-I/some/path") and possibly other flags (e.g.,
+ "-DSOME_FLAG=2").
+
+ A symlink is created from `headers_dir` to the directory containing the C++
+ headers. The list of other flags is returned.
+ """
+ result = repository_ctx.execute([
+ python3,
+ "-c",
+ ";".join([
+ "import tensorflow as tf",
+ "print('\\0'.join(tf.sysconfig.get_compile_flags()))",
+ ]),
+ ])
+ if result.return_code != 0:
+ fail("Failed to determine TensorFlow compile flags; is TensorFlow installed?")
+ include_dir = None
+ copts = []
+ cxxopts = []
+ for flag in result.stdout.strip().split("\0"):
+ if flag.startswith("-I"):
+ if include_dir != None:
+ fail("Only one TensorFlow headers directory is supported.")
+ include_dir = flag[2:]
+ elif flag.startswith("--std=c++"): # Don't add C++-only flags to copts.
+ cxxopts.append(flag)
+ else:
+ copts.append(flag)
+
+ if not include_dir:
+ fail("Unable to find TensorFlow headers directory.")
+ repository_ctx.symlink(include_dir, headers_dir)
+
+ return copts, cxxopts
+
+def _process_link_flags(repository_ctx, python3, library_file):
+ """Processes linker flags required by the system-provided TF package.
+
+ The tf.sysconfig module provides the linker flags that should be used
+ for custom operators. These will include the directory containing
+ libtensorflow_framework.so ("-L/some/path"), the library to link
+ ("-l:libtensorflow_framework.so.2"), and possibly other flags.
+
+ A symlink is created from `library_file` to libtensorflow_framework.so. The
+ list of other flags is returned.
+ """
+ result = repository_ctx.execute([
+ python3,
+ "-c",
+ ";".join([
+ "import tensorflow as tf",
+ "print('\\0'.join(tf.sysconfig.get_link_flags()))",
+ ]),
+ ])
+ if result.return_code != 0:
+ fail("Failed to determine TensorFlow link flags; is TensorFlow installed?")
+ link_dir = None
+ library = None
+ linkopts = []
+ for flag in result.stdout.strip().split("\0"):
+ if flag.startswith("-L"):
+ if link_dir != None:
+ fail("Only one TensorFlow libraries directory is supported.")
+ link_dir = flag[2:]
+ elif flag.startswith("-l"):
+ if library != None:
+ fail("Only one TensorFlow library is supported.")
+
+ # "-l" may be followed by ":" to force the linker to use exact
+ # library name resolution.
+ library = flag[2:].lstrip(":")
+ else:
+ linkopts.append(flag)
+
+ if not link_dir or not library:
+ fail("Unable to find TensorFlow library.")
+ repository_ctx.symlink(link_dir + "/" + library, library_file)
+
+ return linkopts
+
+def _tf_custom_op_configure_impl(repository_ctx):
+ """Defines a repository for using the system-provided TensorFlow package.
+
+ This is a lot like new_local_repository except that (a) the files to
+ include are dynamically determined using TensorFlow's `tf.sysconfig` Python
+ module, and (b) it provides build rules to compile and link C++ code with
+ the necessary options to be compatible with the system-provided TensorFlow
+ package.
+ """
+ python3 = repository_ctx.os.environ.get("PYTHON_BIN_PATH", "python3")
+
+ # Name of the sub-directory that will link to TensorFlow C++ headers.
+ headers_dir = "headers"
+
+ # Name of the file that will link to libtensorflow_framework.so.
+ library_file = "libtensorflow_framework.so"
+
+ copts, cxxopts = _process_compile_flags(repository_ctx, python3, headers_dir)
+ linkopts = _process_link_flags(repository_ctx, python3, library_file)
+
+ # Create a BUILD file providing targets for the TensorFlow C++ headers and
+ # framework library.
+ repository_ctx.template(
+ "BUILD",
+ Label("//fcp/tensorflow/system_provided_tf:templates/BUILD.tpl"),
+ substitutions = {
+ "%{HEADERS_DIR}": headers_dir,
+ "%{LIBRARY_FILE}": library_file,
+ },
+ executable = False,
+ )
+
+ # Create a bzl file providing rules for compiling C++ code compatible with
+ # the TensorFlow package.
+ repository_ctx.template(
+ "system_provided_tf.bzl",
+ Label("//fcp/tensorflow/system_provided_tf:templates/system_provided_tf.bzl.tpl"),
+ substitutions = {
+ "%{COPTS}": str(copts),
+ "%{CXXOPTS}": str(cxxopts),
+ "%{LINKOPTS}": str(linkopts),
+ "%{REPOSITORY_NAME}": repository_ctx.name,
+ },
+ executable = False,
+ )
+
+system_provided_tf = repository_rule(
+ implementation = _tf_custom_op_configure_impl,
+ configure = True,
+ doc = """Creates a repository with targets for the system-provided TensorFlow.
+
+This repository defines (a) //:tf_headers providing the C++ TensorFlow headers,
+(b) //:libtensorflow_framework providing the TensorFlow framework shared
+library, and (c) //:system_provided_tf.bzl for building custom op libraries
+that are compatible with the system-provided TensorFlow package.
+""",
+ environ = [
+ "PYTHON_BIN_PATH",
+ ],
+ local = True,
+)
diff --git a/fcp/tensorflow/system_provided_tf/templates/BUILD.tpl b/fcp/tensorflow/system_provided_tf/templates/BUILD.tpl
new file mode 100644
index 0000000..31385fb
--- /dev/null
+++ b/fcp/tensorflow/system_provided_tf/templates/BUILD.tpl
@@ -0,0 +1,36 @@
+load("@bazel_skylib//rules:common_settings.bzl", "bool_setting")
+
+package(default_visibility = ["//visibility:public"])
+
+# Config setting to use in select()'s to distinguish building for the
+# system-provided TensorFlow package.
+config_setting(
+ name = "system_provided_tf_build",
+ flag_values = {":system_provided_tf_build_setting": "True"},
+)
+
+# Non-configurable build setting to indicate building using the system-provided
+# TensorFlow package.
+bool_setting(
+ name = "system_provided_tf_build_setting",
+ build_setting_default = False,
+ visibility = ["//visibility:private"],
+)
+
+# Internal config setting to distinguish clang from other compilers. This target
+# should not be used directly.
+config_setting(
+ name = "clang_compiler",
+ flag_values = {"@bazel_tools//tools/cpp:compiler": "clang"},
+)
+
+cc_library(
+ name = "tf_headers",
+ hdrs = glob(["%{HEADERS_DIR}/**"]),
+ includes = ["%{HEADERS_DIR}"]
+)
+
+cc_library(
+ name = "libtensorflow_framework",
+ srcs = ["%{LIBRARY_FILE}"],
+)
diff --git a/fcp/tensorflow/system_provided_tf/templates/system_provided_tf.bzl.tpl b/fcp/tensorflow/system_provided_tf/templates/system_provided_tf.bzl.tpl
new file mode 100644
index 0000000..1024f26
--- /dev/null
+++ b/fcp/tensorflow/system_provided_tf/templates/system_provided_tf.bzl.tpl
@@ -0,0 +1,125 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Provides rules for building custom TensorFlow ops compatible with pip."""
+
+def _force_system_provided_tf_transition_impl(settings, attr):
+ copts = list(settings["//command_line_option:copt"])
+ cxxopts = list(settings["//command_line_option:cxxopt"])
+ linkopts = list(settings["//command_line_option:linkopt"])
+ copts += %{COPTS}
+ cxxopts += %{CXXOPTS}
+ linkopts += %{LINKOPTS}
+ # TensorFlow's pip package was built with libstdc++; ensure the right
+ # standard library is used if the compiler supports multiple.
+ if attr.compiler_supports_stdlib:
+ cxxopts += ["-stdlib=libstdc++"]
+ linkopts += ["-stdlib=libstdc++"]
+ return {
+ "//:system_provided_tf_build_setting": True,
+ "//command_line_option:copt": copts,
+ "//command_line_option:cxxopt": cxxopts,
+ "//command_line_option:linkopt": linkopts,
+ }
+
+_force_system_provided_tf_transition = transition(
+ implementation = _force_system_provided_tf_transition_impl,
+ inputs = [
+ "//command_line_option:copt",
+ "//command_line_option:cxxopt",
+ "//command_line_option:linkopt",
+ ],
+ outputs = [
+ "//:system_provided_tf_build_setting",
+ "//command_line_option:copt",
+ "//command_line_option:cxxopt",
+ "//command_line_option:linkopt",
+ ],
+)
+
+def _force_system_provided_tf_impl(ctx):
+ cc_binary = ctx.attr.cc_binary[0]
+ output_file = ctx.actions.declare_file(ctx.label.name)
+ ctx.actions.symlink(
+ output = output_file,
+ target_file = cc_binary.files.to_list()[0],
+ )
+ return DefaultInfo(
+ files = depset([output_file]),
+ data_runfiles = ctx.runfiles(transitive_files = depset([output_file])),
+ )
+
+_force_system_provided_tf = rule(
+ doc = """Forces a shared library to be built in a way that's compatible
+with the system-provided Python TensorFlow package.""",
+ implementation = _force_system_provided_tf_impl,
+ attrs = {
+ "cc_binary": attr.label(
+ cfg = _force_system_provided_tf_transition,
+ mandatory = True,
+ doc = "The cc_binary target to build with TensorFlow compatibility.",
+ ),
+ # Compiler information cannot be read by the transition directly because
+ # the StarlarkAttributeTransitionProvider doesn't yet support providers
+ # for dependency-typed attributes.
+ "compiler_supports_stdlib": attr.bool(
+ doc = "Whether the compiler supports the --stdlib flag.",
+ ),
+ "_allowlist_function_transition": attr.label(
+ default = "@bazel_tools//tools/allowlists/function_transition_allowlist",
+ ),
+ },
+)
+
+def tf_custom_op_library(
+ name,
+ srcs = [],
+ deps = [],
+ tags = [],
+ visibility = None,
+ **kwargs):
+ """Helper to build a dynamic library (.so) from the sources containing
+ implementations of custom ops and kernels.
+
+ This rule will force a transition to an environment that targets the
+ system-provided TF library. This means that all deps of this target and the
+ target's own sources will be compiled with the necessary compiler flags to
+ correctly target a system-provided TF library.
+
+ The `@system_provided_tf//:system_provided_tf_build` setting will also be
+ true when those deps are built for this target.
+ """
+
+ native.cc_binary(
+ name = name + "_lib",
+ srcs = srcs,
+ linkshared = 1,
+ deps = deps + [
+ "@%{REPOSITORY_NAME}//:libtensorflow_framework",
+ "@%{REPOSITORY_NAME}//:tf_headers",
+ ],
+ tags = tags + ["manual"],
+ visibility = ["//visibility:private"],
+ **kwargs
+ )
+
+ _force_system_provided_tf(
+ name = name,
+ cc_binary = name + "_lib",
+ compiler_supports_stdlib = select({
+ "@%{REPOSITORY_NAME}//:clang_compiler": True,
+ "//conditions:default": False,
+ }),
+ visibility = visibility,
+ tags = tags,
+ )
diff --git a/fcp/tensorflow/task_eligibility_info_ops.cc b/fcp/tensorflow/task_eligibility_info_ops.cc
new file mode 100644
index 0000000..3ca968c
--- /dev/null
+++ b/fcp/tensorflow/task_eligibility_info_ops.cc
@@ -0,0 +1,103 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string>
+
+#include "fcp/protos/federated_api.pb.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/version.h"
+
+namespace fcp {
+
+using ::google::internal::federatedml::v2::TaskEligibilityInfo;
+using ::google::internal::federatedml::v2::TaskWeight;
+
+/**
+ * CreateTaskEligibilityInfo op-kernel. Converts a set of input tensors into a
+ * `TaskEligibilityInfo` proto serialized into a string tensor.
+ *
+ * This op is used to generate `TaskEligibilityInfo` protos from a model at
+ * runtime, since TF Mobile does not support the standard TensorFlow ops for
+ * encoding/decoding protos.
+ */
+class CreateTaskEligibilityInfoOp : public tensorflow::OpKernel {
+ public:
+ explicit CreateTaskEligibilityInfoOp(
+ tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(tensorflow::OpKernelContext* ctx) override {
+ // Note: We use the tensorflow::data::ParseScalar/VectorArgument helpers
+ // here, even though this op isn't strictly related to our tf.Dataset
+ // integration. The helpers are public though, and we already use them in
+ // our ExternalDataset implementation, so we might as well use them here
+ // too.
+
+ // Parse/validate the input arguments.
+ tensorflow::int64 version;
+ OP_REQUIRES_OK(
+ ctx, tensorflow::data::ParseScalarArgument(ctx, "version", &version));
+ std::vector<tensorflow::tstring> task_names;
+ OP_REQUIRES_OK(ctx, tensorflow::data::ParseVectorArgument(ctx, "task_names",
+ &task_names));
+ std::vector<float> task_weights;
+ OP_REQUIRES_OK(ctx, tensorflow::data::ParseVectorArgument(
+ ctx, "task_weights", &task_weights));
+ OP_REQUIRES(ctx, task_names.size() == task_weights.size(),
+ tensorflow::errors::InvalidArgument(absl::StrCat(
+ "task_names length must match task_weights length: ",
+ task_names.size(), " vs. ", task_weights.size())));
+
+ // Create the output proto, based on the inputs.
+ TaskEligibilityInfo eligibility_info;
+ eligibility_info.set_version(version);
+ // Create a `TaskWeight` message for each pair of `task_names` and
+ // `task_weights` elements.
+ auto task_weight_it = task_weights.cbegin();
+ for (const tensorflow::tstring& task_name : task_names) {
+ float task_weight = *task_weight_it++;
+ TaskWeight* task_weight_proto = eligibility_info.add_task_weights();
+ task_weight_proto->set_task_name(std::string(task_name));
+ task_weight_proto->set_weight(task_weight);
+ }
+
+ // Place the serialized output proto into the output tensor.
+ tensorflow::Tensor* output_tensor;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output("output", tensorflow::TensorShape({}),
+ &output_tensor));
+ output_tensor->scalar<tensorflow::tstring>()() =
+ eligibility_info.SerializeAsString();
+ }
+};
+
+REGISTER_OP("CreateTaskEligibilityInfo")
+ .Input("version: int64")
+ .Input("task_names: string")
+ .Input("task_weights: float32")
+ .Output("output: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape);
+
+REGISTER_KERNEL_BUILDER(
+ Name("CreateTaskEligibilityInfo").Device(tensorflow::DEVICE_CPU),
+ CreateTaskEligibilityInfoOp);
+
+} // namespace fcp
diff --git a/fcp/tensorflow/task_eligibility_info_ops.py b/fcp/tensorflow/task_eligibility_info_ops.py
new file mode 100644
index 0000000..298cf7b
--- /dev/null
+++ b/fcp/tensorflow/task_eligibility_info_ops.py
@@ -0,0 +1,58 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Ops for creating TaskEligibilityInfo results."""
+
+import tensorflow as tf
+
+# Ops implemented in C++
+from fcp.tensorflow import gen_task_eligibility_info_ops
+
+_task_eligibility_info_ops_so = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile(
+ "./_task_eligibility_info_ops.so"))
+
+
+def create_task_eligibility_info(version, task_names, task_weights):
+ """Outputs a serialized `TaskEligibilityInfo` proto based on the given inputs.
+
+ This op is used to generate `TaskEligibilityInfo` protos from a model at
+ runtime, since TF Mobile does not support the standard TensorFlow ops for
+ encoding/decoding protos.
+
+ See the `TaskEligibilityInfo` and `TaskWeight` proto message documentation for
+ more information.
+
+ Args:
+ version: an int64 value to place in the `TaskEligibilityInfo.version` field.
+ task_names: a rank-1 string tensor containing the task names to assign
+ weights to. Each entry in this tensor will be combined with the
+ corresponding entry into the `task_weights` tensor at the same index, to
+ form a `TaskWeight` message.
+ task_weights: a rank-1 float tensor containing the task weight for each task
+ (see `task_names`). Note: this tensor must have the same number of
+ elements as `task_names`.
+
+ Returns:
+ a string tensor containing the serialized proto.
+ """
+ # Convert the inputs to tensors, as a convenience to callers. This ensures
+ # that they can easily pass regular Python or numpy types in addition to
+ # actual tensors.
+ version = tf.convert_to_tensor(version, dtype=tf.int64, name="version")
+ task_names = tf.convert_to_tensor(
+ task_names, dtype=tf.string, name="task_names")
+ task_weights = tf.convert_to_tensor(
+ task_weights, dtype=tf.float32, name="task_weights")
+ return gen_task_eligibility_info_ops.create_task_eligibility_info(
+ version=version, task_names=task_names, task_weights=task_weights)
diff --git a/fcp/tensorflow/task_eligibility_info_ops_test.py b/fcp/tensorflow/task_eligibility_info_ops_test.py
new file mode 100644
index 0000000..6304273
--- /dev/null
+++ b/fcp/tensorflow/task_eligibility_info_ops_test.py
@@ -0,0 +1,96 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tensorflow as tf
+
+from fcp.protos import federated_api_pb2
+from fcp.tensorflow.task_eligibility_info_ops import create_task_eligibility_info
+
+
+class TaskEligibilityInfoOpsTest(tf.test.TestCase):
+
+ def test_create_task_eligibility_info_succeeds(self):
+ # Run the op and parse its result into the expected proto type.
+ actual_serialized_value = create_task_eligibility_info(
+ version=555,
+ task_names=['foo_task', 'bar_task'],
+ task_weights=[123.456, 789.012])
+ tf.debugging.assert_scalar(actual_serialized_value)
+ tf.debugging.assert_type(actual_serialized_value, tf.string)
+
+ actual_value = federated_api_pb2.TaskEligibilityInfo()
+ # Note: the .numpy() call converts the string tensor to a Python string we
+ # can parse the proto from.
+ actual_value.ParseFromString(actual_serialized_value.numpy())
+
+ # Ensure the resulting proto contains the expected data.
+ expected_value = federated_api_pb2.TaskEligibilityInfo(
+ version=555,
+ task_weights=[
+ federated_api_pb2.TaskWeight(task_name='foo_task', weight=123.456),
+ federated_api_pb2.TaskWeight(task_name='bar_task', weight=789.012)
+ ])
+ assert actual_value == expected_value
+
+ def test_create_task_eligibility_info_empty_task_list_succeeds(self):
+ """Tests that an empty `task_names` input is allowed & handled correctly."""
+ actual_serialized_value = create_task_eligibility_info(
+ version=555, task_names=[], task_weights=[])
+ actual_value = federated_api_pb2.TaskEligibilityInfo()
+ actual_value.ParseFromString(actual_serialized_value.numpy())
+
+ # Ensure the resulting proto contains the expected data.
+ expected_value = federated_api_pb2.TaskEligibilityInfo(version=555)
+ assert actual_value == expected_value
+
+ def test_create_task_eligibility_info_non_scalar_version_raises_error(self):
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ create_task_eligibility_info(
+ version=[555], task_names=['foo_task'], task_weights=[123.456])
+
+ def test_create_task_eligibility_info_non_vector_task_names_list_raises_error(
+ self):
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ create_task_eligibility_info(
+ version=555, task_names=[['foo_task']], task_weights=[123.456])
+
+ def test_create_task_eligibility_info_non_vector_task_weights_list_raises_error(
+ self):
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ create_task_eligibility_info(
+ version=555, task_names=['foo_task'], task_weights=[[123.456]])
+
+ def test_create_task_eligibility_info_differing_names_weights_length_raises_error(
+ self):
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ create_task_eligibility_info(
+ version=555, task_names=['foo_task', 'bar_task'], task_weights=[123])
+
+ def test_create_task_eligibility_info_invalid_task_names_type_raises_error(
+ self):
+ with self.assertRaises(TypeError):
+ create_task_eligibility_info(
+ version=555, task_names=[111, 222], task_weights=[123.456, 789.012])
+
+ def test_create_task_eligibility_info_invalid_task_weights_type_raises_error(
+ self):
+ with self.assertRaises(TypeError):
+ create_task_eligibility_info(
+ version=555,
+ task_names=['foo_task', 'bar_task'],
+ task_weights=['hello', 'world'])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/fcp/tensorflow/tensor_crc32.cc b/fcp/tensorflow/tensor_crc32.cc
new file mode 100644
index 0000000..b781e7d
--- /dev/null
+++ b/fcp/tensorflow/tensor_crc32.cc
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "fcp/tensorflow/tensor_crc32.h"
+
+namespace fcp {
+namespace tensorflow {
+namespace checksums {
+
+using ::tensorflow::StringPiece;
+using ::tensorflow::Tensor;
+using ::tensorflow::crc32c::Value;
+
+uint32_t TensorToCRC32(const Tensor& tensor) {
+ StringPiece tensor_data = tensor.tensor_data();
+ return Value(tensor_data.data(), tensor_data.size());
+}
+
+} // namespace checksums
+} // namespace tensorflow
+} // namespace fcp
diff --git a/fcp/tensorflow/tensor_crc32.h b/fcp/tensorflow/tensor_crc32.h
new file mode 100644
index 0000000..5a5fced
--- /dev/null
+++ b/fcp/tensorflow/tensor_crc32.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_TENSORFLOW_TENSOR_CRC32_H_
+#define FCP_TENSORFLOW_TENSOR_CRC32_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+
+namespace fcp {
+namespace tensorflow {
+namespace checksums {
+
+/* Computes the CRC32c checksum of the in-memory representation of a Tensor. */
+uint32_t TensorToCRC32(const ::tensorflow::Tensor& tensor);
+
+} // namespace checksums
+} // namespace tensorflow
+} // namespace fcp
+
+#endif // FCP_TENSORFLOW_TENSOR_CRC32_H_
diff --git a/fcp/tensorflow/tensor_name.py b/fcp/tensorflow/tensor_name.py
new file mode 100644
index 0000000..73262f6
--- /dev/null
+++ b/fcp/tensorflow/tensor_name.py
@@ -0,0 +1,33 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Provides the `tensor_name` operation.
+
+This wraps the generated op and ensures that necessary shared libraries
+are loaded.
+"""
+
+import tensorflow as tf
+
+from fcp.tensorflow import gen_tensor_name_py
+
+_tensor_name_so = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile('./_tensor_name_op.so'))
+
+
+def tensor_name(tensor):
+ """Returns the final graph name of a tensor as a string tensor."""
+ if not tf.is_tensor(tensor):
+ raise TypeError('`tensor_name` expected a tensor, found object of type '
+ f'{type(tensor)}.')
+ return gen_tensor_name_py.tensor_name(input_tensor=tensor)
diff --git a/fcp/tensorflow/tensor_name_op.cc b/fcp/tensorflow/tensor_name_op.cc
new file mode 100644
index 0000000..b27295d
--- /dev/null
+++ b/fcp/tensorflow/tensor_name_op.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <utility>
+
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_requires.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/stringpiece.h"
+
+namespace fcp {
+
+namespace {
+
+REGISTER_OP("TensorName")
+ .Attr("InputType: type")
+ .Input("input_tensor: InputType")
+ .Output("tensor_name: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape);
+
+class TensorNameOp : public tensorflow::OpKernel {
+ public:
+ explicit TensorNameOp(tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {
+ const tensorflow::NodeDef& def = context->def();
+ // Note: more than one input is allowed since the "true" input node may be
+ // followed by any number of control inputs.
+ OP_REQUIRES(
+ context, def.input_size() >= 1,
+ tensorflow::errors::InvalidArgument("Expected an input, found none."));
+ input_name_ = def.input(0);
+ }
+
+ void Compute(tensorflow::OpKernelContext* context) override {
+ tensorflow::Tensor* output_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor));
+ output_tensor->scalar<tensorflow::tstring>()() = input_name_;
+ }
+
+ private:
+ tensorflow::tstring input_name_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("TensorName").Device(tensorflow::DEVICE_CPU),
+ TensorNameOp);
+
+} // namespace
+
+} // namespace fcp
diff --git a/fcp/tensorflow/tensor_name_test.py b/fcp/tensorflow/tensor_name_test.py
new file mode 100644
index 0000000..68ee59d
--- /dev/null
+++ b/fcp/tensorflow/tensor_name_test.py
@@ -0,0 +1,52 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for `tensor_name` custom op."""
+
+import tensorflow as tf
+
+from fcp.tensorflow import tensor_name
+
+
+class TensorNameTest(tf.test.TestCase):
+
+ def test_returns_simple_name(self):
+ test_name = b'placeholder_test_name'
+ with tf.Graph().as_default() as graph:
+ placeholder = tf.compat.v1.placeholder_with_default(
+ input='default_value', shape=(), name=test_name)
+ tensor_name_out = tensor_name.tensor_name(placeholder)
+ with tf.compat.v1.Session(graph=graph) as sess:
+ result = sess.run(tensor_name_out)
+ self.assertEqual(test_name, result)
+
+ def test_returns_modified_name_after_reimport(self):
+ test_name = b'placeholder_test_name'
+ with tf.Graph().as_default() as inner_graph:
+ placeholder = tf.compat.v1.placeholder_with_default(
+ input='default_value', shape=(), name=test_name)
+ inner_tensor_name_out = tensor_name.tensor_name(placeholder)
+ import_prefix = b'import_prefix_'
+ with tf.Graph().as_default() as outer_graph:
+ tensor_name_out = tf.graph_util.import_graph_def(
+ graph_def=inner_graph.as_graph_def(),
+ input_map={},
+ return_elements=[inner_tensor_name_out.name],
+ name=import_prefix)[0]
+ with tf.compat.v1.Session(graph=outer_graph) as sess:
+ result = sess.run(tensor_name_out)
+ self.assertEqual(b'/'.join([import_prefix, test_name]), result)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/fcp/tensorflow/test_selector.proto b/fcp/tensorflow/test_selector.proto
new file mode 100644
index 0000000..801195f
--- /dev/null
+++ b/fcp/tensorflow/test_selector.proto
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Defines TestSelector, TestCriteria and ResumptionToken, for use in
+// ExternalDataset and ExampleSelectorFuser testing.
+
+syntax = "proto3";
+
+package fcp;
+
+message Limit {
+ int64 value = 1;
+}
+
+// Specifies that examples (int64 scalars) in the range [lower_inclusive,
+// upper_inclusive] should be included.
+message TestSelector {
+ Limit lower_inclusive = 1;
+ Limit upper_inclusive = 2;
+}
+
+// Simple example selection criteria which limits the maximum number of examples
+message TestCriteria {
+ // The max number of examples that should be returned by this query.
+ int32 max_examples = 1;
+}
+
+message ResumptionToken {
+ int32 last_index = 1;
+}
diff --git a/fcp/tensorflow/testing/BUILD b/fcp/tensorflow/testing/BUILD
new file mode 100644
index 0000000..0f89567
--- /dev/null
+++ b/fcp/tensorflow/testing/BUILD
@@ -0,0 +1,39 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = [
+ "//fcp:internal",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "tf_helper",
+ testonly = 1,
+ srcs = ["tf_helper.cc"],
+ hdrs = ["tf_helper.h"],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/base:result",
+ "//fcp/tensorflow:tf_session",
+ "@com_google_absl//absl/strings:cord",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/cc:cc_ops",
+ "@org_tensorflow//tensorflow/cc:scope",
+ "@org_tensorflow//tensorflow/core:protos_all_cc",
+ ],
+)
diff --git a/fcp/tensorflow/testing/tf_helper.cc b/fcp/tensorflow/testing/tf_helper.cc
new file mode 100644
index 0000000..285dddc
--- /dev/null
+++ b/fcp/tensorflow/testing/tf_helper.cc
@@ -0,0 +1,29 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/tensorflow/testing/tf_helper.h"
+
+namespace fcp {
+
+absl::Cord CreateGraph(tensorflow::Scope* root) {
+ tensorflow::GraphDef def;
+ tensorflow::Status to_graph_status = root->ToGraphDef(&def);
+ EXPECT_TRUE(to_graph_status.ok()) << to_graph_status;
+ // TODO(team): Use SerializeAsCord when available.
+ return absl::Cord(def.SerializeAsString());
+}
+
+} // namespace fcp
diff --git a/fcp/tensorflow/testing/tf_helper.h b/fcp/tensorflow/testing/tf_helper.h
new file mode 100644
index 0000000..8f1defb
--- /dev/null
+++ b/fcp/tensorflow/testing/tf_helper.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TENSORFLOW_TESTING_TF_HELPER_H_
+#define FCP_TENSORFLOW_TESTING_TF_HELPER_H_
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/strings/cord.h"
+#include "fcp/base/result.h"
+#include "fcp/tensorflow/tf_session.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/io_ops.h"
+#include "tensorflow/cc/ops/state_ops.h"
+#include "tensorflow/core/protobuf/saver.pb.h"
+
+namespace fcp {
+
+/**
+ * Get a serialized graph with the operations defined on the provided scope.
+ */
+absl::Cord CreateGraph(tensorflow::Scope* root);
+
+} // namespace fcp
+
+#endif // FCP_TENSORFLOW_TESTING_TF_HELPER_H_
diff --git a/fcp/tensorflow/tf_py_smoke_test.py b/fcp/tensorflow/tf_py_smoke_test.py
new file mode 100644
index 0000000..74c39e9
--- /dev/null
+++ b/fcp/tensorflow/tf_py_smoke_test.py
@@ -0,0 +1,38 @@
+#!/usr/bin/python
+#
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.compat.v2 as tf
+tf.enable_v2_behavior()
+
+
+@tf.function
+def matmul(a, b):
+ return tf.matmul(a, b)
+
+
+class TfPySmokeTest(tf.test.TestCase):
+
+ def test_matmul(self):
+ actual = matmul(tf.constant([[1, 2]]), tf.transpose([[3, 4]]))
+ self.assertEqual(11, actual)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/fcp/tensorflow/tf_session.cc b/fcp/tensorflow/tf_session.cc
new file mode 100644
index 0000000..ead7cb4
--- /dev/null
+++ b/fcp/tensorflow/tf_session.cc
@@ -0,0 +1,191 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/tensorflow/tf_session.h"
+
+#include <cstdio>
+#include <fstream>
+#include <iostream>
+#include <string>
+#include <utility>
+
+#include "absl/strings/cord.h"
+#include "fcp/base/platform.h"
+#include "fcp/base/process_unique_id.h"
+#include "fcp/base/result.h"
+#include "fcp/tensorflow/status.h"
+#include "tensorflow/core/protobuf/saver.pb.h"
+
+namespace fcp {
+
+#define TF_STATUS_EXPECT_OK(tf_status) \
+ Result(ConvertFromTensorFlowStatus(tf_status)).Then(ExpectOk())
+
+using CheckpointOp = google::internal::federated::plan::CheckpointOp;
+
+TfSession::TfSession(const std::filesystem::path& tmp_dir,
+ const absl::Cord& graph)
+ : tmp_dir_(StripTrailingPathSeparator(tmp_dir.c_str())),
+ session_(tensorflow::NewSession(tensorflow::SessionOptions{})) {
+ // Parse GraphDef.
+ tensorflow::GraphDef graph_def;
+ // TODO(team): Replace with ParseFromCord (check if it is available).
+ std::string graph_str;
+ absl::CopyCordToString(graph, &graph_str);
+ if (!graph_def.ParseFromString(graph_str)) {
+ session_status_ = FCP_STATUS(INVALID_ARGUMENT)
+ << "Could not parse GraphDef.";
+ return;
+ }
+ session_status_ = ConvertFromTensorFlowStatus(session_->Create(graph_def));
+}
+
+TfSession::TfSession(const std::filesystem::path& tmp_dir,
+ absl::string_view graph)
+ : TfSession(tmp_dir, absl::Cord(graph)) {}
+
+Result<Unit> TfSession::Ready() {
+ return Result(session_status_).Then(ExpectOk());
+}
+
+Result<Unit> TfSession::RunOp(absl::string_view op) {
+ FCP_TRY(Ready());
+ if (op.empty()) {
+ return Unit{};
+ }
+ TracingSpan<RunTfOp> span(op);
+ std::vector<std::string> target_node_names;
+ target_node_names.emplace_back(op);
+ FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
+ /*inputs=*/{},
+ /*output_tensor_names=*/{}, target_node_names,
+ /*outputs=*/nullptr)));
+ return Unit{};
+}
+
+Result<Unit> TfSession::RunOp(const NamedTensorList& inputs,
+ absl::string_view op) {
+ FCP_TRY(Ready());
+ if (op.empty()) {
+ return Unit{};
+ }
+ std::vector<std::string> target_node_names;
+ target_node_names.emplace_back(op);
+ FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(inputs,
+ /*output_tensor_names=*/{},
+ target_node_names,
+ /*outputs=*/nullptr)));
+ return Unit{};
+}
+
+Result<std::unique_ptr<TfSession::NamedTensorMap>> TfSession::GetOutputs(
+ std::unique_ptr<std::vector<std::string>> output_names) {
+ FCP_TRY(Ready());
+ auto outputs = std::make_unique<TfSession::NamedTensorMap>();
+ if (output_names->empty()) {
+ return std::move(outputs);
+ }
+ std::vector<tensorflow::Tensor> output_list;
+ FCP_TRY(TF_STATUS_EXPECT_OK(session_->Run(
+ /*inputs=*/{}, *output_names,
+ /*target_tensor_names=*/{}, &output_list)));
+ FCP_CHECK(output_names->size() == output_list.size());
+ for (int i = 0; i < output_names->size(); i++) {
+ outputs->emplace(std::move((*output_names)[i]), std::move(output_list[i]));
+ }
+ return std::move(outputs);
+}
+
+void DeleteTmpFile(const std::string& tmp_file_name) {
+ if (std::remove(tmp_file_name.c_str()) > 0) {
+ Trace<TmpFileNotDeleted>(tmp_file_name);
+ }
+}
+
+Result<absl::Cord> TfSession::SaveState(const CheckpointOp& op) {
+ FCP_TRY(Ready());
+ TracingSpan<SaveToCheckpoint> span(
+ op.before_save_op(),
+ op.has_saver_def() ? op.saver_def().save_tensor_name() : "",
+ op.after_save_op());
+ FCP_TRY(RunOp(op.before_save_op()));
+ Result<absl::Cord> res = absl::Cord("");
+ if (op.has_saver_def()) {
+ const tensorflow::SaverDef& def = op.saver_def();
+ absl::string_view save_op = def.save_tensor_name();
+ // TODO(team): Workaround due to difference between python and c++
+ // TensorFlow APIs.
+ save_op = absl::StripSuffix(save_op, ":0");
+ std::string tmp_file_name = GetTmpCheckpointFileName("save_checkpoint");
+ res =
+ RunOp({{def.filename_tensor_name(), tensorflow::Tensor(tmp_file_name)}},
+ save_op)
+ .Then([&tmp_file_name](Unit u) -> Result<StatusOr<absl::Cord>> {
+ return Result(fcp::ReadFileToCord(tmp_file_name));
+ })
+ .Then(ExpectOk());
+ DeleteTmpFile(tmp_file_name);
+ }
+ FCP_TRY(RunOp(op.after_save_op()));
+ return res;
+}
+
+Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
+ const absl::Cord& checkpoint) {
+ FCP_TRY(Ready());
+ TracingSpan<RestoreFromCheckpoint> span(
+ op.before_restore_op(),
+ op.has_saver_def() ? op.saver_def().restore_op_name() : "",
+ op.after_restore_op());
+ FCP_TRY(RunOp(op.before_restore_op()));
+ Result<Unit> res = Unit{};
+ if (op.has_saver_def()) {
+ const tensorflow::SaverDef& def = op.saver_def();
+ std::string tmp_file_name = GetTmpCheckpointFileName("restore_checkpoint");
+ res = Result(fcp::WriteCordToFile(tmp_file_name, checkpoint))
+ .Then(ExpectOk())
+ .Then([this, &def, &tmp_file_name](Unit u) -> Result<Unit> {
+ return RunOp({{def.filename_tensor_name(),
+ tensorflow::Tensor(tmp_file_name)}},
+ def.restore_op_name());
+ });
+ DeleteTmpFile(tmp_file_name);
+ }
+ FCP_TRY(RunOp(op.after_restore_op()));
+ return res;
+}
+
+Result<Unit> TfSession::RestoreState(const CheckpointOp& op,
+ const NamedTensorList& restore_inputs) {
+ FCP_TRY(Ready());
+ TracingSpan<RestoreFromTensors> span(op.before_restore_op(),
+ op.after_restore_op());
+ if (op.has_saver_def()) {
+ return TraceError<InvalidCheckpointOp>(
+ "saver_def",
+ "Cannot call RestoreState with a list of named tensors with a "
+ "checkpoint op containing a SaverDef.");
+ }
+ FCP_TRY(RunOp(restore_inputs, op.before_restore_op()));
+ return RunOp(op.after_restore_op());
+}
+
+std::string TfSession::GetTmpCheckpointFileName(absl::string_view name) {
+ return ConcatPath(
+ tmp_dir_, absl::StrCat(name, ProcessUniqueId::Next().value(), ".ckp"));
+}
+
+} // namespace fcp
diff --git a/fcp/tensorflow/tf_session.h b/fcp/tensorflow/tf_session.h
new file mode 100644
index 0000000..e061984
--- /dev/null
+++ b/fcp/tensorflow/tf_session.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TENSORFLOW_TF_SESSION_H_
+#define FCP_TENSORFLOW_TF_SESSION_H_
+
+#include <filesystem>
+#include <string>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/string_view.h"
+#include "fcp/base/result.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/tensorflow/tracing_schema.h"
+#include "fcp/tracing/tracing_span.h"
+#include "tensorflow/core/public/session.h"
+
+namespace fcp {
+
+class TfSession {
+ public:
+ /**
+ * Starts a tensorflow client session with the provided graph def
+ * @param tmp_dir A directory in which to create tmp files used while saving
+ * or restoring checkpoints. This directory can be the same for multiple
+ * TfSessions created in the same process, even if they are running
+ * concurrently, but it must not be the same directory passed to a
+ * TfSession in a different process.
+ * @param graph Serialized graph describing how to aggregate client updates
+ * into a global model. Must be parseable into a tesnorflow::GraphDef
+ * proto.
+ */
+ TfSession(const std::filesystem::path& tmp_dir, const absl::Cord& graph);
+ TfSession(const std::filesystem::path& tmp_dir, absl::string_view graph);
+
+ // TfSession is neither copyable nor movable.
+ TfSession(const TfSession&) = delete;
+ TfSession& operator=(const TfSession&) = delete;
+
+ using NamedTensorList =
+ std::vector<std::pair<std::string, tensorflow::Tensor>>;
+ using NamedTensorMap = absl::flat_hash_map<std::string, tensorflow::Tensor>;
+
+ // Returns Error if the TfSession is in a bad state (for example if the
+ // provided GraphDef was invalid.) Allows failing fast while recording a
+ // useful error for debugging.
+ // If Ready() returns Error, all other methods will return Error as well.
+ Result<Unit> Ready();
+
+ // Run a single operation only if the operation is nonempty. The operation
+ // must be present in the GraphDef that was provided in the constructor.
+ Result<Unit> RunOp(absl::string_view op);
+
+ // Returns a map of name, output tensor pairs for the outputs specified by
+ // output_names.
+ Result<std::unique_ptr<NamedTensorMap>> GetOutputs(
+ std::unique_ptr<std::vector<std::string>> output_names);
+
+ /**
+ * Saves the current state of the session.
+ * @param op Contains instructions for how to save the session state.
+ * @return the state of the session as a serialized checkpoint.
+ */
+ Result<absl::Cord> SaveState(
+ const google::internal::federated::plan::CheckpointOp& op);
+
+ /**
+ * Restores state into the session.
+ * @param op Contains instructions for operations to run to restore the
+ * state.
+ * @param checkpoint Serialized tensorflow checkpoint that should be loaded
+ * into the session.
+ */
+ Result<Unit> RestoreState(
+ const google::internal::federated::plan::CheckpointOp& op,
+ const absl::Cord& checkpoint);
+
+ /**
+ * Restores state into the session.
+ * @param op Contains instructions for operations to run to restore the state.
+ * saver_def must not be set on the op.
+ * @param restore_inputs A collection of tensor variables that should be
+ * loaded into the session.
+ */
+ Result<Unit> RestoreState(
+ const google::internal::federated::plan::CheckpointOp& op,
+ const NamedTensorList& restore_inputs);
+
+ private:
+ // Overload to allow providing inputs to operations.
+ Result<Unit> RunOp(const NamedTensorList& inputs, absl::string_view op);
+ std::string GetTmpCheckpointFileName(absl::string_view name);
+
+ std::string tmp_dir_;
+ std::unique_ptr<tensorflow::Session> session_;
+ fcp::Status session_status_;
+};
+
+} // namespace fcp
+
+#endif // FCP_TENSORFLOW_TF_SESSION_H_
diff --git a/fcp/tensorflow/tf_session_test.cc b/fcp/tensorflow/tf_session_test.cc
new file mode 100644
index 0000000..b035c2f
--- /dev/null
+++ b/fcp/tensorflow/tf_session_test.cc
@@ -0,0 +1,296 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/tensorflow/tf_session.h"
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "fcp/base/tracing_schema.h"
+#include "fcp/protos/plan.pb.h"
+#include "fcp/tensorflow/testing/tf_helper.h"
+#include "fcp/tensorflow/tracing_schema.h"
+#include "fcp/testing/result_matchers.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/math_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/protobuf/saver.pb.h"
+
+namespace fcp {
+
+using google::internal::federated::plan::CheckpointOp;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::ops::Add;
+using tensorflow::ops::Assign;
+using tensorflow::ops::Const;
+using tensorflow::ops::Mul;
+using tensorflow::ops::Placeholder;
+using tensorflow::ops::Restore;
+using tensorflow::ops::Save;
+using tensorflow::ops::Variable;
+using tensorflow::test::AsTensor;
+using tensorflow::test::ExpectTensorEqual;
+using testing::_;
+using testing::Not;
+
+template <typename T>
+void CheckOutput(TfSession* sess, const std::string& output_op,
+ Tensor expected) {
+ Result<std::unique_ptr<TfSession::NamedTensorMap>> outputs =
+ sess->GetOutputs(std::make_unique<std::vector<std::string>>(
+ std::initializer_list<std::string>{output_op}));
+ EXPECT_THAT(outputs, Not(IsError()));
+ ExpectTensorEqual<T>((*outputs.GetValueOrDie())[output_op], expected);
+}
+
+TEST(TfSessionTest, InitializeWithEmptyGraph) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ TestTracingRecorder tracing_recorder;
+ TfSession sess("foo/bar", CreateGraph(&root));
+ ASSERT_THAT(sess.Ready(), Not(IsError()));
+ // Running an empty operation is a no-op.
+ EXPECT_THAT(sess.RunOp(""), Not(IsError()));
+ // Getting an empty list of outputs is a no-op.
+ EXPECT_THAT(sess.GetOutputs(std::make_unique<std::vector<std::string>>()),
+ Not(IsError()));
+ // There are no ops registered in the GraphDef, so trying to run an op won't
+ // work.
+ tracing_recorder.ExpectError<ResultExpectStatusError>();
+ EXPECT_THAT(sess.RunOp("sum"), IsError());
+ // Validate the expected hierarchy of tracing spans. There should be only one
+ // RunTfOp span, as we don't want to bother recording a noop if the op is
+ // empty.
+ EXPECT_THAT(tracing_recorder.root(),
+ ElementsAre(AllOf(
+ IsSpan<RunTfOp>(),
+ ElementsAre(IsEvent<ResultExpectStatusError>(
+ static_cast<int>(fcp::OK),
+ static_cast<int>(fcp::INVALID_ARGUMENT), _, _, _)))));
+}
+
+TEST(TfSessionTest, InvalidGraphBytes) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ TestTracingRecorder tracing_recorder;
+ tracing_recorder.ExpectError<ResultExpectStatusError>();
+ TfSession sess("foo/bar", "garbage");
+ ASSERT_THAT(sess.Ready(), IsError());
+ EXPECT_THAT(tracing_recorder.root(),
+ ElementsAre(IsEvent<ResultExpectStatusError>(
+ static_cast<int>(fcp::OK),
+ static_cast<int>(fcp::INVALID_ARGUMENT), _, _, _)));
+}
+
+TEST(TfSessionTest, RunGraphOp) {
+ // Construct a TensorFlow graph with all desired operations.
+ // This graph just assigns the result of multiplying two constants "a" and "b"
+ // to a variable "c", and makes it possible to output "c".
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto a = Const<int32_t>(root, {{1, 2}, {3, 4}});
+ auto b = Const<int32_t>(root, {{2}});
+ auto c = Variable(root.WithOpName("c"), {2, 2}, tensorflow::DT_INT32);
+ auto assign_c = Assign(root.WithOpName("assign_c"), c, Mul(root, a, b));
+
+ // Run a session using the graph constructed above.
+ TestTracingRecorder tracing_recorder;
+ TfSession sess("foo/bar", CreateGraph(&root));
+ ASSERT_THAT(sess.Ready(), Not(IsError()));
+
+ // Run an operation on the session and validate the result.
+ EXPECT_THAT(sess.RunOp("assign_c"), Not(IsError()));
+ CheckOutput<int32_t>(&sess, "c",
+ AsTensor<int32_t>({2, 4, 6, 8}, TensorShape({2, 2})));
+}
+
+TEST(TfSessionTest, RestoreFromTensor) {
+ // Construct a TensorFlow graph with all desired operations.
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto input = Placeholder(root.WithOpName("i"), tensorflow::DT_INT32);
+ auto a = Variable(root.WithOpName("a"), {2, 2}, tensorflow::DT_INT32);
+ auto restore = Assign(root.WithOpName("restore_a"), a, input);
+ auto double_a = Assign(root.WithOpName("double_a"), a,
+ Mul(root, a, Const<int32_t>(root, {{2}})));
+
+ // Run a session using the graph constructed above.
+ TestTracingRecorder tracing_recorder;
+ TfSession sess(testing::TempDir(), CreateGraph(&root));
+ ASSERT_THAT(sess.Ready(), Not(IsError()));
+
+ CheckpointOp restore_checkpoint_op;
+ restore_checkpoint_op.set_before_restore_op("restore_a");
+ restore_checkpoint_op.set_after_restore_op("double_a");
+
+ tensorflow::Input::Initializer i({{1, 2}, {3, 4}});
+ EXPECT_THAT(sess.RestoreState(restore_checkpoint_op, {{"i", i.tensor}}),
+ Not(IsError()));
+
+ CheckOutput<int32_t>(&sess, "a",
+ AsTensor<int32_t>({2, 4, 6, 8}, TensorShape({2, 2})));
+}
+
+TEST(TfSessionTest, RestoreFromTensorNoSaverDefAllowed) {
+ // Construct a TensorFlow graph with all desired operations.
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto input = Placeholder(root.WithOpName("i"), tensorflow::DT_INT32);
+ auto a = Variable(root, {2, 2}, tensorflow::DT_INT32);
+ auto restore = Assign(root.WithOpName("restore_a"), a, input);
+ auto double_a = Assign(root.WithOpName("double_a"), a,
+ Mul(root, a, Const<int32_t>(root, {{2}})));
+
+ // Run a session using the graph constructed above.
+ TestTracingRecorder tracing_recorder;
+ tracing_recorder.ExpectError<InvalidCheckpointOp>();
+ TfSession sess(testing::TempDir(), CreateGraph(&root));
+ ASSERT_THAT(sess.Ready(), Not(IsError()));
+
+ CheckpointOp restore_checkpoint_op;
+ restore_checkpoint_op.set_before_restore_op("restore_a");
+ restore_checkpoint_op.mutable_saver_def()->set_restore_op_name("restore");
+ restore_checkpoint_op.mutable_saver_def()->set_filename_tensor_name(
+ "filename");
+ restore_checkpoint_op.set_after_restore_op("double_a");
+
+ tensorflow::Input::Initializer i({{1, 2}, {3, 4}});
+ EXPECT_THAT(sess.RestoreState(restore_checkpoint_op, {{"i", i.tensor}}),
+ IsError());
+}
+
+TEST(TfSessionTest, SaveAndRestoreCheckpointBytes) {
+ // Construct a TensorFlow graph with all desired operations.
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto a = Const<int32_t>(root, {{1, 2}, {3, 4}});
+ // Save the current value of constant "a" in a serialized checkpoint.
+ auto filename =
+ Placeholder(root.WithOpName("filename"), tensorflow::DT_STRING);
+ auto save_a = Save(root.WithOpName("save"), filename, {"a"},
+ std::initializer_list<tensorflow::Input>{a});
+ // Restore the value saved in the serialized checkpoint to variable "c".
+ auto c = Variable(root.WithOpName("c"), {2, 2}, tensorflow::DT_INT32);
+ auto restore = Assign(root.WithOpName("restore"), c,
+ Restore(root, filename, "a", tensorflow::DT_INT32));
+
+ // Run a session using the graph constructed above.
+ TestTracingRecorder tracing_recorder;
+ TfSession sess(testing::TempDir(), CreateGraph(&root));
+ ASSERT_THAT(sess.Ready(), Not(IsError()));
+
+ // Save to a checkpoint.
+ CheckpointOp save_checkpoint_op;
+ save_checkpoint_op.mutable_saver_def()->set_save_tensor_name("save");
+ save_checkpoint_op.mutable_saver_def()->set_filename_tensor_name("filename");
+ Result<absl::Cord> save_res = sess.SaveState(save_checkpoint_op);
+ EXPECT_THAT(save_res, Not(IsError()));
+
+ // Restore from that checkpoint.
+ CheckpointOp restore_checkpoint_op;
+ restore_checkpoint_op.mutable_saver_def()->set_restore_op_name("restore");
+ restore_checkpoint_op.mutable_saver_def()->set_filename_tensor_name(
+ "filename");
+ EXPECT_THAT(
+ sess.RestoreState(restore_checkpoint_op, save_res.GetValueOrDie()),
+ Not(IsError()));
+
+ // Verify the value of variable "c" was loaded properly from the checkpoint.
+ CheckOutput<int32_t>(&sess, "c",
+ AsTensor<int32_t>({1, 2, 3, 4}, TensorShape({2, 2})));
+}
+
+TEST(TfSessionTest, SaveCheckpointBytesSaveOpInTensorFormat) {
+ // Construct a TensorFlow graph with all desired operations.
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto a = Const<int32_t>(root, {{1, 2}, {3, 4}});
+ // Save the current value of variable "a" in a serialized checkpoint.
+ auto filename =
+ Placeholder(root.WithOpName("filename"), tensorflow::DT_STRING);
+ auto save_a = Save(root.WithOpName("save"), filename, {"a"},
+ std::initializer_list<tensorflow::Input>{a});
+
+ // Run a session using the graph constructed above.
+ TestTracingRecorder tracing_recorder;
+ TfSession sess(testing::TempDir(), CreateGraph(&root));
+ ASSERT_THAT(sess.Ready(), Not(IsError()));
+
+ // Ensure that attempting to save doesn't return an error even if the save op
+ // is provided in tensor format (with a trailing ":0")
+ CheckpointOp save_checkpoint_op;
+ save_checkpoint_op.mutable_saver_def()->set_save_tensor_name("save:0");
+ save_checkpoint_op.mutable_saver_def()->set_filename_tensor_name("filename");
+ Result<absl::Cord> save_res = sess.SaveState(save_checkpoint_op);
+ EXPECT_THAT(save_res, Not(IsError()));
+}
+
+TEST(TfSessionTest, SaveAndRestoreWithBeforeAndAfterOps) {
+ // Construct a TensorFlow graph with all desired operations.
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto a = Variable(root.WithOpName("a"), {2, 2}, tensorflow::DT_INT32);
+ auto b = Variable(root, {1, 1}, tensorflow::DT_INT32);
+ auto init_a = Assign(root.WithOpName("init_a"), a,
+ Const<int32_t>(root, {{1, 2}, {3, 4}}));
+ auto init_b =
+ Assign(root.WithOpName("init_b"), b, Const<int32_t>(root, {{2}}));
+ auto mul_a = Assign(root.WithOpName("mul_a"), a, Mul(root, a, b));
+ auto inc_b = Assign(root.WithOpName("inc_b"), b,
+ Add(root, b, Const<int32_t>(root, {{1}})));
+ // Save the current value of variable "a" in a serialized checkpoint.
+ auto filename =
+ Placeholder(root.WithOpName("filename"), tensorflow::DT_STRING);
+ auto save_a = Save(root.WithOpName("save"), filename, {"a"},
+ std::initializer_list<tensorflow::Input>{a});
+ // Restore the value saved in the serialized checkpoint to variable "a".
+ auto restore = Assign(root.WithOpName("restore"), a,
+ Restore(root, filename, "a", tensorflow::DT_INT32));
+
+ // Run a session using the graph constructed above.
+ TestTracingRecorder tracing_recorder;
+ TfSession sess(testing::TempDir(), CreateGraph(&root));
+ ASSERT_THAT(sess.Ready(), Not(IsError()));
+ EXPECT_THAT(sess.RunOp("init_a"), Not(IsError()));
+ EXPECT_THAT(sess.RunOp("init_b"), Not(IsError()));
+
+ // Set "a = a * b" and save that value to a checkpoint, then reset "a" to its
+ // initial state.
+ CheckpointOp save_checkpoint_op;
+ save_checkpoint_op.set_before_save_op("mul_a");
+ save_checkpoint_op.mutable_saver_def()->set_save_tensor_name("save");
+ save_checkpoint_op.mutable_saver_def()->set_filename_tensor_name("filename");
+ save_checkpoint_op.set_after_save_op("init_a");
+ Result<absl::Cord> save_res = sess.SaveState(save_checkpoint_op);
+ EXPECT_THAT(save_res, Not(IsError()));
+ // Check that the value of variable "a" has been reset to the initial value by
+ // the after_save_op.
+ CheckOutput<int32_t>(&sess, "a",
+ AsTensor<int32_t>({1, 2, 3, 4}, TensorShape({2, 2})));
+
+ // Increment "b" to 3 in the before_restore_op, set "a" to the value from the
+ // checkpoint, then set "a = a * b".
+ CheckpointOp restore_checkpoint_op;
+ restore_checkpoint_op.set_before_restore_op("inc_b");
+ restore_checkpoint_op.mutable_saver_def()->set_restore_op_name("restore");
+ restore_checkpoint_op.mutable_saver_def()->set_filename_tensor_name(
+ "filename");
+ restore_checkpoint_op.set_after_restore_op("mul_a");
+ EXPECT_THAT(
+ sess.RestoreState(restore_checkpoint_op, save_res.GetValueOrDie()),
+ Not(IsError()));
+ // The initial value of "a" should have been multiplied by 2 in the
+ // before_save_op and multiplied by 3 in the after_restore_op.
+ CheckOutput<int32_t>(&sess, "a",
+ AsTensor<int32_t>({6, 12, 18, 24}, TensorShape({2, 2})));
+}
+
+} // namespace fcp
diff --git a/fcp/tensorflow/tf_smoke_test.cc b/fcp/tensorflow/tf_smoke_test.cc
new file mode 100644
index 0000000..215ce1a
--- /dev/null
+++ b/fcp/tensorflow/tf_smoke_test.cc
@@ -0,0 +1,64 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Runs a trivial TensorFlow session - just to know we can actually build it
+ * correctly.
+ */
+
+#include <stdint.h>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+
+namespace {
+
+using ::testing::Eq;
+
+using tensorflow::ClientSession;
+using tensorflow::Scope;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::test::AsTensor;
+using tensorflow::test::ExpectTensorEqual;
+using tensorflow::ops::Const;
+using tensorflow::ops::Mul;
+
+TEST(TfSmokeTest, DoTrickyMath) {
+ Scope root = Scope::NewRootScope();
+ auto a = Const<int32_t>(root, { {1, 2}, {3, 4} });
+ auto b = Const<int32_t>(root, { {2} });
+ auto r = Mul(root.WithOpName("r"), a, b);
+ std::vector<Tensor> outputs;
+
+ ClientSession session(root);
+ TF_CHECK_OK(session.Run({r}, &outputs));
+
+ Tensor expected = AsTensor<int32_t>(
+ {2, 4, 6, 8},
+ TensorShape({2, 2}));
+
+ EXPECT_THAT(outputs.size(), Eq(1));
+ ExpectTensorEqual<int32_t>(outputs[0], expected);
+}
+
+} // namespace
diff --git a/fcp/tensorflow/tracing_schema.fbs b/fcp/tensorflow/tracing_schema.fbs
new file mode 100644
index 0000000..ddad788
--- /dev/null
+++ b/fcp/tensorflow/tracing_schema.fbs
@@ -0,0 +1,45 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table RunTfOp(tag: "TFOP", span) {
+ op: string;
+}
+
+table SaveToCheckpoint(tag: "SCPT", span) {
+ before_save_op: string;
+ save_op: string;
+ after_save_op: string;
+}
+
+table RestoreFromCheckpoint(tag: "RCPT", span) {
+ before_restore_op: string;
+ restore_op: string;
+ after_restore_op: string;
+}
+
+table RestoreFromTensors(tag: "RFTS", span) {
+ before_restore_op: string;
+ after_restore_op: string;
+}
+
+table InvalidCheckpointOp (tag: "TFCO", error) {
+ field: string; // The name of the field in the checkpoint op that is invalid.
+ message: string; // Information about why the provided value is invalid.
+}
+
+table TmpFileNotDeleted (tag: "TMPD", warning) {
+ field: string; // The name of the file that could not be deleted successfully.
+} \ No newline at end of file
diff --git a/fcp/testdata/federation_client_only_plan.pb b/fcp/testdata/federation_client_only_plan.pb
new file mode 100644
index 0000000..03a0945
--- /dev/null
+++ b/fcp/testdata/federation_client_only_plan.pb
Binary files differ
diff --git a/fcp/testdata/federation_proxy_train_examples.pb b/fcp/testdata/federation_proxy_train_examples.pb
new file mode 100644
index 0000000..3623742
--- /dev/null
+++ b/fcp/testdata/federation_proxy_train_examples.pb
Binary files differ
diff --git a/fcp/testdata/federation_test_checkpoint.client.ckp b/fcp/testdata/federation_test_checkpoint.client.ckp
new file mode 100644
index 0000000..1a8858c
--- /dev/null
+++ b/fcp/testdata/federation_test_checkpoint.client.ckp
Binary files differ
diff --git a/fcp/testdata/federation_test_select_checkpoints.pb b/fcp/testdata/federation_test_select_checkpoints.pb
new file mode 100644
index 0000000..3623742
--- /dev/null
+++ b/fcp/testdata/federation_test_select_checkpoints.pb
Binary files differ
diff --git a/fcp/testing/BUILD b/fcp/testing/BUILD
new file mode 100644
index 0000000..6039d1b
--- /dev/null
+++ b/fcp/testing/BUILD
@@ -0,0 +1,120 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+load("//fcp/tracing:build_defs.bzl", "tracing_schema_cc_library")
+
+package(
+ default_visibility = ["//fcp:internal"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+tracing_schema_cc_library(
+ name = "tracing_schema",
+ srcs = ["tracing_schema.fbs"],
+)
+
+cc_library(
+ name = "result_matchers",
+ testonly = True,
+ srcs = [
+ ],
+ hdrs = [
+ "result_matchers.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/base:error",
+ "//fcp/base:result",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "parse_text_proto",
+ testonly = 1,
+ hdrs = ["parse_text_proto.h"],
+ copts = FCP_COPTS,
+ deps = [
+ "//fcp/base",
+ "@com_google_absl//absl/strings",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "testing",
+ testonly = 1,
+ srcs = ["testing.cc"],
+ hdrs = ["testing.h"],
+ copts = FCP_COPTS,
+ deps = [
+ ":parse_text_proto",
+ ":result_matchers",
+ ":tracing_schema",
+ "//fcp/base",
+ "//fcp/base:error",
+ "//fcp/base:result",
+ "//fcp/base:source_location",
+ "//fcp/tracing",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "result_matchers_test",
+ srcs = [
+ "result_matchers_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":result_matchers",
+ ":testing",
+ "//fcp/base:error",
+ "//fcp/base:result",
+ "//fcp/base:tracing_schema",
+ "//fcp/tracing:test_tracing_recorder",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "testing_test",
+ size = "small",
+ srcs = ["testing_test.cc"],
+ args = ["--baseline_path=$(location :testdata/verify_baseline_test.baseline)"],
+ copts = FCP_COPTS,
+ data = [":testdata/verify_baseline_test.baseline"],
+ deps = [
+ ":test_messages_cc_proto",
+ ":testing",
+ "//fcp/base",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+proto_library(
+ name = "test_messages_proto",
+ srcs = ["test_messages.proto"],
+)
+
+cc_proto_library(
+ name = "test_messages_cc_proto",
+ deps = [":test_messages_proto"],
+)
diff --git a/fcp/testing/parse_text_proto.h b/fcp/testing/parse_text_proto.h
new file mode 100644
index 0000000..c3d498c
--- /dev/null
+++ b/fcp/testing/parse_text_proto.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TESTING_PARSE_TEXT_PROTO_H_
+#define FCP_TESTING_PARSE_TEXT_PROTO_H_
+
+#include <type_traits>
+
+#include "google/protobuf/text_format.h"
+#include "absl/strings/string_view.h"
+#include "fcp/base/monitoring.h"
+
+namespace fcp {
+
+// Convenience macro for parsing text formatted protos in test code.
+// The input string should include only the proto fields but not the proto
+// itself. For example:
+//
+// const MyProtoType foo = PARSE_TEXT_PROTO("foo:1 sub { bar:2 }");
+// const MyProtoType bar = PARSE_TEXT_PROTO(R"(
+// foo: 1
+// sub {
+// bar: 2
+// })");
+//
+// Note that the output of the macro has to be assigned to proper proto message
+// type in order for the parsing to work.
+#define PARSE_TEXT_PROTO(STR) ParseProtoHelper(STR)
+
+class ParseProtoHelper {
+ public:
+ explicit ParseProtoHelper(absl::string_view string_view)
+ : string_view_(string_view) {}
+
+ template <class T>
+ operator T() { // NOLINT
+ static_assert(std::is_base_of<google::protobuf::Message, T>::value &&
+ !std::is_same<google::protobuf::Message, T>::value);
+ T msg;
+ FCP_CHECK(google::protobuf::TextFormat::ParseFromString(
+ std::string(string_view_), // NOLINT(OSS)
+ &msg));
+ return msg;
+ }
+
+ private:
+ absl::string_view string_view_;
+};
+
+} // namespace fcp
+
+#endif // FCP_TESTING_PARSE_TEXT_PROTO_H_
diff --git a/fcp/testing/result_matchers.h b/fcp/testing/result_matchers.h
new file mode 100644
index 0000000..ecebea2
--- /dev/null
+++ b/fcp/testing/result_matchers.h
@@ -0,0 +1,132 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TESTING_RESULT_MATCHERS_H_
+#define FCP_TESTING_RESULT_MATCHERS_H_
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/base/error.h"
+#include "fcp/base/result.h"
+
+namespace fcp {
+
+// Allows to formulate test expectation on a result containing error as:
+// EXPECT_THAT(result, IsError());
+MATCHER(IsError, "") { return arg.is_error(); }
+
+// Allows to formulate test expectation on a non-error result with existing
+// gtest matchers (such as Eq) as:
+// EXPECT_THAT(result, HasValue(Eq(value)));
+template <typename MatcherType>
+class HasValueMatcher {
+ public:
+ explicit HasValueMatcher(MatcherType matcher)
+ : matcher_(std::move(matcher)) {}
+
+ template <typename TargetType>
+ operator testing::Matcher<TargetType>() const { // NOLINT
+ using D = std::remove_cv_t<std::remove_reference_t<TargetType>>;
+ static_assert(result_internal::ResultTraits<D>::is_result());
+ using V = typename result_internal::ResultTraits<D>::ValueType;
+ return testing::Matcher<TargetType>(
+ new Impl<V>(testing::SafeMatcherCast<V const&>(matcher_)));
+ }
+
+ private:
+ template <typename ValueType>
+ class Impl : public testing::MatcherInterface<Result<ValueType> const&> {
+ public:
+ explicit Impl(testing::Matcher<ValueType const&> matcher)
+ : concrete_matcher_(std::move(matcher)) {}
+
+ bool MatchAndExplain(
+ Result<ValueType> const& arg,
+ testing::MatchResultListener* result_listener) const override;
+
+ void DescribeTo(std::ostream* os) const override {
+ *os << FormatDescription(false);
+ }
+
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << FormatDescription(true);
+ }
+
+ private:
+ std::string FormatDescription(bool negation) const;
+ testing::Matcher<ValueType const&> concrete_matcher_;
+ };
+
+ MatcherType matcher_;
+};
+
+template <typename MatcherType>
+HasValueMatcher<MatcherType> HasValue(MatcherType matcher) {
+ return HasValueMatcher<MatcherType>(std::move(matcher));
+}
+
+template <typename MatcherType>
+template <typename ValueType>
+bool HasValueMatcher<MatcherType>::Impl<ValueType>::MatchAndExplain(
+ Result<ValueType> const& arg,
+ testing::MatchResultListener* result_listener) const {
+ if (arg.is_error()) {
+ *result_listener << "is error";
+ return false;
+ } else {
+ ValueType const& value = arg.GetValueOrDie();
+ *result_listener << "value = " << testing::PrintToString(value);
+ return testing::ExplainMatchResult(concrete_matcher_, value,
+ result_listener);
+ }
+}
+
+template <typename MatcherType>
+template <typename ValueType>
+std::string HasValueMatcher<MatcherType>::Impl<ValueType>::FormatDescription(
+ bool negation) const {
+ std::stringstream desc;
+ if (negation) {
+ concrete_matcher_.DescribeNegationTo(&desc);
+ } else {
+ concrete_matcher_.DescribeTo(&desc);
+ }
+ return desc.str();
+}
+
+// Expect a particular status for testing failure modes of protocols.
+// Prefer ExpectOk (defined in result.h) for OK status.
+template <fcp::StatusCode Code>
+struct ExpectStatus : public ExpectBase {
+ using ExpectBase::ExpectBase;
+ constexpr explicit ExpectStatus(
+ SourceLocation loc = SourceLocation::current())
+ : ExpectBase(loc) {}
+
+ Result<Unit> operator()(const Status& s) const {
+ if (s.code() == Code) {
+ return Unit{};
+ } else {
+ return TraceUnexpectedStatus(Code, s);
+ }
+ }
+};
+
+} // namespace fcp
+
+#endif // FCP_TESTING_RESULT_MATCHERS_H_
diff --git a/fcp/testing/result_matchers_test.cc b/fcp/testing/result_matchers_test.cc
new file mode 100644
index 0000000..212835b
--- /dev/null
+++ b/fcp/testing/result_matchers_test.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/testing/result_matchers.h"
+
+#include "gtest/gtest.h"
+#include "fcp/base/error.h"
+#include "fcp/base/result.h"
+#include "fcp/base/tracing_schema.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+
+namespace fcp {
+using testing::Eq;
+using testing::Not;
+
+TEST(ExpectTest, HasValueDirect) {
+ EXPECT_THAT(Result<int>(42), HasValue(42));
+ EXPECT_THAT(Result<int>(42), Not(HasValue(24)));
+ EXPECT_THAT(Result<std::string>("foo"), HasValue("foo"));
+ EXPECT_THAT(Result<std::string>("foo"), Not(HasValue("bar")));
+}
+
+TEST(ExpectTest, HasValueEq) {
+ EXPECT_THAT(Result<int>(42), HasValue(Eq(42)));
+ EXPECT_THAT(Result<int>(42), Not(HasValue(Eq(24))));
+ EXPECT_THAT(Result<int>(42), HasValue(Not(Eq(24))));
+ EXPECT_THAT(Result<std::string>("foo"), HasValue(Eq("foo")));
+ EXPECT_THAT(Result<std::string>("foo"), Not(HasValue(Eq("bar"))));
+ EXPECT_THAT(Result<std::string>("foo"), HasValue(Not(Eq("bar"))));
+}
+
+TEST(ExpectTest, ExpectIsError) {
+ EXPECT_THAT(Result<int>(TraceTestError()), IsError());
+ EXPECT_THAT(Result<int>(42), Not(IsError()));
+}
+
+TEST(ExpectTest, ExpectStatus) {
+ TestTracingRecorder recorder;
+ EXPECT_THAT(Result<Status>(FCP_STATUS(INVALID_ARGUMENT))
+ .Then(ExpectStatus<INVALID_ARGUMENT>()),
+ HasValue(Unit{}));
+}
+
+TEST(ExpectTest, ExpectStatusReturnsError) {
+ TestTracingRecorder recorder;
+ recorder.ExpectError<ResultExpectStatusError>();
+ EXPECT_THAT(
+ Result<Status>(FCP_STATUS(OK)).Then(ExpectStatus<INVALID_ARGUMENT>()),
+ IsError());
+}
+
+} // namespace fcp
diff --git a/fcp/testing/test_messages.proto b/fcp/testing/test_messages.proto
new file mode 100644
index 0000000..fcf9275
--- /dev/null
+++ b/fcp/testing/test_messages.proto
@@ -0,0 +1,25 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.testing;
+
+message Foo {
+ string foo = 1;
+}
+
+message Bar {
+ int32 bar = 1;
+}
diff --git a/fcp/testing/testdata/verify_baseline_test.baseline b/fcp/testing/testdata/verify_baseline_test.baseline
new file mode 100644
index 0000000..f4f212a
--- /dev/null
+++ b/fcp/testing/testdata/verify_baseline_test.baseline
@@ -0,0 +1 @@
+Dies ist ein Test. \ No newline at end of file
diff --git a/fcp/testing/testing.cc b/fcp/testing/testing.cc
new file mode 100644
index 0000000..aec54ef
--- /dev/null
+++ b/fcp/testing/testing.cc
@@ -0,0 +1,264 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/testing/testing.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+#include <filesystem>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "fcp/base/base_name.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/testing/tracing_schema.h"
+#include "fcp/tracing/tracing_span.h"
+
+namespace fcp {
+
+std::string TestName() {
+ auto test_info = testing::UnitTest::GetInstance()->current_test_info();
+ return absl::StrReplaceAll(test_info->name(), {{"/", "_"}});
+}
+
+std::string TestCaseName() {
+ auto test_info = testing::UnitTest::GetInstance()->current_test_info();
+ return absl::StrReplaceAll(test_info->test_case_name(), {{"/", "_"}});
+}
+
+std::string GetTestDataPath(absl::string_view relative_path) {
+ auto env = getenv("TEST_SRCDIR");
+ std::string test_srcdir = env ? env : "";
+ return ConcatPath(test_srcdir, ConcatPath("com_google_fcp", relative_path));
+}
+
+std::string TemporaryTestFile(absl::string_view suffix) {
+ return ConcatPath(StripTrailingPathSeparator(testing::TempDir()),
+ absl::StrCat(TestName(), suffix));
+}
+
+namespace {
+
+absl::Status EnsureDirExists(absl::string_view path) {
+ if (FileExists(path)) {
+ return absl::OkStatus();
+ }
+ auto path_str = std::string(path);
+ int error;
+#ifndef _WIN32
+ error = mkdir(path_str.c_str(), 0733);
+#else
+ error = _mkdir(path_str.c_str());
+#endif
+ if (error) {
+ return absl::InternalError(absl::StrCat(
+ "cannot create directory ", path_str, "(error code ", error, ")"));
+ }
+ return absl::OkStatus();
+}
+
+} // namespace
+
+bool ShouldUpdateBaseline() {
+ return getenv("FCP_UPDATE_BASELINE");
+}
+
+namespace {
+
+std::string MakeTempFileName() {
+#ifdef __APPLE__
+// Apple has marked tmpnam as deprecated. As we are compiling with -Werror,
+// turning this off for this case. Apple recommends to use mkstemp instead,
+// but because this opens a file, it's not exactly what we want, and it's not
+// portable. std::filesystem in C++17 should fix this issue.
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
+ return tmpnam(nullptr);
+#ifdef __APPLE__
+#pragma clang diagnostic pop
+#endif
+}
+
+absl::Status ShellCommand(absl::string_view command, std::string* stdout_result,
+ std::string* stderr_result) {
+#ifdef _WIN32
+ return absl::UnimplementedError("ShellCommand not implemented for Windows");
+#else
+ // Prepare command for output redirection.
+ std::string command_str = std::string(command);
+ std::string stdout_file;
+ if (stdout_result != nullptr) {
+ stdout_file = MakeTempFileName();
+ absl::StrAppend(&command_str, " 1>", stdout_file);
+ }
+ std::string stderr_file;
+ if (stderr_result != nullptr) {
+ stderr_file = MakeTempFileName();
+ absl::StrAppend(&command_str, " 2>", stderr_file);
+ }
+
+ // Call the command.
+ int result = std::system(command_str.c_str());
+
+ // Read and remove redirected output.
+ if (stdout_result != nullptr) {
+ auto status_or_result = ReadFileToString(stdout_file);
+ if (status_or_result.ok()) {
+ *stdout_result = status_or_result.value();
+ std::remove(stdout_file.c_str());
+ } else {
+ *stdout_result = "";
+ }
+ }
+ if (stderr_result != nullptr) {
+ auto status_or_result = ReadFileToString(stderr_file);
+ if (status_or_result.ok()) {
+ *stderr_result = status_or_result.value();
+ std::remove(stderr_file.c_str());
+ } else {
+ *stderr_result = "";
+ }
+ }
+
+ // Construct result.
+ if (result != 0) {
+ return absl::InternalError(absl::StrCat(
+ "command execution failed: ", command_str, " returns ", result));
+ } else {
+ return absl::OkStatus();
+ }
+#endif
+}
+
+} // namespace
+
+absl::StatusOr<std::string> ComputeDiff(absl::string_view baseline_file,
+ absl::string_view content) {
+ std::string diff_result;
+ std::string baseline_file_str = GetTestDataPath(baseline_file);
+ if (!FileExists(baseline_file_str)) {
+ diff_result = absl::StrCat("no recorded baseline file ", baseline_file_str);
+ } else {
+#ifndef _WIN32
+ // Expect Unix diff command to be available.
+ auto provided_file = TemporaryTestFile(".provided");
+ auto status = WriteStringToFile(provided_file, content);
+ if (!status.ok()) {
+ return status;
+ }
+ std::string std_out, std_err;
+ status = ShellCommand(
+ absl::StrCat("diff -u ", baseline_file_str, " ", provided_file),
+ &std_out, &std_err);
+ std::remove(provided_file.c_str());
+ if (status.code() != OK) {
+ if (!std_err.empty()) {
+ // Indicates a failure in diff execution itself.
+ return absl::InternalError(absl::StrCat("command failed: ", std_err));
+ }
+ diff_result = std_out;
+ }
+#else // _WIN32
+ // For now we do a simple string compare on Windows.
+ auto status_or_string = ReadFileToString(baseline_file_str);
+ if (!status_or_string.ok()) {
+ return status_or_string.status();
+ }
+ if (status_or_string.value() != content) {
+ diff_result = "baseline and actual differ (see respective files)";
+ }
+#endif
+ }
+ return diff_result;
+}
+
+StatusOr<std::string> VerifyAgainstBaseline(absl::string_view baseline_file,
+ absl::string_view content) {
+ auto status_or_diff_result = ComputeDiff(baseline_file, content);
+ if (!status_or_diff_result.ok()) {
+ return status_or_diff_result;
+ }
+ auto& diff_result = status_or_diff_result.value();
+ if (diff_result.empty()) {
+ // success
+ return status_or_diff_result;
+ }
+
+ // Determine the location where to store the new baseline.
+ std::string new_baseline_file;
+ bool auto_update = false;
+
+ if (new_baseline_file.empty() && ShouldUpdateBaseline()) {
+ new_baseline_file = GetTestDataPath(baseline_file);
+ diff_result =
+ absl::StrCat("\nAutomatically updated baseline file: ", baseline_file);
+ auto_update = true;
+ }
+
+ if (new_baseline_file.empty()) {
+ // Store new baseline file in a TMP location.
+#ifndef _WIN32
+ const char* temp_dir = "/tmp";
+#else
+ const char* temp_dir = getenv("TEMP");
+#endif
+ auto temp_output_dir =
+ ConcatPath(temp_dir, absl::StrCat("fcp_", TestCaseName()));
+ FCP_CHECK_STATUS(EnsureDirExists(temp_output_dir));
+ new_baseline_file = ConcatPath(temp_output_dir, BaseName(baseline_file));
+ absl::StrAppend(&diff_result, "\nNew baseline file: ", new_baseline_file);
+ absl::StrAppend(&diff_result, "\nTo update, use:");
+ absl::StrAppend(&diff_result, "\n\n cp ", new_baseline_file, " ",
+ baseline_file, "\n");
+ }
+
+ if (!auto_update) {
+ absl::StrAppend(&diff_result,
+ "\nTo automatically update baseline files, use");
+ absl::StrAppend(&diff_result,
+ "\nenvironment variable FCP_UPDATE_BASELINE.");
+ }
+
+ // Write the new baseline.
+ auto status = WriteStringToFile(new_baseline_file, content);
+ if (!status.ok()) {
+ return status;
+ }
+
+ // Deliver result.
+ if (auto_update) {
+ FCP_LOG(INFO) << diff_result;
+ diff_result = ""; // make test pass
+ }
+ return diff_result;
+}
+
+StatusMatcher IsCode(StatusCode code) { return StatusMatcher(code); }
+StatusMatcher IsOk() { return IsCode(OK); }
+
+Error TraceTestError(SourceLocation loc) {
+ return TraceError<TestError>(loc.file_name(), loc.line());
+}
+
+} // namespace fcp
diff --git a/fcp/testing/testing.h b/fcp/testing/testing.h
new file mode 100644
index 0000000..2fa855d
--- /dev/null
+++ b/fcp/testing/testing.h
@@ -0,0 +1,240 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TESTING_TESTING_H_
+#define FCP_TESTING_TESTING_H_
+
+#include <iostream>
+#include <memory>
+#include <string>
+#include <type_traits>
+
+#include "google/protobuf/util/message_differencer.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "fcp/base/error.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/base/result.h"
+#include "fcp/base/source_location.h"
+#include "fcp/testing/result_matchers.h"
+
+#include "fcp/testing/parse_text_proto.h"
+
+// This file defines platform dependent utilities for testing,
+// based on the public version of googletest.
+
+namespace fcp {
+
+// A macro for use inside a GTest test that executes the provided code as a
+// function returning a Result and asserts that the return value is not an
+// Error.
+//
+// The code provided to the macro will be much like the code one would write in
+// the body of a regular test, with the differences being that the code must
+// return Result<Unit>, and only EXPECT_* statements are allowed, not ASSERT_*.
+//
+// This makes it possible to greatly simplify the test body by using FCP_TRY(),
+// rather than having to check in the test body that every return value of
+// Result type is not an error.
+//
+// Example:
+//
+// TEST(FooTest, GetFoo) {
+// FCP_EXPECT_NO_ERROR(
+// Foo foo = FCP_TRY(GetFoo());
+// EXPECT_TRUE(foo.HasBar());
+// return Unit{};
+// );
+// }
+#define FCP_EXPECT_NO_ERROR(test_contents) \
+ auto test_fn = []() -> Result<Unit> test_contents; \
+ ASSERT_THAT(test_fn(), testing::Not(IsError()))
+
+// Convenience macros for `EXPECT_THAT(s, IsOk())`, where `s` is either
+// a `Status` or a `StatusOr<T>`.
+// Old versions of the protobuf library define EXPECT_OK as well, so we only
+// conditionally define our version.
+#if !defined(EXPECT_OK)
+#define EXPECT_OK(result) EXPECT_THAT(result, fcp::IsOk())
+#endif
+#define ASSERT_OK(result) ASSERT_THAT(result, fcp::IsOk())
+
+/** Returns the current test's name. */
+std::string TestName();
+
+/**
+ * Gets path to a test data file based on a path relative to project root.
+ */
+std::string GetTestDataPath(absl::string_view relative_path);
+
+/**
+ * Creates a temporary file name with given suffix unique for the running test.
+ */
+std::string TemporaryTestFile(absl::string_view suffix);
+
+/**
+ * Verifies a provided content against an expected stored in a baseline file.
+ * Returns an empty string if both are identical, otherwise a diagnostic
+ * message for error reports.
+ *
+ * A return status of not ok indicates an operational error which made the
+ * comparison impossible.
+ *
+ * The baseline file name must be provided relative to the project root.
+ */
+StatusOr<std::string> VerifyAgainstBaseline(absl::string_view baseline_file,
+ absl::string_view content);
+
+/**
+ * Polymorphic matchers for Status or StatusOr on status code.
+ */
+template <typename T>
+bool IsCode(StatusOr<T> const& x, StatusCode code) {
+ return x.status().code() == code;
+}
+inline bool IsCode(Status const& x, StatusCode code) {
+ return x.code() == code;
+}
+
+template <typename T>
+class StatusMatcherImpl : public ::testing::MatcherInterface<T> {
+ public:
+ explicit StatusMatcherImpl(StatusCode code) : code_(code) {}
+ void DescribeTo(::std::ostream* os) const override {
+ *os << "is " << absl::StatusCodeToString(code_);
+ }
+ void DescribeNegationTo(::std::ostream* os) const override {
+ *os << "is not " << absl::StatusCodeToString(code_);
+ }
+ bool MatchAndExplain(
+ T x, ::testing::MatchResultListener* listener) const override {
+ return IsCode(x, code_);
+ }
+
+ private:
+ StatusCode code_;
+};
+
+class StatusMatcher {
+ public:
+ explicit StatusMatcher(StatusCode code) : code_(code) {}
+
+ template <typename T>
+ operator testing::Matcher<T>() const { // NOLINT
+ return ::testing::MakeMatcher(new StatusMatcherImpl<T>(code_));
+ }
+
+ private:
+ StatusCode code_;
+};
+
+StatusMatcher IsCode(StatusCode code);
+StatusMatcher IsOk();
+
+template <typename T>
+class ProtoMatcherImpl : public ::testing::MatcherInterface<T> {
+ public:
+ explicit ProtoMatcherImpl(const google::protobuf::Message& arg)
+ : arg_(CloneMessage(arg)) {}
+
+ explicit ProtoMatcherImpl(const std::string& arg) : arg_(ParseMessage(arg)) {}
+
+ void DescribeTo(::std::ostream* os) const override {
+ *os << "is " << arg_->DebugString();
+ }
+ void DescribeNegationTo(::std::ostream* os) const override {
+ *os << "is not " << arg_->DebugString();
+ }
+ bool MatchAndExplain(
+ T x, ::testing::MatchResultListener* listener) const override {
+ if (x.GetDescriptor()->full_name() != arg_->GetDescriptor()->full_name()) {
+ *listener << "Argument proto is of type "
+ << arg_->GetDescriptor()->full_name()
+ << " but expected proto of type "
+ << x.GetDescriptor()->full_name();
+ return false;
+ }
+
+ google::protobuf::util::MessageDifferencer differencer;
+ std::string reported_differences;
+ differencer.ReportDifferencesToString(&reported_differences);
+ if (!differencer.Compare(*arg_, x)) {
+ *listener << reported_differences;
+ return false;
+ }
+ return true;
+ }
+
+ private:
+ static std::unique_ptr<google::protobuf::Message> CloneMessage(
+ const google::protobuf::Message& message) {
+ std::unique_ptr<google::protobuf::Message> copy_of_message =
+ absl::WrapUnique(message.New());
+ copy_of_message->CopyFrom(message);
+ return copy_of_message;
+ }
+
+ static std::unique_ptr<google::protobuf::Message> ParseMessage(
+ const std::string& proto_text) {
+ using V = std::remove_cv_t<std::remove_reference_t<T>>;
+ std::unique_ptr<V> message = std::make_unique<V>();
+ *message = PARSE_TEXT_PROTO(proto_text);
+ return message;
+ }
+
+ std::unique_ptr<google::protobuf::Message> arg_;
+};
+
+template <typename T>
+class ProtoMatcher {
+ public:
+ explicit ProtoMatcher(const T& arg) : arg_(arg) {}
+
+ template <typename U>
+ operator testing::Matcher<U>() const { // NOLINT
+ using V = std::remove_cv_t<std::remove_reference_t<U>>;
+ static_assert(std::is_base_of<google::protobuf::Message, V>::value &&
+ !std::is_same<google::protobuf::Message, V>::value);
+ return ::testing::MakeMatcher(new ProtoMatcherImpl<U>(arg_));
+ }
+
+ private:
+ T arg_;
+};
+
+// Proto matcher that takes another proto message reference as an argument.
+template <class T,
+ typename std::enable_if<std::is_base_of<google::protobuf::Message, T>::value &&
+ !std::is_same<google::protobuf::Message, T>::value,
+ int>::type = 0>
+inline ProtoMatcher<T> EqualsProto(const T& arg) {
+ return ProtoMatcher<T>(arg);
+}
+
+// Proto matcher that takes a text proto as an argument.
+inline ProtoMatcher<std::string> EqualsProto(const std::string& arg) {
+ return ProtoMatcher<std::string>(arg);
+}
+
+// Utility function which creates and traces an instance of test error
+Error TraceTestError(SourceLocation loc = SourceLocation::current());
+
+} // namespace fcp
+
+#endif // FCP_TESTING_TESTING_H_
diff --git a/fcp/testing/testing_test.cc b/fcp/testing/testing_test.cc
new file mode 100644
index 0000000..24c3136
--- /dev/null
+++ b/fcp/testing/testing_test.cc
@@ -0,0 +1,101 @@
+/*
+ * Copyright 2017 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fcp/testing/testing.h"
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/testing/test_messages.pb.h"
+
+ABSL_FLAG(std::string, baseline_path, "", "Path to baseline");
+
+namespace fcp {
+
+namespace {
+
+using ::testing::Not;
+
+TEST(TestingTest, TestName) { ASSERT_EQ(TestName(), "TestName"); }
+
+TEST(TestingTest, TestDataPath) {
+ auto path = GetTestDataPath(absl::GetFlag(FLAGS_baseline_path));
+ ASSERT_TRUE(FileExists(path));
+}
+
+TEST(TestingTest, TemporaryTestFile) {
+ auto path = TemporaryTestFile(".dat");
+ ASSERT_EQ(WriteStringToFile(path, "test").code(), OK);
+ ASSERT_EQ(ReadFileToString(path).value(), "test");
+}
+
+TEST(TestingTest, VerifyAgainstBaseline) {
+ auto status_or_diff = VerifyAgainstBaseline(
+ absl::GetFlag(FLAGS_baseline_path), "Dies ist ein Test.");
+ ASSERT_TRUE(status_or_diff.ok()) << status_or_diff.status();
+ if (!status_or_diff.value().empty()) {
+ FAIL() << status_or_diff.value();
+ }
+}
+
+TEST(TestingTest, VerifyAgainstBaselineFailure) {
+ auto status_or_diff = VerifyAgainstBaseline(
+ absl::GetFlag(FLAGS_baseline_path), "Dies ist kein Test.");
+ ASSERT_TRUE(status_or_diff.ok()) << status_or_diff.status();
+ // The actual output of the diff is much dependent on which mode we run
+ // in and on which platform. Hence only test whether *some* thing is reported.
+ ASSERT_FALSE(status_or_diff.value().empty());
+}
+
+TEST(TestingTest, EqualsProtoMessage) {
+ testing::Foo foo1;
+ foo1.set_foo("foo");
+ testing::Foo foo2;
+ foo2.set_foo("foo");
+ ASSERT_THAT(foo1, EqualsProto(foo2));
+}
+
+TEST(TestingTest, NotEqualsProtoMessage) {
+ testing::Foo foo1;
+ foo1.set_foo("foo-1");
+ testing::Foo foo2;
+ foo2.set_foo("foo-2");
+ ASSERT_THAT(foo1, Not(EqualsProto(foo2)));
+}
+
+TEST(TestingTest, NotEqualsProtoMessageType) {
+ testing::Foo foo;
+ foo.set_foo("foo");
+ testing::Bar bar;
+ bar.set_bar(1);
+ ASSERT_THAT(foo, Not(EqualsProto(bar)));
+}
+
+TEST(TestingTest, EqualsProtoMessageText) {
+ testing::Bar bar;
+ bar.set_bar(1);
+ ASSERT_THAT(bar, EqualsProto("bar: 1"));
+}
+
+TEST(TestingTest, NotEqualsProtoMessageText) {
+ testing::Bar bar;
+ bar.set_bar(1);
+ ASSERT_THAT(bar, Not(EqualsProto("bar: 2")));
+}
+
+} // namespace
+
+} // namespace fcp
diff --git a/fcp/testing/tracing_schema.fbs b/fcp/testing/tracing_schema.fbs
new file mode 100644
index 0000000..189ea72
--- /dev/null
+++ b/fcp/testing/tracing_schema.fbs
@@ -0,0 +1,20 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table TestError(tag: "TERR", error) {
+ file_name: string;
+ line: int32;
+}
diff --git a/fcp/tracing/BUILD b/fcp/tracing/BUILD
new file mode 100644
index 0000000..bf2b141
--- /dev/null
+++ b/fcp/tracing/BUILD
@@ -0,0 +1,130 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp:internal"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "tracing_schema_common.fbs",
+ "tracing_severity.h",
+ "tracing_tag.h",
+ "tracing_traits.h",
+])
+
+flatbuffer_cc_library(
+ name = "tracing_schema_common",
+ srcs = ["tracing_schema_common.fbs"],
+ srcs_filegroup_name = "tracing_schema_common_fbs",
+)
+
+cc_library(
+ name = "tracing",
+ srcs = [
+ "text_tracing_recorder_impl.cc",
+ "text_tracing_span_impl.cc",
+ "tracing_recorder_impl.cc",
+ "tracing_span_id.cc",
+ "tracing_span_impl.cc",
+ "tracing_span_ref.cc",
+ "tracing_traits.cc",
+ ],
+ hdrs = [
+ "scoped_tracing_recorder.h",
+ "text_tracing_recorder.h",
+ "text_tracing_recorder_impl.h",
+ "text_tracing_span_impl.h",
+ "tracing_recorder.h",
+ "tracing_recorder_impl.h",
+ "tracing_span.h",
+ "tracing_span_id.h",
+ "tracing_span_impl.h",
+ "tracing_span_ref.h",
+ "tracing_tag.h",
+ "tracing_traits.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":tracing_severity",
+ "//fcp/base",
+ "//fcp/base:error",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@flatbuffers",
+ ],
+)
+
+cc_library(
+ name = "test_tracing_recorder",
+ testonly = True,
+ srcs = [
+ "test_tracing_recorder.cc",
+ "test_tracing_recorder_impl.cc",
+ "test_tracing_span_impl.cc",
+ ],
+ hdrs = [
+ "test_tracing_recorder.h",
+ "test_tracing_recorder_impl.h",
+ "test_tracing_span_impl.h",
+ "tracing_traits.h",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":tracing",
+ ":tracing_severity",
+ "//fcp/base:source_location",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googletest//:gtest_main",
+ "@flatbuffers",
+ ],
+)
+
+cc_library(
+ name = "tracing_severity",
+ hdrs = ["tracing_severity.h"],
+)
+
+cc_library(
+ name = "tracing_context_utils",
+ srcs = ["tracing_context_utils.cc"],
+ hdrs = ["tracing_context_utils.h"],
+ deps = [
+ ":tracing",
+ "//fcp/base",
+ "@com_google_protobuf//:protobuf",
+ "@flatbuffers",
+ ],
+)
+
+cc_test(
+ name = "tracing_tag_test",
+ srcs = [
+ "tracing_tag_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":tracing",
+ "@com_google_absl//absl/hash:hash_testing",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/tracing/build_defs.bzl b/fcp/tracing/build_defs.bzl
new file mode 100644
index 0000000..15dc227
--- /dev/null
+++ b/fcp/tracing/build_defs.bzl
@@ -0,0 +1,104 @@
+"""Build rule for tracing schemas to be used with fcp/tracing library.
+"""
+
+load("//fcp:config.bzl", "FCP_COPTS")
+load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
+
+def tracing_schema_cc_library(
+ name,
+ srcs,
+ includes = [],
+ visibility = None,
+ testonly = None):
+ """Rule to generate tracing-schema C++ files from .fbs file.
+
+ This macro produces the following output:
+
+ <name>: a cc_library including C++ code generated from .fbs file.
+
+ <srcs>.h: header file to be included by the code using the tracing schema.
+
+ <srcs>.bfbs: binary representation of the tracing schema.
+
+ <srcs>_generated.h: a header file produced by flatc. Note: this file is
+ already included by <srcs>.h, not to be consumed directly.
+
+ Args:
+ name: The label for the library, typically "tracing_schema".
+ srcs: Single .fbs source file, typically [ "tracing_schema.fbs" ].
+ includes: optional list of .fbs includes
+ visibility: standard visibility
+ testonly: standard testonly
+ """
+
+ # Validate only 1 src file is specified in the build rule.
+ if (len(srcs) != 1):
+ fail("Only 1 .fbs file can be specified per build rule.")
+
+ # Rule to invoke flatc on the fbs file (produces <name>_generated.h):
+ flatbuffer_cc_library(
+ name = name + "_fb",
+ srcs = srcs,
+ includes =
+ includes + ["//fcp/tracing:tracing_schema_common_fbs"],
+ flatc_args = [
+ "--gen-object-api",
+ "--gen-generated",
+ "--reflect-names",
+ "--bfbs-comments",
+ "--keep-prefix",
+ ],
+ gen_reflections = True,
+ include_paths = [".", "third_party/fcp/tracing"],
+ )
+
+ # Get generated flatbuff files from flatbuffer_cc_library rule.
+ src_bfbs = srcs[0].replace(".fbs", ".bfbs")
+ src_generated_h = srcs[0].replace(".fbs", "_generated.h")
+ src_generated_h_rootpath = "$(rootpath " + src_generated_h + ") "
+ src_bfbs_rootpath = "$(location " + src_bfbs + ") "
+ src_fbs_rootpath = "$(rootpath " + srcs[0] + ")"
+ out_header = srcs[0].replace(".fbs", ".h")
+
+ # Generating <name>.h with additional traits
+ native.genrule(
+ name = name + "_h",
+ srcs = [
+ src_bfbs,
+ src_generated_h,
+ srcs[0],
+ ],
+ outs = [out_header],
+ cmd = ("$(location //fcp/tracing/tools:tracing_traits_generator) " +
+ src_generated_h_rootpath + src_bfbs_rootpath + src_fbs_rootpath + "> $@"),
+ tools = [
+ "//fcp/tracing/tools:tracing_traits_generator",
+ ],
+ )
+
+ # Packaging everything into cc_library:
+ native.cc_library(
+ name = name,
+ hdrs = [
+ ":" + out_header,
+ "//fcp/tracing:tracing_schema_common_generated.h",
+ "//fcp/tracing:tracing_severity.h",
+ "//fcp/tracing:tracing_tag.h",
+ "//fcp/tracing:tracing_traits.h",
+ ],
+ deps = [
+ ":" + name + "_fb",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/memory",
+ "@flatbuffers//:flatbuffers",
+ "//fcp/base",
+ ],
+ data = [
+ srcs[0],
+ "//fcp/tracing:tracing_schema_common.fbs",
+ ],
+ copts = FCP_COPTS,
+ visibility = visibility,
+ testonly = testonly,
+ )
diff --git a/fcp/tracing/scoped_tracing_recorder.h b/fcp/tracing/scoped_tracing_recorder.h
new file mode 100644
index 0000000..1d84f34
--- /dev/null
+++ b/fcp/tracing/scoped_tracing_recorder.h
@@ -0,0 +1,46 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_SCOPED_TRACING_RECORDER_H_
+#define FCP_TRACING_SCOPED_TRACING_RECORDER_H_
+
+#include "fcp/tracing/tracing_recorder.h"
+
+namespace fcp {
+
+// This is an utility class that installs a specified tracing recorder as
+// thread local and uninstalls it automatically when going out of scope.
+class ScopedTracingRecorder {
+ public:
+ explicit ScopedTracingRecorder(TracingRecorder* tracing_recorder)
+ : tracing_recorder_(tracing_recorder) {
+ tracing_recorder_->InstallAsThreadLocal();
+ }
+
+ ~ScopedTracingRecorder() { tracing_recorder_->UninstallAsThreadLocal(); }
+
+ // This class isn't copyable or moveable and can't be created via
+ // new operator.
+ ScopedTracingRecorder(const ScopedTracingRecorder& other) = delete;
+ ScopedTracingRecorder& operator=(const ScopedTracingRecorder& other) = delete;
+ void* operator new(std::size_t) = delete;
+ void* operator new[](std::size_t) = delete;
+
+ private:
+ TracingRecorder* tracing_recorder_;
+};
+
+} // namespace fcp
+
+#endif // FCP_TRACING_SCOPED_TRACING_RECORDER_H_
diff --git a/fcp/tracing/test/BUILD b/fcp/tracing/test/BUILD
new file mode 100644
index 0000000..0f8ce30
--- /dev/null
+++ b/fcp/tracing/test/BUILD
@@ -0,0 +1,99 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+load("//fcp/tracing:build_defs.bzl", "tracing_schema_cc_library")
+
+package(
+ default_visibility = ["//fcp:internal"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+tracing_schema_cc_library(
+ name = "tracing_schema",
+ srcs = ["tracing_schema.fbs"],
+)
+
+cc_test(
+ name = "tracing_test",
+ srcs = [
+ "tracing_test.cc",
+ ],
+ copts = FCP_COPTS,
+ deps = [
+ ":tracing_schema",
+ "//fcp/tracing:test_tracing_recorder",
+ "@com_google_googletest//:gtest_main",
+ "@flatbuffers",
+ ],
+)
+
+cc_test(
+ name = "text_tracing_test",
+ srcs = ["text_tracing_test.cc"],
+ copts = FCP_COPTS,
+ data = glob([
+ "testdata/*.baseline",
+ ]),
+ deps = [
+ ":tracing_schema",
+ "//fcp/base",
+ "//fcp/testing",
+ "//fcp/tracing",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+cc_test(
+ name = "tracing_context_utils_test",
+ srcs = ["tracing_context_utils_test.cc"],
+ copts = FCP_COPTS,
+ deps = [
+ ":test_api_message_cc_proto",
+ "//fcp/tracing:tracing_context_utils",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "thread_local_tracing_recorder_test",
+ srcs = ["thread_local_tracing_recorder_test.cc"],
+ copts = FCP_COPTS,
+ data = glob([
+ "testdata/*.baseline",
+ ]),
+ deps = [
+ ":tracing_schema",
+ "//fcp/base",
+ "//fcp/base:scheduler",
+ "//fcp/testing",
+ "//fcp/tracing",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+proto_library(
+ name = "test_api_message_proto",
+ srcs = ["test_api_message.proto"],
+)
+
+cc_proto_library(
+ name = "test_api_message_cc_proto",
+ deps = [":test_api_message_proto"],
+)
diff --git a/fcp/tracing/test/test_api_message.proto b/fcp/tracing/test/test_api_message.proto
new file mode 100644
index 0000000..a11e19b
--- /dev/null
+++ b/fcp/tracing/test/test_api_message.proto
@@ -0,0 +1,41 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package fcp.tracing.test;
+
+message ApiMessageWithoutContext {
+ string contents = 1;
+}
+
+message ApiMessageWithContext {
+ string contents = 1;
+ string tracing_context = 2;
+}
+
+message ApiMessageWithContextBytes {
+ string contents = 1;
+ bytes tracing_context = 2;
+}
+
+message ApiMessageWithContextInt {
+ string contents = 1;
+ int32 tracing_context = 2;
+}
+
+message TestTracingContext {
+ int32 first = 1;
+ int32 second = 2;
+}
diff --git a/fcp/tracing/test/testdata/Basic.baseline b/fcp/tracing/test/testdata/Basic.baseline
new file mode 100644
index 0000000..d802525
--- /dev/null
+++ b/fcp/tracing/test/testdata/Basic.baseline
@@ -0,0 +1,18 @@
+${TIME} 0: BEGIN
+${TIME} 0: INFO EventFoo{ first: 10, second: 20 }
+${TIME} 1: BEGIN SpanWithId { id: 111 } parent: 0
+${TIME} 1: INFO EventFoo{ first: 222, second: 333 }
+${TIME} 1: ERROR ErrorEvent{ cause: "Oops!" }
+${TIME} 2: BEGIN SpanWithId { id: 999 } parent: 1
+${TIME} 2: INFO EventFoo{ first: 555, second: 666 }
+${TIME} 2: END SpanWithId { id: 999 }
+${TIME} 1: END SpanWithId { id: 111 }
+${TIME} 3: BEGIN SpanWithNoData { } parent: 0
+${TIME} 3: WARNING EventWithNoData{ }
+${TIME} 3: END SpanWithNoData { }
+${TIME} 4: BEGIN SpanWithNoData { } parent: 0
+${TIME} 5: BEGIN SpanWithId { id: 222 } parent: 4
+${TIME} 5: INFO EventBar{ first: 333, second: "Hello world!" }
+${TIME} 5: END SpanWithId { id: 222 }
+${TIME} 4: END SpanWithNoData { }
+${TIME} 0: END
diff --git a/fcp/tracing/test/testdata/ChangeThreadLocal1.baseline b/fcp/tracing/test/testdata/ChangeThreadLocal1.baseline
new file mode 100644
index 0000000..acedfaf
--- /dev/null
+++ b/fcp/tracing/test/testdata/ChangeThreadLocal1.baseline
@@ -0,0 +1,4 @@
+${TIME} ${ID}: BEGIN
+${TIME} ${ID}: BEGIN SpanWithId { id: 1 } parent: 0
+${TIME} ${ID}: END SpanWithId { id: 1 }
+${TIME} ${ID}: END
diff --git a/fcp/tracing/test/testdata/ChangeThreadLocal2.baseline b/fcp/tracing/test/testdata/ChangeThreadLocal2.baseline
new file mode 100644
index 0000000..ea35e9d
--- /dev/null
+++ b/fcp/tracing/test/testdata/ChangeThreadLocal2.baseline
@@ -0,0 +1,4 @@
+${TIME} ${ID}: BEGIN
+${TIME} ${ID}: BEGIN SpanWithId { id: 2 } parent: 0
+${TIME} ${ID}: END SpanWithId { id: 2 }
+${TIME} ${ID}: END
diff --git a/fcp/tracing/test/testdata/PerThread1.baseline b/fcp/tracing/test/testdata/PerThread1.baseline
new file mode 100644
index 0000000..d3589c8
--- /dev/null
+++ b/fcp/tracing/test/testdata/PerThread1.baseline
@@ -0,0 +1,9 @@
+${TIME} ${ID}: BEGIN
+${TIME} ${ID}: BEGIN SpanWithId { id: 1 } parent: 0
+${TIME} ${ID}: INFO EventFoo{ first: 11, second: 111 }
+${TIME} ${ID}: INFO EventFoo{ first: 11, second: 111 }
+${TIME} ${ID}: INFO EventFoo{ first: 11, second: 111 }
+${TIME} ${ID}: INFO EventFoo{ first: 11, second: 111 }
+${TIME} ${ID}: INFO EventFoo{ first: 11, second: 111 }
+${TIME} ${ID}: END SpanWithId { id: 1 }
+${TIME} ${ID}: END
diff --git a/fcp/tracing/test/testdata/PerThread2.baseline b/fcp/tracing/test/testdata/PerThread2.baseline
new file mode 100644
index 0000000..fc13939
--- /dev/null
+++ b/fcp/tracing/test/testdata/PerThread2.baseline
@@ -0,0 +1,9 @@
+${TIME} ${ID}: BEGIN
+${TIME} ${ID}: BEGIN SpanWithId { id: 2 } parent: 0
+${TIME} ${ID}: INFO EventFoo{ first: 22, second: 222 }
+${TIME} ${ID}: INFO EventFoo{ first: 22, second: 222 }
+${TIME} ${ID}: INFO EventFoo{ first: 22, second: 222 }
+${TIME} ${ID}: INFO EventFoo{ first: 22, second: 222 }
+${TIME} ${ID}: INFO EventFoo{ first: 22, second: 222 }
+${TIME} ${ID}: END SpanWithId { id: 2 }
+${TIME} ${ID}: END
diff --git a/fcp/tracing/test/text_tracing_test.cc b/fcp/tracing/test/text_tracing_test.cc
new file mode 100644
index 0000000..a4af38e
--- /dev/null
+++ b/fcp/tracing/test/text_tracing_test.cc
@@ -0,0 +1,99 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fstream>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/test/tracing_schema.h"
+#include "fcp/tracing/text_tracing_recorder.h"
+#include "re2/re2.h"
+
+namespace fcp {
+namespace {
+
+constexpr char kBaselineDir[] = "fcp/tracing/test/testdata";
+
+bool PostProcessOutput(std::string* input) {
+ RE2 timestamp_pattern("\\d{4}-\\d{2}-\\d{2}T[[:^blank:]]*");
+ return RE2::GlobalReplace(input, timestamp_pattern, "${TIME}") > 0;
+}
+
+TEST(Tracing, Basic) {
+ std::string out_file =
+ ConcatPath(testing::TempDir(), absl::StrCat(TestName(), ".out"));
+ {
+ TextTracingRecorder p(out_file, absl::UTCTimeZone());
+ p.InstallAsGlobal();
+ Trace<EventFoo>(10, 20);
+ {
+ TracingSpan<SpanWithId> inner(111);
+ Trace<EventFoo>(222, 333);
+ auto ignored = TraceError<ErrorEvent>("Oops!");
+ (void)ignored;
+ {
+ TracingSpan<SpanWithId> very_inner(999);
+ Trace<EventFoo>(555, 666);
+ }
+ }
+ {
+ TracingSpan<SpanWithNoData> inner;
+ Trace<EventWithNoData>();
+ }
+ {
+ auto long_running_span =
+ std::make_unique<UnscopedTracingSpan<SpanWithNoData>>(
+ TracingSpanRef::Top());
+ TracingSpan<SpanWithId> foo_inner(long_running_span->Ref(), 222);
+ Trace<EventBar>(333, "Hello world!");
+ }
+ }
+
+ // Reading out file
+ std::string report = ReadFileToString(out_file).value();
+ ASSERT_TRUE(PostProcessOutput(&report));
+ // Producing report which is expected to precisely match .baseline file.
+ std::ostringstream expected;
+ expected << "" << std::endl;
+
+ // Compare produced report with baseline.
+ std::string baseline_path =
+ ConcatPath(kBaselineDir, absl::StrCat(TestName(), ".baseline"));
+ auto status_s = VerifyAgainstBaseline(baseline_path, report);
+ ASSERT_TRUE(status_s.ok()) << status_s.status();
+ auto& diff = status_s.value();
+ if (!diff.empty()) {
+ FAIL() << diff;
+ }
+}
+
+TEST(Tracing, TimestampReplace) {
+ std::string timestamp = "2019-10-24T22:07:07.916321247+00:00";
+ ASSERT_TRUE(PostProcessOutput(&timestamp));
+ ASSERT_EQ(timestamp, "${TIME}");
+}
+
+TEST(Tracing, DefaultProvider) {
+ // This just triggers default stderr logging codepath, without verifying it
+ Trace<EventBar>(444, "Hello world!");
+ TracingSpan<SpanWithId> inner(111);
+ Trace<EventFoo>(222, 333);
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/tracing/test/thread_local_tracing_recorder_test.cc b/fcp/tracing/test/thread_local_tracing_recorder_test.cc
new file mode 100644
index 0000000..8bceaa4
--- /dev/null
+++ b/fcp/tracing/test/thread_local_tracing_recorder_test.cc
@@ -0,0 +1,191 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fstream>
+#include <memory>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/base/scheduler.h"
+#include "fcp/testing/testing.h"
+#include "fcp/tracing/scoped_tracing_recorder.h"
+#include "fcp/tracing/test/tracing_schema.h"
+#include "fcp/tracing/text_tracing_recorder.h"
+#include "re2/re2.h"
+
+constexpr char kBaselineDir[] = "fcp/tracing/test/testdata";
+
+namespace fcp {
+namespace {
+
+// Replaces timestamp with ${TIME} and span ID with ${ID} in text trace output.
+// Span IDs need to be replaced because of the lack of determinism in running
+// multiple threads in parallel.
+inline bool PostProcessOutput(std::string* input) {
+ RE2 timestamp_and_id_pattern("\\d{4}-\\d{2}-\\d{2}T[[:^blank:]]*\\s\\d+");
+ return RE2::GlobalReplace(input, timestamp_and_id_pattern, "${TIME} ${ID}") >
+ 0;
+}
+
+std::string GetOutFileName(int id) {
+ return ConcatPath(testing::TempDir(), absl::StrCat(TestName(), id, ".out"));
+}
+
+std::string GetBaselineFileName(int id) {
+ return ConcatPath(kBaselineDir, absl::StrCat(TestName(), id, ".baseline"));
+}
+
+absl::StatusOr<std::string> VerifyAgainstBaseline(int id) {
+ // Reading out file
+ std::string report = ReadFileToString(GetOutFileName(id)).value();
+ EXPECT_TRUE(PostProcessOutput(&report));
+ // Producing report which is expected to precisely match .baseline file.
+ std::ostringstream expected;
+ expected << "" << std::endl;
+
+ // Compare produced report with baseline.
+ std::string baseline_path = GetBaselineFileName(id);
+ return ::fcp::VerifyAgainstBaseline(baseline_path, report);
+}
+
+// Verifies that thread local tracing recorder can be changed on the same
+// thread.
+TEST(Tracing, ChangeThreadLocal) {
+ const int kCount = 2;
+ for (int i = 0; i < kCount; i++) {
+ const int id = i + 1;
+ TextTracingRecorder local_recorder(GetOutFileName(id), absl::UTCTimeZone());
+ ScopedTracingRecorder scoped_recorder(&local_recorder);
+ TracingSpan<SpanWithId> inner(id);
+ }
+
+ for (int i = 0; i < kCount; i++) {
+ const int id = i + 1;
+ auto status_s = VerifyAgainstBaseline(id);
+ ASSERT_TRUE(status_s.ok()) << status_s.status();
+ auto& diff = status_s.value();
+ if (!diff.empty()) {
+ FAIL() << diff;
+ }
+ }
+}
+
+TEST(Tracing, PerThread) {
+ const int kThreadCount = 2;
+ auto scheduler = CreateThreadPoolScheduler(kThreadCount);
+
+ for (int i = 0; i < kThreadCount; i++) {
+ scheduler->Schedule([&, i]() {
+ const int id = i + 1;
+ TextTracingRecorder local_recorder(GetOutFileName(id),
+ absl::UTCTimeZone());
+ ScopedTracingRecorder scoped_recorder(&local_recorder);
+ TracingSpan<SpanWithId> inner(id);
+ for (int k = 0; k < 5; k++) {
+ absl::SleepFor(absl::Milliseconds(10));
+ Trace<EventFoo>(id * 11, id * 111);
+ }
+ });
+ }
+
+ scheduler->WaitUntilIdle();
+
+ for (int i = 0; i < kThreadCount; i++) {
+ const int id = i + 1;
+ auto status_s = VerifyAgainstBaseline(id);
+ ASSERT_TRUE(status_s.ok()) << status_s.status();
+ auto& diff = status_s.value();
+ if (!diff.empty()) {
+ FAIL() << diff;
+ }
+ }
+}
+
+TEST(Tracing, UninstallRequired) {
+ auto local_recorder =
+ std::make_shared<TextTracingRecorder>(absl::UTCTimeZone());
+ local_recorder->InstallAsThreadLocal();
+ ASSERT_DEATH(
+ local_recorder.reset(),
+ "Trace recorder must not be set as thread local at destruction time");
+ // Note that ASSERT_DEATH statement above runs in a separate process so it is
+ // still OK to uninstall the trace recorder here; otherwise this process
+ // would crash too on destruction of the trace recorder.
+ local_recorder->UninstallAsThreadLocal();
+}
+
+// Tests that setting the same tracing recorder is OK and that the number of
+// InstallAsThreadLocal and UninstallAsThreadLocal must be matching.
+TEST(Tracing, ReentrancySuccess) {
+ auto local_recorder =
+ std::make_shared<TextTracingRecorder>(absl::UTCTimeZone());
+ local_recorder->InstallAsThreadLocal();
+ local_recorder->InstallAsThreadLocal();
+ local_recorder->UninstallAsThreadLocal();
+ local_recorder->UninstallAsThreadLocal();
+}
+
+// Verifies that not matching the number of InstallAsThreadLocal with
+// UninstallAsThreadLocal results in a failure.
+TEST(Tracing, ReentrancyFailure) {
+ auto local_recorder =
+ std::make_shared<TextTracingRecorder>(absl::UTCTimeZone());
+ local_recorder->InstallAsThreadLocal();
+ // This simulates re-entracy by setting the same tracing recorder as
+ // thread local again.
+ local_recorder->InstallAsThreadLocal();
+ local_recorder->UninstallAsThreadLocal();
+ // At this point UninstallAsThreadLocal has been called only once, which isn't
+ // sufficient.
+ ASSERT_DEATH(
+ local_recorder.reset(),
+ "Trace recorder must not be set as thread local at destruction time");
+ // Note that ASSERT_DEATH statement above runs in a separate process so it is
+ // still necessary to uninstall the trace recorder here to make sure that
+ // the test doesn't crash in the main test process.
+ local_recorder->UninstallAsThreadLocal();
+}
+
+// Test that changing per-thread tracing recorder isn't allowed without
+// uninstalling first.
+TEST(Tracing, ChangingThreadLocalRecorderFails) {
+ TextTracingRecorder local_recorder1(absl::UTCTimeZone());
+ TextTracingRecorder local_recorder2(absl::UTCTimeZone());
+ local_recorder1.InstallAsThreadLocal();
+ ASSERT_DEATH(local_recorder2.InstallAsThreadLocal(),
+ "Only one tracing recorder instance per thread is supported");
+ // Note that ASSERT_DEATH statement above runs in a separate process so
+ // uninstalling local_recorder1 is still needed in the main test process.
+ local_recorder1.UninstallAsThreadLocal();
+}
+
+TEST(Tracing, UninstallingWrongThreadLocalRecorderFails) {
+ TextTracingRecorder local_recorder1(absl::UTCTimeZone());
+ TextTracingRecorder local_recorder2(absl::UTCTimeZone());
+ local_recorder1.InstallAsThreadLocal();
+ ASSERT_DEATH(local_recorder2.UninstallAsThreadLocal(),
+ "Attempting to uninstall thread local tracing recorder that "
+ "isn't currently installed");
+ // Note that ASSERT_DEATH statement above runs in a separate process so
+ // uninstalling local_recorder1 is still needed in the main test process.
+ local_recorder1.UninstallAsThreadLocal();
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/tracing/test/tracing_context_utils_test.cc b/fcp/tracing/test/tracing_context_utils_test.cc
new file mode 100644
index 0000000..5da5efe
--- /dev/null
+++ b/fcp/tracing/test/tracing_context_utils_test.cc
@@ -0,0 +1,98 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/tracing_context_utils.h"
+
+#include <iostream>
+#include <optional>
+#include <ostream>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "fcp/tracing/test/test_api_message.pb.h"
+
+namespace fcp {
+namespace {
+
+using fcp::tracing::test::ApiMessageWithContext;
+using fcp::tracing::test::ApiMessageWithContextBytes;
+using fcp::tracing::test::ApiMessageWithContextInt;
+using fcp::tracing::test::ApiMessageWithoutContext;
+using fcp::tracing::test::TestTracingContext;
+
+TEST(Tracing, SetAndRetrieveContextOnMessage) {
+ TestTracingContext original_context;
+ original_context.set_first(222);
+ original_context.set_second(333);
+
+ ApiMessageWithContext message;
+ fcp::tracing_internal::SetTracingContextOnMessage(original_context, message);
+
+ TestTracingContext context =
+ fcp::tracing_internal::GetContextFromMessage<TestTracingContext>(message);
+
+ EXPECT_EQ(context.first(), 222);
+ EXPECT_EQ(context.second(), 333);
+}
+
+TEST(Tracing, SetAndRetrieveContextBytesOnMessage) {
+ TestTracingContext original_context;
+ original_context.set_first(222);
+ original_context.set_second(333);
+
+ ApiMessageWithContextBytes message;
+ fcp::tracing_internal::SetTracingContextOnMessage(original_context, message);
+
+ TestTracingContext context =
+ fcp::tracing_internal::GetContextFromMessage<TestTracingContext>(message);
+
+ EXPECT_EQ(context.first(), 222);
+ EXPECT_EQ(context.second(), 333);
+}
+
+TEST(Tracing, MessageWithoutContext) {
+ TestTracingContext original_context;
+ original_context.set_first(222);
+ ApiMessageWithoutContext message;
+ // Setting the context on a message without it will be a no-op.
+ fcp::tracing_internal::SetTracingContextOnMessage(original_context, message);
+
+ TestTracingContext context =
+ fcp::tracing_internal::GetContextFromMessage<TestTracingContext>(message);
+
+ EXPECT_EQ(context.first(), 0);
+ EXPECT_EQ(context.second(), 0);
+}
+
+TEST(Tracing, SetTracingContextOnMessageWithIntContextCheckFailure) {
+ TestTracingContext original_context;
+ ApiMessageWithContextInt message;
+ // Setting the context on a message with the wrong context type will be a
+ // no-op.
+ EXPECT_DEATH(fcp::tracing_internal::SetTracingContextOnMessage(
+ original_context, message),
+ fcp::tracing_internal::kContextWrongTypeMessage);
+}
+
+TEST(Tracing, GetTracingContextFromMessageWithIntContextCheckFailure) {
+ ApiMessageWithContextInt message;
+ EXPECT_DEATH(
+ TestTracingContext context =
+ fcp::tracing_internal::GetContextFromMessage<TestTracingContext>(
+ message),
+ fcp::tracing_internal::kContextWrongTypeMessage);
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/tracing/test/tracing_schema.fbs b/fcp/tracing/test/tracing_schema.fbs
new file mode 100644
index 0000000..52fc15b
--- /dev/null
+++ b/fcp/tracing/test/tracing_schema.fbs
@@ -0,0 +1,83 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table SpanWithId (tag: "SWID", span) {
+ id: int32;
+}
+
+table SpanWithNoData (tag: "SWND", span) {
+}
+
+table SpanNeverLogged (tag: "SNLG", span) {
+}
+
+table EventFoo (tag: "EFOO") {
+ first: int32;
+ second: int32;
+}
+
+table EventBar (tag: "EBAR") {
+ first: int32;
+ second: string;
+}
+
+table EventWithNoData (warning, tag: "EVND") {
+}
+
+table EventNeverLogged (tag: "ENLG") {
+}
+
+table ErrorEvent (error, tag: "EERR") {
+ cause: string;
+}
+
+table DeprecatedInt (tag: "DEPI") {
+ first: int32 (deprecated);
+ second: int32;
+}
+
+table AllTypes (tag: "ALLT") {
+ fieldz: byte;
+ fieldy: ubyte;
+ fieldx: bool;
+ fieldw: short;
+ fieldv: ushort;
+ fieldu: int;
+ fieldt: uint;
+ fields: float;
+ fieldr: long;
+ fieldq: ulong;
+ fieldp: double;
+ fieldo: string;
+}
+
+table FieldOrder (tag: "FORD") {
+ fieldz: int;
+ fieldy: int;
+ fieldx: string;
+}
+
+table OrderWithIds (tag: "ORDI") {
+ fieldz: int (id: 1);
+ fieldy: int (id: 2);
+ fieldx: string (id: 0);
+}
+
+enum Color : byte { Red = 0, Green = 1, Blue = 2 }
+
+table ColorEnum (tag: "CLEN") {
+ color: Color;
+}
diff --git a/fcp/tracing/test/tracing_test.cc b/fcp/tracing/test/tracing_test.cc
new file mode 100644
index 0000000..037e966
--- /dev/null
+++ b/fcp/tracing/test/tracing_test.cc
@@ -0,0 +1,393 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+#include <string>
+#include <thread> // NOLINT(build/c++11)
+
+#include "gtest/gtest.h"
+#include "fcp/tracing/test/tracing_schema.h"
+#include "fcp/tracing/test_tracing_recorder.h"
+#include "fcp/tracing/tracing_severity.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace fcp {
+namespace {
+
+using flatbuffers::FlatBufferBuilder;
+using flatbuffers::GetRoot;
+using testing::_;
+using testing::ElementsAre;
+using testing::Eq;
+using testing::Gt;
+using testing::Not;
+using testing::SizeIs;
+using testing::UnorderedElementsAre;
+
+TEST(Tracing, TraitsTag) {
+ EXPECT_EQ(TracingTraits<SpanWithId>::kTag.str(), "SWID");
+ EXPECT_EQ(TracingTraits<EventBar>::kTag.str(), "EBAR");
+}
+
+TEST(Tracing, TracingSeverity) {
+ EXPECT_EQ(TracingTraits<SpanWithId>::kSeverity, fcp::TracingSeverity::kInfo);
+ EXPECT_EQ(TracingTraits<ErrorEvent>::kSeverity, fcp::TracingSeverity::kError);
+ EXPECT_EQ(TracingTraits<EventWithNoData>::kSeverity,
+ fcp::TracingSeverity::kWarning);
+}
+
+TEST(Tracing, TraitsCreate) {
+ FlatBufferBuilder fbb_foo;
+ fbb_foo.Finish(TracingTraits<EventFoo>::Create(222, 333, &fbb_foo));
+ auto foo = GetRoot<EventFoo>(fbb_foo.GetBufferPointer());
+ EXPECT_EQ(foo->first(), 222);
+ EXPECT_EQ(foo->second(), 333);
+
+ // Creating a flat buffer involving a string field has different codegen
+ // path, testing this as well:
+ FlatBufferBuilder fbb_bar;
+ fbb_bar.Finish(
+ TracingTraits<EventBar>::Create(444, "Hello world!", &fbb_bar));
+ auto bar = GetRoot<EventBar>(fbb_bar.GetBufferPointer());
+ EXPECT_EQ(bar->first(), 444);
+ EXPECT_EQ(bar->second()->str(), "Hello world!");
+
+ // Also make sure that a flatbuf involving a string field can be created using
+ // a std::string.
+ FlatBufferBuilder fbb_baz;
+ std::string hello_str = "Hello world!";
+ fbb_baz.Finish(TracingTraits<EventBar>::Create(444, hello_str, &fbb_baz));
+ auto baz = GetRoot<EventBar>(fbb_baz.GetBufferPointer());
+ EXPECT_EQ(baz->first(), 444);
+ EXPECT_EQ(baz->second()->str(), "Hello world!");
+}
+
+TEST(Tracing, TraitsCreateFieldOrder) {
+ int first_i = -333;
+ int second_i = 444;
+ FlatBufferBuilder fbb_foo;
+ fbb_foo.Finish(
+ TracingTraits<FieldOrder>::Create(first_i, second_i, "hello", &fbb_foo));
+ auto foo = GetRoot<FieldOrder>(fbb_foo.GetBufferPointer());
+ EXPECT_EQ(foo->fieldz(), first_i);
+ EXPECT_EQ(foo->fieldy(), second_i);
+ EXPECT_EQ(foo->fieldx()->str(), "hello");
+
+ FlatBufferBuilder fbb_bar;
+ fbb_bar.Finish(TracingTraits<OrderWithIds>::Create("hello", first_i, second_i,
+ &fbb_bar));
+ auto bar = GetRoot<OrderWithIds>(fbb_bar.GetBufferPointer());
+ EXPECT_EQ(bar->fieldz(), first_i);
+ EXPECT_EQ(bar->fieldy(), second_i);
+ EXPECT_EQ(bar->fieldx()->str(), "hello");
+}
+
+TEST(Tracing, TraitsCreateAllTypes) {
+ FlatBufferBuilder fbb;
+ std::int8_t byte = -1;
+ std::uint8_t ubyte = 1;
+ std::int16_t short_i = -256;
+ std::uint16_t ushort_i = 256;
+ int i = -333;
+ unsigned int ui = 444;
+ float f = 1.1;
+ std::int64_t li = -4294967296;
+ std::uint64_t uli = 4294967296;
+ double d = 12312318.99999999;
+ fbb.Finish(TracingTraits<AllTypes>::Create(byte, ubyte, true, short_i,
+ ushort_i, i, ui, f, li, uli, d,
+ "hello", &fbb));
+ auto foo = GetRoot<AllTypes>(fbb.GetBufferPointer());
+ EXPECT_EQ(foo->fieldz(), byte);
+ EXPECT_EQ(foo->fieldy(), ubyte);
+ EXPECT_EQ(foo->fieldx(), true);
+ EXPECT_EQ(foo->fieldw(), short_i);
+ EXPECT_EQ(foo->fieldv(), ushort_i);
+ EXPECT_EQ(foo->fieldu(), i);
+ EXPECT_EQ(foo->fieldt(), ui);
+ EXPECT_EQ(foo->fields(), f);
+ EXPECT_EQ(foo->fieldr(), li);
+ EXPECT_EQ(foo->fieldq(), uli);
+ EXPECT_EQ(foo->fieldp(), d);
+ EXPECT_EQ(foo->fieldo()->str(), "hello");
+}
+
+TEST(Tracing, TraitsCreateEnum) {
+ FlatBufferBuilder fbb;
+ fbb.Finish(TracingTraits<ColorEnum>::Create(Color_Blue, &fbb));
+ auto foo = GetRoot<ColorEnum>(fbb.GetBufferPointer());
+ EXPECT_EQ(foo->color(), Color_Blue);
+}
+
+TEST(Tracing, TraitsCreateDeprecatedField) {
+ FlatBufferBuilder fbb_foo;
+ fbb_foo.Finish(TracingTraits<DeprecatedInt>::Create(222, &fbb_foo));
+ auto foo = GetRoot<EventFoo>(fbb_foo.GetBufferPointer());
+ EXPECT_EQ(foo->second(), 222);
+}
+
+TEST(Tracing, LookupTraitByTag) {
+ EXPECT_EQ(TracingTraitsBase::Lookup(TracingTag("SWID"))->Name(),
+ "SpanWithId");
+ EXPECT_EQ(TracingTraitsBase::Lookup(TracingTag("EBAR"))->Name(), "EventBar");
+}
+
+TEST(Tracing, IntegrationTest) {
+ TestTracingRecorder tracing_recorder;
+ {
+ TracingSpan<SpanWithId> inner(111);
+ Trace<EventFoo>(222, 333);
+ Trace<EventBar>(444, "Hello world!");
+ Trace<EventFoo>(555, 666);
+ }
+ {
+ TracingSpan<SpanWithNoData> inner;
+ Trace<EventWithNoData>();
+ }
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ ElementsAre(AllOf(IsSpan<SpanWithId>(),
+ ElementsAre(IsEvent<EventFoo>(222, 333),
+ IsEvent<EventBar>(444, "Hello world!"),
+ IsEvent<EventFoo>(555, 666))),
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventWithNoData>()))))
+ << "Tracing span/events structure and content must match";
+}
+
+TEST(Tracing, UnscopedSpanIntegrationTest) {
+ TestTracingRecorder tracing_recorder;
+ auto outer = std::make_unique<UnscopedTracingSpan<SpanWithId>>(111);
+ auto inner =
+ std::make_unique<UnscopedTracingSpan<SpanWithId>>(outer->Ref(), 222);
+ {
+ TracingSpan<SpanWithNoData> child_of_inner(inner->Ref());
+ Trace<EventFoo>(333, 444);
+ }
+ {
+ TracingSpan<SpanWithNoData> another_child_of_inner(inner->Ref());
+ Trace<EventFoo>(555, 666);
+ Trace<EventBar>(inner->Ref(), 1, "Trace in unscoped span!");
+ Trace<EventBar>(another_child_of_inner.Ref(), 1,
+ "Trace in explicitly specified tracing span!");
+ }
+ TracingSpan<SpanWithNoData> unrelated_span;
+ Trace<EventBar>(777, "Hello world!");
+
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ ElementsAre(
+ AllOf(IsSpan<SpanWithId>(111),
+ ElementsAre(AllOf(
+ IsSpan<SpanWithId>(222),
+ ElementsAre(
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventFoo>(333, 444))),
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventFoo>(555, 666),
+ IsEvent<EventBar>(
+ 1,
+ "Trace in explicitly specified "
+ "tracing span!"))),
+ IsEvent<EventBar>(1, "Trace in unscoped span!"))))),
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventBar>(777, "Hello world!")))))
+ << "Tracing span/events structure and content must match";
+}
+
+TEST(Tracing, ThreadingUnscopedIntegrationTest) {
+ TestTracingRecorder tracing_recorder;
+ auto outer = std::make_unique<UnscopedTracingSpan<SpanWithId>>(111);
+ std::thread thread1([ref = outer->Ref()]() {
+ TracingSpan<SpanWithNoData> child_of_outer(ref);
+ Trace<EventFoo>(333, 444);
+ });
+ std::thread thread2([ref = outer->Ref()]() {
+ TracingSpan<SpanWithNoData> another_child_of_outer(ref);
+ Trace<EventFoo>(555, 666);
+ Trace<EventBar>(ref, 1, "Trace in unscoped span!");
+ Trace<EventBar>(another_child_of_outer.Ref(), 1, "Trace in local span!");
+ });
+ TracingSpan<SpanWithNoData> unrelated_span;
+ Trace<EventBar>(777, "Hello world!");
+ thread1.join();
+ thread2.join();
+
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ ElementsAre(AllOf(IsSpan<SpanWithId>(111),
+ UnorderedElementsAre(
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventFoo>(333, 444))),
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventFoo>(555, 666),
+ IsEvent<EventBar>(
+ 1, "Trace in local span!"))),
+ IsEvent<EventBar>(1, "Trace in unscoped span!"))),
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventBar>(777, "Hello world!")))))
+ << "Tracing span/events structure and content must match";
+}
+
+TEST(Tracing, ThreadingScopedIntegrationTest) {
+ TestTracingRecorder tracing_recorder;
+ TracingSpan<SpanWithId> outer(111);
+ std::thread thread1([ref = outer.Ref()]() {
+ TracingSpan<SpanWithNoData> child_of_outer(ref);
+ Trace<EventFoo>(333, 444);
+ });
+ std::thread thread2([ref = outer.Ref()]() {
+ TracingSpan<SpanWithNoData> another_child_of_outer(ref);
+ Trace<EventFoo>(555, 666);
+ Trace<EventBar>(ref, 1, "Trace in unscoped span!");
+ Trace<EventBar>(1, "Trace in local span!");
+ });
+ TracingSpan<SpanWithNoData> unrelated_span;
+ Trace<EventBar>(777, "Hello world!");
+ thread1.join();
+ thread2.join();
+
+ EXPECT_THAT(
+ tracing_recorder.root(),
+ ElementsAre(AllOf(
+ IsSpan<SpanWithId>(111),
+ UnorderedElementsAre(
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventFoo>(333, 444))),
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventFoo>(555, 666),
+ IsEvent<EventBar>(1, "Trace in local span!"))),
+ IsEvent<EventBar>(1, "Trace in unscoped span!"),
+ AllOf(IsSpan<SpanWithNoData>(),
+ ElementsAre(IsEvent<EventBar>(777, "Hello world!")))))))
+ << "Tracing span/events structure and content must match";
+}
+
+TEST(Tracing, AdvancedMatching) {
+ TestTracingRecorder tracing_recorder;
+ {
+ TracingSpan<SpanWithId> span(111);
+ Trace<EventBar>(222, "Hello world!");
+ }
+
+ auto span = tracing_recorder.root()[0];
+ auto event = span[0];
+ EXPECT_THAT(span, IsSpan<SpanWithId>());
+ EXPECT_THAT(event, IsEvent<EventBar>());
+ EXPECT_THAT(span,
+ AllOf(IsSpan<SpanWithId>(), ElementsAre(IsEvent<EventBar>())));
+ EXPECT_THAT(span, IsSpan<SpanWithId>(_));
+ EXPECT_THAT(span, IsSpan<SpanWithId>(Eq(111)));
+ EXPECT_THAT(span, IsSpan<SpanWithId>(Gt(100)));
+ EXPECT_THAT(event, IsEvent<EventBar>(Eq(222), Eq("Hello world!")));
+ EXPECT_THAT(event, IsEvent<EventBar>(_, Eq("Hello world!")));
+ EXPECT_THAT(event, IsEvent<EventBar>(_, _));
+ EXPECT_THAT(event, IsEvent<EventBar>(Eq(222), _));
+ EXPECT_THAT(event, IsEvent<EventBar>(Not(Eq(666)), _));
+ EXPECT_THAT(event, Not(IsEvent<EventBar>(Eq(666), _)));
+ EXPECT_THAT(event, Not(IsEvent<EventFoo>()));
+}
+
+TEST(Tracing, MultipleRecorders) {
+ // NOTE: it is not a recommended scenario to have multiple instances of
+ // TestTracingRecorder per unit test, but this code path is enforced to
+ // ensure correct behavior of cleaning up global state so it is not carried
+ // over between tests.
+ {
+ TestTracingRecorder tracing_recorder;
+ Trace<EventFoo>(222, 333);
+ EXPECT_THAT(tracing_recorder.root()[0], IsEvent<EventFoo>(222, 333));
+ }
+ {
+ TestTracingRecorder tracing_recorder;
+ Trace<EventFoo>(444, 555);
+ EXPECT_THAT(tracing_recorder.root()[0], IsEvent<EventFoo>(444, 555));
+ }
+}
+
+TEST(Tracing, TraceError) {
+ TestTracingRecorder tracing_recorder;
+ tracing_recorder.ExpectError<ErrorEvent>();
+ {
+ TracingSpan<SpanWithId> inner(111);
+ Error err = TraceError<ErrorEvent>("there was a bug");
+ (void)err;
+ }
+}
+
+TEST(Tracing, FindOnlySpan) {
+ TestTracingRecorder tracing_recorder;
+ {
+ TracingSpan<SpanWithNoData> outer;
+ { TracingSpan<SpanWithId> inner(111); }
+ EXPECT_EQ(tracing_recorder.FindOnlySpan<SpanWithId>().data()->id(), 111);
+ }
+}
+
+TEST(Tracing, FindAllSpans) {
+ TestTracingRecorder tracing_recorder;
+ {
+ TracingSpan<SpanWithNoData> outer;
+ {
+ TracingSpan<SpanWithId> inner1(111);
+ { TracingSpan<SpanWithId> inner2(222); }
+ }
+ EXPECT_THAT(tracing_recorder.FindAllSpans<SpanWithId>(),
+ ElementsAre(IsSpan<SpanWithId>(), IsSpan<SpanWithId>()));
+ EXPECT_THAT(tracing_recorder.FindAllSpans<SpanNeverLogged>(), SizeIs(0));
+ }
+}
+
+TEST(Tracing, FindOnlyEvent) {
+ TestTracingRecorder tracing_recorder;
+ {
+ TracingSpan<SpanWithNoData> outer;
+ { Trace<EventFoo>(111, 222); }
+ EXPECT_EQ(tracing_recorder.FindOnlyEvent<EventFoo>().data()->first(), 111);
+ EXPECT_EQ(tracing_recorder.FindOnlyEvent<EventFoo>().data()->second(), 222);
+ }
+}
+
+TEST(Tracing, FindAllEvents) {
+ TestTracingRecorder tracing_recorder;
+ {
+ TracingSpan<SpanWithNoData> outer;
+ {
+ Trace<EventFoo>(111, 222);
+ TracingSpan<SpanWithNoData> inner;
+ { Trace<EventFoo>(333, 444); }
+ }
+ EXPECT_THAT(tracing_recorder.FindAllEvents<EventFoo>(),
+ ElementsAre(IsEvent<EventFoo>(), IsEvent<EventFoo>()));
+ EXPECT_THAT(tracing_recorder.FindAllEvents<EventNeverLogged>(), SizeIs(0));
+ }
+}
+TEST(Tracing, CreateJsonString) {
+ FlatBufferBuilder fbb_foo;
+ fbb_foo.Finish(TracingTraits<EventFoo>::Create(222, 333, &fbb_foo));
+ auto foo_buf = fbb_foo.GetBufferPointer();
+ auto foo = GetRoot<EventFoo>(fbb_foo.GetBufferPointer());
+ EXPECT_EQ(foo->first(), 222);
+ EXPECT_EQ(foo->second(), 333);
+
+ TracingTraits<EventFoo> tracing_traits;
+ std::string expected = "{\n first: 222,\n second: 333\n}\n";
+ std::string json_gen = tracing_traits.JsonStringFormat(foo_buf);
+ EXPECT_EQ(expected, json_gen);
+}
+
+// TODO(team) Add Testing for when the flatbuf has a package name
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/tracing/test_tracing_recorder.cc b/fcp/tracing/test_tracing_recorder.cc
new file mode 100644
index 0000000..cc0924d
--- /dev/null
+++ b/fcp/tracing/test_tracing_recorder.cc
@@ -0,0 +1,162 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/test_tracing_recorder.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/tracing/test_tracing_recorder_impl.h"
+#include "fcp/tracing/tracing_severity.h"
+#include "fcp/tracing/tracing_tag.h"
+#include "fcp/tracing/tracing_traits.h"
+
+namespace fcp {
+
+TestTracingRecorder::TestTracingRecorder(SourceLocation loc)
+ : loc_(loc),
+ impl_(std::shared_ptr<tracing_internal::TestTracingRecorderImpl>(
+ new tracing_internal::TestTracingRecorderImpl(this))) {
+ InstallAsGlobal();
+}
+
+TestTracingRecorder::~TestTracingRecorder() {
+ UninstallAsGlobal();
+ std::vector<std::string> unseen_error_names;
+ for (TracingTag tag : unseen_expected_errors_) {
+ unseen_error_names.push_back(TracingTraitsBase::Lookup(tag)->Name());
+ }
+ EXPECT_THAT(unseen_error_names, testing::IsEmpty())
+ << "Errors marked as expected with TestTracingRecorder::ExpectEvent<E>() "
+ << "weren't traced via TestTracingRecorder with source location: "
+ << std::endl
+ << loc_.file_name() << ":" << loc_.line() << std::endl;
+}
+
+TracingTraitsBase const* TraitsForRecord(TestTracingRecorder::Record* record) {
+ return TracingTraitsBase::Lookup(*TracingTag::FromFlatbuf(record->data));
+}
+
+std::string Format(TracingTraitsBase const* traits,
+ TestTracingRecorder::Record* record) {
+ return absl::StrCat(traits->Name(), " ", traits->TextFormat(record->data));
+}
+
+void TestTracingRecorder::OnRoot(TracingSpanId id,
+ flatbuffers::DetachedBuffer data) {
+ auto unique_record = std::make_unique<Record>(id, std::move(data));
+ root_record_ = unique_record.get();
+ {
+ absl::MutexLock locked(&map_lock_);
+ FCP_CHECK(id_to_record_map_.empty());
+ id_to_record_map_[id] = std::move(unique_record);
+ }
+}
+
+void TestTracingRecorder::OnTrace(TracingSpanId parent_id, TracingSpanId id,
+ flatbuffers::DetachedBuffer data) {
+ auto unique_record = std::make_unique<Record>(parent_id, id, std::move(data));
+ // Entries in id_to_record_map_ will never be removed or modified. Due to this
+ // invariant we can be sure that a pointer to the record will not be
+ // invalidated by a concurrent modification of the map causing destruction of
+ // the record.
+ Record* record = unique_record.get();
+ {
+ absl::MutexLock locked(&map_lock_);
+ id_to_record_map_[id] = std::move(unique_record);
+ }
+ TracingTraitsBase const* traits = TraitsForRecord(record);
+ if (traits->Severity() == TracingSeverity::kError) {
+ absl::MutexLock locked(&expected_errors_lock_);
+ // We're interested here in errors only:
+ TracingTag error_tag = *TracingTag::FromFlatbuf(record->data);
+ EXPECT_TRUE(expected_errors_.contains(error_tag))
+ << "Unexpected error " << Format(traits, record)
+ << " is traced via TestTracingRecorder with source location: "
+ << std::endl
+ << loc_.file_name() << ":" << loc_.line() << std::endl
+ << "Use TracingRecorder::ExpectError<" << traits->Name()
+ << ">() in the beginning of the unit "
+ << "test to allowlist this error as expected" << std::endl;
+ unseen_expected_errors_.erase(error_tag);
+ }
+}
+
+TestTracingRecorder::SpanOrEvent::SpanOrEvent(TestTracingRecorder* recorder,
+ Record* record)
+ : record_(record), recorder_(recorder) {
+ std::vector<Record*> children = recorder_->GetChildren(record_);
+
+ children_.reserve(children.size());
+ for (Record* c : children) {
+ children_.emplace_back(recorder_, c);
+ }
+}
+
+std::string TestTracingRecorder::SpanOrEvent::TextFormat() const {
+ return Format(traits(), record_);
+}
+
+TracingTraitsBase const* TestTracingRecorder::SpanOrEvent::traits() const {
+ return TraitsForRecord(record_);
+}
+
+void TestTracingRecorder::InstallAsGlobal() { impl_->InstallAsGlobal(); }
+
+void TestTracingRecorder::UninstallAsGlobal() { impl_->UninstallAsGlobal(); }
+
+void TestTracingRecorder::InstallAsThreadLocal() {
+ impl_->InstallAsThreadLocal();
+}
+
+void TestTracingRecorder::UninstallAsThreadLocal() {
+ impl_->UninstallAsThreadLocal();
+}
+
+struct RecordComparison {
+ bool const operator()(TestTracingRecorder::Record* r1,
+ TestTracingRecorder::Record* r2) const {
+ return r1->id < r2->id;
+ }
+};
+
+std::vector<TestTracingRecorder::Record*> TestTracingRecorder::GetChildren(
+ Record* parent) {
+ std::vector<Record*> children;
+ {
+ absl::MutexLock locked(&map_lock_);
+ // Search through the entire hashmap for children of this parent.
+ // Note that this is O(n) in terms of the total number of traces, rather
+ // than in terms of the number of children.
+ for (const auto& [id, record_unique_ptr] : id_to_record_map_) {
+ Record* record = record_unique_ptr.get();
+ if (record->parent_id.has_value() &&
+ record->parent_id.value() == parent->id) {
+ children.push_back(record);
+ }
+ }
+ }
+ // Sort in order of lowest to highest ID, which should be in order of creation
+ // time.
+ std::sort(children.begin(), children.end(), RecordComparison());
+ return children;
+}
+
+TestTracingRecorder::RootSpan TestTracingRecorder::root() {
+ return RootSpan(this, root_record_);
+}
+
+} // namespace fcp
diff --git a/fcp/tracing/test_tracing_recorder.h b/fcp/tracing/test_tracing_recorder.h
new file mode 100644
index 0000000..6e81183
--- /dev/null
+++ b/fcp/tracing/test_tracing_recorder.h
@@ -0,0 +1,471 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TEST_TRACING_RECORDER_H_
+#define FCP_TRACING_TEST_TRACING_RECORDER_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/source_location.h"
+#include "fcp/tracing/test_tracing_recorder_impl.h"
+#include "fcp/tracing/tracing_recorder.h"
+
+namespace fcp {
+
+// Tracing recorder recording all interactions for use in unit tests.
+// Automatically installs itself to be a global for its lifetime.
+//
+// Provides functionality for setting expected errors and failing if either
+// an error is seen that was not expected, or one of the expected errors is
+// never seen.
+//
+// Also provides functionality for searching for particular tracing spans or
+// events, and verifying expectations about the children of a span or event
+// (if desired, one could write a test that verifies the entire expected tree
+// of traces.) See fcp/tracing/test/tracing_test.cc for examples.
+//
+// Upon creation of a tracing span or event, a Record representing the trace is
+// added to a centrally owned hashmap (which is protected by a lock, as tracing
+// spans and events may be created from multiple threads.) Each record stores
+// a reference to its parent.
+//
+// When verifying the expected traces in a test, the tree structure of the
+// traces is reconstructed by searching through the values in the map for
+// children of a given parent ID. An ordering of siblings is re-established
+// by sorting by the tracing span ID. This strategy is due to the fact that
+// children of a single parent may be created from multiple threads (think a
+// parent spawning many fibers.)
+//
+// By not creating the tree structure until it is needed in a test, we avoid
+// additional locking on a parent node at trace creation time. This does incur
+// a performance penalty when verifying expectations, however- if one is to
+// verify the full tree of traces the overhead to reconstruct the tree will be
+// O(n^2).
+class TestTracingRecorder
+ : public TracingRecorder,
+ tracing_internal::TestTracingRecorderImpl::TraceListener {
+ public:
+ explicit TestTracingRecorder(SourceLocation loc = SourceLocation::current());
+ ~TestTracingRecorder() override;
+
+ // Allowlists a given error type as expected (non-allowlisted errors will
+ // trigger testing assertions.)
+ template <typename FlatBufferTable>
+ void ExpectError();
+
+ struct Record {
+ explicit Record(TracingSpanId parent_id, TracingSpanId id,
+ flatbuffers::DetachedBuffer data)
+ : parent_id(parent_id), id(id), data(std::move(data)) {}
+ explicit Record(TracingSpanId id, flatbuffers::DetachedBuffer data)
+ : parent_id(std::nullopt), id(id), data(std::move(data)) {}
+ std::optional<TracingSpanId> parent_id;
+ TracingSpanId id;
+ flatbuffers::DetachedBuffer data;
+ };
+
+ template <typename FlatBufferTable>
+ class Span;
+ template <typename FlatBufferTable>
+ class Event;
+
+ // Allows to dynamically access properties of spans or event, associated
+ // with an encapsulated tracing record.
+ class SpanOrEvent {
+ public:
+ explicit SpanOrEvent(TestTracingRecorder* recorder, Record* record);
+
+ // To integrate easily woth googletest/gmock, SpanOrEvent behaves like
+ // a collection of children SpanOrEvent objects. Following group of methods
+ // mimic STL collection interface by delegating calls to std::vector.
+ using value_type = SpanOrEvent;
+ using iterator = std::vector<SpanOrEvent>::iterator;
+ using const_iterator = std::vector<SpanOrEvent>::const_iterator;
+ const_iterator begin() const { return children_.begin(); }
+ const_iterator end() const { return children_.end(); }
+ const SpanOrEvent& operator[](size_t idx) { return children_.at(idx); }
+ bool empty() const { return children_.empty(); }
+ size_t size() const { return children_.size(); }
+
+ // Checks if span/event is of certain type:
+ template <typename FlatBufferTable>
+ bool HasType() const;
+
+ // Returns typed flatbuffer pointer (fails if type mismatch)
+ template <typename FlatBufferTable>
+ const FlatBufferTable* data() const;
+
+ // Creates text representation:
+ std::string TextFormat() const;
+
+ // Find all spans of given type recursively:
+ template <typename FlatBufferTable>
+ std::vector<Span<FlatBufferTable>> FindAllSpans();
+
+ // Find all events of given type recursively:
+ template <typename FlatBufferTable>
+ std::vector<Event<FlatBufferTable>> FindAllEvents();
+
+ // Find exactly one span recursively, fails if not found
+ template <typename FlatBufferTable>
+ Span<FlatBufferTable> FindOnlySpan(
+ SourceLocation loc = SourceLocation::current());
+
+ // Find exactly one event recursively, fails if not found
+ template <typename FlatBufferTable>
+ Event<FlatBufferTable> FindOnlyEvent(
+ SourceLocation loc = SourceLocation::current());
+
+ TracingTraitsBase const* traits() const;
+
+ protected:
+ Record* record_;
+
+ private:
+ std::vector<SpanOrEvent> children_;
+ TestTracingRecorder* recorder_;
+ };
+
+ // This is used to access root span
+ class RootSpan : public SpanOrEvent {
+ public:
+ explicit RootSpan(TestTracingRecorder* recorder, Record* record)
+ : SpanOrEvent(recorder, record) {}
+ };
+
+ // Strongly-typed accessor for spans
+ template <typename FlatBufferTable>
+ class Span : public SpanOrEvent {
+ public:
+ explicit Span(TestTracingRecorder* recorder, Record* record);
+
+ // Re-declaring SpanOrEvent::data() for convenience, so there's no need
+ // to specify type argument when calling it.
+ const FlatBufferTable* data() const;
+ };
+
+ // Strongly-typed accessor for events
+ template <typename FlatBufferTable>
+ class Event : public SpanOrEvent {
+ public:
+ explicit Event(TestTracingRecorder* recorder, Record* record);
+
+ // Re-declaring SpanOrEvent::data() for convenience, so there's no need
+ // to specify type argument when calling it.
+ const FlatBufferTable* data() const;
+ };
+
+ RootSpan root();
+
+ template <typename FlatBufferTable>
+ std::vector<Span<FlatBufferTable>> FindAllSpans() {
+ return root().FindAllSpans<FlatBufferTable>();
+ }
+
+ template <typename FlatBufferTable>
+ std::vector<Event<FlatBufferTable>> FindAllEvents() {
+ return root().FindAllEvents<FlatBufferTable>();
+ }
+
+ template <typename FlatBufferTable>
+ Span<FlatBufferTable> FindOnlySpan(
+ SourceLocation loc = SourceLocation::current()) {
+ return root().FindOnlySpan<FlatBufferTable>(loc);
+ }
+
+ template <typename FlatBufferTable>
+ Event<FlatBufferTable> FindOnlyEvent(
+ SourceLocation loc = SourceLocation::current()) {
+ return root().FindOnlyEvent<FlatBufferTable>(loc);
+ }
+
+ private:
+ void InstallAsGlobal() override;
+ void UninstallAsGlobal() override;
+ void InstallAsThreadLocal() override;
+ void UninstallAsThreadLocal() override;
+
+ // Retrieves all children of the trace represented by the provided record
+ // by iterating through a hashmap to identify children that reference this
+ // record's ID as their parent.
+ //
+ // parent is borrowed by this function and must live for the duration of the
+ // function.
+ //
+ // The records in the returned vector are borrowed by the caller and will live
+ // until the destruction of this TestTracingRecorder.
+ std::vector<Record*> GetChildren(Record* parent);
+ // Implements TraceListener interface to react to creation of the root span by
+ // creating a record representing the root span, saving it in a map of all
+ // traces, and saving a pointer to it.
+ void OnRoot(TracingSpanId id, flatbuffers::DetachedBuffer data) override;
+ // Implements TraceListener interface which creates a record
+ // representing the span or event identified by id, adds it to the map of all
+ // traces to make it possible to find for verifying test expectations. Also
+ // checks of this is an event that has error severity and if so, checks that
+ // this is an expected error as registered by ExpectError().
+ void OnTrace(TracingSpanId parent_id, TracingSpanId id,
+ flatbuffers::DetachedBuffer data) override;
+
+ SourceLocation loc_;
+ Record* root_record_ = nullptr;
+ absl::Mutex map_lock_;
+ // Global map for storing traces. This map will only grow during the
+ // lifetime of this TestTracingRecorder (elements will never be removed or
+ // replaced.)
+ absl::flat_hash_map<TracingSpanId, std::unique_ptr<Record>> id_to_record_map_
+ ABSL_GUARDED_BY(map_lock_);
+ std::shared_ptr<tracing_internal::TestTracingRecorderImpl> impl_;
+
+ absl::Mutex expected_errors_lock_;
+ // Expectations registered by the test for error events that should be created
+ // during test execution.
+ absl::flat_hash_set<TracingTag> expected_errors_
+ ABSL_GUARDED_BY(expected_errors_lock_);
+ // Errors that are expected but have not yet been seen. If any are present on
+ // destruction of this object, a precondition check will fail.
+ absl::flat_hash_set<TracingTag> unseen_expected_errors_
+ ABSL_GUARDED_BY(expected_errors_lock_);
+};
+
+template <typename FlatBufferTable>
+void TestTracingRecorder::ExpectError() {
+ absl::MutexLock locked(&expected_errors_lock_);
+ expected_errors_.insert(TracingTraits<FlatBufferTable>::kTag);
+ unseen_expected_errors_.insert(TracingTraits<FlatBufferTable>::kTag);
+}
+
+template <typename FlatBufferTable>
+TestTracingRecorder::Event<FlatBufferTable>::Event(
+ TestTracingRecorder* recorder, Record* record)
+ : SpanOrEvent(recorder, record) {
+ static_assert(!TracingTraits<FlatBufferTable>::kIsSpan,
+ "FlatBufferTable must be an event, not a span");
+}
+
+template <typename FlatBufferTable>
+const FlatBufferTable* TestTracingRecorder::Event<FlatBufferTable>::data()
+ const {
+ return SpanOrEvent::data<FlatBufferTable>();
+}
+
+template <typename FlatBufferTable>
+TestTracingRecorder::Span<FlatBufferTable>::Span(TestTracingRecorder* recorder,
+ Record* record)
+ : SpanOrEvent(recorder, record) {
+ static_assert(TracingTraits<FlatBufferTable>::kIsSpan,
+ "FlatBufferTable must be a span");
+}
+
+template <typename FlatBufferTable>
+const FlatBufferTable* TestTracingRecorder::Span<FlatBufferTable>::data()
+ const {
+ return SpanOrEvent::data<FlatBufferTable>();
+}
+
+template <typename FlatBufferTable>
+std::vector<TestTracingRecorder::Span<FlatBufferTable>>
+TestTracingRecorder::SpanOrEvent::FindAllSpans() {
+ static_assert(TracingTraits<FlatBufferTable>::kIsSpan,
+ "FlatBufferTable must be a span");
+ std::vector<TestTracingRecorder::Span<FlatBufferTable>> result;
+ if (HasType<FlatBufferTable>()) {
+ result.emplace_back(recorder_, record_);
+ }
+ for (auto& c : children_) {
+ auto child_result = c.FindAllSpans<FlatBufferTable>();
+ result.insert(result.end(), child_result.begin(), child_result.end());
+ }
+ return result;
+}
+
+template <typename FlatBufferTable>
+std::vector<TestTracingRecorder::Event<FlatBufferTable>>
+TestTracingRecorder::SpanOrEvent::FindAllEvents() {
+ static_assert(!TracingTraits<FlatBufferTable>::kIsSpan,
+ "FlatBufferTable must be an event not a span");
+ std::vector<TestTracingRecorder::Event<FlatBufferTable>> result;
+ if (HasType<FlatBufferTable>()) {
+ result.emplace_back(recorder_, record_);
+ }
+ for (auto& c : children_) {
+ auto child_result = c.FindAllEvents<FlatBufferTable>();
+ result.insert(result.end(), child_result.begin(), child_result.end());
+ }
+ return result;
+}
+
+template <typename FlatBufferTable>
+TestTracingRecorder::Span<FlatBufferTable>
+TestTracingRecorder::SpanOrEvent::FindOnlySpan(SourceLocation loc) {
+ auto all_spans = FindAllSpans<FlatBufferTable>();
+ EXPECT_THAT(all_spans, testing::SizeIs(1))
+ << "Expected exactly one span of type "
+ << TracingTraits<FlatBufferTable>().Name() << ". " << std::endl
+ << "Source location: " << std::endl
+ << loc.file_name() << ":" << loc.line();
+ return all_spans[0];
+}
+
+template <typename FlatBufferTable>
+TestTracingRecorder::Event<FlatBufferTable>
+TestTracingRecorder::SpanOrEvent::FindOnlyEvent(SourceLocation loc) {
+ auto all_events = FindAllEvents<FlatBufferTable>();
+ EXPECT_THAT(all_events, testing::SizeIs(1))
+ << "Expected exactly one event of type "
+ << TracingTraits<FlatBufferTable>().Name() << ". " << std::endl
+ << "Source location: " << std::endl
+ << loc.file_name() << ":" << loc.line();
+ return all_events[0];
+}
+
+template <typename FlatBufferTable>
+bool TestTracingRecorder::SpanOrEvent::HasType() const {
+ return *TracingTag::FromFlatbuf(record_->data) ==
+ TracingTraits<FlatBufferTable>::kTag;
+}
+
+template <typename FlatBufferTable>
+const FlatBufferTable* TestTracingRecorder::SpanOrEvent::data() const {
+ return flatbuffers::GetRoot<FlatBufferTable>(record_->data.data());
+}
+
+using ::testing::Matcher;
+using ::testing::MatcherInterface;
+using ::testing::MatchResultListener;
+
+inline void PrintTo(const TestTracingRecorder::SpanOrEvent& value,
+ std::ostream* os) {
+ *os << value.TextFormat();
+}
+
+// This wraps std::tuple with method .get<I>() as a member. This
+// allows to compose a tuple matcher from multiple testing::Property matchers,
+// combined with testing::AllOf.
+template <typename Tuple>
+class TupleWrapper {
+ public:
+ explicit TupleWrapper(Tuple const& tuple) : tuple_(tuple) {}
+ template <std::size_t I>
+ const typename std::tuple_element<I, Tuple>::type& get() const {
+ return std::get<I>(tuple_);
+ }
+
+ private:
+ const Tuple tuple_;
+};
+
+// Universal matcher for spans and events, checks for type and content.
+template <typename FlatBufferTable>
+class SpanOrEventTypeMatcher
+ : public MatcherInterface<const TestTracingRecorder::SpanOrEvent&> {
+ public:
+ using TupleType = typename TracingTraits<FlatBufferTable>::TupleType;
+ using ContentMatcher = testing::Matcher<TupleWrapper<TupleType>>;
+
+ explicit SpanOrEventTypeMatcher(absl::string_view kind,
+ ContentMatcher content_matcher)
+ : kind_(kind), content_matcher_(content_matcher) {}
+
+ bool MatchAndExplain(const TestTracingRecorder::SpanOrEvent& value,
+ MatchResultListener* listener) const override {
+ *listener << " { " << value.TextFormat() << " } ";
+ bool result = value.HasType<FlatBufferTable>();
+ if (result) {
+ auto content =
+ TupleWrapper<TupleType>(TracingTraits<FlatBufferTable>::MakeTuple(
+ value.data<FlatBufferTable>()));
+ result = content_matcher_.MatchAndExplain(content, listener);
+ }
+ return result;
+ }
+
+ void DescribeTo(std::ostream* os) const override {
+ *os << "Expecting " << kind_ << " of type "
+ << TracingTraits<FlatBufferTable>().Name() << " with fields ";
+ content_matcher_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << "Expecting NOT " << kind_ << " of type "
+ << TracingTraits<FlatBufferTable>().Name() << " with fields ";
+ content_matcher_.DescribeNegationTo(os);
+ }
+
+ private:
+ std::string kind_;
+ ContentMatcher content_matcher_;
+};
+
+template <typename Tuple, typename... M, std::size_t... I>
+auto MatchTupleWrapper(M... m, std::index_sequence<I...>) {
+ return testing::AllOf(
+ testing::Property(&TupleWrapper<Tuple>::template get<I>, m)...);
+}
+
+template <typename Tuple, typename... M>
+testing::Matcher<TupleWrapper<Tuple>> MatchTupleElements(M... m) {
+ return testing::MatcherCast<TupleWrapper<Tuple>>(
+ MatchTupleWrapper<Tuple, M...>(m...,
+ std::make_index_sequence<sizeof...(M)>{}));
+}
+
+template <typename FlatBufferTable, typename... M>
+Matcher<const TestTracingRecorder::SpanOrEvent&> IsSpan(M... field_matchers) {
+ static_assert(TracingTraits<FlatBufferTable>::kIsSpan,
+ "FlatBufferTable must be a span");
+ if constexpr (sizeof...(M) != 0) {
+ constexpr size_t number_of_fields = std::tuple_size<
+ typename TracingTraits<FlatBufferTable>::TupleType>::value;
+ static_assert(
+ sizeof...(M) == number_of_fields,
+ "Matchers must be provided for every field in FlatBufferTable");
+ return MakeMatcher(new SpanOrEventTypeMatcher<FlatBufferTable>(
+ "span",
+ MatchTupleElements<typename TracingTraits<FlatBufferTable>::TupleType>(
+ field_matchers...)));
+ } else {
+ // When no field matchers provided it should match anything
+ return MakeMatcher(
+ new SpanOrEventTypeMatcher<FlatBufferTable>("span", testing::_));
+ }
+}
+
+template <typename FlatBufferTable, typename... M>
+Matcher<const TestTracingRecorder::SpanOrEvent&> IsEvent(M... field_matchers) {
+ static_assert(!TracingTraits<FlatBufferTable>::kIsSpan,
+ "FlatBufferTable must not be a span");
+ if constexpr (sizeof...(M) != 0) {
+ return MakeMatcher(new SpanOrEventTypeMatcher<FlatBufferTable>(
+ "event",
+ MatchTupleElements<typename TracingTraits<FlatBufferTable>::TupleType>(
+ field_matchers...)));
+ } else {
+ // When no field matchers provided it should match anything
+ return MakeMatcher(
+ new SpanOrEventTypeMatcher<FlatBufferTable>("event", testing::_));
+ }
+}
+
+} // namespace fcp
+
+#endif // FCP_TRACING_TEST_TRACING_RECORDER_H_
diff --git a/fcp/tracing/test_tracing_recorder_impl.cc b/fcp/tracing/test_tracing_recorder_impl.cc
new file mode 100644
index 0000000..8256146
--- /dev/null
+++ b/fcp/tracing/test_tracing_recorder_impl.cc
@@ -0,0 +1,66 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/test_tracing_recorder_impl.h"
+
+#include <memory>
+#include <utility>
+
+#include "fcp/tracing/test_tracing_span_impl.h"
+
+namespace fcp {
+namespace tracing_internal {
+
+using flatbuffers::DetachedBuffer;
+using flatbuffers::FlatBufferBuilder;
+
+DetachedBuffer EmptyFlatBuffer() {
+ FlatBufferBuilder fbb;
+ fbb.Finish(fbb.CreateString("ROOT"), "ROOT");
+ return fbb.Release();
+}
+
+TestTracingRecorderImpl::TestTracingRecorderImpl(TraceListener* trace_listener)
+ : trace_listener_(trace_listener),
+ root_span_(
+ std::make_unique<TestTracingSpanImpl>(this, TracingSpanId(0))) {
+ trace_listener_->OnRoot(TracingSpanId(0), EmptyFlatBuffer());
+}
+
+void TestTracingRecorderImpl::TraceImpl(TracingSpanId id, DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) {
+ TracingSpanId new_id = TracingSpanId::NextUniqueId();
+ trace_listener_->OnTrace(id, new_id, std::move(buf));
+}
+
+TracingSpanImpl* TestTracingRecorderImpl::GetRootSpan() {
+ return root_span_.get();
+}
+
+std::unique_ptr<TracingSpanImpl> TestTracingRecorderImpl::CreateChildSpan(
+ TracingSpanId parent_span_id, flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) {
+ TracingSpanId new_id = TracingSpanId::NextUniqueId();
+ trace_listener_->OnTrace(parent_span_id, new_id, std::move(buf));
+ // NOTE: shared_from_this() is defined in a base class, so it returns
+ // std::shared_ptr<TracingRecorderImpl> and we have to (safely) cast it here:
+ auto shared_this =
+ std::static_pointer_cast<TestTracingRecorderImpl>(shared_from_this());
+ return std::make_unique<TestTracingSpanImpl>(shared_this, new_id);
+}
+
+TestTracingRecorderImpl::~TestTracingRecorderImpl() = default;
+
+} // namespace tracing_internal
+} // namespace fcp
diff --git a/fcp/tracing/test_tracing_recorder_impl.h b/fcp/tracing/test_tracing_recorder_impl.h
new file mode 100644
index 0000000..fc9a59f
--- /dev/null
+++ b/fcp/tracing/test_tracing_recorder_impl.h
@@ -0,0 +1,61 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TEST_TRACING_RECORDER_IMPL_H_
+#define FCP_TRACING_TEST_TRACING_RECORDER_IMPL_H_
+
+#include <memory>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "fcp/tracing/tracing_recorder_impl.h"
+
+namespace fcp {
+namespace tracing_internal {
+
+class TestTracingSpanImpl;
+
+class TestTracingRecorderImpl : public TracingRecorderImpl {
+ public:
+ // Allows listening for/handling traces as they appear.
+ class TraceListener {
+ public:
+ virtual ~TraceListener() = default;
+ // Called exactly once when the root span is created.
+ virtual void OnRoot(TracingSpanId id, flatbuffers::DetachedBuffer data) = 0;
+ // Called when a new span or event is created.
+ // @param id Will only be called once per
+ // unique ID (although it may be called multiple times with the same
+ // parent span ID.) id
+ virtual void OnTrace(TracingSpanId parent_id, TracingSpanId id,
+ flatbuffers::DetachedBuffer data) = 0;
+ };
+ explicit TestTracingRecorderImpl(TraceListener* trace_listener);
+ ~TestTracingRecorderImpl() override;
+ TracingSpanImpl* GetRootSpan() override;
+ void TraceImpl(TracingSpanId id, flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) override;
+ std::unique_ptr<TracingSpanImpl> CreateChildSpan(
+ TracingSpanId parent_span_id, flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) override;
+
+ private:
+ TraceListener* trace_listener_;
+ std::unique_ptr<TestTracingSpanImpl> root_span_;
+};
+
+} // namespace tracing_internal
+} // namespace fcp
+
+#endif // FCP_TRACING_TEST_TRACING_RECORDER_IMPL_H_
diff --git a/fcp/tracing/test_tracing_span_impl.cc b/fcp/tracing/test_tracing_span_impl.cc
new file mode 100644
index 0000000..79b8137
--- /dev/null
+++ b/fcp/tracing/test_tracing_span_impl.cc
@@ -0,0 +1,46 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/test_tracing_span_impl.h"
+
+#include <memory>
+#include <utility>
+
+namespace fcp {
+namespace tracing_internal {
+
+using flatbuffers::DetachedBuffer;
+
+TestTracingSpanImpl::TestTracingSpanImpl(
+ std::shared_ptr<TestTracingRecorderImpl> recorder, TracingSpanId id)
+ : recorder_shared_ptr_(std::move(recorder)), id_(id) {
+ recorder_ = recorder_shared_ptr_.get();
+}
+
+TestTracingSpanImpl::TestTracingSpanImpl(TestTracingRecorderImpl* recorder,
+ TracingSpanId id)
+ : recorder_(recorder), id_(id) {}
+
+void TestTracingSpanImpl::TraceImpl(DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) {
+ recorder_->TraceImpl(id_, std::move(buf), traits);
+}
+
+TestTracingSpanImpl::~TestTracingSpanImpl() = default;
+
+TracingSpanRef TestTracingSpanImpl::Ref() {
+ return TracingSpanRef(recorder_->shared_from_this(), id_);
+}
+} // namespace tracing_internal
+} // namespace fcp
diff --git a/fcp/tracing/test_tracing_span_impl.h b/fcp/tracing/test_tracing_span_impl.h
new file mode 100644
index 0000000..3985431
--- /dev/null
+++ b/fcp/tracing/test_tracing_span_impl.h
@@ -0,0 +1,49 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TEST_TRACING_SPAN_IMPL_H_
+#define FCP_TRACING_TEST_TRACING_SPAN_IMPL_H_
+
+#include "fcp/tracing/test_tracing_recorder_impl.h"
+#include "fcp/tracing/tracing_span_impl.h"
+
+namespace fcp {
+namespace tracing_internal {
+
+class TestTracingSpanImpl : public TracingSpanImpl {
+ public:
+ TestTracingSpanImpl(std::shared_ptr<TestTracingRecorderImpl> recorder,
+ TracingSpanId id);
+ TestTracingSpanImpl(TestTracingRecorderImpl* recorder, TracingSpanId id);
+ ~TestTracingSpanImpl() override;
+
+ void TraceImpl(flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) override;
+ TracingSpanRef Ref() override;
+
+ private:
+ TestTracingRecorderImpl* recorder_;
+
+ // For non-root span the following keeps recorder alive, if set,
+ // this holds the same value as recorder_. Since root span is owned directly
+ // by the recorder we can't store shared_ptr for it here to avoid a loop.
+ std::shared_ptr<TestTracingRecorderImpl> recorder_shared_ptr_;
+
+ TracingSpanId id_;
+};
+
+} // namespace tracing_internal
+} // namespace fcp
+
+#endif // FCP_TRACING_TEST_TRACING_SPAN_IMPL_H_
diff --git a/fcp/tracing/text_tracing_recorder.h b/fcp/tracing/text_tracing_recorder.h
new file mode 100644
index 0000000..7e36da8
--- /dev/null
+++ b/fcp/tracing/text_tracing_recorder.h
@@ -0,0 +1,54 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TEXT_TRACING_RECORDER_H_
+#define FCP_TRACING_TEXT_TRACING_RECORDER_H_
+
+#include <fstream>
+#include <string>
+
+#include "fcp/tracing/text_tracing_recorder_impl.h"
+#include "fcp/tracing/tracing_recorder.h"
+
+namespace fcp {
+
+// Entry point for usage of basic tracing API implementation that simply writes
+// to an output stream in human readable format.
+class TextTracingRecorder : public TracingRecorder {
+ public:
+ // Constructs a TextTracingRecorder which will write events to the filestream
+ // with a timestamp formatted for the provided timezone.
+ explicit TextTracingRecorder(const std::string& filename,
+ absl::TimeZone time_zone)
+ : impl_(std::make_shared<tracing_internal::TextTracingRecorderImpl>(
+ filename, time_zone)) {}
+ // Constructs a TextTracingRecorder which will write events to stderr
+ // with a timestamp formatted for the provided timezone.
+ explicit TextTracingRecorder(absl::TimeZone time_zone)
+ : impl_(std::make_shared<tracing_internal::TextTracingRecorderImpl>(
+ time_zone)) {}
+
+ void InstallAsGlobal() override { impl_->InstallAsGlobal(); }
+ void UninstallAsGlobal() override { impl_->UninstallAsGlobal(); }
+ void InstallAsThreadLocal() override { impl_->InstallAsThreadLocal(); }
+ void UninstallAsThreadLocal() override { impl_->UninstallAsThreadLocal(); }
+
+ private:
+ // The tracing recorder implementation shared between tracing spans.
+ std::shared_ptr<tracing_internal::TextTracingRecorderImpl> impl_;
+};
+
+} // namespace fcp
+
+#endif // FCP_TRACING_TEXT_TRACING_RECORDER_H_
diff --git a/fcp/tracing/text_tracing_recorder_impl.cc b/fcp/tracing/text_tracing_recorder_impl.cc
new file mode 100644
index 0000000..dbc1d2c
--- /dev/null
+++ b/fcp/tracing/text_tracing_recorder_impl.cc
@@ -0,0 +1,100 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/text_tracing_recorder_impl.h"
+
+#include <memory>
+#include <ostream>
+#include <string>
+#include <utility>
+
+#include "absl/time/clock.h"
+#include "fcp/tracing/text_tracing_span_impl.h"
+#include "flatbuffers/minireflect.h"
+
+namespace fcp::tracing_internal {
+
+using std::endl;
+
+TextTracingRecorderImpl::TextTracingRecorderImpl(const std::string& filename,
+ absl::TimeZone time_zone)
+ : fstream_(filename), time_zone_(time_zone) {
+ stream_ = &fstream_.value();
+ root_span_ = std::make_unique<TextTracingSpanImpl>(this);
+}
+
+TextTracingRecorderImpl::TextTracingRecorderImpl(absl::TimeZone time_zone)
+ : stream_(&std::cerr), time_zone_(time_zone) {
+ root_span_ = std::make_unique<TextTracingSpanImpl>(this);
+}
+
+// TODO(team): Ensure traces from different threads are not interleaved.
+void TextTracingRecorderImpl::TraceImpl(TracingSpanId id,
+ flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) {
+ LogTime();
+ *stream_ << id << ": " << TracingTraitsBase::SeverityString(traits.Severity())
+ << " " << traits.Name() << traits.TextFormat(buf) << endl;
+}
+
+void TextTracingRecorderImpl::BeginSpan(TracingSpanId id,
+ TracingSpanId parent_id,
+ absl::string_view name,
+ absl::string_view text_format) {
+ LogSpan(/* begin = */ true, id, parent_id, name, text_format);
+}
+
+void TextTracingRecorderImpl::EndSpan(TracingSpanId id, absl::string_view name,
+ absl::string_view text_format) {
+ LogSpan(/* begin = */ false, id, TracingSpanId(0), name, text_format);
+}
+
+void TextTracingRecorderImpl::LogSpan(bool begin, TracingSpanId id,
+ TracingSpanId parent_id,
+ absl::string_view name,
+ absl::string_view text_format) {
+ LogTime();
+ *stream_ << id << (begin ? ": BEGIN" : ": END");
+ if (name.length() > 0) {
+ *stream_ << " " << name << " " << text_format;
+ }
+ if (begin && id != TracingSpanId(0)) {
+ *stream_ << " parent: " << parent_id;
+ }
+ *stream_ << endl;
+}
+
+void TextTracingRecorderImpl::LogTime() {
+ *stream_ << absl::FormatTime(absl::Now(), time_zone_) << " ";
+}
+
+TracingSpanImpl* TextTracingRecorderImpl::GetRootSpan() {
+ return root_span_.get();
+}
+
+std::unique_ptr<TracingSpanImpl> TextTracingRecorderImpl::CreateChildSpan(
+ TracingSpanId parent_span_id, flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) {
+ // NOTE: shared_from_this() is defined in a base class, so it returns
+ // std::shared_ptr<TracingRecorderImpl> and we have to (safely) cast it here:
+ auto shared_this =
+ std::static_pointer_cast<TextTracingRecorderImpl>(shared_from_this());
+ return std::make_unique<TextTracingSpanImpl>(
+ shared_this, std::move(buf), traits, TracingSpanId::NextUniqueId(),
+ parent_span_id);
+}
+
+TextTracingRecorderImpl::~TextTracingRecorderImpl() = default;
+
+} // namespace fcp::tracing_internal
diff --git a/fcp/tracing/text_tracing_recorder_impl.h b/fcp/tracing/text_tracing_recorder_impl.h
new file mode 100644
index 0000000..5c99334
--- /dev/null
+++ b/fcp/tracing/text_tracing_recorder_impl.h
@@ -0,0 +1,91 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TEXT_TRACING_RECORDER_IMPL_H_
+#define FCP_TRACING_TEXT_TRACING_RECORDER_IMPL_H_
+
+#include <flatbuffers/flatbuffers.h>
+
+#include <fstream>
+#include <iostream>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/time/time.h"
+#include "fcp/tracing/tracing_recorder.h"
+#include "fcp/tracing/tracing_recorder_impl.h"
+
+namespace fcp {
+namespace tracing_internal {
+
+class TextTracingSpanImpl;
+// Basic tracing API implementation that writes begin span, end span, and log
+// events to a stream in a human-readable text format.
+class TextTracingRecorderImpl : public TracingRecorderImpl {
+ public:
+ TextTracingRecorderImpl(const std::string& filename,
+ absl::TimeZone time_zone);
+
+ ~TextTracingRecorderImpl() override;
+
+ // Constructs a recorder implementation that writes tracing events to stderr.
+ explicit TextTracingRecorderImpl(absl::TimeZone time_zone);
+
+ // Creates a root span from which child tracing spans can be created.
+ TracingSpanImpl* GetRootSpan() override;
+
+ // Trace an event represented by the flatbuffer.
+ void TraceImpl(TracingSpanId span_id, flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) override;
+
+ // Log that the tracing span represented by the flatbuffer is starting.
+ void BeginSpan(TracingSpanId id, TracingSpanId parent_id,
+ absl::string_view name, absl::string_view text_format);
+
+ // Log that the tracing span represented by the provided flatbuf is finished.
+ void EndSpan(TracingSpanId id, absl::string_view name,
+ absl::string_view text_format);
+
+ // Creates instance of the child tracing span with the parent span ID and
+ // tracing data provided
+ std::unique_ptr<TracingSpanImpl> CreateChildSpan(
+ TracingSpanId parent_span_id, flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) override;
+
+ private:
+ // Log a timestamp of the current time to the filestream.
+ void LogTime();
+ // Common method for logging begin or end of a span.
+ void LogSpan(bool begin, TracingSpanId id, TracingSpanId parent_id,
+ absl::string_view name, absl::string_view text_format);
+
+ // File stream which is present only if this tracing recorder was constructed
+ // to write to a file. Should not be written to directly. This field is
+ // present because this class must own the filestream, since an instance of
+ // this class can be shared by many tracing spans, some of which may outlive
+ // the function that originally created the root tracing span.
+ std::optional<std::ofstream> fstream_;
+
+ // Pointer to an output stream to which tracing events are written.
+ std::ostream* stream_;
+
+ absl::TimeZone time_zone_;
+ std::unique_ptr<TextTracingSpanImpl> root_span_;
+};
+
+} // namespace tracing_internal
+} // namespace fcp
+
+#endif // FCP_TRACING_TEXT_TRACING_RECORDER_IMPL_H_
diff --git a/fcp/tracing/text_tracing_span_impl.cc b/fcp/tracing/text_tracing_span_impl.cc
new file mode 100644
index 0000000..7ab4d47
--- /dev/null
+++ b/fcp/tracing/text_tracing_span_impl.cc
@@ -0,0 +1,42 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/text_tracing_span_impl.h"
+
+#include <iostream>
+#include <memory>
+#include <utility>
+
+#include "flatbuffers/minireflect.h"
+
+namespace fcp {
+namespace tracing_internal {
+
+using flatbuffers::DetachedBuffer;
+
+void TextTracingSpanImpl::TraceImpl(DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) {
+ recorder_->TraceImpl(id_, std::move(buf), traits);
+}
+
+TextTracingSpanImpl::~TextTracingSpanImpl() {
+ recorder_->EndSpan(id_, buf_text_.name, buf_text_.text_format);
+}
+
+TracingSpanRef TextTracingSpanImpl::Ref() {
+ return TracingSpanRef(recorder_->shared_from_this(), id_);
+}
+
+} // namespace tracing_internal
+} // namespace fcp
diff --git a/fcp/tracing/text_tracing_span_impl.h b/fcp/tracing/text_tracing_span_impl.h
new file mode 100644
index 0000000..f6dc174
--- /dev/null
+++ b/fcp/tracing/text_tracing_span_impl.h
@@ -0,0 +1,82 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TEXT_TRACING_SPAN_IMPL_H_
+#define FCP_TRACING_TEXT_TRACING_SPAN_IMPL_H_
+
+#include <atomic>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "fcp/tracing/text_tracing_recorder_impl.h"
+#include "fcp/tracing/tracing_span_impl.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace fcp {
+namespace tracing_internal {
+
+// Basic tracing span implementation that owns the flatbuf representing the
+// current span.
+class TextTracingSpanImpl : public TracingSpanImpl {
+ public:
+ // Constructs a TextTracingSpanImpl for a child span from a serialized flatbuf
+ // and TracingTraitsBase which provides more context about the flatbuf table.
+ TextTracingSpanImpl(std::shared_ptr<TextTracingRecorderImpl> recorder,
+ const flatbuffers::DetachedBuffer& buf,
+ const TracingTraitsBase& traits, TracingSpanId id,
+ TracingSpanId parent_id)
+ : id_(id),
+ recorder_shared_ptr_(std::move(recorder)),
+ buf_text_{traits.Name(), traits.TextFormat(buf)} {
+ recorder_ = recorder_shared_ptr_.get();
+ recorder_->BeginSpan(id, parent_id, buf_text_.name, buf_text_.text_format);
+ }
+
+ // Constructs a TextTracingSpanImpl for root span.
+ explicit TextTracingSpanImpl(TextTracingRecorderImpl* recorder)
+ : id_(0), recorder_(recorder), buf_text_{} {
+ recorder_->BeginSpan(TracingSpanId(0), TracingSpanId(0), buf_text_.name,
+ buf_text_.text_format);
+ }
+ ~TextTracingSpanImpl() override;
+
+ // Logs an event in the current tracing span.
+ void TraceImpl(flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) override;
+
+ TracingSpanRef Ref() override;
+
+ private:
+ struct SpanText {
+ std::string name;
+ std::string text_format;
+ };
+ TracingSpanId id_;
+ TextTracingRecorderImpl* recorder_;
+
+ // For non-root span the following keeps recorder alive, if set,
+ // this holds the same value as recorder_. Since root span is owned directly
+ // by the recorder we can't store shared_ptr for it here to avoid a loop.
+ std::shared_ptr<TextTracingRecorderImpl> recorder_shared_ptr_;
+
+ // Human readable text describing the flatbuffer representing the current
+ // span; empty if this is the root span.
+ SpanText buf_text_;
+};
+
+} // namespace tracing_internal
+} // namespace fcp
+
+#endif // FCP_TRACING_TEXT_TRACING_SPAN_IMPL_H_
diff --git a/fcp/tracing/tools/BUILD b/fcp/tracing/tools/BUILD
new file mode 100644
index 0000000..cf4bf9d
--- /dev/null
+++ b/fcp/tracing/tools/BUILD
@@ -0,0 +1,56 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//fcp:config.bzl", "FCP_COPTS")
+
+package(
+ default_visibility = ["//fcp:internal"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_binary(
+ name = "tracing_traits_generator",
+ srcs = ["tracing_traits_generator.cc"],
+ deps = [
+ "//fcp/tracing:tracing_severity",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@flatbuffers",
+ ],
+)
+
+cc_test(
+ name = "tracing_traits_generator_test",
+ srcs = ["tracing_traits_generator_test.cc"],
+ args = ["--codegen_tool_path=$(location :test_codegen_runner.sh)"],
+ copts = FCP_COPTS,
+ data = glob([
+ "testdata/*.baseline",
+ "testdata/*.fbs",
+ ]) + [
+ ":test_codegen_runner.sh",
+ ":tracing_traits_generator",
+ "//fcp/tracing:tracing_schema_common.fbs",
+ "@flatbuffers//:flatc",
+ ],
+ deps = [
+ "//fcp/base",
+ "//fcp/testing",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/fcp/tracing/tools/README.md b/fcp/tracing/tools/README.md
new file mode 100644
index 0000000..38e6a04
--- /dev/null
+++ b/fcp/tracing/tools/README.md
@@ -0,0 +1,16 @@
+# Tracing Traits Generator
+
+This directory contains an implementation of a simple codegen tool
+`tracing_traits_generator` which uses FlatBuffers reflection to inspect
+user-defined tracing schema and produces a C++ header file containing additional
+traits needed for `TracingSpan::Log<T>()` and `TracingSpan::CreateChild<T>()` to
+function properly.
+
+These traits allow for compile-time lookup of user-defined tags, additional
+attributes needed for tracking backend components to handle the data.
+
+This tool is automatically invoked during a build process with the help of
+`tracing_schema_cc_library` rule in `build/build_defs.bzl`.
+
+Examples of codegen output for various input FlatBuffers schemas can be found in
+`testdata/*.baseline` files.
diff --git a/fcp/tracing/tools/test_codegen_runner.sh b/fcp/tracing/tools/test_codegen_runner.sh
new file mode 100755
index 0000000..1bdae86
--- /dev/null
+++ b/fcp/tracing/tools/test_codegen_runner.sh
@@ -0,0 +1,32 @@
+#!/bin/bash -eu
+#
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -o errexit
+set -o nounset
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 <input.fbs> <output-dir> <header-output-dir>"
+ exit 1;
+fi
+FBS_FILE=$1
+OUT_DIR=$2
+HEADER_OUT_DIR=$3
+# Running flatc to parse fbs and generate bfbs
+external/flatbuffers/flatc -b --schema -o ${OUT_DIR} -I "." ${FBS_FILE} 1>&2
+# Flatc should have produced the following files:
+BFBS_FILE=$OUT_DIR/$(basename ${FBS_FILE%.fbs}).bfbs
+GENERATED_FILE=$HEADER_OUT_DIR/$(basename ${FBS_FILE%.fbs})_generated.h
+# Generate header file from bfbs (to stdout)
+fcp/tracing/tools/tracing_traits_generator ${GENERATED_FILE} ${BFBS_FILE} ${FBS_FILE}
diff --git a/fcp/tracing/tools/testdata/AllTypes.baseline b/fcp/tracing/tools/testdata/AllTypes.baseline
new file mode 100644
index 0000000..a0b311b
--- /dev/null
+++ b/fcp/tracing/tools/testdata/AllTypes.baseline
@@ -0,0 +1,79 @@
+============== AllTypes.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+table AllTypes (tag: "ALLT") {
+ fieldz: byte;
+ fieldy: ubyte;
+ fieldx: bool;
+ fieldw: short;
+ fieldv: ushort;
+ fieldu: int;
+ fieldt: uint;
+ fields: float;
+ fieldr: long;
+ fieldq: ulong;
+ fieldp: double;
+ fieldo: string;
+}
+============== diagnosis ============
+
+============== result ============
+// Autogenerated by tracing_traits_generator, do not edit
+
+#ifndef THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_ALLTYPES_H
+#define THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_ALLTYPES_H
+
+#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#endif
+#include "${DIR}/tools/testdata/AllTypes_generated.h"
+#include "absl/strings/string_view.h"
+#include "${DIR}/tracing_severity.h"
+#include "${DIR}/tracing_traits.h"
+#include "flatbuffers/minireflect.h"
+#include "flatbuffers/idl.h"
+#include "${BASE}/platform.h"
+
+namespace fcp {
+
+template<> class TracingTraits<AllTypes>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("ALLT");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "AllTypes"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), AllTypesTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/AllTypes.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("AllTypes");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<AllTypes> Create(std::int8_t fieldz, std::uint8_t fieldy, bool fieldx, std::int16_t fieldw, std::uint16_t fieldv, std::int32_t fieldu, std::uint32_t fieldt, float fields, std::int64_t fieldr, std::uint64_t fieldq, double fieldp, absl::string_view fieldo, flatbuffers::FlatBufferBuilder* fbb) {
+ auto fieldo__ = fbb->CreateString(fieldo.data(), fieldo.size());
+ return CreateAllTypes(*fbb, fieldz, fieldy, fieldx, fieldw, fieldv, fieldu, fieldt, fields, fieldr, fieldq, fieldp, fieldo__);
+ }
+ using TupleType = std::tuple<std::int8_t, std::uint8_t, bool, std::int16_t, std::uint16_t, std::int32_t, std::uint32_t, float, std::int64_t, std::uint64_t, double, std::string>;
+ static TupleType MakeTuple(const AllTypes* table) {
+ return std::make_tuple(table->fieldz(), table->fieldy(), table->fieldx(), table->fieldw(), table->fieldv(), table->fieldu(), table->fieldt(), table->fields(), table->fieldr(), table->fieldq(), table->fieldp(), table->fieldo()->str());
+ }
+};
+static internal::TracingTraitsRegistrar<AllTypes> registrar_AllTypes;
+} // namespace fcp
+
+#endif // THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_ALLTYPES_H
+
diff --git a/fcp/tracing/tools/testdata/AllTypes.fbs b/fcp/tracing/tools/testdata/AllTypes.fbs
new file mode 100644
index 0000000..5e9acf1
--- /dev/null
+++ b/fcp/tracing/tools/testdata/AllTypes.fbs
@@ -0,0 +1,16 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table AllTypes (tag: "ALLT") {
+ fieldz: byte;
+ fieldy: ubyte;
+ fieldx: bool;
+ fieldw: short;
+ fieldv: ushort;
+ fieldu: int;
+ fieldt: uint;
+ fields: float;
+ fieldr: long;
+ fieldq: ulong;
+ fieldp: double;
+ fieldo: string;
+} \ No newline at end of file
diff --git a/fcp/tracing/tools/testdata/DeprecatedField.baseline b/fcp/tracing/tools/testdata/DeprecatedField.baseline
new file mode 100644
index 0000000..df508f1
--- /dev/null
+++ b/fcp/tracing/tools/testdata/DeprecatedField.baseline
@@ -0,0 +1,109 @@
+============== DeprecatedField.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+table DeprecatedInt (tag: "DEPI") {
+ field1: int32 (deprecated);
+ field2: int32;
+}
+
+table DeprecatedString (tag: "DEPS") {
+ field1: string (deprecated);
+ field2: int32;
+}
+============== diagnosis ============
+
+============== result ============
+// Autogenerated by tracing_traits_generator, do not edit
+
+#ifndef THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_DEPRECATEDFIELD_H
+#define THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_DEPRECATEDFIELD_H
+
+#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#endif
+#include "${DIR}/tools/testdata/DeprecatedField_generated.h"
+#include "absl/strings/string_view.h"
+#include "${DIR}/tracing_severity.h"
+#include "${DIR}/tracing_traits.h"
+#include "flatbuffers/minireflect.h"
+#include "flatbuffers/idl.h"
+#include "${BASE}/platform.h"
+
+namespace fcp {
+
+template<> class TracingTraits<DeprecatedInt>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("DEPI");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "DeprecatedInt"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), DeprecatedIntTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/DeprecatedField.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("DeprecatedInt");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<DeprecatedInt> Create(std::int32_t field2, flatbuffers::FlatBufferBuilder* fbb) {
+ return CreateDeprecatedInt(*fbb, field2);
+ }
+ using TupleType = std::tuple<std::int32_t>;
+ static TupleType MakeTuple(const DeprecatedInt* table) {
+ return std::make_tuple(table->field2());
+ }
+};
+static internal::TracingTraitsRegistrar<DeprecatedInt> registrar_DeprecatedInt;
+template<> class TracingTraits<DeprecatedString>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("DEPS");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "DeprecatedString"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), DeprecatedStringTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/DeprecatedField.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("DeprecatedString");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<DeprecatedString> Create(std::int32_t field2, flatbuffers::FlatBufferBuilder* fbb) {
+ return CreateDeprecatedString(*fbb, field2);
+ }
+ using TupleType = std::tuple<std::int32_t>;
+ static TupleType MakeTuple(const DeprecatedString* table) {
+ return std::make_tuple(table->field2());
+ }
+};
+static internal::TracingTraitsRegistrar<DeprecatedString> registrar_DeprecatedString;
+} // namespace fcp
+
+#endif // THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_DEPRECATEDFIELD_H
+
diff --git a/fcp/tracing/tools/testdata/DeprecatedField.fbs b/fcp/tracing/tools/testdata/DeprecatedField.fbs
new file mode 100644
index 0000000..e61680c
--- /dev/null
+++ b/fcp/tracing/tools/testdata/DeprecatedField.fbs
@@ -0,0 +1,11 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table DeprecatedInt (tag: "DEPI") {
+ field1: int32 (deprecated);
+ field2: int32;
+}
+
+table DeprecatedString (tag: "DEPS") {
+ field1: string (deprecated);
+ field2: int32;
+} \ No newline at end of file
diff --git a/fcp/tracing/tools/testdata/DuplicateTags.baseline b/fcp/tracing/tools/testdata/DuplicateTags.baseline
new file mode 100644
index 0000000..e6714b7
--- /dev/null
+++ b/fcp/tracing/tools/testdata/DuplicateTags.baseline
@@ -0,0 +1,17 @@
+============== DuplicateTags.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+table TableWithTag (tag: "TTT0") {
+ field1: int32;
+ field2: int32;
+}
+
+table TableWithSameTag (tag: "TTT0") {
+ field1: int32;
+ field2: string;
+}
+============== diagnosis ============
+ERROR: ${DIR}/tools/testdata/DuplicateTags.fbs contains table TableWithTag with tag TTT0 which is already present in the schema. All tags must be unique.
+
+============== result ============
+
diff --git a/fcp/tracing/tools/testdata/DuplicateTags.fbs b/fcp/tracing/tools/testdata/DuplicateTags.fbs
new file mode 100644
index 0000000..ee3cc81
--- /dev/null
+++ b/fcp/tracing/tools/testdata/DuplicateTags.fbs
@@ -0,0 +1,11 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table TableWithTag (tag: "TTT0") {
+ field1: int32;
+ field2: int32;
+}
+
+table TableWithSameTag (tag: "TTT0") {
+ field1: int32;
+ field2: string;
+} \ No newline at end of file
diff --git a/fcp/tracing/tools/testdata/EmptyTable.baseline b/fcp/tracing/tools/testdata/EmptyTable.baseline
new file mode 100644
index 0000000..121100d
--- /dev/null
+++ b/fcp/tracing/tools/testdata/EmptyTable.baseline
@@ -0,0 +1,65 @@
+============== EmptyTable.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+table DoesntContainFields (tag: "EMPT", span) {}
+============== diagnosis ============
+
+============== result ============
+// Autogenerated by tracing_traits_generator, do not edit
+
+#ifndef THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_EMPTYTABLE_H
+#define THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_EMPTYTABLE_H
+
+#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#endif
+#include "${DIR}/tools/testdata/EmptyTable_generated.h"
+#include "absl/strings/string_view.h"
+#include "${DIR}/tracing_severity.h"
+#include "${DIR}/tracing_traits.h"
+#include "flatbuffers/minireflect.h"
+#include "flatbuffers/idl.h"
+#include "${BASE}/platform.h"
+
+namespace fcp {
+
+template<> class TracingTraits<DoesntContainFields>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("EMPT");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = true;
+ const char* Name() const override { return "DoesntContainFields"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), DoesntContainFieldsTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/EmptyTable.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("DoesntContainFields");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<DoesntContainFields> Create(flatbuffers::FlatBufferBuilder* fbb) {
+ return CreateDoesntContainFields(*fbb);
+ }
+ using TupleType = std::tuple<>;
+ static TupleType MakeTuple(const DoesntContainFields* table) {
+ return std::make_tuple();
+ }
+};
+static internal::TracingTraitsRegistrar<DoesntContainFields> registrar_DoesntContainFields;
+} // namespace fcp
+
+#endif // THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_EMPTYTABLE_H
+
diff --git a/fcp/tracing/tools/testdata/EmptyTable.fbs b/fcp/tracing/tools/testdata/EmptyTable.fbs
new file mode 100644
index 0000000..43d08b9
--- /dev/null
+++ b/fcp/tracing/tools/testdata/EmptyTable.fbs
@@ -0,0 +1,3 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table DoesntContainFields (tag: "EMPT", span) {} \ No newline at end of file
diff --git a/fcp/tracing/tools/testdata/EnumType.baseline b/fcp/tracing/tools/testdata/EnumType.baseline
new file mode 100644
index 0000000..d5cd58d
--- /dev/null
+++ b/fcp/tracing/tools/testdata/EnumType.baseline
@@ -0,0 +1,71 @@
+============== EnumType.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+enum Color : byte { Red = 0, Green = 1, Blue = 2 }
+
+table Monster(tag: "MONS") {
+ hp: int32;
+ color : Color;
+}
+
+============== diagnosis ============
+
+============== result ============
+// Autogenerated by tracing_traits_generator, do not edit
+
+#ifndef THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_ENUMTYPE_H
+#define THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_ENUMTYPE_H
+
+#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#endif
+#include "${DIR}/tools/testdata/EnumType_generated.h"
+#include "absl/strings/string_view.h"
+#include "${DIR}/tracing_severity.h"
+#include "${DIR}/tracing_traits.h"
+#include "flatbuffers/minireflect.h"
+#include "flatbuffers/idl.h"
+#include "${BASE}/platform.h"
+
+namespace fcp {
+
+template<> class TracingTraits<Monster>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("MONS");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "Monster"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), MonsterTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/EnumType.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("Monster");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<Monster> Create(std::int32_t hp, Color color, flatbuffers::FlatBufferBuilder* fbb) {
+ return CreateMonster(*fbb, hp, color);
+ }
+ using TupleType = std::tuple<std::int32_t, Color>;
+ static TupleType MakeTuple(const Monster* table) {
+ return std::make_tuple(table->hp(), table->color());
+ }
+};
+static internal::TracingTraitsRegistrar<Monster> registrar_Monster;
+} // namespace fcp
+
+#endif // THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_ENUMTYPE_H
+
diff --git a/fcp/tracing/tools/testdata/EnumType.fbs b/fcp/tracing/tools/testdata/EnumType.fbs
new file mode 100644
index 0000000..7319cb3
--- /dev/null
+++ b/fcp/tracing/tools/testdata/EnumType.fbs
@@ -0,0 +1,8 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+enum Color : byte { Red = 0, Green = 1, Blue = 2 }
+
+table Monster(tag: "MONS") {
+ hp: int32;
+ color : Color;
+}
diff --git a/fcp/tracing/tools/testdata/FieldsOfDifferentTypes.baseline b/fcp/tracing/tools/testdata/FieldsOfDifferentTypes.baseline
new file mode 100644
index 0000000..a4b5bad
--- /dev/null
+++ b/fcp/tracing/tools/testdata/FieldsOfDifferentTypes.baseline
@@ -0,0 +1,111 @@
+============== FieldsOfDifferentTypes.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+table IntegersOnly (tag: "TTT0") {
+ field1: int32;
+ field2: int32;
+}
+
+table IntegersWithString (tag: "TTT1") {
+ field1: int32;
+ field2: string;
+}
+
+============== diagnosis ============
+
+============== result ============
+// Autogenerated by tracing_traits_generator, do not edit
+
+#ifndef THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_FIELDSOFDIFFERENTTYPES_H
+#define THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_FIELDSOFDIFFERENTTYPES_H
+
+#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#endif
+#include "${DIR}/tools/testdata/FieldsOfDifferentTypes_generated.h"
+#include "absl/strings/string_view.h"
+#include "${DIR}/tracing_severity.h"
+#include "${DIR}/tracing_traits.h"
+#include "flatbuffers/minireflect.h"
+#include "flatbuffers/idl.h"
+#include "${BASE}/platform.h"
+
+namespace fcp {
+
+template<> class TracingTraits<IntegersOnly>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("TTT0");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "IntegersOnly"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), IntegersOnlyTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/FieldsOfDifferentTypes.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("IntegersOnly");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<IntegersOnly> Create(std::int32_t field1, std::int32_t field2, flatbuffers::FlatBufferBuilder* fbb) {
+ return CreateIntegersOnly(*fbb, field1, field2);
+ }
+ using TupleType = std::tuple<std::int32_t, std::int32_t>;
+ static TupleType MakeTuple(const IntegersOnly* table) {
+ return std::make_tuple(table->field1(), table->field2());
+ }
+};
+static internal::TracingTraitsRegistrar<IntegersOnly> registrar_IntegersOnly;
+template<> class TracingTraits<IntegersWithString>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("TTT1");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "IntegersWithString"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), IntegersWithStringTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/FieldsOfDifferentTypes.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("IntegersWithString");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<IntegersWithString> Create(std::int32_t field1, absl::string_view field2, flatbuffers::FlatBufferBuilder* fbb) {
+ auto field2__ = fbb->CreateString(field2.data(), field2.size());
+ return CreateIntegersWithString(*fbb, field1, field2__);
+ }
+ using TupleType = std::tuple<std::int32_t, std::string>;
+ static TupleType MakeTuple(const IntegersWithString* table) {
+ return std::make_tuple(table->field1(), table->field2()->str());
+ }
+};
+static internal::TracingTraitsRegistrar<IntegersWithString> registrar_IntegersWithString;
+} // namespace fcp
+
+#endif // THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_FIELDSOFDIFFERENTTYPES_H
+
diff --git a/fcp/tracing/tools/testdata/FieldsOfDifferentTypes.fbs b/fcp/tracing/tools/testdata/FieldsOfDifferentTypes.fbs
new file mode 100644
index 0000000..ef4a4cb
--- /dev/null
+++ b/fcp/tracing/tools/testdata/FieldsOfDifferentTypes.fbs
@@ -0,0 +1,11 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table IntegersOnly (tag: "TTT0") {
+ field1: int32;
+ field2: int32;
+}
+
+table IntegersWithString (tag: "TTT1") {
+ field1: int32;
+ field2: string;
+}
diff --git a/fcp/tracing/tools/testdata/NoAttributes.baseline b/fcp/tracing/tools/testdata/NoAttributes.baseline
new file mode 100644
index 0000000..a832996
--- /dev/null
+++ b/fcp/tracing/tools/testdata/NoAttributes.baseline
@@ -0,0 +1,9 @@
+============== NoAttributes.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+table NoAttributes {}
+============== diagnosis ============
+ERROR: ${DIR}/tools/testdata/NoAttributes.fbs contains table NoAttributes without a tag. All tables must have a tag defined.
+
+============== result ============
+
diff --git a/fcp/tracing/tools/testdata/NoAttributes.fbs b/fcp/tracing/tools/testdata/NoAttributes.fbs
new file mode 100644
index 0000000..5898a99
--- /dev/null
+++ b/fcp/tracing/tools/testdata/NoAttributes.fbs
@@ -0,0 +1,3 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table NoAttributes {} \ No newline at end of file
diff --git a/fcp/tracing/tools/testdata/NoTag.baseline b/fcp/tracing/tools/testdata/NoTag.baseline
new file mode 100644
index 0000000..2eff814
--- /dev/null
+++ b/fcp/tracing/tools/testdata/NoTag.baseline
@@ -0,0 +1,9 @@
+============== NoTag.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+table NoTag (error) {}
+============== diagnosis ============
+ERROR: ${DIR}/tools/testdata/NoTag.fbs contains table NoTag without a tag. All tables must have a tag defined.
+
+============== result ============
+
diff --git a/fcp/tracing/tools/testdata/NoTag.fbs b/fcp/tracing/tools/testdata/NoTag.fbs
new file mode 100644
index 0000000..f43801d
--- /dev/null
+++ b/fcp/tracing/tools/testdata/NoTag.fbs
@@ -0,0 +1,3 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table NoTag (error) {} \ No newline at end of file
diff --git a/fcp/tracing/tools/testdata/NonTableObjectsAreSkipped.baseline b/fcp/tracing/tools/testdata/NonTableObjectsAreSkipped.baseline
new file mode 100644
index 0000000..d6ea3bd
--- /dev/null
+++ b/fcp/tracing/tools/testdata/NonTableObjectsAreSkipped.baseline
@@ -0,0 +1,78 @@
+============== NonTableObjectsAreSkipped.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+attribute "priority";
+
+enum Color : byte { Red, Green, Blue }
+
+struct Vec3 {
+ x: float;
+ y: float;
+ z: float;
+}
+
+table Monster(tag: "MONS") {
+ hp: int32;
+}
+
+============== diagnosis ============
+
+============== result ============
+// Autogenerated by tracing_traits_generator, do not edit
+
+#ifndef THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_NONTABLEOBJECTSARESKIPPED_H
+#define THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_NONTABLEOBJECTSARESKIPPED_H
+
+#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#endif
+#include "${DIR}/tools/testdata/NonTableObjectsAreSkipped_generated.h"
+#include "absl/strings/string_view.h"
+#include "${DIR}/tracing_severity.h"
+#include "${DIR}/tracing_traits.h"
+#include "flatbuffers/minireflect.h"
+#include "flatbuffers/idl.h"
+#include "${BASE}/platform.h"
+
+namespace fcp {
+
+template<> class TracingTraits<Monster>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("MONS");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "Monster"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), MonsterTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/NonTableObjectsAreSkipped.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("Monster");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<Monster> Create(std::int32_t hp, flatbuffers::FlatBufferBuilder* fbb) {
+ return CreateMonster(*fbb, hp);
+ }
+ using TupleType = std::tuple<std::int32_t>;
+ static TupleType MakeTuple(const Monster* table) {
+ return std::make_tuple(table->hp());
+ }
+};
+static internal::TracingTraitsRegistrar<Monster> registrar_Monster;
+} // namespace fcp
+
+#endif // THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_NONTABLEOBJECTSARESKIPPED_H
+
diff --git a/fcp/tracing/tools/testdata/NonTableObjectsAreSkipped.fbs b/fcp/tracing/tools/testdata/NonTableObjectsAreSkipped.fbs
new file mode 100644
index 0000000..5640cf7
--- /dev/null
+++ b/fcp/tracing/tools/testdata/NonTableObjectsAreSkipped.fbs
@@ -0,0 +1,15 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+attribute "priority";
+
+enum Color : byte { Red, Green, Blue }
+
+struct Vec3 {
+ x: float;
+ y: float;
+ z: float;
+}
+
+table Monster(tag: "MONS") {
+ hp: int32;
+}
diff --git a/fcp/tracing/tools/testdata/OrderWithIds.baseline b/fcp/tracing/tools/testdata/OrderWithIds.baseline
new file mode 100644
index 0000000..1b78d99
--- /dev/null
+++ b/fcp/tracing/tools/testdata/OrderWithIds.baseline
@@ -0,0 +1,69 @@
+============== OrderWithIds.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+table OrderWithIds (tag: "ORDI") {
+ fieldz: int (id: 1);
+ fieldy: int (id: 2);
+ fieldx: int (id: 0);
+}
+============== diagnosis ============
+
+============== result ============
+// Autogenerated by tracing_traits_generator, do not edit
+
+#ifndef THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_ORDERWITHIDS_H
+#define THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_ORDERWITHIDS_H
+
+#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#endif
+#include "${DIR}/tools/testdata/OrderWithIds_generated.h"
+#include "absl/strings/string_view.h"
+#include "${DIR}/tracing_severity.h"
+#include "${DIR}/tracing_traits.h"
+#include "flatbuffers/minireflect.h"
+#include "flatbuffers/idl.h"
+#include "${BASE}/platform.h"
+
+namespace fcp {
+
+template<> class TracingTraits<OrderWithIds>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("ORDI");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = false;
+ const char* Name() const override { return "OrderWithIds"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), OrderWithIdsTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/OrderWithIds.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("OrderWithIds");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<OrderWithIds> Create(std::int32_t fieldx, std::int32_t fieldz, std::int32_t fieldy, flatbuffers::FlatBufferBuilder* fbb) {
+ return CreateOrderWithIds(*fbb, fieldx, fieldz, fieldy);
+ }
+ using TupleType = std::tuple<std::int32_t, std::int32_t, std::int32_t>;
+ static TupleType MakeTuple(const OrderWithIds* table) {
+ return std::make_tuple(table->fieldx(), table->fieldz(), table->fieldy());
+ }
+};
+static internal::TracingTraitsRegistrar<OrderWithIds> registrar_OrderWithIds;
+} // namespace fcp
+
+#endif // THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_ORDERWITHIDS_H
+
diff --git a/fcp/tracing/tools/testdata/OrderWithIds.fbs b/fcp/tracing/tools/testdata/OrderWithIds.fbs
new file mode 100644
index 0000000..c15045d
--- /dev/null
+++ b/fcp/tracing/tools/testdata/OrderWithIds.fbs
@@ -0,0 +1,7 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table OrderWithIds (tag: "ORDI") {
+ fieldz: int (id: 1);
+ fieldy: int (id: 2);
+ fieldx: int (id: 0);
+} \ No newline at end of file
diff --git a/fcp/tracing/tools/testdata/TableWithNamespace.baseline b/fcp/tracing/tools/testdata/TableWithNamespace.baseline
new file mode 100644
index 0000000..1842f84
--- /dev/null
+++ b/fcp/tracing/tools/testdata/TableWithNamespace.baseline
@@ -0,0 +1,72 @@
+============== TableWithNamespace.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+namespace foo.bar;
+
+enum Color : int { Red = 0, Green = 1, Blue = 2 }
+
+table TestTable (tag: "TETB", span) {
+ id: int;
+ color: Color;
+}
+
+============== diagnosis ============
+
+============== result ============
+// Autogenerated by tracing_traits_generator, do not edit
+
+#ifndef THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_TABLEWITHNAMESPACE_H
+#define THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_TABLEWITHNAMESPACE_H
+
+#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_
+#endif
+#include "${DIR}/tools/testdata/TableWithNamespace_generated.h"
+#include "absl/strings/string_view.h"
+#include "${DIR}/tracing_severity.h"
+#include "${DIR}/tracing_traits.h"
+#include "flatbuffers/minireflect.h"
+#include "flatbuffers/idl.h"
+#include "${BASE}/platform.h"
+
+namespace fcp {
+
+template<> class TracingTraits<foo::bar::TestTable>: public TracingTraitsBase {
+ public:
+ static constexpr TracingTag kTag = TracingTag("TETB");
+ static constexpr TracingSeverity kSeverity = fcp::TracingSeverity::kInfo;
+ static constexpr bool kIsSpan = true;
+ const char* Name() const override { return "foo::bar::TestTable"; }
+ TracingSeverity Severity() const override {
+ return fcp::TracingSeverity::kInfo;
+ }
+ std::string TextFormat(const flatbuffers::DetachedBuffer& buf) const override {
+ return flatbuffers::FlatBufferToString(buf.data(), foo::bar::TestTableTypeTable());
+ }
+ std::string JsonStringFormat(const uint8_t* flatbuf_bytes) const override {
+ flatbuffers::Parser parser;
+ std::string schema_file;
+ std::string fbs_file = "${RUNFILE_PATH}/tools/testdata/TableWithNamespace.fbs";
+ flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), true, &schema_file);
+ std::string schema_path_common = GetDataPath("${DIR}/tracing_schema_common.fbs");
+ std::string directory_common = schema_path_common.substr(0, schema_path_common.find("${DIR}/tracing_schema_common.fbs"));
+ const char *include_directories[] = {
+ directory_common.c_str(), nullptr};
+ parser.Parse(schema_file.c_str(), include_directories);
+ std::string jsongen;
+ parser.SetRootType("foo::bar::TestTable");
+ GenerateText(parser, flatbuf_bytes, &jsongen);
+ return jsongen;
+ }
+ static flatbuffers::Offset<foo::bar::TestTable> Create(std::int32_t id, foo::bar::Color color, flatbuffers::FlatBufferBuilder* fbb) {
+ return foo::bar::CreateTestTable(*fbb, id, color);
+ }
+ using TupleType = std::tuple<std::int32_t, foo::bar::Color>;
+ static TupleType MakeTuple(const foo::bar::TestTable* table) {
+ return std::make_tuple(table->id(), table->color());
+ }
+};
+static internal::TracingTraitsRegistrar<foo::bar::TestTable> registrar_TestTable;
+} // namespace fcp
+
+#endif // THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA_TABLEWITHNAMESPACE_H
+
diff --git a/fcp/tracing/tools/testdata/TableWithNamespace.fbs b/fcp/tracing/tools/testdata/TableWithNamespace.fbs
new file mode 100644
index 0000000..55b5361
--- /dev/null
+++ b/fcp/tracing/tools/testdata/TableWithNamespace.fbs
@@ -0,0 +1,9 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+namespace foo.bar;
+
+enum Color : int { Red = 0, Green = 1, Blue = 2 }
+
+table TestTable (tag: "TETB", span) {
+ id: int;
+ color: Color;
+}
diff --git a/fcp/tracing/tools/testdata/TagTooLong.baseline b/fcp/tracing/tools/testdata/TagTooLong.baseline
new file mode 100644
index 0000000..930c65b
--- /dev/null
+++ b/fcp/tracing/tools/testdata/TagTooLong.baseline
@@ -0,0 +1,9 @@
+============== TagTooLong.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+table TagTooLong (tag: "LONGT") {}
+============== diagnosis ============
+ERROR: ${DIR}/tools/testdata/TagTooLong.fbs contains table TagTooLong with tag LONGT of length 5. All tables must have a tag of length 4.
+
+============== result ============
+
diff --git a/fcp/tracing/tools/testdata/TagTooLong.fbs b/fcp/tracing/tools/testdata/TagTooLong.fbs
new file mode 100644
index 0000000..561f2ea
--- /dev/null
+++ b/fcp/tracing/tools/testdata/TagTooLong.fbs
@@ -0,0 +1,3 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+table TagTooLong (tag: "LONGT") {} \ No newline at end of file
diff --git a/fcp/tracing/tools/testdata/UnsupportedType.baseline b/fcp/tracing/tools/testdata/UnsupportedType.baseline
new file mode 100644
index 0000000..0177f51
--- /dev/null
+++ b/fcp/tracing/tools/testdata/UnsupportedType.baseline
@@ -0,0 +1,18 @@
+============== UnsupportedType.fbs ============
+include "${DIR}/tracing_schema_common.fbs";
+
+struct Vec3 {
+ x: float;
+ y: float;
+ z: float;
+}
+
+table Monster(tag: "MONS") {
+ hp: int32;
+ pos: Vec3;
+}
+============== diagnosis ============
+ERROR: ${DIR}/tools/testdata/UnsupportedType.fbs contains unsupported type Obj for field pos in table Monster
+
+============== result ============
+
diff --git a/fcp/tracing/tools/testdata/UnsupportedType.fbs b/fcp/tracing/tools/testdata/UnsupportedType.fbs
new file mode 100644
index 0000000..c547750
--- /dev/null
+++ b/fcp/tracing/tools/testdata/UnsupportedType.fbs
@@ -0,0 +1,12 @@
+include "fcp/tracing/tracing_schema_common.fbs";
+
+struct Vec3 {
+ x: float;
+ y: float;
+ z: float;
+}
+
+table Monster(tag: "MONS") {
+ hp: int32;
+ pos: Vec3;
+} \ No newline at end of file
diff --git a/fcp/tracing/tools/tracing_traits_generator.cc b/fcp/tracing/tools/tracing_traits_generator.cc
new file mode 100644
index 0000000..1790a18
--- /dev/null
+++ b/fcp/tracing/tools/tracing_traits_generator.cc
@@ -0,0 +1,424 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <iostream>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "fcp/tracing/tracing_severity.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/idl.h"
+#include "flatbuffers/reflection_generated.h"
+#include "flatbuffers/util.h"
+
+using ::reflection::BaseType;
+
+namespace fcp {
+
+struct TypeInfo {
+ BaseType flatbuf_type;
+ std::string cpp_type;
+};
+
+static std::string severity_string(const TracingSeverity tracing_severity) {
+ switch (tracing_severity) {
+ case TracingSeverity::kInfo:
+ return "fcp::TracingSeverity::kInfo";
+ case TracingSeverity::kWarning:
+ return "fcp::TracingSeverity::kWarning";
+ case TracingSeverity::kError:
+ return "fcp::TracingSeverity::kError";
+ }
+}
+
+static std::string gen_header_guard(absl::string_view output_filename) {
+ std::string header_guard = absl::StrReplaceAll(
+ output_filename, {{"_generated", ""}, {"/", "_"}, {".", "_"}});
+ std::transform(header_guard.begin(), header_guard.end(), header_guard.begin(),
+ [](unsigned char c) { return std::toupper(c); });
+ return header_guard;
+}
+
+static std::string gen_fbs_filename(absl::string_view output_filename) {
+ return absl::StrReplaceAll(output_filename, {{".h", ".fbs"}});
+}
+
+static absl::string_view gen_table_name(
+ absl::string_view fully_qualified_table_name) {
+ auto pos = fully_qualified_table_name.find_last_of(':');
+ if (pos != std::string::npos) {
+ return absl::ClippedSubstr(fully_qualified_table_name, pos + 1);
+ }
+ return fully_qualified_table_name;
+}
+
+static absl::string_view gen_table_namespace(
+ absl::string_view fully_qualified_table_name) {
+ auto pos = fully_qualified_table_name.find_last_of(':');
+ if (pos != std::string::npos) {
+ return absl::ClippedSubstr(fully_qualified_table_name, 0, pos + 1);
+ }
+ return "";
+}
+} // namespace fcp
+
+// For codegen examples, see fcp/tracing/tools/testdata.
+int main(int argc, const char** argv) {
+ if (argc != 4) {
+ std::cerr << "Usage: tracing_traits_generator "
+ "<runtime/path/to/tracing_schema_generated.h> "
+ "<full/path/to/tracing_schema.bfbs>"
+ "<full/path/to/tracing_schema.fbs>"
+ << std::endl;
+ return 1;
+ }
+ const char* generated_filename = argv[1];
+ const char* bfbs_filename = argv[2];
+ const char* fbs_filename = argv[3];
+
+ // Loading binary schema file
+ std::string bfbs_file;
+ if (!flatbuffers::LoadFile(bfbs_filename, true, &bfbs_file)) {
+ std::cerr << "Error loading FlatBuffers binary schema (bfbs) file"
+ << std::endl;
+ return 2;
+ }
+
+ // Verify it, just in case:
+ flatbuffers::Verifier verifier(
+ reinterpret_cast<const uint8_t*>(bfbs_file.c_str()), bfbs_file.length());
+ if (!reflection::VerifySchemaBuffer(verifier)) {
+ std::cerr << "Error loading bfbs file" << std::endl;
+ return 3;
+ }
+
+ std::cout << "// Autogenerated by tracing_traits_generator, do not edit"
+ << std::endl;
+ std::cout << std::endl;
+
+ std::string output_filename =
+ absl::StrReplaceAll(generated_filename, {{"_generated", ""}});
+ std::string header_guard = fcp::gen_header_guard(output_filename);
+ std::cout << "#ifndef " << header_guard << std::endl;
+ std::cout << "#define " << header_guard << std::endl;
+ std::cout << std::endl;
+
+ // Workaround for inability of flatc to generate unique (path-dependent)
+ // include guards. Undefining the include guard below allows
+ // to include. Since all the flatc-generated schema files are wrapped
+ // by the guards above, it still remains protected against multiple includes.
+ std::cout << "#ifdef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_" << std::endl;
+ std::cout << "#undef FLATBUFFERS_GENERATED_TRACINGSCHEMA_H_" << std::endl;
+ std::cout << "#endif" << std::endl;
+
+ std::cout << "#include \"" << generated_filename << "\"" << std::endl;
+ std::cout << "#include \"absl/strings/string_view.h\""
+ << std::endl;
+ std::cout << "#include \"fcp/tracing/tracing_severity.h\""
+ << std::endl;
+ std::cout << "#include \"fcp/tracing/tracing_traits.h\""
+ << std::endl;
+ std::cout << "#include "
+ "\"flatbuffers/minireflect.h\""
+ << std::endl;
+ std::cout << "#include "
+ "\"flatbuffers/idl.h\""
+ << std::endl;
+ std::cout << "#include "
+ "\"fcp/base/platform.h\""
+ << std::endl;
+ std::cout << std::endl;
+
+ // Reflecting over schema and enumerating tables
+ auto& schema = *reflection::GetSchema(bfbs_file.c_str());
+ std::cout << "namespace fcp {" << std::endl;
+ std::cout << std::endl;
+
+ absl::flat_hash_map<BaseType, std::string> type_map = {
+ {BaseType::String, "absl::string_view"},
+ {BaseType::Byte, "std::int8_t"},
+ {BaseType::UByte, "std::uint8_t"},
+ {BaseType::Bool, "bool"},
+ {BaseType::Short, "std::int16_t"},
+ {BaseType::UShort, "std::uint16_t"},
+ {BaseType::Int, "std::int32_t"},
+ {BaseType::UInt, "std::uint32_t"},
+ {BaseType::Float, "float"},
+ {BaseType::Long, "std::int64_t"},
+ {BaseType::ULong, "std::uint64_t"},
+ {BaseType::Double, "double"}};
+ absl::flat_hash_set<std::string> tags;
+ for (const reflection::Object* const o : *schema.objects()) {
+ if (o->is_struct()) continue;
+ std::string fully_qualified_table_name =
+ absl::StrReplaceAll(o->name()->c_str(), {{".", "::"}});
+ absl::string_view table_name =
+ fcp::gen_table_name(fully_qualified_table_name);
+ absl::string_view table_namespace =
+ fcp::gen_table_namespace(fully_qualified_table_name);
+
+ // The fields are sorted in alphabetical order, rather than the order in
+ // which they should be passed to the Create method. Sort them by ID which
+ // determines the order in which the generated Create method accepts them.
+ // ID will be the order in which fields are declared in the table if it is
+ // not explicitly specified for each field.
+ std::vector<const reflection::Field*> fields_sorted;
+ fields_sorted.resize(o->fields()->size());
+ for (const reflection::Field* const f : *o->fields()) {
+ // FlatBuffers field IDs are guaranteed to be dense:
+ assert(f->id() < o->fields()->size());
+ fields_sorted[f->id()] = f;
+ }
+
+ std::vector<std::pair<std::string, fcp::TypeInfo>> fields;
+ for (const reflection::Field* const f : fields_sorted) {
+ // Filter out deprecated fields since the Create method no longer takes
+ // them as parameters.
+ if (f->deprecated()) continue;
+ BaseType flatbuf_type = f->type()->base_type();
+ auto type_map_entry = type_map.find(flatbuf_type);
+ if (type_map_entry == type_map.end()) {
+ std::cerr
+ << absl::StreamFormat(
+ "ERROR: %s contains unsupported type %s for field %s in "
+ "table %s",
+ fcp::gen_fbs_filename(output_filename),
+ reflection::EnumNameBaseType(flatbuf_type),
+ f->name()->c_str(), fully_qualified_table_name)
+ << std::endl;
+ return 4;
+ }
+ if (f->type()->index() != -1) {
+ // If the index of the type is set, it means this is a more complex
+ // type, and we can learn more about the type by indexing into one of
+ // the toplevel fields in the schema - either "objects" or "enums".
+ // Since we do not currently support base_type of kind Union, UnionType,
+ // or Object, if the index is anything other than -1, this type must be
+ // an integer derived from an enum, and we can determine more
+ // information by indexing into "enums". See
+ // https://groups.google.com/g/flatbuffers/c/nAi8MQu3A-U.
+ const reflection::Enum* enum_type =
+ schema.enums()->Get(f->type()->index());
+ fields.emplace_back(
+ f->name()->c_str(),
+ fcp::TypeInfo{
+ flatbuf_type,
+ // Replace '.' with '::' in the fully qualified enum name for
+ // C++ compatibility.
+ absl::StrReplaceAll(enum_type->name()->str(), {{".", "::"}})});
+ } else {
+ fields.emplace_back(
+ f->name()->c_str(),
+ fcp::TypeInfo{flatbuf_type, type_map_entry->second});
+ }
+ }
+
+ std::cout << "template<> class TracingTraits<" << fully_qualified_table_name
+ << ">: public TracingTraitsBase {" << std::endl;
+ std::cout << " public:" << std::endl;
+
+ fcp::TracingSeverity severity = fcp::TracingSeverity::kInfo;
+ std::string tag = "";
+ bool is_span = false;
+ if (o->attributes() == nullptr) {
+ std::cerr
+ << absl::StreamFormat(
+ "ERROR: %s contains table %s without a tag. All tables must "
+ "have a tag defined.",
+ fcp::gen_fbs_filename(output_filename),
+ fully_qualified_table_name)
+ << std::endl;
+ return 5;
+ }
+ for (const reflection::KeyValue* a : *o->attributes()) {
+ if (a->key()->str() == "tag") {
+ tag = a->value()->str();
+ if (tag.size() != 4) {
+ std::cerr
+ << absl::StreamFormat(
+ "ERROR: %s contains table %s with tag %s of length %d. "
+ "All tables must have a tag of length 4.",
+ fcp::gen_fbs_filename(output_filename),
+ fully_qualified_table_name, tag, tag.size())
+ << std::endl;
+ return 6;
+ }
+ }
+ if (a->key()->str() == "warning") {
+ severity = fcp::TracingSeverity::kWarning;
+ }
+ if (a->key()->str() == "error") {
+ severity = fcp::TracingSeverity::kError;
+ }
+ if (a->key()->str() == "span") {
+ is_span = true;
+ }
+ }
+ if (tag.empty()) {
+ std::cerr
+ << absl::StreamFormat(
+ "ERROR: %s contains table %s without a tag. All tables must "
+ "have a tag defined.",
+ fcp::gen_fbs_filename(output_filename),
+ fully_qualified_table_name)
+ << std::endl;
+ return 7;
+ }
+
+ if (!tags.insert(tag).second) {
+ std::cerr
+ << absl::StreamFormat(
+ "ERROR: %s contains table %s with tag %s which is already "
+ "present in the schema. All tags must be unique.",
+ fcp::gen_fbs_filename(output_filename),
+ fully_qualified_table_name, tag)
+ << std::endl;
+ return 8;
+ }
+
+ std::cout << " static constexpr TracingTag kTag = TracingTag(\"" << tag
+ << "\");" << std::endl;
+
+ std::cout << " static constexpr TracingSeverity kSeverity = "
+ << fcp::severity_string(severity) << ";" << std::endl;
+
+ std::cout << " static constexpr bool kIsSpan = "
+ << (is_span ? "true" : "false") << ";" << std::endl;
+
+ std::cout << " const char* Name() const override { return \""
+ << fully_qualified_table_name << "\"; }" << std::endl;
+
+ std::cout << " TracingSeverity Severity() const override {" << std::endl;
+ std::cout << " return " << fcp::severity_string(severity) << ";"
+ << std::endl;
+ std::cout << " }" << std::endl;
+ std::cout
+ << " std::string TextFormat(const flatbuffers::DetachedBuffer& buf) "
+ "const override {"
+ << std::endl;
+ std::cout << " return flatbuffers::FlatBufferToString(buf.data(), "
+ << fully_qualified_table_name << "TypeTable());" << std::endl;
+ std::cout << " }" << std::endl;
+ std::cout << " std::string JsonStringFormat(const uint8_t* flatbuf_bytes) "
+ "const override {"
+ << std::endl;
+ std::cout << " flatbuffers::Parser parser;" << std::endl;
+ std::cout << " std::string schema_file;" << std::endl;
+ std::cout << " std::string fbs_file = \"" << fbs_filename << "\";"
+ << std::endl;
+ std::cout << " flatbuffers::LoadFile(GetDataPath(fbs_file).c_str(), "
+ "true, &schema_file);"
+ << std::endl;
+ // Finds the directory in which the flatbuf class should look for
+ // dependencies of the .fbs file
+ // TODO(team) pass in tracing_schema_common to the script instead of
+ // hardcoding it.
+ std::cout << " std::string schema_path_common = "
+ "GetDataPath(\"fcp/tracing/"
+ "tracing_schema_common.fbs\");"
+ << std::endl;
+ std::cout
+ << " std::string directory_common = schema_path_common.substr(0, "
+ "schema_path_common.find(\"fcp/tracing/"
+ "tracing_schema_common.fbs\"));"
+ << std::endl;
+ // Parser.parse() requires the directories passed in to have a nullptr
+ std::cout << " const char *include_directories[] = {" << std::endl;
+ std::cout << " directory_common.c_str(), nullptr};"
+ << std::endl;
+ // Parse takes in the schema file and populates the FlatBufferBuilder from
+ // the unique schema.
+ std::cout << " parser.Parse(schema_file.c_str(), include_directories);"
+ << std::endl;
+ std::cout << " std::string jsongen;" << std::endl;
+ // The root sets the particular table from the Flatbuffer, since flatbuffers
+ // can have different tables.
+ std::cout << " parser.SetRootType(\"" << fully_qualified_table_name
+ << "\");" << std::endl;
+ std::cout << " GenerateText(parser, flatbuf_bytes, &jsongen);"
+ << std::endl;
+ std::cout << " return jsongen;" << std::endl;
+ std::cout << " }" << std::endl;
+ std::cout << " static flatbuffers::Offset<" << fully_qualified_table_name;
+ std::cout << "> Create(";
+ for (const auto& [name, type] : fields) {
+ std::cout << type.cpp_type << " " << name << ", ";
+ }
+ std::cout << "flatbuffers::FlatBufferBuilder* fbb) {" << std::endl;
+
+ // Strings require special handling because the Create method takes an
+ // Offset<String>. Copy each provided string view into an Offset<String>.
+ for (const auto& [name, type] : fields) {
+ if (type.flatbuf_type == BaseType::String) {
+ std::cout << " auto " << name << "__ = fbb->CreateString(" << name
+ << ".data(), " << name << ".size()"
+ << ");" << std::endl;
+ }
+ }
+
+ std::cout << " return " << table_namespace << "Create" << table_name
+ << "(";
+ std::cout << "*fbb";
+ for (const auto& [name, type] : fields) {
+ const char* suffix = (type.flatbuf_type == BaseType::String) ? "__" : "";
+ std::cout << ", " << name << suffix;
+ }
+ std::cout << ");" << std::endl;
+ std::cout << " }" << std::endl;
+
+ // MakeTuple helper, which allows to generate std::tuple from a table.
+ std::string tuple_type = "std::tuple<";
+ std::string make_tuple_args;
+ for (auto [name, type] : fields) {
+ if (!make_tuple_args.empty()) {
+ tuple_type += ", ";
+ make_tuple_args += ", ";
+ }
+ make_tuple_args += "table->" + name + "()";
+ if (type.flatbuf_type == BaseType::String) {
+ tuple_type += "std::string";
+ make_tuple_args += "->str()";
+ } else {
+ tuple_type += std::string(type.cpp_type);
+ }
+ }
+ tuple_type += ">";
+
+ std::cout << " using TupleType = " << tuple_type << ";" << std::endl;
+ std::cout << " static TupleType MakeTuple(const "
+ << fully_qualified_table_name << "* table) {" << std::endl;
+ std::cout << " return std::make_tuple(" << make_tuple_args << ");"
+ << std::endl;
+ std::cout << " }" << std::endl;
+
+ std::cout << "};" << std::endl;
+ std::cout << "static internal::TracingTraitsRegistrar<"
+ << fully_qualified_table_name << "> registrar_" << table_name
+ << ";" << std::endl;
+ }
+ std::cout << "} // namespace fcp" << std::endl;
+ std::cout << std::endl;
+ std::cout << "#endif // " << header_guard << std::endl;
+ return 0;
+}
diff --git a/fcp/tracing/tools/tracing_traits_generator_test.cc b/fcp/tracing/tools/tracing_traits_generator_test.cc
new file mode 100644
index 0000000..a6dfade
--- /dev/null
+++ b/fcp/tracing/tools/tracing_traits_generator_test.cc
@@ -0,0 +1,125 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// TODO(team): switch to re2 library
+#include <regex> // NOLINT
+#include <string>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/base/platform.h"
+#include "fcp/testing/testing.h"
+
+ABSL_FLAG(std::string, codegen_tool_path, "", "Path to codegen tool script");
+
+namespace fcp {
+namespace {
+
+const char* kBaselineDir = "fcp/tracing/tools/testdata";
+
+std::string PostProcessOutput(const std::string& input) {
+ std::regex header_guard_simplifier(
+ "(THIRD_PARTY_)?FCP_TRACING_TOOLS_TESTDATA");
+ std::string header_guard_replaced = std::regex_replace(
+ input, header_guard_simplifier, "THIRD_PARTY_FCP_TRACING_TOOLS_TESTDATA");
+ std::regex runfiles_path_simplifier("\".*runfiles.*fcp/tracing");
+ // replaces the runfile directory with {RUNFILE_PATH} for testing purposes
+ std::string runfiles_replaced = std::regex_replace(
+ header_guard_replaced, runfiles_path_simplifier, "\"${RUNFILE_PATH}");
+ std::regex path_simplifier_pattern("( |\")(\\w+/)*fcp/tracing");
+ // replaces the directory of the .fbs for testing purposes
+ std::string directory_replaced = std::regex_replace(
+ runfiles_replaced, path_simplifier_pattern, "$1${DIR}");
+ std::regex fcp_base_path_simplifier("(\\w+/)?fcp/base");
+ return std::regex_replace(directory_replaced, fcp_base_path_simplifier,
+ "${BASE}");
+}
+
+void DoTest() {
+ std::string source_file = absl::StrCat(TestName(), ".fbs");
+ std::string source_path =
+ GetTestDataPath(ConcatPath(kBaselineDir, source_file));
+
+ // Read fsb source file derived from the test name:
+ StatusOr<std::string> source_s = ReadFileToString(source_path);
+ ASSERT_THAT(source_s, IsOk()) << "Can't read " << source_path;
+ std::string source = source_s.value();
+
+ std::string out_file =
+ ConcatPath(testing::TempDir(), absl::StrCat(TestName(), ".out"));
+ std::string err_file =
+ ConcatPath(testing::TempDir(), absl::StrCat(TestName(), ".err"));
+
+ // Run codegen script, redirecting stdout to out_file and stderr to err_file
+ int exit_code = system(
+ absl::StrCat(GetTestDataPath(absl::GetFlag(FLAGS_codegen_tool_path)), " ",
+ source_path, " ", testing::TempDir(), " ", kBaselineDir,
+ " 1> ", out_file, " 2> ", err_file)
+ .c_str());
+
+ // Reading error and out files
+ std::string out = ReadFileToString(out_file).value();
+ std::string err = ReadFileToString(err_file).value();
+
+ if (exit_code != 0) {
+ // Codegen failed. This might be expected depending on the test.
+ // In the case of failure we're not interested in capturing possible partial
+ // output in baseline file.
+ out.clear();
+ if (err.empty()) {
+ // If error is not empty it already contains relevant diagnostics,
+ // otherwise adding information about exit code.
+ err = absl::StrCat("Exit code ", exit_code);
+ }
+ }
+
+ // Producing report which is expected to precisely match .baseline file.
+ std::ostringstream report;
+ report << "============== " << source_file << " ============" << std::endl;
+ report << PostProcessOutput(source) << std::endl;
+ report << "============== diagnosis ============" << std::endl;
+ report << PostProcessOutput(err) << std::endl;
+ report << "============== result ============" << std::endl;
+ report << PostProcessOutput(out) << std::endl;
+
+ // Compare produced report with baseline.
+ std::string baseline_path =
+ ConcatPath(kBaselineDir, absl::StrCat(TestName(), ".baseline"));
+ auto status_s = VerifyAgainstBaseline(baseline_path, report.str());
+ ASSERT_TRUE(status_s.ok()) << status_s.status();
+ auto& diff = status_s.value();
+ if (!diff.empty()) {
+ FAIL() << diff;
+ }
+}
+
+TEST(Codegen, EmptyTable) { DoTest(); }
+TEST(Codegen, FieldsOfDifferentTypes) { DoTest(); }
+TEST(Codegen, DeprecatedField) { DoTest(); }
+TEST(Codegen, NonTableObjectsAreSkipped) { DoTest(); }
+TEST(Codegen, AllTypes) { DoTest(); }
+TEST(Codegen, OrderWithIds) { DoTest(); }
+TEST(Codegen, NoTag) { DoTest(); }
+TEST(Codegen, NoAttributes) { DoTest(); }
+TEST(Codegen, TagTooLong) { DoTest(); }
+TEST(Codegen, DuplicateTags) { DoTest(); }
+TEST(Codegen, UnsupportedType) { DoTest(); }
+TEST(Codegen, TableWithNamespace) { DoTest(); }
+TEST(Codegen, EnumType) { DoTest(); }
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/tracing/tracing_context_utils.cc b/fcp/tracing/tracing_context_utils.cc
new file mode 100644
index 0000000..7e40094
--- /dev/null
+++ b/fcp/tracing/tracing_context_utils.cc
@@ -0,0 +1,40 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/tracing_context_utils.h"
+
+#include <optional>
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "fcp/base/monitoring.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace fcp::tracing_internal {
+
+void SetTracingContextOnMessage(const google::protobuf::Message& context,
+ google::protobuf::Message& message) {
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ message.GetDescriptor()->FindFieldByName(kContextFieldName);
+ if (field_descriptor == nullptr) {
+ return;
+ }
+ FCP_CHECK(field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_BYTES ||
+ field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING)
+ << kContextWrongTypeMessage;
+ message.GetReflection()->SetString(&message, field_descriptor,
+ context.SerializeAsString());
+}
+
+} // namespace fcp::tracing_internal
diff --git a/fcp/tracing/tracing_context_utils.h b/fcp/tracing/tracing_context_utils.h
new file mode 100644
index 0000000..94df509
--- /dev/null
+++ b/fcp/tracing/tracing_context_utils.h
@@ -0,0 +1,70 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TRACING_CONTEXT_UTILS_H_
+#define FCP_TRACING_TRACING_CONTEXT_UTILS_H_
+
+#include <string>
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/message.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/tracing/tracing_traits.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace fcp::tracing_internal {
+
+constexpr char kContextFieldName[] = "tracing_context";
+constexpr char kContextWrongTypeMessage[] =
+ "Type of tracing_context field should be bytes or string";
+
+// Given a proto message, checks to see if it has a tracing_context field and
+// if so, serializes the provided `context` message and sets the field contents
+// to the serialized context.
+// If no tracing_context field is present on `message` or it is of a type other
+// than string or bytes, this will be a no-op.
+void SetTracingContextOnMessage(const google::protobuf::Message& context,
+ google::protobuf::Message& message);
+
+// Given a proto `message` which may have a field tracing_context which contains
+// a serialized tracing context proto of type ContextT, uses reflection to
+// extract and parse the serialized proto to return a ContextT.
+// If the proto `message` does not have a tracing_context field, or it is empty,
+// returns the default value of ContextT.
+// If the tracing_context field is of a type other than string or bytes, this
+// will fail.
+template <class ContextT>
+ContextT GetContextFromMessage(const google::protobuf::Message& message) {
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ message.GetDescriptor()->FindFieldByName(kContextFieldName);
+ ContextT context;
+ if (field_descriptor == nullptr) {
+ return context;
+ }
+ FCP_CHECK(field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_BYTES ||
+ field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING)
+ << kContextWrongTypeMessage;
+ std::string serialized_context =
+ message.GetReflection()->GetString(message, field_descriptor);
+ if (serialized_context.empty()) {
+ return context;
+ }
+ FCP_CHECK(context.ParseFromString(serialized_context));
+
+ return context;
+}
+
+} // namespace fcp::tracing_internal
+
+#endif // FCP_TRACING_TRACING_CONTEXT_UTILS_H_
diff --git a/fcp/tracing/tracing_recorder.h b/fcp/tracing/tracing_recorder.h
new file mode 100644
index 0000000..fdc6632
--- /dev/null
+++ b/fcp/tracing/tracing_recorder.h
@@ -0,0 +1,62 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TRACING_RECORDER_H_
+#define FCP_TRACING_TRACING_RECORDER_H_
+
+#include "fcp/tracing/tracing_span.h"
+namespace fcp {
+
+// Interface to be implemented by tracing recorders, which are responsible
+// for implementation behind the TracingSpan API.
+// A tracing recorder provides ability to create root span.
+class TracingRecorder {
+ public:
+ // TracingRecorder is neither copyable nor movable.
+ TracingRecorder(const TracingRecorder&) = delete;
+ TracingRecorder& operator=(const TracingRecorder&) = delete;
+
+ TracingRecorder() = default;
+
+ // It is OK to destruct this facade API object anytime, since underlying
+ // implementation lifetime is independent from the facade and automatically
+ // prolonged by active tracing span (in this or other threads)
+ virtual ~TracingRecorder() = default;
+
+ // Installs tracing recorder as global instance.
+ // It uninstalls automatically upon destruction of underlying implementation.
+ // Only one instance can be installed and this operation will fail if other
+ // recorder is installed as global.
+ virtual void InstallAsGlobal() = 0;
+
+ // Uninstalls tracing recorder as global instance. Allowed to be called only
+ // if InstallAsGlobal() was called.
+ // NOTE: if some concurrent threads have active tracing spans on their stacks,
+ // they can continue tracing with the tracing recorder even after uninstalling
+ // it as global.
+ virtual void UninstallAsGlobal() = 0;
+
+ // Installs tracing recorder as thread local instance.
+ // Only one instance can be installed per thread, and this operation will fail
+ // if other recorder is installed for the current thread.
+ virtual void InstallAsThreadLocal() = 0;
+
+ // Uninstalls tracing recorder as thread local instance. Allowed to be called
+ // only if InstallAsThreadLocal has been called.
+ virtual void UninstallAsThreadLocal() = 0;
+};
+
+} // namespace fcp
+
+#endif // FCP_TRACING_TRACING_RECORDER_H_
diff --git a/fcp/tracing/tracing_recorder_impl.cc b/fcp/tracing/tracing_recorder_impl.cc
new file mode 100644
index 0000000..cb9c7be
--- /dev/null
+++ b/fcp/tracing/tracing_recorder_impl.cc
@@ -0,0 +1,146 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/tracing_recorder_impl.h"
+
+#include <memory>
+
+#include "absl/synchronization/mutex.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/tracing/text_tracing_recorder_impl.h"
+
+namespace fcp::tracing_internal {
+
+class TracingState {
+ absl::Mutex mutex_;
+ bool using_thread_local_state_ = false;
+ TracingRecorderImpl* global_tracing_recorder_ = nullptr;
+
+ struct ThreadLocalState {
+ TracingRecorderImpl* tracing_recorder = nullptr;
+ // Ref count is used to track the number of times the same
+ // TracingRecorderImpl has been set in case of re-entrancy.
+ int ref_count = 0;
+ };
+
+ static ThreadLocalState& GetThreadLocalState() {
+ thread_local static ThreadLocalState instance;
+ return instance;
+ }
+
+ static std::shared_ptr<TracingRecorderImpl> GetOrCreateDefaultRecorder() {
+ static auto lazy_init_instance =
+ new std::shared_ptr<TextTracingRecorderImpl>(
+ new TextTracingRecorderImpl(absl::LocalTimeZone()));
+ return *lazy_init_instance;
+ }
+
+ public:
+ std::shared_ptr<TracingRecorderImpl> GetRecorderImpl() {
+ absl::ReaderMutexLock lock(&mutex_);
+ TracingRecorderImpl* tracing_recorder =
+ using_thread_local_state_ ? GetThreadLocalState().tracing_recorder
+ : global_tracing_recorder_;
+ return tracing_recorder ? tracing_recorder->shared_from_this()
+ : GetOrCreateDefaultRecorder();
+ }
+
+ void SetGlobalRecorderImpl(TracingRecorderImpl* impl) {
+ absl::WriterMutexLock lock(&mutex_);
+ FCP_CHECK(!using_thread_local_state_)
+ << "Global and thread local tracing recorders can't be used at the "
+ "same time";
+ FCP_CHECK(global_tracing_recorder_ == nullptr || impl == nullptr)
+ << "Only one global tracing recorder instance is supported";
+ FCP_LOG(INFO) << "Setting global";
+ global_tracing_recorder_ = impl;
+ }
+
+ void SetThreadLocalRecorderImpl(TracingRecorderImpl* impl) {
+ FCP_CHECK(impl != nullptr);
+ absl::WriterMutexLock lock(&mutex_);
+ auto& thread_local_state = GetThreadLocalState();
+ FCP_CHECK(global_tracing_recorder_ == nullptr)
+ << "Global and thread local tracing recorders can't be used at the "
+ "same time";
+ FCP_CHECK(thread_local_state.tracing_recorder == nullptr ||
+ thread_local_state.tracing_recorder == impl)
+ << "Only one tracing recorder instance per thread is supported";
+ thread_local_state.tracing_recorder = impl;
+ thread_local_state.ref_count++;
+ using_thread_local_state_ = true;
+ }
+
+ void ResetThreadLocalRecorderImpl(TracingRecorderImpl* impl) {
+ FCP_CHECK(impl != nullptr);
+ absl::WriterMutexLock lock(&mutex_);
+ auto& thread_local_state = GetThreadLocalState();
+ FCP_CHECK(thread_local_state.tracing_recorder == impl &&
+ thread_local_state.ref_count > 0)
+ << "Attempting to uninstall thread local tracing recorder that isn't "
+ "currently installed";
+ if (--thread_local_state.ref_count == 0) {
+ thread_local_state.tracing_recorder = nullptr;
+ }
+ }
+
+ void EnsureNotSet(TracingRecorderImpl* impl) {
+ absl::WriterMutexLock lock(&mutex_);
+ FCP_CHECK(global_tracing_recorder_ != impl)
+ << "Trace recorder must not be set as global at destruction time";
+ if (using_thread_local_state_) {
+ FCP_CHECK(GetThreadLocalState().tracing_recorder != impl)
+ << "Trace recorder must not be set as thread local at destruction "
+ "time";
+ }
+ }
+
+ static TracingState& GetInstance() {
+ static TracingState* instance = new TracingState();
+ return *instance;
+ }
+};
+
+std::shared_ptr<TracingRecorderImpl> TracingRecorderImpl::GetCurrent() {
+ return TracingState::GetInstance().GetRecorderImpl();
+}
+
+void TracingRecorderImpl::InstallAsGlobal() {
+ FCP_CHECK(!is_global_);
+ TracingState::GetInstance().SetGlobalRecorderImpl(this);
+ is_global_ = true;
+}
+
+void TracingRecorderImpl::UninstallAsGlobal() {
+ FCP_CHECK(is_global_);
+ TracingState::GetInstance().SetGlobalRecorderImpl(nullptr);
+ is_global_ = false;
+}
+
+void TracingRecorderImpl::InstallAsThreadLocal() {
+ TracingState::GetInstance().SetThreadLocalRecorderImpl(this);
+}
+
+void TracingRecorderImpl::UninstallAsThreadLocal() {
+ TracingState::GetInstance().ResetThreadLocalRecorderImpl(this);
+}
+
+TracingRecorderImpl::~TracingRecorderImpl() {
+ if (is_global_) {
+ UninstallAsGlobal();
+ }
+ TracingState::GetInstance().EnsureNotSet(this);
+}
+
+} // namespace fcp::tracing_internal
diff --git a/fcp/tracing/tracing_recorder_impl.h b/fcp/tracing/tracing_recorder_impl.h
new file mode 100644
index 0000000..4d7968c
--- /dev/null
+++ b/fcp/tracing/tracing_recorder_impl.h
@@ -0,0 +1,68 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TRACING_RECORDER_IMPL_H_
+#define FCP_TRACING_TRACING_RECORDER_IMPL_H_
+
+#include <memory>
+
+#include "fcp/tracing/tracing_span_impl.h"
+
+namespace fcp {
+namespace tracing_internal {
+
+class TracingRecorderImpl
+ : public std::enable_shared_from_this<TracingRecorderImpl> {
+ public:
+ TracingRecorderImpl() = default;
+ virtual ~TracingRecorderImpl();
+
+ // TracingRecorderImpl is neither copyable nor movable.
+ TracingRecorderImpl(const TracingRecorderImpl&) = delete;
+ TracingRecorderImpl& operator=(const TracingRecorderImpl&) = delete;
+
+ // Trace an event represented by the flatbuffer.
+ virtual void TraceImpl(TracingSpanId span_id,
+ flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) = 0;
+ virtual TracingSpanImpl* GetRootSpan() = 0;
+
+ // Creates child span.
+ virtual std::unique_ptr<TracingSpanImpl> CreateChildSpan(
+ TracingSpanId parent_span_id, flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) = 0;
+
+ // Installs this tracing recorder as global singleton instance.
+ void InstallAsGlobal();
+
+ // Uninstalls this tracing recorder as global instance. Automatically
+ // called upon destruction.
+ void UninstallAsGlobal();
+
+ // Installs this tracing recorder as thread local singleton instance.
+ void InstallAsThreadLocal();
+
+ // Uninstalls this tracing recorder as thread local singleton instance.
+ void UninstallAsThreadLocal();
+
+ // Gets the current thread local tracing recorder if set; otherwise gets
+ // the current global tracing recorder.
+ static std::shared_ptr<TracingRecorderImpl> GetCurrent();
+ bool is_global_ = false;
+};
+
+} // namespace tracing_internal
+} // namespace fcp
+
+#endif // FCP_TRACING_TRACING_RECORDER_IMPL_H_
diff --git a/fcp/tracing/tracing_schema_common.fbs b/fcp/tracing/tracing_schema_common.fbs
new file mode 100644
index 0000000..a9cb503
--- /dev/null
+++ b/fcp/tracing/tracing_schema_common.fbs
@@ -0,0 +1,28 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Common tracing schema definitions:
+
+// The 4 letter tag for this event or operation. All tracing events and
+// operations must have a tag and the tag must be unique.
+attribute "tag";
+// Log messages for this table should have severity ERROR. Default is INFO.
+attribute "error";
+// Log messages for this table should have severity WARNING. Default is INFO.
+attribute "warning";
+// This table represents a tracing span within which child spans or events can
+// be created. The default is that the table represents an event that can be
+// logged but which can not have child spans or events.
+attribute "span";
+
diff --git a/fcp/tracing/tracing_schema_common_generated.h b/fcp/tracing/tracing_schema_common_generated.h
new file mode 100644
index 0000000..76d2eec
--- /dev/null
+++ b/fcp/tracing/tracing_schema_common_generated.h
@@ -0,0 +1,16 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_TRACINGSCHEMACOMMON_H_
+#define FLATBUFFERS_GENERATED_TRACINGSCHEMACOMMON_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+// Ensure the included flatbuffers.h is the same version as when this file was
+// generated, otherwise it may not be compatible.
+static_assert(FLATBUFFERS_VERSION_MAJOR == 2 &&
+ FLATBUFFERS_VERSION_MINOR == 0 &&
+ FLATBUFFERS_VERSION_REVISION == 7,
+ "Non-compatible flatbuffers version included");
+
+#endif // FLATBUFFERS_GENERATED_TRACINGSCHEMACOMMON_H_ \ No newline at end of file
diff --git a/fcp/tracing/tracing_severity.h b/fcp/tracing/tracing_severity.h
new file mode 100644
index 0000000..2ddf13b
--- /dev/null
+++ b/fcp/tracing/tracing_severity.h
@@ -0,0 +1,24 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TRACING_TRACING_SEVERITY_H_
+#define FCP_TRACING_TRACING_SEVERITY_H_
+
+namespace fcp {
+enum class TracingSeverity { kInfo, kWarning, kError };
+} // namespace fcp
+
+#endif // FCP_TRACING_TRACING_SEVERITY_H_
diff --git a/fcp/tracing/tracing_span.h b/fcp/tracing/tracing_span.h
new file mode 100644
index 0000000..943e2a7
--- /dev/null
+++ b/fcp/tracing/tracing_span.h
@@ -0,0 +1,250 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TRACING_SPAN_H_
+#define FCP_TRACING_TRACING_SPAN_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "fcp/base/error.h"
+#include "fcp/tracing/tracing_recorder_impl.h"
+#include "fcp/tracing/tracing_span_impl.h"
+#include "fcp/tracing/tracing_span_ref.h"
+#include "fcp/tracing/tracing_traits.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace fcp {
+
+namespace tracing_internal {
+
+template <class FlatBufferTable>
+void AssertIsNotSpan() {
+ static_assert(TracingTraits<FlatBufferTable>::kIsSpan == false,
+ "Trace can only be called on an event table, not a span. To "
+ "convert the table to an event remove 'span' from the table "
+ "attributes.");
+}
+
+template <class FlatBufferTable>
+void AssertIsError() {
+ static_assert(
+ TracingTraits<FlatBufferTable>::kSeverity == TracingSeverity::kError,
+ "TraceError can only be called on table of severity Error. To "
+ "convert the table to error severity add 'error' to the "
+ "table attributes.");
+}
+
+// Generic method to build FlatBuff tables from a set of args.
+template <class FlatBufferTable, class... Arg>
+flatbuffers::DetachedBuffer BuildFlatBuffer(Arg&&... args) {
+ flatbuffers::FlatBufferBuilder fbb;
+ flatbuffers::Offset<FlatBufferTable> root =
+ TracingTraits<FlatBufferTable>::Create(std::forward<Arg>(args)..., &fbb);
+ constexpr TracingTag tag = TracingTraits<FlatBufferTable>::kTag;
+ constexpr char signature[5] = {tag.data[0], tag.data[1], tag.data[2],
+ tag.data[3], 0};
+ fbb.Finish(root, signature);
+ return fbb.Release();
+}
+
+/*
+ * Un-templatized version of the tracing span which stores the raw
+ * definition of the tracing span and returns a reference to it.
+ *
+ * This class should not/cannot be publicly initialized, instead it should be
+ * used as a base class for both UnscopedTracingSpan and TracingSpan.
+ *
+ * For Java based tracing spans, all native spans will be cast to this common
+ * structure to pass along the JNI boundary.
+ */
+class TracingSpanBase {
+ public:
+ // Returns a reference to this span.
+ inline TracingSpanRef Ref() const { return impl_->Ref(); }
+ virtual ~TracingSpanBase() = default;
+
+ protected:
+ explicit TracingSpanBase() = default;
+ // Set by the child classes after an instance of TracingSpanImpl is created
+ // from the flatbuf definitions.
+ std::unique_ptr<TracingSpanImpl> impl_;
+};
+
+} // namespace tracing_internal
+
+// Unscoped tracing span carries tracing context for some logical activity.
+// The primary purpose of this class is to keep the tracing context of a
+// long-running asynchronous activity, typically associated with the lifetime of
+// some long-lived object.
+//
+// This span is not scoped to a block/function and need not necessarily be
+// instantiated as a local variable on the stack. It is OK to initialise this
+// span on a heap and/or as a field member of a class. For scoped variant,
+// prefer using TracingSpan instead.
+//
+// To record a Trace within an UnscopedTracingSpan, the caller must explicitly
+// pass a reference to a span e.g Trace<FlatBuff>(span, args...).
+//
+// The class is NOT thread-safe, but OK to use from different threads if invoked
+// sequentially (i.e. with external synchronisation providing sequential
+// consistency memory ordering).
+//
+// Recommended usage is to create new child span for every sub-activity or
+// operation.
+template <class FlatBufferTable>
+class UnscopedTracingSpan : public tracing_internal::TracingSpanBase {
+ public:
+ // UnscopedTracingSpan is neither copyable nor movable.
+ UnscopedTracingSpan(const UnscopedTracingSpan&) = delete;
+ UnscopedTracingSpan& operator=(const UnscopedTracingSpan&) = delete;
+ // Public constructors allow creating new tracing span.
+ // A parent span reference can be optionally provided as a first argument.
+ // By default current span is used as a parent.
+ template <class... Arg>
+ explicit UnscopedTracingSpan(Arg&&... args);
+ template <class... Arg>
+ explicit UnscopedTracingSpan(TracingSpanRef parent, Arg&&... args);
+
+ private:
+ template <class... Arg>
+ void Create(TracingSpanRef parent, Arg&&... args);
+};
+
+template <class FlatBufferTable>
+template <class... Arg>
+UnscopedTracingSpan<FlatBufferTable>::UnscopedTracingSpan(Arg&&... args) {
+ Create(TracingSpanRef::Top(), std::forward<Arg>(args)...);
+}
+
+template <class FlatBufferTable>
+template <class... Arg>
+UnscopedTracingSpan<FlatBufferTable>::UnscopedTracingSpan(TracingSpanRef parent,
+ Arg&&... args) {
+ Create(parent, std::forward<Arg>(args)...);
+}
+
+template <class FlatBufferTable>
+template <class... Arg>
+void UnscopedTracingSpan<FlatBufferTable>::Create(TracingSpanRef parent,
+ Arg&&... args) {
+ static_assert(
+ TracingTraits<FlatBufferTable>::kIsSpan == true,
+ "UnscopedTracingSpan can only be created from a span table, not "
+ "an event. To convert the table to an span add 'span' "
+ "to the table attributes.");
+ flatbuffers::DetachedBuffer trace_data =
+ tracing_internal::BuildFlatBuffer<FlatBufferTable>(
+ std::forward<Arg>(args)...);
+ impl_ = parent.recorder()->CreateChildSpan(parent.span_id(),
+ std::move(trace_data),
+ TracingTraits<FlatBufferTable>());
+}
+
+// Tracing span, carrying abstract tracing context for some logical activity
+// performed by the application. Provides ability to log various events
+// which are automatically associated with such a context, thus allowing logging
+// essential information only.
+//
+// This class uses RAII-style mechanism of entering/existing the tracing span
+// for the duration of a scoped block or a function, this class must be
+// instantiated as a local variable on the stack, in a similar manner as
+// std::lock_guard.
+//
+// The class is NOT thread-safe, but OK to use from different threads if invoked
+// sequentially (i.e. with external synchronisation providing sequential
+// consistency memory ordering).
+//
+// Recommended usage is to create new child span for every sub-activity or
+// operation.
+//
+// For a more general variant that is not necessarily tied to a scoped block,
+// prefer using UnscopedTracingSpan directly.
+template <class FlatBufferTable>
+class TracingSpan final : public UnscopedTracingSpan<FlatBufferTable> {
+ public:
+ // Since this manipulates TLS/FLS in RAII fashion, this is intended to be
+ // used as a stack local (no heap alloc allowed):
+ void* operator new(std::size_t) = delete;
+ void* operator new[](std::size_t) = delete;
+ // Public constructors allow creating new tracing span.
+ template <class... Arg>
+ explicit TracingSpan(Arg&&... args)
+ : UnscopedTracingSpan<FlatBufferTable>(std::forward<Arg>(args)...) {
+ UnscopedTracingSpan<FlatBufferTable>::impl_->Push();
+ }
+ template <class... Arg>
+ explicit TracingSpan(TracingSpanRef parent, Arg&&... args)
+ : UnscopedTracingSpan<FlatBufferTable>(parent,
+ std::forward<Arg>(args)...) {
+ UnscopedTracingSpan<FlatBufferTable>::impl_->Push();
+ }
+
+ // Destructor closes the span
+ ~TracingSpan() override {
+ UnscopedTracingSpan<FlatBufferTable>::impl_->Pop();
+ }
+};
+
+// Writes a trace with the specified args under the topmost span located on
+// TLS. If no span exists on TLS, then root span is fetched from the global
+// recorder.
+template <class FlatBufferTable, class... Arg>
+void Trace(Arg&&... args) {
+ tracing_internal::AssertIsNotSpan<FlatBufferTable>();
+ flatbuffers::DetachedBuffer trace_data =
+ tracing_internal::BuildFlatBuffer<FlatBufferTable>(
+ std::forward<Arg>(args)...);
+ // Now discover what tracing span to log that with:
+ tracing_internal::TracingSpanImpl* top =
+ tracing_internal::TracingSpanImpl::Top();
+ if (top != nullptr) {
+ // Fast path: getting top tracing span from TLS/FCB:
+ top->TraceImpl(std::move(trace_data), TracingTraits<FlatBufferTable>());
+ } else {
+ // Slow path, finding root span from global recorder. This
+ // involves increasing its reference counter:
+ std::shared_ptr<tracing_internal::TracingRecorderImpl> recorder =
+ tracing_internal::TracingRecorderImpl::GetCurrent();
+ recorder->GetRootSpan()->TraceImpl(std::move(trace_data),
+ TracingTraits<FlatBufferTable>());
+ // now, since we done with using root span, it is safe to release shared_ptr
+ }
+}
+
+// Writes a trace under the specified span with the given args.
+template <class FlatBufferTable, class... Arg>
+void Trace(TracingSpanRef span, Arg&&... args) {
+ tracing_internal::AssertIsNotSpan<FlatBufferTable>();
+ flatbuffers::DetachedBuffer trace_data =
+ tracing_internal::BuildFlatBuffer<FlatBufferTable>(
+ std::forward<Arg>(args)...);
+ span.recorder()->TraceImpl(span.span_id(), std::move(trace_data),
+ TracingTraits<FlatBufferTable>());
+}
+
+// Writes an error trace with the specified args under the topmost span located
+// on TLS. If no span exists on TLS, then root span is fetched from the global
+// recorder.
+template <class FlatBufferTable, class... Arg>
+ABSL_MUST_USE_RESULT Error TraceError(Arg&&... args) {
+ tracing_internal::AssertIsError<FlatBufferTable>();
+ Trace<FlatBufferTable>(args...);
+ return Error(Error::ConstructorAccess{});
+}
+
+} // namespace fcp
+
+#endif // FCP_TRACING_TRACING_SPAN_H_
diff --git a/fcp/tracing/tracing_span_id.cc b/fcp/tracing/tracing_span_id.cc
new file mode 100644
index 0000000..59d08d1
--- /dev/null
+++ b/fcp/tracing/tracing_span_id.cc
@@ -0,0 +1,26 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/tracing_span_id.h"
+
+namespace fcp {
+
+std::atomic<std::int64_t> TracingSpanId::id_source = 1;
+
+TracingSpanId TracingSpanId::NextUniqueId() {
+ std::int64_t new_id = id_source.fetch_add(1, std::memory_order_seq_cst);
+ return TracingSpanId(new_id);
+}
+
+} // namespace fcp
diff --git a/fcp/tracing/tracing_span_id.h b/fcp/tracing/tracing_span_id.h
new file mode 100644
index 0000000..b737ccc
--- /dev/null
+++ b/fcp/tracing/tracing_span_id.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FCP_TRACING_TRACING_SPAN_ID_H_
+#define FCP_TRACING_TRACING_SPAN_ID_H_
+
+#include <atomic>
+#include <cstdint>
+#include <ostream>
+
+namespace fcp {
+
+// Uniquely identifies tracing span within a process.
+struct TracingSpanId {
+ std::int64_t value;
+
+ // Generates next unique id
+ static TracingSpanId NextUniqueId();
+ explicit constexpr TracingSpanId(int id) : value(id) {}
+
+ private:
+ static std::atomic<std::int64_t> id_source;
+};
+
+inline bool operator==(const TracingSpanId& a, const TracingSpanId& b) {
+ return a.value == b.value;
+}
+
+inline bool operator!=(const TracingSpanId& a, const TracingSpanId& b) {
+ return a.value != b.value;
+}
+
+// Overload comparison operators to make it possible to sort by order of ID
+// generation.
+inline bool operator<(const TracingSpanId& a, const TracingSpanId& b) {
+ return a.value < b.value;
+}
+
+inline std::ostream& operator<<(std::ostream& s, const TracingSpanId& id) {
+ return s << id.value;
+}
+
+template <typename H>
+H AbslHashValue(H h, const TracingSpanId& id) {
+ return H::combine(std::move(h), id.value);
+}
+
+} // namespace fcp
+
+#endif // FCP_TRACING_TRACING_SPAN_ID_H_
diff --git a/fcp/tracing/tracing_span_impl.cc b/fcp/tracing/tracing_span_impl.cc
new file mode 100644
index 0000000..6c1b20d
--- /dev/null
+++ b/fcp/tracing/tracing_span_impl.cc
@@ -0,0 +1,38 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/tracing_span_impl.h"
+
+#include "fcp/tracing/tracing_recorder_impl.h"
+
+namespace fcp {
+namespace tracing_internal {
+
+thread_local TracingSpanImpl* TracingSpanImpl::top_tracing_span_ = nullptr;
+
+TracingSpanImpl* TracingSpanImpl::Top() { return top_tracing_span_; }
+
+void TracingSpanImpl::Pop() {
+ FCP_CHECK(top_tracing_span_ == this);
+ top_tracing_span_ = prev_;
+}
+
+void TracingSpanImpl::Push() {
+ FCP_CHECK(prev_ == nullptr);
+ prev_ = top_tracing_span_;
+ top_tracing_span_ = this;
+}
+
+} // namespace tracing_internal
+} // namespace fcp
diff --git a/fcp/tracing/tracing_span_impl.h b/fcp/tracing/tracing_span_impl.h
new file mode 100644
index 0000000..2a50148
--- /dev/null
+++ b/fcp/tracing/tracing_span_impl.h
@@ -0,0 +1,68 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TRACING_SPAN_IMPL_H_
+#define FCP_TRACING_TRACING_SPAN_IMPL_H_
+
+#include "fcp/base/monitoring.h"
+#include "fcp/tracing/tracing_span_ref.h"
+#include "fcp/tracing/tracing_traits.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace fcp {
+namespace tracing_internal {
+
+class TracingSpanImpl {
+ public:
+ // TracingSpanImpl is neither copyable nor movable:
+ TracingSpanImpl(const TracingSpanImpl&) = delete;
+ TracingSpanImpl& operator=(const TracingSpanImpl&) = delete;
+ // Destructor closes the span
+ virtual ~TracingSpanImpl() = default;
+ // Internal logging implementation, to be used by tracing recorder only.
+ virtual void TraceImpl(flatbuffers::DetachedBuffer&& buf,
+ const TracingTraitsBase& traits) = 0;
+
+ // Pushes current tracing span to be the top one on the current thread/fiber:
+ void Push();
+ // Pops current tracing span
+ void Pop();
+ // Returns top tracing span for the current thread/fiber
+ static TracingSpanImpl* Top();
+
+ // Returns reference to this tracing span
+ virtual TracingSpanRef Ref() = 0;
+
+ protected:
+ // TracingSpanImpl can't be directly constructed, use CreateChild():
+ TracingSpanImpl() = default;
+
+ private:
+ // Optional pointer to the previous tracing span which was a Top() one before
+ // this span was pushed.
+ // This is used so we can restore the top one with Pop().
+ // NOTE: while this is frequently points to the parent span, it doesn't have
+ // to be the parent span, since a span might be constructed with arbitrary
+ // parent, which doesn't have to be the current Top() one. Example: when a new
+ // fiber is started the parent is on a different stack.
+ TracingSpanImpl* prev_ = nullptr;
+
+ // TODO(team): this assumes 1:1 fiber-thread relationship, use FCB:
+ thread_local static tracing_internal::TracingSpanImpl* top_tracing_span_;
+};
+
+} // namespace tracing_internal
+} // namespace fcp
+
+#endif // FCP_TRACING_TRACING_SPAN_IMPL_H_
diff --git a/fcp/tracing/tracing_span_ref.cc b/fcp/tracing/tracing_span_ref.cc
new file mode 100644
index 0000000..b93a658
--- /dev/null
+++ b/fcp/tracing/tracing_span_ref.cc
@@ -0,0 +1,30 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/tracing_span_ref.h"
+
+#include "fcp/tracing/tracing_recorder_impl.h"
+#include "fcp/tracing/tracing_span_impl.h"
+
+namespace fcp {
+
+TracingSpanRef fcp::TracingSpanRef::Top() {
+ tracing_internal::TracingSpanImpl* top =
+ tracing_internal::TracingSpanImpl::Top();
+ return top ? top->Ref()
+ : tracing_internal::TracingRecorderImpl::GetCurrent()
+ ->GetRootSpan()
+ ->Ref();
+}
+} // namespace fcp
diff --git a/fcp/tracing/tracing_span_ref.h b/fcp/tracing/tracing_span_ref.h
new file mode 100644
index 0000000..4670886
--- /dev/null
+++ b/fcp/tracing/tracing_span_ref.h
@@ -0,0 +1,55 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TRACING_SPAN_REF_H_
+#define FCP_TRACING_TRACING_SPAN_REF_H_
+
+#include <memory>
+#include <utility>
+
+#include "fcp/tracing/tracing_span_id.h"
+
+namespace fcp {
+
+namespace tracing_internal {
+class TracingRecorderImpl;
+}
+
+// Reference to a tracing span.
+class TracingSpanRef {
+ // Reference to the tracing recorder this reference was issued by:
+ std::shared_ptr<fcp::tracing_internal::TracingRecorderImpl> recorder_;
+ // Identifier of the span
+ TracingSpanId span_id_;
+
+ public:
+ TracingSpanRef(
+ std::shared_ptr<fcp::tracing_internal::TracingRecorderImpl> provider,
+ TracingSpanId span_id)
+ : recorder_(std::move(provider)), span_id_(span_id) {}
+
+ std::shared_ptr<tracing_internal::TracingRecorderImpl> recorder() {
+ return recorder_;
+ }
+
+ TracingSpanId span_id() const { return span_id_; }
+
+ // Returns reference to the top tracing span on the current
+ // thread/fiber. If there's no tracing span established, a
+ // reference to the root span of global tracing recorder is returned.
+ static TracingSpanRef Top();
+};
+
+} // namespace fcp
+#endif // FCP_TRACING_TRACING_SPAN_REF_H_
diff --git a/fcp/tracing/tracing_tag.h b/fcp/tracing/tracing_tag.h
new file mode 100644
index 0000000..19fbd4a
--- /dev/null
+++ b/fcp/tracing/tracing_tag.h
@@ -0,0 +1,70 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TRACING_TAG_H_
+#define FCP_TRACING_TRACING_TAG_H_
+
+#include <stdint.h>
+
+#include <iosfwd>
+#include <ostream>
+#include <string>
+
+#include "absl/base/attributes.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace fcp {
+
+// 4-character tag uniquely identifying FlatBuffers tables used in tracing.
+union TracingTag {
+ char data[4];
+ uint32_t value;
+ // Constructing from 4-char string literals. Supplied char array
+ // has a length of 5, to accommodate terminating 0, which is NOT stored in the
+ // tag.
+ explicit constexpr TracingTag(char const (&fourCharsLiteral)[5])
+ : data{fourCharsLiteral[0], fourCharsLiteral[1], fourCharsLiteral[2],
+ fourCharsLiteral[3]} {}
+
+ ABSL_MUST_USE_RESULT std::string str() const {
+ return std::string{data[0], data[1], data[2], data[3]};
+ }
+ TracingTag() = delete;
+
+ static inline const TracingTag* FromFlatbuf(
+ const flatbuffers::DetachedBuffer& buf) {
+ return reinterpret_cast<const TracingTag*>(&buf.data()[4]);
+ }
+};
+
+inline std::ostream& operator<<(std::ostream& out, const TracingTag& tag) {
+ out << tag.str();
+ return out;
+}
+
+inline bool operator==(TracingTag const& a, TracingTag const& b) {
+ return a.value == b.value;
+}
+
+inline bool operator!=(TracingTag const& a, TracingTag const& b) {
+ return a.value != b.value;
+}
+
+template <typename H>
+H AbslHashValue(H h, const fcp::TracingTag& t) {
+ return H::combine(std::move(h), t.value);
+}
+} // namespace fcp
+
+#endif // FCP_TRACING_TRACING_TAG_H_
diff --git a/fcp/tracing/tracing_tag_test.cc b/fcp/tracing/tracing_tag_test.cc
new file mode 100644
index 0000000..2167612
--- /dev/null
+++ b/fcp/tracing/tracing_tag_test.cc
@@ -0,0 +1,42 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/tracing_tag.h"
+
+#include "gtest/gtest.h"
+#include "absl/hash/hash_testing.h"
+
+namespace fcp {
+namespace {
+
+TEST(TracingTag, Construction) { EXPECT_EQ(TracingTag("ABCD").str(), "ABCD"); }
+
+TEST(TracingTag, SupportsAbslHash) {
+ EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly(
+ {TracingTag("AAAA"), TracingTag("AAAA"), TracingTag("ABCD"),
+ TracingTag("DCBA")}));
+}
+
+TEST(TracingTag, Comparison) {
+ TracingTag a1("AAAA");
+ TracingTag a2("AAAA");
+ TracingTag b("BBBB");
+ EXPECT_TRUE(a1 == a2);
+ EXPECT_FALSE(a1 != a2);
+ EXPECT_FALSE(a1 == b);
+ EXPECT_TRUE(a1 != b);
+}
+
+} // namespace
+} // namespace fcp
diff --git a/fcp/tracing/tracing_traits.cc b/fcp/tracing/tracing_traits.cc
new file mode 100644
index 0000000..6e58d7d
--- /dev/null
+++ b/fcp/tracing/tracing_traits.cc
@@ -0,0 +1,63 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fcp/tracing/tracing_traits.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/base/attributes.h"
+#include "absl/container/flat_hash_map.h"
+
+namespace fcp {
+
+class UnknownTracingTraits : public TracingTraitsBase {
+ public:
+ ABSL_MUST_USE_RESULT const char* Name() const override { return "?Unknown?"; }
+ ABSL_MUST_USE_RESULT TracingSeverity Severity() const override {
+ return TracingSeverity::kWarning;
+ };
+ ABSL_MUST_USE_RESULT std::string TextFormat(
+ const flatbuffers::DetachedBuffer& buf) const override {
+ return "";
+ };
+ ABSL_MUST_USE_RESULT std::string JsonStringFormat(
+ const uint8_t* flatbuf_bytes) const override {
+ return "";
+ };
+};
+
+absl::flat_hash_map<TracingTag, std::unique_ptr<TracingTraitsBase>>&
+GetTracingTraitsRegistry() {
+ static auto tracing_traits_registry =
+ new absl::flat_hash_map<TracingTag, std::unique_ptr<TracingTraitsBase>>();
+ return *tracing_traits_registry;
+}
+
+TracingTraitsBase const* TracingTraitsBase::Lookup(TracingTag tag) {
+ auto it = GetTracingTraitsRegistry().find(tag);
+ if (it == GetTracingTraitsRegistry().end()) {
+ static auto unknown_tracing_traits = new UnknownTracingTraits();
+ return unknown_tracing_traits;
+ }
+ return it->second.get();
+}
+
+void TracingTraitsBase::Register(TracingTag tag,
+ std::unique_ptr<TracingTraitsBase> traits) {
+ GetTracingTraitsRegistry()[tag] = std::move(traits);
+}
+
+} // namespace fcp
diff --git a/fcp/tracing/tracing_traits.h b/fcp/tracing/tracing_traits.h
new file mode 100644
index 0000000..d3e7173
--- /dev/null
+++ b/fcp/tracing/tracing_traits.h
@@ -0,0 +1,91 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FCP_TRACING_TRACING_TRAITS_H_
+#define FCP_TRACING_TRACING_TRAITS_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/base/attributes.h"
+#include "fcp/tracing/tracing_severity.h"
+#include "fcp/tracing/tracing_tag.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace fcp {
+
+// Base class for tracing traits, allows working with generated tracing traits
+// (see below) in polymorphic way at runtime without knowledge of a type.
+class TracingTraitsBase {
+ public:
+ // Returns printable name of the FlatBuffers table
+ ABSL_MUST_USE_RESULT virtual const char* Name() const = 0;
+ // Returns printable severity of the event represented by this table.
+ ABSL_MUST_USE_RESULT virtual TracingSeverity Severity() const = 0;
+ // Formats a serialized flatbuffer into human readable format.
+ ABSL_MUST_USE_RESULT virtual std::string TextFormat(
+ const flatbuffers::DetachedBuffer& buf) const = 0;
+ // Formats a serialized flatbuffer into a Json string.
+ ABSL_MUST_USE_RESULT virtual std::string JsonStringFormat(
+ const uint8_t* flatbuf_bytes) const = 0;
+
+ // Allows to lookup FlatBuffers table traits by its 4-character tag.
+ // For unknown tag returns a stub.
+ static TracingTraitsBase const* Lookup(TracingTag tag);
+
+ // Registers runtime trait information (to be used for generated code only).
+ // Multiple compilation units are allowed to register the same traits,
+ // last registration wins.
+ static void Register(TracingTag tag,
+ std::unique_ptr<TracingTraitsBase> trait);
+
+ static std::string SeverityString(const TracingSeverity tracing_severity) {
+ switch (tracing_severity) {
+ case TracingSeverity::kInfo:
+ return "INFO";
+ case TracingSeverity::kWarning:
+ return "WARNING";
+ case TracingSeverity::kError:
+ return "ERROR";
+ }
+ }
+
+ virtual ~TracingTraitsBase() = default;
+};
+
+// Specializations of TracingTraits used by TracingSpan are typically included
+// into user code together with the definitions of concrete FlatBufferTable via
+// "tracing_schema.h" header. The latter is auto-generated by
+// tracing_traits_generator tool from "tracing_schema.fbs".
+template <class FlatBufferTable>
+class TracingTraits;
+
+namespace internal {
+
+// Helper class to for automatic registration of traits for runtime use.
+// This is intended to be used from autogenerated code only.
+template<class FlatBufferTable>
+struct TracingTraitsRegistrar {
+ TracingTraitsRegistrar() {
+ TracingTraitsBase::Register(
+ TracingTraits<FlatBufferTable>::kTag,
+ std::make_unique<TracingTraits<FlatBufferTable>>());
+ }
+};
+
+} // namespace internal
+
+} // namespace fcp
+
+#endif // FCP_TRACING_TRACING_TRAITS_H_
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e43bd4f
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,28 @@
+# Requirements for the Federated Compute Python development environment.
+#
+# For compatibility with TensorFlow and TensorFlow Federated, most Python
+# dependencies should be managed by pip, not Bazel.
+#
+# * For packages that have a stable release, we use a version that is
+# compatible with that release (e.g. `~=x.y`). See
+# https://peps.python.org/pep-0440/#compatible-release for more information.
+# * For packages that do not have a stable release, we use a version that
+# matches a release that has been tested (e.g. `==x.y.z`). See
+# https://peps.python.org/pep-0440/#version-matching for more information.
+#
+# Note: There is bug in `pip` when multiple packages use the compatible release
+# operator `~=` to specify a version and one of those versions ends in `0`. See
+# https://github.com/pypa/pip/issues/9613 for more information. In this case,
+# use the equivalent clause `>=x.0,==x.*` instead of `~=x.0`.
+#
+# This assumes that the packages follow Semantic Versioning, see
+# https://semver.org/. If a package follows a different versioning scheme or
+# requires unique handling, we use a different version specifier and comment the
+# versioning scheme or reasoning.
+
+absl-py>=1.0,==1.*
+protobuf~=3.20
+# The TensorFlow version should match what's specified in the WORKSPACE file for
+# C++ targets.
+tensorflow==2.12.0
+tensorflow-federated~=0.53
diff --git a/third_party/BUILD b/third_party/BUILD
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/third_party/BUILD
diff --git a/third_party/curl.BUILD.bzl b/third_party/curl.BUILD.bzl
new file mode 100644
index 0000000..3ee9700
--- /dev/null
+++ b/third_party/curl.BUILD.bzl
@@ -0,0 +1,618 @@
+"""
+A custom build file for curl
+"""
+
+licenses(["notice"])
+
+exports_files(["COPYING"])
+
+cc_library(
+ name = "curl",
+ srcs = [
+ "include/curl_config.h",
+ "lib/altsvc.c",
+ "lib/altsvc.h",
+ "lib/amigaos.c",
+ "lib/amigaos.h",
+ "lib/arpa_telnet.h",
+ "lib/asyn-ares.c",
+ "lib/asyn.h",
+ "lib/base64.c",
+ "lib/bufref.c",
+ "lib/bufref.h",
+ "lib/c-hyper.c",
+ "lib/c-hyper.h",
+ "lib/config-amigaos.h",
+ "lib/config-dos.h",
+ "lib/config-mac.h",
+ "lib/config-os400.h",
+ "lib/config-plan9.h",
+ "lib/config-riscos.h",
+ "lib/config-win32.h",
+ "lib/config-win32ce.h",
+ "lib/conncache.c",
+ "lib/conncache.h",
+ "lib/connect.c",
+ "lib/connect.h",
+ "lib/content_encoding.c",
+ "lib/content_encoding.h",
+ "lib/cookie.c",
+ "lib/cookie.h",
+ "lib/curl_addrinfo.c",
+ "lib/curl_addrinfo.h",
+ "lib/curl_base64.h",
+ "lib/curl_ctype.c",
+ "lib/curl_ctype.h",
+ "lib/curl_des.c",
+ "lib/curl_des.h",
+ "lib/curl_endian.c",
+ "lib/curl_endian.h",
+ "lib/curl_fnmatch.c",
+ "lib/curl_fnmatch.h",
+ "lib/curl_get_line.c",
+ "lib/curl_get_line.h",
+ "lib/curl_gethostname.c",
+ "lib/curl_gethostname.h",
+ "lib/curl_gssapi.c",
+ "lib/curl_gssapi.h",
+ "lib/curl_hmac.h",
+ "lib/curl_krb5.h",
+ "lib/curl_ldap.h",
+ "lib/curl_md4.h",
+ "lib/curl_md5.h",
+ "lib/curl_memory.h",
+ "lib/curl_memrchr.c",
+ "lib/curl_memrchr.h",
+ "lib/curl_multibyte.c",
+ "lib/curl_multibyte.h",
+ "lib/curl_ntlm_core.c",
+ "lib/curl_ntlm_core.h",
+ "lib/curl_ntlm_wb.c",
+ "lib/curl_ntlm_wb.h",
+ "lib/curl_path.c",
+ "lib/curl_path.h",
+ "lib/curl_printf.h",
+ "lib/curl_range.c",
+ "lib/curl_range.h",
+ "lib/curl_rtmp.c",
+ "lib/curl_rtmp.h",
+ "lib/curl_sasl.c",
+ "lib/curl_sasl.h",
+ "lib/curl_setup.h",
+ "lib/curl_setup_once.h",
+ "lib/curl_sha256.h",
+ "lib/curl_sspi.c",
+ "lib/curl_sspi.h",
+ "lib/curl_threads.c",
+ "lib/curl_threads.h",
+ "lib/curlx.h",
+ "lib/dict.c",
+ "lib/dict.h",
+ "lib/doh.c",
+ "lib/doh.h",
+ "lib/dotdot.c",
+ "lib/dotdot.h",
+ "lib/dynbuf.c",
+ "lib/dynbuf.h",
+ "lib/easy.c",
+ "lib/easygetopt.c",
+ "lib/easyif.h",
+ "lib/easyoptions.c",
+ "lib/easyoptions.h",
+ "lib/easy_lock.h",
+ "lib/escape.c",
+ "lib/escape.h",
+ "lib/file.c",
+ "lib/file.h",
+ "lib/fileinfo.c",
+ "lib/fileinfo.h",
+ "lib/fopen.c",
+ "lib/fopen.h",
+ "lib/formdata.c",
+ "lib/formdata.h",
+ "lib/ftp.c",
+ "lib/ftp.h",
+ "lib/ftplistparser.c",
+ "lib/ftplistparser.h",
+ "lib/getenv.c",
+ "lib/getinfo.c",
+ "lib/getinfo.h",
+ "lib/gopher.c",
+ "lib/gopher.h",
+ "lib/h2h3.c",
+ "lib/h2h3.h",
+ "lib/hash.c",
+ "lib/hash.h",
+ "lib/headers.c",
+ "lib/headers.h",
+ "lib/hmac.c",
+ "lib/hostasyn.c",
+ "lib/hostip.c",
+ "lib/hostip.h",
+ "lib/hostip4.c",
+ "lib/hostip6.c",
+ "lib/hostsyn.c",
+ "lib/hsts.c",
+ "lib/hsts.h",
+ "lib/http.c",
+ "lib/http.h",
+ "lib/http2.c",
+ "lib/http2.h",
+ "lib/http_chunks.c",
+ "lib/http_chunks.h",
+ "lib/http_digest.c",
+ "lib/http_digest.h",
+ "lib/http_negotiate.c",
+ "lib/http_negotiate.h",
+ "lib/http_ntlm.c",
+ "lib/http_ntlm.h",
+ "lib/http_proxy.c",
+ "lib/http_proxy.h",
+ "lib/http_aws_sigv4.c",
+ "lib/http_aws_sigv4.h",
+ "lib/idn_win32.c",
+ "lib/if2ip.c",
+ "lib/if2ip.h",
+ "lib/imap.c",
+ "lib/imap.h",
+ "lib/inet_ntop.h",
+ "lib/inet_pton.c",
+ "lib/inet_pton.h",
+ "lib/krb5.c",
+ "lib/ldap.c",
+ "lib/llist.c",
+ "lib/llist.h",
+ "lib/md4.c",
+ "lib/md5.c",
+ "lib/memdebug.c",
+ "lib/memdebug.h",
+ "lib/mime.c",
+ "lib/mime.h",
+ "lib/mprintf.c",
+ "lib/mqtt.c",
+ "lib/mqtt.h",
+ "lib/multi.c",
+ "lib/multihandle.h",
+ "lib/multiif.h",
+ "lib/netrc.c",
+ "lib/netrc.h",
+ "lib/nonblock.c",
+ "lib/nonblock.h",
+ "lib/openldap.c",
+ "lib/parsedate.c",
+ "lib/parsedate.h",
+ "lib/pingpong.c",
+ "lib/pingpong.h",
+ "lib/pop3.c",
+ "lib/pop3.h",
+ "lib/progress.c",
+ "lib/progress.h",
+ "lib/psl.c",
+ "lib/psl.h",
+ "lib/quic.h",
+ "lib/rand.c",
+ "lib/rand.h",
+ "lib/rename.c",
+ "lib/rename.h",
+ "lib/rtsp.c",
+ "lib/rtsp.h",
+ "lib/select.c",
+ "lib/select.h",
+ "lib/sendf.c",
+ "lib/sendf.h",
+ "lib/setopt.c",
+ "lib/setopt.h",
+ "lib/setup-vms.h",
+ "lib/sha256.c",
+ "lib/share.c",
+ "lib/share.h",
+ "lib/sigpipe.h",
+ "lib/slist.c",
+ "lib/slist.h",
+ "lib/smb.c",
+ "lib/smb.h",
+ "lib/smtp.c",
+ "lib/smtp.h",
+ "lib/sockaddr.h",
+ "lib/socketpair.c",
+ "lib/socketpair.h",
+ "lib/socks.c",
+ "lib/socks.h",
+ "lib/socks_gssapi.c",
+ "lib/socks_sspi.c",
+ "lib/speedcheck.c",
+ "lib/speedcheck.h",
+ "lib/splay.c",
+ "lib/splay.h",
+ "lib/strcase.c",
+ "lib/strcase.h",
+ "lib/strdup.c",
+ "lib/strdup.h",
+ "lib/strerror.c",
+ "lib/strerror.h",
+ "lib/strtok.c",
+ "lib/strtok.h",
+ "lib/strtoofft.c",
+ "lib/strtoofft.h",
+ "lib/system_win32.h",
+ "lib/telnet.c",
+ "lib/telnet.h",
+ "lib/tftp.c",
+ "lib/tftp.h",
+ "lib/timediff.c",
+ "lib/timediff.h",
+ "lib/timeval.c",
+ "lib/timeval.h",
+ "lib/transfer.c",
+ "lib/transfer.h",
+ "lib/url.c",
+ "lib/url.h",
+ "lib/urldata.h",
+ "lib/urlapi-int.h",
+ "lib/urlapi.c",
+ "lib/version.c",
+ "lib/version_win32.c",
+ "lib/version_win32.h",
+ "lib/warnless.c",
+ "lib/warnless.h",
+ "lib/wildcard.c",
+ "lib/wildcard.h",
+ "lib/vauth/cleartext.c",
+ "lib/vauth/cram.c",
+ "lib/vauth/digest.c",
+ "lib/vauth/digest.h",
+ "lib/vauth/digest_sspi.c",
+ "lib/vauth/krb5_gssapi.c",
+ "lib/vauth/krb5_sspi.c",
+ "lib/vauth/ntlm.c",
+ "lib/vauth/ntlm.h",
+ "lib/vauth/ntlm_sspi.c",
+ "lib/vauth/oauth2.c",
+ "lib/vauth/spnego_sspi.c",
+ "lib/vauth/vauth.c",
+ "lib/vauth/vauth.h",
+ "lib/vquic/msh3.c",
+ "lib/vquic/msh3.h",
+ "lib/vquic/ngtcp2.c",
+ "lib/vquic/ngtcp2.h",
+ "lib/vquic/quiche.c",
+ "lib/vquic/quiche.h",
+ "lib/vquic/vquic.c",
+ "lib/vquic/vquic.h",
+ "lib/vssh/libssh.c",
+ "lib/vssh/libssh2.c",
+ "lib/vssh/ssh.h",
+ "lib/vssh/wolfssh.c",
+ "lib/vtls/bearssl.c",
+ "lib/vtls/bearssl.h",
+ "lib/vtls/gskit.c",
+ "lib/vtls/gskit.h",
+ "lib/vtls/gtls.c",
+ "lib/vtls/gtls.h",
+ "lib/vtls/hostcheck.c",
+ "lib/vtls/hostcheck.h",
+ "lib/vtls/keylog.c",
+ "lib/vtls/keylog.h",
+ "lib/vtls/mbedtls.c",
+ "lib/vtls/mbedtls.h",
+ "lib/vtls/mbedtls_threadlock.c",
+ "lib/vtls/mbedtls_threadlock.h",
+ "lib/vtls/nss.c",
+ "lib/vtls/nssg.h",
+ "lib/vtls/openssl.c",
+ "lib/vtls/openssl.h",
+ "lib/vtls/rustls.c",
+ "lib/vtls/rustls.h",
+ "lib/vtls/schannel.c",
+ "lib/vtls/schannel.h",
+ "lib/vtls/schannel_verify.c",
+ "lib/vtls/sectransp.h",
+ "lib/vtls/vtls.c",
+ "lib/vtls/vtls.h",
+ "lib/vtls/wolfssl.c",
+ "lib/vtls/wolfssl.h",
+ "lib/vtls/x509asn1.c",
+ "lib/vtls/x509asn1.h",
+ ],
+ hdrs = [
+ "include/curl/curl.h",
+ "include/curl/curlver.h",
+ "include/curl/easy.h",
+ "include/curl/header.h",
+ "include/curl/mprintf.h",
+ "include/curl/multi.h",
+ "include/curl/options.h",
+ "include/curl/stdcheaders.h",
+ "include/curl/system.h",
+ "include/curl/typecheck-gcc.h",
+ "include/curl/urlapi.h",
+ ],
+ copts = [
+ "-Iexternal/curl/lib",
+ "-D_GNU_SOURCE",
+ "-DBUILDING_LIBCURL",
+ "-DHAVE_CONFIG_H",
+ "-DCURL_DISABLE_FTP",
+ "-DCURL_DISABLE_NTLM",
+ "-DHAVE_LIBZ",
+ "-DHAVE_ZLIB_H",
+ "-Wno-string-plus-int",
+ "-DCURL_MAX_WRITE_SIZE=65536",
+ ],
+ defines = ["CURL_STATICLIB"],
+ includes = ["include", "lib"],
+ linkopts = [
+ "-lrt",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@zlib",
+ "@boringssl//:ssl",
+ ],
+)
+
+genrule(
+ name = "configure",
+ outs = ["include/curl_config.h"],
+ cmd = "\n".join([
+ "cat <<'EOF' >$@",
+ "#ifndef EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_",
+ "#define EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_",
+ "",
+ "#if !defined(_WIN32) && !defined(__APPLE__)",
+ "# include <openssl/opensslv.h>",
+ "# if defined(OPENSSL_IS_BORINGSSL)",
+ "# define HAVE_BORINGSSL 1",
+ "# endif",
+ "#endif",
+ "",
+ "#if defined(_WIN32)",
+ "# include \"lib/config-win32.h\"",
+ "# define BUILDING_LIBCURL 1",
+ "# define CURL_DISABLE_CRYPTO_AUTH 1",
+ "# define CURL_DISABLE_DICT 1",
+ "# define CURL_DISABLE_FILE 1",
+ "# define CURL_DISABLE_GOPHER 1",
+ "# define CURL_DISABLE_IMAP 1",
+ "# define CURL_DISABLE_LDAP 1",
+ "# define CURL_DISABLE_LDAPS 1",
+ "# define CURL_DISABLE_POP3 1",
+ "# define CURL_PULL_WS2TCPIP_H 1",
+ "# define CURL_DISABLE_SMTP 1",
+ "# define CURL_DISABLE_TELNET 1",
+ "# define CURL_DISABLE_TFTP 1",
+ "# define CURL_PULL_WS2TCPIP_H 1",
+ "# define USE_WINDOWS_SSPI 1",
+ "# define USE_WIN32_IDN 1",
+ "# define USE_SCHANNEL 1",
+ "# define WANT_IDN_PROTOTYPES 1",
+ "#elif defined(__APPLE__)",
+ "# define HAVE_FSETXATTR_6 1",
+ "# define HAVE_SETMODE 1",
+ "# define HAVE_SYS_FILIO_H 1",
+ "# define HAVE_SYS_SOCKIO_H 1",
+ "# define OS \"x86_64-apple-darwin15.5.0\"",
+ "# define USE_SECTRANSP 1",
+ "#else",
+ "# define CURL_CA_BUNDLE \"/etc/ssl/certs/ca-certificates.crt\"",
+ "# define GETSERVBYPORT_R_ARGS 6",
+ "# define GETSERVBYPORT_R_BUFSIZE 4096",
+ "# define HAVE_BORINGSSL 1",
+ "# define HAVE_CLOCK_GETTIME_MONOTONIC 1",
+ "# define HAVE_CRYPTO_CLEANUP_ALL_EX_DATA 1",
+ "# define HAVE_FSETXATTR_5 1",
+ "# define HAVE_GETHOSTBYADDR_R 1",
+ "# define HAVE_GETHOSTBYADDR_R_8 1",
+ "# define HAVE_GETHOSTBYNAME_R 1",
+ "# define HAVE_GETHOSTBYNAME_R_6 1",
+ "# define HAVE_GETSERVBYPORT_R 1",
+ "# define HAVE_LIBSSL 1",
+ "# define HAVE_MALLOC_H 1",
+ "# define HAVE_MSG_NOSIGNAL 1",
+ "# define HAVE_OPENSSL_CRYPTO_H 1",
+ "# define HAVE_OPENSSL_ERR_H 1",
+ "# define HAVE_OPENSSL_PEM_H 1",
+ "# define HAVE_OPENSSL_PKCS12_H 1",
+ "# define HAVE_OPENSSL_RSA_H 1",
+ "# define HAVE_OPENSSL_SSL_H 1",
+ "# define HAVE_OPENSSL_X509_H 1",
+ "# define HAVE_RAND_EGD 1",
+ "# define HAVE_RAND_STATUS 1",
+ "# define HAVE_SSL_GET_SHUTDOWN 1",
+ "# define HAVE_TERMIOS_H 1",
+ "# define OS \"x86_64-pc-linux-gnu\"",
+ "# define RANDOM_FILE \"/dev/urandom\"",
+ "# define USE_OPENSSL 1",
+ "#endif",
+ "",
+ "#if !defined(_WIN32)",
+ "# define CURL_DISABLE_DICT 1",
+ "# define CURL_DISABLE_FILE 1",
+ "# define CURL_DISABLE_GOPHER 1",
+ "# define CURL_DISABLE_IMAP 1",
+ "# define CURL_DISABLE_LDAP 1",
+ "# define CURL_DISABLE_LDAPS 1",
+ "# define CURL_DISABLE_POP3 1",
+ "# define CURL_DISABLE_SMTP 1",
+ "# define CURL_DISABLE_TELNET 1",
+ "# define CURL_DISABLE_TFTP 1",
+ "# define CURL_EXTERN_SYMBOL __attribute__ ((__visibility__ (\"default\")))",
+ "# define ENABLE_IPV6 1",
+ "# define GETHOSTNAME_TYPE_ARG2 size_t",
+ "# define GETNAMEINFO_QUAL_ARG1 const",
+ "# define GETNAMEINFO_TYPE_ARG1 struct sockaddr *",
+ "# define GETNAMEINFO_TYPE_ARG2 socklen_t",
+ "# define GETNAMEINFO_TYPE_ARG46 socklen_t",
+ "# define GETNAMEINFO_TYPE_ARG7 int",
+ "# define HAVE_ALARM 1",
+ "# define HAVE_ALLOCA_H 1",
+ "# define HAVE_ARPA_INET_H 1",
+ "# define HAVE_ARPA_TFTP_H 1",
+ "# define HAVE_ASSERT_H 1",
+ "# define HAVE_BASENAME 1",
+ "# define HAVE_BOOL_T 1",
+ "# define HAVE_CONNECT 1",
+ "# define HAVE_DLFCN_H 1",
+ "# define HAVE_ERRNO_H 1",
+ "# define HAVE_FCNTL 1",
+ "# define HAVE_FCNTL_H 1",
+ "# define HAVE_FCNTL_O_NONBLOCK 1",
+ "# define HAVE_FDOPEN 1",
+ "# define HAVE_FORK 1",
+ "# define HAVE_FREEADDRINFO 1",
+ "# define HAVE_FREEIFADDRS 1",
+ "# if !defined(__ANDROID__)",
+ "# define HAVE_FSETXATTR 1",
+ "# endif",
+ "# define HAVE_FTRUNCATE 1",
+ "# define HAVE_GAI_STRERROR 1",
+ "# define HAVE_GETADDRINFO 1",
+ "# define HAVE_GETADDRINFO_THREADSAFE 1",
+ "# define HAVE_GETEUID 1",
+ "# define HAVE_GETHOSTBYADDR 1",
+ "# define HAVE_GETHOSTBYNAME 1",
+ "# define HAVE_GETHOSTNAME 1",
+ "# if !defined(__ANDROID__)",
+ "# define HAVE_GETIFADDRS 1",
+ "# endif",
+ "# define HAVE_GETNAMEINFO 1",
+ "# define HAVE_GETPPID 1",
+ "# define HAVE_GETPROTOBYNAME 1",
+ "# define HAVE_GETPWUID 1",
+ "# if !defined(__ANDROID__)",
+ "# define HAVE_GETPWUID_R 1",
+ "# endif",
+ "# define HAVE_GETRLIMIT 1",
+ "# define HAVE_GETTIMEOFDAY 1",
+ "# define HAVE_GMTIME_R 1",
+ "# if !defined(__ANDROID__)",
+ "# define HAVE_IFADDRS_H 1",
+ "# endif",
+ "# define HAVE_IF_NAMETOINDEX 1",
+ "# define HAVE_INET_ADDR 1",
+ "# define HAVE_INET_NTOP 1",
+ "# define HAVE_INET_PTON 1",
+ "# define HAVE_INTTYPES_H 1",
+ "# define HAVE_IOCTL 1",
+ "# define HAVE_IOCTL_FIONBIO 1",
+ "# define HAVE_IOCTL_SIOCGIFADDR 1",
+ "# define HAVE_LIBGEN_H 1",
+ "# define HAVE_LIBZ 1",
+ "# define HAVE_LIMITS_H 1",
+ "# define HAVE_LL 1",
+ "# define HAVE_LOCALE_H 1",
+ "# define HAVE_LOCALTIME_R 1",
+ "# define HAVE_LONGLONG 1",
+ "# define HAVE_MEMORY_H 1",
+ "# define HAVE_NETDB_H 1",
+ "# define HAVE_NETINET_IN_H 1",
+ "# define HAVE_NETINET_TCP_H 1",
+ "# define HAVE_NET_IF_H 1",
+ "# define HAVE_PERROR 1",
+ "# define HAVE_PIPE 1",
+ "# define HAVE_POLL 1",
+ "# define HAVE_POLL_FINE 1",
+ "# define HAVE_POLL_H 1",
+ "# define HAVE_POSIX_STRERROR_R 1",
+ "# define HAVE_PWD_H 1",
+ "# define HAVE_RECV 1",
+ "# define HAVE_SELECT 1",
+ "# define HAVE_SEND 1",
+ "# define HAVE_SETJMP_H 1",
+ "# define HAVE_SETLOCALE 1",
+ "# define HAVE_SETRLIMIT 1",
+ "# define HAVE_SETSOCKOPT 1",
+ "# define HAVE_SGTTY_H 1",
+ "# define HAVE_SIGACTION 1",
+ "# define HAVE_SIGINTERRUPT 1",
+ "# define HAVE_SIGNAL 1",
+ "# define HAVE_SIGNAL_H 1",
+ "# define HAVE_SIGSETJMP 1",
+ "# define HAVE_SIG_ATOMIC_T 1",
+ "# define HAVE_SOCKADDR_IN6_SIN6_SCOPE_ID 1",
+ "# define HAVE_SOCKET 1",
+ "# define HAVE_SOCKETPAIR 1",
+ "# define HAVE_STDBOOL_H 1",
+ "# define HAVE_STDINT_H 1",
+ "# define HAVE_STDIO_H 1",
+ "# define HAVE_STDLIB_H 1",
+ "# define HAVE_STRCASECMP 1",
+ "# define HAVE_STRDUP 1",
+ "# define HAVE_STRERROR_R 1",
+ "# define HAVE_STRINGS_H 1",
+ "# define HAVE_STRING_H 1",
+ "# define HAVE_STRNCASECMP 1",
+ "# define HAVE_STRSTR 1",
+ "# define HAVE_STRTOK_R 1",
+ "# define HAVE_STRTOLL 1",
+ "# define HAVE_STRUCT_SOCKADDR_STORAGE 1",
+ "# define HAVE_STRUCT_TIMEVAL 1",
+ "# define HAVE_SYS_IOCTL_H 1",
+ "# define HAVE_SYS_PARAM_H 1",
+ "# define HAVE_SYS_POLL_H 1",
+ "# define HAVE_SYS_RESOURCE_H 1",
+ "# define HAVE_SYS_SELECT_H 1",
+ "# define HAVE_SYS_SOCKET_H 1",
+ "# define HAVE_SYS_STAT_H 1",
+ "# define HAVE_SYS_TIME_H 1",
+ "# define HAVE_SYS_TYPES_H 1",
+ "# define HAVE_SYS_UIO_H 1",
+ "# define HAVE_SYS_UN_H 1",
+ "# define HAVE_SYS_WAIT_H 1",
+ "# define HAVE_SYS_XATTR_H 1",
+ "# define HAVE_TIME_H 1",
+ "# define HAVE_UNAME 1",
+ "# define HAVE_UNISTD_H 1",
+ "# define HAVE_UTIME 1",
+ "# define HAVE_UTIME_H 1",
+ "# define HAVE_VARIADIC_MACROS_C99 1",
+ "# define HAVE_VARIADIC_MACROS_GCC 1",
+ "# define HAVE_WRITABLE_ARGV 1",
+ "# define HAVE_WRITEV 1",
+ "# define HAVE_ZLIB_H 1",
+ "# define LT_OBJDIR \".libs/\"",
+ "# define PACKAGE \"curl\"",
+ "# define PACKAGE_BUGREPORT \"a suitable curl mailing list: https://curl.haxx.se/mail/\"",
+ "# define PACKAGE_NAME \"curl\"",
+ "# define PACKAGE_STRING \"curl -\"",
+ "# define PACKAGE_TARNAME \"curl\"",
+ "# define PACKAGE_URL \"\"",
+ "# define PACKAGE_VERSION \"-\"",
+ "# define RECV_TYPE_ARG1 int",
+ "# define RECV_TYPE_ARG2 void *",
+ "# define RECV_TYPE_ARG3 size_t",
+ "# define RECV_TYPE_ARG4 int",
+ "# define RECV_TYPE_RETV ssize_t",
+ "# define RETSIGTYPE void",
+ "# define SELECT_QUAL_ARG5",
+ "# define SELECT_TYPE_ARG1 int",
+ "# define SELECT_TYPE_ARG234 fd_set *",
+ "# define SELECT_TYPE_ARG5 struct timeval *",
+ "# define SELECT_TYPE_RETV int",
+ "# define SEND_QUAL_ARG2 const",
+ "# define SEND_TYPE_ARG1 int",
+ "# define SEND_TYPE_ARG2 void *",
+ "# define SEND_TYPE_ARG3 size_t",
+ "# define SEND_TYPE_ARG4 int",
+ "# define SEND_TYPE_RETV ssize_t",
+ "# define SIZEOF_INT 4",
+ "# define SIZEOF_LONG 8",
+ "# define SIZEOF_OFF_T 8",
+ "# define SIZEOF_CURL_OFF_T 8",
+ "# define SIZEOF_SHORT 2",
+ "# define SIZEOF_SIZE_T 8",
+ "# define SIZEOF_TIME_T 8",
+ "# define SIZEOF_VOIDP 8",
+ "# define STDC_HEADERS 1",
+ "# define STRERROR_R_TYPE_ARG3 size_t",
+ "# define TIME_WITH_SYS_TIME 1",
+ "# define VERSION \"-\"",
+ "# ifndef _DARWIN_USE_64_BIT_INODE",
+ "# define _DARWIN_USE_64_BIT_INODE 1",
+ "# endif",
+ "#endif",
+ "",
+ "#endif // EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_",
+ "EOF",
+ ]),
+)