aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZoe Gong <zgong@google.com>2024-04-17 15:09:23 -0700
committerCopybara-Service <copybara-worker@google.com>2024-04-17 15:10:39 -0700
commit6661e5623ba9d6ffe34d138462890c5d18076642 (patch)
treeced91f013a39609b385527ca548775e7d13d4ee3
parent68bf429181493bac9813da66eff3ce7d68d98e44 (diff)
downloadfederated-compute-6661e5623ba9d6ffe34d138462890c5d18076642.tar.gz
Add Create method for CheckpointAggregator that takes a vector of Intrinsics rather than a Configuration.
PiperOrigin-RevId: 625819529
-rw-r--r--fcp/aggregation/protocol/BUILD1
-rw-r--r--fcp/aggregation/protocol/checkpoint_aggregator.cc7
-rw-r--r--fcp/aggregation/protocol/checkpoint_aggregator.h12
-rw-r--r--fcp/aggregation/protocol/checkpoint_aggregator_test.cc110
4 files changed, 128 insertions, 2 deletions
diff --git a/fcp/aggregation/protocol/BUILD b/fcp/aggregation/protocol/BUILD
index 050e676..d6fb4f4 100644
--- a/fcp/aggregation/protocol/BUILD
+++ b/fcp/aggregation/protocol/BUILD
@@ -155,6 +155,7 @@ cc_test(
":configuration_cc_proto",
"//fcp/aggregation/core:aggregator",
"//fcp/aggregation/core:tensor",
+ "//fcp/aggregation/core:tensor_cc_proto",
"//fcp/aggregation/testing",
"//fcp/aggregation/testing:mocks",
"//fcp/aggregation/testing:test_data",
diff --git a/fcp/aggregation/protocol/checkpoint_aggregator.cc b/fcp/aggregation/protocol/checkpoint_aggregator.cc
index b11bc7f..dc8e270 100644
--- a/fcp/aggregation/protocol/checkpoint_aggregator.cc
+++ b/fcp/aggregation/protocol/checkpoint_aggregator.cc
@@ -67,11 +67,14 @@ absl::Status CheckpointAggregator::ValidateConfig(
absl::StatusOr<std::unique_ptr<CheckpointAggregator>>
CheckpointAggregator::Create(const Configuration& configuration) {
- FCP_RETURN_IF_ERROR(ValidateConfig(configuration));
-
FCP_ASSIGN_OR_RETURN(std::vector<Intrinsic> intrinsics,
ParseFromConfig(configuration));
+ return Create(std::move(intrinsics));
+}
+
+absl::StatusOr<std::unique_ptr<CheckpointAggregator>>
+CheckpointAggregator::Create(std::vector<Intrinsic> intrinsics) {
std::vector<std::unique_ptr<TensorAggregator>> aggregators;
for (const Intrinsic& intrinsic : intrinsics) {
FCP_ASSIGN_OR_RETURN(std::unique_ptr<TensorAggregator> aggregator,
diff --git a/fcp/aggregation/protocol/checkpoint_aggregator.h b/fcp/aggregation/protocol/checkpoint_aggregator.h
index b7b9617..2f0502b 100644
--- a/fcp/aggregation/protocol/checkpoint_aggregator.h
+++ b/fcp/aggregation/protocol/checkpoint_aggregator.h
@@ -49,10 +49,22 @@ class CheckpointAggregator {
// Returns INVALID_ARGUMENT if the configuration is invalid.
static absl::Status ValidateConfig(const Configuration& configuration);
+ // Validates the Intrinsics that will subsequently be used to create an
+ // instance of CheckpointAggregator.
+ // Returns INVALID_ARGUMENT if the configuration is invalid.
+ static absl::Status ValidateIntrinsics(
+ const std::vector<Intrinsic>& intrinsics);
+
// Creates an instance of CheckpointAggregator.
static absl::StatusOr<std::unique_ptr<CheckpointAggregator>> Create(
const Configuration& configuration);
+ // Creates an instance of CheckpointAggregator.
+ // The `intrinsics` are expected to be created using `ParseFromConfig` which
+ // validates the configuration.
+ static absl::StatusOr<std::unique_ptr<CheckpointAggregator>> Create(
+ std::vector<Intrinsic> intrinsics);
+
// Accumulates a checkpoint via nested tensor aggregators. The tensors are
// provided by the CheckpointParser instance.
absl::Status Accumulate(CheckpointParser& checkpoint_parser);
diff --git a/fcp/aggregation/protocol/checkpoint_aggregator_test.cc b/fcp/aggregation/protocol/checkpoint_aggregator_test.cc
index b9c33c3..9ced962 100644
--- a/fcp/aggregation/protocol/checkpoint_aggregator_test.cc
+++ b/fcp/aggregation/protocol/checkpoint_aggregator_test.cc
@@ -21,6 +21,7 @@
#include <functional>
#include <memory>
#include <utility>
+#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -32,10 +33,12 @@
#include "fcp/aggregation/core/datatype.h"
#include "fcp/aggregation/core/intrinsic.h"
#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor.pb.h"
#include "fcp/aggregation/core/tensor_aggregator.h"
#include "fcp/aggregation/core/tensor_aggregator_factory.h"
#include "fcp/aggregation/core/tensor_aggregator_registry.h"
#include "fcp/aggregation/core/tensor_shape.h"
+#include "fcp/aggregation/core/tensor_spec.h"
#include "fcp/aggregation/protocol/configuration.pb.h"
#include "fcp/aggregation/testing/mocks.h"
#include "fcp/aggregation/testing/test_data.h"
@@ -124,6 +127,113 @@ std::unique_ptr<CheckpointAggregator> CreateWithDefaultFedSqlConfig() {
return Create(default_fedsql_configuration());
}
+TEST(CheckpointAggregatorTest, CreateFromIntrinsicsSuccess) {
+ std::vector<Intrinsic> intrinsics;
+ intrinsics.push_back({"federated_sum",
+ {TensorSpec("foo", DT_INT32, {})},
+ {TensorSpec("foo_out", DT_INT32, {})},
+ {},
+ {}});
+ EXPECT_OK(CheckpointAggregator::Create(std::move(intrinsics)));
+}
+
+TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedNumberOfInputs) {
+ std::vector<Intrinsic> intrinsics;
+ intrinsics.push_back(
+ {"federated_sum",
+ {TensorSpec("foo", DT_INT32, {}), TensorSpec("bar", DT_INT32, {})},
+ {TensorSpec("foo_out", DT_INT32, {})},
+ {},
+ {}});
+ EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedNumberOfOutputs) {
+ std::vector<Intrinsic> intrinsics;
+ intrinsics.push_back({"federated_sum",
+ {TensorSpec("foo", DT_INT32, {})},
+ {TensorSpec("foo_out", DT_INT32, {}),
+ TensorSpec("bar_out", DT_INT32, {})},
+ {},
+ {}});
+ EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedInputType) {
+ std::vector<Intrinsic> intrinsics;
+ Tensor parameter =
+ Tensor::Create(DT_FLOAT, {1}, CreateTestData<float>({42})).value();
+ Intrinsic intrinsic{"federated_sum",
+ {TensorSpec("foo", DT_INT32, {})},
+ {TensorSpec("foo_out", DT_INT32, {})},
+ {},
+ {}};
+ intrinsic.parameters.push_back(std::move(parameter));
+ intrinsics.push_back(std::move(intrinsic));
+ Configuration config_message = PARSE_TEXT_PROTO(R"pb(
+ intrinsic_configs {
+ intrinsic_uri: "federated_sum"
+ intrinsic_args { parameter {} }
+ output_tensors {
+ name: "foo_out"
+ dtype: DT_INT32
+ shape {}
+ }
+ }
+ )pb");
+
+ EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedIntrinsicUri) {
+ std::vector<Intrinsic> intrinsics;
+ intrinsics.push_back({"unsupported_xyz",
+ {TensorSpec("foo", DT_INT32, {})},
+ {TensorSpec("foo_out", DT_INT32, {})},
+ {},
+ {}});
+ EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)),
+ IsCode(NOT_FOUND));
+}
+
+TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedInputSpec) {
+ std::vector<Intrinsic> intrinsics;
+ intrinsics.push_back({"federated_sum",
+ {TensorSpec("foo", DT_INT32, {-1})},
+ {TensorSpec("foo_out", DT_INT32, {})},
+ {},
+ {}});
+ EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CheckpointAggregatorTest,
+ CreateFromIntrinsicsMismatchingInputAndOutputDataType) {
+ std::vector<Intrinsic> intrinsics;
+ intrinsics.push_back({"federated_sum",
+ {TensorSpec("foo", DT_INT32, {})},
+ {TensorSpec("foo_out", DT_FLOAT, {})},
+ {},
+ {}});
+ EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)),
+ IsCode(INVALID_ARGUMENT));
+}
+
+TEST(CheckpointAggregatorTest,
+ CreateFromIntrinsicsMismatchingInputAndOutputShape) {
+ std::vector<Intrinsic> intrinsics;
+ intrinsics.push_back({"federated_sum",
+ {TensorSpec("foo", DT_INT32, {1})},
+ {TensorSpec("foo_out", DT_INT32, {2})},
+ {},
+ {}});
+ EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)),
+ IsCode(INVALID_ARGUMENT));
+}
+
TEST(CheckpointAggregatorTest, CreateSuccess) {
EXPECT_OK(CheckpointAggregator::Create(default_configuration()));
}