더 빠르고 메모리 효율적인 트랜스포머를 위한 혁신, FlashAttention-3!
트랜스포머 아키텍처의 핵심 요소인 어텐션은 대규모 언어 모델(LLM)과 긴 컨텍스트 애플리케이션에서 성능의 병목 현상을 일으키는 주범이기도 해요. FlashAttention은 GPU에서 어텐션 연산을 가속화하기 위해 메모리 읽기/쓰기를 최소화하는 획기적인 방법을 제시했고, 지금은 대부분의 라이브러리에서 트랜스포머 학습과 추론 속도를 높이는 데 널리 활용되고 있죠. 덕분에 지난 2년 동안 LLM의 컨텍스트 길이가 기존 2~4K(GPT-3, OPT)에서 128K(GPT-4), 심지어 1M(Llama 3)까지 엄청나게 늘어났어요.
하지만 이런 성공에도 불구하고, FlashAttention은 최신 하드웨어의 새로운 기능을 아직 충분히 활용하지 못했어요. FlashAttention-2는 H100 GPU에서 이론상 최대 FLOPS(초당 부동 소수점 연산 횟수)의 35%밖에 활용하지 못했거든요. 그래서 PyTorch 팀에서는 Hopper(H100) GPU에서 어텐션 연산을 더욱 가속화하기 위한 새로운 기술들을 개발했고, 이를 바탕으로 FlashAttention-3을 선보이게 되었어요!
FlashAttention-3: H100 GPU 성능 극대화
FlashAttention-3은 어텐션 연산을 더욱 빠르게 만들기 위해 세 가지 핵심 기술을 적용했어요.
워프 특수화(Warp Specialization)를 통한 연산 및 데이터 이동 중첩
첫 번째 기술은 워프 특수화예요. 워프 특수화는 텐서 코어와 TMA(Tensor Memory Accelerator)의 비동기성을 활용하여 전체 연산과 데이터 이동을 동시에 수행하는 기법이에요. 쉽게 말해, 데이터를 메모리에서 가져오는 동시에 계산을 시작해서 시간을 절약하는 거죠. 마치 멀티태스킹처럼 여러 작업을 동시에 처리해서 속도를 높이는 거라고 생각하면 돼요.
블록 단위 MatMul 및 Softmax 연산 교차 수행
두 번째 기술은 블록 단위 MatMul과 Softmax 연산을 번갈아 가며 수행하는 거예요. 기존에는 MatMul 연산을 모두 끝낸 후에 Softmax 연산을 수행했는데, FlashAttention-3은 MatMul 연산과 Softmax 연산을 블록 단위로 섞어서 처리하도록 설계되었어요. 이를 통해 연산 파이프라인을 최적화하고, 연산 대기 시간을 줄여 속도를 높일 수 있게 되었죠.
저정밀도 FP8 활용 및 비일관적 처리
마지막 기술은 저정밀도 FP8을 위한 하드웨어 지원을 활용하는 비일관적 처리예요. FP8은 FP16보다 더 작은 메모리 공간을 사용하는데, 이를 통해 메모리 대역폭을 늘리고 더 많은 연산을 동시에 처리할 수 있게 해줘요. 또한, 비일관적 처리를 통해 데이터 접근 방식을 최적화하여 메모리 읽기/쓰기 오버헤드를 줄이고 연산 속도를 높였어요.
FlashAttention-3의 놀라운 성능 향상
FlashAttention-3은 이러한 기술들을 통해 기존 FlashAttention-2보다 훨씬 뛰어난 성능을 보여주고 있어요. FP16에서 FlashAttention-2보다 1.5~2.0배 빠르며, 최대 740 TFLOPS를 달성했어요. 즉, H100의 이론상 최대 FLOPS의 75%까지 활용할 수 있다는 뜻이죠.
FP8에서는 기존 FP8 어텐션 연산보다 오차는 2.6배 줄이면서도 1.2 PFLOPS에 가까운 성능을 달성했어요. 정말 대단하죠?
FP16 | 1.5-2.0x | 740 TFLOPS (75% of H100 peak) | - |
FP8 | - | ~1.2 PFLOPS | 2.6x smaller |
Precision Speedup over FlashAttention-2 Achieved Performance (TFLOPS/PFLOPS) Error (relative to baseline)
자주 묻는 질문 (FAQ)
Q1. FlashAttention은 왜 중요한가요?
A1. FlashAttention은 트랜스포머 모델의 어텐션 연산을 가속화하여 더 크고 복잡한 LLM을 훈련하고 추론하는 데 필수적인 기술이에요. 메모리 효율성을 높여 GPU 자원을 효율적으로 사용하고, 연산 속도를 높여 훈련 시간을 단축시키죠.
Q2. FlashAttention-3은 어떤 점이 개선되었나요?
A2. FlashAttention-3은 H100 GPU의 성능을 극대화하기 위해 워프 특수화, 블록 단위 MatMul 및 Softmax 연산 교차 수행, 저정밀도 FP8 활용 및 비일관적 처리 등의 기술을 도입했어요. 이를 통해 이전 버전보다 훨씬 빠르고 효율적인 어텐션 연산을 수행할 수 있게 되었죠.
Q3. FlashAttention-3을 어떻게 사용할 수 있나요?
A3. FlashAttention-3은 GitHub에서 오픈소스로 공개되어 있어요. 코드를 다운로드하여 자신의 프로젝트에 쉽게 적용할 수 있답니다.
마무리
FlashAttention-3은 H100 GPU에서 어텐션 연산을 가속화하는 혁신적인 기술이에요. 더 빠르고 효율적인 LLM 개발을 위한 중요한 발걸음이 될 것으로 기대됩니다.
키워드
플래시어텐션, FlashAttention, FlashAttention3, 어텐션, Attention, 트랜스포머, Transformer, LLM, 대규모언어모델, GPU, H100, Hopper, PyTorch, 딥러닝, DeepLearning, AI, 인공지능, 머신러닝, MachineLearning, FLOPS, TFLOPS, PFLOPS, 메모리효율성, 성능향상, 최적화, 비동기, WarpSpecialization, MatMul, Softmax, FP8, 저정밀도, 비일관적처리, 오픈소스, OpenSource, GitHub, 논문, 연구, 개발