diff --git a/docs/gt_frontend/guides/fvm/fvm.rst b/docs/gt_frontend/guides/fvm/fvm.rst index 9374d3b..cb1ce9a 100644 --- a/docs/gt_frontend/guides/fvm/fvm.rst +++ b/docs/gt_frontend/guides/fvm/fvm.rst @@ -70,7 +70,7 @@ Up until now we have just considered a single control volume without actually ta .. figure:: mesh.png :width: 300 :align: center - + Schematic of a 2D mesh At this point different choices for the quantities to be solved for are possible. We will here use a vertex-centered approach where the unknowns are choosen to be the densities at the vertices of the mesh :math:`\rho_i^n = \rho^n_i(x_i)`, which are a first order approximation of the average cell density :math:`\bar \rho_i^n` appearing in the time discretized form above. @@ -80,12 +80,12 @@ At this point different choices for the quantities to be solved for are possible \rho_i^{n+1} &= \rho_i^{n} - \frac{\delta t}{|\mathcal{V}_i|} \int_{\partial {\mathcal{V}_i}} \rho^n \mathbf{v} \cdot \mathbf{n} \mathrm{\,dA} \end{align} -The control volumes :math:`\mathcal{V}_i` are then constructed by joining the (bary)centers of the cells adjacent to each vertex with the midpoint of the adjacent edges. The set of control volumes form another mesh, denoted the dual mesh. +The control volumes :math:`\mathcal{V}_i` are then constructed by joining the (bary)centers of the cells adjacent to each vertex with the midpoint of the adjacent edges. The set of control volumes form another mesh, denoted the dual mesh. .. figure:: fvm_median_dual_mesh_cv.png :width: 300 :align: center - + Schematic of the median-dual mesh in 2D. Primary mesh in black, dual mesh in blue. The control volume :math:`\mathcal{V}_i` around the vertex :math:`v_i` is constructed by joining the (bary)centers of adjacent cells with the midpoint of the outgoing edges of :math:`v_i`. It remains to derive a discrete representation for the surface integral by first splitting the integral into its contributions on a set of segments :math:`S_j`, where each segment can be attributed to the edges adjacent to :math:`v_i`. Let :math:`|\mathcal{V}_i|` be the area of the control volume and :math:`l(i)` the number of edges adjacent to :math:`v_i` then @@ -129,7 +129,63 @@ The resulting fully discrete time stepping scheme then reads **Implementation in GT4Py** -To be written. +.. code-block:: python + + @gtscript.stencil(externals={"vel": vel}) + def fvm_advect( + mesh: Mesh, + rho: gtscript.Field[Vertex, dtype], + rho_next: gtscript.Field[Vertex, dtype], + volume: gtscript.Field[Vertex, dtype], + dual_normal: gtscript.Field[Edge, dtype], + dual_face_length: gtscript.Field[Edge, dtype], + face_orientation: gtscript.Field[Vertex, Edge, dtype], # either -1 or 1 + #flux: gtscript.Field[Edge, dtype] + ): + with computation(PARALLEL): + gtscript.Field[Edge, dtype] + # compute flux density through the intersection of the two + # control volumes around the dual cells associated with + # the vertices of `e` using an upwind scheme + with location(Edge) as e: + # upwind flux (instructive) + v1, v2 = vertices(e) + normal_velocity = dot(v, dual_normal[e]) # velocity projected onto the normal + if dot(vel, dual_normal[e]) > 0: + flux = rho[v1] * normal_velocity[e] * dual_face_length[e] + else: + flux = rho[v2] * normal_velocity[e] * dual_face_length[e] + + # upwind flux (compact) + v1, v2 = vertices(e) + normal_velocity = dot(v, dual_normal[e]) # velocity projected onto the normal + flux = dual_face_area[e]*(max(0., normal_velocity)*rho[v1] + min(0., normal_velocity)*rho[v2]) + + # upwind flux (compact with weights) + flux = dual_face_area[e]*sum(rho, weights=[max(0., normal_velocity), min(0., normal_velocity)]) + + # centered flux (different flux just for comparison here) + flux = 0.5*sum(rho[v]*vel for v in vertices(e)) + with location(Vertex) as v: + # compute density in the next timestep + rho_next = rho - δt/volume*sum(flux*face_orientation[v, e] for e in edges(v)) + + # parameters + vel = [1., 2.] # velocity + δt = 1e-6 # time step + niter = 100 + + # initialize mesh + # ... + + # initialize fields + rho = zeros(mesh, dtype) + rho_next = zeros(mesh, dtype) + # todo: geometry: dual_volume, dual_normal, face_orientation + + for i in range(niter): + fvm_advect(mesh, rho, rho_next, dual_volume, dual_normal, face_orientation, flux) + copyto(rho_next, rho) **TODO** @@ -143,4 +199,4 @@ Frame extension to IFS-FVM **Notes** -Code for the construction of the dual mesh in Atlas: src/atlas/mesh/actions/BuildDualMesh.cc \ No newline at end of file +Code for the construction of the dual mesh in Atlas: src/atlas/mesh/actions/BuildDualMesh.cc diff --git a/src/gt_frontend/gtscript_ast.py b/src/gt_frontend/gtscript_ast.py index ca90502..9142359 100644 --- a/src/gt_frontend/gtscript_ast.py +++ b/src/gt_frontend/gtscript_ast.py @@ -14,7 +14,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later # todo(tehrengruber): document nodes -from typing import List, Union +from typing import List, Optional, Union import gtc.common as common from eve import Node @@ -110,10 +110,33 @@ class BinaryOp(Expr): right: Expr +class ListNode(Expr): # todo: this node is not valid in every context + elts: List[Expr] + + +class Keyword(GTScriptASTNode): + key: str + value: Expr + + class Call(Expr): args: List[Expr] + keywords: Optional[List[Keyword]] func: str + # todo(tehrengruber: validate each keyword arg occurs only once) + + def get_keyword_args_as_dict(self): + return {arg.key: arg.value for arg in (self.keywords if self.keywords else [])} + + def has_keyword_arg(self, key): + return key in self.get_keyword_args_as_dict() + + def get_keyword_arg(self, key): + if not self.has_keyword_arg(key): + raise ValueError(f"Call to {self.func} has no keyword argument {key}") + return self.get_keyword_args_as_dict()[key] + # TODO(tehrengruber): can be enabled as soon as eve_toolchain#58 lands # class Call(Generic[T]): diff --git a/src/gt_frontend/gtscript_to_gtir.py b/src/gt_frontend/gtscript_to_gtir.py index 6ab8c96..7e5bd59 100644 --- a/src/gt_frontend/gtscript_to_gtir.py +++ b/src/gt_frontend/gtscript_to_gtir.py @@ -32,6 +32,7 @@ Generator, Interval, IterationOrder, + ListNode, LocationComprehension, LocationSpecification, Pass, @@ -341,18 +342,25 @@ def visit_Call(self, node: Call, *, location_stack, **kwargs): neighbors = self.visit( node.args[0].generators[0], **{**kwargs, "location_stack": location_stack} ) - - # operand gets new location stack - new_location_stack = location_stack + [neighbors] - + weights = None + if node.has_keyword_arg("weights"): + weights = node.get_keyword_arg("weights") + if not isinstance(weights, ListNode): + raise ValueError(f"Weights argument to neighbor reduction must be a list") + + weights = list( + self.visit(weight, **{**kwargs, "location_stack": location_stack}) + for weight in weights.elts + ) operand = self.visit( - node.args[0].elt, **{**kwargs, "location_stack": new_location_stack} + node.args[0].elt, **{**kwargs, "location_stack": location_stack + [neighbors]} ) return gtir.NeighborReduce( op=op, operand=operand, neighbors=neighbors, + weights=weights, location_type=location_stack[-1].chain.elements[-1], ) diff --git a/src/gt_frontend/py_to_gtscript.py b/src/gt_frontend/py_to_gtscript.py index 04cad89..a4f8032 100644 --- a/src/gt_frontend/py_to_gtscript.py +++ b/src/gt_frontend/py_to_gtscript.py @@ -63,6 +63,13 @@ def _all_subclasses(typ, *, module=None): # map to symbols in the gtscript ast and are resolved there assert issubclass(typ, enum.Enum) return {typ} + elif typing_inspect.get_origin(typ) == list: + return { + typing.List[sub_cls] + for sub_cls in PyToGTScript._all_subclasses( + typing_inspect.get_args(typ)[0], module=module + ) + } elif typing_inspect.is_union_type(typ): return { sub_cls @@ -133,7 +140,13 @@ class Patterns: BinaryOp = ast.BinOp(op=Capture("op"), left=Capture("left"), right=Capture("right")) - Call = ast.Call(args=Capture("args"), func=ast.Name(id=Capture("func"))) + ListNode = ast.List(elts=Capture("elts")) + + Keyword = ast.keyword(arg=Capture("key"), value=Capture("value")) + + Call = ast.Call( + args=Capture("args"), keywords=Capture("keywords"), func=ast.Name(id=Capture("func")) + ) LocationComprehension = ast.comprehension( target=Capture("target"), iter=Capture("iterator") @@ -170,7 +183,26 @@ def transform(self, node, eligible_node_types=None): if eligible_node_types is None: eligible_node_types = [gtscript_ast.Computation] - if isinstance(node, ast.AST): + if isinstance(node, typing.List): + # extract eligable node types which are lists + eligable_list_node_types = list( + filter( + lambda node_type: typing_inspect.get_origin(node_type) == list, + eligible_node_types, + ) + ) + if len(eligable_list_node_types) == 0: + raise ValueError(f"Expected a list node, but got {type(node)}.") + + eligable_el_node_types = list( + map( + lambda list_node_type: typing_inspect.get_args(list_node_type)[0], + eligable_list_node_types, + ) + ) + + return [self.transform(el, eligable_el_node_types) for el in node] + elif isinstance(node, ast.AST): is_leaf_node = len(list(ast.iter_fields(node))) == 0 if is_leaf_node: if not type(node) in self.leaf_map: @@ -197,24 +229,10 @@ def transform(self, node, eligible_node_types=None): name in node_type.__annotations__ ), f"Invalid capture. No field named `{name}` in `{str(node_type)}`" field_type = node_type.__annotations__[name] - if typing_inspect.get_origin(field_type) == list: - # determine eligible capture types - el_type = typing_inspect.get_args(field_type)[0] - eligible_capture_types = self._all_subclasses(el_type, module=module) - - # transform captures recursively - transformed_captures[name] = [] - for child_capture in capture: - transformed_captures[name].append( - self.transform(child_capture, eligible_capture_types) - ) - else: - # determine eligible capture types - eligible_capture_types = self._all_subclasses(field_type, module=module) - # transform captures recursively - transformed_captures[name] = self.transform( - capture, eligible_capture_types - ) + # determine eligible capture types + eligible_capture_types = self._all_subclasses(field_type, module=module) + # transform captures recursively + transformed_captures[name] = self.transform(capture, eligible_capture_types) return node_type(**transformed_captures) raise ValueError( "Expected a node of type {}".format( diff --git a/src/gtc/unstructured/gtir.py b/src/gtc/unstructured/gtir.py index d9fb1fb..41fc20a 100644 --- a/src/gtc/unstructured/gtir.py +++ b/src/gtc/unstructured/gtir.py @@ -77,6 +77,7 @@ class NeighborReduce(Expr): operand: Expr op: ReduceOperator neighbors: LocationComprehension + weights: Optional[List[Expr]] @root_validator(pre=True) def check_location_type(cls, values): diff --git a/src/gtc/unstructured/gtir_to_nir.py b/src/gtc/unstructured/gtir_to_nir.py index 08b3e54..f400d89 100644 --- a/src/gtc/unstructured/gtir_to_nir.py +++ b/src/gtc/unstructured/gtir_to_nir.py @@ -122,15 +122,35 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg loc_comprehension[node.neighbors.name] = node.neighbors kwargs["location_comprehensions"] = loc_comprehension + neighbor_loop_name = "neighbor_loop_" + str(node.id_attr_) + + if node.weights and node.neighbors.chain.elements != [ + common.LocationType.Edge, + common.LocationType.Vertex, + ]: + raise ValueError("Invalid usage of weights in NeighborReduce.") + body_location = node.neighbors.chain.elements[-1] reduce_var_name = "local" + str(node.id_attr_) last_block.declarations.append( - nir.LocalVar( + nir.ScalarLocalVar( name=reduce_var_name, vtype=common.DataType.FLOAT64, # TODO location_type=node.location_type, ) ) + if node.weights: + weights_var_name = "local_weights_" + str(node.id_attr_) + last_block.declarations.append( + nir.LocalFieldVar( + name=weights_var_name, + vtype=common.DataType.FLOAT64, # TODO + domain=body_location, + init=self.visit(node.weights, **kwargs), + max_size=2, # TODO + location_type=node.location_type, + ) + ) last_block.statements.append( nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=node.location_type), @@ -142,17 +162,33 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg location_type=node.location_type, ), ) + reduction_item = self.visit(node.operand, in_neighbor_loop=True, **kwargs) + if node.weights: + reduction_item = nir.BinaryOp( + left=nir.LocalFieldAccess( + name=weights_var_name, + location=nir.NeighborLoopLocationAccess( + name=neighbor_loop_name, location_type=body_location + ), + location_type=body_location, + ), + op=common.BinaryOperator.MUL, + right=reduction_item, + location_type=body_location, + ) + + reduction_intermediate = nir.BinaryOp( + left=nir.VarAccess(name=reduce_var_name, location_type=body_location), + op=self.REDUCE_OP_TO_BINOP[node.op], + right=reduction_item, + location_type=body_location, + ) body = nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), - right=nir.BinaryOp( - left=nir.VarAccess(name=reduce_var_name, location_type=body_location), - op=self.REDUCE_OP_TO_BINOP[node.op], - right=self.visit(node.operand, in_neighbor_loop=True, **kwargs), - location_type=body_location, - ), + right=reduction_intermediate, location_type=body_location, ) ], @@ -160,6 +196,7 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg ) last_block.statements.append( nir.NeighborLoop( + name=neighbor_loop_name, neighbors=self.visit(node.neighbors.chain), body=body, location_type=node.location_type, diff --git a/src/gtc/unstructured/nir.py b/src/gtc/unstructured/nir.py index b0086aa..a958dfe 100644 --- a/src/gtc/unstructured/nir.py +++ b/src/gtc/unstructured/nir.py @@ -64,10 +64,38 @@ def __str__(self): class LocalVar(Node): name: Str - vtype: common.DataType location_type: common.LocationType +class ScalarLocalVar(LocalVar): + vtype: common.DataType + init: Optional[Expr] # TODO: use in gtir to nir lowering for reduction var + + +class TensorLocalVar(LocalVar): + vtype: common.DataType + shape: List[int] + init: Optional[List[Expr]] + + +class LocalFieldVar(TensorLocalVar): + def __init__(self, *args, max_size, **kwargs): + assert "shape" not in kwargs + return super().__init__(*args, shape=[max_size], **kwargs) + + domain: common.LocationType # the type of locations the LocalField is defined on + + @validator("shape", pre=True, always=True) + def ensure_one_dimensional(cls, shape): + if len(shape) != 1: + raise ValueError("Invalid shape for LocalFieldVar.") + return shape + + @property + def max_size(self): + return self.shape[0] + + class BlockStmt(Stmt): declarations: List[LocalVar] statements: List[Stmt] @@ -97,6 +125,7 @@ def check_location_type(cls, values): class NeighborLoop(Stmt): + name: Str neighbors: NeighborChain body: BlockStmt @@ -124,6 +153,22 @@ class VarAccess(Access): pass +# class IndexAccess(Access): # TODO(tehrengruber): use for TensorLocalVar +# indices: List[int] + + +class LocationLocalIdAccess(Access): + pass + + +class NeighborLoopLocationAccess(LocationLocalIdAccess): + pass + + +class LocalFieldAccess(Access): + location: LocationLocalIdAccess + + class AssignStmt(Stmt): left: Access # there are no local variables in gtir, only fields right: Expr diff --git a/src/gtc/unstructured/nir_to_usid.py b/src/gtc/unstructured/nir_to_usid.py index 69af9f8..a353340 100644 --- a/src/gtc/unstructured/nir_to_usid.py +++ b/src/gtc/unstructured/nir_to_usid.py @@ -77,6 +77,7 @@ def visit_Literal(self, node: nir.Literal, **kwargs): def visit_NeighborLoop(self, node: nir.NeighborLoop, **kwargs): return usid.NeighborLoop( + iter_var=node.name + "_neigh", outer_sid=kwargs["sids_tbl"][usid.NeighborChain(elements=[node.location_type])].name, connectivity=kwargs["conn_tbl"][node.neighbors].name, sid=kwargs["sids_tbl"][node.neighbors].name @@ -104,21 +105,40 @@ def visit_AssignStmt(self, node: nir.AssignStmt, **kwargs): location_type=node.location_type, ) + def visit_ScalarLocalVar(self, node: nir.ScalarLocalVar, **kwargs): + return usid.ScalarVarDecl( + name=node.name, + init=usid.Literal(value="0.0", vtype=node.vtype, location_type=kwargs["location_type"]), + vtype=node.vtype, + location_type=kwargs["location_type"], + ) + + def visit_TensorLocalVar(self, node: nir.TensorLocalVar, **kwargs): + assert len(node.shape) == 1, "Only one-dimensional arrays allowed" + return usid.StaticArrayDecl( + name=node.name, + init=[self.visit(init_el, **kwargs) for init_el in node.init] if node.init else None, + vtype=node.vtype, + length=node.shape[0], + location_type=kwargs["location_type"], # TODO: why not from node? + ) + + def visit_LocalFieldAccess(self, node: nir.LocalFieldAccess, **kwargs): + return usid.StaticArrayAccess( + name=node.name, + index=self.visit(node.location, **kwargs), + location_type=node.location_type, + ) + + def visit_NeighborLoopLocationAccess(self, node: nir.NeighborLoopLocationAccess, **kwargs): + return usid.IndexAccess(name=node.name + "_neigh", location_type=node.location_type) + def visit_BlockStmt(self, node: nir.BlockStmt, **kwargs): statements = [] - for decl in node.declarations: - statements.append( - usid.VarDecl( - name=decl.name, - init=usid.Literal( - value="0.0", vtype=decl.vtype, location_type=node.location_type - ), - vtype=decl.vtype, - location_type=node.location_type, - ) - ) - for stmt in node.statements: - statements.append(self.visit(stmt, **kwargs)) + statements += [ + self.visit(decl, location_type=node.location_type) for decl in node.declarations + ] + statements += [self.visit(stmt, **kwargs) for stmt in node.statements] return statements def visit_HorizontalLoop(self, node: nir.HorizontalLoop, **kwargs): diff --git a/src/gtc/unstructured/usid.py b/src/gtc/unstructured/usid.py index 4570096..2b61d0b 100644 --- a/src/gtc/unstructured/usid.py +++ b/src/gtc/unstructured/usid.py @@ -56,17 +56,35 @@ def __str__(self): return "_".join([common.LocationType(loc).name.lower() for loc in self.elements]) +class IndexAccess(Expr): + name: Str # TODO(tehrengruber): Maybe IndexDecl in NeighborLoop? + + class FieldAccess(Expr): name: Str # symbol ref to SidCompositeEntry sid: Str # symbol ref +class StaticArrayAccess(Expr): + name: Str + index: IndexAccess # TODO(tehrengruber): Union[Index, IndexAccess] + + class VarDecl(Stmt): name: Str + + +class ScalarVarDecl(VarDecl): init: Expr vtype: common.DataType +class StaticArrayDecl(VarDecl): + init: List[Expr] + vtype: common.DataType + length: int + + class Literal(Expr): value: Union[common.BuiltInLiteral, Str] vtype: common.DataType @@ -227,6 +245,7 @@ class NeighborLoop(Stmt): sid: Optional[ Str ] # symbol ref to SidComposite where the fields of the loop body live (None if only sparse fields are accessed) + iter_var: Str class Kernel(Node): diff --git a/src/gtc/unstructured/usid_codegen.py b/src/gtc/unstructured/usid_codegen.py index d61b181..b14f819 100644 --- a/src/gtc/unstructured/usid_codegen.py +++ b/src/gtc/unstructured/usid_codegen.py @@ -174,6 +174,8 @@ def visit_Kernel(self, node: Kernel, **kwargs): %>*gridtools::device::at_key<${ sid_entry_deref.tag_name }>(${ sid_deref.ptr_name })""" ) + StaticArrayAccess = as_fmt("{ name }[{ index }]") + AssignStmt = as_fmt("{left} = {right};") BinaryOp = as_fmt("({left} {op} {right})") @@ -185,7 +187,7 @@ def visit_Kernel(self, node: Kernel, **kwargs): conn_deref = symbol_tbl_conn[_this_node.connectivity] body_location = _this_generator.LOCATION_TYPE_TO_STR[sid_deref.location.elements[-1]] if sid_deref else None %> - for (int neigh = 0; neigh < gridtools::next::connectivity::max_neighbors(${ conn_deref.name }); ++neigh) { + for (int ${ iter_var } = 0; ${ iter_var } < gridtools::next::connectivity::max_neighbors(${ conn_deref.name }); ++${ iter_var }) { auto absolute_neigh_index = *gridtools::device::at_key<${ conn_deref.neighbor_tbl_tag }>(${ outer_sid_deref.ptr_name}); if (absolute_neigh_index != gridtools::next::connectivity::skip_value(${ conn_deref.name })) { % if sid_deref: @@ -214,10 +216,16 @@ def visit_Kernel(self, node: Kernel, **kwargs): VarAccess = as_fmt("{name}") - VarDecl = as_mako( + IndexAccess = as_fmt("{name}") + + ScalarVarDecl = as_mako( "${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] } ${ name } = ${ init };" ) + StaticArrayDecl = as_mako( + "${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] } ${ name }[${ length }] = {${ ', '.join(init) }};" + ) + def visit_Computation(self, node: Computation, **kwargs): symbol_tbl_kernel = {k.name: k for k in node.kernels} sid_tags = set() diff --git a/tests/tests_gtc/nir_utils.py b/tests/tests_gtc/nir_utils.py index 8bb22c8..a021561 100644 --- a/tests/tests_gtc/nir_utils.py +++ b/tests/tests_gtc/nir_utils.py @@ -57,8 +57,8 @@ def make_horizontal_loop_with_copy(write: Str, read: Str, read_has_extent: Bool) ) -def make_local_var(name: Str): - return nir.LocalVar(name=name, vtype=default_vtype, location_type=default_location) +def make_scalar_local_var(name: Str): + return nir.ScalarLocalVar(name=name, vtype=default_vtype, location_type=default_location) def make_init(field: Str): diff --git a/tests/tests_gtc/test_nir_merge_horizontal_loops.py b/tests/tests_gtc/test_nir_merge_horizontal_loops.py index d21c5f0..53fe3af 100644 --- a/tests/tests_gtc/test_nir_merge_horizontal_loops.py +++ b/tests/tests_gtc/test_nir_merge_horizontal_loops.py @@ -31,7 +31,7 @@ make_horizontal_loop_with_copy, make_horizontal_loop_with_init, make_init, - make_local_var, + make_scalar_local_var, make_vertical_loop, ) @@ -168,11 +168,11 @@ def test_merge_empty_loops(self): assert len(result.horizontal_loops) == 1 def test_merge_loops_with_stats_and_decls(self): - var1 = make_local_var("var1") + var1 = make_scalar_local_var("var1") assignment1, _ = make_init("field1") first_loop = make_horizontal_loop(make_block_stmt([assignment1], [var1])) - var2 = make_local_var("var2") + var2 = make_scalar_local_var("var2") assignment2, _ = make_init("field2") second_loop = make_horizontal_loop(make_block_stmt([assignment2], [var2])) @@ -187,11 +187,11 @@ def test_merge_loops_with_stats_and_decls(self): # TODO more precise checks def test_find_and_merge(self): - var1 = make_local_var("var1") + var1 = make_scalar_local_var("var1") assignment1, _ = make_init("field1") first_loop = make_horizontal_loop(make_block_stmt([assignment1], [var1])) - var2 = make_local_var("var2") + var2 = make_scalar_local_var("var2") assignment2, _ = make_init("field2") second_loop = make_horizontal_loop(make_block_stmt([assignment2], [var2])) @@ -205,11 +205,11 @@ def test_find_and_merge(self): # TODO more precise checks def test_find_and_merge_with_2_vertical_loops(self): - var1 = make_local_var("var1") + var1 = make_scalar_local_var("var1") assignment1, _ = make_init("field1") first_loop = make_horizontal_loop(make_block_stmt([assignment1], [var1])) - var2 = make_local_var("var2") + var2 = make_scalar_local_var("var2") assignment2, _ = make_init("field2") second_loop = make_horizontal_loop(make_block_stmt([assignment2], [var2])) diff --git a/tests/tests_gtc/unit_tests/stencil_definitions.py b/tests/tests_gtc/unit_tests/stencil_definitions.py index e633ac4..d229919 100644 --- a/tests/tests_gtc/unit_tests/stencil_definitions.py +++ b/tests/tests_gtc/unit_tests/stencil_definitions.py @@ -34,7 +34,22 @@ dtype = common.DataType.FLOAT64 -valid_stencils = ["edge_reduction", "sparse_ex", "nested", "fvm_nabla", "temporary_field"] +valid_stencils = [ + "edge_reduction", + "sparse_ex", + "nested", + "fvm_nabla", + "temporary_field", + "weighted_neighbor_reduction", +] + + +def weighted_neighbor_reduction( + mesh: Mesh, vertex_field: Field[Vertex, dtype], edge_field: Field[Vertex, dtype] +): + with computation(FORWARD), interval(0, None): + with location(Edge) as e: + edge_field = sum((vertex_field[v] for v in vertices(e)), weights=[1, 2]) def copy(mesh: Mesh, field_in: Field[Vertex, dtype], field_out: Field[Vertex, dtype]):