Skip to content

[llm][kv][7/N] Prefill/decode token load aware request routing#64400

Draft
jeffreywang88 wants to merge 2 commits into
request-lifecycle-trackingfrom
load-tracking
Draft

[llm][kv][7/N] Prefill/decode token load aware request routing#64400
jeffreywang88 wants to merge 2 commits into
request-lifecycle-trackingfrom
load-tracking

Conversation

@jeffreywang88

Copy link
Copy Markdown
Contributor

Description

Briefly describe what this PR accomplishes and why it's needed.

Related issues

Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234".

Additional information

Optional: Add implementation details, API changes, usage examples, screenshots, etc.

Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
… token load

Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements request lifecycle tracking and active load booking for KV-aware routing in Ray LLM. It introduces a RequestLifecycle dataclass to track in-flight request states and updates KVRouterActor to book reservations, update decode progress, and free reservations in the selection service. Additionally, token_tracking.py is updated to pass the expected output tokens, and extensive tests are added to validate these behaviors. The review feedback highlights three key areas for improvement: addressing a potential memory leak in _effective_prefill_tokens_by_request when requests are cancelled or fail prematurely, improving robustness in token tracking to prevent exceptions from disrupting the core engine stream, and ensuring _get_decay_fraction safely handles negative or zero values for expected output tokens.

Comment on lines +292 to +294
self._effective_prefill_tokens_by_request[request_id] = selection[
"effective_prefill_tokens"
]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Potential Memory Leak in _effective_prefill_tokens_by_request

When a request is routed via select_worker, its request_id and effective_prefill_tokens are stored in self._effective_prefill_tokens_by_request. Under normal circumstances, this entry is popped during on_request_added or on_request_completed.

However, if a request is cancelled, times out, or fails before it actually starts executing on the selected replica, on_request_added and on_request_completed will never be called on that replica. This causes the request_id to remain in self._effective_prefill_tokens_by_request indefinitely, leading to a memory leak in the long-running router actor.

To prevent this, we can implement a simple FIFO eviction policy to bound the size of the dictionary.

        self._effective_prefill_tokens_by_request[request_id] = selection[
            "effective_prefill_tokens"
        ]
        if len(self._effective_prefill_tokens_by_request) > 10000:
            oldest_key = next(iter(self._effective_prefill_tokens_by_request))
            self._effective_prefill_tokens_by_request.pop(oldest_key, None)

Comment on lines +152 to +156
# The request's own output cap is its expected length; weights
# the selection service's decode-block decay.
# TODO(jeffreywang): Use an agent-provided expected-OSL hint for
# more accurate decode-load estimation.
sampling_params.max_tokens,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Robustness Improvement for Token Tracking

If _get_prompt_token_ids(prompt) or RequestTokenTracker initialization raises an exception (for example, if prompt is not in the expected format or sampling_params is missing attributes), it will propagate and disrupt the engine's output stream. This violates the design goal of isolating token tracking from the core engine execution.

To ensure that any failure in the token tracker never disrupts the user's request, consider wrapping the tracker initialization and progress reporting in try-except blocks, falling back gracefully to direct stream consumption if tracking fails:

            lifecycle_request_id = get_serve_request_id() or request_id
            tracker = None
            try:
                prompt_token_ids = _get_prompt_token_ids(prompt)
                tracker = RequestTokenTracker(
                    forwarder,
                    lifecycle_request_id,
                    prompt_token_ids,
                    sampling_params.max_tokens,
                )
            except Exception as e:
                logger.warning("Failed to initialize token tracker: %s", e)

            try:
                async for output in stream:
                    if tracker is not None:
                        try:
                            tracker.on_output(output)
                        except Exception as e:
                            logger.warning("Error in token tracker on_output: %s", e)
                    yield output
            finally:
                if tracker is not None:
                    try:
                        tracker.finish()
                    except Exception as e:
                        logger.warning("Error in token tracker finish: %s", e)

Comment on lines +383 to +388
def _get_decay_fraction(self, state: RequestLifecycle) -> Optional[float]:
"""Fraction of output still expected, or ``None`` without an estimate;
weights each decode block by how much generation remains."""
if not state.expected_output_tokens:
return None
return max(0.0, 1.0 - state.output_tokens / state.expected_output_tokens)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Robustness Improvement for _get_decay_fraction

If state.expected_output_tokens is negative or zero, the current implementation might return an unexpected decay fraction or raise a division-by-zero error (though not state.expected_output_tokens guards against 0 and None).

To make this logic more robust against invalid or negative values, explicitly check that expected_output_tokens is strictly positive.

Suggested change
def _get_decay_fraction(self, state: RequestLifecycle) -> Optional[float]:
"""Fraction of output still expected, or ``None`` without an estimate;
weights each decode block by how much generation remains."""
if not state.expected_output_tokens:
return None
return max(0.0, 1.0 - state.output_tokens / state.expected_output_tokens)
def _get_decay_fraction(self, state: RequestLifecycle) -> Optional[float]:
"""Fraction of output still expected, or None without an estimate;
weights each decode block by how much generation remains."""
if state.expected_output_tokens is None or state.expected_output_tokens <= 0:
return None
return max(0.0, 1.0 - state.output_tokens / state.expected_output_tokens)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant