모델 학습
### Setting for Training
# Data load
transform = transforms.Compose([Normalization(),RandomFlip(),ToTensor()])
dataset_train = Dataset(data_dir=os.path.join(data_dir, 'train'), transform=transform)
loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)
dataset_val = Dataset(data_dir=os.path.join(data_dir, 'val'), transform=transform)
loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=8)
# Model
net = UNet().to(device)
# Loss Function
fn_loss = nn.BCEWithLogitsLoss().to(device)
# Optimizer
optim = torch.optim.Adam(net.parameters(), lr=lr)
# variables
num_data_train = len(dataset_train)
num_data_val = len(dataset_val)
num_batch_train = np.ceil(num_data_train / batch_size)
num_batch_val = np.ceil(num_data_val / batch_size)
# functions
fn_tonumpy = lambda x : x.to('cpu').detach().numpy().transpose(0,2,3,1)
fn_denorm = lambda x, mean, std : (x*std) + mean
fn_clss = lambda x : 1.0 * (x > 0.5)
# SummaryWriter for Tensorboard
writer_train = SummaryWriter(log_dir=os.path.join(log_dir, 'train'))
writer_val = SummaryWriter(log_dir=os.path.join(log_dir, 'val'))
# save Network
def save(ckpt_dir, net, optim, epoch):
if not os.path.exists(ckpt_dir):
os.mkdir(ckpt_dir)
torch.save({'net':net.state_dict(), 'optim':optim.state_dict()},
"./%s/model_epoch%d.pth" % (ckpt_dir, epoch))
# Load Network
def load(ckpt_dir, net, optim):
if not os.path.exists(ckpt_dir):
epoch = 0
return net, optim, epoch
ckpt_lst = os.listdir(ckpt_dir)
ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
dict_model = torch.load('./%s/%s' % (ckpt_dir, ckpt_lst[-1]))
net.load_state_dict(dict_model['net'])
optim.load_state_dict(dict_model['optim'])
epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])
return net, optim, epoch
### Training
st_epoch = 0
net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim)
for epoch in range(st_epoch + 1, num_epoch + 1):
net.train()
loss_arr = []
for batch, data in enumerate(loader_train, 1):
# forward pass
label = data['label'].to(device)
input = data['input'].to(device)
output = net(input)
# backward pass
optim.zero_grad()
loss = fn_loss(output, label)
loss.backward()
optim.step()
# calculate lossfunction
loss_arr += [loss.item()]
print("TRAIN : EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f" %
(epoch, num_epoch, batch, num_batch_train, np.mean(loss_arr)))
# save Tensorboard
label = fn_tonumpy(label)
input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))
output = fn_tonumpy(fn_clss(output))
writer_train.add_image('label', label, num_batch_train * (epoch-1) + batch,
dataformats='NHWC')
writer_train.add_image('input', input, num_batch_train*(epoch-1)+batch,
dataformats='NHWC')
writer_train.add_image('output', output, num_batch_train*(epoch-1)+batch,
dataformats='NHWC')
writer_train.add_scalar('loss', np.mean(loss_arr),epoch)
# NO Backprop in val
with torch.no_grad():
net.eval()
loss_arr = []
for batch, data in enumerate(loader_val, 1):
# forward pass
label = data['label'].to(device)
input = data['input'].to(device)
output = net(input)
# calculate lossfucntion
loss = fn_loss(output, label)
loss_arr += [loss.item()]
print("VALID: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f" %
(epoch, num_epoch, batch, num_batch_val, np.mean(loss_arr)))
# save in Tensorboard
label = fn_tonumpy(label)
input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))
output = fn_tonumpy(fn_clss(output))
writer_val.add_image('label', label, num_batch_val * (epoch-1) + batch,
dataformats='NHWC')
writer_val.add_image('input', input, num_batch_val*(epoch-1)+batch,
dataformats='NHWC')
writer_val.add_image('output', output, num_batch_val*(epoch-1)+batch,
dataformats='NHWC')
writer_val.add_scalar('loss', np.mean(loss_arr), epoch)
if epoch % 10 == 0 :
save(ckpt_dir=ckpt_dir, net=net, optim=optim, epoch=epoch)
writer_train.close()
writer_val.close()
Tensorboard
- log 폴더에서 tensorboard 실행
tensorboard --logdir=.
- 과적합으로 생각되는 Train Loss가 나왔습니다.
- Drop out이 현재 이 모델에는 없는데 추후 넣으면 어떨까 생각됩니다.
소스코드 출처 :
https://www.youtube.com/watch?v=rBb597ct_FQ
'DL' 카테고리의 다른 글
[VIT] Vision Transformer (2021) (0) | 2023.07.27 |
---|---|
U-Net 실습4 - 모델 테스트 (0) | 2023.07.07 |
U-Net 실습2 - 네트워크 구조, Dataloader, Transform 구현 (0) | 2023.07.06 |
U-Net 실습1 - DataSet 다운 및 Split, Docker에서 GUI 설정하기 (0) | 2023.07.06 |
[History] 2018, ImageNet을 기반으로 하는 이미지 분류 알고리즘 리뷰 (0) | 2023.07.02 |