ライブラリのimport
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import numpy as np
!pip install -q segmentation_models_pytorch
import segmentation_models_pytorch as smp
from tqdm.notebook import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.checkpoint as C
import torchvision.transforms.functional as fn
import torchvision.transforms as T
import matplotlib.pyplot as plt
!pip install -q torchsummary
from torchvision import models
from torchsummary import summary
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
Configファイルの設定
class CFG:
# Path to the data folder (Thanks to @Kenni)
GLOBAL_PATH = '/kaggle/input/google-research-identify-contrails-preprocessing'
# base image size
resize_value = 256
# resize image
resize = False
if resize:
resize_value = 384
# Model Settings
model = 'UNET'
encoder = 'timm-resnest50d'
weights = 'imagenet'
batch_size = 16
optimizer='Adam'
lr = 5e-4
#epochs = 40
epochs = 30
今回使用するモデルの構成を決めます。モデルのアーキテクチャにはUNetを使用し、エンコーダーには、resnest-50を使います。
データセットの作成
#A custom Dataset class must implement three functions: __init__, __len__, and __getitem__
class ContrailDataset(Dataset):
def __init__(self, base_dir, data_type='train'):
assert data_type in ['train_images', 'validate_images'], \
"'data_type' should be one of 'train_images' or 'validate_images'"
self.base_dir = base_dir
self.data_type = data_type
self.record = os.listdir(self.base_dir +'/'+ self.data_type)
self.resize_image = T.Resize(CFG.resize_value,interpolation=T.InterpolationMode.BILINEAR,antialias=True)
self.resize_mask = T.Resize(CFG.resize_value,interpolation=T.InterpolationMode.NEAREST,antialias=True)
def __len__(self):
return len(self.record)
def __getitem__(self, idx):
record_id = self.record[idx]
record_dir = os.path.join(self.base_dir, self.data_type, record_id)
false_color = np.load(os.path.join(record_dir,'image.npy'))
human_pixel_mask = np.load(os.path.join(record_dir,'human_pixel_masks.npy'))
false_color = torch.from_numpy(false_color)#.clone().detach()
human_pixel_mask = torch.from_numpy(human_pixel_mask)#.clone().detach()
false_color = torch.moveaxis(false_color,-1,0)
human_pixel_mask = torch.moveaxis(human_pixel_mask,-1,0)
if self.data_type == 'train':
random_crop_factor = torch.rand(1)
crop_min, crop_max = 0.5 , 1
crop_factor = crop_min + random_crop_factor * (crop_max-crop_min)
crop_size = int(crop_factor * 256)
self.crop = T.CenterCrop(size=crop_size)
false_color = self.crop(false_color)
human_pixel_mask = self.crop(human_pixel_mask)
false_color = self.resize_image(false_color)
human_pixel_mask = self.resize_mask(human_pixel_mask)
#if CFG.resize and self.data_type=='validation':
#false_color = self.resize_image(false_color)
#human_pixel_mask = self.resize_mask(human_pixel_mask)
# false color is scaled between 0 and 1!
return false_color, human_pixel_mask.float()
Dataloaderの作成
training_data = ContrailDataset(base_dir=CFG.GLOBAL_PATH, data_type='train_images')
train_dataloader = DataLoader(
training_data,
batch_size=CFG.batch_size,
shuffle=True,
num_workers= 4 if torch.cuda.is_available() else 0,
pin_memory=True,
drop_last = True
)
validation_data = ContrailDataset(base_dir=CFG.GLOBAL_PATH, data_type='validate_images')
validation_dataloader = DataLoader(
validation_data,
batch_size=CFG.batch_size,
shuffle=False,
num_workers= 4 if torch.cuda.is_available() else 0,
pin_memory=True,
drop_last = True
)
Unetモデル作成
if CFG.model == 'UNET':
model = smp.Unet(
encoder_name =CFG.encoder,
encoder_weights=CFG.weights, # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=1, # model output channels (number of classes in your dataset)
activation="sigmoid",
)
model.to(device)
summary(model, (3, 256, 256))
最適化
optimizer = optim.Adam(model.parameters(), lr=CFG.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',patience = 4, factor = 0.31622776601, verbose = True)
print(f'learning rate: {optimizer.param_groups[0]["lr"]}')
損失関数
# Average dice score for the examples in a batch
def dice_avg(y_p, y_t,smooth=1e-3):
i = torch.sum(y_p * y_t, dim=(2, 3))
u = torch.sum(y_p, dim=(2, 3)) + torch.sum(y_t, dim=(2, 3))
score = (2 * i + smooth)/(u + smooth)
return torch.mean(score)
def dice_loss_avg(y_p,y_t):
return 1-dice_score_jan(y_p,y_t)
def dice_global(y_p,y_t,smooth=1e-3):
intersection = torch.sum(y_p * y_t)
union = torch.sum(y_p) + torch.sum(y_t)
dice = (2.0 * intersection + smooth) / (union + smooth)
return dice
def dice_loss_global(y_p,y_t):
return 1-dice_global(y_p,y_t)
学習、推論の実行
train_dice_global = []
train_dice_avg = []
eval_dice_global = []
eval_dice_avg = []
bst_dice = 0
bst_epoch = 1
for epoch in range(1,CFG.epochs+1):
print(f'________epoch: {epoch}________')
# Early stopping
if epoch-bst_epoch >=10:
print(f'early stopping in epoch {epoch}')
break
model.train()
bar = tqdm(train_dataloader)
tot_loss_global = 0
tot_dice_global = 0
tot_dice_avg = 0
count = 0
for image, mask in bar:
image = torch.nn.functional.interpolate(image,
size=CFG.resize_value,
mode='bilinear'
)
# Transfer to Device
image,mask = image.to(device), mask.to(device)
# Set optimizer gradients to zero
optimizer.zero_grad()
#Perform Inference
pred_mask = model(image)
# If the image was resized, use a resizing step to make 256 again
if CFG.resize:
pred_mask = torch.nn.functional.interpolate(pred_mask,
size=256,
mode='bilinear'
)
# Calculate the loss and do a backward pass
loss = dice_loss_global(pred_mask, mask)
loss.backward()
# Adjust the weights
optimizer.step()
tot_loss_global += loss.item()
tot_dice_global+=1-loss.item()
tot_dice_avg += dice_avg(pred_mask,mask).item()
count += 1
bar.set_postfix(TrainDiceLossGlobal=f'{tot_loss_global/count:.4f}',
TrainDiceGlobal=f'{tot_dice_global/count:.4f}',
TrainDiceAvg = f'{tot_dice_avg/count:.4f}')
train_dice_global.append(np.array(tot_dice_global/count))
train_dice_avg.append(np.array(tot_dice_avg/count))
model.train(False)
bar = tqdm(validation_dataloader)
tot_dice_global = 0
tot_dice_avg = 0
count = 0
for image, mask in bar:
if CFG.resize:
image = torch.nn.functional.interpolate(image,
size=CFG.resize_value,
mode='bilinear'
)
image,mask = image.to(device), mask.to(device)
pred_mask = model(image)
if CFG.resize:
pred_mask = torch.nn.functional.interpolate(pred_mask,
size=256,
mode='bilinear'
)
tot_dice_global += dice_global(pred_mask, mask).item()
tot_dice_avg+=dice_avg(pred_mask,mask).item()
count += 1
bar.set_postfix(ValidDiceGlobal=f'{tot_dice_global/count:.4f}',
ValidDiceAvg = f'{tot_dice_avg/count:.4f}')
eval_dice_global.append(np.array(tot_dice_global/count))
eval_dice_avg.append(np.array(tot_dice_avg/count))
scheduler.step(1-(tot_dice_global/count))
print(f'learning rate: {optimizer.param_groups[0]["lr"]}')
if tot_dice_global/count > bst_dice:
bst_dice = tot_dice_global/count
bst_epoch = epoch
torch.save(model.state_dict(), f'model_state_dict_epoch_{epoch}_dice_{bst_dice:.4f}.pth')
torch.save(model, f'model_epoch_{epoch}_dice_{bst_dice:.4f}.pt')
print(f"current model saved! Epoch: {epoch} global dice: {bst_dice} avg dice: {tot_dice_avg/count}")
学習、推論の履歴表示
plt.plot(train_dice_global, label='train_dice_global')
plt.plot(train_dice_avg,label='train_dice_avg')
plt.plot(eval_dice_global, label='eval_dice_global')
plt.plot(eval_dice_avg,label='eval_dice_avg')
plt.legend()
plt.show
参考
https://www.kaggle.com/code/bibanh/lb-0-623-resnet26d-unet-training