DS V4 sparse attention

DSA

使用了 Indexer Projection 的 trick,其中

  • 生成低维度的新 Q/V,并用此做 cross-attention 进行评分计算
  • 然后做 top-k 得到最重要的 k 个 kv
  • 然后每个 query 只选择自己最重要的 kv

CSA: Compressed sparse attention

相比于 DSA 更加压缩,其中

  • 首先先将序列切成 n/m 个 block,block 内部先 merge
  • 然后对这些 block 做 DSA,得到 \(K_1,V_1\)
  • 然后再对于没有 merge 过的序列取小滑动窗口得到 \(K_2,V_2\)
  • \(K_1\cup K_2,V_1\cup V_2\),然后过 attention block.

HCA: Heavily Compressed attention

将序列压缩程度更大地切成 block,每个 block 内部 merge 之后做 dense attention

MQA: Multi-query Attention

  • 传统 multi-head attention 每个 attention 都要缓存自己的 QKV,容易存不下。
  • 于是采用只有 query 多头,别的都单头的机制。

Grouped Output Projection

  • 注意到单个头的维度上升之后,会导致最后的 output projection 变大。
  • 于是将头们划分成若干个组,每个组内部 concat 后投影降维度,然后最后把这些 low-rank 的东西再拼起来然后做 up-projection

SpargeAttention (ICML2025)

和 DS 的 sparse attention 不同,是一个 plug-in 的东西. 基于 flash-attention 的分块.

Sparse block online prediction

对于 flash-attention 划分出来的一个块,计算其 Q,K 的内部平均 cosine similarity,然后通过一个阈值设定是否是相似块(selective / fix)

核心思想是,对 fix block 我们认真地计算 attention,对于 fix block,我们可以把最稀疏的东西略掉.

如何判断是否稀疏?

  • 首先每个块 mean pooling,然后计算内积后,得到对 attention 的估计 \(p\)
  • 将含有 fix block 的 pair 给强制设成 \(m_{i,j}=1\)\(p_{i,j}=-\infty\);然后做 softmax 后,排序后,通过 Top-CDF 截断来将一些比较重要的 \(m_{i,j}\) 设为 \(1\).

然后 \(m=0\) 的就直接略去了.

Sparse online softmax

在使用 flashattn 的 online softmax 的时候,注意到

  • 如果局部的 max 远小于全局历史的 max,那么这一项在 softmax 中就会占比极小可以忽略,就可以跳过和 value 的乘法
  • 然后就能少掉很多次 softmax 后和 value 的乘法

Hilbercurve permuation

注意到如果对于图片/视频,随意分块容易导致有很多 fix block

于是论文采用 Hilbert Curve 来遍历 \(T\times W\times H\) 的视觉 token,这样能够得提升 self-similarity.

实验结果

在图像生成上远超了其他的一些方法;视频生成模型带来了 1.8x 加速.

在长序列中更加使用,并且 fix block 在计算是必要.

Xattention (ICML2025)

这个也是一个 plug-in framework. 注意到 block sparse 为了计算 importance 会多计算不少东西,有一点亏.

考虑到对于任何的 sparse pattern,其一定与反对角线相交. 于是文章提出了一个新的计算 importance 的方法:对于一个 block,提取其反对角线上的元素,用其来计算 importance weight.

Threshold Block Selection

  • 对于块编号 \(i,j\),定义 \(A'_{i,j}\) 为原本的 \(A=QK^T\) 的反对角线按着一个 stride=S 求和.
  • 求这个很简单:直接提取 \(Q\) 的正序 stride=S 然后 \(K\) 的逆序 stride=S 即可.
  • 然后做 Top-CDF 截断获得需要计算的块.

Minimum Threshold Prediction

  • 注意到 Top-CDF 需要一个 threshold,而不同 head 的重要程度不尽相同,所需的 threshold 也不尽相同. 文章提出用 DP 来计算 threshold.

  • 定义一次调整为将一个头的 threshold 下降 10%. 然后令 \(DP[h][m]\) 表示只调整了前 h 个头,并且总共调整次数为 m 的模型最佳表现. 转移是 trivial 的.

    这样就能最后得到 accuracy 和 efficiency 的 curve.

实验结果

考虑 256K 上下文中,S=16 时注意力本身达到了 13x 加速,S=8 时达到 9x 加速.

由于反对角线很容易计算,所 以比 FlexPrefill 什么的有 5.9x 加速.