From cb4d230bd82c721866954360554940fabbf2de0f Mon Sep 17 00:00:00 2001 From: NIK-TIGER-BILL Date: Mon, 25 May 2026 23:15:31 +0000 Subject: [PATCH] fix: make ShardingCodec equality work with rng field (#4005) ShardingCodec gained a field in a recent commit. Because does not implement , the auto-generated dataclass always returned when was not . This patch adds custom and methods that compare Generators by their , restoring deterministic equality for otherwise-identical codec instances. Closes #4005 Signed-off-by: NIK-TIGER-BILL --- src/zarr/codecs/sharding.py | 36 ++++++++++++++++++++++++++++++ tests/test_codecs/test_sharding.py | 26 +++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 33c8602ecb..a784fbca03 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -290,6 +290,20 @@ def to_dict_vectorized( return result +def _rng_eq(a: np.random.Generator | None, b: np.random.Generator | None) -> bool: + if a is None and b is None: + return True + if a is None or b is None: + return False + return a.bit_generator.state == b.bit_generator.state + + +def _rng_hash(rng: np.random.Generator | None) -> int: + if rng is None: + return 0 + return hash(tuple(rng.bit_generator.state["state"]["key"])) + + @dataclass(frozen=True) class ShardingCodec( ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin @@ -375,6 +389,28 @@ def to_dict(self) -> dict[str, JSON]: }, } + def __eq__(self, other: object) -> bool: + if not isinstance(other, ShardingCodec): + return NotImplemented + return ( + self.chunk_shape == other.chunk_shape + and self.codecs == other.codecs + and self.index_codecs == other.index_codecs + and self.index_location == other.index_location + and self.subchunk_write_order == other.subchunk_write_order + and _rng_eq(self.rng, other.rng) + ) + + def __hash__(self) -> int: + return hash(( + self.chunk_shape, + self.codecs, + self.index_codecs, + self.index_location, + self.subchunk_write_order, + _rng_hash(self.rng), + )) + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: shard_spec = self._get_chunk_spec(array_spec) evolved_codecs = tuple(c.evolve_from_array_spec(array_spec=shard_spec) for c in self.codecs) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 74e4a7e0d5..a051db8a98 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -673,6 +673,32 @@ def test_pickle() -> None: assert pickle.loads(pickle.dumps(codec)) == codec +def test_sharding_codec_equality_with_rng() -> None: + """Test that two ShardingCodec instances with identical rng are considered equal.""" + rng = np.random.default_rng(42) + codec1 = ShardingCodec(chunk_shape=(8, 8), rng=rng) + codec2 = ShardingCodec(chunk_shape=(8, 8), rng=rng) + assert codec1 == codec2 + + # Two separate generators with the same seed should also be equal + rng_a = np.random.default_rng(12345) + rng_b = np.random.default_rng(12345) + codec_a = ShardingCodec(chunk_shape=(4, 4), rng=rng_a) + codec_b = ShardingCodec(chunk_shape=(4, 4), rng=rng_b) + assert codec_a == codec_b + + # Different seed => not equal + rng_c = np.random.default_rng(99999) + codec_c = ShardingCodec(chunk_shape=(4, 4), rng=rng_c) + assert codec_a != codec_c + + # None rng should work too + codec_none1 = ShardingCodec(chunk_shape=(8, 8), rng=None) + codec_none2 = ShardingCodec(chunk_shape=(8, 8), rng=None) + assert codec_none1 == codec_none2 + assert codec_none1 != codec1 + + @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) @pytest.mark.parametrize( "index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]