1. Introduction: Compute-bound인가, Memory-bound인가?
Transformer 모델의 최적화를 논할 때 가장 먼저 던져야 할 질문은 "무엇이 모델을 느리게 만드는가?"입니다. 거대한 행렬의 곱셈(MatMul)은 GPU의 막강한 병렬 처리 능력(Tensor Cores) 덕분에 연산기 자체의 속도 한계(Compute-bound)에 부딪히는 경우가 적습니다.
진짜 문제는 Memory-bound입니다. 연산기가 아무리 빨라도 데이터를 담고 있는 크고 느린 메모리(HBM)에서 작고 빠른 캐시 메모리(SRAM)로 데이터를 퍼 나르는 대역폭(Bandwidth)이 부족하여 연산기가 노는 현상이 발생합니다. FlashAttention은 연산량(FLOPs)을 줄이는 알고리즘이 아닙니다. 오히려 연산량을 조금 늘리더라도 메모리 접근 횟수(Memory Accesses)를 극단적으로 줄여(IO-Aware) 전체 속도를 비약적으로 끌어올린 혁명적인 시스템 알고리즘입니다.
2. 하드웨어적 비효율성: Standard Attention의 IO 병목
표준 어텐션(Standard Attention)의 수식을 GPU 메모리 계층 관점에서 해부해 봅시다.
이 수식을 처리하기 위해 GPU는 다음과 같은 미련한 과정(Materialization)을 거칩니다.
- HBM에서 를 SRAM으로 읽어와 를 계산하고, 그 결과인 크기의 거대한 Score 행렬 를 다시 HBM에 씁니다(Write).
- HBM에서 를 다시 읽어와(Read) Softmax를 적용한 행렬 를 만들고 또 HBM에 씁니다.
- 마지막으로 HBM에서 와 를 읽어와 최종 결과 행렬 를 계산하고 HBM에 씁니다.
시퀀스 길이 이 길어지면 이 크기의 중간 행렬()을 HBM에 썼다가 지우는, 즉 HBM에 크기의 중간 결과물을 물리적으로 생성(Materialization) 하는 I/O(Input/Output) 트래픽이 기하급수적으로 폭발합니다. 이것이 긴 문맥(Long Context)을 처리할 때 발생하는 OOM(Out of Memory) 과 속도 저하의 주범입니다.
3. 해결책 1: Tiling (타일링) 기법 - SRAM 안에서 끝내기
FlashAttention의 첫 번째 핵심 아이디어는 Tiling(블록 단위 쪼개기) 입니다. 거대한 행렬을 한 번에 계산하는 대신, 행렬을 GPU의 작은 SRAM 용량(약 20MB 내외)에 쏙 들어갈 수 있는 크기의 블록(Block) 단위로 나눕니다.
- HBM에서 의 작은 블록들을 SRAM으로 한 번만 가져옵니다.
- SRAM 내부에서 해당 블록 간의 연산(MatMul Softmax MatMul)을 끝까지 수행합니다.
- 중간 행렬()은 HBM에 절대 쓰지 않고 버리며, 오직 최종 결과 행렬 의 블록만 HBM에 씁니다.
FlashAttention은 행렬을 HBM에 쓰는 대신, SRAM 내에서 Fusion(연산 통합) 하여 최종 결과물 만 HBM으로 보냅니다.
이 단순해 보이는 아이디어를 통해 HBM 접근 횟수를 에서 블록 크기 에 반비례하는 수준으로 비약적으로 감소시킵니다.
4. 해결책 2: Online Softmax의 수학적 기적
Tiling 기법을 구현하는 데 있어 가장 거대한 수학적 장벽이 존재합니다. 바로 Softmax의 특성입니다. 표준 Softmax 수식은 다음과 같습니다.
분모를 계산하려면 해당 행(Row)의 모든 원소를 한 번에 다 봐야만 합니다. 하지만 Tiling은 데이터를 쪼개서 가져오므로 전체 합을 알 수 없습니다. 이를 수학적으로 해결한 것이 Online Softmax (Safe Softmax[지수 연산() 시 발생할 수 있는 Overflow를 방지하기 위해 매 스텝 최댓값()을 뺌]의 점진적 업데이트) 입니다.
FlashAttention은 전체 합을 기다리지 않고, 새로운 블록이 들어올 때마다 로컬 최대값()과 지수합()을 점진적으로 갱신(Update)합니다.
두 개의 블록(이전 블록 1, 새로운 블록 2)이 있다고 가정할 때, 결합된 새로운 최대값과 지수합은 다음과 같이 업데이트됩니다.
이 수학적 트릭을 통해 이전 블록의 결과를 새로운 로컬 최대값에 맞춰 재조정(Rescaling)할 수 있습니다. 결과적으로, 모든 데이터를 한 번에 보지 않고 부분합만으로 연산했음에도 불구하고, 근사치(Approximation)가 아닌 표준 Attention과 소수점 아래까지 완벽히 동일한(Exact) 결과를 산출합니다.
5. Code Snippets
5.1 standard vs online softmax
import torch
# 1. Standard Softmax (전체 데이터를 한 번에 봐야 함)
def standard_softmax(x):
# 전체 행에서 최댓값을 찾아야 함 -> 전역적 데이터 접근 발생
m = torch.max(x, dim=-1, keepdim=True).values
# 전체 행의 지수 합을 구해야 함 -> HBM Read/Write 반복
exp_x = torch.exp(x - m)
s = torch.sum(exp_x, dim=-1, keepdim=True)
return exp_x / s
# 2. Online Softmax (데이터를 블록 단위로 쪼개서 점진적 업데이트)
# FlashAttention의 핵심: 전체를 다 보지 않고도 결과를 확정할 수 있음
def online_softmax_block_update(x_block, prev_m, prev_l):
"""
x_block: 현재 처리 중인 데이터 블록
prev_m: 이전 블록까지의 최댓값 (m_i-1)
prev_l: 이전 블록까지의 지수 합 (l_i-1)
"""
# 1. 현재 블록의 최댓값 계산
curr_m = torch.max(x_block)
# 2. 새로운 전역 최댓값 갱신
new_m = torch.max(prev_m, curr_m)
# 3. Rescale factor 계산 (이전 값들을 새로운 최댓값 기준으로 보정)
# 이 수식이 있어 데이터가 쪼개져 있어도 수학적 정확성이 유지됨
# l_new = e^(m_prev - m_new) * l_prev + e^(m_curr - m_new) * l_curr
# 여기서 l_curr는 현재 블록의 exp 합: sum(exp(x_block - new_m))
new_l = torch.exp(prev_m - new_m) * prev_l + torch.sum(torch.exp(x_block - new_m))
return new_m, new_l
# 비교:
# Standard는 N이 커지면 메모리 부족(OOM)이 발생하지만,
# Online은 블록 단위로 연산기에 로드하여 처리하므로 메모리 I/O를 획기적으로 줄임
5.2 flashattention kernel (Triton)
## 5.2 FlashAttention Kernel: Forward & Backward
# [Forward] 학습 시 통계량(L, M)을 저장하여 O(N^2) 메모리 점유를 O(N)으로 방지
@triton.jit
def flash_attn_forward_kernel(Q, K, V, L, M, Out):
pid = tl.program_id(0)
q_tile = tl.load(Q + pid * BLOCK_SIZE)
# Online Softmax를 이용해 점진적으로 결과를 계산
for start_n in range(0, seq_len, BLOCK_SIZE):
k_tile = tl.load(K + start_n)
v_tile = tl.load(V + start_n)
# 1. 현재 블록의 어텐션 스코어 계산 (SRAM 내 MatMul)
qk = tl.dot(q_tile, tl.trans(k_tile)) # S_curr
# 2. 현재 블록의 통계량 계산
m_curr = tl.max(qk, axis=1)
p_curr = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p_curr, axis=1)
# 3. 새로운 전역 최댓값 갱신
m_new = tl.maximum(m_prev, m_curr)
# 4. [핵심] Rescale Factor 계산
# 이전 값들을 새로운 최댓값 m_new 기준으로 '축소' 보정함
alpha = tl.exp(m_prev - m_new)
beta = tl.exp(m_curr - m_new)
# 5. 결과(Output)와 지수합(l) 업데이트
# 이전 출력값(acc)에 alpha를 곱해 보정하고 새로운 값을 더함
acc = acc * alpha[:, None] + tl.dot(beta[:, None] * p_curr, v_tile)
l_new = alpha * l_prev + beta * l_curr
# 6. 다음 루프를 위해 통계량 갱신
m_prev = m_new
l_prev = l_new
# Backward Pass 재계산을 위해 Softmax 통계량(m, l)만 HBM에 저장
tl.store(L + pid, l_final)
tl.store(M + pid, m_final)
tl.store(Out + pid, acc)
# [Backward] 저장된 통계량으로 SRAM 내에서 즉석 재계산(Recomputation) 수행
@triton.jit
def flash_attn_backward_kernel(Q, K, V, L, M, dOut, dQ, dK, dV):
# HBM에서 거대한 P 행렬을 읽어오는 대신, 작은 통계량(L, M)만 로드
l = tl.load(L + pid)
m = tl.load(M + pid)
# [Recomputation] Q, K를 다시 읽어와 SRAM 내부에서 P 행렬을 즉석 재건축
# 재계산 비용이 HBM I/O 비용보다 훨씬 저렴함 (Memory-bound 해결)
p_recomputed = tl.exp(tl.dot(q_tile, k_tile.T) - m) / l
# 재계산된 P를 이용해 Gradient 계산
# dV = P^T * dOut, dP = dOut * V^T ...
...
FlashAttention은 메모리 절약을 위해 Backward Pass에서 사용할 중간 행렬()을 저장하지 않습니다. 대신 Forward 때 썼던 Online Softmax 통계량()만 저장하고, Backward 때 이를 이용해 SRAM에서 즉석 재계산(Recomputation) 을 수행하여 메모리 사용량을 으로 유지합니다
6. 결론 및 시스템적 파급 효과
FlashAttention은 인공지능 연구가 단순히 수학적 모델링에 머물러서는 안 되며, 코드가 실행되는 물리적 하드웨어(GPU)의 아키텍처를 이해해야만(Hardware-Aware) 진정한 혁신을 이룰 수 있음을 증명한 기념비적인 연구입니다.
- 속도 및 메모리: IO 비용을 극단적으로 줄여 훈련/추론 속도를 2~4배 향상시켰으며, 메모리 요구량을 선형 수준()으로 감소시켜 이상의 긴 문맥 처리(Long-context LLM) 시대를 열었습니다.
- 통합: 현재 PyTorch
F.scaled_dot_product_attention백엔드의 기본값으로 채택되어 현대 LLM 인프라의 표준이 되었습니다.
한계: 이러한 IO 최적화에도 불구하고 디코딩(Serving) 단계에서 발생하는 KV Cache의 불규칙한 메모리 단편화(Fragmentation) 문제는 여전히 남아있습니다.