Categories
kaggle

PytorchによるBaseline

ライブラリ、データのロード等

学習に必要となるライブラリへのパスを通します。

import sys
sys.path.append("../input/pretrained-models-pytorch")
sys.path.append("../input/efficientnet-pytorch")
sys.path.append("/kaggle/input/smp-github/segmentation_models.pytorch-master")
sys.path.append("/kaggle/input/timm-pretrained-resnest/resnest/")
import segmentation_models_pytorch as smp

!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp /kaggle/input/timm-pretrained-resnest/resnest/gluon_resnest26-50eb607c.pth /root/.cache/torch/hub/checkpoints/gluon_resnest26-50eb607c.pth

学習モデルのパラメータ設定

%%writefile config.yaml

data_path: "/kaggle/input/contrails-images-ash-color"
output_dir: "models"

seed: 42

train_bs: 48
valid_bs: 128
workers: 2

progress_bar_refresh_rate: 1

early_stop:
    monitor: "val_loss"
    mode: "min"
    patience: 999
    verbose: 1

trainer:
    max_epochs: 2
    min_epochs: 2
    enable_progress_bar: True
    precision: "16-mixed"
    devices: 2

model:
    seg_model: "Unet"
    encoder_name: "timm-resnest26d"
    loss_smooth: 1.0
    image_size: 384
    optimizer_params:
        lr: 0.0005
        weight_decay: 0.0
    scheduler:
        name: "CosineAnnealingLR"
        params:
            CosineAnnealingLR:
                T_max: 2
                eta_min: 1.0e-6
                last_epoch: -1
            ReduceLROnPlateau:
                mode: "min"
                factor: 0.31622776601
                patience: 4
                verbose: True

学習回数はmin_epoch、max_epochで変えることが出来ます。
セグメンテーションモデルには、UNetを使っています。エンコーダーには、timmにあるresnetを使用しています。

# Dataset

import torch
import numpy as np
import torchvision.transforms as T

class ContrailsDataset(torch.utils.data.Dataset):
    def __init__(self, df, image_size=256, train=True):

        self.df = df
        self.trn = train
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.image_size = image_size
        if image_size != 256:
            self.resize_image = T.transforms.Resize(image_size)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        con_path = row.path
        con = np.load(str(con_path))

        img = con[..., :-1]
        label = con[..., -1]

        label = torch.tensor(label)

        img = torch.tensor(np.reshape(img, (256, 256, 3))).to(torch.float32).permute(2, 0, 1)

        if self.image_size != 256:
            img = self.resize_image(img)

        img = self.normalize_image(img)

        return img.float(), label.float()

    def __len__(self):
        return len(self.df)

# Lightning module

import torch
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.optim import AdamW
import torch.nn as nn
from torchmetrics.functional import dice

seg_models = {
    "Unet": smp.Unet,
    "Unet++": smp.UnetPlusPlus,
    "MAnet": smp.MAnet,
    "Linknet": smp.Linknet,
    "FPN": smp.FPN,
    "PSPNet": smp.PSPNet,
    "PAN": smp.PAN,
    "DeepLabV3": smp.DeepLabV3,
    "DeepLabV3+": smp.DeepLabV3Plus,
}


class LightningModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = model = seg_models[config["seg_model"]](
            encoder_name=config["encoder_name"],
            encoder_weights="imagenet",
            in_channels=3,
            classes=1,
            activation=None,
        )
        self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=config["loss_smooth"])
        self.val_step_outputs = []
        self.val_step_labels = []

    def forward(self, batch):
        imgs = batch
        preds = self.model(imgs)
        return preds

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), **self.config["optimizer_params"])

        if self.config["scheduler"]["name"] == "CosineAnnealingLR":
            scheduler = CosineAnnealingLR(
                optimizer,
                **self.config["scheduler"]["params"]["CosineAnnealingLR"],
            )
            lr_scheduler_dict = {"scheduler": scheduler, "interval": "step"}
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}
        elif self.config["scheduler"]["name"] == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(
                optimizer,
                **self.config["scheduler"]["params"]["ReduceLROnPlateau"],
            )
            lr_scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        if self.config["image_size"] != 256:
            preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
        loss = self.loss_module(preds, labels)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=16)

        for param_group in self.trainer.optimizers[0].param_groups:
            lr = param_group["lr"]
        self.log("lr", lr, on_step=True, on_epoch=False, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        if self.config["image_size"] != 256:
            preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
        loss = self.loss_module(preds, labels)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_step_outputs.append(preds)
        self.val_step_labels.append(labels)

    def on_validation_epoch_end(self):
        all_preds = torch.cat(self.val_step_outputs)
        all_labels = torch.cat(self.val_step_labels)
        all_preds = torch.sigmoid(all_preds)
        self.val_step_outputs.clear()
        self.val_step_labels.clear()
        val_dice = dice(all_preds, all_labels.long())
        self.log("val_dice", val_dice, on_step=False, on_epoch=True, prog_bar=True)
        if self.trainer.global_rank == 0:
            print(f"\nEpoch: {self.current_epoch}", flush=True)

トレーニング実行

# Actual training

import warnings

warnings.filterwarnings("ignore")

import os
import torch
import yaml
import pandas as pd
import pytorch_lightning as pl
from pprint import pprint
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
from torch.utils.data import DataLoader

with open("config.yaml", "r") as file_obj:
    config = yaml.safe_load(file_obj)

contrails = os.path.join(config["data_path"], "contrails/")
train_path = os.path.join(config["data_path"], "train_df.csv")
valid_path = os.path.join(config["data_path"], "valid_df.csv")

train_df = pd.read_csv(train_path)
valid_df = pd.read_csv(valid_path)

train_df["path"] = contrails + train_df["record_id"].astype(str) + ".npy"
valid_df["path"] = contrails + valid_df["record_id"].astype(str) + ".npy"

dataset_train = ContrailsDataset(train_df, config["model"]["image_size"], train=True)
dataset_validation = ContrailsDataset(valid_df, config["model"]["image_size"], train=False)

data_loader_train = DataLoader(
    dataset_train,
    batch_size=config["train_bs"],
    shuffle=True,
    num_workers=config["workers"],
)
data_loader_validation = DataLoader(
    dataset_validation,
    batch_size=config["valid_bs"],
    shuffle=False,
    num_workers=config["workers"],
)

checkpoint_callback = ModelCheckpoint(
    save_weights_only=True,
    monitor="val_dice",
    dirpath=config["output_dir"],
    mode="max",
    filename="model",
    save_top_k=1,
    verbose=1,
)

progress_bar_callback = TQDMProgressBar(
    refresh_rate=config["progress_bar_refresh_rate"]
)

early_stop_callback = EarlyStopping(**config["early_stop"])

trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stop_callback, progress_bar_callback],
    **config["trainer"],
)

