diff options
author | Zoe Gong <zgong@google.com> | 2024-04-17 15:09:23 -0700 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2024-04-17 15:10:39 -0700 |
commit | 6661e5623ba9d6ffe34d138462890c5d18076642 (patch) | |
tree | ced91f013a39609b385527ca548775e7d13d4ee3 | |
parent | 68bf429181493bac9813da66eff3ce7d68d98e44 (diff) | |
download | federated-compute-6661e5623ba9d6ffe34d138462890c5d18076642.tar.gz |
Add Create method for CheckpointAggregator that takes a vector of Intrinsics rather than a Configuration.
PiperOrigin-RevId: 625819529
-rw-r--r-- | fcp/aggregation/protocol/BUILD | 1 | ||||
-rw-r--r-- | fcp/aggregation/protocol/checkpoint_aggregator.cc | 7 | ||||
-rw-r--r-- | fcp/aggregation/protocol/checkpoint_aggregator.h | 12 | ||||
-rw-r--r-- | fcp/aggregation/protocol/checkpoint_aggregator_test.cc | 110 |
4 files changed, 128 insertions, 2 deletions
diff --git a/fcp/aggregation/protocol/BUILD b/fcp/aggregation/protocol/BUILD index 050e676..d6fb4f4 100644 --- a/fcp/aggregation/protocol/BUILD +++ b/fcp/aggregation/protocol/BUILD @@ -155,6 +155,7 @@ cc_test( ":configuration_cc_proto", "//fcp/aggregation/core:aggregator", "//fcp/aggregation/core:tensor", + "//fcp/aggregation/core:tensor_cc_proto", "//fcp/aggregation/testing", "//fcp/aggregation/testing:mocks", "//fcp/aggregation/testing:test_data", diff --git a/fcp/aggregation/protocol/checkpoint_aggregator.cc b/fcp/aggregation/protocol/checkpoint_aggregator.cc index b11bc7f..dc8e270 100644 --- a/fcp/aggregation/protocol/checkpoint_aggregator.cc +++ b/fcp/aggregation/protocol/checkpoint_aggregator.cc @@ -67,11 +67,14 @@ absl::Status CheckpointAggregator::ValidateConfig( absl::StatusOr<std::unique_ptr<CheckpointAggregator>> CheckpointAggregator::Create(const Configuration& configuration) { - FCP_RETURN_IF_ERROR(ValidateConfig(configuration)); - FCP_ASSIGN_OR_RETURN(std::vector<Intrinsic> intrinsics, ParseFromConfig(configuration)); + return Create(std::move(intrinsics)); +} + +absl::StatusOr<std::unique_ptr<CheckpointAggregator>> +CheckpointAggregator::Create(std::vector<Intrinsic> intrinsics) { std::vector<std::unique_ptr<TensorAggregator>> aggregators; for (const Intrinsic& intrinsic : intrinsics) { FCP_ASSIGN_OR_RETURN(std::unique_ptr<TensorAggregator> aggregator, diff --git a/fcp/aggregation/protocol/checkpoint_aggregator.h b/fcp/aggregation/protocol/checkpoint_aggregator.h index b7b9617..2f0502b 100644 --- a/fcp/aggregation/protocol/checkpoint_aggregator.h +++ b/fcp/aggregation/protocol/checkpoint_aggregator.h @@ -49,10 +49,22 @@ class CheckpointAggregator { // Returns INVALID_ARGUMENT if the configuration is invalid. static absl::Status ValidateConfig(const Configuration& configuration); + // Validates the Intrinsics that will subsequently be used to create an + // instance of CheckpointAggregator. + // Returns INVALID_ARGUMENT if the configuration is invalid. + static absl::Status ValidateIntrinsics( + const std::vector<Intrinsic>& intrinsics); + // Creates an instance of CheckpointAggregator. static absl::StatusOr<std::unique_ptr<CheckpointAggregator>> Create( const Configuration& configuration); + // Creates an instance of CheckpointAggregator. + // The `intrinsics` are expected to be created using `ParseFromConfig` which + // validates the configuration. + static absl::StatusOr<std::unique_ptr<CheckpointAggregator>> Create( + std::vector<Intrinsic> intrinsics); + // Accumulates a checkpoint via nested tensor aggregators. The tensors are // provided by the CheckpointParser instance. absl::Status Accumulate(CheckpointParser& checkpoint_parser); diff --git a/fcp/aggregation/protocol/checkpoint_aggregator_test.cc b/fcp/aggregation/protocol/checkpoint_aggregator_test.cc index b9c33c3..9ced962 100644 --- a/fcp/aggregation/protocol/checkpoint_aggregator_test.cc +++ b/fcp/aggregation/protocol/checkpoint_aggregator_test.cc @@ -21,6 +21,7 @@ #include <functional> #include <memory> #include <utility> +#include <vector> #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -32,10 +33,12 @@ #include "fcp/aggregation/core/datatype.h" #include "fcp/aggregation/core/intrinsic.h" #include "fcp/aggregation/core/tensor.h" +#include "fcp/aggregation/core/tensor.pb.h" #include "fcp/aggregation/core/tensor_aggregator.h" #include "fcp/aggregation/core/tensor_aggregator_factory.h" #include "fcp/aggregation/core/tensor_aggregator_registry.h" #include "fcp/aggregation/core/tensor_shape.h" +#include "fcp/aggregation/core/tensor_spec.h" #include "fcp/aggregation/protocol/configuration.pb.h" #include "fcp/aggregation/testing/mocks.h" #include "fcp/aggregation/testing/test_data.h" @@ -124,6 +127,113 @@ std::unique_ptr<CheckpointAggregator> CreateWithDefaultFedSqlConfig() { return Create(default_fedsql_configuration()); } +TEST(CheckpointAggregatorTest, CreateFromIntrinsicsSuccess) { + std::vector<Intrinsic> intrinsics; + intrinsics.push_back({"federated_sum", + {TensorSpec("foo", DT_INT32, {})}, + {TensorSpec("foo_out", DT_INT32, {})}, + {}, + {}}); + EXPECT_OK(CheckpointAggregator::Create(std::move(intrinsics))); +} + +TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedNumberOfInputs) { + std::vector<Intrinsic> intrinsics; + intrinsics.push_back( + {"federated_sum", + {TensorSpec("foo", DT_INT32, {}), TensorSpec("bar", DT_INT32, {})}, + {TensorSpec("foo_out", DT_INT32, {})}, + {}, + {}}); + EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)), + IsCode(INVALID_ARGUMENT)); +} + +TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedNumberOfOutputs) { + std::vector<Intrinsic> intrinsics; + intrinsics.push_back({"federated_sum", + {TensorSpec("foo", DT_INT32, {})}, + {TensorSpec("foo_out", DT_INT32, {}), + TensorSpec("bar_out", DT_INT32, {})}, + {}, + {}}); + EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)), + IsCode(INVALID_ARGUMENT)); +} + +TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedInputType) { + std::vector<Intrinsic> intrinsics; + Tensor parameter = + Tensor::Create(DT_FLOAT, {1}, CreateTestData<float>({42})).value(); + Intrinsic intrinsic{"federated_sum", + {TensorSpec("foo", DT_INT32, {})}, + {TensorSpec("foo_out", DT_INT32, {})}, + {}, + {}}; + intrinsic.parameters.push_back(std::move(parameter)); + intrinsics.push_back(std::move(intrinsic)); + Configuration config_message = PARSE_TEXT_PROTO(R"pb( + intrinsic_configs { + intrinsic_uri: "federated_sum" + intrinsic_args { parameter {} } + output_tensors { + name: "foo_out" + dtype: DT_INT32 + shape {} + } + } + )pb"); + + EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)), + IsCode(INVALID_ARGUMENT)); +} + +TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedIntrinsicUri) { + std::vector<Intrinsic> intrinsics; + intrinsics.push_back({"unsupported_xyz", + {TensorSpec("foo", DT_INT32, {})}, + {TensorSpec("foo_out", DT_INT32, {})}, + {}, + {}}); + EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)), + IsCode(NOT_FOUND)); +} + +TEST(CheckpointAggregatorTest, CreateFromIntrinsicsUnsupportedInputSpec) { + std::vector<Intrinsic> intrinsics; + intrinsics.push_back({"federated_sum", + {TensorSpec("foo", DT_INT32, {-1})}, + {TensorSpec("foo_out", DT_INT32, {})}, + {}, + {}}); + EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)), + IsCode(INVALID_ARGUMENT)); +} + +TEST(CheckpointAggregatorTest, + CreateFromIntrinsicsMismatchingInputAndOutputDataType) { + std::vector<Intrinsic> intrinsics; + intrinsics.push_back({"federated_sum", + {TensorSpec("foo", DT_INT32, {})}, + {TensorSpec("foo_out", DT_FLOAT, {})}, + {}, + {}}); + EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)), + IsCode(INVALID_ARGUMENT)); +} + +TEST(CheckpointAggregatorTest, + CreateFromIntrinsicsMismatchingInputAndOutputShape) { + std::vector<Intrinsic> intrinsics; + intrinsics.push_back({"federated_sum", + {TensorSpec("foo", DT_INT32, {1})}, + {TensorSpec("foo_out", DT_INT32, {2})}, + {}, + {}}); + EXPECT_THAT(CheckpointAggregator::Create(std::move(intrinsics)), + IsCode(INVALID_ARGUMENT)); +} + TEST(CheckpointAggregatorTest, CreateSuccess) { EXPECT_OK(CheckpointAggregator::Create(default_configuration())); } |