aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaya Spivak <mspivak@google.com>2024-04-15 11:30:25 -0700
committerCopybara-Service <copybara-worker@google.com>2024-04-15 11:32:16 -0700
commit7935e1833e9646b2903a32a5afea364403255a2f (patch)
tree13203e02df47ad21f55106fb05abc88cb2e9fcad
parent2cbdea4ed4f4903761395697ea20f8f27676e1fb (diff)
downloadfederated-compute-7935e1833e9646b2903a32a5afea364403255a2f.tar.gz
Implement Serialize/Deserialize for GroupingFederatedSum and DPGroupingFederatedSum aggregators.
PiperOrigin-RevId: 625027133
-rw-r--r--fcp/aggregation/core/BUILD1
-rw-r--r--fcp/aggregation/core/agg_core.proto6
-rw-r--r--fcp/aggregation/core/dp_grouping_federated_sum.cc58
-rw-r--r--fcp/aggregation/core/dp_grouping_federated_sum_test.cc67
-rw-r--r--fcp/aggregation/core/grouping_federated_sum.cc47
-rw-r--r--fcp/aggregation/core/grouping_federated_sum_test.cc129
-rw-r--r--fcp/aggregation/core/one_dim_grouping_aggregator.h64
-rw-r--r--fcp/aggregation/core/one_dim_grouping_aggregator_test.cc151
8 files changed, 454 insertions, 69 deletions
diff --git a/fcp/aggregation/core/BUILD b/fcp/aggregation/core/BUILD
index 6ad82bf..0ea03d4 100644
--- a/fcp/aggregation/core/BUILD
+++ b/fcp/aggregation/core/BUILD
@@ -360,6 +360,7 @@ cc_test(
name = "one_dim_grouping_aggregator_test",
srcs = ["one_dim_grouping_aggregator_test.cc"],
deps = [
+ ":agg_core_cc_proto",
":aggregation_cores",
":tensor",
"//fcp/aggregation/testing",
diff --git a/fcp/aggregation/core/agg_core.proto b/fcp/aggregation/core/agg_core.proto
index 6efbed5..83fe60f 100644
--- a/fcp/aggregation/core/agg_core.proto
+++ b/fcp/aggregation/core/agg_core.proto
@@ -14,3 +14,9 @@ message FederatedMeanAggregatorState {
bytes weights_sum = 2;
bytes weighted_values_sum = 3;
}
+
+// Internal state representation of a OneDimGroupingAggregator.
+message OneDimGroupingAggregatorState {
+ uint64 num_inputs = 1;
+ bytes vector_data = 2;
+}
diff --git a/fcp/aggregation/core/dp_grouping_federated_sum.cc b/fcp/aggregation/core/dp_grouping_federated_sum.cc
index c1bc8bc..98c02d2 100644
--- a/fcp/aggregation/core/dp_grouping_federated_sum.cc
+++ b/fcp/aggregation/core/dp_grouping_federated_sum.cc
@@ -16,17 +16,19 @@
#include <cmath>
#include <cstdint>
#include <memory>
+#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/agg_vector.h"
#include "fcp/aggregation/core/datatype.h"
#include "fcp/aggregation/core/intrinsic.h"
+#include "fcp/aggregation/core/mutable_vector_data.h"
#include "fcp/aggregation/core/one_dim_grouping_aggregator.h"
#include "fcp/aggregation/core/tensor.pb.h"
#include "fcp/aggregation/core/tensor_aggregator.h"
-#include "fcp/aggregation/core/tensor_aggregator_factory.h"
#include "fcp/aggregation/core/tensor_aggregator_registry.h"
#include "fcp/aggregation/core/tensor_spec.h"
#include "fcp/base/monitoring.h"
@@ -58,6 +60,15 @@ class DPGroupingFederatedSum final
l1_bound_(l1_bound),
l2_bound_(l2_bound) {}
+ DPGroupingFederatedSum(InputT linfinity_bound, double l1_bound,
+ double l2_bound,
+ std::unique_ptr<MutableVectorData<OutputT>> data,
+ int num_inputs)
+ : OneDimGroupingAggregator<InputT, OutputT>(std::move(data), num_inputs),
+ linfinity_bound_(linfinity_bound),
+ l1_bound_(l1_bound),
+ l2_bound_(l2_bound) {}
+
private:
// The following method clamps the input value to the linfinity bound.
inline InputT Clamp(const InputT& input_value) {
@@ -151,9 +162,9 @@ class DPGroupingFederatedSum final
inline void AggregateValue(int64_t i, OutputT value) { data()[i] += value; }
OutputT GetDefaultValue() override { return OutputT{0}; }
- InputT linfinity_bound_;
- double l1_bound_;
- double l2_bound_;
+ const InputT linfinity_bound_;
+ const double l1_bound_;
+ const double l2_bound_;
};
// The following function creates a DPGFS object with a numerical input type.
@@ -161,7 +172,8 @@ class DPGroupingFederatedSum final
// When the input type is floating point, the output type is always double.
template <typename InputT>
StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum(
- InputT linfinity_bound, double l1_bound, double l2_bound) {
+ InputT linfinity_bound, double l1_bound, double l2_bound,
+ const OneDimGroupingAggregatorState* aggregator_state) {
if (internal::TypeTraits<InputT>::type_kind != internal::TypeKind::kNumeric) {
return FCP_STATUS(INVALID_ARGUMENT)
<< "DPGroupingFederatedSum only supports numeric datatypes.";
@@ -176,14 +188,24 @@ StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum(
switch (input_type) {
case DT_INT32:
case DT_INT64:
- return std::unique_ptr<TensorAggregator>(
- new DPGroupingFederatedSum<InputT, int64_t>(linfinity_bound, l1_bound,
- l2_bound));
+ return aggregator_state == nullptr
+ ? std::make_unique<DPGroupingFederatedSum<InputT, int64_t>>(
+ linfinity_bound, l1_bound, l2_bound)
+ : std::make_unique<DPGroupingFederatedSum<InputT, int64_t>>(
+ linfinity_bound, l1_bound, l2_bound,
+ MutableVectorData<int64_t>::CreateFromEncodedContent(
+ aggregator_state->vector_data()),
+ aggregator_state->num_inputs());
case DT_FLOAT:
case DT_DOUBLE:
- return std::unique_ptr<TensorAggregator>(
- new DPGroupingFederatedSum<InputT, double>(linfinity_bound, l1_bound,
- l2_bound));
+ return aggregator_state == nullptr
+ ? std::make_unique<DPGroupingFederatedSum<InputT, double>>(
+ linfinity_bound, l1_bound, l2_bound)
+ : std::make_unique<DPGroupingFederatedSum<InputT, double>>(
+ linfinity_bound, l1_bound, l2_bound,
+ MutableVectorData<double>::CreateFromEncodedContent(
+ aggregator_state->vector_data()),
+ aggregator_state->num_inputs());
default:
return FCP_STATUS(INVALID_ARGUMENT)
<< "DPGroupingFederatedSumFactory does not support "
@@ -193,7 +215,8 @@ StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum(
template <>
StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum(
- string_view linfinity_bound, double l1_bound, double l2_bound) {
+ string_view linfinity_bound, double l1_bound, double l2_bound,
+ const OneDimGroupingAggregatorState* aggregator_state) {
return FCP_STATUS(INVALID_ARGUMENT)
<< "DPGroupingFederatedSum does not support DT_STRING.";
}
@@ -201,7 +224,8 @@ StatusOr<std::unique_ptr<TensorAggregator>> CreateDPGroupingFederatedSum(
// A factory class for the GroupingFederatedSum.
// Permits parameters in the DPGroupingFederatedSum intrinsic,
// unlike GroupingFederatedSumFactory.
-class DPGroupingFederatedSumFactory final : public TensorAggregatorFactory {
+class DPGroupingFederatedSumFactory final
+ : public OneDimBaseGroupingAggregatorFactory {
public:
DPGroupingFederatedSumFactory() = default;
@@ -210,8 +234,10 @@ class DPGroupingFederatedSumFactory final : public TensorAggregatorFactory {
DPGroupingFederatedSumFactory& operator=(
const DPGroupingFederatedSumFactory&) = delete;
- StatusOr<std::unique_ptr<TensorAggregator>> Create(
- const Intrinsic& intrinsic) const override {
+ private:
+ StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal(
+ const Intrinsic& intrinsic,
+ const OneDimGroupingAggregatorState* aggregator_state) const override {
FCP_CHECK(kGoogleSqlDPSumUri == intrinsic.uri)
<< "DPGroupingFederatedSumFactory: Expected intrinsic URI "
<< kGoogleSqlDPSumUri << " but got uri " << intrinsic.uri;
@@ -296,7 +322,7 @@ class DPGroupingFederatedSumFactory final : public TensorAggregatorFactory {
StatusOr<std::unique_ptr<TensorAggregator>> aggregator;
DTYPE_CASES(input_type, T,
aggregator = CreateDPGroupingFederatedSum<T>(
- linfinity_param.AsScalar<T>(), l1, l2));
+ linfinity_param.AsScalar<T>(), l1, l2, aggregator_state));
return aggregator;
}
};
diff --git a/fcp/aggregation/core/dp_grouping_federated_sum_test.cc b/fcp/aggregation/core/dp_grouping_federated_sum_test.cc
index 7f0960d..9c2e103 100644
--- a/fcp/aggregation/core/dp_grouping_federated_sum_test.cc
+++ b/fcp/aggregation/core/dp_grouping_federated_sum_test.cc
@@ -87,6 +87,9 @@ namespace {
using ::testing::Eq;
using ::testing::HasSubstr;
using ::testing::IsTrue;
+using testing::TestWithParam;
+
+using DPGroupingFederatedSumTest = TestWithParam<bool>;
TensorSpec CreateTensorSpec(std::string name, DataType dtype) {
return TensorSpec(name, dtype, {-1});
@@ -511,7 +514,7 @@ TEST_F(ContributionBoundingTester, AllBoundingSucceeds) {
}
// Test merge w/ scalar input (duplicated from grouping_federated_sum_test.cc).
-TEST(DPGroupingFederatedSumTest, ScalarMergeSucceeds) {
+TEST_P(DPGroupingFederatedSumTest, ScalarMergeSucceeds) {
auto aggregator1 = CreateTensorAggregator(CreateDefaultIntrinsic()).value();
auto aggregator2 = CreateTensorAggregator(CreateDefaultIntrinsic()).value();
Tensor ordinal =
@@ -523,6 +526,21 @@ TEST(DPGroupingFederatedSumTest, ScalarMergeSucceeds) {
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk());
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk());
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(CreateDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator1 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release()));
+ auto state = std::move(*(one_dim_base_aggregator1)).ToProto();
+ aggregator1 = factory->FromProto(CreateDefaultIntrinsic(), state).value();
+ auto one_dim_base_aggregator2 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release()));
+ auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto();
+ aggregator2 = factory->FromProto(CreateDefaultIntrinsic(), state2).value();
+ }
+
int aggregator2_num_inputs = aggregator2->GetNumInputs();
auto aggregator2_result =
std::move(std::move(*aggregator2).Report().value()[0]);
@@ -544,7 +562,7 @@ TEST(DPGroupingFederatedSumTest, ScalarMergeSucceeds) {
}
// Test merge w/ scalar input ignores norm bounding.
-TEST(DPGroupingFederatedSumTest, ScalarMergeIgnoresNormBounding) {
+TEST_P(DPGroupingFederatedSumTest, ScalarMergeIgnoresNormBounding) {
Intrinsic intrinsic_with_norm_bounding =
Intrinsic{"GoogleSQL:dp_sum",
{CreateTensorSpec("value", DT_INT32)},
@@ -564,6 +582,23 @@ TEST(DPGroupingFederatedSumTest, ScalarMergeIgnoresNormBounding) {
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk());
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk());
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(intrinsic_with_norm_bounding.uri).value());
+ auto one_dim_base_aggregator1 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release()));
+ auto state = std::move(*(one_dim_base_aggregator1)).ToProto();
+ aggregator1 =
+ factory->FromProto(intrinsic_with_norm_bounding, state).value();
+ auto one_dim_base_aggregator2 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release()));
+ auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto();
+ aggregator2 =
+ factory->FromProto(intrinsic_with_norm_bounding, state2).value();
+ }
+
int aggregator2_num_inputs = aggregator2->GetNumInputs();
auto aggregator2_result =
std::move(std::move(*aggregator2).Report().value()[0]);
@@ -585,7 +620,7 @@ TEST(DPGroupingFederatedSumTest, ScalarMergeIgnoresNormBounding) {
}
// Test merge w/ vector input
-TEST(DPGroupingFederatedSumTest, VectorMergeSucceeds) {
+TEST_P(DPGroupingFederatedSumTest, VectorMergeSucceeds) {
auto aggregator1 = CreateTensorAggregator(CreateDefaultIntrinsic()).value();
Tensor alice_ordinal =
Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({0, 1, 2, 1}))
@@ -612,6 +647,21 @@ TEST(DPGroupingFederatedSumTest, VectorMergeSucceeds) {
// After accumulating Cindy's data: [5, -5, 0, 11]
EXPECT_THAT(aggregator2->Accumulate({&cindy_ordinal, &cindy_values}), IsOk());
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(CreateDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator1 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release()));
+ auto state = std::move(*(one_dim_base_aggregator1)).ToProto();
+ aggregator1 = factory->FromProto(CreateDefaultIntrinsic(), state).value();
+ auto one_dim_base_aggregator2 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release()));
+ auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto();
+ aggregator2 = factory->FromProto(CreateDefaultIntrinsic(), state2).value();
+ }
+
int aggregator2_num_inputs = aggregator2->GetNumInputs();
auto aggregator2_result =
std::move(std::move(*aggregator2).Report().value()[0]);
@@ -751,6 +801,17 @@ TEST(DPGroupingFederatedSumTest, CatchUnsupportedNumericType) {
EXPECT_THAT(s.message(), HasSubstr("does not support"));
}
+TEST(DPGroupingFederatedSumTest, Deserialize_Unimplemented) {
+ Status s = DeserializeTensorAggregator(CreateDefaultIntrinsic(), "").status();
+ EXPECT_THAT(s, IsCode(UNIMPLEMENTED));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ DPGroupingFederatedSumTestInstantiation, DPGroupingFederatedSumTest,
+ testing::ValuesIn<bool>({false, true}),
+ [](const testing::TestParamInfo<DPGroupingFederatedSumTest::ParamType>&
+ info) { return info.param ? "SerializeDeserialize" : "None"; });
+
} // namespace
} // namespace aggregation
} // namespace fcp
diff --git a/fcp/aggregation/core/grouping_federated_sum.cc b/fcp/aggregation/core/grouping_federated_sum.cc
index a9a4731..e565b79 100644
--- a/fcp/aggregation/core/grouping_federated_sum.cc
+++ b/fcp/aggregation/core/grouping_federated_sum.cc
@@ -18,9 +18,11 @@
#include <memory>
#include <string>
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/agg_vector.h"
#include "fcp/aggregation/core/datatype.h"
#include "fcp/aggregation/core/intrinsic.h"
+#include "fcp/aggregation/core/mutable_vector_data.h"
#include "fcp/aggregation/core/one_dim_grouping_aggregator.h"
#include "fcp/aggregation/core/tensor.pb.h"
#include "fcp/aggregation/core/tensor_aggregator.h"
@@ -89,26 +91,33 @@ class GroupingFederatedSum final
};
template <typename OutputT>
-StatusOr<std::unique_ptr<TensorAggregator>> CreateGroupingFederatedSum() {
+StatusOr<std::unique_ptr<TensorAggregator>> CreateGroupingFederatedSum(
+ const OneDimGroupingAggregatorState* aggregator_state) {
if (internal::TypeTraits<OutputT>::type_kind !=
internal::TypeKind::kNumeric) {
// Ensure the type is numeric in case new non-numeric types are added.
return FCP_STATUS(INVALID_ARGUMENT)
<< "GroupingFederatedSum is only supported for numeric datatypes.";
}
- return std::unique_ptr<TensorAggregator>(
- new GroupingFederatedSum<OutputT, OutputT>());
+ return aggregator_state == nullptr
+ ? std::make_unique<GroupingFederatedSum<OutputT, OutputT>>()
+ : std::make_unique<GroupingFederatedSum<OutputT, OutputT>>(
+ MutableVectorData<OutputT>::CreateFromEncodedContent(
+ aggregator_state->vector_data()),
+ aggregator_state->num_inputs());
}
template <>
StatusOr<std::unique_ptr<TensorAggregator>>
-CreateGroupingFederatedSum<string_view>() {
+CreateGroupingFederatedSum<string_view>(
+ const OneDimGroupingAggregatorState* aggregator_state) {
return FCP_STATUS(INVALID_ARGUMENT)
<< "GroupingFederatedSum isn't supported for DT_STRING datatype.";
}
// Factory class for the GroupingFederatedSum.
-class GroupingFederatedSumFactory final : public TensorAggregatorFactory {
+class GroupingFederatedSumFactory final
+ : public OneDimBaseGroupingAggregatorFactory {
public:
GroupingFederatedSumFactory() = default;
@@ -117,8 +126,10 @@ class GroupingFederatedSumFactory final : public TensorAggregatorFactory {
GroupingFederatedSumFactory& operator=(const GroupingFederatedSumFactory&) =
delete;
- StatusOr<std::unique_ptr<TensorAggregator>> Create(
- const Intrinsic& intrinsic) const override {
+ private:
+ StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal(
+ const Intrinsic& intrinsic,
+ const OneDimGroupingAggregatorState* aggregator_state) const override {
if (kGoogleSqlSumUri != intrinsic.uri) {
return FCP_STATUS(INVALID_ARGUMENT)
<< "GroupingFederatedSumFactory: Expected intrinsic URI "
@@ -166,12 +177,20 @@ class GroupingFederatedSumFactory final : public TensorAggregatorFactory {
// case.
if (input_spec.dtype() == DataType::DT_INT32 &&
output_spec.dtype() == DataType::DT_INT64) {
- return std::unique_ptr<TensorAggregator>(
- new GroupingFederatedSum<int32_t, int64_t>());
+ return aggregator_state == nullptr
+ ? std::make_unique<GroupingFederatedSum<int32_t, int64_t>>()
+ : std::make_unique<GroupingFederatedSum<int32_t, int64_t>>(
+ MutableVectorData<int64_t>::CreateFromEncodedContent(
+ aggregator_state->vector_data()),
+ aggregator_state->num_inputs());
} else if (input_spec.dtype() == DataType::DT_FLOAT &&
output_spec.dtype() == DataType::DT_DOUBLE) {
- return std::unique_ptr<TensorAggregator>(
- new GroupingFederatedSum<float, double>());
+ return aggregator_state == nullptr
+ ? std::make_unique<GroupingFederatedSum<float, double>>()
+ : std::make_unique<GroupingFederatedSum<float, double>>(
+ MutableVectorData<double>::CreateFromEncodedContent(
+ aggregator_state->vector_data()),
+ aggregator_state->num_inputs());
} else {
return FCP_STATUS(INVALID_ARGUMENT)
<< "GroupingFederatedSumFactory: Input and output tensors have "
@@ -181,9 +200,11 @@ class GroupingFederatedSumFactory final : public TensorAggregatorFactory {
<< DataType_Name(output_spec.dtype());
}
}
+
StatusOr<std::unique_ptr<TensorAggregator>> aggregator;
- DTYPE_CASES(output_spec.dtype(), OutputT,
- aggregator = CreateGroupingFederatedSum<OutputT>());
+ DTYPE_CASES(
+ output_spec.dtype(), OutputT,
+ aggregator = CreateGroupingFederatedSum<OutputT>(aggregator_state));
return aggregator;
}
};
diff --git a/fcp/aggregation/core/grouping_federated_sum_test.cc b/fcp/aggregation/core/grouping_federated_sum_test.cc
index df6b9bb..7c777f1 100644
--- a/fcp/aggregation/core/grouping_federated_sum_test.cc
+++ b/fcp/aggregation/core/grouping_federated_sum_test.cc
@@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
#include <cstdint>
#include <memory>
+#include <string>
#include <utility>
#include "gmock/gmock.h"
@@ -38,6 +40,9 @@ namespace {
using ::testing::Eq;
using testing::HasSubstr;
using ::testing::IsTrue;
+using testing::TestWithParam;
+
+using GroupingFederatedSumTest = TestWithParam<bool>;
Intrinsic GetDefaultIntrinsic() {
// One "GoogleSQL:sum" intrinsic with a single int32 tensor of unknown size.
@@ -48,7 +53,7 @@ Intrinsic GetDefaultIntrinsic() {
{}};
}
-TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) {
+TEST_P(GroupingFederatedSumTest, ScalarAggregationSucceeds) {
auto aggregator = CreateTensorAggregator(GetDefaultIntrinsic()).value();
Tensor ordinal =
Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
@@ -57,6 +62,17 @@ TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) {
Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
EXPECT_THAT(aggregator->Accumulate({&ordinal, &t1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&ordinal, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release()));
+ auto state = std::move(*(one_dim_base_aggregator)).ToProto();
+ aggregator = factory->FromProto(GetDefaultIntrinsic(), state).value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&ordinal, &t3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
@@ -67,7 +83,7 @@ TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) {
EXPECT_THAT(result.value()[0], IsTensor<int64_t>({1}, {6}));
}
-TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) {
+TEST_P(GroupingFederatedSumTest, DenseAggregationSucceeds) {
TensorShape shape{4};
auto aggregator = CreateTensorAggregator(GetDefaultIntrinsic()).value();
Tensor ordinals =
@@ -81,6 +97,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) {
Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release()));
+ auto state = std::move(*(one_dim_base_aggregator)).ToProto();
+ aggregator = factory->FromProto(GetDefaultIntrinsic(), state).value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -94,15 +121,14 @@ TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) {
+TEST_P(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) {
TensorShape shape{4};
- auto aggregator =
- CreateTensorAggregator(Intrinsic{"GoogleSQL:sum",
- {TensorSpec{"foo", DT_INT32, {-1}}},
- {TensorSpec{"foo_out", DT_INT64, {-1}}},
- {},
- {}})
- .value();
+ Intrinsic intrinsic{"GoogleSQL:sum",
+ {TensorSpec{"foo", DT_INT32, {-1}}},
+ {TensorSpec{"foo_out", DT_INT64, {-1}}},
+ {},
+ {}};
+ auto aggregator = CreateTensorAggregator(intrinsic).value();
Tensor ordinals =
Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3}))
.value();
@@ -114,6 +140,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) {
Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release()));
+ auto state = std::move(*(one_dim_base_aggregator)).ToProto();
+ aggregator = factory->FromProto(intrinsic, state).value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -127,15 +164,15 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) {
+TEST_P(GroupingFederatedSumTest,
+ DenseAggregationCastToLargerFloatTypeSucceeds) {
TensorShape shape{4};
- auto aggregator =
- CreateTensorAggregator(Intrinsic{"GoogleSQL:sum",
- {TensorSpec{"foo", DT_FLOAT, {-1}}},
- {TensorSpec{"foo_out", DT_DOUBLE, {-1}}},
- {},
- {}})
- .value();
+ Intrinsic intrinsic{"GoogleSQL:sum",
+ {TensorSpec{"foo", DT_FLOAT, {-1}}},
+ {TensorSpec{"foo_out", DT_DOUBLE, {-1}}},
+ {},
+ {}};
+ auto aggregator = CreateTensorAggregator(intrinsic).value();
Tensor ordinals =
Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3}))
.value();
@@ -150,6 +187,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) {
.value();
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release()));
+ auto state = std::move(*(one_dim_base_aggregator)).ToProto();
+ aggregator = factory->FromProto(intrinsic, state).value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -163,7 +211,7 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(GroupingFederatedSumTest, MergeSucceeds) {
+TEST_P(GroupingFederatedSumTest, MergeSucceeds) {
auto aggregator1 = CreateTensorAggregator(GetDefaultIntrinsic()).value();
auto aggregator2 = CreateTensorAggregator(GetDefaultIntrinsic()).value();
Tensor ordinal =
@@ -175,6 +223,21 @@ TEST(GroupingFederatedSumTest, MergeSucceeds) {
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk());
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk());
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator1 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release()));
+ auto state = std::move(*(one_dim_base_aggregator1)).ToProto();
+ aggregator1 = factory->FromProto(GetDefaultIntrinsic(), state).value();
+ auto one_dim_base_aggregator2 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release()));
+ auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto();
+ aggregator2 = factory->FromProto(GetDefaultIntrinsic(), state2).value();
+ }
+
int aggregator2_num_inputs = aggregator2->GetNumInputs();
auto aggregator2_result =
std::move(std::move(*aggregator2).Report().value()[0]);
@@ -194,7 +257,7 @@ TEST(GroupingFederatedSumTest, MergeSucceeds) {
EXPECT_THAT(result.value()[0], IsTensor<int64_t>({1}, {6}));
}
-TEST(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) {
+TEST_P(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) {
auto aggregator1 = CreateTensorAggregator(GetDefaultIntrinsic()).value();
auto aggregator2 = CreateTensorAggregator(GetDefaultIntrinsic()).value();
Tensor ordinal =
@@ -206,6 +269,21 @@ TEST(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) {
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk());
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk());
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator1 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release()));
+ auto state = std::move(*(one_dim_base_aggregator1)).ToProto();
+ aggregator1 = factory->FromProto(GetDefaultIntrinsic(), state).value();
+ auto one_dim_base_aggregator2 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release()));
+ auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto();
+ aggregator2 = factory->FromProto(GetDefaultIntrinsic(), state2).value();
+ }
+
int aggregator2_num_inputs = aggregator2->GetNumInputs();
auto aggregator2_result =
std::move(std::move(*aggregator2).Report().value()[0]);
@@ -343,6 +421,17 @@ TEST(GroupingFederatedSumTest, CreateUnsupportedStringDataType) {
HasSubstr("GroupingFederatedSum isn't supported for DT_STRING datatype"));
}
+TEST(GroupingFederatedSumTest, Deserialize_Unimplemented) {
+ Status s = DeserializeTensorAggregator(GetDefaultIntrinsic(), "").status();
+ EXPECT_THAT(s, IsCode(UNIMPLEMENTED));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ GroupingFederatedSumTestInstantiation, GroupingFederatedSumTest,
+ testing::ValuesIn<bool>({false, true}),
+ [](const testing::TestParamInfo<GroupingFederatedSumTest::ParamType>&
+ info) { return info.param ? "SaveIntermediateState" : "None"; });
+
} // namespace
} // namespace aggregation
} // namespace fcp
diff --git a/fcp/aggregation/core/one_dim_grouping_aggregator.h b/fcp/aggregation/core/one_dim_grouping_aggregator.h
index 926f4f4..d7fe498 100644
--- a/fcp/aggregation/core/one_dim_grouping_aggregator.h
+++ b/fcp/aggregation/core/one_dim_grouping_aggregator.h
@@ -20,15 +20,19 @@
#include <cstddef>
#include <cstdint>
#include <memory>
+#include <string>
#include <utility>
#include <vector>
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/agg_vector.h"
#include "fcp/aggregation/core/datatype.h"
#include "fcp/aggregation/core/input_tensor_list.h"
+#include "fcp/aggregation/core/intrinsic.h"
#include "fcp/aggregation/core/mutable_vector_data.h"
#include "fcp/aggregation/core/tensor.h"
#include "fcp/aggregation/core/tensor_aggregator.h"
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
#include "fcp/aggregation/core/tensor_shape.h"
#include "fcp/base/monitoring.h"
@@ -50,6 +54,15 @@ class OneDimBaseGroupingAggregator : public TensorAggregator {
public:
Status MergeWith(TensorAggregator&& other) override;
+ StatusOr<std::string> Serialize() && override {
+ // OneDimBaseGroupingAggregators are always nested within an outer
+ // aggregator. Use ToProto to get intermediate state and then serialize the
+ // outer aggregator state instead.
+ return FCP_STATUS(UNIMPLEMENTED)
+ << "OneDimBaseGroupingAggregator::Serialize is not supported. Use "
+ "ToProto to store intermediate state.";
+ }
+
// Merges intermediate aggregates contained in the tensors param into the
// current Aggregator instance. Expects a tensors param of size 2, where the
// first tensor contains ordinals and the second tensor contains values. The
@@ -66,11 +79,47 @@ class OneDimBaseGroupingAggregator : public TensorAggregator {
// derived class.
virtual Status MergeTensors(InputTensorList tensors, int num_inputs) = 0;
+ // Stores the intermediate state of the OneDimBaseGroupingAggregator as a
+ // proto.
+ virtual OneDimGroupingAggregatorState ToProto() = 0;
+
protected:
// Checks that the input tensors param is valid.
Status ValidateTensorInputs(const InputTensorList& tensors);
};
+class OneDimBaseGroupingAggregatorFactory : public TensorAggregatorFactory {
+ public:
+ StatusOr<std::unique_ptr<TensorAggregator>> Create(
+ const Intrinsic& intrinsic) const override {
+ return CreateInternal(intrinsic, nullptr);
+ }
+
+ StatusOr<std::unique_ptr<TensorAggregator>> Deserialize(
+ const Intrinsic& intrinsic, std::string serialized_state) const override {
+ OneDimGroupingAggregatorState aggregator_state;
+ // OneDimGroupingAggregators are always nested within an outer aggregator.
+ // Use FromProto to create the aggregator from intermediate state stored by
+ // the outer aggregator.
+ return FCP_STATUS(UNIMPLEMENTED)
+ << "OneDimBaseGroupingAggregatorFactory::Deserialize is not "
+ "supported. Use FromProto to create an aggregator from "
+ "intermediate state.";
+ }
+
+ // Creates a OneDimBaseGroupingAggregator from intermediate state.
+ StatusOr<std::unique_ptr<TensorAggregator>> FromProto(
+ const Intrinsic& intrinsic,
+ const OneDimGroupingAggregatorState& aggregator_state) const {
+ return CreateInternal(intrinsic, &aggregator_state);
+ }
+
+ private:
+ virtual StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal(
+ const Intrinsic& intrinsic,
+ const OneDimGroupingAggregatorState* aggregator_state) const = 0;
+};
+
// OneDimGroupingAggregator class is a specialization of
// OneDimBaseGroupingAggregator.
//
@@ -88,8 +137,12 @@ class OneDimGroupingAggregator : public OneDimBaseGroupingAggregator {
// to the ordinal tensor) should be known in advance and thus this constructor
// should take in a shape with a single unknown dimension.
OneDimGroupingAggregator()
- : data_vector_(std::make_unique<MutableVectorData<OutputT>>()),
- num_inputs_(0) {}
+ : OneDimGroupingAggregator(std::make_unique<MutableVectorData<OutputT>>(),
+ 0) {}
+
+ OneDimGroupingAggregator(std::unique_ptr<MutableVectorData<OutputT>> data,
+ int num_inputs)
+ : data_vector_(std::move(data)), num_inputs_(num_inputs) {}
// Implementation of the tensor merge operation.
Status MergeTensors(InputTensorList tensors, int num_inputs) override {
@@ -121,6 +174,13 @@ class OneDimGroupingAggregator : public OneDimBaseGroupingAggregator {
return FCP_STATUS(OK);
}
+ OneDimGroupingAggregatorState ToProto() override {
+ OneDimGroupingAggregatorState aggregator_state;
+ aggregator_state.set_num_inputs(num_inputs_);
+ *(aggregator_state.mutable_vector_data()) = data_vector_->EncodeContent();
+ return aggregator_state;
+ }
+
protected:
// Provides mutable access to the aggregator data as a vector<T>
inline std::vector<OutputT>& data() { return *data_vector_; }
diff --git a/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc b/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc
index 6095d88..fc44cce 100644
--- a/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc
+++ b/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc
@@ -17,11 +17,14 @@
#include <climits>
#include <cstdint>
+#include <memory>
#include <utility>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/agg_vector.h"
+#include "fcp/aggregation/core/mutable_vector_data.h"
#include "fcp/aggregation/core/tensor.h"
#include "fcp/aggregation/core/tensor_shape.h"
#include "fcp/aggregation/testing/test_data.h"
@@ -37,6 +40,9 @@ using testing::Eq;
using testing::HasSubstr;
using testing::IsFalse;
using testing::IsTrue;
+using testing::TestWithParam;
+
+using OneDimGroupingAggregatorTest = TestWithParam<bool>;
// A simple Sum Aggregator
template <typename InputT, typename OutputT = InputT>
@@ -46,6 +52,14 @@ class SumGroupingAggregator final
using OneDimGroupingAggregator<InputT, OutputT>::OneDimGroupingAggregator;
using OneDimGroupingAggregator<InputT, OutputT>::data;
+ static SumGroupingAggregator<InputT, OutputT> FromProto(
+ const OneDimGroupingAggregatorState& aggregator_state) {
+ return SumGroupingAggregator<InputT, OutputT>(
+ MutableVectorData<OutputT>::CreateFromEncodedContent(
+ aggregator_state.vector_data()),
+ aggregator_state.num_inputs());
+ }
+
private:
void AggregateVectorByOrdinals(
const AggVector<int64_t>& ordinals_vector,
@@ -94,6 +108,14 @@ class MinGroupingAggregator final : public OneDimGroupingAggregator<int32_t> {
using OneDimGroupingAggregator<int32_t>::OneDimGroupingAggregator;
using OneDimGroupingAggregator<int32_t>::data;
+ static MinGroupingAggregator FromProto(
+ const OneDimGroupingAggregatorState& aggregator_state) {
+ return MinGroupingAggregator(
+ MutableVectorData<int32_t>::CreateFromEncodedContent(
+ aggregator_state.vector_data()),
+ aggregator_state.num_inputs());
+ }
+
private:
void AggregateVectorByOrdinals(
const AggVector<int64_t>& ordinals_vector,
@@ -132,15 +154,21 @@ class MinGroupingAggregator final : public OneDimGroupingAggregator<int32_t> {
int32_t GetDefaultValue() override { return INT_MAX; }
};
-TEST(OneDimGroupingAggregatorTest, EmptyReport) {
+TEST_P(OneDimGroupingAggregatorTest, EmptyReport) {
SumGroupingAggregator<int32_t> aggregator;
+
+ if (GetParam()) {
+ auto state = std::move(aggregator).ToProto();
+ aggregator = SumGroupingAggregator<int32_t>::FromProto(state);
+ }
+
auto result = std::move(aggregator).Report();
EXPECT_THAT(result, IsOk());
EXPECT_THAT(result->size(), Eq(1));
EXPECT_THAT(result.value()[0], IsTensor<int32_t>({0}, {}));
}
-TEST(OneDimGroupingAggregatorTest, ScalarAggregation_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest, ScalarAggregation_Succeeds) {
SumGroupingAggregator<int32_t> aggregator;
Tensor ordinal =
Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
@@ -149,6 +177,12 @@ TEST(OneDimGroupingAggregatorTest, ScalarAggregation_Succeeds) {
Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
EXPECT_THAT(aggregator.Accumulate({&ordinal, &t1}), IsOk());
EXPECT_THAT(aggregator.Accumulate({&ordinal, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto state = std::move(aggregator).ToProto();
+ aggregator = SumGroupingAggregator<int32_t>::FromProto(state);
+ }
+
EXPECT_THAT(aggregator.Accumulate({&ordinal, &t3}), IsOk());
EXPECT_THAT(aggregator.CanReport(), IsTrue());
@@ -159,7 +193,7 @@ TEST(OneDimGroupingAggregatorTest, ScalarAggregation_Succeeds) {
EXPECT_THAT(result.value()[0], IsTensor({1}, {6}));
}
-TEST(OneDimGroupingAggregatorTest, DenseAggregation_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest, DenseAggregation_Succeeds) {
const TensorShape shape = {4};
SumGroupingAggregator<int32_t> aggregator;
Tensor ordinals =
@@ -173,6 +207,12 @@ TEST(OneDimGroupingAggregatorTest, DenseAggregation_Succeeds) {
Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
EXPECT_THAT(aggregator.Accumulate({&ordinals, &t1}), IsOk());
EXPECT_THAT(aggregator.Accumulate({&ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto state = std::move(aggregator).ToProto();
+ aggregator = SumGroupingAggregator<int32_t>::FromProto(state);
+ }
+
EXPECT_THAT(aggregator.Accumulate({&ordinals, &t3}), IsOk());
EXPECT_THAT(aggregator.CanReport(), IsTrue());
EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
@@ -186,7 +226,7 @@ TEST(OneDimGroupingAggregatorTest, DenseAggregation_Succeeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(OneDimGroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) {
const TensorShape shape = {4};
SumGroupingAggregator<int32_t> aggregator;
Tensor t1_ordinals =
@@ -202,6 +242,12 @@ TEST(OneDimGroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) {
Tensor t2 =
Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto state = std::move(aggregator).ToProto();
+ aggregator = SumGroupingAggregator<int32_t>::FromProto(state);
+ }
+
// Totals: [32, 11, 15, 4, 2]
Tensor t3_ordinals =
Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({2, 2, 5, 1}))
@@ -222,7 +268,7 @@ TEST(OneDimGroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(OneDimGroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) {
SumGroupingAggregator<int32_t> aggregator;
Tensor t1_ordinals =
Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 0})).value();
@@ -236,6 +282,12 @@ TEST(OneDimGroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) {
Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, 5}))
.value();
EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto state = std::move(aggregator).ToProto();
+ aggregator = SumGroupingAggregator<int32_t>::FromProto(state);
+ }
+
// Totals: [13, 23, 17, 4, 2]
Tensor t3_ordinals =
Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4}))
@@ -256,8 +308,8 @@ TEST(OneDimGroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(OneDimGroupingAggregatorTest,
- DifferentShapesPerAccumulate_NonzeroDefaultValue_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest,
+ DifferentShapesPerAccumulate_NonzeroDefaultValue_Succeeds) {
// Use a MinGroupingAggregator which has a non-zero default value so we can
// test that when the output grows, elements are set to the default value.
MinGroupingAggregator aggregator;
@@ -273,6 +325,12 @@ TEST(OneDimGroupingAggregatorTest,
Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, -50}))
.value();
EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto state = std::move(aggregator).ToProto();
+ aggregator = MinGroupingAggregator::FromProto(state);
+ }
+
// Totals: [-50, INT_MAX, 17, INT_MAX, 2]
Tensor t3_ordinals =
Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4}))
@@ -293,7 +351,7 @@ TEST(OneDimGroupingAggregatorTest,
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(OneDimGroupingAggregatorTest, Merge_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest, Merge_Succeeds) {
SumGroupingAggregator<int32_t> aggregator1;
SumGroupingAggregator<int32_t> aggregator2;
Tensor ordinal =
@@ -305,6 +363,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_Succeeds) {
EXPECT_THAT(aggregator2.Accumulate({&ordinal, &t2}), IsOk());
EXPECT_THAT(aggregator2.Accumulate({&ordinal, &t3}), IsOk());
+ if (GetParam()) {
+ auto state1 = std::move(aggregator1).ToProto();
+ aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1);
+ auto state2 = std::move(aggregator2).ToProto();
+ aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2);
+ }
+
int aggregator2_num_inputs = aggregator2.GetNumInputs();
EXPECT_THAT(aggregator2_num_inputs, Eq(2));
auto aggregator2_result = std::move(aggregator2).Report();
@@ -324,10 +389,17 @@ TEST(OneDimGroupingAggregatorTest, Merge_Succeeds) {
EXPECT_THAT(result.value()[0], IsTensor({1}, {6}));
}
-TEST(OneDimGroupingAggregatorTest, Merge_BothEmpty_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest, Merge_BothEmpty_Succeeds) {
SumGroupingAggregator<int32_t> aggregator1;
SumGroupingAggregator<int32_t> aggregator2;
+ if (GetParam()) {
+ auto state1 = std::move(aggregator1).ToProto();
+ aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1);
+ auto state2 = std::move(aggregator2).ToProto();
+ aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2);
+ }
+
// Merge the two empty aggregators together.
auto empty_ordinals =
Tensor::Create(DT_INT64, {0}, CreateTestData<int64_t>({})).value();
@@ -346,7 +418,7 @@ TEST(OneDimGroupingAggregatorTest, Merge_BothEmpty_Succeeds) {
EXPECT_THAT(result.value()[0], IsTensor<int32_t>({0}, {}));
}
-TEST(OneDimGroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) {
SumGroupingAggregator<int32_t> aggregator1;
SumGroupingAggregator<int32_t> aggregator2;
@@ -365,6 +437,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) {
EXPECT_THAT(aggregator2.Accumulate({&t2_ordinals, &t2}), IsOk());
// aggregator2 totals: [32, 11, 15, 4, 2]
+ if (GetParam()) {
+ auto state1 = std::move(aggregator1).ToProto();
+ aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1);
+ auto state2 = std::move(aggregator2).ToProto();
+ aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2);
+ }
+
// Merge aggregator2 into aggregator1 which has not received any inputs.
int aggregator2_num_inputs = aggregator2.GetNumInputs();
auto aggregator2_result =
@@ -388,7 +467,7 @@ TEST(OneDimGroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) {
SumGroupingAggregator<int32_t> aggregator1;
SumGroupingAggregator<int32_t> aggregator2;
@@ -407,6 +486,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) {
EXPECT_THAT(aggregator1.Accumulate({&t2_ordinals, &t2}), IsOk());
// aggregator1 totals: [32, 11, 15, 4, 2]
+ if (GetParam()) {
+ auto state1 = std::move(aggregator1).ToProto();
+ aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1);
+ auto state2 = std::move(aggregator2).ToProto();
+ aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2);
+ }
+
// Merge with aggregator2 which has not received any inputs.
auto empty_ordinals =
Tensor::Create(DT_INT64, {0}, CreateTestData<int64_t>({})).value();
@@ -428,7 +514,8 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasFewerElements_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest,
+ Merge_OtherOutputHasFewerElements_Succeeds) {
SumGroupingAggregator<int32_t> aggregator1;
SumGroupingAggregator<int32_t> aggregator2;
@@ -453,6 +540,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasFewerElements_Succeeds) {
EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk());
// aggregator2 totals: [0, 0, 14]
+ if (GetParam()) {
+ auto state1 = std::move(aggregator1).ToProto();
+ aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1);
+ auto state2 = std::move(aggregator2).ToProto();
+ aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2);
+ }
+
int aggregator2_num_inputs = aggregator2.GetNumInputs();
auto aggregator2_result =
std::move(std::move(aggregator2).Report().value()[0]);
@@ -474,7 +568,8 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasFewerElements_Succeeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasMoreElements_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest,
+ Merge_OtherOutputHasMoreElements_Succeeds) {
SumGroupingAggregator<int32_t> aggregator1;
SumGroupingAggregator<int32_t> aggregator2;
@@ -501,6 +596,13 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasMoreElements_Succeeds) {
EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk());
// aggregator2 totals: [0, 20, 14, 0, 0, 7]
+ if (GetParam()) {
+ auto state1 = std::move(aggregator1).ToProto();
+ aggregator1 = SumGroupingAggregator<int32_t>::FromProto(state1);
+ auto state2 = std::move(aggregator2).ToProto();
+ aggregator2 = SumGroupingAggregator<int32_t>::FromProto(state2);
+ }
+
int aggregator2_num_inputs = aggregator2.GetNumInputs();
auto aggregator2_result =
std::move(std::move(aggregator2).Report().value()[0]);
@@ -523,8 +625,8 @@ TEST(OneDimGroupingAggregatorTest, Merge_OtherOutputHasMoreElements_Succeeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(OneDimGroupingAggregatorTest,
- Merge_OtherOutputHasMoreElements_NonzeroDefaultValue_Succeeds) {
+TEST_P(OneDimGroupingAggregatorTest,
+ Merge_OtherOutputHasMoreElements_NonzeroDefaultValue_Succeeds) {
// Use a MinGroupingAggregator which has a non-zero default value so we can
// test that when the output grows, elements are set to the default value.
MinGroupingAggregator aggregator1;
@@ -551,6 +653,13 @@ TEST(OneDimGroupingAggregatorTest,
EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk());
// aggregator2 totals: [-50, 7, 11, INT_MAX, 2]
+ if (GetParam()) {
+ auto state1 = std::move(aggregator1).ToProto();
+ aggregator1 = MinGroupingAggregator::FromProto(state1);
+ auto state2 = std::move(aggregator2).ToProto();
+ aggregator2 = MinGroupingAggregator::FromProto(state2);
+ }
+
int aggregator2_num_inputs = aggregator2.GetNumInputs();
auto aggregator2_result =
std::move(std::move(aggregator2).Report().value()[0]);
@@ -657,6 +766,18 @@ TEST(OneDimGroupingAggregatorTest, FailsAfterBeingConsumed) {
IsCode(FAILED_PRECONDITION));
}
+TEST(OneDimGroupingAggregatorTest, Serialize_Unimplmeneted) {
+ SumGroupingAggregator<int32_t> aggregator;
+ Status s = std::move(aggregator).Serialize().status();
+ EXPECT_THAT(s, IsCode(UNIMPLEMENTED));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OneDimGroupingAggregatorTestInstantiation, OneDimGroupingAggregatorTest,
+ testing::ValuesIn<bool>({false, true}),
+ [](const testing::TestParamInfo<OneDimGroupingAggregatorTest::ParamType>&
+ info) { return info.param ? "SaveIntermediateState" : "None"; });
+
} // namespace
} // namespace aggregation
} // namespace fcp