@@ -221,17 +221,22 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
221221            self .surface_  =  plot_func (self .xx0 , self .xx1 , self .response , ** kwargs )
222222        else :  # self.response.ndim == 3 
223223            n_responses  =  self .response .shape [- 1 ]
224-             if  (
225-                 isinstance (self .multiclass_colors , str )
226-                 or  self .multiclass_colors  is  None 
224+             for  kwarg  in  ("cmap" , "colors" ):
225+                 if  kwarg  in  kwargs :
226+                     warnings .warn (
227+                         f"'{ kwarg }  
228+                         "in the multiclass case when the response method is " 
229+                         "'decision_function' or 'predict_proba'." 
230+                     )
231+                     del  kwargs [kwarg ]
232+ 
233+             if  self .multiclass_colors  is  None  or  isinstance (
234+                 self .multiclass_colors , str 
227235            ):
228-                 if  isinstance ( self .multiclass_colors ,  str ) :
229-                     cmap  =  self . multiclass_colors 
236+                 if  self .multiclass_colors   is   None :
237+                     cmap  =  "tab10"   if   n_responses   <=   10   else   "gist_rainbow" 
230238                else :
231-                     if  n_responses  <=  10 :
232-                         cmap  =  "tab10" 
233-                     else :
234-                         cmap  =  "gist_rainbow" 
239+                     cmap  =  self .multiclass_colors 
235240
236241                # Special case for the tab10 and tab20 colormaps that encode a 
237242                # discrete set of colors that are easily distinguishable 
@@ -241,40 +246,41 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
241246                elif  cmap  ==  "tab20"  and  n_responses  <=  20 :
242247                    colors  =  plt .get_cmap ("tab20" , 20 ).colors [:n_responses ]
243248                else :
244-                     colors  =  plt .get_cmap (cmap , n_responses ).colors 
245-             elif  isinstance (self .multiclass_colors , str ):
246-                 colors  =  colors  =  plt .get_cmap (
247-                     self .multiclass_colors , n_responses 
248-                 ).colors 
249-             else :
249+                     cmap  =  plt .get_cmap (cmap , n_responses )
250+                     if  not  hasattr (cmap , "colors" ):
251+                         # For LinearSegmentedColormap 
252+                         colors  =  cmap (np .linspace (0 , 1 , n_responses ))
253+                     else :
254+                         colors  =  cmap .colors 
255+             elif  isinstance (self .multiclass_colors , list ):
250256                colors  =  [mpl .colors .to_rgba (color ) for  color  in  self .multiclass_colors ]
257+             else :
258+                 raise  ValueError ("'multiclass_colors' must be a list or a str." )
251259
252260            self .multiclass_colors_  =  colors 
253-             multiclass_cmaps  =  [
254-                 mpl .colors .LinearSegmentedColormap .from_list (
255-                     f"colormap_{ class_idx }  , [(1.0 , 1.0 , 1.0 , 1.0 ), (r , g , b , 1.0 )]
256-                 )
257-                 for  class_idx , (r , g , b , _ ) in  enumerate (colors )
258-             ]
259- 
260-             self .surface_  =  []
261-             for  class_idx , cmap  in  enumerate (multiclass_cmaps ):
262-                 response  =  np .ma .array (
263-                     self .response [:, :, class_idx ],
264-                     mask = ~ (self .response .argmax (axis = 2 ) ==  class_idx ),
261+             if  plot_method  ==  "contour" :
262+                 # Plot only argmax map for contour 
263+                 class_map  =  self .response .argmax (axis = 2 )
264+                 self .surface_  =  plot_func (
265+                     self .xx0 , self .xx1 , class_map , colors = colors , ** kwargs 
265266                )
266-                 # `cmap` should not be in kwargs 
267-                 safe_kwargs  =  kwargs .copy ()
268-                 if  "cmap"  in  safe_kwargs :
269-                     del  safe_kwargs ["cmap" ]
270-                     warnings .warn (
271-                         "Plotting max class of multiclass 'decision_function' or " 
272-                         "'predict_proba', thus 'multiclass_colors' used and " 
273-                         "'cmap' kwarg ignored." 
267+             else :
268+                 multiclass_cmaps  =  [
269+                     mpl .colors .LinearSegmentedColormap .from_list (
270+                         f"colormap_{ class_idx }  , [(1.0 , 1.0 , 1.0 , 1.0 ), (r , g , b , 1.0 )]
271+                     )
272+                     for  class_idx , (r , g , b , _ ) in  enumerate (colors )
273+                 ]
274+ 
275+                 self .surface_  =  []
276+                 for  class_idx , cmap  in  enumerate (multiclass_cmaps ):
277+                     response  =  np .ma .array (
278+                         self .response [:, :, class_idx ],
279+                         mask = ~ (self .response .argmax (axis = 2 ) ==  class_idx ),
280+                     )
281+                     self .surface_ .append (
282+                         plot_func (self .xx0 , self .xx1 , response , cmap = cmap , ** kwargs )
274283                    )
275-                 self .surface_ .append (
276-                     plot_func (self .xx0 , self .xx1 , response , cmap = cmap , ** safe_kwargs )
277-                 )
278284
279285        if  xlabel  is  not None  or  not  ax .get_xlabel ():
280286            xlabel  =  self .xlabel  if  xlabel  is  None  else  xlabel 
0 commit comments