We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d316fbf commit 560e8ffCopy full SHA for 560e8ff
segmentation_models_pytorch/losses/dice.py
@@ -89,10 +89,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
89
y_pred = y_pred * mask.unsqueeze(1)
90
91
y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C
92
- y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W
+ y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W
93
else:
94
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
95
- y_true = y_true.permute(0, 2, 1) # H, C, H*W
+ y_true = y_true.permute(0, 2, 1) # N, C, H*W
96
97
if self.mode == MULTILABEL_MODE:
98
y_true = y_true.view(bs, num_classes, -1)
AltStyle によって変換されたページ (->オリジナル) / アドレス: モード: デフォルト 音声ブラウザ ルビ付き 配色反転 文字拡大 モバイル
0 commit comments