[S4] Efficiently Modeling Long Sequences with Structured State Spaces
Albert Gu, Karan Goel, Christopher Ré
Статья:
https://arxiv.org/abs/2111.00396
Давно мы не писали про RNN, последний раз, кажется, это было про LEM (
https://t.center/gonzo_ML/857). Тогда также хотелось написать про State Space Models (SSM) и их ярких представителей в лице HiPPO и S4, но началась война и стало не до этого. Тема важная, хочется к ней таки вернуться.
Работа про S4 была принята на мою любимую конференцию ICLR в 2022 году и получила там Outstanding Paper Honorable Mentions (
https://blog.iclr.cc/2022/04/20/announcing-the-iclr-2022-outstanding-paper-award-recipients/).
Есть большая проблема с моделированием длинных последовательностей. Длинные, это типа 10к элементов и больше. Ныне рулящие трансформеры по-прежнему испытывают с этим проблемы, в первую очередь благодаря квадратичной сложности механизма внимания (правда, есть куча неквадратичных механизмов, про многие из которых мы писали, например,
https://t.center/gonzo_ML/404 и далее по ссылкам).
На эту тему есть бенчмарк Long Range Arena (LRA,
https://paperswithcode.com/sota/long-range-modeling-on-lra), где рулят в целом не трансформеры (но лидером там всё же теперь свежий скорее-трансформер одноголовый Mega,
https://arxiv.org/abs/2209.10655).
Было много заходов на решение проблемы моделирования реально длинных последовательностей, и один из свежих это как раз state space model (SSM).
SSM описывается уравнениями:
x′(t)=Ax(t)+Bu(t)
y(t)=Cx(t)+Du(t)
где u(t) -- входной сигнал, x(t) -- n-мерное латентное представление, y(t) выходной сигнал, а A,B,C и D -- обучаемые матрицы.
В работе D считается равным нулю, потому что это аналог skip-connection и его легко вычислить.
На практике такая модель работает плохо, потому что решение такой ODE ведёт к экспоненциальной функции со всеми знакомыми по RNN прелестями в виде vanishing/exploding gradients.
Фреймворк HiPPO с NeurIPS 2020 (
https://arxiv.org/abs/2008.07669) предлагает специальный класс матриц A, позволяющих лучше запоминать историю входных данных. Самая важная матрица этого класса, называемая HiPPO Matrix, выглядит так:
/ (2n + 1)^1/2 * (2k + 1)^1/2 if n > k
A_nk = − { n + 1 if n = k
\ 0 if n < k
Замена в SSM случайной матрицы A на эту приводит к улучшению результата на sequential MNIST с 60% to 98%.
Это было непрерывное описание, а чтобы применить его к дискретным входам, SSM надо дискретизировать с шагом Δ и входной сигнал будет сэмплиться с этим шагом. Для этого используется билинейный метод, заменяющий матрицу A на её аппроксимацию
x_k = Ax_k−1 + Bu_k
y_k = Cx_k
A = (I − ∆/2 · A)^−1 * (I + ∆/2 · A)
B = (I − ∆/2 · A)^ −1 * ∆B
C = C.
Уравнения состояния теперь рекуррентная формула с x_k по типу RNN, а сам x_k может рассматриваться как скрытое состояние с матрицей перехода A.
И есть ещё отдельный хак, что такую рекуррентную формулу можно развернуть и заменить на свёртку (с ядром K) размером в длину последовательности, но эту формулу я тут приводить не буду, потому что текстом она выглядит не очень. Такое можно быстро считать на современном железе.
Боттлнек discrete-time SSM в том, что надо многократно умножать на матрицу A.
Текущая работа предлагает расширение и улучшение SSM под названием Structured State Spaces (S4). Предлагается новая параметризация, где матрица A (которая HiPPO) раскладывается в сумму низкорангового и нормального (
https://en.wikipedia.org/wiki/Normal_matrix) терма, Normal Plus Low-Rank (NPLR).
В дополнение к этому применяются ещё несколько техник: truncated SSM generating function + Cauchy kernel + применение Woodbury identity + подсчёт спектра ядра свёртки через truncated generating function + обратное БПФ, за подробностями и доказательствами вэлкам в статью. Доказывается, что у всех HiPPO матриц есть NPLR представление, и что свёрточный фильтр SSM можно вычислить за O(N + L) операций и памяти (и это главный технический контрибьюшн статьи).