-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
executable file
·66 lines (53 loc) · 2.6 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
def CE_Loss(inputs, target, num_classes=21):
n, c, h, w = inputs.size()
nt,_, ht, wt = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
temp_target = target.view(-1)
CE_loss = nn.NLLLoss(ignore_index=num_classes)(F.log_softmax(temp_inputs, dim = -1), temp_target)
return CE_loss
def f_score(inputs, target, beta=1, smooth=1e-5, threhold=0.5):
n, c, h, w = inputs.size()
nt, ht, wt, ct = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c), -1)
temp_target = target.view(n, -1, ct)
# --------------------------------------------#
# 计算dice系数
# --------------------------------------------#
temp_inputs = torch.gt(temp_inputs, threhold).float()
tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1])
fp = torch.sum(temp_inputs, axis=[0, 1]) - tp
fn = torch.sum(temp_target[..., :-1], axis=[0, 1]) - tp
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
score = torch.mean(score)
return score
class JointsMSELoss(nn.Module):
def __init__(self, use_target_weight):
super(JointsMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='mean')
self.use_target_weight = use_target_weight
def forward(self, output, target, target_weight):
batch_size = output.size(0)
num_joints = output.size(1)
# print('output shape', output.size())
heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1) # split along first dimension with size 1 a list?? # N x n_jt split into [njt: Nxn_pix?] a list?
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = 0
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze() # N x long list
heatmap_gt = heatmaps_gt[idx].squeeze()
if self.use_target_weight:
loss += 0.5 * self.criterion(
heatmap_pred.mul(target_weight[:, idx]),
heatmap_gt.mul(target_weight[:, idx])
)
else:
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
return loss / num_joints