diff --git a/willow/api/BUILD b/willow/api/BUILD index f54b9d8..e4b477f 100644 --- a/willow/api/BUILD +++ b/willow/api/BUILD @@ -195,6 +195,7 @@ cc_test( srcs = ["client_test.cc"], deps = [ ":client_cc", + "@protobuf//:duration_cc_proto", "@googletest//:gtest_main", "@abseil-cpp//absl/status", "//ffi_utils:status_matchers", diff --git a/willow/api/client.cc b/willow/api/client.cc index b36ed42..20c9f66 100644 --- a/willow/api/client.cc +++ b/willow/api/client.cc @@ -29,6 +29,7 @@ #include "include/cxx.h" #include "willow/api/client.rs.h" #include "willow/input_encoding/codec.h" +#include "willow/input_encoding/codec_factory.h" #include "willow/proto/shell/ciphertexts.pb.h" #include "willow/proto/willow/aggregation_config.pb.h" #include "willow/proto/willow/input_spec.pb.h" @@ -49,14 +50,19 @@ absl::StatusOr CreateAggregationConfig( // of group-by domains. int64_t flattened_domain_size = 1; for (const auto& group_by_spec : input_spec_proto.group_by_vector_specs()) { - if (group_by_spec.domain_spec().string_values().values_size() == 0) { - return absl::InvalidArgumentError(absl::StrCat( - "Missing domain, invalid domain type (must be StringValues), or " - "empty string_values for group by vector: ", - group_by_spec.vector_name())); + if (!group_by_spec.domain_spec().has_time() && + !group_by_spec.domain_spec().has_string_values()) { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported domain type for group-by vector: ", + group_by_spec.vector_name())); } - flattened_domain_size *= - group_by_spec.domain_spec().string_values().values_size(); + int domain_size = willow::CodecFactory::GetDomainSize(group_by_spec); + if (domain_size <= 0) { + return absl::InvalidArgumentError( + absl::StrCat("Empty or invalid domain size for group by vector: ", + group_by_spec.vector_name())); + } + flattened_domain_size *= domain_size; } // Build VectorConfig (length and bound) for each metric. for (const auto& metric_spec : input_spec_proto.metric_vector_specs()) { diff --git a/willow/api/client_test.cc b/willow/api/client_test.cc index 05f5956..4b06356 100644 --- a/willow/api/client_test.cc +++ b/willow/api/client_test.cc @@ -23,6 +23,7 @@ #include "absl/status/status.h" #include "ffi_utils/status_matchers.h" #include "gmock/gmock.h" +#include "google/protobuf/duration.pb.h" #include "gtest/gtest.h" #include "willow/input_encoding/codec.h" #include "willow/input_encoding/codec_factory.h" @@ -185,6 +186,33 @@ TEST(WillowShellClientTest, CreateAggregationConfigDefaultBound) { EXPECT_EQ(vector_configs.at("metric1").bound(), default_bound); } +TEST(WillowShellClientTest, CreateAggregationConfigSuccessWithTimeDomain) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + metric_spec->mutable_domain_spec()->mutable_interval()->set_max(100); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); + time_domain->set_timezone("UTC"); + + SECAGG_ASSERT_OK_AND_ASSIGN( + AggregationConfigProto config, + CreateAggregationConfig(input_spec, /*key_id=*/"test", + /*max_number_of_clients=*/10)); + + // Expected flattened_domain_size = num_periods + 1 = 7 + const auto& vector_configs = config.vector_configs(); + ASSERT_TRUE(vector_configs.contains("metric1")); + EXPECT_EQ(vector_configs.at("metric1").length(), 7); + EXPECT_EQ(vector_configs.at("metric1").bound(), 100); +} + TEST(WillowShellClientTest, CreateAggregationConfigFailsOnEmptyDomain) { InputSpec input_spec = CreateTestInputSpecProto(); ASSERT_GT(input_spec.group_by_vector_specs_size(), 0); diff --git a/willow/input_encoding/BUILD b/willow/input_encoding/BUILD index 48bd377..78784d8 100644 --- a/willow/input_encoding/BUILD +++ b/willow/input_encoding/BUILD @@ -23,6 +23,25 @@ package( default_visibility = ["//visibility:public"], ) +cc_library( + name = "time_utils", + srcs = [ + "time_utils.cc", + ], + hdrs = [ + "time_utils.h", + ], + deps = [ + "@protobuf//:duration_cc_proto", + "@protobuf//:timestamp_cc_proto", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/time", + "//willow/proto/willow:input_spec_cc_proto", + ], +) + cc_library( name = "codec", srcs = [ @@ -33,12 +52,15 @@ cc_library( "codec_factory.h", ], deps = [ + ":time_utils", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/memory", "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/time", + "//ffi_utils:status_macros", "//willow/proto/willow:input_spec_cc_proto", ], ) @@ -48,15 +70,33 @@ cc_test( srcs = ["explicit_codec_test.cc"], deps = [ ":codec", + ":time_utils", + "@protobuf//:duration_cc_proto", "@googletest//:gtest_main", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/status", + "@abseil-cpp//absl/time", "//ffi_utils:status_matchers", "//willow/proto/willow:input_spec_cc_proto", "//willow/testing_utils:testing_utils_cc", ], ) +cc_test( + name = "time_utils_test", + srcs = ["time_utils_test.cc"], + deps = [ + ":time_utils", + "@protobuf//:duration_cc_proto", + "@protobuf//:timestamp_cc_proto", + "@googletest//:gtest_main", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/time", + "//ffi_utils:status_matchers", + "//willow/proto/willow:input_spec_cc_proto", + ], +) + pytype_pybind_extension( name = "codec_bindings", srcs = ["codec_bindings.cc"], diff --git a/willow/input_encoding/codec_bindings.cc b/willow/input_encoding/codec_bindings.cc index d076527..3be6caa 100644 --- a/willow/input_encoding/codec_bindings.cc +++ b/willow/input_encoding/codec_bindings.cc @@ -46,6 +46,12 @@ PYBIND11_MODULE(codec_bindings, m) { return self.ValidateExampleQuery(cpp_map); }); + // Creates an ExplicitCodec from a serialized InputSpec. + // + // Remark: 'encoding_time' and 'decoding_anchor_time' are not exposed to + // Python because these bindings are currently only used to validate example + // queries (ValidateExampleQuery), which does not require encoding or + // decoding. m.def( "CreateExplicitCodec", [](const std::string& serialized_input_spec) diff --git a/willow/input_encoding/codec_factory.cc b/willow/input_encoding/codec_factory.cc index b6e8ec3..167d20a 100644 --- a/willow/input_encoding/codec_factory.cc +++ b/willow/input_encoding/codec_factory.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -28,7 +29,10 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "ffi_utils/status_macros.h" #include "willow/input_encoding/codec.h" +#include "willow/input_encoding/time_utils.h" #include "willow/proto/willow/input_spec.pb.h" namespace secure_aggregation { @@ -87,6 +91,12 @@ class ExplicitCodecImpl : public Codec { // The indices within their respective domain of each group-by key. absl::flat_hash_map group_by_domain_indices_; + std::optional encoding_time_; + std::optional decoding_anchor_time_; + + // Map of group-by key names to their TimeDomainInfo if they are time domains. + absl::flat_hash_map time_domains_; + absl::Status ValidateData(const GroupData& group_by_data, const MetricData& metric_data) const; @@ -94,35 +104,52 @@ class ExplicitCodecImpl : public Codec { size_t GetCombinedIndex(const std::vector& indices) const; + int EncodeGroupValue(const std::string& group_name, + const std::string& value) const; + + absl::StatusOr DecodeGroupValue(const std::string& group_name, + int bucket_index) const; + explicit ExplicitCodecImpl( InputSpec input_spec, absl::flat_hash_map group_by_spec_map, - absl::flat_hash_map metric_spec_map) + absl::flat_hash_map metric_spec_map, + std::optional encoding_time, + std::optional decoding_anchor_time) : input_spec_(std::move(input_spec)), group_by_spec_map_(std::move(group_by_spec_map)), - metric_spec_map_(std::move(metric_spec_map)) { + metric_spec_map_(std::move(metric_spec_map)), + encoding_time_(encoding_time), + decoding_anchor_time_(decoding_anchor_time) { group_by_keys_.reserve(group_by_spec_map_.size()); - // Compute sorted group-by keys. + // Compute sorted group-by keys and initialize domains. for (const auto& [key, spec] : group_by_spec_map_) { group_by_keys_.push_back(spec->vector_name()); - // Precompute indices into domains to allow efficient lookups. - const auto& domain = spec->domain_spec().string_values(); - for (int i = 0; i < domain.values_size(); ++i) { - group_by_domain_indices_[GroupDomainKey{spec->vector_name(), - domain.values(i)}] = i; + if (spec->domain_spec().has_time()) { + // Validation in CodecFactory ensures this succeeds. + auto ts_info_or = ParseTimeDomain(spec->domain_spec().time()); + CHECK(ts_info_or.ok()); + time_domains_[spec->vector_name()] = *std::move(ts_info_or); + } else { + // Precompute indices into domains to allow efficient lookups for string + // domains. + const auto& domain = spec->domain_spec().string_values(); + for (int i = 0; i < domain.values_size(); ++i) { + group_by_domain_indices_[GroupDomainKey{spec->vector_name(), + domain.values(i)}] = i; + } } } std::sort(group_by_keys_.begin(), group_by_keys_.end()); - // Compute the sizes of the string domains for each group-by key. + // Compute the sizes of the string/timestamp domains for each group-by key. flattened_domain_size_ = 1; group_by_domain_sizes_.reserve(group_by_keys_.size()); for (const auto& key : group_by_keys_) { auto spec_it = group_by_spec_map_.find(key); CHECK(spec_it != group_by_spec_map_.end()); // We expect the key to be present. - int domain_size = - spec_it->second->domain_spec().string_values().values_size(); + int domain_size = CodecFactory::GetDomainSize(*spec_it->second); group_by_domain_sizes_.push_back(domain_size); flattened_domain_size_ *= domain_size; } @@ -131,6 +158,7 @@ class ExplicitCodecImpl : public Codec { }; absl::Status ExplicitCodecImpl::ValidateData( + const GroupData& group_by_data, const MetricData& metric_data) const { // Check that all vectors in metric_data and group_by_data are present in // metric_spec_map and group_by_spec_map_, respectively. This ensures that the @@ -207,11 +235,24 @@ absl::Status ExplicitCodecImpl::ValidateData( } // Check that all values in group_by_data are in the domain provided in the // input_spec. - for (const auto& d : data) { - if (!group_by_domain_indices_.contains(GroupDomainKey{name, d})) { - return absl::InvalidArgumentError( - absl::StrCat("Domain mismatch for key ", name, - ": group_by_data value ", d, " not found in domain.")); + if (spec->domain_spec().has_time()) { + const auto& ts_info = time_domains_.at(name); + for (const auto& d : data) { + absl::Time t; + std::string err; + if (!absl::ParseTime(ts_info.format, d, ts_info.timezone, &t, &err)) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse timestamp ", d, " with format ", + ts_info.format, ": ", err)); + } + } + } else { + for (const auto& d : data) { + if (!group_by_domain_indices_.contains(GroupDomainKey{name, d})) { + return absl::InvalidArgumentError(absl::StrCat( + "Domain mismatch for key ", name, ": group_by_data value ", d, + " not found in domain.")); + } } } } @@ -251,11 +292,67 @@ size_t ExplicitCodecImpl::GetCombinedIndex( return combined_index; } +// Helper function to encode a single group-by value into the corresponding +// index in its domain. +int ExplicitCodecImpl::EncodeGroupValue(const std::string& group_name, + const std::string& value) const { + auto spec_it = group_by_spec_map_.find(group_name); + CHECK(spec_it != group_by_spec_map_.end()); + if (spec_it->second->domain_spec().has_time()) { + const auto& ts_info = time_domains_.at(group_name); + absl::Time t; + // Already validated in ValidateData. + CHECK( + absl::ParseTime(ts_info.format, value, ts_info.timezone, &t, nullptr)); + // encoding_time_ is guaranteed to have a value if there is a time domain + // due to the validation check at the start of Encode(). + return EncodeTime(t, ts_info, *encoding_time_); + } else { + auto it = group_by_domain_indices_.find(GroupDomainKey{group_name, value}); + // ValidateData ensures the key exists in the domain. + CHECK(it != group_by_domain_indices_.end()); + return it->second; + } +} + +// Helper function to decode a bucket index back to its original string +// representation. +absl::StatusOr ExplicitCodecImpl::DecodeGroupValue( + const std::string& group_name, int bucket_index) const { + auto spec_it = group_by_spec_map_.find(group_name); + CHECK(spec_it != group_by_spec_map_.end()); + + if (spec_it->second->domain_spec().has_time()) { + const auto& ts_info = time_domains_.at(group_name); + if (bucket_index == ts_info.num_periods) { + return std::string(kInvalidTimestamp); + } + // decoding_anchor_time_ is guaranteed to have value if time_domains_ is not + // empty. + SECAGG_ASSIGN_OR_RETURN( + auto reconstructed_time, + DecodeTime(bucket_index, ts_info, *decoding_anchor_time_)); + return absl::FormatTime(ts_info.format, reconstructed_time, + ts_info.timezone); + } else { + const auto& domain = spec_it->second->domain_spec().string_values(); + if (bucket_index < 0 || bucket_index >= domain.values_size()) { + return absl::InvalidArgumentError( + absl::StrCat("Index ", bucket_index, " for key ", group_name, + " is out of bounds [0, ", domain.values_size(), ").")); + } + return domain.values(bucket_index); + } +} + absl::StatusOr ExplicitCodecImpl::Encode( const GroupData& group_by_data, const MetricData& metric_data) const { - if (absl::Status status = ValidateData(group_by_data, metric_data); - !status.ok()) { - return status; + SECAGG_RETURN_IF_ERROR(ValidateData(group_by_data, metric_data)); + + if (!time_domains_.empty() && !encoding_time_.has_value()) { + return absl::FailedPreconditionError( + "encoding_time is required in the constructor for encoding time " + "domains"); } absl::flat_hash_map> result; @@ -273,11 +370,8 @@ absl::StatusOr ExplicitCodecImpl::Encode( for (const auto& group_name : group_by_keys_) { auto it_data = group_by_data.find(group_name); CHECK(it_data != group_by_data.end()); - auto key = it_data->second[i]; - auto it = - group_by_domain_indices_.find(GroupDomainKey{group_name, key}); - // ValidateData ensures the key exists in the domain. - indices.push_back(it->second); + const std::string& group_value = it_data->second[i]; + indices.push_back(EncodeGroupValue(group_name, group_value)); } result[metric_name][GetCombinedIndex(indices)] = values[i]; } @@ -288,6 +382,13 @@ absl::StatusOr ExplicitCodecImpl::Encode( absl::StatusOr ExplicitCodecImpl::Decode( const EncodedData& encoded_data) const { + // If the input spec defines any time domain, we must have a decoding anchor + // time to reconstruct the absolute timestamps during decoding. + if (!time_domains_.empty() && !decoding_anchor_time_.has_value()) { + return absl::FailedPreconditionError( + "decoding_anchor_time is required in the constructor for decoding time " + "domains"); + } DecodedData decoded_data; if (group_by_keys_.empty()) { @@ -329,16 +430,9 @@ absl::StatusOr ExplicitCodecImpl::Decode( std::vector indices = GetIndices(i); for (int j = 0; j < group_by_keys_.size(); ++j) { const auto& key_name = group_by_keys_[j]; - auto spec_it = group_by_spec_map_.find(key_name); - CHECK(spec_it != group_by_spec_map_.end()); - const auto& domain = spec_it->second->domain_spec().string_values(); - // Check that the index is in the domain. - if (indices[j] >= domain.values_size()) { - return absl::InvalidArgumentError(absl::StrCat( - "Index ", indices[j], " for key ", key_name, " is out of bounds ", - "[0, ", domain.values_size(), ").")); - } - decoded_data.group_data[key_name].push_back(domain.values(indices[j])); + SECAGG_ASSIGN_OR_RETURN(std::string decoded_val, + DecodeGroupValue(key_name, indices[j])); + decoded_data.group_data[key_name].push_back(decoded_val); } for (const auto& [metric_name, values] : encoded_data) { decoded_data.metric_data[metric_name].push_back(values[i]); @@ -377,7 +471,20 @@ absl::Status CodecFactory::ValidateExplicitCodecInputSpec( const InputSpec& input_spec, size_t max_flattened_domain_size) { size_t flattened_domain_size = 1; for (const auto& spec : input_spec.group_by_vector_specs()) { - flattened_domain_size *= spec.domain_spec().string_values().values_size(); + if (spec.domain_spec().has_time()) { + if (spec.data_type() != InputSpec::STRING) { + return absl::InvalidArgumentError( + "Time domain can only be used with STRING data type."); + } + SECAGG_ASSIGN_OR_RETURN(auto ts_info, + ParseTimeDomain(spec.domain_spec().time())); + flattened_domain_size *= (ts_info.num_periods + 1); // +1 for invalid + } else if (spec.domain_spec().has_string_values()) { + flattened_domain_size *= spec.domain_spec().string_values().values_size(); + } else { + return absl::InvalidArgumentError( + "Unsupported domain type for group-by vector"); + } if (max_flattened_domain_size < flattened_domain_size) { return absl::InvalidArgumentError( "Global output domain size exceeds maximum threshold."); @@ -387,7 +494,8 @@ absl::Status CodecFactory::ValidateExplicitCodecInputSpec( } absl::StatusOr> CodecFactory::CreateExplicitCodec( - InputSpec input_spec) { + InputSpec input_spec, std::optional encoding_time, + std::optional decoding_anchor_time) { // Check that specs include at least one metric vector. if (input_spec.metric_vector_specs().empty()) { return absl::InvalidArgumentError( @@ -408,9 +516,28 @@ absl::StatusOr> CodecFactory::CreateExplicitCodec( absl::StrCat("Duplicate vector name: ", spec.vector_name())); } } - return absl::WrapUnique(new ExplicitCodecImpl(std::move(input_spec), - std::move(group_by_spec_map), - std::move(metric_spec_map))); + // Validate input spec (including time domains if present) + for (const auto& spec : input_spec.group_by_vector_specs()) { + if (spec.domain_spec().has_time()) { + if (spec.data_type() != InputSpec::STRING) { + return absl::InvalidArgumentError( + "Time domain can only be used with STRING data type."); + } + SECAGG_RETURN_IF_ERROR( + ParseTimeDomain(spec.domain_spec().time()).status()); + } + } + + return absl::WrapUnique(new ExplicitCodecImpl( + std::move(input_spec), std::move(group_by_spec_map), + std::move(metric_spec_map), encoding_time, decoding_anchor_time)); +} + +int CodecFactory::GetDomainSize(const InputSpec::InputVectorSpec& spec) { + if (spec.domain_spec().has_time()) { + return spec.domain_spec().time().num_periods() + 1; + } + return spec.domain_spec().string_values().values_size(); } } // namespace willow diff --git a/willow/input_encoding/codec_factory.h b/willow/input_encoding/codec_factory.h index 5ef4bb6..a7c00be 100644 --- a/willow/input_encoding/codec_factory.h +++ b/willow/input_encoding/codec_factory.h @@ -15,11 +15,16 @@ #ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_ #define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_ #include +#include #include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/time/time.h" #include "willow/input_encoding/codec.h" +#include "willow/input_encoding/time_utils.h" #include "willow/proto/willow/input_spec.pb.h" namespace secure_aggregation { @@ -33,14 +38,29 @@ constexpr size_t kMaxGlobalOutputDomainSize = 1000000; class CodecFactory { public: // Creates an instance of ExplicitCodec. + // + // `encoding_time` is used during encoding to evaluate staleness for + // time domains. Must be set if the input spec contains any TimeDomain and the + // codec is used for encoding. + // + // `decoding_anchor_time` is used during decoding to align the modular periods + // back to absolute timestamps. It represents the start time of the current + // aggregation window. Must be set if the input spec contains any TimeDomain + // and the codec is used for decoding. static absl::StatusOr> CreateExplicitCodec( - InputSpec input_spec); + InputSpec input_spec, + std::optional encoding_time = std::nullopt, + std::optional decoding_anchor_time = std::nullopt); // Check that the combined size of the string domains is less than the // maximum allowed size. static absl::Status ValidateExplicitCodecInputSpec( const InputSpec& input_spec, size_t max_flattened_domain_size = kMaxGlobalOutputDomainSize); + + // Returns the domain size of a group-by vector spec. If it is a TimeDomain, + // returns num_periods + 1. Otherwise, returns the number of string values. + static int GetDomainSize(const InputSpec::InputVectorSpec& spec); }; } // namespace willow diff --git a/willow/input_encoding/explicit_codec_test.cc b/willow/input_encoding/explicit_codec_test.cc index eb46406..9c3c37d 100644 --- a/willow/input_encoding/explicit_codec_test.cc +++ b/willow/input_encoding/explicit_codec_test.cc @@ -12,17 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/time/time.h" #include "ffi_utils/status_matchers.h" #include "gmock/gmock.h" +#include "google/protobuf/duration.pb.h" #include "gtest/gtest.h" #include "willow/input_encoding/codec.h" #include "willow/input_encoding/codec_factory.h" +#include "willow/input_encoding/time_utils.h" #include "willow/proto/willow/input_spec.pb.h" #include "willow/testing_utils/testing_utils.h" @@ -32,6 +37,7 @@ namespace { using ::secure_aggregation::secagg_internal::IsOkAndHolds; using ::secure_aggregation::secagg_internal::StatusIs; + using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::Pair; @@ -298,6 +304,48 @@ TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) { SECAGG_EXPECT_OK(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 2)); } +TEST(CodecFactoryTest, ValidateInputSpecFlattenedDomainSize) { + InputSpec input_spec; + + // Metric 1 + auto* metric_spec1 = input_spec.add_metric_vector_specs(); + metric_spec1->set_vector_name("metric1"); + metric_spec1->set_data_type(InputSpec::INT64); + + // Metric 2 + auto* metric_spec2 = input_spec.add_metric_vector_specs(); + metric_spec2->set_vector_name("metric2"); + metric_spec2->set_data_type(InputSpec::INT64); + + // Group-by 1: Time Domain (size = 5 + 1 = 6) + auto* group_by_spec1 = input_spec.add_group_by_vector_specs(); + group_by_spec1->set_vector_name("time"); + group_by_spec1->set_data_type(InputSpec::STRING); + auto* time_domain = group_by_spec1->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(5); // 5 days + time_domain->set_timezone("UTC"); + time_domain->set_format(absl::RFC3339_full); + + // Group-by 2: String values domain (size = 3) + auto* group_by_spec2 = input_spec.add_group_by_vector_specs(); + group_by_spec2->set_vector_name("country"); + group_by_spec2->set_data_type(InputSpec::STRING); + auto* string_values = + group_by_spec2->mutable_domain_spec()->mutable_string_values(); + string_values->add_values("US"); + string_values->add_values("CA"); + string_values->add_values("MX"); + + // Total flattened domain size is (num_periods + 1) * string_values_size = 6 * + // 3 = 18. Passing a smaller limit should fail. + EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 17), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Global output domain size exceeds"))); + SECAGG_EXPECT_OK( + CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 18)); +} + TEST(CodecFactoryTest, EncodeSimpleGroupBy) { InputSpec input_spec = CreateTestInputSpecProto(); MetricData metric_data = CreateTestMetricData(); @@ -513,6 +561,356 @@ TEST(CodecFactoryTest, EncodeWithDomainValueNotFound) { "group_by_data value c not found in domain"))); } +TEST(CodecFactoryTest, ValidateInputSpecWithTimeDomain) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("UTC"); + + // Valid spec + SECAGG_EXPECT_OK(CodecFactory::ValidateExplicitCodecInputSpec(input_spec)); + + // Valid spec with lookback_window + time_domain->mutable_lookback_window()->set_seconds(86400 * 2); // 2 days + SECAGG_EXPECT_OK(CodecFactory::ValidateExplicitCodecInputSpec(input_spec)); + + // Invalid lookback_window <= 0 + time_domain->mutable_lookback_window()->set_seconds(0); + EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("lookback_window must be > 0"))); + time_domain->clear_lookback_window(); + + // Invalid period_duration <= 0 + time_domain->mutable_period_duration()->set_seconds(0); + EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("period_duration must be > 0"))); + time_domain->mutable_period_duration()->set_seconds(86400); + + // Invalid num_periods <= 0 + time_domain->set_num_periods(0); + EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("num_periods must be > 0"))); + time_domain->set_num_periods(6); + + // Invalid timezone + time_domain->set_timezone("Invalid/Timezone"); + EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid timezone"))); + time_domain->set_timezone("UTC"); + + // Invalid data type (must be STRING) + group_by_spec->set_data_type(InputSpec::INT64); + EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Time domain can only be used with STRING"))); +} + +TEST(CodecFactoryTest, EncodeTimeDomainFailsWithoutEncodingTime) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("UTC"); + + SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, + CodecFactory::CreateExplicitCodec(input_spec)); + + MetricData metric_data; + metric_data["metric1"] = {10}; + GroupData group_by_data; + group_by_data["time"] = {"2026-01-08T12:00:00Z"}; + + // Encode without encoding_time should fail + EXPECT_THAT(encoder->Encode(group_by_data, metric_data), + StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("encoding_time is required"))); +} + +TEST(CodecFactoryTest, EncodeTimeDomain) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("UTC"); + time_domain->set_format(absl::RFC3339_full); + time_domain->mutable_lookback_window()->set_seconds(86400 * 2); // 2 days + + // Encoding time = Day 8 start + absl::Time encoding_time; + std::string err; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "1970-01-09T00:00:00Z", + &encoding_time, &err)) + << err; + + // Event times + std::string t1_str = + "1970-01-08T12:00:00Z"; // Valid (Day 7, 12h ago relative to Day 8 start) + std::string t2_str = "1970-01-06T12:00:00Z"; // Stale (Day 5, 2.5d ago) + std::string t3_str = "1970-01-10T12:00:00Z"; // Future (Day 9, 1.5d future) + + MetricData metric_data; + metric_data["metric1"] = {10, 20, 30}; + GroupData group_by_data; + group_by_data["time"] = {t1_str, t2_str, t3_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder, + CodecFactory::CreateExplicitCodec(input_spec, encoding_time)); + + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, + encoder->Encode(group_by_data, metric_data)); + + EXPECT_THAT(encoded_data.at("metric1"), ElementsAre(0, 10, 0, 0, 0, 0, 30)); +} + +TEST(CodecFactoryTest, DecodeTimeDomain) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("UTC"); + time_domain->set_format("%Y-%m-%d"); + + // decoding_anchor_time = 1970-01-09 00:00:00 UTC (Day 8 start) + absl::Time decoding_anchor_time; + std::string err; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "1970-01-09T00:00:00Z", + &decoding_anchor_time, &err)) + << err; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr decoder, + CodecFactory::CreateExplicitCodec( + input_spec, /*encoding_time=*/std::nullopt, decoding_anchor_time)); + + // Create a histogram with 7 buckets (num_periods + 1). + EncodedData encoded_data; + encoded_data["metric1"] = std::vector(7, 0); + encoded_data["metric1"][2] = 100; + encoded_data["metric1"][6] = 200; + + // Decode and check the timestamps. + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data, + decoder->Decode(encoded_data)); + EXPECT_THAT(decoded_data.metric_data.at("metric1"), ElementsAre(100, 200)); + + // Bucket 2 maps to Day 8 since since 8 % 6 = 2 and 8 <= 8 < 8 + 6. + // Bucket 6 is an invalid timestamp. + std::string expected_time_str = "1970-01-09"; + EXPECT_THAT(decoded_data.group_data.at("time"), + ElementsAre(expected_time_str, kInvalidTimestamp)); +} + +TEST(CodecFactoryTest, EncodeThenDecodeLocalTime) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + // Unset timezone defaults to UTC. + time_domain->set_format("%Y-%m-%d"); // No timezone in output, just civil day + time_domain->mutable_lookback_window()->set_seconds(86400 * 2); // 2 days + + // Scenario: Two local events occurring in different physical timezones, but + // representing the exact same civil day on their respective devices. + + // --- Client 1 (New York, EDT/UTC-4) --- + absl::TimeZone ny_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/New_York", &ny_tz)); + // Device encodes at local 17:34:56. + absl::Time encoding_time1; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-12T17:34:56-04:00", + &encoding_time1, nullptr)); + + // Event occurred 15 minutes prior to encoding. + absl::Time event_time_ny = encoding_time1 - absl::Minutes(15); + // Format using local format "%Y-%m-%d" -> "2026-05-12". + std::string t1_str = absl::FormatTime("%Y-%m-%d", event_time_ny, ny_tz); + MetricData md1; + md1["metric1"] = {10}; + GroupData gd1; + gd1["time"] = {t1_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder1, + CodecFactory::CreateExplicitCodec(input_spec, encoding_time1)); + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData ed1, encoder1->Encode(gd1, md1)); + + // --- Client 2 (Los Angeles, PDT/UTC-7) --- + absl::TimeZone la_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/Los_Angeles", &la_tz)); + // Device encodes at local 17:39:12. + absl::Time encoding_time2; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-12T17:39:12-07:00", + &encoding_time2, nullptr)); + + // Event occurred 3 hours and 12 minutes prior to encoding + absl::Time event_time_la = + encoding_time2 - absl::Hours(3) - absl::Minutes(12); + // Format using local format "%Y-%m-%d" -> "2026-05-12". + std::string t2_str = absl::FormatTime("%Y-%m-%d", event_time_la, la_tz); + MetricData md2; + md2["metric1"] = {20}; + GroupData gd2; + gd2["time"] = {t2_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder2, + CodecFactory::CreateExplicitCodec(input_spec, encoding_time2)); + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData ed2, encoder2->Encode(gd2, md2)); + + // Decode using decoding_anchor_time = 2026-05-12T00:00:00Z UTC. + absl::Time decoding_anchor_time; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-12T00:00:00Z", + &decoding_anchor_time, nullptr)); + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr decoder, + CodecFactory::CreateExplicitCodec( + input_spec, /*encoding_time=*/std::nullopt, decoding_anchor_time)); + + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data_1, decoder->Decode(ed1)); + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data_2, decoder->Decode(ed2)); + + // The bucket maps back to 2026-05-12. + std::string expected_decoded_time = "2026-05-12"; + + EXPECT_THAT(decoded_data_1.metric_data.at("metric1"), ElementsAre(10)); + EXPECT_THAT(decoded_data_1.group_data.at("time"), + ElementsAre(expected_decoded_time)); + + EXPECT_THAT(decoded_data_2.metric_data.at("metric1"), ElementsAre(20)); + EXPECT_THAT(decoded_data_2.group_data.at("time"), + ElementsAre(expected_decoded_time)); +} + +TEST(CodecFactoryTest, EncodeThenDecodeAbsoluteTime) { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + + auto* group_by_spec = input_spec.add_group_by_vector_specs(); + group_by_spec->set_vector_name("time"); + group_by_spec->set_data_type(InputSpec::STRING); + + auto* time_domain = group_by_spec->mutable_domain_spec()->mutable_time(); + time_domain->mutable_period_duration()->set_seconds(86400); // 1 day + time_domain->set_num_periods(6); // 6 days + time_domain->set_timezone("America/Los_Angeles"); // Output TimeZone + time_domain->set_format(absl::RFC3339_full); + time_domain->mutable_lookback_window()->set_seconds(86400 * 2); // 2 days + + // Scenario: Two local events occurring in different civil days but at the + // same absolute time. Both devices also encode at the exact same absolute + // physical time. Just past midnight ET on May 12, but still May 11 in LA. + absl::Time encoding_time; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-12T00:34:56-04:00", + &encoding_time, nullptr)); + + // Event occurred 3 minutes prior to encoding. + absl::Time event_time = encoding_time - absl::Minutes(3); + + // --- Client 1 (New York, EDT/UTC-4) --- + absl::TimeZone ny_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/New_York", &ny_tz)); + std::string t1_str = absl::FormatTime(absl::RFC3339_full, event_time, ny_tz); + MetricData md1; + md1["metric1"] = {10}; + GroupData gd1; + gd1["time"] = {t1_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder1, + CodecFactory::CreateExplicitCodec(input_spec, encoding_time)); + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData ed1, encoder1->Encode(gd1, md1)); + + // --- Client 2 (Los Angeles, PDT/UTC-7) --- + absl::TimeZone la_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/Los_Angeles", &la_tz)); + std::string t2_str = absl::FormatTime(absl::RFC3339_full, event_time, la_tz); + MetricData md2; + md2["metric1"] = {20}; + GroupData gd2; + gd2["time"] = {t2_str}; + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder2, + CodecFactory::CreateExplicitCodec(input_spec, encoding_time)); + SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData ed2, encoder2->Encode(gd2, md2)); + + // Decode using anchor time before the events. + absl::Time decoding_anchor_time; + ASSERT_TRUE(absl::ParseTime(absl::RFC3339_full, "2026-05-10T00:00:00Z", + &decoding_anchor_time, nullptr)); + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::unique_ptr decoder, + CodecFactory::CreateExplicitCodec( + input_spec, /*encoding_time=*/std::nullopt, decoding_anchor_time)); + + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data_1, decoder->Decode(ed1)); + SECAGG_ASSERT_OK_AND_ASSIGN(DecodedData decoded_data_2, decoder->Decode(ed2)); + + // Reconstructed absolute time is 2026-05-12T00:00:00Z UTC. It's the start of + // the day in UTC because the default origin time is in UTC, and the period is + // 1 day. Formatted back using the spec's configured timezone + // (America/Los_Angeles, PDT UTC-7) and full RFC3339 format: + // 2026-05-12T00:00:00Z UTC -> "2026-05-11T17:00:00-07:00". + std::string expected_decoded_time = "2026-05-11T17:00:00-07:00"; + + EXPECT_THAT(decoded_data_1.metric_data.at("metric1"), ElementsAre(10)); + EXPECT_THAT(decoded_data_1.group_data.at("time"), + ElementsAre(expected_decoded_time)); + + EXPECT_THAT(decoded_data_2.metric_data.at("metric1"), ElementsAre(20)); + EXPECT_THAT(decoded_data_2.group_data.at("time"), + ElementsAre(expected_decoded_time)); +} + } // namespace } // namespace willow } // namespace secure_aggregation diff --git a/willow/input_encoding/time_utils.cc b/willow/input_encoding/time_utils.cc new file mode 100644 index 0000000..c91c2f0 --- /dev/null +++ b/willow/input_encoding/time_utils.cc @@ -0,0 +1,167 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/input_encoding/time_utils.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "willow/proto/willow/input_spec.pb.h" + +namespace secure_aggregation { +namespace willow { +namespace { + +absl::Duration DurationFromProto(const google::protobuf::Duration& proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + +absl::Time TimeFromProto(const google::protobuf::Timestamp& proto) { + return absl::FromUnixSeconds(proto.seconds()) + + absl::Nanoseconds(proto.nanos()); +} + +int64_t GetPeriodIndex(absl::Time t, absl::Time origin_time, + absl::Duration period_duration) { + // Calculate the exact duration elapsed since the time origin. + absl::Duration elapsed = t - origin_time; + + // Round down to the nearest period multiple, even for negative elapsed times. + // e.g. Floor(-12h, 24h) = -24h. + absl::Duration floored_elapsed = absl::Floor(elapsed, period_duration); + + // Calculate the number of elapsed periods using integer division, since + // floored_elapsed is an exact multiple of period_duration. + absl::Duration unused_remainder; + int64_t period_index = + absl::IDivDuration(floored_elapsed, period_duration, &unused_remainder); + + return period_index; +} + +} // namespace + +absl::StatusOr ParseTimeDomain( + const InputSpec::TimeDomain& proto) { + absl::TimeZone tz; + if (proto.timezone().empty()) { + tz = absl::UTCTimeZone(); + } else if (!absl::LoadTimeZone(proto.timezone(), &tz)) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid timezone: ", proto.timezone())); + } + + absl::Duration period_duration = DurationFromProto(proto.period_duration()); + if (period_duration <= absl::ZeroDuration()) { + return absl::InvalidArgumentError("period_duration must be > 0"); + } + + int32_t num_periods = proto.num_periods(); + if (num_periods <= 0) { + return absl::InvalidArgumentError("num_periods must be > 0"); + } + + std::string format; + if (proto.format().empty()) { + format = absl::RFC3339_full; + } else { + format = proto.format(); + } + + absl::Time origin_time = absl::UnixEpoch(); + if (proto.has_origin_time()) { + origin_time = TimeFromProto(proto.origin_time()); + } + + absl::Duration lookback_window; + if (proto.has_lookback_window()) { + lookback_window = DurationFromProto(proto.lookback_window()); + if (lookback_window <= absl::ZeroDuration()) { + return absl::InvalidArgumentError("lookback_window must be > 0"); + } + } else { + lookback_window = period_duration * proto.num_periods() / 2; + } + + return TimeDomainInfo{ + .period_duration = period_duration, + .num_periods = num_periods, + .format = format, + .timezone = tz, + .origin_time = origin_time, + .lookback_window = lookback_window, + }; +} + +int64_t EncodeTime(absl::Time t, const TimeDomainInfo& info, + absl::Time encoding_time) { + // How many full periods have elapsed since origin_time, with handling for + // negative times. + int64_t period_index = + GetPeriodIndex(t, info.origin_time, info.period_duration); + + // Correct negative remainder to positive modulo: bucket in [0, num_periods). + int64_t bucket = period_index % info.num_periods; + if (bucket < 0) { + bucket += info.num_periods; + } + + // Future events are always invalid. + if (t > encoding_time) { + return info.num_periods; // Invalid bucket index (last index in domain) + } + + // Any event older than lookback_window relative to now is marked invalid. + // This filters out obsolete data and prevents wraparound collisions + if (t < encoding_time - info.lookback_window) { + return info.num_periods; // Invalid bucket index + } + + return bucket; +} + +absl::StatusOr DecodeTime(int64_t bucket_index, + const TimeDomainInfo& info, + absl::Time decoding_anchor_time) { + if (bucket_index < 0 || bucket_index >= info.num_periods) { + return absl::InvalidArgumentError( + absl::StrCat("Bucket index out of range: ", bucket_index, + ", must be in [0, ", info.num_periods, ")")); + } + + // anchor_period is the period index of anchor_time relative to origin_time. + int64_t anchor_period = GetPeriodIndex(decoding_anchor_time, info.origin_time, + info.period_duration); + + // We want to find k in [0, num_periods) such that: + // (anchor_period + k) mod num_periods = bucket_index + int64_t k = (bucket_index - anchor_period) % info.num_periods; + if (k < 0) { + k += info.num_periods; // Turn remainder into positive modulo + } + + // Unique period start time t in [anchor_time, anchor_time + info.num_periods + // * info.period_duration) that maps to bucket_index. + return info.origin_time + (anchor_period + k) * info.period_duration; +} + +} // namespace willow +} // namespace secure_aggregation diff --git a/willow/input_encoding/time_utils.h b/willow/input_encoding/time_utils.h new file mode 100644 index 0000000..78fa6a1 --- /dev/null +++ b/willow/input_encoding/time_utils.h @@ -0,0 +1,65 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_TIME_UTILS_H_ +#define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_TIME_UTILS_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "willow/proto/willow/input_spec.pb.h" + +namespace secure_aggregation { +namespace willow { + +// A special string representation returned when decoding an invalid/stale +// modular timestamp bucket index. +inline constexpr absl::string_view kInvalidTimestamp = "INVALID_TIMESTAMP"; + +// Struct containing the parsed time domain parameters. +struct TimeDomainInfo { + absl::Duration period_duration; + int num_periods; + std::string format; + absl::TimeZone timezone; + absl::Time origin_time; + absl::Duration lookback_window; +}; + +// Parses a TimeDomain protobuf spec into a validated TimeDomainInfo struct. +// Returns an InvalidArgument status if spec fields (duration, periods, +// timezone) are invalid. +absl::StatusOr ParseTimeDomain( + const InputSpec::TimeDomain& proto); + +// Encodes a parsed absl::Time into a modular bucket index in [0, num_periods). +// Returns num_periods (the invalid/stale bucket) if the event is in the future +// or older than the lookback window relative to `encoding_time`. +int64_t EncodeTime(absl::Time t, const TimeDomainInfo& info, + absl::Time encoding_time); + +// Reconstructs the absolute start-of-period absl::Time from a bucket index. +// Expects bucket_index in [0, num_periods). Returns an InvalidArgument status +// if the bucket index is out of bounds. +absl::StatusOr DecodeTime(int64_t bucket_index, + const TimeDomainInfo& info, + absl::Time decoding_anchor_time); + +} // namespace willow +} // namespace secure_aggregation + +#endif // SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_TIME_UTILS_H_ diff --git a/willow/input_encoding/time_utils_test.cc b/willow/input_encoding/time_utils_test.cc new file mode 100644 index 0000000..59d9983 --- /dev/null +++ b/willow/input_encoding/time_utils_test.cc @@ -0,0 +1,159 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/input_encoding/time_utils.h" + +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "ffi_utils/status_matchers.h" +#include "gmock/gmock.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "gtest/gtest.h" +#include "willow/proto/willow/input_spec.pb.h" + +namespace secure_aggregation { +namespace willow { +namespace { + +using ::secure_aggregation::secagg_internal::StatusIs; +using ::testing::HasSubstr; + +TEST(TimeDomainHelperTest, ParseTimeDomainWithLookback) { + InputSpec::TimeDomain proto; + proto.mutable_period_duration()->set_seconds(86400); // 1 day + proto.set_num_periods(6); // 6 days + proto.set_timezone("America/Los_Angeles"); + proto.set_format("%Y-%m-%d"); + proto.mutable_origin_time()->set_seconds(1234567890); + proto.mutable_lookback_window()->set_seconds(7 * 24 * 3600); // 7 days + + SECAGG_ASSERT_OK_AND_ASSIGN(TimeDomainInfo info, ParseTimeDomain(proto)); + + EXPECT_EQ(info.period_duration, absl::Hours(24)); + EXPECT_EQ(info.num_periods, 6); + EXPECT_EQ(info.format, "%Y-%m-%d"); + + absl::TimeZone expected_tz; + ASSERT_TRUE(absl::LoadTimeZone("America/Los_Angeles", &expected_tz)); + EXPECT_EQ(info.timezone, expected_tz); + EXPECT_EQ(info.origin_time, absl::FromUnixSeconds(1234567890)); + EXPECT_EQ(info.lookback_window, absl::Hours(7 * 24)); +} + +TEST(TimeDomainHelperTest, ParseTimeDomainDefaultLookback) { + InputSpec::TimeDomain proto; + proto.mutable_period_duration()->set_seconds(86400); // 1 day + proto.set_num_periods(6); // 6 days + proto.set_timezone("UTC"); + + SECAGG_ASSERT_OK_AND_ASSIGN(TimeDomainInfo info, ParseTimeDomain(proto)); + // Default lookback is 6 * 1 day / 2 = 3 days (72 hours). + EXPECT_EQ(info.lookback_window, absl::Hours(3 * 24)); +} + +TEST(TimeDomainHelperTest, ParseTimeDomainInvalid) { + InputSpec::TimeDomain proto; + // Missing period_duration (default 0) + proto.set_num_periods(6); + EXPECT_THAT(ParseTimeDomain(proto), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("period_duration must be > 0"))); + + proto.mutable_period_duration()->set_seconds(86400); + proto.set_num_periods(0); // Invalid periods + EXPECT_THAT(ParseTimeDomain(proto), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("num_periods must be > 0"))); + + proto.set_num_periods(6); + proto.set_timezone("Invalid/Timezone"); + EXPECT_THAT(ParseTimeDomain(proto), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid timezone"))); + proto.set_timezone("UTC"); + + // Invalid lookback_window <= 0 + proto.mutable_lookback_window()->set_seconds(0); + EXPECT_THAT(ParseTimeDomain(proto), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("lookback_window must be > 0"))); +} + +TEST(TimeDomainHelperTest, EncodeTimeWithCustomLookback) { + TimeDomainInfo info{ + .period_duration = absl::Hours(24), + .num_periods = 6, + .format = "%Y-%m-%d", + .timezone = absl::UTCTimeZone(), + .origin_time = absl::UnixEpoch(), + .lookback_window = + absl::Hours(24 * 2), // Events older than 2 days are stale + }; + + absl::Time encoding_time = + absl::UnixEpoch() + absl::Hours(24 * 10); // Day 10 (1970-01-11) + + // Valid event: 1 day ago + absl::Time t1 = encoding_time - absl::Hours(24 * 1); + EXPECT_EQ(EncodeTime(t1, info, encoding_time), 3); + + // Stale event: 3 days ago + absl::Time t2 = encoding_time - absl::Hours(24 * 3); + EXPECT_EQ(EncodeTime(t2, info, encoding_time), 6); // Invalid bucket + + // Future event: 1 day in future. + absl::Time t3 = encoding_time + absl::Hours(24 * 1); + EXPECT_EQ(EncodeTime(t3, info, encoding_time), 6); // Invalid bucket +} + +TEST(TimeDomainHelperTest, DecodeTime) { + TimeDomainInfo info{ + .period_duration = absl::Hours(24), + .num_periods = 6, + .format = "%Y-%m-%d", + .timezone = absl::UTCTimeZone(), + .origin_time = absl::UnixEpoch(), + }; + + // anchor_time = Day 8 + absl::Time anchor_time = absl::UnixEpoch() + absl::Hours(24 * 8); + + // Decode bucket 3, expect Day 9 because 9 mod 6 = 3 and 8 <= 9 < 8 + 6 + absl::Time expected1 = absl::UnixEpoch() + absl::Hours(24 * 9); + + SECAGG_ASSERT_OK_AND_ASSIGN(absl::Time decoded1, + DecodeTime(3, info, anchor_time)); + EXPECT_EQ(decoded1, expected1); + EXPECT_GE(decoded1, anchor_time); + EXPECT_LT(decoded1, anchor_time + absl::Hours(24 * 6)); + + // Decode bucket 1, expect Day 13 because 13 mod 6 = 1 and 8 <= 13 < 8 + 6. + absl::Time expected2 = absl::UnixEpoch() + absl::Hours(24 * 13); + + SECAGG_ASSERT_OK_AND_ASSIGN(absl::Time decoded2, + DecodeTime(1, info, anchor_time)); + EXPECT_EQ(decoded2, expected2); + EXPECT_GE(decoded2, anchor_time); + EXPECT_LT(decoded2, anchor_time + absl::Hours(24 * 6)); + + // Invalid bucket index raises an error. + EXPECT_THAT(DecodeTime(6, info, anchor_time), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Bucket index out of range"))); +} + +} // namespace +} // namespace willow +} // namespace secure_aggregation diff --git a/willow/proto/willow/BUILD b/willow/proto/willow/BUILD index f543f23..12672b3 100644 --- a/willow/proto/willow/BUILD +++ b/willow/proto/willow/BUILD @@ -84,6 +84,10 @@ rust_proto_library( proto_library( name = "input_spec_proto", srcs = ["input_spec.proto"], + deps = [ + "@protobuf//:duration_proto", + "@protobuf//:timestamp_proto", + ], ) cc_proto_library( diff --git a/willow/proto/willow/input_spec.proto b/willow/proto/willow/input_spec.proto index 2d7013a..e24c00f 100644 --- a/willow/proto/willow/input_spec.proto +++ b/willow/proto/willow/input_spec.proto @@ -16,6 +16,9 @@ syntax = "proto3"; package secure_aggregation.willow; +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; + option java_multiple_files = true; option java_outer_classname = "InputSpecProto"; @@ -89,6 +92,70 @@ message InputSpec { repeated string values = 1; } + // A domain for timestamps represented in a modular way, yielding a small + // closed domain instead of an unbounded number of timestamps. + // + // Input timestamps t (as strings) are parsed and mapped to a period index in + // [0, num_periods) using: + // period_index = ((t - origin_time) / period_duration) mod num_periods + // Stale events (older than lookback_window relative to Now()) or future + // events are mapped to a special invalid bucket (index num_periods). During + // decoding, timestamps are reconstructed relative to a dynamic anchor_time. + // + // Inputs encoded more than (num_periods * period_duration - lookback_window) + // apart might collide, i.e. contain different timestamps mapping to the same + // bucket. Shorter lookback windows allow for longer retention before + // aggregation and decoding. + // + // Example: inputs from Device 1 and Device 3 should not be decoded together. + // + // <----- num_periods * period_duration -----> + // --+------|------|------|------|------|------|------|------|------|--> Time + // | B3 | B4 | B5 | B0 | B1 | B2 | B3 | B4 | B5 | Bucket + // | Day 1| Day 2| Day 3| Day 4| Day 5| Day 6| Day 7| Day 8| Day 9| Day + // +------+------+------+------+------+------+------+------+------+ + // ^ ^ + // anchor_time anchor_time + num_periods * period_duration + // + // Device 1: + // t_1 + // v + // [=============] (lookback_window = 2 days) + // + // Device 2: + // t_2 + // v + // [=============] + // Device 3: + // t_3 + // v + // [=============] + // + + message TimeDomain { + // The duration of each period. + google.protobuf.Duration period_duration = 1; + // The number of periods. + int32 num_periods = 2; + // The format string used for parsing and formatting (absl::ParseTime + // compatible). If empty, RFC 3339 is used. + string format = 3; + // The timezone to use if the timestamp string does not contain timezone + // info (e.g., "America/Los_Angeles", "UTC"). Must be absl::LoadTimeZone + // compatible, i.e. a valid IANA timezone name. If not set, "UTC" is used. + // If the timestamp string contains a UTC offset specification (e.g., %z), + // the timezone offset in the string takes precedence. + string timezone = 4; + // The origin time from which the periods are calculated. + // If not set, the Unix Epoch (1970-01-01T00:00:00Z) is used. + google.protobuf.Timestamp origin_time = 5; + // The lookback window defining how far back in time from the encoding + // time we accept events. Events older than this duration, or + // events in the future, will be mapped to the invalid bucket. + // If not set, defaults to `num_periods * period_duration / 2`. + google.protobuf.Duration lookback_window = 6; + } + // A new message type to represent the domain specification. message DomainSpec { oneof domain_type { @@ -97,6 +164,9 @@ message InputSpec { // Defines a domain as an interval of values. Interval interval = 2; + + // Defines a domain for timestamps. + TimeDomain time = 3; } }