Skip to content

Commit

Permalink
Merge pull request #678 from hayesla/pow
Browse files Browse the repository at this point in the history
Adding __pow__ to NDCube
  • Loading branch information
DanRyanIrish authored Apr 23, 2024
2 parents 3a62c59 + dc6fe12 commit 1db2816
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 3 deletions.
1 change: 1 addition & 0 deletions changelog/678.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable `~ndcube.NDCube` to be raised to a power.
6 changes: 3 additions & 3 deletions ndcube/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def skycoord_2d_lut(shape):
return SkyCoord(*data, unit=u.deg)


def data_nd(shape):
def data_nd(shape, dtype=float):
nelem = np.prod(shape)
return np.arange(nelem).reshape(shape)
return np.arange(nelem, dtype=dtype).reshape(shape)


def time_extra_coords(shape, axis, base):
Expand Down Expand Up @@ -330,7 +330,7 @@ def ndcube_4d_ln_l_t_lt(wcs_4d_lt_t_l_ln):
def ndcube_4d_ln_lt_l_t(wcs_4d_t_l_lt_ln):
shape = (5, 8, 10, 12)
wcs_4d_t_l_lt_ln.array_shape = shape
data_cube = data_nd(shape)
data_cube = data_nd(shape, dtype=int)
return NDCube(data_cube, wcs=wcs_4d_t_l_lt_ln)


Expand Down
22 changes: 22 additions & 0 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,28 @@ def __rmul__(self, value):
def __truediv__(self, value):
return self.__mul__(1/value)

def __pow__(self, value):
new_data = self.data ** value
new_unit = self.unit if self.unit is None else self.unit ** value
new_uncertainty = self.uncertainty

if self.uncertainty is not None:
try:
new_uncertainty = new_uncertainty.propagate(np.power, self, self.data ** value, correlation=1)
except ValueError as e:
if "unsupported operation" in e.args[0]:
new_uncertainty = None
warnings.warn(f"{type(self.uncertainty)} does not support propagation of uncertainties for power. Setting uncertainties to None.",
UserWarning, stacklevel=2)
elif "does not support uncertainty propagation" in e.args[0]:
new_uncertainty = None
warnings.warn(f"{e.args[0]} Setting uncertainties to None.",
UserWarning, stacklevel=2)
else:
raise e

return self._new_instance_from_op(new_data, new_unit, new_uncertainty)

def to(self, new_unit, **kwargs):
"""Convert instance to another unit.
Expand Down
27 changes: 27 additions & 0 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,33 @@ def test_cube_arithmetic_multiply_notimplementederror(ndcube_2d_ln_lt_units):
_ = ndcube_2d_ln_lt_units * ndcube_2d_ln_lt_units



@pytest.mark.parametrize('power', [2, -2, 10, 0.5])
def test_cube_arithmetic_power(ndcube_2d_ln_lt, power):
cube_quantity = u.Quantity(ndcube_2d_ln_lt.data, ndcube_2d_ln_lt.unit)
with np.errstate(divide='ignore'):
new_cube = ndcube_2d_ln_lt ** power
check_arithmetic_value_and_units(new_cube, cube_quantity**power)


@pytest.mark.parametrize('power', [2, -2, 10, 0.5])
def test_cube_arithmetic_power_unknown_uncertainty(ndcube_4d_unit_uncertainty, power):
cube_quantity = u.Quantity(ndcube_4d_unit_uncertainty.data, ndcube_4d_unit_uncertainty.unit)
with pytest.warns(UserWarning, match="UnknownUncertainty does not support uncertainty propagation with correlation. Setting uncertainties to None."):
with np.errstate(divide='ignore'):
new_cube = ndcube_4d_unit_uncertainty ** power
check_arithmetic_value_and_units(new_cube, cube_quantity**power)


@pytest.mark.parametrize('power', [2, -2, 10, 0.5])
def test_cube_arithmetic_power_std_uncertainty(ndcube_2d_ln_lt_uncert, power):
cube_quantity = u.Quantity(ndcube_2d_ln_lt_uncert.data, ndcube_2d_ln_lt_uncert.unit)
with pytest.warns(UserWarning, match=r"<class 'astropy.nddata.nduncertainty.StdDevUncertainty'> does not support propagation of uncertainties for power. Setting uncertainties to None."):
with np.errstate(divide='ignore'):
new_cube = ndcube_2d_ln_lt_uncert ** power
check_arithmetic_value_and_units(new_cube, cube_quantity**power)


@pytest.mark.parametrize('new_unit', [u.mJ, 'mJ'])
def test_to(ndcube_1d_l, new_unit):
cube = ndcube_1d_l
Expand Down

0 comments on commit 1db2816

Please sign in to comment.