From 109f9ad8c421a960507c7de5425de20c45f97d52 Mon Sep 17 00:00:00 2001 From: zatchbell1311-wq Date: Fri, 26 Jun 2026 15:29:30 +0530 Subject: [PATCH 1/3] perf(got): batch aggregation pairs to reduce scheduler round-trips --- inferlets/graph-of-thought/src/lib.rs | 137 +++++++++++++++----------- 1 file changed, 77 insertions(+), 60 deletions(-) diff --git a/inferlets/graph-of-thought/src/lib.rs b/inferlets/graph-of-thought/src/lib.rs index 41c4b3967..dd08d1856 100644 --- a/inferlets/graph-of-thought/src/lib.rs +++ b/inferlets/graph-of-thought/src/lib.rs @@ -1,12 +1,23 @@ //! Demonstrates Graph-of-Thought (GoT) for hierarchical aggregation. //! //! This example generates multiple initial proposals concurrently, then -//! progressively aggregates them in pairs across multiple levels. The streaming -//! nature allows aggregation to begin as soon as pairs of proposals are ready, -//! maximizing parallelism. +//! progressively aggregates them in pairs across multiple levels. +//! +//! # Batched aggregation +//! +//! The original implementation processed proposals one at a time in a +//! `while let Some` loop, which meant aggregation pairs were launched +//! sequentially. This prevented the runtime scheduler from seeing all +//! aggregation requests simultaneously and coalescing them into a single +//! batched forward pass. +//! +//! This version collects all proposals first via `future::join_all`, then +//! launches all aggregation pairs in one shot. The scheduler now receives +//! all N/2 aggregation `flush()` calls concurrently and can batch them +//! together, reducing the number of GPU kernel invocations per level from +//! N/2 sequential firings to ideally 1 batched firing. -use futures::stream::FuturesUnordered; -use futures::{StreamExt, future}; +use futures::future; use inferlet::{ Context, sample::Sampler, model::Model, runtime, Result, @@ -39,6 +50,39 @@ const AGGREGATE_PROMPT: &str = "\ Please compare the following solution with the one you just provided \ and aggregate their ideas into a single, improved solution:\n"; +/// Pair up a flat list of (text, context) results into aggregation tasks, +/// launching all pairs simultaneously so the scheduler can batch them. +/// +/// If the input has an odd number of items, the last unpaired item is +/// dropped (its context is released immediately, freeing KV pages). +fn launch_aggregation_pairs( + items: Vec<(String, Context)>, + aggregation_tokens: usize, +) -> Result>>> { + let mut tasks = Vec::new(); + let mut iter = items.into_iter(); + + while let Some((text_a, _ctx_a)) = iter.next() { + if let Some((_text_b, mut ctx_b)) = iter.next() { + // ctx_a is dropped here — its unique KV pages are freed immediately. + // ctx_b carries the shared prefix forward into the aggregation. + let prompt = format!("{}{}", AGGREGATE_PROMPT, text_a); + ctx_b.user(&prompt); + ctx_b.cue(); + tasks.push(async move { + let text = ctx_b + .generate(Sampler::TopP { temperature: 0.6, p: 0.95 }) + .max_tokens(aggregation_tokens) + .collect_text() + .await?; + Ok((text, ctx_b)) + }); + } + // Odd item out: context dropped, KV pages freed. + } + + Ok(tasks) +} /// Main logic for running the hierarchical aggregation workflow. async fn run_hierarchical_aggregation( @@ -47,12 +91,12 @@ async fn run_hierarchical_aggregation( proposal_tokens: Vec, aggregation_tokens: usize, ) -> Result> { - // --- Stage 1: Generate Initial Proposals --- + // --- Stage 1: Generate Initial Proposals (all concurrent) --- let propose_prompt = PROPOSAL_PROMPT_TEMPLATE.replace("{}", question); base_context.user(&propose_prompt); base_context.flush().await?; - let mut proposal_tasks = proposal_tokens + let proposal_futures = proposal_tokens .into_iter() .map(|max_tokens| { let mut ctx = base_context.fork()?; @@ -66,62 +110,35 @@ async fn run_hierarchical_aggregation( Ok::<_, String>((proposal_text, ctx)) }) }) - .collect::>>()? + .collect::>>()?; + + // Collect ALL proposals before pairing. This lets Stage 2 launch all + // aggregation pairs simultaneously rather than as each proposal dribbles in. + let proposals: Vec<(String, Context)> = future::join_all(proposal_futures) + .await .into_iter() - .collect::>(); - - // --- Stage 2: First-Level Aggregation (Pairing Proposals) --- - let mut first_aggregation_tasks = FuturesUnordered::new(); - let mut pending_proposal: Option<(String, Context)> = None; - - while let Some(result) = proposal_tasks.next().await { - let (proposal_text, mut proposal_ctx) = result?; - if pending_proposal.is_none() { - pending_proposal = Some((proposal_text, proposal_ctx)); - } else { - let (previous_proposal_text, _) = pending_proposal.take().unwrap(); - let aggregation_prompt = format!("{}{}", AGGREGATE_PROMPT, previous_proposal_text); - proposal_ctx.user(&aggregation_prompt); - proposal_ctx.cue(); - - first_aggregation_tasks.push(async move { - let aggregation_text = proposal_ctx - .generate(Sampler::TopP { temperature: 0.6, p: 0.95 }) - .max_tokens(aggregation_tokens) - .collect_text() - .await?; - Ok::<_, String>((aggregation_text, proposal_ctx)) - }); - } - } + .collect::>()?; + + // --- Stage 2: First-Level Aggregation (all pairs launched at once) --- + // All N/2 aggregation flush() calls hit the scheduler simultaneously, + // which can coalesce them into a single batched forward pass. + let first_agg_futures = launch_aggregation_pairs(proposals, aggregation_tokens)?; + let first_aggregations: Vec<(String, Context)> = future::join_all(first_agg_futures) + .await + .into_iter() + .collect::>()?; - // --- Stage 3: Second-Level Aggregation (Pairing Aggregations) --- - let mut second_aggregation_tasks = Vec::new(); - let mut pending_aggregation: Option<(String, Context)> = None; - - while let Some(result) = first_aggregation_tasks.next().await { - let (aggregation_text, mut aggregation_ctx) = result?; - if pending_aggregation.is_none() { - pending_aggregation = Some((aggregation_text, aggregation_ctx)); - } else { - let (previous_aggregation_text, _) = pending_aggregation.take().unwrap(); - let final_prompt = format!("{}{}", AGGREGATE_PROMPT, previous_aggregation_text); - aggregation_ctx.user(&final_prompt); - aggregation_ctx.cue(); - - second_aggregation_tasks.push(async move { - aggregation_ctx - .generate(Sampler::TopP { temperature: 0.6, p: 0.95 }) - .max_tokens(aggregation_tokens) - .collect_text() - .await - }); - } - } + // --- Stage 3: Second-Level Aggregation (all pairs launched at once) --- + let second_agg_futures = launch_aggregation_pairs(first_aggregations, aggregation_tokens)?; + let final_results: Vec = future::join_all(second_agg_futures) + .await + .into_iter() + .collect::>>()? + .into_iter() + .map(|(text, _ctx)| text) + .collect(); - // --- Stage 4: Collect Final Results --- - let results = future::join_all(second_aggregation_tasks).await; - results.into_iter().collect::>>() + Ok(final_results) } #[inferlet::main] From da3242221506835931f676466087f8bfea6c78a2 Mon Sep 17 00:00:00 2001 From: Dhruv Dubey Date: Sat, 27 Jun 2026 07:50:14 +0000 Subject: [PATCH 2/3] fix(got): include question in aggregation prompt to prevent context loss --- inferlets/graph-of-thought/src/lib.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/inferlets/graph-of-thought/src/lib.rs b/inferlets/graph-of-thought/src/lib.rs index 41c4b3967..00882b465 100644 --- a/inferlets/graph-of-thought/src/lib.rs +++ b/inferlets/graph-of-thought/src/lib.rs @@ -31,13 +31,11 @@ fn default_aggregation_tokens() -> usize { 256 } const SYSTEM_PROMPT: &str = "You are a helpful, respectful and honest assistant."; const PROPOSAL_PROMPT_TEMPLATE: &str = "\ -Could you suggest a method or approach to solve the following question? \ -Please provide a high-level plan without doing the actual calculation. \ -Keep it concise, around 80 words. Question: {}"; +Solve the following math problem step by step. Show your work and give the final numeric answer. \ +End your response with: ANSWER: \n\ +Question: {}"; -const AGGREGATE_PROMPT: &str = "\ -Please compare the following solution with the one you just provided \ -and aggregate their ideas into a single, improved solution:\n"; +const AGGREGATE_PROMPT: &str = ""; // unused - see format_aggregate_prompt below /// Main logic for running the hierarchical aggregation workflow. @@ -80,7 +78,7 @@ async fn run_hierarchical_aggregation( pending_proposal = Some((proposal_text, proposal_ctx)); } else { let (previous_proposal_text, _) = pending_proposal.take().unwrap(); - let aggregation_prompt = format!("{}{}", AGGREGATE_PROMPT, previous_proposal_text); + let aggregation_prompt = format!("The math problem is: {}\n\nHere is another solution:\n{}\n\nNow give the correct final answer. End with: ANSWER: ",question, previous_proposal_text); proposal_ctx.user(&aggregation_prompt); proposal_ctx.cue(); From 612a8b5554f0e982600ecaff37da1a6733626b9d Mon Sep 17 00:00:00 2001 From: zatchbell1311-wq Date: Sat, 27 Jun 2026 19:13:37 +0530 Subject: [PATCH 3/3] fix(got): add guard for minimum 4 proposals required by two aggregation levels --- inferlets/graph-of-thought/src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/inferlets/graph-of-thought/src/lib.rs b/inferlets/graph-of-thought/src/lib.rs index dd08d1856..ae444a5db 100644 --- a/inferlets/graph-of-thought/src/lib.rs +++ b/inferlets/graph-of-thought/src/lib.rs @@ -157,6 +157,10 @@ async fn main(input: Input) -> Result { proposal_tokens, aggregation_tokens ); + if proposal_tokens.len() < 4 { + return Err("GoT requires at least 4 proposals for two aggregation levels".to_string()); + } + let models = runtime::models(); let model_name = models.first().ok_or("No models available")?; let model = Model::load(model_name)?;