diff --git a/README.md b/README.md index d75b9d3..0d7a0bc 100644 --- a/README.md +++ b/README.md @@ -26,19 +26,18 @@ remotely on an HPC system you have access to. You are most likely one of: -1. **Local user** — laptop only, no HPC. → [Local install](#local-install) (5 min). +1. **Local user** — laptop only, no HPC. → [Local install](#local-install). 2. **HPC user, endpoint already exists** — someone at your lab gave you a Globus Compute endpoint UUID. → [Local install](#local-install), then - [docs/remote-hpc.md](docs/remote-hpc.md) (15 min). + [docs/remote-hpc.md](docs/remote-hpc.md). 3. **HPC user, your own personal endpoint** — you have a Globus identity and shell access to an HPC machine, and want to stand up an endpoint just for yourself. → [Local install](#local-install), then - [docs/operating-an-endpoint.md](docs/operating-an-endpoint.md#solo-personal-endpoint-quickstart) - (~30 min). + [docs/operating-an-endpoint.md](docs/operating-an-endpoint.md#solo-personal-endpoint-quickstart). 4. **Group / shared endpoint operator** — you're standing one up for a team, project, or lab. → [Local install](#local-install), then the full [docs/operating-an-endpoint.md](docs/operating-an-endpoint.md) including - service-account migration and the MEP allowlist (~1 hr+, site-dependent). + service-account migration and the MEP allowlist. --- diff --git a/src/uxarray_mcp/domain/__init__.py b/src/uxarray_mcp/domain/__init__.py index 24cd380..d3c3b29 100644 --- a/src/uxarray_mcp/domain/__init__.py +++ b/src/uxarray_mcp/domain/__init__.py @@ -5,7 +5,7 @@ """ from .area import compute_area_stats -from .mesh import load_grid +from .mesh import load_dataset, load_grid from .variable import compute_variable_info from .vector_calc import ( compute_azimuthal_mean, @@ -17,6 +17,7 @@ __all__ = [ "load_grid", + "load_dataset", "compute_area_stats", "compute_variable_info", "compute_zonal_mean_stats", diff --git a/src/uxarray_mcp/domain/mesh.py b/src/uxarray_mcp/domain/mesh.py index 3ebc6bd..1952248 100644 --- a/src/uxarray_mcp/domain/mesh.py +++ b/src/uxarray_mcp/domain/mesh.py @@ -1,10 +1,11 @@ -"""Shared grid loading with HEALPix support.""" +"""Shared grid loading with HEALPix and GIS support.""" +import os from typing import Any def load_grid(file_path: str) -> Any: - """Load a UXarray Grid from a file path or HEALPix spec. + """Load a UXarray Grid from a file path, HEALPix spec, or shapefile/geojson. Parameters ---------- @@ -23,4 +24,40 @@ def load_grid(file_path: str) -> Any: zoom = int(parts[1]) if len(parts) > 1 else 1 return ux.Grid.from_healpix(zoom=zoom) + ext = os.path.splitext(file_path.lower())[1] + if ext in [".shp", ".geojson"]: + return ux.Grid.from_file(file_path, backend="geopandas") + return ux.open_grid(file_path) + + +def load_dataset(grid_path: str, data_path: str) -> Any: + """Load a UXarray Dataset from grid and data paths, supporting shapefiles/geojson. + + Parameters + ---------- + grid_path : str + Path to mesh grid file or "healpix:". + data_path : str + Path to netCDF data file. + + Returns + ------- + ux.UxDataset + Loaded dataset object. + """ + import uxarray as ux + + # HEALPix is a special case (usually grid-only, but we support it if matched) + if grid_path.lower().startswith("healpix"): + parts = grid_path.split(":") + zoom = int(parts[1]) if len(parts) > 1 else 1 + grid = ux.Grid.from_healpix(zoom=zoom) + return ux.open_dataset(grid.to_xarray(), data_path) + + ext = os.path.splitext(grid_path.lower())[1] + if ext in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + return ux.open_dataset(grid.to_xarray(), data_path) + + return ux.open_dataset(grid_path, data_path) diff --git a/src/uxarray_mcp/remote/compute_functions.py b/src/uxarray_mcp/remote/compute_functions.py index 5570063..154d961 100644 --- a/src/uxarray_mcp/remote/compute_functions.py +++ b/src/uxarray_mcp/remote/compute_functions.py @@ -9,6 +9,13 @@ from typing import Any, Dict, Optional +# NOTE on the inline branching below: Globus Compute serializes each remote_* +# function body and ships it to the worker. Closures over module-level helpers +# such as _remote_load_grid don't survive serialization reliably across SDK +# versions, so each function inlines the same ~6 lines of extension dispatch +# (HEALPix spec, .shp / .geojson via geopandas, else default open_grid / +# open_dataset). If you change one, change them all. + def remote_runtime_probe() -> Dict[str, Any]: """Return lightweight runtime diagnostics from the remote worker.""" @@ -138,11 +145,14 @@ def remote_inspect_mesh(file_path: str) -> Dict[str, Any]: This function executes on the HPC endpoint, not locally. All imports must be within function scope for serialization. """ + import os + import uxarray as ux if file_path.startswith("healpix:"): - zoom = int(file_path.split(":")[1]) - grid = ux.Grid.from_healpix(zoom) + grid = ux.Grid.from_healpix(int(file_path.split(":")[1])) + elif os.path.splitext(file_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(file_path, backend="geopandas") else: grid = ux.open_grid(file_path) @@ -172,12 +182,15 @@ def remote_calculate_area(file_path: str) -> Dict[str, Any]: This function executes on the HPC endpoint, not locally. All imports must be within function scope for serialization. """ + import os + import numpy as np import uxarray as ux if file_path.startswith("healpix:"): - zoom = int(file_path.split(":")[1]) - grid = ux.Grid.from_healpix(zoom) + grid = ux.Grid.from_healpix(int(file_path.split(":")[1])) + elif os.path.splitext(file_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(file_path, backend="geopandas") else: grid = ux.open_grid(file_path) @@ -218,10 +231,19 @@ def remote_inspect_variable( ----- This function executes on the HPC endpoint, not locally. """ + import os + import numpy as np import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) face_dims = {"n_face", "nCells"} node_dims = {"n_node", "nVertices"} @@ -312,12 +334,15 @@ def remote_plot_mesh( import matplotlib matplotlib.use("Agg") + import os + import matplotlib.pyplot as plt import uxarray as ux if grid_path.startswith("healpix:"): - zoom = int(grid_path.split(":")[1]) - grid = ux.Grid.from_healpix(zoom) + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") else: grid = ux.open_grid(grid_path) @@ -402,10 +427,19 @@ def remote_plot_variable( import matplotlib matplotlib.use("Agg") + import os + import matplotlib.pyplot as plt import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) face_dims = {"n_face", "nCells"} @@ -531,10 +565,19 @@ def remote_plot_zonal_mean( import matplotlib matplotlib.use("Agg") + import os + import matplotlib.pyplot as plt import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) if variable_name not in uxds: raise ValueError(f"Variable '{variable_name}' not found") @@ -612,9 +655,18 @@ def remote_calculate_zonal_mean( ----- This function executes on the HPC endpoint, not locally. """ + import os + import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) if variable_name not in uxds: raise ValueError(f"Variable '{variable_name}' not found") @@ -704,10 +756,17 @@ def remote_subset_bbox_plot( import matplotlib matplotlib.use("Agg") + import os + import matplotlib.pyplot as plt import uxarray as ux - grid = ux.open_grid(grid_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + else: + grid = ux.open_grid(grid_path) n_face_total = int(grid.n_face) # full-mesh mean area @@ -979,10 +1038,19 @@ def remote_calculate_gradient( grid_path: str, data_path: str, variable_name: str ) -> Dict[str, Any]: """Compute the spatial gradient of a face-centered scalar field on HPC.""" + import os + import numpy as np import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) if variable_name not in uxds.data_vars: raise ValueError( f"Variable '{variable_name}' not found. Available: {list(uxds.data_vars)}" @@ -1024,10 +1092,19 @@ def remote_calculate_curl( zeta = dv/dx - du/dy """ + import os + import numpy as np import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) for name in (u_variable, v_variable): if name not in uxds.data_vars: raise ValueError( @@ -1070,10 +1147,19 @@ def remote_calculate_divergence( divergence = du/dx + dv/dy """ + import os + import numpy as np import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) for name in (u_variable, v_variable): if name not in uxds.data_vars: raise ValueError( @@ -1119,9 +1205,18 @@ def remote_calculate_azimuthal_mean( radius_step: float, ) -> Dict[str, Any]: """Compute the azimuthal (radial) mean around a centre point on HPC.""" + import os + import uxarray as ux - uxds = ux.open_dataset(grid_path, data_path) + if grid_path.startswith("healpix:"): + grid = ux.Grid.from_healpix(int(grid_path.split(":")[1])) + uxds = ux.open_dataset(grid.to_xarray(), data_path) + elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]: + grid = ux.Grid.from_file(grid_path, backend="geopandas") + uxds = ux.open_dataset(grid.to_xarray(), data_path) + else: + uxds = ux.open_dataset(grid_path, data_path) if variable_name not in uxds.data_vars: raise ValueError( f"Variable '{variable_name}' not found. Available: {list(uxds.data_vars)}" diff --git a/src/uxarray_mcp/tools/advanced.py b/src/uxarray_mcp/tools/advanced.py index 67b7c27..8248912 100644 --- a/src/uxarray_mcp/tools/advanced.py +++ b/src/uxarray_mcp/tools/advanced.py @@ -12,7 +12,7 @@ import xarray as xr from matplotlib.path import Path as MplPath -from uxarray_mcp.domain.mesh import load_grid +from uxarray_mcp.domain.mesh import load_dataset, load_grid from uxarray_mcp.provenance import attach_provenance from uxarray_mcp.state import ( OperationTracker, @@ -58,7 +58,7 @@ def _load_dataarray( data_path: str, variable_name: str | None, ) -> tuple[ux.UxDataset, ux.UxDataArray, str]: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) selected = variable_name if selected is None: for name, var in uxds.data_vars.items(): @@ -434,8 +434,8 @@ def _load_comparison_arrays( variable_name: str, ) -> tuple[xr.DataArray, xr.DataArray]: if grid_path: - first = ux.open_dataset(grid_path, data_path_a)[variable_name].to_xarray() - second = ux.open_dataset(grid_path, data_path_b)[variable_name].to_xarray() + first = load_dataset(grid_path, data_path_a)[variable_name].to_xarray() + second = load_dataset(grid_path, data_path_b)[variable_name].to_xarray() else: first = xr.open_dataset(data_path_a)[variable_name] second = xr.open_dataset(data_path_b)[variable_name] @@ -702,7 +702,7 @@ def regrid_dataset( if resolved_data is None: raise ValueError("data_path is required to regrid a dataset.") - uxds = ux.open_dataset(resolved_grid, resolved_data) + uxds = load_dataset(resolved_grid, resolved_data) target_grid = load_grid(target_grid_path) if not hasattr(uxds[next(iter(uxds.data_vars))].remap, method): raise ValueError( diff --git a/src/uxarray_mcp/tools/capabilities.py b/src/uxarray_mcp/tools/capabilities.py index 9f11f8f..cd43f7e 100644 --- a/src/uxarray_mcp/tools/capabilities.py +++ b/src/uxarray_mcp/tools/capabilities.py @@ -5,6 +5,7 @@ import uxarray as ux +from uxarray_mcp.domain import load_dataset, load_grid from uxarray_mcp.provenance import attach_provenance from uxarray_mcp.remote.config import load_config @@ -74,7 +75,7 @@ def get_capabilities( if not grid_file.exists(): raise FileNotFoundError(f"Grid file not found: {grid_path}") try: - grid = ux.open_grid(grid_path) + grid = load_grid(grid_path) except Exception as e: raise RuntimeError(f"Failed to load grid file: {e}") from e grid_format = str(getattr(grid, "source_grid_spec", "Unknown")) @@ -110,7 +111,7 @@ def get_capabilities( if not data_file.exists(): raise FileNotFoundError(f"Data file not found: {data_path}") try: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) except Exception as e: raise RuntimeError(f"Failed to load dataset: {e}") from e diff --git a/src/uxarray_mcp/tools/inspection.py b/src/uxarray_mcp/tools/inspection.py index c7a27e7..a4d2fe1 100644 --- a/src/uxarray_mcp/tools/inspection.py +++ b/src/uxarray_mcp/tools/inspection.py @@ -6,12 +6,12 @@ from typing import Any, Dict, Optional import numpy as np -import uxarray as ux from uxarray_mcp.domain import ( compute_area_stats, compute_variable_info, compute_zonal_mean_stats, + load_dataset, load_grid, ) from uxarray_mcp.provenance import attach_provenance @@ -77,7 +77,7 @@ def _inspect_mesh_local(file_path: str) -> Dict[str, Any]: file_size_mb = path.stat().st_size / (1024 * 1024) try: - grid = ux.open_grid(file_path) + grid = load_grid(file_path) except Exception as e: raise RuntimeError(f"Failed to load mesh file: {str(e)}") @@ -151,7 +151,7 @@ def _inspect_variable_local( raise FileNotFoundError(f"Data file not found: {data_path}") try: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) except Exception as e: raise RuntimeError(f"Failed to load dataset: {str(e)}") @@ -291,7 +291,7 @@ def _calculate_zonal_mean_local( raise FileNotFoundError(f"Data file not found: {data_path}") try: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) except Exception as e: raise RuntimeError(f"Failed to load dataset: {str(e)}") @@ -367,7 +367,7 @@ def validate_dataset(grid_path: str, data_path: str) -> Dict[str, Any]: raise FileNotFoundError(f"Data file not found: {data_path}") try: - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) except Exception as e: raise RuntimeError(f"Failed to load dataset: {str(e)}") diff --git a/src/uxarray_mcp/tools/plotting.py b/src/uxarray_mcp/tools/plotting.py index 35bb94a..2b35cf6 100644 --- a/src/uxarray_mcp/tools/plotting.py +++ b/src/uxarray_mcp/tools/plotting.py @@ -5,10 +5,9 @@ from pathlib import Path from typing import Any, Optional -import uxarray as ux from mcp.types import ImageContent, TextContent -from uxarray_mcp.domain.mesh import load_grid +from uxarray_mcp.domain.mesh import load_dataset, load_grid from uxarray_mcp.domain.plotting import ( render_mesh, render_mesh_geo, @@ -616,7 +615,7 @@ def _plot_variable_local( "The file may not have been written correctly." ) - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) if variable_name is None: for var in uxds.data_vars: @@ -794,7 +793,7 @@ def _plot_zonal_mean_local( "The file may not have been written correctly." ) - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) zonal_result = compute_zonal_mean_stats( uxds, variable_name, lat_spec=lat_spec, conservative=conservative diff --git a/src/uxarray_mcp/tools/vector_calc.py b/src/uxarray_mcp/tools/vector_calc.py index d88c605..2053144 100644 --- a/src/uxarray_mcp/tools/vector_calc.py +++ b/src/uxarray_mcp/tools/vector_calc.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional +from uxarray_mcp.domain.mesh import load_dataset from uxarray_mcp.state import OperationTracker from .remote_tools import ( @@ -120,7 +121,6 @@ def calculate_gradient( >>> calculate_gradient("grid.nc", "data.nc", "temperature") {"components": ["d_temperature_d_x", "d_temperature_d_y"], ...} """ - import uxarray as ux from uxarray_mcp.domain.vector_calc import compute_gradient from uxarray_mcp.provenance import attach_provenance @@ -132,7 +132,7 @@ def calculate_gradient( } def _local(): - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) return attach_provenance( compute_gradient(uxds, variable_name), tool="calculate_gradient", @@ -203,7 +203,6 @@ def calculate_curl( >>> calculate_curl("/hpc/grid.nc", "/hpc/data.nc", "u", "v", use_remote=True) {"stats": {...}, "_provenance": {"execution_venue": "hpc:...", ...}} """ - import uxarray as ux from uxarray_mcp.domain.vector_calc import compute_curl from uxarray_mcp.provenance import attach_provenance @@ -216,7 +215,7 @@ def calculate_curl( } def _local(): - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) return attach_provenance( compute_curl(uxds, u_variable, v_variable), tool="calculate_curl", @@ -287,7 +286,6 @@ def calculate_divergence( ... ) {"interpretation": "horizontal divergence du/dx + dv/dy", ...} """ - import uxarray as ux from uxarray_mcp.domain.vector_calc import compute_divergence from uxarray_mcp.provenance import attach_provenance @@ -300,7 +298,7 @@ def calculate_divergence( } def _local(): - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) return attach_provenance( compute_divergence(uxds, u_variable, v_variable), tool="calculate_divergence", @@ -386,7 +384,6 @@ def calculate_azimuthal_mean( ... ) {"radii_deg": [0.0, 0.5, 1.0, ...], "azimuthal_mean_values": [...], ...} """ - import uxarray as ux from uxarray_mcp.domain.vector_calc import compute_azimuthal_mean from uxarray_mcp.provenance import attach_provenance @@ -402,7 +399,7 @@ def calculate_azimuthal_mean( } def _local(): - uxds = ux.open_dataset(grid_path, data_path) + uxds = load_dataset(grid_path, data_path) return attach_provenance( compute_azimuthal_mean( uxds, variable_name, center_lon, center_lat, outer_radius, radius_step diff --git a/tests/test_inspect_mesh.py b/tests/test_inspect_mesh.py index 8e77c50..a3a68d0 100644 --- a/tests/test_inspect_mesh.py +++ b/tests/test_inspect_mesh.py @@ -49,6 +49,38 @@ def test_inspect_scrip_mesh(self, scrip_grid): result = inspect_mesh("/path/to/scrip.nc") assert result["format"] == "SCRIP" + def test_inspect_shapefile_mesh(self, base_grid): + """Test inspection of a Shapefile mesh.""" + base_grid.source_grid_spec = "Shapefile" + with ( + patch("uxarray_mcp.tools.inspection.Path") as MockPath, + patch("uxarray.Grid.from_file", return_value=base_grid) as mock_from_file, + ): + MockPath.return_value.exists.return_value = True + MockPath.return_value.stat.return_value.st_size = 1024 * 1024 + + result = inspect_mesh("/path/to/shapefile.shp") + mock_from_file.assert_called_once_with( + "/path/to/shapefile.shp", backend="geopandas" + ) + assert result["format"] == "Shapefile" + + def test_inspect_geojson_mesh(self, base_grid): + """Test inspection of a GeoJSON mesh.""" + base_grid.source_grid_spec = "GeoJSON" + with ( + patch("uxarray_mcp.tools.inspection.Path") as MockPath, + patch("uxarray.Grid.from_file", return_value=base_grid) as mock_from_file, + ): + MockPath.return_value.exists.return_value = True + MockPath.return_value.stat.return_value.st_size = 1024 * 1024 + + result = inspect_mesh("/path/to/geojson.geojson") + mock_from_file.assert_called_once_with( + "/path/to/geojson.geojson", backend="geopandas" + ) + assert result["format"] == "GeoJSON" + def test_inspect_healpix_mesh(self, base_grid): """Test inspection of a HEALPix mesh generation.""" # Mock grid returned by from_healpix diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 235f549..2bf8be1 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -160,8 +160,8 @@ class TestPlotVariableMocked: """Tests for plot_variable tool using mocks.""" @patch("uxarray_mcp.tools.plotting.render_variable", return_value=b"\x89PNG_var") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_variable_auto_select(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_variable_auto_select(self, mock_load_dataset, mock_render): mock_var = MagicMock() mock_var.dims = ("n_face",) mock_uxds = MagicMock() @@ -170,7 +170,7 @@ def test_plot_variable_auto_select(self, mock_ux, mock_render): mock_uxds.uxgrid.n_face = 100 mock_uxds.uxgrid.n_node = 200 mock_uxds.uxgrid.n_edge = 300 - mock_ux.open_dataset.return_value = mock_uxds + mock_load_dataset.return_value = mock_uxds with patch("uxarray_mcp.tools.plotting.Path") as MockPath: MockPath.return_value.exists.return_value = True @@ -201,8 +201,8 @@ def test_empty_data_file_raises(self): Path(empty_path).unlink(missing_ok=True) @patch("uxarray_mcp.tools.plotting.render_variable", return_value=b"\x89PNG_var") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_variable_vmin_vmax(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_variable_vmin_vmax(self, mock_load_dataset, mock_render): """vmin/vmax are passed through to render_variable.""" mock_var = MagicMock() mock_var.dims = ("n_face",) @@ -212,7 +212,7 @@ def test_plot_variable_vmin_vmax(self, mock_ux, mock_render): mock_uxds.uxgrid.n_face = 100 mock_uxds.uxgrid.n_node = 200 mock_uxds.uxgrid.n_edge = 300 - mock_ux.open_dataset.return_value = mock_uxds + mock_load_dataset.return_value = mock_uxds with patch("uxarray_mcp.tools.plotting.Path") as MockPath: MockPath.return_value.exists.return_value = True @@ -226,8 +226,8 @@ def test_plot_variable_vmin_vmax(self, mock_ux, mock_render): assert kwargs["vmax"] == 5.0 @patch("uxarray_mcp.tools.plotting.render_variable", return_value=b"\x89PNG_var") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_variable_custom_title(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_variable_custom_title(self, mock_load_dataset, mock_render): """title is passed through to render_variable.""" mock_var = MagicMock() mock_var.dims = ("n_face",) @@ -237,7 +237,7 @@ def test_plot_variable_custom_title(self, mock_ux, mock_render): mock_uxds.uxgrid.n_face = 100 mock_uxds.uxgrid.n_node = 200 mock_uxds.uxgrid.n_edge = 300 - mock_ux.open_dataset.return_value = mock_uxds + mock_load_dataset.return_value = mock_uxds with patch("uxarray_mcp.tools.plotting.Path") as MockPath: MockPath.return_value.exists.return_value = True @@ -258,10 +258,10 @@ class TestPlotZonalMeanMocked: """Tests for plot_zonal_mean tool using mocks.""" @patch("uxarray_mcp.tools.plotting.render_zonal_mean", return_value=b"\x89PNG_zm") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_zonal_mean_basic(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_zonal_mean_basic(self, mock_load_dataset, mock_render): mock_uxds = MagicMock() - mock_ux.open_dataset.return_value = mock_uxds + mock_load_dataset.return_value = mock_uxds zonal_stats = { "latitudes": [-90.0, 0.0, 90.0], @@ -291,10 +291,10 @@ def test_plot_zonal_mean_basic(self, mock_ux, mock_render): assert prov["_provenance"]["tool"] == "plot_zonal_mean" @patch("uxarray_mcp.tools.plotting.render_zonal_mean", return_value=b"\x89PNG_zm") - @patch("uxarray_mcp.tools.plotting.ux") - def test_plot_zonal_mean_line_color_and_title(self, mock_ux, mock_render): + @patch("uxarray_mcp.tools.plotting.load_dataset") + def test_plot_zonal_mean_line_color_and_title(self, mock_load_dataset, mock_render): """line_color and title are passed through to render_zonal_mean.""" - mock_ux.open_dataset.return_value = MagicMock() + mock_load_dataset.return_value = MagicMock() zonal_stats = { "latitudes": [-90.0, 0.0, 90.0], "zonal_mean_values": [270.0, 300.0, 270.0],