diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1497,6 +1497,28 @@ let assemblyFormat = "`<` $value `>`"; } +/// Enum attribute of the different PTX element types used for WGMMA operands. +def WGMMATypeF16 : I32EnumAttrCase<"f16", 0>; +def WGMMATypeTF32 : I32EnumAttrCase<"tf32", 1>; +def WGMMATypeU8 : I32EnumAttrCase<"u8", 2>; +def WGMMATypeS8 : I32EnumAttrCase<"s8", 3>; +def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>; +def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>; +def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>; +def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>; +def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types", + [WGMMATypeF16, WGMMATypeTF32, + WGMMATypeU8, WGMMATypeS8, + WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3, + WGMMATypeF8E5M2]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def WGMMATypesAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + + def NVVM_WgmmaMmaSyncOp : NVVM_Op<"wgmma.mma_async", [DeclareOpInterfaceMethods, PredOpTrait<"input struct and result struct must be the same type", @@ -1508,15 +1530,14 @@ I64:$descriptorA, I64:$descriptorB, NVVM_MMAShapeAttr:$shape, - MMATypesAttr:$typeA, - MMATypesAttr:$typeB, + WGMMATypesAttr:$typeA, + WGMMATypesAttr:$typeB, WGMMAScaleOutAttr:$scaleD, WGMMAScaleInAttr:$scaleA, WGMMAScaleInAttr:$scaleB, MMALayoutAttr:$layoutA, MMALayoutAttr:$layoutB, OptionalAttr:$satfinite - // OptionalAttr:$satfinite ); let assemblyFormat = [{ @@ -1536,44 +1557,50 @@ Supported shapes: ``` - |-------------------|--------------------|----------------|---------------| - | | f16 += f16 * f16 | s32 += s8 * s8 | | - |f32 += tf32 * tf32 | f32 += f16 * f16 | s32 += s8 * u8 |s32 += b1 * b1 | - | | f32 += bf16 * bf16 | s32 += u8 * u8 | | - |-------------------|--------------------|----------------|---------------| - | .m64n8k8 | .m64n8k16 | .m64n8k32 | .m64n8k256 | - | .m64n16k8 | .m64n16k16 | .m64n16k32 | .m64n16k256 | - | .m64n24k8 | .m64n24k16 | .m64n24k32 | .m64n24k256 | - | .m64n32k8 | .m64n32k16 | .m64n32k32 | .m64n32k256 | - | .m64n40k8 | .m64n40k16 | .m64n48k32 | .m64n48k256 | - | .m64n48k8 | .m64n48k16 | .m64n64k32 | .m64n64k256 | - | .m64n56k8 | .m64n56k16 | .m64n80k32 | .m64n80k256 | - | .m64n64k8 | .m64n64k16 | .m64n96k32 | .m64n96k256 | - | .m64n72k8 | .m64n72k16 | .m64n112k32 | .m64n112k256 | - | .m64n80k8 | .m64n80k16 | .m64n128k32 | .m64n128k256 | - | .m64n88k8 | .m64n88k16 | .m64n144k32 | .m64n144k256 | - | .m64n96k8 | .m64n96k16 | .m64n160k32 | .m64n160k256 | - | .m64n104k8 | .m64n104k16 | .m64n176k32 | .m64n176k256 | - | .m64n112k8 | .m64n112k16 | .m64n192k32 | .m64n192k256 | - | .m64n120k8 | .m64n120k16 | .m64n208k32 | .m64n208k256 | - | .m64n128k8 | .m64n128k16 | .m64n224k32 | .m64n224k256 | - | .m64n136k8 | .m64n136k16 | .m64n240k32 | .m64n240k256 | - | .m64n144k8 | .m64n144k16 | .m64n256k32 | .m64n256k256 | - | .m64n152k8 | .m64n152k16 | | | - | .m64n160k8 | .m64n160k16 | | | - | .m64n168k8 | .m64n168k16 | | | - | .m64n176k8 | .m64n176k16 | | | - | .m64n184k8 | .m64n184k16 | | | - | .m64n192k8 | .m64n192k16 | | | - | .m64n200k8 | .m64n200k16 | | | - | .m64n208k8 | .m64n208k16 | | | - | .m64n216k8 | .m64n216k16 | | | - | .m64n224k8 | .m64n224k16 | | | - | .m64n232k8 | .m64n232k16 | | | - | .m64n240k8 | .m64n240k16 | | | - | .m64n248k8 | .m64n248k16 | | | - | .m64n256k8 | .m64n256k16 | | | - |-------------------|--------------------|----------------|---------------| + |--------------|--------------|------------|--------------|---------------| + | | | | |f16+=e4m3*e4m3 | + | | | | |f16+=e5m2*e5m2 | + |f32+=tf32*tf32|f16+=f16 *f16 | s32+=s8*s8 |s32 += b1 * b1|f16+=e5m2*e4m3 | + | |f32+=f16 *f16 | s32+=u8*u8 | |f16+=e4m3*e5m2 | + | |f32+=bf16*bf16| s32+=u8*u8 | |f16+=e4m3*e5m2 | + | |f32+=bf16*bf16| s32+=s8*u8 | |f32+=e4m3*e4m3 | + | | | s32+=u8*s8 | |f32+=e5m2*e5m2 | + | | | | |f32+=e4m3*e5m2 | + | | | | |f32+=e4m3*e5m2 | + |--------------|--------------|------------|--------------|---------------| + | .m64n8k8 | .m64n8k16 | .m64n8k32 | .m64n8k256 | .m64n8k32 | + | .m64n16k8 | .m64n16k16 | .m64n16k32 | .m64n16k256 | .m64n16k32 | + | .m64n24k8 | .m64n24k16 | .m64n24k32 | .m64n24k256 | .m64n24k32 | + | .m64n32k8 | .m64n32k16 | .m64n32k32 | .m64n32k256 | .m64n32k32 | + | .m64n40k8 | .m64n40k16 | .m64n48k32 | .m64n48k256 | .m64n40k32 | + | .m64n48k8 | .m64n48k16 | .m64n64k32 | .m64n64k256 | .m64n48k32 | + | .m64n56k8 | .m64n56k16 | .m64n80k32 | .m64n80k256 | .m64n56k32 | + | .m64n64k8 | .m64n64k16 | .m64n96k32 | .m64n96k256 | .m64n64k32 | + | .m64n72k8 | .m64n72k16 | .m64n112k32| .m64n112k256 | .m64n72k32 | + | .m64n80k8 | .m64n80k16 | .m64n128k32| .m64n128k256 | .m64n80k32 | + | .m64n88k8 | .m64n88k16 | .m64n144k32| .m64n144k256 | .m64n88k32 | + | .m64n96k8 | .m64n96k16 | .m64n160k32| .m64n160k256 | .m64n96k32 | + | .m64n104k8 | .m64n104k16 | .m64n176k32| .m64n176k256 | .m64n104k32 | + | .m64n112k8 | .m64n112k16 | .m64n192k32| .m64n192k256 | .m64n112k32 | + | .m64n120k8 | .m64n120k16 | .m64n208k32| .m64n208k256 | .m64n120k32 | + | .m64n128k8 | .m64n128k16 | .m64n224k32| .m64n224k256 | .m64n128k32 | + | .m64n136k8 | .m64n136k16 | .m64n240k32| .m64n240k256 | .m64n136k32 | + | .m64n144k8 | .m64n144k16 | .m64n256k32| .m64n256k256 | .m64n144k32 | + | .m64n152k8 | .m64n152k16 | | | .m64n152k32 | + | .m64n160k8 | .m64n160k16 | | | .m64n160k32 | + | .m64n168k8 | .m64n168k16 | | | .m64n168k32 | + | .m64n176k8 | .m64n176k16 | | | .m64n176k32 | + | .m64n184k8 | .m64n184k16 | | | .m64n184k32 | + | .m64n192k8 | .m64n192k16 | | | .m64n192k32 | + | .m64n200k8 | .m64n200k16 | | | .m64n200k32 | + | .m64n208k8 | .m64n208k16 | | | .m64n208k32 | + | .m64n216k8 | .m64n216k16 | | | .m64n216k32 | + | .m64n224k8 | .m64n224k16 | | | .m64n224k32 | + | .m64n232k8 | .m64n232k16 | | | .m64n232k32 | + | .m64n240k8 | .m64n240k16 | | | .m64n240k32 | + | .m64n248k8 | .m64n248k16 | | | .m64n248k32 | + | .m64n256k8 | .m64n256k16 | | | .m64n256k32 | + |--------------|--------------|------------|--------------|---------------| ``` See for more information: diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -37,6 +37,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include #include #include @@ -708,6 +709,81 @@ return success(); } +FailureOr getAllowedSizeK(NVVM::WGMMATypes typeA) { + if (typeA == NVVM::WGMMATypes::tf32) + return 8; + if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16) + return 16; + if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8) + return 32; + if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2) + return 32; + if (typeA == NVVM::WGMMATypes::b1) + return 256; + return failure(); +} + +LogicalResult isAllowedWGMMADataType(Type typeD, NVVM::WGMMATypes typeA, + NVVM::WGMMATypes typeB) { + switch (typeA) { + case NVVM::WGMMATypes::f16: + if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::f16) + return success(); + break; + case NVVM::WGMMATypes::tf32: + if (typeD.isF32() && typeB == NVVM::WGMMATypes::tf32) + return success(); + break; + case NVVM::WGMMATypes::u8: + case NVVM::WGMMATypes::s8: + if (typeD.isInteger(32) && + (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8)) + return success(); + break; + case NVVM::WGMMATypes::b1: + if (typeD.isInteger(32) && typeB == NVVM::WGMMATypes::b1) + return success(); + break; + case NVVM::WGMMATypes::bf16: + if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::bf16) + return success(); + break; + case NVVM::WGMMATypes::e4m3: + case NVVM::WGMMATypes::e5m2: + if ((typeD.isF32() || typeD.isF16()) && + (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3)) + return success(); + break; + } + return failure(); +} + +LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) { + SmallVector allowedN = {8, 16, 24, 32, 40, 48, 56, 64, + 72, 80, 88, 96, 104, 112, 120, 128, + 136, 144, 152, 160, 168, 176, 184, 192, + 200, 208, 216, 224, 232, 240, 248, 256}; + SmallVector allowedNshort = {8, 16, 24, 32, 48, 64, + 80, 96, 112, 128, 144, 160, + 176, 192, 208, 224, 240, 256}; + switch (typeA) { + case mlir::NVVM::WGMMATypes::f16: + case mlir::NVVM::WGMMATypes::tf32: + case mlir::NVVM::WGMMATypes::bf16: + case mlir::NVVM::WGMMATypes::e4m3: + case mlir::NVVM::WGMMATypes::e5m2: + if (llvm::any_of(allowedN, [&](int n) { return sizeN == n; })) + return success(); + break; + case mlir::NVVM::WGMMATypes::u8: + case mlir::NVVM::WGMMATypes::s8: + case mlir::NVVM::WGMMATypes::b1: + if (llvm::any_of(allowedNshort, [&](int n) { return sizeN == n; })) + return success(); + } + return failure(); +} + LogicalResult NVVM::WgmmaMmaSyncOp::verify() { Value outValue = getResults(); auto stype = dyn_cast(outValue.getType()); @@ -730,142 +806,49 @@ return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; } + mlir::NVVM::WGMMATypes typeA = getTypeA(); + mlir::NVVM::WGMMATypes typeB = getTypeB(); + if (failed(isAllowedWGMMADataType(outputType, typeA, typeB))) { + return emitOpError() << outputType + << " += " << NVVM::stringifyWGMMATypes(typeA) << " * " + << NVVM::stringifyWGMMATypes(typeB) + << ", it is not supported."; + } + // Check M if (getShape().getM() != 64) return emitOpError() << "shape 'm' must be 64"; // Check K - mlir::NVVM::MMATypes typeA = getTypeA(); - mlir::NVVM::MMATypes typeB = getTypeB(); - switch (typeA) { - case mlir::NVVM::MMATypes::bf16: - case mlir::NVVM::MMATypes::f16: - if (typeA != typeB) { - return emitOpError() << "input types must be same but got " - << NVVM::stringifyMMATypes(typeA) << " and " - << NVVM::stringifyMMATypes(typeB); - } - if (getShape().getK() != 16) { - return emitOpError() << "shape 'k' must be 16 " - << "for input type " - << NVVM::stringifyMMATypes(typeA); - } - break; - case mlir::NVVM::MMATypes::tf32: - if (typeA != typeB) { - return emitOpError() << "input types must be same but got " - << NVVM::stringifyMMATypes(typeA) << " and " - << NVVM::stringifyMMATypes(typeB); - } - if (getShape().getK() != 8) { - return emitOpError() << "shape 'k' must be 8 " - << "for input type " - << NVVM::stringifyMMATypes(typeA); - } - break; - case mlir::NVVM::MMATypes::s8: - case mlir::NVVM::MMATypes::u8: - if (typeB != mlir::NVVM::MMATypes::s8 && - typeB != mlir::NVVM::MMATypes::u8) { - return emitOpError() << "input type of rhs could be " - << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::s8) - << " or " - << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::u8) - << " same but got and " - << NVVM::stringifyMMATypes(typeB); - } - if (getShape().getK() != 32) { - emitOpError() << "shape 'k' must be 32 " - << "for input type " << NVVM::stringifyMMATypes(typeA); - } - break; - case mlir::NVVM::MMATypes::b1: - if (typeA != typeB) { - return emitOpError() << "input types must be same but got " - << NVVM::stringifyMMATypes(typeA) << " and " - << NVVM::stringifyMMATypes(typeB); - } - if (getShape().getK() != 256) { - return emitOpError() << "shape 'k' must be 256 " - << "for input type " - << NVVM::stringifyMMATypes(typeA); - } - break; - default: - return emitOpError() << "Unsupported input type " - << NVVM::stringifyMMATypes(typeA) << " and " - << NVVM::stringifyMMATypes(typeB); - } + FailureOr allowedK = getAllowedSizeK(typeA); + if (failed(allowedK) || allowedK.value() != getShape().getK()) + return emitOpError() << "shape 'k' must be " << allowedK.value() + << " for input type " + << NVVM::stringifyWGMMATypes(typeA); // Check N - SmallVector allowedNShapesF16 = {8, 16, 24, 32, 40, 48, 56, 64, - 72, 80, 88, 96, 104, 112, 120, 128, - 136, 144, 152, 160, 168, 176, 184, 192, - 200, 208, 216, 224, 232, 240, 248, 256}; - SmallVector allowedNShapesU8S8B1 = {8, 16, 24, 32, 48, 64, - 80, 96, 112, 128, 144, 160, - 176, 192, 208, 224, 240, 256}; - - bool validGEMMType = false; - // f16 += f16 * f16 - if (outputType.isF16() && typeA == mlir::NVVM::MMATypes::f16) { - if (!llvm::any_of(allowedNShapesF16, - [&](int n) { return getShape().getN() == n; })) { - return emitOpError() << "has input type " - << NVVM::stringifyMMATypes(typeA) << " n is set to " - << getShape().getN() << ", it is not supported."; - } - validGEMMType = true; - } - // f32 += tf32 * tf32| f32 += f16 * f16| f16 += bf16 * bf16 - if (outputType.isF32() && (typeA == mlir::NVVM::MMATypes::bf16 || - typeA == mlir::NVVM::MMATypes::tf32 || - typeA == mlir::NVVM::MMATypes::f16)) { - if (!llvm::any_of(allowedNShapesF16, - [&](int n) { return getShape().getN() == n; })) { - return emitOpError() << "has input type " - << NVVM::stringifyMMATypes(typeA) << " n is set to " - << getShape().getN() << ", it is not supported."; - } - validGEMMType = true; - } - // s32 += s8 * s8 | s32 += s8 * u8 | s32 += u8 * u8 | s32 += b1 * b1 - if (outputType.isInteger(32) && - (typeA == mlir::NVVM::MMATypes::s8 || typeA == mlir::NVVM::MMATypes::u8 || - typeA == mlir::NVVM::MMATypes::b1)) { - if (!llvm::any_of(allowedNShapesU8S8B1, - [&](int n) { return getShape().getN() == n; })) { - return emitOpError() << "has input type " - << NVVM::stringifyMMATypes(typeA) << " n is set to " - << getShape().getN() << ", it is not supported."; - } - validGEMMType = true; + if (failed(isAllowedSizeN(getShape().getN(), typeA))) { + return emitOpError() << "has input type " + << NVVM::stringifyWGMMATypes(typeA) << " n is set to " + << getShape().getN() << ", it is not supported."; } - if (!validGEMMType) { - return emitOpError() << outputType - << " += " << NVVM::stringifyMMATypes(typeA) << " * " - << NVVM::stringifyMMATypes(typeB) - << ", it is not supported."; - } - - // Check transpose is needed from the given layouts. It is only - // supported for bf16 or f16. - if ((typeA != mlir::NVVM::MMATypes::f16 && - typeA != mlir::NVVM::MMATypes::bf16) && + // Check transpose (only available for f16/bf16) + if ((typeA != mlir::NVVM::WGMMATypes::f16 && + typeA != mlir::NVVM::WGMMATypes::bf16) && (getLayoutA() == mlir::NVVM::MMALayout::col || getLayoutB() == mlir::NVVM::MMALayout::col)) { return emitOpError() << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) << " and layout_b = " << stringifyMMALayout(getLayoutB()) - << " for input types " << stringifyMMATypes(typeA) << " and " - << stringifyMMATypes(typeB) + << " for input types " << stringifyWGMMATypes(typeA) << " and " + << stringifyWGMMATypes(typeB) << " requires transpose. However, this is only supported for: " << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and " << stringifyMMATypes(mlir::NVVM::MMATypes::bf16); } - // Check number of result registers + // Check result registers int expectedOutput; if (outputType.isF32() || outputType.isInteger(32)) expectedOutput = getShape().getN() / 2; @@ -876,7 +859,7 @@ << ", however output struct has " << outputSize << " elements"; } - // Check satfinite is set. It is only for s32 accumulator + // Check satfinite (only availalbe for s32 accumulator) if (!outputType.isInteger(32) && getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == NVVM::MMAIntOverflow::satfinite) { @@ -892,8 +875,8 @@ std::string NVVM::WgmmaMmaSyncOp::getPtx() { int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); - bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 || - getTypeA() == mlir::NVVM::MMATypes::bf16; + bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 || + getTypeA() == mlir::NVVM::WGMMATypes::bf16; Value outValue = getResults() ? getResults() : getInouts(); auto stype = dyn_cast(outValue.getType()); @@ -901,10 +884,13 @@ std::string outputTypeName; if (outputType.isF16()) outputTypeName = "f16"; - if (outputType.isF32()) + else if (outputType.isF32()) outputTypeName = "f32"; else if (outputType.isInteger(32)) outputTypeName = "s32"; + else + assert(false && "unsupported output type"); + int expectedOutputRegisters; if (outputType.isF32() || outputType.isInteger(32)) expectedOutputRegisters = getShape().getN() / 2; @@ -921,7 +907,8 @@ << ", 0;\n" "wgmma.mma_async.sync.aligned.m" << m << "n" << n << "k" << k << "." << outputTypeName << "." - << stringifyMMATypes(getTypeA()) << "." << stringifyMMATypes(getTypeB()); + << stringifyWGMMATypes(getTypeA()) << "." + << stringifyWGMMATypes(getTypeB()); if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == NVVM::MMAIntOverflow::satfinite) ss << ".satfinite"; @@ -959,8 +946,8 @@ Value outValue = getResults() ? getResults() : getInouts(); auto stype = dyn_cast(outValue.getType()); Type outputType = stype.getBody().front(); - bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 || - getTypeA() == mlir::NVVM::MMATypes::bf16; + bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 || + getTypeA() == mlir::NVVM::WGMMATypes::bf16; if (getResults()) asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); if (getInouts()) diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir --- a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir @@ -62,7 +62,7 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> - // expected-error @+1 {{op shape 'k' must be 16 for input type f16}} + // expected-error @+1 {{op shape 'k' must be 16 for input type f16}} %res = nvvm.wgmma.mma_async %descA, %descB, #nvvm.shape, D [%result, ], @@ -116,4 +116,19 @@ : !llvm.struct<(i32, i32, i32, i32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> return +} + +// ----- + +func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { + %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // expected-error @+1 {{op 'f32' += bf16 * f16, it is not supported}} + %res = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, ], + A [, #nvvm.wgmma_scale_in, ], + B [, #nvvm.wgmma_scale_in, ] + : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + return } \ No newline at end of file diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -257,14 +257,71 @@ %result1 = nvvm.wgmma.mma_async %descA, %descB, #nvvm.shape, D [%result, #nvvm.wgmma_scale_out], - A [#nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], - B [#nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + A [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + B [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] : !mat32f32 -> !mat32f32 %result2 = nvvm.wgmma.mma_async %descA, %descB, #nvvm.shape, D [%result1, #nvvm.wgmma_scale_out], - A [#nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], - B [#nvvm.mma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + A [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + B [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + : !mat32f32 -> !mat32f32 + return %result2 : !mat32f32 +} + + +// ----- + +!mat32f32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32)> + +// CHECK-LABEL: @wgmma_f32_e4m3_e4m3 +func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 { + // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" + // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" + %result = llvm.mlir.undef : !mat32f32 + %result1 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, #nvvm.wgmma_scale_out], + A [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + B [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + : !mat32f32 -> !mat32f32 + %result2 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result1, #nvvm.wgmma_scale_out], + A [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + B [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + : !mat32f32 -> !mat32f32 + return %result2 : !mat32f32 +} + +// ----- + +!mat32f32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32)> + +// CHECK-LABEL: @wgmma_f32_e5m2_e4m3 +func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 { + // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" + // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n" + %result = llvm.mlir.undef : !mat32f32 + %result1 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result, #nvvm.wgmma_scale_out], + A [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + B [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] + : !mat32f32 -> !mat32f32 + %result2 = nvvm.wgmma.mma_async %descA, %descB, + #nvvm.shape, + D [%result1, #nvvm.wgmma_scale_out], + A [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout], + B [#nvvm.wgmma_type, #nvvm.wgmma_scale_in, #nvvm.mma_layout] : !mat32f32 -> !mat32f32 return %result2 : !mat32f32 }