記念写真
https://arxiv.org/pdf/2506.10361
以下は、実装者向けに最短距離でコード化できるよう整理した“FaceLiVT”の完全要約です。図表や式の参照も付けています。実装の核は RepMix(構造的再パラメータ化付きの深層可分畳み込み) と MHLA(Multi‑Head Linear Attention) の2点です。図1(p.2)にブロック構造、表1(p.4)に各ステージ設定、表2–4(p.5)に結果とアブレーションの要点がまとまっています。
FaceLiVT は、CNN×Transformerのハイブリッドで、前段=RepMix、後段=MHLA をトークンミキサとして組み合わせ、モバイル推論向けにレイテンシと精度の両立を図る設計。図1参照。
MHLA は自己注意(MHSA)の 二乗計算を避け、トークン次元に2層の線形(+非線形) をかける設計。計算量は Ω(MHLA)=2(N·Nr)·C と低く、MHSAの Ω=4NC²+2N²C より小さい(§3.4)。
構造的再パラメータ化(BN融合・分岐結合・残差の折り畳み)で、学習時は表現力、推論時は単一DWCへ畳み込み、Core ML等で高速化(§3.2)。
入力:112×112 のアライン済み顔画像(Glint360Kで学習時に使用)。値域は [-1,1] 正規化。
Stem:3×3 s=2 畳み込み ×2(112→56→28)。図1、表1。
ステージ数:4(Stage1/2=RepMix、Stage3/4=MHLA 版ではMHLA)。各ブロックは TokenMixer → 残差加算 → Channel MLP → 残差加算(式(1))。
Channel MLP:2層全結合(拡張比 r=3)、BN、GELU 1回(式(2))。実装は 1×1 Conv×2+BN+GELU でOK。
ヘッド:Global AvgPool → FC(512) で埋め込み512次元(CosFace学習用)。表1。
DWC k×k と DWC 1×1 を足し合わせ、その後 BN、元入力を残差接続:
X' = X + BN( DWC_k×k(X) + DWC_1×1(X) )
(非線形は省略)。図1(b)、式(6)。
(a) BN融合:Conv→BN を 重みW' = W·γ/σ, バイアス b' = (b−μ)·γ/σ + β で単一Conv化(式(3)(4))。
(b) 分岐の畳み込み:
1×1 DWC と 恒等(残差)を k×kのDWCカーネル中央に埋め込み、k×k分岐に合算。最終的に 単一のDWC(k×k)+(a)でBN融合済み になる。図1(b)、§3.2。
入力:X∈ℝ^{B×C×H×W} → N=H·Wの1Dトークンへ並べ X∈ℝ^{B×C×N}。Heads=He に チャネル分割:Xᵍ∈ℝ^{B×(C/He)×N}。図1(c)、§3.4。
各ヘッドはトークン次元で 2層線形 + GELU:
Y = GELU( Xᵍ · Wᵢ )
(Wᵢ∈ℝ^{N×Nr}), Z = Y · Wₒ
(Wₒ∈ℝ^{Nr×N})。ヘッド結合で出力(式(10))。Q/K/Vやsoftmaxは使わない。
計算量:Ω=2·N·Nr·C(MHSAより小)。Nがステージで固定(例:Stage3=7×7=49、Stage4=4×4=16)なので、ステージごとに(N, Nr, He)を固定して重みを持つ 実装が簡潔。
FaceLiVT‑S / M と、MHLA置換版 S‑(Li) / M‑(Li)。サイズは 112×112 入力を前提。各行は「解像度 / チャネル / ブロック数 / ミキサ」を示す:
Stem:3×3 s=2 ×2、出力C=S:40, M:64
Stage1(28×28):C=S:40, M:64、RepMix、Blocks=2
Stage2(14×14):Down=RepMix 3×3 s=2、C=S:80, M:128、RepMix、Blocks=4
Stage3(7×7):Down=RepMix 3×3 s=2、C=S:160, M:256、S/MはMHSA, S-(Li)/M-(Li)はMHLA、Blocks=6
Stage4(4×4):Down=RepMix 3×3 s=2、C=S:320, M:512、同上、Blocks=2
Head:AvgPool → FC(512)
He(ヘッド数) はアブレーションで 8 or 16 を検討。16 は精度が上がり(CFP‑FP 94.6%)、8 はレイテンシ短縮(0.41ms)。表4。
データ:Glint360K(112×112、テンソル化、[-1,1]正規化)。分散学習。
損失:CosFace。最後のFC(512)出力をPartialFCと併用。
最適化:AdamW, LR=6e‑3, Polynomial decay。Batch=256。20 or 40 エポック。埋め込み 512。
評価:LFW / CFP‑FP / AgeDB‑30 / IJB‑B / IJB‑C。iPhone 15 Pro(Core ML) でレイテンシ計測。
FaceLiVT‑M‑(Li):9.75M Param / 386M FLOP、LFW 99.7–99.8、CFP‑FP 96.0–97.2、AgeDB 96.7–97.6、IJB‑C 94.1–95.7、0.67ms。
FaceLiVT‑S‑(Li):5.05M Param / 160M FLOP、LFW 99.6–99.7、CFP‑FP 94.6–95.1、AgeDB 95.6–96.6、IJB‑C 82.5–92.7、0.47ms。
速度優位:EdgeFace‑XS(0.6) 比 8.6×、純ViT系比 21.2× 高速(本文・表2、抄録)。
再パラメータ化の効果(S‑(Li)):
fused BNなし → 0.47→0.60ms と悪化、
残差の再パラメータ化なし → 0.47→0.50ms、
DWC 1×1除去 → 精度低下(CFP‑FP −1.4pt, AgeDB −1.0pt)。表3。
He=8 vs 16:He=16 で精度↑(CFP‑FP 94.6%)、He=8 で速度↑(0.41ms)。表4。
※形だけ把握できるように簡略化。実環境では SyncBN / AMP / DDP / PartialFC などを追加してください。MHLAはステージごとにN(トークン数)固定の重みを持つ前提で実装しています。
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------- 共通 ----------
def dw_conv(ch, k=3, s=1, padding=None):
if padding is None: padding = k // 2
return nn.Conv2d(ch, ch, k, stride=s, padding=padding, groups=ch, bias=True) # bias=True: BN融合前提
class ChannelMLP(nn.Module):
# 1x1 Conv -> BN -> GELU -> 1x1 Conv -> BN
def __init__(self, ch, r=3):
super().__init__()
mid = ch * r
self.fc1 = nn.Conv2d(ch, mid, 1, bias=True)
self.bn1 = nn.BatchNorm2d(mid)
self.fc2 = nn.Conv2d(mid, ch, 1, bias=True)
self.bn2 = nn.BatchNorm2d(ch)
def forward(self, x):
y = self.bn1(self.fc1(x))
y = F.gelu(y)
y = self.bn2(self.fc2(y))
return x + y # 残差(式(1))
# ---------- RepMix(学習時) ----------
class RepMix(nn.Module):
def __init__(self, ch, k=3, stride=1):
super().__init__()
self.dw_k = dw_conv(ch, k=k, s=stride)
self.dw_1 = dw_conv(ch, k=1, s=stride, padding=0)
self.bn = nn.BatchNorm2d(ch)
# 推論時に切替
self.reparam = False
self.fused_dw = None
def forward(self, x):
if self.reparam:
return x + self.fused_dw(x) # 既にBN融合&分岐結合済み
y = self.dw_k(x) + self.dw_1(x)
y = self.bn(y)
return x + y
@torch.no_grad()
def fuse_for_inference(self):
"""
1) dw_k, dw_1 のBN融合 → (W', b')
2) 1x1分岐と恒等(残差)をk×kへ埋め込み、k×kへ合算
3) 最終的に単一DWCに置換
"""
# (a) Conv+BN の融合: W' = W*γ/σ, b' = (b-μ)*γ/σ + β
def fuse_conv_bn(conv, bn):
W = conv.weight
b = conv.bias if conv.bias is not None else torch.zeros(W.size(0), device=W.device)
gamma = bn.weight; beta = bn.bias
mean = bn.running_mean; var = bn.running_var; eps = bn.eps
std = torch.sqrt(var + eps)
Wf = W * (gamma / std).reshape(-1, 1, 1, 1)
bf = (b - mean) * (gamma / std) + beta
return Wf, bf
# dw_k と dw_1 を個別にBN融合(同一BNを通るので厳密には和を取ってから融合でも良いが、実装は合算等価に調整)
Wk, bk = self.dw_k.weight.data.clone(), self.dw_k.bias.data.clone()
W1, b1 = self.dw_1.weight.data.clone(), self.dw_1.bias.data.clone()
# BNパラメータを各分岐に適用するため、一旦和を想定しつつ「総和に対するBN融合」に近づける実装が必要。
# 実運用では、(dw_k + dw_1) を1つの擬似Convとして重み合算→そのConvにBN融合、が簡潔。
# ここでは簡易に「擬似合算Conv」を作る:
k = self.dw_k.kernel_size[0]
pad_k = 0
# 1x1をk×kへ埋め込み
W1_expanded = torch.zeros_like(Wk)
center = k // 2
W1_expanded[:, :, center:center+1, center:center+1] = W1
# 合算(残差=恒等は畳み込みに埋め込まない。ブロック外で x + ... を維持)
W_sum = Wk + W1_expanded
b_sum = bk + b1
# 合算Convに対して BN を融合
gamma, beta = self.bn.weight.data, self.bn.bias.data
mean, var, eps = self.bn.running_mean, self.bn.running_var, self.bn.eps
std = torch.sqrt(var + eps)
Wf = W_sum * (gamma / std).reshape(-1, 1, 1, 1)
bf = (b_sum - mean) * (gamma / std) + beta
fused = dw_conv(Wf.size(0), k=k, s=self.dw_k.stride[0])
fused.weight.data.copy_(Wf); fused.bias.data.copy_(bf)
self.fused_dw = fused
self.reparam = True
# 不要モジュールを無効化
del self.dw_k, self.dw_1, self.bn
# ---------- MHLA(ステージ毎にN固定) ----------
class MHLA(nn.Module):
"""
トークン次元(N=H*W)に2層線形をかける。ヘッド毎にW_i, W_oを持つ。
Nr = int(N * r) 程度(rは表記上の拡張率)。He=8/16など。
"""
def __init__(self, C, H, W, He=16, r=2):
super().__init__()
self.C, self.H, self.W = C, H, W
self.N = H * W
self.He = He
self.C_h = C // He
Nr = int(self.N * r)
# (He, N, Nr) / (He, Nr, N)
self.Wi = nn.Parameter(torch.randn(He, self.N, Nr) * (self.N ** -0.5))
self.Wo = nn.Parameter(torch.randn(He, Nr, self.N) * (Nr ** -0.5))
def forward(self, x): # x: (B, C, H, W)
B, C, H, W = x.shape
assert (C == self.C) and (H == self.H) and (W == self.W)
N = self.N
x = x.view(B, C, N).view(B, self.He, self.C_h, N) # (B, He, C_h, N)
# bmmでトークン次元に線形を適用: (B, He, C_h, N) @ (He, N, Nr) -> (B, He, C_h, Nr)
Y = torch.einsum('bhcn,hnm->bhcm', x, self.Wi)
Y = F.gelu(Y)
Z = torch.einsum('bhcm,hmn->bhcn', Y, self.Wo) # (B, He, C_h, N)
Z = Z.contiguous().view(B, C, N).view(B, C, H, W)
return Z
# ---------- FaceLiVTブロック ----------
class FaceLiVTBlock(nn.Module):
def __init__(self, mixer, ch, mlp_ratio=3):
super().__init__()
self.mixer = mixer # RepMix か MHLA
self.mlp = ChannelMLP(ch, r=mlp_ratio)
def forward(self, x):
x = x + self.mixer(x) # 式(1)の前半
x = self.mlp(x) # 式(1)の後半(残差はChannelMLP内で加算)
return x
# ---------- ネットワーク骨格(S-(Li)例) ----------
class FaceLiVT_S_Li(nn.Module):
def __init__(self, He=16):
super().__init__()
# Stem
self.stem = nn.Sequential(
nn.Conv2d(3, 40, 3, 2, 1, bias=True),
nn.BatchNorm2d(40), nn.GELU(),
nn.Conv2d(40, 40, 3, 2, 1, bias=True),
nn.BatchNorm2d(40), nn.GELU()
) # 112->56->28
# Stage1: 28x28, C=40, RepMix x2
self.s1 = nn.Sequential(*[
FaceLiVTBlock(RepMix(40, k=3, stride=1), 40) for _ in range(2)
])
# Down to 14x14
self.down2 = RepMix(40, k=3, stride=2)
# Stage2: 14x14, C=80, RepMix x4
self.s2_proj = nn.Conv2d(40, 80, 1, bias=True)
self.s2 = nn.Sequential(*[
FaceLiVTBlock(RepMix(80, k=3, stride=1), 80) for _ in range(4)
])
# Down to 7x7
self.down3 = RepMix(80, k=3, stride=2)
# Stage3: 7x7, C=160, MHLA x6
self.s3_proj = nn.Conv2d(80, 160, 1, bias=True)
self.s3 = nn.Sequential(*[
FaceLiVTBlock(MHLA(160, 7, 7, He=He, r=2), 160) for _ in range(6)
])
# Down to 4x4
self.down4 = RepMix(160, k=3, stride=2)
# Stage4: 4x4, C=320, MHLA x2
self.s4_proj = nn.Conv2d(160, 320, 1, bias=True)
self.s4 = nn.Sequential(*[
FaceLiVTBlock(MHLA(320, 4, 4, He=He, r=2), 320) for _ in range(2)
])
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(320, 512, bias=True) # CosFace学習で使用
)
def forward(self, x):
x = self.stem(x) # 28x28, C=40
x = self.s1(x)
x = self.down2(x); x = self.s2_proj(x) # 14x14, C=80
x = self.s2(x)
x = self.down3(x); x = self.s3_proj(x) # 7x7, C=160
x = self.s3(x)
x = self.down4(x); x = self.s4_proj(x) # 4x4, C=320
x = self.s4(x)
emb = self.head(x) # 512-d
return emb
@torch.no_grad()
def reparam_for_inference(self):
# RepMixの再パラメータ化(推論高速化)
for m in self.modules():
if isinstance(m, RepMix):
m.fuse_for_inference()
学習→推論への切り替え時に、RepMixの fuse_for_inference() を必ず実行(BN融合+分岐統合)。表3の通り、fused BN はレイテンシ短縮に効く(0.47→0.60ms防止)。
MHLA は ステージ解像度依存(N固定) の重みを持つため、Core ML / TensorRT 変換前に形状を固定。
iPhone 15 Pro(Core ML) の計測では、S-(Li)で0.47ms、M-(Li)で0.67ms。エッジ展開では INT8/FP16 量子化も候補。
112×112, [-1,1] 正規化
CosFace + PartialFC, 埋め込み512
AdamW, LR=6e‑3, poly decay, batch=256, 20/40ep
S-(Li)=Param≈5.05M / FLOP≈160M / LFW≈99.6–99.7 / CFP‑FP≈94.6–95.1 / AgeDB≈95.6–96.6 / IJB‑C≈82.5–92.7 / 0.47ms
M-(Li)=Param≈9.75M / FLOP≈386M / LFW≈99.7–99.8 / CFP‑FP≈96.0–97.2 / AgeDB≈96.7–97.6 / IJB‑C≈94.1–95.7 / 0.67ms
He=16で精度優先、He=8で速度優先(表4)。
前段はRepMix:DWC(k×k)+DWC(1×1)→BN→残差。推論では単一DWCへ折り畳む(BN融合を忘れない)。
後段はMHLA:トークン次元に線形×2+GELU。Nはステージで固定。MHSA不要で長距離関係を近似しつつ高速。
ハイパラ:r(MLP)=3, He=8/16, S/Mは表1のチャンネル/ブロック数 を踏襲。
実運用:Core ML変換+再パラメータ化済み重みで、EdgeFaceやViT系より高速。
図1(p.2)のブロック図、表1(p.4)の設定、表2–4(p.5)の数値をそのままコードに落とし込めば、学習〜推論まで一気通貫で再現できます。
――以上。
あいうえお