[Research Note] 1.58-bit Ultra-low-bit Quantization (BitNet)

Mathematical implementation and QAT design

📝 [Research Note] 1.58-bit Ultra-low-bit 양자화(BitNet) 수학적 구현 및 QAT 설계

1. 문제 정의 및 목표

배경: Edge 환경에서의 Memory Wall 한계
최근 대형 언어 모델(LLM)의 발전에도 불구하고, 이를 Jetson, MPU와 같은 리소스가 제한된 Edge 디바이스에 배포하는 데는 치명적인 병목 현상이 존재합니다. LLM 추론 지연의 근본적인 원인은 연산력(Compute-bound)이 아닌, 메모리에서 가중치를 불러오는 속도인 대역폭(Memory-bound)의 한계, 즉 Memory Wall에 있습니다.

해결책 제안: 1.58-bit의 도입
기존의 FP16이나 INT8 양자화 방식은 여전히 곱셈 연산을 수반하며 일정 수준 이상의 메모리 대역폭을 요구합니다. 이에 대한 돌파구로, 가중치를 1,0,1{-1, 0, 1}의 3가지 값(Ternary)으로 매핑하는 **1.58-bit 양자화(BitNet b1.58)**가 대두되었습니다. (3개의 상태를 표현하기 위해 비트가 필요)

연구 목표
본 문서는 1.58-bit 양자화 알고리즘의 수학적 원리를 분석하고, 미분 불가능성(Non-differentiability)을 해결하기 위한 STE(Straight-Through Estimator) 의 PyTorch 기반 QAT(Quantization-Aware Training) 커스텀 구현 과정을 기록합니다. 더불어 이를 온디바이스 추론 엔진에 올렸을 때의 하드웨어 가속 원리를 고찰합니다.


2. 양자화 알고리즘의 수학적 모델링

BitNet b1.58의 핵심은 복잡한 부동소수점 가중치를 정수형으로 스케일링하여 매핑하는 것입니다.

Weight Quantization (Absmean 기반 스케일링)
가중치 행렬 를 1,0,1{-1, 0, 1}로 양자화하기 위해 평균 절댓값(Absmean) 스케일링 방식인 γ\gamma를 사용합니다.

γ=1nmi,jWij\gamma = \frac{1}{nm} \sum_{i,j} |W_{ij}|

이 스케일링 팩터 를 이용하여 가중치를 나누고, 반올림(Round)한 뒤 범위를 클리핑(Clipping)합니다. ϵ\epsilon은 0으로 나누어지는 것을 방지하는 작은 상수입니다.

Wq=Clip(Round(Wγ+ϵ),1,1)W_q = \text{Clip}(\text{Round}(\frac{W}{\gamma + \epsilon}), -1, 1)

Activation Quantization
활성화 함수(Activation)의 출력값 역시 연산 효율을 위해 양자화(통상 8-bit)합니다. 여기서는 최댓값(Absmax)을 기준으로 스케일링을 수행합니다.

Xq=Clip(Round(Xmax(X)×127),128,127)X_q = \text{Clip}(\text{Round}(\frac{X}{\max(|X|)} \times 127), -128, 127)


3. PyTorch QAT 및 STE(Straight-Through Estimator) 구현

양자화 과정에서 사용되는 함수는 계단식(Step) 형태를 가지므로 기울기(Gradient)가 0이 되어 오차 역전파(Backpropagation)가 불가능해집니다. 이 미분 불가능성 문제를 해결하기 위해 STE(Straight-Through Estimator) 를 도입했습니다.

STE의 핵심은 Forward 패스에서는 양자화를 수행하고, Backward 패스에서는 미분값을 그대로(Straight-through) 흘려보내는 것입니다. 이를 PyTorch의 torch.autograd.Function을 활용하여 커스텀 레이어로 구현했습니다.

import torch
import torch.nn as nn
import torch.nn.functional as F

class ActivationQuantizerSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits=8):
        # 8-bit Absmax 양자화
        Qb = 2 ** (num_bits - 1) - 1
        scale = x.abs().max().clamp(min=1e-5) / Qb
        x_q = torch.round(x / scale).clamp(-Qb, Qb)
        ctx.save_for_backward(x_q)
        return x_q * scale # 역양자화 형태로 반환

    @staticmethod
    def backward(ctx, grad_output):
        # STE: Forward에서의 Round 연산을 무시하고 Gradient를 그대로 통과
        return grad_output, None

class WeightQuantizerSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, weight):
        # 1.58-bit (Ternary) Absmean 양자화
        scale = weight.abs().mean().clamp(min=1e-5)
        weight_q = torch.round(weight / scale).clamp(-1, 1)
        
        ctx.save_for_backward(weight, scale)
        # 실제 연산을 위해 스케일을 복원한 값 반환
        return weight_q * scale 

    @staticmethod
    def backward(ctx, grad_output):
        # STE 적용
        return grad_output

class BitLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        # 1. Weight를 1.58-bit로 양자화
        quantized_weight = WeightQuantizerSTE.apply(self.weight)
        
        # 2. LayerNorm 이후의 Activation 양자화 (논문 기준)
        quantized_x = ActivationQuantizerSTE.apply(x, 8)
        
        # 3. 선형 연산 수행
        out = F.linear(quantized_x, quantized_weight, self.bias)
        return out


4. Algorithm-Hardware Co-design 통찰

이러한 극단적인 양자화가 하드웨어 관점에서 가지는 의미는 단순한 '용량 감소'를 넘어섭니다. ALU(산술논리연산장치) 레벨에서의 혁신적인 변화를 이끌어냅니다.

Multiplier(곱셈기)의 소멸과 Adder(덧셈기)로의 완전 대체
기존 모델의 행렬 곱(MatMul)은 Y=W×XY = W \times X 형태의 수많은 MAC(Multiply-Accumulate) 연산을 필요로 합니다. 하지만 W{1,0,1}W \in \{-1, 0, 1\} 상황에서는 연산의 패러다임이 바뀝니다.

  • W=1W = 1 : XX를 더함 (+X+X)
  • W=1W = -1 : XX를 뺌 (X-X)
  • W=0W = 0 : 연산 생략

즉, 부동소수점 곱셈기가 하드웨어 파이프라인에서 완전히 배제되고, 순수한 정수형 덧셈기(Adder)만으로 모든 선형 연산 처리가 가능해집니다.

Edge 환경 최적화 시나리오

  • Memory Bandwidth: 가중치를 로드하는 데 필요한 I/O 대역폭이 FP16 대비 10배 이상 감소합니다.
  • Power Consumption: 하드웨어 연산 시 부동소수점 곱셈(FP16 MAC)은 정수형 덧셈(INT Add) 대비 막대한 전력을 소모합니다. 본 구조를 On-device 추론 엔진에 포팅할 경우 배터리 기반의 Edge Device에서 획기적인 에너지 효율 최적화를 달성할 수 있습니다.

5. 결론 및 향후 계획

본 연구를 통해 1.58-bit 단위의 양자화 기법이 단순한 가중치 압축을 넘어, 하드웨어 친화적(Hardware-aware) 아키텍처 재설계의 핵심이 될 수 있음을 확인했습니다.

소규모 텐서를 대상으로 한 단위 테스트(Unit Test) 수준의 가설 검증 결과, STE 적용 시 Weight 맵핑 과정에서 초기 MSE(Mean Squared Error) 변동폭이 관찰되었으나, 적절한 Learning Rate 조절을 통해 수렴 가능한 궤적을 확인했습니다.

Next Steps

  • QAT 파이프라인 구축:
    소규모의 파라미터를 가진 Model(SLM등)을 대상으로 본 커스텀 BitLinear 레이어를 삽입하여 Full QAT 실험 진행.