From 2793d1167bf495be91c5bd379ce350545cc7972c Mon Sep 17 00:00:00 2001 From: Pierre Tholoniat Date: Thu, 28 May 2026 10:00:06 -0700 Subject: [PATCH] Add time domain to input_spec proto and implement modular timestamp bucketing in input_encoding. PiperOrigin-RevId: 922833191 --- willow/api/client.cc | 24 +- willow/input_encoding/BUILD | 46 +- .../{codec_factory.cc => codec.cc} | 343 +++++-- willow/input_encoding/codec.h | 60 ++ willow/input_encoding/codec_bindings.cc | 31 +- willow/input_encoding/codec_bindings_test.py | 21 +- willow/input_encoding/codec_factory.h | 38 +- willow/input_encoding/codec_test.cc | 966 ++++++++++++++++++ willow/input_encoding/explicit_codec_test.cc | 518 ---------- willow/input_encoding/time_utils.cc | 167 +++ willow/input_encoding/time_utils.h | 64 ++ willow/input_encoding/time_utils_test.cc | 159 +++ willow/proto/willow/BUILD | 4 + willow/proto/willow/input_spec.proto | 70 ++ 14 files changed, 1853 insertions(+), 658 deletions(-) rename willow/input_encoding/{codec_factory.cc => codec.cc} (53%) create mode 100644 willow/input_encoding/codec_test.cc delete mode 100644 willow/input_encoding/explicit_codec_test.cc create mode 100644 willow/input_encoding/time_utils.cc create mode 100644 willow/input_encoding/time_utils.h create mode 100644 willow/input_encoding/time_utils_test.cc diff --git a/willow/api/client.cc b/willow/api/client.cc index b36ed42..51c21d0 100644 --- a/willow/api/client.cc +++ b/willow/api/client.cc @@ -14,6 +14,7 @@ #include "willow/api/client.h" +#include #include #include #include @@ -22,7 +23,6 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "ffi_utils/cxx_utils.h" #include "ffi_utils/status_macros.h" @@ -45,24 +45,18 @@ absl::StatusOr CreateAggregationConfig( config_proto.set_max_number_of_decryptors(max_number_of_decryptors); config_proto.set_max_decryptor_dropouts(max_decryptor_dropouts); config_proto.set_key_id(std::string(key_id)); - // All metrics have same vector length, corresponding to the Cartesian product - // of group-by domains. - int64_t flattened_domain_size = 1; - for (const auto& group_by_spec : input_spec_proto.group_by_vector_specs()) { - if (group_by_spec.domain_spec().string_values().values_size() == 0) { - return absl::InvalidArgumentError(absl::StrCat( - "Missing domain, invalid domain type (must be StringValues), or " - "empty string_values for group by vector: ", - group_by_spec.vector_name())); - } - flattened_domain_size *= - group_by_spec.domain_spec().string_values().values_size(); - } + + // Validate the input spec and create the codec (for vector length). + SECAGG_ASSIGN_OR_RETURN( + auto codec, willow::Codec::CreateFlatHistogramCodec(input_spec_proto)); + // Build VectorConfig (length and bound) for each metric. for (const auto& metric_spec : input_spec_proto.metric_vector_specs()) { auto& vector_config = (*config_proto.mutable_vector_configs())[metric_spec.vector_name()]; - vector_config.set_length(flattened_domain_size); + SECAGG_ASSIGN_OR_RETURN(size_t length, codec->GetEncodedVectorLength( + metric_spec.vector_name())); + vector_config.set_length(length); if (metric_spec.has_domain_spec() && metric_spec.domain_spec().has_interval()) { vector_config.set_bound( diff --git a/willow/input_encoding/BUILD b/willow/input_encoding/BUILD index 48bd377..b631ba9 100644 --- a/willow/input_encoding/BUILD +++ b/willow/input_encoding/BUILD @@ -23,40 +23,80 @@ package( default_visibility = ["//visibility:public"], ) +cc_library( + name = "time_utils", + srcs = [ + "time_utils.cc", + ], + hdrs = [ + "time_utils.h", + ], + deps = [ + "@protobuf//:duration_cc_proto", + "@protobuf//:timestamp_cc_proto", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/time", + "//willow/proto/willow:input_spec_cc_proto", + ], +) + cc_library( name = "codec", srcs = [ - "codec_factory.cc", + "codec.cc", ], hdrs = [ "codec.h", "codec_factory.h", ], deps = [ + ":time_utils", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/memory", "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/time", + "//ffi_utils:status_macros", "//willow/proto/willow:input_spec_cc_proto", ], ) cc_test( - name = "explicit_codec_test", - srcs = ["explicit_codec_test.cc"], + name = "codec_test", + srcs = ["codec_test.cc"], deps = [ ":codec", + ":time_utils", + "@protobuf//:duration_cc_proto", "@googletest//:gtest_main", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/status", + "@abseil-cpp//absl/time", "//ffi_utils:status_matchers", "//willow/proto/willow:input_spec_cc_proto", "//willow/testing_utils:testing_utils_cc", ], ) +cc_test( + name = "time_utils_test", + srcs = ["time_utils_test.cc"], + deps = [ + ":time_utils", + "@protobuf//:duration_cc_proto", + "@protobuf//:timestamp_cc_proto", + "@googletest//:gtest_main", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/time", + "//ffi_utils:status_matchers", + "//willow/proto/willow:input_spec_cc_proto", + ], +) + pytype_pybind_extension( name = "codec_bindings", srcs = ["codec_bindings.cc"], diff --git a/willow/input_encoding/codec_factory.cc b/willow/input_encoding/codec.cc similarity index 53% rename from willow/input_encoding/codec_factory.cc rename to willow/input_encoding/codec.cc index b6e8ec3..eb90f65 100644 --- a/willow/input_encoding/codec_factory.cc +++ b/willow/input_encoding/codec.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "willow/input_encoding/codec_factory.h" +#include "willow/input_encoding/codec.h" #include #include #include +#include #include +#include #include #include #include @@ -28,7 +30,10 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "willow/input_encoding/codec.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "ffi_utils/status_macros.h" +#include "willow/input_encoding/time_utils.h" #include "willow/proto/willow/input_spec.pb.h" namespace secure_aggregation { @@ -53,13 +58,17 @@ struct GroupDomainKey { } }; -// WillowInputExplicitEncoder must be instantiated through the factory class -// CodecFactory. -class ExplicitCodecImpl : public Codec { +// FlatHistogramCodecImpl implements a Codec that encodes data into a dense, +// flat 1D histogram representing the Cartesian product of the group-by +// domains. +// +// It must be instantiated through the factory function +// Codec::CreateFlatHistogramCodec. +class FlatHistogramCodecImpl : public Codec { public: - ExplicitCodecImpl(const ExplicitCodecImpl&) = delete; - ExplicitCodecImpl& operator=(const ExplicitCodecImpl&) = delete; - ~ExplicitCodecImpl() override = default; + FlatHistogramCodecImpl(const FlatHistogramCodecImpl&) = delete; + FlatHistogramCodecImpl& operator=(const FlatHistogramCodecImpl&) = delete; + ~FlatHistogramCodecImpl() override = default; absl::StatusOr Encode( const GroupData& group_by_data, @@ -72,13 +81,17 @@ class ExplicitCodecImpl : public Codec { const absl::flat_hash_map& query_output_specs) const override; + absl::StatusOr GetEncodedVectorLength( + absl::string_view metric_name) const override; + private: InputSpec input_spec_; // Map of group-by key names to their specs. absl::flat_hash_map group_by_spec_map_; // Map of metric key names to their specs. absl::flat_hash_map metric_spec_map_; - std::int64_t flattened_domain_size_; + // The size of the Cartesian product of all group-by domains. + std::int64_t flat_histogram_bin_count_; // The names of the group-by keys, sorted. std::vector group_by_keys_; // The size of the string domains for each group-by key. The order is the same @@ -86,6 +99,9 @@ class ExplicitCodecImpl : public Codec { std::vector group_by_domain_sizes_; // The indices within their respective domain of each group-by key. absl::flat_hash_map group_by_domain_indices_; + std::optional reference_time_; + // Map of group-by key names to their TimeDomainInfo if they are time domains. + absl::flat_hash_map time_domains_; absl::Status ValidateData(const GroupData& group_by_data, const MetricData& metric_data) const; @@ -94,43 +110,37 @@ class ExplicitCodecImpl : public Codec { size_t GetCombinedIndex(const std::vector& indices) const; - explicit ExplicitCodecImpl( + int EncodeGroupValue(const std::string& group_name, + const std::string& value) const; + + absl::StatusOr DecodeGroupValue(absl::string_view group_name, + int group_domain_index) const; + + explicit FlatHistogramCodecImpl( InputSpec input_spec, absl::flat_hash_map group_by_spec_map, - absl::flat_hash_map metric_spec_map) + absl::flat_hash_map metric_spec_map, + std::optional reference_time, + absl::flat_hash_map time_domains, + std::vector group_by_keys, + std::vector group_by_domain_sizes, + absl::flat_hash_map group_by_domain_indices, + size_t flat_histogram_bin_count) : input_spec_(std::move(input_spec)), group_by_spec_map_(std::move(group_by_spec_map)), - metric_spec_map_(std::move(metric_spec_map)) { - group_by_keys_.reserve(group_by_spec_map_.size()); - // Compute sorted group-by keys. - for (const auto& [key, spec] : group_by_spec_map_) { - group_by_keys_.push_back(spec->vector_name()); - // Precompute indices into domains to allow efficient lookups. - const auto& domain = spec->domain_spec().string_values(); - for (int i = 0; i < domain.values_size(); ++i) { - group_by_domain_indices_[GroupDomainKey{spec->vector_name(), - domain.values(i)}] = i; - } - } - std::sort(group_by_keys_.begin(), group_by_keys_.end()); - // Compute the sizes of the string domains for each group-by key. - flattened_domain_size_ = 1; - group_by_domain_sizes_.reserve(group_by_keys_.size()); - for (const auto& key : group_by_keys_) { - auto spec_it = group_by_spec_map_.find(key); - CHECK(spec_it != - group_by_spec_map_.end()); // We expect the key to be present. - int domain_size = - spec_it->second->domain_spec().string_values().values_size(); - group_by_domain_sizes_.push_back(domain_size); - flattened_domain_size_ *= domain_size; - } - } - friend class CodecFactory; + metric_spec_map_(std::move(metric_spec_map)), + flat_histogram_bin_count_( + static_cast(flat_histogram_bin_count)), + group_by_keys_(std::move(group_by_keys)), + group_by_domain_sizes_(std::move(group_by_domain_sizes)), + group_by_domain_indices_(std::move(group_by_domain_indices)), + reference_time_(reference_time), + time_domains_(std::move(time_domains)) {} + friend class Codec; }; -absl::Status ExplicitCodecImpl::ValidateData( +absl::Status FlatHistogramCodecImpl::ValidateData( const GroupData& group_by_data, const MetricData& metric_data) const { // Check that all vectors in metric_data and group_by_data are present in // metric_spec_map and group_by_spec_map_, respectively. This ensures that the @@ -154,8 +164,8 @@ absl::Status ExplicitCodecImpl::ValidateData( // input_spec have the same length, and that the length is greater than 0. // 2. all metric vector names in metric_data are present in input_spec. // 3. all group-by vector names in group_by_data are present in input_spec. - // 4. metric vectors have type INT64. - // 5. group-by vectors have type STRING. + // 4. metric vectors have the same type as specified in the input_spec. + // 5. group-by vectors have the same type as specified in the input_spec. // 6. all values in group_by_data are in the domain provided in the // input_spec. if (metric_data.empty()) { @@ -168,6 +178,8 @@ absl::Status ExplicitCodecImpl::ValidateData( return absl::InvalidArgumentError( "All input vectors must have length > 0."); } + + constexpr InputSpec::DataType kDataMetricType = InputSpec::INT64; for (const auto& [name, data] : metric_data) { if (data.size() != vector_size) { return absl::InvalidArgumentError( @@ -179,15 +191,15 @@ absl::Status ExplicitCodecImpl::ValidateData( "Key ", name, " found in metric_data but not in input_spec.")); } const auto& spec = it->second; - if (spec->data_type() != InputSpec::INT64) { - return absl::InvalidArgumentError( - absl::StrCat("Type mismatch for key ", name, - ": metric_data type is int64_t but input_spec type " - "is not INT64, it is ", - spec->data_type())); + if (spec->data_type() != kDataMetricType) { + return absl::InvalidArgumentError(absl::StrCat( + "Type mismatch for key ", name, ": metric_data type is ", + InputSpec::DataType_Name(kDataMetricType), " but input_spec type is ", + InputSpec::DataType_Name(spec->data_type()))); } } + constexpr InputSpec::DataType kDataGroupByType = InputSpec::STRING; for (const auto& [name, data] : group_by_data) { if (data.size() != vector_size) { return absl::InvalidArgumentError( @@ -199,19 +211,31 @@ absl::Status ExplicitCodecImpl::ValidateData( "Key ", name, " found in group_by_data but not in input_spec.")); } const auto& spec = it->second; - if (spec->data_type() != InputSpec::STRING) { - return absl::InvalidArgumentError( - absl::StrCat("Type mismatch for key ", name, - ": group_by_data type is string but input_spec type is " - "not STRING.")); + if (spec->data_type() != kDataGroupByType) { + return absl::InvalidArgumentError(absl::StrCat( + "Type mismatch for key ", name, ": group_by_data type is ", + InputSpec::DataType_Name(kDataGroupByType), + " but input_spec type is ", + InputSpec::DataType_Name(spec->data_type()))); } // Check that all values in group_by_data are in the domain provided in the // input_spec. for (const auto& d : data) { - if (!group_by_domain_indices_.contains(GroupDomainKey{name, d})) { - return absl::InvalidArgumentError( - absl::StrCat("Domain mismatch for key ", name, - ": group_by_data value ", d, " not found in domain.")); + if (spec->domain_spec().has_time()) { + const auto& ts_info = time_domains_.at(name); + absl::Time t; + std::string err; + if (!absl::ParseTime(ts_info.format, d, ts_info.timezone, &t, &err)) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse timestamp value '", d, "' for key ", + name, ": ", err)); + } + } else { + if (!group_by_domain_indices_.contains(GroupDomainKey{name, d})) { + return absl::InvalidArgumentError(absl::StrCat( + "Domain mismatch for key ", name, ": group_by_data value ", d, + " not found in domain.")); + } } } } @@ -221,7 +245,7 @@ absl::Status ExplicitCodecImpl::ValidateData( // Returns the indices of elements in individual domains/vectors of size `sizes` // that correspond to the global index `global_index` of an element of their // cartesian product. -std::vector ExplicitCodecImpl::GetIndices(int global_index) const { +std::vector FlatHistogramCodecImpl::GetIndices(int global_index) const { if (group_by_domain_sizes_.empty()) { return {}; } @@ -236,12 +260,12 @@ std::vector ExplicitCodecImpl::GetIndices(int global_index) const { } // Returns the index of an element in the cartesian product of domains of size -// `sizes`, given the indices of the elements the individual domains. +// `sizes`, given the indices of the elements in the individual domains. // E.g., if sizes = {2, 3} and indices = {1, 0}, this implies we have two // domains of size 2 and 3 respectively, and we want to find overall index of // an element that has index 1 in the first domain and index 0 in the second // domain. The function will return 1 * 3 + 0 = 3. -size_t ExplicitCodecImpl::GetCombinedIndex( +size_t FlatHistogramCodecImpl::GetCombinedIndex( const std::vector& indices) const { int64_t combined_index = 0; for (int i = 0; i < indices.size(); ++i) { @@ -251,11 +275,63 @@ size_t ExplicitCodecImpl::GetCombinedIndex( return combined_index; } -absl::StatusOr ExplicitCodecImpl::Encode( +// Encodes a single group-by value to its unique group-by domain index within +// its own flattened domain. +int FlatHistogramCodecImpl::EncodeGroupValue(const std::string& group_name, + const std::string& value) const { + auto spec_it = group_by_spec_map_.find(group_name); + CHECK(spec_it != group_by_spec_map_.end()); + if (spec_it->second->domain_spec().has_time()) { + const auto& ts_info = time_domains_.at(group_name); + absl::Time t; + CHECK( + absl::ParseTime(ts_info.format, value, ts_info.timezone, &t, nullptr)); + return EncodeTime(t, ts_info, *reference_time_); + } else { + auto it = group_by_domain_indices_.find(GroupDomainKey{group_name, value}); + CHECK(it != group_by_domain_indices_.end()); + return it->second; + } +} + +// Decodes a group-by domain index to its original value within the group-by +// domain. Must be called on group_names known to be in the input spec. +absl::StatusOr FlatHistogramCodecImpl::DecodeGroupValue( + absl::string_view group_name, int group_domain_index) const { + auto spec_it = group_by_spec_map_.find(group_name); + CHECK(spec_it != group_by_spec_map_.end()); + + if (spec_it->second->domain_spec().has_time()) { + const auto& ts_info = time_domains_.at(group_name); + absl::Time reconstructed_time; + if (group_domain_index == ts_info.num_periods) { + reconstructed_time = DefaultTimestamp(); + } else { + SECAGG_ASSIGN_OR_RETURN( + reconstructed_time, + DecodeTime(group_domain_index, ts_info, *reference_time_)); + } + return absl::FormatTime(ts_info.format, reconstructed_time, + ts_info.timezone); + } else { + const auto& domain = spec_it->second->domain_spec().string_values(); + if (group_domain_index < 0 || group_domain_index >= domain.values_size()) { + return absl::InvalidArgumentError( + absl::StrCat("Index ", group_domain_index, " for key ", group_name, + " is out of bounds [0, ", domain.values_size(), ").")); + } + return domain.values(group_domain_index); + } +} + +absl::StatusOr FlatHistogramCodecImpl::Encode( const GroupData& group_by_data, const MetricData& metric_data) const { - if (absl::Status status = ValidateData(group_by_data, metric_data); - !status.ok()) { - return status; + SECAGG_RETURN_IF_ERROR(ValidateData(group_by_data, metric_data)); + + if (!time_domains_.empty() && !reference_time_.has_value()) { + return absl::FailedPreconditionError( + "reference_time is required in the constructor for encoding time " + "domains"); } absl::flat_hash_map> result; @@ -266,18 +342,15 @@ absl::StatusOr ExplicitCodecImpl::Encode( } } else { for (const auto& [metric_name, values] : metric_data) { - result[metric_name] = std::vector(flattened_domain_size_, 0); + result[metric_name] = std::vector(flat_histogram_bin_count_, 0); for (int i = 0; i < values.size(); ++i) { std::vector indices; indices.reserve(group_by_keys_.size()); for (const auto& group_name : group_by_keys_) { auto it_data = group_by_data.find(group_name); CHECK(it_data != group_by_data.end()); - auto key = it_data->second[i]; - auto it = - group_by_domain_indices_.find(GroupDomainKey{group_name, key}); - // ValidateData ensures the key exists in the domain. - indices.push_back(it->second); + const std::string& group_value = it_data->second[i]; + indices.push_back(EncodeGroupValue(group_name, group_value)); } result[metric_name][GetCombinedIndex(indices)] = values[i]; } @@ -286,15 +359,20 @@ absl::StatusOr ExplicitCodecImpl::Encode( return result; } -absl::StatusOr ExplicitCodecImpl::Decode( +absl::StatusOr FlatHistogramCodecImpl::Decode( const EncodedData& encoded_data) const { + if (!time_domains_.empty() && !reference_time_.has_value()) { + return absl::FailedPreconditionError( + "reference_time is required in the constructor for decoding time " + "domains"); + } DecodedData decoded_data; if (group_by_keys_.empty()) { // No group-by, so decoded metrics are just the encoded data. decoded_data.metric_data = encoded_data; - // Check if all encoded vectors have the same size. if (encoded_data.empty()) return decoded_data; + // Check if all encoded vectors have the same size. size_t expected_size = encoded_data.begin()->second.size(); for (const auto& [name, values] : encoded_data) { if (values.size() != expected_size) { @@ -310,14 +388,14 @@ absl::StatusOr ExplicitCodecImpl::Decode( return absl::InvalidArgumentError(absl::StrCat( "Key ", name, " found in encoded_data but not in input_spec.")); } - if (values.size() != flattened_domain_size_) { + if (values.size() != flat_histogram_bin_count_) { return absl::InvalidArgumentError(absl::StrCat( "Encoded data for metric ", name, " has wrong size: expected ", - flattened_domain_size_, ", got ", values.size())); + flat_histogram_bin_count_, ", got ", values.size())); } } - for (int i = 0; i < flattened_domain_size_; ++i) { + for (int i = 0; i < flat_histogram_bin_count_; ++i) { bool has_nonzero_metric = false; for (const auto& [metric_name, values] : encoded_data) { if (values[i] != 0) { @@ -329,16 +407,9 @@ absl::StatusOr ExplicitCodecImpl::Decode( std::vector indices = GetIndices(i); for (int j = 0; j < group_by_keys_.size(); ++j) { const auto& key_name = group_by_keys_[j]; - auto spec_it = group_by_spec_map_.find(key_name); - CHECK(spec_it != group_by_spec_map_.end()); - const auto& domain = spec_it->second->domain_spec().string_values(); - // Check that the index is in the domain. - if (indices[j] >= domain.values_size()) { - return absl::InvalidArgumentError(absl::StrCat( - "Index ", indices[j], " for key ", key_name, " is out of bounds ", - "[0, ", domain.values_size(), ").")); - } - decoded_data.group_data[key_name].push_back(domain.values(indices[j])); + SECAGG_ASSIGN_OR_RETURN(std::string decoded_val, + DecodeGroupValue(key_name, indices[j])); + decoded_data.group_data[key_name].push_back(decoded_val); } for (const auto& [metric_name, values] : encoded_data) { decoded_data.metric_data[metric_name].push_back(values[i]); @@ -348,7 +419,7 @@ absl::StatusOr ExplicitCodecImpl::Decode( return decoded_data; } -absl::Status ExplicitCodecImpl::ValidateExampleQuery( +absl::Status FlatHistogramCodecImpl::ValidateExampleQuery( const absl::flat_hash_map& query_output_specs) const { for (const auto& [name, type] : query_output_specs) { @@ -373,26 +444,29 @@ absl::Status ExplicitCodecImpl::ValidateExampleQuery( return absl::OkStatus(); } -absl::Status CodecFactory::ValidateExplicitCodecInputSpec( - const InputSpec& input_spec, size_t max_flattened_domain_size) { - size_t flattened_domain_size = 1; - for (const auto& spec : input_spec.group_by_vector_specs()) { - flattened_domain_size *= spec.domain_spec().string_values().values_size(); - if (max_flattened_domain_size < flattened_domain_size) { - return absl::InvalidArgumentError( - "Global output domain size exceeds maximum threshold."); - } - } - return absl::OkStatus(); -} - -absl::StatusOr> CodecFactory::CreateExplicitCodec( - InputSpec input_spec) { +absl::StatusOr> Codec::CreateFlatHistogramCodec( + InputSpec input_spec, std::optional reference_time) { // Check that specs include at least one metric vector. if (input_spec.metric_vector_specs().empty()) { return absl::InvalidArgumentError( "input_spec must include at least one metric vector."); } + for (const auto& spec : input_spec.group_by_vector_specs()) { + if (spec.domain_spec().has_time()) { + if (spec.data_type() != InputSpec::STRING) { + return absl::InvalidArgumentError( + "Time domain can only be used with STRING data type."); + } + } else if (spec.domain_spec().has_string_values()) { + if (spec.domain_spec().string_values().values_size() == 0) { + return absl::InvalidArgumentError("String domain cannot be empty."); + } + } else { + return absl::InvalidArgumentError( + "Unsupported domain type for group-by vector"); + } + } + // Construct maps of vector names to specs, and checks for duplicates. absl::flat_hash_map group_by_spec_map; for (const auto& spec : input_spec.group_by_vector_specs()) { @@ -408,9 +482,72 @@ absl::StatusOr> CodecFactory::CreateExplicitCodec( absl::StrCat("Duplicate vector name: ", spec.vector_name())); } } - return absl::WrapUnique(new ExplicitCodecImpl(std::move(input_spec), - std::move(group_by_spec_map), - std::move(metric_spec_map))); + + // Parse and validate time domains. + absl::flat_hash_map time_domains; + for (const auto& spec : input_spec.group_by_vector_specs()) { + if (spec.domain_spec().has_time()) { + SECAGG_ASSIGN_OR_RETURN(TimeDomainInfo ts_info, + ParseTimeDomain(spec.domain_spec().time())); + time_domains.insert({spec.vector_name(), std::move(ts_info)}); + } + } + + // Compute sorted group-by keys. + std::vector group_by_keys; + group_by_keys.reserve(group_by_spec_map.size()); + absl::flat_hash_map group_by_domain_indices; + for (const auto& [key, spec] : group_by_spec_map) { + group_by_keys.push_back(spec->vector_name()); + if (spec->domain_spec().has_string_values()) { + // Precompute indices into domains to allow efficient lookups. + const auto& domain = spec->domain_spec().string_values(); + for (int i = 0; i < domain.values_size(); ++i) { + group_by_domain_indices[GroupDomainKey{spec->vector_name(), + domain.values(i)}] = i; + } + } + } + std::sort(group_by_keys.begin(), group_by_keys.end()); + + // Compute the sizes of the string domains for each group-by key, and the + // total number of bins. + int64_t flattened_domain_size = 1; + std::vector group_by_domain_sizes; + group_by_domain_sizes.reserve(group_by_keys.size()); + for (const auto& key : group_by_keys) { + auto spec_it = group_by_spec_map.find(key); + CHECK(spec_it != + group_by_spec_map.end()); // We expect the key to be present. + int domain_size = 0; + if (spec_it->second->domain_spec().has_time()) { + domain_size = time_domains.at(key).num_periods + 1; + } else { + domain_size = + spec_it->second->domain_spec().string_values().values_size(); + } + if (flattened_domain_size > + std::numeric_limits::max() / domain_size) { + return absl::InvalidArgumentError("Flat histogram bin count overflow."); + } + group_by_domain_sizes.push_back(domain_size); + flattened_domain_size *= domain_size; + } + + return absl::WrapUnique(new FlatHistogramCodecImpl( + std::move(input_spec), std::move(group_by_spec_map), + std::move(metric_spec_map), reference_time, std::move(time_domains), + std::move(group_by_keys), std::move(group_by_domain_sizes), + std::move(group_by_domain_indices), flattened_domain_size)); +} + +absl::StatusOr FlatHistogramCodecImpl::GetEncodedVectorLength( + absl::string_view metric_name) const { + if (!metric_spec_map_.contains(metric_name)) { + return absl::InvalidArgumentError( + absl::StrCat("Metric ", metric_name, " not found in input spec.")); + } + return static_cast(flat_histogram_bin_count_); } } // namespace willow diff --git a/willow/input_encoding/codec.h b/willow/input_encoding/codec.h index 540494f..c442b18 100644 --- a/willow/input_encoding/codec.h +++ b/willow/input_encoding/codec.h @@ -15,17 +15,29 @@ #ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_H_ #define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_H_ +#include #include +#include +#include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "ffi_utils/status_macros.h" +#include "willow/proto/willow/input_spec.pb.h" namespace secure_aggregation { namespace willow { +// The maximum number of bins in a flat histogram, which is the maximum size of +// the Cartesian product of domains for string features. +constexpr size_t kMaxFlatHistogramBins = 1000000; + using MetricData = absl::flat_hash_map>; using GroupData = absl::flat_hash_map>; using EncodedData = absl::flat_hash_map>; @@ -69,6 +81,54 @@ class Codec { virtual absl::Status ValidateExampleQuery( const absl::flat_hash_map& query_output_specs) const = 0; + + // Returns the length of the encoded vector for the given metric. + // Returns an InvalidArgument error if the metric is not found in the spec. + virtual absl::StatusOr GetEncodedVectorLength( + absl::string_view metric_name) const = 0; + + // Creates an instance of FlatHistogramCodec. + // + // `reference_time` must be set if the input spec contains any TimeDomain: + // - If the codec is used for encoding, `reference_time` represents the client + // encoding time used to evaluate staleness. + // - If the codec is used for decoding, `reference_time` represents the + // decoding anchor time (e.g., start of the current aggregation window) + // used to align the modular periods back to absolute timestamps. + static absl::StatusOr> CreateFlatHistogramCodec( + ::secure_aggregation::willow::InputSpec input_spec, + std::optional reference_time = std::nullopt); + + // Deprecated aliases for backward compatibility + [[deprecated("Use CreateFlatHistogramCodec instead")]] + static absl::StatusOr> CreateExplicitCodec( + ::secure_aggregation::willow::InputSpec input_spec, + std::optional reference_time = std::nullopt) { + return CreateFlatHistogramCodec(std::move(input_spec), reference_time); + } + + [[deprecated( + "Use CreateFlatHistogramCodec and GetEncodedVectorLength instead")]] + static absl::Status ValidateExplicitCodecInputSpec( + const ::secure_aggregation::willow::InputSpec& input_spec, + size_t max_flattened_domain_size = kMaxFlatHistogramBins) { + // Creating a codec allocates memory, compared to just validating the input + // spec, but the memory allocation is just proportional to the + // size of the input spec. + SECAGG_ASSIGN_OR_RETURN(std::unique_ptr codec, + CreateFlatHistogramCodec(input_spec)); + if (!input_spec.metric_vector_specs().empty()) { + std::string first_metric = + input_spec.metric_vector_specs(0).vector_name(); + SECAGG_ASSIGN_OR_RETURN(size_t length, + codec->GetEncodedVectorLength(first_metric)); + if (length > max_flattened_domain_size) { + return absl::InvalidArgumentError( + "Flat histogram bin count exceeds maximum threshold."); + } + } + return absl::OkStatus(); + } }; } // namespace willow diff --git a/willow/input_encoding/codec_bindings.cc b/willow/input_encoding/codec_bindings.cc index d076527..c856e68 100644 --- a/willow/input_encoding/codec_bindings.cc +++ b/willow/input_encoding/codec_bindings.cc @@ -23,7 +23,6 @@ #include "pybind11/stl.h" #include "pybind11_abseil/status_casters.h" #include "willow/input_encoding/codec.h" -#include "willow/input_encoding/codec_factory.h" #include "willow/proto/willow/input_spec.pb.h" namespace py = pybind11; @@ -44,8 +43,28 @@ PYBIND11_MODULE(codec_bindings, m) { item.second.cast(); } return self.ValidateExampleQuery(cpp_map); - }); + }) + .def("GetEncodedVectorLength", &Codec::GetEncodedVectorLength, + py::arg("metric_name")); + // Creates a FlatHistogramCodec from a serialized InputSpec. + // + // Remark: 'reference_time' is not exposed to Python because these bindings + // are currently only used to validate example queries (ValidateExampleQuery), + // which does not require encoding or decoding. + m.def( + "CreateFlatHistogramCodec", + [](const std::string& serialized_input_spec) + -> absl::StatusOr> { + InputSpec input_spec; + if (!input_spec.ParseFromString(serialized_input_spec)) { + return absl::InvalidArgumentError("Failed to parse InputSpec"); + } + return Codec::CreateFlatHistogramCodec(input_spec); + }, + py::arg("serialized_input_spec")); + + // DEPRECATED: Use CreateFlatHistogramCodec instead. m.def( "CreateExplicitCodec", [](const std::string& serialized_input_spec) @@ -54,10 +73,12 @@ PYBIND11_MODULE(codec_bindings, m) { if (!input_spec.ParseFromString(serialized_input_spec)) { return absl::InvalidArgumentError("Failed to parse InputSpec"); } - return CodecFactory::CreateExplicitCodec(input_spec); + return Codec::CreateFlatHistogramCodec(input_spec); }, py::arg("serialized_input_spec")); + // DEPRECATED: Use CreateFlatHistogramCodec and Codec.GetEncodedVectorLength + // instead. m.def( "ValidateExplicitCodecInputSpec", [](const std::string& serialized_input_spec, @@ -67,9 +88,9 @@ PYBIND11_MODULE(codec_bindings, m) { return absl::InvalidArgumentError("Failed to parse InputSpec"); } if (max_flattened_domain_size == 0) { - return CodecFactory::ValidateExplicitCodecInputSpec(input_spec); + return Codec::ValidateExplicitCodecInputSpec(input_spec); } else { - return CodecFactory::ValidateExplicitCodecInputSpec( + return Codec::ValidateExplicitCodecInputSpec( input_spec, max_flattened_domain_size); } }, diff --git a/willow/input_encoding/codec_bindings_test.py b/willow/input_encoding/codec_bindings_test.py index 7d77fc3..959e83d 100644 --- a/willow/input_encoding/codec_bindings_test.py +++ b/willow/input_encoding/codec_bindings_test.py @@ -31,11 +31,19 @@ def setUp(self): ], group_by_vector_specs=[ input_spec_pb2.InputSpec.InputVectorSpec( - vector_name="group1", data_type=input_spec_pb2.InputSpec.STRING + vector_name="group1", + data_type=input_spec_pb2.InputSpec.STRING, + domain_spec=input_spec_pb2.InputSpec.DomainSpec( + string_values=input_spec_pb2.InputSpec.StringValues( + values=["a"] + ) + ), ) ], ) - self.codec = codec_bindings.CreateExplicitCodec(spec.SerializeToString()) + self.codec = codec_bindings.CreateFlatHistogramCodec( + spec.SerializeToString() + ) def test_validate_example_query_success(self): query_specs = { @@ -64,6 +72,15 @@ def test_validate_example_query_vector_not_found(self): ): self.codec.ValidateExampleQuery(query_specs) + def test_get_encoded_vector_length_success(self): + self.assertEqual(self.codec.GetEncodedVectorLength("metric1"), 1) + + def test_get_encoded_vector_length_metric_not_found(self): + with self.assertRaisesRegex( + py_status.StatusNotOk, "not found in input spec" + ): + self.codec.GetEncodedVectorLength("unknown_metric") + if __name__ == "__main__": unittest.main() diff --git a/willow/input_encoding/codec_factory.h b/willow/input_encoding/codec_factory.h index 5ef4bb6..d4b874b 100644 --- a/willow/input_encoding/codec_factory.h +++ b/willow/input_encoding/codec_factory.h @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,33 +14,47 @@ #ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_ #define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_ + #include #include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/time/time.h" #include "willow/input_encoding/codec.h" #include "willow/proto/willow/input_spec.pb.h" namespace secure_aggregation { namespace willow { -// The maximum size of the Cartesian product of domains for string features. -constexpr size_t kMaxGlobalOutputDomainSize = 1000000; - -// Factory class that constructs non-copyable instances of children classes of -// Codec. -class CodecFactory { +class [[deprecated("Use Codec class static methods instead")]] CodecFactory { public: - // Creates an instance of ExplicitCodec. + // Creates an instance of FlatHistogramCodec. + // + // `reference_time` is used as the time reference for time domains: + // - If the codec is used for encoding, `reference_time` represents the client + // encoding time used to evaluate staleness. Must be set if the input spec + // contains any TimeDomain and the codec is used for encoding. + // - If the codec is used for decoding, `reference_time` represents the + // decoding anchor time (e.g., start of the current aggregation window) + // used to align the modular periods back to absolute timestamps. Must be + // set if the input spec contains any TimeDomain and the codec is used for + // decoding. static absl::StatusOr> CreateExplicitCodec( - InputSpec input_spec); + InputSpec input_spec, + std::optional reference_time = std::nullopt) { + return Codec::CreateFlatHistogramCodec(std::move(input_spec), + reference_time); + } - // Check that the combined size of the string domains is less than the - // maximum allowed size. static absl::Status ValidateExplicitCodecInputSpec( const InputSpec& input_spec, - size_t max_flattened_domain_size = kMaxGlobalOutputDomainSize); + size_t max_flattened_domain_size = kMaxFlatHistogramBins) { + return Codec::ValidateExplicitCodecInputSpec(input_spec, + max_flattened_domain_size); + } }; } // namespace willow diff --git a/willow/input_encoding/codec_test.cc b/willow/input_encoding/codec_test.cc new file mode 100644 index 0000000..5733047 --- /dev/null +++ b/willow/input_encoding/codec_test.cc @@ -0,0 +1,966 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/input_encoding/codec.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "ffi_utils/status_matchers.h" +#include "gmock/gmock.h" +#include "google/protobuf/duration.pb.h" +#include "gtest/gtest.h" +#include "willow/proto/willow/input_spec.pb.h" +#include "willow/testing_utils/testing_utils.h" + +namespace secure_aggregation { +namespace willow { +namespace { + +using ::secure_aggregation::secagg_internal::IsOkAndHolds; +using ::secure_aggregation::secagg_internal::StatusIs; + +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +TEST(CodecTest, ValidateInputAndSpecLengthMismatch) { + MetricData metric_data; + metric_data["metric1"] = {1, 2, 3}; + GroupData group_by_data; + group_by_data["feature1"] = {"a", "b", "a"}; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + // Missing group_by_vector_specs for "feature1" + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Key feature1 found in group_by_data but not " + "in input_spec."))); +} + +TEST(CodecTest, ValidateInputAndSpecTypeMismatch) { + MetricData metric_data; + metric_data["metric1"] = {1, 2, 3}; + GroupData group_by_data; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::STRING); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type mismatch for key metric1"))); +} + +TEST(CodecTest, ValidateInputAndSpecEmptyInputData) { + MetricData metric_data; + GroupData group_by_data; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("feature1"); + group_by_spec->set_data_type(InputSpec::STRING); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "a"); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Metric data cannot be empty."))); +} + +TEST(CodecTest, ValidateInputAndSpecDomainValueNotFound) { + MetricData metric_data; + metric_data["metric1"] = {1}; + GroupData group_by_data; + group_by_data["feature1"] = {"c"}; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("feature1"); + group_by_spec->set_data_type(InputSpec::STRING); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "a"); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "b"); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Domain mismatch for key feature1"))); +} + +TEST(CodecTest, ValidateInputAndSpecInputDataVectorLengthMismatch) { + MetricData metric_data; + metric_data["metric1"] = {1, 2, 3}; + metric_data["metric2"] = {1, 2}; + GroupData group_by_data; + InputSpec input_spec; + auto* metric_spec1 = input_spec.add_metric_vector_specs(); + metric_spec1->set_vector_name("metric1"); + metric_spec1->set_data_type(InputSpec::INT64); + auto* metric_spec2 = input_spec.add_metric_vector_specs(); + metric_spec2->set_vector_name("metric2"); + metric_spec2->set_data_type(InputSpec::INT64); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("must have the same length"))); +} + +TEST(CodecTest, ValidateInputAndSpecGroupByDataVectorLengthMismatch) { + MetricData metric_data; + metric_data["metric1"] = {1, 2, 3}; + GroupData group_by_data; + group_by_data["feature1"] = {"a", "b"}; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("feature1"); + group_by_spec->set_data_type(InputSpec::STRING); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "a"); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "b"); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("must have the same length"))); +} + +TEST(CodecTest, ValidateInputAndSpecDomainSizeVectorLengthMismatch) { + MetricData metric_data; + metric_data["metric1"] = {1, 2, 3}; + GroupData group_by_data; + group_by_data["feature1"] = {"a", "b", "c"}; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + metric_spec->mutable_domain_spec()->mutable_string_values()->add_values("x"); + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("feature1"); + group_by_spec->set_data_type(InputSpec::STRING); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "a"); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "b"); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Domain mismatch for key feature1: " + "group_by_data value c not found in domain"))); +} + +TEST(CodecTest, ValidateInputAndSpecInputKeyNotInSpec) { + MetricData metric_data; + metric_data["metric1"] = {1}; + metric_data["metric2"] = {2}; + GroupData group_by_data; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + metric_spec->mutable_domain_spec()->mutable_string_values()->add_values("x"); + // Missing metric_vector_specs for "metric2" + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Key metric2 found in metric_data but not in " + "input_spec."))); +} + +TEST(CodecTest, ValidateInputAndSpecGroupByKeyNotInSpec) { + MetricData metric_data; + metric_data["metric1"] = {1}; + GroupData group_by_data; + group_by_data["feature1"] = {"a"}; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + metric_spec->mutable_domain_spec()->mutable_string_values()->add_values("x"); + // Missing group_by_vector_specs for "feature1" + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Key feature1 found in group_by_data but not " + "in input_spec."))); +} + +TEST(CodecTest, ValidateInputAndSpecGroupByTypeMismatch) { + MetricData metric_data; + metric_data["metric1"] = {1}; + GroupData group_by_data; + group_by_data["feature1"] = {"a"}; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + metric_spec->mutable_domain_spec()->mutable_string_values()->add_values("x"); + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("feature1"); + group_by_spec->set_data_type(InputSpec::INT64); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "y"); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type mismatch for key feature1"))); +} + +TEST(CodecTest, ValidateInputAndSpecMaxFlatHistogramBinsExceeded) { + MetricData metric_data; + metric_data["metric1"] = {1}; + GroupData group_by_data; + group_by_data["feature1"] = {"a"}; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + metric_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "1, 2, 3"); + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("feature1"); + group_by_spec->set_data_type(InputSpec::STRING); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "a"); + for (int i = 0; i < 1000000; ++i) { + input_spec.mutable_group_by_vector_specs(0) + ->mutable_domain_spec() + ->mutable_string_values() + ->add_values(std::to_string(i)); + } + + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Flat histogram bin count exceeds"))); +} + +TEST(CodecTest, ValidateInputAndSpecCustomMaxFlatHistogramBins) { + MetricData metric_data; + metric_data["metric1"] = {1}; + GroupData group_by_data; + group_by_data["feature1"] = {"a"}; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("feature1"); + group_by_spec->set_data_type(InputSpec::STRING); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "a"); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "b"); + // Domain size is 2. + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec, 1), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Flat histogram bin count exceeds"))); + SECAGG_EXPECT_OK(Codec::ValidateExplicitCodecInputSpec(input_spec, 2)); +} + +TEST(CodecTest, ValidateInputSpecFlattenedDomainSize) { + InputSpec input_spec; + + // Metric 1 + auto* metric_spec1 = input_spec.add_metric_vector_specs(); + metric_spec1->set_vector_name("metric1"); + metric_spec1->set_data_type(InputSpec::INT64); + + // Metric 2 + auto* metric_spec2 = input_spec.add_metric_vector_specs(); + metric_spec2->set_vector_name("metric2"); + metric_spec2->set_data_type(InputSpec::INT64); + + // Group-by 1: Time Domain (size = 5 + 1 = 6) + auto* group_by_spec1 = input_spec.add_group_by_vector_specs(); + group_by_spec1->set_vector_name("time"); + group_by_spec1->set_data_type(InputSpec::STRING); + auto* time_domain = group_by_spec1->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(5); // 5 days + time_domain->set_timezone("UTC"); + time_domain->set_format(absl::RFC3339_full); + + // Group-by 2: String values domain (size = 3) + auto* group_by_spec2 = input_spec.add_group_by_vector_specs(); + group_by_spec2->set_vector_name("country"); + group_by_spec2->set_data_type(InputSpec::STRING); + auto* string_values = + group_by_spec2->mutable_domain_spec()->mutable_string_values(); + string_values->add_values("US"); + string_values->add_values("CA"); + string_values->add_values("MX"); + + // Total flattened domain size is (num_periods + 1) * string_values_size = 6 * + // 3 = 18. Passing a smaller limit should fail. + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec, 17), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Flat histogram bin count exceeds"))); + SECAGG_EXPECT_OK(Codec::ValidateExplicitCodecInputSpec(input_spec, 18)); +} + +TEST(CodecTest, EncodeSimpleGroupBy) { + InputSpec input_spec = CreateTestInputSpecProto(); + MetricData metric_data = CreateTestMetricData(); + GroupData group_by_data = CreateTestGroupData(); + + // group_by keys are sorted: "country", "lang" + // value_to_index_maps["country"]: {"CA":0, "GB":1, "MX":2, "US":3} + // value_to_index_maps["lang"]: {"en":0, "es":1} + + // Row 0: country=US(3), lang=en(0). metric1=10. + // combo_index = 3*2 + 0 = 6 + // Row 1: country=CA(0), lang=es(1). metric1=20. + // combo_index = 0*2 + 1 = 1 + // Row 2: country=US(3), lang=es(1). metric1=5. + // combo_index = 3*2 + 1 = 7 + + // Expected histogram for metric1: + // Index 0 (CA, en): 0 + // Index 1 (CA, es): 20 + // Index 2 (GB, en): 0 + // Index 3 (GB, es): 0 + // Index 4 (MX, en): 0 + // Index 5 (MX, es): 0 + // Index 6 (US, en): 10 + // Index 7 (US, es): 5 + // Result: [0, 20, 0, 0, 0, 0, 10, 5] + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + IsOkAndHolds(UnorderedElementsAre( + Pair("metric1", ElementsAre(0, 20, 0, 0, 0, 0, 10, 5))))); +} + +TEST(CodecTest, EncodeTwoMetricsOneGroupBy) { + MetricData metric_data; + metric_data["metric1"] = {10, 20}; + metric_data["metric2"] = {100, 200}; + GroupData group_by_data; + group_by_data["country"] = {"US", "CA"}; + InputSpec input_spec; + auto* metric_spec1 = input_spec.add_metric_vector_specs(); + metric_spec1->set_vector_name("metric1"); + metric_spec1->set_data_type(InputSpec::INT64); + auto* metric_spec2 = input_spec.add_metric_vector_specs(); + metric_spec2->set_vector_name("metric2"); + metric_spec2->set_data_type(InputSpec::INT64); + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("country"); + group_by_spec->set_data_type(InputSpec::STRING); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "CA"); + group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( + "US"); + + // group_by keys are sorted: "country" + // value_to_index_maps["country"]: {"CA":0, "US":1} + // combinations: {0}->0, {1}->1 + + // Row 0: country=US(1), metric1=10, metric2=100. + // combo_index for {1} is 1. + // result["metric1"][1]=10, result["metric2"][1]=100 + // Row 1: country=CA(0), metric1=20, metric2=200. + // combo_index for {0} is 0. + // result["metric1"][0]=20, result["metric2"][0]=200 + + // Expected: + // metric1: [20, 10] + // metric2: [200, 100] + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + IsOkAndHolds(UnorderedElementsAre( + Pair("metric1", ElementsAre(20, 10)), + Pair("metric2", ElementsAre(200, 100))))); +} + +TEST(CodecTest, EncodeThenDecode) { + InputSpec input_spec = CreateTestInputSpecProto(); + MetricData metric_data = CreateTestMetricData(); + GroupData group_by_data = CreateTestGroupData(); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, + encoder->Encode(group_by_data, metric_data)); + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data, + encoder->Decode(encoded_data)); + + const GroupData& decoded_groups = decoded_data.group_data; + const MetricData& decoded_metrics = decoded_data.metric_data; + + // The decoded output is sparse and only contains rows with non-zero metrics. + // The order depends on iteration over dense vector. + // metric1 values for combo indices 1,6,7 are 20,10,5. + // The decoded result should contain 3 rows in order of combination index. + // combo 1: CA, es, metric1=20 + // combo 6: US, en, metric1=10 + // combo 7: US, es, metric1=5 + EXPECT_THAT(decoded_metrics, + UnorderedElementsAre(Pair("metric1", ElementsAre(20, 10, 5)))); + EXPECT_THAT( + decoded_groups, + UnorderedElementsAre(Pair("country", ElementsAre("CA", "US", "US")), + Pair("lang", ElementsAre("es", "en", "es")))); +} + +TEST(CodecTest, EncodeThenDecodeDataOrderDoesNotMatter) { + InputSpec input_spec = CreateTestInputSpecProto(); + MetricData metric_data1 = CreateTestMetricData(); + GroupData group_by_data1 = CreateTestGroupData(); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder1, + Codec::CreateFlatHistogramCodec(input_spec)); + + // Note that the order of metric_data2 and group_by_data2 is different from + // metric_data1 and group_by_data1. The decoded result should be the same. + MetricData metric_data2; + metric_data2["metric1"] = {20, 10, 5}; + GroupData group_by_data2; + group_by_data2["lang"] = {"es", "en", "es"}; + group_by_data2["country"] = {"CA", "US", "US"}; + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder2, + Codec::CreateFlatHistogramCodec(input_spec)); + + SECAGG_ASSERT_OK_AND_ASSIGN(auto encoded_data, + encoder1->Encode(group_by_data1, metric_data1)); + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data, + encoder2->Decode(encoded_data)); + + const auto& decoded_groups = decoded_data.group_data; + const auto& decoded_metrics = decoded_data.metric_data; + + EXPECT_THAT(decoded_metrics, + UnorderedElementsAre(Pair("metric1", ElementsAre(20, 10, 5)))); + EXPECT_THAT( + decoded_groups, + UnorderedElementsAre(Pair("country", ElementsAre("CA", "US", "US")), + Pair("lang", ElementsAre("es", "en", "es")))); +} + +TEST(CodecTest, EncodeThenDecodeNoGroupBy) { + MetricData metric_data; + metric_data["metric1"] = {10, 20, 5}; + MetricData expected_metric_data = metric_data; + GroupData group_by_data; + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, + encoder->Encode(group_by_data, metric_data)); + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data, + encoder->Decode(encoded_data)); + + const GroupData& decoded_groups = decoded_data.group_data; + const MetricData& decoded_metrics = decoded_data.metric_data; + + EXPECT_EQ(decoded_groups.size(), 0); + EXPECT_THAT(decoded_metrics, + UnorderedElementsAre(Pair("metric1", ElementsAre(10, 20, 5)))); +} + +TEST(CodecTest, EncodeWithDomainValueNotFound) { + MetricData metric_data; + metric_data["metric1"] = {10}; + GroupData group_by_data; + group_by_data["country"] = {"US"}; + group_by_data["lang"] = {"en"}; + group_by_data["feature1"] = {"c"}; // 'c' is not in the domain + + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + auto* group_by_spec1 = input_spec.add_group_by_vector_specs(); + group_by_spec1->set_vector_name("country"); + group_by_spec1->set_data_type(InputSpec::STRING); + group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( + "US"); + group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( + "CA"); + auto* group_by_spec2 = input_spec.add_group_by_vector_specs(); + group_by_spec2->set_vector_name("lang"); + group_by_spec2->set_data_type(InputSpec::STRING); + group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( + "en"); + group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( + "fr"); + auto* group_by_spec3 = input_spec.add_group_by_vector_specs(); + group_by_spec3->set_vector_name("feature1"); + group_by_spec3->set_data_type(InputSpec::STRING); + group_by_spec3->mutable_domain_spec()->mutable_string_values()->add_values( + "a"); + group_by_spec3->mutable_domain_spec()->mutable_string_values()->add_values( + "b"); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Domain mismatch for key feature1: " + "group_by_data value c not found in domain"))); +} + +TEST(CodecTest, CreateCodecUnsupportedDomain) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("group1"); + group_by_spec->set_data_type(InputSpec::STRING); + + // Set unsupported interval domain on a group-by vector + group_by_spec->mutable_domain_spec()->mutable_interval()->set_min(0); + group_by_spec->mutable_domain_spec()->mutable_interval()->set_max(10); + + EXPECT_THAT(Codec::CreateFlatHistogramCodec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported domain type"))); +} + +TEST(CodecTest, GetEncodedVectorLengthSuccess) { + InputSpec input_spec = CreateTestInputSpecProto(); + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr codec, + Codec::CreateFlatHistogramCodec(input_spec)); + EXPECT_THAT(codec->GetEncodedVectorLength("metric1"), IsOkAndHolds(8)); + EXPECT_THAT(codec->GetEncodedVectorLength("unknown_metric"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not found in input spec"))); +} + +TEST(CodecTest, CreateCodecOverflow) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + // Add 8 domains of size 2^10. (2^10)^8 = 2^80 > 2^64, which overflows + // int64_t. + for (int j = 0; j < 8; ++j) { + auto* spec = input_spec.add_group_by_vector_specs(); + spec->set_vector_name("group" + std::to_string(j)); + spec->set_data_type(InputSpec::STRING); + for (int i = 0; i < 1024; ++i) { + spec->mutable_domain_spec()->mutable_string_values()->add_values( + std::to_string(i)); + } + } + + EXPECT_THAT( + Codec::CreateFlatHistogramCodec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("overflow"))); +} + +TEST(CodecTest, ValidateInputSpecWithTimeDomain) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("UTC"); + + // Valid spec + SECAGG_EXPECT_OK(Codec::ValidateExplicitCodecInputSpec(input_spec)); + + // Valid spec with lookback_window + time_domain->mutable_lookback_window()->set_seconds(86400 * 2); // 2 days + SECAGG_EXPECT_OK(Codec::ValidateExplicitCodecInputSpec(input_spec)); + + // Invalid lookback_window <= 0 + time_domain->mutable_lookback_window()->set_seconds(0); + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("lookback_window must be > 0"))); + time_domain->clear_lookback_window(); + + // Invalid period_duration <= 0 + time_domain->mutable_period_duration()->set_seconds(0); + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("period_duration must be > 0"))); + time_domain->mutable_period_duration()->set_seconds(86400); + + // Invalid num_periods <= 0 + time_domain->set_num_periods(0); + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("num_periods must be > 0"))); + time_domain->set_num_periods(6); + + // Invalid timezone + time_domain->set_timezone("Invalid/Timezone"); + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid timezone"))); + time_domain->set_timezone("UTC"); + + // Invalid data type (must be STRING) + group_by_spec->set_data_type(InputSpec::INT64); + EXPECT_THAT(Codec::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Time domain can only be used with STRING"))); +} + +TEST(CodecTest, EncodeTimeDomainFailsWithoutEncodingTime) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("UTC"); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec)); + + MetricData metric_data; + metric_data["metric1"] = {10}; + GroupData group_by_data; + group_by_data["time"] = {"2026-01-08T12:00:00Z"}; + + // Encode without reference_time should fail + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("reference_time is required"))); +} + +TEST(CodecTest, EncodeTimeDomain) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("UTC"); + time_domain->set_format(absl::RFC3339_full); + time_domain->mutable_lookback_window()->set_seconds(86400 * 2); // 2 days + + // Encoding time = Day 8 start + absl::Time encoding_time; + std::string err; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "1970-01-09T00:00:00Z", + &encoding_time, &err)) + << err; + + // Event times + std::string t1_str = + "1970-01-08T12:00:00Z"; // Valid (Day 7, 12h ago relative to Day 8 start) + std::string t2_str = "1970-01-06T12:00:00Z"; // Stale (Day 5, 2.5d ago) + std::string t3_str = "1970-01-10T12:00:00Z"; // Future (Day 9, 1.5d future) + + MetricData metric_data; + metric_data["metric1"] = {10, 20, 30}; + GroupData group_by_data; + group_by_data["time"] = {t1_str, t2_str, t3_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder, + Codec::CreateFlatHistogramCodec(input_spec, encoding_time)); + + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, + encoder->Encode(group_by_data, metric_data)); + + EXPECT_THAT(encoded_data.at("metric1"), ElementsAre(0, 10, 0, 0, 0, 0, 30)); +} + +TEST(CodecTest, DecodeTimeDomain) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("UTC"); + time_domain->set_format("%Y-%m-%d"); + + // decoding_anchor_time = 1970-01-09 00:00:00 UTC (Day 8 start) + absl::Time decoding_anchor_time; + std::string err; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "1970-01-09T00:00:00Z", + &decoding_anchor_time, &err)) + << err; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr decoder, + Codec::CreateFlatHistogramCodec(input_spec, decoding_anchor_time)); + + // Create a histogram with 7 buckets (num_periods + 1). + EncodedData encoded_data; + encoded_data["metric1"] = std::vector(7, 0); + encoded_data["metric1"][2] = 100; + encoded_data["metric1"][6] = 200; + + // Decode and check the timestamps. + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data, + decoder->Decode(encoded_data)); + EXPECT_THAT(decoded_data.metric_data.at("metric1"), ElementsAre(100, 200)); + + // Bucket 2 maps to Day 8 since since 8 % 6 = 2 and 8 <= 8 < 8 + 6. + std::string expected_time_2 = "1970-01-09"; + // Bucket 6 is an invalid timestamp, so it maps back to the default. + std::string expected_time_6 = "1970-01-01"; + EXPECT_THAT(decoded_data.group_data.at("time"), + ElementsAre(expected_time_2, expected_time_6)); +} + +TEST(CodecTest, EncodeThenDecodeLocalTime) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + // Unset timezone defaults to UTC. + time_domain->set_format("%Y-%m-%d"); // No timezone in output, just civil day + time_domain->mutable_lookback_window()->set_seconds(86400 * 2); // 2 days + + // Scenario: Two local events occurring in different physical timezones, but + // representing the exact same civil day on their respective devices. + + // --- Client 1 (New York, EDT/UTC-4) --- + absl::TimeZone ny_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/New_York", &ny_tz)); + // Device encodes at local 17:34:56. + absl::Time encoding_time1; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-12T17:34:56-04:00", + &encoding_time1, nullptr)); + + // Event occurred 15 minutes prior to encoding. + absl::Time event_time_ny = encoding_time1 - absl::Minutes(15); + // Format using local format "%Y-%m-%d" -> "2026-05-12". + std::string t1_str = absl::FormatTime("%Y-%m-%d", event_time_ny, ny_tz); + MetricData md1; + md1["metric1"] = {10}; + GroupData gd1; + gd1["time"] = {t1_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder1, + Codec::CreateFlatHistogramCodec(input_spec, encoding_time1)); + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData ed1, encoder1->Encode(gd1, md1)); + + // --- Client 2 (Los Angeles, PDT/UTC-7) --- + absl::TimeZone la_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/Los_Angeles", &la_tz)); + // Device encodes at local 17:39:12. + absl::Time encoding_time2; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-12T17:39:12-07:00", + &encoding_time2, nullptr)); + + // Event occurred 3 hours and 12 minutes prior to encoding + absl::Time event_time_la = + encoding_time2 - absl::Hours(3) - absl::Minutes(12); + // Format using local format "%Y-%m-%d" -> "2026-05-12". + std::string t2_str = absl::FormatTime("%Y-%m-%d", event_time_la, la_tz); + MetricData md2; + md2["metric1"] = {20}; + GroupData gd2; + gd2["time"] = {t2_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder2, + Codec::CreateFlatHistogramCodec(input_spec, encoding_time2)); + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData ed2, encoder2->Encode(gd2, md2)); + + // Decode using decoding_anchor_time = 2026-05-12T00:00:00Z UTC. + absl::Time decoding_anchor_time; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-12T00:00:00Z", + &decoding_anchor_time, nullptr)); + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr decoder, + Codec::CreateFlatHistogramCodec(input_spec, decoding_anchor_time)); + + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data_1, decoder->Decode(ed1)); + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data_2, decoder->Decode(ed2)); + + // The bucket maps back to 2026-05-12. + std::string expected_decoded_time = "2026-05-12"; + + EXPECT_THAT(decoded_data_1.metric_data.at("metric1"), ElementsAre(10)); + EXPECT_THAT(decoded_data_1.group_data.at("time"), + ElementsAre(expected_decoded_time)); + + EXPECT_THAT(decoded_data_2.metric_data.at("metric1"), ElementsAre(20)); + EXPECT_THAT(decoded_data_2.group_data.at("time"), + ElementsAre(expected_decoded_time)); +} + +TEST(CodecTest, EncodeThenDecodeAbsoluteTime) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("America/Los_Angeles"); // Output TimeZone + time_domain->set_format(absl::RFC3339_full); + time_domain->mutable_lookback_window()->set_seconds(86400 * 2); // 2 days + + // Scenario: Two local events occurring in different civil days but at the + // same absolute time. Both devices also encode at the exact same absolute + // physical time. Just past midnight ET on May 12, but still May 11 in LA. + absl::Time encoding_time; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-12T00:34:56-04:00", + &encoding_time, nullptr)); + + // Event occurred 3 minutes prior to encoding. + absl::Time event_time = encoding_time - absl::Minutes(3); + + // --- Client 1 (New York, EDT/UTC-4) --- + absl::TimeZone ny_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/New_York", &ny_tz)); + std::string t1_str = absl::FormatTime(absl::RFC3339_full, event_time, ny_tz); + MetricData md1; + md1["metric1"] = {10}; + GroupData gd1; + gd1["time"] = {t1_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder1, + Codec::CreateFlatHistogramCodec(input_spec, encoding_time)); + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData ed1, encoder1->Encode(gd1, md1)); + + // --- Client 2 (Los Angeles, PDT/UTC-7) --- + absl::TimeZone la_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/Los_Angeles", &la_tz)); + std::string t2_str = absl::FormatTime(absl::RFC3339_full, event_time, la_tz); + MetricData md2; + md2["metric1"] = {20}; + GroupData gd2; + gd2["time"] = {t2_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder2, + Codec::CreateFlatHistogramCodec(input_spec, encoding_time)); + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData ed2, encoder2->Encode(gd2, md2)); + + // Decode using anchor time before the events. + absl::Time decoding_anchor_time; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-10T00:00:00Z", + &decoding_anchor_time, nullptr)); + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr decoder, + Codec::CreateFlatHistogramCodec(input_spec, decoding_anchor_time)); + + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data_1, decoder->Decode(ed1)); + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data_2, decoder->Decode(ed2)); + + // Reconstructed absolute time is 2026-05-12T00:00:00Z UTC. It's the start of + // the day in UTC because the default origin time is in UTC, and the period is + // 1 day. Formatted back using the spec's configured timezone + // (America/Los_Angeles, PDT UTC-7) and full RFC3339 format: + // 2026-05-12T00:00:00Z UTC -> "2026-05-11T17:00:00-07:00". + std::string expected_decoded_time = "2026-05-11T17:00:00-07:00"; + + EXPECT_THAT(decoded_data_1.metric_data.at("metric1"), ElementsAre(10)); + EXPECT_THAT(decoded_data_1.group_data.at("time"), + ElementsAre(expected_decoded_time)); + + EXPECT_THAT(decoded_data_2.metric_data.at("metric1"), ElementsAre(20)); + EXPECT_THAT(decoded_data_2.group_data.at("time"), + ElementsAre(expected_decoded_time)); +} + +} // namespace +} // namespace willow +} // namespace secure_aggregation diff --git a/willow/input_encoding/explicit_codec_test.cc b/willow/input_encoding/explicit_codec_test.cc deleted file mode 100644 index eb46406..0000000 --- a/willow/input_encoding/explicit_codec_test.cc +++ /dev/null @@ -1,518 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "ffi_utils/status_matchers.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "willow/input_encoding/codec.h" -#include "willow/input_encoding/codec_factory.h" -#include "willow/proto/willow/input_spec.pb.h" -#include "willow/testing_utils/testing_utils.h" - -namespace secure_aggregation { -namespace willow { -namespace { - -using ::secure_aggregation::secagg_internal::IsOkAndHolds; -using ::secure_aggregation::secagg_internal::StatusIs; -using ::testing::ElementsAre; -using ::testing::HasSubstr; -using ::testing::Pair; -using ::testing::UnorderedElementsAre; - -TEST(CodecFactoryTest, ValidateInputAndSpecLengthMismatch) { - MetricData metric_data; - metric_data["metric1"] = {1, 2, 3}; - GroupData group_by_data; - group_by_data["feature1"] = {"a", "b", "a"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - // Missing group_by_vector_specs for "feature1" - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Key feature1 found in group_by_data but not " - "in input_spec."))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecTypeMismatch) { - MetricData metric_data; - metric_data["metric1"] = {1, 2, 3}; - GroupData group_by_data; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::STRING); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type mismatch for key metric1"))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecEmptyInputData) { - MetricData metric_data; - GroupData group_by_data; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - auto* group_by_spec = input_spec.add_group_by_vector_specs(); - group_by_spec->set_vector_name("feature1"); - group_by_spec->set_data_type(InputSpec::STRING); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Metric data cannot be empty."))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecDomainValueNotFound) { - MetricData metric_data; - metric_data["metric1"] = {1}; - GroupData group_by_data; - group_by_data["feature1"] = {"c"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - auto* group_by_spec = input_spec.add_group_by_vector_specs(); - group_by_spec->set_vector_name("feature1"); - group_by_spec->set_data_type(InputSpec::STRING); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "a"); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "b"); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Domain mismatch for key feature1"))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecInputDataVectorLengthMismatch) { - MetricData metric_data; - metric_data["metric1"] = {1, 2, 3}; - metric_data["metric2"] = {1, 2}; - GroupData group_by_data; - InputSpec input_spec; - auto* metric_spec1 = input_spec.add_metric_vector_specs(); - metric_spec1->set_vector_name("metric1"); - metric_spec1->set_data_type(InputSpec::INT64); - auto* metric_spec2 = input_spec.add_metric_vector_specs(); - metric_spec2->set_vector_name("metric2"); - metric_spec2->set_data_type(InputSpec::INT64); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("must have the same length"))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecGroupByDataVectorLengthMismatch) { - MetricData metric_data; - metric_data["metric1"] = {1, 2, 3}; - GroupData group_by_data; - group_by_data["feature1"] = {"a", "b"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - auto* group_by_spec = input_spec.add_group_by_vector_specs(); - group_by_spec->set_vector_name("feature1"); - group_by_spec->set_data_type(InputSpec::STRING); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "a"); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "b"); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("must have the same length"))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecDomainSizeVectorLengthMismatch) { - MetricData metric_data; - metric_data["metric1"] = {1, 2, 3}; - GroupData group_by_data; - group_by_data["feature1"] = {"a", "b", "c"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - metric_spec->mutable_domain_spec()->mutable_string_values()->add_values("x"); - auto* group_by_spec = input_spec.add_group_by_vector_specs(); - group_by_spec->set_vector_name("feature1"); - group_by_spec->set_data_type(InputSpec::STRING); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "a"); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "b"); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Domain mismatch for key feature1: " - "group_by_data value c not found in domain"))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecInputKeyNotInSpec) { - MetricData metric_data; - metric_data["metric1"] = {1}; - metric_data["metric2"] = {2}; - GroupData group_by_data; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - metric_spec->mutable_domain_spec()->mutable_string_values()->add_values("x"); - // Missing metric_vector_specs for "metric2" - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Key metric2 found in metric_data but not in " - "input_spec."))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecGroupByKeyNotInSpec) { - MetricData metric_data; - metric_data["metric1"] = {1}; - GroupData group_by_data; - group_by_data["feature1"] = {"a"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - metric_spec->mutable_domain_spec()->mutable_string_values()->add_values("x"); - // Missing group_by_vector_specs for "feature1" - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Key feature1 found in group_by_data but not " - "in input_spec."))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecGroupByTypeMismatch) { - MetricData metric_data; - metric_data["metric1"] = {1}; - GroupData group_by_data; - group_by_data["feature1"] = {"a"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - metric_spec->mutable_domain_spec()->mutable_string_values()->add_values("x"); - auto* group_by_spec = input_spec.add_group_by_vector_specs(); - group_by_spec->set_vector_name("feature1"); - group_by_spec->set_data_type(InputSpec::INT64); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "y"); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type mismatch for key feature1"))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecGlobalDomainSizeExceeded) { - MetricData metric_data; - metric_data["metric1"] = {1}; - GroupData group_by_data; - group_by_data["feature1"] = {"a"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - metric_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "1, 2, 3"); - auto* group_by_spec = input_spec.add_group_by_vector_specs(); - group_by_spec->set_vector_name("feature1"); - group_by_spec->set_data_type(InputSpec::STRING); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "a"); - for (int i = 0; i < 1000000; ++i) { - input_spec.mutable_group_by_vector_specs(0) - ->mutable_domain_spec() - ->mutable_string_values() - ->add_values(std::to_string(i)); - } - - EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Global output domain size exceeds"))); -} - -TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) { - MetricData metric_data; - metric_data["metric1"] = {1}; - GroupData group_by_data; - group_by_data["feature1"] = {"a"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - auto* group_by_spec = input_spec.add_group_by_vector_specs(); - group_by_spec->set_vector_name("feature1"); - group_by_spec->set_data_type(InputSpec::STRING); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "a"); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "b"); - // Domain size is 2. - EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 1), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Global output domain size exceeds"))); - SECAGG_EXPECT_OK(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 2)); -} - -TEST(CodecFactoryTest, EncodeSimpleGroupBy) { - InputSpec input_spec = CreateTestInputSpecProto(); - MetricData metric_data = CreateTestMetricData(); - GroupData group_by_data = CreateTestGroupData(); - - // group_by keys are sorted: "country", "lang" - // value_to_index_maps["country"]: {"CA":0, "GB":1, "MX":2, "US":3} - // value_to_index_maps["lang"]: {"en":0, "es":1} - - // Row 0: country=US(3), lang=en(0). metric1=10. - // combo_index = 3*2 + 0 = 6 - // Row 1: country=CA(0), lang=es(1). metric1=20. - // combo_index = 0*2 + 1 = 1 - // Row 2: country=US(3), lang=es(1). metric1=5. - // combo_index = 3*2 + 1 = 7 - - // Expected histogram for metric1: - // Index 0 (CA, en): 0 - // Index 1 (CA, es): 20 - // Index 2 (GB, en): 0 - // Index 3 (GB, es): 0 - // Index 4 (MX, en): 0 - // Index 5 (MX, es): 0 - // Index 6 (US, en): 10 - // Index 7 (US, es): 5 - // Result: [0, 20, 0, 0, 0, 0, 10, 5] - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - IsOkAndHolds(UnorderedElementsAre( - Pair("metric1", ElementsAre(0, 20, 0, 0, 0, 0, 10, 5))))); -} - -TEST(CodecFactoryTest, EncodeTwoMetricsOneGroupBy) { - MetricData metric_data; - metric_data["metric1"] = {10, 20}; - metric_data["metric2"] = {100, 200}; - GroupData group_by_data; - group_by_data["country"] = {"US", "CA"}; - InputSpec input_spec; - auto* metric_spec1 = input_spec.add_metric_vector_specs(); - metric_spec1->set_vector_name("metric1"); - metric_spec1->set_data_type(InputSpec::INT64); - auto* metric_spec2 = input_spec.add_metric_vector_specs(); - metric_spec2->set_vector_name("metric2"); - metric_spec2->set_data_type(InputSpec::INT64); - auto* group_by_spec = input_spec.add_group_by_vector_specs(); - group_by_spec->set_vector_name("country"); - group_by_spec->set_data_type(InputSpec::STRING); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "CA"); - group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( - "US"); - - // group_by keys are sorted: "country" - // value_to_index_maps["country"]: {"CA":0, "US":1} - // combinations: {0}->0, {1}->1 - - // Row 0: country=US(1), metric1=10, metric2=100. - // combo_index for {1} is 1. - // result["metric1"][1]=10, result["metric2"][1]=100 - // Row 1: country=CA(0), metric1=20, metric2=200. - // combo_index for {0} is 0. - // result["metric1"][0]=20, result["metric2"][0]=200 - - // Expected: - // metric1: [20, 10] - // metric2: [200, 100] - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - IsOkAndHolds(UnorderedElementsAre( - Pair("metric1", ElementsAre(20, 10)), - Pair("metric2", ElementsAre(200, 100))))); -} - -TEST(CodecFactoryTest, EncodeThenDecode) { - InputSpec input_spec = CreateTestInputSpecProto(); - MetricData metric_data = CreateTestMetricData(); - GroupData group_by_data = CreateTestGroupData(); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - - SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, - encoder->Encode(group_by_data, metric_data)); - SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data, - encoder->Decode(encoded_data)); - - const GroupData& decoded_groups = decoded_data.group_data; - const MetricData& decoded_metrics = decoded_data.metric_data; - - // The decoded output is sparse and only contains rows with non-zero metrics. - // The order depends on iteration over dense vector. - // metric1 values for combo indices 1,6,7 are 20,10,5. - // The decoded result should contain 3 rows in order of combination index. - // combo 1: CA, es, metric1=20 - // combo 6: US, en, metric1=10 - // combo 7: US, es, metric1=5 - EXPECT_THAT(decoded_metrics, - UnorderedElementsAre(Pair("metric1", ElementsAre(20, 10, 5)))); - EXPECT_THAT( - decoded_groups, - UnorderedElementsAre(Pair("country", ElementsAre("CA", "US", "US")), - Pair("lang", ElementsAre("es", "en", "es")))); -} - -TEST(CodecFactoryTest, EncodeThenDecodeDataOrderDoesNotMatter) { - InputSpec input_spec = CreateTestInputSpecProto(); - MetricData metric_data1 = CreateTestMetricData(); - GroupData group_by_data1 = CreateTestGroupData(); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder1, - CodecFactory::CreateExplicitCodec(input_spec)); - - // Note that the order of metric_data2 and group_by_data2 is different from - // metric_data1 and group_by_data1. The decoded result should be the same. - MetricData metric_data2; - metric_data2["metric1"] = {20, 10, 5}; - GroupData group_by_data2; - group_by_data2["lang"] = {"es", "en", "es"}; - group_by_data2["country"] = {"CA", "US", "US"}; - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder2, - CodecFactory::CreateExplicitCodec(input_spec)); - - SECAGG_ASSERT_OK_AND_ASSIGN(auto encoded_data, - encoder1->Encode(group_by_data1, metric_data1)); - SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data, - encoder2->Decode(encoded_data)); - - const auto& decoded_groups = decoded_data.group_data; - const auto& decoded_metrics = decoded_data.metric_data; - - EXPECT_THAT(decoded_metrics, - UnorderedElementsAre(Pair("metric1", ElementsAre(20, 10, 5)))); - EXPECT_THAT( - decoded_groups, - UnorderedElementsAre(Pair("country", ElementsAre("CA", "US", "US")), - Pair("lang", ElementsAre("es", "en", "es")))); -} - -TEST(CodecFactoryTest, EncodeThenDecodeNoGroupBy) { - MetricData metric_data; - metric_data["metric1"] = {10, 20, 5}; - MetricData expected_metric_data = metric_data; - GroupData group_by_data; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - - SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, - encoder->Encode(group_by_data, metric_data)); - SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data, - encoder->Decode(encoded_data)); - - const GroupData& decoded_groups = decoded_data.group_data; - const MetricData& decoded_metrics = decoded_data.metric_data; - - EXPECT_EQ(decoded_groups.size(), 0); - EXPECT_THAT(decoded_metrics, - UnorderedElementsAre(Pair("metric1", ElementsAre(10, 20, 5)))); -} - -TEST(CodecFactoryTest, EncodeWithDomainValueNotFound) { - MetricData metric_data; - metric_data["metric1"] = {10}; - GroupData group_by_data; - group_by_data["country"] = {"US"}; - group_by_data["lang"] = {"en"}; - group_by_data["feature1"] = {"c"}; // 'c' is not in the domain - - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - auto* group_by_spec1 = input_spec.add_group_by_vector_specs(); - group_by_spec1->set_vector_name("country"); - group_by_spec1->set_data_type(InputSpec::STRING); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "US"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "CA"); - auto* group_by_spec2 = input_spec.add_group_by_vector_specs(); - group_by_spec2->set_vector_name("lang"); - group_by_spec2->set_data_type(InputSpec::STRING); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "en"); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "fr"); - auto* group_by_spec3 = input_spec.add_group_by_vector_specs(); - group_by_spec3->set_vector_name("feature1"); - group_by_spec3->set_data_type(InputSpec::STRING); - group_by_spec3->mutable_domain_spec()->mutable_string_values()->add_values( - "a"); - group_by_spec3->mutable_domain_spec()->mutable_string_values()->add_values( - "b"); - - SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, - CodecFactory::CreateExplicitCodec(input_spec)); - - EXPECT_THAT(encoder->Encode(group_by_data, metric_data), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Domain mismatch for key feature1: " - "group_by_data value c not found in domain"))); -} - -} // namespace -} // namespace willow -} // namespace secure_aggregation diff --git a/willow/input_encoding/time_utils.cc b/willow/input_encoding/time_utils.cc new file mode 100644 index 0000000..c91c2f0 --- /dev/null +++ b/willow/input_encoding/time_utils.cc @@ -0,0 +1,167 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/input_encoding/time_utils.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "willow/proto/willow/input_spec.pb.h" + +namespace secure_aggregation { +namespace willow { +namespace { + +absl::Duration DurationFromProto(const google::protobuf::Duration& proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + +absl::Time TimeFromProto(const google::protobuf::Timestamp& proto) { + return absl::FromUnixSeconds(proto.seconds()) + + absl::Nanoseconds(proto.nanos()); +} + +int64_t GetPeriodIndex(absl::Time t, absl::Time origin_time, + absl::Duration period_duration) { + // Calculate the exact duration elapsed since the time origin. + absl::Duration elapsed = t - origin_time; + + // Round down to the nearest period multiple, even for negative elapsed times. + // e.g. Floor(-12h, 24h) = -24h. + absl::Duration floored_elapsed = absl::Floor(elapsed, period_duration); + + // Calculate the number of elapsed periods using integer division, since + // floored_elapsed is an exact multiple of period_duration. + absl::Duration unused_remainder; + int64_t period_index = + absl::IDivDuration(floored_elapsed, period_duration, &unused_remainder); + + return period_index; +} + +} // namespace + +absl::StatusOr ParseTimeDomain( + const InputSpec::TimeDomain& proto) { + absl::TimeZone tz; + if (proto.timezone().empty()) { + tz = absl::UTCTimeZone(); + } else if (!absl::LoadTimeZone(proto.timezone(), &tz)) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid timezone: ", proto.timezone())); + } + + absl::Duration period_duration = DurationFromProto(proto.period_duration()); + if (period_duration <= absl::ZeroDuration()) { + return absl::InvalidArgumentError("period_duration must be > 0"); + } + + int32_t num_periods = proto.num_periods(); + if (num_periods <= 0) { + return absl::InvalidArgumentError("num_periods must be > 0"); + } + + std::string format; + if (proto.format().empty()) { + format = absl::RFC3339_full; + } else { + format = proto.format(); + } + + absl::Time origin_time = absl::UnixEpoch(); + if (proto.has_origin_time()) { + origin_time = TimeFromProto(proto.origin_time()); + } + + absl::Duration lookback_window; + if (proto.has_lookback_window()) { + lookback_window = DurationFromProto(proto.lookback_window()); + if (lookback_window <= absl::ZeroDuration()) { + return absl::InvalidArgumentError("lookback_window must be > 0"); + } + } else { + lookback_window = period_duration * proto.num_periods() / 2; + } + + return TimeDomainInfo{ + .period_duration = period_duration, + .num_periods = num_periods, + .format = format, + .timezone = tz, + .origin_time = origin_time, + .lookback_window = lookback_window, + }; +} + +int64_t EncodeTime(absl::Time t, const TimeDomainInfo& info, + absl::Time encoding_time) { + // How many full periods have elapsed since origin_time, with handling for + // negative times. + int64_t period_index = + GetPeriodIndex(t, info.origin_time, info.period_duration); + + // Correct negative remainder to positive modulo: bucket in [0, num_periods). + int64_t bucket = period_index % info.num_periods; + if (bucket < 0) { + bucket += info.num_periods; + } + + // Future events are always invalid. + if (t > encoding_time) { + return info.num_periods; // Invalid bucket index (last index in domain) + } + + // Any event older than lookback_window relative to now is marked invalid. + // This filters out obsolete data and prevents wraparound collisions + if (t < encoding_time - info.lookback_window) { + return info.num_periods; // Invalid bucket index + } + + return bucket; +} + +absl::StatusOr DecodeTime(int64_t bucket_index, + const TimeDomainInfo& info, + absl::Time decoding_anchor_time) { + if (bucket_index < 0 || bucket_index >= info.num_periods) { + return absl::InvalidArgumentError( + absl::StrCat("Bucket index out of range: ", bucket_index, + ", must be in [0, ", info.num_periods, ")")); + } + + // anchor_period is the period index of anchor_time relative to origin_time. + int64_t anchor_period = GetPeriodIndex(decoding_anchor_time, info.origin_time, + info.period_duration); + + // We want to find k in [0, num_periods) such that: + // (anchor_period + k) mod num_periods = bucket_index + int64_t k = (bucket_index - anchor_period) % info.num_periods; + if (k < 0) { + k += info.num_periods; // Turn remainder into positive modulo + } + + // Unique period start time t in [anchor_time, anchor_time + info.num_periods + // * info.period_duration) that maps to bucket_index. + return info.origin_time + (anchor_period + k) * info.period_duration; +} + +} // namespace willow +} // namespace secure_aggregation diff --git a/willow/input_encoding/time_utils.h b/willow/input_encoding/time_utils.h new file mode 100644 index 0000000..c85ded2 --- /dev/null +++ b/willow/input_encoding/time_utils.h @@ -0,0 +1,64 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_TIME_UTILS_H_ +#define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_TIME_UTILS_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "willow/proto/willow/input_spec.pb.h" + +namespace secure_aggregation { +namespace willow { + +// The default timestamp used to represent invalid or stale timestamps during +// decoding. +inline absl::Time DefaultTimestamp() { return absl::UnixEpoch(); } + +// Struct containing the parsed time domain parameters. +struct TimeDomainInfo { + absl::Duration period_duration; + int num_periods; + std::string format; + absl::TimeZone timezone; + absl::Time origin_time; + absl::Duration lookback_window; +}; + +// Parses a TimeDomain protobuf spec into a validated TimeDomainInfo struct. +// Returns an InvalidArgument status if spec fields (duration, periods, +// timezone) are invalid. +absl::StatusOr ParseTimeDomain( + const InputSpec::TimeDomain& proto); + +// Encodes a parsed absl::Time into a modular bucket index in [0, num_periods). +// Returns num_periods (the invalid/stale bucket) if the event is in the future +// or older than the lookback window relative to `encoding_time`. +int64_t EncodeTime(absl::Time t, const TimeDomainInfo& info, + absl::Time encoding_time); + +// Reconstructs the absolute start-of-period absl::Time from a bucket index. +// Expects bucket_index in [0, num_periods). Returns an InvalidArgument status +// if the bucket index is out of bounds. +absl::StatusOr DecodeTime(int64_t bucket_index, + const TimeDomainInfo& info, + absl::Time decoding_anchor_time); + +} // namespace willow +} // namespace secure_aggregation + +#endif // SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_TIME_UTILS_H_ diff --git a/willow/input_encoding/time_utils_test.cc b/willow/input_encoding/time_utils_test.cc new file mode 100644 index 0000000..59d9983 --- /dev/null +++ b/willow/input_encoding/time_utils_test.cc @@ -0,0 +1,159 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/input_encoding/time_utils.h" + +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "ffi_utils/status_matchers.h" +#include "gmock/gmock.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "gtest/gtest.h" +#include "willow/proto/willow/input_spec.pb.h" + +namespace secure_aggregation { +namespace willow { +namespace { + +using ::secure_aggregation::secagg_internal::StatusIs; +using ::testing::HasSubstr; + +TEST(TimeDomainHelperTest, ParseTimeDomainWithLookback) { + InputSpec::TimeDomain proto; + proto.mutable_period_duration()->set_seconds(86400); // 1 day + proto.set_num_periods(6); // 6 days + proto.set_timezone("America/Los_Angeles"); + proto.set_format("%Y-%m-%d"); + proto.mutable_origin_time()->set_seconds(1234567890); + proto.mutable_lookback_window()->set_seconds(7 * 24 * 3600); // 7 days + + SECAGG_ASSERT_OK_AND_ASSIGN(TimeDomainInfo info, ParseTimeDomain(proto)); + + EXPECT_EQ(info.period_duration, absl::Hours(24)); + EXPECT_EQ(info.num_periods, 6); + EXPECT_EQ(info.format, "%Y-%m-%d"); + + absl::TimeZone expected_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/Los_Angeles", &expected_tz)); + EXPECT_EQ(info.timezone, expected_tz); + EXPECT_EQ(info.origin_time, absl::FromUnixSeconds(1234567890)); + EXPECT_EQ(info.lookback_window, absl::Hours(7 * 24)); +} + +TEST(TimeDomainHelperTest, ParseTimeDomainDefaultLookback) { + InputSpec::TimeDomain proto; + proto.mutable_period_duration()->set_seconds(86400); // 1 day + proto.set_num_periods(6); // 6 days + proto.set_timezone("UTC"); + + SECAGG_ASSERT_OK_AND_ASSIGN(TimeDomainInfo info, ParseTimeDomain(proto)); + // Default lookback is 6 * 1 day / 2 = 3 days (72 hours). + EXPECT_EQ(info.lookback_window, absl::Hours(3 * 24)); +} + +TEST(TimeDomainHelperTest, ParseTimeDomainInvalid) { + InputSpec::TimeDomain proto; + // Missing period_duration (default 0) + proto.set_num_periods(6); + EXPECT_THAT(ParseTimeDomain(proto), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("period_duration must be > 0"))); + + proto.mutable_period_duration()->set_seconds(86400); + proto.set_num_periods(0); // Invalid periods + EXPECT_THAT(ParseTimeDomain(proto), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("num_periods must be > 0"))); + + proto.set_num_periods(6); + proto.set_timezone("Invalid/Timezone"); + EXPECT_THAT(ParseTimeDomain(proto), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid timezone"))); + proto.set_timezone("UTC"); + + // Invalid lookback_window <= 0 + proto.mutable_lookback_window()->set_seconds(0); + EXPECT_THAT(ParseTimeDomain(proto), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("lookback_window must be > 0"))); +} + +TEST(TimeDomainHelperTest, EncodeTimeWithCustomLookback) { + TimeDomainInfo info{ + .period_duration = absl::Hours(24), + .num_periods = 6, + .format = "%Y-%m-%d", + .timezone = absl::UTCTimeZone(), + .origin_time = absl::UnixEpoch(), + .lookback_window = + absl::Hours(24 * 2), // Events older than 2 days are stale + }; + + absl::Time encoding_time = + absl::UnixEpoch() + absl::Hours(24 * 10); // Day 10 (1970-01-11) + + // Valid event: 1 day ago + absl::Time t1 = encoding_time - absl::Hours(24 * 1); + EXPECT_EQ(EncodeTime(t1, info, encoding_time), 3); + + // Stale event: 3 days ago + absl::Time t2 = encoding_time - absl::Hours(24 * 3); + EXPECT_EQ(EncodeTime(t2, info, encoding_time), 6); // Invalid bucket + + // Future event: 1 day in future. + absl::Time t3 = encoding_time + absl::Hours(24 * 1); + EXPECT_EQ(EncodeTime(t3, info, encoding_time), 6); // Invalid bucket +} + +TEST(TimeDomainHelperTest, DecodeTime) { + TimeDomainInfo info{ + .period_duration = absl::Hours(24), + .num_periods = 6, + .format = "%Y-%m-%d", + .timezone = absl::UTCTimeZone(), + .origin_time = absl::UnixEpoch(), + }; + + // anchor_time = Day 8 + absl::Time anchor_time = absl::UnixEpoch() + absl::Hours(24 * 8); + + // Decode bucket 3, expect Day 9 because 9 mod 6 = 3 and 8 <= 9 < 8 + 6 + absl::Time expected1 = absl::UnixEpoch() + absl::Hours(24 * 9); + + SECAGG_ASSERT_OK_AND_ASSIGN(absl::Time decoded1, + DecodeTime(3, info, anchor_time)); + EXPECT_EQ(decoded1, expected1); + EXPECT_GE(decoded1, anchor_time); + EXPECT_LT(decoded1, anchor_time + absl::Hours(24 * 6)); + + // Decode bucket 1, expect Day 13 because 13 mod 6 = 1 and 8 <= 13 < 8 + 6. + absl::Time expected2 = absl::UnixEpoch() + absl::Hours(24 * 13); + + SECAGG_ASSERT_OK_AND_ASSIGN(absl::Time decoded2, + DecodeTime(1, info, anchor_time)); + EXPECT_EQ(decoded2, expected2); + EXPECT_GE(decoded2, anchor_time); + EXPECT_LT(decoded2, anchor_time + absl::Hours(24 * 6)); + + // Invalid bucket index raises an error. + EXPECT_THAT(DecodeTime(6, info, anchor_time), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Bucket index out of range"))); +} + +} // namespace +} // namespace willow +} // namespace secure_aggregation diff --git a/willow/proto/willow/BUILD b/willow/proto/willow/BUILD index f543f23..12672b3 100644 --- a/willow/proto/willow/BUILD +++ b/willow/proto/willow/BUILD @@ -84,6 +84,10 @@ rust_proto_library( proto_library( name = "input_spec_proto", srcs = ["input_spec.proto"], + deps = [ + "@protobuf//:duration_proto", + "@protobuf//:timestamp_proto", + ], ) cc_proto_library( diff --git a/willow/proto/willow/input_spec.proto b/willow/proto/willow/input_spec.proto index 2d7013a..e24c00f 100644 --- a/willow/proto/willow/input_spec.proto +++ b/willow/proto/willow/input_spec.proto @@ -16,6 +16,9 @@ syntax = "proto3"; package secure_aggregation.willow; +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; + option java_multiple_files = true; option java_outer_classname = "InputSpecProto"; @@ -89,6 +92,70 @@ message InputSpec { repeated string values = 1; } + // A domain for timestamps represented in a modular way, yielding a small + // closed domain instead of an unbounded number of timestamps. + // + // Input timestamps t (as strings) are parsed and mapped to a period index in + // [0, num_periods) using: + // period_index = ((t - origin_time) / period_duration) mod num_periods + // Stale events (older than lookback_window relative to Now()) or future + // events are mapped to a special invalid bucket (index num_periods). During + // decoding, timestamps are reconstructed relative to a dynamic anchor_time. + // + // Inputs encoded more than (num_periods * period_duration - lookback_window) + // apart might collide, i.e. contain different timestamps mapping to the same + // bucket. Shorter lookback windows allow for longer retention before + // aggregation and decoding. + // + // Example: inputs from Device 1 and Device 3 should not be decoded together. + // + // <----- num_periods * period_duration -----> + // --+------|------|------|------|------|------|------|------|------|--> Time + // | B3 | B4 | B5 | B0 | B1 | B2 | B3 | B4 | B5 | Bucket + // | Day 1| Day 2| Day 3| Day 4| Day 5| Day 6| Day 7| Day 8| Day 9| Day + // +------+------+------+------+------+------+------+------+------+ + // ^ ^ + // anchor_time anchor_time + num_periods * period_duration + // + // Device 1: + // t_1 + // v + // [=============] (lookback_window = 2 days) + // + // Device 2: + // t_2 + // v + // [=============] + // Device 3: + // t_3 + // v + // [=============] + // + + message TimeDomain { + // The duration of each period. + google.protobuf.Duration period_duration = 1; + // The number of periods. + int32 num_periods = 2; + // The format string used for parsing and formatting (absl::ParseTime + // compatible). If empty, RFC 3339 is used. + string format = 3; + // The timezone to use if the timestamp string does not contain timezone + // info (e.g., "America/Los_Angeles", "UTC"). Must be absl::LoadTimeZone + // compatible, i.e. a valid IANA timezone name. If not set, "UTC" is used. + // If the timestamp string contains a UTC offset specification (e.g., %z), + // the timezone offset in the string takes precedence. + string timezone = 4; + // The origin time from which the periods are calculated. + // If not set, the Unix Epoch (1970-01-01T00:00:00Z) is used. + google.protobuf.Timestamp origin_time = 5; + // The lookback window defining how far back in time from the encoding + // time we accept events. Events older than this duration, or + // events in the future, will be mapped to the invalid bucket. + // If not set, defaults to `num_periods * period_duration / 2`. + google.protobuf.Duration lookback_window = 6; + } + // A new message type to represent the domain specification. message DomainSpec { oneof domain_type { @@ -97,6 +164,9 @@ message InputSpec { // Defines a domain as an interval of values. Interval interval = 2; + + // Defines a domain for timestamps. + TimeDomain time = 3; } }