View in Telegram
Cut Your Losses in Large-Vocabulary Language Models авторы вспоминают вычисление кросс энтропии
loss = F.cross_entropy((e @ c.T).float(), targets)
где e - финальные эмбеддинги, с - матрица, которая переводит из эмбеддингов в vocab_size вектор. и эта самая матрица e @ c.T, как оказывается, кушает оочень многоо памяти, а при увеличении вокаб сайза затрат по памяти становится еще больше и чтобы снизить память, авторы сделали свой тритоновский кернел, который не засовывает сразу все логиты в глобальную память, а вычисляет логиты для верных токенов (во flash memory из-за концепции teacher forcing) и log-sum-exp на лету (по аналогии с онлайн софтмаксом из флеш аттн), т.е. не аллоцирует память на e @ c.T к тому же еще при бекворде они заметили, что меньше чем 0.02% от нормализованных логитов являются ненулевыми (по большей части из-за действительно большого вокаб сайза) → не вычисляют градиент для элементов, которые не проходят трешхолд в 2е-12 для бф16 точности, что тоже снижает по памяти и прибавляет по скорости оверолл голова ллмок начала теперь кушать по памяти в среднем гигабайт вместо 28 при батч сайзе в 65к (при том экспы ставили на модельках от 1.3B do 70B за что респект). есть правда вопросы насколько стабильно будет этот метод работать для претрена с нуля, ибо авторы только “файнтюнили” очень понятно описано решение их кернела, как и сам код, рекомендуем к прочтению выглядит интересно и прикольно, на последней картинке только с осторожностью относился бы к абсолютным числам, которые у них получились на экспах по замеру времени 👀LINK
Love Center - Dating, Friends & Matches, NY, LA, Dubai, Global
Love Center - Dating, Friends & Matches, NY, LA, Dubai, Global
Find friends or serious relationships easily