Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,10 @@ For day-to-day workflow and commands, see [`CLAUDE.md`](../CLAUDE.md) and
so recursive `@cache` functions still cache. GIL builds use a single GIL-serialized flag;
free-threaded builds use a per-thread set of active function addresses. The shared backend is
unaffected (key comparison is by serialized bytes, never Python `__eq__`).
- **A raising `__eq__` must propagate, not be swallowed (issue #36).** `PyObject_RichCompareBool`
returns -1 with a Python exception set when a key's `__eq__` raises. `key.rs::rich_compare_eq`
reports -1 as "not equal" (so hashbrown stops probing) but leaves the exception set, and every
memory-backend lookup site (`__call__` read + write-double-check, `get`, `_probe`) calls
`PyErr::take` after the lookup and returns the error instead of recomputing / returning `Ok`
with an exception pending (which PyO3 turns into a masking `SystemError`). Any new lookup site
that can run `__eq__` must do the same check.
32 changes: 27 additions & 5 deletions src/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ impl PartialEq for CacheKey {
// serializes that via its reentrancy guard (try_enter, see issue #30), so
// no aliasing/reentrant shard guard is taken during this comparison.
// This is the same direct C API call that lru_cache uses.
unsafe {
ffi::PyObject_RichCompareBool(self.key_obj.as_ptr(), other.key_obj.as_ptr(), ffi::Py_EQ)
== 1
}
unsafe { rich_compare_eq(self.key_obj.as_ptr(), other.key_obj.as_ptr()) }
}
}

Expand Down Expand Up @@ -99,6 +96,31 @@ impl hashbrown::Equivalent<CacheKey> for BorrowedArgs {
// call stack) and `key.key_obj` is an owned reference in the map. The
// arbitrary Python __eq__ this runs may re-enter; CachedFunction's
// reentrancy guard prevents a second, aliasing shard guard (issue #30).
unsafe { ffi::PyObject_RichCompareBool(self.ptr, key.key_obj.as_ptr(), ffi::Py_EQ) == 1 }
unsafe { rich_compare_eq(self.ptr, key.key_obj.as_ptr()) }
}
}

/// `a == b` via Python's rich comparison, for use from `PartialEq`/`Equivalent`
/// (which can't return `Result`).
///
/// `PyObject_RichCompareBool` returns -1 and leaves a Python exception set when a
/// key's `__eq__` raises. We must NOT map that to `true`/`false` and silently drop
/// the exception (issue #36): callers fetch the pending exception after the lookup
/// and propagate it. Here, -1 is reported as "not equal" so hashbrown stops at this
/// slot, with the exception left set for the caller.
///
/// If an exception is already pending (an earlier comparison in the same lookup
/// raised), we return `false` without calling into Python again — re-entering the
/// interpreter with a live exception would clobber the original error.
///
/// # Safety
/// Both pointers must be valid live Python objects and the GIL must be held.
#[inline(always)]
unsafe fn rich_compare_eq(a: *mut ffi::PyObject, b: *mut ffi::PyObject) -> bool {
if !ffi::PyErr_Occurred().is_null() {
return false;
}
// 1 = equal; 0 = not equal; -1 = __eq__ raised (exception now set) — the latter
// two are both "not equal" for the probe, but -1 leaves the error for the caller.
ffi::PyObject_RichCompareBool(a, b, ffi::Py_EQ) == 1
}
23 changes: 23 additions & 0 deletions src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ impl CachedFunction {
None => false,
};

// A key's __eq__ may have raised during the probe above (RichCompareBool
// returned -1, leaving the exception set and the probe reported as a miss).
// Surface it instead of recomputing with an exception pending (#36).
if let Some(err) = PyErr::take(py) {
return Err(err);
}

// Cache miss (or reentrant bypass): call the wrapped function (no lock held)
let result = self.fn_obj.bind(py).call(args, kwargs.as_ref())?.unbind();

