diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -35,6 +35,8 @@ mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(NVVMOpsStructs.h.inc -gen-struct-attr-decls) +mlir_tablegen(NVVMOpsStructs.cpp.inc -gen-struct-attr-defs) mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm) mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm) add_public_tablegen_target(MLIRNVVMConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -21,6 +21,7 @@ #include "llvm/IR/IntrinsicsNVPTX.h" #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc" +#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.h.inc" namespace mlir { namespace NVVM { diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -195,18 +195,6 @@ let assemblyFormat = "$n attr-dict"; } -def NVVM_MmaOp : - NVVM_Op<"mma.sync">, - Results<(outs LLVM_Type:$res)>, - Arguments<(ins Variadic:$args)> { - string llvmBuilder = [{ - $res = createIntrinsicCall( - builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args); - }]; - let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; - let hasVerifier = 1; -} - /// Helpers to instantiate different version of wmma intrinsics. /// This matches the hierarchy used in IntrinsicsNVVM.td to define all the /// combinations of the intrinsics. @@ -296,6 +284,7 @@ // Creates list of valid combinations of fragments. This is a subset of what // llvm supports and can be extended as needed. class NVVM_MMA_OPS { + // "wmma" operations list> tf32_wmma_ops = MMA_OPS< [GEOM<16, 16, 8>], ["tf32"], [], ["f32"], []>.ret; @@ -324,6 +313,32 @@ // Separate A/B/C fragments (loads) from D (stores). list all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d")); list all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d")); + + // "mma_sync" operations + list> tf32_mma_ops = MMA_OPS< + [GEOM<16,8,4>, GEOM<16,8,8>], + ["tf32"], [], ["f32"], []>.ret; + list> bf16_mma_ops = MMA_OPS< + [GEOM<16,8,16>, GEOM<16,8,8>], + ["bf16"], [], ["f32"], []>.ret; + list> f64_mma_ops = MMA_OPS< + [GEOM<8,8,4>], + ["f64"], [], ["f64"], []>.ret; + list> fp_mma_ops = MMA_OPS< + [GEOM<8,8,4>, GEOM<16,8,8>, GEOM<16,8,16>], + ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; + list> int_mma_ops = MMA_OPS< + [GEOM<8,8,16>, GEOM<16,8,16>, GEOM<16,8,32>], + ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret; + list> subint_mma_ops = MMA_OPS< + [GEOM<8,8,32>, GEOM<16,8,32>, GEOM<16,8,64>], + ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret; + list> bit_mma_ops = MMA_OPS< + [GEOM<8,8,128>, GEOM<16,8,128>, GEOM<16,8,256>], + ["b1"], [], ["s32"], []>.ret; + list> all_mma_sync_ops = !listconcat( + tf32_mma_ops, bf16_mma_ops, f64_mma_ops, + fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops); } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -405,6 +420,150 @@ string id = !foldl("", f, acc, el, acc # "\n" # el); } +/// Enum attribute for binary (b1) MMA operation type +def MMAB1OpNone : I32EnumAttrCase<"none", 0>; +def MMAB1OpXorPopc : I32EnumAttrCase<"xor_popc", 1>; +def MMAB1OpAndPopc : I32EnumAttrCase<"and_popc", 2>; +def MMAB1Op : I32EnumAttr<"MMAB1Op", "MMA binary operations", + [MMAB1OpNone, MMAB1OpXorPopc, MMAB1OpAndPopc]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def MMAB1OpAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +/// Enum attribute type for the overflow behavior of MMA integer operations +def MMAIntOverflowWrap : I32EnumAttrCase<"wrapped", 0>; +def MMAIntOverflowSat : I32EnumAttrCase<"satfinite", 1>; +def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options", + [MMAIntOverflowSat, MMAIntOverflowWrap]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def MMAIntOverflowAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +/// Attribute to hold the MMA shape +def NVVM_MMAShapeAttr : StructAttr<"MMAShapeAttr", NVVM_Dialect, [ + StructFieldAttr<"m", I32Attr>, + StructFieldAttr<"n", I32Attr>, + StructFieldAttr<"k", I32Attr> + ]> { + let summary = "Attribute for MMA operation shape."; +} + +// Returns true if this combination of layout/satf for MMA ops is supported; +// false otherwise. +// E.g. +// if NVVM_MMA_SUPPORTED<...>.ret then +// def : FOO<>; // The record will only be defined for supported ops. +// +class NVVM_MMA_SUPPORTED frags, string layout_a, string layout_b, int satf> { + // MMA ops check both layouts. + string layout = layout_a # ":" # layout_b; + string a_type = frags[0].ptx_elt_type; + string b_type = frags[1].ptx_elt_type; + string c_type = frags[2].ptx_elt_type; + string d_type = frags[3].ptx_elt_type; + string geom = frags[0].geom; + + // gcd is a shortcut used to identify instructions that depend on + // geom+frag_c+frag_d. + string gcd = geom # ":" # c_type # d_type; + bit ret = !cond( + + // Limit satf to valid types + !and(!eq(satf, 1), + !ne(a_type, "s8"), + !ne(a_type, "u8"), + !ne(a_type, "s4"), + !ne(a_type, "u4")): false, + + // m8n8k4 has no C=f32 D=f16 variant. + !eq(gcd, "m8n8k4:f32f16"): false, + + // only m8n8k4 for f16 does not require row:col layout + !and(!ne(layout, "row:col"), + !or(!ne(geom, "m8n8k4"), + !ne(a_type, "f16"))) : false, + + // m16n8k8 requires A and B to be the same type and C and D to be the same + // type. + !and(!eq(geom, "m16n8k8"), + !or(!ne(a_type, b_type), + !ne(c_type, d_type))): false, + + // m16n8k8 requires C and D to be the same type. + !and(!eq(geom, "m16n8k8"), + !ne(c_type, d_type)): false, + + // All other are OK. + true: true + ); +} + +// Returns a list of operation suffixes corresponding to possible b1 +// multiply-and-accumulate operations for all fragments which have a +// b1 type. For all other fragments, the list returned holds a list +// containing the empty string. +class NVVM_MMA_B1OPS frags> { + list ret = !cond( + !eq(frags[0].ptx_elt_type, "b1") : ["xor_popc", "and_popc"], + true: [""] + ); +} + +/// Generate enum value of the mma.sync intrinsic. +class MMA_SYNC_NAME { + string signature = MMA_SIGNATURE.ret; + string id = "llvm::Intrinsic::nvvm_mma" + # !if(!ne(b1op, ""), "_" # b1op, "") + # "_" # A.geom + # "_" # ALayout + # "_" # BLayout + # !if(Satfinite, "_satfinite", "") + # signature; +} + +/// Helper to create the mapping between the configuration and the mma.sync +/// intrinsic enum value. +class MMA_SYNC_INTR { + list>>>> cond0 = + !foreach(op, NVVM_MMA_OPS.all_mma_sync_ops, + !foreach(layoutA, ["row", "col"], + !foreach(layoutB, ["row", "col"], + !foreach (sat, [0, 1], + !foreach (b1op, NVVM_MMA_B1OPS.ret, + !if(NVVM_MMA_SUPPORTED<[op[0], op[1], op[2], op[3]], + layoutA, layoutB, sat>.ret, + "if (layoutA == \"" # layoutA # "\" && layoutB == \"" # layoutB # "\" && " + " m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k # + " && \"" # op[0].ptx_elt_type # "\" == eltypeA && \"" + # op[1].ptx_elt_type # "\" == eltypeB && " + # " \"" # op[2].ptx_elt_type # "\" == eltypeC && " + # " \"" # op[3].ptx_elt_type # "\" == eltypeD " + # " && (sat.hasValue() ? " # sat # " == static_cast(*sat) : true)" + # !if(!ne(b1op, ""), " && (b1Op.hasValue() ? MMAB1Op::" # b1op # " == b1Op.getValue() : true)", "") # ")\n" + # " return " # + MMA_SYNC_NAME.id # ";", + "") // if supported + ) // b1op + ) // sat + ) // layoutB + ) // layoutA + ); // all_mma_sync_ops + list>> f1 = !foldl([[[""]]], + !foldl([[[[""]]]], cond0, acc, el, + !listconcat(acc, el)), + acc1, el1, !listconcat(acc1, el1)); + list> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1)); + list f3 = !foldl([""], f2, acc, el, !listconcat(acc, el)); + string id = !foldl("", f3, acc, el, acc # "\n" # el); +} + def MMALayoutRow : I32EnumAttrCase<"row", 0>; def MMALayoutCol : I32EnumAttrCase<"col", 1>; @@ -418,13 +577,23 @@ let assemblyFormat = "`<` $value `>`"; } +/// Enum attribute of the different PTX element types used for MMA operands. def MMATypeF16 : I32EnumAttrCase<"f16", 0>; def MMATypeF32 : I32EnumAttrCase<"f32", 1>; def MMATypeTF32 : I32EnumAttrCase<"tf32", 2>; +def MMATypeU8 : I32EnumAttrCase<"u8", 3>; +def MMATypeS8 : I32EnumAttrCase<"s8", 4>; +def MMATypeS32 : I32EnumAttrCase<"s32", 5>; +def MMATypeB1 : I32EnumAttrCase<"b1", 6>; +def MMATypeU4 : I32EnumAttrCase<"u4", 7>; +def MMATypeS4 : I32EnumAttrCase<"s4", 8>; +def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>; -/// Enum attribute of the different matrix types. def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types", - [MMATypeF16, MMATypeF32, MMATypeTF32]> { + [MMATypeF16, MMATypeF32, MMATypeTF32, + MMATypeBF16, MMATypeS8, MMATypeU8, + MMATypeS32, MMATypeS4, MMATypeU4, + MMATypeB1]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } @@ -678,4 +847,141 @@ let hasVerifier = 1; } +def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> { + + let summary = "cooperative matrix-multiply and accumulate"; + + let description = [{ + The `nvvm.mma.sync` operation collectively performs the operation + `D = matmul(A, B) + C` using all threads in a warp. + + All the threads in the warp must execute the same `mma.sync` operation. + + For each possible multiplicand PTX data type, there are one or more possible + instruction shapes given as "mMnNkK". The below table describes the posssibilities + as well as the types required for the operands. Note that the data type for + C (the accumulator) and D (the result) can vary independently when there are + multiple possibilities in the "C/D Type" column. + + When an optional attribute cannot be immediately inferred from the types of + the operands and the result during parsing or validation, an error will be + raised. + + `b1Op` is only relevant when the binary (b1) type is given to + `multiplicandDataType`. It specifies how the multiply-and-acumulate is + performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`. + + `intOverflowBehavior` is only relevant when the `multiplicandType` attribute + is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled + in the accumulator. When the attribute is `satfinite`, the accumulator values + are clamped in the int32 range on overflow. This is the default behavior. + Alternatively, accumulator behavior `wrapped` can also be specified, in + which case overflow wraps from one end of the range to the other. + + `layoutA` and `layoutB` are required and should generally be set to + `#nvvm.mma_layout` and `#nvvm.mma_layout` respectively, but other + combinations are possible for certain layouts according to the table below. + + ``` + | A/B Type | Shape | ALayout | BLayout | A Type | B Type | C/D Type | + |----------|-----------|---------|---------|----------|----------|-------------------| + | f64 | .m8n8k4 | row | col | 1x f64 | 1x f64 | 2x f64 | + | f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 | + | | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 | + | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 | + | bf16 | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 | + | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 | + | tf32 | m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 | + | | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 | + | u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 | + | | .m16n8k16 | row | col | 2x i32 | 1x i32 | 4x i32 | + | | .m16n8k32 | row | col | 4x i32 | 2x i32 | 4x i32 | + | u4/s4 | .m8n8k32 | row | col | 1x i32 | 1x i32 | 2x i32 | + | | m16n8k32 | row | col | 2x i32 | 1x i32 | 4x i32 | + | | m16n8k64 | row | col | 4x i32 | 2x i32 | 4x i32 | + | b1 | m8n8k128 | row | col | 1x i32 | 1x i32 | 2x i32 | + | | m16n8k128 | row | col | 2x i32 | 1x i32 | 4x i32 | + ``` + + + Example: + ```mlir + + %128 = nvvm.mma.sync A[%120, %121, %122, %123] + B[%124, %125] + C[%126, %127] + {layoutA = #nvvm.mma_layout, + layoutB = #nvvm.mma_layout, + shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) + -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + ``` + }]; + + let results = (outs LLVM_AnyStruct:$res); + let arguments = (ins NVVM_MMAShapeAttr:$shape, + OptionalAttr:$b1Op, + OptionalAttr:$intOverflowBehavior, + MMALayoutAttr:$layoutA, + MMALayoutAttr:$layoutB, + OptionalAttr:$multiplicandAPtxType, + OptionalAttr:$multiplicandBPtxType, + Variadic:$operandA, + Variadic:$operandB, + Variadic:$operandC); + + let extraClassDeclaration = !strconcat([{ + static llvm::Intrinsic::ID getIntrinsicID( + int64_t m, int64_t n, uint64_t k, + llvm::Optional b1Op, + llvm::Optional sat, + mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum, + mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum, + mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) { + llvm::StringRef layoutA = stringifyEnum(layoutAEnum); + llvm::StringRef layoutB = stringifyEnum(layoutBEnum); + llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum); + llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum); + llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum); + llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum); + }], + MMA_SYNC_INTR<>.id, [{ + return 0; + } + + static Optional inferOperandMMAType(Type operandType, + bool isAccum); + + MMATypes accumPtxType(); + MMATypes resultPtxType(); + }]); + + let builders = [ + OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, + "ValueRange":$operandB, "ValueRange":$operandC, + "ArrayRef":$shape, "Optional":$b1Op, + "Optional":$intOverflow, + "Optional>":$multiplicandPtxTypes, + "Optional>":$multiplicandLayouts)> + ]; + + string llvmBuilder = [{ + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + auto intId = mlir::NVVM::MmaOp::getIntrinsicID( + $shape.m().getInt(), $shape.n().getInt(), $shape.k().getInt(), + $b1Op, $intOverflowBehavior, + $layoutA, $layoutB, + $multiplicandAPtxType.getValue(), + $multiplicandBPtxType.getValue(), + op.accumPtxType(), + op.resultPtxType()); + + $res = createIntrinsicCall( + builder, intId, operands); + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + #endif // NVVMIR_OPS diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -34,6 +34,7 @@ #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" +#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.cpp.inc" //===----------------------------------------------------------------------===// // Printing/parsing for NVVM ops @@ -69,47 +70,442 @@ return success(); } +// static +Optional MmaOp::inferOperandMMAType(Type type, + bool isAccum) { + auto half2Type = + LLVM::getFixedVectorType(Float16Type::get(type.getContext()), 2); + if (type.isa() || type == half2Type) + return NVVM::MMATypes::f16; + if (type.dyn_cast()) + return NVVM::MMATypes::f32; + if (type.isa()) { + if (isAccum) + return NVVM::MMATypes::s32; + return llvm::None; + } + + if (auto structType = type.dyn_cast()) { + if (structType.getBody().empty()) + return llvm::None; + return inferOperandMMAType(structType.getBody()[0], isAccum); + } + + return llvm::None; +} + +static bool isInt4PtxType(MMATypes type) { + return (type == MMATypes::u4 || type == MMATypes::s4); +} + +static bool isInt8PtxType(MMATypes type) { + return (type == MMATypes::u8 || type == MMATypes::s8); +} + +static bool isIntegerPtxType(MMATypes type) { + return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 || + type == MMATypes::s32; +} + +MMATypes MmaOp::accumPtxType() { + Optional val = inferOperandMMAType( + getODSOperands(2).getTypes().front(), /*isAccum=*/true); + assert(val.hasValue() && "accumulator PTX type should always be inferrable"); + return val.getValue(); +} +MMATypes MmaOp::resultPtxType() { + Optional val = + inferOperandMMAType(getResult().getType(), /*isAccum=*/true); + assert(val.hasValue() && "result PTX type should always be inferrable"); + return val.getValue(); +} + +void MmaOp::print(OpAsmPrinter &p) { + SmallVector regTypes; + struct OperandFragment { + StringRef operandName; + StringRef ptxTypeAttr; + SmallVector regs; + explicit OperandFragment(StringRef name, StringRef ptxTypeName) + : operandName(name), ptxTypeAttr(ptxTypeName) {} + }; + + std::array frags{ + OperandFragment("A", multiplicandAPtxTypeAttrName()), + OperandFragment("B", multiplicandBPtxTypeAttrName()), + OperandFragment("C", "")}; + SmallVector ignoreAttrNames{ + mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()}; + + for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { + auto &frag = frags[fragIdx]; + auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); + for (auto operandIdx = varOperandSpec.first; + operandIdx < varOperandSpec.first + varOperandSpec.second; + operandIdx++) { + frag.regs.push_back(this->getOperand(operandIdx)); + if (operandIdx == 0) { + regTypes.push_back(this->getOperand(operandIdx).getType()); + } + } + Optional inferredType = + inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2); + if (inferredType) + ignoreAttrNames.push_back(frag.ptxTypeAttr); + } + + auto printMmaOperand = [&](const OperandFragment &frag) -> void { + p << " " << frag.operandName; + p << "["; + p.printOperands(frag.regs); + p << "] "; + }; + + for (const auto &frag : frags) { + printMmaOperand(frag); + } + + p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); + + // Print the types of the operands and result. + p << " : " + << "("; + llvm::interleaveComma(SmallVector{frags[0].regs[0].getType(), + frags[1].regs[0].getType(), + frags[2].regs[0].getType()}, + p); + p << ")"; + p.printArrowTypeList(TypeRange{this->res().getType()}); +} + +// static +void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + ArrayRef shape, Optional b1Op, + Optional intOverflow, + Optional> multiplicandPtxTypes, + Optional> multiplicandLayouts) { + + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + MLIRContext *ctx = builder.getContext(); + Type i32 = builder.getIntegerType(32); + result.addAttribute( + "shape", MMAShapeAttr::get(builder.getIntegerAttr(i32, shape[0]), + builder.getIntegerAttr(i32, shape[1]), + builder.getIntegerAttr(i32, shape[2]), ctx)); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + + if (multiplicandPtxTypes.hasValue()) { + result.addAttribute("multiplicandAPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); + result.addAttribute("multiplicandBPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); + } else { + if (auto res = inferOperandMMAType(operandA[0].getType(), false)) + result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); + if (auto res = inferOperandMMAType(operandB[0].getType(), false)) + result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); + } + + if (multiplicandLayouts.hasValue()) { + result.addAttribute("layoutA", + MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0])); + result.addAttribute("layoutB", + MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1])); + } else { + result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row)); + result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col)); + } + + if (intOverflow.hasValue()) + result.addAttribute("intOverflowBehavior", + MMAIntOverflowAttr::get(ctx, *intOverflow)); + if (b1Op.hasValue()) + result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op)); + + result.addTypes(resultType); + result.addAttribute( + MmaOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({static_cast(operandA.size()), + static_cast(operandB.size()), + static_cast(operandC.size())})); +} + +// := +// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]` +// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0])) +// `->` type($res) +ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) { + struct OperandFragment { + Optional elemtype; + SmallVector regs; + SmallVector regTypes; + }; + + Builder &builder = parser.getBuilder(); + std::array frags; + + NamedAttrList namedAttributes; + + // A helper to parse the operand segments. + auto parseMmaOperand = [&](StringRef operandName, + OperandFragment &frag) -> LogicalResult { + if (parser.parseKeyword(operandName).failed()) + return failure(); + if (parser + .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) + .failed()) + return failure(); + return success(); + }; + + // Parse the operand segments. + if (parseMmaOperand("A", frags[0]).failed()) + return failure(); + if (parseMmaOperand("B", frags[1]).failed()) + return failure(); + if (parseMmaOperand("C", frags[2]).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse the type specification and resolve operands. + SmallVector operandTypes; + if (failed(parser.parseColon())) + return failure(); + if (failed(parser.parseLParen())) + return failure(); + if (failed(parser.parseTypeList(operandTypes))) + return failure(); + if (failed(parser.parseRParen())) + if (operandTypes.size() != 3) + return parser.emitError( + parser.getNameLoc(), + "expected one type for each operand segment but got " + + Twine(operandTypes.size()) + " types"); + for (auto iter : llvm::enumerate(operandTypes)) { + auto &frag = frags[iter.index()]; + frag.regTypes.resize(frag.regs.size(), iter.value()); + if (failed(parser.resolveOperands(frag.regs, frag.regTypes, + parser.getNameLoc(), result.operands))) + return failure(); + frag.elemtype = + inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2); + } + + Type resultType; + parser.parseArrow(); + parser.parseType(resultType); + frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true); + + std::array names{"multiplicandAPtxType", + "multiplicandBPtxType"}; + for (unsigned idx = 0; idx < names.size(); idx++) { + const auto &frag = frags[idx]; + Optional attr = namedAttributes.getNamed(names[idx]); + if (!frag.elemtype.hasValue() && !attr.hasValue()) { + return parser.emitError( + parser.getNameLoc(), + "attribute " + names[idx] + + " is not provided explicitly and cannot be inferred"); + } + if (!attr.hasValue()) + result.addAttribute( + names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); + } + + result.addTypes(resultType); + if (!namedAttributes.empty()) + result.addAttributes(namedAttributes); + result.addAttribute(MmaOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({ + static_cast(frags[0].regs.size()), + static_cast(frags[1].regs.size()), + static_cast(frags[2].regs.size()), + })); + return success(); +} + LogicalResult MmaOp::verify() { MLIRContext *context = getContext(); auto f16Ty = Float16Type::get(context); + auto i32Ty = IntegerType::get(context, 32); auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); auto f32Ty = Float32Type::get(context); auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); - auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( - context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); - - auto operandTypes = getOperandTypes(); - if (operandTypes != SmallVector(8, f16x2Ty) && - operandTypes != ArrayRef{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty}) { - return emitOpError("expected operands to be 4 s followed by either " - "4 s or 8 floats"); + + auto s32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); + auto f32x8StructTy = + LLVM::LLVMStructType::getLiteral(context, SmallVector(8, f32Ty)); + auto f16x2x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); + auto f32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); + auto s32x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); + + std::array mmaShape{shapeAttr().m().getInt(), + shapeAttr().n().getInt(), + shapeAttr().k().getInt()}; + + // These variables define the set of allowed data types for matrices A, B, C, + // and result. + using AllowedShapes = SmallVector, 2>; + using AllowedTypes = SmallVector, 2>; + AllowedShapes allowedShapes; + AllowedTypes expectedA; + AllowedTypes expectedB; + AllowedTypes expectedC; + SmallVector expectedResult; + + // When M = 16, we just need to calculate the number of 8xk tiles, where + // k is a factor that depends on the data type. + if (mmaShape[0] == 16) { + int64_t kFactor; + Type multiplicandFragType; + switch (multiplicandAPtxType().getValue()) { + // case MMATypes::f64: + case MMATypes::tf32: + kFactor = 4; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {i32Ty, i32Ty, i32Ty, i32Ty})); + break; + case MMATypes::f16: + case MMATypes::bf16: + kFactor = 8; + multiplicandFragType = f16x2Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + break; + case MMATypes::s4: + case MMATypes::u4: + kFactor = 32; + break; + case MMATypes::b1: + kFactor = 128; + break; + case MMATypes::s8: + case MMATypes::u8: + kFactor = 16; + break; + default: + return emitError("invalid shape or multiplicand type: " + + stringifyEnum(multiplicandAPtxType().getValue())); + } + + if (isIntegerPtxType(multiplicandAPtxType().getValue())) { + expectedResult.push_back(s32x4StructTy); + expectedC.emplace_back(4, i32Ty); + multiplicandFragType = i32Ty; + } else { + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } + + int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor); + int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); + expectedA.emplace_back(unitA, multiplicandFragType); + expectedB.emplace_back(unitB, multiplicandFragType); + allowedShapes.push_back({16, 8, kFactor}); + allowedShapes.push_back({16, 8, kFactor * 2}); } - if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) { - return emitOpError("expected result type to be a struct of either 4 " - "s or 8 floats"); + + // In the M=8 case, there is only 1 possible case per data type. + if (mmaShape[0] == 8) { + if (multiplicandAPtxType().getValue() == MMATypes::f16) { + expectedA.emplace_back(2, f16x2Ty); + expectedB.emplace_back(2, f16x2Ty); + expectedResult.push_back(f16x2x4StructTy); + expectedResult.push_back(f32x8StructTy); + expectedC.emplace_back(4, f16x2Ty); + expectedC.emplace_back(8, f32Ty); + allowedShapes.push_back({8, 8, 4}); + } + if (isIntegerPtxType(multiplicandAPtxType().getValue())) { + expectedA.push_back({i32Ty}); + expectedB.push_back({i32Ty}); + expectedC.push_back({i32Ty, i32Ty}); + expectedResult.push_back(s32x2StructTy); + if (isInt4PtxType(multiplicandAPtxType().getValue())) + allowedShapes.push_back({8, 8, 32}); + if (isInt8PtxType(multiplicandAPtxType().getValue())) + allowedShapes.push_back({8, 8, 16}); + if (multiplicandAPtxType().getValue() == MMATypes::b1) + allowedShapes.push_back({8, 8, 128}); + } } - auto alayout = (*this)->getAttrOfType("alayout"); - auto blayout = (*this)->getAttrOfType("blayout"); + std::string errorMessage; + llvm::raw_string_ostream errorStream(errorMessage); - if (!(alayout && blayout) || - !(alayout.getValue() == "row" || alayout.getValue() == "col") || - !(blayout.getValue() == "row" || blayout.getValue() == "col")) { - return emitOpError("alayout and blayout attributes must be set to either " - "\"row\" or \"col\""); + // Check that we matched an existing shape/dtype combination. + if (expectedA.empty() || expectedB.empty() || expectedC.empty() || + !llvm::any_of(allowedShapes, + [&](const auto &allowed) { return allowed == mmaShape; })) { + errorStream << "unimplemented variant for MMA shape <"; + llvm::interleaveComma(mmaShape, errorStream); + errorStream << ">"; + return emitOpError(errorMessage); } - if (operandTypes == ArrayRef{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty} && - getType() == f32x8StructTy && alayout.getValue() == "row" && - blayout.getValue() == "col") { - return success(); + // Verify the operand types for segments of A, B, and C operands. + std::array operandNames{"A", "B", "C"}; + for (const auto &iter : llvm::enumerate( + SmallVector{expectedA, expectedB, expectedC})) { + auto spec = this->getODSOperandIndexAndLength(iter.index()); + SmallVector operandTySeg(operand_type_begin() + spec.first, + operand_type_begin() + spec.first + + spec.second); + bool match = + llvm::any_of(iter.value(), [&](const SmallVector &typeSet) { + return typeSet == operandTySeg; + }); + + if (!match) { + errorStream << "Could not match types for the " + << operandNames[iter.index()] + << " operands; expected one of "; + for (const auto &x : iter.value()) { + errorStream << x.size() << "x" << x[0] << " "; + } + errorStream << "but got "; + llvm::interleaveComma(operandTySeg, errorStream); + return emitOpError(errorStream.str()); + } + } + + // Check the result type + if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { + return expectedResultType == getResult().getType(); + })) { + errorStream + << "Could not match allowed types for the result; expected one of "; + llvm::interleaveComma(expectedResult, errorStream); + errorStream << " but got " << getResult().getType(); + return emitOpError(errorStream.str()); + } + + // Ensure that binary MMA variants have a b1 MMA operation defined. + if (multiplicandAPtxType() == MMATypes::b1 && !b1Op().hasValue()) { + return emitOpError("op requires " + b1OpAttrName().strref() + " attribute"); } - return emitOpError("unimplemented mma.sync variant"); + + // Ensure int4/int8 MMA variants specify the accum overflow behavior + // attribute. + if (isInt4PtxType(*multiplicandAPtxType()) || + isInt8PtxType(*multiplicandAPtxType())) { + if (!intOverflowBehavior().hasValue()) + return emitOpError("op requires " + + intOverflowBehaviorAttrName().strref() + " attribute"); + } + + return success(); } LogicalResult ShflOp::verify() { diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -514,12 +514,13 @@ // ----- -func @nvvm_invalid_mma_0(%a0 : f16, %a1 : vector<2xf16>, +func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { - // expected-error@+1 {{expected operands to be 4 s followed by either 4 s or 8 floats}} - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (f16, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // expected-error@+1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}} + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] + {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -529,8 +530,9 @@ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { - // expected-error@+1 {{expected result type to be a struct of either 4 s or 8 floats}} - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)> + // expected-error@+1 {{Could not match allowed types for the result; expected one of !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> but got !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>}} + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] + {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)> } @@ -540,8 +542,9 @@ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { - // expected-error@+1 {{alayout and blayout attributes must be set to either "row" or "col"}} - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // expected-error@+1 {{op requires attribute 'layoutA'}} + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] + {shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -549,55 +552,23 @@ func @nvvm_invalid_mma_3(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, - %c0 : vector<2xf16>, %c1 : vector<2xf16>, - %c2 : vector<2xf16>, %c3 : vector<2xf16>) { - // expected-error@+1 {{unimplemented mma.sync variant}} - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> - llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> -} - -// ----- - -func @nvvm_invalid_mma_4(%a0 : vector<2xf16>, %a1 : vector<2xf16>, - %b0 : vector<2xf16>, %b1 : vector<2xf16>, - %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, - %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { - // expected-error@+1 {{unimplemented mma.sync variant}} - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + %c0 : vector<2xf16>, %c1 : vector<2xf16>) { + // expected-error@+1 {{unimplemented variant for MMA shape <8, 8, 16>}} + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } // ----- -func @nvvm_invalid_mma_5(%a0 : vector<2xf16>, %a1 : vector<2xf16>, - %b0 : vector<2xf16>, %b1 : vector<2xf16>, - %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, - %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { - // expected-error@+1 {{unimplemented mma.sync variant}} - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> - llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> -} - -// ----- - -func @nvvm_invalid_mma_6(%a0 : vector<2xf16>, %a1 : vector<2xf16>, - %b0 : vector<2xf16>, %b1 : vector<2xf16>, - %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, - %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { - // expected-error@+1 {{invalid kind of type specified}} - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> - llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> -} - -// ----- - -func @nvvm_invalid_mma_7(%a0 : vector<2xf16>, %a1 : vector<2xf16>, - %b0 : vector<2xf16>, %b1 : vector<2xf16>, - %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, - %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { - // expected-error@+1 {{op requires one result}} - %0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> (!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, i32) - llvm.return %0#0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> +func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { + // expected-error@+1 {{op requires b1Op attribute}} + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } // ----- diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -66,15 +66,164 @@ llvm.return %0 : i32 } -func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>, +// CHECK-LABEL: @nvvm_mma_m8n8k4_row_col_f32_f32 +func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, - %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, - %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { - // CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { + // CHECK: nvvm.mma.sync + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) { + // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +} + +func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32, + %c0 : i32, %c1 : i32) { + // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> + %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior=#nvvm.mma_int_overflow, + shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> + llvm.return %0 : !llvm.struct<(i32, i32)> +} + +func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>) { + // CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> +} + +func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> +} + +func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> + llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> +} + +func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> +} + +func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> +} + +func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior=#nvvm.mma_int_overflow, + shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior=#nvvm.mma_int_overflow, + shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + b1Op = #nvvm.mma_b1op, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + b1Op = #nvvm.mma_b1op, + shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +// CHECK-LABEL: @nvvm_mma_m8n8k128_b1_b1 +func @nvvm_mma_m8n8k128_b1_b1(%a0 : i32, + %b0 : i32, + %c0 : i32, %c1 : i32) { + // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> + %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + b1Op = #nvvm.mma_b1op, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32)> +} + +// CHECK-LABEL: @nvvm_mma_m16n8k32_s4_s4 +func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior=#nvvm.mma_int_overflow, + shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +// CHECK-LABEL: @nvvm_wmma_load_tf32 func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr, %arg1 : i32) -> !llvm.struct<(i32, i32, i32, i32)> { // CHECK: nvvm.wmma.load {{.*}} {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 8 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} %0 = nvvm.wmma.load %arg0, %arg1 diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -88,17 +88,115 @@ llvm.return %3 : i32 } -llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>, +// CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32 +llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> { // CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32 - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> { + // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f16 + %0 = nvvm.mma.sync A[ %a0, %a1, %a2, %a3 ] B[ %b0, %b1 ] C[ %c0, %c1 ] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> +} + +// f32 return type, f16 accumulate type +llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> { + // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16 + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> + llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> +} + +// f16 return type, f32 accumulate type +llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> { + // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32 + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> +} + +// f32 return type, f32 accumulate type +llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, + %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> { + // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f32 + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> +} + +llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> { + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8 + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior=#nvvm.mma_int_overflow, + shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> { + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8 + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior=#nvvm.mma_int_overflow, + shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1 + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + b1Op = #nvvm.mma_b1op, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, + %b0 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k32.row.col.satfinite.s4 + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior=#nvvm.mma_int_overflow, + shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + + // The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic // in the LLVM NVPTX backend. +// CHECK-LABEL: @gpu_wmma_load_op llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr, %arg1: i32) { // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}}) %0 = nvvm.wmma.load %arg0, %arg1