Categories
kaggle

RSNA 2023 Abdominal Trauma Detection – CNNを使った解法 train編

導入

CTスキャン画像を、CNNモデルを使って、学習します。

ライブラリ準備

まずは、学習を行う前に必要なライブラリのインストール、インポートを行います。

!pip install -q keras-cv-attention-models
!pip install -qU wandb
!pip install -qU scikit-learn
!pip install -q seaborn

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # to avoid too many logging messages
import pandas as pd, numpy as np, random, shutil
import tensorflow as tf, re, math
import tensorflow.keras.backend as K
import sklearn
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow_probability as tfp
import wandb
import yaml

from IPython import display as ipd
from glob import glob
from tqdm import tqdm
from sklearn.model_selection import KFold, StratifiedKFold, GroupKFold, StratifiedGroupKFold
from sklearn.metrics import roc_auc_score
from sklearn.utils.class_weight import compute_class_weight

Wandbの設定

学習のパラメータなどを記録しておくために、ここでは、wandbを使用します。

import wandb

try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("WANDB")

    wandb.login(key=api_key)
    anonymous = None
except:
    anonymous = "must"
    print('To use your W&B account,\nGo to Add-ons -> Secrets and provide your W&B access token. Use the Label name as WANDB. \nGet your W&B access token from here: https://wandb.ai/authorize')

モデルの構成

使用するモデルの構成について記述します。

class CFG:
    wandb         = True
    competition   = 'rsna-atd' 
    _wandb_kernel = 'awsaf49'
    debug         = False
    comment       = 'EfficientNetV1B0-256x256-low_lr-vflip'
    exp_name      = 'baseline-v4: new_ds + multi_head' # name of the experiment, folds will be grouped using 'exp_name'
    
    # use verbose=0 for silent, vebose=1 for interactive,
    verbose      = 0
    display_plot = True

    # device
    device = "TPU-VM" #or "GPU"

    model_name = 'EfficientNetV1B0'

    # seed for data-split, layer init, augs
    seed = 42

    # number of folds for data-split
    folds = 4
    
    # which folds to train
    selected_folds = [0, 1, 2]

    # size of the image
    img_size = [256, 256]
#     eq_dim = np.prod(img_size)**0.5

    # batch_size and epochs
    batch_size = 48
    epochs = 10

    # loss
    loss      = 'BCE & CCE'  # BCE, Focal
    
    # optimizer
    optimizer = 'Adam'

    # augmentation
    augment   = True

    # scale-shift-rotate-shear
    transform = 0.90  # transform prob
    fill_mode = 'constant'
    rot    = 2.0
    shr    = 2.0
    hzoom  = 50.0
    wzoom  = 50.0
    hshift = 10.0
    wshift = 10.0

    # flip
    hflip = True
    vflip = True

    # clip
    clip = False

    # lr-scheduler
    scheduler   = 'cosine' # cosine

    # dropout
    drop_prob   = 0.6
    drop_cnt    = 5
    drop_size   = 0.05
    
    # cut-mix-up
    mixup_prob = 0.0
    mixup_alpha = 0.5
    
    cutmix_prob = 0.0
    cutmix_alpha = 2.5

    # pixel-augment
    pixel_aug = 0.90  # prob of pixel_aug
    sat  = [0.7, 1.3]
    cont = [0.8, 1.2]
    bri  = 0.15
    hue  = 0.05

    # test-time augs
    tta = 1
    
    # target column
    target_col  = [ "bowel_injury", "extravasation_injury", "kidney_healthy", "kidney_low",
                   "kidney_high", "liver_healthy", "liver_low", "liver_high",
                   "spleen_healthy", "spleen_low", "spleen_high"] # not using "bowel_healthy" & "extravasation_healthy"
def seeding(SEED):
    np.random.seed(SEED)
    random.seed(SEED)
    os.environ['PYTHONHASHSEED'] = str(SEED)
#     os.environ['TF_CUDNN_DETERMINISTIC'] = str(SEED)
    tf.random.set_seed(SEED)
    print('seeding done!!!')
seeding(CFG.seed)

CPU、パスの設定

if "TPU" in CFG.device:
    tpu = 'local' if CFG.device=='TPU-VM' else None
    print("connecting to TPU...")
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu=tpu)
        strategy = tf.distribute.TPUStrategy(tpu)
    except:
        CFG.device = "GPU"
        
