AI Engineer

U-Net 실습4 - 모델 테스트

scone 2023. 7. 7. 18:58

모델 테스트

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms, datasets

### Hyper Parameters
lr = 1e-3
batch_size = 4
num_epoch = 100

data_dir = 'workspace/data'
ckpt_dir = 'workspace/checkpoint'
log_dir = 'workspace/log'
result_dir = 'workspace/results'

if not os.path.exists(result_dir):
    os.mkdir(result_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Network
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
    
        ## Convolution, BatchNormalization, Relu 2D
        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
            layers = []
            # Conv
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                kernel_size=kernel_size, stride=stride, padding=padding,
                                bias=bias)]
            # Batch Normalization
            layers += [nn.BatchNorm2d(num_features=out_channels)]
            # Relu
            layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)
            return cbr
        
        ## Contracting path (enc1_1 : Encoder First Stage First Step)
        self.enc1_1 = CBR2d(in_channels=1, out_channels=64)
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64)

        self.pool1 = nn.MaxPool2d(kernel_size=2)
        
        self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128)

        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256)

        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc4_2 = CBR2d(in_channels=512, out_channels=512)

        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)

        ## Expansive path (dec5_1 : Decoder Fifth Stage First Step)
        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)

        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)
        
        self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512) # Skip Connection Exist
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256) # match with enc4_1

        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256) # Skip Connection Exist
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128) # match with enc3_1

        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128) # Skip Connection Exist
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64) # match with enc2_1

        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64) # Skip Connection Exist
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)

        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1,
                            stride=1, padding=0,bias=True)
    
    def forward(self, x):
        ## Contracting path
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)
        
        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)
        
        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)
        
        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)

        dec5_1 = self.dec5_1(enc5_1)

        ## Expansive path
        ''' 
        cat or concatenate
        dim=[0:batch,1:channel,2:height,3:width]
        '''
        unpool4 = self.unpool4(dec5_1)
        cat4 = torch.cat((unpool4, enc4_2), dim=1)
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)

        unpool3 = self.unpool3(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)

        unpool2 = self.unpool2(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)

        unpool1 = self.unpool1(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)

        return x

### Data Loader
class Dataset(torch.utils.data.Dataset) :
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        lst_data = os.listdir(self.data_dir)
        lst_label = [f for f in lst_data if f.startswith('label')]
        lst_input = [f for f in lst_data if f.startswith('input')]

        lst_label.sort()
        lst_input.sort()

        self.lst_label = lst_label
        self.lst_input = lst_input
    
    def __len__(self):
        return len(self.lst_label)
    
    def __getitem__(self,index):
        label = np.load(os.path.join(self.data_dir, self.lst_label[index]))
        input = np.load(os.path.join(self.data_dir, self.lst_input[index]))

        label = label/255.0
        input = input/255.0

        if label.ndim == 2 :
            label = label[:,:,np.newaxis]
        if input.ndim == 2 :
            input = input[:,:,np.newaxis]
        
        data = {'input': input, 'label':label}

        if self.transform:
            data = self.transform(data)

        return data
    
# ##
# dataset_train = Dataset(data_dir=os.path.join(data_dir, 'train'))

# ##
# data =dataset_train.__getitem__(0)

# input = data['input']
# label = data['label']

# ##
# plt.subplot(121)
# plt.imshow(input)

# plt.subplot(122)
# plt.imshow(label)

# plt.show()

# # (512, 512, 1)
# print(label.shape)


### Transform 
class ToTensor(object):
    def __call__(self,data) :
        label, input = data['label'], data['input']

        # np = (Y, X, CH) -> tensor = (CH, Y, X)
        label = label.transpose((2,0,1)).astype(np.float32)
        input = input.transpose((2,0,1)).astype(np.float32)

        data = {'label':torch.from_numpy(label),
                'input':torch.from_numpy(input)}
        
        return data

class Normalization(object):
    def __init__(self, mean=0.5, std=0.5):
        self.mean = mean
        self.std = std
    
    def __call__(self, data):
        label, input = data['label'], data['input']
        input = (input-self.mean) / self.std
        data = {'label':label, 'input':input}
        return data

