diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index a93a5ffd6..d9760d7d9 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -112,3 +112,28 @@ def test_uxdataset_init_from_xarray_dataset(): assert "a" in uxds.data_vars assert "x" in uxds.coords assert uxds.attrs["source"] == "testing" + +def test_uxdataset_to_array(): + """Tests UxDataset.to_array(), ensuring `dim` and `name` kwargs work too.""" + uxds = UxDataset( + data_vars={ + "a": ("x", [1, 2]), + "b": ("x", [3, 4]), + "c": ("y", [-1, -2, -3, -4]), + }, + coords={"x": [10, 20], "y": [-10, -20, -30, -40]}, + attrs={"source": "testing"}, + ) + # first check basic functionality without worrying about kwargs + arr = uxds.to_array() + assert isinstance(arr, ux.UxDataArray) + assert arr.sizes == {"variable": 3, "x": 2, "y": 4} + assert arr.attrs["source"] == "testing" + for k, c in arr.coords.items(): + assert k in arr.coords and c.equals(arr.coords[k]) + # next check that dim & name args/kwargs work as expected. + arr1 = uxds.to_array('custom_dim') + assert arr1.sizes == {"custom_dim": 3, "x": 2, "y": 4} + assert arr1.name is None + arr2 = uxds.to_array(dim='custom_dim', name='custom_name') + assert arr2.name == 'custom_name' diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index fe0be4cbe..64023f058 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -3,7 +3,7 @@ import os import sys from html import escape -from typing import IO, Any, Mapping +from typing import IO, Any, Hashable, Mapping from warnings import warn import numpy as np @@ -621,11 +621,32 @@ def integrate(self, quadrature_rule="triangular", order=4): return integral - def to_array(self) -> UxDataArray: - """Override to make the result an instance of - ``uxarray.UxDataArray``.""" + def to_array( + self, + dim: Hashable = "variable", + name: Hashable = None, + ) -> UxDataArray: + """Convert this ``uxarray.UxDataset`` into a ``uxarray.UxDataArray``, + attaching this UxDataset's uxgrid to the result. + + Similarly to xarray.Dataset.to_array(), the data variables will be + broadcast against each other and stacked along the first axis of + the new array. All coordinates of this dataset will remain coordinates. + + Parameters + ---------- + dim : Hashable, optional + Name of the new dimension. Defaults to "variable" + name : Hashable or None, optional + Name of the new data array. + + Returns + ------- + UxDataArray + The ``uxarray.UxDataset`` represented as a ``uxarray.UxDataArray`` + """ - xarr = super().to_array() + xarr = super().to_array(dim=dim, name=name) return UxDataArray(xarr, uxgrid=self.uxgrid) def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset: