Cut Your Losses in Large-Vocabulary Language Models
авторы вспоминают вычисление кросс энтропии
loss = F.cross_entropy((e @ c.T).float(), targets)
где e - финальные эмбеддинги, с - матрица, которая переводит из эмбеддингов в vocab_size вектор. и эта самая матрица
e @ c.T
, как оказывается, кушает оочень многоо памяти, а при увеличении вокаб сайза затрат по памяти становится еще больше
и чтобы снизить память, авторы сделали свой тритоновский кернел, который не засовывает сразу все логиты в глобальную память, а вычисляет логиты для верных токенов (во flash memory из-за концепции teacher forcing) и log-sum-exp на лету (по аналогии с онлайн софтмаксом из флеш аттн), т.е. не аллоцирует память на
e @ c.T
к тому же еще при бекворде они заметили, что меньше чем 0.02% от нормализованных логитов являются ненулевыми (по большей части из-за действительно большого вокаб сайза) → не вычисляют градиент для элементов, которые не проходят трешхолд в 2е-12 для бф16 точности, что тоже снижает по памяти и прибавляет по скорости
оверолл голова ллмок начала теперь кушать по памяти в среднем гигабайт вместо 28 при батч сайзе в 65к (при том экспы ставили на модельках от 1.3B do 70B за что респект). есть правда вопросы насколько стабильно будет этот метод работать для претрена с нуля, ибо авторы только “файнтюнили”
очень понятно описано решение их кернела, как и сам код, рекомендуем к прочтению
выглядит интересно и прикольно, на последней картинке только с осторожностью относился бы к абсолютным числам, которые у них получились на экспах по замеру времени
👀LINK