diff --git a/optbinning/binning/binning_statistics.py b/optbinning/binning/binning_statistics.py index d12fd19..3260db3 100644 --- a/optbinning/binning/binning_statistics.py +++ b/optbinning/binning/binning_statistics.py @@ -608,7 +608,7 @@ def build(self, show_digits=2, add_totals=True): return df def plot(self, metric="woe", add_special=True, add_missing=True, - style="bin", show_bin_labels=False, savefig=None, figsize=None): + style="bin", show_bin_labels=False, savefig=None, figsize=None, save_kwargs=None): """Plot the binning table. Visualize the non-event and event count, and the Weight of Evidence or @@ -642,6 +642,9 @@ def plot(self, metric="woe", add_special=True, add_missing=True, figsize : tuple or None (default=None) Size of the plot. + + save_kwargs : dict or None (default=None) + Additional keyword arguments to be passed to `plt.savefig`. """ _check_is_built(self) @@ -863,7 +866,13 @@ def plot(self, metric="woe", add_special=True, add_missing=True, if not isinstance(savefig, str): raise TypeError("savefig must be a string path; got {}." .format(savefig)) - plt.savefig(savefig) + if save_kwargs is None: + save_kwargs = {} + else: + if not isinstance(save_kwargs, dict): + raise TypeError("save_kwargs must be a dictionary; got {}." + .format(save_kwargs)) + plt.savefig(savefig, **save_kwargs) plt.close() def analysis(self, pvalue_test="chi2", n_samples=100, print_output=True): diff --git a/optbinning/binning/multidimensional/binning_statistics_2d.py b/optbinning/binning/multidimensional/binning_statistics_2d.py index 37ae0b8..faf7cb3 100644 --- a/optbinning/binning/multidimensional/binning_statistics_2d.py +++ b/optbinning/binning/multidimensional/binning_statistics_2d.py @@ -338,7 +338,7 @@ def build(self, show_digits=2, show_bin_xy=False, add_totals=True): return df - def plot(self, metric="woe", savefig=None): + def plot(self, metric="woe", savefig=None, save_kwargs=None): """Plot the binning table. Visualize the Weight of Evidence or the event rate for each bin as a @@ -352,6 +352,9 @@ def plot(self, metric="woe", savefig=None): savefig : str or None (default=None) Path to save the plot figure. + + save_kwargs : dict or None (default=None) + Additional keyword arguments to be passed to `plt.savefig`. """ _check_is_built(self) @@ -437,7 +440,13 @@ def plot(self, metric="woe", savefig=None): if not isinstance(savefig, str): raise TypeError("savefig must be a string path; got {}." .format(savefig)) - plt.savefig(savefig) + if save_kwargs is None: + save_kwargs = {} + else: + if not isinstance(save_kwargs, dict): + raise TypeError("save_kwargs must be a dictionary; got {}." + .format(save_kwargs)) + plt.savefig(savefig, **save_kwargs) plt.close() def analysis(self, pvalue_test="chi2", n_samples=100, print_output=True): diff --git a/optbinning/binning/piecewise/binning_statistics.py b/optbinning/binning/piecewise/binning_statistics.py index e05109d..ad4328f 100644 --- a/optbinning/binning/piecewise/binning_statistics.py +++ b/optbinning/binning/piecewise/binning_statistics.py @@ -177,7 +177,7 @@ def build(self, show_digits=2, add_totals=True): return df - def plot(self, metric="woe", n_samples=10000, savefig=None): + def plot(self, metric="woe", n_samples=10000, savefig=None, save_kwargs=None): """Plot the binning table. Visualize the non-event and event count, and the predicted Weight of @@ -194,6 +194,9 @@ def plot(self, metric="woe", n_samples=10000, savefig=None): savefig : str or None (default=None) Path to save the plot figure. + + save_kwargs : dict or None (default=None) + Additional keyword arguments to be passed to `plt.savefig`. """ _check_is_built(self) @@ -258,7 +261,13 @@ def plot(self, metric="woe", n_samples=10000, savefig=None): if not isinstance(savefig, str): raise TypeError("savefig must be a string path; got {}." .format(savefig)) - plt.savefig(savefig) + if save_kwargs is None: + save_kwargs = {} + else: + if not isinstance(save_kwargs, dict): + raise TypeError("save_kwargs must be a dictionary; got {}." + .format(save_kwargs)) + plt.savefig(savefig, **save_kwargs) plt.close() def analysis(self, pvalue_test="chi2", n_samples=100, print_output=True):