diff options
author | Maya Spivak <mspivak@google.com> | 2024-04-09 09:49:22 -0700 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2024-04-09 09:50:19 -0700 |
commit | 099841e1eb49d720e6dafe34e678f211864f80b8 (patch) | |
tree | 5925b8f3ec686b4ef936df1e90689ee6e6cca247 | |
parent | 5fb2bf43a96ef4ba83ed28ad218eb6e036dc55f4 (diff) | |
download | federated-compute-099841e1eb49d720e6dafe34e678f211864f80b8.tar.gz |
Implement serialization and deserialization for the FederatedMean aggregator.
PiperOrigin-RevId: 623197173
-rw-r--r-- | fcp/aggregation/core/BUILD | 1 | ||||
-rw-r--r-- | fcp/aggregation/core/agg_core.proto | 7 | ||||
-rw-r--r-- | fcp/aggregation/core/federated_mean.cc | 101 | ||||
-rw-r--r-- | fcp/aggregation/core/federated_mean_test.cc | 92 |
4 files changed, 172 insertions, 29 deletions
diff --git a/fcp/aggregation/core/BUILD b/fcp/aggregation/core/BUILD index 841de08..6ad82bf 100644 --- a/fcp/aggregation/core/BUILD +++ b/fcp/aggregation/core/BUILD @@ -252,6 +252,7 @@ cc_test( srcs = ["federated_mean_test.cc"], copts = FCP_COPTS, deps = [ + ":agg_core_cc_proto", ":aggregation_cores", ":aggregator", ":tensor", diff --git a/fcp/aggregation/core/agg_core.proto b/fcp/aggregation/core/agg_core.proto index 0ce048e..6efbed5 100644 --- a/fcp/aggregation/core/agg_core.proto +++ b/fcp/aggregation/core/agg_core.proto @@ -7,3 +7,10 @@ message AggVectorAggregatorState { uint64 num_inputs = 1; bytes vector_data = 2; } + +// Internal state representation of a FederatedMeanAggregator. +message FederatedMeanAggregatorState { + uint64 num_inputs = 1; + bytes weights_sum = 2; + bytes weighted_values_sum = 3; +} diff --git a/fcp/aggregation/core/federated_mean.cc b/fcp/aggregation/core/federated_mean.cc index 5e3d35c..b7c21d3 100644 --- a/fcp/aggregation/core/federated_mean.cc +++ b/fcp/aggregation/core/federated_mean.cc @@ -16,19 +16,21 @@ #include <cstddef> #include <memory> +#include <string> #include <utility> #include <vector> +#include "fcp/aggregation/core/agg_core.pb.h" #include "fcp/aggregation/core/agg_vector.h" #include "fcp/aggregation/core/datatype.h" #include "fcp/aggregation/core/input_tensor_list.h" #include "fcp/aggregation/core/intrinsic.h" #include "fcp/aggregation/core/mutable_vector_data.h" #include "fcp/aggregation/core/tensor.h" +#include "fcp/aggregation/core/tensor.pb.h" #include "fcp/aggregation/core/tensor_aggregator.h" #include "fcp/aggregation/core/tensor_aggregator_factory.h" #include "fcp/aggregation/core/tensor_aggregator_registry.h" -#include "fcp/aggregation/core/tensor_data.h" #include "fcp/aggregation/core/tensor_shape.h" #include "fcp/aggregation/core/tensor_spec.h" #include "fcp/base/monitoring.h" @@ -42,13 +44,29 @@ constexpr char kFederatedWeightedMeanUri[] = "federated_weighted_mean"; template <typename V, typename W> class FederatedMean final : public TensorAggregator { public: - explicit FederatedMean(DataType dtype, TensorShape shape, - MutableVectorData<V>* weighted_values_sum) - : weighted_values_sum_(*weighted_values_sum), - result_tensor_( - Tensor::Create(dtype, shape, - std::unique_ptr<TensorData>(weighted_values_sum)) - .value()) {} + explicit FederatedMean( + DataType dtype, TensorShape shape, + std::unique_ptr<MutableVectorData<V>> weighted_values_sum) + : FederatedMean(dtype, shape, std::move(weighted_values_sum), 0, 0) {} + + FederatedMean(DataType dtype, TensorShape shape, + std::unique_ptr<MutableVectorData<V>> weighted_values_sum, + W weights_sum, int num_inputs) + : dtype_(dtype), + shape_(std::move(shape)), + weighted_values_sum_(std::move(weighted_values_sum)), + weights_sum_(weights_sum), + num_inputs_(num_inputs) {} + + StatusOr<std::string> Serialize() && override { + FederatedMeanAggregatorState aggregator_state; + aggregator_state.set_num_inputs(num_inputs_); + *(aggregator_state.mutable_weighted_values_sum()) = + weighted_values_sum_->EncodeContent(); + *(aggregator_state.mutable_weights_sum()) = std::string( + reinterpret_cast<char*>(&weights_sum_), sizeof(weights_sum_)); + return aggregator_state.SerializeAsString(); + } private: Status MergeWith(TensorAggregator&& other) override { @@ -61,16 +79,16 @@ class FederatedMean final : public TensorAggregator { } FCP_RETURN_IF_ERROR((*other_ptr).CheckValid()); - std::pair<std::vector<V>, W> other_internal_state = + std::pair<std::unique_ptr<MutableVectorData<V>>, W> other_internal_state = other_ptr->GetInternalState(); - if (other_internal_state.first.size() != weighted_values_sum_.size()) { + if (other_internal_state.first->size() != weighted_values_sum_->size()) { return FCP_STATUS(INVALID_ARGUMENT) << "FederatedMean::MergeWith: Can only merge weighted value sum " "tensors of equal length."; } - for (int i = 0; i < weighted_values_sum_.size(); ++i) { - weighted_values_sum_[i] += other_internal_state.first[i]; + for (int i = 0; i < weighted_values_sum_->size(); ++i) { + (*weighted_values_sum_)[i] += (*other_internal_state.first)[i]; } weights_sum_ += other_internal_state.second; num_inputs_ += other_ptr->GetNumInputs(); @@ -101,12 +119,12 @@ class FederatedMean final : public TensorAggregator { "weights are allowed."; } for (auto value : values) { - weighted_values_sum_[value.index] += value.value * weight; + (*weighted_values_sum_)[value.index] += value.value * weight; } weights_sum_ += weight; } else { for (auto value : values) { - weighted_values_sum_[value.index] += value.value; + (*weighted_values_sum_)[value.index] += value.value; } } num_inputs_++; @@ -126,29 +144,32 @@ class FederatedMean final : public TensorAggregator { // Produce the final weighted mean values by dividing the weighted values // sum by the weights sum (tracked by weights_sum_ in the weighted case and // num_inputs_ in the non-weighted case). - for (int i = 0; i < weighted_values_sum_.size(); ++i) { - weighted_values_sum_[i] /= + for (int i = 0; i < weighted_values_sum_->size(); ++i) { + (*weighted_values_sum_)[i] /= (weights_sum_ > 0 ? weights_sum_ : num_inputs_); } OutputTensorList outputs = std::vector<Tensor>(); - outputs.push_back(std::move(result_tensor_)); + outputs.push_back( + Tensor::Create(dtype_, shape_, std::move(weighted_values_sum_)) + .value()); return outputs; } int GetNumInputs() const override { return num_inputs_; } - std::pair<std::vector<V>, W> GetInternalState() { + std::pair<std::unique_ptr<MutableVectorData<V>>, W> GetInternalState() { output_consumed_ = true; return std::make_pair(std::move(weighted_values_sum_), weights_sum_); } bool output_consumed_ = false; - std::vector<V>& weighted_values_sum_; + DataType dtype_; + TensorShape shape_; + std::unique_ptr<MutableVectorData<V>> weighted_values_sum_; // In the weighted case, use the weights_sum_ variable to track the total // weight. Otherwise, just rely on the num_inputs_ variable. - W weights_sum_ = 0; - Tensor result_tensor_; - int num_inputs_ = 0; + W weights_sum_; + int num_inputs_; }; // Factory class for the FederatedMean. @@ -162,6 +183,24 @@ class FederatedMeanFactory final : public TensorAggregatorFactory { StatusOr<std::unique_ptr<TensorAggregator>> Create( const Intrinsic& intrinsic) const override { + return CreateInternal(intrinsic, nullptr); + } + + StatusOr<std::unique_ptr<TensorAggregator>> Deserialize( + const Intrinsic& intrinsic, std::string serialized_state) const override { + FederatedMeanAggregatorState aggregator_state; + if (!aggregator_state.ParseFromString(serialized_state)) { + return FCP_STATUS(INVALID_ARGUMENT) + << "FederatedMeanFactory::Deserialize: Failed to parse " + "FederatedMeanAggregatorState."; + } + return CreateInternal(intrinsic, &aggregator_state); + } + + private: + StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal( + const Intrinsic& intrinsic, + const FederatedMeanAggregatorState* aggregator_state) const { // Check that the configuration is valid. if (kFederatedMeanUri == intrinsic.uri) { if (intrinsic.inputs.size() != 1) { @@ -232,13 +271,29 @@ class FederatedMeanFactory final : public TensorAggregatorFactory { } std::unique_ptr<TensorAggregator> aggregator; + if (aggregator_state == nullptr) { + FLOATING_ONLY_DTYPE_CASES( + input_value_type, V, + NUMERICAL_ONLY_DTYPE_CASES( + input_weight_type, W, + aggregator = (std::make_unique<FederatedMean<V, W>>( + input_value_type, input_value_spec.shape(), + std::make_unique<MutableVectorData<V>>( + value_num_elements.value()))))); + return aggregator; + } + FLOATING_ONLY_DTYPE_CASES( input_value_type, V, NUMERICAL_ONLY_DTYPE_CASES( input_weight_type, W, aggregator = (std::make_unique<FederatedMean<V, W>>( input_value_type, input_value_spec.shape(), - new MutableVectorData<V>(value_num_elements.value()))))); + MutableVectorData<V>::CreateFromEncodedContent( + aggregator_state->weighted_values_sum()), + *(reinterpret_cast<const W*>( + aggregator_state->weights_sum().data())), + aggregator_state->num_inputs())))); return aggregator; } }; diff --git a/fcp/aggregation/core/federated_mean_test.cc b/fcp/aggregation/core/federated_mean_test.cc index 0e1ef24..9e67312 100644 --- a/fcp/aggregation/core/federated_mean_test.cc +++ b/fcp/aggregation/core/federated_mean_test.cc @@ -15,10 +15,12 @@ */ #include <memory> +#include <string> #include <utility> #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "fcp/aggregation/core/agg_core.pb.h" #include "fcp/aggregation/core/intrinsic.h" #include "fcp/aggregation/core/tensor.h" #include "fcp/aggregation/core/tensor_aggregator_factory.h" @@ -38,8 +40,11 @@ using testing::Eq; using testing::HasSubstr; using testing::IsFalse; using testing::IsTrue; +using testing::TestWithParam; -TEST(FederatedMeanTest, ScalarAggregation_Succeeds) { +using FederatedMeanTest = TestWithParam<bool>; + +TEST_P(FederatedMeanTest, ScalarAggregation_Succeeds) { Intrinsic federated_mean_intrinsic{"federated_mean", {TensorSpec{"foo", DT_FLOAT, {}}}, {TensorSpec{"foo_out", DT_FLOAT, {}}}, @@ -51,6 +56,14 @@ TEST(FederatedMeanTest, ScalarAggregation_Succeeds) { Tensor v3 = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({3})).value(); EXPECT_THAT(aggregator->Accumulate(v1), IsOk()); EXPECT_THAT(aggregator->Accumulate(v2), IsOk()); + + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize(); + aggregator = DeserializeTensorAggregator(federated_mean_intrinsic, + serialized_state.value()) + .value(); + } + EXPECT_THAT(aggregator->Accumulate(v3), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -64,7 +77,7 @@ TEST(FederatedMeanTest, ScalarAggregation_Succeeds) { EXPECT_THAT(result.value()[0], IsTensor<float>({}, {2})); } -TEST(FederatedMeanTest, WeightedScalarAggregation_Succeeds) { +TEST_P(FederatedMeanTest, WeightedScalarAggregation_Succeeds) { Intrinsic federated_mean_intrinsic{ "federated_weighted_mean", {TensorSpec{"foo", DT_FLOAT, {}}, TensorSpec{"bar", DT_FLOAT, {}}}, @@ -80,6 +93,14 @@ TEST(FederatedMeanTest, WeightedScalarAggregation_Succeeds) { Tensor w3 = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({5})).value(); EXPECT_THAT(aggregator->Accumulate({&v1, &w1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&v2, &w2}), IsOk()); + + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize(); + aggregator = DeserializeTensorAggregator(federated_mean_intrinsic, + serialized_state.value()) + .value(); + } + EXPECT_THAT(aggregator->Accumulate({&v3, &w3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -95,7 +116,7 @@ TEST(FederatedMeanTest, WeightedScalarAggregation_Succeeds) { EXPECT_THAT(result.value()[0], IsTensor<float>({}, {expected_value})); } -TEST(FederatedMeanTest, DenseAggregation_Succeeds) { +TEST_P(FederatedMeanTest, DenseAggregation_Succeeds) { Intrinsic federated_mean_intrinsic{"federated_mean", {TensorSpec{"foo", DT_FLOAT, {4}}}, {TensorSpec{"foo_out", DT_FLOAT, {4}}}, @@ -113,6 +134,14 @@ TEST(FederatedMeanTest, DenseAggregation_Succeeds) { .value(); EXPECT_THAT(aggregator->Accumulate(v1), IsOk()); EXPECT_THAT(aggregator->Accumulate(v2), IsOk()); + + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize(); + aggregator = DeserializeTensorAggregator(federated_mean_intrinsic, + serialized_state.value()) + .value(); + } + EXPECT_THAT(aggregator->Accumulate(v3), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -127,7 +156,7 @@ TEST(FederatedMeanTest, DenseAggregation_Succeeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(FederatedMeanTest, WeightedDenseAggregation_Succeeds) { +TEST_P(FederatedMeanTest, WeightedDenseAggregation_Succeeds) { Intrinsic federated_mean_intrinsic{ "federated_weighted_mean", {TensorSpec{"foo", DT_FLOAT, {4}}, TensorSpec{"bar", DT_FLOAT, {}}}, @@ -149,6 +178,14 @@ TEST(FederatedMeanTest, WeightedDenseAggregation_Succeeds) { Tensor w3 = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({5})).value(); EXPECT_THAT(aggregator->Accumulate({&v1, &w1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&v2, &w2}), IsOk()); + + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize(); + aggregator = DeserializeTensorAggregator(federated_mean_intrinsic, + serialized_state.value()) + .value(); + } + EXPECT_THAT(aggregator->Accumulate({&v3, &w3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -167,7 +204,7 @@ TEST(FederatedMeanTest, WeightedDenseAggregation_Succeeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(FederatedMeanTest, Merge_Succeeds) { +TEST_P(FederatedMeanTest, Merge_Succeeds) { Intrinsic federated_mean_intrinsic{"federated_mean", {TensorSpec{"foo", DT_FLOAT, {}}}, {TensorSpec{"foo_out", DT_FLOAT, {}}}, @@ -182,6 +219,17 @@ TEST(FederatedMeanTest, Merge_Succeeds) { EXPECT_THAT(aggregator2->Accumulate(v2), IsOk()); EXPECT_THAT(aggregator2->Accumulate(v3), IsOk()); + if (GetParam()) { + auto serialized_state1 = std::move(*aggregator1).Serialize(); + auto serialized_state2 = std::move(*aggregator2).Serialize(); + aggregator1 = DeserializeTensorAggregator(federated_mean_intrinsic, + serialized_state1.value()) + .value(); + aggregator2 = DeserializeTensorAggregator(federated_mean_intrinsic, + serialized_state2.value()) + .value(); + } + EXPECT_THAT(aggregator1->MergeWith(std::move(*aggregator2)), IsOk()); EXPECT_THAT(aggregator2->CanReport(), IsFalse()); EXPECT_THAT(aggregator1->CanReport(), IsTrue()); @@ -193,7 +241,7 @@ TEST(FederatedMeanTest, Merge_Succeeds) { EXPECT_THAT(result.value()[0], IsTensor<float>({}, {2})); } -TEST(FederatedMeanTest, WeightedDenseMerge_Succeeds) { +TEST_P(FederatedMeanTest, WeightedDenseMerge_Succeeds) { Intrinsic federated_mean_intrinsic{ "federated_weighted_mean", {TensorSpec{"foo", DT_FLOAT, {4}}, TensorSpec{"bar", DT_FLOAT, {}}}, @@ -218,6 +266,17 @@ TEST(FederatedMeanTest, WeightedDenseMerge_Succeeds) { EXPECT_THAT(aggregator2->Accumulate({&v2, &w2}), IsOk()); EXPECT_THAT(aggregator2->Accumulate({&v3, &w3}), IsOk()); + if (GetParam()) { + auto serialized_state1 = std::move(*aggregator1).Serialize(); + auto serialized_state2 = std::move(*aggregator2).Serialize(); + aggregator1 = DeserializeTensorAggregator(federated_mean_intrinsic, + serialized_state1.value()) + .value(); + aggregator2 = DeserializeTensorAggregator(federated_mean_intrinsic, + serialized_state2.value()) + .value(); + } + EXPECT_THAT(aggregator1->MergeWith(std::move(*aggregator2)), IsOk()); EXPECT_THAT(aggregator2->CanReport(), IsFalse()); EXPECT_THAT(aggregator1->CanReport(), IsTrue()); @@ -399,6 +458,27 @@ TEST(FederatedMeanTest, Create_UnsupportedNestedIntrinsic) { EXPECT_THAT(s.message(), HasSubstr("Expected no nested intrinsics")); } +TEST(FederatedMeanTest, Deserialize_FailToParseProto) { + Intrinsic federated_mean_intrinsic{"federated_mean", + {TensorSpec{"foo", DT_FLOAT, {}}}, + {TensorSpec{"foo_out", DT_FLOAT, {}}}, + {}, + {}}; + std::string invalid_state("invalid_state"); + Status s = + DeserializeTensorAggregator(federated_mean_intrinsic, invalid_state) + .status(); + EXPECT_THAT(s, IsCode(INVALID_ARGUMENT)); + EXPECT_THAT(s.message(), HasSubstr("Failed to parse")); +} + +INSTANTIATE_TEST_SUITE_P( + FederatedMeanTestInstantiation, FederatedMeanTest, + testing::ValuesIn<bool>({false, true}), + [](const testing::TestParamInfo<FederatedMeanTest::ParamType>& info) { + return info.param ? "SerializeDeserialize" : "None"; + }); + } // namespace } // namespace aggregation } // namespace fcp |