diff --git a/dace/sdfg/graph.py b/dace/sdfg/graph.py index f107e2caac..1af788c74b 100644 --- a/dace/sdfg/graph.py +++ b/dace/sdfg/graph.py @@ -8,6 +8,7 @@ from dace.dtypes import deduplicate import dace.serialize from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, TypeVar, Union +from ordered_set import OrderedSet class NodeNotFoundError(Exception): @@ -215,7 +216,7 @@ def __getitem__(self, node: NodeT) -> Iterable[NodeT]: def all_edges(self, *nodes: NodeT) -> Iterable[Edge[EdgeT]]: """Returns an iterable to incoming and outgoing Edge objects.""" - result = set() + result = OrderedSet() for node in nodes: result.update(self.in_edges(node)) result.update(self.out_edges(node)) diff --git a/dace/sdfg/scope.py b/dace/sdfg/scope.py index cd139aaa17..f5a397d96b 100644 --- a/dace/sdfg/scope.py +++ b/dace/sdfg/scope.py @@ -8,6 +8,7 @@ from dace.config import Config from dace.sdfg import nodes as nd from dace.sdfg.state import StateSubgraphView +from ordered_set import OrderedSet ScopeDictType = Dict[nd.Node, List[nd.Node]] @@ -62,16 +63,16 @@ def _scope_subgraph(graph, entry_node, include_entry, include_exit) -> ScopeSubg raise TypeError("Received {}: should be dace.nodes.EntryNode".format(type(entry_node).__name__)) node_to_children = graph.scope_children() if include_exit: - children_nodes = set(node_to_children[entry_node]) + children_nodes = OrderedSet(node_to_children[entry_node]) else: - children_nodes = set(n for n in node_to_children[entry_node] if not isinstance(n, nd.ExitNode)) + children_nodes = OrderedSet(n for n in node_to_children[entry_node] if not isinstance(n, nd.ExitNode)) map_nodes = [node for node in children_nodes if isinstance(node, nd.EntryNode)] while len(map_nodes) > 0: next_map_nodes = [] # Traverse children map nodes for map_node in map_nodes: # Get child map subgraph (1 level) - more_nodes = set(node_to_children[map_node]) + more_nodes = OrderedSet(node_to_children[map_node]) # Unionize children_nodes with new nodes children_nodes |= more_nodes # Add nodes of the next level to next_map_nodes diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index c3596e8f4f..9403f567d4 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1790,9 +1790,13 @@ def add_nested_sdfg( sdfg.update_cfg_list([]) # Make dictionary of autodetect connector types from set - if isinstance(inputs, (set, collections.abc.KeysView)): + # TODO(tehrengruber): Using sets here leads to a situation where self._nodes has a different + # ordering, but to_json from_json restores the order again. Investigate. + if isinstance(inputs, set) or isinstance(outputs, set): + warnings.warn("Using sets as inputs is discouraged as it leads to indeterministic behavior.") + if isinstance(inputs, (set, collections.abc.KeysView, collections.abc.Set)): inputs = {k: None for k in inputs} - if isinstance(outputs, (set, collections.abc.KeysView)): + if isinstance(outputs, (set, collections.abc.KeysView, collections.abc.Set)): outputs = {k: None for k in outputs} s = nd.NestedSDFG( diff --git a/dace/transformation/dataflow/map_fusion_vertical.py b/dace/transformation/dataflow/map_fusion_vertical.py index 89883425cb..92b0b109d5 100644 --- a/dace/transformation/dataflow/map_fusion_vertical.py +++ b/dace/transformation/dataflow/map_fusion_vertical.py @@ -8,6 +8,7 @@ from dace.sdfg import SDFG, SDFGState, graph, nodes, propagation from dace.transformation.dataflow import map_fusion_helper as mfhelper from dace.sdfg.type_inference import infer_expr_type +from ordered_set import OrderedSet @properties.make_properties @@ -405,9 +406,9 @@ def partition_first_outputs( param_repl: Dict[str, str], ) -> Union[ Tuple[ - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], + OrderedSet[graph.MultiConnectorEdge[dace.Memlet]], + OrderedSet[graph.MultiConnectorEdge[dace.Memlet]], + OrderedSet[graph.MultiConnectorEdge[dace.Memlet]], ], None, ]: @@ -447,9 +448,9 @@ def partition_first_outputs( `require_all_intermediates` and by `self.require_exclusive_intermediates`. """ # The three outputs set. - pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = OrderedSet() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = OrderedSet() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = OrderedSet() # Set of intermediate nodes that we have already processed. processed_inter_nodes: Set[nodes.Node] = set() @@ -703,7 +704,7 @@ def partition_first_outputs( def handle_intermediate_set( self, - intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], + intermediate_outputs: OrderedSet[graph.MultiConnectorEdge[dace.Memlet]], state: dace.SDFGState, sdfg: SDFG, first_map_exit: nodes.MapExit, @@ -870,7 +871,7 @@ def handle_intermediate_set( # the input connectors on the MapEntry, such that we know where we # have to reroute inside the Map. # NOTE: Assumes that Map (if connected is the direct neighbour). - conn_names: Set[str] = set() + conn_names: OrderedSet[str] = OrderedSet() for inter_node_out_edge in state.out_edges(inter_node): if inter_node_out_edge.dst == second_map_entry: assert inter_node_out_edge.dst_conn.startswith("IN_") diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 75530224d0..32391a7f90 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -865,10 +865,9 @@ def isolate_nested_sdfg( # These are the nodes that belongs to the Post State. There are two reasons why a # node belongs to the set of post nodes. # The first is that the node does not belong to any other set. - post_nodes: Set[nodes.Node] = { - node - for node in state.nodes() if (node not in pre_nodes) and (node not in middle_nodes) - } + post_nodes: list[nodes.Node] = [ + node for node in state.nodes() if (node not in pre_nodes) and (node not in middle_nodes) + ] # The second reason, are read dependencies, for this we have to look at the incoming # edges and add any node that we need. @@ -881,7 +880,7 @@ def isolate_nested_sdfg( if test_if_applicable: return False raise ValueError("Can not replicate non non-View AccessNodes into the post state.") - post_nodes.add(node) + post_nodes.append(node) if test_if_applicable: return True diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 8785979004..464f9e3703 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -55,13 +55,13 @@ class Pass: CATEGORY: str = 'Helper' - def depends_on(self) -> Set[Union[Type['Pass'], 'Pass']]: + def depends_on(self) -> List[Union[Type['Pass'], 'Pass']]: """ If in the context of a ``Pipeline``, which other Passes need to run first. :return: A set of Pass subclasses or objects that need to run prior to this Pass. """ - return set() + return [] def modifies(self) -> Modifies: """ @@ -412,7 +412,7 @@ class Pipeline(Pass): def __init__(self, passes: List[Pass]): self.passes = [] - self._pass_names = set(type(p).__name__ for p in passes) + self._pass_names = set(type(p).__name__ for p in passes) # todo sort this? self.passes.extend(passes) # Add missing Pass dependencies @@ -482,11 +482,8 @@ def modifies(self) -> Modifies: def should_reapply(self, modified: Modifies) -> bool: return any(p.should_reapply(modified) for p in self.passes) - def depends_on(self) -> Set[Type[Pass]]: - result = set() - for p in self.passes: - result.update(p.depends_on()) - return result + def depends_on(self) -> List[Type[Pass]]: + return list(dict.fromkeys([p.depends_on() for p in self.passes])) def _make_dependency_graph(self) -> gr.OrderedDiGraph: """ diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index d3e9c872ce..e8853b06ba 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -16,6 +16,7 @@ from typing import Dict, Iterable, List, Set, Tuple, Any, Optional, Union import networkx as nx from networkx.algorithms import shortest_paths as nxsp +from ordered_set import OrderedSet from dace.transformation.passes.analysis import loop_analysis @@ -41,9 +42,9 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG def depends_on(self): - return {ControlFlowBlockReachability} + return [ControlFlowBlockReachability] - def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: + def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, OrderedSet[SDFGState]]]: """ :return: A dictionary mapping each state to its other reachable states. """ @@ -52,9 +53,9 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGS cf_block_reach_dict = ControlFlowBlockReachability().apply_pass(top_sdfg, {}) else: cf_block_reach_dict = pipeline_res[ControlFlowBlockReachability.__name__] - reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} + reachable: Dict[int, Dict[SDFGState, OrderedSet[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[SDFGState, Set[SDFGState]] = defaultdict(set) + result: Dict[SDFGState, OrderedSet[SDFGState]] = defaultdict(OrderedSet) for state in sdfg.states(): for reached in cf_block_reach_dict[state.parent_graph.cfg_id][state]: if isinstance(reached, SDFGState): @@ -88,10 +89,10 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def _region_closure( self, region: ControlFlowRegion, - block_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]], - cached_closures: dict[int, Set[ControlFlowBlock]], + block_reach: Dict[int, Dict[ControlFlowBlock, OrderedSet[ControlFlowBlock]]], + cached_closures: dict[int, OrderedSet[ControlFlowBlock]], ) -> Set[ControlFlowBlock]: - closure: Set[ControlFlowBlock] = set() + closure: Set[ControlFlowBlock] = OrderedSet() if isinstance(region, LoopRegion): # Any point inside the loop may reach any other point inside the loop again. # TODO(later): This is an overapproximation. A branch terminating in a break is excluded from this. @@ -114,7 +115,7 @@ def _region_closure( pivot = pivot.parent_graph return closure - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, OrderedSet[ControlFlowBlock]]]: """ :return: For each control flow region, a dictionary mapping each control flow block to its other reachable control flow blocks. @@ -122,13 +123,13 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ top_sdfg.reset_cfg_list() single_level_reachable: Dict[int, Dict[ControlFlowBlock, - Set[ControlFlowBlock]]] = defaultdict(lambda: defaultdict(set)) + OrderedSet[ControlFlowBlock]]] = defaultdict(lambda: defaultdict(set)) for cfg in top_sdfg.all_control_flow_regions(recursive=True): # In networkx this is currently implemented naively for directed graphs. # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) for n, v in reachable_nodes(cfg.nx): - reach = set() + reach = OrderedSet() for nd in v: reach.add(nd) if isinstance(nd, AbstractControlFlowRegion): @@ -140,11 +141,11 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ if self.contain_to_single_level: return single_level_reachable - reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = {} - cached_closures: dict[int, Set[ControlFlowBlock]] = {} + reachable: Dict[int, Dict[ControlFlowBlock, OrderedSet[ControlFlowBlock]]] = {} + cached_closures: dict[int, OrderedSet[ControlFlowBlock]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): for cfg in sdfg.all_control_flow_regions(): - result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(set) + result: Dict[ControlFlowBlock, OrderedSet[ControlFlowBlock]] = defaultdict(OrderedSet) for block in cfg.nodes(): for reached in single_level_reachable[block.parent_graph.cfg_id][block]: if isinstance(reached, AbstractControlFlowRegion): @@ -179,11 +180,11 @@ def _single_shortest_path_length_no_self(adj, source): seen = {} # level (number of hops) when seen in BFS level = 0 # the current level - nextlevel = set(firstlevel) # set of nodes to check at next level + nextlevel = OrderedSet(firstlevel) # set of nodes to check at next level n = len(adj) while nextlevel: thislevel = nextlevel # advance to next level - nextlevel = set() # and start a new set (fringe) + nextlevel = OrderedSet() # and start a new set (fringe) found = [] for v in thislevel: if v not in seen: @@ -225,9 +226,9 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes def apply(self, region: ControlFlowRegion, - _) -> Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]: + _) -> Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], Tuple[OrderedSet[str], OrderedSet[str]]]: adesc = set(region.sdfg.arrays.keys()) - result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} + result: Dict[ControlFlowBlock, Tuple[OrderedSet[str], OrderedSet[str]]] = {} for block in region.nodes(): # No symbols may be written to inside blocks. result[block] = (block.free_symbols, set()) @@ -254,7 +255,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If access nodes were modified, reapply return modified & ppl.Modifies.AccessNodes - def _get_loop_region_readset(self, loop: LoopRegion, arrays: Set[str]) -> Set[str]: + def _get_loop_region_readset(self, loop: LoopRegion, arrays: OrderedSet[str]) -> OrderedSet[str]: readset = set() exprs = {loop.loop_condition.as_string} update_stmt = loop_analysis.get_update_assignment(loop) @@ -267,15 +268,15 @@ def _get_loop_region_readset(self, loop: LoopRegion, arrays: Set[str]) -> Set[st readset |= (symbolic.free_symbols_and_functions(expr) | symbolic.arrays(expr)) & arrays return readset - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[OrderedSet[str], OrderedSet[str]]]: """ :return: A dictionary mapping each control flow block to a tuple of its (read, written) data descriptors. """ - result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} + result: Dict[ControlFlowBlock, Tuple[OrderedSet[str], OrderedSet[str]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - arrays: Set[str] = set(sdfg.arrays.keys()) + arrays: OrderedSet[str] = OrderedSet(sdfg.arrays.keys()) for block in sdfg.all_control_flow_blocks(): - readset, writeset = set(), set() + readset, writeset = OrderedSet(), OrderedSet() if isinstance(block, SDFGState): for anode in block.data_nodes(): if block.in_degree(anode) > 0: @@ -302,7 +303,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str] # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() for e in sdfg.all_interstate_edges(): - fsyms = e.data.free_symbols & anames + fsyms = sorted(e.data.free_symbols & anames) if fsyms: result[e.src][0].update(fsyms) result[e.dst][0].update(fsyms) @@ -325,14 +326,14 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.AccessNodes - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, OrderedSet[SDFGState]]]: """ :return: A dictionary mapping each data descriptor name to states where it can be found in. """ - top_result: Dict[int, Dict[str, Set[SDFGState]]] = {} + top_result: Dict[int, Dict[str, OrderedSet[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[str, Set[SDFGState]] = defaultdict(set) + result: Dict[str, OrderedSet[SDFGState]] = defaultdict(OrderedSet) for state in sdfg.states(): for anode in state.data_nodes(): result[anode.data].add(state) @@ -340,9 +341,9 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() for e in sdfg.all_interstate_edges(): - fsyms = e.data.free_symbols & anames + fsyms = sorted(e.data.free_symbols & anames) for access in fsyms: - result[access].update({e.src, e.dst}) + result[access].update((e.src, e.dst)) top_result[sdfg.cfg_id] = result return top_result @@ -374,18 +375,18 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.AccessNodes & ppl.Modifies.CFG - def apply_pass(self, sdfg: SDFG, _) -> Dict[SDFG, Set[str]]: + def apply_pass(self, sdfg: SDFG, _) -> Dict[SDFG, OrderedSet[str]]: """ :return: A dictionary mapping SDFGs to a `set` of strings containing the name of the data descriptors that are only used once. """ # TODO(pschaad): Should we index on cfg or the SDFG itself. - exclusive_data: Dict[SDFG, Set[str]] = {} + exclusive_data: Dict[SDFG, OrderedSet[str]] = {} for nsdfg in sdfg.all_sdfgs_recursive(): exclusive_data[nsdfg] = self._find_single_use_data_in_sdfg(nsdfg) return exclusive_data - def _find_single_use_data_in_sdfg(self, sdfg: SDFG) -> Set[str]: + def _find_single_use_data_in_sdfg(self, sdfg: SDFG) -> OrderedSet[str]: """Scans an SDFG and computes the data that is only used once in the SDFG. The rules used to classify data descriptors are outlined above. The function @@ -396,8 +397,8 @@ def _find_single_use_data_in_sdfg(self, sdfg: SDFG) -> Set[str]: # If we encounter a data descriptor for the first time we immediately # classify it as single use. We will undo this decision as soon as # learn that it is used somewhere else. - single_use_data: Set[str] = set() - previously_seen: Set[str] = set() + single_use_data: OrderedSet[str] = OrderedSet() + previously_seen: OrderedSet[str] = OrderedSet() for state in sdfg.states(): for dnode in state.data_nodes(): @@ -437,17 +438,19 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.AccessNodes - def apply_pass(self, top_sdfg: SDFG, - _) -> Dict[int, Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]]]: + def apply_pass( + self, top_sdfg: SDFG, + _) -> Dict[int, Dict[str, Dict[SDFGState, Tuple[OrderedSet[nd.AccessNode], OrderedSet[nd.AccessNode]]]]]: """ :return: A dictionary mapping each data descriptor name to a dictionary keyed by states with all access nodes that use that data descriptor. """ - top_result: Dict[int, Dict[str, Set[nd.AccessNode]]] = dict() + top_result: Dict[int, Dict[str, OrderedSet[nd.AccessNode]]] = dict() for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = defaultdict( - lambda: defaultdict(lambda: [set(), set()])) + result: Dict[str, Dict[SDFGState, + Tuple[OrderedSet[nd.AccessNode], OrderedSet[nd.AccessNode]]]] = defaultdict( + lambda: defaultdict(lambda: [OrderedSet(), OrderedSet()])) for state in sdfg.states(): for anode in state.data_nodes(): if state.in_degree(anode) > 0: @@ -475,7 +478,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.Symbols | ppl.Modifies.CFG | ppl.Modifies.Edges | ppl.Modifies.Nodes def depends_on(self): - return {SymbolAccessSets, ControlFlowBlockReachability} + return [SymbolAccessSets, ControlFlowBlockReachability] def _find_dominating_write(self, sym: str, read: Union[ControlFlowBlock, Edge[InterstateEdge]], block_idom: Dict[ControlFlowBlock, ControlFlowBlock]) -> Optional[Edge[InterstateEdge]]: @@ -510,15 +513,16 @@ def _find_dominating_write(self, sym: str, read: Union[ControlFlowBlock, Edge[In return write_isedge def apply(self, region, pipeline_results) -> SymbolScopeDict: - result: SymbolScopeDict = defaultdict(lambda: defaultdict(lambda: set())) + result: SymbolScopeDict = defaultdict(lambda: defaultdict(lambda: OrderedSet())) idom = nx.immediate_dominators(region.nx, region.start_block) all_doms = cfg_analysis.all_dominators(region, idom) - b_reach: Dict[ControlFlowBlock, - Set[ControlFlowBlock]] = pipeline_results[ControlFlowBlockReachability.__name__][region.cfg_id] + b_reach: Dict[ControlFlowBlock, OrderedSet[ControlFlowBlock]] = pipeline_results[ + ControlFlowBlockReachability.__name__][region.cfg_id] symbol_access_sets: Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], - Tuple[Set[str], Set[str]]] = pipeline_results[SymbolAccessSets.__name__][region.cfg_id] + Tuple[OrderedSet[str], + OrderedSet[str]]] = pipeline_results[SymbolAccessSets.__name__][region.cfg_id] for read_loc, (reads, _) in symbol_access_sets.items(): for sym in reads: @@ -552,7 +556,7 @@ def apply(self, region, pipeline_results) -> SymbolScopeDict: other_accesses.update(accesses) other_accesses.add(write) to_remove.add((sym, write)) - result[sym][write] = set() + result[sym][write] = OrderedSet() for sym, write in to_remove: del result[sym][write] @@ -577,15 +581,16 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.States def depends_on(self): - return {AccessSets, FindAccessNodes, ControlFlowBlockReachability} + return [AccessSets, FindAccessNodes, ControlFlowBlockReachability] def _find_dominating_write(self, desc: str, block: ControlFlowBlock, read: Union[nd.AccessNode, InterstateEdge], - access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], + access_nodes: Dict[SDFGState, Tuple[OrderedSet[nd.AccessNode], + OrderedSet[nd.AccessNode]]], idom_dict: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]], - access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]], + access_sets: Dict[ControlFlowBlock, Tuple[OrderedSet[str], OrderedSet[str]]], no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]: if isinstance(read, nd.AccessNode): state: SDFGState = block @@ -652,16 +657,18 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i """ top_result: Dict[int, WriteScopeDict] = dict() - access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = pipeline_results[AccessSets.__name__] + access_sets: Dict[ControlFlowBlock, Tuple[OrderedSet[str], + OrderedSet[str]]] = pipeline_results[AccessSets.__name__] for sdfg in top_sdfg.all_sdfgs_recursive(): - result: WriteScopeDict = defaultdict(lambda: defaultdict(lambda: set())) + result: WriteScopeDict = defaultdict(lambda: defaultdict(lambda: OrderedSet())) idom_dict: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]] = {} - all_doms_transitive: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(lambda: set()) + all_doms_transitive: Dict[ControlFlowBlock, + OrderedSet[ControlFlowBlock]] = defaultdict(lambda: OrderedSet()) for cfg in sdfg.all_control_flow_regions(): if isinstance(cfg, ConditionalBlock): idom_dict[cfg] = {b: b for _, b in cfg.branches} - all_doms = {b: set([b]) for _, b in cfg.branches} + all_doms = {b: OrderedSet([b]) for _, b in cfg.branches} else: idom_dict[cfg] = nx.immediate_dominators(cfg.nx, cfg.start_block) all_doms = cfg_analysis.all_dominators(cfg, idom_dict[cfg]) @@ -673,15 +680,16 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i all_doms_transitive[k].add(cfg) all_doms_transitive[k].update(all_doms_transitive[cfg]) - access_nodes: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = pipeline_results[ - FindAccessNodes.__name__][sdfg.cfg_id] + access_nodes: Dict[str, Dict[SDFGState, Tuple[OrderedSet[nd.AccessNode], + OrderedSet[nd.AccessNode]]]] = pipeline_results[ + FindAccessNodes.__name__][sdfg.cfg_id] block_reach: Dict[ControlFlowBlock, - Set[ControlFlowBlock]] = pipeline_results[ControlFlowBlockReachability.__name__] + OrderedSet[ControlFlowBlock]] = pipeline_results[ControlFlowBlockReachability.__name__] anames = sdfg.arrays.keys() for desc in sdfg.arrays: - desc_states_with_nodes = set(access_nodes[desc].keys()) + desc_states_with_nodes = OrderedSet(access_nodes[desc].keys()) for state in desc_states_with_nodes: for read_node in access_nodes[desc][state][0]: write = self._find_dominating_write(desc, state, read_node, access_nodes, idom_dict, @@ -713,7 +721,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i # If any write A is dominated by another write B and any reads in B's scope are also reachable by A, # then merge A and its scope into B's scope. - to_remove = set() + to_remove = OrderedSet() for write, accesses in result[desc].items(): if write is None: continue @@ -730,7 +738,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i other_accesses.update(accesses) other_accesses.add(write) to_remove.add(write) - result[desc][write] = set() + result[desc][write] = OrderedSet() for write in to_remove: del result[desc][write] top_result[sdfg.cfg_id] = result @@ -752,14 +760,14 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.Memlets - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, OrderedSet[Memlet]]]: """ :return: A dictionary mapping each data descriptor name to a set of memlets. """ - top_result: Dict[int, Dict[str, Set[Memlet]]] = dict() + top_result: Dict[int, Dict[str, OrderedSet[Memlet]]] = dict() for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[str, Set[Memlet]] = defaultdict(set) + result: Dict[str, OrderedSet[Memlet]] = defaultdict(OrderedSet) for state in sdfg.states(): for anode in state.data_nodes(): for e in state.all_edges(anode): @@ -794,17 +802,17 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.Memlets - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, OrderedSet[Union[Memlet, nd.CodeNode]]]]: """ :return: A dictionary mapping each data descriptor name to a set of memlets. """ - top_result: Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]] = dict() + top_result: Dict[int, Dict[str, OrderedSet[Union[Memlet, nd.CodeNode]]]] = dict() for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[str, Set[Memlet]] = defaultdict(set) - reference_descs = set(k for k, v in sdfg.arrays.items() if isinstance(v, dt.Reference)) + result: Dict[str, OrderedSet[Memlet]] = defaultdict(OrderedSet) + reference_descs = OrderedSet(k for k, v in sdfg.arrays.items() if isinstance(v, dt.Reference)) for state in sdfg.states(): - code_sources: Dict[str, Set[nd.CodeNode]] = defaultdict(set) + code_sources: Dict[str, OrderedSet[nd.CodeNode]] = defaultdict(OrderedSet) for anode in state.data_nodes(): if anode.data not in reference_descs: continue @@ -883,21 +891,22 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.Everything - def _derive_parameter_datasize_constraints(self, sdfg: SDFG, invariants: Dict[str, Set[str]]) -> None: - handled = set() + def _derive_parameter_datasize_constraints(self, sdfg: SDFG, invariants: Dict[str, OrderedSet[str]]) -> None: + handled = OrderedSet() for arr in sdfg.arrays.values(): for dim in arr.shape: if isinstance(dim, symbolic.symbol) and not dim in handled: ds = str(dim) if ds not in invariants: - invariants[ds] = set() + invariants[ds] = OrderedSet() invariants[ds].add(f'{ds} > 0') if self.assume_max_data_size is not None: invariants[ds].add(f'{ds} <= {self.assume_max_data_size}') handled.add(ds) - def apply_pass(self, sdfg: SDFG, _) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]], Dict[str, Set[str]]]: - invariants: Dict[str, Set[str]] = {} + def apply_pass(self, sdfg: SDFG, + _) -> Tuple[Dict[str, OrderedSet[str]], Dict[str, OrderedSet[str]], Dict[str, OrderedSet[str]]]: + invariants: Dict[str, OrderedSet[str]] = {} self._derive_parameter_datasize_constraints(sdfg, invariants) return {}, invariants, {} @@ -925,11 +934,11 @@ def __init__(self): self.apply_to_conditionals = True def depends_on(self): - return {ControlFlowBlockReachability} + return [ControlFlowBlockReachability] - def _propagate_in_cfg(self, cfg: ControlFlowRegion, reachable: Dict[ControlFlowBlock, Set[ControlFlowBlock]], + def _propagate_in_cfg(self, cfg: ControlFlowRegion, reachable: Dict[ControlFlowBlock, OrderedSet[ControlFlowBlock]], starting_executions: int, starting_dynamic_executions: bool): - visited_blocks: Set[ControlFlowBlock] = set() + visited_blocks: OrderedSet[ControlFlowBlock] = OrderedSet() traversal_q: deque[Tuple[ControlFlowBlock, int, bool, List[str]]] = deque() traversal_q.append((cfg.start_block, starting_executions, starting_dynamic_executions, [])) while traversal_q: @@ -1079,13 +1088,13 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG def depends_on(self): - return {} + return [] - def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Set[nd.AccessNode]: + def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> OrderedSet[nd.AccessNode]: """ :return: A set of access nodes, which are unique writes in conditional blocks. """ - cond_unique = set() + cond_unique = OrderedSet() for cfb in top_sdfg.all_control_flow_blocks(recursive=True): if not isinstance(cfb, ConditionalBlock): continue @@ -1101,12 +1110,15 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Set[nd.AccessNode]: for st in br.all_states(): for an in st.data_nodes(): array_name = an.data - write_subsets = set(e.data.dst_subset for e in st.in_edges(an)) + write_subsets = OrderedSet(e.data.dst_subset for e in st.in_edges(an)) wss = str(write_subsets) if array_name not in access_write_branch: access_write_branch[array_name] = {} if wss not in access_write_branch[array_name]: - access_write_branch[array_name][wss] = {"branches": set(), "access_nodes": set()} + access_write_branch[array_name][wss] = { + "branches": OrderedSet(), + "access_nodes": OrderedSet() + } access_write_branch[array_name][wss]["branches"].add(br) access_write_branch[array_name][wss]["access_nodes"].add(an) diff --git a/dace/transformation/passes/analysis/scope_data_and_symbol_analysis.py b/dace/transformation/passes/analysis/scope_data_and_symbol_analysis.py index 16b3b7285b..6824776bf8 100644 --- a/dace/transformation/passes/analysis/scope_data_and_symbol_analysis.py +++ b/dace/transformation/passes/analysis/scope_data_and_symbol_analysis.py @@ -52,7 +52,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & (ppl.Modifies.CFG | ppl.Modifies.SDFG | ppl.Modifies.Nodes) def depends_on(self): - return {} + return [] def apply_pass(self, sdfg: dace.SDFG, pipeline_res: Dict) -> Dict[str, ScopeAnalysis]: """ diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index 3e503dfd22..7fa23c41e7 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -13,6 +13,7 @@ SqueezeViewRemove, UnsqueezeViewRemove, RemoveSliceView) from dace.transformation.passes import analysis as ap from dace.transformation.transformation import SingleStateTransformation +from ordered_set import OrderedSet @properties.make_properties @@ -32,7 +33,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.AccessNodes def depends_on(self): - return {ap.StateReachability, ap.FindAccessStates} + return [ap.StateReachability, ap.FindAccessStates] def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Set[str]]: """ @@ -56,7 +57,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S return None for state in reversed(state_order): # Find all data descriptors that will no longer be used after this state - removable_data: Set[str] = set( + removable_data: OrderedSet[str] = OrderedSet( s for s in access_sets if state in access_sets[s] and not (access_sets[s] & reachable[state]) - {state}) # Find duplicate access nodes as an ordered list diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index ccc9aef418..8610c1ec03 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -45,7 +45,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & (ppl.Modifies.Nodes | ppl.Modifies.Edges | ppl.Modifies.CFG) def depends_on(self) -> Set[Type[ppl.Pass]]: - return {ap.ControlFlowBlockReachability, ap.AccessSets} + return [ap.ControlFlowBlockReachability, ap.AccessSets] def apply(self, region, pipeline_results): """ diff --git a/dace/transformation/passes/full_map_fusion.py b/dace/transformation/passes/full_map_fusion.py index d6b9e02340..749a726f36 100644 --- a/dace/transformation/passes/full_map_fusion.py +++ b/dace/transformation/passes/full_map_fusion.py @@ -162,7 +162,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & (ppl.Modifies.Scopes | ppl.Modifies.AccessNodes | ppl.Modifies.Memlets | ppl.Modifies.States) def depends_on(self): - return {ap.FindSingleUseData} + return [ap.FindSingleUseData] def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[int]: """ diff --git a/dace/transformation/passes/lift_struct_views.py b/dace/transformation/passes/lift_struct_views.py index bd82fcc7ac..8c912aa6eb 100644 --- a/dace/transformation/passes/lift_struct_views.py +++ b/dace/transformation/passes/lift_struct_views.py @@ -355,7 +355,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.AccessNodes & ppl.Modifies.Tasklets & ppl.Modifies.Memlets def depends_on(self): - return {} + return [] def _lift_control_flow_region_access(self, cfg: ControlFlowRegion, result: Dict[str, Set[str]]) -> bool: lifted_something = False diff --git a/dace/transformation/passes/loop_local_memory_reduction.py b/dace/transformation/passes/loop_local_memory_reduction.py index b68a2d937a..ee8f57e2c9 100644 --- a/dace/transformation/passes/loop_local_memory_reduction.py +++ b/dace/transformation/passes/loop_local_memory_reduction.py @@ -114,7 +114,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified != ppl.Modifies.Nothing def depends_on(self): - return {StateReachability, FindAccessStates, ConditionUniqueWrites} + return [StateReachability, FindAccessStates, ConditionUniqueWrites] def apply_pass(self, sdfg: sd.SDFG, pipeline_results: Dict[str, Any]) -> Optional[Set[str]]: self.num_applications = 0 diff --git a/dace/transformation/passes/pattern_matching.py b/dace/transformation/passes/pattern_matching.py index 9f557527f0..2b72a1ced8 100644 --- a/dace/transformation/passes/pattern_matching.py +++ b/dace/transformation/passes/pattern_matching.py @@ -13,7 +13,7 @@ from dace.sdfg.state import ControlFlowRegion import networkx as nx from networkx.algorithms import isomorphism as iso -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union from dace.sdfg.validation import InvalidSDFGError from dace.transformation import transformation as xf, pass_pipeline as ppl @@ -79,11 +79,8 @@ def __init__(self, self.print_report = print_report self.progress = progress - def depends_on(self) -> Set[Type[ppl.Pass]]: - result = set() - for p in self.transformations: - result.update(p.depends_on()) - return result + def depends_on(self) -> List[Type[ppl.Pass]]: + return list(dict.fromkeys([p.depends_on() for p in self.transformations])) def modifies(self) -> ppl.Modifies: result = ppl.Modifies.Nothing diff --git a/dace/transformation/passes/reference_reduction.py b/dace/transformation/passes/reference_reduction.py index ade8be45f4..60fc07189b 100644 --- a/dace/transformation/passes/reference_reduction.py +++ b/dace/transformation/passes/reference_reduction.py @@ -26,7 +26,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.AccessNodes def depends_on(self): - return {ap.FindAccessStates, ap.FindReferenceSources} + return [ap.FindAccessStates, ap.FindReferenceSources] def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Set[str]]: """ diff --git a/dace/transformation/passes/scalar_fission.py b/dace/transformation/passes/scalar_fission.py index 8d88f2752b..76d2324f77 100644 --- a/dace/transformation/passes/scalar_fission.py +++ b/dace/transformation/passes/scalar_fission.py @@ -23,7 +23,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.AccessNodes def depends_on(self): - return {ap.ScalarWriteShadowScopes} + return [ap.ScalarWriteShadowScopes] def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[str, Set[str]]]: """ diff --git a/dace/transformation/passes/split_tasklets.py b/dace/transformation/passes/split_tasklets.py index b77c07504d..dcf008fcba 100644 --- a/dace/transformation/passes/split_tasklets.py +++ b/dace/transformation/passes/split_tasklets.py @@ -105,7 +105,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.Tasklets def depends_on(self): - return {} + return [] tmp_access_identifier = "_split_" diff --git a/dace/transformation/passes/symbol_ssa.py b/dace/transformation/passes/symbol_ssa.py index da0d1cdbb1..ce30981563 100644 --- a/dace/transformation/passes/symbol_ssa.py +++ b/dace/transformation/passes/symbol_ssa.py @@ -23,7 +23,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.Symbols | ppl.Modifies.Edges | ppl.Modifies.Nodes | ppl.Modifies.States def depends_on(self): - return {ap.SymbolWriteScopes} + return [ap.SymbolWriteScopes] def apply(self, region, pipeline_results) -> Optional[Dict[str, Set[str]]]: """ diff --git a/requirements.txt b/requirements.txt index 4ee823a761..eebd1d7c80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ fparser==0.1.4 mpmath==1.3.0 networkx==3.4.2 numpy==1.26.4 +ordered-set==4.1.0 packaging==24.1 ply==3.11 PyYAML==6.0.2 diff --git a/setup.py b/setup.py index 63678f89aa..dcc6aa934a 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ include_package_data=True, install_requires=[ 'numpy', 'networkx >= 2.5, <= 3.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', 'fparser >= 0.1.3', 'dill', - 'pyreadline;platform_system=="Windows"', 'packaging', 'typing-extensions' + 'ordered-set', 'pyreadline;platform_system=="Windows"', 'packaging', 'typing-extensions' ] + cmake_requires, extras_require={ 'ml': ['onnx', 'torch', 'onnxsim', 'onnxscript', 'onnxruntime', 'protobuf', 'ninja'], diff --git a/tests/passes/pipeline_test.py b/tests/passes/pipeline_test.py index 05ec48d845..602ad708d1 100644 --- a/tests/passes/pipeline_test.py +++ b/tests/passes/pipeline_test.py @@ -42,7 +42,7 @@ def test_pipeline_with_dependencies(): class PassA(MyPass): def depends_on(self): - return {MyPass} + return [MyPass] def apply_pass(self, sdfg, pipeline_results): res = super().apply_pass(sdfg, pipeline_results) @@ -70,7 +70,7 @@ def modifies(self) -> ppl.Modifies: class PassA(MyPass): def depends_on(self): - return {MyAnalysis} + return [MyAnalysis] def modifies(self) -> ppl.Modifies: return ppl.Modifies.Descriptors @@ -78,7 +78,7 @@ def modifies(self) -> ppl.Modifies: class PassB(MyPass): def depends_on(self): - return {MyAnalysis} + return [MyAnalysis] def modifies(self) -> ppl.Modifies: return ppl.Modifies.Symbols @@ -86,7 +86,7 @@ def modifies(self) -> ppl.Modifies: class PassC(MyPass): def depends_on(self): - return {MyAnalysis} + return [MyAnalysis] def modifies(self) -> ppl.Modifies: return ppl.Modifies.Everything