diff --git a/src/ahp_graph/SSTGraph.py b/src/ahp_graph/SSTGraph.py index 96bc588..9b00934 100644 --- a/src/ahp_graph/SSTGraph.py +++ b/src/ahp_graph/SSTGraph.py @@ -98,12 +98,14 @@ def write_json(self, self.__write_model( f"{output}/{filename}", nranks, + rank, program_options) else: (base, ext) = os.path.splitext(f"{output}/{filename}") self.__write_model( base + str(rank) + ext, nranks, + rank, program_options) @staticmethod @@ -213,6 +215,7 @@ def recurseSubcomponents(dev: Device, comp: 'sst.Component') -> None: def __write_model(self, filename: str, nranks: int, + rank: int, program_options: dict = None) -> None: """ Write this DeviceGraph out as JSON. @@ -228,19 +231,21 @@ def __write_model(self, model["program_options"] = dict(program_options) # - # If running in parallel, then set up the SST SELF partitioner. + # If running in parallel, then set up the SST LINEAR partitioner. # if nranks > 1: - model["program_options"]["partitioner"] = "sst.self" + model["program_options"]["partitioner"] = "sst.linear" # # Set up global parameters. # - global_params = self.__encode(self.attr, True) - model["global_params"] = dict() - for (key, val) in global_params.items(): - model["global_params"][key] = dict({key: val}) - global_set = list(global_params.keys()) + # TODO: Revert if needed, but is causing parsing issue with + # SST 15.1+ + # global_params = self.__encode(self.attr, True) + # model["global_params"] = dict() + # for (key, val) in global_params.items(): + # model["global_params"][key] = dict({key: val}) + # global_set = list(global_params.keys()) def recurseSubcomponents(dev: Device) -> list: """Add subcomponents to the Device.""" @@ -257,7 +262,10 @@ def recurseSubcomponents(dev: Device) -> list: "type" : d1.library, "slot_number" : s1, "params" : self.__encode(d1.attr, True), - "params_global_sets" : global_set, + + # TODO: Revert if needed, but is causing parsing issue with + # SST 15.1+ + # "params_global_sets" : global_set, } if d1.subs: item["subcomponents"] = recurseSubcomponents(d1) @@ -276,7 +284,10 @@ def recurseSubcomponents(dev: Device) -> list: "name" : d0.name, "type" : d0.library, "params" : self.__encode(d0.attr, True), - "params_global_sets" : global_set, + + # TODO: Revert if needed, but is causing parsing issue with + # SST 15.1+ + # "params_global_sets" : global_set, } if d0.partition is not None: component["partition"] = { @@ -297,8 +308,6 @@ def recurseSubcomponents(dev: Device) -> list: # links = list() for ((p0,p1),t) in self.links.items(): - #assert p0.device.library is not None - #assert p1.device.library is not None if p0.device.library is None: raise RuntimeError(f"No SST library: {p0.device.name}") if p1.device.library is None: @@ -309,19 +318,57 @@ def recurseSubcomponents(dev: Device) -> list: name = f'{p0}__{t}__{p1}' else: name = f'{p1}__{t}__{p0}' - links.append({ - "name" : name, - "left" : { - "component" : p0.device.name, - "port" : p0.get_name(), - "latency" : latency - }, - "right" : { - "component" : p1.device.name, - "port" : p1.get_name(), - "latency" : latency - }, - }) + + d0 = p0.device + d1 = p1.device + r0, t0 = (d0.partition or (0, None)) + r1, t1 = (d1.partition or (0, None)) + t0 = 0 if t0 is None else t0 + t1 = 0 if t1 is None else t1 + + # Determine if link is inter-rank relative to this file's rank + is_nonlocal = nranks > 1 and r0 != r1 + + if is_nonlocal: + # Ensure left side is the local endpoint for this rank + if r0 == rank: + left = {"component": d0.name, "port": p0.get_name(), "latency": latency} + right = {"rank": r1, "thread": t1} + elif r1 == rank: + left = {"component": d1.name, "port": p1.get_name(), "latency": latency} + right = {"rank": r0, "thread": t0} + else: + # Neither endpoint is local to this rank, skip emitting + # This shouldn't really happen, but added as a safeguard. + continue + + links.append({ + "name": name, + "noCut": False, + "nonlocal": True, + "left": left, + "right": right, + }) + else: + # Local link (same-rank or single-rank output): keep both endpoints + # If multi-rank, only keep links where both endpoints are on this rank + if nranks > 1 and not (r0 == rank and r1 == rank): + continue + links.append({ + "name": name, + "noCut": False, + "nonlocal": False, + "left": { + "component": d0.name, + "port": p0.get_name(), + "latency": latency, + }, + "right": { + "component": d1.name, + "port": p1.get_name(), + "latency": latency, + }, + }) model["links"] = links