99import  matplotlib .collections  as  mcollections 
1010import  matplotlib .patches  as  patches 
1111
12+ 
1213__all__  =  ['streamplot' ]
1314
1415
@@ -46,11 +47,16 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
4647    *minlength* : float 
4748        Minimum length of streamline in axes coordinates. 
4849
49-     Returns *streamlines* : :class:`~matplotlib.collections.LineCollection` 
50-         Line collection with all streamlines as a series of line segments. 
51-         Currently, there is no way to differentiate between line segments 
52-         on different streamlines (other than manually checking that segments 
53-         are connected). 
50+     Returns 
51+     ------- 
52+     *stream_container* : StreamplotSet 
53+         Container object with attributes 
54+             lines : `matplotlib.collections.LineCollection` of streamlines 
55+             arrows : collection of `matplotlib.patches.FancyArrowPatch` objects 
56+                 repesenting arrows half-way along stream lines. 
57+         This container will probably change in the future to allow changes to 
58+         the colormap, alpha, etc. for both lines and arrows, but these changes 
59+         should be backward compatible. 
5460    """ 
5561    grid  =  Grid (x , y )
5662    mask  =  StreamMask (density )
@@ -108,6 +114,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
108114            cmap  =  cm .get_cmap (cmap )
109115
110116    streamlines  =  []
117+     arrows  =  []
111118    for  t  in  trajectories :
112119        tgx  =  np .array (t [0 ])
113120        tgy  =  np .array (t [1 ])
@@ -139,6 +146,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
139146                                    transform = transform , 
140147                                    ** arrow_kw )
141148        axes .add_patch (p )
149+         arrows .append (p )
142150
143151    lc  =  mcollections .LineCollection (streamlines , 
144152                                     transform = transform , 
@@ -151,7 +159,17 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
151159
152160    axes .update_datalim (((x .min (), y .min ()), (x .max (), y .max ())))
153161    axes .autoscale_view (tight = True )
154-     return  lc 
162+ 
163+     ac  =  matplotlib .collections .PatchCollection (arrows )
164+     stream_container  =  StreamplotSet (lc , ac )
165+     return  stream_container 
166+ 
167+ 
168+ class  StreamplotSet (object ):
169+ 
170+     def  __init__ (self , lines , arrows , ** kwargs ):
171+         self .lines  =  lines 
172+         self .arrows  =  arrows 
155173
156174
157175# Coordinate definitions 
0 commit comments