Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/4006.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed `ShardingCodec` equality and pickling: codecs that differ only in `subchunk_write_order` now compare unequal, codecs seeded with identical `rng` values now compare equal, and both `subchunk_write_order` and `rng` survive a pickle round-trip (previously `subchunk_write_order` silently reverted to `morton`).
47 changes: 46 additions & 1 deletion src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def parse_index_location(data: object) -> ShardingCodecIndexLocation:
return parse_enum(data, ShardingCodecIndexLocation)


def _rng_state(rng: np.random.Generator | None) -> Mapping[str, Any] | None:
"""Return a value-comparable snapshot of a Generator, or None.

numpy Generators have no value equality, so two Generators seeded identically are not
`==`. Their `bit_generator.state` is a plain dict that does compare by value, which lets
us compare codecs by configured seed rather than object identity.
"""
return None if rng is None else rng.bit_generator.state


@dataclass(frozen=True)
class _ShardingByteGetter(ByteGetter):
shard_dict: ShardMapping
Expand Down Expand Up @@ -340,7 +350,13 @@ def __init__(

# todo: typedict return type
def __getstate__(self) -> dict[str, Any]:
return {"rng": self.rng, **self.to_dict()}
# `subchunk_write_order` and `rng` are not part of the codec metadata (`to_dict`),
# so they must be carried explicitly to survive a pickle round-trip.
return {
"rng": self.rng,
"subchunk_write_order": self.subchunk_write_order,
**self.to_dict(),
}

def __setstate__(self, state: dict[str, Any]) -> None:
config = state["configuration"]
Expand All @@ -349,12 +365,41 @@ def __setstate__(self, state: dict[str, Any]) -> None:
object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"]))
object.__setattr__(self, "index_location", parse_index_location(config["index_location"]))
object.__setattr__(self, "rng", state["rng"])
object.__setattr__(self, "subchunk_write_order", state["subchunk_write_order"])

# Use instance-local lru_cache to avoid memory leaks
# object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec))
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))

def __eq__(self, other: object) -> bool:
# numpy Generators have no value equality, so the dataclass-generated __eq__ would
# compare `rng` by identity. Compare by bit-generator state instead so two codecs
# seeded identically are equal. Everything else compares fieldwise as usual.
if not isinstance(other, ShardingCodec):
return NotImplemented
return (
self.chunk_shape == other.chunk_shape
and self.codecs == other.codecs
and self.index_codecs == other.index_codecs
and self.index_location == other.index_location
and self.subchunk_write_order == other.subchunk_write_order
and _rng_state(self.rng) == _rng_state(other.rng)
)

def __hash__(self) -> int:
# `rng` is excluded — its state is not hashable, and omitting it is sound since
# equal objects must only agree on the *fields hashed*, not all fields.
return hash(
(
self.chunk_shape,
self.codecs,
self.index_codecs,
self.index_location,
self.subchunk_write_order,
)
)

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(data, "sharding_indexed")
Expand Down
61 changes: 61 additions & 0 deletions tests/test_codecs/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,67 @@ async def test_unordered_can_be_seeded() -> None:
assert all(orders[0] == o for o in orders)


@pytest.mark.parametrize(
("left", "right", "expected_equal"),
[
pytest.param(
ShardingCodec(chunk_shape=(4, 4), subchunk_write_order="lexicographic"),
ShardingCodec(chunk_shape=(4, 4), subchunk_write_order="lexicographic"),
True,
id="same-order",
),
pytest.param(
ShardingCodec(chunk_shape=(4, 4), subchunk_write_order="morton"),
ShardingCodec(chunk_shape=(4, 4), subchunk_write_order="lexicographic"),
False,
id="different-order",
),
pytest.param(
ShardingCodec(chunk_shape=(4, 4), rng=np.random.default_rng(seed=0)),
ShardingCodec(chunk_shape=(4, 4), rng=np.random.default_rng(seed=0)),
True,
id="same-rng-seed",
),
pytest.param(
ShardingCodec(chunk_shape=(4, 4), rng=np.random.default_rng(seed=0)),
ShardingCodec(chunk_shape=(4, 4), rng=np.random.default_rng(seed=1)),
False,
id="different-rng-seed",
),
pytest.param(
ShardingCodec(chunk_shape=(4, 4), rng=np.random.default_rng(seed=0)),
ShardingCodec(chunk_shape=(4, 4)),
False,
id="rng-vs-no-rng",
),
],
)
def test_eq(left: ShardingCodec, right: ShardingCodec, expected_equal: bool) -> None:
"""Equality includes ``subchunk_write_order`` and compares ``rng`` by seed state
(numpy Generators have no value equality of their own, so identically-seeded
Generators must still produce equal codecs)."""
assert (left == right) is expected_equal


def test_pickle_preserves_subchunk_write_order() -> None:
"""``subchunk_write_order`` must survive a pickle round-trip rather than reverting
to the default (it is not stored in the codec metadata)."""
codec = ShardingCodec(chunk_shape=(8, 8), subchunk_write_order="lexicographic")
restored = pickle.loads(pickle.dumps(codec))
assert restored.subchunk_write_order == "lexicographic"
assert restored == codec


def test_pickle_preserves_seeded_rng() -> None:
"""A seeded rng must survive a pickle round-trip so unordered writes are reproducible
across process boundaries."""
codec = ShardingCodec(
chunk_shape=(8, 8), subchunk_write_order="unordered", rng=np.random.default_rng(seed=0)
)
restored = pickle.loads(pickle.dumps(codec))
assert restored == codec


@pytest.mark.parametrize(
"subchunk_write_order",
get_args(SubchunkWriteOrder),
Expand Down
Loading