diff --git a/ocw/plotter.py b/ocw/plotter.py index 7f9b0920..f0af03fe 100755 --- a/ocw/plotter.py +++ b/ocw/plotter.py @@ -367,9 +367,24 @@ def draw_subregions(subregions, lats, lons, fname, fmt='png', ptitle='', fig.clf() +def _get_colors(num_colors): + """ + matplotlib will recycle colors after a certain number. This can make + line type charts confusing as colors will be reused. This function + provides a distribution of colors across the default color map + to better approximate uniqueness. + + :param num_colors: The number of unique colors to generate. + :return: A color map with num_colors. + """ + cmap = plt.get_cmap() + return [cmap(1. * i / num_colors) for i in range(num_colors)] + + def draw_time_series(results, times, labels, fname, fmt='png', gridshape=(1, 1), xlabel='', ylabel='', ptitle='', subtitles=None, - label_month=False, yscale='linear', aspect=None): + label_month=False, yscale='linear', aspect=None, + cycle_colors=True, cmap=None): ''' Draw a time series plot. :param results: 3D array of time series data. @@ -415,7 +430,22 @@ def draw_time_series(results, times, labels, fname, fmt='png', gridshape=(1, 1), :param aspect: (Optional) approximate aspect ratio of each subplot (width / height). Default is 8.5 / 5.5 :type aspect: :class:`float` + + :param cycle_colors: (Optional) flag to toggle whether to allow matlibplot + to re-use colors when plotting or force an evenly distributed range. + :type cycle_colors: :class:`bool` + + :param cmap: (Optional) string or :class:`matplotlib.colors.LinearSegmentedColormap` + instance denoting the colormap. This must be able to be recognized by + `Matplotlib's get_cmap function `_. + Maps like rainbow and spectral with wide spectrum of colors are nice choices when used with + the cycle_colors option. tab20, tab20b, and tab20c are good if the plot has less than 20 datasets. + :type cmap: :mod:`string` or :class:`matplotlib.colors.LinearSegmentedColormap` + ''' + if cmap is not None: + set_cmap(cmap) + # Handle the single plot case. if results.ndim == 2: results = results.reshape(1, *results.shape) @@ -448,6 +478,10 @@ def draw_time_series(results, times, labels, fname, fmt='png', gridshape=(1, 1), # Make the plots for i, ax in enumerate(grid): data = results[i] + + if not cycle_colors: + ax.set_prop_cycle('color', _get_colors(data.shape[0])) + if label_month: xfmt = mpl.dates.DateFormatter('%b') xloc = mpl.dates.MonthLocator()