Index: mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -33,6 +33,8 @@ add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc) set(LLVM_TARGET_DEFINITIONS NVVMOps.td) mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) +mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRNVVMConversionsIncGen) add_mlir_dialect(ROCDLOps rocdl) Index: mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -18,6 +18,16 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/IR/IntrinsicsNVPTX.h" + +#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc" + +/// Return the element type and number of elements associated with a wmma matrix +/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td +/// WMMA_REGS structure. +std::pair inferMMAType(mlir::NVVM::MMATypes type, + mlir::NVVM::MMAFrag frag, + mlir::MLIRContext *context); ///// Ops ///// #define GET_OP_CLASSES Index: mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -151,124 +151,355 @@ let verifier = [{ return ::verify(*this); }]; } -// Base class for all the variants of WMMA loadOps that may be defined. -class NVVM_WMMALoadOp : NVVM_Op, - Results<(outs LLVM_AnyStruct:$res)>, - Arguments<(ins Variadic:$args)> { +/// Helpers to instantiate different version of wmma intrinsics. +/// This matches the hierarchy used in IntrinsicsNVVM.td to define all the +/// combinations of the intrinsics. +class GEOM { + int m = M; + int n = N; + int k = K; +} - let summary = "Warp synchronous matrix load"; +/// Class containing information about valid mma matrix types. +class WMMA_REGS { + int m = Geom.m; + int n = Geom.n; + int k = Geom.k; + string geom = "m"#Geom.m#"n"#Geom.n#"k"#Geom.k; + string frag = Frag; + string ptx_elt_type = PtxEltType; + string gft = geom#":"#Frag#":"#ptx_elt_type; +} - string baseDescription = [{"The `nvvm.wmma.m*n*k*.load.[a, b, c]` operation" - "loads a matrix collectively using all the threads in a warp." +//// Generate enum value of the mma.load/mma.store intrinsic. +class WMMA_NAME_LDST { + string id = "llvm::Intrinsic::nvvm_wmma" + # "_" # Frag.geom + # "_" # Op + # "_" # Frag.frag + # "_" # Frag.ptx_elt_type + # "_" # Layout + # !if(WithStride, "_stride", ""); +} - "The operation takes two arguments, the address from where the matrix" - "elements are to be loaded from and a stride. The stride argument" - "represents the leading dimension of the source matrix. The address and" - "the stride are required to be the same across all threads in the warp." - "Each thread in a warp holds a certain number of elements. The Op returns" - "a LLVMStruct which holds the elements of the matrix held by this thread." +/// Generate the signature part of the mma intrinsic name. +class MMA_SIGNATURE { + list id_frags = !cond( + // FP16 ops are identified by accumulator & result type. + !eq(A.ptx_elt_type, "f16") : [D, C], + // other ops are identified by input types. + !ne(A.ptx_elt_type, B.ptx_elt_type): [A, B], + true: [A] + ); + string ret = !foldl("", id_frags, a, b, !strconcat(a, "_", b.ptx_elt_type)); +} - "This op is meant to be used along with `nvvm.wmma.m*n*k*.store` and" - "`nvvm.wmma.m*n*k*.mma`."}]; +/// Generate enum value of the wmma.mma intrinsic. +class WMMA_NAME { + string signature = MMA_SIGNATURE.ret; + string id = "llvm::Intrinsic::nvvm_wmma" + # "_" # A.geom + # "_mma" + # "_" # ALayout + # "_" # BLayout + # signature; +} - let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; +// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. +// Geom: list of supported geometries. +// TypeN: PTX type of the corresponding fragment's element. +// TypeB and TypeD may be empty if it must match that of TypeA or TypeC. +class MMA_OPS Geom, list TypeA, list TypeB, + list TypeC, list TypeD> { + list> ret = + !foldl([]>, Geom, t1, geom, !listconcat(t1, + !foldl([]>, TypeA, t2, type_a, !listconcat(t2, + !foldl([]>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3, + !foldl([]>, TypeC, t4, type_c, !listconcat(t4, + !foldl([]>, !if(!size(TypeD), TypeD, [type_c]), t5, type_d, !listconcat(t5, + [[WMMA_REGS, + WMMA_REGS, + WMMA_REGS, + WMMA_REGS]])))))))))); + // Debugging aid for readable representation of the list above. + list> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]); } -def NVVM_WMMALoadAM16N16K16Op : - NVVM_WMMALoadOp<"wmma.m16n16k16.load.a.f16.row.stride">{ +/// Creates a list of combinations of load/store operations supported. +class MMA_LDST_OPS Geom, list Frags, list Types> { + list ret = + !foldl([], Geom, t1, geom, !listconcat(t1, + !foldl([], Frags, t2, frag, !listconcat(t2, + !foldl([], Types, t3, type, !listconcat(t3, + [WMMA_REGS])))))); + // Debugging aid for readable representation of the list above. + list ops = !foreach(x, ret, x.gft); +} - string llvmBuilder = [{ - $res = createNvvmIntrinsicCall( - builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride, $args); - }]; +// Creates list of valid combinations of fragments. This is a subset of what +// llvm supports and can be extended as needed. +class NVVM_MMA_OPS { + list> tf32_wmma_ops = MMA_OPS< + [GEOM<16, 16, 8>], + ["tf32"], [], ["f32"], []>.ret; + list> fp_wmma_ops = MMA_OPS< + [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], + ["f16"], [], ["f16", "f32"], []>.ret; + list> all_wmma_ops = !listconcat( + tf32_wmma_ops, + fp_wmma_ops); + + list ldst_ab_ops = MMA_LDST_OPS< + [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], + ["a", "b"], ["f16"]>.ret; + list ldst_cd_ops = MMA_LDST_OPS< + [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], + ["c", "d"], ["f16", "f32"]>.ret; + list ldst_tf32_ab_ops = MMA_LDST_OPS< + [GEOM<16, 16, 8>], + ["a", "b"], ["tf32"]>.ret; + list ldst_tf32_cd_ops = MMA_LDST_OPS< + [GEOM<16, 16, 8>], + ["c", "d"], ["f32"]>.ret; + list all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops, + ldst_tf32_ab_ops, + ldst_tf32_cd_ops); + // Separate A/B/C fragments (loads) from D (stores). + list all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d")); + list all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d")); +} - string opDescription = [{ - Example: +def NVVM_MMA_OPS : NVVM_MMA_OPS; + +/// Helper to create the mapping between the configuration and the store +/// intrinsic enum value. +class MMA_ST_INTR { + list> cond0 = !foreach(frag, NVVM_MMA_OPS.all_st_ops, + !foreach(layout, ["row", "col"], + "if (layout == \"" # layout # "\" && m == " # frag.m # " &&" + " n == " #frag.n # " && k == " # frag.k # " && \"" # + frag.ptx_elt_type # "\" == eltype)" + " return " #WMMA_NAME_LDST.id #";")); + string id = !foldl("", + !foldl([""], cond0, acc, el, !listconcat(acc, el)), + acc1, el1, acc1 # "\n" # el1); +} - ```mlir - %2 = nvvm.wmma.m16n16k16.load.a %0, %1 : !llvm.ptr, !llvm.i32 -> - !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, - vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> - ``` - }]; +/// Helper to map a mxk shape to a supported mxnxk matrix type. This will return +/// the n value of the supported configuration. +class MMA_ST_INFER_N ldst> { + list cond = !foreach(frag, ldst, + "if (m == " # frag.m # " && k == " #frag.k # " && \"" # + frag.ptx_elt_type # "\" == eltype)" + " return "# frag.n #";"); + string id = !foldl("", cond, acc, el, acc # "\n" # el); +} - let description = !strconcat(baseDescription, opDescription); +/// Helper to map a kxn shape to a supported mxnxk matrix type. This will return +/// the m value of the supported configuration. +class MMA_ST_INFER_M ldst> { + list cond = !foreach(frag, ldst, + "if (n == " # frag.n # " && k == " #frag.k # " && \"" # + frag.ptx_elt_type # "\" == eltype)" + " return "# frag.m #";"); + string id = !foldl("", cond, acc, el, acc # "\n" # el); +} - let verifier = [{ return ::verify(*this); }]; +/// Helper to map a mxn shape to a supported mxnxk matrix type. This will return +/// the k value of the supported configuration. +class MMA_ST_INFER_K ldst> { + list cond = !foreach(frag, ldst, + "if (m == " # frag.m # " && n == " #frag.n # " && \"" # + frag.ptx_elt_type # "\" == eltype)" + " return "# frag.k #";"); + string id = !foldl("", cond, acc, el, acc # "\n" # el); } -def NVVM_WMMALoadBM16N16K16Op : - NVVM_WMMALoadOp<"wmma.m16n16k16.load.b.f16.row.stride">{ +/// Helper to create the mapping between the configuration and the load +/// intrinsic enum value. +class MMA_LD_INTR { + list> cond0 = !foreach(frag, NVVM_MMA_OPS.all_ld_ops, + !foreach(layout, ["row", "col"], + "if (layout == \"" # layout # "\" && m == " # frag.m # " &&" + " n == " #frag.n # " && k == " # frag.k # " && \"" # + frag.ptx_elt_type # "\" == eltype && frag == \""#frag.frag#"\")" + " return "# WMMA_NAME_LDST.id #";")); + string id = !foldl("", + !foldl([""], cond0, acc, el, !listconcat(acc, el)), + acc1, el1, acc1 # "\n" # el1); +} - string llvmBuilder = [{ - $res = createNvvmIntrinsicCall( - builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride, $args); - }]; +/// Helper to create the mapping between the configuration and the wmma.mma +/// intrinsic enum value. +class MMA_MMA_INTR { + list>> cond0 = + !foreach(op, NVVM_MMA_OPS.all_wmma_ops, + !foreach(layoutA, ["row", "col"], + !foreach(layoutB, ["row", "col"], + "if (layoutA == \"" # layoutA # "\" && layoutB == \"" # layoutB # "\" && " + " m == " # op[0].m # " && n == " #op[0].n # " && k == " # op[0].k # + " && \"" # op[0].ptx_elt_type # "\" == eltypeA && \"" + # op[3].ptx_elt_type # "\" == eltypeB)" + " return " # + WMMA_NAME.id # ";"))); + list f = !foldl([""], + !foldl([[""]], cond0, acc, el, !listconcat(acc, el)), + acc1, el1, !listconcat(acc1, el1)); + string id = !foldl("", f, acc, el, acc # "\n" # el); +} - string opDescription = [{ - Example: +def MMALayoutRow : StrEnumAttrCase<"row">; +def MMALayoutCol : StrEnumAttrCase<"col">; - ```mlir - %2 = nvvm.wmma.m16n16k16.load.b %0, %1 : !llvm.ptr, !llvm.i32 -> - !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, - vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> - ``` - }]; +/// Enum attribute of the different matrix layout. +def MMALayout : StrEnumAttr<"MMALayout", "NVVM MMA layout", + [MMALayoutRow, MMALayoutCol]> { + let cppNamespace = "::mlir::NVVM"; + let storageType = "mlir::StringAttr"; + let returnType = "NVVM::MMALayout"; + let convertFromStorage = "*symbolizeEnum($_self.getValue())"; + let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))"; +} - let description = !strconcat(baseDescription, opDescription); +def MMATypeF16 : StrEnumAttrCase<"f16">; +def MMATypeF32 : StrEnumAttrCase<"f32">; +def MMATypeTF32 : StrEnumAttrCase<"tf32">; - let verifier = [{ return ::verify(*this); }]; +/// Enum attribute of the different matrix types. +def MMATypes : StrEnumAttr<"MMATypes", "NVVM MMA types", + [MMATypeF16, MMATypeF32, MMATypeTF32]> { + let cppNamespace = "::mlir::NVVM"; + let storageType = "mlir::StringAttr"; + let returnType = "NVVM::MMATypes"; + let convertFromStorage = "*symbolizeEnum($_self.getValue())"; + let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))"; } -def NVVM_WMMALoadCF16M16N16K16Op : - NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f16.row.stride">{ - string llvmBuilder = [{ - $res = createNvvmIntrinsicCall( - builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride, $args); - }]; +def MMAFragA : StrEnumAttrCase<"a">; +def MMAFragB : StrEnumAttrCase<"b">; +def MMAFragC : StrEnumAttrCase<"c">; - string opDescription = [{ - Example: +/// Enum attribute of the different frag types. +def MMAFragAttr : StrEnumAttr<"MMAFrag", "NVVM MMA frag type", + [MMAFragA, MMAFragB, MMAFragC]> { + let cppNamespace = "::mlir::NVVM"; + let storageType = "mlir::StringAttr"; + let returnType = "NVVM::MMAFrag"; + let convertFromStorage = "*symbolizeEnum($_self.getValue())"; + let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))"; +} - ```mlir - %2 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %0, %1 : !llvm.ptr, !llvm.i32 -> - !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> - ``` - }]; +def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">, + Results<(outs LLVM_AnyStruct:$res)>, + Arguments<(ins LLVM_AnyPointer: $ptr, I32: $stride, I32Attr:$m, + I32Attr:$n, I32Attr:$k, MMALayout:$layout, MMATypes:$eltype, + MMAFragAttr:$frag)> { - let description = !strconcat(baseDescription, opDescription); + let summary = "Warp synchronous matrix load"; + + // Since LLVM intrinsic IDs are enum that cannot be dynamically generated in + // C++ we instanciate a function in tablegen to map the valide configuration + // to the corresponsding intrinsic ID. + // Because we want a single source of truth, this mean the source of truth + // about valid combinations needs to be in tablgen, therefore we generate + // extra helpers to query valid configurations based on the shapes of + // load/store operations. + let extraClassDeclaration = + "static llvm::Intrinsic::ID getIntrinsicID(" + "int m, int n, int k, mlir::NVVM::MMALayout layoutEnum," + "mlir::NVVM::MMATypes eltypeEnum,mlir::NVVM::MMAFrag fragEnum) {" + "llvm::StringRef layout = stringifyEnum(layoutEnum);" + "llvm::StringRef eltype = stringifyEnum(eltypeEnum);" + "llvm::StringRef frag = stringifyEnum(fragEnum);" + #MMA_LD_INTR<"load">.id# "\n" + "return 0;" + "}\n" + "/// Helpers to find valid n dimension based on mxk load shape.\n" + "static int inferNDimension(int m, int k, mlir::NVVM::MMATypes eltypeEnum) {" + " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" + #MMA_ST_INFER_N.id# "\n" + "return 0;" + "}\n" + "/// Helpers to find valid m dimension based on kxn load shape.\n" + "static int inferMDimension(int k, int n, mlir::NVVM::MMATypes eltypeEnum) {" + " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" + #MMA_ST_INFER_M.id# "\n" + "return 0;" + "}\n" + "/// Helpers to find valid k dimension based on mxn load shape.\n" + "static int inferKDimension(int m, int n, mlir::NVVM::MMATypes eltypeEnum) {" + " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" + #MMA_ST_INFER_K.id# "\n" + "return 0;" + "}\n"; - let verifier = [{ return ::verify(*this); }]; -} -def NVVM_WMMALoadCF32M16N16K16Op : - NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f32.row.stride">{ string llvmBuilder = [{ - $res = createNvvmIntrinsicCall( - builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride, $args); + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + auto intId = mlir::NVVM::WMMALoadOp::getIntrinsicID( + $m, $n, $k, $layout, $eltype, $frag); + $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); }]; - string opDescription = [{ + string baseDescription = [{ + The `nvvm.wmma.load` operation loads a matrix collectively using all the + threads in a warp. + + The operation takes two arguments, the address from where the matrix + elements are to be loaded from and a stride. The stride argument + represents the leading dimension of the source matrix. The address and + the stride are required to be the same across all threads in the warp. + Each thread in a warp holds a certain number of elements. The Op returns + a LLVMStruct which holds the elements of the matrix held by this thread. + + This op is meant to be used along with `nvvm.wmma.store` and + `nvvm.wmma.mma`. + Example: ```mlir - %2 = nvvm.wmma.m16n16k16.load.c.f32.row.stride %0, %1 : !llvm.ptr, !llvm.i32 -> - !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + %2 = nvvm.wmma.load %0, %1 + {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> ``` - }]; - - let description = !strconcat(baseDescription, opDescription); + }]; + let assemblyFormat = "$ptr `,` $stride attr-dict `:` functional-type($ptr, $res)"; let verifier = [{ return ::verify(*this); }]; } -// Base class for all the variants of WMMA storeOps that may be defined. -class NVVM_WMMAStoreOp : NVVM_Op, - Arguments<(ins Variadic:$args)>{ +def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">, + Arguments<(ins LLVM_AnyPointer: $ptr, + I32Attr:$m, I32Attr:$n, I32Attr:$k, MMALayout:$layout, + MMATypes:$eltype, Variadic:$args, I32: $stride)>{ let summary = "Warp synchronous matrix store"; + let extraClassDeclaration = + "static llvm::Intrinsic::ID getIntrinsicID(" + "int m, int n, int k, mlir::NVVM::MMALayout layoutEnum," + "mlir::NVVM::MMATypes eltypeEnum) {" + " llvm::StringRef layout = stringifyEnum(layoutEnum);" + " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" + #MMA_ST_INTR<"store">.id# "\n" + "return 0;" + "}\n" + "/// Helpers to find valid k dimension based on mxn store shape.\n" + "static int inferKDimension(int m, int n, mlir::NVVM::MMATypes eltypeEnum) {" + " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" + #MMA_ST_INFER_K.id# "\n" + "return 0;" + "}"; + + string llvmBuilder = [{ + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + auto intId = + mlir::NVVM::WMMAStoreOp::getIntrinsicID($m, $n, $k, $layout, $eltype); + createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); + }]; + string baseDescription = [{ - The `nvvm.wmma.m*n*k*.store` operation stores a matrix collectively using + The `nvvm.wmma.store` operation stores a matrix collectively using all the threads in a warp. The operation takes as arguments the address to where the matrix elements are @@ -279,60 +510,50 @@ This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma.m16n16k16.mma`. - }]; - - let assemblyFormat = "$args attr-dict `:` type($args)"; -} - -def NVVM_WMMAStoreF16M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f16.row.stride"> { - string llvmBuilder = [{ - createNvvmIntrinsicCall( - builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride, $args); - }]; - - string opDescription = [{ - Example: - - ```mlir - nvvm.wmma.m16n16k16.stored.f16.row.stride %0, %1, %2, %3, %4, %5, %6 : !llvm.ptr, - !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>, !llvm.i32 - ``` - }]; - - let description = !strconcat(baseDescription, opDescription); - - let verifier = [{ return ::verify(*this); }]; -} -def NVVM_WMMAStoreF32M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f32.row.stride"> { - string llvmBuilder = [{ - createNvvmIntrinsicCall( - builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride, $args); - }]; - - string opDescription = [{ Example: ```mlir - nvvm.wmma.m16n16k16.store.d.f32.row.stride %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, - %10 : !llvm.ptr, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, - !llvm.i32 + nvvm.wmma.store %0, %1, %2, %3, %4, %5 + {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16> ``` }]; - let description = !strconcat(baseDescription, opDescription); - + let assemblyFormat = "$ptr `,` $stride `,` $args attr-dict `:` type($ptr) `,` type($args)"; let verifier = [{ return ::verify(*this); }]; } // Base class for all the variants of WMMA mmaOps that may be defined. -class NVVM_WMMAMmaOp : NVVM_Op, +def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, Results<(outs LLVM_AnyStruct:$res)>, - Arguments<(ins Variadic:$args)>{ + Arguments<(ins I32Attr:$m, I32Attr:$n, I32Attr:$k, MMALayout:$layoutA, + MMALayout:$layoutB, MMATypes:$eltypeA, MMATypes:$eltypeB, + Variadic:$args)>{ let summary = "Warp synchronous matrix-multiply accumulate using tensor cores."; + let extraClassDeclaration = + "static llvm::Intrinsic::ID getIntrinsicID(" + "int m, int n, int k, mlir::NVVM::MMALayout layoutAEnum," + "mlir::NVVM::MMALayout layoutBEnum, mlir::NVVM::MMATypes eltypeAEnum," + "mlir::NVVM::MMATypes eltypeBEnum) {" + "llvm::StringRef layoutA = stringifyEnum(layoutAEnum);" + "llvm::StringRef layoutB = stringifyEnum(layoutBEnum);" + "llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);" + "llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);" + #MMA_MMA_INTR<"mma">.id# "\n" + "return 0;" + "}"; + + string llvmBuilder = [{ + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + auto intId = mlir::NVVM::WMMAMmaOp::getIntrinsicID( + $m, $n, $k, $layoutA, $layoutB, $eltypeA, $eltypeB); + $res = createIntrinsicCall(builder, intId, operands); + }]; + string baseDescription = [{ - The `nvvm.wmma.m*n*k*.mma` operation performs a matrix-multiply accumulate + The `nvvm.wmma.mma` operation performs a matrix-multiply accumulate (mma) operation using all the threads in a warp. The operation performed is represented as `D = A * B + C`. The operation takes @@ -340,64 +561,20 @@ current thread. The op returns a LLVM struct which holds a part of the result held by the current thread. - This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma. - m16n16k16.store`. - }]; -} - -def NVVM_WMMAMmaF16F16M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f16.f16">{ - string llvmBuilder = [{ - $res = createNvvmIntrinsicCall( - builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16, $args); - }]; - - string opDescription = [{ - Example: - - ```mlir - %20 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %0, %1, %2, %3, %4, %5, %6, %7, %8, - %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19 : vector<2xf16> -> !llvm.struct - <(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - ``` - }]; - - let parser = [{ - return parseWMMAMmaF16F16M16N16K16Op(parser, result); - }]; - - let printer = [{ - printWMMAMmaF16F16M16N16K16Op(p, *this); - }]; + This op is meant to be used along with `nvvm.wmma.load` and + `nvvm.wmma.store`. - let description = !strconcat(baseDescription, opDescription); - - let verifier = [{ return ::verify(*this); }]; -} - -def NVVM_WMMAMmaF32F32M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f32.f32">{ - string llvmBuilder = [{ - $res = createNvvmIntrinsicCall( - builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32, $args); - }]; - - string opDescription = [{ Example: ```mlir - %24 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %0, %1, %2, %3, %4, %5, %6, %7, %8 - %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23 : - (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, - vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, - vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, - vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, - f32, f32, f32, f32, f32, f32, f32)> + %16 = nvvm.wmma.mma %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15 + {eltypeA = "tf32", eltypeB = "f32", k = 8 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + : (i32, i32, i32, i32, i32, i32, i32, i32, f32, f32, f32, f32, f32, f32, f32, f32) + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> ``` }]; let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; - - let description = !strconcat(baseDescription, opDescription); - let verifier = [{ return ::verify(*this); }]; } Index: mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h =================================================================== --- mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -359,14 +359,6 @@ llvm::Intrinsic::ID intrinsic, ArrayRef args = {}, ArrayRef tys = {}); - -/// Creates a call to an LLVM IR intrinsic function with the given arguments -/// for NVVM WMMA ops. Handles cases where the intrinsic name is overloaded -/// using the types of arguments supplied. Selects the correct intrinsic -/// by inspecting the argument types. -llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder, - llvm::Intrinsic::ID intrinsic, - ArrayRef args = {}); } // namespace detail } // namespace LLVM Index: mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -36,26 +36,36 @@ return success(); } -/// Error string to emit when unimplemented WMMA variant is encountered. -static constexpr StringRef kInvalidCaseStr = - "Unimplemented WMMA variant, Only M16N16K16 version implemented."; +/// Error string to emit when a unimplemented WMMA variant is encountered. +static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant."; + +static NVVM::MMAFrag convertOperand(StringRef operandName) { + if (operandName.equals("AOp")) + return NVVM::MMAFrag::a; + if (operandName.equals("BOp")) + return NVVM::MMAFrag::b; + if (operandName.equals("COp")) + return NVVM::MMAFrag::c; + llvm_unreachable("Unknown operand name"); +} + +static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { + if (type.getElementType().isF16()) + return NVVM::MMATypes::f16; + if (type.getElementType().isF32()) + return type.getOperand().equals("COp") ? NVVM::MMATypes::f32 + : NVVM::MMATypes::tf32; + llvm_unreachable("Unsupported type"); +} /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. static LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) { - StringRef operandStr = type.getOperand(); - assert(type.getElementType().isa()); - Type baseType = type.getElementType().isF16() - ? VectorType::get(2, type.getElementType()) - : type.getElementType(); - auto getLLVMType = [&](int64_t numElements) { - return LLVM::LLVMStructType::getLiteral( - type.getContext(), SmallVector(numElements, baseType)); - }; - if (operandStr.equals("AOp") || operandStr.equals("BOp")) - return getLLVMType(8); - if (type.getElementType().isF16()) - return getLLVMType(4); - return getLLVMType(8); + NVVM::MMAFrag frag = convertOperand(type.getOperand()); + NVVM::MMATypes eltType = getElementType(type); + std::pair typeInfo = + inferMMAType(eltType, frag, type.getContext()); + return LLVM::LLVMStructType::getLiteral( + type.getContext(), SmallVector(typeInfo.second, typeInfo.first)); } /// This class implements the conversion of GPU MMA loadOp to wmma.load op @@ -118,41 +128,41 @@ gpu::MMAMatrixType retType = subgroupMmaLoadMatrixOp.res().getType().cast(); ArrayRef retTypeShape = retType.getShape(); + int64_t m = 0; + int64_t n = 0; + int64_t k = 0; + NVVM::MMATypes eltype = getElementType(retType); + // NVVM intrinsics require to give mxnxk dimensions, infer the missing + // dimension based on the valid intrinsics available. + if (retType.getOperand().equals("AOp")) { + m = retTypeShape[0]; + k = retTypeShape[1]; + n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype); + } else if (retType.getOperand().equals("BOp")) { + k = retTypeShape[0]; + n = retTypeShape[1]; + m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype); + } else if (retType.getOperand().equals("COp")) { + m = retTypeShape[0]; + n = retTypeShape[1]; + k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype); + } + NVVM::MMALayout layout = NVVM::MMALayout::row; + NVVM::MMAFrag frag = convertOperand(retType.getOperand()); + // Check that there is an exisiting instruction for the combination we need. + if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); Type resType = convertMMAToLLVMType(retType); - StringRef operandStr = retType.getOperand(); // Create nvvm.mma_load op according to the operand types. Value leadingDim32 = rewriter.create( loc, rewriter.getI32Type(), leadDimension); - SmallVector loadOpOperands({loadAddressCasted, leadingDim32}); - if (operandStr.equals("AOp")) { - if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { - rewriter.replaceOpWithNewOp(op, resType, - loadOpOperands); - } else { - return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - } - } else if (operandStr.equals("BOp")) { - if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { - rewriter.replaceOpWithNewOp(op, resType, - loadOpOperands); - } else { - return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - } - } else { - if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { - if (retType.getElementType().isF16()) { - rewriter.replaceOpWithNewOp( - op, resType, loadOpOperands); - } else if (retType.getElementType().isF32()) { - rewriter.replaceOpWithNewOp( - op, resType, loadOpOperands); - } - } else { - return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - } - } + + rewriter.replaceOpWithNewOp( + op, resType, loadAddressCasted, leadingDim32, m, n, k, layout, eltype, + frag); + return success(); } }; @@ -212,13 +222,18 @@ storeAddress); SmallVector storeOpOperands; - storeOpOperands.push_back(storeAddressCasted); - // Get the shape of the MMAMatrix type being stored. The shape will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType srcType = subgroupMmaStoreMatrixOp.src().getType().cast(); ArrayRef srcTypeShape = srcType.getShape(); + NVVM::MMALayout layout = NVVM::MMALayout::row; + NVVM::MMATypes eltype = getElementType(srcType); + int64_t m = srcTypeShape[0]; + int64_t n = srcTypeShape[1]; + int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype); + if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0) + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); auto matrixType = adaptor.src().getType().cast(); for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { @@ -229,29 +244,11 @@ } Value leadingDim32 = rewriter.create( loc, rewriter.getI32Type(), leadDimension); - storeOpOperands.push_back(leadingDim32); - // Unpack the results from the source. - if (srcType.getElementType().isF16()) { - // Create nvvm.mma_store op. - if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) { - rewriter.create(loc, storeOpOperands); - } else { - return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - } - rewriter.eraseOp(op); - return success(); - } - if (srcType.getElementType().isF32()) { - // Create nvvm.mma_store op. - if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) - rewriter.create(loc, storeOpOperands); - else { - return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - } - rewriter.eraseOp(op); - return success(); - } - return failure(); + rewriter.create(loc, storeAddressCasted, m, n, k, layout, + eltype, storeOpOperands, leadingDim32); + + rewriter.eraseOp(op); + return success(); } }; @@ -292,40 +289,27 @@ gpu::MMAMatrixType aType = subgroupMmaComputeOp.opA().getType().cast(); ArrayRef aTypeShape = aType.getShape(); - gpu::MMAMatrixType bType = - subgroupMmaComputeOp.opB().getType().cast(); - ArrayRef bTypeShape = bType.getShape(); gpu::MMAMatrixType cType = subgroupMmaComputeOp.opC().getType().cast(); ArrayRef cTypeShape = cType.getShape(); + int64_t m = cTypeShape[0]; + int64_t n = cTypeShape[1]; + int64_t k = aTypeShape[1]; + NVVM::MMALayout layout = NVVM::MMALayout::row; + NVVM::MMATypes sourceType = getElementType(aType); + NVVM::MMATypes destType = getElementType(cType); + if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, layout, layout, sourceType, + destType) == 0) + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); unpackOp(adaptor.opA()); unpackOp(adaptor.opB()); unpackOp(adaptor.opC()); - if (cType.getElementType().isF16()) { - if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && - bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { - // Create nvvm.wmma.mma op. - rewriter.replaceOpWithNewOp( - op, adaptor.opC().getType(), unpackedOps); - - return success(); - } - return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - } - if (cType.getElementType().isF32()) { - if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && - bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { - // Create nvvm.wmma.mma op. - rewriter.replaceOpWithNewOp( - op, adaptor.opC().getType(), unpackedOps); - - return success(); - } - return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - } - return failure(); + rewriter.replaceOpWithNewOp( + op, adaptor.opC().getType(), m, n, k, layout, layout, sourceType, + destType, unpackedOps); + return success(); } }; Index: mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -31,6 +31,7 @@ using namespace NVVM; #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" +#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" //===----------------------------------------------------------------------===// // Printing/parsing for NVVM ops @@ -132,201 +133,100 @@ return op.emitOpError("unimplemented mma.sync variant"); } -template -static LogicalResult verifyWMMALoadOp(T op, StringRef operand) { - MLIRContext *context = op.getContext(); - auto i32Ty = IntegerType::get(context, 32); - auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1); - auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3); - auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0); - auto f16Ty = FloatType::getF16(context); - auto f32Ty = FloatType::getF32(context); - auto f16x2Ty = VectorType::get(2, f16Ty); - auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( - context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); - auto f16x2x8StructTy = LLVM::LLVMStructType::getLiteral( - context, - {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); - auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( - context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); - - SmallVector operandTypes(op.getOperandTypes().begin(), - op.getOperandTypes().end()); - if (operandTypes != SmallVector{i32Ptr1Ty, i32Ty} && - operandTypes != SmallVector{i32Ptr3Ty, i32Ty} && - operandTypes != SmallVector{i32Ptr0Ty, i32Ty}) { - return op.emitOpError("expected operands to be a source pointer in memory " - "space 0, 1, 3 followed by ldm of the source"); +std::pair +inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) { + unsigned numberElements = 0; + Type elementType; + OpBuilder builder(context); + Type f16x2 = VectorType::get(2, builder.getF16Type()); + if (type == NVVM::MMATypes::f16) { + elementType = f16x2; + if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) + numberElements = 8; + else + numberElements = 4; + } else if (type == NVVM::MMATypes::f32) { + elementType = builder.getF32Type(); + numberElements = 8; + } else if (type == NVVM::MMATypes::tf32) { + elementType = builder.getI32Type(); + numberElements = 4; } - - if (operand.equals("AOp") || operand.equals("BOp")) { - if (op.getType() != f16x2x8StructTy) { - return op.emitOpError("expected result type of loadAOp and loadBOp to be " - "a struct of 8 s"); - } - } else if (operand.equals("COp")) { - if (op.getType() != f16x2x4StructTy && op.getType() != f32x8StructTy) { - return op.emitOpError("expected result type of loadCOp to be a struct of " - "4 s or 8 f32s"); - } - } - - return success(); -} - -static LogicalResult verify(WMMALoadAM16N16K16Op op) { - return verifyWMMALoadOp(op, "AOp"); -} - -static LogicalResult verify(WMMALoadBM16N16K16Op op) { - return verifyWMMALoadOp(op, "BOp"); -} - -static LogicalResult verify(WMMALoadCF16M16N16K16Op op) { - return verifyWMMALoadOp(op, "COp"); -} - -static LogicalResult verify(WMMALoadCF32M16N16K16Op op) { - return verifyWMMALoadOp(op, "COp"); -} - -template -static bool verifyWMMAStoreOp(T op, SmallVector &containedElems) { - SmallVector operandTypes(op.getOperandTypes().begin(), - op.getOperandTypes().end()); - if (operandTypes == containedElems) - return true; - - return false; -} - -static LogicalResult verify(WMMAStoreF16M16N16K16Op op) { - MLIRContext *context = op.getContext(); - auto i32Ty = IntegerType::get(context, 32); - auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1); - auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3); - auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0); - auto f16Ty = FloatType::getF16(context); - auto f16x2Ty = VectorType::get(2, f16Ty); - SmallVector type1{i32Ptr1Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty}; - SmallVector type0{i32Ptr0Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty}; - SmallVector type3{i32Ptr3Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty}; - if (verifyWMMAStoreOp(op, type1) || verifyWMMAStoreOp(op, type0) || - verifyWMMAStoreOp(op, type3)) - return success(); - - return op.emitOpError("expected operands to be a source pointer in memory" - "space 0, 1, 3 followed by ldm of the source"); -} - -static LogicalResult verify(WMMAStoreF32M16N16K16Op op) { - MLIRContext *context = op.getContext(); - auto i32Ty = IntegerType::get(context, 32); - auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1); - auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3); - auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0); - auto f32Ty = FloatType::getF32(context); - - SmallVector type1{i32Ptr1Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty, f32Ty, i32Ty}; - SmallVector type0{i32Ptr0Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty, f32Ty, i32Ty}; - SmallVector type3{i32Ptr3Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty, f32Ty, i32Ty}; - if (verifyWMMAStoreOp(op, type0) || verifyWMMAStoreOp(op, type1) || - verifyWMMAStoreOp(op, type3)) - return success(); - - return op.emitOpError("expected operands to be a source pointer in memory" - "space 0, 1, 3 followed by ldm of the source"); + assert(numberElements != 0 && elementType != nullptr); + return std::make_pair(elementType, numberElements); } -static LogicalResult verify(WMMAMmaF16F16M16N16K16Op op) { - MLIRContext *context = op.getContext(); - auto f16Ty = FloatType::getF16(context); - auto f16x2Ty = VectorType::get(2, f16Ty); - auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( - context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); - - SmallVector operandTypes(op.getOperandTypes().begin(), - op.getOperandTypes().end()); - if (operandTypes != SmallVector(20, f16x2Ty)) - return op.emitOpError("expected 20 s as operands"); - - if (op.getResult().getType() != f16x2x4StructTy) - return op.emitOpError("expected result type to be a struct of 4 s"); - +static LogicalResult verify(NVVM::WMMALoadOp op) { + unsigned addressSpace = + op.ptr().getType().cast().getAddressSpace(); + if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) + return op.emitOpError("expected source pointer in memory " + "space 0, 1, 3"); + + if (NVVM::WMMALoadOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(), + op.eltype(), op.frag()) == 0) + return op.emitOpError() << "invalid attribute combination"; + std::pair typeInfo = + inferMMAType(op.eltype(), op.frag(), op.getContext()); + Type dstType = LLVM::LLVMStructType::getLiteral( + op.getContext(), SmallVector(typeInfo.second, typeInfo.first)); + if (op.getType() != dstType) + return op.emitOpError("expected destination type is a structure of ") + << typeInfo.second << " elements of type " << typeInfo.first; return success(); } -static LogicalResult parseWMMAMmaF16F16M16N16K16Op(OpAsmParser &parser, - OperationState &result) { - SmallVector operands; - ::llvm::SMLoc operandsLoc; - Type operandType; - Type resType; - - operandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(operands) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || - parser.parseType(operandType) || parser.parseArrow()) - return failure(); - - unsigned numOperands = operands.size(); - SmallVector operandTypes(numOperands, operandType); - if (parser.parseType(resType)) - return failure(); - result.addTypes(resType); - if (parser.resolveOperands(operands, operandTypes, operandsLoc, - result.operands)) - return failure(); +static LogicalResult verify(NVVM::WMMAStoreOp op) { + unsigned addressSpace = + op.ptr().getType().cast().getAddressSpace(); + if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) + return op.emitOpError("expected operands to be a source pointer in memory " + "space 0, 1, 3"); + + if (NVVM::WMMAStoreOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(), + op.eltype()) == 0) + return op.emitOpError() << "invalid attribute combination"; + std::pair typeInfo = + inferMMAType(op.eltype(), NVVM::MMAFrag::c, op.getContext()); + if (op.args().size() != typeInfo.second) + return op.emitOpError() + << "expected " << typeInfo.second << " data operands"; + if (llvm::any_of(op.args(), [&typeInfo](Value operands) { + return operands.getType() != typeInfo.first; + })) + return op.emitOpError() + << "expected data operands of type " << typeInfo.first; return success(); } -static void printWMMAMmaF16F16M16N16K16Op(OpAsmPrinter &p, - WMMAMmaF16F16M16N16K16Op &op) { - p << ' '; - p << op.args(); - p.printOptionalAttrDict(op->getAttrs(), {}); - p << " : "; - p << op->getOperand(0).getType(); - p << ' ' << "->"; - p << ' '; - p << ::llvm::ArrayRef<::mlir::Type>(op.res().getType()); -} - -static LogicalResult verify(WMMAMmaF32F32M16N16K16Op op) { - unsigned numABOperands = 16; - unsigned numCOperands = 8; - MLIRContext *context = op.getContext(); - auto f16Ty = FloatType::getF16(context); - auto f32Ty = FloatType::getF32(context); - auto f16x2Ty = VectorType::get(2, f16Ty); - auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( - context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); - - SmallVector abOpTypes; - SmallVector bOpTypes; - SmallVector cOpTypes; - - for (auto operand : op->getOperands().take_front(numABOperands)) { - abOpTypes.push_back(operand.getType()); - } - - for (auto operand : - op->getOperands().drop_front(numABOperands).take_front(numCOperands)) { - cOpTypes.push_back(operand.getType()); +static LogicalResult verify(NVVM::WMMAMmaOp op) { + if (NVVM::WMMAMmaOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layoutA(), + op.layoutB(), op.eltypeA(), + op.eltypeB()) == 0) + return op.emitOpError() << "invalid attribute combination"; + std::pair typeInfoA = + inferMMAType(op.eltypeA(), NVVM::MMAFrag::a, op.getContext()); + std::pair typeInfoB = + inferMMAType(op.eltypeA(), NVVM::MMAFrag::b, op.getContext()); + std::pair typeInfoC = + inferMMAType(op.eltypeB(), NVVM::MMAFrag::c, op.getContext()); + SmallVector arguments; + arguments.append(typeInfoA.second, typeInfoA.first); + arguments.append(typeInfoB.second, typeInfoB.first); + arguments.append(typeInfoC.second, typeInfoC.first); + unsigned numArgs = arguments.size(); + if (op.args().size() != numArgs) + return op.emitOpError() << "expected " << numArgs << " arguments"; + for (unsigned i = 0; i < numArgs; i++) { + if (op.args()[i].getType() != arguments[i]) + return op.emitOpError() + << "expected argument " << i << " to be of type " << arguments[i]; } - - if (abOpTypes != SmallVector(16, f16x2Ty)) - return op.emitOpError("expected 16 s for `a` and `b` operand"); - - if (cOpTypes != SmallVector(8, f32Ty)) - return op.emitOpError("expected 8 f32s for `c` operand"); - - if (op.getResult().getType() != f32x8StructTy) - return op.emitOpError("expected result type to be a struct of 8 f32s"); - + Type dstType = LLVM::LLVMStructType::getLiteral( + op.getContext(), SmallVector(typeInfoC.second, typeInfoC.first)); + if (op.getType() != dstType) + return op.emitOpError("expected destination type is a structure of ") + << typeInfoC.second << " elements of type " << typeInfoC.first; return success(); } Index: mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -22,7 +22,6 @@ using namespace mlir; using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; -using mlir::LLVM::detail::createNvvmIntrinsicCall; static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, bool withPredicate) { Index: mlir/lib/Target/LLVMIR/ModuleTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -442,29 +442,6 @@ return builder.CreateCall(fn, args); } -llvm::Value * -mlir::LLVM::detail::createNvvmIntrinsicCall(llvm::IRBuilderBase &builder, - llvm::Intrinsic::ID intrinsic, - ArrayRef args) { - llvm::Module *module = builder.GetInsertBlock()->getModule(); - llvm::Function *fn; - if (llvm::Intrinsic::isOverloaded(intrinsic)) { - if (intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16 && - intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32) { - // NVVM load and store instrinsic names are overloaded on the - // source/destination pointer type. Pointer is the first argument in the - // corresponding NVVM Op. - fn = llvm::Intrinsic::getDeclaration(module, intrinsic, - {args[0]->getType()}); - } else { - fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {}); - } - } else { - fn = llvm::Intrinsic::getDeclaration(module, intrinsic); - } - return builder.CreateCall(fn, args); -} - /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LogicalResult Index: mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -22,7 +22,8 @@ // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM32]] : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]] + // CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 @@ -36,7 +37,8 @@ // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr // CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK32: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM32]] : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]] + // CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> } @@ -70,7 +72,8 @@ // CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM32]] : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 + // CHECK: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] + // CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> // CHECK: llvm.return // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 @@ -88,7 +91,8 @@ // CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 - // CHECK32: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM32]] : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 + // CHECK32: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] + // CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> // CHECK32: llvm.return return } @@ -122,7 +126,9 @@ // CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - // CHECK: %[[RES:.*]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] + // CHECK-SAME: {eltypeA = "f16", eltypeB = "f16", k = 16 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} : ( + // CHECK-SAME: vector<2xf16>, {{.*}}) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> return %D : !gpu.mma_matrix<16x16xf16, "COp"> } @@ -133,13 +139,13 @@ gpu.module @test_module { // CHECK-LABEL: func @gpu_wmma_mma_loop_op -// CHECK: %[[C:.+]] = nvvm.wmma.m16n16k16.load.c.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) // CHECK: ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2 // CHECK: llvm.cond_br %{{.*}}, ^bb2, ^bb3 // CHECK: ^bb2: // pred: ^bb1 -// CHECK: %[[A:.+]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: %[[B:.+]] = nvvm.wmma.m16n16k16.load.b.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> @@ -160,14 +166,14 @@ // CHECK: %[[ACC1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: %[[ACC_MUL:.+]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[ACC_MUL:.+]] = nvvm.wmma.mma %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] {eltypeA = "f16", eltypeB = "f16", k = 16 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} : (vector<2xf16>, {{.*}} -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) // CHECK: ^bb3: // pred: ^bb1 // CHECK: %[[E0:.+]] = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[E1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[E2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[E3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> -// CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]], %{{.*}} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 +// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> builtin.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) { %c0 = arith.constant 0 : index Index: mlir/test/Dialect/LLVMIR/invalid.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/invalid.mlir +++ mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1010,44 +1010,30 @@ // ----- llvm.func @wmmaLoadOp_invalid_mem_space(%arg0: !llvm.ptr, %arg1: i32) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}} - %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - - llvm.return -} - -// ----- - -llvm.func @wmmaLoadOp_invalid_missing_ldm(%arg0: !llvm.ptr, %arg1: i32) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}} - %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0: (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - - llvm.return -} - -// ----- - -llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr, %arg1: i32) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 s}} - %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - + // expected-error@+1 {{'nvvm.wmma.load' op expected source pointer in memory space 0, 1, 3}} + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> llvm.return } // ----- llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr, %arg1: i32) { - // expected-error@+1 {{nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 s}} - %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - + // expected-error@+1 {{'nvvm.wmma.load' op expected destination type is a structure of 8 elements of type 'vector<2xf16>'}} + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> llvm.return } // ----- llvm.func @wmmaLoadOp_invalid_BOp(%arg0: !llvm.ptr, %arg1: i32) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.b.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 s}} - %0 = nvvm.wmma.m16n16k16.load.b.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // expected-error@+1 {{'nvvm.wmma.load' op expected destination type is a structure of 8 elements of type 'vector<2xf16>'}} + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> llvm.return } @@ -1055,29 +1041,23 @@ // ----- llvm.func @wmmaLoadOp_invalid_COp(%arg0: !llvm.ptr, %arg1: i32) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.c.f16.row.stride' op expected result type of loadCOp to be a struct of 4 s or 8 f32s}} - %0 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // expected-error@+1 {{'nvvm.wmma.load' op expected destination type is a structure of 4 elements of type 'vector<2xf16>'}} + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> llvm.return } // ----- -llvm.func @wmmaStoreOp_invalid_mem_space(%arg0: !llvm.ptr, %arg1: vector<2 x f16>, +llvm.func @wmmaStoreOp_invalid_mem_space(%arg0: !llvm.ptr, %arg1: i32, %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, - %arg4: vector<2 xf16>, %arg5: i32) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}} - nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32 - llvm.return -} - -// ----- - -llvm.func @wmmaStoreOp_invalid_missing_ldm(%arg0: !llvm.ptr, %arg1: vector<2 x f16>, - %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, - %arg4: vector<2 xf16>, %arg5: i32) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}} - nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4 : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16> + %arg4: vector<2 x f16>, %arg5: vector<2 xf16>) { + // expected-error@+1 {{'nvvm.wmma.store' op expected operands to be a source pointer in memory space 0, 1, 3}} + nvvm.wmma.store %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 + {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16> llvm.return } @@ -1093,8 +1073,12 @@ %arg14: vector<2 x f16>, %arg15: vector<2 x f16>, %arg16: vector<2 x f16>, %arg17: vector<2 x f16>, %arg18: vector<2 x f16>) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f16.f16' op expected 20 s as operands}} - %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> + // expected-error@+1 {{'nvvm.wmma.mma' op expected 20 arguments}} + %0 = nvvm.wmma.mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18 + {eltypeA = "f16", eltypeB = "f16", k = 16 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + : (vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, + vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>) + -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> llvm.return } @@ -1110,9 +1094,12 @@ %arg14: vector<2 x f16>, %arg15: vector<2 x f16>, %arg16: vector<2 x f16>, %arg17: vector<2 x f16>, %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) { - // expected-error@+1 {{expected result type to be a struct of 4 s}} - %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> - llvm.return + // expected-error@+1 {{'nvvm.wmma.mma' op expected destination type is a structure of 4 elements of type 'vector<2xf16>'}} + %0 = nvvm.wmma.mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 + {eltypeA = "f16", eltypeB = "f16", k = 16 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + : (vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, + vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>) + -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> llvm.return } // ----- @@ -1127,8 +1114,10 @@ %arg14: vector<2 x f16>, %arg15: f32, %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32, %arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 16 s for `a` and `b` operand}} - %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // expected-error@+1 {{'nvvm.wmma.mma' op expected argument 15 to be of type 'vector<2xf16>'}} + %0 = nvvm.wmma.mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 + {eltypeA = "f16", eltypeB = "f32", k = 16 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return } @@ -1144,8 +1133,10 @@ %arg14: vector<2 x f16>, %arg15: vector<2xf16>, %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32, %arg20: f32, %arg21: f32, %arg22: f32, %arg23: vector<2xf16>) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 8 f32s for `c` operand}} - %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // expected-error@+1 {{'nvvm.wmma.mma' op expected argument 23 to be of type 'f32'}} + %0 = nvvm.wmma.mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 + {eltypeA = "f16", eltypeB = "f32", k = 16 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return } @@ -1161,8 +1152,10 @@ %arg14: vector<2 x f16>, %arg15: vector<2xf16>, %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32, %arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) { - // expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected result type to be a struct of 8 f32s}} - %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)> + // expected-error@+1 {{'nvvm.wmma.mma' op expected destination type is a structure of 8 elements of type 'f32'}} + %0 = nvvm.wmma.mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 + {eltypeA = "f16", eltypeB = "f32", k = 16 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)> llvm.return } Index: mlir/test/Dialect/LLVMIR/nvvm.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/nvvm.mlir +++ mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -69,6 +69,27 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr, %arg1 : i32) -> !llvm.struct<(i32, i32, i32, i32)> { + // CHECK: nvvm.wmma.load {{.*}} {eltype = "tf32", frag = "a", k = 8 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = "tf32", frag = "a", k = 8 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return %0 : !llvm.struct<(i32, i32, i32, i32)> +} + +func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32, + %6 : i32, %7 : i32, %8 : f32, %9 : f32, %10 : f32, + %11 : f32, %12 : f32, %13 : f32, %14 : f32, %15 : f32) + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> { + // CHECK: nvvm.wmma.mma {{.*}} {eltypeA = "tf32", eltypeB = "f32", k = 8 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + %r = nvvm.wmma.mma %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15 + {eltypeA = "tf32", eltypeB = "f32", k = 8 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + : (i32, i32, i32, i32, i32, i32, i32, i32, f32, f32, f32, f32, f32, f32, f32, f32) + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> +} + + // ----- // expected-error@below {{attribute attached to unexpected op}} Index: mlir/test/Target/LLVMIR/nvvmir.mlir =================================================================== --- mlir/test/Target/LLVMIR/nvvmir.mlir +++ mlir/test/Target/LLVMIR/nvvmir.mlir @@ -77,18 +77,22 @@ // in the LLVM NVPTX backend. llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr, %arg1: i32) { // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}}) - %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> llvm.return } // The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic // in the LLVM NVPTX backend. -llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr, %arg1: vector<2 x f16>, +llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr, %arg1: i32, %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, - %arg4: vector<2 xf16>, %arg5: i32) { + %arg4: vector<2 xf16>, %arg5: vector<2 x f16>) { // CHECK: call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, <2 x half> {{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 %{{.*}}) - nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32 + nvvm.wmma.store %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 + {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16> llvm.return } @@ -105,8 +109,32 @@ %arg16: vector<2 x f16>, %arg17: vector<2 x f16>, %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) { // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> + %0 = nvvm.wmma.mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 + {eltypeA = "f16", eltypeB = "f16", k = 16 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + : (vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, + vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, + vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, + vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>) + -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> + llvm.return +} + +llvm.func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr, %arg1 : i32) { + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %{{.*}}, i32 %{{.*}}) + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = "tf32", frag = "a", k = 8 : i32, layout = "row", m = 16 : i32, n = 16 : i32} + : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} +llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32, + %6 : i32, %7 : i32, %8 : f32, %9 : f32, %10 : f32, + %11 : f32, %12 : f32, %13 : f32, %14 : f32, %15 : f32) { + // CHECK: { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) + %r = nvvm.wmma.mma %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15 + {eltypeA = "tf32", eltypeB = "f32", k = 8 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} + : (i32, i32, i32, i32, i32, i32, i32, i32, f32, f32, f32, f32, f32, f32, f32, f32) + -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return } Index: utils/bazel/llvm-project-overlay/mlir/BUILD.bazel =================================================================== --- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3295,6 +3295,14 @@ ], "include/mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc", ), + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc", + ), ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/LLVMIR/NVVMOps.td",