Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -37,6 +37,69 @@ // MISC // +// Helper class for construction of n-element list [t,t,...,t] +class RepLLVMType { + list ret = !if(N, !listconcat(RepLLVMType.ret, [T]), []); +} + +// Helper class that represents a 'fragment' of an NVPTX *MMA instruction. +// Geom: mnk. E.g. m8n32k16 +// Frag: [abcd] +// PtxEltType: PTX type for the element. +class WMMA_REGS { + string geom = Geom; + string frag = Frag; + string ptx_elt_type = PtxEltType; + string ft = frag#":"#ptx_elt_type; + list regs = !cond( + // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16 + // All currently supported geometries use the same fragment format, + // so we only need to consider {fragment, type}. + !eq(ft,"a:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret, + !eq(ft,"b:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret, + !eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, + !eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, + !eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret, + !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret); +} + +class WMMA_NAME_LDST { + string intr = "llvm.nvvm.wmma." + # Frag.geom + # "." # Op + # "." # Frag.frag + # "." # Layout + # !if(WithStride, ".stride", "") + # "." # Frag.ptx_elt_type + ; + // TODO(tra): record name should ideally use the same field order as the intrinsic. + // E.g. string record = !subst("llvm", "int", + // !subst(".", "_", llvm)); + string record = "int_nvvm_wmma_" + # Frag.geom + # "_" # Op + # "_" # Frag.frag + # "_" # Frag.ptx_elt_type + # "_" # Layout + # !if(WithStride, "_stride", ""); +} + +class WMMA_NAME_MMA { + string llvm = "llvm.nvvm.wmma." + # C.geom + # ".mma" + # "." # ALayout + # "." # BLayout + # "." # D.ptx_elt_type // Intrinsic encodes 'd' first. + # "." # C.ptx_elt_type + # !if(Satfinite, ".satfinite", ""); + + string record = !subst(".", "_", + !subst("llvm.", "int_", llvm)); +} + let TargetPrefix = "nvvm" in { def int_nvvm_prmt : GCCBuiltin<"__nvvm_prmt">, Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], @@ -3889,166 +3952,69 @@ // // WMMA instructions // - // WMMA.LOAD -class NVVM_WMMA_LD_GALSTS - : Intrinsic + : Intrinsic, NoCapture<0>], - "llvm.nvvm.wmma." - # Geometry - # ".load" - # "." # Abc - # "." # Layout - # !if(WithStride, ".stride", "") - # "." # Type>; - -multiclass NVVM_WMMA_LD_GALT { - def _stride: NVVM_WMMA_LD_GALSTS; - def NAME : NVVM_WMMA_LD_GALSTS; -} - -multiclass NVVM_WMMA_LD_GAT { - defm _row: NVVM_WMMA_LD_GALT; - defm _col: NVVM_WMMA_LD_GALT; -} - -multiclass NVVM_WMMA_LD_G { - defm _a_f16: NVVM_WMMA_LD_GAT; - defm _b_f16: NVVM_WMMA_LD_GAT; - defm _c_f16: NVVM_WMMA_LD_GAT; - defm _c_f32: NVVM_WMMA_LD_GAT; -} - -multiclass NVVM_WMMA_LD { - defm _m32n8k16_load: NVVM_WMMA_LD_G<"m32n8k16">; - defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">; - defm _m8n32k16_load: NVVM_WMMA_LD_G<"m8n32k16">; -} - -defm int_nvvm_wmma: NVVM_WMMA_LD; + WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.intr>; // WMMA.STORE.D -class NVVM_WMMA_STD_GLSTSEmpty=[]> +class NVVM_WMMA_ST : Intrinsic<[], !listconcat( [llvm_anyptr_ty], - !if(!eq(Type,"f16"), - [regty, regty, regty, regty], - [regty, regty, regty, regty, - regty, regty, regty, regty]), - !if(WithStride, [llvm_i32_ty], Empty)), + Frag.regs, + !if(WithStride, [llvm_i32_ty], [])), [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>], - "llvm.nvvm.wmma." - # Geometry - # ".store.d" - # "." # Layout - # !if(WithStride, ".stride", "") - # "." # Type>; + WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>; -multiclass NVVM_WMMA_STD_GLT { - def _stride: NVVM_WMMA_STD_GLSTS; - def NAME: NVVM_WMMA_STD_GLSTS; +// Create all load/store variants +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach frag = [WMMA_REGS, + WMMA_REGS, + WMMA_REGS, + WMMA_REGS] in { + def WMMA_NAME_LDST<"load", frag, layout, stride>.record + : NVVM_WMMA_LD; + } + foreach frag = [WMMA_REGS, + WMMA_REGS] in { + def WMMA_NAME_LDST<"store", frag, layout, stride>.record + : NVVM_WMMA_ST; + } + } + } } -multiclass NVVM_WMMA_STD_GT { - defm _row: NVVM_WMMA_STD_GLT; - defm _col: NVVM_WMMA_STD_GLT; -} -multiclass NVVM_WMMA_STD_G { - defm _d_f16: NVVM_WMMA_STD_GT; - defm _d_f32: NVVM_WMMA_STD_GT; -} - -multiclass NVVM_WMMA_STD { - defm _m32n8k16_store: NVVM_WMMA_STD_G<"m32n8k16">; - defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">; - defm _m8n32k16_store: NVVM_WMMA_STD_G<"m8n32k16">; -} - -defm int_nvvm_wmma: NVVM_WMMA_STD; - // WMMA.MMA -class NVVM_WMMA_MMA_GABDCS - : Intrinsic + : Intrinsic.regs, + WMMA_REGS.regs, + C.regs), [IntrNoMem], - "llvm.nvvm.wmma." - # Geometry - # ".mma" - # "." # ALayout - # "." # BLayout - # "." # DType - # "." # CType - # Satfinite> { -} - -multiclass NVVM_WMMA_MMA_GABDC { - def NAME : NVVM_WMMA_MMA_GABDCS; - def _satfinite: NVVM_WMMA_MMA_GABDCS; -} - -multiclass NVVM_WMMA_MMA_GABD { - defm _f16: NVVM_WMMA_MMA_GABDC; - defm _f32: NVVM_WMMA_MMA_GABDC; -} - -multiclass NVVM_WMMA_MMA_GAB { - defm _f16: NVVM_WMMA_MMA_GABD; - defm _f32: NVVM_WMMA_MMA_GABD; -} - -multiclass NVVM_WMMA_MMA_GA { - defm _col: NVVM_WMMA_MMA_GAB; - defm _row: NVVM_WMMA_MMA_GAB; -} + WMMA_NAME_MMA.llvm>; -multiclass NVVM_WMMA_MMA_G { - defm _col: NVVM_WMMA_MMA_GA; - defm _row: NVVM_WMMA_MMA_GA; +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach frag_c = [WMMA_REGS, + WMMA_REGS] in { + foreach frag_d = [WMMA_REGS, + WMMA_REGS] in { + foreach satf = [0, 1] in { + def WMMA_NAME_MMA.record + : NVVM_WMMA_MMA; + } + } + } + } + } } -multiclass NVVM_WMMA_MMA { - defm _m32n8k16_mma : NVVM_WMMA_MMA_G<"m32n8k16">; - defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">; - defm _m8n32k16_mma : NVVM_WMMA_MMA_G<"m8n32k16">; -} - -defm int_nvvm_wmma : NVVM_WMMA_MMA; - } // let TargetPrefix = "nvvm" Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -26,7 +26,17 @@ return (d==1.0); }]>; - +def AS_match { + code generic = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); + }]; + code shared = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); + }]; + code global = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); + }]; +} //----------------------------------- // Synchronization and shuffle functions @@ -1006,17 +1016,11 @@ //----------------------------------- class ATOMIC_GLOBAL_CHK - : PatFrag; + : PatFrag; class ATOMIC_SHARED_CHK - : PatFrag; + : PatFrag; class ATOMIC_GENERIC_CHK - : PatFrag; + : PatFrag; multiclass F_ATOMIC_2_imp; -// -// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] -// - class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>; +// Generates list of n sequential register names. +class RegSeq { + list ret = !if(n, !listconcat(RegSeq.ret, + [prefix # !add(n, -1)]), + []); +} -class WMMA_LOAD_GALSTOS - : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic - // for this function. - PatFrag IntrMatcher = !cast("INT_WMMA_" - # Geometry # "_load_" - # !subst("c", "c_" # Type, Abc) - # "_" # Layout - # !subst(".", "_", Space) - # !if(WithStride,"_stride", "") - # "_Intr"); - dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); - dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); - dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47)); +// Helper class that represents a 'fragment' of an NVPTX *MMA instruction. +// In addition to target-independent fields provided by WMMA_REGS, it adds +// the fields commonly used to implement specific PTX instruction -- register +// types and names, constraints, parts of assembly, etc. +class WMMA_REGINFO + : WMMA_REGS { + // NVPTX register types used to carry fragment data. + NVPTXRegClass regclass = !cond( + !eq(PtxEltType, "f16") : Float16x2Regs, + !eq(PtxEltType, "f32") : Float32Regs); - dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); - dag Ins = !con((ins SrcOp:$src), StrideArg); + // Instruction input/output arguments for the fragment. + list ptx_regs = !foreach(tmp, regs, regclass); + // List of register names for the fragment -- ["ra0", "ra1",...] + list reg_names = RegSeq.ret; + // Generates "{{$r0, $r1,.... $rN-1}}" for use in asm string construction. + string regstring = "{{$" # !head(reg_names) + # !foldl("", !tail(reg_names), a, b, + !strconcat(a, ", $", b)) + # "}}"; + + // Predicates for particular fragment variant. Technically those are + // per-instruction predicates, but currently all fragments that can be used in + // a given instruction are subject to the same constraints, so an instruction + // can use predicates from any of its fragments. If/when this is no + // longer the case, we can concat all per-fragment predicates to enforce that + // all fragments of the instruction are viable. + list Predicates = !cond( + // fp16 -> fp16/fp32 @ m16n16k16 + !and(!eq(Geom, "m16n16k16"), + !or(!eq(PtxEltType, "f16"), + !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60], + + // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16 + !and(!or(!eq(Geom, "m8n32k16"), + !eq(Geom, "m32n8k16")), + !or(!eq(PtxEltType, "f16"), + !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]); + + // template DAGs for instruction inputs/output. + dag Outs = !dag(outs, ptx_regs, reg_names); + dag Ins = !dag(ins, ptx_regs, reg_names); +} + +class BuildPattern { // Build a dag pattern that matches the intrinsic call. // We want a dag that looks like this: // (set , (intrinsic )) where input and @@ -7430,277 +7458,127 @@ !subst(ins, IntrMatcher, tmp))))); // Finally, consatenate both parts together. !con() requires both dags to have // the same operator, so we wrap PatArgs in a (set ...) dag. - let Pattern = [!con(PatOuts, (set PatArgs))]; - let OutOperandList = Outs; + dag ret = !con(PatOuts, (set PatArgs)); +} + +// +// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] +// + +class WMMA_LOAD_INTR_HELPER + : PatFrag <(ops),(ops)> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast(WMMA_NAME_LDST<"load", Frag, Layout, + WithStride>.record); + let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src)); + let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; + let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, + !eq(Space, ".global"): AS_match.global, + 1: AS_match.generic); +} + +class WMMA_LOAD + : EmptyNVPTXInst, + Requires { + // Pattern that matches the intrinsic for this instruction variant. + PatFrag IntrMatcher = WMMA_LOAD_INTR_HELPER; + dag Ins = !con((ins SrcOp:$src), !if(WithStride, (ins Int32Regs:$ldm), (ins))); + + let Pattern = [BuildPattern.ret]; + let OutOperandList = Frag.Outs; let InOperandList = Ins; let AsmString = "wmma.load." - # Abc + # Frag.frag # ".sync" # "." # Layout - # "." # Geometry + # "." # Frag.geom # Space - # "." # Type # " \t" - # !if(!eq(Abc#Type, "cf16"), - "{{$r0, $r1, $r2, $r3}}", - "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + # "." # Frag.ptx_elt_type # " \t" + # Frag.regstring # ", [$src]" # !if(WithStride, ", $ldm", "") # ";"; } -class WMMA_LOAD_INTR_HELPER - : PatFrag <(ops),(ops)> { - // Intrinsic that matches this instruction. - Intrinsic Intr = !cast("int_nvvm_wmma" - # "_" # Geometry # "_load_" - # Abc # "_" # Type # "_" # Layout - # !if(WithStride,"_stride", "")); - code match_generic = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); - }]; - code match_shared = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); - }]; - code match_global = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); - }]; - - let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src)); - let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !if(!eq(Space, ".shared"), match_shared, - !if(!eq(Space, ".global"), match_global, match_generic)); -} - -multiclass WMMA_LOAD_GALSTS { - def _avar: WMMA_LOAD_GALSTOS; - def _areg: WMMA_LOAD_GALSTOS; - def _areg64: WMMA_LOAD_GALSTOS; - def _ari: WMMA_LOAD_GALSTOS; - def _ari64: WMMA_LOAD_GALSTOS; -} - -multiclass WMMA_LOAD_GALSTSh { - // Define a PatFrag that matches appropriate intrinsic that loads from the - // given address space. - def _Intr: WMMA_LOAD_INTR_HELPER; - defm NAME: WMMA_LOAD_GALSTS; -} - -multiclass WMMA_LOAD_GALST { - defm _stride: WMMA_LOAD_GALSTSh; - defm NAME: WMMA_LOAD_GALSTSh; -} - -multiclass WMMA_LOAD_GALT { - defm _global: WMMA_LOAD_GALST; - defm _shared: WMMA_LOAD_GALST; - defm NAME: WMMA_LOAD_GALST; -} - -multiclass WMMA_LOAD_GAT { - defm _row: WMMA_LOAD_GALT; - defm _col: WMMA_LOAD_GALT; -} - -multiclass WMMA_LOAD_G { - defm _load_a: WMMA_LOAD_GAT; - defm _load_b: WMMA_LOAD_GAT; - defm _load_c_f16: WMMA_LOAD_GAT; - defm _load_c_f32: WMMA_LOAD_GAT; -} - -defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">; -defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">; -defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">; - // // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] // -class WMMA_STORE_D_GLSTSO - : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - PatFrag IntrMatcher = !cast("INT_WMMA" - # "_" # Geometry # "_store_d" - # "_" # Type - # "_" # Layout - # !subst(".", "_", Space) - # !if(WithStride,"_stride", "") - # "_Intr"); - dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, - regclass:$r2, regclass:$r3); - dag InsR47 = (ins regclass:$r4, regclass:$r5, - regclass:$r6, regclass:$r7); - dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47)); - dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); - dag Ins = !con(InsR, StrideArg); +class WMMA_STORE_INTR_HELPER + : PatFrag <(ops),(ops)> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast(WMMA_NAME_LDST<"store", Frag, Layout, + WithStride>.record); + let Operands = !con((ops node:$dst), + !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names), + !if(WithStride, (ops node:$ldm), (ops))); + let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; + let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, + !eq(Space, ".global"): AS_match.global, + 1: AS_match.generic); +} - // Construct the pattern to match corresponding intrinsic call. See the - // details in the comments in WMMA_LOAD_ALSTOS. - dag PatArgs = !foreach(tmp, Ins, - !subst(imem, ADDRvar, - !subst(MEMri64, ADDRri64, - !subst(MEMri, ADDRri, - !subst(ins, IntrMatcher, tmp))))); - let Pattern = [PatArgs]; +class WMMA_STORE + : EmptyNVPTXInst, + Requires { + PatFrag IntrMatcher = WMMA_STORE_INTR_HELPER; + dag Ins = !con((ins DstOp:$src), + Frag.Ins, + !if(WithStride, (ins Int32Regs:$ldm), (ins))); + let Pattern = [BuildPattern<(set), IntrMatcher, Ins>.ret]; let OutOperandList = (outs); let InOperandList = Ins; let AsmString = "wmma.store.d.sync." # Layout - # "." # Geometry + # "." # Frag.geom # Space - # "." # Type + # "." # Frag.ptx_elt_type # " \t[$src]," - # !if(!eq(Type,"f16"), - "{{$r0, $r1, $r2, $r3}}", - "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + # Frag.regstring # !if(WithStride, ", $ldm", "") # ";"; - -} - -class WMMA_STORE_INTR_HELPER - : PatFrag <(ops),(ops)> { - // Intrinsic that matches this instruction. - Intrinsic Intr = !cast("int_nvvm_wmma_" - # Geometry - # "_store_d" - # "_" # Type - # "_" # Layout - # !if(WithStride, "_stride", "")); - code match_generic = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); - }]; - code match_shared = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); - }]; - code match_global = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); - }]; - - dag Args = !if(!eq(Type,"f16"), - (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3), - (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3, - node:$r4, node:$r5, node:$r6, node:$r7)); - dag StrideArg = !if(WithStride, (ops node:$ldm), (ops)); - let Operands = !con(Args, StrideArg); - let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !if(!eq(Space, ".shared"), match_shared, - !if(!eq(Space, ".global"), match_global, match_generic)); -} - -multiclass WMMA_STORE_D_GLSTS { - def _avar: WMMA_STORE_D_GLSTSO; - def _areg: WMMA_STORE_D_GLSTSO; - def _areg64: WMMA_STORE_D_GLSTSO; - def _ari: WMMA_STORE_D_GLSTSO; - def _ari64: WMMA_STORE_D_GLSTSO; -} - -multiclass WMMA_STORE_D_GLSTSh { - // Define a PatFrag that matches appropriate intrinsic that loads from the - // given address space. - def _Intr: WMMA_STORE_INTR_HELPER; - defm NAME: WMMA_STORE_D_GLSTS; -} - -multiclass WMMA_STORE_D_GLST { - defm _stride: WMMA_STORE_D_GLSTSh; - defm NAME: WMMA_STORE_D_GLSTSh; -} - -multiclass WMMA_STORE_D_GLT { - defm _global: WMMA_STORE_D_GLST; - defm _shared: WMMA_STORE_D_GLST; - defm NAME: WMMA_STORE_D_GLST; -} - -multiclass WMMA_STORE_D_GT { - defm _row: WMMA_STORE_D_GLT; - defm _col: WMMA_STORE_D_GLT; -} - -multiclass WMMA_STORE_D_G { - defm _store_d_f16: WMMA_STORE_D_GT; - defm _store_d_f32: WMMA_STORE_D_GT; } -defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">; -defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">; -defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">; +// Create all load/store variants +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach space = [".global", ".shared", ""] in { + foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { + foreach frag = [WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO] in { + def : WMMA_LOAD; + } + foreach frag = [WMMA_REGINFO, + WMMA_REGINFO] in { + def : WMMA_STORE; + } + } // addr + } // space + } // stride + } // layout +} // geom // WMMA.MMA -class WMMA_MMA_GABDCS +class WMMA_MMA : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - Intrinsic Intr = !cast("int_nvvm_wmma_" - # Geometry - # "_mma" - # "_" # ALayout - # "_" # BLayout - # "_" # DType - # "_" # CType - # !subst(".", "_", Satfinite)); - dag Outs = !if(!eq(DType,"f16"), - (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3), - (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3, - d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)); - dag InsExtraCArgs = !if(!eq(CType,"f16"), - (ins), - (ins c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)); - dag Ins = !con((ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3, - ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7, - ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3, - ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7, - c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3), - InsExtraCArgs); + Requires { + //Intrinsic Intr = int_nvvm_suld_1d_v4i32_zero; + Intrinsic Intr = !cast(WMMA_NAME_MMA.record); + dag Outs = FragD.Outs; + dag Ins = !con(FragA.Ins, + FragB.Ins, + FragC.Ins); - // Construct the pattern to match corresponding intrinsic call. See the - // details in the comments in WMMA_LOAD_ALSTOS. + // Construct the pattern to match corresponding intrinsic call. + // mma does not load/store anything, so we don't need complex operand matching here. dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp)); dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp)); let Pattern = [!con(PatOuts, (set PatArgs))]; @@ -7709,54 +7587,30 @@ let AsmString = "wmma.mma.sync." # ALayout # "." # BLayout - # "." # Geometry - # "." # DType - # "." # CType - # Satfinite # "\n\t\t" - # !if(!eq(DType,"f16"), - "{{$d0, $d1, $d2, $d3}}, \n\t\t", - "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t") - # "{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t" - # "{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t" - # !if(!eq(CType,"f16"), - "{{$c0, $c1, $c2, $c3}};", - "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"); + # "." # FragA.geom + # "." # FragD.ptx_elt_type + # "." # FragC.ptx_elt_type + # !if(Satfinite, ".satfinite", "") # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"; } -multiclass WMMA_MMA_GABDC { - def _satfinite: WMMA_MMA_GABDCS; - def NAME: WMMA_MMA_GABDCS; -} - -multiclass WMMA_MMA_GABD { - defm _f16: WMMA_MMA_GABDC; - defm _f32: WMMA_MMA_GABDC; -} - -multiclass WMMA_MMA_GAB { - defm _f16: WMMA_MMA_GABD; - defm _f32: WMMA_MMA_GABD; -} - -multiclass WMMA_MMA_GA { - defm _col: WMMA_MMA_GAB; - defm _row: WMMA_MMA_GAB; -} - -multiclass WMMA_MMA_G { - defm _col: WMMA_MMA_GA; - defm _row: WMMA_MMA_GA; -} - -defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">; -defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">; -defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">; +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach frag_c = [WMMA_REGINFO, + WMMA_REGINFO] in { + foreach frag_d = [WMMA_REGINFO, + WMMA_REGINFO] in { + foreach satf = [0, 1] in { + def : WMMA_MMA, + WMMA_REGINFO, + frag_c, frag_d, layout_a, layout_b, satf>; + } // satf + } // frag_d + } // frag_c + } // layout_b + } // layout_a +} // geom