在 AMDGPU 上优化 Triton Flash-Attention

最近在 AMDGPU 上优化用 Triton 实现的 Flash-Attention 算子,有一些优化手段值得记录下来。 通过调整 Block 发射顺序减少 SIMD 的 IDLE 时间 FA 的 Triton 实现中,将 Q 在 M 方向切分为了不同的 block。在前向过程中,如果 causal = True,那么 Q 只有左下三角的元素参与计载。即参与计算的元素在 M 方向从上到下逐渐增加。在默认的实现中,block 是从上到下按序发射的,即先发射负载小的块,再发射负载大的块。由于负载较大的块难以被分配到 SIMD 上,因此导致了较大的 SIMD IDLE。通过从倒序从下到上发射块,即先发射负载大的块,再发射负载小的块,由于负载小的块可以被更均衡地分配到各 SIMD 上,因此可以有效减少 SIMD IDLE。 先发射负载小的块,再发射负载大的块,导致较大的 SIMD IDLE 先发射负载大的块,再发射负载小的块,可以减少 SIMD IDLE 通过实现 chain-dot 减少对 LDS 的访存 在我们的硬件规范下,Q 和 K 矩阵乘的结果 QK 的 Layout 跟 Q 是不同的,因此需要先将 QK 的 Layout 转到跟 Q 一样才可以继续与 V 进行矩阵乘(和 Q 一样作为第一个操作数)。可以通过插入一些寄存器指令对线程之间的数据进行交换,以避免通过写入写出 LDS 来进行 Layout 的转换,这些指令(例如 bpermute, swizzle 等)的开销远小于 LDS 访存。 ...

January 26, 2025

用 MUBUF 替换 FA Kernel 中的 GLOBAL 指令引起了概率错误

最近在优化用 triton 写的 flash-attention 的算子的性能,有一个优化是用 MUBUF 指令(buffer_load_dowrdx4,buffer_store_dwordx4 等)替换 GLOBAL 指令(global_store_dwordx4,global_store_dwordx4等),因为通过利用 MUBUF 的 swizzle 的特性可以增加对 L1 Cache 的利用率,同时 MUBUF 指令可以传递一些有用的 cache modifer。 替换完成之后,在 Z 卡上验证精度没有问题,但是在 B 卡上却出现了概率性的精度问题。起初怀疑是在 B 卡上的编译结果有问题,但是直接将 Z 卡上的汇编文件编译到目标为 B 卡的二进制之后,在 B 卡上仍然有精度问题。 仔细对比了 MUBUF 指令替换前后的 LLVM IR,但是看不出来有任何问题。又对比 golden 数据和 triton 的结果,发现错误的数据存在一些规律,但是错误的坐标和数值也存在随机性。于是不再纠结数据是否正确,转而去看是什么原因导致了结果的随机性。有个这个目标后,不断地简化 flash-attention 的实现,直到只剩下几行代码可以稳定地复现随机性。代码简化后,汇编文件也很简单了,直接在汇编文件上修改,同时观察结果。最终发现了出现问题的 pattern: buffer_store_dwordx4 v[0:3], v32, s[4:7], 0 offen v_add_u32_e32 v83, s2, v82 v_lshlrev_b32_e32 v0, 1, v33 第一行的 buffer_store_dwordx4 将 v[0:3] 的数据写出,第三行 v0 被覆写。虽然根据 ISA 文档,这个 pattern 不存在 data harzard,但是如果在 buffer_store 之后插入一个 nop,精度问题就不存在了。 buffer_store_dwordx4 v[0:3], v32, s[4:7], 0 offen s_nop 0 v_add_u32_e32 v83, s2, v82 v_lshlrev_b32_e32 v0, 1, v33 了解到最近固件的更新会导致非常相似的问题,因此我们怀疑是这个问题同样是由于固件更新导致的。本身不是由 MUBUF 指令引起的,应该是正好撞到了这个 pattern。 ...

January 17, 2025