Как получить hard-батч из super-batch?
Сэмплы, итерационно, добавляются чанками (
n_chunks), причем на каждой итерации добавляются наиболее путаемые* с уже существующими в батче.
*на самом деле здесь сэмплирование, а не добавление «в лоб».
Если вернуться к sigmoid-лоссу, то
внутри происходит подсчет logsigmoid-матрицы на основе img/text alignment-матрицы. Именно с 1-ой и работают авторы.
Ее размер (bs, bs), и ее можно назвать матрицей ошибок.
На практике, у нас есть
learner — модель, которую мы будем обучать и
reference — модель, которая уже умеет матчить картинки с текстами.
Для каждого super-батча и 2-ух моделей авторы получают 2 logsigmoid-матрицы, вычитая, которые получают
learnability scores — матрицу ошибок, значения которой выступают в роли логитов для сэмплирования следующих элементов в батч (exp(scores)) чтобы подтолкнуть
learner к
reference.
Пример:
Допустим у нас есть super-батч размера 1000 элементов.
filter_ratio выберем 0.88 — это гиперпараметр, отвечающий за пропорцию оставленного hard-батча, (итоговый hard-батч будет 1000 * 0.12 = 120 элементов).
n_chunks выберем 10, то есть будем итеративно (10 - 1 = 9) раз добавлять в hard-подбатч по 12 сэмплов.
После dataloader-а,
learner (обучаемый) и
reference (замороженный) прогнозируют свои logsigmoid-матрицы, которые вычитаются, и в результате остается learnability scores матрица (1000, 1000).
Первоначально батч инициализируется 12 элементами на основе такой матрицы.
Далее, для выбранных элементов из этой матрицы берутся 2-подматрицы (12, 1000) и (1000, 12), которые суммируются и для оставшихся 988 элементов получается сумма logsigmoid-ошибок с существующими. Таким образом следующие 12 сэмплов будут сэмплироваться на основе ошибок с существующими. Алгоритм набора сэмплов повторяется
n_chunks - 1 раз.
** Прогнозы
reference-модели можно закэшировать, чтобы не прогнозировать ею один и тот же сэмпл каждую эпоху.
*** Можно в качестве
learnability scores выбрать только матрицу ошибок
reference-модели, но это хуже чем учитывать 2 матрицы (статья: фигура 3 middle vs right).
**** Можно в качестве
learnability scores выбрать только матрицу ошибок
learner-модели, но это черевато бэкпропингу на мусоре (или слишком общих описаниях).