config["model"]["scheduler"]["params"]["CosineAnnealingLR"]["T_max"] *= len(data_loader_train)/config["trainer"]["devices"]
model = LightningModule(config["model"])

trainer.fit(model, data_loader_train, data_loader_validation)

結果出力

batch_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = '/kaggle/input/google-research-identify-contrails-reduce-global-warming'
data_root = '/kaggle/input/google-research-identify-contrails-reduce-global-warming/test/'

filenames = os.listdir(data_root)
test_df = pd.DataFrame(filenames, columns=['record_id'])
test_df['path'] = data_root + test_df['record_id'].astype(str)

class ContrailsDataset(torch.utils.data.Dataset):
    def __init__(self, df, image_size=256, train=True):
        
        self.df = df
        self.trn = train
        self.df_idx: pd.DataFrame = pd.DataFrame({'idx': os.listdir(f'/kaggle/input/google-research-identify-contrails-reduce-global-warming/test')})
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.image_size = image_size
        if image_size != 256:
            self.resize_image = T.transforms.Resize(image_size)
    
    def read_record(self, directory):
        record_data = {}
        for x in [
            "band_11", 
            "band_14", 
            "band_15"
        ]:

            record_data[x] = np.load(os.path.join(directory, x + ".npy"))

        return record_data

    def normalize_range(self, data, bounds):
        """Maps data to the range [0, 1]."""
        return (data - bounds[0]) / (bounds[1] - bounds[0])
    
    def get_false_color(self, record_data):
        _T11_BOUNDS = (243, 303)
        _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
        _TDIFF_BOUNDS = (-4, 2)
        
        N_TIMES_BEFORE = 4

        r = self.normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
        g = self.normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
        b = self.normalize_range(record_data["band_14"], _T11_BOUNDS)
        false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
        img = false_color[..., N_TIMES_BEFORE]

        return img
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        con_path = row.path
        data = self.read_record(con_path)    
        
        img = self.get_false_color(data)
        
        img = torch.tensor(np.reshape(img, (256, 256, 3))).to(torch.float32).permute(2, 0, 1)
        
        if self.image_size != 256:
            img = self.resize_image(img)
        
        img = self.normalize_image(img)
        
        image_id = int(self.df_idx.iloc[index]['idx'])
            
        return img.float(), torch.tensor(image_id)
    
    def __len__(self):
        return len(self.df)

test_ds = ContrailsDataset(
        test_df,
        config["model"]["image_size"],
        train = False
    )
 
test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers = 1)

class LightningModule(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = smp.Unet(encoder_name="timm-resnest26d",
                              encoder_weights=None,
                              in_channels=3,
                              classes=1,
                              activation=None,
                              )

    def forward(self, batch):
        return self.model(batch)

model = LightningModule().load_from_checkpoint("/kaggle/working/models/model.ckpt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
model.zero_grad()

def rle_encode(x, fg_val=1):
    """
    Args:
        x:  numpy array of shape (height, width), 1 - mask, 0 - background
    Returns: run length encoding as list
    """

    dots = np.where(
        x.T.flatten() == fg_val)[0]  # .T sets Fortran order down-then-right
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

def list_to_string(x):
    """
    Converts list to a string representation
    Empty list returns '-'
    """
    if x: # non-empty list
        s = str(x).replace("[", "").replace("]", "").replace(",", "")
    else:
        s = '-'
    return s

submission = pd.read_csv('/kaggle/input/google-research-identify-contrails-reduce-global-warming/sample_submission.csv', index_col='record_id')

for i, data in enumerate(test_dl):
    images, image_id = data
    
    images = images.to(device)
    with torch.no_grad():
        predicted_mask = model.forward(images[:, :, :, :])
    if config["model"]["image_size"] != 256:
        predicted_mask = torch.nn.functional.interpolate(predicted_mask, size=256, mode='bilinear')
    predicted_mask = torch.sigmoid(predicted_mask).cpu().detach().numpy()
    
    predicted_mask_with_threshold = np.zeros((images.shape[0], 256, 256))
    predicted_mask_with_threshold[predicted_mask[:, 0, :, :] < 0.5] = 0
    predicted_mask_with_threshold[predicted_mask[:, 0, :, :] > 0.5] = 1
    
    for img_num in range(0, images.shape[0]):
        current_mask = predicted_mask_with_threshold[img_num, :, :]
        current_image_id = image_id[img_num].item()
        
        submission.loc[int(current_image_id), 'encoded_pixels'] = list_to_string(rle_encode(current_mask))

submission.to_csv('submission.csv')

参考

https://www.kaggle.com/code/manishkumar7432698/google-reseach-baseline-pytorch