From 6d9777205f0c3f2da77a15170127055719327b91 Mon Sep 17 00:00:00 2001 From: Pierre Tholoniat Date: Thu, 28 May 2026 15:42:02 -0700 Subject: [PATCH] Make FlatHistogramCodec implementation more modular and improve test coverage. PiperOrigin-RevId: 923017209 --- willow/api/client.cc | 24 +- willow/input_encoding/BUILD | 7 +- .../{codec_factory.cc => codec.cc} | 257 +++++++++++------- willow/input_encoding/codec.h | 49 ++++ willow/input_encoding/codec_bindings.cc | 26 +- willow/input_encoding/codec_bindings_test.py | 21 +- willow/input_encoding/codec_factory.h | 23 +- .../{explicit_codec_test.cc => codec_test.cc} | 138 +++++++--- 8 files changed, 368 insertions(+), 177 deletions(-) rename willow/input_encoding/{codec_factory.cc => codec.cc} (61%) rename willow/input_encoding/{explicit_codec_test.cc => codec_test.cc} (79%) diff --git a/willow/api/client.cc b/willow/api/client.cc index b36ed42..51c21d0 100644 --- a/willow/api/client.cc +++ b/willow/api/client.cc @@ -14,6 +14,7 @@ #include "willow/api/client.h" +#include #include #include #include @@ -22,7 +23,6 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "ffi_utils/cxx_utils.h" #include "ffi_utils/status_macros.h" @@ -45,24 +45,18 @@ absl::StatusOr CreateAggregationConfig( config_proto.set_max_number_of_decryptors(max_number_of_decryptors); config_proto.set_max_decryptor_dropouts(max_decryptor_dropouts); config_proto.set_key_id(std::string(key_id)); - // All metrics have same vector length, corresponding to the Cartesian product - // of group-by domains. - int64_t flattened_domain_size = 1; - for (const auto& group_by_spec : input_spec_proto.group_by_vector_specs()) { - if (group_by_spec.domain_spec().string_values().values_size() == 0) { - return absl::InvalidArgumentError(absl::StrCat( - "Missing domain, invalid domain type (must be StringValues), or " - "empty string_values for group by vector: ", - group_by_spec.vector_name())); - } - flattened_domain_size *= - group_by_spec.domain_spec().string_values().values_size(); - } + + // Validate the input spec and create the codec (for vector length). + SECAGG_ASSIGN_OR_RETURN( + auto codec, willow::Codec::CreateFlatHistogramCodec(input_spec_proto)); + // Build VectorConfig (length and bound) for each metric. for (const auto& metric_spec : input_spec_proto.metric_vector_specs()) { auto& vector_config = (*config_proto.mutable_vector_configs())[metric_spec.vector_name()]; - vector_config.set_length(flattened_domain_size); + SECAGG_ASSIGN_OR_RETURN(size_t length, codec->GetEncodedVectorLength( + metric_spec.vector_name())); + vector_config.set_length(length); if (metric_spec.has_domain_spec() && metric_spec.domain_spec().has_interval()) { vector_config.set_bound( diff --git a/willow/input_encoding/BUILD b/willow/input_encoding/BUILD index 48bd377..203f680 100644 --- a/willow/input_encoding/BUILD +++ b/willow/input_encoding/BUILD @@ -26,7 +26,7 @@ package( cc_library( name = "codec", srcs = [ - "codec_factory.cc", + "codec.cc", ], hdrs = [ "codec.h", @@ -39,13 +39,14 @@ cc_library( "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", + "//ffi_utils:status_macros", "//willow/proto/willow:input_spec_cc_proto", ], ) cc_test( - name = "explicit_codec_test", - srcs = ["explicit_codec_test.cc"], + name = "codec_test", + srcs = ["codec_test.cc"], deps = [ ":codec", "@googletest//:gtest_main", diff --git a/willow/input_encoding/codec_factory.cc b/willow/input_encoding/codec.cc similarity index 61% rename from willow/input_encoding/codec_factory.cc rename to willow/input_encoding/codec.cc index b6e8ec3..9755ded 100644 --- a/willow/input_encoding/codec_factory.cc +++ b/willow/input_encoding/codec.cc @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "willow/input_encoding/codec_factory.h" +#include "willow/input_encoding/codec.h" #include #include #include +#include #include #include #include @@ -28,7 +29,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "willow/input_encoding/codec.h" +#include "absl/strings/string_view.h" +#include "ffi_utils/status_macros.h" #include "willow/proto/willow/input_spec.pb.h" namespace secure_aggregation { @@ -53,13 +55,17 @@ struct GroupDomainKey { } }; -// WillowInputExplicitEncoder must be instantiated through the factory class -// CodecFactory. -class ExplicitCodecImpl : public Codec { +// FlatHistogramCodecImpl implements a Codec that encodes data into a dense, +// flat 1D histogram representing the Cartesian product of the group-by +// domains. +// +// It must be instantiated through the factory function +// Codec::CreateFlatHistogramCodec. +class FlatHistogramCodecImpl : public Codec { public: - ExplicitCodecImpl(const ExplicitCodecImpl&) = delete; - ExplicitCodecImpl& operator=(const ExplicitCodecImpl&) = delete; - ~ExplicitCodecImpl() override = default; + FlatHistogramCodecImpl(const FlatHistogramCodecImpl&) = delete; + FlatHistogramCodecImpl& operator=(const FlatHistogramCodecImpl&) = delete; + ~FlatHistogramCodecImpl() override = default; absl::StatusOr Encode( const GroupData& group_by_data, @@ -72,13 +78,17 @@ class ExplicitCodecImpl : public Codec { const absl::flat_hash_map& query_output_specs) const override; + absl::StatusOr GetEncodedVectorLength( + absl::string_view metric_name) const override; + private: InputSpec input_spec_; // Map of group-by key names to their specs. absl::flat_hash_map group_by_spec_map_; // Map of metric key names to their specs. absl::flat_hash_map metric_spec_map_; - std::int64_t flattened_domain_size_; + // The size of the Cartesian product of all group-by domains. + std::int64_t flat_histogram_bin_count_; // The names of the group-by keys, sorted. std::vector group_by_keys_; // The size of the string domains for each group-by key. The order is the same @@ -94,43 +104,33 @@ class ExplicitCodecImpl : public Codec { size_t GetCombinedIndex(const std::vector& indices) const; - explicit ExplicitCodecImpl( + int EncodeGroupValue(const std::string& group_name, + const std::string& value) const; + + absl::StatusOr DecodeGroupValue(absl::string_view group_name, + int group_domain_index) const; + + explicit FlatHistogramCodecImpl( InputSpec input_spec, absl::flat_hash_map group_by_spec_map, - absl::flat_hash_map metric_spec_map) + absl::flat_hash_map metric_spec_map, + std::vector group_by_keys, + std::vector group_by_domain_sizes, + absl::flat_hash_map group_by_domain_indices, + size_t flat_histogram_bin_count) : input_spec_(std::move(input_spec)), group_by_spec_map_(std::move(group_by_spec_map)), - metric_spec_map_(std::move(metric_spec_map)) { - group_by_keys_.reserve(group_by_spec_map_.size()); - // Compute sorted group-by keys. - for (const auto& [key, spec] : group_by_spec_map_) { - group_by_keys_.push_back(spec->vector_name()); - // Precompute indices into domains to allow efficient lookups. - const auto& domain = spec->domain_spec().string_values(); - for (int i = 0; i < domain.values_size(); ++i) { - group_by_domain_indices_[GroupDomainKey{spec->vector_name(), - domain.values(i)}] = i; - } - } - std::sort(group_by_keys_.begin(), group_by_keys_.end()); - // Compute the sizes of the string domains for each group-by key. - flattened_domain_size_ = 1; - group_by_domain_sizes_.reserve(group_by_keys_.size()); - for (const auto& key : group_by_keys_) { - auto spec_it = group_by_spec_map_.find(key); - CHECK(spec_it != - group_by_spec_map_.end()); // We expect the key to be present. - int domain_size = - spec_it->second->domain_spec().string_values().values_size(); - group_by_domain_sizes_.push_back(domain_size); - flattened_domain_size_ *= domain_size; - } - } - friend class CodecFactory; + metric_spec_map_(std::move(metric_spec_map)), + flat_histogram_bin_count_( + static_cast(flat_histogram_bin_count)), + group_by_keys_(std::move(group_by_keys)), + group_by_domain_sizes_(std::move(group_by_domain_sizes)), + group_by_domain_indices_(std::move(group_by_domain_indices)) {} + friend class Codec; }; -absl::Status ExplicitCodecImpl::ValidateData( +absl::Status FlatHistogramCodecImpl::ValidateData( const GroupData& group_by_data, const MetricData& metric_data) const { // Check that all vectors in metric_data and group_by_data are present in // metric_spec_map and group_by_spec_map_, respectively. This ensures that the @@ -154,8 +154,8 @@ absl::Status ExplicitCodecImpl::ValidateData( // input_spec have the same length, and that the length is greater than 0. // 2. all metric vector names in metric_data are present in input_spec. // 3. all group-by vector names in group_by_data are present in input_spec. - // 4. metric vectors have type INT64. - // 5. group-by vectors have type STRING. + // 4. metric vectors have the same type as specified in the input_spec. + // 5. group-by vectors have the same type as specified in the input_spec. // 6. all values in group_by_data are in the domain provided in the // input_spec. if (metric_data.empty()) { @@ -168,6 +168,8 @@ absl::Status ExplicitCodecImpl::ValidateData( return absl::InvalidArgumentError( "All input vectors must have length > 0."); } + + constexpr InputSpec::DataType kDataMetricType = InputSpec::INT64; for (const auto& [name, data] : metric_data) { if (data.size() != vector_size) { return absl::InvalidArgumentError( @@ -179,15 +181,15 @@ absl::Status ExplicitCodecImpl::ValidateData( "Key ", name, " found in metric_data but not in input_spec.")); } const auto& spec = it->second; - if (spec->data_type() != InputSpec::INT64) { - return absl::InvalidArgumentError( - absl::StrCat("Type mismatch for key ", name, - ": metric_data type is int64_t but input_spec type " - "is not INT64, it is ", - spec->data_type())); + if (spec->data_type() != kDataMetricType) { + return absl::InvalidArgumentError(absl::StrCat( + "Type mismatch for key ", name, ": metric_data type is ", + InputSpec::DataType_Name(kDataMetricType), " but input_spec type is ", + InputSpec::DataType_Name(spec->data_type()))); } } + constexpr InputSpec::DataType kDataGroupByType = InputSpec::STRING; for (const auto& [name, data] : group_by_data) { if (data.size() != vector_size) { return absl::InvalidArgumentError( @@ -199,11 +201,12 @@ absl::Status ExplicitCodecImpl::ValidateData( "Key ", name, " found in group_by_data but not in input_spec.")); } const auto& spec = it->second; - if (spec->data_type() != InputSpec::STRING) { - return absl::InvalidArgumentError( - absl::StrCat("Type mismatch for key ", name, - ": group_by_data type is string but input_spec type is " - "not STRING.")); + if (spec->data_type() != kDataGroupByType) { + return absl::InvalidArgumentError(absl::StrCat( + "Type mismatch for key ", name, ": group_by_data type is ", + InputSpec::DataType_Name(kDataGroupByType), + " but input_spec type is ", + InputSpec::DataType_Name(spec->data_type()))); } // Check that all values in group_by_data are in the domain provided in the // input_spec. @@ -221,7 +224,7 @@ absl::Status ExplicitCodecImpl::ValidateData( // Returns the indices of elements in individual domains/vectors of size `sizes` // that correspond to the global index `global_index` of an element of their // cartesian product. -std::vector ExplicitCodecImpl::GetIndices(int global_index) const { +std::vector FlatHistogramCodecImpl::GetIndices(int global_index) const { if (group_by_domain_sizes_.empty()) { return {}; } @@ -236,12 +239,12 @@ std::vector ExplicitCodecImpl::GetIndices(int global_index) const { } // Returns the index of an element in the cartesian product of domains of size -// `sizes`, given the indices of the elements the individual domains. +// `sizes`, given the indices of the elements in the individual domains. // E.g., if sizes = {2, 3} and indices = {1, 0}, this implies we have two // domains of size 2 and 3 respectively, and we want to find overall index of // an element that has index 1 in the first domain and index 0 in the second // domain. The function will return 1 * 3 + 0 = 3. -size_t ExplicitCodecImpl::GetCombinedIndex( +size_t FlatHistogramCodecImpl::GetCombinedIndex( const std::vector& indices) const { int64_t combined_index = 0; for (int i = 0; i < indices.size(); ++i) { @@ -251,12 +254,37 @@ size_t ExplicitCodecImpl::GetCombinedIndex( return combined_index; } -absl::StatusOr ExplicitCodecImpl::Encode( - const GroupData& group_by_data, const MetricData& metric_data) const { - if (absl::Status status = ValidateData(group_by_data, metric_data); - !status.ok()) { - return status; +// Encodes a single group-by value to its unique group-by domain index within +// its own flattened domain. +int FlatHistogramCodecImpl::EncodeGroupValue(const std::string& group_name, + const std::string& value) const { + auto spec_it = group_by_spec_map_.find(group_name); + CHECK(spec_it != group_by_spec_map_.end()); + + auto it = group_by_domain_indices_.find(GroupDomainKey{group_name, value}); + CHECK(it != group_by_domain_indices_.end()); + return it->second; +} + +// Decodes a group-by domain index to its original value within the group-by +// domain. Must be called on group_names known to be in the input spec. +absl::StatusOr FlatHistogramCodecImpl::DecodeGroupValue( + absl::string_view group_name, int group_domain_index) const { + auto spec_it = group_by_spec_map_.find(group_name); + CHECK(spec_it != group_by_spec_map_.end()); + + const auto& domain = spec_it->second->domain_spec().string_values(); + if (group_domain_index < 0 || group_domain_index >= domain.values_size()) { + return absl::InvalidArgumentError( + absl::StrCat("Index ", group_domain_index, " for key ", group_name, + " is out of bounds [0, ", domain.values_size(), ").")); } + return domain.values(group_domain_index); +} + +absl::StatusOr FlatHistogramCodecImpl::Encode( + const GroupData& group_by_data, const MetricData& metric_data) const { + SECAGG_RETURN_IF_ERROR(ValidateData(group_by_data, metric_data)); absl::flat_hash_map> result; if (group_by_keys_.empty()) { @@ -266,18 +294,15 @@ absl::StatusOr ExplicitCodecImpl::Encode( } } else { for (const auto& [metric_name, values] : metric_data) { - result[metric_name] = std::vector(flattened_domain_size_, 0); + result[metric_name] = std::vector(flat_histogram_bin_count_, 0); for (int i = 0; i < values.size(); ++i) { std::vector indices; indices.reserve(group_by_keys_.size()); for (const auto& group_name : group_by_keys_) { auto it_data = group_by_data.find(group_name); CHECK(it_data != group_by_data.end()); - auto key = it_data->second[i]; - auto it = - group_by_domain_indices_.find(GroupDomainKey{group_name, key}); - // ValidateData ensures the key exists in the domain. - indices.push_back(it->second); + const std::string& group_value = it_data->second[i]; + indices.push_back(EncodeGroupValue(group_name, group_value)); } result[metric_name][GetCombinedIndex(indices)] = values[i]; } @@ -286,7 +311,7 @@ absl::StatusOr ExplicitCodecImpl::Encode( return result; } -absl::StatusOr ExplicitCodecImpl::Decode( +absl::StatusOr FlatHistogramCodecImpl::Decode( const EncodedData& encoded_data) const { DecodedData decoded_data; @@ -310,14 +335,14 @@ absl::StatusOr ExplicitCodecImpl::Decode( return absl::InvalidArgumentError(absl::StrCat( "Key ", name, " found in encoded_data but not in input_spec.")); } - if (values.size() != flattened_domain_size_) { + if (values.size() != flat_histogram_bin_count_) { return absl::InvalidArgumentError(absl::StrCat( "Encoded data for metric ", name, " has wrong size: expected ", - flattened_domain_size_, ", got ", values.size())); + flat_histogram_bin_count_, ", got ", values.size())); } } - for (int i = 0; i < flattened_domain_size_; ++i) { + for (int i = 0; i < flat_histogram_bin_count_; ++i) { bool has_nonzero_metric = false; for (const auto& [metric_name, values] : encoded_data) { if (values[i] != 0) { @@ -329,16 +354,9 @@ absl::StatusOr ExplicitCodecImpl::Decode( std::vector indices = GetIndices(i); for (int j = 0; j < group_by_keys_.size(); ++j) { const auto& key_name = group_by_keys_[j]; - auto spec_it = group_by_spec_map_.find(key_name); - CHECK(spec_it != group_by_spec_map_.end()); - const auto& domain = spec_it->second->domain_spec().string_values(); - // Check that the index is in the domain. - if (indices[j] >= domain.values_size()) { - return absl::InvalidArgumentError(absl::StrCat( - "Index ", indices[j], " for key ", key_name, " is out of bounds ", - "[0, ", domain.values_size(), ").")); - } - decoded_data.group_data[key_name].push_back(domain.values(indices[j])); + SECAGG_ASSIGN_OR_RETURN(std::string decoded_val, + DecodeGroupValue(key_name, indices[j])); + decoded_data.group_data[key_name].push_back(decoded_val); } for (const auto& [metric_name, values] : encoded_data) { decoded_data.metric_data[metric_name].push_back(values[i]); @@ -348,7 +366,7 @@ absl::StatusOr ExplicitCodecImpl::Decode( return decoded_data; } -absl::Status ExplicitCodecImpl::ValidateExampleQuery( +absl::Status FlatHistogramCodecImpl::ValidateExampleQuery( const absl::flat_hash_map& query_output_specs) const { for (const auto& [name, type] : query_output_specs) { @@ -373,26 +391,23 @@ absl::Status ExplicitCodecImpl::ValidateExampleQuery( return absl::OkStatus(); } -absl::Status CodecFactory::ValidateExplicitCodecInputSpec( - const InputSpec& input_spec, size_t max_flattened_domain_size) { - size_t flattened_domain_size = 1; - for (const auto& spec : input_spec.group_by_vector_specs()) { - flattened_domain_size *= spec.domain_spec().string_values().values_size(); - if (max_flattened_domain_size < flattened_domain_size) { - return absl::InvalidArgumentError( - "Global output domain size exceeds maximum threshold."); - } - } - return absl::OkStatus(); -} - -absl::StatusOr> CodecFactory::CreateExplicitCodec( +absl::StatusOr> Codec::CreateFlatHistogramCodec( InputSpec input_spec) { // Check that specs include at least one metric vector. if (input_spec.metric_vector_specs().empty()) { return absl::InvalidArgumentError( "input_spec must include at least one metric vector."); } + for (const auto& spec : input_spec.group_by_vector_specs()) { + if (!spec.domain_spec().has_string_values()) { + return absl::InvalidArgumentError( + "Unsupported domain type for group-by vector"); + } + if (spec.domain_spec().string_values().values_size() == 0) { + return absl::InvalidArgumentError("String domain cannot be empty."); + } + } + // Construct maps of vector names to specs, and checks for duplicates. absl::flat_hash_map group_by_spec_map; for (const auto& spec : input_spec.group_by_vector_specs()) { @@ -408,9 +423,55 @@ absl::StatusOr> CodecFactory::CreateExplicitCodec( absl::StrCat("Duplicate vector name: ", spec.vector_name())); } } - return absl::WrapUnique(new ExplicitCodecImpl(std::move(input_spec), - std::move(group_by_spec_map), - std::move(metric_spec_map))); + + // Compute sorted group-by keys. + std::vector group_by_keys; + group_by_keys.reserve(group_by_spec_map.size()); + absl::flat_hash_map group_by_domain_indices; + for (const auto& [key, spec] : group_by_spec_map) { + group_by_keys.push_back(spec->vector_name()); + // Precompute indices into domains to allow efficient lookups. + const auto& domain = spec->domain_spec().string_values(); + for (int i = 0; i < domain.values_size(); ++i) { + group_by_domain_indices[GroupDomainKey{spec->vector_name(), + domain.values(i)}] = i; + } + } + std::sort(group_by_keys.begin(), group_by_keys.end()); + + // Compute the sizes of the string domains for each group-by key, and the + // total number of bins. + int64_t flattened_domain_size = 1; + std::vector group_by_domain_sizes; + group_by_domain_sizes.reserve(group_by_keys.size()); + for (const auto& key : group_by_keys) { + auto spec_it = group_by_spec_map.find(key); + CHECK(spec_it != + group_by_spec_map.end()); // We expect the key to be present. + int domain_size = + spec_it->second->domain_spec().string_values().values_size(); + if (flattened_domain_size > + std::numeric_limits::max() / domain_size) { + return absl::InvalidArgumentError("Flat histogram bin count overflow."); + } + group_by_domain_sizes.push_back(domain_size); + flattened_domain_size *= domain_size; + } + + return absl::WrapUnique(new FlatHistogramCodecImpl( + std::move(input_spec), std::move(group_by_spec_map), + std::move(metric_spec_map), std::move(group_by_keys), + std::move(group_by_domain_sizes), std::move(group_by_domain_indices), + flattened_domain_size)); +} + +absl::StatusOr FlatHistogramCodecImpl::GetEncodedVectorLength( + absl::string_view metric_name) const { + if (!metric_spec_map_.contains(metric_name)) { + return absl::InvalidArgumentError( + absl::StrCat("Metric ", metric_name, " not found in input spec.")); + } + return static_cast(flat_histogram_bin_count_); } } // namespace willow diff --git a/willow/input_encoding/codec.h b/willow/input_encoding/codec.h index 540494f..9f5df94 100644 --- a/willow/input_encoding/codec.h +++ b/willow/input_encoding/codec.h @@ -15,17 +15,27 @@ #ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_H_ #define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_H_ +#include #include +#include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "ffi_utils/status_macros.h" +#include "willow/proto/willow/input_spec.pb.h" namespace secure_aggregation { namespace willow { +// The maximum number of bins in a flat histogram, which is the maximum size of +// the Cartesian product of domains for string features. +constexpr size_t kMaxFlatHistogramBins = 1000000; + using MetricData = absl::flat_hash_map>; using GroupData = absl::flat_hash_map>; using EncodedData = absl::flat_hash_map>; @@ -69,6 +79,45 @@ class Codec { virtual absl::Status ValidateExampleQuery( const absl::flat_hash_map& query_output_specs) const = 0; + + // Returns the length of the encoded vector for the given metric. + // Returns an InvalidArgument error if the metric is not found in the spec. + virtual absl::StatusOr GetEncodedVectorLength( + absl::string_view metric_name) const = 0; + + // Creates an instance of FlatHistogramCodec. + static absl::StatusOr> CreateFlatHistogramCodec( + ::secure_aggregation::willow::InputSpec input_spec); + + // Deprecated aliases for backward compatibility + [[deprecated("Use CreateFlatHistogramCodec instead")]] + static absl::StatusOr> CreateExplicitCodec( + ::secure_aggregation::willow::InputSpec input_spec) { + return CreateFlatHistogramCodec(std::move(input_spec)); + } + + [[deprecated( + "Use CreateFlatHistogramCodec and GetEncodedVectorLength instead")]] + static absl::Status ValidateExplicitCodecInputSpec( + const ::secure_aggregation::willow::InputSpec& input_spec, + size_t max_flattened_domain_size = kMaxFlatHistogramBins) { + // Creating a codec allocates memory, compared to just validating the input + // spec, but the memory allocation is just proportional to the + // size of the input spec. + SECAGG_ASSIGN_OR_RETURN(std::unique_ptr codec, + CreateFlatHistogramCodec(input_spec)); + if (!input_spec.metric_vector_specs().empty()) { + std::string first_metric = + input_spec.metric_vector_specs(0).vector_name(); + SECAGG_ASSIGN_OR_RETURN(size_t length, + codec->GetEncodedVectorLength(first_metric)); + if (length > max_flattened_domain_size) { + return absl::InvalidArgumentError( + "Flat histogram bin count exceeds maximum threshold."); + } + } + return absl::OkStatus(); + } }; } // namespace willow diff --git a/willow/input_encoding/codec_bindings.cc b/willow/input_encoding/codec_bindings.cc index d076527..a71e296 100644 --- a/willow/input_encoding/codec_bindings.cc +++ b/willow/input_encoding/codec_bindings.cc @@ -23,7 +23,6 @@ #include "pybind11/stl.h" #include "pybind11_abseil/status_casters.h" #include "willow/input_encoding/codec.h" -#include "willow/input_encoding/codec_factory.h" #include "willow/proto/willow/input_spec.pb.h" namespace py = pybind11; @@ -44,8 +43,23 @@ PYBIND11_MODULE(codec_bindings, m) { item.second.cast(); } return self.ValidateExampleQuery(cpp_map); - }); + }) + .def("GetEncodedVectorLength", &Codec::GetEncodedVectorLength, + py::arg("metric_name")); + m.def( + "CreateFlatHistogramCodec", + [](const std::string& serialized_input_spec) + -> absl::StatusOr> { + InputSpec input_spec; + if (!input_spec.ParseFromString(serialized_input_spec)) { + return absl::InvalidArgumentError("Failed to parse InputSpec"); + } + return Codec::CreateFlatHistogramCodec(input_spec); + }, + py::arg("serialized_input_spec")); + + // DEPRECATED: Use CreateFlatHistogramCodec instead. m.def( "CreateExplicitCodec", [](const std::string& serialized_input_spec) @@ -54,10 +68,12 @@ PYBIND11_MODULE(codec_bindings, m) { if (!input_spec.ParseFromString(serialized_input_spec)) { return absl::InvalidArgumentError("Failed to parse InputSpec"); } - return CodecFactory::CreateExplicitCodec(input_spec); + return Codec::CreateFlatHistogramCodec(input_spec); }, py::arg("serialized_input_spec")); + // DEPRECATED: Use CreateFlatHistogramCodec and Codec.GetEncodedVectorLength + // instead. m.def( "ValidateExplicitCodecInputSpec", [](const std::string& serialized_input_spec, @@ -67,9 +83,9 @@ PYBIND11_MODULE(codec_bindings, m) { return absl::InvalidArgumentError("Failed to parse InputSpec"); } if (max_flattened_domain_size == 0) { - return CodecFactory::ValidateExplicitCodecInputSpec(input_spec); + return Codec::ValidateExplicitCodecInputSpec(input_spec); } else { - return CodecFactory::ValidateExplicitCodecInputSpec( + return Codec::ValidateExplicitCodecInputSpec( input_spec, max_flattened_domain_size); } }, diff --git a/willow/input_encoding/codec_bindings_test.py b/willow/input_encoding/codec_bindings_test.py index 7d77fc3..959e83d 100644 --- a/willow/input_encoding/codec_bindings_test.py +++ b/willow/input_encoding/codec_bindings_test.py @@ -31,11 +31,19 @@ def setUp(self): ], group_by_vector_specs=[ input_spec_pb2.InputSpec.InputVectorSpec( - vector_name="group1", data_type=input_spec_pb2.InputSpec.STRING + vector_name="group1", + data_type=input_spec_pb2.InputSpec.STRING, + domain_spec=input_spec_pb2.InputSpec.DomainSpec( + string_values=input_spec_pb2.InputSpec.StringValues( + values=["a"] + ) + ), ) ], ) - self.codec = codec_bindings.CreateExplicitCodec(spec.SerializeToString()) + self.codec = codec_bindings.CreateFlatHistogramCodec( + spec.SerializeToString() + ) def test_validate_example_query_success(self): query_specs = { @@ -64,6 +72,15 @@ def test_validate_example_query_vector_not_found(self): ): self.codec.ValidateExampleQuery(query_specs) + def test_get_encoded_vector_length_success(self): + self.assertEqual(self.codec.GetEncodedVectorLength("metric1"), 1) + + def test_get_encoded_vector_length_metric_not_found(self): + with self.assertRaisesRegex( + py_status.StatusNotOk, "not found in input spec" + ): + self.codec.GetEncodedVectorLength("unknown_metric") + if __name__ == "__main__": unittest.main() diff --git a/willow/input_encoding/codec_factory.h b/willow/input_encoding/codec_factory.h index 5ef4bb6..d8c3bad 100644 --- a/willow/input_encoding/codec_factory.h +++ b/willow/input_encoding/codec_factory.h @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,8 +14,10 @@ #ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_ #define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_ + #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -25,22 +27,19 @@ namespace secure_aggregation { namespace willow { -// The maximum size of the Cartesian product of domains for string features. -constexpr size_t kMaxGlobalOutputDomainSize = 1000000; - -// Factory class that constructs non-copyable instances of children classes of -// Codec. -class CodecFactory { +class [[deprecated("Use Codec class static methods instead")]] CodecFactory { public: - // Creates an instance of ExplicitCodec. static absl::StatusOr> CreateExplicitCodec( - InputSpec input_spec); + InputSpec input_spec) { + return Codec::CreateFlatHistogramCodec(std::move(input_spec)); + } - // Check that the combined size of the string domains is less than the - // maximum allowed size. static absl::Status ValidateExplicitCodecInputSpec( const InputSpec& input_spec, - size_t max_flattened_domain_size = kMaxGlobalOutputDomainSize); + size_t max_flattened_domain_size = kMaxFlatHistogramBins) { + return Codec::ValidateExplicitCodecInputSpec(input_spec, + max_flattened_domain_size); + } }; } // namespace willow diff --git a/willow/input_encoding/explicit_codec_test.cc b/willow/input_encoding/codec_test.cc similarity index 79% rename from willow/input_encoding/explicit_codec_test.cc rename to willow/input_encoding/codec_test.cc index eb46406..784a2be 100644 --- a/willow/input_encoding/explicit_codec_test.cc +++ b/willow/input_encoding/codec_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "willow/input_encoding/codec.h" + #include #include #include @@ -21,8 +23,6 @@ #include "ffi_utils/status_matchers.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "willow/input_encoding/codec.h" -#include "willow/input_encoding/codec_factory.h" #include "willow/proto/willow/input_spec.pb.h" #include "willow/testing_utils/testing_utils.h" @@ -37,7 +37,7 @@ using ::testing::HasSubstr; using ::testing::Pair; using ::testing::UnorderedElementsAre; -TEST(CodecFactoryTest, ValidateInputAndSpecLengthMismatch) { +TEST(CodecTest, ValidateInputAndSpecLengthMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; GroupData group_by_data; @@ -49,14 +49,14 @@ TEST(CodecFactoryTest, ValidateInputAndSpecLengthMismatch) { // Missing group_by_vector_specs for "feature1" SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Key feature1 found in group_by_data but not " "in input_spec."))); } -TEST(CodecFactoryTest, ValidateInputAndSpecTypeMismatch) { +TEST(CodecTest, ValidateInputAndSpecTypeMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; GroupData group_by_data; @@ -66,13 +66,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecTypeMismatch) { metric_spec->set_data_type(InputSpec::STRING); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Type mismatch for key metric1"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecEmptyInputData) { +TEST(CodecTest, ValidateInputAndSpecEmptyInputData) { MetricData metric_data; GroupData group_by_data; InputSpec input_spec; @@ -82,15 +82,17 @@ TEST(CodecFactoryTest, ValidateInputAndSpecEmptyInputData) { auto* group_by_spec = input_spec.add_group_by_vector_specs(); group_by_spec->set_vector_name("feature1"); group_by_spec->set_data_type(InputSpec::STRING); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "a"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Metric data cannot be empty."))); } -TEST(CodecFactoryTest, ValidateInputAndSpecDomainValueNotFound) { +TEST(CodecTest, ValidateInputAndSpecDomainValueNotFound) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -108,13 +110,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecDomainValueNotFound) { "b"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Domain mismatch for key feature1"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecInputDataVectorLengthMismatch) { +TEST(CodecTest, ValidateInputAndSpecInputDataVectorLengthMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; metric_data["metric2"] = {1, 2}; @@ -128,13 +130,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecInputDataVectorLengthMismatch) { metric_spec2->set_data_type(InputSpec::INT64); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("must have the same length"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecGroupByDataVectorLengthMismatch) { +TEST(CodecTest, ValidateInputAndSpecGroupByDataVectorLengthMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; GroupData group_by_data; @@ -152,13 +154,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGroupByDataVectorLengthMismatch) { "b"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("must have the same length"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecDomainSizeVectorLengthMismatch) { +TEST(CodecTest, ValidateInputAndSpecDomainSizeVectorLengthMismatch) { MetricData metric_data; metric_data["metric1"] = {1, 2, 3}; GroupData group_by_data; @@ -177,14 +179,14 @@ TEST(CodecFactoryTest, ValidateInputAndSpecDomainSizeVectorLengthMismatch) { "b"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Domain mismatch for key feature1: " "group_by_data value c not found in domain"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecInputKeyNotInSpec) { +TEST(CodecTest, ValidateInputAndSpecInputKeyNotInSpec) { MetricData metric_data; metric_data["metric1"] = {1}; metric_data["metric2"] = {2}; @@ -197,14 +199,14 @@ TEST(CodecFactoryTest, ValidateInputAndSpecInputKeyNotInSpec) { // Missing metric_vector_specs for "metric2" SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Key metric2 found in metric_data but not in " "input_spec."))); } -TEST(CodecFactoryTest, ValidateInputAndSpecGroupByKeyNotInSpec) { +TEST(CodecTest, ValidateInputAndSpecGroupByKeyNotInSpec) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -217,14 +219,14 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGroupByKeyNotInSpec) { // Missing group_by_vector_specs for "feature1" SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Key feature1 found in group_by_data but not " "in input_spec."))); } -TEST(CodecFactoryTest, ValidateInputAndSpecGroupByTypeMismatch) { +TEST(CodecTest, ValidateInputAndSpecGroupByTypeMismatch) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -241,13 +243,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGroupByTypeMismatch) { "y"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Type mismatch for key feature1"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecGlobalDomainSizeExceeded) { +TEST(CodecTest, ValidateInputAndSpecMaxFlatHistogramBinsExceeded) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -270,12 +272,12 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGlobalDomainSizeExceeded) { ->add_values(std::to_string(i)); } - EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec), + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Global output domain size exceeds"))); + HasSubstr("Flat histogram bin count exceeds"))); } -TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) { +TEST(CodecTest, ValidateInputAndSpecCustomMaxFlatHistogramBins) { MetricData metric_data; metric_data["metric1"] = {1}; GroupData group_by_data; @@ -292,13 +294,13 @@ TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) { group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( "b"); // Domain size is 2. - EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 1), + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec, 1), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Global output domain size exceeds"))); - SECAGG_EXPECT_OK(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 2)); + HasSubstr("Flat histogram bin count exceeds"))); + SECAGG_EXPECT_OK(Codec::ValidateExplicitCodecInputSpec(input_spec, 2)); } -TEST(CodecFactoryTest, EncodeSimpleGroupBy) { +TEST(CodecTest, EncodeSimpleGroupBy) { InputSpec input_spec = CreateTestInputSpecProto(); MetricData metric_data = CreateTestMetricData(); GroupData group_by_data = CreateTestGroupData(); @@ -326,14 +328,14 @@ TEST(CodecFactoryTest, EncodeSimpleGroupBy) { // Result: [0, 20, 0, 0, 0, 0, 10, 5] SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), IsOkAndHolds(UnorderedElementsAre( Pair("metric1", ElementsAre(0, 20, 0, 0, 0, 0, 10, 5))))); } -TEST(CodecFactoryTest, EncodeTwoMetricsOneGroupBy) { +TEST(CodecTest, EncodeTwoMetricsOneGroupBy) { MetricData metric_data; metric_data["metric1"] = {10, 20}; metric_data["metric2"] = {100, 200}; @@ -370,7 +372,7 @@ TEST(CodecFactoryTest, EncodeTwoMetricsOneGroupBy) { // metric2: [200, 100] SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), IsOkAndHolds(UnorderedElementsAre( @@ -378,13 +380,13 @@ TEST(CodecFactoryTest, EncodeTwoMetricsOneGroupBy) { Pair("metric2", ElementsAre(200, 100))))); } -TEST(CodecFactoryTest, EncodeThenDecode) { +TEST(CodecTest, EncodeThenDecode) { InputSpec input_spec = CreateTestInputSpecProto(); MetricData metric_data = CreateTestMetricData(); GroupData group_by_data = CreateTestGroupData(); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, encoder->Encode(group_by_data, metric_data)); @@ -409,13 +411,13 @@ TEST(CodecFactoryTest, EncodeThenDecode) { Pair("lang", ElementsAre("es", "en", "es")))); } -TEST(CodecFactoryTest, EncodeThenDecodeDataOrderDoesNotMatter) { +TEST(CodecTest, EncodeThenDecodeDataOrderDoesNotMatter) { InputSpec input_spec = CreateTestInputSpecProto(); MetricData metric_data1 = CreateTestMetricData(); GroupData group_by_data1 = CreateTestGroupData(); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder1, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); // Note that the order of metric_data2 and group_by_data2 is different from // metric_data1 and group_by_data1. The decoded result should be the same. @@ -426,7 +428,7 @@ TEST(CodecFactoryTest, EncodeThenDecodeDataOrderDoesNotMatter) { group_by_data2["country"] = {"CA", "US", "US"}; SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder2, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); SECAGG_ASSERT_OK_AND_ASSIGN(auto encoded_data, encoder1->Encode(group_by_data1, metric_data1)); @@ -444,7 +446,7 @@ TEST(CodecFactoryTest, EncodeThenDecodeDataOrderDoesNotMatter) { Pair("lang", ElementsAre("es", "en", "es")))); } -TEST(CodecFactoryTest, EncodeThenDecodeNoGroupBy) { +TEST(CodecTest, EncodeThenDecodeNoGroupBy) { MetricData metric_data; metric_data["metric1"] = {10, 20, 5}; MetricData expected_metric_data = metric_data; @@ -455,7 +457,7 @@ TEST(CodecFactoryTest, EncodeThenDecodeNoGroupBy) { metric_spec->set_data_type(InputSpec::INT64); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, encoder->Encode(group_by_data, metric_data)); @@ -470,7 +472,7 @@ TEST(CodecFactoryTest, EncodeThenDecodeNoGroupBy) { UnorderedElementsAre(Pair("metric1", ElementsAre(10, 20, 5)))); } -TEST(CodecFactoryTest, EncodeWithDomainValueNotFound) { +TEST(CodecTest, EncodeWithDomainValueNotFound) { MetricData metric_data; metric_data["metric1"] = {10}; GroupData group_by_data; @@ -505,7 +507,7 @@ TEST(CodecFactoryTest, EncodeWithDomainValueNotFound) { "b"); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); + Codec::CreateFlatHistogramCodec(input_spec)); EXPECT_THAT(encoder->Encode(group_by_data, metric_data), StatusIs(absl::StatusCode::kInvalidArgument, @@ -513,6 +515,58 @@ TEST(CodecFactoryTest, EncodeWithDomainValueNotFound) { "group_by_data value c not found in domain"))); } +TEST(CodecTest, CreateCodecUnsupportedDomain) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("group1"); + group_by_spec->set_data_type(InputSpec::STRING); + + // Set unsupported interval domain on a group-by vector + group_by_spec->mutable_domain_spec()->mutable_interval()->set_min(0); + group_by_spec->mutable_domain_spec()->mutable_interval()->set_max(10); + + EXPECT_THAT(Codec::CreateFlatHistogramCodec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported domain type"))); +} + +TEST(CodecTest, GetEncodedVectorLengthSuccess) { + InputSpec input_spec = CreateTestInputSpecProto(); + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr codec, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(codec->GetEncodedVectorLength("metric1"), IsOkAndHolds(8)); + EXPECT_THAT(codec->GetEncodedVectorLength("unknown_metric"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not found in input spec"))); +} + +TEST(CodecTest, CreateCodecOverflow) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + // Add 8 domains of size 2^10. (2^10)^8 = 2^80 > 2^64, which overflows + // int64_t. + for (int j = 0; j < 8; ++j) { + auto* spec = input_spec.add_group_by_vector_specs(); + spec->set_vector_name("group" + std::to_string(j)); + spec->set_data_type(InputSpec::STRING); + for (int i = 0; i < 1024; ++i) { + spec->mutable_domain_spec()->mutable_string_values()->add_values( + std::to_string(i)); + } + } + + EXPECT_THAT( + Codec::CreateFlatHistogramCodec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("overflow"))); +} + } // namespace } // namespace willow } // namespace secure_aggregation