diff --git a/segmentation_models_pytorch/losses/dice.py b/segmentation_models_pytorch/losses/dice.py index cb60561f..2eb7ca74 100644 --- a/segmentation_models_pytorch/losses/dice.py +++ b/segmentation_models_pytorch/losses/dice.py @@ -89,10 +89,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = y_pred * mask.unsqueeze(1) y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C - y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W + y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W else: y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C - y_true = y_true.permute(0, 2, 1) # H, C, H*W + y_true = y_true.permute(0, 2, 1) # N, C, H*W if self.mode == MULTILABEL_MODE: y_true = y_true.view(bs, num_classes, -1)