π SageAttention: ΠΌΠ΅ΡΠΎΠ΄ ΠΊΠ²Π°Π½ΡΠΎΠ²Π°Π½ΠΈΡ ΠΌΠ΅Ρ
Π°Π½ΠΈΠ·ΠΌΠ° Π²Π½ΠΈΠΌΠ°Π½ΠΈΡ Π² Π°ΡΡ
ΠΈΡΠ΅ΠΊΡΡΡΠ°Ρ
ΡΡΠ°Π½ΡΡΠΎΡΠΌΠ΅ΡΠΎΠ².
ΠΠ½ΠΈΠΌΠ°Π½ΠΈΠ΅ - ΠΊΠ»ΡΡΠ΅Π²ΠΎΠΉ ΠΊΠΎΠΌΠΏΠΎΠ½Π΅Π½Ρ ΡΡΠ°Π½ΡΡΠΎΡΠΌΠ΅ΡΠΎΠ², Π½ΠΎ Π΅Π³ΠΎ ΠΊΠ²Π°Π΄ΡΠ°ΡΠΈΡΠ½Π°Ρ ΡΠ»ΠΎΠΆΠ½ΠΎΡΡΡ Π²ΡΡΠΈΡΠ»Π΅Π½ΠΈΠΉ ΡΡΠ°Π½ΠΎΠ²ΠΈΡΡΡ ΠΏΡΠΎΠ±Π»Π΅ΠΌΠΎΠΉ ΠΏΡΠΈ ΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠ΅ Π΄Π»ΠΈΠ½Π½ΡΡ
ΠΏΠΎΡΠ»Π΅Π΄ΠΎΠ²Π°ΡΠ΅Π»ΡΠ½ΠΎΡΡΠ΅ΠΉ. ΠΠ²Π°Π½ΡΠΎΠ²Π°Π½ΠΈΠ΅ ΡΡΠΏΠ΅ΡΠ½ΠΎ ΠΏΡΠΈΠΌΠ΅Π½ΡΠ΅ΡΡΡ Π΄Π»Ρ ΡΡΠΊΠΎΡΠ΅Π½ΠΈΡ Π»ΠΈΠ½Π΅ΠΉΠ½ΡΡ
ΡΠ»ΠΎΠ΅Π², Π½ΠΎ ΠΎΠ½ΠΎ ΠΌΠ°Π»ΠΎ ΠΈΠ·ΡΡΠ΅Π½ΠΎ ΠΏΡΠΈΠΌΠ΅Π½ΠΈΡΠ΅Π»ΡΠ½ΠΎ ΠΊ ΠΌΠ΅Ρ
Π°Π½ΠΈΠ·ΠΌΡ Π²Π½ΠΈΠΌΠ°Π½ΠΈΡ.
SageAttention - ΡΠΊΡΠΏΠ΅ΡΠΈΠΌΠ΅Π½ΡΠ°Π»ΡΠ½ΡΠΉ ΠΌΠ΅ΡΠΎΠ΄, ΠΊΠΎΡΠΎΡΡΠΉ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠ΅Ρ 8-Π±ΠΈΡΠ½ΠΎΠ΅ ΠΊΠ²Π°Π½ΡΠΎΠ²Π°Π½ΠΈΠ΅ ΠΌΠ΅Ρ
Π°Π½ΠΈΠ·ΠΌΠ° Π²Π½ΠΈΠΌΠ°Π½ΠΈΡ Π΄Π»Ρ ΡΡΠΊΠΎΡΠ΅Π½ΠΈΡ Π²ΡΡΠΈΡΠ»Π΅Π½ΠΈΠΉ ΠΈ ΡΠΎΡ
ΡΠ°Π½Π΅Π½ΠΈΡ ΡΠΎΡΠ½ΠΎΡΡΠΈ ΠΌΠΎΠ΄Π΅Π»ΠΈ.
ΠΠ΅ΡΠΎΠ΄ Π½Π΅ ΡΡΠ΅Π±ΡΠ΅Ρ ΡΠΏΠ΅ΡΠΈΠ°Π»ΡΠ½ΠΎΠ³ΠΎ ΠΎΠ±ΡΡΠ΅Π½ΠΈΡ ΠΈ ΠΊΠΎΠ½Π²Π΅ΡΡΠ°ΡΠΈΠΈ ΠΌΠΎΠ΄Π΅Π»Π΅ΠΉ Π² ΠΊΠ°ΠΊΠΎΠΉ-Π»ΠΈΠ±ΠΎ ΡΠΎΡΠΌΠ°Ρ, ΠΎΠ½ ΠΏΡΠΈΠΌΠ΅Π½ΡΠ΅ΡΡΡ ΠΊ ΡΡΡΠ΅ΡΡΠ²ΡΡΡΠΈΠΌ ΡΡΠ°Π½ΡΡΠΎΡΠΌΠ΅Π½ΡΠΌ ΠΌΠΎΠ΄Π΅Π»ΡΠΌ Π² ΡΠ΅ΠΆΠΈΠΌΠ΅ "plug-and-play".
ΠΠ»ΡΡΠ΅Π²ΡΠ΅ ΠΎΡΠΎΠ±Π΅Π½Π½ΠΎΡΡΠΈ ΠΌΠ΅ΡΠΎΠ΄Π°:
π’ΠΠ»Ρ ΡΠΌΠ΅Π½ΡΡΠ΅Π½ΠΈΡ ΠΎΡΠΈΠ±ΠΊΠΈ ΠΊΠ²Π°Π½ΡΠΎΠ²Π°Π½ΠΈΡ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ ΡΠ³Π»Π°ΠΆΠΈΠ²Π°Π½ΠΈΡ ΠΌΠ°ΡΡΡΠΈΡΡ Π (ΡΡΠ΅Π΄Π½Π΅Π΅ Π·Π½Π°ΡΠ΅Π½ΠΈΠ΅ K Π²ΡΡΠΈΡΠ°Π΅ΡΡΡ ΠΏΠΎ Π²ΡΠ΅ΠΌ ΡΠΎΠΊΠ΅Π½Π°ΠΌ);
π’ΠΠ²Π°Π½ΡΠΎΠ²Π°Π½ΠΈΠ΅ Q ΠΈ K Π² INT8;
INT8 Π² ΡΠ΅ΡΡΡΠ΅ ΡΠ°Π·Π° Π±ΡΡΡΡΠ΅Π΅, ΡΠ΅ΠΌ Π² FP16, ΠΈ Π² Π΄Π²Π° ΡΠ°Π·Π° Π±ΡΡΡΡΠ΅Π΅, ΡΠ΅ΠΌ Π² FP8.
π’Matmul PV Π²ΡΠΏΠΎΠ»Π½ΡΠ΅ΡΡΡ Ρ FP16-Π½Π°ΠΊΠΎΠΏΠΈΡΠ΅Π»Π΅ΠΌ;
Π£ΠΌΠ½ΠΎΠΆΠ΅Π½ΠΈΠ΅ ΠΌΠ°ΡΡΠΈΡ Π² Π²ΡΡΠΎΠΊΠΎΠΉ ΡΠ°Π·ΡΡΠ΄Π½ΠΎΡΡΠΈ ΠΏΠΎΠ·Π²ΠΎΠ»ΡΠ΅Ρ ΡΡΠΊΠΎΡΠΈΡΡ Π²ΡΡΠΈΡΠ»Π΅Π½ΠΈΡ Π±Π΅Π· ΠΏΠΎΡΠ΅ΡΠΈ ΡΠΎΡΠ½ΠΎΡΡΠΈ.
π’ΠΠ΄Π°ΠΏΡΠΈΠ²Π½ΠΎΠ΅ ΠΊΠ²Π°Π½ΡΠΎΠ²Π°Π½ΠΈΠ΅;
ΠΠ»Ρ ΠΊΠ°ΠΆΠ΄ΠΎΠ³ΠΎ ΡΠ»ΠΎΡ Π²Π½ΠΈΠΌΠ°Π½ΠΈΡ Π²ΡΠ±ΠΈΡΠ°Π΅ΡΡΡ Π½Π°ΠΈΠ±ΠΎΠ»Π΅Π΅ Π±ΡΡΡΡΡΠΉ Π²Π°ΡΠΈΠ°Π½Ρ ΠΊΠ²Π°Π½ΡΠΎΠ²Π°Π½ΠΈΡ.
SageAttention ΡΠ΅Π°Π»ΠΈΠ·ΠΎΠ²Π°Π½ Ρ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΠ΅ΠΌ
Triton
ΠΈ ΠΎΠΏΡΠΈΠΌΠΈΠ·ΠΈΡΠΎΠ²Π°Π½ Π΄Π»Ρ GPU RTX4090 ΠΈ 3090. ΠΠ΅ΡΠΎΠ΄ ΠΏΡΠ΅Π²ΠΎΡΡ
ΠΎΠ΄ΠΈΡ FlashAttention2 ΠΈ xformers ΠΏΠΎ ΡΠΊΠΎΡΠΎΡΡΠΈ ΠΏΡΠΈΠΌΠ΅ΡΠ½ΠΎ Π² 2,1 ΠΈ 2,7 ΡΠ°Π·Π° ΡΠΎΠΎΡΠ²Π΅ΡΡΡΠ²Π΅Π½Π½ΠΎ.
Π’Π΅ΡΡΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅ Π½Π° Llama2, CogvideoX, Unidiffuser ΠΈ TIMM ΠΏΠΎΠ΄ΡΠ²Π΅ΡΠ΄ΠΈΠ»ΠΎ ΡΠΎΡ
ΡΠ°Π½Π΅Π½ΠΈΠ΅ ΠΌΠ΅ΡΡΠΈΠΊ ΡΠΎΡΠ½ΠΎΡΡΠΈ ΠΏΡΠΈ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΠΈ SageAttention.
β οΈ ΠΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΠ΅ SageAttention ΡΠ΅ΠΊΠΎΠΌΠ΅Π½Π΄ΡΠ΅ΡΡΡ Ρ Π²Π΅ΡΡΠΈΡΠΌΠΈ:
π python>=3.11;
π torch>=2.4.0;
π triton-nightly.
β οΈ SageAttention ΠΎΠΏΡΠΈΠΌΠΈΠ·ΠΈΡΠΎΠ²Π°Π½ Π΄Π»Ρ RTX4090 ΠΈ RTX3090. ΠΠ° Π΄ΡΡΠ³ΠΈΡ
Π°ΡΡ
ΠΈΡΠ΅ΠΊΡΡΡΠ°Ρ
GPU ΠΏΡΠΈΡΠΎΡΡ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ΄ΠΈΡΠ΅Π»ΡΠ½ΠΎΡΡΠΈ ΠΌΠΎΠΆΠ΅Ρ Π±ΡΡΡ Π½Π΅Π·Π½Π°ΡΠΈΡΠ΅Π»ΡΠ½ΡΠΌ.
βΆοΈΠΡΠΈΠΌΠ΅Ρ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΡ:
# Install sageattention
pip install sageattention
# How to use
from sageattention import sageattn
attn_output = sageattn(q, k, v, is_causal=False, smooth_k=True)
# Plug-and-play example with Cogvideo
# add the following codes and run
from sageattention import sageattn
import torch.nn.functional as F
F.scaled_dot_product_attention = sageattn
# Specifically
cd example
python sageattn_cogvideo.py
πΠΠΈΡΠ΅Π½Π·ΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅: BSD-3-Clause license.
π‘Arxiv
π₯GitHub
@ai_machinelearning_big_data
#AI #ML #SageAttention #Transformers