diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -81,8 +81,10 @@ return NVVM::MMATypes::f64; if (operandElType.isF16() || operandElType == half2Type) return NVVM::MMATypes::f16; - if (operandElType.isF32()) + if (operandElType.isF32() && isAccumulator) return NVVM::MMATypes::f32; + if (operandElType.isF32() && !isAccumulator) + return NVVM::MMATypes::tf32; if (operandElType.isa()) { if (isAccumulator) return NVVM::MMATypes::s32; @@ -291,7 +293,7 @@ parser.getNameLoc(), "expected one type for each operand segment but got " + Twine(operandTypes.size()) + " types"); - for (const auto& iter : llvm::enumerate(operandTypes)) { + for (const auto &iter : llvm::enumerate(operandTypes)) { auto &frag = frags[iter.index()]; frag.regTypes.resize(frag.regs.size(), iter.value()); if (failed(parser.resolveOperands(frag.regs, frag.regTypes, @@ -376,8 +378,9 @@ switch (multiplicandAPtxType().getValue()) { case MMATypes::tf32: kFactor = 4; + multiplicandFragType = i32Ty; expectedResult.push_back(LLVM::LLVMStructType::getLiteral( - context, {i32Ty, i32Ty, i32Ty, i32Ty})); + context, {f32Ty, f32Ty, f32Ty, f32Ty})); break; case MMATypes::f16: case MMATypes::bf16: diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -152,6 +152,17 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } +func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> +} + func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -203,6 +203,17 @@ llvm.return %0 : !llvm.struct<(f64, f64)> } +llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> { + // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32 + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + shape = {m = 16 : i32, n = 8 : i32, k = 4 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> +} + // The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic // in the LLVM NVPTX backend. // CHECK-LABEL: @gpu_wmma_load_op