Skip to content

Commit

Permalink
refactored: second_y arg moved to TimeSeriesFigure.plot
Browse files Browse the repository at this point in the history
  • Loading branch information
alondmnt committed Sep 3, 2024
1 parent 558eac9 commit de30a05
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 32 deletions.
14 changes: 13 additions & 1 deletion nbs/15_timeseries_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
" n_axes: int = 1, \n",
" height: float = 1, \n",
" sharex: Union[str, int, plt.Axes] = None, \n",
" second_y: bool = False,\n",
" name: str = None, \n",
" ax: Union[str, int, plt.Axes] = None, \n",
" adjust_time: bool = True, \n",
Expand All @@ -97,6 +98,7 @@
" n_axes (int): The number of axes required. Default is 1.\n",
" height (float): The proportional height of the axes relative to a single unit axis.\n",
" sharex (str, int, or plt.Axes): Index or name of the axis to share the x-axis with. If None, the x-axis is independent.\n",
" second_y (bool): If True, plot will be done on a secondary y-axis in the plot. Default is False.s\n",
" name (str): Name or ID to assign to the axis.\n",
" ax (plt.Axes, str, int): Pre-existing axis (object, name, or index) or list of axes to plot on.\n",
" adjust_time (bool): Whether to adjust the time limits of all axes to match the data.\n",
Expand All @@ -110,9 +112,18 @@
" else:\n",
" ax = self.get_axes(ax, squeeze=True)\n",
"\n",
" if second_y:\n",
" ax.yaxis.grid(False)\n",
" ax = ax.twinx()\n",
"\n",
" plot_function(*args, ax=ax, **kwargs)\n",
" if adjust_time:\n",
" self.set_time_limits(None, None) # Adjust all axes to the same time limits\n",
" if second_y:\n",
" ax.yaxis.grid(False)\n",
" ax.yaxis.label.set_rotation(90)\n",
" ax.yaxis.label.set_ha('center')\n",
"\n",
"\n",
" return ax\n",
"\n",
Expand All @@ -121,7 +132,7 @@
" height: float = 1, \n",
" n_axes: int = 1, \n",
" sharex: Optional[Union[str, int, plt.Axes]] = None, \n",
" name: Optional[str] = None\n",
" name: Optional[str] = None,\n",
" ) -> Union[plt.Axes, Iterable[plt.Axes]]:\n",
" \"\"\"\n",
" Add one or more axes with a specific proportional height to the figure.\n",
Expand Down Expand Up @@ -195,6 +206,7 @@
"\n",
" Args:\n",
" ax: The axis object, index, name, or list of those to retrieve.\n",
" squeeze (bool): Whether to return a single axis object if only one is found.\n",
" \n",
" Returns:\n",
" Iterable[plt.Axes]: A list of axis objects.\n",
Expand Down
17 changes: 2 additions & 15 deletions nbs/16_diet_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@
" Plot a lollipop chart with pie charts representing nutrient composition for each meal.\n",
"\n",
" NOTE: The y-axis is scaled to match the units of the x-axis, to avoid distortion of the pie charts.\n",
" Use the `second_y` option to plot it with other y-axis data.\n",
" Due to scaling, if you intend to change `xlim` after plotting, you must also provide `date_range`.\n",
" Use the `second_y` of g.plot() option to plot it with other y-axis data.\n",
"\n",
" Args:\n",
" diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.\n",
Expand Down Expand Up @@ -221,11 +221,6 @@
" if ax is None:\n",
" fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)\n",
"\n",
" # Secondary y-axis\n",
" if second_y:\n",
" ax.yaxis.grid(False)\n",
" ax = ax.twinx()\n",
"\n",
" # Convert nutrients in mg to grams\n",
" for nut in grouped_nutrients['mg']:\n",
" df[nut.replace('_mg', '_g')] = df[nut] / 1000\n",
Expand Down Expand Up @@ -284,13 +279,7 @@
" ax.legend(handles=wedges, labels= pie_nuts, loc='upper left', bbox_to_anchor=LEGEND_SHIFT)\n",
"\n",
" # Format x-axis to display dates properly\n",
" if second_y:\n",
" ha = 'center'\n",
" rotation = 90\n",
" else:\n",
" ha = 'right'\n",
" rotation=0\n",
" ax.set_ylabel(y.replace('_', ' ').title(), rotation=rotation, horizontalalignment=ha)\n",
" ax.set_ylabel(y.replace('_', ' ').title(), rotation=0, horizontalalignment='right')\n",
" ax.grid(True)\n",
"\n",
" # Set y-ticks and x-ticks\n",
Expand All @@ -299,8 +288,6 @@
" yticks = np.arange(0, ylim[1] / aspect_ratio, 100, dtype=int)\n",
" ax.set_yticks(yticks * aspect_ratio)\n",
" ax.set_yticklabels(yticks)\n",
" if second_y:\n",
" ax.yaxis.grid(False)\n",
"\n",
" format_xticks(ax, df[x])\n",
" if label is not None:\n",
Expand Down
17 changes: 2 additions & 15 deletions pheno_utils/diet_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def plot_nutrient_lollipop(
Plot a lollipop chart with pie charts representing nutrient composition for each meal.
NOTE: The y-axis is scaled to match the units of the x-axis, to avoid distortion of the pie charts.
Use the `second_y` option to plot it with other y-axis data.
Due to scaling, if you intend to change `xlim` after plotting, you must also provide `date_range`.
Use the `second_y` of g.plot() option to plot it with other y-axis data.
Args:
diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.
Expand Down Expand Up @@ -183,11 +183,6 @@ def plot_nutrient_lollipop(
if ax is None:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)

# Secondary y-axis
if second_y:
ax.yaxis.grid(False)
ax = ax.twinx()

# Convert nutrients in mg to grams
for nut in grouped_nutrients['mg']:
df[nut.replace('_mg', '_g')] = df[nut] / 1000
Expand Down Expand Up @@ -246,13 +241,7 @@ def ytick_formatter(y, pos):
ax.legend(handles=wedges, labels= pie_nuts, loc='upper left', bbox_to_anchor=LEGEND_SHIFT)

# Format x-axis to display dates properly
if second_y:
ha = 'center'
rotation = 90
else:
ha = 'right'
rotation=0
ax.set_ylabel(y.replace('_', ' ').title(), rotation=rotation, horizontalalignment=ha)
ax.set_ylabel(y.replace('_', ' ').title(), rotation=0, horizontalalignment='right')
ax.grid(True)

# Set y-ticks and x-ticks
Expand All @@ -261,8 +250,6 @@ def ytick_formatter(y, pos):
yticks = np.arange(0, ylim[1] / aspect_ratio, 100, dtype=int)
ax.set_yticks(yticks * aspect_ratio)
ax.set_yticklabels(yticks)
if second_y:
ax.yaxis.grid(False)

format_xticks(ax, df[x])
if label is not None:
Expand Down
14 changes: 13 additions & 1 deletion pheno_utils/timeseries_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def plot(
n_axes: int = 1,
height: float = 1,
sharex: Union[str, int, plt.Axes] = None,
second_y: bool = False,
name: str = None,
ax: Union[str, int, plt.Axes] = None,
adjust_time: bool = True,
Expand All @@ -58,6 +59,7 @@ def plot(
n_axes (int): The number of axes required. Default is 1.
height (float): The proportional height of the axes relative to a single unit axis.
sharex (str, int, or plt.Axes): Index or name of the axis to share the x-axis with. If None, the x-axis is independent.
second_y (bool): If True, plot will be done on a secondary y-axis in the plot. Default is False.s
name (str): Name or ID to assign to the axis.
ax (plt.Axes, str, int): Pre-existing axis (object, name, or index) or list of axes to plot on.
adjust_time (bool): Whether to adjust the time limits of all axes to match the data.
Expand All @@ -71,9 +73,18 @@ def plot(
else:
ax = self.get_axes(ax, squeeze=True)

if second_y:
ax.yaxis.grid(False)
ax = ax.twinx()

plot_function(*args, ax=ax, **kwargs)
if adjust_time:
self.set_time_limits(None, None) # Adjust all axes to the same time limits
if second_y:
ax.yaxis.grid(False)
ax.yaxis.label.set_rotation(90)
ax.yaxis.label.set_ha('center')


return ax

Expand All @@ -82,7 +93,7 @@ def add_axes(
height: float = 1,
n_axes: int = 1,
sharex: Optional[Union[str, int, plt.Axes]] = None,
name: Optional[str] = None
name: Optional[str] = None,
) -> Union[plt.Axes, Iterable[plt.Axes]]:
"""
Add one or more axes with a specific proportional height to the figure.
Expand Down Expand Up @@ -156,6 +167,7 @@ def get_axes(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Ax
Args:
ax: The axis object, index, name, or list of those to retrieve.
squeeze (bool): Whether to return a single axis object if only one is found.
Returns:
Iterable[plt.Axes]: A list of axis objects.
Expand Down

0 comments on commit de30a05

Please sign in to comment.