Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions willow/input_encoding/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ package(
cc_library(
name = "codec",
srcs = [
"codec_factory.cc",
"codec.cc",
],
hdrs = [
"codec.h",
Expand All @@ -44,8 +44,8 @@ cc_library(
)

cc_test(
name = "explicit_codec_test",
srcs = ["explicit_codec_test.cc"],
name = "codec_test",
srcs = ["codec_test.cc"],
deps = [
":codec",
"@googletest//:gtest_main",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "willow/input_encoding/codec_factory.h"
#include "willow/input_encoding/codec.h"

#include <algorithm>
#include <cstddef>
Expand All @@ -28,7 +28,6 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "willow/input_encoding/codec.h"
#include "willow/proto/willow/input_spec.pb.h"

namespace secure_aggregation {
Expand All @@ -53,13 +52,17 @@ struct GroupDomainKey {
}
};

// WillowInputExplicitEncoder must be instantiated through the factory class
// CodecFactory.
class ExplicitCodecImpl : public Codec {
// FlatHistogramCodecImpl implements a Codec that encodes data into a dense,
// flat 1D histogram representing the Cartesian product of the group-by
// domains.
//
// It must be instantiated through the factory function
// Codec::CreateFlatHistogramCodec.
class FlatHistogramCodecImpl : public Codec {
public:
ExplicitCodecImpl(const ExplicitCodecImpl&) = delete;
ExplicitCodecImpl& operator=(const ExplicitCodecImpl&) = delete;
~ExplicitCodecImpl() override = default;
FlatHistogramCodecImpl(const FlatHistogramCodecImpl&) = delete;
FlatHistogramCodecImpl& operator=(const FlatHistogramCodecImpl&) = delete;
~FlatHistogramCodecImpl() override = default;

absl::StatusOr<EncodedData> Encode(
const GroupData& group_by_data,
Expand Down Expand Up @@ -94,7 +97,7 @@ class ExplicitCodecImpl : public Codec {

size_t GetCombinedIndex(const std::vector<int>& indices) const;

explicit ExplicitCodecImpl(
explicit FlatHistogramCodecImpl(
InputSpec input_spec,
absl::flat_hash_map<std::string, const InputVectorSpec*>
group_by_spec_map,
Expand Down Expand Up @@ -127,10 +130,10 @@ class ExplicitCodecImpl : public Codec {
flattened_domain_size_ *= domain_size;
}
}
friend class CodecFactory;
friend class Codec;
};

absl::Status ExplicitCodecImpl::ValidateData(
absl::Status FlatHistogramCodecImpl::ValidateData(
const GroupData& group_by_data, const MetricData& metric_data) const {
// Check that all vectors in metric_data and group_by_data are present in
// metric_spec_map and group_by_spec_map_, respectively. This ensures that the
Expand Down Expand Up @@ -221,7 +224,7 @@ absl::Status ExplicitCodecImpl::ValidateData(
// Returns the indices of elements in individual domains/vectors of size `sizes`
// that correspond to the global index `global_index` of an element of their
// cartesian product.
std::vector<int> ExplicitCodecImpl::GetIndices(int global_index) const {
std::vector<int> FlatHistogramCodecImpl::GetIndices(int global_index) const {
if (group_by_domain_sizes_.empty()) {
return {};
}
Expand All @@ -241,7 +244,7 @@ std::vector<int> ExplicitCodecImpl::GetIndices(int global_index) const {
// domains of size 2 and 3 respectively, and we want to find overall index of
// an element that has index 1 in the first domain and index 0 in the second
// domain. The function will return 1 * 3 + 0 = 3.
size_t ExplicitCodecImpl::GetCombinedIndex(
size_t FlatHistogramCodecImpl::GetCombinedIndex(
const std::vector<int>& indices) const {
int64_t combined_index = 0;
for (int i = 0; i < indices.size(); ++i) {
Expand All @@ -251,7 +254,7 @@ size_t ExplicitCodecImpl::GetCombinedIndex(
return combined_index;
}

absl::StatusOr<EncodedData> ExplicitCodecImpl::Encode(
absl::StatusOr<EncodedData> FlatHistogramCodecImpl::Encode(
const GroupData& group_by_data, const MetricData& metric_data) const {
if (absl::Status status = ValidateData(group_by_data, metric_data);
!status.ok()) {
Expand Down Expand Up @@ -286,7 +289,7 @@ absl::StatusOr<EncodedData> ExplicitCodecImpl::Encode(
return result;
}

absl::StatusOr<DecodedData> ExplicitCodecImpl::Decode(
absl::StatusOr<DecodedData> FlatHistogramCodecImpl::Decode(
const EncodedData& encoded_data) const {
DecodedData decoded_data;

Expand Down Expand Up @@ -348,7 +351,7 @@ absl::StatusOr<DecodedData> ExplicitCodecImpl::Decode(
return decoded_data;
}

absl::Status ExplicitCodecImpl::ValidateExampleQuery(
absl::Status FlatHistogramCodecImpl::ValidateExampleQuery(
const absl::flat_hash_map<std::string, std::string>& query_output_specs)
const {
for (const auto& [name, type] : query_output_specs) {
Expand All @@ -373,20 +376,20 @@ absl::Status ExplicitCodecImpl::ValidateExampleQuery(
return absl::OkStatus();
}

absl::Status CodecFactory::ValidateExplicitCodecInputSpec(
const InputSpec& input_spec, size_t max_flattened_domain_size) {
size_t flattened_domain_size = 1;
absl::Status Codec::ValidateInputSpec(const InputSpec& input_spec,
size_t max_flat_histogram_bins) {
size_t flat_histogram_bins = 1;
for (const auto& spec : input_spec.group_by_vector_specs()) {
flattened_domain_size *= spec.domain_spec().string_values().values_size();
if (max_flattened_domain_size < flattened_domain_size) {
flat_histogram_bins *= spec.domain_spec().string_values().values_size();
if (max_flat_histogram_bins < flat_histogram_bins) {
return absl::InvalidArgumentError(
"Global output domain size exceeds maximum threshold.");
"Flat histogram bin count exceeds maximum threshold.");
}
}
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<Codec>> CodecFactory::CreateExplicitCodec(
absl::StatusOr<std::unique_ptr<Codec>> Codec::CreateFlatHistogramCodec(
InputSpec input_spec) {
// Check that specs include at least one metric vector.
if (input_spec.metric_vector_specs().empty()) {
Expand All @@ -408,9 +411,9 @@ absl::StatusOr<std::unique_ptr<Codec>> CodecFactory::CreateExplicitCodec(
absl::StrCat("Duplicate vector name: ", spec.vector_name()));
}
}
return absl::WrapUnique(new ExplicitCodecImpl(std::move(input_spec),
std::move(group_by_spec_map),
std::move(metric_spec_map)));
return absl::WrapUnique(new FlatHistogramCodecImpl(
std::move(input_spec), std::move(group_by_spec_map),
std::move(metric_spec_map)));
}

} // namespace willow
Expand Down
32 changes: 32 additions & 0 deletions willow/input_encoding/codec.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,25 @@
#ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_H_
#define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_H_

#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "willow/proto/willow/input_spec.pb.h"

namespace secure_aggregation {
namespace willow {

// The maximum number of bins in a flat histogram, which is the maximum size of
// the Cartesian product of domains for string features.
constexpr size_t kMaxFlatHistogramBins = 1000000;

using MetricData = absl::flat_hash_map<std::string, std::vector<int64_t>>;
using GroupData = absl::flat_hash_map<std::string, std::vector<std::string>>;
using EncodedData = absl::flat_hash_map<std::string, std::vector<int64_t>>;
Expand Down Expand Up @@ -69,6 +77,30 @@ class Codec {
virtual absl::Status ValidateExampleQuery(
const absl::flat_hash_map<std::string, std::string>& query_output_specs)
const = 0;

// Creates an instance of FlatHistogramCodec.
static absl::StatusOr<std::unique_ptr<Codec>> CreateFlatHistogramCodec(
::secure_aggregation::willow::InputSpec input_spec);

// Check that the combined size of the string domains is less than the
// maximum allowed size.
static absl::Status ValidateInputSpec(
const ::secure_aggregation::willow::InputSpec& input_spec,
size_t max_flat_histogram_bins = kMaxFlatHistogramBins);

// Deprecated aliases for backward compatibility
[[deprecated("Use CreateFlatHistogramCodec instead")]]
static absl::StatusOr<std::unique_ptr<Codec>> CreateExplicitCodec(
::secure_aggregation::willow::InputSpec input_spec) {
return CreateFlatHistogramCodec(std::move(input_spec));
}

[[deprecated("Use ValidateInputSpec instead")]]
static absl::Status ValidateExplicitCodecInputSpec(
const ::secure_aggregation::willow::InputSpec& input_spec,
size_t max_flattened_domain_size = kMaxFlatHistogramBins) {
return ValidateInputSpec(input_spec, max_flattened_domain_size);
}
};

} // namespace willow
Expand Down
41 changes: 36 additions & 5 deletions willow/input_encoding/codec_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "pybind11/stl.h"
#include "pybind11_abseil/status_casters.h"
#include "willow/input_encoding/codec.h"
#include "willow/input_encoding/codec_factory.h"
#include "willow/proto/willow/input_spec.pb.h"

namespace py = pybind11;
Expand All @@ -46,6 +45,19 @@ PYBIND11_MODULE(codec_bindings, m) {
return self.ValidateExampleQuery(cpp_map);
});

m.def(
"CreateFlatHistogramCodec",
[](const std::string& serialized_input_spec)
-> absl::StatusOr<std::unique_ptr<Codec>> {
InputSpec input_spec;
if (!input_spec.ParseFromString(serialized_input_spec)) {
return absl::InvalidArgumentError("Failed to parse InputSpec");
}
return Codec::CreateFlatHistogramCodec(input_spec);
},
py::arg("serialized_input_spec"));

// DEPRECATED: Use CreateFlatHistogramCodec instead.
m.def(
"CreateExplicitCodec",
[](const std::string& serialized_input_spec)
Expand All @@ -54,10 +66,29 @@ PYBIND11_MODULE(codec_bindings, m) {
if (!input_spec.ParseFromString(serialized_input_spec)) {
return absl::InvalidArgumentError("Failed to parse InputSpec");
}
return CodecFactory::CreateExplicitCodec(input_spec);
return Codec::CreateFlatHistogramCodec(input_spec);
},
py::arg("serialized_input_spec"));

m.def(
"ValidateInputSpec",
[](const std::string& serialized_input_spec,
size_t max_flattened_domain_size) -> absl::Status {
InputSpec input_spec;
if (!input_spec.ParseFromString(serialized_input_spec)) {
return absl::InvalidArgumentError("Failed to parse InputSpec");
}
if (max_flattened_domain_size == 0) {
return Codec::ValidateInputSpec(input_spec);
} else {
return Codec::ValidateInputSpec(input_spec,
max_flattened_domain_size);
}
},
py::arg("serialized_input_spec"),
py::arg("max_flattened_domain_size") = 0);

// DEPRECATED: Use ValidateInputSpec instead.
m.def(
"ValidateExplicitCodecInputSpec",
[](const std::string& serialized_input_spec,
Expand All @@ -67,10 +98,10 @@ PYBIND11_MODULE(codec_bindings, m) {
return absl::InvalidArgumentError("Failed to parse InputSpec");
}
if (max_flattened_domain_size == 0) {
return CodecFactory::ValidateExplicitCodecInputSpec(input_spec);
return Codec::ValidateInputSpec(input_spec);
} else {
return CodecFactory::ValidateExplicitCodecInputSpec(
input_spec, max_flattened_domain_size);
return Codec::ValidateInputSpec(input_spec,
max_flattened_domain_size);
}
},
py::arg("serialized_input_spec"),
Expand Down
4 changes: 3 additions & 1 deletion willow/input_encoding/codec_bindings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def setUp(self):
)
],
)
self.codec = codec_bindings.CreateExplicitCodec(spec.SerializeToString())
self.codec = codec_bindings.CreateFlatHistogramCodec(
spec.SerializeToString()
)

def test_validate_example_query_success(self):
query_specs = {
Expand Down
22 changes: 10 additions & 12 deletions willow/input_encoding/codec_factory.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2025 Google LLC
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -14,8 +14,10 @@

#ifndef SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_
#define SECURE_AGGREGATION_WILLOW_INPUT_ENCODING_CODEC_FACTORY_H_

#include <cstddef>
#include <memory>
#include <utility>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -25,22 +27,18 @@
namespace secure_aggregation {
namespace willow {

// The maximum size of the Cartesian product of domains for string features.
constexpr size_t kMaxGlobalOutputDomainSize = 1000000;

// Factory class that constructs non-copyable instances of children classes of
// Codec.
class CodecFactory {
class [[deprecated("Use Codec class static methods instead")]] CodecFactory {
public:
// Creates an instance of ExplicitCodec.
static absl::StatusOr<std::unique_ptr<Codec>> CreateExplicitCodec(
InputSpec input_spec);
InputSpec input_spec) {
return Codec::CreateFlatHistogramCodec(std::move(input_spec));
}

// Check that the combined size of the string domains is less than the
// maximum allowed size.
static absl::Status ValidateExplicitCodecInputSpec(
const InputSpec& input_spec,
size_t max_flattened_domain_size = kMaxGlobalOutputDomainSize);
size_t max_flattened_domain_size = kMaxFlatHistogramBins) {
return Codec::ValidateInputSpec(input_spec, max_flattened_domain_size);
}
};

} // namespace willow
Expand Down
Loading