2025年9月20日土曜日

FaceLiVTの論文をGPT-5 proに要約させる。


https://arxiv.org/pdf/2506.10361



 以下は、実装者向けに最短距離でコード化できるよう整理した“FaceLiVT”の完全要約です。図表や式の参照も付けています。実装の核は RepMix(構造的再パラメータ化付きの深層可分畳み込み)MHLA(Multi‑Head Linear Attention) の2点です。図1(p.2)にブロック構造、表1(p.4)に各ステージ設定、表2–4(p.5)に結果とアブレーションの要点がまとまっています。


1) 目的と全体像(実装観点の要点)

  • FaceLiVT は、CNN×Transformerのハイブリッドで、前段=RepMix後段=MHLAトークンミキサとして組み合わせ、モバイル推論向けにレイテンシと精度の両立を図る設計。図1参照。

  • MHLA は自己注意(MHSA)の 二乗計算を避け、トークン次元に2層の線形(+非線形) をかける設計。計算量は Ω(MHLA)=2(N·Nr)·C と低く、MHSAの Ω=4NC²+2N²C より小さい(§3.4)。

  • 構造的再パラメータ化(BN融合・分岐結合・残差の折り畳み)で、学習時は表現力推論時は単一DWCへ畳み込み、Core ML等で高速化(§3.2)。


2) 入出力・ネットワーク構成

  • 入力:112×112 のアライン済み顔画像(Glint360Kで学習時に使用)。値域は [-1,1] 正規化。

  • Stem:3×3 s=2 畳み込み ×2(112→56→28)。図1、表1。

  • ステージ数:4(Stage1/2=RepMix、Stage3/4=MHLA 版ではMHLA)。各ブロックは TokenMixer → 残差加算 → Channel MLP → 残差加算(式(1))。

  • Channel MLP:2層全結合(拡張比 r=3)、BNGELU 1回(式(2))。実装は 1×1 Conv×2+BN+GELU でOK。

  • ヘッド:Global AvgPool → FC(512) で埋め込み512次元(CosFace学習用)。表1。


3) トークンミキサの詳細

3.1 RepMix(学習時)— 式(6)

  • DWC k×kDWC 1×1 を足し合わせ、その後 BN、元入力を残差接続:
    X' = X + BN( DWC_k×k(X) + DWC_1×1(X) )(非線形は省略)。図1(b)、式(6)。

3.2 RepMix(推論時)— 構造的再パラメータ化

  • (a) BN融合:Conv→BN を 重みW' = W·γ/σ, バイアス b' = (b−μ)·γ/σ + β で単一Conv化(式(3)(4))。

  • (b) 分岐の畳み込み
    1×1 DWC と 恒等(残差)を k×kのDWCカーネル中央に埋め込み、k×k分岐に合算。最終的に 単一のDWC(k×k)+(a)でBN融合済み になる。図1(b)、§3.2。

3.3 MHLA(Multi‑Head Linear Attention)— 式(9)(10)

  • 入力:X∈ℝ^{B×C×H×W} → N=H·Wの1Dトークンへ並べ X∈ℝ^{B×C×N}。Heads=Heチャネル分割:Xᵍ∈ℝ^{B×(C/He)×N}。図1(c)、§3.4。

  • 各ヘッドはトークン次元で 2層線形 + GELU
    Y = GELU( Xᵍ · Wᵢ )(Wᵢ∈ℝ^{N×Nr}), Z = Y · Wₒ(Wₒ∈ℝ^{Nr×N})。ヘッド結合で出力(式(10))。Q/K/Vやsoftmaxは使わない

  • 計算量:Ω=2·N·Nr·C(MHSAより小)。Nがステージで固定(例:Stage3=7×7=49、Stage4=4×4=16)なので、ステージごとに(N, Nr, He)を固定して重みを持つ 実装が簡潔。


4) バリアントとステージ設定(表1の実装写経)

FaceLiVT‑S / M と、MHLA置換版 S‑(Li) / M‑(Li)。サイズは 112×112 入力を前提。各行は「解像度 / チャネル / ブロック数 / ミキサ」を示す:

  • Stem:3×3 s=2 ×2、出力C=S:40, M:64

  • Stage1(28×28):C=S:40, M:64RepMixBlocks=2

  • Stage2(14×14):Down=RepMix 3×3 s=2、C=S:80, M:128RepMixBlocks=4

  • Stage3(7×7):Down=RepMix 3×3 s=2、C=S:160, M:256S/MはMHSA, S-(Li)/M-(Li)はMHLABlocks=6

  • Stage4(4×4):Down=RepMix 3×3 s=2、C=S:320, M:512、同上、Blocks=2

  • Head:AvgPool → FC(512)

