GEMM 形态

NT,TN,NN,TT 不同厂商的规则不一样,我们现在是把以下情况称之为 NN,也就是 A 是行优先,B 是列优先的时候。 再具体点就是 A 和 B 都是 K 方向连续的时候,称之为 NN。 当 A 或者 B 的连续方向改变,就变成了 NT, TT, TN。用下表总结一下,我把最后一个维度作为连续的维度。 上图对应的就是 id=0 的记录。 id A B 注释 0 [M, K] [N, K] NN 1 [M, K] [K, N] NT 2 [K, M] [N, K] TN 3 [K, M] [K, N] TT 从数据加载的角度来看,A 和 B 都是 K 方向连续的时候性能是最好的,因为这时可以尽量合并使用 b128 的指令。 而且对于减少 L1 的 set 冲突也是很有效的。 GROUPED-GEMM 用 NN 的数据布局来举例,A 的形状是 [M, K], B 的形状是 [NUM_GROUPS, N, K],结果 C 的形状是 [M, N]。伴随 A 的有一个输入是 m_indicies,它是一个长度为 M 的一维数组,用来指定 A 的每一行使用 B 的哪一个 group 的权重来进行矩阵乘。 ...

June 17, 2025

ThreadBlock-Swizzle 和 Persistent-Kernel

THREADBLOCK-SWIZZLE (GROUP-M, GROUP-N) 下面这张图可以直观地解释 threadblock-swizzle 想要优化的方向。假设一共有 6 个 Compute Unit;左图是 raster along M 同时 swizzle=1,右图是 raster along M 同时 swizzle=2。 看左图,出一波 6 个输出块需要加载 7 个输入块(6A + 1B);而右图,同样出一波 6 个输出块只需要加载 5 个输入快(3A + 2B)。如果不考虑 cache hit 的话,显然右侧对带宽的需求更低,性能更好。 这个也好理解,小学就学过,当面积相等的时候,正方形的周长最短,所以越接近方形的分组,需要加载的输入就越少。 不过在实际情况下,需要考虑 CU 的个数,L2 缓存的大小等因素,很难从第一性原理出发计算出最好的 raster 方向和 swizzle_size,一般是作为参数去 autoune。 比方说假设共有 6 个 CU,那么基本可以按照右图的顺序计算 tiles:第一波先计算 0-5,第二波计算 6-11,以此类推;但是如果有 12 个 CU,那么第一波此就可以计算 0-11,左图和右图就没有区别了。在这种情况下,我们会希望 CU/swizzle_size < M tiles,这样不同组的 tiles 确实会在不同的波次中计算。 PERSISTENT-KERNEL Persistent kernel 指的是每一个 cu 只起一个 threadblock, 这个 tg 一直执行同一个 kernel。比方说共有 4 个 cu,那么共起 4 个 tg,每个 cu 一个单独的 tg。 ...

June 13, 2025

StreamK 小结二

实验数据 性能的总体影响 由于之前的 streamk 的实现,导致了约一半的 GEMM MNK 组合出现了性能的下降;因此,我把 data-parallel,splitk 和 streamk (dp + 1 streamk 的版本) 都加入了 triton 的 autoune 的范围。在四千多个 MNK GEMM 组合上进行了测试,结果如下: 可以看到 99.2% 的组合都是性能提升的,也出现了一些性能下降的。这里需要强调一下,这次的测试除了引入 streamk 之外,也做了一些数据精度上的调整。之前的 GEMM 在 K 方向的累加是用 float16 实现的,这个跟甲方确认过没有问题。但是在实现 streamk 的过程中,因为 k 方向的累加次数可能会很多,另外也是为了使得 GEMM 的精度更高,因此统一都更换到了 float32 进行累加,这个改动会使得性能下降很多(大概比 float16 精度下降约一半)。因此这里的性能的改变并不是单单由于 streamk 导致的,实际上如果单独考量的话,streamk 的提升肯定会更多。 下面这个分布图可以直观地看出来性能提升的分布,可以看到平均数和中位数都大约是 120%-130%,另外在 150%-200% 之间也有分布。最多的提升到原来的 266%,而最差的是原来性能的 46% (因为累加精度的提升)。 性能下降的奇怪 pattern 在第一张图中,注意到一个奇怪的现象,性能下降的 MNK 似乎遵循某种 pattern,排列得非常有规律。下图很明显地体现出来,当 M * N 等于三个数的时候,性能是有较多下降的。这张图中横坐标是 M * N, 纵坐标是性能百分比,虚线表示 100% 的位置。至于原因,我研究了半天也没有发现为什么,怀疑可能跟用 atomic_add 写出 float32 的结果矩阵有关系,但是没有完全解释清楚。 ...

