-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] ENH Adds plot_confusion matrix #15083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] ENH Adds plot_confusion matrix #15083
Conversation
lint ;) |
fmt = '.2f' if self.normalize else 'd' | ||
thresh = cm.max() / 2. | ||
for i, j in product(range(cm.shape[0]), range(cm.shape[1])): | ||
color = "white" if cm[i, j] < thresh else "black" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's weird as it doesn't depend on the colormap.
Here's how I usually do it:
https://github.com/amueller/mglearn/blob/master/mglearn/tools.py#L76
without depending on the colormap there's no way this works, right? because someone could use greys
and greys_r
and they clearly need the opposite colors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be pcolormesh not pcolor, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also: shouldn't this go in a separate helper function? It's probably not the only time we want to show a heatmap (grid search will need this as well). The main question then is if that will be public or not :-/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated PR, with an alternative: it uses the colormap to get the colors for the text.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable.
Can you maybe add a test? Like calling ConfusionMatrixDisplay
with np.eye(2)
and plt.cm.greys
and check that the text colors are black white black white and with plt.cm.greys_r
and check that the text colors are white black white black?
titles_options = [("Confusion matrix, without normalization", False), | ||
("Normalized confusion matrix", True)] | ||
for title, normalize in titles_options: | ||
fig, ax = plt.subplots() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? There's no reason to pass ax, right?
For setting the title you could just do plt.gca().set_title(title)
.
Or do you knot like using the state like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated with not having to define ax
and passing it in and using the axes stored in the Display
object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, even better.
how about adding it in all the examples that use confusion_matrix: |
select a subset of labels. If `None` is given, those that appear at | ||
least once in `y_true` or `y_pred` are used in sorted order. | ||
|
||
target_names : array-like of shape (n_classes,), default=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't call this target names. That implies multiple targets. Rather, display_labels will be sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, do you think classes
or class_names
would be better?
Includes values in confusion matrix. | ||
|
||
normalize : bool, default=False | ||
Normalizes confusion matrix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user might want to normalise over either axis, or altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Four options? I guess we can do 'row'
, 'column'
, 'all'
, None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay to not provide this flexibility, too. Another way to specify it is "all", "recall", "precision", None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to use "truth" and "predicted" instead of "recall" and "precision"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated PR to use 'truth' and 'predicted'. Almost feels like this should be in confusion_matrix
itself.
Updated with using |
CC @NicolasHug |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Needs a what's new
-
The example doesn't render properly https://79914-843222-gh.circle-artifacts.com/0/doc/auto_examples/model_selection/plot_confusion_matrix.html. Also I think the original color map was nicer but no strong opinion
-
The list at the end of https://79914-843222-gh.circle-artifacts.com/0/doc/visualizations.html#visualizations should be updated
-
The User guide about confusion matrix shoudl be updated too
cmap='viridis', ax=None): | ||
"""Plot Confusion Matrix. | ||
|
||
Read more in the :ref:`User Guide <visualizations>`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably link to https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix instead.
Rotation of xtick labels. | ||
|
||
values_format : str, default=None | ||
Format specification for values in confusion matrix. If None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Format specification for values in confusion matrix. If None, | |
Format specification for values in confusion matrix. If `None`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jus this nitpick
@glemaitre yes our classifiers work with list of strings, but out simple example using |
Updated PR by adding normalize={'all', 'truth', 'predicted'} and None support. |
Conceptually |
lmk if you need reviews. |
Let's call them "true", "pred" for consistency?
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the latest changes, LGTM
Rotation of xtick labels. | ||
|
||
values_format : str, default=None | ||
Format specification for values in confusion matrix. If None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jus this nitpick
So it seems that we need them in case we want to overwrite it. So we can keep it has it is until by default we don't need to specify it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @thomasjpfan , mostly looks good.
I'm slightly concerned about testing time and coupling though
Includes values in confusion matrix. | ||
|
||
normalize : {'true', 'pred', 'all'}, default=None | ||
Normalizes confusion matrix over the true, predicited conditions or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a suggestion
Normalizes confusion matrix over the true, predicited conditions or | |
Normalizes confusion matrix over the true (rows), predicited conditions (columns) or |
labels=labels) | ||
|
||
if normalize == 'true': | ||
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should not convert to float (see other msg about high coupling)
|
||
cm = self.confusion_matrix | ||
n_classes = cm.shape[0] | ||
normalized = np.issubdtype(cm.dtype, np.float_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic involves a strong coupling between
confusion_matrix -> plot_confusion_matrix -> ConfusionMatrixDisplay
and might cause silent bugs in the future.
I would rather pass a is_normalized
parameter (or remove, see below)
if include_values: | ||
self.text_ = np.empty_like(cm, dtype=object) | ||
if values_format is None: | ||
values_format = '.2f' if normalized else 'd' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the .2g
option is what we need, and you wouldn't have to use the normalized
variable anymore:
In [15]: "{:.2g} -- {:.2g} -- {:.2g}".format(2, 2.0000, 2.23425)
Out[15]: '2 -- 2 -- 2.2'
@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None]) | ||
@pytest.mark.parametrize("with_sample_weight", [True, False]) | ||
@pytest.mark.parametrize("with_labels", [True, False]) | ||
@pytest.mark.parametrize("cmap", ['viridis', 'plasma']) | ||
@pytest.mark.parametrize("with_custom_axes", [True, False]) | ||
@pytest.mark.parametrize("with_display_labels", [True, False]) | ||
@pytest.mark.parametrize("include_values", [True, False]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need each of these combinations to be tested independently?
It seems to me that most of the checks in this test could be independent tests functions. Parametrization is nice but seems way overkill here.
This will test 256 instances, and it take about 10s on my machine which is not negligible considering small increment in testing time really add up over time.
To be consistent with the |
Ah |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
last nits
assert disp.ax_ == ax | ||
|
||
if normalize == 'true': | ||
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you dont need the conversion anymore right?
@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None]) | ||
@pytest.mark.parametrize("with_labels", [True, False]) | ||
@pytest.mark.parametrize("with_display_labels", [True, False]) | ||
@pytest.mark.parametrize("include_values", [True, False]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason I'm not a fan of this is that such parametrization suggests that all these 4 parameters are intertwined and are dependent one to another, but in reality this isn't the case
I think we could still remove some parametrizations, but that's fine
create a :class:`ConfusionMatrixDisplay`. All parameters are stored as | ||
attributes. | ||
|
||
Read more in the :ref:`User Guide <confusion_matrix>`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this link to the visualization UG?
include_values : bool, default=True | ||
Includes values in confusion matrix. | ||
|
||
normalize : {'true', 'pred', 'all'}, default=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we decide to support normalize here, perhaps we should also support it in confusion_matrix (See #14478).
And I can't understand why we need normalize="all".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good remark. normalize='all'
will normalize by the total support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, I would suggest to add it to another PR.
I made a push to solve the conflicts |
and I added a similar test to the other plotting for pipeline. |
OK merging this one. I will open a new PR to address the problem raised by @qinhanmin2014 in #15083 (comment) |
Reference Issues/PRs
Related to #7116
What does this implement/fix? Explain your changes.
Adds plotting function for the confusion matrix.