Index: mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -75,19 +75,8 @@ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); - unsigned indexTypeBitwidth = - this->getTypeConverter()->getIndexTypeBitwidth(); - - // The corresponding intrinsics expects leadDimension to be a 32-bit - // integer, so all the calculations of linearizing the load address - // must also follow this restriction. - if (indexTypeBitwidth != 32) - return rewriter.notifyMatchFailure( - op, "Expected indices to the memref to be 32-bit wide."); Location loc = op->getLoc(); - auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr(); - // MemRefDescriptor to extract alignedPtr and offset. MemRefDescriptor promotedSrcOp(adaptor.srcMemref()); @@ -95,21 +84,21 @@ // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are // assumed to be normalized and hence the simple conversion works. + IntegerAttr leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr(); SmallVector indices(adaptor.indices()); Value srcOffsetIVal = indices[0]; Value srcOffsetJVal = indices[1]; - Type i32Ty = rewriter.getI32Type(); - Value leadingDim32 = - rewriter.create(loc, i32Ty, leadDimension); + Value leadingDim = rewriter.create( + loc, srcOffsetIVal.getType(), leadDimension); Value numElemsLeadDim = - rewriter.create(loc, i32Ty, leadingDim32, srcOffsetIVal); - Value loadOffset = rewriter.create(loc, i32Ty, numElemsLeadDim, - srcOffsetJVal); + rewriter.create(loc, leadingDim, srcOffsetIVal); + Value loadOffset = + rewriter.create(loc, numElemsLeadDim, srcOffsetJVal); Value promotedSrcOpToUse; promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc); - Value actualOffset = rewriter.create(loc, i32Ty, loadOffset, - promotedSrcOpToUse); + Value actualOffset = + rewriter.create(loc, loadOffset, promotedSrcOpToUse); Value loadAddress = rewriter.create( loc, promotedSrcOp.getElementPtrType(), promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); @@ -120,7 +109,8 @@ Value loadAddressCasted = rewriter.create( loc, LLVM::LLVMPointerType::get( - i32Ty, promotedSrcOp.getElementPtrType().getAddressSpace()), + rewriter.getI32Type(), + promotedSrcOp.getElementPtrType().getAddressSpace()), loadAddress); // Get the shape of the MMAMatrix type being returned. The shape will @@ -133,6 +123,8 @@ StringRef operandStr = retType.getOperand(); // Create nvvm.mma_load op according to the operand types. + Value leadingDim32 = rewriter.create( + loc, rewriter.getI32Type(), leadDimension); SmallVector loadOpOperands({loadAddressCasted, leadingDim32}); if (operandStr.equals("AOp")) { if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { @@ -182,40 +174,29 @@ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); - unsigned indexTypeBitwidth = - this->getTypeConverter()->getIndexTypeBitwidth(); - // The corresponding intrinsics expects leadDimension to be a 32-bit - // integer, so all the calculations of linearizing the store address - // must also follow this restriction. - if (indexTypeBitwidth != 32) - return rewriter.notifyMatchFailure( - op, "expected indices to the memref to be 32-bit wide."); - Location loc = op->getLoc(); // MemRefDescriptor to extract alignedPtr and offset. MemRefDescriptor promotedDstOp(adaptor.dstMemref()); - auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr(); - // Emit ops which compute the store offset using `dstOffsetI`, // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + // ((leadDimension * dstOffsetI) + dstOffsetJ)). + auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr(); SmallVector indices(adaptor.indices()); Value dstOffsetIVal = indices[0]; Value dstOffsetJVal = indices[1]; - Type i32Ty = rewriter.getI32Type(); - Value leadingDim32 = - rewriter.create(loc, i32Ty, leadDimension); + Value leadingDim = rewriter.create( + loc, dstOffsetIVal.getType(), leadDimension); Value numElemsLeadDim = - rewriter.create(loc, i32Ty, leadingDim32, dstOffsetIVal); - Value loadOffset = rewriter.create(loc, i32Ty, numElemsLeadDim, - dstOffsetJVal); + rewriter.create(loc, leadingDim, dstOffsetIVal); + Value loadOffset = + rewriter.create(loc, numElemsLeadDim, dstOffsetJVal); Value promotedDstOpToUse; promotedDstOpToUse = promotedDstOp.offset(rewriter, loc); - Value actualOffset = rewriter.create(loc, i32Ty, loadOffset, - promotedDstOpToUse); + Value actualOffset = + rewriter.create(loc, loadOffset, promotedDstOpToUse); Value storeAddress = rewriter.create( loc, promotedDstOp.getElementPtrType(), promotedDstOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); @@ -226,7 +207,8 @@ Value storeAddressCasted = rewriter.create( loc, LLVM::LLVMPointerType::get( - i32Ty, promotedDstOp.getElementPtrType().getAddressSpace()), + rewriter.getI32Type(), + promotedDstOp.getElementPtrType().getAddressSpace()), storeAddress); SmallVector storeOpOperands; @@ -245,6 +227,8 @@ rewriter.getI32ArrayAttr(i)); storeOpOperands.push_back(toUse); } + Value leadingDim32 = rewriter.create( + loc, rewriter.getI32Type(), leadDimension); storeOpOperands.push_back(leadingDim32); // Unpack the results from the source. if (srcType.getElementType().isF16()) { Index: mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -1,26 +1,43 @@ -// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck --check-prefix=CHECK32 %s gpu.module @test_module { // CHECK-LABEL: func @gpu_wmma_load_op() -> // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> { + // CHECK32-LABEL: func @gpu_wmma_load_op() -> builtin.func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) { %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index %j = arith.constant 16 : index %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> - // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] - // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 - // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 - // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 - // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 + // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 + // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64 + // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr - // CHECK: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM]] : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM32]] : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 + // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + // CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 + // CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + // CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK32: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM32]] : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> } } @@ -31,27 +48,48 @@ // CHECK-LABEL: func @gpu_wmma_store_op // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) { + // CHECK32-LABEL: func @gpu_wmma_store_op + // CHECK32-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) { builtin.func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () { %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index %j = arith.constant 16 : index gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3> - // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] - // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 - // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 - // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 - // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 + // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 + // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64 + // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr // CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - // CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM]] : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 + // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM32]] : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 // CHECK: llvm.return + + // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 + // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + // CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 + // CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + // CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + // CHECK32: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK32: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM32]] : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 + // CHECK32: llvm.return return } } @@ -96,9 +134,9 @@ // CHECK-LABEL: func @gpu_wmma_mma_loop_op // CHECK: %[[C:.+]] = nvvm.wmma.m16n16k16.load.c.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) -// CHECK: ^bb1(%{{.*}}: i32, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2 -// CHECK: llvm.cond_br %38, ^bb2, ^bb3 +// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) +// CHECK: ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2 +// CHECK: llvm.cond_br %{{.*}}, ^bb2, ^bb3 // CHECK: ^bb2: // pred: ^bb1 // CHECK: %[[A:.+]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[B:.+]] = nvvm.wmma.m16n16k16.load.b.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> @@ -123,13 +161,13 @@ // CHECK: %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[ACC_MUL:.+]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) +// CHECK: llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) // CHECK: ^bb3: // pred: ^bb1 -// CHECK: %87 = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: %88 = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: %89 = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: %90 = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %86, %87, %88, %89, %90, %79 : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 +// CHECK: %[[E0:.+]] = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[E1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[E2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[E3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]], %{{.*}} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 builtin.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) { %c0 = arith.constant 0 : index Index: mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir =================================================================== --- mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir +++ mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s \ // RUN: -gpu-kernel-outlining \ -// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm{index-bitwidth=32},gpu-to-cubin{chip=sm_70})' \ +// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin{chip=sm_70})' \ // RUN: --convert-scf-to-std -gpu-to-llvm \ // RUN: | mlir-cpu-runner \ // RUN: --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \ Index: mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir =================================================================== --- mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir +++ mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s \ // RUN: -gpu-kernel-outlining \ -// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm{index-bitwidth=32},gpu-to-cubin{chip=sm_70})' \ +// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin{chip=sm_70})' \ // RUN: --convert-scf-to-std -gpu-to-llvm \ // RUN: | mlir-cpu-runner \ // RUN: --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \