Categories
kaggle

ResNetによる解法-train

ライブラリのimport

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import numpy as np
!pip install -q segmentation_models_pytorch
import segmentation_models_pytorch as smp
from tqdm.notebook import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.checkpoint as C
import torchvision.transforms.functional as fn
import torchvision.transforms as T
import matplotlib.pyplot as plt
!pip install -q torchsummary
from torchvision import models
from torchsummary import summary

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

Configファイルの設定

class CFG:
    
    # Path to the data folder (Thanks to @Kenni)
    GLOBAL_PATH = '/kaggle/input/google-research-identify-contrails-preprocessing'
    
    # base image size
    resize_value = 256
    
    # resize image
    resize = False
    if resize:
        resize_value = 384
        
    # Model Settings    
    model = 'UNET'
    encoder = 'timm-resnest50d'
    weights = 'imagenet'
    
    batch_size = 16
    optimizer='Adam'
    lr = 5e-4
    #epochs = 40    
    epochs = 30

今回使用するモデルの構成を決めます。モデルのアーキテクチャにはUNetを使用し、エンコーダーには、resnest-50を使います。

データセットの作成

#A custom Dataset class must implement three functions: __init__, __len__, and __getitem__
class ContrailDataset(Dataset):
    
    def __init__(self, base_dir, data_type='train'):
        assert data_type in ['train_images', 'validate_images'], \
            "'data_type' should be one of 'train_images' or 'validate_images'"
        
        self.base_dir = base_dir
        self.data_type = data_type
        self.record = os.listdir(self.base_dir +'/'+ self.data_type)
       
        self.resize_image = T.Resize(CFG.resize_value,interpolation=T.InterpolationMode.BILINEAR,antialias=True)
        self.resize_mask = T.Resize(CFG.resize_value,interpolation=T.InterpolationMode.NEAREST,antialias=True)
   
    def __len__(self):
        return len(self.record)

    def __getitem__(self, idx):
        
        record_id = self.record[idx]
        record_dir = os.path.join(self.base_dir, self.data_type, record_id)
        
        false_color = np.load(os.path.join(record_dir,'image.npy'))
        human_pixel_mask = np.load(os.path.join(record_dir,'human_pixel_masks.npy')) 
        
        false_color = torch.from_numpy(false_color)#.clone().detach()
        human_pixel_mask = torch.from_numpy(human_pixel_mask)#.clone().detach()
        
        false_color = torch.moveaxis(false_color,-1,0)
        human_pixel_mask = torch.moveaxis(human_pixel_mask,-1,0)
            
        if self.data_type == 'train':
            
            random_crop_factor = torch.rand(1)
            crop_min, crop_max = 0.5 , 1
            crop_factor = crop_min + random_crop_factor * (crop_max-crop_min) 
            crop_size = int(crop_factor * 256)
            self.crop = T.CenterCrop(size=crop_size)
            
            false_color = self.crop(false_color)
            human_pixel_mask =  self.crop(human_pixel_mask)
            
            false_color = self.resize_image(false_color)
            human_pixel_mask =  self.resize_mask(human_pixel_mask)

        
        #if CFG.resize and self.data_type=='validation':
            #false_color = self.resize_image(false_color)
            #human_pixel_mask =  self.resize_mask(human_pixel_mask)
                  
        # false color is scaled between 0 and 1!
        return false_color, human_pixel_mask.float()

Dataloaderの作成

training_data = ContrailDataset(base_dir=CFG.GLOBAL_PATH, data_type='train_images')
train_dataloader = DataLoader(
    training_data, 
    batch_size=CFG.batch_size, 
    shuffle=True, 
    num_workers= 4 if torch.cuda.is_available() else 0,
    pin_memory=True,
    drop_last = True
)

validation_data = ContrailDataset(base_dir=CFG.GLOBAL_PATH, data_type='validate_images')
validation_dataloader = DataLoader(
    validation_data, 
    batch_size=CFG.batch_size, 
    shuffle=False, 
    num_workers= 4 if torch.cuda.is_available() else 0,
    pin_memory=True,
    drop_last = True
)

Unetモデル作成

if CFG.model == 'UNET':
    model = smp.Unet(
    encoder_name =CFG.encoder,
    encoder_weights=CFG.weights,    # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,        # model output channels (number of classes in your dataset)
    activation="sigmoid",
    )
    model.to(device)
    summary(model, (3, 256, 256))

最適化

