Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,18 @@ remotely on an HPC system you have access to.

You are most likely one of:

1. **Local user** — laptop only, no HPC. → [Local install](#local-install) (5 min).
1. **Local user** — laptop only, no HPC. → [Local install](#local-install).
2. **HPC user, endpoint already exists** — someone at your lab gave you a
Globus Compute endpoint UUID. → [Local install](#local-install), then
[docs/remote-hpc.md](docs/remote-hpc.md) (15 min).
[docs/remote-hpc.md](docs/remote-hpc.md).
3. **HPC user, your own personal endpoint** — you have a Globus identity and
shell access to an HPC machine, and want to stand up an endpoint just for
yourself. → [Local install](#local-install), then
[docs/operating-an-endpoint.md](docs/operating-an-endpoint.md#solo-personal-endpoint-quickstart)
(~30 min).
[docs/operating-an-endpoint.md](docs/operating-an-endpoint.md#solo-personal-endpoint-quickstart).
4. **Group / shared endpoint operator** — you're standing one up for a team,
project, or lab. → [Local install](#local-install), then the full
[docs/operating-an-endpoint.md](docs/operating-an-endpoint.md) including
service-account migration and the MEP allowlist (~1 hr+, site-dependent).
service-account migration and the MEP allowlist.

---

Expand Down
3 changes: 2 additions & 1 deletion src/uxarray_mcp/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from .area import compute_area_stats
from .mesh import load_grid
from .mesh import load_dataset, load_grid
from .variable import compute_variable_info
from .vector_calc import (
compute_azimuthal_mean,
Expand All @@ -17,6 +17,7 @@

__all__ = [
"load_grid",
"load_dataset",
"compute_area_stats",
"compute_variable_info",
"compute_zonal_mean_stats",
Expand Down
41 changes: 39 additions & 2 deletions src/uxarray_mcp/domain/mesh.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Shared grid loading with HEALPix support."""
"""Shared grid loading with HEALPix and GIS support."""

import os
from typing import Any


def load_grid(file_path: str) -> Any:
"""Load a UXarray Grid from a file path or HEALPix spec.
"""Load a UXarray Grid from a file path, HEALPix spec, or shapefile/geojson.

Parameters
----------
Expand All @@ -23,4 +24,40 @@ def load_grid(file_path: str) -> Any:
zoom = int(parts[1]) if len(parts) > 1 else 1
return ux.Grid.from_healpix(zoom=zoom)

ext = os.path.splitext(file_path.lower())[1]
if ext in [".shp", ".geojson"]:
return ux.Grid.from_file(file_path, backend="geopandas")

return ux.open_grid(file_path)


def load_dataset(grid_path: str, data_path: str) -> Any:
"""Load a UXarray Dataset from grid and data paths, supporting shapefiles/geojson.

Parameters
----------
grid_path : str
Path to mesh grid file or "healpix:<zoom>".
data_path : str
Path to netCDF data file.

Returns
-------
ux.UxDataset
Loaded dataset object.
"""
import uxarray as ux

# HEALPix is a special case (usually grid-only, but we support it if matched)
if grid_path.lower().startswith("healpix"):
parts = grid_path.split(":")
zoom = int(parts[1]) if len(parts) > 1 else 1
grid = ux.Grid.from_healpix(zoom=zoom)
return ux.open_dataset(grid.to_xarray(), data_path)

ext = os.path.splitext(grid_path.lower())[1]
if ext in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
return ux.open_dataset(grid.to_xarray(), data_path)

return ux.open_dataset(grid_path, data_path)
125 changes: 110 additions & 15 deletions src/uxarray_mcp/remote/compute_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@

from typing import Any, Dict, Optional

# NOTE on the inline branching below: Globus Compute serializes each remote_*
# function body and ships it to the worker. Closures over module-level helpers
# such as _remote_load_grid don't survive serialization reliably across SDK
# versions, so each function inlines the same ~6 lines of extension dispatch
# (HEALPix spec, .shp / .geojson via geopandas, else default open_grid /
# open_dataset). If you change one, change them all.


def remote_runtime_probe() -> Dict[str, Any]:
"""Return lightweight runtime diagnostics from the remote worker."""
Expand Down Expand Up @@ -138,11 +145,14 @@ def remote_inspect_mesh(file_path: str) -> Dict[str, Any]:
This function executes on the HPC endpoint, not locally.
All imports must be within function scope for serialization.
"""
import os

import uxarray as ux

if file_path.startswith("healpix:"):
zoom = int(file_path.split(":")[1])
grid = ux.Grid.from_healpix(zoom)
grid = ux.Grid.from_healpix(int(file_path.split(":")[1]))
elif os.path.splitext(file_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(file_path, backend="geopandas")
else:
grid = ux.open_grid(file_path)

Expand Down Expand Up @@ -172,12 +182,15 @@ def remote_calculate_area(file_path: str) -> Dict[str, Any]:
This function executes on the HPC endpoint, not locally.
All imports must be within function scope for serialization.
"""
import os

import numpy as np
import uxarray as ux

if file_path.startswith("healpix:"):
zoom = int(file_path.split(":")[1])
grid = ux.Grid.from_healpix(zoom)
grid = ux.Grid.from_healpix(int(file_path.split(":")[1]))
elif os.path.splitext(file_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(file_path, backend="geopandas")
else:
grid = ux.open_grid(file_path)

Expand Down Expand Up @@ -218,10 +231,19 @@ def remote_inspect_variable(
-----
This function executes on the HPC endpoint, not locally.
"""
import os

import numpy as np
import uxarray as ux

uxds = ux.open_dataset(grid_path, data_path)
if grid_path.startswith("healpix:"):
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
uxds = ux.open_dataset(grid.to_xarray(), data_path)
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
uxds = ux.open_dataset(grid.to_xarray(), data_path)
else:
uxds = ux.open_dataset(grid_path, data_path)

face_dims = {"n_face", "nCells"}
node_dims = {"n_node", "nVertices"}
Expand Down Expand Up @@ -312,12 +334,15 @@ def remote_plot_mesh(
import matplotlib

matplotlib.use("Agg")
import os

import matplotlib.pyplot as plt
import uxarray as ux

if grid_path.startswith("healpix:"):
zoom = int(grid_path.split(":")[1])
grid = ux.Grid.from_healpix(zoom)
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
else:
grid = ux.open_grid(grid_path)

Expand Down Expand Up @@ -402,10 +427,19 @@ def remote_plot_variable(
import matplotlib

matplotlib.use("Agg")
import os

import matplotlib.pyplot as plt
import uxarray as ux

uxds = ux.open_dataset(grid_path, data_path)
if grid_path.startswith("healpix:"):
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
uxds = ux.open_dataset(grid.to_xarray(), data_path)
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
uxds = ux.open_dataset(grid.to_xarray(), data_path)
else:
uxds = ux.open_dataset(grid_path, data_path)

face_dims = {"n_face", "nCells"}

Expand Down Expand Up @@ -531,10 +565,19 @@ def remote_plot_zonal_mean(
import matplotlib

matplotlib.use("Agg")
import os

import matplotlib.pyplot as plt
import uxarray as ux

uxds = ux.open_dataset(grid_path, data_path)
if grid_path.startswith("healpix:"):
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
uxds = ux.open_dataset(grid.to_xarray(), data_path)
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
uxds = ux.open_dataset(grid.to_xarray(), data_path)
else:
uxds = ux.open_dataset(grid_path, data_path)

if variable_name not in uxds:
raise ValueError(f"Variable '{variable_name}' not found")
Expand Down Expand Up @@ -612,9 +655,18 @@ def remote_calculate_zonal_mean(
-----
This function executes on the HPC endpoint, not locally.
"""
import os

import uxarray as ux

uxds = ux.open_dataset(grid_path, data_path)
if grid_path.startswith("healpix:"):
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
uxds = ux.open_dataset(grid.to_xarray(), data_path)
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
uxds = ux.open_dataset(grid.to_xarray(), data_path)
else:
uxds = ux.open_dataset(grid_path, data_path)

if variable_name not in uxds:
raise ValueError(f"Variable '{variable_name}' not found")
Expand Down Expand Up @@ -704,10 +756,17 @@ def remote_subset_bbox_plot(
import matplotlib

matplotlib.use("Agg")
import os

import matplotlib.pyplot as plt
import uxarray as ux

grid = ux.open_grid(grid_path)
if grid_path.startswith("healpix:"):
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
else:
grid = ux.open_grid(grid_path)
n_face_total = int(grid.n_face)

# full-mesh mean area
Expand Down Expand Up @@ -979,10 +1038,19 @@ def remote_calculate_gradient(
grid_path: str, data_path: str, variable_name: str
) -> Dict[str, Any]:
"""Compute the spatial gradient of a face-centered scalar field on HPC."""
import os

import numpy as np
import uxarray as ux

uxds = ux.open_dataset(grid_path, data_path)
if grid_path.startswith("healpix:"):
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
uxds = ux.open_dataset(grid.to_xarray(), data_path)
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
uxds = ux.open_dataset(grid.to_xarray(), data_path)
else:
uxds = ux.open_dataset(grid_path, data_path)
if variable_name not in uxds.data_vars:
raise ValueError(
f"Variable '{variable_name}' not found. Available: {list(uxds.data_vars)}"
Expand Down Expand Up @@ -1024,10 +1092,19 @@ def remote_calculate_curl(

zeta = dv/dx - du/dy
"""
import os

import numpy as np
import uxarray as ux

uxds = ux.open_dataset(grid_path, data_path)
if grid_path.startswith("healpix:"):
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
uxds = ux.open_dataset(grid.to_xarray(), data_path)
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
uxds = ux.open_dataset(grid.to_xarray(), data_path)
else:
uxds = ux.open_dataset(grid_path, data_path)
for name in (u_variable, v_variable):
if name not in uxds.data_vars:
raise ValueError(
Expand Down Expand Up @@ -1070,10 +1147,19 @@ def remote_calculate_divergence(

divergence = du/dx + dv/dy
"""
import os

import numpy as np
import uxarray as ux

uxds = ux.open_dataset(grid_path, data_path)
if grid_path.startswith("healpix:"):
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
uxds = ux.open_dataset(grid.to_xarray(), data_path)
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
uxds = ux.open_dataset(grid.to_xarray(), data_path)
else:
uxds = ux.open_dataset(grid_path, data_path)
for name in (u_variable, v_variable):
if name not in uxds.data_vars:
raise ValueError(
Expand Down Expand Up @@ -1119,9 +1205,18 @@ def remote_calculate_azimuthal_mean(
radius_step: float,
) -> Dict[str, Any]:
"""Compute the azimuthal (radial) mean around a centre point on HPC."""
import os

import uxarray as ux

uxds = ux.open_dataset(grid_path, data_path)
if grid_path.startswith("healpix:"):
grid = ux.Grid.from_healpix(int(grid_path.split(":")[1]))
uxds = ux.open_dataset(grid.to_xarray(), data_path)
elif os.path.splitext(grid_path.lower())[1] in [".shp", ".geojson"]:
grid = ux.Grid.from_file(grid_path, backend="geopandas")
uxds = ux.open_dataset(grid.to_xarray(), data_path)
else:
uxds = ux.open_dataset(grid_path, data_path)
if variable_name not in uxds.data_vars:
raise ValueError(
f"Variable '{variable_name}' not found. Available: {list(uxds.data_vars)}"
Expand Down
10 changes: 5 additions & 5 deletions src/uxarray_mcp/tools/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import xarray as xr
from matplotlib.path import Path as MplPath

from uxarray_mcp.domain.mesh import load_grid
from uxarray_mcp.domain.mesh import load_dataset, load_grid
from uxarray_mcp.provenance import attach_provenance
from uxarray_mcp.state import (
OperationTracker,
Expand Down Expand Up @@ -58,7 +58,7 @@ def _load_dataarray(
data_path: str,
variable_name: str | None,
) -> tuple[ux.UxDataset, ux.UxDataArray, str]:
uxds = ux.open_dataset(grid_path, data_path)
uxds = load_dataset(grid_path, data_path)
selected = variable_name
if selected is None:
for name, var in uxds.data_vars.items():
Expand Down Expand Up @@ -434,8 +434,8 @@ def _load_comparison_arrays(
variable_name: str,
) -> tuple[xr.DataArray, xr.DataArray]:
if grid_path:
first = ux.open_dataset(grid_path, data_path_a)[variable_name].to_xarray()
second = ux.open_dataset(grid_path, data_path_b)[variable_name].to_xarray()
first = load_dataset(grid_path, data_path_a)[variable_name].to_xarray()
second = load_dataset(grid_path, data_path_b)[variable_name].to_xarray()
else:
first = xr.open_dataset(data_path_a)[variable_name]
second = xr.open_dataset(data_path_b)[variable_name]
Expand Down Expand Up @@ -702,7 +702,7 @@ def regrid_dataset(
if resolved_data is None:
raise ValueError("data_path is required to regrid a dataset.")

uxds = ux.open_dataset(resolved_grid, resolved_data)
uxds = load_dataset(resolved_grid, resolved_data)
target_grid = load_grid(target_grid_path)
if not hasattr(uxds[next(iter(uxds.data_vars))].remap, method):
raise ValueError(
Expand Down
Loading
Loading