@@ -438,9 +438,36 @@ def valid(self):
438438 return not self ._force_update
439439
440440
441- def get_gridlines (self ):
441+ def get_gridlines (self , which , axis ):
442+ """
443+ Return list of grid lines as a list of paths (list of points).
444+
445+ *which* : "major" or "minor"
446+ *axis* : "both", "x" or "y"
447+ """
442448 return []
443449
450+ def new_gridlines (self , ax ):
451+ """
452+ Create and return a new GridlineCollection instance.
453+
454+ *which* : "major" or "minor"
455+ *axis* : "both", "x" or "y"
456+
457+ """
458+ gridlines = GridlinesCollection (None , transform = ax .transData ,
459+ colors = rcParams ['grid.color' ],
460+ linestyles = rcParams ['grid.linestyle' ],
461+ linewidths = rcParams ['grid.linewidth' ])
462+ ax ._set_artist_props (gridlines )
463+ gridlines .set_grid_helper (self )
464+
465+ ax .axes ._set_artist_props (gridlines )
466+ # gridlines.set_clip_path(self.axes.patch)
467+ # set_clip_path need to be deferred after Axes.cla is completed.
468+ # It is done inside the cla.
469+
470+ return gridlines
444471
445472
446473class GridHelperRectlinear (GridHelperBase ):
@@ -497,33 +524,43 @@ def new_floating_axis(self, nth_coord, value,
497524 return axisline
498525
499526
500- def get_gridlines (self ):
527+ def get_gridlines (self , which = "major" , axis = "both" ):
501528 """
502529 return list of gridline coordinates in data coordinates.
530+
531+ *which* : "major" or "minor"
532+ *axis* : "both", "x" or "y"
503533 """
504534
505535 gridlines = []
506536
507- locs = []
508- y1 , y2 = self .axes .get_ylim ()
509- if self .axes .xaxis ._gridOnMajor :
510- locs .extend (self .axes .xaxis .major .locator ())
511- if self .axes .xaxis ._gridOnMinor :
512- locs .extend (self .axes .xaxis .minor .locator ())
513537
514- for x in locs :
515- gridlines .append ([[x , x ], [y1 , y2 ]])
538+ if axis in ["both" , "x" ]:
539+ locs = []
540+ y1 , y2 = self .axes .get_ylim ()
541+ #if self.axes.xaxis._gridOnMajor:
542+ if which in ["both" , "major" ]:
543+ locs .extend (self .axes .xaxis .major .locator ())
544+ #if self.axes.xaxis._gridOnMinor:
545+ if which in ["both" , "minor" ]:
546+ locs .extend (self .axes .xaxis .minor .locator ())
516547
548+ for x in locs :
549+ gridlines .append ([[x , x ], [y1 , y2 ]])
517550
518- x1 , x2 = self .axes .get_xlim ()
519- locs = []
520- if self .axes .yaxis ._gridOnMajor :
521- locs .extend (self .axes .yaxis .major .locator ())
522- if self .axes .yaxis ._gridOnMinor :
523- locs .extend (self .axes .yaxis .minor .locator ())
524551
525- for y in locs :
526- gridlines .append ([[x1 , x2 ], [y , y ]])
552+ if axis in ["both" , "y" ]:
553+ x1 , x2 = self .axes .get_xlim ()
554+ locs = []
555+ if self .axes .yaxis ._gridOnMajor :
556+ #if which in ["both", "major"]:
557+ locs .extend (self .axes .yaxis .major .locator ())
558+ if self .axes .yaxis ._gridOnMinor :
559+ #if which in ["both", "minor"]:
560+ locs .extend (self .axes .yaxis .minor .locator ())
561+
562+ for y in locs :
563+ gridlines .append ([[x1 , x2 ], [y , y ]])
527564
528565 return gridlines
529566
@@ -627,20 +664,25 @@ def _get_axislines(self):
627664
628665 axis = property (_get_axislines )
629666
630- def _init_gridlines (self , grid_helper = None ):
631- gridlines = GridlinesCollection (None , transform = self .transData ,
632- colors = rcParams ['grid.color' ],
633- linestyles = rcParams ['grid.linestyle' ],
634- linewidths = rcParams ['grid.linewidth' ])
635- self ._set_artist_props (gridlines )
667+ def new_gridlines (self , grid_helper = None ):
668+ """
669+ Create and return a new GridlineCollection instance.
670+
671+ *which* : "major" or "minor"
672+ *axis* : "both", "x" or "y"
673+
674+ """
636675 if grid_helper is None :
637676 grid_helper = self .get_grid_helper ()
638- gridlines .set_grid_helper (grid_helper )
639677
640- self .axes ._set_artist_props (gridlines )
641- # gridlines.set_clip_path(self.axes.patch)
642- # set_clip_path need to be deferred after Axes.cla is completed.
678+ gridlines = grid_helper .new_gridlines (self )
679+
680+ return gridlines
681+
682+
683+ def _init_gridlines (self , grid_helper = None ):
643684 # It is done inside the cla.
685+ gridlines = self .new_gridlines (grid_helper )
644686
645687 self .gridlines = gridlines
646688
@@ -668,7 +710,7 @@ def grid(self, b=None, which='major', axis="both", **kwargs):
668710 # axes_grid and the original mpl's grid, because axes_grid
669711 # explicitly set the visibility of the gridlines.
670712
671- super (Axes , self ).grid (b , ** kwargs )
713+ super (Axes , self ).grid (b , which = which , axis = axis , ** kwargs )
672714 if not self ._axisline_on :
673715 return
674716
@@ -680,6 +722,8 @@ def grid(self, b=None, which='major', axis="both", **kwargs):
680722 else :
681723 b = False
682724
725+ self .gridlines .set_which (which )
726+ self .gridlines .set_axis (axis )
683727 self .gridlines .set_visible (b )
684728
685729 if len (kwargs ):
0 commit comments