Expand Down Expand Up @@ -530,6 +537,12 @@ impl CachedFunction {
None => true,
};

// The double-check lookup runs __eq__ too; if it raised, propagate
// rather than inserting and returning with an exception pending (#36).
if let Some(err) = PyErr::take(py) {
return Err(err);
}

if needs_insert {
// Remove expired entry from map if present (order cleaned lazily)
shard.map.remove(&cache_key);
Expand Down Expand Up @@ -588,6 +601,11 @@ impl CachedFunction {
}
}

// A raising __eq__ during the probe leaves the exception set (#36).
if let Some(err) = PyErr::take(py) {
return Err(err);
}

self.misses.fetch_add(1, Ordering::Relaxed);
Ok(None)
}
Expand Down Expand Up @@ -623,6 +641,11 @@ impl CachedFunction {
}
}

// A raising __eq__ during the probe leaves the exception set (#36).
if let Some(err) = PyErr::take(py) {
return Err(err);
}

self.misses.fetch_add(1, Ordering::Relaxed);
Ok((false, py.None()))
}
Expand Down
103 changes: 103 additions & 0 deletions tests/test_raising_eq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Regression tests for issue #36.

When a cache key's ``__eq__`` raises during a lookup, ``PyObject_RichCompareBool``
returns -1 with a Python exception set. The old code compared the result ``== 1``,
mapping -1 to "not equal" and silently dropping the exception. The user's real
error was lost and PyO3 later surfaced a confusing
``SystemError: ... returned a result with an exception set`` (and, in collision
cases, a spurious recompute). The fix fetches the pending exception after the
lookup and propagates it.
"""

import pytest

from warp_cache import cache


class RaisingEq:
"""Constant hash so all instances collide into one slot — this forces
hashbrown to invoke ``__eq__`` during probing — and ``__eq__`` always raises."""

def __hash__(self):
return 0

def __eq__(self, other):
raise RuntimeError("boom from __eq__")


def test_raising_eq_propagates_on_call():
@cache(max_size=128)
def f(key):
return 1

f(RaisingEq()) # prime: empty bucket, no comparison yet
# Second call collides (same hash) -> __eq__ runs and raises. The original
# RuntimeError must propagate, not a masked SystemError.
with pytest.raises(RuntimeError, match="boom from __eq__"):
f(RaisingEq())


def test_raising_eq_propagates_on_get():
@cache(max_size=128)
def f(key):
return 1

f(RaisingEq()) # prime
with pytest.raises(RuntimeError, match="boom from __eq__"):
f.get(RaisingEq())


def test_raising_eq_propagates_on_probe():
@cache(max_size=128)
def f(key):
return 1

f(RaisingEq()) # prime
with pytest.raises(RuntimeError, match="boom from __eq__"):
f._probe(RaisingEq())


def test_cache_usable_after_raising_eq():
"""A raising __eq__ must not leave a dangling exception that poisons the
next, unrelated call."""

@cache(max_size=128)
def f(key):
return 1

f(RaisingEq())
with pytest.raises(RuntimeError):
f(RaisingEq())

# Different key type, no collision, no raising — must work cleanly.
assert f("ok") == 1
assert f("ok") == 1 # cached hit, no lingering error


def test_non_raising_collision_still_caches():
"""Guard against over-eager error detection: keys that collide by hash but
compare cleanly (return False) must still cache independently."""

class CleanCollide:
def __init__(self, tag):
self.tag = tag

def __hash__(self):
return 0 # force collisions

def __eq__(self, other):
return isinstance(other, CleanCollide) and self.tag == other.tag

calls = {"n": 0}

@cache(max_size=8)
def f(key):
calls["n"] += 1
return key.tag

a, b = CleanCollide(1), CleanCollide(2)
assert f(a) == 1
assert f(b) == 2 # collides with a but != a -> distinct entry
assert f(a) == 1 # hit
assert f(b) == 2 # hit
assert calls["n"] == 2
Loading