요약
Attention Pooling은 시퀀스의 각 요소에 학습된 가중치를 부여하여 하나의 벡터로 집계하는 방법이다. Mean/Max Pooling과 달리 중요한 요소에 더 높은 가중치를 동적으로 할당할 수 있다.
본문
기본 구조
Input: H = [h₁, h₂, ..., hₙ] (n개의 hidden states, 각 d차원)
Output: c (context vector, d차원)
학습 과정
1. Attention Score 계산
eᵢ = v^T · tanh(W · hᵢ + b)
W(d × d): 학습 가능한 가중치 행렬b(d): 학습 가능한 biasv(d): 학습 가능한 context vector
2. Softmax로 정규화
αᵢ = softmax(eᵢ) = exp(eᵢ) / Σⱼ exp(eⱼ)
3. 가중합으로 집계
c = Σᵢ αᵢ · hᵢ
학습이 일어나는 방식
- Forward Pass: 현재 W, b, v로 attention weight α 계산 → context vector c 생성
- Loss 계산: downstream task (분류, 생성 등)의 loss 계산
- Backward Pass: loss에 대한 gradient가 c → α → (W, b, v)로 역전파
- 파라미터 업데이트: W, b, v가 task에 유용한 정보에 높은 가중치를 주도록 학습
핵심: "어떤 hidden state가 중요한가"를 end-to-end로 학습한다. 정답 레이블이 직접 attention을 supervise하지 않아도, task loss를 통해 간접적으로 학습된다.
Mean/Max Pooling과 비교
| 방식 | 학습 파라미터 | 특징 |
|---|---|---|
| Mean Pooling | 없음 | 모든 요소 동등 취급 |
| Max Pooling | 없음 | 가장 큰 값만 선택 |
| Attention Pooling | W, b, v | 중요도 기반 가중합 |
코드 예시 (PyTorch)
class AttentionPooling(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.W = nn.Linear(hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, H, mask=None):
# H: (batch, seq_len, hidden_dim)
scores = self.v(torch.tanh(self.W(H))) # (batch, seq_len, 1)
scores = scores.squeeze(-1) # (batch, seq_len)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1) # (batch, seq_len)
context = torch.bmm(attn_weights.unsqueeze(1), H).squeeze(1)
return context, attn_weights
참고
- Self-Attention과 다름: Self-Attention은 Query가 입력에서 오지만, Attention Pooling은 학습된 Query(v) 사용
- 문서 분류, 감성 분석 등 시퀀스를 하나의 벡터로 압축해야 할 때 유용