-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Jeff Hernandez
authored
Sep 10, 2019
1 parent
49dca85
commit 50e9f90
Showing
14 changed files
with
442 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import matplotlib as mpl | ||
import pandas as pd | ||
import seaborn as sns | ||
|
||
pd.plotting.register_matplotlib_converters() | ||
sns.set_context('notebook') | ||
sns.set_style('darkgrid') | ||
COLOR = sns.color_palette("Set1", n_colors=100, desat=.75) | ||
|
||
|
||
class LabelPlots: | ||
"""Creates plots for Label Times.""" | ||
|
||
def __init__(self, label_times): | ||
"""Initializes Label Plots. | ||
Args: | ||
label_times (LabelTimes) : instance of Label Times | ||
""" | ||
self._label_times = label_times | ||
|
||
def count_by_time(self, ax=None, **kwargs): | ||
"""Plots the label distribution across cutoff times.""" | ||
count_by_time = self._label_times.count_by_time | ||
count_by_time.sort_index(inplace=True) | ||
|
||
ax = ax or mpl.pyplot.axes() | ||
vmin = count_by_time.index.min() | ||
vmax = count_by_time.index.max() | ||
ax.set_xlim(vmin, vmax) | ||
|
||
locator = mpl.dates.AutoDateLocator() | ||
formatter = mpl.dates.AutoDateFormatter(locator) | ||
ax.xaxis.set_major_locator(locator) | ||
ax.xaxis.set_major_formatter(formatter) | ||
ax.figure.autofmt_xdate() | ||
|
||
if len(count_by_time.shape) > 1: | ||
ax.stackplot( | ||
count_by_time.index, | ||
count_by_time.values.T, | ||
labels=count_by_time.columns, | ||
colors=COLOR, | ||
alpha=.9, | ||
**kwargs, | ||
) | ||
|
||
ax.legend( | ||
loc='upper left', | ||
title=self._label_times.name, | ||
facecolor='w', | ||
framealpha=.9, | ||
) | ||
|
||
ax.set_title('Label Count vs. Cutoff Times') | ||
ax.set_ylabel('Count') | ||
ax.set_xlabel('Time') | ||
|
||
else: | ||
ax.fill_between( | ||
count_by_time.index, | ||
count_by_time.values.T, | ||
color=COLOR[1], | ||
) | ||
|
||
ax.set_title('Label vs. Cutoff Times') | ||
ax.set_ylabel(self._label_times.name) | ||
ax.set_xlabel('Time') | ||
|
||
return ax | ||
|
||
@property | ||
def dist(self): | ||
"""Alias for distribution.""" | ||
return self.distribution | ||
|
||
def distribution(self, **kwargs): | ||
"""Plots the label distribution.""" | ||
dist = self._label_times[self._label_times.name] | ||
|
||
if self._label_times.is_discrete: | ||
ax = sns.countplot(dist, palette=COLOR, **kwargs) | ||
else: | ||
ax = sns.distplot(dist, kde=True, color=COLOR[1], **kwargs) | ||
|
||
ax.set_title('Label Distribution') | ||
ax.set_ylabel('Count') | ||
return ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,21 @@ | ||
def test_distribution_plot(labels): | ||
labels = labels.threshold(200) | ||
plot = labels.plot.distribution() | ||
assert plot.get_title() == 'Label Distribution' | ||
def test_count_by_time_categorical(total_spent): | ||
labels = range(2) | ||
total_spent = total_spent.bin(2, labels=labels) | ||
ax = total_spent.plot.count_by_time() | ||
assert ax.get_title() == 'Label Count vs. Cutoff Times' | ||
|
||
|
||
def test_count_by_time_plot(labels): | ||
labels = labels.threshold(200) | ||
plot = labels.plot.count_by_time() | ||
assert plot.get_title() == 'Label Count vs. Time' | ||
def test_count_by_time_continuous(total_spent): | ||
ax = total_spent.plot.count_by_time() | ||
assert ax.get_title() == 'Label vs. Cutoff Times' | ||
|
||
|
||
def test_distribution_categorical(total_spent): | ||
ax = total_spent.bin(2, labels=range(2)) | ||
ax = ax.plot.dist() | ||
assert ax.get_title() == 'Label Distribution' | ||
|
||
|
||
def test_distribution_continuous(total_spent): | ||
ax = total_spent.plot.dist() | ||
assert ax.get_title() == 'Label Distribution' |
Oops, something went wrong.