Skip to content

Commit

Permalink
Merge pull request #136 from MOchiara/chiara-patch-37
Browse files Browse the repository at this point in the history
[FEAT] Added section to demo showing resize/change labels etc.
  • Loading branch information
MOchiara authored Nov 22, 2024
2 parents 0a43fc5 + e2fbdbb commit f32ff9c
Show file tree
Hide file tree
Showing 2 changed files with 276 additions and 77 deletions.
92 changes: 46 additions & 46 deletions glidertest/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def plot_glider_track(ds: xr.Dataset, ax: plt.Axes = None, **kw: dict) -> tuple(
gl.top_labels = False
gl.right_labels = False
plt.show()
ax = plt.gca()

return fig, ax

Expand Down Expand Up @@ -622,7 +623,8 @@ def plot_grid_spacing(ds: xr.Dataset, ax: plt.Axes = None, **kw: dict) -> tuple(

return fig, ax

def plot_ts(ds: xr.Dataset, ax: plt.Axes = None, **kw: dict) -> tuple({plt.Figure, plt.Axes}):

def plot_ts(ds: xr.Dataset, axs: plt.Axes = None, **kw: dict) -> tuple({plt.Figure, plt.Axes}):
"""
This function plots histograms of temperature and salinity values (middle 95%), and a 2D histogram of salinity and temperature with density contours.
Expand All @@ -644,23 +646,19 @@ def plot_ts(ds: xr.Dataset, ax: plt.Axes = None, **kw: dict) -> tuple({plt.Figur
"""
utilities._check_necessary_variables(ds, ['DEPTH', 'LONGITUDE', 'LATITUDE', 'PSAL', 'TEMP'])
with plt.style.context(glidertest_style_file):
if ax is None:
fig= plt.figure()
if axs is None:
fig, ax = plt.subplots(2, 3)
plt.subplots_adjust(wspace=0.03, hspace=0.03)
force_plot = True
axs = ax.flatten()
else:
fig = plt.gcf()
plt.subplots_adjust(wspace=0.03, hspace=0.03)
force_plot = False

axs[3].set_visible(False)
axs[5].set_visible(False)
num_bins = 30

gs = fig.add_gridspec(2, 2, width_ratios=(1, 1), height_ratios=(1, 1),
left=0.1, right=0.9, bottom=0.1, top=0.9,
wspace=0.03, hspace=0.03)
# Create the Axes
ax = fig.add_subplot(gs[0, 1])
ax_histy = fig.add_subplot(gs[0, 0], sharey=ax)
ax_histx = fig.add_subplot(gs[1, 1], sharex=ax)

temp_orig = ds.TEMP.values
sal_orig = ds.PSAL.values

Expand All @@ -687,51 +685,52 @@ def plot_ts(ds: xr.Dataset, ax: plt.Axes = None, **kw: dict) -> tuple({plt.Figur
zi = gsw.sigma0(xi, yi)

# Temperature histogram
ax_histy.hist(CT_filtered, bins=num_bins, orientation="horizontal", **kw)
ax_histy.set_ylabel('Conservative Temperature (°C)')
ax_histy.set_xlabel('Frequency', rotation="horizontal")
ax_histy.invert_xaxis()
axs[0].hist(CT_filtered, bins=num_bins, orientation="horizontal", **kw)
axs[0].set_ylabel('Conservative Temperature (°C)')
axs[0].set_xlabel('Frequency', rotation="horizontal")
axs[0].invert_xaxis()

# Salinity histogram
ax_histx.hist(SA_filtered, bins=num_bins, **kw)
ax_histx.set_xlabel('Absolute Salinity ( )')
ax_histx.set_ylabel('Frequency', rotation="vertical")
ax_histx.yaxis.set_label_position("right")
ax_histx.yaxis.tick_right()
ax_histx.invert_yaxis()

for tick in ax.xaxis.get_major_ticks():
axs[4].hist(SA_filtered, bins=num_bins, **kw)
axs[4].set_xlabel('Absolute Salinity ( )')
axs[4].set_ylabel('Frequency', rotation="vertical")
axs[4].yaxis.set_label_position("right")
axs[4].yaxis.tick_right()
axs[4].invert_yaxis()

for tick in axs[1].xaxis.get_major_ticks():
tick.tick1line.set_visible(False)
tick.tick2line.set_visible(False)
tick.label1.set_visible(False)
tick.label2.set_visible(False)
for tick in ax.yaxis.get_major_ticks():
for tick in axs[1].yaxis.get_major_ticks():
tick.tick1line.set_visible(False)
tick.tick2line.set_visible(False)
tick.label1.set_visible(False)
tick.label2.set_visible(False)

# 2-d T-S histogram
h = ax.hist2d(SA_filtered, CT_filtered, bins=num_bins, cmap='viridis', norm=mcolors.LogNorm(), **kw)
ax.contour(xi, yi, zi, colors='black', alpha=0.5, linewidths=0.5)
ax.clabel(ax.contour(xi, yi, zi, colors='black', alpha=0.5, linewidths=0.5), inline=True)
cb_ax = fig.add_axes([.91, 0.52, .04, .37])
cbar = fig.colorbar(h[3], orientation='vertical', cax=cb_ax)
h = axs[1].hist2d(SA_filtered, CT_filtered, bins=num_bins, cmap='viridis', norm=mcolors.LogNorm(), **kw)
axs[1].contour(xi, yi, zi, colors='black', alpha=0.5, linewidths=0.5)
axs[1].clabel(axs[1].contour(xi, yi, zi, colors='black', alpha=0.5, linewidths=0.5), inline=True)
cbar = fig.colorbar(h[3], orientation='vertical',cax=axs[2])
cbar.set_label('Log Counts')
ax.set_title('2D Histogram \n (Log Scale)')
axs[1].set_title('2D Histogram \n (Log Scale)')
#Resize axs[2] as colorbar
box2 = axs[2].get_position()
axs[2].set_position([box2.x0, box2.y0, box2.width / 6, box2.height])
# Set x-limits based on salinity plot and y-limits based on temperature plot
ax.set_xlim(ax_histx.get_xlim())
ax.set_ylim(ax_histy.get_ylim())
axs[1].set_xlim(axs[4].get_xlim())
axs[1].set_ylim(axs[0].get_ylim())

# Set font sizes for all annotations
for axes in [ax, ax_histx, ax_histy]:
for axes in [axs[0],axs[1],axs[4]]:
axes.tick_params(axis='both', which='major')
axes.grid(True, which='both', linestyle='--', linewidth=0.5, color='grey')
if force_plot:
plt.show()
all_ax = [ax, ax_histy, ax_histx]
return fig, all_ax

all_ax = axs
return fig, all_ax
def plot_vertical_speeds_with_histograms(ds, start_prof=None, end_prof=None):
"""
Plot vertical speeds with histograms for diagnostic purposes.
Expand Down Expand Up @@ -948,6 +947,7 @@ def plot_combined_velocity_profiles(ds_out_dives: xr.Dataset, ds_out_climbs: xr.
plt.show()
return fig, ax


def plot_hysteresis(ds, var='DOXY', v_res=1, perct_err=2, ax=None):
"""
This function creates 4 plots which can help the user visualize any possible hysteresis
Expand Down Expand Up @@ -975,17 +975,17 @@ def plot_hysteresis(ds, var='DOXY', v_res=1, perct_err=2, ax=None):
with plt.style.context(glidertest_style_file):
if ax is None:
fig = plt.figure()
ax = [plt.subplot(3, 3, 1), plt.subplot(3, 3, 2), plt.subplot(3, 3, 3), plt.subplot(3, 1, 2)]
force_plot = True
else:
fig = plt.gcf()
force_plot = False
gs = fig.add_gridspec(2, 3)
# open axes/subplots
ax = []
ax.append(fig.add_subplot(gs[0, 0]))
ax.append(fig.add_subplot(gs[0, 1]))
ax.append(fig.add_subplot(gs[0, 2]))
ax.append(fig.add_subplot(gs[1, :]))
if len(ax) == 1:
fig = plt.gcf()
ax = [plt.subplot(3, 3, 1), plt.subplot(3, 3, 2), plt.subplot(3, 3, 3), plt.subplot(3, 1, 2)]
force_plot = False
else:
fig = plt.gcf()
force_plot = False

ax[0].plot(df.climb, df.depth, label='Dive')
ax[0].plot(df.dive, df.depth, label='Climb')
ax[0].legend()
Expand Down
Loading

0 comments on commit f32ff9c

Please sign in to comment.