Index: llvm/trunk/include/llvm/Target/TargetLowering.h =================================================================== --- llvm/trunk/include/llvm/Target/TargetLowering.h +++ llvm/trunk/include/llvm/Target/TargetLowering.h @@ -2348,6 +2348,7 @@ bool IsInReg : 1; bool DoesNotReturn : 1; bool IsReturnValueUsed : 1; + bool IsConvergent : 1; // IsTailCall should be modified by implementations of // TargetLowering::LowerCall that perform tail call conversions. @@ -2366,10 +2367,11 @@ SmallVector Ins; CallLoweringInfo(SelectionDAG &DAG) - : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false), - IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true), - IsTailCall(false), NumFixedArgs(-1), CallConv(CallingConv::C), - DAG(DAG), CS(nullptr), IsPatchPoint(false) {} + : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false), + IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true), + IsConvergent(false), IsTailCall(false), NumFixedArgs(-1), + CallConv(CallingConv::C), DAG(DAG), CS(nullptr), IsPatchPoint(false) { + } CallLoweringInfo &setDebugLoc(SDLoc dl) { DL = dl; @@ -2441,6 +2443,11 @@ return *this; } + CallLoweringInfo &setConvergent(bool Value = true) { + IsConvergent = Value; + return *this; + } + CallLoweringInfo &setSExtResult(bool Value = true) { RetSExt = Value; return *this; Index: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -5562,9 +5562,11 @@ isTailCall = false; TargetLowering::CallLoweringInfo CLI(DAG); - CLI.setDebugLoc(getCurSDLoc()).setChain(getRoot()) - .setCallee(RetTy, FTy, Callee, std::move(Args), CS) - .setTailCall(isTailCall); + CLI.setDebugLoc(getCurSDLoc()) + .setChain(getRoot()) + .setCallee(RetTy, FTy, Callee, std::move(Args), CS) + .setTailCall(isTailCall) + .setConvergent(CS.isConvergent()); std::pair Result = lowerInvokable(CLI, EHPadBB); if (Result.first.getNode()) { Index: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h =================================================================== --- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h +++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h @@ -34,7 +34,9 @@ DeclareRet, DeclareScalarRet, PrintCall, + PrintConvergentCall, PrintCallUni, + PrintConvergentCallUni, CallArgBegin, CallArg, LastCallArg, Index: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -314,8 +314,12 @@ return "NVPTXISD::DeclareRetParam"; case NVPTXISD::PrintCall: return "NVPTXISD::PrintCall"; + case NVPTXISD::PrintConvergentCall: + return "NVPTXISD::PrintConvergentCall"; case NVPTXISD::PrintCallUni: return "NVPTXISD::PrintCallUni"; + case NVPTXISD::PrintConvergentCallUni: + return "NVPTXISD::PrintConvergentCallUni"; case NVPTXISD::LoadParam: return "NVPTXISD::LoadParam"; case NVPTXISD::LoadParamV2: @@ -1439,8 +1443,12 @@ SDValue PrintCallOps[] = { Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InFlag }; - Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall), - dl, PrintCallVTs, PrintCallOps); + // We model convergent calls as separate opcodes. + unsigned Opcode = Func ? NVPTXISD::PrintCallUni : NVPTXISD::PrintCall; + if (CLI.IsConvergent) + Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni + : NVPTXISD::PrintConvergentCall; + Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps); InFlag = Chain.getValue(1); // Ops to print out the function name Index: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1701,9 +1701,15 @@ def PrintCall : SDNode<"NVPTXISD::PrintCall", SDTPrintCallProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def PrintConvergentCall : + SDNode<"NVPTXISD::PrintConvergentCall", SDTPrintCallProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def PrintCallUni : SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def PrintConvergentCallUni : + SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallUniProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def StoreParam : SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; @@ -1821,53 +1827,44 @@ []>; let isCall=1 in { - def PrintCallNoRetInst : NVPTXInst<(outs), (ins), - "call ", [(PrintCall (i32 0))]>; - def PrintCallRetInst1 : NVPTXInst<(outs), (ins), - "call (retval0), ", [(PrintCall (i32 1))]>; - def PrintCallRetInst2 : NVPTXInst<(outs), (ins), - "call (retval0, retval1), ", [(PrintCall (i32 2))]>; - def PrintCallRetInst3 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2), ", [(PrintCall (i32 3))]>; - def PrintCallRetInst4 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3), ", [(PrintCall (i32 4))]>; - def PrintCallRetInst5 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4), ", - [(PrintCall (i32 5))]>; - def PrintCallRetInst6 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5), ", - [(PrintCall (i32 6))]>; - def PrintCallRetInst7 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ", - [(PrintCall (i32 7))]>; - def PrintCallRetInst8 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6, " - "retval7), ", - [(PrintCall (i32 8))]>; - - def PrintCallUniNoRetInst : NVPTXInst<(outs), (ins), - "call.uni ", [(PrintCallUni (i32 0))]>; - def PrintCallUniRetInst1 : NVPTXInst<(outs), (ins), - "call.uni (retval0), ", [(PrintCallUni (i32 1))]>; - def PrintCallUniRetInst2 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1), ", [(PrintCallUni (i32 2))]>; - def PrintCallUniRetInst3 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2), ", [(PrintCallUni (i32 3))]>; - def PrintCallUniRetInst4 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3), ", [(PrintCallUni (i32 4))]>; - def PrintCallUniRetInst5 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4), ", - [(PrintCallUni (i32 5))]>; - def PrintCallUniRetInst6 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5), ", - [(PrintCallUni (i32 6))]>; - def PrintCallUniRetInst7 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ", - [(PrintCallUni (i32 7))]>; - def PrintCallUniRetInst8 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6, " - "retval7), ", - [(PrintCallUni (i32 8))]>; + multiclass CALL { + def PrintCallNoRetInst : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " "), [(OpNode (i32 0))]>; + def PrintCallRetInst1 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0), "), [(OpNode (i32 1))]>; + def PrintCallRetInst2 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1), "), [(OpNode (i32 2))]>; + def PrintCallRetInst3 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2), "), [(OpNode (i32 3))]>; + def PrintCallRetInst4 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3), "), + [(OpNode (i32 4))]>; + def PrintCallRetInst5 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4), "), + [(OpNode (i32 5))]>; + def PrintCallRetInst6 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5), "), + [(OpNode (i32 6))]>; + def PrintCallRetInst7 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5, retval6), "), + [(OpNode (i32 7))]>; + def PrintCallRetInst8 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5, retval6, retval7), "), + [(OpNode (i32 8))]>; + } +} + +defm Call : CALL<"call", PrintCall>; +defm CallUni : CALL<"call.uni", PrintCallUni>; + +// Convergent call instructions. These are identical to regular calls, except +// they have the isConvergent bit set. +let isConvergent=1 in { + defm ConvergentCall : CALL<"call", PrintConvergentCall>; + defm ConvergentCallUni : CALL<"call.uni", PrintConvergentCallUni>; } def LoadParamMemI64 : LoadParamMemInst; Index: llvm/trunk/test/CodeGen/NVPTX/convergent-mir-call.ll =================================================================== --- llvm/trunk/test/CodeGen/NVPTX/convergent-mir-call.ll +++ llvm/trunk/test/CodeGen/NVPTX/convergent-mir-call.ll @@ -0,0 +1,27 @@ +; RUN: llc -mtriple nvptx64-nvidia-cuda -stop-after machine-cp -o - < %s 2>&1 | FileCheck %s + +; Check that convergent calls are emitted using convergent MIR instructions, +; while non-convergent calls are not. + +target triple = "nvptx64-nvidia-cuda" + +declare void @conv() convergent +declare void @not_conv() + +define void @test(void ()* %f) { + ; CHECK: ConvergentCallUniPrintCall + ; CHECK-NEXT: @conv + call void @conv() + + ; CHECK: CallUniPrintCall + ; CHECK-NEXT: @not_conv + call void @not_conv() + + ; CHECK: ConvergentCallPrintCall + call void %f() convergent + + ; CHECK: CallPrintCall + call void %f() + + ret void +}