aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaya Spivak <mspivak@google.com>2024-04-09 09:49:22 -0700
committerCopybara-Service <copybara-worker@google.com>2024-04-09 09:50:19 -0700
commit099841e1eb49d720e6dafe34e678f211864f80b8 (patch)
tree5925b8f3ec686b4ef936df1e90689ee6e6cca247
parent5fb2bf43a96ef4ba83ed28ad218eb6e036dc55f4 (diff)
downloadfederated-compute-099841e1eb49d720e6dafe34e678f211864f80b8.tar.gz
Implement serialization and deserialization for the FederatedMean aggregator.
PiperOrigin-RevId: 623197173
-rw-r--r--fcp/aggregation/core/BUILD1
-rw-r--r--fcp/aggregation/core/agg_core.proto7
-rw-r--r--fcp/aggregation/core/federated_mean.cc101
-rw-r--r--fcp/aggregation/core/federated_mean_test.cc92
4 files changed, 172 insertions, 29 deletions
diff --git a/fcp/aggregation/core/BUILD b/fcp/aggregation/core/BUILD
index 841de08..6ad82bf 100644
--- a/fcp/aggregation/core/BUILD
+++ b/fcp/aggregation/core/BUILD
@@ -252,6 +252,7 @@ cc_test(
srcs = ["federated_mean_test.cc"],
copts = FCP_COPTS,
deps = [
+ ":agg_core_cc_proto",
":aggregation_cores",
":aggregator",
":tensor",
diff --git a/fcp/aggregation/core/agg_core.proto b/fcp/aggregation/core/agg_core.proto
index 0ce048e..6efbed5 100644
--- a/fcp/aggregation/core/agg_core.proto
+++ b/fcp/aggregation/core/agg_core.proto
@@ -7,3 +7,10 @@ message AggVectorAggregatorState {
uint64 num_inputs = 1;
bytes vector_data = 2;
}
+
+// Internal state representation of a FederatedMeanAggregator.
+message FederatedMeanAggregatorState {
+ uint64 num_inputs = 1;
+ bytes weights_sum = 2;
+ bytes weighted_values_sum = 3;
+}
diff --git a/fcp/aggregation/core/federated_mean.cc b/fcp/aggregation/core/federated_mean.cc
index 5e3d35c..b7c21d3 100644
--- a/fcp/aggregation/core/federated_mean.cc
+++ b/fcp/aggregation/core/federated_mean.cc
@@ -16,19 +16,21 @@
#include <cstddef>
#include <memory>
+#include <string>
#include <utility>
#include <vector>
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/agg_vector.h"
#include "fcp/aggregation/core/datatype.h"
#include "fcp/aggregation/core/input_tensor_list.h"
#include "fcp/aggregation/core/intrinsic.h"
#include "fcp/aggregation/core/mutable_vector_data.h"
#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor.pb.h"
#include "fcp/aggregation/core/tensor_aggregator.h"
#include "fcp/aggregation/core/tensor_aggregator_factory.h"
#include "fcp/aggregation/core/tensor_aggregator_registry.h"
-#include "fcp/aggregation/core/tensor_data.h"
#include "fcp/aggregation/core/tensor_shape.h"
#include "fcp/aggregation/core/tensor_spec.h"
#include "fcp/base/monitoring.h"
@@ -42,13 +44,29 @@ constexpr char kFederatedWeightedMeanUri[] = "federated_weighted_mean";
template <typename V, typename W>
class FederatedMean final : public TensorAggregator {
public:
- explicit FederatedMean(DataType dtype, TensorShape shape,
- MutableVectorData<V>* weighted_values_sum)
- : weighted_values_sum_(*weighted_values_sum),
- result_tensor_(
- Tensor::Create(dtype, shape,
- std::unique_ptr<TensorData>(weighted_values_sum))
- .value()) {}
+ explicit FederatedMean(
+ DataType dtype, TensorShape shape,
+ std::unique_ptr<MutableVectorData<V>> weighted_values_sum)
+ : FederatedMean(dtype, shape, std::move(weighted_values_sum), 0, 0) {}
+
+ FederatedMean(DataType dtype, TensorShape shape,
+ std::unique_ptr<MutableVectorData<V>> weighted_values_sum,
+ W weights_sum, int num_inputs)
+ : dtype_(dtype),
+ shape_(std::move(shape)),
+ weighted_values_sum_(std::move(weighted_values_sum)),
+ weights_sum_(weights_sum),
+ num_inputs_(num_inputs) {}
+
+ StatusOr<std::string> Serialize() && override {
+ FederatedMeanAggregatorState aggregator_state;
+ aggregator_state.set_num_inputs(num_inputs_);
+ *(aggregator_state.mutable_weighted_values_sum()) =
+ weighted_values_sum_->EncodeContent();
+ *(aggregator_state.mutable_weights_sum()) = std::string(
+ reinterpret_cast<char*>(&weights_sum_), sizeof(weights_sum_));
+ return aggregator_state.SerializeAsString();
+ }
private:
Status MergeWith(TensorAggregator&& other) override {
@@ -61,16 +79,16 @@ class FederatedMean final : public TensorAggregator {
}
FCP_RETURN_IF_ERROR((*other_ptr).CheckValid());
- std::pair<std::vector<V>, W> other_internal_state =
+ std::pair<std::unique_ptr<MutableVectorData<V>>, W> other_internal_state =
other_ptr->GetInternalState();
- if (other_internal_state.first.size() != weighted_values_sum_.size()) {
+ if (other_internal_state.first->size() != weighted_values_sum_->size()) {
return FCP_STATUS(INVALID_ARGUMENT)
<< "FederatedMean::MergeWith: Can only merge weighted value sum "
"tensors of equal length.";
}
- for (int i = 0; i < weighted_values_sum_.size(); ++i) {
- weighted_values_sum_[i] += other_internal_state.first[i];
+ for (int i = 0; i < weighted_values_sum_->size(); ++i) {
+ (*weighted_values_sum_)[i] += (*other_internal_state.first)[i];
}
weights_sum_ += other_internal_state.second;
num_inputs_ += other_ptr->GetNumInputs();
@@ -101,12 +119,12 @@ class FederatedMean final : public TensorAggregator {
"weights are allowed.";
}
for (auto value : values) {
- weighted_values_sum_[value.index] += value.value * weight;
+ (*weighted_values_sum_)[value.index] += value.value * weight;
}
weights_sum_ += weight;
} else {
for (auto value : values) {
- weighted_values_sum_[value.index] += value.value;
+ (*weighted_values_sum_)[value.index] += value.value;
}
}
num_inputs_++;
@@ -126,29 +144,32 @@ class FederatedMean final : public TensorAggregator {
// Produce the final weighted mean values by dividing the weighted values
// sum by the weights sum (tracked by weights_sum_ in the weighted case and
// num_inputs_ in the non-weighted case).
- for (int i = 0; i < weighted_values_sum_.size(); ++i) {
- weighted_values_sum_[i] /=
+ for (int i = 0; i < weighted_values_sum_->size(); ++i) {
+ (*weighted_values_sum_)[i] /=
(weights_sum_ > 0 ? weights_sum_ : num_inputs_);
}
OutputTensorList outputs = std::vector<Tensor>();
- outputs.push_back(std::move(result_tensor_));
+ outputs.push_back(
+ Tensor::Create(dtype_, shape_, std::move(weighted_values_sum_))
+ .value());
return outputs;
}
int GetNumInputs() const override { return num_inputs_; }
- std::pair<std::vector<V>, W> GetInternalState() {
+ std::pair<std::unique_ptr<MutableVectorData<V>>, W> GetInternalState() {
output_consumed_ = true;
return std::make_pair(std::move(weighted_values_sum_), weights_sum_);
}
bool output_consumed_ = false;
- std::vector<V>& weighted_values_sum_;
+ DataType dtype_;
+ TensorShape shape_;
+ std::unique_ptr<MutableVectorData<V>> weighted_values_sum_;
// In the weighted case, use the weights_sum_ variable to track the total
// weight. Otherwise, just rely on the num_inputs_ variable.
- W weights_sum_ = 0;
- Tensor result_tensor_;
- int num_inputs_ = 0;
+ W weights_sum_;
+ int num_inputs_;
};
// Factory class for the FederatedMean.
@@ -162,6 +183,24 @@ class FederatedMeanFactory final : public TensorAggregatorFactory {
StatusOr<std::unique_ptr<TensorAggregator>> Create(
const Intrinsic& intrinsic) const override {
+ return CreateInternal(intrinsic, nullptr);
+ }
+
+ StatusOr<std::unique_ptr<TensorAggregator>> Deserialize(
+ const Intrinsic& intrinsic, std::string serialized_state) const override {
+ FederatedMeanAggregatorState aggregator_state;
+ if (!aggregator_state.ParseFromString(serialized_state)) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "FederatedMeanFactory::Deserialize: Failed to parse "
+ "FederatedMeanAggregatorState.";
+ }
+ return CreateInternal(intrinsic, &aggregator_state);
+ }
+
+ private:
+ StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal(
+ const Intrinsic& intrinsic,
+ const FederatedMeanAggregatorState* aggregator_state) const {
// Check that the configuration is valid.
if (kFederatedMeanUri == intrinsic.uri) {
if (intrinsic.inputs.size() != 1) {
@@ -232,13 +271,29 @@ class FederatedMeanFactory final : public TensorAggregatorFactory {
}
std::unique_ptr<TensorAggregator> aggregator;
+ if (aggregator_state == nullptr) {
+ FLOATING_ONLY_DTYPE_CASES(
+ input_value_type, V,
+ NUMERICAL_ONLY_DTYPE_CASES(
+ input_weight_type, W,
+ aggregator = (std::make_unique<FederatedMean<V, W>>(
+ input_value_type, input_value_spec.shape(),
+ std::make_unique<MutableVectorData<V>>(
+ value_num_elements.value())))));
+ return aggregator;
+ }
+
FLOATING_ONLY_DTYPE_CASES(
input_value_type, V,
NUMERICAL_ONLY_DTYPE_CASES(
input_weight_type, W,
aggregator = (std::make_unique<FederatedMean<V, W>>(
input_value_type, input_value_spec.shape(),
- new MutableVectorData<V>(value_num_elements.value())))));
+ MutableVectorData<V>::CreateFromEncodedContent(
+ aggregator_state->weighted_values_sum()),
+ *(reinterpret_cast<const W*>(
+ aggregator_state->weights_sum().data())),
+ aggregator_state->num_inputs()))));
return aggregator;
}
};
diff --git a/fcp/aggregation/core/federated_mean_test.cc b/fcp/aggregation/core/federated_mean_test.cc
index 0e1ef24..9e67312 100644
--- a/fcp/aggregation/core/federated_mean_test.cc
+++ b/fcp/aggregation/core/federated_mean_test.cc
@@ -15,10 +15,12 @@
*/
#include <memory>
+#include <string>
#include <utility>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/intrinsic.h"
#include "fcp/aggregation/core/tensor.h"
#include "fcp/aggregation/core/tensor_aggregator_factory.h"
@@ -38,8 +40,11 @@ using testing::Eq;
using testing::HasSubstr;
using testing::IsFalse;
using testing::IsTrue;
+using testing::TestWithParam;
-TEST(FederatedMeanTest, ScalarAggregation_Succeeds) {
+using FederatedMeanTest = TestWithParam<bool>;
+
+TEST_P(FederatedMeanTest, ScalarAggregation_Succeeds) {
Intrinsic federated_mean_intrinsic{"federated_mean",
{TensorSpec{"foo", DT_FLOAT, {}}},
{TensorSpec{"foo_out", DT_FLOAT, {}}},
@@ -51,6 +56,14 @@ TEST(FederatedMeanTest, ScalarAggregation_Succeeds) {
Tensor v3 = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({3})).value();
EXPECT_THAT(aggregator->Accumulate(v1), IsOk());
EXPECT_THAT(aggregator->Accumulate(v2), IsOk());
+
+ if (GetParam()) {
+ auto serialized_state = std::move(*aggregator).Serialize();
+ aggregator = DeserializeTensorAggregator(federated_mean_intrinsic,
+ serialized_state.value())
+ .value();
+ }
+
EXPECT_THAT(aggregator->Accumulate(v3), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -64,7 +77,7 @@ TEST(FederatedMeanTest, ScalarAggregation_Succeeds) {
EXPECT_THAT(result.value()[0], IsTensor<float>({}, {2}));
}
-TEST(FederatedMeanTest, WeightedScalarAggregation_Succeeds) {
+TEST_P(FederatedMeanTest, WeightedScalarAggregation_Succeeds) {
Intrinsic federated_mean_intrinsic{
"federated_weighted_mean",
{TensorSpec{"foo", DT_FLOAT, {}}, TensorSpec{"bar", DT_FLOAT, {}}},
@@ -80,6 +93,14 @@ TEST(FederatedMeanTest, WeightedScalarAggregation_Succeeds) {
Tensor w3 = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({5})).value();
EXPECT_THAT(aggregator->Accumulate({&v1, &w1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&v2, &w2}), IsOk());
+
+ if (GetParam()) {
+ auto serialized_state = std::move(*aggregator).Serialize();
+ aggregator = DeserializeTensorAggregator(federated_mean_intrinsic,
+ serialized_state.value())
+ .value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&v3, &w3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -95,7 +116,7 @@ TEST(FederatedMeanTest, WeightedScalarAggregation_Succeeds) {
EXPECT_THAT(result.value()[0], IsTensor<float>({}, {expected_value}));
}
-TEST(FederatedMeanTest, DenseAggregation_Succeeds) {
+TEST_P(FederatedMeanTest, DenseAggregation_Succeeds) {
Intrinsic federated_mean_intrinsic{"federated_mean",
{TensorSpec{"foo", DT_FLOAT, {4}}},
{TensorSpec{"foo_out", DT_FLOAT, {4}}},
@@ -113,6 +134,14 @@ TEST(FederatedMeanTest, DenseAggregation_Succeeds) {
.value();
EXPECT_THAT(aggregator->Accumulate(v1), IsOk());
EXPECT_THAT(aggregator->Accumulate(v2), IsOk());
+
+ if (GetParam()) {
+ auto serialized_state = std::move(*aggregator).Serialize();
+ aggregator = DeserializeTensorAggregator(federated_mean_intrinsic,
+ serialized_state.value())
+ .value();
+ }
+
EXPECT_THAT(aggregator->Accumulate(v3), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -127,7 +156,7 @@ TEST(FederatedMeanTest, DenseAggregation_Succeeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(FederatedMeanTest, WeightedDenseAggregation_Succeeds) {
+TEST_P(FederatedMeanTest, WeightedDenseAggregation_Succeeds) {
Intrinsic federated_mean_intrinsic{
"federated_weighted_mean",
{TensorSpec{"foo", DT_FLOAT, {4}}, TensorSpec{"bar", DT_FLOAT, {}}},
@@ -149,6 +178,14 @@ TEST(FederatedMeanTest, WeightedDenseAggregation_Succeeds) {
Tensor w3 = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({5})).value();
EXPECT_THAT(aggregator->Accumulate({&v1, &w1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&v2, &w2}), IsOk());
+
+ if (GetParam()) {
+ auto serialized_state = std::move(*aggregator).Serialize();
+ aggregator = DeserializeTensorAggregator(federated_mean_intrinsic,
+ serialized_state.value())
+ .value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&v3, &w3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -167,7 +204,7 @@ TEST(FederatedMeanTest, WeightedDenseAggregation_Succeeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(FederatedMeanTest, Merge_Succeeds) {
+TEST_P(FederatedMeanTest, Merge_Succeeds) {
Intrinsic federated_mean_intrinsic{"federated_mean",
{TensorSpec{"foo", DT_FLOAT, {}}},
{TensorSpec{"foo_out", DT_FLOAT, {}}},
@@ -182,6 +219,17 @@ TEST(FederatedMeanTest, Merge_Succeeds) {
EXPECT_THAT(aggregator2->Accumulate(v2), IsOk());
EXPECT_THAT(aggregator2->Accumulate(v3), IsOk());
+ if (GetParam()) {
+ auto serialized_state1 = std::move(*aggregator1).Serialize();
+ auto serialized_state2 = std::move(*aggregator2).Serialize();
+ aggregator1 = DeserializeTensorAggregator(federated_mean_intrinsic,
+ serialized_state1.value())
+ .value();
+ aggregator2 = DeserializeTensorAggregator(federated_mean_intrinsic,
+ serialized_state2.value())
+ .value();
+ }
+
EXPECT_THAT(aggregator1->MergeWith(std::move(*aggregator2)), IsOk());
EXPECT_THAT(aggregator2->CanReport(), IsFalse());
EXPECT_THAT(aggregator1->CanReport(), IsTrue());
@@ -193,7 +241,7 @@ TEST(FederatedMeanTest, Merge_Succeeds) {
EXPECT_THAT(result.value()[0], IsTensor<float>({}, {2}));
}
-TEST(FederatedMeanTest, WeightedDenseMerge_Succeeds) {
+TEST_P(FederatedMeanTest, WeightedDenseMerge_Succeeds) {
Intrinsic federated_mean_intrinsic{
"federated_weighted_mean",
{TensorSpec{"foo", DT_FLOAT, {4}}, TensorSpec{"bar", DT_FLOAT, {}}},
@@ -218,6 +266,17 @@ TEST(FederatedMeanTest, WeightedDenseMerge_Succeeds) {
EXPECT_THAT(aggregator2->Accumulate({&v2, &w2}), IsOk());
EXPECT_THAT(aggregator2->Accumulate({&v3, &w3}), IsOk());
+ if (GetParam()) {
+ auto serialized_state1 = std::move(*aggregator1).Serialize();
+ auto serialized_state2 = std::move(*aggregator2).Serialize();
+ aggregator1 = DeserializeTensorAggregator(federated_mean_intrinsic,
+ serialized_state1.value())
+ .value();
+ aggregator2 = DeserializeTensorAggregator(federated_mean_intrinsic,
+ serialized_state2.value())
+ .value();
+ }
+
EXPECT_THAT(aggregator1->MergeWith(std::move(*aggregator2)), IsOk());
EXPECT_THAT(aggregator2->CanReport(), IsFalse());
EXPECT_THAT(aggregator1->CanReport(), IsTrue());
@@ -399,6 +458,27 @@ TEST(FederatedMeanTest, Create_UnsupportedNestedIntrinsic) {
EXPECT_THAT(s.message(), HasSubstr("Expected no nested intrinsics"));
}
+TEST(FederatedMeanTest, Deserialize_FailToParseProto) {
+ Intrinsic federated_mean_intrinsic{"federated_mean",
+ {TensorSpec{"foo", DT_FLOAT, {}}},
+ {TensorSpec{"foo_out", DT_FLOAT, {}}},
+ {},
+ {}};
+ std::string invalid_state("invalid_state");
+ Status s =
+ DeserializeTensorAggregator(federated_mean_intrinsic, invalid_state)
+ .status();
+ EXPECT_THAT(s, IsCode(INVALID_ARGUMENT));
+ EXPECT_THAT(s.message(), HasSubstr("Failed to parse"));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ FederatedMeanTestInstantiation, FederatedMeanTest,
+ testing::ValuesIn<bool>({false, true}),
+ [](const testing::TestParamInfo<FederatedMeanTest::ParamType>& info) {
+ return info.param ? "SerializeDeserialize" : "None";
+ });
+
} // namespace
} // namespace aggregation
} // namespace fcp