He(ヘッド数) はアブレーションで 8 or 16 を検討。16 は精度が上がり(CFP‑FP 94.6%)、8 はレイテンシ短縮(0.41ms)。表4。


5) 学習レシピ(そのまま再現可能)

  • データGlint360K(112×112、テンソル化、[-1,1]正規化)。分散学習。

  • 損失CosFace。最後のFC(512)出力をPartialFCと併用。

  • 最適化AdamW, LR=6e‑3, Polynomial decayBatch=25620 or 40 エポック。埋め込み 512

  • 評価LFW / CFP‑FP / AgeDB‑30 / IJB‑B / IJB‑CiPhone 15 Pro(Core ML) でレイテンシ計測。


6) 主要結果(モバイル実用バランス)

  • FaceLiVT‑M‑(Li)9.75M Param / 386M FLOP、LFW 99.7–99.8CFP‑FP 96.0–97.2AgeDB 96.7–97.6IJB‑C 94.1–95.70.67ms

  • FaceLiVT‑S‑(Li)5.05M Param / 160M FLOP、LFW 99.6–99.7CFP‑FP 94.6–95.1AgeDB 95.6–96.6IJB‑C 82.5–92.70.47ms

  • 速度優位:EdgeFace‑XS(0.6) 比 8.6×、純ViT系比 21.2× 高速(本文・表2、抄録)。


7) アブレーション(実装判断に効く要点)

  • 再パラメータ化の効果(S‑(Li)):

    • fused BNなし0.47→0.60ms と悪化、

    • 残差の再パラメータ化なし0.47→0.50ms

    • DWC 1×1除去 → 精度低下(CFP‑FP −1.4pt, AgeDB −1.0pt)。表3。

  • He=8 vs 16:He=16 で精度↑(CFP‑FP 94.6%)、He=8 で速度↑(0.41ms)。表4。


8) 最小実装スケルトン(PyTorch風・擬似コード)

※形だけ把握できるように簡略化。実環境では SyncBN / AMP / DDP / PartialFC などを追加してください。MHLAはステージごとにN(トークン数)固定の重みを持つ前提で実装しています。

