diff --git a/segmentation_models_pytorch/metrics/functional.py b/segmentation_models_pytorch/metrics/functional.py index 435d6947..38f60c1e 100644 --- a/segmentation_models_pytorch/metrics/functional.py +++ b/segmentation_models_pytorch/metrics/functional.py @@ -121,7 +121,7 @@ def get_stats( ) if torch.is_floating_point(output) and mode == "multiclass": - raise ValueError(f"For ``multiclass`` mode ``target`` should be one of the integer types, got {output.dtype}.") + raise ValueError(f"For ``multiclass`` mode ``output`` should be one of the integer types, got {output.dtype}.") if mode not in {"binary", "multiclass", "multilabel"}: raise ValueError(f"``mode`` should be in ['binary', 'multiclass', 'multilabel'], got mode={mode}.")