Categories
モデル

バッチ正規化

概要

 深層学習において各層の入力データの分布が大きく変わると、学習が不安定になることがあります。バッチ正則化は、勾配が急激に大きくなることを防ぎ、学習を安定させる仕組みです。学習が不安定になる原因の一つとして、内部共変量シフトが大きくなることが挙げられます。

内部共変量シフト

 ニューラルネットワークの各層でパラメータの更新が進み、各層での入力の分布が変化していきます。これにより、前の層からの入力の分布が変化し、これを内部共変量シフトといいます。このシフトが大きくなると、勾配が急激に大きくなり、学習が困難になることがあります。バッチ正規化を用いることで、内部共変量シフトを小さく抑えることが可能になります。

具体的な処理の内容としては、バッチ内の入力データに対して平均が0、分散が1になるように正規化します。これにより、バッチごとの分布の変化を抑制し、学習が安定します。

 
実装例

import numpy as np

class BatchNormalization:
    def __init__(self, epsilon=1e-5, momentum=0.9):
        self.epsilon = epsilon
        self.momentum = momentum
        self.running_mean = None
        self.running_var = None
        self.gamma = None
        self.beta = None
        self.std = None
        self.x_centered = None

    def forward_pass(self, x, train=True):
        if self.running_mean is None:
            self.running_mean = np.mean(x, axis=0)
            self.running_var = np.var(x, axis=0)

        if train:
            batch_mean = np.mean(x, axis=0)
            batch_var = np.var(x, axis=0)
            self.x_centered = x - batch_mean
            self.std = np.sqrt(batch_var + self.epsilon)
            x_normalized = self.x_centered / self.std

            self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * batch_mean
            self.running_var = self.momentum * self.running_var + (1 - self.momentum) * batch_var
        else:
            x_normalized = (x - self.running_mean) / np.sqrt(self.running_var + self.epsilon)

        return x_normalized

    def backward_pass(self, grad):
        N, D = grad.shape
        dgamma = np.sum(grad * self.x_normalized, axis=0)
        dbeta = np.sum(grad, axis=0)
        dx_normalized = grad * self.gamma
        dx = (1. / N) * self.std * (N*dx_normalized - np.sum(dx_normalized, axis=0)
              - self.x_centered*np.sum(dx_normalized*self.x_centered, axis=0)/self.std**2)
        
        self.gamma -= learning_rate * dgamma
        self.beta -= learning_rate * dbeta

        return dx

# 使い方の例:
# バッチ正則化のインスタンスを作成
bn = BatchNormalization()

# 仮の入力データを作成(例えば、4つのデータポイントと3つの特徴量を持つ)
input_data = np.random.rand(4, 3)

# バッチ正則化を適用(学習時)
normalized_data = bn.forward_pass(input_data)
print("Normalized data (training):")
print(normalized_data)

# 逆伝播(バックプロパゲーション)
gradient = np.random.rand(4, 3)  # 仮の勾配
dx = bn.backward_pass(gradient)
print("\nBackward pass:")
print(dx)