diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -991,7 +991,7 @@ ``` }]; - let arguments = (ins Arg, "", [MemRead]>:$srcMemref, + let arguments = (ins Arg, "", [MemRead]>:$srcMemref, Variadic:$indices, IndexAttr:$leadDimension); @@ -1031,7 +1031,7 @@ }]; let arguments = (ins Arg>:$src, - Arg, "",[MemWrite]>:$dstMemref, + Arg, "",[MemWrite]>:$dstMemref, Variadic:$indices, IndexAttr:$leadDimension); diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -77,44 +77,6 @@ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); - Location loc = op->getLoc(); - - // MemRefDescriptor to extract alignedPtr and offset. - MemRefDescriptor promotedSrcOp(adaptor.srcMemref()); - - // Emit ops which compute the load offset using `srcOffsetI`, - // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + - // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are - // assumed to be normalized and hence the simple conversion works. - IntegerAttr leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr(); - SmallVector indices(adaptor.indices()); - Value srcOffsetIVal = indices[0]; - Value srcOffsetJVal = indices[1]; - Value leadingDim = rewriter.create( - loc, srcOffsetIVal.getType(), leadDimension); - Value numElemsLeadDim = - rewriter.create(loc, leadingDim, srcOffsetIVal); - Value loadOffset = - rewriter.create(loc, numElemsLeadDim, srcOffsetJVal); - - Value promotedSrcOpToUse; - promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc); - Value actualOffset = - rewriter.create(loc, loadOffset, promotedSrcOpToUse); - Value loadAddress = rewriter.create( - loc, promotedSrcOp.getElementPtrType(), - promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); - - // Bitcast the base address pointer of the destination memref, So that - // values can be stored in chunks of 32-bits and semantics match with the - // intrinsic exposed by NVPTX backend. - Value loadAddressCasted = rewriter.create( - loc, - LLVM::LLVMPointerType::get( - rewriter.getI32Type(), - promotedSrcOp.getElementPtrType().getAddressSpace()), - loadAddress); - // Get the shape of the MMAMatrix type being returned. The shape will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType retType = @@ -146,15 +108,18 @@ return rewriter.notifyMatchFailure(op, kInvalidCaseStr); Type resType = convertMMAToLLVMType(retType); + Location loc = op->getLoc(); // Create nvvm.mma_load op according to the operand types. - Value leadingDim32 = rewriter.create( - loc, rewriter.getI32Type(), leadDimension); + Value dataPtr = getStridedElementPtr( + loc, subgroupMmaLoadMatrixOp.srcMemref().getType().cast(), + adaptor.srcMemref(), adaptor.indices(), rewriter); + Value leadingDim = rewriter.create( + loc, rewriter.getI32Type(), + subgroupMmaLoadMatrixOp.leadDimensionAttr()); rewriter.replaceOpWithNewOp( - op, resType, loadAddressCasted, leadingDim32, m, n, k, layout, eltype, - frag); - + op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag); return success(); } }; @@ -178,41 +143,6 @@ Location loc = op->getLoc(); - // MemRefDescriptor to extract alignedPtr and offset. - MemRefDescriptor promotedDstOp(adaptor.dstMemref()); - - // Emit ops which compute the store offset using `dstOffsetI`, - // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + - // ((leadDimension * dstOffsetI) + dstOffsetJ)). - auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr(); - SmallVector indices(adaptor.indices()); - Value dstOffsetIVal = indices[0]; - Value dstOffsetJVal = indices[1]; - Value leadingDim = rewriter.create( - loc, dstOffsetIVal.getType(), leadDimension); - Value numElemsLeadDim = - rewriter.create(loc, leadingDim, dstOffsetIVal); - Value loadOffset = - rewriter.create(loc, numElemsLeadDim, dstOffsetJVal); - - Value promotedDstOpToUse; - promotedDstOpToUse = promotedDstOp.offset(rewriter, loc); - Value actualOffset = - rewriter.create(loc, loadOffset, promotedDstOpToUse); - Value storeAddress = rewriter.create( - loc, promotedDstOp.getElementPtrType(), - promotedDstOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); - - // Bitcast the base address pointer of the destination memref, So that - // values can be stored in chunks of 32-bits and semantics match with the - // intrinsic exposed by NVPTX backend. - Value storeAddressCasted = rewriter.create( - loc, - LLVM::LLVMPointerType::get( - rewriter.getI32Type(), - promotedDstOp.getElementPtrType().getAddressSpace()), - storeAddress); - SmallVector storeOpOperands; // Get the shape of the MMAMatrix type being stored. The shape will // choose which intrinsic this op will be lowered to. @@ -234,12 +164,15 @@ rewriter.getI32ArrayAttr(i)); storeOpOperands.push_back(toUse); } - Value leadingDim32 = rewriter.create( - loc, rewriter.getI32Type(), leadDimension); - rewriter.create(loc, storeAddressCasted, m, n, k, layout, - eltype, storeOpOperands, leadingDim32); - rewriter.eraseOp(op); + Value dataPtr = getStridedElementPtr( + loc, subgroupMmaStoreMatrixOp.dstMemref().getType().cast(), + adaptor.dstMemref(), adaptor.indices(), rewriter); + Value leadingDim = rewriter.create( + loc, rewriter.getI32Type(), + subgroupMmaStoreMatrixOp.leadDimensionAttr()); + rewriter.replaceOpWithNewOp( + op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim); return success(); } }; diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -13,32 +13,26 @@ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 - // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64 + // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 - // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64 - // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]] - // CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] + // CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 + // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 - // CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 - // CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr - // CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i32) -> !llvm.ptr // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]] - // CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] + // CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> } @@ -59,40 +53,34 @@ gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3> // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] - // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 - // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64 - // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 - // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64 - // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr // CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 + // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] - // CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> + // CHECK: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] + // CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> // CHECK: llvm.return // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] - // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 - // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 - // CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 - // CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr - // CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr // CHECK32: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 + // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i32) -> !llvm.ptr // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK32: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] - // CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> + // CHECK32: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] + // CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> // CHECK32: llvm.return return } @@ -139,13 +127,13 @@ gpu.module @test_module { // CHECK-LABEL: func @gpu_wmma_mma_loop_op -// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) // CHECK: ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2 // CHECK: llvm.cond_br %{{.*}}, ^bb2, ^bb3 // CHECK: ^bb2: // pred: ^bb1 -// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> @@ -173,7 +161,7 @@ // CHECK: %[[E1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[E2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[E3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> +// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> builtin.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) { %c0 = arith.constant 0 : index