Add support for loading, computing, and storing gpu.subgroup WMMA ops
in transpose mode as well. Update the GPU to NVVM lowerings to support
transpose mode and update integration tests as well.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | ||
---|---|---|
1162 | please add some doc (or move it from below) | |
1182–1188 | Do you understand why the compute operation has to know that? The transposition should happen during the load, I don't understand why we need to know this at that point. It looks like in your test you don't set it and it works fine. | |
1231–1232 | This doc seems to be at the wrong spot? | |
mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir | ||
13 | we should also test the non transposed case. Is it somewhere? |
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | ||
---|---|---|
157 | Please also change the lowering to SPIRV to avoid miscompile (it's fine to fail the lowering pattern but ignoring the new field is wrong) |
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | ||
---|---|---|
1182–1188 | I didn't understand this question. Please see the NVVM lowering below. Why would the LLVM/NVVM builder take aLayout and bLayout in that case? rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>( op, adaptor.getOpC().getType(), m, n, k, layout, layout, sourceType, op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType, destType, unpackedOps); destType, unpackedOps); It looks like the mma intrinsic allows you to multiply elements in the transposed order and this is orthogonal to the load. So, you can load in whatever way, but do the multiplication on the transpose of the values if desired? | |
1228–1232 | Doc added at the wrong place. | |
mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir | ||
96 | As a convention, attributes like these should use snake case: a_transpose? | |
mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir | ||
30 | How is this change related/needed? The comment above would now be inaccurate. |
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | ||
---|---|---|
1182–1188 |
From the ptx spec this doesn't sound like what it is: The qualifiers .alayout and .blayout must match the layout specified on the wmma.load instructions that produce the contents of operands a and b respectively. Similarly, the qualifiers .atype, .btype and .ctype must match the corresponding qualifiers on the wmma.load instructions that produce the contents of operands a, b and c respectively. I'm guessing this is a way to allow a more flexible implementation of wmma but this makes the representation awkward. Interestingly enough SPIRV equivalent doesn't have this restriction. This is why I wanted to understand if we can simplify the design. |
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | ||
---|---|---|
1182–1188 | Shouldn't the LLVM (and NVVM) dialect mimic 1:1 the underlying intrinsics? If the underlying intrinsic expects it and requires it to be matched, the MLIR counterpart shouldn't remove it but expose the same requirements on the attributes. This should also be documented. |
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | ||
---|---|---|
1182–1188 | The LLVM and NVVM dialect should mimic 1:1 the intrinsics but in the GPU dialect can try to abstract those API specific details if possible. Anyway I don't have a good solution so it is fine to leave as is if nobody has a better idea, however the current solution feels very error prone. |
Several minor improvements needed on the readability side.
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | ||
---|---|---|
80–82 | Please use a ternary operator - more readable and compact. | |
157 | Likewise. | |
228–233 | Likewise. | |
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | ||
121 | useColMajor -> isColumnMajor to be consistent with the one below. | |
123 | Can fix typo in variable name while on this. | |
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | ||
478 | Use `/*arg=..*/ style for the last one. | |
804 | Likewise. | |
mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir | ||
25 | with the column index |
please add some doc (or move it from below)