Skip to content

Commit b9d7e03

Browse files
committed
Added statistics lines to the violinplot function.
1 parent da40c9d commit b9d7e03

File tree

2 files changed

+110
-40
lines changed

2 files changed

+110
-40
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6725,7 +6725,8 @@ def matshow(self, Z, **kwargs):
67256725
integer=True))
67266726
return im
67276727

6728-
def violinplot(self, dataset, positions=None, width=0.5):
6728+
def violinplot(self, dataset, positions=None, widths=0.5, showmeans=False,
6729+
showextrema=True, showmedians=False):
67296730
"""
67306731
Make a violin plot.
67316732
@@ -6748,11 +6749,20 @@ def violinplot(self, dataset, positions=None, width=0.5):
67486749
Sets the positions of the violins. The ticks and limits are
67496750
automatically set to match the positions.
67506751
6751-
width : array-like, default = 0.5
6752+
widths : array-like, default = 0.5
67526753
Either a scalar or a vector that sets the maximal width of
67536754
each violin. The default is 0.5, which uses about half of the
67546755
available horizontal space.
67556756
6757+
showmeans : bool, default = False
6758+
If true, will toggle rendering of the means.
6759+
6760+
showextrema : bool, default = True
6761+
If true, will toggle rendering of the extrema.
6762+
6763+
showmedians : bool, default = False
6764+
If true, will toggle rendering of the medians.
6765+
67566766
Returns
67576767
-------
67586768
@@ -6763,24 +6773,58 @@ def violinplot(self, dataset, positions=None, width=0.5):
67636773
- bodies: A list of the
67646774
:class:`matplotlib.collections.PolyCollection` instances
67656775
containing the filled area of each violin.
6766-
- means: A list of the :class:`matplotlib.lines.Line2D` instances
6767-
created to identify the mean values for each of the violins.
6768-
- caps: A list of the :class:`matplotlib.lines.Line2D` instances
6769-
created to identify the extremal values of each violin's
6770-
data set.
6776+
- means: A :class:`matplotlib.collections.LineCollection` instance
6777+
created to identify the mean values of each of the violin's
6778+
distribution.
6779+
- mins: A :class:`matplotlib.collections.LineCollection` instance
6780+
created to identify the bottom of each violin's distribution.
6781+
- maxes: A :class:`matplotlib.collections.LineCollection` instance
6782+
created to identify the top of each violin's distribution.
6783+
- bars: A :class:`matplotlib.collections.LineCollection` instance
6784+
created to identify the centers of each violin's distribution.
6785+
- medians: A :class:`matplotlib.collections.LineCollection` instance
6786+
created to identify the median values of each of the violin's
6787+
distribution.
67716788
67726789
"""
67736790

6774-
bodies = []
6791+
# Statistical quantities to be plotted on the violins
67756792
means = []
6776-
caps = []
6793+
mins = []
6794+
maxes = []
6795+
medians = []
6796+
6797+
# Collections to be returned
6798+
bodies = []
6799+
cmeans = None
6800+
cmaxes = None
6801+
cmins = None
6802+
cbars = None
6803+
cmedians = None
67776804

6805+
# Validate positions
67786806
if positions == None:
67796807
positions = range(1, len(dataset) + 1)
67806808
elif len(positions) != len(dataset):
67816809
raise ValueError(datashape_message.format("positions"))
67826810

6783-
for d,p in zip(dataset,positions):
6811+
# Validate widths
6812+
if np.isscalar(widths):
6813+
widths = [widths] * len(dataset)
6814+
elif len(widths) != len(dataset):
6815+
raise ValueError(datashape_message.format("widths"))
6816+
6817+
# Calculate mins and maxes for statistics lines
6818+
pmins = -0.25 * np.array(widths) + positions
6819+
pmaxes = 0.25 * np.array(widths) + positions
6820+
6821+
# Check hold status
6822+
if not self._hold:
6823+
self.cla()
6824+
holdStatus = self._hold
6825+
6826+
# Render violins
6827+
for d,p,w in zip(dataset,positions,widths):
67846828
# Calculate the kernel density
67856829
kde = mlab.ksdensity(d)
67866830
m = kde['xmin']
@@ -6793,18 +6837,43 @@ def violinplot(self, dataset, positions=None, width=0.5):
67936837
# Since each data point p is plotted from v-p to v+p,
67946838
# we need to scale it by an additional 0.5 factor so that we get
67956839
# correct width in the end.
6796-
v = 0.5 * width * v/v.max()
6840+
v = 0.5 * w * v/v.max()
67976841

