Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion python/ray/serve/_private/replica_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,80 @@ async def __anext__(self):
return await asyncio.wrap_future(fut)

def add_done_callback(self, callback: Callable):
self._call.add_done_callback(callback)
"""Register ``callback``, invoked when the underlying RPC completes.

The actor transport invokes done-callbacks with the request's result
or a ``RayError`` (see ``ActorReplicaResult.add_done_callback``).
``grpc.aio`` instead invokes them with the raw ``grpc.aio.Call``, which
is opaque to consumers that rely on the actor-transport contract -- most
importantly the router's request-completion handler, which invalidates
the queue-length cache when a replica becomes unavailable (see
``Router._process_finished_request``). Without normalization a failed
gRPC request is silently ignored, so a dead replica keeps getting
selected by power-of-two-choices routing until another code path probes
it (https://github.com/ray-project/ray/issues/63261).

Normalize a failed call into the same ``ActorUnavailableError`` the data
path raises (see ``_process_grpc_response`` / ``get_rejection_response``)
so behavior is consistent across transports.
"""

def _on_done(call: grpc.aio.Call):
# The status code is only available via a coroutine, so resolve it
# before invoking ``callback``. The call is already complete here, so
# the coroutine resolves without blocking.
async def _invoke():
callback(await self._normalize_done_result(call))

# grpc.aio fires done-callbacks on the call's own loop, so in the
# common case schedule directly there and avoid the thread-safe
# queue/locking overhead of run_coroutine_threadsafe.
try:
current_loop = asyncio.get_running_loop()
except RuntimeError:
current_loop = None

if current_loop is self._grpc_call_loop and current_loop.is_running():
current_loop.create_task(_invoke())
return

# Called from a different thread, or no running loop (e.g.
# interpreter shutdown): hop onto the call's loop, falling back to
# the previous behavior if that loop is gone rather than dropping
# the callback entirely.
coro = _invoke()
try:
run_coroutine_threadsafe(coro, self._grpc_call_loop)
except RuntimeError:
coro.close()
callback(call)
Comment on lines +619 to +624

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When _on_done is called, it is almost always executed on the same event loop thread (self._grpc_call_loop). Using run_coroutine_threadsafe in this case introduces unnecessary overhead (such as thread-safe queues and locking). We can optimize this by checking if we are already running on the target event loop and using create_task directly, falling back to run_coroutine_threadsafe only if called from a different thread or during shutdown.

            try:
                current_loop = asyncio.get_running_loop()
            except RuntimeError:
                current_loop = None

            if current_loop is self._grpc_call_loop and current_loop.is_running():
                current_loop.create_task(_invoke())
            else:
                coro = _invoke()
                try:
                    run_coroutine_threadsafe(coro, self._grpc_call_loop)
                except RuntimeError:
                    coro.close()
                    callback(call)


self._call.add_done_callback(_on_done)

async def _normalize_done_result(self, call: grpc.aio.Call) -> Any:
"""Map a completed gRPC call to the actor-transport callback contract.

Returns an ``ActorUnavailableError`` if the replica was unreachable,
otherwise the ``call`` itself (preserving previous behavior).
"""
try:
code = await call.code()
except Exception:
# If the status can't be determined, don't mask the outcome.
return call

# UNAVAILABLE means the replica's gRPC server was unreachable, so the
# request never completed on a live replica. Treat it like a
# RayActorError so the router invalidates its cache and reroutes; if the
# replica is actually dead the router learns that via active probing.
# (CANCELLED is intentionally excluded: in the done-callback path it is
# most often a client-initiated cancellation, not a replica failure.)
if code == grpc.StatusCode.UNAVAILABLE:
return ActorUnavailableError(
"Actor is unavailable.", self._actor_id.binary()
)

return call

def cancel(self):
self._call.cancel()
Expand Down
88 changes: 88 additions & 0 deletions python/ray/serve/tests/unit/test_grpc_replica_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import sys
import threading

import grpc
import pytest

from ray import ActorID, cloudpickle
from ray._common.test_utils import wait_for_condition
from ray.exceptions import ActorUnavailableError
from ray.serve._private.common import RequestMetadata
from ray.serve._private.replica_result import gRPCReplicaResult
from ray.serve.generated import serve_pb2
Expand Down Expand Up @@ -387,5 +389,91 @@ async def test_streaming_error_async(self, create_asyncio_event_loop_in_thread):
await replica_result.__anext__()


class FakegRPCCallWithStatus:
"""Minimal fake of a completed ``grpc.aio.Call`` for done-callback tests.

``grpc.aio`` invokes done-callbacks with the call object itself and exposes
the final status only via the async ``code()`` method, so we mirror that.
"""

def __init__(self, code: grpc.StatusCode):
self._loop = asyncio.get_running_loop()
self._code = code
self._done_callbacks = []

def add_done_callback(self, cb):
self._done_callbacks.append(cb)

async def code(self) -> grpc.StatusCode:
return self._code

def complete(self):
# grpc invokes done-callbacks with the call object itself.
for cb in self._done_callbacks:
cb(self)


@pytest.mark.asyncio
class TestDoneCallbackNormalization:
"""gRPCReplicaResult.add_done_callback must normalize a failed call into the
same shape the actor transport delivers, so the router's completion handler
can invalidate its queue-length cache for the dead replica. See
https://github.com/ray-project/ray/issues/63261.
"""

def make_result(self, code: grpc.StatusCode):
fake_call = FakegRPCCallWithStatus(code)
result = gRPCReplicaResult(
fake_call,
metadata=RequestMetadata(
request_id="",
internal_request_id="",
is_streaming=False,
_on_separate_loop=False,
),
actor_id=ActorID(b"2" * 16),
loop=asyncio.get_running_loop(),
)
return result, fake_call

async def _fire_and_capture(self, code: grpc.StatusCode):
result, fake_call = self.make_result(code)
event = asyncio.Event()
received = []

def callback(r):
received.append(r)
event.set()

result.add_done_callback(callback)
fake_call.complete()
# Normalization resolves the status code asynchronously on the call's
# loop; wait deterministically for the callback to be invoked.
try:
await asyncio.wait_for(event.wait(), timeout=2.0)
except asyncio.TimeoutError:
pytest.fail("done-callback was never invoked")
return received[0], fake_call

async def test_unavailable_normalized_to_actor_unavailable(self):
"""A failed (UNAVAILABLE) call is surfaced as ActorUnavailableError so the
router invalidates its cache instead of silently ignoring the failure."""
received, _ = await self._fire_and_capture(grpc.StatusCode.UNAVAILABLE)
assert isinstance(received, ActorUnavailableError)

async def test_ok_passes_through_call(self):
"""A successful call preserves the previous behavior: the callback
receives the call object, not a synthesized error."""
received, fake_call = await self._fire_and_capture(grpc.StatusCode.OK)
assert received is fake_call

async def test_cancelled_not_treated_as_failure(self):
"""CANCELLED (typically a client-initiated cancellation) must NOT be
converted into a retryable ActorUnavailableError."""
received, fake_call = await self._fire_and_capture(grpc.StatusCode.CANCELLED)
assert not isinstance(received, ActorUnavailableError)
assert received is fake_call


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
Loading