optimizer = optim.Adam(model.parameters(), lr=CFG.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',patience = 4, factor = 0.31622776601, verbose = True)
print(f'learning rate: {optimizer.param_groups[0]["lr"]}')

損失関数

# Average dice score for the examples in a batch
def dice_avg(y_p, y_t,smooth=1e-3):
    i = torch.sum(y_p * y_t, dim=(2, 3))
    u = torch.sum(y_p, dim=(2, 3)) + torch.sum(y_t, dim=(2, 3))
    score = (2 * i + smooth)/(u + smooth)
    return torch.mean(score)


def dice_loss_avg(y_p,y_t):
    return 1-dice_score_jan(y_p,y_t)

def dice_global(y_p,y_t,smooth=1e-3):

    intersection = torch.sum(y_p * y_t)
    union = torch.sum(y_p) + torch.sum(y_t)

    dice = (2.0 * intersection + smooth) / (union + smooth)

    return dice

def dice_loss_global(y_p,y_t):
    return 1-dice_global(y_p,y_t)

学習、推論の実行

train_dice_global = []
train_dice_avg = []
eval_dice_global = []
eval_dice_avg = []
bst_dice = 0
bst_epoch = 1
for epoch in range(1,CFG.epochs+1):
    
    print(f'________epoch: {epoch}________')
    
    # Early stopping
    if epoch-bst_epoch >=10:
        print(f'early stopping in epoch {epoch}')
        break
    
    model.train()
    bar = tqdm(train_dataloader)
    tot_loss_global = 0
    tot_dice_global = 0
    tot_dice_avg = 0
    count = 0
    for image, mask in bar:
        
        image = torch.nn.functional.interpolate(image, 
                                                size=CFG.resize_value,
                                                mode='bilinear'
                                               )
        
        # Transfer to Device
        image,mask = image.to(device), mask.to(device)
        
        # Set optimizer gradients to zero
        optimizer.zero_grad()
        
        #Perform Inference
        pred_mask = model(image)
        
        # If the image was resized, use a resizing step to make 256 again
        if CFG.resize:
            pred_mask = torch.nn.functional.interpolate(pred_mask, 
                                                        size=256,
                                                        mode='bilinear'
                                                       )
        
        # Calculate the loss and do a backward pass
        loss = dice_loss_global(pred_mask, mask)
        loss.backward()
        
        # Adjust the weights
        optimizer.step()

        tot_loss_global += loss.item()
        tot_dice_global+=1-loss.item()
        tot_dice_avg += dice_avg(pred_mask,mask).item()
        count += 1
        bar.set_postfix(TrainDiceLossGlobal=f'{tot_loss_global/count:.4f}', 
                        TrainDiceGlobal=f'{tot_dice_global/count:.4f}',
                        TrainDiceAvg = f'{tot_dice_avg/count:.4f}')
        
    train_dice_global.append(np.array(tot_dice_global/count))
    train_dice_avg.append(np.array(tot_dice_avg/count))
      
    model.train(False)
    bar = tqdm(validation_dataloader)
    tot_dice_global = 0
    tot_dice_avg = 0
    count = 0
    for image, mask in bar:
        
        if CFG.resize:
            image = torch.nn.functional.interpolate(image, 
                                                size=CFG.resize_value,
                                                mode='bilinear'
                                               )
        image,mask = image.to(device), mask.to(device)
        pred_mask = model(image)
        
        if CFG.resize:
            pred_mask = torch.nn.functional.interpolate(pred_mask, 
                                                size=256,
                                                mode='bilinear'
                                               )
        
        tot_dice_global += dice_global(pred_mask, mask).item()
        tot_dice_avg+=dice_avg(pred_mask,mask).item()
        count += 1
        bar.set_postfix(ValidDiceGlobal=f'{tot_dice_global/count:.4f}',
                        ValidDiceAvg = f'{tot_dice_avg/count:.4f}')
        

    eval_dice_global.append(np.array(tot_dice_global/count))
    eval_dice_avg.append(np.array(tot_dice_avg/count))
    scheduler.step(1-(tot_dice_global/count))
    print(f'learning rate: {optimizer.param_groups[0]["lr"]}')
        
    if tot_dice_global/count > bst_dice:
        bst_dice = tot_dice_global/count
        bst_epoch = epoch
        torch.save(model.state_dict(), f'model_state_dict_epoch_{epoch}_dice_{bst_dice:.4f}.pth')
        torch.save(model, f'model_epoch_{epoch}_dice_{bst_dice:.4f}.pt')
        print(f"current model saved! Epoch: {epoch} global dice: {bst_dice} avg dice: {tot_dice_avg/count}") 
        
 

学習、推論の履歴表示

plt.plot(train_dice_global, label='train_dice_global')
plt.plot(train_dice_avg,label='train_dice_avg')
plt.plot(eval_dice_global, label='eval_dice_global')
plt.plot(eval_dice_avg,label='eval_dice_avg')
plt.legend()
plt.show

参考

https://www.kaggle.com/code/bibanh/lb-0-623-resnet26d-unet-training