diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -167,7 +167,7 @@ /*desc=*/[{Generate constant value.}], /*retType=*/"::mlir::Value", /*methodName=*/"makeConstantI32", - /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "unsigned" : $val), + /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val), /*methodBody=*/"", /*defaultImpl=*/ [{ mlir::Operation* op = $_op; @@ -1472,4 +1472,108 @@ }]; } +/// Enum attribute type for the negating of input operands +def WGMMAScaleInNeg : I32EnumAttrCase<"neg", -1>; +def WGMMAScaleInOne : I32EnumAttrCase<"one", 1>; +def WGMMAScaleIn : I32EnumAttr<"WGMMAScaleIn", "WGMMA overflow options", + [WGMMAScaleInOne, WGMMAScaleInNeg]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def WGMMAScaleInAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +/// Enum attribute type for the output operand +def WGMMAScaleOutZero : I32EnumAttrCase<"zero", 0>; +def WGMMAScaleOutOne : I32EnumAttrCase<"one", 1>; +def WGMMAScaleOut : I32EnumAttr<"WGMMAScaleOut", "WGMMA input predicate", + [WGMMAScaleOutZero, WGMMAScaleOutOne]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def WGMMAScaleOutAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_WgmmaMmaSyncOp : NVVM_Op<"wgmma.mma_async", + [DeclareOpInterfaceMethods]> { + let results = (outs Optional:$results); + let arguments = (ins + Optional:$inouts, + I64:$descriptorA, + I64:$descriptorB, + NVVM_MMAShapeAttr:$shape, + MMATypesAttr:$typeA, + MMATypesAttr:$typeB, + WGMMAScaleOutAttr:$scaleD, + WGMMAScaleInAttr:$scaleA, + WGMMAScaleInAttr:$scaleB, + MMALayoutAttr:$layoutA, + MMALayoutAttr:$layoutB, + OptionalAttr:$satfinite + ); + + let assemblyFormat = "$descriptorA `,` $descriptorB `,` $shape `,` `scale_out` `=` $scaleD `,` `[` `lhs` `=` $typeA `,` $scaleA `,` $layoutA `]` `,` `[` `rhs` `=` $typeB `,` $scaleB `,` $layoutB `]` (`,` `inout` `=` $inouts `->` type($inouts)^)? attr-dict (`->` type($results)^)?"; + + let description = [{ + The warpgroup (128 threads) level matrix multiply and accumulate operation + has either of the following forms, where matrix D is called accumulator: + D = A * B + D + D = A * B, where the input from accumulator D is disabled. + + Supported shapes: + ``` + |-------------------|--------------------|----------------|---------------| + | | f16 += f16 * f16 | s32 += s8 * s8 | | + |f32 += tf32 * tf32 | f32 += f16 * f16 | s32 += s8 * u8 |s32 += b1 * b1 | + | | f32 += bf16 * bf16 | s32 += u8 * u8 | | + |-------------------|--------------------|----------------|---------------| + | .m64n8k8 | .m64n8k16 | .m64n8k32 | .m64n8k256 | + | .m64n16k8 | .m64n16k16 | .m64n16k32 | .m64n16k256 | + | .m64n24k8 | .m64n24k16 | .m64n24k32 | .m64n24k256 | + | .m64n32k8 | .m64n32k16 | .m64n32k32 | .m64n32k256 | + | .m64n40k8 | .m64n40k16 | .m64n48k32 | .m64n48k256 | + | .m64n48k8 | .m64n48k16 | .m64n64k32 | .m64n64k256 | + | .m64n56k8 | .m64n56k16 | .m64n80k32 | .m64n80k256 | + | .m64n64k8 | .m64n64k16 | .m64n96k32 | .m64n96k256 | + | .m64n72k8 | .m64n72k16 | .m64n112k32 | .m64n112k256 | + | .m64n80k8 | .m64n80k16 | .m64n128k32 | .m64n128k256 | + | .m64n88k8 | .m64n88k16 | .m64n144k32 | .m64n144k256 | + | .m64n96k8 | .m64n96k16 | .m64n160k32 | .m64n160k256 | + | .m64n104k8 | .m64n104k16 | .m64n176k32 | .m64n176k256 | + | .m64n112k8 | .m64n112k16 | .m64n192k32 | .m64n192k256 | + | .m64n120k8 | .m64n120k16 | .m64n208k32 | .m64n208k256 | + | .m64n128k8 | .m64n128k16 | .m64n224k32 | .m64n224k256 | + | .m64n136k8 | .m64n136k16 | .m64n240k32 | .m64n240k256 | + | .m64n144k8 | .m64n144k16 | .m64n256k32 | .m64n256k256 | + | .m64n152k8 | .m64n152k16 | | | + | .m64n160k8 | .m64n160k16 | | | + | .m64n168k8 | .m64n168k16 | | | + | .m64n176k8 | .m64n176k16 | | | + | .m64n184k8 | .m64n184k16 | | | + | .m64n192k8 | .m64n192k16 | | | + | .m64n200k8 | .m64n200k16 | | | + | .m64n208k8 | .m64n208k16 | | | + | .m64n216k8 | .m64n216k16 | | | + | .m64n224k8 | .m64n224k16 | | | + | .m64n232k8 | .m64n232k16 | | | + | .m64n240k8 | .m64n240k16 | | | + | .m64n248k8 | .m64n248k16 | | | + | .m64n256k8 | .m64n256k16 | | | + |-------------------|--------------------|----------------|---------------| + ``` + + See for more information: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + void getAsmValues(RewriterBase &rewriter, + llvm::SmallVectorImpl> &asmValues); + }]; +} + #endif // NVVMIR_OPS diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -52,7 +52,7 @@ namespace { class PtxBuilder { - Operation *op; + NVVM::BasicPtxBuilderInterface op; PatternRewriter &rewriter; std::string asmStr; SmallVector asmVals; @@ -61,30 +61,35 @@ bool hasResult = false; // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints - char getRegisterType(Value v) { - if (v.getDefiningOp()) - return 'n'; - if (v.getType().isInteger(16)) + char getRegisterType(Type type) { + if (type.isInteger(16)) return 'h'; - if (v.getType().isInteger(32)) + if (type.isInteger(32)) return 'r'; - if (v.getType().isInteger(64)) + if (type.isInteger(64)) return 'l'; - if (v.getType().isF32()) + if (type.isF32()) return 'f'; - if (v.getType().isF64()) + if (type.isF64()) return 'd'; - if (auto ptr = v.getType().dyn_cast()) { + if (auto ptr = type.dyn_cast()) { // Shared address spaces is addressed with 32-bit pointers. if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) { return 'r'; } return 'l'; } - assert(false && "Register type is not handled yet"); + op->emitError() << "Register type could not deduced from MLIR type: " + << type; return ' '; } + char getRegisterType(Value v) { + if (v.getDefiningOp()) + return 'n'; + return getRegisterType(v.getType()); + } + public: PtxBuilder(Operation *op, PatternRewriter &rewriter, std::string ptxAsm, bool sideEffects = false) @@ -92,26 +97,52 @@ sideEffects(sideEffects) {} void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) { - llvm::raw_string_ostream ss(asmConstraints); - if (itype == PTXRegisterMod::Read) { - asmVals.push_back(v); - } else if (itype == PTXRegisterMod::ReadWrite) { - asmVals.push_back(v); - ss << "+"; - hasResult = true; - } else if (itype == PTXRegisterMod::Write) { - ss << "="; + LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n"); + auto getModifier = [&]() -> const char * { + if (itype == PTXRegisterMod::ReadWrite) + return "+"; + if (itype == PTXRegisterMod::Write) + return "="; + return ""; + }; + auto addValue = [&](Value v) { + if (itype == PTXRegisterMod::Read) { + asmVals.push_back(v); + return; + } + if (itype == PTXRegisterMod::ReadWrite) + asmVals.push_back(v); hasResult = true; + }; + + llvm::raw_string_ostream ss(asmConstraints); + // Handle Structs + if (auto stype = dyn_cast(v.getType())) { + if (itype == PTXRegisterMod::Write) { + addValue(v); + } + for (auto [idx, t] : llvm::enumerate(stype.getBody())) { + if (itype != PTXRegisterMod::Write) { + Value extractValue = + rewriter.create(op->getLoc(), v, idx); + addValue(extractValue); + } + ss << getModifier() << getRegisterType(t) << ","; + ss.flush(); + } + return; } - ss << getRegisterType(v) << ","; + // Handle Scalars + addValue(v); + ss << getModifier() << getRegisterType(v) << ","; ss.flush(); } LLVM::InlineAsmOp build() { auto asmDialectAttr = LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT); - Type resultType = hasResult ? op->getResult(0).getType() - : LLVM::LLVMVoidType::get(op->getContext()); + + auto resultTypes = op->getResultTypes(); // Remove the last comma from the constraints string. if (!asmConstraints.empty() && @@ -122,7 +153,8 @@ std::replace(asmStr.begin(), asmStr.end(), '%', '$'); return rewriter.create( - op->getLoc(), resultType, + op->getLoc(), + /*result types=*/resultTypes, /*operands=*/asmVals, /*asm_string=*/llvm::StringRef(asmStr), /*constraints=*/asmConstraints.data(), @@ -158,6 +190,7 @@ } SmallVector> asmValues; + LLVM_DEBUG(DBGS() << op.getPtx() << "\n"); PtxBuilder generator(op, rewriter, op.getPtx(), op.hasSideEffect()); op.getAsmValues(rewriter, asmValues); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -24,7 +24,9 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Types.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Attributes.h" @@ -32,6 +34,7 @@ #include "llvm/IR/Type.h" #include "llvm/Support/Casting.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" #include #include @@ -703,6 +706,286 @@ return success(); } +LogicalResult NVVM::WgmmaMmaSyncOp::verify() { + if ((!getResults() && !getInouts()) || (getResults() && getInouts())) + return emitOpError() << "expected output or inout"; + Value outValue = getResults() ? getResults() : getInouts(); + auto stype = dyn_cast(outValue.getType()); + if (!stype) + return emitOpError() << "expected results to be struct"; + Type outputType = stype.getBody().front(); + int outputSize = stype.getBody().size(); + for (Type t : stype.getBody()) { + if (t != outputType) + return emitOpError() + << "all elements in struct must be same type but there is " << t; + } + + if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) { + return emitOpError() << "does not support the given output type " + << outputType; + } + if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg || + getScaleB() == NVVM::WGMMAScaleIn::neg)) { + return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; + } + + // Check M + if (getShape().getM() != 64) + return emitOpError() << "shape 'm' must be 64"; + + // Check K + mlir::NVVM::MMATypes typeA = getTypeA(); + mlir::NVVM::MMATypes typeB = getTypeB(); + switch (typeA) { + case mlir::NVVM::MMATypes::bf16: + case mlir::NVVM::MMATypes::f16: + if (typeA != typeB) { + return emitOpError() << "input types must be same but got " + << NVVM::stringifyMMATypes(typeA) << " and " + << NVVM::stringifyMMATypes(typeB); + } + if (getShape().getK() != 16) { + return emitOpError() << "shape 'k' must be 16 " + << "for input type " + << NVVM::stringifyMMATypes(typeA); + } + break; + case mlir::NVVM::MMATypes::tf32: + if (typeA != typeB) { + return emitOpError() << "input types must be same but got " + << NVVM::stringifyMMATypes(typeA) << " and " + << NVVM::stringifyMMATypes(typeB); + } + if (getShape().getK() != 8) { + return emitOpError() << "shape 'k' must be 8 " + << "for input type " + << NVVM::stringifyMMATypes(typeA); + } + break; + case mlir::NVVM::MMATypes::s8: + case mlir::NVVM::MMATypes::u8: + if (typeB != mlir::NVVM::MMATypes::s8 && + typeB != mlir::NVVM::MMATypes::u8) { + return emitOpError() << "input type of rhs could be " + << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::s8) + << " or " + << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::u8) + << " same but got and " + << NVVM::stringifyMMATypes(typeB); + } + if (getShape().getK() != 32) { + emitOpError() << "shape 'k' must be 32 " + << "for input type " << NVVM::stringifyMMATypes(typeA); + } + break; + case mlir::NVVM::MMATypes::b1: + if (typeA != typeB) { + return emitOpError() << "input types must be same but got " + << NVVM::stringifyMMATypes(typeA) << " and " + << NVVM::stringifyMMATypes(typeB); + } + if (getShape().getK() != 256) { + return emitOpError() << "shape 'k' must be 256 " + << "for input type " + << NVVM::stringifyMMATypes(typeA); + } + break; + default: + return emitOpError() << "Unsupported input type " + << NVVM::stringifyMMATypes(typeA) << " and " + << NVVM::stringifyMMATypes(typeB); + } + + // Check N + SmallVector allowedNShapesF16 = {8, 16, 24, 32, 40, 48, 56, 64, + 72, 80, 88, 96, 104, 112, 120, 128, + 136, 144, 152, 160, 168, 176, 184, 192, + 200, 208, 216, 224, 232, 240, 248, 256}; + SmallVector allowedNShapesU8S8B1 = {8, 16, 24, 32, 48, 64, + 80, 96, 112, 128, 144, 160, + 176, 192, 208, 224, 240, 256}; + + bool validGEMMType = false; + // f16 += f16 * f16 + if (outputType.isF16() && typeA == mlir::NVVM::MMATypes::f16) { + if (!llvm::any_of(allowedNShapesF16, + [&](int n) { return getShape().getN() == n; })) { + return emitOpError() << "has input type " + << NVVM::stringifyMMATypes(typeA) << " n is set to " + << getShape().getN() << ", it is not supported."; + } + validGEMMType = true; + } + // f32 += tf32 * tf32| f32 += f16 * f16| f16 += bf16 * bf16 + if (outputType.isF32() && (typeA == mlir::NVVM::MMATypes::bf16 || + typeA == mlir::NVVM::MMATypes::tf32 || + typeA == mlir::NVVM::MMATypes::f16)) { + if (!llvm::any_of(allowedNShapesF16, + [&](int n) { return getShape().getN() == n; })) { + return emitOpError() << "has input type " + << NVVM::stringifyMMATypes(typeA) << " n is set to " + << getShape().getN() << ", it is not supported."; + } + validGEMMType = true; + } + // s32 += s8 * s8 | s32 += s8 * u8 | s32 += u8 * u8 | s32 += b1 * b1 + if (outputType.isInteger(32) && + (typeA == mlir::NVVM::MMATypes::s8 || typeA == mlir::NVVM::MMATypes::u8 || + typeA == mlir::NVVM::MMATypes::b1)) { + if (!llvm::any_of(allowedNShapesU8S8B1, + [&](int n) { return getShape().getN() == n; })) { + return emitOpError() << "has input type " + << NVVM::stringifyMMATypes(typeA) << " n is set to " + << getShape().getN() << ", it is not supported."; + } + validGEMMType = true; + } + + if (!validGEMMType) { + return emitOpError() << outputType + << " += " << NVVM::stringifyMMATypes(typeA) << " * " + << NVVM::stringifyMMATypes(typeB) + << ", it is not supported."; + } + + // Check transpose is needed from the given layouts. It is only + // supported for bf16 or f16. + if ((typeA != mlir::NVVM::MMATypes::f16 && + typeA != mlir::NVVM::MMATypes::bf16) && + (getLayoutA() == mlir::NVVM::MMALayout::col || + getLayoutB() == mlir::NVVM::MMALayout::col)) { + return emitOpError() + << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) + << " and layout_b = " << stringifyMMALayout(getLayoutB()) + << " for input types " << stringifyMMATypes(typeA) << " and " + << stringifyMMATypes(typeB) + << " requires transpose. However, this is only supported for: " + << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and " + << stringifyMMATypes(mlir::NVVM::MMATypes::bf16); + } + + // Check number of result registers + int expectedOutput; + if (outputType.isF32() || outputType.isInteger(32)) + expectedOutput = getShape().getN() / 2; + if (outputType.isF16()) + expectedOutput = getShape().getN() / 4; + if (outputSize != expectedOutput) { + return emitOpError() << "results " << expectedOutput + << ", however output struct has " << outputSize + << " elements"; + } + // Check satfinite is set. It is only for s32 accumulator + if (!outputType.isInteger(32) && getSatfinite().value_or(false)) { + return emitOpError() + << " `satfinite` can be only used with s32 accumulator, however " + "the current accumulator is " + << outputType; + } + + return success(); +} + +std::string NVVM::WgmmaMmaSyncOp::getPtx() { + + int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); + bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 || + getTypeA() == mlir::NVVM::MMATypes::bf16; + + Value outValue = getResults() ? getResults() : getInouts(); + auto stype = dyn_cast(outValue.getType()); + Type outputType = stype.getBody().front(); + std::string outputTypeName; + if (outputType.isF16()) + outputTypeName = "f16"; + if (outputType.isF32()) + outputTypeName = "f32"; + else if (outputType.isInteger(32)) + outputTypeName = "s32"; + int expectedOutputRegisters; + if (outputType.isF32() || outputType.isInteger(32)) + expectedOutputRegisters = getShape().getN() / 2; + if (outputType.isF16()) + expectedOutputRegisters = getShape().getN() / 4; + + std::string ptx; + llvm::raw_string_ostream ss(ptx); + + ss << "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, $" + << (expectedOutputRegisters + 2) + << ", 0;\n" + "wgmma.mma_async.sync.aligned.m" + << m << "n" << n << "k" << k << "." << outputTypeName << "." + << stringifyMMATypes(getTypeA()) << "." << stringifyMMATypes(getTypeB()); + if (getSatfinite().value_or(false)) + ss << ".satfinite"; + ss << " {"; + int regCnt = 0; + for (; regCnt < expectedOutputRegisters; ++regCnt) { + ss << "$" << regCnt; + if (regCnt != expectedOutputRegisters - 1) + ss << ", "; + } + + ss << "},"; + ss << " $" << (expectedOutputRegisters) << "," + << " $" << (expectedOutputRegisters + 1) << "," + << " p"; + if (!outputType.isInteger(32)) { + ss << ", $" << (expectedOutputRegisters + 3) << ", $" + << (expectedOutputRegisters + 4); + } + // Don't add transpose parameters unless needed. + if (isF16) { + ss << ", $" << (expectedOutputRegisters + 5) << ", $" + << (expectedOutputRegisters + 6); + } + ss << ";\n" + << "}\n"; + ss.flush(); + return ptx; +} + +void NVVM::WgmmaMmaSyncOp::getAsmValues( + RewriterBase &rewriter, + llvm::SmallVectorImpl> + &asmValues) { + Value outValue = getResults() ? getResults() : getInouts(); + auto stype = dyn_cast(outValue.getType()); + Type outputType = stype.getBody().front(); + bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 || + getTypeA() == mlir::NVVM::MMATypes::bf16; + if (getResults()) + asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); + if (getInouts()) + asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite}); + asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read}); + asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); + asmValues.push_back({makeConstantI32(rewriter, static_cast(getScaleD())), + mlir::NVVM::PTXRegisterMod::Read}); + if (!outputType.isInteger(32)) { + asmValues.push_back( + {makeConstantI32(rewriter, + getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), + mlir::NVVM::PTXRegisterMod::Read}); + asmValues.push_back( + {makeConstantI32(rewriter, + getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1), + mlir::NVVM::PTXRegisterMod::Read}); + } + if (isF16) { + asmValues.push_back( + {makeConstantI32(rewriter, static_cast(getLayoutA())), + mlir::NVVM::PTXRegisterMod::Read}); + asmValues.push_back( + {makeConstantI32(rewriter, static_cast(getLayoutB())), + mlir::NVVM::PTXRegisterMod::Read}); + } +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir @@ -0,0 +1,91 @@ +// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file -verify-diagnostics %s + +!mat64f32 = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32)> +func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{ + // expected-error @+1 {{'nvvm.wgmma.mma_async' op results 64, however output struct has 7 elements}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + -> !mat64f32 + return %res : !mat64f32 +} + +// ----- + +func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) { + // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is 'f32'}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] {satfinite} + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { + // expected-error @+1 {{shape 'm' must be 64}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { + // expected-error @+1 {{op all elements in struct must be same type but there is 'i32'}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + -> !llvm.struct<(f32, f32, i32, i32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { + // expected-error @+1 {{op shape 'k' must be 16 for input type f16}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_transpose(%descA : i64, %descB : i64) { + // expected-error @+1 {{op given layouts layout_a = col and layout_b = col for input types tf32 and tf32 requires transpose. However, this is only supported for: f16 and bf16}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + return +} + +// ----- + +func.func @wgmma_transpose(%descA : i64, %descB : i64) { + // expected-error @+1 {{'nvvm.wgmma.mma_async' op 'f16' += tf32 * tf32, it is not supported.}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + -> !llvm.struct<(f16, f16, f16, f16)> + return +} diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -100,3 +100,115 @@ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32) return } + +!mat64f32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32)> + +// CHECK-LABEL : @wgmma_f32_f16_f16(%[[ARG0:.*]] : i64, %[[ARG1:.*]] : i64 +func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{ + // CHECK : %[[A0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK : %[[A1:.*]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK : %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK : %[[A3:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK : %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK : %[[RES:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;\0Awgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63}, $64, $65, p, $67, $68, $69, $70;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l,n,n,n,n,n" %[[ARG0]], %[[ARG1]], %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : (i64, i64, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK : %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK : %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK : %[[V45:.*]] = llvm.extractvalue %[[RES]][45] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK : %[[V63:.*]] = llvm.extractvalue %[[RES]][63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;\0Awgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63}, $64, $65, p, $67, $68, $69, $70;\0A}\0A", "+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,+f,l,l,n,n,n,n,n" %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V45]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V63]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) -> () + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + -> !mat64f32 + %c2 = arith.constant 2 : i64 + %descAnext = arith.addi %descA, %c2 : i64 + %descBnext = arith.addi %descB, %c2 : i64 + nvvm.wgmma.mma_async + %descAnext, %descBnext, + #nvvm.shape, + scale_out = #nvvm.wgmma_scale_out, + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + inout = %res -> !mat64f32 + return %res : !mat64f32 +} + +// ----- + +!mat16i32 = !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + +// CHECK-LABEL: @wgmma_s32_s8_s8_satfinite +func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{ + // CHECK: %[[A0:.+]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $10, 0;\0Awgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite {$0, $1, $2, $3, $4, $5, $6, $7}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,=r,=r,=r,=r,l,l,n" %{{.*}}, %{{.*}}, %[[A0]] : (i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[V0:.+]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[V7:.+]] = llvm.extractvalue %[[RES]][7] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $10, 0;\0Awgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 {$0, $1, $2, $3, $4, $5, $6, $7}, $8, $9, p;\0A}\0A", "+r,+r,+r,+r,+r,+r,+r,+r,l,l,n" %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V7]], %{{.*}}, %{{.*}} : (i32, i32, i32, i32, i32, i32, i32, i32, i64, i64, i32) -> () + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + {satfinite} + -> !mat16i32 + nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + inout = %res -> !mat16i32 + return %res : !mat16i32 +} + +// CHECK-LABEL: @wgmma_s32_u8_u8 +func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32{ + // CHECK: %[[A0:.+]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $10, 0;\0Awgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 {$0, $1, $2, $3, $4, $5, $6, $7}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,=r,=r,=r,=r,l,l,n" %{{.*}}, %{{.*}}, %[[A0]] : (i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[V0:.+]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[V7:.+]] = llvm.extractvalue %[[RES]][7] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $10, 0;\0Awgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 {$0, $1, $2, $3, $4, $5, $6, $7}, $8, $9, p;\0A}\0A", "+r,+r,+r,+r,+r,+r,+r,+r,l,l,n" %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V7]], %{{.*}}, %{{.*}} : (i32, i32, i32, i32, i32, i32, i32, i32, i64, i64, i32) -> () + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + -> !mat16i32 + nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + , inout = %res -> !mat16i32 + return %res : !mat16i32 +} + +!mat32f32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32)> + +// CHECK-LABEL: @wgmma_f32_tf32_tf32 +func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32{ + // CHECK: %[[A0:.+]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[A1:.+]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[A2:.+]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l,n,n,n" %{{.*}}, %{{.*}}, %[[A0]], %[[A1]], %[[A2]] : (i64, i64, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + scale_out = , + [lhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + [rhs = #nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + -> !mat32f32 + return %res : !mat32f32 +}