class RandomFlip(object):
    def __call__(self, data) :
        label, input = data['label'], data['input']
        
        if np.random.rand() > 0.5 :
            label = np.fliplr(label) # left-right
            input = np.fliplr(input)
        
        if np.random.rand() > 0.5 :
            label = np.flipud(label) # up-down
            input = np.flipud(input)
        
        data = {'label':label, 'input': input}
        return data
    
# ##
# transform = transforms.Compose([Normalization(),RandomFlip(),ToTensor()])
# dataset_train = Dataset(data_dir=os.path.join(data_dir, 'train'), transform=transform)

# ##
# data = dataset_train.__getitem__(0)

# input = data['input']
# label = data['label']

# ##
# plt.subplot(1,2,1)
# plt.imshow(input.squeeze())

# plt.subplot(1,2,2)
# plt.imshow(label.squeeze())

# plt.show()
# -------------------------------------------------------------------------------------------------------------------
### Setting for Test
# Data load
transform = transforms.Compose([Normalization(),ToTensor()])
dataset_test = Dataset(data_dir=os.path.join(data_dir, 'test'), transform=transform)
loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=8)

# Model
net = UNet().to(device)

# Loss Function
fn_loss = nn.BCEWithLogitsLoss().to(device)

# Optimizer
optim = torch.optim.Adam(net.parameters(), lr=lr)

# variables
num_data_test = len(dataset_test)

num_batch_test = np.ceil(num_data_test / batch_size)

# functions
fn_tonumpy = lambda x :  x.to('cpu').detach().numpy().transpose(0,2,3,1)
fn_denorm = lambda x, mean, std : (x*std) + mean
fn_clss = lambda x : 1.0 * (x > 0.5)

# save Network
def save(ckpt_dir, net, optim, epoch):
    if not os.path.exists(ckpt_dir):
        os.mkdir(ckpt_dir)
    torch.save({'net':net.state_dict(), 'optim':optim.state_dict()},
               "./%s/model_epoch%d.pth" % (ckpt_dir, epoch))

# Load Network
def load(ckpt_dir, net, optim):
    if not os.path.exists(ckpt_dir):
        epoch = 0
        return net, optim, epoch
    
    ckpt_lst = os.listdir(ckpt_dir)
    ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))

    dict_model = torch.load('./%s/%s' % (ckpt_dir, ckpt_lst[-1]))

    net.load_state_dict(dict_model['net'])
    optim.load_state_dict(dict_model['optim'])
    epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])

    return net, optim, epoch

### Training
st_epoch = 0
net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim)

with torch.no_grad():
    net.eval()
    loss_arr = []

    for batch, data in enumerate(loader_test, 1):
        # forward pass
        label = data['label'].to(device)
        input = data['input'].to(device)

        output = net(input)

        # calculate lossfucntion
        loss = fn_loss(output, label)
        loss_arr += [loss.item()]

        print("TEST: BATCH %04d / %04d | LOSS %.4f" %
                (batch, num_batch_test, np.mean(loss_arr)))
        
        # save in Tensorboard
        label = fn_tonumpy(label)
        input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))
        output = fn_tonumpy(fn_clss(output))

        for j in range(label.shape[0]):
            id = num_batch_test * (batch-1) + j

            plt.imsave(os.path.join(result_dir, 'label_%04d.png' % id),
                       label[j].squeeze(), cmap='gray')
            plt.imsave(os.path.join(result_dir, 'input_%04d.png' % id),
                       input[j].squeeze(), cmap='gray')
            plt.imsave(os.path.join(result_dir, 'output_%04d.png' % id),
                       output[j].squeeze(), cmap='gray')
            
            np.save(os.path.join(result_dir, 'label_%04d.np' % id),
                    label[j].squeeze())
            np.save(os.path.join(result_dir,'input_%04d.np' % id),
                    input[j].squeeze())
            np.save(os.path.join(result_dir,'output_%04d.np' % id),
                    output[j].squeeze())
    
print("Average TEST: BATCH %04d / %04d | LOSS %.4f" %
        (batch, num_batch_test, np.mean(loss_arr)))

Test Result

테스트 결과

 

 

 

 

 

실습 및 코드 출처 :

https://www.youtube.com/watch?v=igvk1W1JtHA 

 

구성한 Docker 환경 :

https://hub.docker.com/repository/docker/kimjungtaek/u-net/general