if CFG.device == "GPU"  or CFG.device=="CPU":
    ngpu = len(tf.config.experimental.list_physical_devices('GPU'))
    if ngpu>1:
        print("Using multi GPU")
        strategy = tf.distribute.MirroredStrategy()
    elif ngpu==1:
        print("Using single GPU")
        strategy = tf.distribute.get_strategy()
    else:
        print("Using CPU")
        strategy = tf.distribute.get_strategy()
        CFG.device = "CPU"

if CFG.device == "GPU":
    print("Num GPUs Available: ", ngpu)
    

AUTO     = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

BASE_PATH = f'/kaggle/input/rsna-atd-512x512-png-v2-dataset'

データ読み込み

データの読み込みを行っていきます。train.csvファイルには、以下の情報が入っています。

  • patient_id: 患者ごとのユニークなID.
  • series_id: スキャンごとのユニークなID
  • instance_number: スキャン内の画像数
  • [bowel/extravasation]_[healthy/injury]: 2つの値を持つ外傷の種類。
  • [kidney/liver/spleen]_[healthy/low/high]: 3つの値を持つ外傷の種類。
  • any_injury: 何らかの外傷を持っているか。

train.csvを読み込み、DataFrameに入れていきます。DataFrameに新たに画像パス(image_path)の列を追加し、パスを追加していきます。test.csvに対しても同様の処理を行います。

df = pd.read_csv(f'{BASE_PATH}/train.csv')

df['image_path'] = f'{BASE_PATH}/train_images'\
                    + '/' + df.patient_id.astype(str)\
                    + '/' + df.series_id.astype(str)\
                    + '/' + df.instance_number.astype(str) +'.png'
df = df.drop_duplicates()
print('Train:')
display(df.head(2))

# test
test_df = pd.read_csv(f'{BASE_PATH}/test.csv')
test_df['image_path'] = f'{BASE_PATH}/test_images'\
                    + '/' + test_df.patient_id.astype(str)\
                    + '/' + test_df.series_id.astype(str)\
                    + '/' + test_df.instance_number.astype(str) +'.png'
test_df = test_df.drop_duplicates()

データ分割

df['stratify'] = ''
for col in CFG.target_col:
    df['stratify'] += df[col].astype(str)

df = df.reset_index(drop=True)
skf = StratifiedGroupKFold(n_splits=CFG.folds, shuffle=True, random_state=CFG.seed)
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['stratify'], df["patient_id"])):
    df.loc[val_idx, 'fold'] = fold

データ拡張

モデルの精度を向上させるために、データ数を増やします。左右反転、輝度、コントラストの変化、拡大/縮小などの加工を画像に行います。

def get_mat(shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    #rotation = math.pi * rotation / 180.
    shear    = math.pi * shear    / 180.

    def get_3x3_mat(lst):
        return tf.reshape(tf.concat([lst],axis=0), [3,3])
    
    # ROTATION MATRIX
#     c1   = tf.math.cos(rotation)
#     s1   = tf.math.sin(rotation)
    one  = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    
#     rotation_matrix = get_3x3_mat([c1,   s1,   zero, 
#                                    -s1,  c1,   zero, 
#                                    zero, zero, one])    
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)    
    
    shear_matrix = get_3x3_mat([one,  s2,   zero, 
                               zero, c2,   zero, 
                                zero, zero, one])        
    # ZOOM MATRIX
    zoom_matrix = get_3x3_mat([one/height_zoom, zero,           zero, 
                               zero,            one/width_zoom, zero, 
                               zero,            zero,           one])    
    # SHIFT MATRIX
    shift_matrix = get_3x3_mat([one,  zero, height_shift, 
                                zero, one,  width_shift, 
                                zero, zero, one])
    

    return  K.dot(shear_matrix,K.dot(zoom_matrix, shift_matrix)) #K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))                  

