3434from . import art3d
3535from . import proj3d
3636from . import axis3d
37- from mpl_toolkits .mplot3d .art3d import Line3DCollection
3837
3938def unit_bbox ():
4039 box = Bbox (np .array ([[0 , 0 ], [1 , 1 ]]))
@@ -2431,19 +2430,23 @@ def quiver(self, *args, **kwargs):
24312430 *U*, *V*, *W*:
24322431 The direction vector that the arrow is pointing
24332432
2434- The arguments could be iterable or scalars they will be broadcast together. The arguments can
2435- also be masked arrays, if a position in any of argument is masked, then the corresponding
2436- quiver will not be plotted.
2433+ The arguments could be array-like or scalars, so long as they
2434+ they can be broadcast together. The arguments can also be
2435+ masked arrays. If an element in any of argument is masked, then
2436+ that corresponding quiver element will not be plotted.
24372437
24382438 Keyword arguments:
24392439
24402440 *length*: [1.0 | float]
2441- The length of each quiver, default to 1.0, the unit is the same with the axes
2441+ The length of each quiver, default to 1.0, the unit is
2442+ the same with the axes
24422443
24432444 *arrow_length_ratio*: [0.3 | float]
2444- The ratio of the arrow head with respect to the quiver, default to 0.3
2445+ The ratio of the arrow head with respect to the quiver,
2446+ default to 0.3
24452447
2446- Any additional keyword arguments are delegated to :class:`~matplotlib.collections.LineCollection`
2448+ Any additional keyword arguments are delegated to
2449+ :class:`~matplotlib.collections.LineCollection`
24472450
24482451 """
24492452 def calc_arrow (u , v , w , angle = 15 ):
@@ -2472,8 +2475,8 @@ def rotatefunction(angle):
24722475
24732476 # construct the rotation matrix
24742477 R = np .matrix ([[c + (x ** 2 )* (1 - c ), x * y * (1 - c )- z * s , x * z * (1 - c )+ y * s ],
2475- [y * x * (1 - c )+ z * s , c + (y ** 2 )* (1 - c ), y * z * (1 - c )- x * s ],
2476- [z * x * (1 - c )- y * s , z * y * (1 - c )+ x * s , c + (z ** 2 )* (1 - c )]])
2478+ [y * x * (1 - c )+ z * s , c + (y ** 2 )* (1 - c ), y * z * (1 - c )- x * s ],
2479+ [z * x * (1 - c )- y * s , z * y * (1 - c )+ x * s , c + (z ** 2 )* (1 - c )]])
24772480
24782481 # construct the column vector for (u,v,w)
24792482 line = np .matrix ([[u ],[v ],[w ]])
@@ -2512,7 +2515,9 @@ def point_vector_to_line(point, vector, length):
25122515 # first 6 arguments are X, Y, Z, U, V, W
25132516 input_args = args [:argi ]
25142517 # if any of the args are scalar, convert into list
2515- input_args = [[k ] if isinstance (k , (int , float )) else k for k in input_args ]
2518+ input_args = [[k ] if isinstance (k , (int , float )) else k
2519+ for k in input_args ]
2520+
25162521 # extract the masks, if any
25172522 masks = [k .mask for k in input_args if isinstance (k , np .ma .MaskedArray )]
25182523 # broadcast to match the shape
@@ -2523,36 +2528,43 @@ def point_vector_to_line(point, vector, length):
25232528 # combine the masks into one
25242529 mask = reduce (np .logical_or , masks )
25252530 # put mask on and compress
2526- input_args = [np .ma .array (k , mask = mask ).compressed () for k in input_args ]
2531+ input_args = [np .ma .array (k , mask = mask ).compressed ()
2532+ for k in input_args ]
25272533 else :
25282534 input_args = [k .flatten () for k in input_args ]
25292535
2536+ if any (len (v ) == 0 for v in input_args ):
2537+ # No quivers, so just make an empty collection and return early
2538+ linec = art3d .Line3DCollection ([], * args [6 :], ** kwargs )
2539+ self .add_collection (linec )
2540+ return linec
2541+
25302542 points = input_args [:3 ]
25312543 vectors = input_args [3 :]
25322544
25332545 # Below assertions must be true before proceed
25342546 # must all be ndarray
2535- assert all ([ isinstance (k , np .ndarray ) for k in input_args ] )
2547+ assert all (isinstance (k , np .ndarray ) for k in input_args )
25362548 # must all in same shape
25372549 assert len (set ([k .shape for k in input_args ])) == 1
25382550
2539-
25402551 # X, Y, Z, U, V, W
2541- coords = list (map (lambda k : np .array (k ) if not isinstance (k , np .ndarray ) else k , args ))
2552+ coords = (np .array (k ) if not isinstance (k , np .ndarray ) else k
2553+ for k in args )
25422554 coords = [k .flatten () for k in coords ]
25432555 xs , ys , zs , us , vs , ws = coords
25442556 lines = []
25452557
25462558 # for each arrow
2547- for i in xrange (xs .shape [0 ]):
2559+ for i in range (xs .shape [0 ]):
25482560 # calulate body
25492561 x = xs [i ]
25502562 y = ys [i ]
25512563 z = zs [i ]
25522564 u = us [i ]
25532565 v = vs [i ]
25542566 w = ws [i ]
2555- if any ([ k is np .ma .masked for k in [x , y , z , u , v , w ] ]):
2567+ if any (k is np .ma .masked for k in [x , y , z , u , v , w ]):
25562568 continue
25572569
25582570 # (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
@@ -2590,7 +2602,7 @@ def point_vector_to_line(point, vector, length):
25902602 line = list (zip (la2x , la2y , la2z ))
25912603 lines .append (line )
25922604
2593- linec = Line3DCollection (lines , * args [6 :], ** kwargs )
2605+ linec = art3d . Line3DCollection (lines , * args [6 :], ** kwargs )
25942606 self .add_collection (linec )
25952607
25962608 self .auto_scale_xyz (xs , ys , zs , had_data )
0 commit comments