From 9a4c2a5ce236a26d64ac06df947fb929eb56fa38 Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Wed, 19 Oct 2022 20:45:34 -0400 Subject: [PATCH] fix return type of get_stats The previous return type of `Tuple[torch.LongTensor]` implies that the tuple includes one item. This commit changes this to a tuple of four LongTensors to represent TP, FP, FN, and TN. --- segmentation_models_pytorch/metrics/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/metrics/functional.py b/segmentation_models_pytorch/metrics/functional.py index 1a6d80b6..435d6947 100644 --- a/segmentation_models_pytorch/metrics/functional.py +++ b/segmentation_models_pytorch/metrics/functional.py @@ -65,7 +65,7 @@ def get_stats( ignore_index: Optional[int] = None, threshold: Optional[Union[float, List[float]]] = None, num_classes: Optional[int] = None, -) -> Tuple[torch.LongTensor]: +) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: """Compute true positive, false positive, false negative, true negative 'pixels' for each image and each class.