Index: llvm/trunk/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h =================================================================== --- llvm/trunk/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h +++ llvm/trunk/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h @@ -40,6 +40,8 @@ const char *Modifier = nullptr); void printLdStCode(const MCInst *MI, int OpNum, raw_ostream &O, const char *Modifier = nullptr); + void printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O, + const char *Modifier = nullptr); void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O, const char *Modifier = nullptr); void printProtoIdent(const MCInst *MI, int OpNum, Index: llvm/trunk/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp =================================================================== --- llvm/trunk/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp +++ llvm/trunk/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp @@ -269,6 +269,20 @@ llvm_unreachable("Empty Modifier"); } +void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O, + const char *Modifier) { + const MCOperand &MO = MI->getOperand(OpNum); + int Imm = (int)MO.getImm(); + if (Modifier == nullptr || strcmp(Modifier, "version") == 0) { + O << Imm; // Just print out PTX version + } else if (strcmp(Modifier, "aligned") == 0) { + // PTX63 requires '.aligned' in the name of the instruction. + if (Imm >= 63) + O << ".aligned"; + } else + llvm_unreachable("Unknown Modifier"); +} + void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O, const char *Modifier) { printOperand(MI, OpNum, O); Index: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1548,6 +1548,10 @@ let PrintMethod = "printLdStCode"; } +def MmaCode : Operand { + let PrintMethod = "printMmaCode"; +} + def SDTWrapper : SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>, SDTCisPtrTy<0>]>; def Wrapper : SDNode<"NVPTXISD::Wrapper", SDTWrapper>; Index: llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td +++ llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -38,6 +38,24 @@ }]; } +// A node that will be replaced with the current PTX version. +class PTX { + SDNodeXForm PTXVerXform = SDNodeXFormgetPTXVersion(), SDLoc(N)); + }]>; + // (i32 0) will be XForm'ed to the currently used PTX version. + dag version = (PTXVerXform (i32 0)); +} +def ptx : PTX; + +// Generates list of n sequential register names. +// E.g. RegNames<3,"r">.ret -> ["r0", "r1", "r2" ] +class RegSeq { + list ret = !if(n, !listconcat(RegSeq.ret, + [prefix # !add(n, -1)]), + []); +} + //----------------------------------- // Synchronization and shuffle functions //----------------------------------- @@ -7384,14 +7402,6 @@ NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;", [(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>; -class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>; -// Generates list of n sequential register names. -class RegSeq { - list ret = !if(n, !listconcat(RegSeq.ret, - [prefix # !add(n, -1)]), - []); -} - // Helper class that represents a 'fragment' of an NVPTX *MMA instruction. // In addition to target-independent fields provided by WMMA_REGS, it adds // the fields commonly used to implement specific PTX instruction -- register @@ -7408,6 +7418,7 @@ // List of register names for the fragment -- ["ra0", "ra1",...] list reg_names = RegSeq.ret; + // Generates "{{$r0, $r1,.... $rN-1}}" for use in asm string construction. string regstring = "{{$" # !head(reg_names) # !foldl("", !tail(reg_names), a, b, @@ -7437,61 +7448,65 @@ dag Ins = !dag(ins, ptx_regs, reg_names); } -class BuildPattern { +// Convert dag of arguments into a dag to match given intrinsic. +class BuildPatternI { // Build a dag pattern that matches the intrinsic call. - // We want a dag that looks like this: - // (set , (intrinsic )) where input and - // output arguments are named patterns that would match corresponding - // input/output arguments of the instruction. - // - // First we construct (set ) from instruction's outs dag by - // replacing dag operator 'outs' with 'set'. - dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp)); - // Similarly, construct (intrinsic ) sub-dag from - // instruction's input arguments, only now we also need to replace operands - // with patterns that would match them and the operator 'ins' with the - // intrinsic. - dag PatArgs = !foreach(tmp, Ins, - !subst(imem, ADDRvar, - !subst(MEMri64, ADDRri64, - !subst(MEMri, ADDRri, - !subst(ins, IntrMatcher, tmp))))); - // Finally, consatenate both parts together. !con() requires both dags to have - // the same operator, so we wrap PatArgs in a (set ...) dag. - dag ret = !con(PatOuts, (set PatArgs)); + dag ret = !foreach(tmp, Ins, + !subst(imem, ADDRvar, + !subst(MEMri64, ADDRri64, + !subst(MEMri, ADDRri, + !subst(ins, Intr, tmp))))); +} + +// Same as above, but uses PatFrag instead of an Intrinsic. +class BuildPatternPF { + // Build a dag pattern that matches the intrinsic call. + dag ret = !foreach(tmp, Ins, + !subst(imem, ADDRvar, + !subst(MEMri64, ADDRri64, + !subst(MEMri, ADDRri, + !subst(ins, Intr, tmp))))); +} + +// Common WMMA-related fields used for building patterns for all MMA instructions. +class WMMA_INSTR _Args> + : NVPTXInst<(outs), (ins), "?", []> { + Intrinsic Intr = !cast(_Intr); + // Concatenate all arguments into a single dag. + dag Args = !foldl((ins), _Args, a, b, !con(a,b)); + // Pre-build the pattern to match (intrinsic arg0, arg1, ...). + dag IntrinsicPattern = BuildPatternI(Intr), Args>.ret; } // // wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] // -class WMMA_LOAD_INTR_HELPER - : PatFrag <(ops),(ops)> { - // Intrinsic that matches this instruction. - Intrinsic Intr = !cast(WMMA_NAME_LDST<"load", Frag, Layout, - WithStride>.record); - let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src)); - let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, - !eq(Space, ".global"): AS_match.global, - 1: AS_match.generic); -} - class WMMA_LOAD - : EmptyNVPTXInst, + : WMMA_INSTR.record, + [!con((ins SrcOp:$src), + !if(WithStride, (ins Int32Regs:$ldm), (ins)))]>, Requires { - // Pattern that matches the intrinsic for this instruction variant. - PatFrag IntrMatcher = WMMA_LOAD_INTR_HELPER; - dag Ins = !con((ins SrcOp:$src), !if(WithStride, (ins Int32Regs:$ldm), (ins))); + // Load/store intrinsics are overloaded on pointer's address space. + // To match the right intrinsic, we need to build AS-constrained PatFrag. + // Operands is a dag equivalent in shape to Args, but using (ops node:$name, .....). + dag PFOperands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src)); + // Build PatFrag that only matches particular address space. + PatFrag IntrFrag = PatFrag; + // Build AS-constrained pattern. + let IntrinsicPattern = BuildPatternPF.ret; - let Pattern = [BuildPattern.ret]; let OutOperandList = Frag.Outs; - let InOperandList = Ins; + let InOperandList = !con(Args, (ins MmaCode:$ptx)); let AsmString = "wmma.load." # Frag.frag # ".sync" + # "${ptx:aligned}" # "." # Layout # "." # Frag.geom # Space @@ -7505,87 +7520,79 @@ // // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] // -class WMMA_STORE_INTR_HELPER - : PatFrag <(ops),(ops)> { - // Intrinsic that matches this instruction. - Intrinsic Intr = !cast(WMMA_NAME_LDST<"store", Frag, Layout, - WithStride>.record); - let Operands = !con((ops node:$dst), - !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names), - !if(WithStride, (ops node:$ldm), (ops))); - let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, - !eq(Space, ".global"): AS_match.global, - 1: AS_match.generic); -} - -class WMMA_STORE - : EmptyNVPTXInst, +class WMMA_STORE_D + : WMMA_INSTR.record, + [!con((ins DstOp:$dst), + Frag.Ins, + !if(WithStride, (ins Int32Regs:$ldm), (ins)))]>, Requires { - PatFrag IntrMatcher = WMMA_STORE_INTR_HELPER; - dag Ins = !con((ins DstOp:$src), - Frag.Ins, - !if(WithStride, (ins Int32Regs:$ldm), (ins))); - let Pattern = [BuildPattern<(set), IntrMatcher, Ins>.ret]; + + // Load/store intrinsics are overloaded on pointer's address space. + // To match the right intrinsic, we need to build AS-constrained PatFrag. + // Operands is a dag equivalent in shape to Args, but using (ops node:$name, .....). + dag PFOperands = !con((ops node:$dst), + !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names), + !if(WithStride, (ops node:$ldm), (ops))); + // Build PatFrag that only matches particular address space. + PatFrag IntrFrag = PatFrag; + // Build AS-constrained pattern. + let IntrinsicPattern = BuildPatternPF.ret; + + let InOperandList = !con(Args, (ins MmaCode:$ptx)); let OutOperandList = (outs); - let InOperandList = Ins; - let AsmString = "wmma.store.d.sync." - # Layout + let AsmString = "wmma.store.d.sync" + # "${ptx:aligned}" + # "." # Layout # "." # Frag.geom # Space # "." # Frag.ptx_elt_type - # " \t[$src]," + # " \t[$dst]," # Frag.regstring # !if(WithStride, ", $ldm", "") # ";"; } // Create all load/store variants -foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { - foreach layout = ["row", "col"] in { - foreach stride = [0, 1] in { - foreach space = [".global", ".shared", ""] in { - foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { - foreach frag = [WMMA_REGINFO, - WMMA_REGINFO, - WMMA_REGINFO, - WMMA_REGINFO] in { - def : WMMA_LOAD; - } - foreach frag = [WMMA_REGINFO, - WMMA_REGINFO] in { - def : WMMA_STORE; - } - } // addr - } // space - } // stride - } // layout -} // geom +defset list MMA_LDSTs = { + foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach space = [".global", ".shared", ""] in { + foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { + foreach frag = [WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO] in { + def : WMMA_LOAD; + } + foreach frag = [WMMA_REGINFO, + WMMA_REGINFO] in { + def : WMMA_STORE_D; + } + } // addr + } // space + } // stride + } // layout + } // geom +} // defset // WMMA.MMA class WMMA_MMA - : EmptyNVPTXInst, + : WMMA_INSTR.record, + [FragA.Ins, FragB.Ins, FragC.Ins]>, Requires { - //Intrinsic Intr = int_nvvm_suld_1d_v4i32_zero; - Intrinsic Intr = !cast(WMMA_NAME_MMA.record); - dag Outs = FragD.Outs; - dag Ins = !con(FragA.Ins, - FragB.Ins, - FragC.Ins); - - // Construct the pattern to match corresponding intrinsic call. - // mma does not load/store anything, so we don't need complex operand matching here. - dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp)); - dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp)); - let Pattern = [!con(PatOuts, (set PatArgs))]; - let OutOperandList = Outs; - let InOperandList = Ins; - let AsmString = "wmma.mma.sync." - # ALayout + let OutOperandList = FragD.Outs; + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + let AsmString = "wmma.mma.sync" + # "${ptx:aligned}" + # "." # ALayout # "." # BLayout # "." # FragA.geom # "." # FragD.ptx_elt_type @@ -7597,20 +7604,34 @@ # FragC.regstring # ";"; } -foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { - foreach layout_a = ["row", "col"] in { - foreach layout_b = ["row", "col"] in { - foreach frag_c = [WMMA_REGINFO, - WMMA_REGINFO] in { - foreach frag_d = [WMMA_REGINFO, - WMMA_REGINFO] in { - foreach satf = [0, 1] in { - def : WMMA_MMA, - WMMA_REGINFO, - frag_c, frag_d, layout_a, layout_b, satf>; - } // satf - } // frag_d - } // frag_c - } // layout_b - } // layout_a -} // geom +defset list MMAs = { + foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach frag_c = [WMMA_REGINFO, + WMMA_REGINFO] in { + foreach frag_d = [WMMA_REGINFO, + WMMA_REGINFO] in { + foreach satf = [0, 1] in { + def : WMMA_MMA, + WMMA_REGINFO, + frag_c, frag_d, layout_a, layout_b, satf>; + } // satf + } // frag_d + } // frag_c + } // layout_b + } // layout_a + } // geom +} // defset + +// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a +// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with +// the instruction record. +class WMMA_PAT + : Pat; + +// Build intrinsic->instruction patterns for all MMA instructions. +foreach mma = !listconcat(MMAs, MMA_LDSTs) in + def : WMMA_PAT; Index: llvm/trunk/test/CodeGen/NVPTX/wmma.py =================================================================== --- llvm/trunk/test/CodeGen/NVPTX/wmma.py +++ llvm/trunk/test/CodeGen/NVPTX/wmma.py @@ -3,9 +3,12 @@ # RUN: python %s > %t.ll # RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll +# RUN: python %s --ptx=63 > %t-ptx63.ll +# RUN: llc < %t-ptx63.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx63 | FileCheck %t-ptx63.ll from __future__ import print_function +import argparse from itertools import product from string import Template @@ -64,7 +67,7 @@ } """ intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}" - instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}" + instruction_template = "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}" for geom, abc, layout, space, stride, itype in product( known_geoms, @@ -76,6 +79,7 @@ params = { "abc" : abc, + "aligned" : ".aligned" if ptx_version >= 63 else "", "layout" : layout, "space" : space, "stride" : stride, @@ -135,7 +139,7 @@ } """ intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}" - instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}" + instruction_template = "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}" for geom, abc, layout, space, stride, itype in product( known_geoms, @@ -147,6 +151,7 @@ params = { "abc" : abc, + "aligned" : ".aligned" if ptx_version >= 63 else "", "layout" : layout, "space" : space, "stride" : stride, @@ -191,7 +196,7 @@ } """ intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}" - instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}" + instruction_template = "wmma.mma.sync${aligned}.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}" for geom, alayout, blayout, ctype, dtype, satf in product( known_geoms, @@ -202,6 +207,7 @@ [".satfinite", ""]): params = { + "aligned" : ".aligned" if ptx_version >= 63 else "", "alayout" : alayout, "blayout" : blayout, "ctype" : ctype, @@ -230,4 +236,9 @@ gen_wmma_store_tests() gen_wmma_mma_tests() +parser = argparse.ArgumentParser() +parser.add_argument('--ptx', type=int, default=60) +args = parser.parse_args() +ptx_version = args.ptx + main()