View in Telegram
Forwarded from Machinelearning
🌟 SageAttention: ΠΌΠ΅Ρ‚ΠΎΠ΄ квантования ΠΌΠ΅Ρ…Π°Π½ΠΈΠ·ΠΌΠ° внимания Π² Π°Ρ€Ρ…ΠΈΡ‚Π΅ΠΊΡ‚ΡƒΡ€Π°Ρ… трансформСров. Π’Π½ΠΈΠΌΠ°Π½ΠΈΠ΅ - ΠΊΠ»ΡŽΡ‡Π΅Π²ΠΎΠΉ ΠΊΠΎΠΌΠΏΠΎΠ½Π΅Π½Ρ‚ трансформСров, Π½ΠΎ Π΅Π³ΠΎ квадратичная ΡΠ»ΠΎΠΆΠ½ΠΎΡΡ‚ΡŒ вычислСний становится ΠΏΡ€ΠΎΠ±Π»Π΅ΠΌΠΎΠΉ ΠΏΡ€ΠΈ ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ΅ Π΄Π»ΠΈΠ½Π½Ρ‹Ρ… ΠΏΠΎΡΠ»Π΅Π΄ΠΎΠ²Π°Ρ‚Π΅Π»ΡŒΠ½ΠΎΡΡ‚Π΅ΠΉ. ΠšΠ²Π°Π½Ρ‚ΠΎΠ²Π°Π½ΠΈΠ΅ ΡƒΡΠΏΠ΅ΡˆΠ½ΠΎ примСняСтся для ускорСния Π»ΠΈΠ½Π΅ΠΉΠ½Ρ‹Ρ… слоСв, Π½ΠΎ ΠΎΠ½ΠΎ ΠΌΠ°Π»ΠΎ ΠΈΠ·ΡƒΡ‡Π΅Π½ΠΎ ΠΏΡ€ΠΈΠΌΠ΅Π½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎ ΠΊ ΠΌΠ΅Ρ…Π°Π½ΠΈΠ·ΠΌΡƒ внимания. SageAttention - ΡΠΊΡΠΏΠ΅Ρ€ΠΈΠΌΠ΅Π½Ρ‚Π°Π»ΡŒΠ½Ρ‹ΠΉ ΠΌΠ΅Ρ‚ΠΎΠ΄, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ 8-Π±ΠΈΡ‚Π½ΠΎΠ΅ ΠΊΠ²Π°Π½Ρ‚ΠΎΠ²Π°Π½ΠΈΠ΅ ΠΌΠ΅Ρ…Π°Π½ΠΈΠ·ΠΌΠ° внимания для ускорСния вычислСний ΠΈ сохранСния точности ΠΌΠΎΠ΄Π΅Π»ΠΈ. ΠœΠ΅Ρ‚ΠΎΠ΄ Π½Π΅ Ρ‚Ρ€Π΅Π±ΡƒΠ΅Ρ‚ ΡΠΏΠ΅Ρ†ΠΈΠ°Π»ΡŒΠ½ΠΎΠ³ΠΎ обучСния ΠΈ ΠΊΠΎΠ½Π²Π΅Ρ€Ρ‚Π°Ρ†ΠΈΠΈ ΠΌΠΎΠ΄Π΅Π»Π΅ΠΉ Π² ΠΊΠ°ΠΊΠΎΠΉ-Π»ΠΈΠ±ΠΎ Ρ„ΠΎΡ€ΠΌΠ°Ρ‚, ΠΎΠ½ примСняСтся ΠΊ ΡΡƒΡ‰Π΅ΡΡ‚Π²ΡƒΡŽΡ‰ΠΈΠΌ трансформСным модСлям Π² Ρ€Π΅ΠΆΠΈΠΌΠ΅ "plug-and-play". ΠšΠ»ΡŽΡ‡Π΅Π²Ρ‹Π΅ особСнности ΠΌΠ΅Ρ‚ΠΎΠ΄Π°: πŸŸ’Π”Π»Ρ ΡƒΠΌΠ΅Π½ΡŒΡˆΠ΅Π½ΠΈΡ ошибки квантования ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ΡΡ сглаТивания ΠΌΠ°Ρ‚Ρ€Ρ‚ΠΈΡ†Ρ‹ К (срСднСС Π·Π½Π°Ρ‡Π΅Π½ΠΈΠ΅ K вычитаСтся ΠΏΠΎ всСм Ρ‚ΠΎΠΊΠ΅Π½Π°ΠΌ); πŸŸ’ΠšΠ²Π°Π½Ρ‚ΠΎΠ²Π°Π½ΠΈΠ΅ Q ΠΈ K Π² INT8; INT8 Π² Ρ‡Π΅Ρ‚Ρ‹Ρ€Π΅ Ρ€Π°Π·Π° быстрСС, Ρ‡Π΅ΠΌ Π² FP16, ΠΈ Π² Π΄Π²Π° Ρ€Π°Π·Π° быстрСС, Ρ‡Π΅ΠΌ Π² FP8. 🟒Matmul PV выполняСтся с FP16-Π½Π°ΠΊΠΎΠΏΠΈΡ‚Π΅Π»Π΅ΠΌ; Π£ΠΌΠ½ΠΎΠΆΠ΅Π½ΠΈΠ΅ ΠΌΠ°Ρ‚Ρ€ΠΈΡ† Π² высокой разрядности позволяСт ΡƒΡΠΊΠΎΡ€ΠΈΡ‚ΡŒ вычислСния Π±Π΅Π· ΠΏΠΎΡ‚Π΅Ρ€ΠΈ точности. πŸŸ’ΠΠ΄Π°ΠΏΡ‚ΠΈΠ²Π½ΠΎΠ΅ ΠΊΠ²Π°Π½Ρ‚ΠΎΠ²Π°Π½ΠΈΠ΅; Для ΠΊΠ°ΠΆΠ΄ΠΎΠ³ΠΎ слоя внимания выбираСтся Π½Π°ΠΈΠ±ΠΎΠ»Π΅Π΅ быстрый Π²Π°Ρ€ΠΈΠ°Π½Ρ‚ квантования. SageAttention Ρ€Π΅Π°Π»ΠΈΠ·ΠΎΠ²Π°Π½ с использованиСм Triton ΠΈ ΠΎΠΏΡ‚ΠΈΠΌΠΈΠ·ΠΈΡ€ΠΎΠ²Π°Π½ для GPU RTX4090 ΠΈ 3090. ΠœΠ΅Ρ‚ΠΎΠ΄ прСвосходит FlashAttention2 ΠΈ xformers ΠΏΠΎ скорости ΠΏΡ€ΠΈΠΌΠ΅Ρ€Π½ΠΎ Π² 2,1 ΠΈ 2,7 Ρ€Π°Π·Π° соотвСтствСнно. ВСстированиС Π½Π° Llama2, CogvideoX, Unidiffuser ΠΈ TIMM ΠΏΠΎΠ΄Ρ‚Π²Π΅Ρ€Π΄ΠΈΠ»ΠΎ сохранСниС ΠΌΠ΅Ρ‚Ρ€ΠΈΠΊ точности ΠΏΡ€ΠΈ использовании SageAttention. ⚠️ ИспользованиС SageAttention рСкомСндуСтся с вСрсиями: 🟠python>=3.11; 🟠torch>=2.4.0; 🟠triton-nightly. ⚠️ SageAttention ΠΎΠΏΡ‚ΠΈΠΌΠΈΠ·ΠΈΡ€ΠΎΠ²Π°Π½ для RTX4090 ΠΈ RTX3090. На Π΄Ρ€ΡƒΠ³ΠΈΡ… Π°Ρ€Ρ…ΠΈΡ‚Π΅ΠΊΡ‚ΡƒΡ€Π°Ρ… GPU прирост ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ΄ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΡΡ‚ΠΈ ΠΌΠΎΠΆΠ΅Ρ‚ Π±Ρ‹Ρ‚ΡŒ Π½Π΅Π·Π½Π°Ρ‡ΠΈΡ‚Π΅Π»ΡŒΠ½Ρ‹ΠΌ. β–ΆοΈΠŸΡ€ΠΈΠΌΠ΅Ρ€ использования:
# Install sageattention
pip install sageattention

# How to use
from sageattention import sageattn
attn_output = sageattn(q, k, v, is_causal=False, smooth_k=True)

# Plug-and-play example with Cogvideo
# add the following codes and run
from sageattention import sageattn
import torch.nn.functional as F

F.scaled_dot_product_attention = sageattn

# Specifically
cd example
python sageattn_cogvideo.py
πŸ“ŒΠ›ΠΈΡ†Π΅Π½Π·ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅: BSD-3-Clause license. 🟑Arxiv πŸ–₯GitHub @ai_machinelearning_big_data #AI #ML #SageAttention #Transformers
Please open Telegram to view this post
VIEW IN TELEGRAM
Telegram Center
Telegram Center
Channel