Categories
モデル

グラフニューラルネットワーク

概要

グラフニューラルネットワーク(Graph Neural Network, GNN)は、グラフ構造を持つデータ(例: ソーシャルネットワーク、分子構造、交通ネットワークなど)を扱うための機械学習モデルです。

通常のニューラルネットワークは、各データポイントがベクトルや行列で表現されると仮定しますが、グラフデータは異なります。グラフデータはノード(Node)とエッジ(Edge)から成り立ち、ノードはデータポイントを表し、エッジはノード間の関係性を示します。

GNNは、これらのノードとエッジの情報を活用して、グラフ全体の特徴を学習します。これは、隣接するノード間の情報を共有することで実現されます。

GNNの主な要素としては、以下のようなものがあります:

  1. 隣接行列 (Adjacency Matrix): グラフ内のノード間の接続関係を表す行列。各エントリは、ノード間の接続が存在する場合は1、存在しない場合は0となります。
  2. 特徴行列 (Feature Matrix): 各ノードに関連付けられた特徴を表す行列。通常は、各行がノードを表し、各列が特徴を表します。
  3. 隣接情報の集約 (Aggregation of Neighboring Information): GNNは、隣接するノードから情報を集め、それを元に各ノードの表現を更新します。この際、隣接ノードの情報を組み合わせる手法が用いられます。
  4. 更新規則 (Update Rule): 各ノードの新しい表現を計算するためのアルゴリズム。通常は、隣接ノードの情報を使ってノードの表現を更新します。

GNNは、異なるタスクに応用することができます。例えば、ノード分類、リンク予測、グラフ分類などがあります。

GNNは、社会ネットワーク分析、分子構造の予測、交通流の予測など、様々な分野で活用されています。例としては、Graph Convolutional Network (GCN)、GraphSAGE、GAT(Graph Attention Network)などがあります。

コード例

import dgl
import dgl.function as fn
import torch
import torch.nn as nn

# グラフの定義
G = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 0]))  # (0->1, 1->2, 2->3, 3->0)

# ノードの特徴を定義
G.ndata['h'] = torch.eye(4)

# グラフ畳み込み層の定義
class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, feature):
        # 通常の線形変換
        h = self.linear(feature)
        # 隣接ノードの情報を集約
        g.ndata['h'] = h
        g.update_all(fn.copy_src(src='h', out='m'),
                     fn.sum(msg='m', out='h'))
        return g.ndata.pop('h')

# モデルの定義
class GCN(nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.layer1 = GCNLayer(4, 4)  # 入力特徴の次元数が4, 出力特徴の次元数も4

    def forward(self, g, features):
        x = self.layer1(g, features)
        return x

# モデルのインスタンス化
model = GCN()

# 順伝播
output = model(G, G.ndata['h'])
print(output)