ライブラリ、データのロード等
学習に必要となるライブラリへのパスを通します。
import sys
sys.path.append("../input/pretrained-models-pytorch")
sys.path.append("../input/efficientnet-pytorch")
sys.path.append("/kaggle/input/smp-github/segmentation_models.pytorch-master")
sys.path.append("/kaggle/input/timm-pretrained-resnest/resnest/")
import segmentation_models_pytorch as smp
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp /kaggle/input/timm-pretrained-resnest/resnest/gluon_resnest26-50eb607c.pth /root/.cache/torch/hub/checkpoints/gluon_resnest26-50eb607c.pth
学習モデルのパラメータ設定
%%writefile config.yaml
data_path: "/kaggle/input/contrails-images-ash-color"
output_dir: "models"
seed: 42
train_bs: 48
valid_bs: 128
workers: 2
progress_bar_refresh_rate: 1
early_stop:
monitor: "val_loss"
mode: "min"
patience: 999
verbose: 1
trainer:
max_epochs: 2
min_epochs: 2
enable_progress_bar: True
precision: "16-mixed"
devices: 2
model:
seg_model: "Unet"
encoder_name: "timm-resnest26d"
loss_smooth: 1.0
image_size: 384
optimizer_params:
lr: 0.0005
weight_decay: 0.0
scheduler:
name: "CosineAnnealingLR"
params:
CosineAnnealingLR:
T_max: 2
eta_min: 1.0e-6
last_epoch: -1
ReduceLROnPlateau:
mode: "min"
factor: 0.31622776601
patience: 4
verbose: True
学習回数はmin_epoch、max_epochで変えることが出来ます。
セグメンテーションモデルには、UNetを使っています。エンコーダーには、timmにあるresnetを使用しています。
# Dataset
import torch
import numpy as np
import torchvision.transforms as T
class ContrailsDataset(torch.utils.data.Dataset):
def __init__(self, df, image_size=256, train=True):
self.df = df
self.trn = train
self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
self.image_size = image_size
if image_size != 256:
self.resize_image = T.transforms.Resize(image_size)
def __getitem__(self, index):
row = self.df.iloc[index]
con_path = row.path
con = np.load(str(con_path))
img = con[..., :-1]
label = con[..., -1]
label = torch.tensor(label)
img = torch.tensor(np.reshape(img, (256, 256, 3))).to(torch.float32).permute(2, 0, 1)
if self.image_size != 256:
img = self.resize_image(img)
img = self.normalize_image(img)
return img.float(), label.float()
def __len__(self):
return len(self.df)
# Lightning module
import torch
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.optim import AdamW
import torch.nn as nn
from torchmetrics.functional import dice
seg_models = {
"Unet": smp.Unet,
"Unet++": smp.UnetPlusPlus,
"MAnet": smp.MAnet,
"Linknet": smp.Linknet,
"FPN": smp.FPN,
"PSPNet": smp.PSPNet,
"PAN": smp.PAN,
"DeepLabV3": smp.DeepLabV3,
"DeepLabV3+": smp.DeepLabV3Plus,
}
class LightningModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.model = model = seg_models[config["seg_model"]](
encoder_name=config["encoder_name"],
encoder_weights="imagenet",
in_channels=3,
classes=1,
activation=None,
)
self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=config["loss_smooth"])
self.val_step_outputs = []
self.val_step_labels = []
def forward(self, batch):
imgs = batch
preds = self.model(imgs)
return preds
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), **self.config["optimizer_params"])
if self.config["scheduler"]["name"] == "CosineAnnealingLR":
scheduler = CosineAnnealingLR(
optimizer,
**self.config["scheduler"]["params"]["CosineAnnealingLR"],
)
lr_scheduler_dict = {"scheduler": scheduler, "interval": "step"}
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}
elif self.config["scheduler"]["name"] == "ReduceLROnPlateau":
scheduler = ReduceLROnPlateau(
optimizer,
**self.config["scheduler"]["params"]["ReduceLROnPlateau"],
)
lr_scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
def training_step(self, batch, batch_idx):
imgs, labels = batch
preds = self.model(imgs)
if self.config["image_size"] != 256:
preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
loss = self.loss_module(preds, labels)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=16)
for param_group in self.trainer.optimizers[0].param_groups:
lr = param_group["lr"]
self.log("lr", lr, on_step=True, on_epoch=False, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
imgs, labels = batch
preds = self.model(imgs)
if self.config["image_size"] != 256:
preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
loss = self.loss_module(preds, labels)
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.val_step_outputs.append(preds)
self.val_step_labels.append(labels)
def on_validation_epoch_end(self):
all_preds = torch.cat(self.val_step_outputs)
all_labels = torch.cat(self.val_step_labels)
all_preds = torch.sigmoid(all_preds)
self.val_step_outputs.clear()
self.val_step_labels.clear()
val_dice = dice(all_preds, all_labels.long())
self.log("val_dice", val_dice, on_step=False, on_epoch=True, prog_bar=True)
if self.trainer.global_rank == 0:
print(f"\nEpoch: {self.current_epoch}", flush=True)
トレーニング実行
# Actual training
import warnings
warnings.filterwarnings("ignore")
import os
import torch
import yaml
import pandas as pd
import pytorch_lightning as pl
from pprint import pprint
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
from torch.utils.data import DataLoader
with open("config.yaml", "r") as file_obj:
config = yaml.safe_load(file_obj)
contrails = os.path.join(config["data_path"], "contrails/")
train_path = os.path.join(config["data_path"], "train_df.csv")
valid_path = os.path.join(config["data_path"], "valid_df.csv")
train_df = pd.read_csv(train_path)
valid_df = pd.read_csv(valid_path)
train_df["path"] = contrails + train_df["record_id"].astype(str) + ".npy"
valid_df["path"] = contrails + valid_df["record_id"].astype(str) + ".npy"
dataset_train = ContrailsDataset(train_df, config["model"]["image_size"], train=True)
dataset_validation = ContrailsDataset(valid_df, config["model"]["image_size"], train=False)
data_loader_train = DataLoader(
dataset_train,
batch_size=config["train_bs"],
shuffle=True,
num_workers=config["workers"],
)
data_loader_validation = DataLoader(
dataset_validation,
batch_size=config["valid_bs"],
shuffle=False,
num_workers=config["workers"],
)
checkpoint_callback = ModelCheckpoint(
save_weights_only=True,
monitor="val_dice",
dirpath=config["output_dir"],
mode="max",
filename="model",
save_top_k=1,
verbose=1,
)
progress_bar_callback = TQDMProgressBar(
refresh_rate=config["progress_bar_refresh_rate"]
)
early_stop_callback = EarlyStopping(**config["early_stop"])
trainer = pl.Trainer(
callbacks=[checkpoint_callback, early_stop_callback, progress_bar_callback],
**config["trainer"],
)
config["model"]["scheduler"]["params"]["CosineAnnealingLR"]["T_max"] *= len(data_loader_train)/config["trainer"]["devices"]
model = LightningModule(config["model"])
trainer.fit(model, data_loader_train, data_loader_validation)
結果出力
batch_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = '/kaggle/input/google-research-identify-contrails-reduce-global-warming'
data_root = '/kaggle/input/google-research-identify-contrails-reduce-global-warming/test/'
filenames = os.listdir(data_root)
test_df = pd.DataFrame(filenames, columns=['record_id'])
test_df['path'] = data_root + test_df['record_id'].astype(str)
class ContrailsDataset(torch.utils.data.Dataset):
def __init__(self, df, image_size=256, train=True):
self.df = df
self.trn = train
self.df_idx: pd.DataFrame = pd.DataFrame({'idx': os.listdir(f'/kaggle/input/google-research-identify-contrails-reduce-global-warming/test')})
self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
self.image_size = image_size
if image_size != 256:
self.resize_image = T.transforms.Resize(image_size)
def read_record(self, directory):
record_data = {}
for x in [
"band_11",
"band_14",
"band_15"
]:
record_data[x] = np.load(os.path.join(directory, x + ".npy"))
return record_data
def normalize_range(self, data, bounds):
"""Maps data to the range [0, 1]."""
return (data - bounds[0]) / (bounds[1] - bounds[0])
def get_false_color(self, record_data):
_T11_BOUNDS = (243, 303)
_CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
_TDIFF_BOUNDS = (-4, 2)
N_TIMES_BEFORE = 4
r = self.normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
g = self.normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
b = self.normalize_range(record_data["band_14"], _T11_BOUNDS)
false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
img = false_color[..., N_TIMES_BEFORE]
return img
def __getitem__(self, index):
row = self.df.iloc[index]
con_path = row.path
data = self.read_record(con_path)
img = self.get_false_color(data)
img = torch.tensor(np.reshape(img, (256, 256, 3))).to(torch.float32).permute(2, 0, 1)
if self.image_size != 256:
img = self.resize_image(img)
img = self.normalize_image(img)
image_id = int(self.df_idx.iloc[index]['idx'])
return img.float(), torch.tensor(image_id)
def __len__(self):
return len(self.df)
test_ds = ContrailsDataset(
test_df,
config["model"]["image_size"],
train = False
)
test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers = 1)
class LightningModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = smp.Unet(encoder_name="timm-resnest26d",
encoder_weights=None,
in_channels=3,
classes=1,
activation=None,
)
def forward(self, batch):
return self.model(batch)
model = LightningModule().load_from_checkpoint("/kaggle/working/models/model.ckpt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
model.zero_grad()
def rle_encode(x, fg_val=1):
"""
Args:
x: numpy array of shape (height, width), 1 - mask, 0 - background
Returns: run length encoding as list
"""
dots = np.where(
x.T.flatten() == fg_val)[0] # .T sets Fortran order down-then-right
run_lengths = []
prev = -2
for b in dots:
if b > prev + 1:
run_lengths.extend((b + 1, 0))
run_lengths[-1] += 1
prev = b
return run_lengths
def list_to_string(x):
"""
Converts list to a string representation
Empty list returns '-'
"""
if x: # non-empty list
s = str(x).replace("[", "").replace("]", "").replace(",", "")
else:
s = '-'
return s
submission = pd.read_csv('/kaggle/input/google-research-identify-contrails-reduce-global-warming/sample_submission.csv', index_col='record_id')
for i, data in enumerate(test_dl):
images, image_id = data
images = images.to(device)
with torch.no_grad():
predicted_mask = model.forward(images[:, :, :, :])
if config["model"]["image_size"] != 256:
predicted_mask = torch.nn.functional.interpolate(predicted_mask, size=256, mode='bilinear')
predicted_mask = torch.sigmoid(predicted_mask).cpu().detach().numpy()
predicted_mask_with_threshold = np.zeros((images.shape[0], 256, 256))
predicted_mask_with_threshold[predicted_mask[:, 0, :, :] < 0.5] = 0
predicted_mask_with_threshold[predicted_mask[:, 0, :, :] > 0.5] = 1
for img_num in range(0, images.shape[0]):
current_mask = predicted_mask_with_threshold[img_num, :, :]
current_image_id = image_id[img_num].item()
submission.loc[int(current_image_id), 'encoded_pixels'] = list_to_string(rle_encode(current_mask))
submission.to_csv('submission.csv')
参考
https://www.kaggle.com/code/manishkumar7432698/google-reseach-baseline-pytorch