Giải thích ArcFace Loss
Giải thích chi tiết ArcFace Loss
Công thức tổng quát
\[L = -\frac{1}{N}\sum_{i=1}^{N}\log\frac{e^{s(\cos(\theta_{y_i}+m))}}{e^{s(\cos(\theta_{y_i}+m))}+\sum_{j=1,j\neq y_i}^{n}e^{s\cos\theta_j}}\]
Ý nghĩa từng thành phần
1. Các ký hiệu cơ bản
| Ký hiệu | Ý nghĩa |
|---|---|
| Tổng số mẫu trong batch | |
| Tổng số class (danh tính) | |
| Nhãn đúng của mẫu thứ | |
| Góc giữa feature vector và weight vector của class | |
| Scale factor (thường = 64) | |
| Angular margin (thường = 0.5 radian ~ 28.6°) |
2. Góc \( \theta \) là gì?
Trong ArcFace, cả feature vector \( x_i \) và weight vector \( W_j \) đều được normalize về độ dài 1:
\[\cos\theta_j = \frac{W_j^T \cdot x_i}{\|W_j\| \cdot \|x_i\|} = W_j^T \cdot x_i\]
Vì sau khi normalize, tích vô hướng chính là cosine của góc giữa 2 vector.
3. Tại sao cộng thêm margin \( m \)?
Đây là điểm mấu chốt của ArcFace. Hãy xem xét 2 trường hợp:
Softmax thông thường:
\[P(y_i) = \frac{e^{s\cos\theta_{y_i}}}{e^{s\cos\theta_{y_i}} + \sum_{j \neq y_i} e^{s\cos\theta_j}}\]
ArcFace:
\[P(y_i) = \frac{e^{s\cos(\theta_{y_i} + m)}}{e^{s\cos(\theta_{y_i} + m)} + \sum_{j \neq y_i} e^{s\cos\theta_j}}\]
Việc cộng thêm \( m \) vào góc \( \theta_{y_i} \) có nghĩa:
- \( \cos(\theta + m) < \cos(\theta) \) khi \( 0 < \theta < \pi - m \)
- Model phải học để feature gần class đúng hơn một khoảng margin \( m \)
4. Hình dung trực quan
Class 1 weight
↑
/|\
/ | \
/ | \
/ θ | \
/ ↓ | \
/←-m-→| \
/ | \
←───────┼───────→ Feature space
|
↓
Class 2 weight
Với Softmax thường, decision boundary nằm ở giữa 2 class. Với ArcFace, mỗi class phải "lùi lại" một góc \( m \), tạo ra vùng margin giữa các class.
5. Phân tích công thức từng bước
Bước 1: Tính góc giữa feature và mỗi class weight
\[\theta_j = \arccos(W_j^T \cdot x_i)\]
Bước 2: Với class đúng, cộng thêm margin
\[\theta_{y_i}^{new} = \theta_{y_i} + m\]
Bước 3: Tính logit cho mỗi class
\[\text{logit}_j = \begin{cases}s \cdot \cos(\theta_j + m) & \text{nếu } j = y_i \\s \cdot \cos\theta_j & \text{nếu } j \neq y_i\end{cases}\]
Bước 4: Áp dụng Softmax và Cross-Entropy
\[L_i = -\log\frac{e^{\text{logit}_{y_i}}}{\sum_j e^{\text{logit}_j}}\]
Bước 5: Trung bình qua batch
\[L = \frac{1}{N}\sum_{i=1}^{N} L_i\]
6. Vai trò của Scale factor \( s \)
Scale factor \( s \) giúp:
- Phóng đại sự khác biệt giữa các logit
- Làm sắc nét phân phối softmax
- Không có \( s \), gradient sẽ rất nhỏ vì \( \cos\theta \in [-1, 1] \)
Ví dụ với \( s = 64 \):\[e^{64 \times 0.9} \approx 10^{25} \gg e^{64 \times 0.5} \approx 10^{14}\]
7. So sánh với các phương pháp khác
| Phương pháp | Công thức logit cho class đúng |
|---|---|
| Softmax | |
| SphereFace | |
| CosFace | |
| ArcFace |
ArcFace cộng margin trực tiếp vào góc, tạo ra angular margin có ý nghĩa hình học rõ ràng nhất.
8. Tại sao ArcFace hiệu quả?
- Angular margin có ý nghĩa hình học trực tiếp trên hypersphere
- Consistent margin ở mọi góc (không như multiplicative margin của SphereFace)
- Dễ implement và stable trong training
- Feature vectors được phân bố đều trên hypersphere
9. Code minh họa (PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
class ArcFaceLoss(nn.Module):
def __init__(self, s=64.0, m=0.5):
super().__init__()
self.s = s
self.m = m
def forward(self, logits, labels):
# logits = cos(theta), đã normalize
# logits shape: (batch_size, num_classes)
# Lấy cos(theta) của class đúng
cos_theta = logits.gather(1, labels.view(-1, 1)).squeeze()
# Tính theta từ cos(theta)
theta = torch.acos(torch.clamp(cos_theta, -1 + 1e-7, 1 - 1e-7))
# Cộng margin và tính cos(theta + m)
cos_theta_m = torch.cos(theta + self.m)
# Thay thế logit của class đúng
one_hot = F.one_hot(labels, num_classes=logits.size(1))
output = logits * (1 - one_hot) + cos_theta_m.unsqueeze(1) * one_hot
# Scale và tính cross-entropy
output = output * self.s
loss = F.cross_entropy(output, labels)
return loss
10. Tóm tắt
ArcFace Loss buộc model học các feature sao cho:
- Feature của cùng class gần nhau (intra-class compactness)
- Feature của khác class xa nhau (inter-class discrepancy)
- Khoảng cách giữa các class ít nhất là góc \( 2m \)
Điều này đặc biệt quan trọng trong face recognition, nơi cần phân biệt hàng triệu danh tính khác nhau.