Flash Attention

IO-Aware Tiling과 Online Softmax를 통한 메모리 벽 돌파

1. Introduction: Compute-bound인가, Memory-bound인가?

Transformer 모델의 최적화를 논할 때 가장 먼저 던져야 할 질문은 "무엇이 모델을 느리게 만드는가?"입니다. 거대한 행렬의 곱셈(MatMul)은 GPU의 막강한 병렬 처리 능력(Tensor Cores) 덕분에 연산기 자체의 속도 한계(Compute-bound)에 부딪히는 경우가 적습니다.

진짜 문제는 Memory-bound입니다. 연산기가 아무리 빨라도 데이터를 담고 있는 크고 느린 메모리(HBM)에서 작고 빠른 캐시 메모리(SRAM)로 데이터를 퍼 나르는 대역폭(Bandwidth)이 부족하여 연산기가 노는 현상이 발생합니다. FlashAttention은 연산량(FLOPs)을 줄이는 알고리즘이 아닙니다. 오히려 연산량을 조금 늘리더라도 메모리 접근 횟수(Memory Accesses)를 극단적으로 줄여(IO-Aware) 전체 속도를 비약적으로 끌어올린 혁명적인 시스템 알고리즘입니다.


2. 하드웨어적 비효율성: Standard Attention의 O(N2)O(N^2) IO 병목

표준 어텐션(Standard Attention)의 수식을 GPU 메모리 계층 관점에서 해부해 봅시다.

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V

이 수식을 처리하기 위해 GPU는 다음과 같은 미련한 과정(Materialization)을 거칩니다.

  1. HBM에서 Q,KQ, K를 SRAM으로 읽어와 QKTQK^T를 계산하고, 그 결과인 N×NN \times N 크기의 거대한 Score 행렬 SS다시 HBM에 씁니다(Write).
  2. HBM에서 SS를 다시 읽어와(Read) Softmax를 적용한 행렬 PP를 만들고 또 HBM에 씁니다.
  3. 마지막으로 HBM에서 PPVV를 읽어와 최종 결과 행렬 OO를 계산하고 HBM에 씁니다.

시퀀스 길이 NN이 길어지면 이 N×NN \times N 크기의 중간 행렬(S,PS, P)을 HBM에 썼다가 지우는, 즉 HBM에 N2N^2 크기의 중간 결과물을 물리적으로 생성(Materialization) 하는 I/O(Input/Output) 트래픽이 기하급수적으로 폭발합니다. 이것이 긴 문맥(Long Context)을 처리할 때 발생하는 OOM(Out of Memory) 과 속도 저하의 주범입니다.


3. 해결책 1: Tiling (타일링) 기법 - SRAM 안에서 끝내기

FlashAttention의 첫 번째 핵심 아이디어는 Tiling(블록 단위 쪼개기) 입니다. 거대한 N×NN \times N 행렬을 한 번에 계산하는 대신, Q,K,VQ, K, V 행렬을 GPU의 작은 SRAM 용량(약 20MB 내외)에 쏙 들어갈 수 있는 크기의 블록(Block) 단위로 나눕니다.

  1. HBM에서 Q,K,VQ, K, V의 작은 블록들을 SRAM으로 한 번만 가져옵니다.
  2. SRAM 내부에서 해당 블록 간의 연산(MatMul \rightarrow Softmax \rightarrow MatMul)을 끝까지 수행합니다.
  3. 중간 행렬(S,PS, P)은 HBM에 절대 쓰지 않고 버리며, 오직 최종 결과 행렬 OO의 블록만 HBM에 씁니다.

FlashAttention은 N2N^2 행렬을 HBM에 쓰는 대신, SRAM 내에서 Fusion(연산 통합) 하여 최종 결과물 OO만 HBM으로 보냅니다.
이 단순해 보이는 아이디어를 통해 HBM 접근 횟수를 O(N2)O(N^2)에서 블록 크기 MM에 반비례하는 O(N2d/M)O(N^2 d / M) 수준으로 비약적으로 감소시킵니다.


4. 해결책 2: Online Softmax의 수학적 기적

Tiling 기법을 구현하는 데 있어 가장 거대한 수학적 장벽이 존재합니다. 바로 Softmax의 특성입니다. 표준 Softmax 수식은 다음과 같습니다.

softmax(x)i=exij=1Nexj\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}

분모를 계산하려면 해당 행(Row)의 모든 원소를 한 번에 다 봐야만 합니다. 하지만 Tiling은 데이터를 쪼개서 가져오므로 전체 합을 알 수 없습니다. 이를 수학적으로 해결한 것이 Online Softmax (Safe Softmax[지수 연산(exe^x) 시 발생할 수 있는 Overflow를 방지하기 위해 매 스텝 최댓값(mm)을 뺌]의 점진적 업데이트) 입니다.

FlashAttention은 전체 합을 기다리지 않고, 새로운 블록이 들어올 때마다 로컬 최대값(mm)과 지수합(\ell)을 점진적으로 갱신(Update)합니다.

두 개의 블록(이전 블록 1, 새로운 블록 2)이 있다고 가정할 때, 결합된 새로운 최대값과 지수합은 다음과 같이 업데이트됩니다.

mnew=max(m1,m2)m_{\text{new}} = \max(m_1, m_2)

new=em1mnew1+em2mnew2\ell_{\text{new}} = e^{m_1 - m_{\text{new}}} \ell_1 + e^{m_2 - m_{\text{new}}} \ell_2

이 수학적 트릭을 통해 이전 블록의 결과를 새로운 로컬 최대값에 맞춰 재조정(Rescaling)할 수 있습니다. 결과적으로, 모든 데이터를 한 번에 보지 않고 부분합만으로 연산했음에도 불구하고, 근사치(Approximation)가 아닌 표준 Attention과 소수점 아래까지 완벽히 동일한(Exact) 결과를 산출합니다.


5. Code Snippets

5.1 standard vs online softmax

import torch

# 1. Standard Softmax (전체 데이터를 한 번에 봐야 함)
def standard_softmax(x):
    # 전체 행에서 최댓값을 찾아야 함 -> 전역적 데이터 접근 발생
    m = torch.max(x, dim=-1, keepdim=True).values
    # 전체 행의 지수 합을 구해야 함 -> HBM Read/Write 반복
    exp_x = torch.exp(x - m)
    s = torch.sum(exp_x, dim=-1, keepdim=True)
    return exp_x / s

# 2. Online Softmax (데이터를 블록 단위로 쪼개서 점진적 업데이트)
# FlashAttention의 핵심: 전체를 다 보지 않고도 결과를 확정할 수 있음
def online_softmax_block_update(x_block, prev_m, prev_l):
    """
    x_block: 현재 처리 중인 데이터 블록
    prev_m: 이전 블록까지의 최댓값 (m_i-1)
    prev_l: 이전 블록까지의 지수 합 (l_i-1)
    """
    # 1. 현재 블록의 최댓값 계산
    curr_m = torch.max(x_block)
    
    # 2. 새로운 전역 최댓값 갱신
    new_m = torch.max(prev_m, curr_m)
    
    # 3. Rescale factor 계산 (이전 값들을 새로운 최댓값 기준으로 보정)
    # 이 수식이 있어 데이터가 쪼개져 있어도 수학적 정확성이 유지됨
    # l_new = e^(m_prev - m_new) * l_prev + e^(m_curr - m_new) * l_curr
    # 여기서 l_curr는 현재 블록의 exp 합: sum(exp(x_block - new_m))
    new_l = torch.exp(prev_m - new_m) * prev_l + torch.sum(torch.exp(x_block - new_m))
    
    return new_m, new_l

# 비교:
# Standard는 N이 커지면 메모리 부족(OOM)이 발생하지만,
# Online은 블록 단위로 연산기에 로드하여 처리하므로 메모리 I/O를 획기적으로 줄임

5.2 flashattention kernel (Triton)

## 5.2 FlashAttention Kernel: Forward & Backward 

# [Forward] 학습 시 통계량(L, M)을 저장하여 O(N^2) 메모리 점유를 O(N)으로 방지
@triton.jit
def flash_attn_forward_kernel(Q, K, V, L, M, Out):
    pid = tl.program_id(0)
    q_tile = tl.load(Q + pid * BLOCK_SIZE)
    
    # Online Softmax를 이용해 점진적으로 결과를 계산
    for start_n in range(0, seq_len, BLOCK_SIZE):
        k_tile = tl.load(K + start_n)
        v_tile = tl.load(V + start_n)
        
        # 1. 현재 블록의 어텐션 스코어 계산 (SRAM 내 MatMul)
        qk = tl.dot(q_tile, tl.trans(k_tile)) # S_curr
        
        # 2. 현재 블록의 통계량 계산
        m_curr = tl.max(qk, axis=1)
        p_curr = tl.exp(qk - m_curr[:, None])
        l_curr = tl.sum(p_curr, axis=1)
        
        # 3. 새로운 전역 최댓값 갱신
        m_new = tl.maximum(m_prev, m_curr)
        
        # 4. [핵심] Rescale Factor 계산
        # 이전 값들을 새로운 최댓값 m_new 기준으로 '축소' 보정함
        alpha = tl.exp(m_prev - m_new)
        beta = tl.exp(m_curr - m_new)
        
        # 5. 결과(Output)와 지수합(l) 업데이트
        # 이전 출력값(acc)에 alpha를 곱해 보정하고 새로운 값을 더함
        acc = acc * alpha[:, None] + tl.dot(beta[:, None] * p_curr, v_tile)
        l_new = alpha * l_prev + beta * l_curr
        
        # 6. 다음 루프를 위해 통계량 갱신
        m_prev = m_new
        l_prev = l_new
        
    # Backward Pass 재계산을 위해 Softmax 통계량(m, l)만 HBM에 저장
    tl.store(L + pid, l_final) 
    tl.store(M + pid, m_final)
    tl.store(Out + pid, acc)

# [Backward] 저장된 통계량으로 SRAM 내에서 즉석 재계산(Recomputation) 수행
@triton.jit
def flash_attn_backward_kernel(Q, K, V, L, M, dOut, dQ, dK, dV):
    # HBM에서 거대한 P 행렬을 읽어오는 대신, 작은 통계량(L, M)만 로드
    l = tl.load(L + pid)
    m = tl.load(M + pid)
    
    # [Recomputation] Q, K를 다시 읽어와 SRAM 내부에서 P 행렬을 즉석 재건축
    # 재계산 비용이 HBM I/O 비용보다 훨씬 저렴함 (Memory-bound 해결)
    p_recomputed = tl.exp(tl.dot(q_tile, k_tile.T) - m) / l
    
    # 재계산된 P를 이용해 Gradient 계산
    # dV = P^T * dOut, dP = dOut * V^T ...
    ...

FlashAttention은 메모리 절약을 위해 Backward Pass에서 사용할 중간 행렬(S,PS, P)을 저장하지 않습니다. 대신 Forward 때 썼던 Online Softmax 통계량(m,m, \ell)만 저장하고, Backward 때 이를 이용해 SRAM에서 즉석 재계산(Recomputation) 을 수행하여 메모리 사용량을 O(N)O(N)으로 유지합니다

6. 결론 및 시스템적 파급 효과

FlashAttention은 인공지능 연구가 단순히 수학적 모델링에 머물러서는 안 되며, 코드가 실행되는 물리적 하드웨어(GPU)의 아키텍처를 이해해야만(Hardware-Aware) 진정한 혁신을 이룰 수 있음을 증명한 기념비적인 연구입니다.

  • 속도 및 메모리: IO 비용을 극단적으로 줄여 훈련/추론 속도를 2~4배 향상시켰으며, N2N^2 메모리 요구량을 선형 수준(O(N)O(N))으로 감소시켜 32K,128K32K, 128K 이상의 긴 문맥 처리(Long-context LLM) 시대를 열었습니다.
  • 통합: 현재 PyTorch F.scaled_dot_product_attention 백엔드의 기본값으로 채택되어 현대 LLM 인프라의 표준이 되었습니다.

한계: 이러한 IO 최적화에도 불구하고 디코딩(Serving) 단계에서 발생하는 KV Cache의 불규칙한 메모리 단편화(Fragmentation) 문제는 여전히 남아있습니다.