@@ -793,15 +793,27 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
793793 *ncols* : int
794794 Number of columns of the subplot grid. Defaults to 1.
795795
796- *sharex* : bool
796+ *sharex* : string or bool
797797 If *True*, the X axis will be shared amongst all subplots. If
798798 *True* and you have multiple rows, the x tick labels on all but
799799 the last row of plots will have visible set to *False*
800-
801- *sharey* : bool
800+ If a string must be one of "row", "col", "all", or "none".
801+ "all" has the same effect as *True*, "none" has the same effect
802+ as *False*.
803+ If "row", each subplot row will share a X axis.
804+ If "col", each subplot column will share a X axis and the x tick
805+ labels on all but the last row will have visible set to *False*.
806+
807+ *sharey* : string or bool
802808 If *True*, the Y axis will be shared amongst all subplots. If
803809 *True* and you have multiple columns, the y tick labels on all but
804810 the first column of plots will have visible set to *False*
811+ If a string must be one of "row", "col", "all", or "none".
812+ "all" has the same effect as *True*, "none" has the same effect
813+ as *False*.
814+ If "row", each subplot row will share a Y axis.
815+ If "col", each subplot column will share a Y axis and the y tick
816+ labels on all but the last row will have visible set to *False*.
805817
806818 *squeeze* : bool
807819 If *True*, extra dimensions are squeezed out from the
@@ -859,7 +871,36 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
859871
860872 # Four polar axes
861873 plt.subplots(2, 2, subplot_kw=dict(polar=True))
874+
875+ # Share a X axis with each column of subplots
876+ plt.subplots(2, 2, sharex='col')
877+
878+ # Share a Y axis with each row of subplots
879+ plt.subplots(2, 2, sharey='row')
880+
881+ # Share a X and Y axis with all subplots
882+ plt.subplots(2, 2, sharex='all', sharey='all')
883+ # same as
884+ plt.subplots(2, 2, sharex=True, sharey=True)
862885 """
886+ # for backwards compatability
887+ if isinstance (sharex , bool ):
888+ if sharex :
889+ sharex = "all"
890+ else :
891+ sharex = "none"
892+ if isinstance (sharey , bool ):
893+ if sharey :
894+ sharey = "all"
895+ else :
896+ sharey = "none"
897+ share_values = ["all" , "row" , "col" , "none" ]
898+ if sharex not in share_values :
899+ raise ValueError ("sharex [%s] must be one of %s" % \
900+ (sharex , share_values ))
901+ if sharey not in share_values :
902+ raise ValueError ("sharey [%s] must be one of %s" % \
903+ (sharey , share_values ))
863904
864905 if subplot_kw is None :
865906 subplot_kw = {}
@@ -873,34 +914,52 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
873914
874915 # Create first subplot separately, so we can share it if requested
875916 ax0 = fig .add_subplot (nrows , ncols , 1 , ** subplot_kw )
876- if sharex :
877- subplot_kw ['sharex' ] = ax0
878- if sharey :
879- subplot_kw ['sharey' ] = ax0
917+ # if sharex:
918+ # subplot_kw['sharex'] = ax0
919+ # if sharey:
920+ # subplot_kw['sharey'] = ax0
880921 axarr [0 ] = ax0
881922
923+ r , c = np .mgrid [:nrows , :ncols ]
924+ r = r .flatten () * ncols
925+ c = c .flatten ()
926+ lookup = {
927+ "none" : np .arange (nplots ),
928+ "all" : np .zeros (nplots , dtype = int ),
929+ "row" : r ,
930+ "col" : c ,
931+ }
932+ sxs = lookup [sharex ]
933+ sys = lookup [sharey ]
934+
882935 # Note off-by-one counting because add_subplot uses the MATLAB 1-based
883936 # convention.
884937 for i in range (1 , nplots ):
885- axarr [i ] = fig .add_subplot (nrows , ncols , i + 1 , ** subplot_kw )
886-
887-
938+ if sxs [i ] == i :
939+ subplot_kw ['sharex' ] = None
940+ else :
941+ subplot_kw ['sharex' ] = axarr [sxs [i ]]
942+ if sys [i ] == i :
943+ subplot_kw ['sharey' ] = None
944+ else :
945+ subplot_kw ['sharey' ] = axarr [sys [i ]]
946+ axarr [i ] = fig .add_subplot (nrows , ncols , i + 1 , ** subplot_kw )
888947
889948 # returned axis array will be always 2-d, even if nrows=ncols=1
890949 axarr = axarr .reshape (nrows , ncols )
891950
892-
893951 # turn off redundant tick labeling
894- if sharex and nrows > 1 :
952+ if sharex in ["col" , "all" ] and nrows > 1 :
953+ #if sharex and nrows>1:
895954 # turn off all but the bottom row
896- for ax in axarr [:- 1 ,:].flat :
955+ for ax in axarr [:- 1 , :].flat :
897956 for label in ax .get_xticklabels ():
898957 label .set_visible (False )
899958
900-
901- if sharey and ncols > 1 :
959+ if sharey in [ "row" , "all" ] and ncols > 1 :
960+ # if sharey and ncols>1:
902961 # turn off all but the first column
903- for ax in axarr [:,1 :].flat :
962+ for ax in axarr [:, 1 :].flat :
904963 for label in ax .get_yticklabels ():
905964 label .set_visible (False )
906965
0 commit comments