From a767503518bfb2f2b954f021bd979e2c1967666d Mon Sep 17 00:00:00 2001 From: Pierre Tholoniat Date: Wed, 27 May 2026 13:55:08 -0700 Subject: [PATCH] Refactor the Codec into a single class, and rename the concrete factory. PiperOrigin-RevId: 922329519 --- willow/input_encoding/BUILD | 6 +- .../{codec_factory.cc => codec.cc} | 55 ++++++------ willow/input_encoding/codec.h | 32 +++++++ willow/input_encoding/codec_bindings.cc | 41 +++++++-- willow/input_encoding/codec_bindings_test.py | 4 +- willow/input_encoding/codec_factory.h | 22 +++-- .../{explicit_codec_test.cc => codec_test.cc} | 84 +++++++++---------- 7 files changed, 155 insertions(+), 89 deletions(-) rename willow/input_encoding/{codec_factory.cc => codec.cc} (90%) rename willow/input_encoding/{explicit_codec_test.cc => codec_test.cc} (86%) diff --git a/willow/input_encoding/BUILD b/willow/input_encoding/BUILD index 48bd377..00d5d8c 100644 --- a/willow/input_encoding/BUILD +++ b/willow/input_encoding/BUILD @@ -26,7 +26,7 @@ package( cc_library( name = "codec", srcs = [ - "codec_factory.cc", + "codec.cc", ], hdrs = [ "codec.h", @@ -44,8 +44,8 @@ cc_library( ) cc_test( - name = "explicit_codec_test", - srcs = ["explicit_codec_test.cc"], + name = "codec_test", + srcs = ["codec_test.cc"], deps = [ ":codec", "@googletest//:gtest_main", diff --git a/willow/input_encoding/codec_factory.cc b/willow/input_encoding/codec.cc similarity index 90% rename from willow/input_encoding/codec_factory.cc rename to willow/input_encoding/codec.cc index b6e8ec3..534b4f9 100644 --- a/willow/input_encoding/codec_factory.cc +++ b/willow/input_encoding/codec.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "willow/input_encoding/codec_factory.h" +#include "willow/input_encoding/codec.h" #include #include @@ -28,7 +28,6 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "willow/input_encoding/codec.h" #include "willow/proto/willow/input_spec.pb.h" namespace secure_aggregation { @@ -53,13 +52,17 @@ struct GroupDomainKey { } }; -// WillowInputExplicitEncoder must be instantiated through the factory class -// CodecFactory. -class ExplicitCodecImpl : public Codec { +// FlatHistogramCodecImpl implements a Codec that encodes data into a dense, +// flat 1D histogram representing the Cartesian product of the group-by +// domains. +// +// It must be instantiated through the factory function +// Codec::CreateFlatHistogramCodec. +class FlatHistogramCodecImpl : public Codec { public: - ExplicitCodecImpl(const ExplicitCodecImpl&) = delete; - ExplicitCodecImpl& operator=(const ExplicitCodecImpl&) = delete; - ~ExplicitCodecImpl() override = default; + FlatHistogramCodecImpl(const FlatHistogramCodecImpl&) = delete; + FlatHistogramCodecImpl& operator=(const FlatHistogramCodecImpl&) = delete; + ~FlatHistogramCodecImpl() override = default; absl::StatusOr Encode( const GroupData& group_by_data, @@ -94,7 +97,7 @@ class ExplicitCodecImpl : public Codec { size_t GetCombinedIndex(const std::vector& indices) const; - explicit ExplicitCodecImpl( + explicit FlatHistogramCodecImpl( InputSpec input_spec, absl::flat_hash_map group_by_spec_map, @@ -127,10 +130,10 @@ class ExplicitCodecImpl : public Codec { flattened_domain_size_ *= domain_size; } } - friend class CodecFactory; + friend class Codec; }; -absl::Status ExplicitCodecImpl::ValidateData( +absl::Status FlatHistogramCodecImpl::ValidateData( const GroupData& group_by_data, const MetricData& metric_data) const { // Check that all vectors in metric_data and group_by_data are present in // metric_spec_map and group_by_spec_map_, respectively. This ensures that the @@ -221,7 +224,7 @@ absl::Status ExplicitCodecImpl::ValidateData( // Returns the indices of elements in individual domains/vectors of size `sizes` // that correspond to the global index `global_index` of an element of their // cartesian product. -std::vector ExplicitCodecImpl::GetIndices(int global_index) const { +std::vector FlatHistogramCodecImpl::GetIndices(int global_index) const { if (group_by_domain_sizes_.empty()) { return {}; } @@ -241,7 +244,7 @@ std::vector ExplicitCodecImpl::GetIndices(int global_index) const { // domains of size 2 and 3 respectively, and we want to find overall index of // an element that has index 1 in the first domain and index 0 in the second // domain. The function will return 1 * 3 + 0 = 3. -size_t ExplicitCodecImpl::GetCombinedIndex( +size_t FlatHistogramCodecImpl::GetCombinedIndex( const std::vector& indices) const { int64_t combined_index = 0; for (int i = 0; i < indices.size(); ++i) { @@ -251,7 +254,7 @@ size_t ExplicitCodecImpl::GetCombinedIndex( return combined_index; } -absl::StatusOr ExplicitCodecImpl::Encode( +absl::StatusOr FlatHistogramCodecImpl::Encode( const GroupData& group_by_data, const MetricData& metric_data) const { if (absl::Status status = ValidateData(group_by_data, metric_data); !status.ok()) { @@ -286,7 +289,7 @@ absl::StatusOr ExplicitCodecImpl::Encode( return result; } -absl::StatusOr ExplicitCodecImpl::Decode( +absl::StatusOr FlatHistogramCodecImpl::Decode( const EncodedData& encoded_data) const { DecodedData decoded_data; @@ -348,7 +351,7 @@ absl::StatusOr ExplicitCodecImpl::Decode( return decoded_data; } -absl::Status ExplicitCodecImpl::ValidateExampleQuery( +absl::Status FlatHistogramCodecImpl::ValidateExampleQuery( const absl::flat_hash_map& query_output_specs) const { for (const auto& [name, type] : query_output_specs) { @@ -373,20 +376,20 @@ absl::Status ExplicitCodecImpl::ValidateExampleQuery( return absl::OkStatus(); } -absl::Status CodecFactory::ValidateExplicitCodecInputSpec( - const InputSpec& input_spec, size_t max_flattened_domain_size) { - size_t flattened_domain_size = 1; +absl::Status Codec::ValidateInputSpec(const InputSpec& input_spec, + size_t max_flat_histogram_bins) { + size_t flat_histogram_bins = 1; for (const auto& spec : input_spec.group_by_vector_specs()) { - flattened_domain_size *= spec.domain_spec().string_values().values_size(); - if (max_flattened_domain_size < flattened_domain_size) { + flat_histogram_bins *= spec.domain_spec().string_values().values_size(); + if (max_flat_histogram_bins < flat_histogram_bins) { return absl::InvalidArgumentError( - "Global output domain size exceeds maximum threshold."); + "Flat histogram bin count exceeds maximum threshold."); } } return absl::OkStatus(); } -absl::StatusOr> CodecFactory::CreateExplicitCodec( +absl::StatusOr> Codec::CreateFlatHistogramCodec( InputSpec input_spec) { // Check that specs include at least one metric vector. if (input_spec.metric_vector_specs().empty()) { @@ -408,9 +411,9 @@ absl::StatusOr> CodecFactory::CreateExplicitCodec( absl::StrCat("Duplicate vector name: ", spec.vector_name())); } } - return absl::WrapUnique(new ExplicitCodecImpl(std::move(input_spec), - std::move(group_by_spec_map), - std::move(metric_spec_map))); + return absl::WrapUnique(new FlatHistogramCodecImpl( + std::move(input_spec), std::move(group_by_spec_map), + std::move(metric_spec_map))); } } // namespace willow diff --git a/willow/input_encoding/codec.h b/willow/input_encoding/codec.h index 540494f..6489def 100644 --- a/willow/input_encoding/codec.h +++ b/willow/input_encoding/codec.h @@ -15,17 +15,25 @@ #ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_H_ #define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_H_ +#include #include +#include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "willow/proto/willow/input_spec.pb.h" namespace secure_aggregation { namespace willow { +// The maximum number of bins in a flat histogram, which is the maximum size of +// the Cartesian product of domains for string features. +constexpr size_t kMaxFlatHistogramBins = 1000000; + using MetricData = absl::flat_hash_map>; using GroupData = absl::flat_hash_map>; using EncodedData = absl::flat_hash_map>; @@ -69,6 +77,30 @@ class Codec { virtual absl::Status ValidateExampleQuery( const absl::flat_hash_map& query_output_specs) const = 0; + + // Creates an instance of FlatHistogramCodec. + static absl::StatusOr> CreateFlatHistogramCodec( + ::secure_aggregation::willow::InputSpec input_spec); + + // Check that the combined size of the string domains is less than the + // maximum allowed size. + static absl::Status ValidateInputSpec( + const ::secure_aggregation::willow::InputSpec& input_spec, + size_t max_flat_histogram_bins = kMaxFlatHistogramBins); + + // Deprecated aliases for backward compatibility + [[deprecated("Use CreateFlatHistogramCodec instead")]] + static absl::StatusOr> CreateExplicitCodec( + ::secure_aggregation::willow::InputSpec input_spec) { + return CreateFlatHistogramCodec(std::move(input_spec)); + } + + [[deprecated("Use ValidateInputSpec instead")]] + static absl::Status ValidateExplicitCodecInputSpec( + const ::secure_aggregation::willow::InputSpec& input_spec, + size_t max_flattened_domain_size = kMaxFlatHistogramBins) { + return ValidateInputSpec(input_spec, max_flattened_domain_size); + } }; } // namespace willow diff --git a/willow/input_encoding/codec_bindings.cc b/willow/input_encoding/codec_bindings.cc index d076527..3e0edaa 100644 --- a/willow/input_encoding/codec_bindings.cc +++ b/willow/input_encoding/codec_bindings.cc @@ -23,7 +23,6 @@ #include "pybind11/stl.h" #include "pybind11_abseil/status_casters.h" #include "willow/input_encoding/codec.h" -#include "willow/input_encoding/codec_factory.h" #include "willow/proto/willow/input_spec.pb.h" namespace py = pybind11; @@ -46,6 +45,19 @@ PYBIND11_MODULE(codec_bindings, m) { return self.ValidateExampleQuery(cpp_map); }); + m.def( + "CreateFlatHistogramCodec", + [](const std::string& serialized_input_spec) + -> absl::StatusOr> { + InputSpec input_spec; + if (!input_spec.ParseFromString(serialized_input_spec)) { + return absl::InvalidArgumentError("Failed to parse InputSpec"); + } + return Codec::CreateFlatHistogramCodec(input_spec); + }, + py::arg("serialized_input_spec")); + + // DEPRECATED: Use CreateFlatHistogramCodec instead. m.def( "CreateExplicitCodec", [](const std::string& serialized_input_spec) @@ -54,10 +66,29 @@ PYBIND11_MODULE(codec_bindings, m) { if (!input_spec.ParseFromString(serialized_input_spec)) { return absl::InvalidArgumentError("Failed to parse InputSpec"); } - return CodecFactory::CreateExplicitCodec(input_spec); + return Codec::CreateFlatHistogramCodec(input_spec); }, py::arg("serialized_input_spec")); + m.def( + "ValidateInputSpec", + [](const std::string& serialized_input_spec, + size_t max_flattened_domain_size) -> absl::Status { + InputSpec input_spec; + if (!input_spec.ParseFromString(serialized_input_spec)) { + return absl::InvalidArgumentError("Failed to parse InputSpec"); + } + if (max_flattened_domain_size == 0) { + return Codec::ValidateInputSpec(input_spec); + } else { + return Codec::ValidateInputSpec(input_spec, + max_flattened_domain_size); + } + }, + py::arg("serialized_input_spec"), + py::arg("max_flattened_domain_size") = 0); + + // DEPRECATED: Use ValidateInputSpec instead. m.def( "ValidateExplicitCodecInputSpec", [](const std::string& serialized_input_spec, @@ -67,10 +98,10 @@ PYBIND11_MODULE(codec_bindings, m) { return absl::InvalidArgumentError("Failed to parse InputSpec"); } if (max_flattened_domain_size == 0) { - return CodecFactory::ValidateExplicitCodecInputSpec(input_spec); + return Codec::ValidateInputSpec(input_spec); } else { - return CodecFactory::ValidateExplicitCodecInputSpec( - input_spec, max_flattened_domain_size); + return Codec::ValidateInputSpec(input_spec, + max_flattened_domain_size); } }, py::arg("serialized_input_spec"), diff --git a/willow/input_encoding/codec_bindings_test.py b/willow/input_encoding/codec_bindings_test.py index 7d77fc3..50442f2 100644 --- a/willow/input_encoding/codec_bindings_test.py +++ b/willow/input_encoding/codec_bindings_test.py @@ -35,7 +35,9 @@ def setUp(self): ) ], ) - self.codec = codec_bindings.CreateExplicitCodec(spec.SerializeToString()) + self.codec = codec_bindings.CreateFlatHistogramCodec( + spec.SerializeToString() + ) def test_validate_example_query_success(self): query_specs = { diff --git a/willow/input_encoding/codec_factory.h b/willow/input_encoding/codec_factory.h index 5ef4bb6..316f6b2 100644 --- a/willow/input_encoding/codec_factory.h +++ b/willow/input_encoding/codec_factory.h @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,8 +14,10 @@ #ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_ #define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_ + #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -25,22 +27,18 @@ namespace secure_aggregation { namespace willow { -// The maximum size of the Cartesian product of domains for string features. -constexpr size_t kMaxGlobalOutputDomainSize = 1000000; - -// Factory class that constructs non-copyable instances of children classes of -// Codec. -class CodecFactory { +class [[deprecated("Use Codec class static methods instead")]] CodecFactory { public: - // Creates an instance of ExplicitCodec. static absl::StatusOr> CreateExplicitCodec( - InputSpec input_spec); + InputSpec input_spec) { + return Codec::CreateFlatHistogramCodec(std::move(input_spec)); + } - // Check that the combined size of the string domains is less than the - // maximum allowed size. static absl::Status ValidateExplicitCodecInputSpec( const InputSpec& input_spec, - size_t max_flattened_domain_size = kMaxGlobalOutputDomainSize); + size_t max_flattened_domain_size = kMaxFlatHistogramBins) { + return Codec::ValidateInputSpec(input_spec, max_flattened_domain_size); + } }; } // namespace willow diff --git a/willow/input_encoding/explicit_codec_test.cc b/willow/input_encoding/codec_test.cc similarity index 86% rename from willow/input_encoding/explicit_codec_test.cc rename to willow/input_encoding/codec_test.cc index eb46406..890a3e1 100644 --- a/willow/input_encoding/explicit_codec_test.cc +++ b/willow/input_encoding/codec_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "willow/input_encoding/codec.h" + #include #include #include @@ -21,8 +23,6 @@ #include "ffi_utils/status_matchers.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "willow/input_encoding/codec.h" -#include "willow/input_encoding/codec_factory.h" #include "willow/proto/willow/input_spec.pb.h" #include "willow/testing_utils/testing_utils.h" @@ -37,7 +37,7 @@ using ::testing::HasSubstr; using ::testing::Pair; using ::testing::UnorderedElementsAre; -TEST(CodecFactoryTest, ValidateInputAndSpecLengthMismatch) { +TEST(CodecTest, ValidateInputAndSpecLengthMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; GroupData group_by_data; @@ -49,14 +49,14 @@ TEST(CodecFactoryTest, ValidateInputAndSpecLengthMismatch) { // Missing group_by_vector_specs for "feature1" SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Key feature1 found in group_by_data but not " "in input_spec."))); } -TEST(CodecFactoryTest, ValidateInputAndSpecTypeMismatch) { +TEST(CodecTest, ValidateInputAndSpecTypeMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; GroupData group_by_data; @@ -66,13 +66,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecTypeMismatch) { metric_spec->set_data_type(InputSpec::STRING); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Type mismatch for key metric1"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecEmptyInputData) { +TEST(CodecTest, ValidateInputAndSpecEmptyInputData) { MetricData metric_data; GroupData group_by_data; InputSpec input_spec; @@ -84,13 +84,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecEmptyInputData) { group_by_spec->set_data_type(InputSpec::STRING); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Metric data cannot be empty."))); } -TEST(CodecFactoryTest, ValidateInputAndSpecDomainValueNotFound) { +TEST(CodecTest, ValidateInputAndSpecDomainValueNotFound) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -108,13 +108,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecDomainValueNotFound) { "b"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Domain mismatch for key feature1"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecInputDataVectorLengthMismatch) { +TEST(CodecTest, ValidateInputAndSpecInputDataVectorLengthMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; metric_data["metric2"] = {1, 2}; @@ -128,13 +128,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecInputDataVectorLengthMismatch) { metric_spec2->set_data_type(InputSpec::INT64); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("must have the same length"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecGroupByDataVectorLengthMismatch) { +TEST(CodecTest, ValidateInputAndSpecGroupByDataVectorLengthMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; GroupData group_by_data; @@ -152,13 +152,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGroupByDataVectorLengthMismatch) { "b"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("must have the same length"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecDomainSizeVectorLengthMismatch) { +TEST(CodecTest, ValidateInputAndSpecDomainSizeVectorLengthMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; GroupData group_by_data; @@ -177,14 +177,14 @@ TEST(CodecFactoryTest, ValidateInputAndSpecDomainSizeVectorLengthMismatch) { "b"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Domain mismatch for key feature1: " "group_by_data value c not found in domain"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecInputKeyNotInSpec) { +TEST(CodecTest, ValidateInputAndSpecInputKeyNotInSpec) { MetricData metric_data; metric_data["metric1"] = {1}; metric_data["metric2"] = {2}; @@ -197,14 +197,14 @@ TEST(CodecFactoryTest, ValidateInputAndSpecInputKeyNotInSpec) { // Missing metric_vector_specs for "metric2" SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Key metric2 found in metric_data but not in " "input_spec."))); } -TEST(CodecFactoryTest, ValidateInputAndSpecGroupByKeyNotInSpec) { +TEST(CodecTest, ValidateInputAndSpecGroupByKeyNotInSpec) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -217,14 +217,14 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGroupByKeyNotInSpec) { // Missing group_by_vector_specs for "feature1" SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Key feature1 found in group_by_data but not " "in input_spec."))); } -TEST(CodecFactoryTest, ValidateInputAndSpecGroupByTypeMismatch) { +TEST(CodecTest, ValidateInputAndSpecGroupByTypeMismatch) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -241,13 +241,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGroupByTypeMismatch) { "y"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Type mismatch for key feature1"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecGlobalDomainSizeExceeded) { +TEST(CodecTest, ValidateInputAndSpecMaxFlatHistogramBinsExceeded) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -270,12 +270,12 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGlobalDomainSizeExceeded) { ->add_values(std::to_string(i)); } - EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec), + EXPECT_THAT(Codec::ValidateInputSpec(input_spec), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Global output domain size exceeds"))); + HasSubstr("Flat histogram bin count exceeds"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) { +TEST(CodecTest, ValidateInputAndSpecCustomMaxFlatHistogramBins) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -292,13 +292,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) { group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( "b"); // Domain size is 2. - EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 1), + EXPECT_THAT(Codec::ValidateInputSpec(input_spec, 1), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Global output domain size exceeds"))); - SECAGG_EXPECT_OK(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 2)); + HasSubstr("Flat histogram bin count exceeds"))); + SECAGG_EXPECT_OK(Codec::ValidateInputSpec(input_spec, 2)); } -TEST(CodecFactoryTest, EncodeSimpleGroupBy) { +TEST(CodecTest, EncodeSimpleGroupBy) { InputSpec input_spec = CreateTestInputSpecProto(); MetricData metric_data = CreateTestMetricData(); GroupData group_by_data = CreateTestGroupData(); @@ -326,14 +326,14 @@ TEST(CodecFactoryTest, EncodeSimpleGroupBy) { // Result: [0, 20, 0, 0, 0, 0, 10, 5] SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), IsOkAndHolds(UnorderedElementsAre( Pair("metric1", ElementsAre(0, 20, 0, 0, 0, 0, 10, 5))))); } -TEST(CodecFactoryTest, EncodeTwoMetricsOneGroupBy) { +TEST(CodecTest, EncodeTwoMetricsOneGroupBy) { MetricData metric_data; metric_data["metric1"] = {10, 20}; metric_data["metric2"] = {100, 200}; @@ -370,7 +370,7 @@ TEST(CodecFactoryTest, EncodeTwoMetricsOneGroupBy) { // metric2: [200, 100] SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), IsOkAndHolds(UnorderedElementsAre( @@ -378,13 +378,13 @@ TEST(CodecFactoryTest, EncodeTwoMetricsOneGroupBy) { Pair("metric2", ElementsAre(200, 100))))); } -TEST(CodecFactoryTest, EncodeThenDecode) { +TEST(CodecTest, EncodeThenDecode) { InputSpec input_spec = CreateTestInputSpecProto(); MetricData metric_data = CreateTestMetricData(); GroupData group_by_data = CreateTestGroupData(); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, encoder->Encode(group_by_data, metric_data)); @@ -409,13 +409,13 @@ TEST(CodecFactoryTest, EncodeThenDecode) { Pair("lang", ElementsAre("es", "en", "es")))); } -TEST(CodecFactoryTest, EncodeThenDecodeDataOrderDoesNotMatter) { +TEST(CodecTest, EncodeThenDecodeDataOrderDoesNotMatter) { InputSpec input_spec = CreateTestInputSpecProto(); MetricData metric_data1 = CreateTestMetricData(); GroupData group_by_data1 = CreateTestGroupData(); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder1, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); // Note that the order of metric_data2 and group_by_data2 is different from // metric_data1 and group_by_data1. The decoded result should be the same. @@ -426,7 +426,7 @@ TEST(CodecFactoryTest, EncodeThenDecodeDataOrderDoesNotMatter) { group_by_data2["country"] = {"CA", "US", "US"}; SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder2, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); SECAGG_ASSERT_OK_AND_ASSIGN(auto encoded_data, encoder1->Encode(group_by_data1, metric_data1)); @@ -444,7 +444,7 @@ TEST(CodecFactoryTest, EncodeThenDecodeDataOrderDoesNotMatter) { Pair("lang", ElementsAre("es", "en", "es")))); } -TEST(CodecFactoryTest, EncodeThenDecodeNoGroupBy) { +TEST(CodecTest, EncodeThenDecodeNoGroupBy) { MetricData metric_data; metric_data["metric1"] = {10, 20, 5}; MetricData expected_metric_data = metric_data; @@ -455,7 +455,7 @@ TEST(CodecFactoryTest, EncodeThenDecodeNoGroupBy) { metric_spec->set_data_type(InputSpec::INT64); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, encoder->Encode(group_by_data, metric_data)); @@ -470,7 +470,7 @@ TEST(CodecFactoryTest, EncodeThenDecodeNoGroupBy) { UnorderedElementsAre(Pair("metric1", ElementsAre(10, 20, 5)))); } -TEST(CodecFactoryTest, EncodeWithDomainValueNotFound) { +TEST(CodecTest, EncodeWithDomainValueNotFound) { MetricData metric_data; metric_data["metric1"] = {10}; GroupData group_by_data; @@ -505,7 +505,7 @@ TEST(CodecFactoryTest, EncodeWithDomainValueNotFound) { "b"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument,