Zettelkasten

Attention Pooling은 학습 가능한 가중치로 시퀀스를 집계한다

·수정 2026.04.23·수정 2

요약

Attention Pooling은 시퀀스의 각 요소에 학습된 가중치를 부여하여 하나의 벡터로 집계하는 방법이다. Mean/Max Pooling과 달리 중요한 요소에 더 높은 가중치를 동적으로 할당할 수 있다.

본문

기본 구조

Input: H = [h₁, h₂, ..., hₙ]  (n개의 hidden states, 각 d차원)
Output: c (context vector, d차원)

학습 과정

1. Attention Score 계산

eᵢ = v^T · tanh(W · hᵢ + b)
  • W (d × d): 학습 가능한 가중치 행렬
  • b (d): 학습 가능한 bias
  • v (d): 학습 가능한 context vector

2. Softmax로 정규화

αᵢ = softmax(eᵢ) = exp(eᵢ) / Σⱼ exp(eⱼ)

3. 가중합으로 집계

c = Σᵢ αᵢ · hᵢ

학습이 일어나는 방식

  1. Forward Pass: 현재 W, b, v로 attention weight α 계산 → context vector c 생성
  2. Loss 계산: downstream task (분류, 생성 등)의 loss 계산
  3. Backward Pass: loss에 대한 gradient가 c → α → (W, b, v)로 역전파
  4. 파라미터 업데이트: W, b, v가 task에 유용한 정보에 높은 가중치를 주도록 학습

핵심: "어떤 hidden state가 중요한가"를 end-to-end로 학습한다. 정답 레이블이 직접 attention을 supervise하지 않아도, task loss를 통해 간접적으로 학습된다.

Mean/Max Pooling과 비교

방식 학습 파라미터 특징
Mean Pooling 없음 모든 요소 동등 취급
Max Pooling 없음 가장 큰 값만 선택
Attention Pooling W, b, v 중요도 기반 가중합

코드 예시 (PyTorch)

class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.W = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, H, mask=None):
        # H: (batch, seq_len, hidden_dim)
        scores = self.v(torch.tanh(self.W(H)))  # (batch, seq_len, 1)
        scores = scores.squeeze(-1)  # (batch, seq_len)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)  # (batch, seq_len)
        context = torch.bmm(attn_weights.unsqueeze(1), H).squeeze(1)
        return context, attn_weights

참고

  • Self-Attention과 다름: Self-Attention은 Query가 입력에서 오지만, Attention Pooling은 학습된 Query(v) 사용
  • 문서 분류, 감성 분석 등 시퀀스를 하나의 벡터로 압축해야 할 때 유용