|  | 
|  | 1 | +import itertools | 
|  | 2 | +from functools import partial | 
|  | 3 | + | 
|  | 4 | +import numpy as np | 
|  | 5 | +import matplotlib.pyplot as plt | 
|  | 6 | +from cycler import cycler | 
|  | 7 | +from six.moves import zip | 
|  | 8 | + | 
|  | 9 | + | 
|  | 10 | +def filled_hist(ax, edges, values, bottoms=None, orientation='v', | 
|  | 11 | +                **kwargs): | 
|  | 12 | +    """ | 
|  | 13 | +    Draw a histogram as a stepped patch. | 
|  | 14 | +
 | 
|  | 15 | +    Extra kwargs are passed through to `fill_between` | 
|  | 16 | +
 | 
|  | 17 | +    Parameters | 
|  | 18 | +    ---------- | 
|  | 19 | +    ax : Axes | 
|  | 20 | +        The axes to plot to | 
|  | 21 | +
 | 
|  | 22 | +    edges : array | 
|  | 23 | +        A length n+1 array giving the left edges of each bin and the | 
|  | 24 | +        right edge of the last bin. | 
|  | 25 | +
 | 
|  | 26 | +    values : array | 
|  | 27 | +        A length n array of bin counts or values | 
|  | 28 | +
 | 
|  | 29 | +    bottoms : scalar or array, optional | 
|  | 30 | +        A length n array of the bottom of the bars.  If None, zero is used. | 
|  | 31 | +
 | 
|  | 32 | +    orientation : {'v', 'h'} | 
|  | 33 | +       Orientation of the histogram.  'v' (default) has | 
|  | 34 | +       the bars increasing in the positive y-direction. | 
|  | 35 | +
 | 
|  | 36 | +    Returns | 
|  | 37 | +    ------- | 
|  | 38 | +    ret : PolyCollection | 
|  | 39 | +        Artist added to the Axes | 
|  | 40 | +    """ | 
|  | 41 | +    print(orientation) | 
|  | 42 | +    if orientation not in set('hv'): | 
|  | 43 | +        raise ValueError("orientation must be in {'h', 'v'} " | 
|  | 44 | +                         "not {o}".format(o=orientation)) | 
|  | 45 | + | 
|  | 46 | +    kwargs.setdefault('step', 'post') | 
|  | 47 | +    edges = np.asarray(edges) | 
|  | 48 | +    values = np.asarray(values) | 
|  | 49 | +    if len(edges) - 1 != len(values): | 
|  | 50 | +        raise ValueError('Must provide one more bin edge than value not: ' | 
|  | 51 | +                         'len(edges): {lb} len(values): {lv}'.format( | 
|  | 52 | +                             lb=len(edges), lv=len(values))) | 
|  | 53 | + | 
|  | 54 | +    if bottoms is None: | 
|  | 55 | +        bottoms = np.zeros_like(values) | 
|  | 56 | +    if np.isscalar(bottoms): | 
|  | 57 | +        bottoms = np.ones_like(values) * bottoms | 
|  | 58 | + | 
|  | 59 | +    values = np.r_[values, values[-1]] | 
|  | 60 | +    bottoms = np.r_[bottoms, bottoms[-1]] | 
|  | 61 | +    if orientation == 'h': | 
|  | 62 | +        return ax.fill_betweenx(edges, values, bottoms, **kwargs) | 
|  | 63 | +    elif orientation == 'v': | 
|  | 64 | +        return ax.fill_between(edges, values, bottoms, **kwargs) | 
|  | 65 | +    else: | 
|  | 66 | +        raise AssertionError("you should never be here") | 
|  | 67 | + | 
|  | 68 | + | 
|  | 69 | +def stack_hist(ax, stacked_data, sty_cycle, bottoms=None, | 
|  | 70 | +               hist_func=None, labels=None, | 
|  | 71 | +               plot_func=None, plot_kwargs=None): | 
|  | 72 | +    """ | 
|  | 73 | +    ax : axes.Axes | 
|  | 74 | +        The axes to add artists too | 
|  | 75 | +
 | 
|  | 76 | +    stacked_data : array or Mapping | 
|  | 77 | +        A (N, M) shaped array.  The first dimension will be iterated over to | 
|  | 78 | +        compute histograms row-wise | 
|  | 79 | +
 | 
|  | 80 | +    sty_cycle : Cycler or operable of dict | 
|  | 81 | +        Style to apply to each set | 
|  | 82 | +
 | 
|  | 83 | +    bottoms : array, optional | 
|  | 84 | +        The initial positions of the bottoms, defaults to 0 | 
|  | 85 | +
 | 
|  | 86 | +    hist_func : callable, optional | 
|  | 87 | +        Must have signature `bin_vals, bin_edges = f(data)`. | 
|  | 88 | +        `bin_edges` expected to be one longer than `bin_vals` | 
|  | 89 | +
 | 
|  | 90 | +    labels : list of str, optional | 
|  | 91 | +        The label for each set. | 
|  | 92 | +
 | 
|  | 93 | +        If not given and stacked data is an array defaults to 'default set {n}' | 
|  | 94 | +
 | 
|  | 95 | +        If stacked_data is a mapping, and labels is None, default to the keys | 
|  | 96 | +        (which may come out in a random order). | 
|  | 97 | +
 | 
|  | 98 | +        If stacked_data is a mapping and labels is given then only | 
|  | 99 | +        the columns listed by be plotted. | 
|  | 100 | +
 | 
|  | 101 | +    plot_func : callable, optional | 
|  | 102 | +        Function to call to draw the histogram must have signature: | 
|  | 103 | +
 | 
|  | 104 | +          ret = plot_func(ax, edges, top, bottoms=bottoms, | 
|  | 105 | +                          label=label, **kwargs) | 
|  | 106 | +
 | 
|  | 107 | +    plot_kwargs : dict, optional | 
|  | 108 | +        Any extra kwargs to pass through to the plotting function.  This | 
|  | 109 | +        will be the same for all calls to the plotting function and will | 
|  | 110 | +        over-ride the values in cycle. | 
|  | 111 | +
 | 
|  | 112 | +    Returns | 
|  | 113 | +    ------- | 
|  | 114 | +    arts : dict | 
|  | 115 | +        Dictionary of artists keyed on their labels | 
|  | 116 | +    """ | 
|  | 117 | +    # deal with default binning function | 
|  | 118 | +    if hist_func is None: | 
|  | 119 | +        hist_func = np.histogram | 
|  | 120 | + | 
|  | 121 | +    # deal with default plotting function | 
|  | 122 | +    if plot_func is None: | 
|  | 123 | +        plot_func = filled_hist | 
|  | 124 | + | 
|  | 125 | +    # deal with default | 
|  | 126 | +    if plot_kwargs is None: | 
|  | 127 | +        plot_kwargs = {} | 
|  | 128 | +    print(plot_kwargs) | 
|  | 129 | +    try: | 
|  | 130 | +        l_keys = stacked_data.keys() | 
|  | 131 | +        label_data = True | 
|  | 132 | +        if labels is None: | 
|  | 133 | +            labels = l_keys | 
|  | 134 | + | 
|  | 135 | +    except AttributeError: | 
|  | 136 | +        label_data = False | 
|  | 137 | +        if labels is None: | 
|  | 138 | +            labels = itertools.repeat(None) | 
|  | 139 | + | 
|  | 140 | +    if label_data: | 
|  | 141 | +        loop_iter = enumerate((stacked_data[lab], lab, s) for lab, s in | 
|  | 142 | +                              zip(labels, sty_cycle)) | 
|  | 143 | +    else: | 
|  | 144 | +        loop_iter = enumerate(zip(stacked_data, labels, sty_cycle)) | 
|  | 145 | + | 
|  | 146 | +    arts = {} | 
|  | 147 | +    for j, (data, label, sty) in loop_iter: | 
|  | 148 | +        if label is None: | 
|  | 149 | +            label = 'default set {n}'.format(n=j) | 
|  | 150 | +        label = sty.pop('label', label) | 
|  | 151 | +        vals, edges = hist_func(data) | 
|  | 152 | +        if bottoms is None: | 
|  | 153 | +            bottoms = np.zeros_like(vals) | 
|  | 154 | +        top = bottoms + vals | 
|  | 155 | +        print(sty) | 
|  | 156 | +        sty.update(plot_kwargs) | 
|  | 157 | +        print(sty) | 
|  | 158 | +        ret = plot_func(ax, edges, top, bottoms=bottoms, | 
|  | 159 | +                        label=label, **sty) | 
|  | 160 | +        bottoms = top | 
|  | 161 | +        arts[label] = ret | 
|  | 162 | +    ax.legend() | 
|  | 163 | +    return arts | 
|  | 164 | + | 
|  | 165 | + | 
|  | 166 | +# set up histogram function to fixed bins | 
|  | 167 | +edges = np.linspace(-3, 3, 20, endpoint=True) | 
|  | 168 | +hist_func = partial(np.histogram, bins=edges) | 
|  | 169 | + | 
|  | 170 | +# set up style cycles | 
|  | 171 | +color_cycle = cycler('facecolor', 'rgbm') | 
|  | 172 | +label_cycle = cycler('label', ['set {n}'.format(n=n) for n in range(4)]) | 
|  | 173 | +hatch_cycle = cycler('hatch', ['/', '*', '+', '|']) | 
|  | 174 | + | 
|  | 175 | +# make some synthetic data | 
|  | 176 | +stack_data = np.random.randn(4, 12250) | 
|  | 177 | +dict_data = {lab: d for lab, d in zip(list(c['label'] for c in label_cycle), | 
|  | 178 | +                                      stack_data)} | 
|  | 179 | + | 
|  | 180 | +# work with plain arrays | 
|  | 181 | +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True) | 
|  | 182 | +arts = stack_hist(ax1, stack_data, color_cycle + label_cycle + hatch_cycle, | 
|  | 183 | +                  hist_func=hist_func) | 
|  | 184 | + | 
|  | 185 | +arts = stack_hist(ax2, stack_data, color_cycle, | 
|  | 186 | +                  hist_func=hist_func, | 
|  | 187 | +                  plot_kwargs=dict(edgecolor='w', orientation='h')) | 
|  | 188 | +ax1.set_ylabel('counts') | 
|  | 189 | +ax1.set_xlabel('x') | 
|  | 190 | +ax2.set_xlabel('counts') | 
|  | 191 | +ax2.set_ylabel('x') | 
|  | 192 | + | 
|  | 193 | +# work with labeled data | 
|  | 194 | + | 
|  | 195 | +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), | 
|  | 196 | +                               tight_layout=True, sharey=True) | 
|  | 197 | + | 
|  | 198 | +arts = stack_hist(ax1, dict_data, color_cycle + hatch_cycle, | 
|  | 199 | +                  hist_func=hist_func) | 
|  | 200 | + | 
|  | 201 | +arts = stack_hist(ax2, dict_data, color_cycle + hatch_cycle, | 
|  | 202 | +                  hist_func=hist_func, labels=['set 0', 'set 3']) | 
|  | 203 | + | 
|  | 204 | +ax1.set_ylabel('counts') | 
|  | 205 | +ax1.set_xlabel('x') | 
|  | 206 | +ax2.set_xlabel('x') | 
0 commit comments