diff options
author | Maya Spivak <mspivak@google.com> | 2024-04-15 11:30:25 -0700 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2024-04-15 11:32:16 -0700 |
commit | 7935e1833e9646b2903a32a5afea364403255a2f (patch) | |
tree | 13203e02df47ad21f55106fb05abc88cb2e9fcad | |
parent | 2cbdea4ed4f4903761395697ea20f8f27676e1fb (diff) | |
download | federated-compute-7935e1833e9646b2903a32a5afea364403255a2f.tar.gz |
Implement Serialize/Deserialize for GroupingFederatedSum and DPGroupingFederatedSum aggregators.
PiperOrigin-RevId: 625027133
-rw-r--r-- | fcp/aggregation/core/BUILD | 1 | ||||
-rw-r--r-- | fcp/aggregation/core/agg_core.proto | 6 | ||||
-rw-r--r-- | fcp/aggregation/core/dp_grouping_federated_sum.cc | 58 | ||||
-rw-r--r-- | fcp/aggregation/core/dp_grouping_federated_sum_test.cc | 67 | ||||
-rw-r--r-- | fcp/aggregation/core/grouping_federated_sum.cc | 47 | ||||
-rw-r--r-- | fcp/aggregation/core/grouping_federated_sum_test.cc | 129 | ||||
-rw-r--r-- | fcp/aggregation/core/one_dim_grouping_aggregator.h | 64 | ||||
-rw-r--r-- | fcp/aggregation/core/one_dim_grouping_aggregator_test.cc | 151 |
8 files changed, 454 insertions, 69 deletions
diff --git a/fcp/aggregation/core/BUILD b/fcp/aggregation/core/BUILD index 6ad82bf..0ea03d4 100644 --- a/fcp/aggregation/core/BUILD +++ b/fcp/aggregation/core/BUILD @@ -360,6 +360,7 @@ cc_test( name = "one_dim_grouping_aggregator_test", srcs = ["one_dim_grouping_aggregator_test.cc"], deps = [ + ":agg_core_cc_proto", ":aggregation_cores", ":tensor", "//fcp/aggregation/testing", diff --git a/fcp/aggregation/core/agg_core.proto b/fcp/aggregation/core/agg_core.proto index 6efbed5..83fe60f 100644 --- a/fcp/aggregation/core/agg_core.proto +++ b/fcp/aggregation/core/agg_core.proto @@ -14,3 +14,9 @@ message FederatedMeanAggregatorState { bytes weights_sum = 2; bytes weighted_values_sum = 3; } + +// Internal state representation of a OneDimGroupingAggregator. +message OneDimGroupingAggregatorState { + uint64 num_inputs = 1; + bytes vector_data = 2; +} diff --git a/fcp/aggregation/core/dp_grouping_federated_sum.cc b/fcp/aggregation/core/dp_grouping_federated_sum.cc index c1bc8bc..98c02d2 100644 --- a/fcp/aggregation/core/dp_grouping_federated_sum.cc +++ b/fcp/aggregation/core/dp_grouping_federated_sum.cc @@ -16,17 +16,19 @@ #include <cmath> #include <cstdint> #include <memory> +#include <string> #include <vector> #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "fcp/aggregation/core/agg_core.pb.h" #include "fcp/aggregation/core/agg_vector.h" #include "fcp/aggregation/core/datatype.h" #include "fcp/aggregation/core/intrinsic.h" +#include "fcp/aggregation/core/mutable_vector_data.h" #include "fcp/aggregation/core/one_dim_grouping_aggregator.h" #include "fcp/aggregation/core/tensor.pb.h" #include "fcp/aggregation/core/tensor_aggregator.h" -#include "fcp/aggregation/core/tensor_aggregator_factory.h" #include "fcp/aggregation/core/tensor_aggregator_registry.h" #include "fcp/aggregation/core/tensor_spec.h" #include "fcp/base/monitoring.h" @@ -58,6 +60,15 @@ class DPGroupingFederatedSum final l1_bound_(l1_bound), l2_bound_(l2_bound) {} + DPGroupingFederatedSum(InputT linfinity_bound, double l1_bound, + double l2_bound, + std::unique_ptr<MutableVectorData<OutputT>> data, + int num_inputs) + : OneDimGroupingAggregator<InputT, OutputT>(std::move(data), num_inputs), + linfinity_bound_(linfinity_bound), + l1_bound_(l1_bound), + l2_bound_(l2_bound) {} + private: // The following method clamps the input value to the linfinity bound. inline InputT Clamp(const InputT& input_value) { @@ -151,9 +162,9 @@ class DPGroupingFederatedSum final inline void AggregateValue(int64_t i, OutputT value) { data()[i] += value; } OutputT GetDefaultValue() override { return OutputT{0}; } - InputT linfinity_bound_; - double l1_bound_; - double l2_bound_; + const InputT linfinity_bound_; + const double l1_bound_; + const double l2_bound_; }; // The following function creates a DPGFS object with a numerical input type. @@ -161,7 +172,8 @@ class DPGroupingFederatedSum final // When the input type is floating point, the output type is always double. template <typename InputT> StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum( - InputT linfinity_bound, double l1_bound, double l2_bound) { + InputT linfinity_bound, double l1_bound, double l2_bound, + const OneDimGroupingAggregatorState* aggregator_state) { if (internal::TypeTraits<InputT>::type_kind != internal::TypeKind::kNumeric) { return FCP_STATUS(INVALID_ARGUMENT) << "DPGroupingFederatedSum only supports numeric datatypes."; @@ -176,14 +188,24 @@ StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum( switch (input_type) { case DT_INT32: case DT_INT64: - return std::unique_ptr<TensorAggregator>( - new DPGroupingFederatedSum<InputT, int64_t>(linfinity_bound, l1_bound, - l2_bound)); + return aggregator_state == nullptr + ? std::make_unique<DPGroupingFederatedSum<InputT, int64_t>>( + linfinity_bound, l1_bound, l2_bound) + : std::make_unique<DPGroupingFederatedSum<InputT, int64_t>>( + linfinity_bound, l1_bound, l2_bound, + MutableVectorData<int64_t>::CreateFromEncodedContent( + aggregator_state->vector_data()), + aggregator_state->num_inputs()); case DT_FLOAT: case DT_DOUBLE: - return std::unique_ptr<TensorAggregator>( - new DPGroupingFederatedSum<InputT, double>(linfinity_bound, l1_bound, - l2_bound)); + return aggregator_state == nullptr + ? std::make_unique<DPGroupingFederatedSum<InputT, double>>( + linfinity_bound, l1_bound, l2_bound) + : std::make_unique<DPGroupingFederatedSum<InputT, double>>( + linfinity_bound, l1_bound, l2_bound, + MutableVectorData<double>::CreateFromEncodedContent( + aggregator_state->vector_data()), + aggregator_state->num_inputs()); default: return FCP_STATUS(INVALID_ARGUMENT) << "DPGroupingFederatedSumFactory does not support " @@ -193,7 +215,8 @@ StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum( template <> StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum( - string_view linfinity_bound, double l1_bound, double l2_bound) { + string_view linfinity_bound, double l1_bound, double l2_bound, + const OneDimGroupingAggregatorState* aggregator_state) { return FCP_STATUS(INVALID_ARGUMENT) << "DPGroupingFederatedSum does not support DT_STRING."; } @@ -201,7 +224,8 @@ StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum( // A factory class for the GroupingFederatedSum. // Permits parameters in the DPGroupingFederatedSum intrinsic, // unlike GroupingFederatedSumFactory. -class DPGroupingFederatedSumFactory final : public TensorAggregatorFactory { +class DPGroupingFederatedSumFactory final + : public OneDimBaseGroupingAggregatorFactory { public: DPGroupingFederatedSumFactory() = default; @@ -210,8 +234,10 @@ class DPGroupingFederatedSumFactory final : public TensorAggregatorFactory { DPGroupingFederatedSumFactory& operator=( const DPGroupingFederatedSumFactory&) = delete; - StatusOr<std::unique_ptr<TensorAggregator>> Create( - const Intrinsic& intrinsic) const override { + private: + StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal( + const Intrinsic& intrinsic, + const OneDimGroupingAggregatorState* aggregator_state) const override { FCP_CHECK(kGoogleSqlDPSumUri == intrinsic.uri) << "DPGroupingFederatedSumFactory: Expected intrinsic URI " << kGoogleSqlDPSumUri << " but got uri " << intrinsic.uri; @@ -296,7 +322,7 @@ class DPGroupingFederatedSumFactory final : public TensorAggregatorFactory { StatusOr<std::unique_ptr<TensorAggregator>> aggregator; DTYPE_CASES(input_type, T, aggregator = CreateDPGroupingFederatedSum<T>( - linfinity_param.AsScalar<T>(), l1, l2)); + linfinity_param.AsScalar<T>(), l1, l2, aggregator_state)); return aggregator; } }; diff --git a/fcp/aggregation/core/dp_grouping_federated_sum_test.cc b/fcp/aggregation/core/dp_grouping_federated_sum_test.cc index 7f0960d..9c2e103 100644 --- a/fcp/aggregation/core/dp_grouping_federated_sum_test.cc +++ b/fcp/aggregation/core/dp_grouping_federated_sum_test.cc @@ -87,6 +87,9 @@ namespace { using ::testing::Eq; using ::testing::HasSubstr; using ::testing::IsTrue; +using testing::TestWithParam; + +using DPGroupingFederatedSumTest = TestWithParam<bool>; TensorSpec CreateTensorSpec(std::string name, DataType dtype) { return TensorSpec(name, dtype, {-1}); @@ -511,7 +514,7 @@ TEST_F(ContributionBoundingTester, AllBoundingSucceeds) { } // Test merge w/ scalar input (duplicated from grouping_federated_sum_test.cc). -TEST(DPGroupingFederatedSumTest, ScalarMergeSucceeds) { +TEST_P(DPGroupingFederatedSumTest, ScalarMergeSucceeds) { auto aggregator1 = CreateTensorAggregator(CreateDefaultIntrinsic()).value(); auto aggregator2 = CreateTensorAggregator(CreateDefaultIntrinsic()).value(); Tensor ordinal = @@ -523,6 +526,21 @@ TEST(DPGroupingFederatedSumTest, ScalarMergeSucceeds) { EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk()); EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk()); + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(CreateDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator1 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release())); + auto state = std::move(*(one_dim_base_aggregator1)).ToProto(); + aggregator1 = factory->FromProto(CreateDefaultIntrinsic(), state).value(); + auto one_dim_base_aggregator2 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release())); + auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto(); + aggregator2 = factory->FromProto(CreateDefaultIntrinsic(), state2).value(); + } + int aggregator2_num_inputs = aggregator2->GetNumInputs(); auto aggregator2_result = std::move(std::move(*aggregator2).Report().value()[0]); @@ -544,7 +562,7 @@ TEST(DPGroupingFederatedSumTest, ScalarMergeSucceeds) { } // Test merge w/ scalar input ignores norm bounding. -TEST(DPGroupingFederatedSumTest, ScalarMergeIgnoresNormBounding) { +TEST_P(DPGroupingFederatedSumTest, ScalarMergeIgnoresNormBounding) { Intrinsic intrinsic_with_norm_bounding = Intrinsic{"GoogleSQL:dp_sum", {CreateTensorSpec("value", DT_INT32)}, @@ -564,6 +582,23 @@ TEST(DPGroupingFederatedSumTest, ScalarMergeIgnoresNormBounding) { EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk()); EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk()); + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(intrinsic_with_norm_bounding.uri).value()); + auto one_dim_base_aggregator1 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release())); + auto state = std::move(*(one_dim_base_aggregator1)).ToProto(); + aggregator1 = + factory->FromProto(intrinsic_with_norm_bounding, state).value(); + auto one_dim_base_aggregator2 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release())); + auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto(); + aggregator2 = + factory->FromProto(intrinsic_with_norm_bounding, state2).value(); + } + int aggregator2_num_inputs = aggregator2->GetNumInputs(); auto aggregator2_result = std::move(std::move(*aggregator2).Report().value()[0]); @@ -585,7 +620,7 @@ TEST(DPGroupingFederatedSumTest, ScalarMergeIgnoresNormBounding) { } // Test merge w/ vector input -TEST(DPGroupingFederatedSumTest, VectorMergeSucceeds) { +TEST_P(DPGroupingFederatedSumTest, VectorMergeSucceeds) { auto aggregator1 = CreateTensorAggregator(CreateDefaultIntrinsic()).value(); Tensor alice_ordinal = Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({0, 1, 2, 1})) @@ -612,6 +647,21 @@ TEST(DPGroupingFederatedSumTest, VectorMergeSucceeds) { // After accumulating Cindy's data: [5, -5, 0, 11] EXPECT_THAT(aggregator2->Accumulate({&cindy_ordinal, &cindy_values}), IsOk()); + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(CreateDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator1 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release())); + auto state = std::move(*(one_dim_base_aggregator1)).ToProto(); + aggregator1 = factory->FromProto(CreateDefaultIntrinsic(), state).value(); + auto one_dim_base_aggregator2 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release())); + auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto(); + aggregator2 = factory->FromProto(CreateDefaultIntrinsic(), state2).value(); + } + int aggregator2_num_inputs = aggregator2->GetNumInputs(); auto aggregator2_result = std::move(std::move(*aggregator2).Report().value()[0]); @@ -751,6 +801,17 @@ TEST(DPGroupingFederatedSumTest, CatchUnsupportedNumericType) { EXPECT_THAT(s.message(), HasSubstr("does not support")); } +TEST(DPGroupingFederatedSumTest, Deserialize_Unimplemented) { + Status s = DeserializeTensorAggregator(CreateDefaultIntrinsic(), "").status(); + EXPECT_THAT(s, IsCode(UNIMPLEMENTED)); +} + +INSTANTIATE_TEST_SUITE_P( + DPGroupingFederatedSumTestInstantiation, DPGroupingFederatedSumTest, + testing::ValuesIn<bool>({false, true}), + [](const testing::TestParamInfo<DPGroupingFederatedSumTest::ParamType>& + info) { return info.param ? "SerializeDeserialize" : "None"; }); + } // namespace } // namespace aggregation } // namespace fcp diff --git a/fcp/aggregation/core/grouping_federated_sum.cc b/fcp/aggregation/core/grouping_federated_sum.cc index a9a4731..e565b79 100644 --- a/fcp/aggregation/core/grouping_federated_sum.cc +++ b/fcp/aggregation/core/grouping_federated_sum.cc @@ -18,9 +18,11 @@ #include <memory> #include <string> +#include "fcp/aggregation/core/agg_core.pb.h" #include "fcp/aggregation/core/agg_vector.h" #include "fcp/aggregation/core/datatype.h" #include "fcp/aggregation/core/intrinsic.h" +#include "fcp/aggregation/core/mutable_vector_data.h" #include "fcp/aggregation/core/one_dim_grouping_aggregator.h" #include "fcp/aggregation/core/tensor.pb.h" #include "fcp/aggregation/core/tensor_aggregator.h" @@ -89,26 +91,33 @@ class GroupingFederatedSum final }; template <typename OutputT> -StatusOr<std::unique_ptr<TensorAggregator>> CreateGroupingFederatedSum() { +StatusOr<std::unique_ptr<TensorAggregator>> CreateGroupingFederatedSum( + const OneDimGroupingAggregatorState* aggregator_state) { if (internal::TypeTraits<OutputT>::type_kind != internal::TypeKind::kNumeric) { // Ensure the type is numeric in case new non-numeric types are added. return FCP_STATUS(INVALID_ARGUMENT) << "GroupingFederatedSum is only supported for numeric datatypes."; } - return std::unique_ptr<TensorAggregator>( - new GroupingFederatedSum<OutputT, OutputT>()); + return aggregator_state == nullptr + ? std::make_unique<GroupingFederatedSum<OutputT, OutputT>>() + : std::make_unique<GroupingFederatedSum<OutputT, OutputT>>( + MutableVectorData<OutputT>::CreateFromEncodedContent( + aggregator_state->vector_data()), + aggregator_state->num_inputs()); } template <> StatusOr<std::unique_ptr<TensorAggregator>> -CreateGroupingFederatedSum<string_view>() { +CreateGroupingFederatedSum<string_view>( + const OneDimGroupingAggregatorState* aggregator_state) { return FCP_STATUS(INVALID_ARGUMENT) << "GroupingFederatedSum isn't supported for DT_STRING datatype."; } // Factory class for the GroupingFederatedSum. -class GroupingFederatedSumFactory final : public TensorAggregatorFactory { +class GroupingFederatedSumFactory final + : public OneDimBaseGroupingAggregatorFactory { public: GroupingFederatedSumFactory() = default; @@ -117,8 +126,10 @@ class GroupingFederatedSumFactory final : public TensorAggregatorFactory { GroupingFederatedSumFactory& operator=(const GroupingFederatedSumFactory&) = delete; - StatusOr<std::unique_ptr<TensorAggregator>> Create( - const Intrinsic& intrinsic) const override { + private: + StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal( + const Intrinsic& intrinsic, + const OneDimGroupingAggregatorState* aggregator_state) const override { if (kGoogleSqlSumUri != intrinsic.uri) { return FCP_STATUS(INVALID_ARGUMENT) << "GroupingFederatedSumFactory: Expected intrinsic URI " @@ -166,12 +177,20 @@ class GroupingFederatedSumFactory final : public TensorAggregatorFactory { // case. if (input_spec.dtype() == DataType::DT_INT32 && output_spec.dtype() == DataType::DT_INT64) { - return std::unique_ptr<TensorAggregator>( - new GroupingFederatedSum<int32_t, int64_t>()); + return aggregator_state == nullptr + ? std::make_unique<GroupingFederatedSum<int32_t, int64_t>>() + : std::make_unique<GroupingFederatedSum<int32_t, int64_t>>( + MutableVectorData<int64_t>::CreateFromEncodedContent( + aggregator_state->vector_data()), + aggregator_state->num_inputs()); } else if (input_spec.dtype() == DataType::DT_FLOAT && output_spec.dtype() == DataType::DT_DOUBLE) { - return std::unique_ptr<TensorAggregator>( - new GroupingFederatedSum<float, double>()); + return aggregator_state == nullptr + ? std::make_unique<GroupingFederatedSum<float, double>>() + : std::make_unique<GroupingFederatedSum<float, double>>( + MutableVectorData<double>::CreateFromEncodedContent( + aggregator_state->vector_data()), + aggregator_state->num_inputs()); } else { return FCP_STATUS(INVALID_ARGUMENT) << "GroupingFederatedSumFactory: Input and output tensors have " @@ -181,9 +200,11 @@ class GroupingFederatedSumFactory final : public TensorAggregatorFactory { << DataType_Name(output_spec.dtype()); } } + StatusOr<std::unique_ptr<TensorAggregator>> aggregator; - DTYPE_CASES(output_spec.dtype(), OutputT, - aggregator = CreateGroupingFederatedSum<OutputT>()); + DTYPE_CASES( + output_spec.dtype(), OutputT, + aggregator = CreateGroupingFederatedSum<OutputT>(aggregator_state)); return aggregator; } }; diff --git a/fcp/aggregation/core/grouping_federated_sum_test.cc b/fcp/aggregation/core/grouping_federated_sum_test.cc index df6b9bb..7c777f1 100644 --- a/fcp/aggregation/core/grouping_federated_sum_test.cc +++ b/fcp/aggregation/core/grouping_federated_sum_test.cc @@ -13,8 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include <cstdint> #include <memory> +#include <string> #include <utility> #include "gmock/gmock.h" @@ -38,6 +40,9 @@ namespace { using ::testing::Eq; using testing::HasSubstr; using ::testing::IsTrue; +using testing::TestWithParam; + +using GroupingFederatedSumTest = TestWithParam<bool>; Intrinsic GetDefaultIntrinsic() { // One "GoogleSQL:sum" intrinsic with a single int32 tensor of unknown size. @@ -48,7 +53,7 @@ Intrinsic GetDefaultIntrinsic() { {}}; } -TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) { +TEST_P(GroupingFederatedSumTest, ScalarAggregationSucceeds) { auto aggregator = CreateTensorAggregator(GetDefaultIntrinsic()).value(); Tensor ordinal = Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value(); @@ -57,6 +62,17 @@ TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) { Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value(); EXPECT_THAT(aggregator->Accumulate({&ordinal, &t1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&ordinal, &t2}), IsOk()); + + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release())); + auto state = std::move(*(one_dim_base_aggregator)).ToProto(); + aggregator = factory->FromProto(GetDefaultIntrinsic(), state).value(); + } + EXPECT_THAT(aggregator->Accumulate({&ordinal, &t3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); @@ -67,7 +83,7 @@ TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) { EXPECT_THAT(result.value()[0], IsTensor<int64_t>({1}, {6})); } -TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) { +TEST_P(GroupingFederatedSumTest, DenseAggregationSucceeds) { TensorShape shape{4}; auto aggregator = CreateTensorAggregator(GetDefaultIntrinsic()).value(); Tensor ordinals = @@ -81,6 +97,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) { Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value(); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release())); + auto state = std::move(*(one_dim_base_aggregator)).ToProto(); + aggregator = factory->FromProto(GetDefaultIntrinsic(), state).value(); + } + EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -94,15 +121,14 @@ TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) { +TEST_P(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) { TensorShape shape{4}; - auto aggregator = - CreateTensorAggregator(Intrinsic{"GoogleSQL:sum", - {TensorSpec{"foo", DT_INT32, {-1}}}, - {TensorSpec{"foo_out", DT_INT64, {-1}}}, - {}, - {}}) - .value(); + Intrinsic intrinsic{"GoogleSQL:sum", + {TensorSpec{"foo", DT_INT32, {-1}}}, + {TensorSpec{"foo_out", DT_INT64, {-1}}}, + {}, + {}}; + auto aggregator = CreateTensorAggregator(intrinsic).value(); Tensor ordinals = Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3})) .value(); @@ -114,6 +140,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) { Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value(); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release())); + auto state = std::move(*(one_dim_base_aggregator)).ToProto(); + aggregator = factory->FromProto(intrinsic, state).value(); + } + EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -127,15 +164,15 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) { +TEST_P(GroupingFederatedSumTest, + DenseAggregationCastToLargerFloatTypeSucceeds) { TensorShape shape{4}; - auto aggregator = - CreateTensorAggregator(Intrinsic{"GoogleSQL:sum", - {TensorSpec{"foo", DT_FLOAT, {-1}}}, - {TensorSpec{"foo_out", DT_DOUBLE, {-1}}}, - {}, - {}}) - .value(); + Intrinsic intrinsic{"GoogleSQL:sum", + {TensorSpec{"foo", DT_FLOAT, {-1}}}, + {TensorSpec{"foo_out", DT_DOUBLE, {-1}}}, + {}, + {}}; + auto aggregator = CreateTensorAggregator(intrinsic).value(); Tensor ordinals = Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3})) .value(); @@ -150,6 +187,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) { .value(); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release())); + auto state = std::move(*(one_dim_base_aggregator)).ToProto(); + aggregator = factory->FromProto(intrinsic, state).value(); + } + EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -163,7 +211,7 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(GroupingFederatedSumTest, MergeSucceeds) { +TEST_P(GroupingFederatedSumTest, MergeSucceeds) { auto aggregator1 = CreateTensorAggregator(GetDefaultIntrinsic()).value(); auto aggregator2 = CreateTensorAggregator(GetDefaultIntrinsic()).value(); Tensor ordinal = @@ -175,6 +223,21 @@ TEST(GroupingFederatedSumTest, MergeSucceeds) { EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk()); EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk()); + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator1 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release())); + auto state = std::move(*(one_dim_base_aggregator1)).ToProto(); + aggregator1 = factory->FromProto(GetDefaultIntrinsic(), state).value(); + auto one_dim_base_aggregator2 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release())); + auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto(); + aggregator2 = factory->FromProto(GetDefaultIntrinsic(), state2).value(); + } + int aggregator2_num_inputs = aggregator2->GetNumInputs(); auto aggregator2_result = std::move(std::move(*aggregator2).Report().value()[0]); @@ -194,7 +257,7 @@ TEST(GroupingFederatedSumTest, MergeSucceeds) { EXPECT_THAT(result.value()[0], IsTensor<int64_t>({1}, {6})); } -TEST(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) { +TEST_P(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) { auto aggregator1 = CreateTensorAggregator(GetDefaultIntrinsic()).value(); auto aggregator2 = CreateTensorAggregator(GetDefaultIntrinsic()).value(); Tensor ordinal = @@ -206,6 +269,21 @@ TEST(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) { EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk()); EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk()); + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator1 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release())); + auto state = std::move(*(one_dim_base_aggregator1)).ToProto(); + aggregator1 = factory->FromProto(GetDefaultIntrinsic(), state).value(); + auto one_dim_base_aggregator2 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release())); + auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto(); + aggregator2 = factory->FromProto(GetDefaultIntrinsic(), state2).value(); + } + int aggregator2_num_inputs = aggregator2->GetNumInputs(); auto aggregator2_result = std::move(std::move(*aggregator2).Report().value()[0]); @@ -343,6 +421,17 @@ TEST(GroupingFederatedSumTest, CreateUnsupportedStringDataType) { HasSubstr("GroupingFederatedSum isn't supported for DT_STRING datatype")); } +TEST(GroupingFederatedSumTest, Deserialize_Unimplemented) { + Status s = DeserializeTensorAggregator(GetDefaultIntrinsic(), "").status(); + EXPECT_THAT(s, IsCode(UNIMPLEMENTED)); +} + +INSTANTIATE_TEST_SUITE_P( + GroupingFederatedSumTestInstantiation, GroupingFederatedSumTest, + testing::ValuesIn<bool>({false, true}), + [](const testing::TestParamInfo<GroupingFederatedSumTest::ParamType>& + info) { return info.param ? "SaveIntermediateState" : "None"; }); + } // namespace } // namespace aggregation } // namespace fcp diff --git a/fcp/aggregation/core/one_dim_grouping_aggregator.h b/fcp/aggregation/core/one_dim_grouping_aggregator.h index 926f4f4..d7fe498 100644 --- a/fcp/aggregation/core/one_dim_grouping_aggregator.h +++ b/fcp/aggregation/core/one_dim_grouping_aggregator.h @@ -20,15 +20,19 @@ #include <cstddef> #include <cstdint> #include <memory> +#include <string> #include <utility> #include <vector> +#include "fcp/aggregation/core/agg_core.pb.h" #include "fcp/aggregation/core/agg_vector.h" #include "fcp/aggregation/core/datatype.h" #include "fcp/aggregation/core/input_tensor_list.h" +#include "fcp/aggregation/core/intrinsic.h" #include "fcp/aggregation/core/mutable_vector_data.h" #include "fcp/aggregation/core/tensor.h" #include "fcp/aggregation/core/tensor_aggregator.h" +#include "fcp/aggregation/core/tensor_aggregator_factory.h" #include "fcp/aggregation/core/tensor_shape.h" #include "fcp/base/monitoring.h" @@ -50,6 +54,15 @@ class OneDimBaseGroupingAggregator : public TensorAggregator { public: Status MergeWith(TensorAggregator&& other) override; + StatusOr<std::string> Serialize() && override { + // OneDimBaseGroupingAggregators are always nested within an outer + // aggregator. Use ToProto to get intermediate state and then serialize the + // outer aggregator state instead. + return FCP_STATUS(UNIMPLEMENTED) + << "OneDimBaseGroupingAggregator::Serialize is not supported. Use " + "ToProto to store intermediate state."; + } + // Merges intermediate aggregates contained in the tensors param into the // current Aggregator instance. Expects a tensors param of size 2, where the // first tensor contains ordinals and the second tensor contains values. The @@ -66,11 +79,47 @@ class OneDimBaseGroupingAggregator : public TensorAggregator { // derived class. virtual Status MergeTensors(InputTensorList tensors, int num_inputs) = 0; + // Stores the intermediate state of the OneDimBaseGroupingAggregator as a + // proto. + virtual OneDimGroupingAggregatorState ToProto() = 0; + protected: // Checks that the input tensors param is valid. Status ValidateTensorInputs(const InputTensorList& tensors); }; +class OneDimBaseGroupingAggregatorFactory : public TensorAggregatorFactory { + public: + StatusOr<std::unique_ptr<TensorAggregator>> Create( + const Intrinsic& intrinsic) const override { + return CreateInternal(intrinsic, nullptr); + } + + StatusOr<std::unique_ptr<TensorAggregator>> Deserialize( + const Intrinsic& intrinsic, std::string serialized_state) const override { + OneDimGroupingAggregatorState aggregator_state; + // OneDimGroupingAggregators are always nested within an outer aggregator. + // Use FromProto to create the aggregator from intermediate state stored by + // the outer aggregator. + return FCP_STATUS(UNIMPLEMENTED) + << "OneDimBaseGroupingAggregatorFactory::Deserialize is not " + "supported. Use FromProto to create an aggregator from " + "intermediate state."; + } + + // Creates a OneDimBaseGroupingAggregator from intermediate state. + StatusOr<std::unique_ptr<TensorAggregator>> FromProto( + const Intrinsic& intrinsic, + const OneDimGroupingAggregatorState& aggregator_state) const { + return CreateInternal(intrinsic, &aggregator_state); + } + + private: + virtual StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal( + const Intrinsic& intrinsic, + const OneDimGroupingAggregatorState* aggregator_state) const = 0; +}; + // OneDimGroupingAggregator class is a specialization of // OneDimBaseGroupingAggregator. // @@ -88,8 +137,12 @@ class OneDimGroupingAggregator : public OneDimBaseGroupingAggregator { // to the ordinal tensor) should be known in advance and thus this constructor // should take in a shape with a single unknown dimension. OneDimGroupingAggregator() - : data_vector_(std::make_unique<MutableVectorData<OutputT>>()), - num_inputs_(0) {} + : OneDimGroupingAggregator(std::make_unique<MutableVectorData<OutputT>>(), + 0) {} + + OneDimGroupingAggregator(std::unique_ptr<MutableVectorData<OutputT>> data, + int num_inputs) + : data_vector_(std::move(data)), num_inputs_(num_inputs) {} // Implementation of the tensor merge operation. Status MergeTensors(InputTensorList tensors, int num_inputs) override { @@ -121,6 +174,13 @@ class OneDimGroupingAggregator : public OneDimBaseGroupingAggregator { return FCP_STATUS(OK); } + OneDimGroupingAggregatorState ToProto() override { + OneDimGroupingAggregatorState aggregator_state; + aggregator_state.set_num_inputs(num_inputs_); + *(aggregator_state.mutable_vector_data()) = data_vector_->EncodeContent(); + return aggregator_state; + } + protected: // Provides mutable access to the aggregator data as a vector<T> inline std::vector<OutputT>& data() { return *data_vector_; } diff --git a/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc b/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc index 6095d88..fc44cce 100644 --- a/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc +++ b/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc @@ -17,11 +17,14 @@ #include <climits> #include <cstdint> +#include <memory> #include <utility> #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "fcp/aggregation/core/agg_core.pb.h" #include "fcp/aggregation/core/agg_vector.h" +#include "fcp/aggregation/core/mutable_vector_data.h" #include "fcp/aggregation/core/tensor.h" #include "fcp/aggregation/core/tensor_shape.h" #include "fcp/aggregation/testing/test_data.h" @@ -37,6 +40,9 @@ using testing::Eq; using testing::HasSubstr; using testing::IsFalse; using testing::IsTrue; +using testing::TestWithParam; + +using OneDimGroupingAggregatorTest = TestWithParam<bool>; // A simple Sum Aggregator template <typename InputT, typename OutputT = InputT> @@ -46,6 +52,14 @@ class SumGroupingAggregator final using OneDimGroupingAggregator<InputT, OutputT>::OneDimGroupingAggregator; using OneDimGroupingAggregator<InputT, OutputT>::data; + static SumGroupingAggregator<InputT, OutputT> FromProto( + const OneDimGroupingAggregatorState& aggregator_state) { + return SumGroupingAggregator<InputT, OutputT>( + MutableVectorData<OutputT>::CreateFromEncodedContent( + aggregator_state.vector_data()), + aggregator_state.num_inputs()); + } + private: void AggregateVectorByOrdinals( const AggVector<int64_t>& ordinals_vector, @@ -94,6 +108,14 @@ class MinGroupingAggregator final : public OneDimGroupingAggregator<int32_t> { using OneDimGroupingAggregator<int32_t>::OneDimGroupingAggregator; using OneDimGroupingAggregator<int32_t>::data; + static MinGroupingAggregator FromProto( + const OneDimGroupingAggregatorState& aggregator_state) { + return MinGroupingAggregator( + MutableVectorData<int32_t>::CreateFromEncodedContent( + aggregator_state.vector_data()), + aggregator_state.num_inputs()); + } + private: void AggregateVectorByOrdinals( const AggVector<int64_t>& ordinals_vector, @@ -132,15 +154,21 @@ class MinGroupingAggregator final : public OneDimGroupingAggregator<int32_t> { int32_t GetDefaultValue() override { return INT_MAX; } }; -TEST(OneDimGroupingAggregatorTest, EmptyReport) { +TEST_P(OneDimGroupingAggregatorTest, EmptyReport) { SumGroupingAggregator<int32_t> aggregator; + + if (GetParam()) { + auto state = std::move(aggregator).ToProto(); + aggregator = SumGroupingAggregator<int32_t>::FromProto(state); + } + auto result = std::move(aggregator).Report(); EXPECT_THAT(result, IsOk()); EXPECT_THAT(result->size(), Eq(1)); EXPECT_THAT(result.value()[0], IsTensor<int32_t>({0}, {})); } -TEST(OneDimGroupingAggregatorTest, ScalarAggregation_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, ScalarAggregation_Succeeds) { SumGroupingAggregator<int32_t> aggregator; Tensor ordinal = Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value(); @@ -149,6 +177,12 @@ TEST(OneDimGroupingAggregatorTest, ScalarAggregation_Succeeds) { Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value(); EXPECT_THAT(aggregator.Accumulate({&ordinal, &t1}), IsOk()); EXPECT_THAT(aggregator.Accumulate({&ordinal, &t2}), IsOk()); + + if (GetParam()) { + auto state = std::move(aggregator).ToProto(); + aggregator = SumGroupingAggregator<int32_t>::FromProto(state); + } + EXPECT_THAT(aggregator.Accumulate({&ordinal, &t3}), IsOk()); EXPECT_THAT(aggregator.CanReport(), IsTrue()); @@ -159,7 +193,7 @@ TEST(OneDimGroupingAggregatorTest, ScalarAggregation_Succeeds) { EXPECT_THAT(result.value()[0], IsTensor({1}, {6})); } -TEST(OneDimGroupingAggregatorTest, DenseAggregation_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, DenseAggregation_Succeeds) { const TensorShape shape = {4}; SumGroupingAggregator<int32_t> aggregator; Tensor ordinals = @@ -173,6 +207,12 @@ TEST(OneDimGroupingAggregatorTest, DenseAggregation_Succeeds) { Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value(); EXPECT_THAT(aggregator.Accumulate({&ordinals, &t1}), IsOk()); EXPECT_THAT(aggregator.Accumulate({&ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto state = std::move(aggregator).ToProto(); + aggregator = SumGroupingAggregator<int32_t>::FromProto(state); + } + EXPECT_THAT(aggregator.Accumulate({&ordinals, &t3}), IsOk()); EXPECT_THAT(aggregator.CanReport(), IsTrue()); EXPECT_THAT(aggregator.GetNumInputs(), Eq(3)); @@ -186,7 +226,7 @@ TEST(OneDimGroupingAggregatorTest, DenseAggregation_Succeeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(OneDimGroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) { const TensorShape shape = {4}; SumGroupingAggregator<int32_t> aggregator; Tensor t1_ordinals = @@ -202,6 +242,12 @@ TEST(OneDimGroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) { Tensor t2 = Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value(); EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto state = std::move(aggregator).ToProto(); + aggregator = SumGroupingAggregator<int32_t>::FromProto(state); + } + // Totals: [32, 11, 15, 4, 2] Tensor t3_ordinals = Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({2, 2, 5, 1})) @@ -222,7 +268,7 @@ TEST(OneDimGroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(OneDimGroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) { SumGroupingAggregator<int32_t> aggregator; Tensor t1_ordinals = Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 0})).value(); @@ -236,6 +282,12 @@ TEST(OneDimGroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) { Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, 5})) .value(); EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto state = std::move(aggregator).ToProto(); + aggregator = SumGroupingAggregator<int32_t>::FromProto(state); + } + // Totals: [13, 23, 17, 4, 2] Tensor t3_ordinals = Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4})) @@ -256,8 +308,8 @@ TEST(OneDimGroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(OneDimGroupingAggregatorTest, - DifferentShapesPerAccumulate_NonzeroDefaultValue_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, + DifferentShapesPerAccumulate_NonzeroDefaultValue_Succeeds) { // Use a MinGroupingAggregator which has a non-zero default value so we can // test that when the output grows, elements are set to the default value. MinGroupingAggregator aggregator; @@ -273,6 +325,12 @@ TEST(OneDimGroupingAggregatorTest, Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, -50})) .value(); EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto state = std::move(aggregator).ToProto(); + aggregator = MinGroupingAggregator::FromProto(state); + } + // Totals: [-50, INT_MAX, 17, INT_MAX, 2] Tensor t3_ordinals = Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4})) @@ -293,7 +351,7 @@ TEST(OneDimGroupingAggregatorTest, EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(OneDimGroupingAggregatorTest, Merge_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, Merge_Succeeds) { SumGroupingAggregator<int32_t> aggregator1; SumGroupingAggregator<int32_t> aggregator2; Tensor ordinal = @@ -305,6 +363,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_Succeeds) { EXPECT_THAT(aggregator2.Accumulate({&ordinal, &t2}), IsOk()); EXPECT_THAT(aggregator2.Accumulate({&ordinal, &t3}), IsOk()); + if (GetParam()) { + auto state1 = std::move(aggregator1).ToProto(); + aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1); + auto state2 = std::move(aggregator2).ToProto(); + aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2); + } + int aggregator2_num_inputs = aggregator2.GetNumInputs(); EXPECT_THAT(aggregator2_num_inputs, Eq(2)); auto aggregator2_result = std::move(aggregator2).Report(); @@ -324,10 +389,17 @@ TEST(OneDimGroupingAggregatorTest, Merge_Succeeds) { EXPECT_THAT(result.value()[0], IsTensor({1}, {6})); } -TEST(OneDimGroupingAggregatorTest, Merge_BothEmpty_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, Merge_BothEmpty_Succeeds) { SumGroupingAggregator<int32_t> aggregator1; SumGroupingAggregator<int32_t> aggregator2; + if (GetParam()) { + auto state1 = std::move(aggregator1).ToProto(); + aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1); + auto state2 = std::move(aggregator2).ToProto(); + aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2); + } + // Merge the two empty aggregators together. auto empty_ordinals = Tensor::Create(DT_INT64, {0}, CreateTestData<int64_t>({})).value(); @@ -346,7 +418,7 @@ TEST(OneDimGroupingAggregatorTest, Merge_BothEmpty_Succeeds) { EXPECT_THAT(result.value()[0], IsTensor<int32_t>({0}, {})); } -TEST(OneDimGroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) { SumGroupingAggregator<int32_t> aggregator1; SumGroupingAggregator<int32_t> aggregator2; @@ -365,6 +437,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) { EXPECT_THAT(aggregator2.Accumulate({&t2_ordinals, &t2}), IsOk()); // aggregator2 totals: [32, 11, 15, 4, 2] + if (GetParam()) { + auto state1 = std::move(aggregator1).ToProto(); + aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1); + auto state2 = std::move(aggregator2).ToProto(); + aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2); + } + // Merge aggregator2 into aggregator1 which has not received any inputs. int aggregator2_num_inputs = aggregator2.GetNumInputs(); auto aggregator2_result = @@ -388,7 +467,7 @@ TEST(OneDimGroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) { SumGroupingAggregator<int32_t> aggregator1; SumGroupingAggregator<int32_t> aggregator2; @@ -407,6 +486,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) { EXPECT_THAT(aggregator1.Accumulate({&t2_ordinals, &t2}), IsOk()); // aggregator1 totals: [32, 11, 15, 4, 2] + if (GetParam()) { + auto state1 = std::move(aggregator1).ToProto(); + aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1); + auto state2 = std::move(aggregator2).ToProto(); + aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2); + } + // Merge with aggregator2 which has not received any inputs. auto empty_ordinals = Tensor::Create(DT_INT64, {0}, CreateTestData<int64_t>({})).value(); @@ -428,7 +514,8 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasFewerElements_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, + Merge_OtherOutputHasFewerElements_Succeeds) { SumGroupingAggregator<int32_t> aggregator1; SumGroupingAggregator<int32_t> aggregator2; @@ -453,6 +540,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasFewerElements_Succeeds) { EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk()); // aggregator2 totals: [0, 0, 14] + if (GetParam()) { + auto state1 = std::move(aggregator1).ToProto(); + aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1); + auto state2 = std::move(aggregator2).ToProto(); + aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2); + } + int aggregator2_num_inputs = aggregator2.GetNumInputs(); auto aggregator2_result = std::move(std::move(aggregator2).Report().value()[0]); @@ -474,7 +568,8 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasFewerElements_Succeeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasMoreElements_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, + Merge_OtherOutputHasMoreElements_Succeeds) { SumGroupingAggregator<int32_t> aggregator1; SumGroupingAggregator<int32_t> aggregator2; @@ -501,6 +596,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasMoreElements_Succeeds) { EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk()); // aggregator2 totals: [0, 20, 14, 0, 0, 7] + if (GetParam()) { + auto state1 = std::move(aggregator1).ToProto(); + aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1); + auto state2 = std::move(aggregator2).ToProto(); + aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2); + } + int aggregator2_num_inputs = aggregator2.GetNumInputs(); auto aggregator2_result = std::move(std::move(aggregator2).Report().value()[0]); @@ -523,8 +625,8 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasMoreElements_Succeeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(OneDimGroupingAggregatorTest, - Merge_OtherOutputHasMoreElements_NonzeroDefaultValue_Succeeds) { +TEST_P(OneDimGroupingAggregatorTest, + Merge_OtherOutputHasMoreElements_NonzeroDefaultValue_Succeeds) { // Use a MinGroupingAggregator which has a non-zero default value so we can // test that when the output grows, elements are set to the default value. MinGroupingAggregator aggregator1; @@ -551,6 +653,13 @@ TEST(OneDimGroupingAggregatorTest, EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk()); // aggregator2 totals: [-50, 7, 11, INT_MAX, 2] + if (GetParam()) { + auto state1 = std::move(aggregator1).ToProto(); + aggregator1 = MinGroupingAggregator::FromProto(state1); + auto state2 = std::move(aggregator2).ToProto(); + aggregator2 = MinGroupingAggregator::FromProto(state2); + } + int aggregator2_num_inputs = aggregator2.GetNumInputs(); auto aggregator2_result = std::move(std::move(aggregator2).Report().value()[0]); @@ -657,6 +766,18 @@ TEST(OneDimGroupingAggregatorTest, FailsAfterBeingConsumed) { IsCode(FAILED_PRECONDITION)); } +TEST(OneDimGroupingAggregatorTest, Serialize_Unimplmeneted) { + SumGroupingAggregator<int32_t> aggregator; + Status s = std::move(aggregator).Serialize().status(); + EXPECT_THAT(s, IsCode(UNIMPLEMENTED)); +} + +INSTANTIATE_TEST_SUITE_P( + OneDimGroupingAggregatorTestInstantiation, OneDimGroupingAggregatorTest, + testing::ValuesIn<bool>({false, true}), + [](const testing::TestParamInfo<OneDimGroupingAggregatorTest::ParamType>& + info) { return info.param ? "SaveIntermediateState" : "None"; }); + } // namespace } // namespace aggregation } // namespace fcp |