diff --git a/changelog/534.bugfix.rst b/changelog/534.bugfix.rst new file mode 100644 index 000000000..cbb7802ca --- /dev/null +++ b/changelog/534.bugfix.rst @@ -0,0 +1 @@ +Dynamically copy docstring and function signature from `NDCube.plotter.plot()` to `NDCube.plot()`. diff --git a/ndcube/tests/test_ndcube.py b/ndcube/tests/test_ndcube.py index 4efa8dd0c..a80f6b011 100644 --- a/ndcube/tests/test_ndcube.py +++ b/ndcube/tests/test_ndcube.py @@ -1,3 +1,4 @@ +from inspect import signature from textwrap import dedent import astropy.units as u @@ -850,3 +851,10 @@ def test_reproject_exact_incompatible_wcs(ndcube_4d_ln_l_t_lt, wcs_4d_lt_t_l_ln, with pytest.raises(ValueError): _ = ndcube_4d_ln_l_t_lt.reproject_to(wcs_4d_lt_t_l_ln, algorithm='exact', shape_out=(5, 10, 12, 8)) + + +def test_plot_docstring(): + cube = NDCube([], astropy.wcs.WCS()) + + assert cube.plot.__doc__ == cube.plotter.plot.__doc__ + assert signature(cube.plot) == signature(cube.plotter.plot) diff --git a/ndcube/visualization/descriptor.py b/ndcube/visualization/descriptor.py index f45a77281..b4e58ce0d 100644 --- a/ndcube/visualization/descriptor.py +++ b/ndcube/visualization/descriptor.py @@ -1,3 +1,5 @@ +import functools + MISSING_MATPLOTLIB_ERROR_MSG = ("Matplotlib can not be imported, so the default plotting " "functionality is disabled. Please install matplotlib.") @@ -18,30 +20,43 @@ def __set_name__(self, owner, name): # attribute name is the name of the attribute on the parent class where # the data is stored. self._attribute_name = f"_{name}" + plotter = self._resolve_default_type(raise_error=False) + if plotter is not None and hasattr(plotter, "plot"): + functools.update_wrapper(owner.plot, plotter.plot) - def __get__(self, obj, objtype=None): - if obj is None: - return - - if getattr(obj, self._attribute_name, None) is None: - - # We special case the default MatplotlibPlotter so that we can - # delay the import of matplotlib until the plotter is first - # accessed. + def _resolve_default_type(self, raise_error=True): + # We special case the default MatplotlibPlotter so that we can + # delay the import of matplotlib until the plotter is first + # accessed. + if self._default_type in ("mpl_plotter", "mpl_sequence_plotter"): try: if self._default_type == "mpl_plotter": from ndcube.visualization.mpl_plotter import MatplotlibPlotter - self.__set__(obj, MatplotlibPlotter) + return MatplotlibPlotter elif self._default_type == "mpl_sequence_plotter": from ndcube.visualization.mpl_sequence_plotter import MatplotlibSequencePlotter - self.__set__(obj, MatplotlibSequencePlotter) - elif self._default_type is not None: - self.__set__(obj, self._default_type) - else: - # If we have no default type then just return None - return + return MatplotlibSequencePlotter except ImportError as e: - raise ImportError(MISSING_MATPLOTLIB_ERROR_MSG) from e + if raise_error: + raise ImportError(MISSING_MATPLOTLIB_ERROR_MSG) from e + + elif self._default_type is not None: + return self._default_type + + # If we have no default type then just return None + else: + return + + def __get__(self, obj, objtype=None): + if obj is None: + return + + if getattr(obj, self._attribute_name, None) is None: + plotter_type = self._resolve_default_type() + if plotter_type is None: + return + + self.__set__(obj, plotter_type) return getattr(obj, self._attribute_name) @@ -49,4 +64,14 @@ def __set__(self, obj, value): if not isinstance(value, type): raise TypeError( "Plotter attribute can only be set with an uninitialised plotter object.") + setattr(obj, self._attribute_name, value(obj)) + # here obj is the ndcube object and value is the plotter type + # Get the instantiated plotter we just assigned to the ndcube + plotter = getattr(obj, self._attribute_name) + # If the plotter has a plot object then update the signature and + # docstring of the cubes `plot()` method to match + # Docstrings of methods aren't writeable so we copy to the underlying + # function object instead + if hasattr(plotter, "plot"): + functools.update_wrapper(obj.plot.__func__, plotter.plot.__func__)