JEST —
Joint
Example
Selec
Tion
📜 Статья от DeepMind 2024 года, описывающая обучение на сложных батчах (hard negatives) для CLIP-ранжировщиков. Вклад статьи:
умный batch composition и
уменьшение затрат на обучение. Авторы создают
JEST — модель, обучаемая в разы быстрее, чем
SigLIP.
👋 Batch composition:— Подход, похожий на дистилляцию, в котором есть
learner-модель, которая обучается, и
reference-модель, которая уже умеет решать задачи zero-shot classification/retrieval.
— В методе явно выделены
2️⃣ стадии: inference (pseudo-labeling) на большом батче и forward + backward на малом, но сложном.
— Исходный большой батч от даталоадера просеивается на 80-90%. Оставшиеся 10-20% собираются итерационно на основе прогнозов
learner и
reference моделей. Через получившийся сложный батч
💪 будет считаться backward для learner-модели.
— Сложный батч собирается по итерациям, исходя из разницы матриц ошибок
learner и
reference-моделей для большого батча. На каждой итерации в сложный батч сэмплируются примеры, образующие с уже имеющимися наибольшую ошибку (похожие на имеющиеся).
👋 Training cost:— Авторы снижают затраты на обучение, меняя
ViT на
Flexi-ViT 🎉, метрика которого максимизируется с двумя patch sizes — 16 и 32. На первой стадии используется модель с patch size 32 (что уменьшает FLOPs на 72% по сравнению с patch size = 16).
— Сформировав сложный батч, половина его примеров обрабатывается с patch size 16, а другая с patch size 32 (multi-resolution training).
Reference-моделью является SigLIP, прогнозы которого лучше закэшировать т.к. в этом сетапе он не обучается.
Авторы тюнят reference-модель на 2 версиях своего датасета:
WebLI-curated и
WebLI-curated++ (в 7 раз больше данных; 100M
🆚 700M). Датасеты очень хорошо отфильтрованы
👍.
Архитектура, использующая Flexi-ViT в качестве vision tower называется Flexi-JEST, иначе просто JEST. Если reference-модель была затюнена на WebLI-curated++, то к модели приписывается «++».
💪 В итоге авторы создают модель
Flexi-JEST++, которая тратит на обучение в 9 раз меньше FLOPs, чем SigLIP (2023 года), достигая при этом сопоставимых метрик.
⚡️ Но и без замены vision tower, авторы демонстрируют как
JEST++ тратит 23% FLOPs по сравнению с SigLIP только за счет обучения на hard negative батчах.
Мы опробовали
batch composition авторов, и подтверждаем, что метод докидывает в качестве по сравнению с обучением на большом батче, в котором примеры фильтровались независимо (а не совместно).
Автор:
@darkasevgen#paperwatch