June 6, 2025

StreamK

原理 StreamK 这个算法的目的是将 GEMM 中的 K 方向的 MAC 均分到每一个 Compute Unit,使得所有 CU 可以同时结束,从而减少 idle 的时间。 GEMM 最常见的调度算法是每一个输出的块由一个 TG 负责计算,这个 TG 把 K 方向的累加全部算完了。这个算法以前是没有什么问题的,性能很好。但是现在 CU 性能越来越强大,就使得切块约切越大,对于同一个 MKN,需要用到的 TG 越来越少;而同时 CU 数量越来越多,这就使得很多情况下,CU 没有被打满,浪费了算力。 当然可以减少切块的大小,但是这样会使得 cache 的利用率降低。也有一种 split-k 的算法,将一个输出的块平均分给 N 个 TG 来计算,最后再累加到一起。Split-k 的切分是通过在 K 方向均匀地切分,每个 TG 计算相等的一部分的 K 方向的 MAC。这个算法提高了 CU 的利用率,但是不能完全消灭 CU 的 idle 的时间。Stream-k 的算法就是进阶版的 split-k,K 方向不进行均匀切分,而是按需分配,相切多少切多少,满足的目标就是所有 CU 分配到的 K 方向的 MAC 是完全相等的,这样就完全没有 idle 的浪费。 Xprof 显示 80 个 CU 的耗时基本是平均的。 ...

June 3, 2025

AMDGPU 并行度计算器

传送门

March 4, 2025

在 AMDGPU 上优化 Triton Flash-Attention

最近在 AMDGPU 上优化用 Triton 实现的 Flash-Attention 算子,有一些优化手段值得记录下来。 通过调整 Block 发射顺序减少 SIMD 的 IDLE 时间 FA 的 Triton 实现中,将 Q 在 M 方向切分为了不同的 block。在前向过程中,如果 causal = True,那么 Q 只有左下三角的元素参与计载。即参与计算的元素在 M 方向从上到下逐渐增加。在默认的实现中,block 是从上到下按序发射的,即先发射负载小的块,再发射负载大的块。由于负载较大的块难以被分配到 SIMD 上,因此导致了较大的 SIMD IDLE。通过从倒序从下到上发射块,即先发射负载大的块,再发射负载小的块,由于负载小的块可以被更均衡地分配到各 SIMD 上,因此可以有效减少 SIMD IDLE。 先发射负载小的块,再发射负载大的块,导致较大的 SIMD IDLE 先发射负载大的块,再发射负载小的块,可以减少 SIMD IDLE 通过实现 chain-dot 减少对 LDS 的访存 在我们的硬件规范下,Q 和 K 矩阵乘的结果 QK 的 Layout 跟 Q 是不同的,因此需要先将 QK 的 Layout 转到跟 Q 一样才可以继续与 V 进行矩阵乘(和 Q 一样作为第一个操作数)。可以通过插入一些寄存器指令对线程之间的数据进行交换,以避免通过写入写出 LDS 来进行 Layout 的转换,这些指令(例如 bpermute, swizzle 等)的开销远小于 LDS 访存。 ...

January 26, 2025

用 MUBUF 替换 FA Kernel 中的 GLOBAL 指令引起了概率错误