67986842
bodies += [self.fill_betweenx(coords,
67996843
-v+p,
68006844
v+p,
68016845
facecolor='y',
68026846
alpha=0.3)]
68036847

6848+
means.append(mean)
6849+
mins.append(m)
6850+
maxes.append(M)
6851+
medians.append(median)
6852+
6853+
# Render means
6854+
if showmeans:
6855+
cmeans = self.hlines(means, pmins, pmaxes, colors='r')
6856+
6857+
# Render extrema
6858+
if showextrema:
6859+
cmaxes = self.hlines(maxes, pmins, pmaxes, colors='r')
6860+
cmins = self.hlines(mins, pmins, pmaxes, colors='r')
6861+
cbars = self.vlines(positions, mins, maxes, colors='r')
6862+
6863+
# Render medians
6864+
if showmedians:
6865+
cmedians = self.hlines(medians, pmins, pmaxes, colors='r')
6866+
6867+
# Reset hold
6868+
self.hold(holdStatus)
6869+
68046870
return {
68056871
'bodies' : bodies,
6806-
'means' : means,
6807-
'caps' : caps
6872+
'means' : cmeans,
6873+
'mins' : cmins,
6874+
'maxes' : cmaxes,
6875+
'bars' : cbars,
6876+
'medians' : cmedians
68086877
}
68096878

68106879

