Skip to content

Commit

Permalink
Improved plot_splits for time series splits (#1113)
Browse files Browse the repository at this point in the history
* fixed test_end_idx calculation

* added legend title

* fixed intersection of xticklabels for Sample index

* added changes to changelog

* Update CHANGELOG.md

* removed redundancy

* start x-axis at the origin of coordinates

* fixed CV iteration labels

* revert idx change

* Update CHANGELOG.md

* add groups legend

* fixed cmap in groups legend
  • Loading branch information
d-kleine authored Nov 8, 2024
1 parent c229178 commit ec40b75
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
13 changes: 13 additions & 0 deletions docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ The CHANGELOG for the current development version is available at

---

### Version 0.23.3 (tbd)

##### Downloads
...

##### New Features and Enhancements

Files updated:
- ['mlxtend.evaluate.time_series.plot_splits'](https://github.com/rasbt/mlxtend/blob/master/mlxtend/evaluate/time_series.py)
- Improved `plot_splits` for better visualization of time series splits

##### Changes
...

### Version 0.23.2 (5 Nov 2024)

Expand Down
27 changes: 23 additions & 4 deletions mlxtend/evaluate/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def plot_split_indices(cv, cv_args, X, y, groups, n_splits, image_file_path=None
s=marker_size,
)

yticklabels = list(range(n_splits)) + ["group"]
yticklabels = list(range(1, n_splits + 1)) + ["group"]
ax.set(
yticks=np.arange(n_splits + 1) + 0.5,
yticklabels=yticklabels,
Expand All @@ -299,15 +299,34 @@ def plot_split_indices(cv, cv_args, X, y, groups, n_splits, image_file_path=None
xlim=[-0.5, len(indices) - 0.5],
)

ax.legend(
legend_splits = ax.legend(
[Patch(color=cmap_cv(0.2)), Patch(color=cmap_cv(0.8))],
["Training set", "Testing set"],
loc=(1.02, 0.8),
title="Data Splits",
loc="upper right",
fontsize=13,
)

ax.add_artist(legend_splits)

group_labels = [f"{group}" for group in np.unique(groups)]
cmap = plt.cm.get_cmap("tab20", len(group_labels))

unique_patches = {}
for i, group in enumerate(np.unique(groups)):
unique_patches[group] = Patch(color=cmap(i), label=f"{group}")

ax.legend(
handles=list(unique_patches.values()),
title="Groups",
loc="center left",
bbox_to_anchor=(1.02, 0.5),
fontsize=13,
)

ax.set_title("{}\n{}".format(type(cv).__name__, cv_args), fontsize=15)
ax.xaxis.set_major_locator(MaxNLocator(min_n_ticks=len(X), integer=True))
ax.set_xlim(0, len(X))
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_xlabel(xlabel="Sample index", fontsize=13)
ax.set_ylabel(ylabel="CV iteration", fontsize=13)
ax.tick_params(axis="both", which="major", labelsize=13)
Expand Down

0 comments on commit ec40b75

Please sign in to comment.