def transform(image, DIM=CFG.img_size):#[rot,shr,h_zoom,w_zoom,h_shift,w_shift]):
    if DIM[0]>DIM[1]:
        diff  = (DIM[0]-DIM[1])
        pad   = [diff//2, diff//2 + diff%2]
        image = tf.pad(image, [[0, 0], [pad[0], pad[1]],[0, 0]])
        NEW_DIM = DIM[0]
    elif DIM[0]<DIM[1]:
        diff  = (DIM[1]-DIM[0])
        pad   = [diff//2, diff//2 + diff%2]
        image = tf.pad(image, [[pad[0], pad[1]], [0, 0],[0, 0]])
        NEW_DIM = DIM[1]
    
    rot = CFG.rot * tf.random.normal([1], dtype='float32')
    shr = CFG.shr * tf.random.normal([1], dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([1], dtype='float32') / CFG.hzoom
    w_zoom = 1.0 + tf.random.normal([1], dtype='float32') / CFG.wzoom
    h_shift = CFG.hshift * tf.random.normal([1], dtype='float32') 
    w_shift = CFG.wshift * tf.random.normal([1], dtype='float32') 
    
    transformation_matrix=tf.linalg.inv(get_mat(shr,h_zoom,w_zoom,h_shift,w_shift))
    
    flat_tensor=tfa.image.transform_ops.matrices_to_flat_transforms(transformation_matrix)
    
    image=tfa.image.transform(image,flat_tensor, fill_mode=CFG.fill_mode)
    
    rotation = math.pi * rot / 180.
    
    image=tfa.image.rotate(image,-rotation, fill_mode=CFG.fill_mode)
    
    if DIM[0]>DIM[1]:
        image=tf.reshape(image, [NEW_DIM, NEW_DIM,3])
        image = image[:, pad[0]:-pad[1],:]
    elif DIM[1]>DIM[0]:
        image=tf.reshape(image, [NEW_DIM, NEW_DIM,3])
        image = image[pad[0]:-pad[1],:,:]
    image = tf.reshape(image, [*DIM, 3])    
    return image

def dropout(image,DIM=CFG.img_size, PROBABILITY = 0.6, CT = 5, SZ = 0.1):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image with CT squares of side size SZ*DIM removed
    
    # DO DROPOUT WITH PROBABILITY DEFINED ABOVE
    P = tf.cast( tf.random.uniform([],0,1)<PROBABILITY, tf.int32)
    if (P==0)|(CT==0)|(SZ==0): 
        return image
    
    for k in range(CT):
        # CHOOSE RANDOM LOCATION
        x = tf.cast( tf.random.uniform([],0,DIM[1]),tf.int32)
        y = tf.cast( tf.random.uniform([],0,DIM[0]),tf.int32)
        # COMPUTE SQUARE 
        WIDTH = tf.cast( SZ*min(DIM),tf.int32) * P
        ya = tf.math.maximum(0,y-WIDTH//2)
        yb = tf.math.minimum(DIM[0],y+WIDTH//2)
        xa = tf.math.maximum(0,x-WIDTH//2)
        xb = tf.math.minimum(DIM[1],x+WIDTH//2)
        # DROPOUT IMAGE
        one = image[ya:yb,0:xa,:]
        two = tf.zeros([yb-ya,xb-xa,3], dtype = image.dtype) 
        three = image[ya:yb,xb:DIM[1],:]
        middle = tf.concat([one,two,three],axis=1)
        image = tf.concat([image[0:ya,:,:],middle,image[yb:DIM[0],:,:]],axis=0)
        image = tf.reshape(image,[*DIM,3])

#     image = tf.reshape(image,[*DIM,3])
    return image

cutmix

def random_int(shape=[], minval=0, maxval=1):
    return tf.random.uniform(
        shape=shape, minval=minval, maxval=maxval, dtype=tf.int32)


def random_float(shape=[], minval=0.0, maxval=1.0):
    rnd = tf.random.uniform(
        shape=shape, minval=minval, maxval=maxval, dtype=tf.float32)
    return rnd

# mixup
def get_mixup(alpha=0.2, prob=0.5):
    @tf.function
    def mixup(images, labels, alpha=alpha, prob=prob):
        if random_float() > prob:
            return images, labels

        image_shape = tf.shape(images)
        label_shape = tf.shape(labels)

        beta = tfp.distributions.Beta(alpha, alpha)
        lam = beta.sample(1)[0]

        images = lam * images + (1.0 - lam) * tf.roll(images, shift=1, axis=0)
        labels = lam * labels + (1.0 - lam) * tf.roll(labels, shift=1, axis=0)

        images = tf.reshape(images, image_shape)
        labels = tf.reshape(labels, label_shape)
        return images, labels
    return mixup

# cutmix
def get_cutmix(alpha, prob=0.5):
    @tf.function
    def cutmix(images, labels, alpha=alpha, prob=prob):
        if random_float() > prob:
            return images, labels
        image_shape = tf.shape(images)
        label_shape = tf.shape(labels)
        
        W = tf.cast(image_shape[2], tf.int32)
        H = tf.cast(image_shape[1], tf.int32)

        beta = tfp.distributions.Beta(alpha, alpha)
        lam = beta.sample(1)[0]

        images_rolled = tf.roll(images, shift=1, axis=0)
        labels_rolled = tf.roll(labels, shift=1, axis=0)

        r_x = random_int([], minval=0, maxval=W)
        r_y = random_int([], minval=0, maxval=H)
        r = 0.5 * tf.math.sqrt(1.0 - lam)
        r_w_half = tf.cast(r * tf.cast(W, tf.float32), tf.int32)
        r_h_half = tf.cast(r * tf.cast(H, tf.float32), tf.int32)

        x1 = tf.cast(tf.clip_by_value(r_x - r_w_half, 0, W), tf.int32)
        x2 = tf.cast(tf.clip_by_value(r_x + r_w_half, 0, W), tf.int32)
        y1 = tf.cast(tf.clip_by_value(r_y - r_h_half, 0, H), tf.int32)
        y2 = tf.cast(tf.clip_by_value(r_y + r_h_half, 0, H), tf.int32)

        # outer-pad patch -> [0, 0, 1, 1, 0, 0]
        patch1 = images[:, y1:y2, x1:x2, :]  # [batch, height, width, channel]
        patch1 = tf.pad(
            patch1, [[0, 0], [y1, H - y2], [x1, W - x2], [0, 0]])  # outer-pad

        # inner-pad patch -> [1, 1, 0, 0, 1, 1]
        patch2 = images_rolled[:, y1:y2, x1:x2, :]
        patch2 = tf.pad(
            patch2, [[0, 0], [y1, H - y2], [x1, W - x2], [0, 0]])  # outer-pad
        patch2 = images_rolled - patch2  # inner-pad = img - outer-pad

        images = patch1 + patch2  # cutmix img

        lam = tf.cast((1.0 - (x2 - x1) * (y2 - y1) / (W * H)), tf.float32)  # no H as (y1 - y2)/H = 1
        labels = lam * labels + (1.0 - lam) * labels_rolled  # cutmix label

        images = tf.reshape(images, image_shape)
        labels = tf.reshape(labels, label_shape)

        return images, labels

    return cutmix

パイプライン

データをモデルに読み込ませる前に、入力データの大きさなどを整えます。

def build_decoder(with_labels=True, target_size=CFG.img_size, ext='png'):
    def decode_image(path):
        file_bytes = tf.io.read_file(path)
        if ext == 'png':
            img = tf.image.decode_png(file_bytes, channels=3, dtype=tf.uint8)
        elif ext in ['jpg', 'jpeg']:
            img = tf.image.decode_jpeg(file_bytes, channels=3)
        else:
            raise ValueError("Image extension not supported")

        img = tf.image.resize(img, target_size, method='bilinear')
        img = tf.cast(img, tf.float32) / 255.0
        img = tf.reshape(img, [*target_size, 3])

        return img
    
    def decode_label(label):
        label = tf.cast(label, tf.float32)
        return (label[0:1], label[1:2], label[2:5], label[5:8], label[8:11])
    
    def decode_with_labels(path, label):
        return decode_image(path), decode_label(label)
    
    return decode_with_labels if with_labels else decode


def build_augmenter(with_labels=True, dim=CFG.img_size):
    def augment(img, dim=dim):
        if random_float() < CFG.transform:
            img = transform(img,DIM=dim)
        img = tf.image.random_flip_left_right(img) if CFG.hflip else img
        img = tf.image.random_flip_up_down(img) if CFG.vflip else img
        if random_float() < CFG.pixel_aug:
            img = tf.image.random_hue(img, CFG.hue)
            img = tf.image.random_saturation(img, CFG.sat[0], CFG.sat[1])
            img = tf.image.random_contrast(img, CFG.cont[0], CFG.cont[1])
            img = tf.image.random_brightness(img, CFG.bri)
        img = tf.clip_by_value(img, 0, 1)  if CFG.clip else img         
        img = tf.reshape(img, [*dim, 3])
        return img
    
    def augment_with_labels(img, label):    
        return augment(img), label
    
    return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, batch_size=32, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024, 
                  cache_dir="", drop_remainder=False):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    ds = tf.data.Dataset.from_tensor_slices(slices)
    ds = ds.map(decode_fn, num_parallel_calls=AUTO)
    ds = ds.cache(cache_dir) if cache else ds
    ds = ds.repeat() if repeat else ds
    if shuffle: 
        ds = ds.shuffle(shuffle, seed=CFG.seed)
        opt = tf.data.Options()
        opt.experimental_deterministic = False
        ds = ds.with_options(opt)
    ds = ds.map(augment_fn, num_parallel_calls=AUTO) if augment else ds
    if augment and labels is not None:
        ds = ds.map(lambda img, label: (dropout(img, 
                                               DIM=CFG.img_size, 
                                               PROBABILITY=CFG.drop_prob, 
                                               CT=CFG.drop_cnt,
                                               SZ=CFG.drop_size), label),num_parallel_calls=AUTO)
    ds = ds.batch(batch_size, drop_remainder=drop_remainder)
    if augment and labels is not None:
        if CFG.cutmix_prob:
            ds = ds.map(get_cutmix(alpha=CFG.cutmix_alpha,prob=CFG.cutmix_prob),num_parallel_calls=AUTO)
        if CFG.mixup_prob:
            ds = ds.map(get_mixup(alpha=CFG.mixup_alpha,prob=CFG.mixup_prob),num_parallel_calls=AUTO)
    ds = ds.prefetch(AUTO)
    return ds

fold = 0
fold_df = df[df.fold==fold].sample(frac=1.0)
paths  = fold_df.image_path.tolist()
labels = fold_df[CFG.target_col].values
ds = build_dataset(paths, labels, cache=False, batch_size=32,
                   repeat=True, shuffle=True, augment=False)
ds = ds.unbatch().batch(20)
batch = next(iter(ds))

学習モデルの生成

from keras_cv_attention_models import efficientnet

def build_model(model_name=CFG.model_name,
                loss_name=CFG.loss,
                dim=CFG.img_size,
                compile_model=True,
                include_top=False):         
    
    # Define backbone
    base = getattr(efficientnet, model_name)(input_shape=(*dim,3),
                                    pretrained='imagenet',
                                    num_classes=0) # get base model (efficientnet), use imgnet weights

    inp = base.inputs
    x = base.output
    x = tf.keras.layers.GlobalAveragePooling2D()(x) # use GAP to get pooling result form conv outputs

    # Define 'necks' for each head
    x_bowel = tf.keras.layers.Dense(32, activation='silu')(x)
    x_extra = tf.keras.layers.Dense(32, activation='silu')(x)
    x_liver = tf.keras.layers.Dense(32, activation='silu')(x)
    x_kidney = tf.keras.layers.Dense(32, activation='silu')(x)
    x_spleen = tf.keras.layers.Dense(32, activation='silu')(x)

    # Define heads
    out_bowel = tf.keras.layers.Dense(1, name='bowel', activation='sigmoid')(x_bowel) # use sigmoid to convert predictions to [0-1]
    out_extra = tf.keras.layers.Dense(1, name='extra', activation='sigmoid')(x_extra) # use sigmoid to convert predictions to [0-1]
    out_liver = tf.keras.layers.Dense(3, name='liver', activation='softmax')(x_liver) # use softmax for the liver head
    out_kidney = tf.keras.layers.Dense(3, name='kidney', activation='softmax')(x_kidney) # use softmax for the kidney head
    out_spleen = tf.keras.layers.Dense(3, name='spleen', activation='softmax')(x_spleen) # use softmax for the spleen head

    # Combine outputs
#     out = tf.keras.layers.Concatenate()([out_bowel, out_extra, 
#                                          out_liver, out_kidney, out_spleen])
    out = [out_bowel, out_extra, out_liver, out_kidney, out_spleen]

    # Create model
    model = tf.keras.Model(inputs=inp, outputs=out)

    
    if compile_model:
        # optimizer
        opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
        # loss
        loss = {
            'bowel':tf.keras.losses.BinaryCrossentropy(label_smoothing=0.05),
            'extra':tf.keras.losses.BinaryCrossentropy(label_smoothing=0.05),
            'liver':tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.05),
            'kidney':tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.05),
            'spleen':tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.05),
        }
        # metric
        metrics = {
            'bowel':['accuracy'],
            'extra':['accuracy'],
            'liver':['accuracy'],
            'kidney':['accuracy'],
            'spleen':['accuracy'],
        }
        # compile
        model.compile(optimizer=opt,
                      loss=loss,
                      metrics=metrics)
    return model

tmp = build_model(CFG.model_name, dim=CFG.img_size, compile_model=True)

学習率のスケジュール

def get_lr_callback(batch_size=8, plot=False):
    lr_start   = 0.000005
    lr_max     = 0.00000050 * REPLICAS * batch_size
    lr_min     = 0.000001
    lr_ramp_ep = 4
    lr_sus_ep  = 0
    lr_decay   = 0.8
   
    def lrfn(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
            
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
            
        elif CFG.scheduler=='exp':
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
            
        elif CFG.scheduler=='cosine':
            decay_total_epochs = CFG.epochs - lr_ramp_ep - lr_sus_ep + 3
            decay_epoch_index = epoch - lr_ramp_ep - lr_sus_ep
            phase = math.pi * decay_epoch_index / decay_total_epochs
            cosine_decay = 0.4 * (1 + math.cos(phase))
            lr = (lr_max - lr_min) * cosine_decay + lr_min
        return lr
    if plot:
        plt.figure(figsize=(10,5))
        plt.plot(np.arange(CFG.epochs), [lrfn(epoch) for epoch in np.arange(CFG.epochs)], marker='o')
        plt.xlabel('epoch'); plt.ylabel('learnig rate')
        plt.title('Learning Rate Scheduler')
        plt.show()

    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=False)
    return lr_callback

_=get_lr_callback(CFG.batch_size, plot=True )

Wandbログの設定

# create directory to save gradcam imgs
!mkdir -p gradcam

# intialize wandb run
def wandb_init(fold):
    workingDir = f'D:/OneDrive/sourceCode/jupyter/kaggle-competitions/202308_RSNA 2023 Abdominal Trauma Detection/RSNA-ATD CNN/working/'
    config = {k:v for k,v in dict(vars(CFG)).items() if '__' not in k}
    config.update({"fold":int(fold)})
    yaml.dump(config, open(workingDir + f'config fold-{fold}.yaml', 'w'),)
    config = yaml.load(open(workingDir + f'config fold-{fold}.yaml', 'r'), Loader=yaml.FullLoader)
    run    = wandb.init(project="rsna-atd-public",
               name=f"fold-{fold}|dim-{CFG.img_size[0]}x{CFG.img_size[1]}|model-{CFG.model_name}",
               config=config,
               anonymous=anonymous,
               group=CFG.exp_name
                    )
    return run

def log_wandb(fold):
    "log best result for error analysis"
    # log values to wandb
    wandb.log({
               'best_acc': best_acc,
               'best_loss': best_loss,
               'best_epoch': best_epoch,
               'best_acc_bowel': best_acc_bowel,
               'best_acc_extra': best_acc_extra,
               'best_acc_liver': best_acc_liver,
               'best_acc_kidney': best_acc_kidney,
               'best_acc_spleen': best_acc_spleen,
              })

def get_wb_callbacks(fold):
    wb_ckpt = wandb.keras.WandbModelCheckpoint(filepath='fold-%i.h5'%fold, 
                                               monitor='val_loss',
                                               verbose=CFG.verbose,
                                               save_best_only=True,
                                               save_weights_only=False,
                                               mode='min',)
    wb_metr = wandb.keras.WandbMetricsLogger()
    return [wb_ckpt, wb_metr]

モデル学習の実行

scores = []

for fold in np.arange(CFG.folds):
    
    # ignore not selected folds
    if fold not in CFG.selected_folds:
        continue
        
    # init wandb
    if CFG.wandb:
        run = wandb_init(fold)
        wb_callbacks = get_wb_callbacks(fold)
            
    # train and valid dataframe
    train_df = df.query("fold!=@fold")
    valid_df = df.query("fold==@fold")
    
    # get image_paths and labels
    train_paths = train_df.image_path.values; train_labels = train_df[CFG.target_col].values.astype(np.float32)
    valid_paths = valid_df.image_path.values; valid_labels = valid_df[CFG.target_col].values.astype(np.float32)
    test_paths  = test_df.image_path.values
    
    # shuffle train data
    index = np.arange(len(train_df))
    np.random.shuffle(index)
    train_paths  = train_paths[index]
    train_labels = train_labels[index]
    
    # min samples in debug mode
    min_samples = CFG.batch_size*REPLICAS*2
    
    # for debug model run on small portion
    if CFG.debug:
        train_paths = train_paths[:min_samples]; train_labels = train_labels[:min_samples]
        valid_paths = valid_paths[:min_samples]; valid_labels = valid_labels[:min_samples]
    
    # show message
    print('#'*40); print('#### FOLD: ',fold)
    print('#### IMAGE_SIZE: (%i, %i) | MODEL_NAME: %s | BATCH_SIZE: %i'%
          (CFG.img_size[0],CFG.img_size[1],CFG.model_name,CFG.batch_size*REPLICAS))
    
    # data stat
    num_train = len(train_paths)
    num_valid = len(valid_paths)
    if CFG.wandb:
        wandb.log({'num_train':num_train,
                   'num_valid':num_valid})
    print('#### NUM_TRAIN: {:,} | NUM_VALID: {:,}'.format(num_train, num_valid))
    
    # build model
    K.clear_session()
    with strategy.scope():
        model = build_model(CFG.model_name, dim=CFG.img_size, compile_model=True)

    # build dataset
    cache = 1 if 'TPU' in CFG.device else 0
    train_ds = build_dataset(train_paths, train_labels, cache=cache, batch_size=CFG.batch_size*REPLICAS,
                   repeat=True, shuffle=True, augment=CFG.augment)
    val_ds = build_dataset(valid_paths, valid_labels, cache=cache, batch_size=CFG.batch_size*REPLICAS,
                   repeat=False, shuffle=False, augment=False)
    print('#'*40)   
    
    # callbacks
    callbacks = []
    ## save best model after each fold
    sv = tf.keras.callbacks.ModelCheckpoint(
        'fold-%i.h5'%fold, monitor='val_loss', verbose=CFG.verbose, save_best_only=True,
        save_weights_only=False, mode='min', save_freq='epoch')
    callbacks +=[sv]
    ## lr-scheduler
    callbacks += [get_lr_callback(CFG.batch_size)]
    ## wandb callbacks
    if CFG.wandb:
        callbacks += wb_callbacks
        
    # train
    print('Training...')
    history = model.fit(
        train_ds, 
        epochs=CFG.epochs if not CFG.debug else 2, 
        callbacks = callbacks, 
        steps_per_epoch=len(train_paths)/CFG.batch_size//REPLICAS,
        validation_data=val_ds, 
        verbose=CFG.verbose
    )
    
    # store best results
    best_epoch = np.argmin(history.history['val_loss'])
    best_loss = history.history['val_loss'][best_epoch]
    best_acc_bowel = history.history['val_bowel_accuracy'][best_epoch]
    best_acc_extra = history.history['val_extra_accuracy'][best_epoch]
    best_acc_liver = history.history['val_liver_accuracy'][best_epoch]
    best_acc_kidney = history.history['val_kidney_accuracy'][best_epoch]
    best_acc_spleen = history.history['val_spleen_accuracy'][best_epoch]

    # Find mean accuracy
    best_acc = np.mean([best_acc_bowel, best_acc_extra, 
                        best_acc_liver, best_acc_kidney, best_acc_spleen])

    print(f'\n{"="*17} FOLD {fold} RESULTS {"="*17}')
    print(f'>>>> BEST Loss  : {best_loss:.3f}\n>>>> BEST Acc   : {best_acc:.3f}\n>>>> BEST Epoch : {best_epoch}\n')
    print('ORGAN Acc:')
    print(f'  >>>> {"Bowel".ljust(15)} : {best_acc_bowel:.3f}')
    print(f'  >>>> {"Extravasation".ljust(15)} : {best_acc_extra:.3f}')
    print(f'  >>>> {"Liver".ljust(15)} : {best_acc_liver:.3f}')
    print(f'  >>>> {"Kidney".ljust(15)} : {best_acc_kidney:.3f}')
    print(f'  >>>> {"Spleen".ljust(15)} : {best_acc_spleen:.3f}')

    print(f'{"="*50}\n')

    scores.append([best_loss, best_acc, 
                   best_acc_bowel, best_acc_extra, 
                   best_acc_liver, best_acc_kidney, best_acc_spleen])
    
    # log best result on wandb & plot
    if CFG.wandb:
        log_wandb(fold) # log
        wandb.run.finish() # finish the run
        display(ipd.IFrame(run.url, width=1080, height=720)) # show wandb dashboard

参考

https://www.kaggle.com/code/awsaf49/rsna-atd-cnn-tpu-train/notebook