Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,20 @@ def to_dict_vectorized(
return result


def _rng_eq(a: np.random.Generator | None, b: np.random.Generator | None) -> bool:
if a is None and b is None:
return True
if a is None or b is None:
return False
return a.bit_generator.state == b.bit_generator.state


def _rng_hash(rng: np.random.Generator | None) -> int:
if rng is None:
return 0
return hash(tuple(rng.bit_generator.state["state"]["key"]))


@dataclass(frozen=True)
class ShardingCodec(
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
Expand Down Expand Up @@ -375,6 +389,28 @@ def to_dict(self) -> dict[str, JSON]:
},
}

def __eq__(self, other: object) -> bool:
if not isinstance(other, ShardingCodec):
return NotImplemented
return (
self.chunk_shape == other.chunk_shape
and self.codecs == other.codecs
and self.index_codecs == other.index_codecs
and self.index_location == other.index_location
and self.subchunk_write_order == other.subchunk_write_order
and _rng_eq(self.rng, other.rng)
)

def __hash__(self) -> int:
return hash((
self.chunk_shape,
self.codecs,
self.index_codecs,
self.index_location,
self.subchunk_write_order,
_rng_hash(self.rng),
))

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
shard_spec = self._get_chunk_spec(array_spec)
evolved_codecs = tuple(c.evolve_from_array_spec(array_spec=shard_spec) for c in self.codecs)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_codecs/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,32 @@ def test_pickle() -> None:
assert pickle.loads(pickle.dumps(codec)) == codec


def test_sharding_codec_equality_with_rng() -> None:
"""Test that two ShardingCodec instances with identical rng are considered equal."""
rng = np.random.default_rng(42)
codec1 = ShardingCodec(chunk_shape=(8, 8), rng=rng)
codec2 = ShardingCodec(chunk_shape=(8, 8), rng=rng)
assert codec1 == codec2

# Two separate generators with the same seed should also be equal
rng_a = np.random.default_rng(12345)
rng_b = np.random.default_rng(12345)
codec_a = ShardingCodec(chunk_shape=(4, 4), rng=rng_a)
codec_b = ShardingCodec(chunk_shape=(4, 4), rng=rng_b)
assert codec_a == codec_b

# Different seed => not equal
rng_c = np.random.default_rng(99999)
codec_c = ShardingCodec(chunk_shape=(4, 4), rng=rng_c)
assert codec_a != codec_c

# None rng should work too
codec_none1 = ShardingCodec(chunk_shape=(8, 8), rng=None)
codec_none2 = ShardingCodec(chunk_shape=(8, 8), rng=None)
assert codec_none1 == codec_none2
assert codec_none1 != codec1


@pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"])
@pytest.mark.parametrize(
"index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]
Expand Down
Loading