Skip to content

Commit

Permalink
Try better user flexibility when creating figures (#27)
Browse files Browse the repository at this point in the history
* Try better user flexibility when creating figures

* Adjusted function - removed fignum

* Improved function

* Missing xarray package - fixed

* Modified all plotting functions for consistency

---------

Co-authored-by: ChiaraMonforte <[email protected]>
  • Loading branch information
MOchiara and ChiaraMonforte authored Jul 16, 2024
1 parent 06cd63a commit 9fca3c1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 28 deletions.
52 changes: 35 additions & 17 deletions glidertest/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skyfield import api
from skyfield import almanac
from tqdm import tqdm
import xarray as xr


def grid2d(x, y, v, xi=1, yi=1):
Expand Down Expand Up @@ -73,7 +74,11 @@ def updown_bias(ds, var='PSAL', v_res=1):
df = pd.DataFrame(data={'dc' : dc, 'cd' : cd,'depth': depthG[0,:]})
return df

def plot_updown_bias(df, ax, xlabel='Temperature [C]'):
def plot_updown_bias(df: pd.DataFrame,ax: plt.Axes = None, xlabel='Temperature [C]', **kw: dict,)-> tuple({plt.Figure, plt.Axes}):
if ax is None:
fig, ax = plt.subplots( figsize=(5,5))
else:
fig = plt.gcf()
"""
This function can be used to plot the up and downcast differences computed with the updown_bias function
Expand All @@ -96,7 +101,7 @@ def plot_updown_bias(df, ax, xlabel='Temperature [C]'):
ax.set_xlabel(xlabel)
ax.set_ylim(df.depth.max() + 10, -df.depth.max()/30)
ax.grid()
return ax
return fig, ax


def find_cline(var, depth_array):
Expand Down Expand Up @@ -427,7 +432,7 @@ def day_night_avg(ds, sel_var='CHLA', start_time='2024-04-18', end_time='2024-04
night_av.loc[np.where(night_av.batch==i)[0],'date'] = date_val
return day_av, night_av

def plot_daynight_avg(day, night, ax, sel_day='2023-09-09', xlabel='Chlorophyll [mg m-3]'):
def plot_daynight_avg(day: pd.DataFrame, night: pd.DataFrame, ax: plt.Axes = None, sel_day='2023-09-09', xlabel='Chlorophyll [mg m-3]', **kw: dict,) -> tuple({plt.Figure, plt.Axes}):
"""
This function can be used to plot the day and night averages computed with the day_night_avg function
Expand All @@ -444,34 +449,42 @@ def plot_daynight_avg(day, night, ax, sel_day='2023-09-09', xlabel='Chlorophyll
A line plot comparing the day and night average over depth for the selcted day
"""
if ax is None:
fig, ax = plt.subplots(figsize=(5,5))
else:
fig = plt.gcf()
ax.plot(night.where(night.date==sel_day).dropna().dat, night.where(night.date==sel_day).dropna().depth, label='Night time average')
ax.plot(day.where(day.date==sel_day).dropna().dat, day.where(day.date==sel_day).dropna().depth, label='Daytime average')
ax.legend()
ax.invert_yaxis()
ax.grid()
ax.set(xlabel= xlabel, ylabel='Depth [m]')
ax.set_title(sel_day)
return ax
return fig, ax

def plot_section_with_srss(ds, ax, sel_var='TEMP',start_time = '2023-09-06', end_time = '2023-09-10', ylim=45):
def plot_section_with_srss(ds: xr.Dataset, sel_var: str, ax: plt.Axes = None, start_time = '2023-09-06', end_time = '2023-09-10', ylim=45, **kw: dict,) -> tuple({plt.Figure, plt.Axes}):
"""
This function can be used to plot sections for any variable with the sunrise and sunset plotted over
Parameters
----------
ax: axis to plot the data
ds: xarray on OG1 format containing at least time, depth, latitude, longitude and the selected variable.
Data should not be gridded.
sel_var: seleted variable to plot
ax: axis to plot the data
start_time: Start date of the data selection. As missions can be long and came make it hard to visualise NPQ effetc,
end_time: End date of the data selection. As missions can be long and came make it hard to visualise NPQ effetc,
ylim: specified limit for the maximum y axis value. The minumum is computed as ylim/30
Returns
-------
A section showing the variability of the selcted data over time and depth
"""
if ax is None:
fig, ax = plt.subplots(figsize=(5,5))
else:
fig = plt.gcf()

if not "TIME" in ds.indexes.keys():
ds = ds.set_xindex('TIME')
ds_sel = ds.sel(TIME=slice(start_time, end_time))
Expand All @@ -486,17 +499,22 @@ def plot_section_with_srss(ds, ax, sel_var='TEMP',start_time = '2023-09-06', end
ax.axvline(np.unique(m), c='orange')
ax.set_ylabel('Depth [m]')
plt.colorbar(c, label=f'{sel_var} [{ds[sel_var].units}]')
return ax
return fig, ax

def check_temporal_drift(ds, ax1, ax2, var='DOXY'):
ax1.scatter(mdates.date2num(ds.TIME),ds[var], s=10)
ax1.xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
ax1.set(ylim=(np.nanpercentile(ds[var], 0.01), np.nanpercentile(ds[var], 99.99)), ylabel=var)
def check_temporal_drift(ds: xr.Dataset, var: str,ax: plt.Axes = None, **kw: dict,)-> tuple({plt.Figure, plt.Axes}):
if ax is None:
fig, ax = plt.subplots(1, 2, figsize=(14, 6))
else:
fig = plt.gcf()

ax[0].scatter(mdates.date2num(ds.TIME),ds[var], s=10)
ax[0].xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
ax[0].set(ylim=(np.nanpercentile(ds[var], 0.01), np.nanpercentile(ds[var], 99.99)), ylabel=var)

c=ax2.scatter(ds[var],ds.DEPTH,c=mdates.date2num(ds.TIME), s=10)
ax2.set(xlim=(np.nanpercentile(ds[var], 0.01), np.nanpercentile(ds[var], 99.99)),ylabel='Depth (m)', xlabel=var)
ax2.invert_yaxis()
c=ax[1].scatter(ds[var],ds.DEPTH,c=mdates.date2num(ds.TIME), s=10)
ax[1].set(xlim=(np.nanpercentile(ds[var], 0.01), np.nanpercentile(ds[var], 99.99)),ylabel='Depth (m)', xlabel=var)
ax[1].invert_yaxis()

[a.grid() for a in [ax1, ax2]]
[a.grid() for a in ax]
plt.colorbar(c, format=DateFormatter('%b %d'))
return ax1, ax2
return fig, ax
21 changes: 13 additions & 8 deletions notebooks/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "7c437da6-c3b3-4c48-b272-ee5b8ac27f69",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -142,8 +142,7 @@
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n",
"tools.check_temporal_drift(ds,ax[0], ax[1], var='CHLA')"
"tools.check_temporal_drift(ds, var='CHLA')"
]
},
{
Expand All @@ -156,7 +155,7 @@
"# Let's visually check a section of chlorphyll and see if we observe any NPQ\n",
"fig, ax = plt.subplots(1, 1, figsize=(15, 5))\n",
"\n",
"tools.plot_section_with_srss(ds, ax, sel_var='CHLA',start_time = '2023-09-06', end_time = '2023-09-10', ylim=35)"
"tools.plot_section_with_srss(ds, 'CHLA', ax, start_time = '2023-09-06', end_time = '2023-09-10', ylim=35)"
]
},
{
Expand Down Expand Up @@ -278,8 +277,7 @@
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n",
"tools.check_temporal_drift(ds, ax[0], ax[1], var='BBP700')"
"tools.check_temporal_drift(ds, var='BBP700')"
]
},
{
Expand All @@ -301,9 +299,16 @@
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n",
"tools.check_temporal_drift(ds, ax[0], ax[1], var='DOXY')"
"tools.check_temporal_drift(ds, var='DOXY')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0bc6bf18-ab77-4446-a8c4-eaecb9460b59",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
6 changes: 3 additions & 3 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_quench_sequence():
if not "TIME" in ds.indexes.keys():
ds = ds.set_xindex('TIME')
fig, ax = plt.subplots()
tools.plot_section_with_srss(ds, ax, sel_var='CHLA',start_time = '2023-09-06', end_time = '2023-09-10', ylim=35)
tools.plot_section_with_srss(ds, 'CHLA', ax,start_time = '2023-09-06', end_time = '2023-09-10', ylim=35)
dayT, nightT = tools.day_night_avg(ds, sel_var='TEMP',start_time = '2023-09-06', end_time = '2023-09-10')
fig, ax = plt.subplots()
tools.plot_daynight_avg( dayT, nightT,ax,sel_day='2023-09-08', xlabel='Temperature [C]')
Expand All @@ -35,6 +35,6 @@ def test_temporal_drift():
ds = fetchers.load_sample_dataset()
fig, ax = plt.subplots(1, 2)
if 'DOXY' in ds.variables:
tools.check_temporal_drift(ds,ax[0], ax[1], var='DOXY')
tools.check_temporal_drift(ds,'DOXY', ax)
if 'CHLA' in ds.variables:
tools.check_temporal_drift(ds,ax[0], ax[1], var='CHLA')
tools.check_temporal_drift(ds,'CHLA')

0 comments on commit 9fca3c1

Please sign in to comment.