diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index a4d17cf..8f80fd3 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -96,3 +96,10 @@ For day-to-day workflow and commands, see [`CLAUDE.md`](../CLAUDE.md) and so recursive `@cache` functions still cache. GIL builds use a single GIL-serialized flag; free-threaded builds use a per-thread set of active function addresses. The shared backend is unaffected (key comparison is by serialized bytes, never Python `__eq__`). +- **A raising `__eq__` must propagate, not be swallowed (issue #36).** `PyObject_RichCompareBool` + returns -1 with a Python exception set when a key's `__eq__` raises. `key.rs::rich_compare_eq` + reports -1 as "not equal" (so hashbrown stops probing) but leaves the exception set, and every + memory-backend lookup site (`__call__` read + write-double-check, `get`, `_probe`) calls + `PyErr::take` after the lookup and returns the error instead of recomputing / returning `Ok` + with an exception pending (which PyO3 turns into a masking `SystemError`). Any new lookup site + that can run `__eq__` must do the same check. diff --git a/src/key.rs b/src/key.rs index 803c5e9..1082a2a 100644 --- a/src/key.rs +++ b/src/key.rs @@ -51,10 +51,7 @@ impl PartialEq for CacheKey { // serializes that via its reentrancy guard (try_enter, see issue #30), so // no aliasing/reentrant shard guard is taken during this comparison. // This is the same direct C API call that lru_cache uses. - unsafe { - ffi::PyObject_RichCompareBool(self.key_obj.as_ptr(), other.key_obj.as_ptr(), ffi::Py_EQ) - == 1 - } + unsafe { rich_compare_eq(self.key_obj.as_ptr(), other.key_obj.as_ptr()) } } } @@ -99,6 +96,31 @@ impl hashbrown::Equivalent for BorrowedArgs { // call stack) and `key.key_obj` is an owned reference in the map. The // arbitrary Python __eq__ this runs may re-enter; CachedFunction's // reentrancy guard prevents a second, aliasing shard guard (issue #30). - unsafe { ffi::PyObject_RichCompareBool(self.ptr, key.key_obj.as_ptr(), ffi::Py_EQ) == 1 } + unsafe { rich_compare_eq(self.ptr, key.key_obj.as_ptr()) } + } +} + +/// `a == b` via Python's rich comparison, for use from `PartialEq`/`Equivalent` +/// (which can't return `Result`). +/// +/// `PyObject_RichCompareBool` returns -1 and leaves a Python exception set when a +/// key's `__eq__` raises. We must NOT map that to `true`/`false` and silently drop +/// the exception (issue #36): callers fetch the pending exception after the lookup +/// and propagate it. Here, -1 is reported as "not equal" so hashbrown stops at this +/// slot, with the exception left set for the caller. +/// +/// If an exception is already pending (an earlier comparison in the same lookup +/// raised), we return `false` without calling into Python again — re-entering the +/// interpreter with a live exception would clobber the original error. +/// +/// # Safety +/// Both pointers must be valid live Python objects and the GIL must be held. +#[inline(always)] +unsafe fn rich_compare_eq(a: *mut ffi::PyObject, b: *mut ffi::PyObject) -> bool { + if !ffi::PyErr_Occurred().is_null() { + return false; } + // 1 = equal; 0 = not equal; -1 = __eq__ raised (exception now set) — the latter + // two are both "not equal" for the probe, but -1 leaves the error for the caller. + ffi::PyObject_RichCompareBool(a, b, ffi::Py_EQ) == 1 } diff --git a/src/store.rs b/src/store.rs index ed739cd..f3334e5 100644 --- a/src/store.rs +++ b/src/store.rs @@ -493,6 +493,13 @@ impl CachedFunction { None => false, }; + // A key's __eq__ may have raised during the probe above (RichCompareBool + // returned -1, leaving the exception set and the probe reported as a miss). + // Surface it instead of recomputing with an exception pending (#36). + if let Some(err) = PyErr::take(py) { + return Err(err); + } + // Cache miss (or reentrant bypass): call the wrapped function (no lock held) let result = self.fn_obj.bind(py).call(args, kwargs.as_ref())?.unbind(); @@ -530,6 +537,12 @@ impl CachedFunction { None => true, }; + // The double-check lookup runs __eq__ too; if it raised, propagate + // rather than inserting and returning with an exception pending (#36). + if let Some(err) = PyErr::take(py) { + return Err(err); + } + if needs_insert { // Remove expired entry from map if present (order cleaned lazily) shard.map.remove(&cache_key); @@ -588,6 +601,11 @@ impl CachedFunction { } } + // A raising __eq__ during the probe leaves the exception set (#36). + if let Some(err) = PyErr::take(py) { + return Err(err); + } + self.misses.fetch_add(1, Ordering::Relaxed); Ok(None) } @@ -623,6 +641,11 @@ impl CachedFunction { } } + // A raising __eq__ during the probe leaves the exception set (#36). + if let Some(err) = PyErr::take(py) { + return Err(err); + } + self.misses.fetch_add(1, Ordering::Relaxed); Ok((false, py.None())) } diff --git a/tests/test_raising_eq.py b/tests/test_raising_eq.py new file mode 100644 index 0000000..e13b75c --- /dev/null +++ b/tests/test_raising_eq.py @@ -0,0 +1,103 @@ +"""Regression tests for issue #36. + +When a cache key's ``__eq__`` raises during a lookup, ``PyObject_RichCompareBool`` +returns -1 with a Python exception set. The old code compared the result ``== 1``, +mapping -1 to "not equal" and silently dropping the exception. The user's real +error was lost and PyO3 later surfaced a confusing +``SystemError: ... returned a result with an exception set`` (and, in collision +cases, a spurious recompute). The fix fetches the pending exception after the +lookup and propagates it. +""" + +import pytest + +from warp_cache import cache + + +class RaisingEq: + """Constant hash so all instances collide into one slot — this forces + hashbrown to invoke ``__eq__`` during probing — and ``__eq__`` always raises.""" + + def __hash__(self): + return 0 + + def __eq__(self, other): + raise RuntimeError("boom from __eq__") + + +def test_raising_eq_propagates_on_call(): + @cache(max_size=128) + def f(key): + return 1 + + f(RaisingEq()) # prime: empty bucket, no comparison yet + # Second call collides (same hash) -> __eq__ runs and raises. The original + # RuntimeError must propagate, not a masked SystemError. + with pytest.raises(RuntimeError, match="boom from __eq__"): + f(RaisingEq()) + + +def test_raising_eq_propagates_on_get(): + @cache(max_size=128) + def f(key): + return 1 + + f(RaisingEq()) # prime + with pytest.raises(RuntimeError, match="boom from __eq__"): + f.get(RaisingEq()) + + +def test_raising_eq_propagates_on_probe(): + @cache(max_size=128) + def f(key): + return 1 + + f(RaisingEq()) # prime + with pytest.raises(RuntimeError, match="boom from __eq__"): + f._probe(RaisingEq()) + + +def test_cache_usable_after_raising_eq(): + """A raising __eq__ must not leave a dangling exception that poisons the + next, unrelated call.""" + + @cache(max_size=128) + def f(key): + return 1 + + f(RaisingEq()) + with pytest.raises(RuntimeError): + f(RaisingEq()) + + # Different key type, no collision, no raising — must work cleanly. + assert f("ok") == 1 + assert f("ok") == 1 # cached hit, no lingering error + + +def test_non_raising_collision_still_caches(): + """Guard against over-eager error detection: keys that collide by hash but + compare cleanly (return False) must still cache independently.""" + + class CleanCollide: + def __init__(self, tag): + self.tag = tag + + def __hash__(self): + return 0 # force collisions + + def __eq__(self, other): + return isinstance(other, CleanCollide) and self.tag == other.tag + + calls = {"n": 0} + + @cache(max_size=8) + def f(key): + calls["n"] += 1 + return key.tag + + a, b = CleanCollide(1), CleanCollide(2) + assert f(a) == 1 + assert f(b) == 2 # collides with a but != a -> distinct entry + assert f(a) == 1 # hit + assert f(b) == 2 # hit + assert calls["n"] == 2