Neural Network Diffusion
Kai Wang, Zhaopan Xu, Yukun Zhou, Zelin Zang, Trevor Darrell, Zhuang Liu, Yang You
Статья:
https://arxiv.org/abs/2402.13144
Код:
https://github.com/NUS-HPC-AI-Lab/Neural-Network-Diffusion
Диффузионные модели сейчас рулят, создавая прекрасные картинки и не только. Авторы предложили, что они могут генерить и параметры нейросетей. Вообще, мне кажется, они изобрели hypernetwork (писали про них тут
https://t.center/gonzo_ML/1696) через диффузию.
Для тех, кто не знает как работают диффузионные модели, совсем в двух словах и на пальцах. Прямой диффузионный процесс получает на вход картинку (вместо картинки может быть любой другой сигнал) и последовательно шаг за шагом добавляет в неё шум, пока она не превратится в совсем шумный сигнал. Прямой диффузионный процесс не очень интересен, интересен обратный -- он получает на вход шум и последовательно его убирает, “открывая” (создавая) скрывающуюся за ним картинку (как бы делая denoising). Примеры диффузионных моделей мы разбирали в лице DALLE 2 (
https://t.center/gonzo_ML/919) и Imagen (
https://t.center/gonzo_ML/980).
Обучение нейросети через SGD идейно похоже на обратный диффузионный процесс: стартуем с рандомной инициализации и последовательно обновляем веса, пока не достигнем высокого качества на заданной задачи. Свой подход авторы назвали
neural network diffusion или
p-diff (от parameter diffusion).
Идея и реализация просты и по-своему красивы.
Во-первых, мы собираем датасет с параметрами нейросетей, обученных SGD и обучаем на нём автоэнкодер, из которого потом возьмём latent representation (можем это делать не на полном наборе параметров, а на подмножестве). Вторым шагом мы обучаем диффузионную модель, которая из случайного шума сгенерит latent representation, который в свою очередь через декодер обученного на первом шаге автоэнкодера мы восстановим в сами веса. Теоретически можно было бы и обучить диффузию на самих весах сразу, но это требует сильно больше памяти.
Для автоэнкодера параметры преобразуются в одномерный вектор, также используется одновременная аугментация шумом входных параметров и латентного представления. Обучение диффузионной модели -- это классический DDPM (
https://arxiv.org/abs/2006.11239). Использовались 4-слойные 1D CNN энкодер и декодер.
Проверяли на картиночных датасетах MNIST, CIFAR-10, CIFAR-100, STL-10, Flowers, Pets, F-101, ImageNet-1K и на сетях ResNet-18/50, ViT-Tiny/Base, ConvNeXt-T/B.
Для каждой архитектуры накапливали 200 точек для обучения (чекпойнты последней эпохи). Я не до конца уловил, что именно они сохраняли, говорят про два последних слоя нормализации (только параметры BatchNorm’а чтоли?) и фиксированные остальные параметры. В большинстве случаев обучение автоэнкодера и диффузионки требовало 1-3 часа на одной A100 40G.
На инференсе генерят 100 новых параметров, из них оставляют один с максимальным перформансом на training set, его оценивают на validation set и этот результат и репортят.
В качестве бейзлайнов выступают 1) оригинальные модели и 2) ансамбли в виде усреднённого супа файнтюненных моделей (
“Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time”,
https://arxiv.org/abs/2203.05482).
Результат в большинстве случаев не хуже обоих бейзлайнов. То есть выучивается распределение high-performing параметров. Метод стабильно хорошо работает на разных датасетах.
Провели много абляций на ResNet-18 + CIFAR-100.
Чем больше моделей было в обучении, тем лучше. Метод генерит более качественные модели для слоёв на любой глубине. При этом на последних слоях результат самый высокий (предполагают, что это из-за меньшего накопления ошибок во время forward prop). Аугментация шумом в автоэнкодере очень важна, особенно для латентного состояния (а лучше одновременно и для входа тоже).
Это всё было для подмножества весов. Проверили также на генерации полного набора весов на маленьких сетях MLP-3 и ConvNet-3 и MNIST/CIFAR-10/100. Размеры сетей здесь 25-155к параметров. Также работает.