diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1141,7 +1141,8 @@ `!gpu.mma_matrix` is the source value containing the data to be stored into the destination memref which can be in global or shared memory. The store address is determined using the indices provided. The `leadDimension` attribute - specifies the leading dimension of the destination matrix. + specifies the leading dimension of the destination matrix. If the + `transpose` attribute is present then the op does a transposed store. This op is often meant to be used along with `gpu.subgroup_mma_load_matrix` and `gpu.subgroup_mma_compute`. @@ -1157,7 +1158,8 @@ let arguments = (ins Arg>:$src, Arg:$dstMemref, Variadic:$indices, - IndexAttr:$leadDimension); + IndexAttr:$leadDimension, + OptionalAttr:$transpose); let assemblyFormat = [{ $src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref) @@ -1165,8 +1167,8 @@ let hasVerifier = 1; } -def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", - [Pure, AllTypesMatch<["opC", "res"]>]>{ +def GPU_SubgroupMmaComputeOp + : GPU_Op<"subgroup_mma_compute", [Pure, AllTypesMatch<["opC", "res"]>]> { let summary = "GPU warp synchronous matrix multiply accumulate"; @@ -1175,9 +1177,14 @@ operation using all the threads in a subgroup. This operation takes three `!gpu.mma_matrix`s as arguments: these hold `A`, - `B` and `C`operands for the mma operation. The operation performed is represented + `B` and `C`operands for the mma operation. The operation performed is represented as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of - the operation held by all threads in a subgroup. + the operation held by all threads in a subgroup. `a_transpose` or + `b_transpose` if present, signify that the respective operand was loaded in a + transposed manner. The transpose opernads are required to map to correct + underlying intrisics but they currently do not seem to affect correctness + even if they are absent given that the operands were loaded correctly using + the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op. This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and `gpu.subgroup_mma_load_matrix` ops. @@ -1193,9 +1200,11 @@ let arguments = (ins Arg>:$opA, Arg>:$opB, - Arg>:$opC); + Arg>:$opC, + OptionalAttr:$a_transpose, + OptionalAttr:$b_transpose); - let results = (outs GPU_MMAMatrix:$res); + let results = (outs GPU_MMAMatrix : $res); let assemblyFormat = [{ $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res) @@ -1215,11 +1224,11 @@ The `gpu.subgroup_mma_constant_matrix` creates a `!gpu.mma_matrix` with constant elements. - The operation takes a scalar input and return a `!gpu.mma_matrix` where each - element of is equal to the operand constant. The destination mma_matrix type - must have elememt type equal to the constant type. Since the layout of - `!gpu.mma_matrix` is opaque this only support setting all the elements to - the same value. + The operation takes a scalar input and return a `!gpu.mma_matrix` where + each element of is equal to the operand constant. The destination + mma_matrix type must have elememt type equal to the constant type. Since + the layout of `!gpu.mma_matrix` is opaque this only support setting all the + elements to the same value. This op is meant to be used along with `gpu.subgroup_mma_compute`. diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -77,12 +77,11 @@ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); - // TODO: Support transposed mma loads. - if (subgroupMmaLoadMatrixOp.getTranspose()) - return failure(); - // Get the shape of the MMAMatrix type being returned. The shape will // choose which intrinsic this op will be lowered to. + NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose() + ? NVVM::MMALayout::col + : NVVM::MMALayout::row; gpu::MMAMatrixType retType = subgroupMmaLoadMatrixOp.getRes().getType().cast(); ArrayRef retTypeShape = retType.getShape(); @@ -105,7 +104,6 @@ n = retTypeShape[1]; k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype); } - NVVM::MMALayout layout = NVVM::MMALayout::row; NVVM::MMAFrag frag = convertOperand(retType.getOperand()); // Check that there is an exisiting instruction for the combination we need. if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) @@ -154,7 +152,9 @@ gpu::MMAMatrixType srcType = subgroupMmaStoreMatrixOp.getSrc().getType().cast(); ArrayRef srcTypeShape = srcType.getShape(); - NVVM::MMALayout layout = NVVM::MMALayout::row; + NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose() + ? NVVM::MMALayout::col + : NVVM::MMALayout::row; NVVM::MMATypes eltype = getElementType(srcType); int64_t m = srcTypeShape[0]; int64_t n = srcTypeShape[1]; @@ -224,10 +224,15 @@ int64_t m = cTypeShape[0]; int64_t n = cTypeShape[1]; int64_t k = aTypeShape[1]; - NVVM::MMALayout layout = NVVM::MMALayout::row; + NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose() + ? NVVM::MMALayout::col + : NVVM::MMALayout::row; + NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose() + ? NVVM::MMALayout::col + : NVVM::MMALayout::row; NVVM::MMATypes sourceType = getElementType(aType); NVVM::MMATypes destType = getElementType(cType); - if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, layout, layout, sourceType, + if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType, destType) == 0) return rewriter.notifyMatchFailure(op, kInvalidCaseStr); @@ -236,7 +241,7 @@ unpackOp(adaptor.getOpC()); rewriter.replaceOpWithNewOp( - op, adaptor.getOpC().getType(), m, n, k, layout, layout, sourceType, + op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType, destType, unpackedOps); return success(); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -87,10 +87,9 @@ auto i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create( loc, i32Type, IntegerAttr::get(i32Type, stride)); - bool useColMajor = - static_cast(subgroupMmaLoadMatrixOp.getTranspose()); + bool isColMajor = static_cast(subgroupMmaLoadMatrixOp.getTranspose()); auto columnMajor = rewriter.create( - loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor)); + loc, rewriter.getI1Type(), rewriter.getBoolAttr(isColMajor)); rewriter.replaceOpWithNewOp( subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor, spirv::MemoryAccessAttr()); @@ -118,11 +117,13 @@ auto i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create( loc, i32Type, IntegerAttr::get(i32Type, stride)); - auto coloumnMajor = rewriter.create( - loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + bool useColMajor = + static_cast(subgroupMmaStoreMatrixOp.getTranspose()); + auto columnMajor = rewriter.create( + loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor)); rewriter.replaceOpWithNewOp( subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue, - coloumnMajor, spirv::MemoryAccessAttr()); + columnMajor, spirv::MemoryAccessAttr()); return success(); } }; diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -473,9 +473,9 @@ assert(stride); OpBuilder b(op); Value matrix = valueMapping.find(op.getVector())->second; - b.create(op.getLoc(), matrix, op.getSource(), - op.getIndices(), - b.getIndexAttr(*stride)); + b.create( + op.getLoc(), matrix, op.getSource(), op.getIndices(), + b.getIndexAttr(*stride), /*transpose=*/UnitAttr()); op.erase(); } @@ -800,8 +800,9 @@ Value opA = valueMapping.find(op.getLhs())->second; Value opB = valueMapping.find(op.getRhs())->second; Value opC = valueMapping.find(op.getAcc())->second; - Value matmul = b.create(op.getLoc(), opC.getType(), - opA, opB, opC); + Value matmul = b.create( + op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), + /*b_transpose=*/UnitAttr()); valueMapping[op.getResult()] = matmul; } diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -10,7 +10,7 @@ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index %j = arith.constant 16 : index - %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -20,7 +20,7 @@ // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] - // CHECK-SAME: {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK-SAME: {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 @@ -32,7 +32,7 @@ // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i32) -> !llvm.ptr // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] - // CHECK32-SAME: {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK32-SAME: {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> } @@ -50,7 +50,7 @@ %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index %j = arith.constant 16 : index - gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3> + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3> // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] // CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> @@ -64,7 +64,7 @@ // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] - // CHECK-SAME: {eltype = #nvvm.mma_type, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> + // CHECK-SAME: {eltype = #nvvm.mma_type, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> // CHECK: llvm.return // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 @@ -80,7 +80,7 @@ // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i32) -> !llvm.ptr // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK32: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] - // CHECK32-SAME: {eltype = #nvvm.mma_type, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> + // CHECK32-SAME: {eltype = #nvvm.mma_type, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> // CHECK32: llvm.return return } @@ -93,7 +93,7 @@ // CHECK-LABEL: func @gpu_wmma_mma_op // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) func.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) { - %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + %D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> @@ -115,7 +115,7 @@ // CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] - // CHECK-SAME: {eltypeA = #nvvm.mma_type, eltypeB = #nvvm.mma_type, k = 16 : i32, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : ( + // CHECK-SAME: {eltypeA = #nvvm.mma_type, eltypeB = #nvvm.mma_type, k = 16 : i32, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : ( // CHECK-SAME: vector<2xf16>, {{.*}}) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> return %D : !gpu.mma_matrix<16x16xf16, "COp"> diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir @@ -12,7 +12,8 @@ attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { %i = arith.constant 16 : index %j = arith.constant 16 : index - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr as !spirv.coopmatrix<16x16xf16, Subgroup> %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: spirv.Return gpu.return @@ -22,6 +23,29 @@ // ----- +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK: spirv.module @{{.*}} Logical GLSL450 { + // CHECK-LABEL: spirv.func @gpu_wmma_load_op_transpose + // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} + // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi + gpu.func @gpu_wmma_load_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %i = arith.constant 16 : index + %j = arith.constant 16 : index + // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr as !spirv.coopmatrix<16x16xf16, Subgroup> + %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, #spirv.storage_class> -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- + module attributes { gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { @@ -35,7 +59,8 @@ attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { %i = arith.constant 16 : index %j = arith.constant 16 : index - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr, !spirv.coopmatrix<16x16xf16, Subgroup> gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class> // CHECK: spirv.Return gpu.return @@ -45,6 +70,30 @@ // ----- +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + gpu.module @kernels { + // CHECK: spirv.module @{{.*}} Logical GLSL450 { + // CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose + // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} + // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) + // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi + gpu.func @gpu_wmma_store_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %i = arith.constant 16 : index + %j = arith.constant 16 : index + // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr, !spirv.coopmatrix<16x16xf16, Subgroup> + gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- + module attributes { gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { @@ -107,4 +156,4 @@ gpu.return } } -} \ No newline at end of file +} diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir --- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir +++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir @@ -22,10 +22,12 @@ %c32 = arith.constant 32 : index %c1 = arith.constant 1 : index - // Intialize the Input matrix with ones. + // Intialize the Input matrix with the column index in each row. scf.for %arg0 = %c0 to %c16 step %c1 { scf.for %arg1 = %c0 to %c16 step %c1 { - memref.store %f1, %0[%arg0, %arg1] : memref<16x16xf16> + %2 = arith.index_cast %arg1 : index to i16 + %3 = arith.sitofp %2 : i16 to f16 + memref.store %3, %0[%arg0, %arg1] : memref<16x16xf16> } } // Intialize the accumulator matrix with zeros. @@ -43,11 +45,11 @@ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) { - %A = gpu.subgroup_mma_load_matrix %0[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> + %A = gpu.subgroup_mma_load_matrix %0[%c0, %c0] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> %B = gpu.subgroup_mma_load_matrix %0[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> %C = gpu.subgroup_mma_load_matrix %22[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> - %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + %R = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> gpu.subgroup_mma_store_matrix %R, %0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> gpu.terminator @@ -64,22 +66,22 @@ // Print the memref after computation. call @printMemrefF32(%3) : (memref<*xf32>) -> () - // CHECK: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], - // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16] + // CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + // CHECK-NEXT: [0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240], + // CHECK-NEXT: [0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480], + // CHECK-NEXT: [0, 48, 96, 144, 192, 240, 288, 336, 384, 432, 480, 528, 576, 624, 672, 720], + // CHECK-NEXT: [0, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960], + // CHECK-NEXT: [0, 80, 160, 240, 320, 400, 480, 560, 640, 720, 800, 880, 960, 1040, 1120, 1200], + // CHECK-NEXT: [0, 96, 192, 288, 384, 480, 576, 672, 768, 864, 960, 1056, 1152, 1248, 1344, 1440], + // CHECK-NEXT: [0, 112, 224, 336, 448, 560, 672, 784, 896, 1008, 1120, 1232, 1344, 1456, 1568, 1680], + // CHECK-NEXT: [0, 128, 256, 384, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920], + // CHECK-NEXT: [0, 144, 288, 432, 576, 720, 864, 1008, 1152, 1296, 1440, 1584, 1728, 1872, 2016, 2160], + // CHECK-NEXT: [0, 160, 320, 480, 640, 800, 960, 1120, 1280, 1440, 1600, 1760, 1920, 2080, 2240, 2400], + // CHECK-NEXT: [0, 176, 352, 528, 704, 880, 1056, 1232, 1408, 1584, 1760, 1936, 2112, 2288, 2464, 2640], + // CHECK-NEXT: [0, 192, 384, 576, 768, 960, 1152, 1344, 1536, 1728, 1920, 2112, 2304, 2496, 2688, 2880], + // CHECK-NEXT: [0, 208, 416, 624, 832, 1040, 1248, 1456, 1664, 1872, 2080, 2288, 2496, 2704, 2912, 3120], + // CHECK-NEXT: [0, 224, 448, 672, 896, 1120, 1344, 1568, 1792, 2016, 2240, 2464, 2688, 2912, 3136, 3360], + // CHECK-NEXT: [0, 240, 480, 720, 960, 1200, 1440, 1680, 1920, 2160, 2400, 2640, 2880, 3120, 3360, 3600]] return }