Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 288 additions & 4 deletions crates/lib/dag-builder/src/expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::converter::DAGConverter;
use waymark_dag::AssignmentNode;
use waymark_dag::DAGNode;
use waymark_dag::{DAG, DAGEdge, EdgeType};
use waymark_proto::ast as ir;

/// Inline function calls and remap expansion edges.
impl DAGConverter {
Expand Down Expand Up @@ -122,6 +123,10 @@ impl DAGConverter {
let is_entry_function = id_prefix.is_none();
let fn_node_ids: HashSet<String> = fn_nodes.keys().cloned().collect();
let ordered_nodes = self.get_topo_order(unexpanded, &fn_node_ids);
let scope_var_map = id_prefix
.as_ref()
.map(|prefix| self.build_scope_var_map(&fn_nodes, prefix))
.unwrap_or_default();

let mut id_map: HashMap<String, String> = HashMap::new();
let mut first_real_node: Option<String> = None;
Expand Down Expand Up @@ -178,13 +183,25 @@ impl DAGConverter {
&& let Some(io) = &fn_def.io
&& !io.inputs.is_empty()
{
let kwarg_exprs = fn_node.kwarg_exprs.clone();
let child_scope_var_map = self.build_scope_var_map(
&unexpanded.get_nodes_for_function(&called_fn),
&child_prefix,
);
let mut kwarg_exprs = fn_node.kwarg_exprs.clone();
for expr in kwarg_exprs.values_mut() {
Self::rewrite_expr_scope(expr, &scope_var_map);
}
for (idx, input_name) in io.inputs.iter().enumerate() {
if let Some(expr) = kwarg_exprs.get(input_name) {
let bind_id = format!("{child_prefix}:bind_{input_name}_{idx}");
let bind_node = AssignmentNode::new(
bind_id.clone(),
vec![input_name.clone()],
vec![
child_scope_var_map
.get(input_name)
.cloned()
.unwrap_or_else(|| input_name.clone()),
],
None,
Some(expr.clone()),
None,
Expand Down Expand Up @@ -242,9 +259,26 @@ impl DAGConverter {
}

let fn_call_targets = if let Some(targets) = &fn_node.targets {
Some(targets.clone())
Some(
targets
.iter()
.map(|target| {
scope_var_map
.get(target)
.cloned()
.unwrap_or_else(|| target.clone())
})
.collect(),
)
} else {
fn_node.target.clone().map(|target| vec![target])
fn_node.target.as_ref().map(|target| {
vec![
scope_var_map
.get(target)
.cloned()
.unwrap_or_else(|| target.clone()),
]
})
};

if let Some(targets) = fn_call_targets {
Expand Down Expand Up @@ -303,6 +337,9 @@ impl DAGConverter {
agg_node.aggregates_from = format!("{prefix}:{}", agg_node.aggregates_from);
}
}
if !scope_var_map.is_empty() {
self.rewrite_node_scope_vars(&mut cloned, &scope_var_map);
}

id_map.insert(old_id.clone(), new_id.clone());

Expand Down Expand Up @@ -354,6 +391,9 @@ impl DAGConverter {
let mut cloned_edge = edge.clone();
cloned_edge.source = new_source;
cloned_edge.target = new_target;
if !scope_var_map.is_empty() {
Self::rewrite_edge_scope_vars(&mut cloned_edge, &scope_var_map);
}
target.add_edge(cloned_edge);
}

Expand Down Expand Up @@ -432,6 +472,250 @@ impl DAGConverter {
DAGNode::Expression(node) => node.id = new_id.to_string(),
}
}

fn build_scope_var_map(
&self,
fn_nodes: &HashMap<String, DAGNode>,
prefix: &str,
) -> HashMap<String, String> {
let mut vars = HashSet::new();
for node in fn_nodes.values() {
if let DAGNode::Input(node) = node {
vars.extend(node.io_vars.iter().cloned());
}
if let DAGNode::Output(node) = node {
vars.extend(node.io_vars.iter().cloned());
}
vars.extend(Self::targets_for_node(node));
}

vars.into_iter()
.map(|name| (name.clone(), Self::scoped_var_name(prefix, &name)))
.collect()
}

fn scoped_var_name(prefix: &str, var_name: &str) -> String {
format!("{prefix}::{var_name}")
}

fn rewrite_node_scope_vars(&self, node: &mut DAGNode, scope_var_map: &HashMap<String, String>) {
match node {
DAGNode::Input(node) => {
Self::rewrite_names(&mut node.io_vars, scope_var_map);
}
DAGNode::Output(node) => {
Self::rewrite_names(&mut node.io_vars, scope_var_map);
}
DAGNode::Assignment(node) => {
Self::rewrite_names(&mut node.targets, scope_var_map);
Self::rewrite_optional_name(&mut node.target, scope_var_map);
if let Some(expr) = &mut node.assign_expr {
Self::rewrite_expr_scope(expr, scope_var_map);
}
}
DAGNode::ActionCall(node) => {
if let Some(targets) = &mut node.targets {
Self::rewrite_names(targets, scope_var_map);
}
Self::rewrite_optional_name(&mut node.target, scope_var_map);
Self::rewrite_optional_name(&mut node.spread_loop_var, scope_var_map);
for expr in node.kwarg_exprs.values_mut() {
Self::rewrite_expr_scope(expr, scope_var_map);
}
if let Some(expr) = &mut node.spread_collection_expr {
Self::rewrite_expr_scope(expr, scope_var_map);
}
node.kwargs = node
.kwarg_exprs
.iter()
.map(|(name, expr)| (name.clone(), self.expr_to_string(expr)))
.collect();
}
DAGNode::FnCall(node) => {
if let Some(targets) = &mut node.targets {
Self::rewrite_names(targets, scope_var_map);
}
Self::rewrite_optional_name(&mut node.target, scope_var_map);
for expr in node.kwarg_exprs.values_mut() {
Self::rewrite_expr_scope(expr, scope_var_map);
}
if let Some(expr) = &mut node.assign_expr {
Self::rewrite_expr_scope(expr, scope_var_map);
}
node.kwargs = node
.kwarg_exprs
.iter()
.map(|(name, expr)| (name.clone(), self.expr_to_string(expr)))
.collect();
}
DAGNode::Aggregator(node) => {
if let Some(targets) = &mut node.targets {
Self::rewrite_names(targets, scope_var_map);
}
Self::rewrite_optional_name(&mut node.target, scope_var_map);
}
DAGNode::Join(node) => {
if let Some(targets) = &mut node.targets {
Self::rewrite_names(targets, scope_var_map);
}
Self::rewrite_optional_name(&mut node.target, scope_var_map);
}
DAGNode::Return(node) => {
if let Some(expr) = &mut node.assign_expr {
Self::rewrite_expr_scope(expr, scope_var_map);
}
if let Some(targets) = &mut node.targets {
Self::rewrite_names(targets, scope_var_map);
}
Self::rewrite_optional_name(&mut node.target, scope_var_map);
}
DAGNode::Sleep(node) => {
if let Some(expr) = &mut node.duration_expr {
Self::rewrite_expr_scope(expr, scope_var_map);
}
}
DAGNode::Parallel(_)
| DAGNode::Branch(_)
| DAGNode::Break(_)
| DAGNode::Continue(_)
| DAGNode::Expression(_) => {}
}
}

fn rewrite_names(names: &mut Vec<String>, scope_var_map: &HashMap<String, String>) {
for name in names {
if let Some(scoped) = scope_var_map.get(name) {
*name = scoped.clone();
}
}
}

fn rewrite_optional_name(name: &mut Option<String>, scope_var_map: &HashMap<String, String>) {
if let Some(value) = name.as_mut()
&& let Some(scoped) = scope_var_map.get(value)
{
*value = scoped.clone();
}
}

fn rewrite_edge_scope_vars(edge: &mut DAGEdge, scope_var_map: &HashMap<String, String>) {
if let Some(expr) = &mut edge.guard_expr {
Self::rewrite_expr_scope(expr, scope_var_map);
}
Self::rewrite_optional_name(&mut edge.variable, scope_var_map);
}

fn rewrite_expr_scope(expr: &mut ir::Expr, scope_var_map: &HashMap<String, String>) {
let Some(kind) = expr.kind.as_mut() else {
return;
};

match kind {
ir::expr::Kind::Literal(_) => {}
ir::expr::Kind::Variable(var) => {
if let Some(scoped) = scope_var_map.get(&var.name) {
var.name = scoped.clone();
}
}
ir::expr::Kind::BinaryOp(op) => {
if let Some(left) = op.left.as_mut() {
Self::rewrite_expr_scope(left, scope_var_map);
}
if let Some(right) = op.right.as_mut() {
Self::rewrite_expr_scope(right, scope_var_map);
}
}
ir::expr::Kind::UnaryOp(op) => {
if let Some(operand) = op.operand.as_mut() {
Self::rewrite_expr_scope(operand, scope_var_map);
}
}
ir::expr::Kind::List(list) => {
for element in &mut list.elements {
Self::rewrite_expr_scope(element, scope_var_map);
}
}
ir::expr::Kind::Dict(dict_expr) => {
for entry in &mut dict_expr.entries {
if let Some(key) = entry.key.as_mut() {
Self::rewrite_expr_scope(key, scope_var_map);
}
if let Some(value) = entry.value.as_mut() {
Self::rewrite_expr_scope(value, scope_var_map);
}
}
}
ir::expr::Kind::Index(index) => {
if let Some(object) = index.object.as_mut() {
Self::rewrite_expr_scope(object, scope_var_map);
}
if let Some(index_expr) = index.index.as_mut() {
Self::rewrite_expr_scope(index_expr, scope_var_map);
}
}
ir::expr::Kind::Dot(dot) => {
if let Some(object) = dot.object.as_mut() {
Self::rewrite_expr_scope(object, scope_var_map);
}
}
ir::expr::Kind::FunctionCall(call) => {
for arg in &mut call.args {
Self::rewrite_expr_scope(arg, scope_var_map);
}
for kwarg in &mut call.kwargs {
if let Some(value) = kwarg.value.as_mut() {
Self::rewrite_expr_scope(value, scope_var_map);
}
}
}
ir::expr::Kind::ActionCall(action) => {
for kwarg in &mut action.kwargs {
if let Some(value) = kwarg.value.as_mut() {
Self::rewrite_expr_scope(value, scope_var_map);
}
}
}
ir::expr::Kind::ParallelExpr(parallel) => {
for call in &mut parallel.calls {
match call.kind.as_mut() {
Some(ir::call::Kind::Action(action)) => {
for kwarg in &mut action.kwargs {
if let Some(value) = kwarg.value.as_mut() {
Self::rewrite_expr_scope(value, scope_var_map);
}
}
}
Some(ir::call::Kind::Function(function)) => {
for arg in &mut function.args {
Self::rewrite_expr_scope(arg, scope_var_map);
}
for kwarg in &mut function.kwargs {
if let Some(value) = kwarg.value.as_mut() {
Self::rewrite_expr_scope(value, scope_var_map);
}
}
}
None => {}
}
}
}
ir::expr::Kind::SpreadExpr(spread) => {
if let Some(scoped) = scope_var_map.get(&spread.loop_var) {
spread.loop_var = scoped.clone();
}
if let Some(collection) = spread.collection.as_mut() {
Self::rewrite_expr_scope(collection, scope_var_map);
}
if let Some(action) = spread.action.as_mut() {
for kwarg in &mut action.kwargs {
if let Some(value) = kwarg.value.as_mut() {
Self::rewrite_expr_scope(value, scope_var_map);
}
}
}
}
}
}
}

#[cfg(test)]
Expand Down
Loading
Loading