U-Net 네트워크 구조
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
### Hyper Parameters
lr = 1e-3
batch_size = 4
num_epoch = 100
data_dir = 'workspace/data'
ckpt_dir = 'workspace/checkpoint'
log_dir = 'workspace/log'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
### Network
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
## Convolution, BatchNormalization, Relu 2D
def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
layers = []
# Conv
layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias)]
# Batch Normalization
layers += [nn.BatchNorm2d(num_features=out_channels)]
# Relu
layers += [nn.ReLU()]
cbr = nn.Sequential(*layers)
return cbr
## Contracting path (enc1_1 : Encoder First Stage First Step)
self.enc1_1 = CBR2d(in_channels=1, out_channels=64)
self.enc1_2 = CBR2d(in_channels=64, out_channels=64)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
self.enc2_2 = CBR2d(in_channels=128, out_channels=128)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
self.enc3_2 = CBR2d(in_channels=256, out_channels=256)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
self.enc4_2 = CBR2d(in_channels=512, out_channels=512)
self.pool4 = nn.MaxPool2d(kernel_size=2)
self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)
## Expansive path (dec5_1 : Decoder Fifth Stage First Step)
self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)
self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
kernel_size=2, stride=2, padding=0,
bias=True)
self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512) # Skip Connection Exist
self.dec4_1 = CBR2d(in_channels=512, out_channels=256) # match with enc4_1
self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256) # Skip Connection Exist
self.dec3_1 = CBR2d(in_channels=256, out_channels=128) # match with enc3_1
self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128) # Skip Connection Exist
self.dec2_1 = CBR2d(in_channels=128, out_channels=64) # match with enc2_1
self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64) # Skip Connection Exist
self.dec1_1 = CBR2d(in_channels=64, out_channels=64)
self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1,
stride=1, padding=0,bias=True)
def forward(self, x):
## Contracting path
enc1_1 = self.enc1_1(x)
enc1_2 = self.enc1_2(enc1_1)
pool1 = self.pool1(enc1_2)
enc2_1 = self.enc2_1(pool1)
enc2_2 = self.enc2_2(enc2_1)
pool2 = self.pool2(enc2_2)
enc3_1 = self.enc3_1(pool2)
enc3_2 = self.enc3_2(enc3_1)
pool3 = self.pool3(enc3_2)
enc4_1 = self.enc4_1(pool3)
enc4_2 = self.enc4_2(enc4_1)
pool4 = self.pool4(enc4_2)
enc5_1 = self.enc5_1(pool4)
dec5_1 = self.dec5_1(enc5_1)
## Expansive path
'''
cat or concatenate
dim=[0:batch,1:channel,2:height,3:width]
'''
unpool4 = self.unpool4(dec5_1)
cat4 = torch.cat((unpool4, enc4_2), dim=1)
dec4_2 = self.dec4_2(cat4)
dec4_1 = self.dec4_1(dec4_2)
unpool3 = self.unpool3(dec4_1)
cat3 = torch.cat((unpool3, enc3_2), dim=1)
dec3_2 = self.dec3_2(cat3)
dec3_1 = self.dec3_1(dec3_2)
unpool2 = self.unpool2(dec3_1)
cat2 = torch.cat((unpool2, enc2_2), dim=1)
dec2_2 = self.dec2_2(cat2)
dec2_1 = self.dec2_1(dec2_2)
unpool1 = self.unpool1(dec2_1)
cat1 = torch.cat((unpool1, enc1_2), dim=1)
dec1_2 = self.dec1_2(cat1)
dec1_1 = self.dec1_1(dec1_2)
x = self.fc(dec1_1)
return x
Data Loader
### Data Loader
class Dataset(torch.utils.data.Dataset) :
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
lst_data = os.listdir(self.data_dir)
lst_label = [f for f in lst_data if f.startswith('label')]
lst_input = [f for f in lst_data if f.startswith('input')]
lst_label.sort()
lst_input.sort()
self.lst_label = lst_label
self.lst_input = lst_input
def __len__(self):
return len(self.lst_label)
def __getitem__(self,index):
label = np.load(os.path.join(self.data_dir, self.lst_label[index]))
input = np.load(os.path.join(self.data_dir, self.lst_input[index]))
label = label/255.0
input = input/255.0
if label.ndim == 2 :
label = label[:,:,np.newaxis]
if input.ndim == 2 :
input = input[:,:,np.newaxis]
data = {'input': input, 'label':label}
if self.transform:
data = self.transform(data)
return data
DataLoder 작동 테스트
##
dataset_train = Dataset(data_dir=os.path.join(data_dir, 'train'))
##
data =dataset_train.__getitem__(0)
input = data['input']
label = data['label']
##
plt.subplot(121)
plt.imshow(input)
plt.subplot(122)
plt.imshow(label)
plt.show()
# (512, 512, 1)
print(label.shape)
Transform
- Image의 Numpy 차원 = (Y, X, CH)
- Image의 Tensor 차원 = (CH, Y, X)
### Transform
class ToTensor(object):
def __call__(self,data) :
label, input = data['label'], data['input']
# np = (Y, X, CH) -> tensor = (CH, Y, X)
label = label.transpose((2,0,1)).astype(np.float32)
input = input.transpose((2,0,1)).astype(np.float32)
data = {'label':torch.from_numpy(label),
'input':torch.from_numpy(input)}
return data
class Normalization(object):
def __init__(self, mean=0.5, std=0.5):
self.mean = mean
self.std = std
def __call__(self, data):
label, input = data['label'], data['input']
input = (input-self.mean) / self.std
data = {'label':label, 'input':input}
return data
class RandomFlip(object):
def __call__(self, data) :
label, input = data['label'], data['input']
if np.random.rand() > 0.5 :
label = np.fliplr(label) # left-right
input = np.fliplr(input)
if np.random.rand() > 0.5 :
label = np.flipud(label) # up-down
input = np.flipud(input)
data = {'label':label, 'input': input}
return data
Transform 작동 테스트
##
transform = transforms.Compose([Normalization(),RandomFlip(),ToTensor()])
dataset_train = Dataset(data_dir=os.path.join(data_dir, 'train'), transform=transform)
##
data = dataset_train.__getitem__(0)
input = data['input']
label = data['label']
##
plt.subplot(1,2,1)
plt.imshow(input.squeeze())
plt.subplot(1,2,2)
plt.imshow(label.squeeze())
plt.show()
- 버그 수정 : dimension에서 channel이 앞으로오면서 size 에러가 발생하기 때문에, squeeze 함수를 통해 채널을 줄여주었습니다.
- Normalize 효과
- input : 어두운 부분은 -1 쪽으로, 노란색 밝은 부분은 1에 가깝게 Normalize
- label : Normalize 수행 x, 밝은 부분은 0, 어두운 부분은 1
- 랜덤으로 좌우가 반전 되거나 또는 위 아래가 반전되었습니다.
코드 출처 :
https://www.youtube.com/watch?v=1gMnChpUS9k
이미지 출처 :
https://towardsdatascience.com/unet-line-by-line-explanation-9b191c76baf5
'DL' 카테고리의 다른 글
U-Net 실습4 - 모델 테스트 (0) | 2023.07.07 |
---|---|
U-Net 실습3 - 모델 학습 및 Tensorboard (0) | 2023.07.06 |
U-Net 실습1 - DataSet 다운 및 Split, Docker에서 GUI 설정하기 (0) | 2023.07.06 |
[History] 2018, ImageNet을 기반으로 하는 이미지 분류 알고리즘 리뷰 (0) | 2023.07.02 |
[YOLO; You Only Look Once] Unified, Real-Time Object Detection (0) | 2023.06.29 |