diff --git a/segmentation_models_pytorch/metrics/functional.py b/segmentation_models_pytorch/metrics/functional.py index 1a6d80b6..435d6947 100644 --- a/segmentation_models_pytorch/metrics/functional.py +++ b/segmentation_models_pytorch/metrics/functional.py @@ -65,7 +65,7 @@ def get_stats( ignore_index: Optional[int] = None, threshold: Optional[Union[float, List[float]]] = None, num_classes: Optional[int] = None, -) -> Tuple[torch.LongTensor]: +) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: """Compute true positive, false positive, false negative, true negative 'pixels' for each image and each class.