From 2d61b104e1c601c8763136fa9bfe706d22f40f47 Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Thu, 16 Apr 2026 09:54:33 -0700 Subject: [PATCH] Fix scoped bindings for expanded parallel function calls --- crates/lib/dag-builder/src/expansion.rs | 290 +++++++++++++++++++++++- crates/lib/runner/src/executor.rs | 94 +++++++- 2 files changed, 379 insertions(+), 5 deletions(-) diff --git a/crates/lib/dag-builder/src/expansion.rs b/crates/lib/dag-builder/src/expansion.rs index eb65a5bb0..14e1c50fc 100644 --- a/crates/lib/dag-builder/src/expansion.rs +++ b/crates/lib/dag-builder/src/expansion.rs @@ -8,6 +8,7 @@ use super::converter::DAGConverter; use waymark_dag::AssignmentNode; use waymark_dag::DAGNode; use waymark_dag::{DAG, DAGEdge, EdgeType}; +use waymark_proto::ast as ir; /// Inline function calls and remap expansion edges. impl DAGConverter { @@ -122,6 +123,10 @@ impl DAGConverter { let is_entry_function = id_prefix.is_none(); let fn_node_ids: HashSet = fn_nodes.keys().cloned().collect(); let ordered_nodes = self.get_topo_order(unexpanded, &fn_node_ids); + let scope_var_map = id_prefix + .as_ref() + .map(|prefix| self.build_scope_var_map(&fn_nodes, prefix)) + .unwrap_or_default(); let mut id_map: HashMap = HashMap::new(); let mut first_real_node: Option = None; @@ -178,13 +183,25 @@ impl DAGConverter { && let Some(io) = &fn_def.io && !io.inputs.is_empty() { - let kwarg_exprs = fn_node.kwarg_exprs.clone(); + let child_scope_var_map = self.build_scope_var_map( + &unexpanded.get_nodes_for_function(&called_fn), + &child_prefix, + ); + let mut kwarg_exprs = fn_node.kwarg_exprs.clone(); + for expr in kwarg_exprs.values_mut() { + Self::rewrite_expr_scope(expr, &scope_var_map); + } for (idx, input_name) in io.inputs.iter().enumerate() { if let Some(expr) = kwarg_exprs.get(input_name) { let bind_id = format!("{child_prefix}:bind_{input_name}_{idx}"); let bind_node = AssignmentNode::new( bind_id.clone(), - vec![input_name.clone()], + vec![ + child_scope_var_map + .get(input_name) + .cloned() + .unwrap_or_else(|| input_name.clone()), + ], None, Some(expr.clone()), None, @@ -242,9 +259,24 @@ impl DAGConverter { } let fn_call_targets = if let Some(targets) = &fn_node.targets { - Some(targets.clone()) + Some( + targets + .iter() + .map(|target| { + scope_var_map + .get(target) + .cloned() + .unwrap_or_else(|| target.clone()) + }) + .collect(), + ) } else if let Some(target) = &fn_node.target { - Some(vec![target.clone()]) + Some(vec![ + scope_var_map + .get(target) + .cloned() + .unwrap_or_else(|| target.clone()), + ]) } else { None }; @@ -305,6 +337,9 @@ impl DAGConverter { agg_node.aggregates_from = format!("{prefix}:{}", agg_node.aggregates_from); } } + if !scope_var_map.is_empty() { + self.rewrite_node_scope_vars(&mut cloned, &scope_var_map); + } id_map.insert(old_id.clone(), new_id.clone()); @@ -356,6 +391,9 @@ impl DAGConverter { let mut cloned_edge = edge.clone(); cloned_edge.source = new_source; cloned_edge.target = new_target; + if !scope_var_map.is_empty() { + Self::rewrite_edge_scope_vars(&mut cloned_edge, &scope_var_map); + } target.add_edge(cloned_edge); } @@ -434,6 +472,250 @@ impl DAGConverter { DAGNode::Expression(node) => node.id = new_id.to_string(), } } + + fn build_scope_var_map( + &self, + fn_nodes: &HashMap, + prefix: &str, + ) -> HashMap { + let mut vars = HashSet::new(); + for node in fn_nodes.values() { + if let DAGNode::Input(node) = node { + vars.extend(node.io_vars.iter().cloned()); + } + if let DAGNode::Output(node) = node { + vars.extend(node.io_vars.iter().cloned()); + } + vars.extend(Self::targets_for_node(node)); + } + + vars.into_iter() + .map(|name| (name.clone(), Self::scoped_var_name(prefix, &name))) + .collect() + } + + fn scoped_var_name(prefix: &str, var_name: &str) -> String { + format!("{prefix}::{var_name}") + } + + fn rewrite_node_scope_vars(&self, node: &mut DAGNode, scope_var_map: &HashMap) { + match node { + DAGNode::Input(node) => { + Self::rewrite_names(&mut node.io_vars, scope_var_map); + } + DAGNode::Output(node) => { + Self::rewrite_names(&mut node.io_vars, scope_var_map); + } + DAGNode::Assignment(node) => { + Self::rewrite_names(&mut node.targets, scope_var_map); + Self::rewrite_optional_name(&mut node.target, scope_var_map); + if let Some(expr) = &mut node.assign_expr { + Self::rewrite_expr_scope(expr, scope_var_map); + } + } + DAGNode::ActionCall(node) => { + if let Some(targets) = &mut node.targets { + Self::rewrite_names(targets, scope_var_map); + } + Self::rewrite_optional_name(&mut node.target, scope_var_map); + Self::rewrite_optional_name(&mut node.spread_loop_var, scope_var_map); + for expr in node.kwarg_exprs.values_mut() { + Self::rewrite_expr_scope(expr, scope_var_map); + } + if let Some(expr) = &mut node.spread_collection_expr { + Self::rewrite_expr_scope(expr, scope_var_map); + } + node.kwargs = node + .kwarg_exprs + .iter() + .map(|(name, expr)| (name.clone(), self.expr_to_string(expr))) + .collect(); + } + DAGNode::FnCall(node) => { + if let Some(targets) = &mut node.targets { + Self::rewrite_names(targets, scope_var_map); + } + Self::rewrite_optional_name(&mut node.target, scope_var_map); + for expr in node.kwarg_exprs.values_mut() { + Self::rewrite_expr_scope(expr, scope_var_map); + } + if let Some(expr) = &mut node.assign_expr { + Self::rewrite_expr_scope(expr, scope_var_map); + } + node.kwargs = node + .kwarg_exprs + .iter() + .map(|(name, expr)| (name.clone(), self.expr_to_string(expr))) + .collect(); + } + DAGNode::Aggregator(node) => { + if let Some(targets) = &mut node.targets { + Self::rewrite_names(targets, scope_var_map); + } + Self::rewrite_optional_name(&mut node.target, scope_var_map); + } + DAGNode::Join(node) => { + if let Some(targets) = &mut node.targets { + Self::rewrite_names(targets, scope_var_map); + } + Self::rewrite_optional_name(&mut node.target, scope_var_map); + } + DAGNode::Return(node) => { + if let Some(expr) = &mut node.assign_expr { + Self::rewrite_expr_scope(expr, scope_var_map); + } + if let Some(targets) = &mut node.targets { + Self::rewrite_names(targets, scope_var_map); + } + Self::rewrite_optional_name(&mut node.target, scope_var_map); + } + DAGNode::Sleep(node) => { + if let Some(expr) = &mut node.duration_expr { + Self::rewrite_expr_scope(expr, scope_var_map); + } + } + DAGNode::Parallel(_) + | DAGNode::Branch(_) + | DAGNode::Break(_) + | DAGNode::Continue(_) + | DAGNode::Expression(_) => {} + } + } + + fn rewrite_names(names: &mut Vec, scope_var_map: &HashMap) { + for name in names { + if let Some(scoped) = scope_var_map.get(name) { + *name = scoped.clone(); + } + } + } + + fn rewrite_optional_name(name: &mut Option, scope_var_map: &HashMap) { + if let Some(value) = name.as_mut() + && let Some(scoped) = scope_var_map.get(value) + { + *value = scoped.clone(); + } + } + + fn rewrite_edge_scope_vars(edge: &mut DAGEdge, scope_var_map: &HashMap) { + if let Some(expr) = &mut edge.guard_expr { + Self::rewrite_expr_scope(expr, scope_var_map); + } + Self::rewrite_optional_name(&mut edge.variable, scope_var_map); + } + + fn rewrite_expr_scope(expr: &mut ir::Expr, scope_var_map: &HashMap) { + let Some(kind) = expr.kind.as_mut() else { + return; + }; + + match kind { + ir::expr::Kind::Literal(_) => {} + ir::expr::Kind::Variable(var) => { + if let Some(scoped) = scope_var_map.get(&var.name) { + var.name = scoped.clone(); + } + } + ir::expr::Kind::BinaryOp(op) => { + if let Some(left) = op.left.as_mut() { + Self::rewrite_expr_scope(left, scope_var_map); + } + if let Some(right) = op.right.as_mut() { + Self::rewrite_expr_scope(right, scope_var_map); + } + } + ir::expr::Kind::UnaryOp(op) => { + if let Some(operand) = op.operand.as_mut() { + Self::rewrite_expr_scope(operand, scope_var_map); + } + } + ir::expr::Kind::List(list) => { + for element in &mut list.elements { + Self::rewrite_expr_scope(element, scope_var_map); + } + } + ir::expr::Kind::Dict(dict_expr) => { + for entry in &mut dict_expr.entries { + if let Some(key) = entry.key.as_mut() { + Self::rewrite_expr_scope(key, scope_var_map); + } + if let Some(value) = entry.value.as_mut() { + Self::rewrite_expr_scope(value, scope_var_map); + } + } + } + ir::expr::Kind::Index(index) => { + if let Some(object) = index.object.as_mut() { + Self::rewrite_expr_scope(object, scope_var_map); + } + if let Some(index_expr) = index.index.as_mut() { + Self::rewrite_expr_scope(index_expr, scope_var_map); + } + } + ir::expr::Kind::Dot(dot) => { + if let Some(object) = dot.object.as_mut() { + Self::rewrite_expr_scope(object, scope_var_map); + } + } + ir::expr::Kind::FunctionCall(call) => { + for arg in &mut call.args { + Self::rewrite_expr_scope(arg, scope_var_map); + } + for kwarg in &mut call.kwargs { + if let Some(value) = kwarg.value.as_mut() { + Self::rewrite_expr_scope(value, scope_var_map); + } + } + } + ir::expr::Kind::ActionCall(action) => { + for kwarg in &mut action.kwargs { + if let Some(value) = kwarg.value.as_mut() { + Self::rewrite_expr_scope(value, scope_var_map); + } + } + } + ir::expr::Kind::ParallelExpr(parallel) => { + for call in &mut parallel.calls { + match call.kind.as_mut() { + Some(ir::call::Kind::Action(action)) => { + for kwarg in &mut action.kwargs { + if let Some(value) = kwarg.value.as_mut() { + Self::rewrite_expr_scope(value, scope_var_map); + } + } + } + Some(ir::call::Kind::Function(function)) => { + for arg in &mut function.args { + Self::rewrite_expr_scope(arg, scope_var_map); + } + for kwarg in &mut function.kwargs { + if let Some(value) = kwarg.value.as_mut() { + Self::rewrite_expr_scope(value, scope_var_map); + } + } + } + None => {} + } + } + } + ir::expr::Kind::SpreadExpr(spread) => { + if let Some(scoped) = scope_var_map.get(&spread.loop_var) { + spread.loop_var = scoped.clone(); + } + if let Some(collection) = spread.collection.as_mut() { + Self::rewrite_expr_scope(collection, scope_var_map); + } + if let Some(action) = spread.action.as_mut() { + for kwarg in &mut action.kwargs { + if let Some(value) = kwarg.value.as_mut() { + Self::rewrite_expr_scope(value, scope_var_map); + } + } + } + } + } + } } #[cfg(test)] diff --git a/crates/lib/runner/src/executor.rs b/crates/lib/runner/src/executor.rs index ed5d1b1ef..7a8048e8b 100644 --- a/crates/lib/runner/src/executor.rs +++ b/crates/lib/runner/src/executor.rs @@ -1510,7 +1510,10 @@ mod tests { use waymark_dag_builder::convert_to_dag; use waymark_ir_parser::parse_program; use waymark_proto::ast as ir; - use waymark_runner_state::{ExecutionEdge, ExecutionNode, NodeStatus, RunnerState}; + use waymark_runner_state::value_visitor::ValueExpr; + use waymark_runner_state::{ + ExecutionEdge, ExecutionNode, LiteralValue, NodeStatus, RunnerState, + }; fn variable(name: &str) -> ir::Expr { ir::Expr { @@ -2948,6 +2951,95 @@ fn main(input: [], output: [done]): assert!(agg_nodes[0].assignments.contains_key("results")); } + #[test] + fn test_increment_resolves_parallel_nested_function_kwargs() { + let dag = dag_from_ir_source( + r#" +fn main(input: [payload], output: [result]): + @tests.fixtures.ticket.start() + alpha_result, beta_result = parallel: + run_alpha(payload.alpha) + run_beta(payload.beta) + result = [alpha_result, beta_result] + return result + +fn run_alpha(input: [payload], output: [alpha_plan]): + alpha_plan = @tests.fixtures.ticket.alpha_prepare(items=payload.items) + return alpha_plan + +fn run_beta(input: [payload], output: [beta_plan]): + beta_plan = @tests.fixtures.ticket.beta_prepare(count=payload.config.count, flag=payload.config.flag) + return beta_plan +"#, + ); + + let mut state = RunnerState::from_dag(Arc::clone(&dag)); + state + .record_assignment_value( + vec!["payload".to_string()], + ValueExpr::Literal(LiteralValue { + value: serde_json::json!({ + "label": "demo", + "alpha": {"items": ["a", "b"]}, + "beta": {"config": {"count": 2, "flag": true}}, + }), + }), + None, + Some("input payload".to_string()), + ) + .expect("record payload assignment"); + let entry_template = dag.entry_node.as_ref().expect("dag entry node"); + let entry_exec = state + .queue_template_node(entry_template, None) + .expect("queue entry node"); + + let mut executor = + RunnerExecutor::without_updates_collection(Arc::clone(&dag), state, HashMap::new()); + + let step1 = executor + .increment(&[entry_exec.node_id]) + .expect("increment initial action"); + assert_eq!(step1.actions.len(), 1); + + executor.set_action_result( + step1.actions[0].node_id, + UncheckedExecutionResult(Value::Bool(true)), + ); + + let step2 = executor + .increment(&[step1.actions[0].node_id]) + .expect("increment parallel function calls"); + assert_eq!(step2.actions.len(), 2); + + let mut kwargs_by_action = HashMap::new(); + for action in &step2.actions { + let spec = action.action.as_ref().expect("action spec"); + let kwargs = executor + .resolve_action_kwargs(action.node_id, spec) + .expect("resolve scoped action kwargs"); + kwargs_by_action.insert(spec.action_name.clone(), kwargs); + } + + assert_eq!( + kwargs_by_action + .get("alpha_prepare") + .and_then(|kwargs| kwargs.get("items")), + Some(&serde_json::json!(["a", "b"])), + ); + assert_eq!( + kwargs_by_action + .get("beta_prepare") + .and_then(|kwargs| kwargs.get("count")), + Some(&Value::Number(2.into())), + ); + assert_eq!( + kwargs_by_action + .get("beta_prepare") + .and_then(|kwargs| kwargs.get("flag")), + Some(&Value::Bool(true)), + ); + } + #[test] fn test_rehydrate_timeline_ordering_preserved() { let mut dag = DAG::default();