import torch import torch.nn as nn import torch.nn.functional as F # ---------- 共通 ---------- def dw_conv(ch, k=3, s=1, padding=None): if padding is None: padding = k // 2 return nn.Conv2d(ch, ch, k, stride=s, padding=padding, groups=ch, bias=True) # bias=True: BN融合前提 class ChannelMLP(nn.Module): # 1x1 Conv -> BN -> GELU -> 1x1 Conv -> BN def __init__(self, ch, r=3): super().__init__() mid = ch * r self.fc1 = nn.Conv2d(ch, mid, 1, bias=True) self.bn1 = nn.BatchNorm2d(mid) self.fc2 = nn.Conv2d(mid, ch, 1, bias=True) self.bn2 = nn.BatchNorm2d(ch) def forward(self, x): y = self.bn1(self.fc1(x)) y = F.gelu(y) y = self.bn2(self.fc2(y)) return x + y # 残差(式(1))
# ---------- RepMix(学習時) ---------- class RepMix(nn.Module): def __init__(self, ch, k=3, stride=1): super().__init__() self.dw_k = dw_conv(ch, k=k, s=stride) self.dw_1 = dw_conv(ch, k=1, s=stride, padding=0) self.bn = nn.BatchNorm2d(ch) # 推論時に切替 self.reparam = False self.fused_dw = None def forward(self, x): if self.reparam: return x + self.fused_dw(x) # 既にBN融合&分岐結合済み y = self.dw_k(x) + self.dw_1(x) y = self.bn(y) return x + y @torch.no_grad() def fuse_for_inference(self): """ 1) dw_k, dw_1 のBN融合 → (W', b') 2) 1x1分岐と恒等(残差)をk×kへ埋め込み、k×kへ合算 3) 最終的に単一DWCに置換 """ # (a) Conv+BN の融合: W' = W*γ/σ, b' = (b-μ)*γ/σ + β def fuse_conv_bn(conv, bn): W = conv.weight b = conv.bias if conv.bias is not None else torch.zeros(W.size(0), device=W.device) gamma = bn.weight; beta = bn.bias mean = bn.running_mean; var = bn.running_var; eps = bn.eps std = torch.sqrt(var + eps) Wf = W * (gamma / std).reshape(-1, 1, 1, 1) bf = (b - mean) * (gamma / std) + beta return Wf, bf # dw_k と dw_1 を個別にBN融合(同一BNを通るので厳密には和を取ってから融合でも良いが、実装は合算等価に調整) Wk, bk = self.dw_k.weight.data.clone(), self.dw_k.bias.data.clone() W1, b1 = self.dw_1.weight.data.clone(), self.dw_1.bias.data.clone() # BNパラメータを各分岐に適用するため、一旦和を想定しつつ「総和に対するBN融合」に近づける実装が必要。 # 実運用では、(dw_k + dw_1) を1つの擬似Convとして重み合算→そのConvにBN融合、が簡潔。 # ここでは簡易に「擬似合算Conv」を作る: k = self.dw_k.kernel_size[0] pad_k = 0 # 1x1をk×kへ埋め込み W1_expanded = torch.zeros_like(Wk) center = k // 2 W1_expanded[:, :, center:center+1, center:center+1] = W1 # 合算(残差=恒等は畳み込みに埋め込まない。ブロック外で x + ... を維持) W_sum = Wk + W1_expanded b_sum = bk + b1 # 合算Convに対して BN を融合 gamma, beta = self.bn.weight.data, self.bn.bias.data mean, var, eps = self.bn.running_mean, self.bn.running_var, self.bn.eps std = torch.sqrt(var + eps) Wf = W_sum * (gamma / std).reshape(-1, 1, 1, 1) bf = (b_sum - mean) * (gamma / std) + beta fused = dw_conv(Wf.size(0), k=k, s=self.dw_k.stride[0]) fused.weight.data.copy_(Wf); fused.bias.data.copy_(bf) self.fused_dw = fused self.reparam = True # 不要モジュールを無効化 del self.dw_k, self.dw_1, self.bn
# ---------- MHLA(ステージ毎にN固定) ---------- class MHLA(nn.Module): """ トークン次元(N=H*W)に2層線形をかける。ヘッド毎にW_i, W_oを持つ。 Nr = int(N * r) 程度(rは表記上の拡張率)。He=8/16など。 """ def __init__(self, C, H, W, He=16, r=2): super().__init__() self.C, self.H, self.W = C, H, W self.N = H * W self.He = He self.C_h = C // He Nr = int(self.N * r) # (He, N, Nr) / (He, Nr, N) self.Wi = nn.Parameter(torch.randn(He, self.N, Nr) * (self.N ** -0.5)) self.Wo = nn.Parameter(torch.randn(He, Nr, self.N) * (Nr ** -0.5)) def forward(self, x): # x: (B, C, H, W) B, C, H, W = x.shape assert (C == self.C) and (H == self.H) and (W == self.W) N = self.N x = x.view(B, C, N).view(B, self.He, self.C_h, N) # (B, He, C_h, N) # bmmでトークン次元に線形を適用: (B, He, C_h, N) @ (He, N, Nr) -> (B, He, C_h, Nr) Y = torch.einsum('bhcn,hnm->bhcm', x, self.Wi) Y = F.gelu(Y) Z = torch.einsum('bhcm,hmn->bhcn', Y, self.Wo) # (B, He, C_h, N) Z = Z.contiguous().view(B, C, N).view(B, C, H, W) return Z
# ---------- FaceLiVTブロック ---------- class FaceLiVTBlock(nn.Module): def __init__(self, mixer, ch, mlp_ratio=3): super().__init__() self.mixer = mixer # RepMix か MHLA self.mlp = ChannelMLP(ch, r=mlp_ratio) def forward(self, x): x = x + self.mixer(x) # 式(1)の前半 x = self.mlp(x) # 式(1)の後半(残差はChannelMLP内で加算) return x
# ---------- ネットワーク骨格(S-(Li)例) ---------- class FaceLiVT_S_Li(nn.Module): def __init__(self, He=16): super().__init__() # Stem self.stem = nn.Sequential( nn.Conv2d(3, 40, 3, 2, 1, bias=True), nn.BatchNorm2d(40), nn.GELU(), nn.Conv2d(40, 40, 3, 2, 1, bias=True), nn.BatchNorm2d(40), nn.GELU() ) # 112->56->28 # Stage1: 28x28, C=40, RepMix x2 self.s1 = nn.Sequential(*[ FaceLiVTBlock(RepMix(40, k=3, stride=1), 40) for _ in range(2) ]) # Down to 14x14 self.down2 = RepMix(40, k=3, stride=2) # Stage2: 14x14, C=80, RepMix x4 self.s2_proj = nn.Conv2d(40, 80, 1, bias=True) self.s2 = nn.Sequential(*[ FaceLiVTBlock(RepMix(80, k=3, stride=1), 80) for _ in range(4) ]) # Down to 7x7 self.down3 = RepMix(80, k=3, stride=2) # Stage3: 7x7, C=160, MHLA x6 self.s3_proj = nn.Conv2d(80, 160, 1, bias=True) self.s3 = nn.Sequential(*[ FaceLiVTBlock(MHLA(160, 7, 7, He=He, r=2), 160) for _ in range(6) ]) # Down to 4x4 self.down4 = RepMix(160, k=3, stride=2) # Stage4: 4x4, C=320, MHLA x2 self.s4_proj = nn.Conv2d(160, 320, 1, bias=True) self.s4 = nn.Sequential(*[ FaceLiVTBlock(MHLA(320, 4, 4, He=He, r=2), 320) for _ in range(2) ]) self.head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(320, 512, bias=True) # CosFace学習で使用 ) def forward(self, x): x = self.stem(x) # 28x28, C=40 x = self.s1(x) x = self.down2(x); x = self.s2_proj(x) # 14x14, C=80 x = self.s2(x) x = self.down3(x); x = self.s3_proj(x) # 7x7, C=160 x = self.s3(x) x = self.down4(x); x = self.s4_proj(x) # 4x4, C=320 x = self.s4(x) emb = self.head(x) # 512-d return emb @torch.no_grad() def reparam_for_inference(self): # RepMixの再パラメータ化(推論高速化) for m in self.modules(): if isinstance(m, RepMix): m.fuse_for_inference()

