Skip to content

Commit 679f919

Browse files
Generalize check of split_mode name (always lowercase)
1 parent 33e152c commit 679f919

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

kllr/regression_plotting.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1281,24 +1281,24 @@ def Plot_Cov_Corr_Matrix_Split(df, xlabel, ylabels, split_label, split_bins=[],
12811281

12821282
# Choose bin edges for binning data
12831283
if (isinstance(split_bins, int)):
1284-
if split_mode == 'Data':
1284+
if split_mode.lower() == 'data':
12851285
split_bins = [np.percentile(split_data, float(i / split_bins) * 100) for i in
12861286
range(0, split_bins + 1)]
1287-
elif split_mode == 'Residuals':
1287+
elif split_mode.lower() == 'residuals':
12881288
split_res = lm.residuals(x_data, split_data, xrange=None, bins = bins, nBootstrap = 1)
12891289
split_bins = [np.percentile(split_res, float(i / split_bins) * 100) for i in
12901290
range(0, split_bins + 1)]
1291-
elif isinstance(split_bins, (np.ndarray, list, tuple)) & (split_mode == 'Residuals'):
1291+
elif isinstance(split_bins, (np.ndarray, list, tuple)) & (split_mode.lower() == 'residuals'):
12921292
split_res = lm.residuals(x_data, split_data, xrange=None, bins = bins, nBootstrap = 1)
12931293

12941294
# Define Output_Data variable to store all computed data that is then plotted
12951295
output_Data = {'Bin' + str(i): {} for i in range(len(split_bins) - 1)}
12961296

12971297
for k in range(len(split_bins) - 1):
12981298

1299-
if split_mode == 'Data':
1299+
if split_mode.lower() == 'data':
13001300
split_Mask = (split_data <= split_bins[k + 1]) & (split_data > split_bins[k])
1301-
elif split_mode == 'Residuals':
1301+
elif split_mode.lower() == 'residuals':
13021302
split_Mask = (split_res <= split_bins[k + 1]) & (split_res > split_bins[k])
13031303

13041304
# Edge case for y_err
@@ -1321,9 +1321,9 @@ def Plot_Cov_Corr_Matrix_Split(df, xlabel, ylabels, split_label, split_bins=[],
13211321

13221322
if nBootstrap == 1: cov_corr = cov_corr[None, :]
13231323

1324-
if split_mode == 'Data':
1324+
if split_mode.lower() == 'data':
13251325
label = r'$%0.2f <$ %s $< %0.2f$' % (split_bins[k], labels[-1], split_bins[k + 1])
1326-
elif split_mode == 'Residuals':
1326+
elif split_mode.lower() == 'residuals':
13271327
label = r'$%0.2f < {\rm res}($%s$) < %0.2f$' % (split_bins[k], labels[-1], split_bins[k + 1])
13281328

13291329
if xlog:

0 commit comments

Comments
 (0)