diff --git a/segmentation_models_pytorch/metrics/functional.py b/segmentation_models_pytorch/metrics/functional.py index 1a6d80b6..812bd97c 100644 --- a/segmentation_models_pytorch/metrics/functional.py +++ b/segmentation_models_pytorch/metrics/functional.py @@ -260,7 +260,7 @@ def _compute_metric( tn = tn.sum() score = metric_fn(tp, fp, fn, tn, **metric_kwargs) - elif reduction == "macro" or reduction == "weighted": + elif reduction == "macro" : tp = tp.sum(0) fp = fp.sum(0) fn = fn.sum(0) @@ -268,6 +268,15 @@ def _compute_metric( score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) score = (score * class_weights).mean() + + elif reduction == "weighted": + tp = tp.sum(0) + fp = fp.sum(0) + fn = fn.sum(0) + tn = tn.sum(0) + score = metric_fn(tp, fp, fn, tn, **metric_kwargs) + score = _handle_zero_division(score, zero_division) + score = (score * class_weights).sum() elif reduction == "micro-imagewise": tp = tp.sum(1)