lib/matplotlib/mlab.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3661,7 +3661,7 @@ def ksdensity(dataset, bw_method=None):
36613661
Representation of a kernel-density estimate using Gaussian kernels.
36623662
36633663
Call signature::
3664-
xmin, xmax, result = ksdensity(dataset, 'scott')
3664+
kde_dict = ksdensity(dataset, 'silverman')
36653665
36663666
Parameters
36673667
----------
@@ -3714,22 +3714,22 @@ def ksdensity(dataset, bw_method=None):
37143714
"""
37153715

37163716
# This implementation with minor modification was too good to pass up.
3717-
# from scipy: https://github.com/scipy/scipy/blob/master/scipy/stats/kde.py
3717+
# from scipy: https://github.com/scipy/scipy/blob/master/scipy/stats/kde.py
37183718

3719-
dataset = np.atleast_2d(dataset)
3719+
dataset = np.array(np.atleast_2d(dataset))
37203720
xmin = dataset.min()
37213721
xmax = dataset.max()
37223722

37233723
if not dataset.size > 1:
37243724
raise ValueError("`dataset` input should have multiple elements.")
37253725

3726-
d, n = dataset.shape
3726+
dim, num_dp = dataset.shape
37273727

37283728
# ----------------------------------------------
37293729
# Set Bandwidth, defaulted to Scott's Factor
37303730
# ----------------------------------------------
3731-
scotts_factor = lambda: np.power(n, -1./(d+4))
3732-
silverman_factor = lambda: np.power(n*(d+2.0)/4.0, -1./(d+4))
3731+
scotts_factor = lambda: np.power(num_dp, -1./(dim+4))
3732+
silverman_factor = lambda: np.power(num_dp*(dim+2.0)/4.0, -1./(dim+4))
37333733

37343734
# Default method to calculate bandwidth, can be overwritten by subclass
37353735
covariance_factor = scotts_factor
@@ -3740,7 +3740,7 @@ def ksdensity(dataset, bw_method=None):
37403740
covariance_factor = scotts_factor
37413741
elif bw_method == 'silverman':
37423742
covariance_factor = silverman_factor
3743-
elif np.isscalar(bw_method) and not isinstance(bw_method, string_types):
3743+
elif np.isscalar(bw_method) and not isinstance(bw_method, six.string_types):
37443744
covariance_factor = lambda: bw_method
37453745
else:
37463746
msg = "`bw_method` should be 'scott', 'silverman', or a scalar"
@@ -3752,53 +3752,54 @@ def ksdensity(dataset, bw_method=None):
37523752
factor = covariance_factor()
37533753

37543754
# Cache covariance and inverse covariance of the data
3755-
data_covariance = np.atleast_2d(np.cov(dataset, rowvar=1,bias=False))
3755+
data_covariance = np.atleast_2d(np.cov(dataset, rowvar=1, bias=False))
37563756
data_inv_cov = np.linalg.inv(data_covariance)
37573757

37583758
covariance = data_covariance * factor**2
37593759
inv_cov = data_inv_cov / factor**2
3760-
norm_factor = np.sqrt(np.linalg.det(2*np.pi*covariance)) * n
3760+
norm_factor = np.sqrt(np.linalg.det(2*np.pi*covariance)) * num_dp
37613761

37623762
# ----------------------------------------------
37633763
# Evaluate the estimated pdf on a set of points.
37643764
# ----------------------------------------------
3765-
points = np.atleast_2d(np.arange(xmin,xmax, (xmax-xmin)/100.))
3765+
points = np.atleast_2d(np.arange(xmin, xmax, (xmax-xmin)/100.))
37663766

3767-
d1, m1 = points.shape
3768-
if d1 != d:
3769-
if d1 == 1 and m1 == d:
3767+
dim_pts, num_dp_pts = np.array(points).shape
3768+
if dim_pts != dim:
3769+
if dim_pts == 1 and num_dp_pts == num_dp:
37703770
# points was passed in as a row vector
3771-
points = np.reshape(points, (d, 1))
3772-
m1 = 1
3771+
points = np.reshape(points, (dim, 1))
3772+
num_dp_pts = 1
37733773
else:
3774-
msg = "points have dimension %s, dataset has dimension %s" % (d1, d)
3774+
msg = "points have dimension %s,\
3775+
dataset has dimension %s" % (dim_pts, dim)
37753776
raise ValueError(msg)
37763777

3777-
result = np.zeros((m1,), dtype=np.float)
3778+
result = np.zeros((num_dp_pts,), dtype=np.float)
37783779

3779-
if m1 >= n:
3780+
if num_dp_pts >= num_dp:
37803781
# there are more points than data, so loop over data
3781-
for i in range(n):
3782+
for i in range(num_dp):
37823783
diff = dataset[:, i, np.newaxis] - points
37833784
tdiff = np.dot(inv_cov, diff)
3784-
energy = np.sum(diff*tdiff,axis=0) / 2.0
3785+
energy = np.sum(diff*tdiff, axis=0) / 2.0
37853786
result = result + np.exp(-energy)
37863787
else:
37873788
# loop over points
3788-
for i in range(m):
3789-
diff = dataset - points[:, i, newaxis]
3789+
for i in range(num_dp_pts):
3790+
diff = dataset - points[:, i, np.newaxis]
37903791
tdiff = np.dot(inv_cov, diff)
37913792
energy = np.sum(diff * tdiff, axis=0) / 2.0
37923793
result[i] = np.sum(np.exp(-energy), axis=0)
37933794

37943795
result = result / norm_factor
37953796

37963797
return {
3797-
'xmin' : xmin,
3798-
'xmax' : xmax,
3799-
'mean' : np.mean(result),
3800-
'median' : np.median(result),
3801-
'result' : result
3798+
'xmin': xmin,
3799+
'xmax': xmax,
3800+
'mean': np.mean(dataset),
3801+
'median': np.median(dataset),
3802+
'result': result
38023803
}
38033804

38043805
##################################################

0 commit comments

Comments
 (0)