From d7735b454579bd167360fc01b7ba288d150a8c25 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 14 Oct 2020 20:31:29 +0200 Subject: [PATCH 1/5] Add fvm_advect stencil proposals --- docs/gt_frontend/guides/fvm/fvm.rst | 66 ++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/docs/gt_frontend/guides/fvm/fvm.rst b/docs/gt_frontend/guides/fvm/fvm.rst index 9374d3b..cb1ce9a 100644 --- a/docs/gt_frontend/guides/fvm/fvm.rst +++ b/docs/gt_frontend/guides/fvm/fvm.rst @@ -70,7 +70,7 @@ Up until now we have just considered a single control volume without actually ta .. figure:: mesh.png :width: 300 :align: center - + Schematic of a 2D mesh At this point different choices for the quantities to be solved for are possible. We will here use a vertex-centered approach where the unknowns are choosen to be the densities at the vertices of the mesh :math:`\rho_i^n = \rho^n_i(x_i)`, which are a first order approximation of the average cell density :math:`\bar \rho_i^n` appearing in the time discretized form above. @@ -80,12 +80,12 @@ At this point different choices for the quantities to be solved for are possible \rho_i^{n+1} &= \rho_i^{n} - \frac{\delta t}{|\mathcal{V}_i|} \int_{\partial {\mathcal{V}_i}} \rho^n \mathbf{v} \cdot \mathbf{n} \mathrm{\,dA} \end{align} -The control volumes :math:`\mathcal{V}_i` are then constructed by joining the (bary)centers of the cells adjacent to each vertex with the midpoint of the adjacent edges. The set of control volumes form another mesh, denoted the dual mesh. +The control volumes :math:`\mathcal{V}_i` are then constructed by joining the (bary)centers of the cells adjacent to each vertex with the midpoint of the adjacent edges. The set of control volumes form another mesh, denoted the dual mesh. .. figure:: fvm_median_dual_mesh_cv.png :width: 300 :align: center - + Schematic of the median-dual mesh in 2D. Primary mesh in black, dual mesh in blue. The control volume :math:`\mathcal{V}_i` around the vertex :math:`v_i` is constructed by joining the (bary)centers of adjacent cells with the midpoint of the outgoing edges of :math:`v_i`. It remains to derive a discrete representation for the surface integral by first splitting the integral into its contributions on a set of segments :math:`S_j`, where each segment can be attributed to the edges adjacent to :math:`v_i`. Let :math:`|\mathcal{V}_i|` be the area of the control volume and :math:`l(i)` the number of edges adjacent to :math:`v_i` then @@ -129,7 +129,63 @@ The resulting fully discrete time stepping scheme then reads **Implementation in GT4Py** -To be written. +.. code-block:: python + + @gtscript.stencil(externals={"vel": vel}) + def fvm_advect( + mesh: Mesh, + rho: gtscript.Field[Vertex, dtype], + rho_next: gtscript.Field[Vertex, dtype], + volume: gtscript.Field[Vertex, dtype], + dual_normal: gtscript.Field[Edge, dtype], + dual_face_length: gtscript.Field[Edge, dtype], + face_orientation: gtscript.Field[Vertex, Edge, dtype], # either -1 or 1 + #flux: gtscript.Field[Edge, dtype] + ): + with computation(PARALLEL): + gtscript.Field[Edge, dtype] + # compute flux density through the intersection of the two + # control volumes around the dual cells associated with + # the vertices of `e` using an upwind scheme + with location(Edge) as e: + # upwind flux (instructive) + v1, v2 = vertices(e) + normal_velocity = dot(v, dual_normal[e]) # velocity projected onto the normal + if dot(vel, dual_normal[e]) > 0: + flux = rho[v1] * normal_velocity[e] * dual_face_length[e] + else: + flux = rho[v2] * normal_velocity[e] * dual_face_length[e] + + # upwind flux (compact) + v1, v2 = vertices(e) + normal_velocity = dot(v, dual_normal[e]) # velocity projected onto the normal + flux = dual_face_area[e]*(max(0., normal_velocity)*rho[v1] + min(0., normal_velocity)*rho[v2]) + + # upwind flux (compact with weights) + flux = dual_face_area[e]*sum(rho, weights=[max(0., normal_velocity), min(0., normal_velocity)]) + + # centered flux (different flux just for comparison here) + flux = 0.5*sum(rho[v]*vel for v in vertices(e)) + with location(Vertex) as v: + # compute density in the next timestep + rho_next = rho - δt/volume*sum(flux*face_orientation[v, e] for e in edges(v)) + + # parameters + vel = [1., 2.] # velocity + δt = 1e-6 # time step + niter = 100 + + # initialize mesh + # ... + + # initialize fields + rho = zeros(mesh, dtype) + rho_next = zeros(mesh, dtype) + # todo: geometry: dual_volume, dual_normal, face_orientation + + for i in range(niter): + fvm_advect(mesh, rho, rho_next, dual_volume, dual_normal, face_orientation, flux) + copyto(rho_next, rho) **TODO** @@ -143,4 +199,4 @@ Frame extension to IFS-FVM **Notes** -Code for the construction of the dual mesh in Atlas: src/atlas/mesh/actions/BuildDualMesh.cc \ No newline at end of file +Code for the construction of the dual mesh in Atlas: src/atlas/mesh/actions/BuildDualMesh.cc From 07468b3bac84cd073c6a3bbcc0395690ed66437c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 16 Oct 2020 13:35:44 +0200 Subject: [PATCH 2/5] Refactored py_to_gtscript transformer to support fields of type Optional[List] --- src/gt_frontend/py_to_gtscript.py | 41 +++++++++++++++++-------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/src/gt_frontend/py_to_gtscript.py b/src/gt_frontend/py_to_gtscript.py index 04cad89..bef03e1 100644 --- a/src/gt_frontend/py_to_gtscript.py +++ b/src/gt_frontend/py_to_gtscript.py @@ -63,6 +63,8 @@ def _all_subclasses(typ, *, module=None): # map to symbols in the gtscript ast and are resolved there assert issubclass(typ, enum.Enum) return {typ} + elif typing_inspect.get_origin(typ) == list: + return {typing.List[sub_cls] for sub_cls in PyToGTScript._all_subclasses(typing_inspect.get_args(typ)[0], module=module)} elif typing_inspect.is_union_type(typ): return { sub_cls @@ -170,7 +172,20 @@ def transform(self, node, eligible_node_types=None): if eligible_node_types is None: eligible_node_types = [gtscript_ast.Computation] - if isinstance(node, ast.AST): + if isinstance(node, typing.List): + # extract eligable node types which are lists + eligable_list_node_types = list(filter(lambda node_type: typing_inspect.get_origin(node_type) == list, + eligible_node_types)) + if len(eligable_list_node_types) == 0: + raise ValueError( + f"Expected a list node, but got {type(node)}." + ) + + eligable_el_node_types = list(map(lambda list_node_type: typing_inspect.get_args(list_node_type)[0], + eligable_list_node_types)) + + return [self.transform(el, eligable_el_node_types) for el in node] + elif isinstance(node, ast.AST): is_leaf_node = len(list(ast.iter_fields(node))) == 0 if is_leaf_node: if not type(node) in self.leaf_map: @@ -197,24 +212,12 @@ def transform(self, node, eligible_node_types=None): name in node_type.__annotations__ ), f"Invalid capture. No field named `{name}` in `{str(node_type)}`" field_type = node_type.__annotations__[name] - if typing_inspect.get_origin(field_type) == list: - # determine eligible capture types - el_type = typing_inspect.get_args(field_type)[0] - eligible_capture_types = self._all_subclasses(el_type, module=module) - - # transform captures recursively - transformed_captures[name] = [] - for child_capture in capture: - transformed_captures[name].append( - self.transform(child_capture, eligible_capture_types) - ) - else: - # determine eligible capture types - eligible_capture_types = self._all_subclasses(field_type, module=module) - # transform captures recursively - transformed_captures[name] = self.transform( - capture, eligible_capture_types - ) + # determine eligible capture types + eligible_capture_types = self._all_subclasses(field_type, module=module) + # transform captures recursively + transformed_captures[name] = self.transform( + capture, eligible_capture_types + ) return node_type(**transformed_captures) raise ValueError( "Expected a node of type {}".format( From 44a68c11ffa24f6c0264f4bdb71bba368aaf146f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 17 Oct 2020 18:22:21 +0200 Subject: [PATCH 3/5] Introduce weights argument to gtscript_ast and gtir (lowering not functional since missing translation from gtir downwards) --- src/gt_frontend/gtscript_ast.py | 23 +++++++++++++++++++++-- src/gt_frontend/gtscript_to_gtir.py | 13 +++++++++---- src/gt_frontend/py_to_gtscript.py | 6 +++++- src/gtc/unstructured/gtir.py | 1 + 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/gt_frontend/gtscript_ast.py b/src/gt_frontend/gtscript_ast.py index ca90502..1fc6c2f 100644 --- a/src/gt_frontend/gtscript_ast.py +++ b/src/gt_frontend/gtscript_ast.py @@ -14,7 +14,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later # todo(tehrengruber): document nodes -from typing import List, Union +from typing import List, Union, Optional import gtc.common as common from eve import Node @@ -109,11 +109,31 @@ class BinaryOp(Expr): left: Expr right: Expr +class ListNode(Expr): # todo: this node is not valid in every context + elts: List[Expr] + +class Keyword(GTScriptASTNode): + key: str + value: Expr class Call(Expr): args: List[Expr] + keywords: Optional[List[Keyword]] func: str + #todo(tehrengruber: validate each keyword arg occurs only once) + + def get_keyword_args_as_dict(self): + return {arg.key: arg.value for arg in (self.keywords if self.keywords else [])} + + def has_keyword_arg(self, key): + return key in self.get_keyword_args_as_dict() + + def get_keyword_arg(self, key): + if not self.has_keyword_arg(key): + raise ValueError(f"Call to {self.func} has no keyword argument {key}") + return self.get_keyword_args_as_dict()[key] + # TODO(tehrengruber): can be enabled as soon as eve_toolchain#58 lands # class Call(Generic[T]): @@ -132,7 +152,6 @@ class Generator(Expr): generators: List[LocationComprehension] elt: Expr - class Assign(Statement): target: Union[Symbol, SubscriptSingle, SubscriptMultiple] value: Expr diff --git a/src/gt_frontend/gtscript_to_gtir.py b/src/gt_frontend/gtscript_to_gtir.py index 6ab8c96..7ef5a77 100644 --- a/src/gt_frontend/gtscript_to_gtir.py +++ b/src/gt_frontend/gtscript_to_gtir.py @@ -32,6 +32,7 @@ Generator, Interval, IterationOrder, + ListNode, LocationComprehension, LocationSpecification, Pass, @@ -341,18 +342,22 @@ def visit_Call(self, node: Call, *, location_stack, **kwargs): neighbors = self.visit( node.args[0].generators[0], **{**kwargs, "location_stack": location_stack} ) + weights = None + if node.has_keyword_arg("weights"): + weights = node.get_keyword_arg("weights") + if not isinstance(weights, ListNode): + raise ValueError(f"Weights argument to neighbor reduction must be a list") - # operand gets new location stack - new_location_stack = location_stack + [neighbors] - + weights = list(self.visit(weight, **{**kwargs, "location_stack": location_stack}) for weight in weights.elts) operand = self.visit( - node.args[0].elt, **{**kwargs, "location_stack": new_location_stack} + node.args[0].elt, **{**kwargs, "location_stack": location_stack + [neighbors]} ) return gtir.NeighborReduce( op=op, operand=operand, neighbors=neighbors, + weights=weights, location_type=location_stack[-1].chain.elements[-1], ) diff --git a/src/gt_frontend/py_to_gtscript.py b/src/gt_frontend/py_to_gtscript.py index bef03e1..9f17e65 100644 --- a/src/gt_frontend/py_to_gtscript.py +++ b/src/gt_frontend/py_to_gtscript.py @@ -135,7 +135,11 @@ class Patterns: BinaryOp = ast.BinOp(op=Capture("op"), left=Capture("left"), right=Capture("right")) - Call = ast.Call(args=Capture("args"), func=ast.Name(id=Capture("func"))) + ListNode = ast.List(elts=Capture("elts")) + + Keyword = ast.keyword(arg=Capture("key"), value=Capture("value")) + + Call = ast.Call(args=Capture("args"), keywords=Capture("keywords"), func=ast.Name(id=Capture("func"))) LocationComprehension = ast.comprehension( target=Capture("target"), iter=Capture("iterator") diff --git a/src/gtc/unstructured/gtir.py b/src/gtc/unstructured/gtir.py index d9fb1fb..41fc20a 100644 --- a/src/gtc/unstructured/gtir.py +++ b/src/gtc/unstructured/gtir.py @@ -77,6 +77,7 @@ class NeighborReduce(Expr): operand: Expr op: ReduceOperator neighbors: LocationComprehension + weights: Optional[List[Expr]] @root_validator(pre=True) def check_location_type(cls, values): From 7a000dd72b3579a53e86eec4d307c50a737c4efe Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 21 Oct 2020 13:00:28 +0200 Subject: [PATCH 4/5] Adopted nir and usid to support weights argument to neighbor reduction (should be functional, but untested) --- src/gtc/unstructured/gtir.py | 1 - src/gtc/unstructured/gtir_to_nir.py | 40 +++++++++++++--- src/gtc/unstructured/nir.py | 42 ++++++++++++++++- src/gtc/unstructured/nir_to_usid.py | 46 +++++++++++++------ src/gtc/unstructured/usid.py | 18 ++++++++ src/gtc/unstructured/usid_codegen.py | 13 +++++- tests/tests_gtc/nir_utils.py | 4 +- .../test_nir_merge_horizontal_loops.py | 14 +++--- .../unit_tests/stencil_definitions.py | 7 ++- 9 files changed, 151 insertions(+), 34 deletions(-) diff --git a/src/gtc/unstructured/gtir.py b/src/gtc/unstructured/gtir.py index 41fc20a..b7c85db 100644 --- a/src/gtc/unstructured/gtir.py +++ b/src/gtc/unstructured/gtir.py @@ -125,7 +125,6 @@ def check_location_type(cls, values): return values - class VerticalDimension(Node): pass diff --git a/src/gtc/unstructured/gtir_to_nir.py b/src/gtc/unstructured/gtir_to_nir.py index 08b3e54..3cfaf59 100644 --- a/src/gtc/unstructured/gtir_to_nir.py +++ b/src/gtc/unstructured/gtir_to_nir.py @@ -122,15 +122,32 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg loc_comprehension[node.neighbors.name] = node.neighbors kwargs["location_comprehensions"] = loc_comprehension + neighbor_loop_name = "neighbor_loop_"+str(node.id_attr_) + + if node.weights and node.neighbors.chain.elements != [common.LocationType.Edge, common.LocationType.Vertex]: + raise ValueError("Invalid usage of weights in NeighborReduce.") + body_location = node.neighbors.chain.elements[-1] reduce_var_name = "local" + str(node.id_attr_) last_block.declarations.append( - nir.LocalVar( + nir.ScalarLocalVar( name=reduce_var_name, vtype=common.DataType.FLOAT64, # TODO location_type=node.location_type, ) ) + if node.weights: + weights_var_name = "local_weights_" + str(node.id_attr_) + last_block.declarations.append( + nir.LocalFieldVar( + name=weights_var_name, + vtype=common.DataType.FLOAT64, # TODO + domain=body_location, + init=self.visit(node.weights, **kwargs), + max_size=2, # TODO + location_type=node.location_type, + ) + ) last_block.statements.append( nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=node.location_type), @@ -142,17 +159,25 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg location_type=node.location_type, ), ) + reduction_item = self.visit(node.operand, in_neighbor_loop=True, **kwargs) + if node.weights: + reduction_item = nir.BinaryOp( + left=nir.LocalFieldAccess(name=weights_var_name, location=nir.NeighborLoopLocationAccess( + name=neighbor_loop_name, location_type=body_location), location_type=body_location), + op=common.BinaryOperator.MUL, right=reduction_item, location_type=body_location) + + reduction_intermediate = nir.BinaryOp( + left=nir.VarAccess(name=reduce_var_name, location_type=body_location), + op=self.REDUCE_OP_TO_BINOP[node.op], + right=reduction_item, + location_type=body_location, + ) body = nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), - right=nir.BinaryOp( - left=nir.VarAccess(name=reduce_var_name, location_type=body_location), - op=self.REDUCE_OP_TO_BINOP[node.op], - right=self.visit(node.operand, in_neighbor_loop=True, **kwargs), - location_type=body_location, - ), + right=reduction_intermediate, location_type=body_location, ) ], @@ -160,6 +185,7 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg ) last_block.statements.append( nir.NeighborLoop( + name=neighbor_loop_name, neighbors=self.visit(node.neighbors.chain), body=body, location_type=node.location_type, diff --git a/src/gtc/unstructured/nir.py b/src/gtc/unstructured/nir.py index b0086aa..47e98d5 100644 --- a/src/gtc/unstructured/nir.py +++ b/src/gtc/unstructured/nir.py @@ -64,10 +64,38 @@ def __str__(self): class LocalVar(Node): name: Str - vtype: common.DataType location_type: common.LocationType +class ScalarLocalVar(LocalVar): + vtype: common.DataType + init: Optional[Expr] # TODO: use in gtir to nir lowering for reduction var + + +class TensorLocalVar(LocalVar): + vtype: common.DataType + shape: List[int] + init: Optional[List[Expr]] + + +class LocalFieldVar(TensorLocalVar): + def __init__(self, *args, max_size, **kwargs): + assert "shape" not in kwargs + return super().__init__(*args, shape=[max_size], **kwargs) + + domain: common.LocationType # the type of locations the LocalField is defined on + + @validator('shape', pre=True, always=True) + def ensure_one_dimensional(cls, shape): + if len(shape) != 1: + raise ValueError("Invalid shape for LocalFieldVar.") + return shape + + @property + def max_size(self): + return self.shape[0] + + class BlockStmt(Stmt): declarations: List[LocalVar] statements: List[Stmt] @@ -97,6 +125,7 @@ def check_location_type(cls, values): class NeighborLoop(Stmt): + name: Str neighbors: NeighborChain body: BlockStmt @@ -123,6 +152,17 @@ def extent(self): class VarAccess(Access): pass +#class IndexAccess(Access): # TODO(tehrengruber): use for TensorLocalVar +# indices: List[int] + +class LocationLocalIdAccess(Access): + pass + +class NeighborLoopLocationAccess(LocationLocalIdAccess): + pass + +class LocalFieldAccess(Access): + location: LocationLocalIdAccess class AssignStmt(Stmt): left: Access # there are no local variables in gtir, only fields diff --git a/src/gtc/unstructured/nir_to_usid.py b/src/gtc/unstructured/nir_to_usid.py index 69af9f8..c94ef5c 100644 --- a/src/gtc/unstructured/nir_to_usid.py +++ b/src/gtc/unstructured/nir_to_usid.py @@ -77,6 +77,7 @@ def visit_Literal(self, node: nir.Literal, **kwargs): def visit_NeighborLoop(self, node: nir.NeighborLoop, **kwargs): return usid.NeighborLoop( + iter_var=node.name+"_neigh", outer_sid=kwargs["sids_tbl"][usid.NeighborChain(elements=[node.location_type])].name, connectivity=kwargs["conn_tbl"][node.neighbors].name, sid=kwargs["sids_tbl"][node.neighbors].name @@ -104,21 +105,40 @@ def visit_AssignStmt(self, node: nir.AssignStmt, **kwargs): location_type=node.location_type, ) + def visit_ScalarLocalVar(self, node: nir.ScalarLocalVar, **kwargs): + return usid.ScalarVarDecl( + name=node.name, + init=usid.Literal( + value="0.0", vtype=node.vtype, location_type=kwargs["location_type"] + ), + vtype=node.vtype, + location_type=kwargs["location_type"], + ) + + def visit_TensorLocalVar(self, node: nir.TensorLocalVar, **kwargs): + assert len(node.shape) == 1, "Only one-dimensional arrays allowed" + return usid.StaticArrayDecl( + name=node.name, + init=[self.visit(init_el, **kwargs) for init_el in node.init] if node.init else None, + vtype=node.vtype, + length=node.shape[0], + location_type=kwargs["location_type"], # TODO: why not from node? + ) + + def visit_LocalFieldAccess(self, node: nir.LocalFieldAccess, **kwargs): + return usid.StaticArrayAccess( + name=node.name, + index=self.visit(node.location, **kwargs), + location_type=node.location_type + ) + + def visit_NeighborLoopLocationAccess(self, node: nir.NeighborLoopLocationAccess, **kwargs): + return usid.IndexAccess(name=node.name+"_neigh", location_type=node.location_type) + def visit_BlockStmt(self, node: nir.BlockStmt, **kwargs): statements = [] - for decl in node.declarations: - statements.append( - usid.VarDecl( - name=decl.name, - init=usid.Literal( - value="0.0", vtype=decl.vtype, location_type=node.location_type - ), - vtype=decl.vtype, - location_type=node.location_type, - ) - ) - for stmt in node.statements: - statements.append(self.visit(stmt, **kwargs)) + statements += [self.visit(decl, location_type=node.location_type) for decl in node.declarations] + statements += [self.visit(stmt, **kwargs) for stmt in node.statements] return statements def visit_HorizontalLoop(self, node: nir.HorizontalLoop, **kwargs): diff --git a/src/gtc/unstructured/usid.py b/src/gtc/unstructured/usid.py index 4570096..2969eda 100644 --- a/src/gtc/unstructured/usid.py +++ b/src/gtc/unstructured/usid.py @@ -56,17 +56,34 @@ def __str__(self): return "_".join([common.LocationType(loc).name.lower() for loc in self.elements]) +class IndexAccess(Expr): + name: Str # TODO(tehrengruber): Maybe IndexDecl in NeighborLoop? + + class FieldAccess(Expr): name: Str # symbol ref to SidCompositeEntry sid: Str # symbol ref +class StaticArrayAccess(Expr): + name: Str + index: IndexAccess # TODO(tehrengruber): Union[Index, IndexAccess] + + class VarDecl(Stmt): name: Str + + +class ScalarVarDecl(VarDecl): init: Expr vtype: common.DataType +class StaticArrayDecl(VarDecl): + init: List[Expr] + vtype: common.DataType + length: int + class Literal(Expr): value: Union[common.BuiltInLiteral, Str] vtype: common.DataType @@ -227,6 +244,7 @@ class NeighborLoop(Stmt): sid: Optional[ Str ] # symbol ref to SidComposite where the fields of the loop body live (None if only sparse fields are accessed) + iter_var: Str class Kernel(Node): diff --git a/src/gtc/unstructured/usid_codegen.py b/src/gtc/unstructured/usid_codegen.py index d61b181..30941e6 100644 --- a/src/gtc/unstructured/usid_codegen.py +++ b/src/gtc/unstructured/usid_codegen.py @@ -29,6 +29,7 @@ Kernel, KernelCall, SidCompositeNeighborTableEntry, + StaticArrayAccess, Temporary, ) @@ -174,6 +175,8 @@ def visit_Kernel(self, node: Kernel, **kwargs): %>*gridtools::device::at_key<${ sid_entry_deref.tag_name }>(${ sid_deref.ptr_name })""" ) + StaticArrayAccess = as_fmt("{ name }[{ index }]") + AssignStmt = as_fmt("{left} = {right};") BinaryOp = as_fmt("({left} {op} {right})") @@ -185,7 +188,7 @@ def visit_Kernel(self, node: Kernel, **kwargs): conn_deref = symbol_tbl_conn[_this_node.connectivity] body_location = _this_generator.LOCATION_TYPE_TO_STR[sid_deref.location.elements[-1]] if sid_deref else None %> - for (int neigh = 0; neigh < gridtools::next::connectivity::max_neighbors(${ conn_deref.name }); ++neigh) { + for (int ${ iter_var } = 0; ${ iter_var } < gridtools::next::connectivity::max_neighbors(${ conn_deref.name }); ++${ iter_var }) { auto absolute_neigh_index = *gridtools::device::at_key<${ conn_deref.neighbor_tbl_tag }>(${ outer_sid_deref.ptr_name}); if (absolute_neigh_index != gridtools::next::connectivity::skip_value(${ conn_deref.name })) { % if sid_deref: @@ -214,10 +217,16 @@ def visit_Kernel(self, node: Kernel, **kwargs): VarAccess = as_fmt("{name}") - VarDecl = as_mako( + IndexAccess = as_fmt("{name}") + + ScalarVarDecl = as_mako( "${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] } ${ name } = ${ init };" ) + StaticArrayDecl = as_mako( + "${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] } ${ name }[${ length }] = {${ ', '.join(init) }};" + ) + def visit_Computation(self, node: Computation, **kwargs): symbol_tbl_kernel = {k.name: k for k in node.kernels} sid_tags = set() diff --git a/tests/tests_gtc/nir_utils.py b/tests/tests_gtc/nir_utils.py index 8bb22c8..a021561 100644 --- a/tests/tests_gtc/nir_utils.py +++ b/tests/tests_gtc/nir_utils.py @@ -57,8 +57,8 @@ def make_horizontal_loop_with_copy(write: Str, read: Str, read_has_extent: Bool) ) -def make_local_var(name: Str): - return nir.LocalVar(name=name, vtype=default_vtype, location_type=default_location) +def make_scalar_local_var(name: Str): + return nir.ScalarLocalVar(name=name, vtype=default_vtype, location_type=default_location) def make_init(field: Str): diff --git a/tests/tests_gtc/test_nir_merge_horizontal_loops.py b/tests/tests_gtc/test_nir_merge_horizontal_loops.py index d21c5f0..53fe3af 100644 --- a/tests/tests_gtc/test_nir_merge_horizontal_loops.py +++ b/tests/tests_gtc/test_nir_merge_horizontal_loops.py @@ -31,7 +31,7 @@ make_horizontal_loop_with_copy, make_horizontal_loop_with_init, make_init, - make_local_var, + make_scalar_local_var, make_vertical_loop, ) @@ -168,11 +168,11 @@ def test_merge_empty_loops(self): assert len(result.horizontal_loops) == 1 def test_merge_loops_with_stats_and_decls(self): - var1 = make_local_var("var1") + var1 = make_scalar_local_var("var1") assignment1, _ = make_init("field1") first_loop = make_horizontal_loop(make_block_stmt([assignment1], [var1])) - var2 = make_local_var("var2") + var2 = make_scalar_local_var("var2") assignment2, _ = make_init("field2") second_loop = make_horizontal_loop(make_block_stmt([assignment2], [var2])) @@ -187,11 +187,11 @@ def test_merge_loops_with_stats_and_decls(self): # TODO more precise checks def test_find_and_merge(self): - var1 = make_local_var("var1") + var1 = make_scalar_local_var("var1") assignment1, _ = make_init("field1") first_loop = make_horizontal_loop(make_block_stmt([assignment1], [var1])) - var2 = make_local_var("var2") + var2 = make_scalar_local_var("var2") assignment2, _ = make_init("field2") second_loop = make_horizontal_loop(make_block_stmt([assignment2], [var2])) @@ -205,11 +205,11 @@ def test_find_and_merge(self): # TODO more precise checks def test_find_and_merge_with_2_vertical_loops(self): - var1 = make_local_var("var1") + var1 = make_scalar_local_var("var1") assignment1, _ = make_init("field1") first_loop = make_horizontal_loop(make_block_stmt([assignment1], [var1])) - var2 = make_local_var("var2") + var2 = make_scalar_local_var("var2") assignment2, _ = make_init("field2") second_loop = make_horizontal_loop(make_block_stmt([assignment2], [var2])) diff --git a/tests/tests_gtc/unit_tests/stencil_definitions.py b/tests/tests_gtc/unit_tests/stencil_definitions.py index e633ac4..a424d94 100644 --- a/tests/tests_gtc/unit_tests/stencil_definitions.py +++ b/tests/tests_gtc/unit_tests/stencil_definitions.py @@ -34,7 +34,12 @@ dtype = common.DataType.FLOAT64 -valid_stencils = ["edge_reduction", "sparse_ex", "nested", "fvm_nabla", "temporary_field"] +valid_stencils = ["edge_reduction", "sparse_ex", "nested", "fvm_nabla", "temporary_field", "weighted_neighbor_reduction"] + +def weighted_neighbor_reduction(mesh: Mesh, vertex_field : Field[Vertex, dtype], edge_field : Field[Vertex, dtype]): + with computation(FORWARD), interval(0, None): + with location(Edge) as e: + edge_field = sum((vertex_field[v] for v in vertices(e)), weights=[1, 2]) def copy(mesh: Mesh, field_in: Field[Vertex, dtype], field_out: Field[Vertex, dtype]): From 6b69107944b83fd73d5a7b9108fa7820faf28e19 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 22 Oct 2020 00:14:47 +0200 Subject: [PATCH 5/5] Reformat --- src/gt_frontend/gtscript_ast.py | 10 ++++-- src/gt_frontend/gtscript_to_gtir.py | 5 ++- src/gt_frontend/py_to_gtscript.py | 35 ++++++++++++------- src/gtc/unstructured/gtir.py | 1 + src/gtc/unstructured/gtir_to_nir.py | 23 ++++++++---- src/gtc/unstructured/nir.py | 11 ++++-- src/gtc/unstructured/nir_to_usid.py | 14 ++++---- src/gtc/unstructured/usid.py | 3 +- src/gtc/unstructured/usid_codegen.py | 1 - .../unit_tests/stencil_definitions.py | 16 +++++++-- 10 files changed, 82 insertions(+), 37 deletions(-) diff --git a/src/gt_frontend/gtscript_ast.py b/src/gt_frontend/gtscript_ast.py index 1fc6c2f..9142359 100644 --- a/src/gt_frontend/gtscript_ast.py +++ b/src/gt_frontend/gtscript_ast.py @@ -14,7 +14,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later # todo(tehrengruber): document nodes -from typing import List, Union, Optional +from typing import List, Optional, Union import gtc.common as common from eve import Node @@ -109,19 +109,22 @@ class BinaryOp(Expr): left: Expr right: Expr -class ListNode(Expr): # todo: this node is not valid in every context + +class ListNode(Expr): # todo: this node is not valid in every context elts: List[Expr] + class Keyword(GTScriptASTNode): key: str value: Expr + class Call(Expr): args: List[Expr] keywords: Optional[List[Keyword]] func: str - #todo(tehrengruber: validate each keyword arg occurs only once) + # todo(tehrengruber: validate each keyword arg occurs only once) def get_keyword_args_as_dict(self): return {arg.key: arg.value for arg in (self.keywords if self.keywords else [])} @@ -152,6 +155,7 @@ class Generator(Expr): generators: List[LocationComprehension] elt: Expr + class Assign(Statement): target: Union[Symbol, SubscriptSingle, SubscriptMultiple] value: Expr diff --git a/src/gt_frontend/gtscript_to_gtir.py b/src/gt_frontend/gtscript_to_gtir.py index 7ef5a77..7e5bd59 100644 --- a/src/gt_frontend/gtscript_to_gtir.py +++ b/src/gt_frontend/gtscript_to_gtir.py @@ -348,7 +348,10 @@ def visit_Call(self, node: Call, *, location_stack, **kwargs): if not isinstance(weights, ListNode): raise ValueError(f"Weights argument to neighbor reduction must be a list") - weights = list(self.visit(weight, **{**kwargs, "location_stack": location_stack}) for weight in weights.elts) + weights = list( + self.visit(weight, **{**kwargs, "location_stack": location_stack}) + for weight in weights.elts + ) operand = self.visit( node.args[0].elt, **{**kwargs, "location_stack": location_stack + [neighbors]} ) diff --git a/src/gt_frontend/py_to_gtscript.py b/src/gt_frontend/py_to_gtscript.py index 9f17e65..a4f8032 100644 --- a/src/gt_frontend/py_to_gtscript.py +++ b/src/gt_frontend/py_to_gtscript.py @@ -64,7 +64,12 @@ def _all_subclasses(typ, *, module=None): assert issubclass(typ, enum.Enum) return {typ} elif typing_inspect.get_origin(typ) == list: - return {typing.List[sub_cls] for sub_cls in PyToGTScript._all_subclasses(typing_inspect.get_args(typ)[0], module=module)} + return { + typing.List[sub_cls] + for sub_cls in PyToGTScript._all_subclasses( + typing_inspect.get_args(typ)[0], module=module + ) + } elif typing_inspect.is_union_type(typ): return { sub_cls @@ -139,7 +144,9 @@ class Patterns: Keyword = ast.keyword(arg=Capture("key"), value=Capture("value")) - Call = ast.Call(args=Capture("args"), keywords=Capture("keywords"), func=ast.Name(id=Capture("func"))) + Call = ast.Call( + args=Capture("args"), keywords=Capture("keywords"), func=ast.Name(id=Capture("func")) + ) LocationComprehension = ast.comprehension( target=Capture("target"), iter=Capture("iterator") @@ -178,15 +185,21 @@ def transform(self, node, eligible_node_types=None): if isinstance(node, typing.List): # extract eligable node types which are lists - eligable_list_node_types = list(filter(lambda node_type: typing_inspect.get_origin(node_type) == list, - eligible_node_types)) - if len(eligable_list_node_types) == 0: - raise ValueError( - f"Expected a list node, but got {type(node)}." + eligable_list_node_types = list( + filter( + lambda node_type: typing_inspect.get_origin(node_type) == list, + eligible_node_types, ) + ) + if len(eligable_list_node_types) == 0: + raise ValueError(f"Expected a list node, but got {type(node)}.") - eligable_el_node_types = list(map(lambda list_node_type: typing_inspect.get_args(list_node_type)[0], - eligable_list_node_types)) + eligable_el_node_types = list( + map( + lambda list_node_type: typing_inspect.get_args(list_node_type)[0], + eligable_list_node_types, + ) + ) return [self.transform(el, eligable_el_node_types) for el in node] elif isinstance(node, ast.AST): @@ -219,9 +232,7 @@ def transform(self, node, eligible_node_types=None): # determine eligible capture types eligible_capture_types = self._all_subclasses(field_type, module=module) # transform captures recursively - transformed_captures[name] = self.transform( - capture, eligible_capture_types - ) + transformed_captures[name] = self.transform(capture, eligible_capture_types) return node_type(**transformed_captures) raise ValueError( "Expected a node of type {}".format( diff --git a/src/gtc/unstructured/gtir.py b/src/gtc/unstructured/gtir.py index b7c85db..41fc20a 100644 --- a/src/gtc/unstructured/gtir.py +++ b/src/gtc/unstructured/gtir.py @@ -125,6 +125,7 @@ def check_location_type(cls, values): return values + class VerticalDimension(Node): pass diff --git a/src/gtc/unstructured/gtir_to_nir.py b/src/gtc/unstructured/gtir_to_nir.py index 3cfaf59..f400d89 100644 --- a/src/gtc/unstructured/gtir_to_nir.py +++ b/src/gtc/unstructured/gtir_to_nir.py @@ -122,10 +122,13 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg loc_comprehension[node.neighbors.name] = node.neighbors kwargs["location_comprehensions"] = loc_comprehension - neighbor_loop_name = "neighbor_loop_"+str(node.id_attr_) + neighbor_loop_name = "neighbor_loop_" + str(node.id_attr_) - if node.weights and node.neighbors.chain.elements != [common.LocationType.Edge, common.LocationType.Vertex]: - raise ValueError("Invalid usage of weights in NeighborReduce.") + if node.weights and node.neighbors.chain.elements != [ + common.LocationType.Edge, + common.LocationType.Vertex, + ]: + raise ValueError("Invalid usage of weights in NeighborReduce.") body_location = node.neighbors.chain.elements[-1] reduce_var_name = "local" + str(node.id_attr_) @@ -162,9 +165,17 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg reduction_item = self.visit(node.operand, in_neighbor_loop=True, **kwargs) if node.weights: reduction_item = nir.BinaryOp( - left=nir.LocalFieldAccess(name=weights_var_name, location=nir.NeighborLoopLocationAccess( - name=neighbor_loop_name, location_type=body_location), location_type=body_location), - op=common.BinaryOperator.MUL, right=reduction_item, location_type=body_location) + left=nir.LocalFieldAccess( + name=weights_var_name, + location=nir.NeighborLoopLocationAccess( + name=neighbor_loop_name, location_type=body_location + ), + location_type=body_location, + ), + op=common.BinaryOperator.MUL, + right=reduction_item, + location_type=body_location, + ) reduction_intermediate = nir.BinaryOp( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), diff --git a/src/gtc/unstructured/nir.py b/src/gtc/unstructured/nir.py index 47e98d5..a958dfe 100644 --- a/src/gtc/unstructured/nir.py +++ b/src/gtc/unstructured/nir.py @@ -69,7 +69,7 @@ class LocalVar(Node): class ScalarLocalVar(LocalVar): vtype: common.DataType - init: Optional[Expr] # TODO: use in gtir to nir lowering for reduction var + init: Optional[Expr] # TODO: use in gtir to nir lowering for reduction var class TensorLocalVar(LocalVar): @@ -85,7 +85,7 @@ def __init__(self, *args, max_size, **kwargs): domain: common.LocationType # the type of locations the LocalField is defined on - @validator('shape', pre=True, always=True) + @validator("shape", pre=True, always=True) def ensure_one_dimensional(cls, shape): if len(shape) != 1: raise ValueError("Invalid shape for LocalFieldVar.") @@ -152,18 +152,23 @@ def extent(self): class VarAccess(Access): pass -#class IndexAccess(Access): # TODO(tehrengruber): use for TensorLocalVar + +# class IndexAccess(Access): # TODO(tehrengruber): use for TensorLocalVar # indices: List[int] + class LocationLocalIdAccess(Access): pass + class NeighborLoopLocationAccess(LocationLocalIdAccess): pass + class LocalFieldAccess(Access): location: LocationLocalIdAccess + class AssignStmt(Stmt): left: Access # there are no local variables in gtir, only fields right: Expr diff --git a/src/gtc/unstructured/nir_to_usid.py b/src/gtc/unstructured/nir_to_usid.py index c94ef5c..a353340 100644 --- a/src/gtc/unstructured/nir_to_usid.py +++ b/src/gtc/unstructured/nir_to_usid.py @@ -77,7 +77,7 @@ def visit_Literal(self, node: nir.Literal, **kwargs): def visit_NeighborLoop(self, node: nir.NeighborLoop, **kwargs): return usid.NeighborLoop( - iter_var=node.name+"_neigh", + iter_var=node.name + "_neigh", outer_sid=kwargs["sids_tbl"][usid.NeighborChain(elements=[node.location_type])].name, connectivity=kwargs["conn_tbl"][node.neighbors].name, sid=kwargs["sids_tbl"][node.neighbors].name @@ -108,9 +108,7 @@ def visit_AssignStmt(self, node: nir.AssignStmt, **kwargs): def visit_ScalarLocalVar(self, node: nir.ScalarLocalVar, **kwargs): return usid.ScalarVarDecl( name=node.name, - init=usid.Literal( - value="0.0", vtype=node.vtype, location_type=kwargs["location_type"] - ), + init=usid.Literal(value="0.0", vtype=node.vtype, location_type=kwargs["location_type"]), vtype=node.vtype, location_type=kwargs["location_type"], ) @@ -129,15 +127,17 @@ def visit_LocalFieldAccess(self, node: nir.LocalFieldAccess, **kwargs): return usid.StaticArrayAccess( name=node.name, index=self.visit(node.location, **kwargs), - location_type=node.location_type + location_type=node.location_type, ) def visit_NeighborLoopLocationAccess(self, node: nir.NeighborLoopLocationAccess, **kwargs): - return usid.IndexAccess(name=node.name+"_neigh", location_type=node.location_type) + return usid.IndexAccess(name=node.name + "_neigh", location_type=node.location_type) def visit_BlockStmt(self, node: nir.BlockStmt, **kwargs): statements = [] - statements += [self.visit(decl, location_type=node.location_type) for decl in node.declarations] + statements += [ + self.visit(decl, location_type=node.location_type) for decl in node.declarations + ] statements += [self.visit(stmt, **kwargs) for stmt in node.statements] return statements diff --git a/src/gtc/unstructured/usid.py b/src/gtc/unstructured/usid.py index 2969eda..2b61d0b 100644 --- a/src/gtc/unstructured/usid.py +++ b/src/gtc/unstructured/usid.py @@ -57,7 +57,7 @@ def __str__(self): class IndexAccess(Expr): - name: Str # TODO(tehrengruber): Maybe IndexDecl in NeighborLoop? + name: Str # TODO(tehrengruber): Maybe IndexDecl in NeighborLoop? class FieldAccess(Expr): @@ -84,6 +84,7 @@ class StaticArrayDecl(VarDecl): vtype: common.DataType length: int + class Literal(Expr): value: Union[common.BuiltInLiteral, Str] vtype: common.DataType diff --git a/src/gtc/unstructured/usid_codegen.py b/src/gtc/unstructured/usid_codegen.py index 30941e6..b14f819 100644 --- a/src/gtc/unstructured/usid_codegen.py +++ b/src/gtc/unstructured/usid_codegen.py @@ -29,7 +29,6 @@ Kernel, KernelCall, SidCompositeNeighborTableEntry, - StaticArrayAccess, Temporary, ) diff --git a/tests/tests_gtc/unit_tests/stencil_definitions.py b/tests/tests_gtc/unit_tests/stencil_definitions.py index a424d94..d229919 100644 --- a/tests/tests_gtc/unit_tests/stencil_definitions.py +++ b/tests/tests_gtc/unit_tests/stencil_definitions.py @@ -34,9 +34,19 @@ dtype = common.DataType.FLOAT64 -valid_stencils = ["edge_reduction", "sparse_ex", "nested", "fvm_nabla", "temporary_field", "weighted_neighbor_reduction"] - -def weighted_neighbor_reduction(mesh: Mesh, vertex_field : Field[Vertex, dtype], edge_field : Field[Vertex, dtype]): +valid_stencils = [ + "edge_reduction", + "sparse_ex", + "nested", + "fvm_nabla", + "temporary_field", + "weighted_neighbor_reduction", +] + + +def weighted_neighbor_reduction( + mesh: Mesh, vertex_field: Field[Vertex, dtype], edge_field: Field[Vertex, dtype] +): with computation(FORWARD), interval(0, None): with location(Edge) as e: edge_field = sum((vertex_field[v] for v in vertices(e)), weights=[1, 2])