From 00f9a1c278471d2699062cb470992ac611773df8 Mon Sep 17 00:00:00 2001 From: Rajeev Jain Date: Tue, 9 Jun 2026 13:01:38 -0500 Subject: [PATCH] Support shapefile / GeoJSON meshes via geopandas; drop times from README router MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds GIS vector formats (.shp, .geojson) as first-class mesh inputs alongside the existing UGRID / MPAS / SCRIP / NetCDF / HEALPix paths. Loaders: - domain/mesh.py: new load_dataset() companion to load_grid(). Both branch on extension: .shp / .geojson go through ux.Grid.from_file(..., backend="geopandas"), HEALPix specs continue to use ux.Grid.from_healpix(), everything else falls through to ux.open_grid / ux.open_dataset. - domain/__init__.py: re-export load_dataset. Tools — route every local file open through the new loaders so the GIS path applies uniformly (no behaviour change for existing formats): - tools/inspection.py, tools/plotting.py, tools/capabilities.py, tools/advanced.py, tools/vector_calc.py: ux.open_dataset / ux.open_grid → load_dataset / load_grid. Remote — Globus Compute serializes each remote_* function body and ships it to the worker, so closures over a module-level helper aren't reliable across SDK versions. Each remote_* function inlines the same ~6 lines of extension dispatch. A NOTE at the top of remote/compute_functions.py explains why and warns "change one, change all". The earlier draft also contained unused _remote_load_grid / _remote_load_dataset helpers that have been removed to prevent a future maintainer from refactoring them in and breaking serialization. Extension list: dropped .shx / .dbf from the accepted-extension list — only .shp is a valid entry point (geopandas picks up the siblings automatically), and a user passing .shx or .dbf directly would have gotten a confusing geopandas error rather than the helpful fall-through. Tests: - tests/test_inspect_mesh.py: +test_inspect_shapefile_mesh, +test_inspect_geojson_mesh. - tests/test_plotting.py: 5 tests updated to patch load_dataset instead of the now-unused ux.open_dataset import path. - 295 / 295 pass locally. - Verified against real fixtures in ~/uxarray/test/meshfiles: outCSne30 ugrid (5400 faces), ne30pg2 scrip (21600 faces), chicago_neighborhoods.shp (101 faces), sample_chicago_buildings.geojson (10 faces). - Verified remote NetCDF path unchanged: inspect_mesh_remote + calculate_area_remote against the Polaris oi240lr240 base_mesh on chrysalis return identical results to prior runs (10302 faces, total_area = 4π unit sphere) with execution_venue=hpc:chrysalis. README: dropped "(5 min)", "(15 min)" suffixes from the four-row router for consistency; concrete times in step-by-step docs are kept. --- README.md | 9 +- src/uxarray_mcp/domain/__init__.py | 3 +- src/uxarray_mcp/domain/mesh.py | 41 ++++++- src/uxarray_mcp/remote/compute_functions.py | 125 +++++++++++++++++--- src/uxarray_mcp/tools/advanced.py | 10 +- src/uxarray_mcp/tools/capabilities.py | 5 +- src/uxarray_mcp/tools/inspection.py | 10 +- src/uxarray_mcp/tools/plotting.py | 7 +- src/uxarray_mcp/tools/vector_calc.py | 13 +- tests/test_inspect_mesh.py | 32 +++++ tests/test_plotting.py | 30 ++--- 11 files changed, 223 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index d75b9d3..0d7a0bc 100644 --- a/README.md +++ b/README.md @@ -26,19 +26,18 @@ remotely on an HPC system you have access to. You are most likely one of: -1. **Local user** — laptop only, no HPC. → [Local install](#local-install) (5 min). +1. **Local user** — laptop only, no HPC. → [Local install](#local-install). 2. **HPC user, endpoint already exists** — someone at your lab gave you a Globus Compute endpoint UUID. → [Local install](#local-install), then - [docs/remote-hpc.md](docs/remote-hpc.md) (15 min). + [docs/remote-hpc.md](docs/remote-hpc.md). 3. **HPC user, your own personal endpoint** — you have a Globus identity and shell access to an HPC machine, and want to stand up an endpoint just for yourself. → [Local install](#local-install), then - [docs/operating-an-endpoint.md](docs/operating-an-endpoint.md#solo-personal-endpoint-quickstart) - (~30 min). + [docs/operating-an-endpoint.md](docs/operating-an-endpoint.md#solo-personal-endpoint-quickstart). 4. **Group / shared endpoint operator** — you're standing one up for a team, project, or lab. → [Local install](#local-install), then the full [docs/operating-an-endpoint.md](docs/operating-an-endpoint.md) including - service-account migration and the MEP allowlist (~1 hr+, site-dependent). + service-account migration and the MEP allowlist. --- diff --git a/src/uxarray_mcp/domain/__init__.py b/src/uxarray_mcp/domain/__init__.py index 24cd380..d3c3b29 100644 --- a/src/uxarray_mcp/domain/__init__.py +++ b/src/uxarray_mcp/domain/__init__.py @@ -5,7 +5,7 @@ """ from .area import compute_area_stats -from .mesh import load_grid +from .mesh import load_dataset, load_grid from .variable import compute_variable_info from .vector_calc import ( compute_azimuthal_mean, @@ -17,6 +17,7 @@ __all__ = [ "load_grid", + "load_dataset", "compute_area_stats", "compute_variable_info", "compute_zonal_mean_stats", diff --git a/src/uxarray_mcp/domain/mesh.py b/src/uxarray_mcp/domain/mesh.py index 3ebc6bd..1952248 100644 --- a/src/uxarray_mcp/domain/mesh.py +++ b/src/uxarray_mcp/domain/mesh.py @@ -1,10 +1,11 @@ -"""Shared grid loading with HEALPix support.""" +"""Shared grid loading with HEALPix and GIS support.""" +import os from typing import Any def load_grid(file_path: str) -> Any: - """Load a UXarray Grid from a file path or HEALPix spec. + """Load a UXarray Grid from a file path, HEALPix spec, or shapefile/geojson. Parameters ---------- @@ -23,4 +24,40 @@ def load_grid(file_path: str) -> Any: zoom = int(parts[1]) if len(parts) > 1 else 1 return ux.Grid.from_healpix(zoom=zoom) + ext = os.path.splitext(file_path.lower())[1] + if ext in [".shp", ".geojson"]: + return ux.Grid.from_file(file_path, backend="geopandas") + return ux.open_grid(file_path) + + +def load_dataset(grid_path: str, data_path: str) -> Any: + """Load a UXarray Dataset from grid and data paths, supporting shapefiles/geojson. + + Parameters + ---------- + grid_path : str + Path to mesh grid file or "healpix:". + data_path : str + Path to netCDF data file. + + Returns + ------- + ux.UxDataset + Loaded dataset object. + """ + import uxarray as ux + + # HEALPix is a special case (usually grid-only, but we support it if matched) + if grid_path.lower().startswith("healpix"): + parts = grid_path.split(":") + zoom = int(parts[1]) if len(parts) > 1 else 1 + grid = ux.Grid.from_healpix(zoom=zoom) + return ux.open_dataset(grid.to_xarray(), data_path) + + ext = os.path.splitext(grid_path.lower())[1] + if ext in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + return ux.open_dataset(grid.to_xarray(), data_path) + + return ux.open_dataset(grid_path, data_path) diff --git a/src/uxarray_mcp/remote/compute_functions.py b/src/uxarray_mcp/remote/compute_functions.py index 5570063..154d961 100644 --- a/src/uxarray_mcp/remote/compute_functions.py +++ b/src/uxarray_mcp/remote/compute_functions.py @@ -9,6 +9,13 @@ from typing import Any, Dict, Optional +# NOTE on the inline branching below: Globus Compute serializes each remote_* +# function body and ships it to the worker. Closures over module-level helpers +# such as _remote_load_grid don't survive serialization reliably across SDK +# versions, so each function inlines the same ~6 lines of extension dispatch +# (HEALPix spec, .shp / .geojson via geopandas, else default open_grid / +# open_dataset). If you change one, change them all. + def remote_runtime_probe() -> Dict[str, Any]: """Return lightweight runtime diagnostics from the remote worker.""" @@ -138,11 +145,14 @@ def remote_inspect_mesh(file_path: str) -> Dict[str, Any]: This function executes on the HPC endpoint, not locally. All imports must be within function scope for serialization. """ + import os + import uxarray as ux if file_path.startswith("healpix:"): - zoom = int(file_path.split(":")[1]) - grid = ux.Grid.from_healpix(zoom) + grid = ux.Grid.from_healpix(int(file_path.split(":")[1])) + elif os.path.splitext(file_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(file_path, backend="geopandas") else: grid = ux.open_grid(file_path) @@ -172,12 +182,15 @@ def remote_calculate_area(file_path: str) -> Dict[str, Any]: This function executes on the HPC endpoint, not locally. All imports must be within function scope for serialization. """ + import os + import numpy as np import uxarray as ux if file_path.startswith("healpix:"): - zoom = int(file_path.split(":")[1]) - grid = ux.Grid.from_healpix(zoom) + grid = ux.Grid.from_healpix(int(file_path.split(":")[1])) + elif os.path.splitext(file_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(file_path, backend="geopandas") else: grid = ux.open_grid(file_path) @@ -218,10 +231,19 @@ def remote_inspect_variable( ----- This function executes on the HPC endpoint, not locally. """ + import os + import numpy as np import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) face_dims = {"n_face", "nCells"} node_dims = {"n_node", "nVertices"} @@ -312,12 +334,15 @@ def remote_plot_mesh( import matplotlib matplotlib.use("Agg") + import os + import matplotlib.pyplot as plt import uxarray as ux if grid_path.startswith("healpix:"): - zoom = int(grid_path.split(":")[1]) - grid = ux.Grid.from_healpix(zoom) + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") else: grid = ux.open_grid(grid_path) @@ -402,10 +427,19 @@ def remote_plot_variable( import matplotlib matplotlib.use("Agg") + import os + import matplotlib.pyplot as plt import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) face_dims = {"n_face", "nCells"} @@ -531,10 +565,19 @@ def remote_plot_zonal_mean( import matplotlib matplotlib.use("Agg") + import os + import matplotlib.pyplot as plt import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) if variable_name not in uxds: raise ValueError(f"Variable '{variable_name}' not found") @@ -612,9 +655,18 @@ def remote_calculate_zonal_mean( ----- This function executes on the HPC endpoint, not locally. """ + import os + import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) if variable_name not in uxds: raise ValueError(f"Variable '{variable_name}' not found") @@ -704,10 +756,17 @@ def remote_subset_bbox_plot( import matplotlib matplotlib.use("Agg") + import os + import matplotlib.pyplot as plt import uxarray as ux - grid = ux.open_grid(grid_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + else: + grid = ux.open_grid(grid_path) n_face_total = int(grid.n_face) # full-mesh mean area @@ -979,10 +1038,19 @@ def remote_calculate_gradient( grid_path: str, data_path: str, variable_name: str ) -> Dict[str, Any]: """Compute the spatial gradient of a face-centered scalar field on HPC.""" + import os + import numpy as np import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) if variable_name not in uxds.data_vars: raise ValueError( f"Variable '{variable_name}' not found. Available: {list(uxds.data_vars)}" @@ -1024,10 +1092,19 @@ def remote_calculate_curl( zeta = dv/dx - du/dy """ + import os + import numpy as np import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) for name in (u_variable, v_variable): if name not in uxds.data_vars: raise ValueError( @@ -1070,10 +1147,19 @@ def remote_calculate_divergence( divergence = du/dx + dv/dy """ + import os + import numpy as np import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) for name in (u_variable, v_variable): if name not in uxds.data_vars: raise ValueError( @@ -1119,9 +1205,18 @@ def remote_calculate_azimuthal_mean( radius_step: float, ) -> Dict[str, Any]: """Compute the azimuthal (radial) mean around a centre point on HPC.""" + import os + import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) if variable_name not in uxds.data_vars: raise ValueError( f"Variable '{variable_name}' not found. Available: {list(uxds.data_vars)}" diff --git a/src/uxarray_mcp/tools/advanced.py b/src/uxarray_mcp/tools/advanced.py index 67b7c27..8248912 100644 --- a/src/uxarray_mcp/tools/advanced.py +++ b/src/uxarray_mcp/tools/advanced.py @@ -12,7 +12,7 @@ import xarray as xr from matplotlib.path import Path as MplPath -from uxarray_mcp.domain.mesh import load_grid +from uxarray_mcp.domain.mesh import load_dataset, load_grid from uxarray_mcp.provenance import attach_provenance from uxarray_mcp.state import ( OperationTracker, @@ -58,7 +58,7 @@ def _load_dataarray( data_path: str, variable_name: str | None, ) -> tuple[ux.UxDataset, ux.UxDataArray, str]: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) selected = variable_name if selected is None: for name, var in uxds.data_vars.items(): @@ -434,8 +434,8 @@ def _load_comparison_arrays( variable_name: str, ) -> tuple[xr.DataArray, xr.DataArray]: if grid_path: - first = ux.open_dataset(grid_path, data_path_a)[variable_name].to_xarray() - second = ux.open_dataset(grid_path, data_path_b)[variable_name].to_xarray() + first = load_dataset(grid_path, data_path_a)[variable_name].to_xarray() + second = load_dataset(grid_path, data_path_b)[variable_name].to_xarray() else: first = xr.open_dataset(data_path_a)[variable_name] second = xr.open_dataset(data_path_b)[variable_name] @@ -702,7 +702,7 @@ def regrid_dataset( if resolved_data is None: raise ValueError("data_path is required to regrid a dataset.") - uxds = ux.open_dataset(resolved_grid, resolved_data) + uxds = load_dataset(resolved_grid, resolved_data) target_grid = load_grid(target_grid_path) if not hasattr(uxds[next(iter(uxds.data_vars))].remap, method): raise ValueError( diff --git a/src/uxarray_mcp/tools/capabilities.py b/src/uxarray_mcp/tools/capabilities.py index 9f11f8f..cd43f7e 100644 --- a/src/uxarray_mcp/tools/capabilities.py +++ b/src/uxarray_mcp/tools/capabilities.py @@ -5,6 +5,7 @@ import uxarray as ux +from uxarray_mcp.domain import load_dataset, load_grid from uxarray_mcp.provenance import attach_provenance from uxarray_mcp.remote.config import load_config @@ -74,7 +75,7 @@ def get_capabilities( if not grid_file.exists(): raise FileNotFoundError(f"Grid file not found: {grid_path}") try: - grid = ux.open_grid(grid_path) + grid = load_grid(grid_path) except Exception as e: raise RuntimeError(f"Failed to load grid file: {e}") from e grid_format = str(getattr(grid, "source_grid_spec", "Unknown")) @@ -110,7 +111,7 @@ def get_capabilities( if not data_file.exists(): raise FileNotFoundError(f"Data file not found: {data_path}") try: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) except Exception as e: raise RuntimeError(f"Failed to load dataset: {e}") from e diff --git a/src/uxarray_mcp/tools/inspection.py b/src/uxarray_mcp/tools/inspection.py index c7a27e7..a4d2fe1 100644 --- a/src/uxarray_mcp/tools/inspection.py +++ b/src/uxarray_mcp/tools/inspection.py @@ -6,12 +6,12 @@ from typing import Any, Dict, Optional import numpy as np -import uxarray as ux from uxarray_mcp.domain import ( compute_area_stats, compute_variable_info, compute_zonal_mean_stats, + load_dataset, load_grid, ) from uxarray_mcp.provenance import attach_provenance @@ -77,7 +77,7 @@ def _inspect_mesh_local(file_path: str) -> Dict[str, Any]: file_size_mb = path.stat().st_size / (1024 * 1024) try: - grid = ux.open_grid(file_path) + grid = load_grid(file_path) except Exception as e: raise RuntimeError(f"Failed to load mesh file: {str(e)}") @@ -151,7 +151,7 @@ def _inspect_variable_local( raise FileNotFoundError(f"Data file not found: {data_path}") try: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) except Exception as e: raise RuntimeError(f"Failed to load dataset: {str(e)}") @@ -291,7 +291,7 @@ def _calculate_zonal_mean_local( raise FileNotFoundError(f"Data file not found: {data_path}") try: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) except Exception as e: raise RuntimeError(f"Failed to load dataset: {str(e)}") @@ -367,7 +367,7 @@ def validate_dataset(grid_path: str, data_path: str) -> Dict[str, Any]: raise FileNotFoundError(f"Data file not found: {data_path}") try: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) except Exception as e: raise RuntimeError(f"Failed to load dataset: {str(e)}") diff --git a/src/uxarray_mcp/tools/plotting.py b/src/uxarray_mcp/tools/plotting.py index 35bb94a..2b35cf6 100644 --- a/src/uxarray_mcp/tools/plotting.py +++ b/src/uxarray_mcp/tools/plotting.py @@ -5,10 +5,9 @@ from pathlib import Path from typing import Any, Optional -import uxarray as ux from mcp.types import ImageContent, TextContent -from uxarray_mcp.domain.mesh import load_grid +from uxarray_mcp.domain.mesh import load_dataset, load_grid from uxarray_mcp.domain.plotting import ( render_mesh, render_mesh_geo, @@ -616,7 +615,7 @@ def _plot_variable_local( "The file may not have been written correctly." ) - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) if variable_name is None: for var in uxds.data_vars: @@ -794,7 +793,7 @@ def _plot_zonal_mean_local( "The file may not have been written correctly." ) - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) zonal_result = compute_zonal_mean_stats( uxds, variable_name, lat_spec=lat_spec, conservative=conservative diff --git a/src/uxarray_mcp/tools/vector_calc.py b/src/uxarray_mcp/tools/vector_calc.py index d88c605..2053144 100644 --- a/src/uxarray_mcp/tools/vector_calc.py +++ b/src/uxarray_mcp/tools/vector_calc.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional +from uxarray_mcp.domain.mesh import load_dataset from uxarray_mcp.state import OperationTracker from .remote_tools import ( @@ -120,7 +121,6 @@ def calculate_gradient( >>> calculate_gradient("grid.nc", "data.nc", "temperature") {"components": ["d_temperature_d_x", "d_temperature_d_y"], ...} """ - import uxarray as ux from uxarray_mcp.domain.vector_calc import compute_gradient from uxarray_mcp.provenance import attach_provenance @@ -132,7 +132,7 @@ def calculate_gradient( } def _local(): - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) return attach_provenance( compute_gradient(uxds, variable_name), tool="calculate_gradient", @@ -203,7 +203,6 @@ def calculate_curl( >>> calculate_curl("/hpc/grid.nc", "/hpc/data.nc", "u", "v", use_remote=True) {"stats": {...}, "_provenance": {"execution_venue": "hpc:...", ...}} """ - import uxarray as ux from uxarray_mcp.domain.vector_calc import compute_curl from uxarray_mcp.provenance import attach_provenance @@ -216,7 +215,7 @@ def calculate_curl( } def _local(): - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) return attach_provenance( compute_curl(uxds, u_variable, v_variable), tool="calculate_curl", @@ -287,7 +286,6 @@ def calculate_divergence( ... ) {"interpretation": "horizontal divergence du/dx + dv/dy", ...} """ - import uxarray as ux from uxarray_mcp.domain.vector_calc import compute_divergence from uxarray_mcp.provenance import attach_provenance @@ -300,7 +298,7 @@ def calculate_divergence( } def _local(): - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) return attach_provenance( compute_divergence(uxds, u_variable, v_variable), tool="calculate_divergence", @@ -386,7 +384,6 @@ def calculate_azimuthal_mean( ... ) {"radii_deg": [0.0, 0.5, 1.0, ...], "azimuthal_mean_values": [...], ...} """ - import uxarray as ux from uxarray_mcp.domain.vector_calc import compute_azimuthal_mean from uxarray_mcp.provenance import attach_provenance @@ -402,7 +399,7 @@ def calculate_azimuthal_mean( } def _local(): - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) return attach_provenance( compute_azimuthal_mean( uxds, variable_name, center_lon, center_lat, outer_radius, radius_step diff --git a/tests/test_inspect_mesh.py b/tests/test_inspect_mesh.py index 8e77c50..a3a68d0 100644 --- a/tests/test_inspect_mesh.py +++ b/tests/test_inspect_mesh.py @@ -49,6 +49,38 @@ def test_inspect_scrip_mesh(self, scrip_grid): result = inspect_mesh("/path/to/scrip.nc") assert result["format"] == "SCRIP" + def test_inspect_shapefile_mesh(self, base_grid): + """Test inspection of a Shapefile mesh.""" + base_grid.source_grid_spec = "Shapefile" + with ( + patch("uxarray_mcp.tools.inspection.Path") as MockPath, + patch("uxarray.Grid.from_file", return_value=base_grid) as mock_from_file, + ): + MockPath.return_value.exists.return_value = True + MockPath.return_value.stat.return_value.st_size = 1024 * 1024 + + result = inspect_mesh("/path/to/shapefile.shp") + mock_from_file.assert_called_once_with( + "/path/to/shapefile.shp", backend="geopandas" + ) + assert result["format"] == "Shapefile" + + def test_inspect_geojson_mesh(self, base_grid): + """Test inspection of a GeoJSON mesh.""" + base_grid.source_grid_spec = "GeoJSON" + with ( + patch("uxarray_mcp.tools.inspection.Path") as MockPath, + patch("uxarray.Grid.from_file", return_value=base_grid) as mock_from_file, + ): + MockPath.return_value.exists.return_value = True + MockPath.return_value.stat.return_value.st_size = 1024 * 1024 + + result = inspect_mesh("/path/to/geojson.geojson") + mock_from_file.assert_called_once_with( + "/path/to/geojson.geojson", backend="geopandas" + ) + assert result["format"] == "GeoJSON" + def test_inspect_healpix_mesh(self, base_grid): """Test inspection of a HEALPix mesh generation.""" # Mock grid returned by from_healpix diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 235f549..2bf8be1 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -160,8 +160,8 @@ class TestPlotVariableMocked: """Tests for plot_variable tool using mocks.""" @patch("uxarray_mcp.tools.plotting.render_variable", return_value=b"\x89PNG_var") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_variable_auto_select(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_variable_auto_select(self, mock_load_dataset, mock_render): mock_var = MagicMock() mock_var.dims = ("n_face",) mock_uxds = MagicMock() @@ -170,7 +170,7 @@ def test_plot_variable_auto_select(self, mock_ux, mock_render): mock_uxds.uxgrid.n_face = 100 mock_uxds.uxgrid.n_node = 200 mock_uxds.uxgrid.n_edge = 300 - mock_ux.open_dataset.return_value = mock_uxds + mock_load_dataset.return_value = mock_uxds with patch("uxarray_mcp.tools.plotting.Path") as MockPath: MockPath.return_value.exists.return_value = True @@ -201,8 +201,8 @@ def test_empty_data_file_raises(self): Path(empty_path).unlink(missing_ok=True) @patch("uxarray_mcp.tools.plotting.render_variable", return_value=b"\x89PNG_var") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_variable_vmin_vmax(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_variable_vmin_vmax(self, mock_load_dataset, mock_render): """vmin/vmax are passed through to render_variable.""" mock_var = MagicMock() mock_var.dims = ("n_face",) @@ -212,7 +212,7 @@ def test_plot_variable_vmin_vmax(self, mock_ux, mock_render): mock_uxds.uxgrid.n_face = 100 mock_uxds.uxgrid.n_node = 200 mock_uxds.uxgrid.n_edge = 300 - mock_ux.open_dataset.return_value = mock_uxds + mock_load_dataset.return_value = mock_uxds with patch("uxarray_mcp.tools.plotting.Path") as MockPath: MockPath.return_value.exists.return_value = True @@ -226,8 +226,8 @@ def test_plot_variable_vmin_vmax(self, mock_ux, mock_render): assert kwargs["vmax"] == 5.0 @patch("uxarray_mcp.tools.plotting.render_variable", return_value=b"\x89PNG_var") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_variable_custom_title(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_variable_custom_title(self, mock_load_dataset, mock_render): """title is passed through to render_variable.""" mock_var = MagicMock() mock_var.dims = ("n_face",) @@ -237,7 +237,7 @@ def test_plot_variable_custom_title(self, mock_ux, mock_render): mock_uxds.uxgrid.n_face = 100 mock_uxds.uxgrid.n_node = 200 mock_uxds.uxgrid.n_edge = 300 - mock_ux.open_dataset.return_value = mock_uxds + mock_load_dataset.return_value = mock_uxds with patch("uxarray_mcp.tools.plotting.Path") as MockPath: MockPath.return_value.exists.return_value = True @@ -258,10 +258,10 @@ class TestPlotZonalMeanMocked: """Tests for plot_zonal_mean tool using mocks.""" @patch("uxarray_mcp.tools.plotting.render_zonal_mean", return_value=b"\x89PNG_zm") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_zonal_mean_basic(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_zonal_mean_basic(self, mock_load_dataset, mock_render): mock_uxds = MagicMock() - mock_ux.open_dataset.return_value = mock_uxds + mock_load_dataset.return_value = mock_uxds zonal_stats = { "latitudes": [-90.0, 0.0, 90.0], @@ -291,10 +291,10 @@ def test_plot_zonal_mean_basic(self, mock_ux, mock_render): assert prov["_provenance"]["tool"] == "plot_zonal_mean" @patch("uxarray_mcp.tools.plotting.render_zonal_mean", return_value=b"\x89PNG_zm") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_zonal_mean_line_color_and_title(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_zonal_mean_line_color_and_title(self, mock_load_dataset, mock_render): """line_color and title are passed through to render_zonal_mean.""" - mock_ux.open_dataset.return_value = MagicMock() + mock_load_dataset.return_value = MagicMock() zonal_stats = { "latitudes": [-90.0, 0.0, 90.0], "zonal_mean_values": [270.0, 300.0, 270.0],