Skip to content
208 changes: 123 additions & 85 deletions lading_payload/src/common/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ where
/// equal to `+0.0` under IEEE-754 numeric ordering.
const NEG_ZERO_AS_BITS: u32 = 0x8000_0000;

/// Error returned when a value cannot be turned into a [`Probability`].
/// Error returned when a value cannot be turned into a [`BoundedProbability`].
#[derive(Debug, thiserror::Error, Clone, Copy)]
pub enum ProbabilityError {
/// Value is [`f32::NAN`], [`f32::INFINITY`], or [`f32::NEG_INFINITY`].
Expand Down Expand Up @@ -147,42 +147,59 @@ pub enum ProbabilityError {
/// before storage. This canonical-bit-pattern guarantee is what makes hashing
/// on `value.to_bits()` consistent with numeric equality.
///
/// Two type aliases are provided for the bounds that actually occur in lading
/// payload configuration today; callers should prefer them over spelling the
/// bit pattern at the use site. Define additional aliases as new bounds appear.
///
/// # Example
///
/// ```
/// use lading_payload::common::config::Probability;
/// use lading_payload::common::config::{BoundedProbability, Probability};
///
/// type AtLeastHalf = Probability<{ f32::to_bits(0.5) }>;
/// let p = AtLeastHalf::try_new(0.75).expect("0.75 is in [0.5, 1.0]");
/// // For the common `[0.0, 1.0]` case, use the `Probability` alias.
/// let p = Probability::try_new(0.75).expect("0.75 is in [0.0, 1.0]");
/// assert_eq!(p.get(), 0.75);
///
/// // For other lower bounds, parameterize `BoundedProbability` directly.
/// type AtLeastHalf = BoundedProbability<{ f32::to_bits(0.5) }>;
/// let q = AtLeastHalf::try_new(0.75).expect("0.75 is in [0.5, 1.0]");
/// assert_eq!(q.get(), 0.75);
/// ```
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
#[serde(into = "f32", try_from = "f32")]
pub struct Probability<const MIN_AS_BITS: u32> {
pub struct BoundedProbability<const MIN_AS_BITS: u32> {
value: f32,
}

impl<const MIN_AS_BITS: u32> TryFrom<f32> for Probability<MIN_AS_BITS> {
/// A probability in the closed unit interval `[0.0, 1.0]`. The most common bound.
pub type Probability = BoundedProbability<{ f32::to_bits(0.0) }>;

/// A probability or ratio in `[0.01, 1.0]`. Use for fields such as
/// `unique_tag_ratio` that must avoid extreme low values but admit
/// in-the-wild values below `0.1`.
pub type AtLeastOneHundredth = BoundedProbability<{ f32::to_bits(0.01) }>;

impl<const MIN_AS_BITS: u32> TryFrom<f32> for BoundedProbability<MIN_AS_BITS> {
type Error = ProbabilityError;

fn try_from(value: f32) -> Result<Self, Self::Error> {
Self::try_new(value)
}
}

impl<const MIN_AS_BITS: u32> From<Probability<MIN_AS_BITS>> for f32 {
fn from(p: Probability<MIN_AS_BITS>) -> Self {
impl<const MIN_AS_BITS: u32> From<BoundedProbability<MIN_AS_BITS>> for f32 {
fn from(p: BoundedProbability<MIN_AS_BITS>) -> Self {
p.value
}
}

impl<const MIN_AS_BITS: u32> fmt::Display for Probability<MIN_AS_BITS> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.value, f)
impl<const MIN_AS_BITS: u32> PartialEq for BoundedProbability<MIN_AS_BITS> {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}

impl<const MIN_AS_BITS: u32> Probability<MIN_AS_BITS> {
impl<const MIN_AS_BITS: u32> BoundedProbability<MIN_AS_BITS> {
/// The lower bound decoded from `MIN_AS_BITS`.
///
/// The `assert!`s here run at const-evaluation time for every
Expand All @@ -208,7 +225,7 @@ impl<const MIN_AS_BITS: u32> Probability<MIN_AS_BITS> {
/// `[MIN, +1.0]` and is not [`f32::NAN`], [`f32::INFINITY`], or
/// [`f32::NEG_INFINITY`]. A `-0.0` input is normalized to `+0.0`.
///
/// This is a `const fn`, so callers can build a [`Probability`] in a
/// This is a `const fn`, so callers can build a [`BoundedProbability`] in a
/// `const` context by matching on the returned [`Result`]; the validation
/// then runs at compile time.
///
Expand Down Expand Up @@ -245,14 +262,34 @@ impl<const MIN_AS_BITS: u32> Probability<MIN_AS_BITS> {
}
}

/// Generate a uniformly-distributed-over-bit-patterns value in `[MIN, +1.0]`
/// by sampling a `u32` in `[MIN_AS_BITS, f32::to_bits(+1.0)]` and decoding it.
///
/// This works because the f32 <-> u32 ordering (documented on the type) is
/// monotonic for non-negative finite values, so every bit pattern in that
/// range decodes to a valid stored value. `-0.0`'s bit pattern is
/// `0x8000_0000`, far above `f32::to_bits(+1.0) = 0x3f80_0000`, so it can
/// never be generated.
#[cfg(feature = "arbitrary")]
impl<'a, const MIN_AS_BITS: u32> arbitrary::Arbitrary<'a> for BoundedProbability<MIN_AS_BITS> {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let bits = u.int_in_range(MIN_AS_BITS..=f32::to_bits(Self::MAX))?;
let value = f32::from_bits(bits);
// Routing through `try_new` fires the per-monomorphization const-eval
// bound check on `Self::MIN` and forwards any future invariant added
// to the constructor. The `expect` is safe by the argument above.
Ok(Self::try_new(value).expect("bits in [MIN_AS_BITS, MAX_AS_BITS] always valid"))
}
Comment on lines +265 to +282
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this trait implementation would actually get used unless we were fuzzing this type for some reason. I can't think of a scenario where we would actually do that. I think this block of code could be deleted without any real consequence.

}

#[cfg(test)]
mod probability_tests {
use super::{NEG_ZERO_AS_BITS, Probability, ProbabilityError};
use super::{BoundedProbability, NEG_ZERO_AS_BITS, ProbabilityError};
use proptest::prelude::*;

type ZeroOrMore = Probability<{ f32::to_bits(0.0) }>;
type AtLeastHalf = Probability<{ f32::to_bits(0.5) }>;
type AtLeastOne = Probability<{ f32::to_bits(1.0) }>;
type ZeroOrMore = BoundedProbability<{ f32::to_bits(0.0) }>;
type AtLeastHalf = BoundedProbability<{ f32::to_bits(0.5) }>;
type AtLeastOne = BoundedProbability<{ f32::to_bits(1.0) }>;

// ===== Unit tests: constants =====

Expand Down Expand Up @@ -308,6 +345,17 @@ mod probability_tests {
}
}

// ===== Unit tests: equality =====

#[test]
fn equality_holds_for_same_bit_pattern() {
let a = AtLeastHalf::try_new(0.75).expect("valid");
let b = AtLeastHalf::try_new(0.75).expect("valid");
let c = AtLeastHalf::try_new(0.875).expect("valid");
assert_eq!(a, b);
assert_ne!(a, c);
}

// ===== Unit tests: wire-format pins =====

#[test]
Expand Down Expand Up @@ -359,48 +407,45 @@ mod probability_tests {
// ===== Property-test helpers (generic over MIN_AS_BITS) =====

fn check_accepts_in_range<const MIN_AS_BITS: u32>(v: f32) {
let p = Probability::<MIN_AS_BITS>::try_new(v).expect("v should be valid by construction");
let p = BoundedProbability::<MIN_AS_BITS>::try_new(v)
.expect("v should be valid by construction");
assert_eq!(p.get().to_bits(), v.to_bits());
}

fn check_rejects_below_min<const MIN_AS_BITS: u32>(v: f32) {
let err = Probability::<MIN_AS_BITS>::try_new(v).expect_err("v should be below MIN");
let err = BoundedProbability::<MIN_AS_BITS>::try_new(v).expect_err("v should be below MIN");
match err {
ProbabilityError::BelowMin { min, value } => {
assert_eq!(min.to_bits(), Probability::<MIN_AS_BITS>::MIN.to_bits());
assert_eq!(
min.to_bits(),
BoundedProbability::<MIN_AS_BITS>::MIN.to_bits()
);
assert_eq!(value.to_bits(), v.to_bits());
}
other => panic!("expected BelowMin, got {other:?}"),
}
}

fn check_display_matches<const MIN_AS_BITS: u32>(v: f32) {
let p = Probability::<MIN_AS_BITS>::try_new(v).expect("valid v");
assert_eq!(format!("{p}"), format!("{v}"));
}

fn check_display_precision<const MIN_AS_BITS: u32>(v: f32, n: usize) {
let p = Probability::<MIN_AS_BITS>::try_new(v).expect("valid v");
assert_eq!(format!("{p:.n$}"), format!("{v:.n$}"));
}

fn check_serde_json_round_trip<const MIN_AS_BITS: u32>(v: f32) {
let p = Probability::<MIN_AS_BITS>::try_new(v).expect("valid v");
let p = BoundedProbability::<MIN_AS_BITS>::try_new(v).expect("valid v");
let json = serde_json::to_string(&p).expect("serialize");
let back: Probability<MIN_AS_BITS> = serde_json::from_str(&json).expect("deserialize");
let back: BoundedProbability<MIN_AS_BITS> =
serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.get().to_bits(), v.to_bits());
}

fn check_serde_yaml_round_trip<const MIN_AS_BITS: u32>(v: f32) {
let p = Probability::<MIN_AS_BITS>::try_new(v).expect("valid v");
let p = BoundedProbability::<MIN_AS_BITS>::try_new(v).expect("valid v");
let yaml = serde_yaml::to_string(&p).expect("serialize");
let back: Probability<MIN_AS_BITS> = serde_yaml::from_str(&yaml).expect("deserialize");
let back: BoundedProbability<MIN_AS_BITS> =
serde_yaml::from_str(&yaml).expect("deserialize");
assert_eq!(back.get().to_bits(), v.to_bits());
}

fn check_serde_json_rejects_below_min<const MIN_AS_BITS: u32>(v: f32) {
let json = serde_json::to_string(&v).expect("serialize raw f32");
let err = serde_json::from_str::<Probability<MIN_AS_BITS>>(&json).expect_err("v < MIN");
let err =
serde_json::from_str::<BoundedProbability<MIN_AS_BITS>>(&json).expect_err("v < MIN");
assert!(
err.to_string().contains("below lower bound"),
"unexpected error: {err}"
Expand All @@ -409,7 +454,8 @@ mod probability_tests {

fn check_serde_yaml_rejects_below_min<const MIN_AS_BITS: u32>(v: f32) {
let yaml = serde_yaml::to_string(&v).expect("serialize raw f32");
let err = serde_yaml::from_str::<Probability<MIN_AS_BITS>>(&yaml).expect_err("v < MIN");
let err =
serde_yaml::from_str::<BoundedProbability<MIN_AS_BITS>>(&yaml).expect_err("v < MIN");
assert!(
err.to_string().contains("below lower bound"),
"unexpected error: {err}"
Expand Down Expand Up @@ -474,55 +520,6 @@ mod probability_tests {
}
}

// ===== Property tests: Display =====

proptest! {
#[test]
fn display_matches_inner_f32_for_valid_values_zero_or_more(
v in valid_value_strategy(ZeroOrMore::MIN),
) {
check_display_matches::<{ f32::to_bits(0.0) }>(v);
}

#[test]
fn display_matches_inner_f32_for_valid_values_at_least_half(
v in valid_value_strategy(AtLeastHalf::MIN),
) {
check_display_matches::<{ f32::to_bits(0.5) }>(v);
}

#[test]
fn display_matches_inner_f32_for_valid_values_at_least_one(
v in valid_value_strategy(AtLeastOne::MIN),
) {
check_display_matches::<{ f32::to_bits(1.0) }>(v);
}

#[test]
fn display_propagates_precision_zero_or_more(
v in valid_value_strategy(ZeroOrMore::MIN),
n in 0_usize..=10,
) {
check_display_precision::<{ f32::to_bits(0.0) }>(v, n);
}

#[test]
fn display_propagates_precision_at_least_half(
v in valid_value_strategy(AtLeastHalf::MIN),
n in 0_usize..=10,
) {
check_display_precision::<{ f32::to_bits(0.5) }>(v, n);
}

#[test]
fn display_propagates_precision_at_least_one(
v in valid_value_strategy(AtLeastOne::MIN),
n in 0_usize..=10,
) {
check_display_precision::<{ f32::to_bits(1.0) }>(v, n);
}
}

// ===== Property tests: serde round-trips =====

proptest! {
Expand Down Expand Up @@ -634,4 +631,45 @@ mod probability_tests {
);
}
}

// ===== Property tests: Arbitrary impl (feature = "arbitrary") =====

#[cfg(feature = "arbitrary")]
fn check_arbitrary_produces_valid<const MIN_AS_BITS: u32>(bytes: &[u8]) {
use arbitrary::{Arbitrary, Unstructured};
let mut u = Unstructured::new(bytes);
// `int_in_range` can fail with `NotEnoughData` on short inputs; that's
// fine -- we only need to check that any `Ok` value is valid.
if let Ok(p) = BoundedProbability::<MIN_AS_BITS>::arbitrary(&mut u) {
let v = p.get();
assert!(v.is_finite());
assert_ne!(v.to_bits(), NEG_ZERO_AS_BITS);
assert!(v >= BoundedProbability::<MIN_AS_BITS>::MIN);
assert!(v <= BoundedProbability::<MIN_AS_BITS>::MAX);
}
}

#[cfg(feature = "arbitrary")]
proptest! {
#[test]
fn arbitrary_produces_valid_zero_or_more(
bytes in prop::collection::vec(any::<u8>(), 4..32),
) {
check_arbitrary_produces_valid::<{ f32::to_bits(0.0) }>(&bytes);
}

#[test]
fn arbitrary_produces_valid_at_least_half(
bytes in prop::collection::vec(any::<u8>(), 4..32),
) {
check_arbitrary_produces_valid::<{ f32::to_bits(0.5) }>(&bytes);
}

#[test]
fn arbitrary_produces_valid_at_least_one(
bytes in prop::collection::vec(any::<u8>(), 4..32),
) {
check_arbitrary_produces_valid::<{ f32::to_bits(1.0) }>(&bytes);
}
}
}
Loading