Skip to content

Commit

Permalink
Merge pull request #563 from meeseeksmachine/auto-backport-of-pr-534-…
Browse files Browse the repository at this point in the history
…on-2.0

Backport PR #534 on branch 2.0 (Add functools call to copy `plotter.plot` doc and signature to `plot`)
  • Loading branch information
Cadair authored Sep 23, 2022
2 parents 99b94d4 + d64709b commit 3fcbdbb
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
1 change: 1 addition & 0 deletions changelog/534.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Dynamically copy docstring and function signature from `NDCube.plotter.plot()` to `NDCube.plot()`.
8 changes: 8 additions & 0 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from inspect import signature
from textwrap import dedent

import astropy.units as u
Expand Down Expand Up @@ -850,3 +851,10 @@ def test_reproject_exact_incompatible_wcs(ndcube_4d_ln_l_t_lt, wcs_4d_lt_t_l_ln,
with pytest.raises(ValueError):
_ = ndcube_4d_ln_l_t_lt.reproject_to(wcs_4d_lt_t_l_ln, algorithm='exact',
shape_out=(5, 10, 12, 8))


def test_plot_docstring():
cube = NDCube([], astropy.wcs.WCS())

assert cube.plot.__doc__ == cube.plotter.plot.__doc__
assert signature(cube.plot) == signature(cube.plotter.plot)
59 changes: 42 additions & 17 deletions ndcube/visualization/descriptor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

MISSING_MATPLOTLIB_ERROR_MSG = ("Matplotlib can not be imported, so the default plotting "
"functionality is disabled. Please install matplotlib.")

Expand All @@ -18,35 +20,58 @@ def __set_name__(self, owner, name):
# attribute name is the name of the attribute on the parent class where
# the data is stored.
self._attribute_name = f"_{name}"
plotter = self._resolve_default_type(raise_error=False)
if plotter is not None and hasattr(plotter, "plot"):
functools.update_wrapper(owner.plot, plotter.plot)

def __get__(self, obj, objtype=None):
if obj is None:
return

if getattr(obj, self._attribute_name, None) is None:

# We special case the default MatplotlibPlotter so that we can
# delay the import of matplotlib until the plotter is first
# accessed.
def _resolve_default_type(self, raise_error=True):
# We special case the default MatplotlibPlotter so that we can
# delay the import of matplotlib until the plotter is first
# accessed.
if self._default_type in ("mpl_plotter", "mpl_sequence_plotter"):
try:
if self._default_type == "mpl_plotter":
from ndcube.visualization.mpl_plotter import MatplotlibPlotter
self.__set__(obj, MatplotlibPlotter)
return MatplotlibPlotter
elif self._default_type == "mpl_sequence_plotter":
from ndcube.visualization.mpl_sequence_plotter import MatplotlibSequencePlotter
self.__set__(obj, MatplotlibSequencePlotter)
elif self._default_type is not None:
self.__set__(obj, self._default_type)
else:
# If we have no default type then just return None
return
return MatplotlibSequencePlotter
except ImportError as e:
raise ImportError(MISSING_MATPLOTLIB_ERROR_MSG) from e
if raise_error:
raise ImportError(MISSING_MATPLOTLIB_ERROR_MSG) from e

elif self._default_type is not None:
return self._default_type

# If we have no default type then just return None
else:
return

def __get__(self, obj, objtype=None):
if obj is None:
return

if getattr(obj, self._attribute_name, None) is None:
plotter_type = self._resolve_default_type()
if plotter_type is None:
return

self.__set__(obj, plotter_type)

return getattr(obj, self._attribute_name)

def __set__(self, obj, value):
if not isinstance(value, type):
raise TypeError(
"Plotter attribute can only be set with an uninitialised plotter object.")

setattr(obj, self._attribute_name, value(obj))
# here obj is the ndcube object and value is the plotter type
# Get the instantiated plotter we just assigned to the ndcube
plotter = getattr(obj, self._attribute_name)
# If the plotter has a plot object then update the signature and
# docstring of the cubes `plot()` method to match
# Docstrings of methods aren't writeable so we copy to the underlying
# function object instead
if hasattr(plotter, "plot"):
functools.update_wrapper(obj.plot.__func__, plotter.plot.__func__)

0 comments on commit 3fcbdbb

Please sign in to comment.