Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 81 additions & 60 deletions inferlets/graph-of-thought/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
//! Demonstrates Graph-of-Thought (GoT) for hierarchical aggregation.
//!
//! This example generates multiple initial proposals concurrently, then
//! progressively aggregates them in pairs across multiple levels. The streaming
//! nature allows aggregation to begin as soon as pairs of proposals are ready,
//! maximizing parallelism.
//! progressively aggregates them in pairs across multiple levels.
//!
//! # Batched aggregation
//!
//! The original implementation processed proposals one at a time in a
//! `while let Some` loop, which meant aggregation pairs were launched
//! sequentially. This prevented the runtime scheduler from seeing all
//! aggregation requests simultaneously and coalescing them into a single
//! batched forward pass.
//!
//! This version collects all proposals first via `future::join_all`, then
//! launches all aggregation pairs in one shot. The scheduler now receives
//! all N/2 aggregation `flush()` calls concurrently and can batch them
//! together, reducing the number of GPU kernel invocations per level from
//! N/2 sequential firings to ideally 1 batched firing.

use futures::stream::FuturesUnordered;
use futures::{StreamExt, future};
use futures::future;
use inferlet::{
Context, sample::Sampler, model::Model,
runtime, Result,
Expand Down Expand Up @@ -39,6 +50,39 @@ const AGGREGATE_PROMPT: &str = "\
Please compare the following solution with the one you just provided \
and aggregate their ideas into a single, improved solution:\n";

/// Pair up a flat list of (text, context) results into aggregation tasks,
/// launching all pairs simultaneously so the scheduler can batch them.
///
/// If the input has an odd number of items, the last unpaired item is
/// dropped (its context is released immediately, freeing KV pages).
fn launch_aggregation_pairs(
items: Vec<(String, Context)>,
aggregation_tokens: usize,
) -> Result<Vec<impl std::future::Future<Output = Result<(String, Context)>>>> {
let mut tasks = Vec::new();
let mut iter = items.into_iter();

while let Some((text_a, _ctx_a)) = iter.next() {
if let Some((_text_b, mut ctx_b)) = iter.next() {
// ctx_a is dropped here — its unique KV pages are freed immediately.
// ctx_b carries the shared prefix forward into the aggregation.
let prompt = format!("{}{}", AGGREGATE_PROMPT, text_a);
ctx_b.user(&prompt);
ctx_b.cue();
tasks.push(async move {
let text = ctx_b
.generate(Sampler::TopP { temperature: 0.6, p: 0.95 })
.max_tokens(aggregation_tokens)
.collect_text()
.await?;
Ok((text, ctx_b))
});
}
// Odd item out: context dropped, KV pages freed.
}
Comment on lines +65 to +82

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Inspect Context lifecycle/drop behavior without executing repository code.
rg -n -C4 'impl\s+Drop\s+for\s+Context|fn\s+drop\(|pub fn destroy|pub fn suspend' sdk/rust/inferlet/src/context.rs sdk/rust/inferlet/src

Repository: pie-project/pie

Length of output: 3738


🏁 Script executed:

#!/bin/bash
# Verify Context definition and Drop implementation
# 1. Find the struct definition of Context
echo "=== Searching Context struct definition ==="
rg -n -B5 -A20 '^pub struct Context' sdk/rust/inferlet/src/context.rs

echo -e "\n=== Searching for 'impl.*Context' blocks ==="
rg -n -A3 'pub impl.*Context|impl.*Context' sdk/rust/inferlet/src/context.rs | head -50

echo -e "\n=== Checking if Context is an alias or has derived Drop ==="
rg -n -B2 -A10 'pub type Context|struct Context' sdk/rust/inferlet/src/context.rs

Repository: pie-project/pie

Length of output: 2383


🏁 Script executed:

#!/bin/bash
# Check RawContext Drop implementation
echo "=== Searching for RawContext struct and Drop ==="
rg -n -B2 -A10 'pub struct RawContext|impl\s+Drop\s+for\s+RawContext' sdk/rust/inferlet/src/context.rs

echo -e "\n=== Searching for 'destroy' implementation in RawContext or inner ==="
rg -n -B2 -A5 'fn\s+destroy\s*\(' sdk/rust/inferlet/src/context.rs | head -30

Repository: pie-project/pie

Length of output: 504


Use explicit Context::destroy() for contexts intentionally discarded.

The Context struct does not implement a custom Drop trait to immediately free KV pages upon scope exit. Relying on the default drop behavior makes the claim that "KV pages are freed immediately" uncertain. The API provides destroy(self) specifically for deterministic, immediate release of resources.

Explicitly call destroy() for ctx_a and any odd trailing context to guarantee the intended lifecycle behavior.

♻️ Proposed lifecycle cleanup
-    while let Some((text_a, _ctx_a)) = iter.next() {
+    while let Some((text_a, ctx_a)) = iter.next() {
         if let Some((_text_b, mut ctx_b)) = iter.next() {
-            // ctx_a is dropped here — its unique KV pages are freed immediately.
+            ctx_a.destroy();
             // ctx_b carries the shared prefix forward into the aggregation.
             let prompt = format!("{}{}", AGGREGATE_PROMPT, text_a);
             ctx_b.user(&prompt);
             ctx_b.cue();
@@
-        }
-        // Odd item out: context dropped, KV pages freed.
+        } else {
+            ctx_a.destroy();
+        }
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
while let Some((text_a, _ctx_a)) = iter.next() {
if let Some((_text_b, mut ctx_b)) = iter.next() {
// ctx_a is dropped here — its unique KV pages are freed immediately.
// ctx_b carries the shared prefix forward into the aggregation.
let prompt = format!("{}{}", AGGREGATE_PROMPT, text_a);
ctx_b.user(&prompt);
ctx_b.cue();
tasks.push(async move {
let text = ctx_b
.generate(Sampler::TopP { temperature: 0.6, p: 0.95 })
.max_tokens(aggregation_tokens)
.collect_text()
.await?;
Ok((text, ctx_b))
});
}
// Odd item out: context dropped, KV pages freed.
}
while let Some((text_a, ctx_a)) = iter.next() {
if let Some((_text_b, mut ctx_b)) = iter.next() {
ctx_a.destroy();
// ctx_b carries the shared prefix forward into the aggregation.
let prompt = format!("{}{}", AGGREGATE_PROMPT, text_a);
ctx_b.user(&prompt);
ctx_b.cue();
tasks.push(async move {
let text = ctx_b
.generate(Sampler::TopP { temperature: 0.6, p: 0.95 })
.max_tokens(aggregation_tokens)
.collect_text()
.await?;
Ok((text, ctx_b))
});
} else {
ctx_a.destroy();
}
}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@inferlets/graph-of-thought/src/lib.rs` around lines 65 - 82, The aggregation
loop in lib.rs assumes contexts are freed immediately when dropped, but Context
relies on destroy(self) for deterministic KV-page release. Update the logic
around the while let Some((text_a, _ctx_a)) pattern so discarded contexts are
explicitly destroyed via Context::destroy(), including the unused ctx_a and any
odd trailing context, while keeping ctx_b alive only for the async task in the
aggregation flow.


Ok(tasks)
}

/// Main logic for running the hierarchical aggregation workflow.
async fn run_hierarchical_aggregation(
Expand All @@ -47,12 +91,12 @@ async fn run_hierarchical_aggregation(
proposal_tokens: Vec<usize>,
aggregation_tokens: usize,
) -> Result<Vec<String>> {
// --- Stage 1: Generate Initial Proposals ---
// --- Stage 1: Generate Initial Proposals (all concurrent) ---
let propose_prompt = PROPOSAL_PROMPT_TEMPLATE.replace("{}", question);
base_context.user(&propose_prompt);
base_context.flush().await?;

let mut proposal_tasks = proposal_tokens
let proposal_futures = proposal_tokens
Comment thread
coderabbitai[bot] marked this conversation as resolved.
.into_iter()
.map(|max_tokens| {
let mut ctx = base_context.fork()?;
Expand All @@ -66,62 +110,35 @@ async fn run_hierarchical_aggregation(
Ok::<_, String>((proposal_text, ctx))
})
})
.collect::<Result<Vec<_>>>()?
.collect::<Result<Vec<_>>>()?;

// Collect ALL proposals before pairing. This lets Stage 2 launch all
// aggregation pairs simultaneously rather than as each proposal dribbles in.
let proposals: Vec<(String, Context)> = future::join_all(proposal_futures)
.await
.into_iter()
.collect::<FuturesUnordered<_>>();

// --- Stage 2: First-Level Aggregation (Pairing Proposals) ---
let mut first_aggregation_tasks = FuturesUnordered::new();
let mut pending_proposal: Option<(String, Context)> = None;

while let Some(result) = proposal_tasks.next().await {
let (proposal_text, mut proposal_ctx) = result?;
if pending_proposal.is_none() {
pending_proposal = Some((proposal_text, proposal_ctx));
} else {
let (previous_proposal_text, _) = pending_proposal.take().unwrap();
let aggregation_prompt = format!("{}{}", AGGREGATE_PROMPT, previous_proposal_text);
proposal_ctx.user(&aggregation_prompt);
proposal_ctx.cue();

first_aggregation_tasks.push(async move {
let aggregation_text = proposal_ctx
.generate(Sampler::TopP { temperature: 0.6, p: 0.95 })
.max_tokens(aggregation_tokens)
.collect_text()
.await?;
Ok::<_, String>((aggregation_text, proposal_ctx))
});
}
}
.collect::<Result<_>>()?;

// --- Stage 2: First-Level Aggregation (all pairs launched at once) ---
// All N/2 aggregation flush() calls hit the scheduler simultaneously,
// which can coalesce them into a single batched forward pass.
let first_agg_futures = launch_aggregation_pairs(proposals, aggregation_tokens)?;
let first_aggregations: Vec<(String, Context)> = future::join_all(first_agg_futures)
.await
.into_iter()
.collect::<Result<_>>()?;

// --- Stage 3: Second-Level Aggregation (Pairing Aggregations) ---
let mut second_aggregation_tasks = Vec::new();
let mut pending_aggregation: Option<(String, Context)> = None;

while let Some(result) = first_aggregation_tasks.next().await {
let (aggregation_text, mut aggregation_ctx) = result?;
if pending_aggregation.is_none() {
pending_aggregation = Some((aggregation_text, aggregation_ctx));
} else {
let (previous_aggregation_text, _) = pending_aggregation.take().unwrap();
let final_prompt = format!("{}{}", AGGREGATE_PROMPT, previous_aggregation_text);
aggregation_ctx.user(&final_prompt);
aggregation_ctx.cue();

second_aggregation_tasks.push(async move {
aggregation_ctx
.generate(Sampler::TopP { temperature: 0.6, p: 0.95 })
.max_tokens(aggregation_tokens)
.collect_text()
.await
});
}
}
// --- Stage 3: Second-Level Aggregation (all pairs launched at once) ---
let second_agg_futures = launch_aggregation_pairs(first_aggregations, aggregation_tokens)?;
let final_results: Vec<String> = future::join_all(second_agg_futures)
.await
.into_iter()
.collect::<Result<Vec<(String, Context)>>>()?
.into_iter()
.map(|(text, _ctx)| text)
.collect();

// --- Stage 4: Collect Final Results ---
let results = future::join_all(second_aggregation_tasks).await;
results.into_iter().collect::<Result<Vec<_>>>()
Ok(final_results)
}

#[inferlet::main]
Expand All @@ -140,6 +157,10 @@ async fn main(input: Input) -> Result<String> {
proposal_tokens, aggregation_tokens
);

if proposal_tokens.len() < 4 {
return Err("GoT requires at least 4 proposals for two aggregation levels".to_string());
}

let models = runtime::models();
let model_name = models.first().ok_or("No models available")?;
let model = Model::load(model_name)?;
Expand Down