@@ -221,17 +221,22 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
221
221
self .surface_ = plot_func (self .xx0 , self .xx1 , self .response , ** kwargs )
222
222
else : # self.response.ndim == 3
223
223
n_responses = self .response .shape [- 1 ]
224
- if (
225
- isinstance (self .multiclass_colors , str )
226
- or self .multiclass_colors is None
224
+ for kwarg in ("cmap" , "colors" ):
225
+ if kwarg in kwargs :
226
+ warnings .warn (
227
+ f"'{ kwarg } ' is ignored in favor of 'multiclass_colors' "
228
+ "in the multiclass case when the response method is "
229
+ "'decision_function' or 'predict_proba'."
230
+ )
231
+ del kwargs [kwarg ]
232
+
233
+ if self .multiclass_colors is None or isinstance (
234
+ self .multiclass_colors , str
227
235
):
228
- if isinstance ( self .multiclass_colors , str ) :
229
- cmap = self . multiclass_colors
236
+ if self .multiclass_colors is None :
237
+ cmap = "tab10" if n_responses <= 10 else "gist_rainbow"
230
238
else :
231
- if n_responses <= 10 :
232
- cmap = "tab10"
233
- else :
234
- cmap = "gist_rainbow"
239
+ cmap = self .multiclass_colors
235
240
236
241
# Special case for the tab10 and tab20 colormaps that encode a
237
242
# discrete set of colors that are easily distinguishable
@@ -241,40 +246,41 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
241
246
elif cmap == "tab20" and n_responses <= 20 :
242
247
colors = plt .get_cmap ("tab20" , 20 ).colors [:n_responses ]
243
248
else :
244
- colors = plt .get_cmap (cmap , n_responses ).colors
245
- elif isinstance (self .multiclass_colors , str ):
246
- colors = colors = plt .get_cmap (
247
- self .multiclass_colors , n_responses
248
- ).colors
249
- else :
249
+ cmap = plt .get_cmap (cmap , n_responses )
250
+ if not hasattr (cmap , "colors" ):
251
+ # For LinearSegmentedColormap
252
+ colors = cmap (np .linspace (0 , 1 , n_responses ))
253
+ else :
254
+ colors = cmap .colors
255
+ elif isinstance (self .multiclass_colors , list ):
250
256
colors = [mpl .colors .to_rgba (color ) for color in self .multiclass_colors ]
257
+ else :
258
+ raise ValueError ("'multiclass_colors' must be a list or a str." )
251
259
252
260
self .multiclass_colors_ = colors
253
- multiclass_cmaps = [
254
- mpl .colors .LinearSegmentedColormap .from_list (
255
- f"colormap_{ class_idx } " , [(1.0 , 1.0 , 1.0 , 1.0 ), (r , g , b , 1.0 )]
256
- )
257
- for class_idx , (r , g , b , _ ) in enumerate (colors )
258
- ]
259
-
260
- self .surface_ = []
261
- for class_idx , cmap in enumerate (multiclass_cmaps ):
262
- response = np .ma .array (
263
- self .response [:, :, class_idx ],
264
- mask = ~ (self .response .argmax (axis = 2 ) == class_idx ),
261
+ if plot_method == "contour" :
262
+ # Plot only argmax map for contour
263
+ class_map = self .response .argmax (axis = 2 )
264
+ self .surface_ = plot_func (
265
+ self .xx0 , self .xx1 , class_map , colors = colors , ** kwargs
265
266
)
266
- # `cmap` should not be in kwargs
267
- safe_kwargs = kwargs .copy ()
268
- if "cmap" in safe_kwargs :
269
- del safe_kwargs ["cmap" ]
270
- warnings .warn (
271
- "Plotting max class of multiclass 'decision_function' or "
272
- "'predict_proba', thus 'multiclass_colors' used and "
273
- "'cmap' kwarg ignored."
267
+ else :
268
+ multiclass_cmaps = [
269
+ mpl .colors .LinearSegmentedColormap .from_list (
270
+ f"colormap_{ class_idx } " , [(1.0 , 1.0 , 1.0 , 1.0 ), (r , g , b , 1.0 )]
271
+ )
272
+ for class_idx , (r , g , b , _ ) in enumerate (colors )
273
+ ]
274
+
275
+ self .surface_ = []
276
+ for class_idx , cmap in enumerate (multiclass_cmaps ):
277
+ response = np .ma .array (
278
+ self .response [:, :, class_idx ],
279
+ mask = ~ (self .response .argmax (axis = 2 ) == class_idx ),
280
+ )
281
+ self .surface_ .append (
282
+ plot_func (self .xx0 , self .xx1 , response , cmap = cmap , ** kwargs )
274
283
)
275
- self .surface_ .append (
276
- plot_func (self .xx0 , self .xx1 , response , cmap = cmap , ** safe_kwargs )
277
- )
278
284
279
285
if xlabel is not None or not ax .get_xlabel ():
280
286
xlabel = self .xlabel if xlabel is None else xlabel
0 commit comments