【DL輪読会】Training-Free Activation Sparsity in Large Language Models

-- Views

December 11, 25

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

ダウンロード

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] TEAL: Training-Free Activation Sparsity in Large Language Models D1,Kim Yongmin, The University of Tokyo http://deeplearning.jp/ 1

2.

書誌情報 タイトル: TEAL: Training-Free Activation Sparsity in Large Language Models 著者: James Liu†, Pragaash Ponnusamy‡, Tianle Cai§, Han Guo§, Yonglong Tian‡, Tri Dao‡ 所属: †MIT, ‡Together AI, §Princeton University 発表情報: ICLR 2025 Spotlight 概要: Magnitude-based pruningで学習なしに40-50%のmodel-wide sparsityを達成し, LLM推論を最大1.8倍高速化 選定理由: LLM推論効率化の新アプローチとして,学習不要かつPruning後性能も高いため 実装: https://github.com/FasterDecoding /TEAL 2

3.

概要 背景 ・LLMの推論は"memory wall"問題に直面(演算速度 >> メモリ帯域幅) ・既存手法CATSはMLP層のみに適用可能 手法 ・TEAL: 全層(Attention + MLP)に適用可能なmagnitude-based activation sparsity ・Block-wise greedy optimizationで最適なsparsity配分を決定 結果 ・Perplexity: Llama-3-8Bで5.87→6.67 (50% sparsity, +13.6%) ・Downstream: 5タスク平均 72.8%→68.4% (50% sparsity, -4.4%) ・推論速度: A6000で最大1.80×高速化 (50% sparsity) 3

4.

背景 現代GPUの課題 ・演算速度がメモリ帯域幅を大幅に上回る ・LLM推論(特にバッチサイズ1のdecoding)はメモリ律速 ・重み行列のロード時間が推論時間の大部分を占める Arithmetic Intensity(演算強度) ・定義: 演算量 ÷ メモリアクセス量 ・LLM推論では約1(非常に低い)→ メモリ律速 ・高いほどGPU演算器を有効活用できる 解決アプローチ ・Weight quantization: 重みを低精度で保存 ・Activation sparsity: 活性化の一部のみを計算 ← 本研究 4

5.

背景 既存のActivation Sparsity手法の問題点 DejaVu [Liu+ 2023] ・Sparsity predictorの学習が必要(追加学習コスト) ・Predictorの精度に依存 ・推論時にpredictorを実行するオーバーヘッド CATS [Lee+ 2024](最も関連する先行研究) ・MLPのWup, Wdown層にのみ適用可能 ・Attention層には適用不可 ・40%以上のsparsityで性能が急激に低下 → 学習不要で全層に適用可能な手法が必要 5

6.

背景 従来手法 (CATS)の限界 CATSの手法 t: 閾値←Calibrationデータから目標sparistyに基づいて,決定 CATSの限界 ・Attention層(Q, K, V, O projection)への適用不可 ・Model-wide sparsity上限 Pruning 𝑆i𝐿𝑈 𝑥𝑊𝑔𝑎𝑡𝑒 𝑆i𝐿𝑈 Pruning ・高sparsity時(40%以上)で性能が急激に崩壊 MLP層の例 6

7.

背景 活性化分布の分析 Key Insight ・大きな活性化値が出力に大きく寄与 ・小さな活性化値をゼロにしても出力への影響は限定的 観察結果 ・MLP層: Down projection入力で明確なheavy-tail分布 ・Attention層: Output projection入力で同様の分布 結論 ・全ての各Blockの中間の値ではなくて,入力からPruning → Wgateの活性値からPruningするCATSより,もっといい Pruningが可能 C4でのLlama3-8Bの活性化分布 7

8.

背景 活性化分布の分析 Key Insight ・大きな活性化値が出力に大きく寄与 ・小さな活性化値をゼロにしても出力への影響は限定的 TEAL 観察結果 ・MLP層: Down projection入力で明確なheavy-tail分布 ・Attention層: Output projection入力で同様の分布 結論 ・全ての各Blockの中間の値ではなくて,入力からPruning → Wgateの活性値からPruningするCATSより,もっといい Pruningが可能 CATs SILU TEAL 𝑥 CATsとTEALのPruningするところの比較 7

9.

手法 TEAL: Training-Free Activation Sparsity 手法 ・入力活性化xの要素のうち, magnitude(絶対値)が小さいものをゼロ化 𝑡𝑝 : 閾値←Calibrationデータから目標sparistyに基づいて,決定 適用対象 ・MLPとAttentionブロックの入力 Sparsity設定 ・各層で異なる最適sparsity → Block-wise greedy optimizationで決定 ・後段層ほど高いsparsityが可能(経験的観察) ・学習不要: calibration dataのみで最適化 TEALの概略図 8

10.

手法 Block-wise Greedy Optimization 問題設定 ・目標: 全体sparsity制約下で,ブロック別sparsity配分を最適化 ・目的関数: ブロック別ℓ2活性化誤差の最小化 1. 全ブロックを最低sparsity0%で初期化 2. 各レイヤーのsparsityを少し増加させた場合の誤差を測定 3. Greedy: 誤差増加が最小のブロックからsparsity増加 Sparisty 最適化手順 4. 目標の全体sparsityに達するまで繰り返し 計算コスト ・Llama-3-8B: A100 1GPUで1時間未満 ・一度計算すれば,推論時は固定sparsity設定を使用 Layer Index Llama-3-70Bの block-wise greedy Optimizationの結果 10

11.