9) 推論最適化・移植のポイント

  • 学習→推論への切り替え時に、RepMixの fuse_for_inference() を必ず実行(BN融合+分岐統合)。表3の通り、fused BN はレイテンシ短縮に効く(0.47→0.60ms防止)。

  • MHLAステージ解像度依存(N固定) の重みを持つため、Core ML / TensorRT 変換前に形状を固定

  • iPhone 15 Pro(Core ML) の計測では、S-(Li)で0.47msM-(Li)で0.67ms。エッジ展開では INT8/FP16 量子化も候補。


10) 評価・再現用チェックリスト

  • 112×112, [-1,1] 正規化

  • CosFace + PartialFC, 埋め込み512

  • AdamW, LR=6e‑3, poly decay, batch=256, 20/40ep

  • S-(Li)Param≈5.05M / FLOP≈160M / LFW≈99.6–99.7 / CFP‑FP≈94.6–95.1 / AgeDB≈95.6–96.6 / IJB‑C≈82.5–92.7 / 0.47ms

  • M-(Li)Param≈9.75M / FLOP≈386M / LFW≈99.7–99.8 / CFP‑FP≈96.0–97.2 / AgeDB≈96.7–97.6 / IJB‑C≈94.1–95.7 / 0.67ms

  • He=16で精度優先、He=8で速度優先(表4)。


11) まとめ(実装の勘所)

  • 前段はRepMix:DWC(k×k)+DWC(1×1)→BN→残差。推論では単一DWCへ折り畳む(BN融合を忘れない)。

  • 後段はMHLAトークン次元に線形×2+GELUNはステージで固定MHSA不要で長距離関係を近似しつつ高速。

  • ハイパラr(MLP)=3, He=8/16, S/Mは表1のチャンネル/ブロック数 を踏襲。

  • 実運用Core ML変換再パラメータ化済み重みで、EdgeFaceやViT系より高速

図1(p.2)のブロック図、表1(p.4)の設定、表2–4(p.5)の数値をそのままコードに落とし込めば、学習〜推論まで一気通貫で再現できます。

――以上。

0 件のコメント:

コメントを投稿

あいうえお