perf(got): batch aggregation pairs to reduce scheduler round-trips#449
perf(got): batch aggregation pairs to reduce scheduler round-trips#449zatchbell1311-wq wants to merge 4 commits into
Conversation
WalkthroughThe GoT example now generates proposals and both aggregation levels in batches with ChangesGraph-of-Thought batched aggregation
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
inferlets/graph-of-thought/src/lib.rs (1)
117-120: 🩺 Stability & Availability | 🔵 TrivialUse
try_join_allto short-circuit on the first error.The
futures = "0.3"dependency is available, which supportsfutures::future::try_join_all. The current pattern ofjoin_all(...).await.into_iter().collect()waits for all futures to complete even if one fails, potentially wasting resources and delaying error reporting.Replace the manual collection with
try_join_allto:
- Preserve concurrent execution of all futures.
- Return immediately upon the first error.
- Cancel remaining pending futures (if their
Futureimplementation supports cancellation).let proposal_futures = proposal_tokens .into_iter() .map(|max_tokens| { // ... }); - let proposals: Vec<(String, Context)> = future::join_all(proposal_futures) - .await - .into_iter() - .collect::<Result<_>>()?; + let proposals: Vec<(String, Context)> = future::try_join_all(proposal_futures).await?;Also replace the occurrences in:
- Lines 126–129 (
first_aggregations)- Lines 133–136 (
final_results)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@inferlets/graph-of-thought/src/lib.rs` around lines 117 - 120, Replace the manual `future::join_all(...).await.into_iter().collect::<Result<_>>()?` patterns in `lib.rs` with `future::try_join_all` so the async work short-circuits on the first error instead of waiting for every future to finish. Update the `proposals`, `first_aggregations`, and `final_results` joins in the relevant helper flow to use `try_join_all` while keeping the same concurrent execution and result types.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@inferlets/graph-of-thought/src/lib.rs`:
- Around line 94-99: The proposal pipeline in graph-of-thought can succeed with
an empty result when there are fewer than four proposal branches, because the
fixed aggregation stages in the flow leading from proposal_futures can never
form a second-level pair. Add an upfront validation in the main execution path
before starting Stage 1 or the later aggregation steps to reject insufficient
proposal counts, and return an error instead of Ok(vec![]) so main cannot report
success without a final solution.
- Around line 65-82: The aggregation loop in lib.rs assumes contexts are freed
immediately when dropped, but Context relies on destroy(self) for deterministic
KV-page release. Update the logic around the while let Some((text_a, _ctx_a))
pattern so discarded contexts are explicitly destroyed via Context::destroy(),
including the unused ctx_a and any odd trailing context, while keeping ctx_b
alive only for the async task in the aggregation flow.
---
Nitpick comments:
In `@inferlets/graph-of-thought/src/lib.rs`:
- Around line 117-120: Replace the manual
`future::join_all(...).await.into_iter().collect::<Result<_>>()?` patterns in
`lib.rs` with `future::try_join_all` so the async work short-circuits on the
first error instead of waiting for every future to finish. Update the
`proposals`, `first_aggregations`, and `final_results` joins in the relevant
helper flow to use `try_join_all` while keeping the same concurrent execution
and result types.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: f4808318-76eb-4f0b-b6fb-ee2aff171833
📒 Files selected for processing (1)
inferlets/graph-of-thought/src/lib.rs
| while let Some((text_a, _ctx_a)) = iter.next() { | ||
| if let Some((_text_b, mut ctx_b)) = iter.next() { | ||
| // ctx_a is dropped here — its unique KV pages are freed immediately. | ||
| // ctx_b carries the shared prefix forward into the aggregation. | ||
| let prompt = format!("{}{}", AGGREGATE_PROMPT, text_a); | ||
| ctx_b.user(&prompt); | ||
| ctx_b.cue(); | ||
| tasks.push(async move { | ||
| let text = ctx_b | ||
| .generate(Sampler::TopP { temperature: 0.6, p: 0.95 }) | ||
| .max_tokens(aggregation_tokens) | ||
| .collect_text() | ||
| .await?; | ||
| Ok((text, ctx_b)) | ||
| }); | ||
| } | ||
| // Odd item out: context dropped, KV pages freed. | ||
| } |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Inspect Context lifecycle/drop behavior without executing repository code.
rg -n -C4 'impl\s+Drop\s+for\s+Context|fn\s+drop\(|pub fn destroy|pub fn suspend' sdk/rust/inferlet/src/context.rs sdk/rust/inferlet/srcRepository: pie-project/pie
Length of output: 3738
🏁 Script executed:
#!/bin/bash
# Verify Context definition and Drop implementation
# 1. Find the struct definition of Context
echo "=== Searching Context struct definition ==="
rg -n -B5 -A20 '^pub struct Context' sdk/rust/inferlet/src/context.rs
echo -e "\n=== Searching for 'impl.*Context' blocks ==="
rg -n -A3 'pub impl.*Context|impl.*Context' sdk/rust/inferlet/src/context.rs | head -50
echo -e "\n=== Checking if Context is an alias or has derived Drop ==="
rg -n -B2 -A10 'pub type Context|struct Context' sdk/rust/inferlet/src/context.rsRepository: pie-project/pie
Length of output: 2383
🏁 Script executed:
#!/bin/bash
# Check RawContext Drop implementation
echo "=== Searching for RawContext struct and Drop ==="
rg -n -B2 -A10 'pub struct RawContext|impl\s+Drop\s+for\s+RawContext' sdk/rust/inferlet/src/context.rs
echo -e "\n=== Searching for 'destroy' implementation in RawContext or inner ==="
rg -n -B2 -A5 'fn\s+destroy\s*\(' sdk/rust/inferlet/src/context.rs | head -30Repository: pie-project/pie
Length of output: 504
Use explicit Context::destroy() for contexts intentionally discarded.
The Context struct does not implement a custom Drop trait to immediately free KV pages upon scope exit. Relying on the default drop behavior makes the claim that "KV pages are freed immediately" uncertain. The API provides destroy(self) specifically for deterministic, immediate release of resources.
Explicitly call destroy() for ctx_a and any odd trailing context to guarantee the intended lifecycle behavior.
♻️ Proposed lifecycle cleanup
- while let Some((text_a, _ctx_a)) = iter.next() {
+ while let Some((text_a, ctx_a)) = iter.next() {
if let Some((_text_b, mut ctx_b)) = iter.next() {
- // ctx_a is dropped here — its unique KV pages are freed immediately.
+ ctx_a.destroy();
// ctx_b carries the shared prefix forward into the aggregation.
let prompt = format!("{}{}", AGGREGATE_PROMPT, text_a);
ctx_b.user(&prompt);
ctx_b.cue();
@@
- }
- // Odd item out: context dropped, KV pages freed.
+ } else {
+ ctx_a.destroy();
+ }
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| while let Some((text_a, _ctx_a)) = iter.next() { | |
| if let Some((_text_b, mut ctx_b)) = iter.next() { | |
| // ctx_a is dropped here — its unique KV pages are freed immediately. | |
| // ctx_b carries the shared prefix forward into the aggregation. | |
| let prompt = format!("{}{}", AGGREGATE_PROMPT, text_a); | |
| ctx_b.user(&prompt); | |
| ctx_b.cue(); | |
| tasks.push(async move { | |
| let text = ctx_b | |
| .generate(Sampler::TopP { temperature: 0.6, p: 0.95 }) | |
| .max_tokens(aggregation_tokens) | |
| .collect_text() | |
| .await?; | |
| Ok((text, ctx_b)) | |
| }); | |
| } | |
| // Odd item out: context dropped, KV pages freed. | |
| } | |
| while let Some((text_a, ctx_a)) = iter.next() { | |
| if let Some((_text_b, mut ctx_b)) = iter.next() { | |
| ctx_a.destroy(); | |
| // ctx_b carries the shared prefix forward into the aggregation. | |
| let prompt = format!("{}{}", AGGREGATE_PROMPT, text_a); | |
| ctx_b.user(&prompt); | |
| ctx_b.cue(); | |
| tasks.push(async move { | |
| let text = ctx_b | |
| .generate(Sampler::TopP { temperature: 0.6, p: 0.95 }) | |
| .max_tokens(aggregation_tokens) | |
| .collect_text() | |
| .await?; | |
| Ok((text, ctx_b)) | |
| }); | |
| } else { | |
| ctx_a.destroy(); | |
| } | |
| } |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@inferlets/graph-of-thought/src/lib.rs` around lines 65 - 82, The aggregation
loop in lib.rs assumes contexts are freed immediately when dropped, but Context
relies on destroy(self) for deterministic KV-page release. Update the logic
around the while let Some((text_a, _ctx_a)) pattern so discarded contexts are
explicitly destroyed via Context::destroy(), including the unused ctx_a and any
odd trailing context, while keeping ctx_b alive only for the async task in the
aggregation flow.
Problem
Two issues in the original GoT implementation:
Sequential pairing — proposals were processed one at a time in a
while let Someloop, preventing the scheduler from seeing all aggregation requests simultaneously and batching them into a single forward pass.Question context loss — during aggregation, the forked context carries full conversation history, so the model loses track of what problem it's solving, causing incorrect answers.
Fix
future::join_allbefore pairing, then launch all aggregation pairs simultaneously. The scheduler can now coalesce all N/2flush()calls into a single batched forward pass per level.launch_aggregation_pairs()as a reusable helper used across all aggregation levels.Benchmark (5 GSM8K samples, Qwen3-0.6B, CUDA, RTX 2050)
GoT uses 3x more tokens and is 1.8x slower on this sample. The accuracy gap traces to the aggregation context issue addressed here. Further improvement requires restructuring aggregation to use fresh contexts rather than forked ones.
Closes #418 (partially — documents the context loss root cause)