手法 Hardware-Aware Acceleration Sparse GEMVカーネル ・Tritonベースの実装 ・Dense GEMVより高速化を達成 最適化技術 1. Mask fusion: マスク生成とスパース行列演算を融合 (CATsと同様) → メモリ演算が大幅に軽減 2. FP16 accumulation: FP32→FP16で帯域幅削減 3. PTX eviction policy: L2キャッシュ効率化 Torch Step 1: マスク生成 x = [0.1, 2.5, -0.3, 1.8, -0.05] threshold = 0.5 mask = [0, 1, 0, 1, 0] ← メモリに保存 全部のweight Step 2: Sparse演算 をメモリから読み込む メモリからmaskを読み込み sparse_x = x * mask = [0, 2.5, 0, 1.8, 0] y = sparse_x @ W 1 stepに軽減 Custom Kernal (CATs, TEAL) Step 1: マスク生成 + Sparse演算を同時に x = [0.1, 2.5, -0.3, 1.8, -0.05] 必要なweightだけ threshold = 0.5 メモリから読み込む → |x[i]| > 0.5のものだけ直接計算 → maskをメモリに保存する必要なし 11

12.

手法 Hardware-Aware Acceleration Sparse GEMVカーネル ・Tritonベースの実装 ・Dense GEMVより高速化を達成 最適化技術 1. Mask fusion: マスク生成とスパース行列演算を融合 (CATsと同様) → メモリ演算が大幅に軽減 2. FP16 accumulation: FP32→FP16で帯域幅削減 3. PTX eviction policy: L2キャッシュ効率化 TEALと他のカーネルとTorchとの速度比較 11

13.

実験 実験設定 モデル ・Llama-2-7B, Llama-2-13B, Llama-3-8B, Llama-3-70B ・Mistral-7B, Falcon-7B 評価データセット ・Perplexity: WikiText-2, C4 ・Downstream: ARC-E, ARC-C, HellaSwag, PIQA, WinoGrande Prefillの後半の50%からPruning.GSM8Kのような長いGenerationタスクはPrefillはPruningなし Baseline ・Dense (sparsityなし) ・CATS [Lee+ 2024] ・Uniform Sparsity(全層同一sparsity) Calibration ・C4から128シーケンス(2048トークン/シーケンス) 12

14.

結果 Perplexity評価結果 WikiText-2 Perplexity Llama-3-8B - Dense: 5.87 - TEAL 40% sparsity: 6.21 (+5.8%) - TEAL 50% sparsity: 6.67 (+13.6%) ・Llama-2-7B - Dense: 5.07 - TEAL 40% sparsity: 5.22 (+3.0%) - TEAL 50% sparsity: 5.43 (+7.1%) 主要な発見 ・Block-wise greedy > Uniform sparsity (全sparsityレベルで優位) ・50%以上のsparsityでも破滅的な性能劣化なし 13

15.

結果 Downstreamタスク評価結果 6タスク平均精度 (MMLU, ARC, HellaSwag, GSM8K, PIQA, Winogrande) ・Llama-3-8B - Dense: 68.07% - TEAL 40% sparsity: 66.21% (-1.9%p) - TEAL 50% sparsity: 63.42% (-4.7%p) ・Llama-2-7B - Dense: 56.50% - TEAL 40% sparsity: 55.45% (-1.1%p) - TEAL 50% sparsity: 54.26% (-2.2%p) 結論: - 40% sparsityまでは実用的な精度を維持 - モデルサイズ大きくなるほど,もっとPruning可能 14

16.

結果 推論速度評価結果 A6000での測定結果 (Llama-3-8B) ・Dense baseline: 1.00× ・TEAL 40% sparsity: 1.53× 高速化 ・TEAL 50% sparsity: 1.80× 高速化 A100での測定結果 ・Dense baseline: 1.00× ・TEAL 50% sparsity: 1.45× 高速化 高速化の要因 ・メモリ帯域幅削減(sparse要素のみロード) ・演算量削減(ゼロ要素との積はスキップ) 15

17.

まとめ まとめ・所感 TEALの貢献 ・学習不要のactivation sparsity手法を提案 ・全層(Attention + MLP)に適用可能 ・40-50%のmodel-wide sparsityで安定動作 ・既存のカスタムカーネルに加えて,さらに高速化 主要結果 ・WikiText PPL: +13%増加で50% sparsity達成 ・推論速度: 最大1.8×高速化(A6000) ・Weight quantizationと互換 所感・今後の課題 ・ 多様な分析とカスタムカーネル実装など充実な研究という感じ ・ 各入力の分布しているため,複数バッチの場合,各サンプルが違うところをPruning → 複数バッチが不可能 ・ マイトークンPruningするところを変更 → Reasoning taskでも性能維持可能 ・入力からPruningするところを決定 → 少し違和感 19

18.

参考文献 [Liu+ 2024] James Liu et al. "TEAL: Training-Free Activation Sparsity in Large Language Models" arXiv:2408.14690 [Lee+ 2024] Donghyuk Lee et al. "CATS: Contextually-Aware Thresholding for Sparsity in Large Language Models" arXiv:2404.08763 [Liu+ 2023] Zichang Liu et al. "DejaVu: Contextual Sparsity for Efficient LLMs at Inference Time" ICML 2023 [Song+ 2023] Yixin Song et al. "PowerInfer: Fast Large Language Model Serving with a Consumer-grade GPU" arXiv:2312.12456 [Frantar+ 2023] Elias Frantar et al. "SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot" ICML 2023 [Lin+ 2024] Ji Lin et al. "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration" MLSys 2024 20