diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -168,7 +168,7 @@ /*desc=*/[{Generate constant value.}], /*retType=*/"::mlir::Value", /*methodName=*/"makeConstantI32", - /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "unsigned" : $val), + /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val), /*methodBody=*/"", /*defaultImpl=*/ [{ mlir::Operation* op = $_op; @@ -1473,6 +1473,118 @@ }]; } +/// Enum attribute type for the negating of input operands +def WGMMAScaleInNeg : I32EnumAttrCase<"neg", -1>; +def WGMMAScaleInOne : I32EnumAttrCase<"one", 1>; +def WGMMAScaleIn : I32EnumAttr<"WGMMAScaleIn", "WGMMA overflow options", + [WGMMAScaleInOne, WGMMAScaleInNeg]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def WGMMAScaleInAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +/// Enum attribute type for the output operand +def WGMMAScaleOutZero : I32EnumAttrCase<"zero", 0>; +def WGMMAScaleOutOne : I32EnumAttrCase<"one", 1>; +def WGMMAScaleOut : I32EnumAttr<"WGMMAScaleOut", "WGMMA input predicate", + [WGMMAScaleOutZero, WGMMAScaleOutOne]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def WGMMAScaleOutAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_WgmmaMmaSyncOp : NVVM_Op<"wgmma.mma_async", + [DeclareOpInterfaceMethods]> { + let results = (outs LLVM_AnyAggregate:$results); + let arguments = (ins + LLVM_AnyAggregate:$inouts, + I64:$descriptorA, + I64:$descriptorB, + NVVM_MMAShapeAttr:$shape, + MMATypesAttr:$typeA, + MMATypesAttr:$typeB, + WGMMAScaleOutAttr:$scaleD, + WGMMAScaleInAttr:$scaleA, + WGMMAScaleInAttr:$scaleB, + MMALayoutAttr:$layoutA, + MMALayoutAttr:$layoutB, + OptionalAttr:$satfinite + // OptionalAttr:$satfinite + ); + + let assemblyFormat = [{ + $descriptorA `,` $descriptorB `,` $shape `,` + `D` `[` $inouts `,` $scaleD (`,` $satfinite^)? `]` `,` + `A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,` + `B` `[` $typeB `,` $scaleB `,` $layoutB `]` + attr-dict `:` + type($inouts) `->` type($results) + }]; + + let description = [{ + The warpgroup (128 threads) level matrix multiply and accumulate operation + has either of the following forms, where matrix D is called accumulator: + D = A * B + D + D = A * B, where the input from accumulator D is disabled. + + Supported shapes: + ``` + |-------------------|--------------------|----------------|---------------| + | | f16 += f16 * f16 | s32 += s8 * s8 | | + |f32 += tf32 * tf32 | f32 += f16 * f16 | s32 += s8 * u8 |s32 += b1 * b1 | + | | f32 += bf16 * bf16 | s32 += u8 * u8 | | + |-------------------|--------------------|----------------|---------------| + | .m64n8k8 | .m64n8k16 | .m64n8k32 | .m64n8k256 | + | .m64n16k8 | .m64n16k16 | .m64n16k32 | .m64n16k256 | + | .m64n24k8 | .m64n24k16 | .m64n24k32 | .m64n24k256 | + | .m64n32k8 | .m64n32k16 | .m64n32k32 | .m64n32k256 | + | .m64n40k8 | .m64n40k16 | .m64n48k32 | .m64n48k256 | + | .m64n48k8 | .m64n48k16 | .m64n64k32 | .m64n64k256 | + | .m64n56k8 | .m64n56k16 | .m64n80k32 | .m64n80k256 | + | .m64n64k8 | .m64n64k16 | .m64n96k32 | .m64n96k256 | + | .m64n72k8 | .m64n72k16 | .m64n112k32 | .m64n112k256 | + | .m64n80k8 | .m64n80k16 | .m64n128k32 | .m64n128k256 | + | .m64n88k8 | .m64n88k16 | .m64n144k32 | .m64n144k256 | + | .m64n96k8 | .m64n96k16 | .m64n160k32 | .m64n160k256 | + | .m64n104k8 | .m64n104k16 | .m64n176k32 | .m64n176k256 | + | .m64n112k8 | .m64n112k16 | .m64n192k32 | .m64n192k256 | + | .m64n120k8 | .m64n120k16 | .m64n208k32 | .m64n208k256 | + | .m64n128k8 | .m64n128k16 | .m64n224k32 | .m64n224k256 | + | .m64n136k8 | .m64n136k16 | .m64n240k32 | .m64n240k256 | + | .m64n144k8 | .m64n144k16 | .m64n256k32 | .m64n256k256 | + | .m64n152k8 | .m64n152k16 | | | + | .m64n160k8 | .m64n160k16 | | | + | .m64n168k8 | .m64n168k16 | | | + | .m64n176k8 | .m64n176k16 | | | + | .m64n184k8 | .m64n184k16 | | | + | .m64n192k8 | .m64n192k16 | | | + | .m64n200k8 | .m64n200k16 | | | + | .m64n208k8 | .m64n208k16 | | | + | .m64n216k8 | .m64n216k16 | | | + | .m64n224k8 | .m64n224k16 | | | + | .m64n232k8 | .m64n232k16 | | | + | .m64n240k8 | .m64n240k16 | | | + | .m64n248k8 | .m64n248k16 | | | + | .m64n256k8 | .m64n256k16 | | | + |-------------------|--------------------|----------------|---------------| + ``` + + See for more information: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + void getAsmValues(RewriterBase &rewriter, + llvm::SmallVectorImpl> &asmValues); + }]; +} + //===----------------------------------------------------------------------===// // NVVM target attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -53,7 +53,7 @@ namespace { class PtxBuilder { - Operation *op; + NVVM::BasicPtxBuilderInterface op; PatternRewriter &rewriter; std::string asmStr; SmallVector asmVals; @@ -62,30 +62,35 @@ bool hasResult = false; // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints - char getRegisterType(Value v) { - if (v.getDefiningOp()) - return 'n'; - if (v.getType().isInteger(16)) + char getRegisterType(Type type) { + if (type.isInteger(16)) return 'h'; - if (v.getType().isInteger(32)) + if (type.isInteger(32)) return 'r'; - if (v.getType().isInteger(64)) + if (type.isInteger(64)) return 'l'; - if (v.getType().isF32()) + if (type.isF32()) return 'f'; - if (v.getType().isF64()) + if (type.isF64()) return 'd'; - if (auto ptr = v.getType().dyn_cast()) { + if (auto ptr = type.dyn_cast()) { // Shared address spaces is addressed with 32-bit pointers. if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) { return 'r'; } return 'l'; } - assert(false && "Register type is not handled yet"); + op->emitError() << "Register type could not deduced from MLIR type: " + << type; return ' '; } + char getRegisterType(Value v) { + if (v.getDefiningOp()) + return 'n'; + return getRegisterType(v.getType()); + } + public: PtxBuilder(Operation *op, PatternRewriter &rewriter, std::string ptxAsm, bool sideEffects = false) @@ -93,26 +98,60 @@ sideEffects(sideEffects) {} void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) { - llvm::raw_string_ostream ss(asmConstraints); - if (itype == PTXRegisterMod::Read) { - asmVals.push_back(v); - } else if (itype == PTXRegisterMod::ReadWrite) { - asmVals.push_back(v); - ss << "+"; - hasResult = true; - } else if (itype == PTXRegisterMod::Write) { - ss << "="; + LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n"); + auto getModifier = [&]() -> const char * { + if (itype == PTXRegisterMod::ReadWrite) { + assert(false && "Read-Write modifier is not supported. Try setting the " + "same value as Write and Read seperately."); + return "+"; + } + if (itype == PTXRegisterMod::Write) { + return "="; + } + return ""; + }; + auto addValue = [&](Value v) { + if (itype == PTXRegisterMod::Read) { + asmVals.push_back(v); + return; + } + if (itype == PTXRegisterMod::ReadWrite) + asmVals.push_back(v); hasResult = true; + }; + + llvm::raw_string_ostream ss(asmConstraints); + // Handle Structs + if (auto stype = dyn_cast(v.getType())) { + if (itype == PTXRegisterMod::Write) { + addValue(v); + } + for (auto [idx, t] : llvm::enumerate(stype.getBody())) { + if (itype != PTXRegisterMod::Write) { + Value extractValue = + rewriter.create(op->getLoc(), v, idx); + addValue(extractValue); + } + if (itype == PTXRegisterMod::ReadWrite) { + ss << idx << ","; + } else { + ss << getModifier() << getRegisterType(t) << ","; + } + ss.flush(); + } + return; } - ss << getRegisterType(v) << ","; + // Handle Scalars + addValue(v); + ss << getModifier() << getRegisterType(v) << ","; ss.flush(); } LLVM::InlineAsmOp build() { auto asmDialectAttr = LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT); - Type resultType = hasResult ? op->getResult(0).getType() - : LLVM::LLVMVoidType::get(op->getContext()); + + auto resultTypes = op->getResultTypes(); // Remove the last comma from the constraints string. if (!asmConstraints.empty() && @@ -123,7 +162,8 @@ std::replace(asmStr.begin(), asmStr.end(), '%', '$'); return rewriter.create( - op->getLoc(), resultType, + op->getLoc(), + /*result types=*/resultTypes, /*operands=*/asmVals, /*asm_string=*/llvm::StringRef(asmStr), /*constraints=*/asmConstraints.data(), @@ -159,6 +199,7 @@ } SmallVector> asmValues; + LLVM_DEBUG(DBGS() << op.getPtx() << "\n"); PtxBuilder generator(op, rewriter, op.getPtx(), op.hasSideEffect()); op.getAsmValues(rewriter, asmValues); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -26,7 +26,9 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Types.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Attributes.h" @@ -34,6 +36,7 @@ #include "llvm/IR/Type.h" #include "llvm/Support/Casting.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" #include #include @@ -705,6 +708,292 @@ return success(); } +LogicalResult NVVM::WgmmaMmaSyncOp::verify() { + Value outValue = getResults(); + Value inoutValue = getInouts(); + if (outValue.getType() != inoutValue.getType()) { + return emitOpError() + << "expected output type and inout type must be same struct type"; + } + auto stype = dyn_cast(outValue.getType()); + if (!stype) + return emitOpError() << "expected results to be struct"; + Type outputType = stype.getBody().front(); + int outputSize = stype.getBody().size(); + for (Type t : stype.getBody()) { + if (t != outputType) + return emitOpError() + << "all elements in struct must be same type but there is " << t; + } + + if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) { + return emitOpError() << "does not support the given output type " + << outputType; + } + if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg || + getScaleB() == NVVM::WGMMAScaleIn::neg)) { + return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; + } + + // Check M + if (getShape().getM() != 64) + return emitOpError() << "shape 'm' must be 64"; + + // Check K + mlir::NVVM::MMATypes typeA = getTypeA(); + mlir::NVVM::MMATypes typeB = getTypeB(); + switch (typeA) { + case mlir::NVVM::MMATypes::bf16: + case mlir::NVVM::MMATypes::f16: + if (typeA != typeB) { + return emitOpError() << "input types must be same but got " + << NVVM::stringifyMMATypes(typeA) << " and " + << NVVM::stringifyMMATypes(typeB); + } + if (getShape().getK() != 16) { + return emitOpError() << "shape 'k' must be 16 " + << "for input type " + << NVVM::stringifyMMATypes(typeA); + } + break; + case mlir::NVVM::MMATypes::tf32: + if (typeA != typeB) { + return emitOpError() << "input types must be same but got " + << NVVM::stringifyMMATypes(typeA) << " and " + << NVVM::stringifyMMATypes(typeB); + } + if (getShape().getK() != 8) { + return emitOpError() << "shape 'k' must be 8 " + << "for input type " + << NVVM::stringifyMMATypes(typeA); + } + break; + case mlir::NVVM::MMATypes::s8: + case mlir::NVVM::MMATypes::u8: + if (typeB != mlir::NVVM::MMATypes::s8 && + typeB != mlir::NVVM::MMATypes::u8) { + return emitOpError() << "input type of rhs could be " + << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::s8) + << " or " + << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::u8) + << " same but got and " + << NVVM::stringifyMMATypes(typeB); + } + if (getShape().getK() != 32) { + emitOpError() << "shape 'k' must be 32 " + << "for input type " << NVVM::stringifyMMATypes(typeA); + } + break; + case mlir::NVVM::MMATypes::b1: + if (typeA != typeB) { + return emitOpError() << "input types must be same but got " + << NVVM::stringifyMMATypes(typeA) << " and " + << NVVM::stringifyMMATypes(typeB); + } + if (getShape().getK() != 256) { + return emitOpError() << "shape 'k' must be 256 " + << "for input type " + << NVVM::stringifyMMATypes(typeA); + } + break; + default: + return emitOpError() << "Unsupported input type " + << NVVM::stringifyMMATypes(typeA) << " and " + << NVVM::stringifyMMATypes(typeB); + } + + // Check N + SmallVector allowedNShapesF16 = {8, 16, 24, 32, 40, 48, 56, 64, + 72, 80, 88, 96, 104, 112, 120, 128, + 136, 144, 152, 160, 168, 176, 184, 192, + 200, 208, 216, 224, 232, 240, 248, 256}; + SmallVector allowedNShapesU8S8B1 = {8, 16, 24, 32, 48, 64, + 80, 96, 112, 128, 144, 160, + 176, 192, 208, 224, 240, 256}; + + bool validGEMMType = false; + // f16 += f16 * f16 + if (outputType.isF16() && typeA == mlir::NVVM::MMATypes::f16) { + if (!llvm::any_of(allowedNShapesF16, + [&](int n) { return getShape().getN() == n; })) { + return emitOpError() << "has input type " + << NVVM::stringifyMMATypes(typeA) << " n is set to " + << getShape().getN() << ", it is not supported."; + } + validGEMMType = true; + } + // f32 += tf32 * tf32| f32 += f16 * f16| f16 += bf16 * bf16 + if (outputType.isF32() && (typeA == mlir::NVVM::MMATypes::bf16 || + typeA == mlir::NVVM::MMATypes::tf32 || + typeA == mlir::NVVM::MMATypes::f16)) { + if (!llvm::any_of(allowedNShapesF16, + [&](int n) { return getShape().getN() == n; })) { + return emitOpError() << "has input type " + << NVVM::stringifyMMATypes(typeA) << " n is set to " + << getShape().getN() << ", it is not supported."; + } + validGEMMType = true; + } + // s32 += s8 * s8 | s32 += s8 * u8 | s32 += u8 * u8 | s32 += b1 * b1 + if (outputType.isInteger(32) && + (typeA == mlir::NVVM::MMATypes::s8 || typeA == mlir::NVVM::MMATypes::u8 || + typeA == mlir::NVVM::MMATypes::b1)) { + if (!llvm::any_of(allowedNShapesU8S8B1, + [&](int n) { return getShape().getN() == n; })) { + return emitOpError() << "has input type " + << NVVM::stringifyMMATypes(typeA) << " n is set to " + << getShape().getN() << ", it is not supported."; + } + validGEMMType = true; + } + + if (!validGEMMType) { + return emitOpError() << outputType + << " += " << NVVM::stringifyMMATypes(typeA) << " * " + << NVVM::stringifyMMATypes(typeB) + << ", it is not supported."; + } + + // Check transpose is needed from the given layouts. It is only + // supported for bf16 or f16. + if ((typeA != mlir::NVVM::MMATypes::f16 && + typeA != mlir::NVVM::MMATypes::bf16) && + (getLayoutA() == mlir::NVVM::MMALayout::col || + getLayoutB() == mlir::NVVM::MMALayout::col)) { + return emitOpError() + << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) + << " and layout_b = " << stringifyMMALayout(getLayoutB()) + << " for input types " << stringifyMMATypes(typeA) << " and " + << stringifyMMATypes(typeB) + << " requires transpose. However, this is only supported for: " + << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and " + << stringifyMMATypes(mlir::NVVM::MMATypes::bf16); + } + + // Check number of result registers + int expectedOutput; + if (outputType.isF32() || outputType.isInteger(32)) + expectedOutput = getShape().getN() / 2; + if (outputType.isF16()) + expectedOutput = getShape().getN() / 4; + if (outputSize != expectedOutput) { + return emitOpError() << "results " << expectedOutput + << ", however output struct has " << outputSize + << " elements"; + } + // Check satfinite is set. It is only for s32 accumulator + if (!outputType.isInteger(32) && + getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == + NVVM::MMAIntOverflow::satfinite) { + return emitOpError() + << " `satfinite` can be only used with s32 accumulator, however " + "the current accumulator is " + << outputType; + } + + return success(); +} + +std::string NVVM::WgmmaMmaSyncOp::getPtx() { + + int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); + bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 || + getTypeA() == mlir::NVVM::MMATypes::bf16; + + Value outValue = getResults() ? getResults() : getInouts(); + auto stype = dyn_cast(outValue.getType()); + Type outputType = stype.getBody().front(); + std::string outputTypeName; + if (outputType.isF16()) + outputTypeName = "f16"; + if (outputType.isF32()) + outputTypeName = "f32"; + else if (outputType.isInteger(32)) + outputTypeName = "s32"; + int expectedOutputRegisters; + if (outputType.isF32() || outputType.isInteger(32)) + expectedOutputRegisters = getShape().getN() / 2; + if (outputType.isF16()) + expectedOutputRegisters = getShape().getN() / 4; + + std::string ptx; + llvm::raw_string_ostream ss(ptx); + + ss << "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, $" + << (expectedOutputRegisters + 2) + << ", 0;\n" + "wgmma.mma_async.sync.aligned.m" + << m << "n" << n << "k" << k << "." << outputTypeName << "." + << stringifyMMATypes(getTypeA()) << "." << stringifyMMATypes(getTypeB()); + if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == + NVVM::MMAIntOverflow::satfinite) + ss << ".satfinite"; + ss << " {"; + int regCnt = 0; + for (; regCnt < expectedOutputRegisters; ++regCnt) { + ss << "$" << regCnt; + if (regCnt != expectedOutputRegisters - 1) + ss << ", "; + } + + ss << "},"; + ss << " $" << (expectedOutputRegisters) << "," + << " $" << (expectedOutputRegisters + 1) << "," + << " p"; + if (!outputType.isInteger(32)) { + ss << ", $" << (expectedOutputRegisters + 3) << ", $" + << (expectedOutputRegisters + 4); + } + // Don't add transpose parameters unless needed. + if (isF16) { + ss << ", $" << (expectedOutputRegisters + 5) << ", $" + << (expectedOutputRegisters + 6); + } + ss << ";\n" + << "}\n"; + ss.flush(); + return ptx; +} + +void NVVM::WgmmaMmaSyncOp::getAsmValues( + RewriterBase &rewriter, + llvm::SmallVectorImpl> + &asmValues) { + Value outValue = getResults() ? getResults() : getInouts(); + auto stype = dyn_cast(outValue.getType()); + Type outputType = stype.getBody().front(); + bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 || + getTypeA() == mlir::NVVM::MMATypes::bf16; + if (getResults()) + asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); + if (getInouts()) + asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite}); + asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read}); + asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); + asmValues.push_back({makeConstantI32(rewriter, static_cast(getScaleD())), + mlir::NVVM::PTXRegisterMod::Read}); + if (!outputType.isInteger(32)) { + asmValues.push_back( + {makeConstantI32(rewriter, + getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), + mlir::NVVM::PTXRegisterMod::Read}); + asmValues.push_back( + {makeConstantI32(rewriter, + getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1), + mlir::NVVM::PTXRegisterMod::Read}); + } + if (isF16) { + asmValues.push_back( + {makeConstantI32(rewriter, static_cast(getLayoutA())), + mlir::NVVM::PTXRegisterMod::Read}); + asmValues.push_back( + {makeConstantI32(rewriter, static_cast(getLayoutB())), + mlir::NVVM::PTXRegisterMod::Read}); + } +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVVMToLLVM/code.mlir b/mlir/test/Conversion/NVVMToLLVM/code.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/NVVMToLLVM/code.mlir @@ -0,0 +1,248 @@ +module { + llvm.func @init_mbarrier_arrive_expect_tx(%arg0: !llvm.ptr<3>, %arg1: i32) { + llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> () + llvm.return + } + llvm.func @init_mbarrier_arrive_expect_tx_generic(%arg0: !llvm.ptr, %arg1: i32) { + llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> () + llvm.return + } + llvm.func @init_mbarrier_try_wait_shared(%arg0: !llvm.ptr<3>, %arg1: i32, %arg2: i32) { + llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09LAB_WAIT: \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2; \0A\09@P1 bra.uni DONE; \0A\09bra.uni LAB_WAIT; \0A\09DONE: \0A\09}", "r,r,r" %arg0, %arg2, %arg1 : (!llvm.ptr<3>, i32, i32) -> () + llvm.return + } + llvm.func @init_mbarrier_try_wait(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32) { + llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09LAB_WAIT: \0A\09mbarrier.try_wait.parity.b64 P1, [$0], $1, $2; \0A\09@P1 bra.uni DONE; \0A\09bra.uni LAB_WAIT; \0A\09DONE: \0A\09}", "l,r,r" %arg0, %arg2, %arg1 : (!llvm.ptr, i32, i32) -> () + llvm.return + } + func.func @async_cp(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { + nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1> + nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1> + return + } + func.func @async_cp_zfill(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>, %arg2: i32) { + %0 = llvm.mlir.constant(16 : i32) : i32 + llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %arg0, %arg1, %0, %arg2 : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> () + %1 = llvm.mlir.constant(4 : i32) : i32 + llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %arg0, %arg1, %1, %arg2 : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> () + return + } + func.func @tma_load_1d(%arg0: !llvm.ptr, %arg1: !llvm.ptr<3>, %arg2: !llvm.ptr<3>, %arg3: i32) { + llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3} ], [$2];", "r,l,r,r" %arg1, %arg0, %arg2, %arg3 : (!llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32) -> () + return + } + func.func @tma_load_2d(%arg0: !llvm.ptr, %arg1: !llvm.ptr<3>, %arg2: !llvm.ptr<3>, %arg3: i32, %arg4: i32) { + llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4} ], [$2];", "r,l,r,r,r" %arg1, %arg0, %arg2, %arg3, %arg4 : (!llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32) -> () + return + } + func.func @tma_load_3d(%arg0: !llvm.ptr, %arg1: !llvm.ptr<3>, %arg2: !llvm.ptr<3>, %arg3: i32, %arg4: i32, %arg5: i32) { + llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5} ], [$2];", "r,l,r,r,r,r" %arg1, %arg0, %arg2, %arg3, %arg4, %arg5 : (!llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32) -> () + return + } + func.func @tma_load_4d(%arg0: !llvm.ptr, %arg1: !llvm.ptr<3>, %arg2: !llvm.ptr<3>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { + llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6} ], [$2];", "r,l,r,r,r,r,r" %arg1, %arg0, %arg2, %arg3, %arg4, %arg5, %arg6 : (!llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32) -> () + return + } + func.func @tma_load_5d(%arg0: !llvm.ptr, %arg1: !llvm.ptr<3>, %arg2: !llvm.ptr<3>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6, $7} ], [$2];", "r,l,r,r,r,r,r,r" %arg1, %arg0, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 : (!llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32) -> () + return + } + func.func @wgmma_execute() { + llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", "" : () -> () + llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", "" : () -> () + %0 = llvm.mlir.constant(0 : i32) : i32 + llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %0 : (i32) -> () + llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", "" : () -> () + llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", "" : () -> () + %1 = llvm.mlir.constant(1 : i32) : i32 + llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %1 : (i32) -> () + return + } +} + + +// ----- +module { + func.func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> { + %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %1 = llvm.mlir.constant(0 : i32) : i32 + %2 = llvm.mlir.constant(-1 : i32) : i32 + %3 = llvm.mlir.constant(-1 : i32) : i32 + %4 = llvm.mlir.constant(1 : i32) : i32 + %5 = llvm.mlir.constant(1 : i32) : i32 + %6 = llvm.extractvalue %0[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %7 = llvm.extractvalue %0[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %8 = llvm.extractvalue %0[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %9 = llvm.extractvalue %0[3] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %10 = llvm.extractvalue %0[4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %11 = llvm.extractvalue %0[5] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %12 = llvm.extractvalue %0[6] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %13 = llvm.extractvalue %0[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %14 = llvm.extractvalue %0[8] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %15 = llvm.extractvalue %0[9] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %16 = llvm.extractvalue %0[10] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %17 = llvm.extractvalue %0[11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %18 = llvm.extractvalue %0[12] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %19 = llvm.extractvalue %0[13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %20 = llvm.extractvalue %0[14] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %21 = llvm.extractvalue %0[15] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %22 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $18, 0;\0Awgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $16, $17, p, $19, $20, $21, $22;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %arg0, %arg1, %1, %2, %3, %4, %5 : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %c2_i64 = arith.constant 2 : i64 + %23 = arith.addi %arg0, %c2_i64 : i64 + %24 = arith.addi %arg1, %c2_i64 : i64 + %25 = llvm.mlir.constant(0 : i32) : i32 + %26 = llvm.mlir.constant(-1 : i32) : i32 + %27 = llvm.mlir.constant(-1 : i32) : i32 + %28 = llvm.mlir.constant(1 : i32) : i32 + %29 = llvm.mlir.constant(1 : i32) : i32 + %30 = llvm.extractvalue %22[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %31 = llvm.extractvalue %22[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %32 = llvm.extractvalue %22[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %33 = llvm.extractvalue %22[3] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %34 = llvm.extractvalue %22[4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %35 = llvm.extractvalue %22[5] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %36 = llvm.extractvalue %22[6] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %37 = llvm.extractvalue %22[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %38 = llvm.extractvalue %22[8] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %39 = llvm.extractvalue %22[9] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %40 = llvm.extractvalue %22[10] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %41 = llvm.extractvalue %22[11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %42 = llvm.extractvalue %22[12] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %43 = llvm.extractvalue %22[13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %44 = llvm.extractvalue %22[14] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %45 = llvm.extractvalue %22[15] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %46 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $18, 0;\0Awgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $16, $17, p, $19, $20, $21, $22;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %23, %24, %25, %26, %27, %28, %29 : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + return %46 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + } +} + + +// ----- +module { + func.func @wgmma_s32_s8_s8_satfinite(%arg0: i64, %arg1: i64) -> !llvm.struct<(i32, i32, i32, i32)> { + %0 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)> + %1 = llvm.mlir.constant(1 : i32) : i32 + %2 = llvm.extractvalue %0[0] : !llvm.struct<(i32, i32, i32, i32)> + %3 = llvm.extractvalue %0[1] : !llvm.struct<(i32, i32, i32, i32)> + %4 = llvm.extractvalue %0[2] : !llvm.struct<(i32, i32, i32, i32)> + %5 = llvm.extractvalue %0[3] : !llvm.struct<(i32, i32, i32, i32)> + %6 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %2, %3, %4, %5, %arg0, %arg1, %1 : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %7 = llvm.mlir.constant(1 : i32) : i32 + %8 = llvm.extractvalue %6[0] : !llvm.struct<(i32, i32, i32, i32)> + %9 = llvm.extractvalue %6[1] : !llvm.struct<(i32, i32, i32, i32)> + %10 = llvm.extractvalue %6[2] : !llvm.struct<(i32, i32, i32, i32)> + %11 = llvm.extractvalue %6[3] : !llvm.struct<(i32, i32, i32, i32)> + %12 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %8, %9, %10, %11, %arg0, %arg1, %7 : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %13 = llvm.mlir.constant(1 : i32) : i32 + %14 = llvm.extractvalue %12[0] : !llvm.struct<(i32, i32, i32, i32)> + %15 = llvm.extractvalue %12[1] : !llvm.struct<(i32, i32, i32, i32)> + %16 = llvm.extractvalue %12[2] : !llvm.struct<(i32, i32, i32, i32)> + %17 = llvm.extractvalue %12[3] : !llvm.struct<(i32, i32, i32, i32)> + %18 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %14, %15, %16, %17, %arg0, %arg1, %13 : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return %18 : !llvm.struct<(i32, i32, i32, i32)> + } + func.func @wgmma_s32_u8_u8(%arg0: i64, %arg1: i64) -> !llvm.struct<(i32, i32, i32, i32)> { + %0 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)> + %1 = llvm.mlir.constant(1 : i32) : i32 + %2 = llvm.extractvalue %0[0] : !llvm.struct<(i32, i32, i32, i32)> + %3 = llvm.extractvalue %0[1] : !llvm.struct<(i32, i32, i32, i32)> + %4 = llvm.extractvalue %0[2] : !llvm.struct<(i32, i32, i32, i32)> + %5 = llvm.extractvalue %0[3] : !llvm.struct<(i32, i32, i32, i32)> + %6 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %2, %3, %4, %5, %arg0, %arg1, %1 : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %7 = llvm.mlir.constant(1 : i32) : i32 + %8 = llvm.extractvalue %6[0] : !llvm.struct<(i32, i32, i32, i32)> + %9 = llvm.extractvalue %6[1] : !llvm.struct<(i32, i32, i32, i32)> + %10 = llvm.extractvalue %6[2] : !llvm.struct<(i32, i32, i32, i32)> + %11 = llvm.extractvalue %6[3] : !llvm.struct<(i32, i32, i32, i32)> + %12 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %8, %9, %10, %11, %arg0, %arg1, %7 : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %13 = llvm.mlir.constant(1 : i32) : i32 + %14 = llvm.extractvalue %12[0] : !llvm.struct<(i32, i32, i32, i32)> + %15 = llvm.extractvalue %12[1] : !llvm.struct<(i32, i32, i32, i32)> + %16 = llvm.extractvalue %12[2] : !llvm.struct<(i32, i32, i32, i32)> + %17 = llvm.extractvalue %12[3] : !llvm.struct<(i32, i32, i32, i32)> + %18 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %14, %15, %16, %17, %arg0, %arg1, %13 : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return %18 : !llvm.struct<(i32, i32, i32, i32)> + } +} + + +// ----- +module { + func.func @wgmma_f32_tf32_tf32(%arg0: i64, %arg1: i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> { + %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %1 = llvm.mlir.constant(1 : i32) : i32 + %2 = llvm.mlir.constant(1 : i32) : i32 + %3 = llvm.mlir.constant(1 : i32) : i32 + %4 = llvm.extractvalue %0[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %5 = llvm.extractvalue %0[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %6 = llvm.extractvalue %0[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %7 = llvm.extractvalue %0[3] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %8 = llvm.extractvalue %0[4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %9 = llvm.extractvalue %0[5] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %10 = llvm.extractvalue %0[6] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %11 = llvm.extractvalue %0[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %12 = llvm.extractvalue %0[8] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %13 = llvm.extractvalue %0[9] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %14 = llvm.extractvalue %0[10] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %15 = llvm.extractvalue %0[11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %16 = llvm.extractvalue %0[12] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %17 = llvm.extractvalue %0[13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %18 = llvm.extractvalue %0[14] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %19 = llvm.extractvalue %0[15] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %20 = llvm.extractvalue %0[16] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %21 = llvm.extractvalue %0[17] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %22 = llvm.extractvalue %0[18] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %23 = llvm.extractvalue %0[19] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %24 = llvm.extractvalue %0[20] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %25 = llvm.extractvalue %0[21] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %26 = llvm.extractvalue %0[22] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %27 = llvm.extractvalue %0[23] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %28 = llvm.extractvalue %0[24] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %29 = llvm.extractvalue %0[25] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %30 = llvm.extractvalue %0[26] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %31 = llvm.extractvalue %0[27] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %32 = llvm.extractvalue %0[28] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %33 = llvm.extractvalue %0[29] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %34 = llvm.extractvalue %0[30] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %35 = llvm.extractvalue %0[31] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %36 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %arg0, %arg1, %1, %2, %3 : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %37 = llvm.mlir.constant(1 : i32) : i32 + %38 = llvm.mlir.constant(1 : i32) : i32 + %39 = llvm.mlir.constant(1 : i32) : i32 + %40 = llvm.extractvalue %36[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %41 = llvm.extractvalue %36[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %42 = llvm.extractvalue %36[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %43 = llvm.extractvalue %36[3] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %44 = llvm.extractvalue %36[4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %45 = llvm.extractvalue %36[5] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %46 = llvm.extractvalue %36[6] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %47 = llvm.extractvalue %36[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %48 = llvm.extractvalue %36[8] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %49 = llvm.extractvalue %36[9] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %50 = llvm.extractvalue %36[10] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %51 = llvm.extractvalue %36[11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %52 = llvm.extractvalue %36[12] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %53 = llvm.extractvalue %36[13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %54 = llvm.extractvalue %36[14] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %55 = llvm.extractvalue %36[15] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %56 = llvm.extractvalue %36[16] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %57 = llvm.extractvalue %36[17] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %58 = llvm.extractvalue %36[18] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %59 = llvm.extractvalue %36[19] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %60 = llvm.extractvalue %36[20] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %61 = llvm.extractvalue %36[21] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %62 = llvm.extractvalue %36[22] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %63 = llvm.extractvalue %36[23] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %64 = llvm.extractvalue %36[24] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %65 = llvm.extractvalue %36[25] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %66 = llvm.extractvalue %36[26] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %67 = llvm.extractvalue %36[27] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %68 = llvm.extractvalue %36[28] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %69 = llvm.extractvalue %36[29] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %70 = llvm.extractvalue %36[30] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %71 = llvm.extractvalue %36[31] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %72 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %arg0, %arg1, %37, %38, %39 : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + return %72 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + } +} + diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir @@ -0,0 +1,104 @@ +// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file -verify-diagnostics %s + +!mat64f32 = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32)> +func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{ + %result = llvm.mlir.undef : !mat64f32 + // expected-error @+1 {{'nvvm.wgmma.mma_async' op results 64, however output struct has 7 elements}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !mat64f32 -> !mat64f32 + return %res : !mat64f32 +} + +// ----- + +func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) { + %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is 'f32'}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, , ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { + %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // expected-error @+1 {{shape 'm' must be 64}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { + %result = llvm.mlir.undef : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> + // expected-error @+1 {{op all elements in struct must be same type but there is 'i32'}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> + -> !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { + %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // expected-error @+1 {{op shape 'k' must be 16 for input type f16}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_transpose(%descA : i64, %descB : i64) { + %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // expected-error @+1 {{op given layouts layout_a = col and layout_b = col for input types tf32 and tf32 requires transpose. However, this is only supported for: f16 and bf16}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_transpose(%descA : i64, %descB : i64) { + %result = llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16)> + // expected-error @+1 {{'nvvm.wgmma.mma_async' op 'f16' += tf32 * tf32, it is not supported.}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + :!llvm.struct<(f16, f16, f16, f16)> + -> !llvm.struct<(f16, f16, f16, f16)> + return +} diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -103,3 +103,168 @@ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32) return } + +// ----- + +!mat64f32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32)> + +// CHECK-LABEL: @wgmma_f32_f16_f16( +// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64 +func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{ + // CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct + // CHECK: %[[A1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[A5:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[V13:.*]] = llvm.extractvalue %[[RES]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[RES1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $18, 0;\0Awgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $16, $17, p, $19, $20, $21, $22;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11]], %{{.*}}, %[[V13]], %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[C2:.*]] = arith.constant 2 : i64 + // CHECK: %[[DESCa:.+]] = arith.addi %[[ARG0]], %[[C2]] : i64 + // CHECK: %[[DESCb:.+]] = arith.addi %[[ARG1]], %[[C2]] : i64 + // CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES1]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[V4_2:.*]] = llvm.extractvalue %[[RES1]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[V11_2:.*]] = llvm.extractvalue %[[RES1]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[V13_2:.*]] = llvm.extractvalue %[[RES1]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $18, 0;\0Awgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $16, $17, p, $19, $20, $21, $22;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %result = llvm.mlir.undef : !mat64f32 + %result1 = nvvm.wgmma.mma_async + %descA, %descB, + #nvvm.shape, + D [%result, #nvvm.wgmma_scale_out], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + :!mat64f32 -> !mat64f32 + %c2 = arith.constant 2 : i64 + %descAnext = arith.addi %descA, %c2 : i64 + %descBnext = arith.addi %descB, %c2 : i64 + %result2 = nvvm.wgmma.mma_async + %descAnext, %descBnext, + #nvvm.shape, + D [%result1, #nvvm.wgmma_scale_out], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !mat64f32 -> !mat64f32 + return %result2 : !mat64f32 +} + +// ----- + +!mat16i32 = !llvm.struct<(i32, i32, i32, i32)> + +// CHECK-LABEL: @wgmma_s32_s8_s8_satfinite( +// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64 +func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{ + %result = llvm.mlir.undef : !mat16i32 +// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct +// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] +// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1] +// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2] +// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3] +// CHECK: %[[RES_2:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)> +// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0] +// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1] +// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2] +// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3] +// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}} +// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0] +// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1] +// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2] +// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3] +// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} + %result1 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, #nvvm.wgmma_scale_out, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !mat16i32 -> !mat16i32 + %result2 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result1, #nvvm.wgmma_scale_out, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !mat16i32 -> !mat16i32 + %result3 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result2, #nvvm.wgmma_scale_out, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !mat16i32 -> !mat16i32 + return %result3 : !mat16i32 +} + +// CHECK-LABEL: @wgmma_s32_u8_u8( + // CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64 +func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 { +// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct +// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] +// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1] +// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2] +// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3] +// CHECK: %[[RES_2:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)> +// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0] +// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1] +// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2] +// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3] +// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}} +// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0] +// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1] +// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2] +// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3] +// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} + %result = llvm.mlir.undef : !mat16i32 + %result1 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, #nvvm.wgmma_scale_out], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !mat16i32 -> !mat16i32 + %result2 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result1, #nvvm.wgmma_scale_out], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !mat16i32 -> !mat16i32 + %result3 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result2, #nvvm.wgmma_scale_out], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !mat16i32 -> !mat16i32 + return %result3 : !mat16i32 +} + +// ----- + +!mat32f32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32)> + +// CHECK-LABEL: @wgmma_f32_tf32_tf32 +func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 { + // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" + // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" + %result = llvm.mlir.undef : !mat32f32 + %result1 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, #nvvm.wgmma_scale_out], + A [#nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + B [#nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + : !mat32f32 -> !mat32f32 + %result2 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result1, #nvvm.wgmma_scale_out], + A [#nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + B [#nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + : !mat32f32 -> !mat32f32 + return %result2 : !mat32f32 +} \ No newline at end of file