Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,28 @@ def test_uxdataset_init_from_xarray_dataset():
assert "a" in uxds.data_vars
assert "x" in uxds.coords
assert uxds.attrs["source"] == "testing"

def test_uxdataset_to_array():
"""Tests UxDataset.to_array(), ensuring `dim` and `name` kwargs work too."""
uxds = UxDataset(
data_vars={
"a": ("x", [1, 2]),
"b": ("x", [3, 4]),
"c": ("y", [-1, -2, -3, -4]),
},
coords={"x": [10, 20], "y": [-10, -20, -30, -40]},
attrs={"source": "testing"},
)
# first check basic functionality without worrying about kwargs
arr = uxds.to_array()
assert isinstance(arr, ux.UxDataArray)
assert arr.sizes == {"variable": 3, "x": 2, "y": 4}
assert arr.attrs["source"] == "testing"
for k, c in arr.coords.items():
assert k in arr.coords and c.equals(arr.coords[k])
# next check that dim & name args/kwargs work as expected.
arr1 = uxds.to_array('custom_dim')
assert arr1.sizes == {"custom_dim": 3, "x": 2, "y": 4}
assert arr1.name is None
arr2 = uxds.to_array(dim='custom_dim', name='custom_name')
assert arr2.name == 'custom_name'
31 changes: 26 additions & 5 deletions uxarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from html import escape
from typing import IO, Any, Mapping
from typing import IO, Any, Hashable, Mapping
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -621,11 +621,32 @@ def integrate(self, quadrature_rule="triangular", order=4):

return integral

def to_array(self) -> UxDataArray:
"""Override to make the result an instance of
``uxarray.UxDataArray``."""
def to_array(
self,
dim: Hashable = "variable",
name: Hashable = None,
) -> UxDataArray:
"""Convert this ``uxarray.UxDataset`` into a ``uxarray.UxDataArray``,
attaching this UxDataset's uxgrid to the result.

Similarly to xarray.Dataset.to_array(), the data variables will be
broadcast against each other and stacked along the first axis of
the new array. All coordinates of this dataset will remain coordinates.

Parameters
----------
dim : Hashable, optional
Name of the new dimension. Defaults to "variable"
name : Hashable or None, optional
Name of the new data array.

Returns
-------
UxDataArray
The ``uxarray.UxDataset`` represented as a ``uxarray.UxDataArray``
"""

xarr = super().to_array()
xarr = super().to_array(dim=dim, name=name)
return UxDataArray(xarr, uxgrid=self.uxgrid)

def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset:
Expand Down
Loading