From ec40b758254c7531a85c8a36a91ea8bdb1298a1f Mon Sep 17 00:00:00 2001 From: Daniel Kleine <53251018+d-kleine@users.noreply.github.com> Date: Fri, 8 Nov 2024 18:55:42 +0100 Subject: [PATCH] Improved `plot_splits` for time series splits (#1113) * fixed test_end_idx calculation * added legend title * fixed intersection of xticklabels for Sample index * added changes to changelog * Update CHANGELOG.md * removed redundancy * start x-axis at the origin of coordinates * fixed CV iteration labels * revert idx change * Update CHANGELOG.md * add groups legend * fixed cmap in groups legend --- docs/sources/CHANGELOG.md | 13 +++++++++++++ mlxtend/evaluate/time_series.py | 27 +++++++++++++++++++++++---- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/docs/sources/CHANGELOG.md b/docs/sources/CHANGELOG.md index 705a9844d..6b2bce57c 100755 --- a/docs/sources/CHANGELOG.md +++ b/docs/sources/CHANGELOG.md @@ -7,6 +7,19 @@ The CHANGELOG for the current development version is available at --- +### Version 0.23.3 (tbd) + +##### Downloads +... + +##### New Features and Enhancements + +Files updated: + - ['mlxtend.evaluate.time_series.plot_splits'](https://github.com/rasbt/mlxtend/blob/master/mlxtend/evaluate/time_series.py) + - Improved `plot_splits` for better visualization of time series splits + +##### Changes +... ### Version 0.23.2 (5 Nov 2024) diff --git a/mlxtend/evaluate/time_series.py b/mlxtend/evaluate/time_series.py index fb2d97169..58b5149ed 100644 --- a/mlxtend/evaluate/time_series.py +++ b/mlxtend/evaluate/time_series.py @@ -290,7 +290,7 @@ def plot_split_indices(cv, cv_args, X, y, groups, n_splits, image_file_path=None s=marker_size, ) - yticklabels = list(range(n_splits)) + ["group"] + yticklabels = list(range(1, n_splits + 1)) + ["group"] ax.set( yticks=np.arange(n_splits + 1) + 0.5, yticklabels=yticklabels, @@ -299,15 +299,34 @@ def plot_split_indices(cv, cv_args, X, y, groups, n_splits, image_file_path=None xlim=[-0.5, len(indices) - 0.5], ) - ax.legend( + legend_splits = ax.legend( [Patch(color=cmap_cv(0.2)), Patch(color=cmap_cv(0.8))], ["Training set", "Testing set"], - loc=(1.02, 0.8), + title="Data Splits", + loc="upper right", + fontsize=13, + ) + + ax.add_artist(legend_splits) + + group_labels = [f"{group}" for group in np.unique(groups)] + cmap = plt.cm.get_cmap("tab20", len(group_labels)) + + unique_patches = {} + for i, group in enumerate(np.unique(groups)): + unique_patches[group] = Patch(color=cmap(i), label=f"{group}") + + ax.legend( + handles=list(unique_patches.values()), + title="Groups", + loc="center left", + bbox_to_anchor=(1.02, 0.5), fontsize=13, ) ax.set_title("{}\n{}".format(type(cv).__name__, cv_args), fontsize=15) - ax.xaxis.set_major_locator(MaxNLocator(min_n_ticks=len(X), integer=True)) + ax.set_xlim(0, len(X)) + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.set_xlabel(xlabel="Sample index", fontsize=13) ax.set_ylabel(ylabel="CV iteration", fontsize=13) ax.tick_params(axis="both", which="major", labelsize=13)