-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackbones.py
57 lines (48 loc) · 1.75 KB
/
backbones.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet50
class DummyBackbone(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int):
super().__init__()
self.hidden = nn.Linear(input_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, num_classes)
self.act = nn.ReLU()
def forward(self, x, y):
x = x.view(x.size(0), -1)
x = self.act(self.hidden(x))
y_hat = self.act(self.fc(x))
if y is None:
return (y_hat,)
loss = F.cross_entropy(y_hat, y)
return (loss, y_hat)
class ResnetBackbone(nn.Module):
def __init__(self, hidden_dim: int, num_classes: int, freeze_resnet: bool = False):
super().__init__()
self.resnet = resnet50(pretrained=True)
if freeze_resnet:
for param in self.resnet.parameters():
param.requires_grad = False
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.BatchNorm1d(self.resnet.fc.in_features),
nn.Dropout(0.5),
nn.Linear(in_features=self.resnet.fc.in_features, out_features=hidden_dim),
nn.ReLU(),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(0.5),
nn.Linear(in_features=hidden_dim, out_features=num_classes),
)
self.model = nn.Sequential(
nn.Sequential(*list(self.resnet.children())[:-2]),
self.head
)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x, y):
y_hat = self.model(x)
if y is None:
return (y_hat,)
self.y_hat = y_hat
loss = self.criterion(y_hat, y)
return (loss, y_hat)