-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodified_unet.py
More file actions
41 lines (34 loc) · 1.65 KB
/
modified_unet.py
File metadata and controls
41 lines (34 loc) · 1.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet
def build_res_unet(n_input=1, n_output=2, size=256):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
body = create_body(resnet18, pretrained=True, n_in=n_input, cut=-2)
net_G = DynamicUnet(body, n_output, (size, size)).to(device)
return net_G
def pretrain_generator(net_G, train_dl, opt, criterion, epochs):
for e in range(epochs):
loss_meter = AverageMeter()
for data in tqdm(train_dl):
L, ab = data['L'].to(device), data['ab'].to(device)
preds = net_G(L)
loss = criterion(preds, ab)
opt.zero_grad()
loss.backward()
opt.step()
loss_meter.update(loss.item(), L.size(0))
print(f"Epoch {e + 1}/{epochs}")
print(f"L1 Loss: {loss_meter.avg:.5f}")
net_G = build_res_unet(n_input=1, n_output=2, size=256)
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()
pretrain_generator(net_G, train_dl, opt, criterion, 20)
#torch.save(net_G.state_dict(), "res18-unet.pt")
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
model = MainModel(net_G=net_G)
train_model(model, train_dl, 20)
# net_G = build_res_unet(n_input=1, n_output=2, size=256)
# net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
# model = MainModel(net_G=net_G)
# model.load_state_dict(torch.load("final_model_weights.pt", map_location=device))