[llm][kv][7/N] Prefill/decode token load aware request routing#64400
[llm][kv][7/N] Prefill/decode token load aware request routing#64400jeffreywang88 wants to merge 2 commits into
Conversation
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
… token load Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
There was a problem hiding this comment.
Code Review
This pull request implements request lifecycle tracking and active load booking for KV-aware routing in Ray LLM. It introduces a RequestLifecycle dataclass to track in-flight request states and updates KVRouterActor to book reservations, update decode progress, and free reservations in the selection service. Additionally, token_tracking.py is updated to pass the expected output tokens, and extensive tests are added to validate these behaviors. The review feedback highlights three key areas for improvement: addressing a potential memory leak in _effective_prefill_tokens_by_request when requests are cancelled or fail prematurely, improving robustness in token tracking to prevent exceptions from disrupting the core engine stream, and ensuring _get_decay_fraction safely handles negative or zero values for expected output tokens.
| self._effective_prefill_tokens_by_request[request_id] = selection[ | ||
| "effective_prefill_tokens" | ||
| ] |
There was a problem hiding this comment.
Potential Memory Leak in _effective_prefill_tokens_by_request
When a request is routed via select_worker, its request_id and effective_prefill_tokens are stored in self._effective_prefill_tokens_by_request. Under normal circumstances, this entry is popped during on_request_added or on_request_completed.
However, if a request is cancelled, times out, or fails before it actually starts executing on the selected replica, on_request_added and on_request_completed will never be called on that replica. This causes the request_id to remain in self._effective_prefill_tokens_by_request indefinitely, leading to a memory leak in the long-running router actor.
To prevent this, we can implement a simple FIFO eviction policy to bound the size of the dictionary.
self._effective_prefill_tokens_by_request[request_id] = selection[
"effective_prefill_tokens"
]
if len(self._effective_prefill_tokens_by_request) > 10000:
oldest_key = next(iter(self._effective_prefill_tokens_by_request))
self._effective_prefill_tokens_by_request.pop(oldest_key, None)| # The request's own output cap is its expected length; weights | ||
| # the selection service's decode-block decay. | ||
| # TODO(jeffreywang): Use an agent-provided expected-OSL hint for | ||
| # more accurate decode-load estimation. | ||
| sampling_params.max_tokens, |
There was a problem hiding this comment.
Robustness Improvement for Token Tracking
If _get_prompt_token_ids(prompt) or RequestTokenTracker initialization raises an exception (for example, if prompt is not in the expected format or sampling_params is missing attributes), it will propagate and disrupt the engine's output stream. This violates the design goal of isolating token tracking from the core engine execution.
To ensure that any failure in the token tracker never disrupts the user's request, consider wrapping the tracker initialization and progress reporting in try-except blocks, falling back gracefully to direct stream consumption if tracking fails:
lifecycle_request_id = get_serve_request_id() or request_id
tracker = None
try:
prompt_token_ids = _get_prompt_token_ids(prompt)
tracker = RequestTokenTracker(
forwarder,
lifecycle_request_id,
prompt_token_ids,
sampling_params.max_tokens,
)
except Exception as e:
logger.warning("Failed to initialize token tracker: %s", e)
try:
async for output in stream:
if tracker is not None:
try:
tracker.on_output(output)
except Exception as e:
logger.warning("Error in token tracker on_output: %s", e)
yield output
finally:
if tracker is not None:
try:
tracker.finish()
except Exception as e:
logger.warning("Error in token tracker finish: %s", e)| def _get_decay_fraction(self, state: RequestLifecycle) -> Optional[float]: | ||
| """Fraction of output still expected, or ``None`` without an estimate; | ||
| weights each decode block by how much generation remains.""" | ||
| if not state.expected_output_tokens: | ||
| return None | ||
| return max(0.0, 1.0 - state.output_tokens / state.expected_output_tokens) |
There was a problem hiding this comment.
Robustness Improvement for _get_decay_fraction
If state.expected_output_tokens is negative or zero, the current implementation might return an unexpected decay fraction or raise a division-by-zero error (though not state.expected_output_tokens guards against 0 and None).
To make this logic more robust against invalid or negative values, explicitly check that expected_output_tokens is strictly positive.
| def _get_decay_fraction(self, state: RequestLifecycle) -> Optional[float]: | |
| """Fraction of output still expected, or ``None`` without an estimate; | |
| weights each decode block by how much generation remains.""" | |
| if not state.expected_output_tokens: | |
| return None | |
| return max(0.0, 1.0 - state.output_tokens / state.expected_output_tokens) | |
| def _get_decay_fraction(self, state: RequestLifecycle) -> Optional[float]: | |
| """Fraction of output still expected, or None without an estimate; | |
| weights each decode block by how much generation remains.""" | |
| if state.expected_output_tokens is None or state.expected_output_tokens <= 0: | |
| return None | |
| return max(0.0, 1.0 - state.output_tokens / state.expected_output_tokens) |
Description
Related issues
Additional information