最近在优化用 triton 写的 flash-attention 的算子的性能,有一个优化是用 MUBUF 指令(buffer_load_dowrdx4,buffer_store_dwordx4 等)替换 GLOBAL 指令(global_store_dwordx4,global_store_dwordx4等),因为通过利用 MUBUF 的 swizzle 的特性可以增加对 L1 Cache 的利用率,同时 MUBUF 指令可以传递一些有用的 cache modifer。 替换完成之后,在 Z 卡上验证精度没有问题,但是在 B 卡上却出现了概率性的精度问题。起初怀疑是在 B 卡上的编译结果有问题,但是直接将 Z 卡上的汇编文件编译到目标为 B 卡的二进制之后,在 B 卡上仍然有精度问题。 仔细对比了 MUBUF 指令替换前后的 LLVM IR,但是看不出来有任何问题。又对比 golden 数据和 triton 的结果,发现错误的数据存在一些规律,但是错误的坐标和数值也存在随机性。于是不再纠结数据是否正确,转而去看是什么原因导致了结果的随机性。有个这个目标后,不断地简化 flash-attention 的实现,直到只剩下几行代码可以稳定地复现随机性。代码简化后,汇编文件也很简单了,直接在汇编文件上修改,同时观察结果。最终发现了出现问题的 pattern: buffer_store_dwordx4 v[0:3], v32, s[4:7], 0 offen v_add_u32_e32 v83, s2, v82 v_lshlrev_b32_e32 v0, 1, v33 第一行的 buffer_store_dwordx4 将 v[0:3] 的数据写出,第三行 v0 被覆写。虽然根据 ISA 文档,这个 pattern 不存在 data harzard,但是如果在 buffer_store 之后插入一个 nop,精度问题就不存在了。 buffer_store_dwordx4 v[0:3], v32, s[4:7], 0 offen s_nop 0 v_add_u32_e32 v83, s2, v82 v_lshlrev_b32_e32 v0, 1, v33 了解到最近固件的更新会导致非常相似的问题,因此我们怀疑是这个问题同样是由于固件更新导致的。本身不是由 MUBUF 指令引起的,应该是正好撞到了这个 pattern。 ...

January 17, 2025

一个关于 AMDGPU Page Fault 的报错排查

背景 最近在做一个项目,其中涉及到构建一个函数,函数的输入是一个内存地址,函数需要解析这个内存地址中的值然后执行相应的动作。已知的是输入的内存地址是 int64 类型的数值,代表的是 amdgpu scratch memory 中的一个地址。 基于项目当前的架构,我的设计是在 opencl 中实现这个函数,然后将这个函数编译到一个 bitcode 文件中,作为一个 bitcode 文件,它可以在项目编译的时候通过 -mlink-builtin-bitcode 链接进来,调用方只需要在模块中正确地声明这个函数就可以。 这个函数是在 runtime 执行的,因为在编译期间不能获取到内存中存的数值。 问题浮现 当我完成函数代码的编写,并且成功编译到可执行文件,看起来一切顺利。但是当我运行可执行文件的时候,意外发生了,出现了 Page Fault 的错误。 原因分析 经过一盘排查之后,发现问题出在加载了非法的指针地址。输入的指针地址指向的是 scratch 地址空间(i8 addrspace(5)),为了可以访问这个地址,我通过 inttoptr 将其转换到了 generic 地址空间的指针(i8),然后再访问这个地址。我用下面的伪代码来表达一下这番操作: define protected void void @foo(i64 %0) { Entry: %1 = inttoptr i64 %0 to i8* %2 = load i8, i8* %1, align 1 ... } 而调用方首先通过 ptrtoint 将 i8 addrspace(5)* 转成了 int64,所以整个流程可以简化为下面的伪代码: define amdgpu_kernel void @bar(...) { Entry: %0 = ptrtoint i8 addrspace(5)* %ptr to i64 %1 = inttoptr i64 %0 to i8* %2 = load i8, i8* %1, align 1 ... } 看起来一切都符合逻辑,除了 scratch 的地址空间被转到了 generic 的地址空间,但是理论上 generic 指针也可以处理 scratch 的指针,只要指针指向的地址没有改变就可以了。但是指针指向的地址确实被改变了! ...

May 9, 2024