Index: lib/Transforms/IPO/FunctionAttrs.cpp =================================================================== --- lib/Transforms/IPO/FunctionAttrs.cpp +++ lib/Transforms/IPO/FunctionAttrs.cpp @@ -935,49 +935,65 @@ return MadeChange; } -/// Removes convergent attributes where we can prove that none of the SCC's -/// callees are themselves convergent. Returns true if successful at removing -/// the attribute. +/// Remove convergent attributes on calls where we can prove that the callee is +/// not a convergent function, and remove convergent attributes on the SCC's +/// functions where we can prove that none of the SCC's callees are themselves +/// convergent. Returns true if any changes were made. static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) { - // Determines whether a function can be made non-convergent, ignoring all - // other functions in SCC. (A function can *actually* be made non-convergent - // only if all functions in its SCC can be made convergent.) - auto CanRemoveConvergent = [&](Function *F) { - if (!F->isConvergent()) - return true; - - // Can't remove convergent from declarations. - if (F->isDeclaration()) - return false; - - for (Instruction &I : instructions(*F)) - if (auto CS = CallSite(&I)) { - // Can't remove convergent if any of F's callees -- ignoring functions - // in the SCC itself -- are convergent. This needs to consider both - // function calls and intrinsic calls. We also assume indirect calls - // might call a convergent function. - // FIXME: We should revisit this when we put convergent onto calls - // instead of functions so that indirect calls which should be - // convergent are required to be marked as such. - Function *Callee = CS.getCalledFunction(); - if (!Callee || (SCCNodes.count(Callee) == 0 && Callee->isConvergent())) - return false; - } - - return true; - }; - - // We can remove the convergent attr from functions in the SCC if they all - // can be made non-convergent (because they call only non-convergent - // functions, other than each other). - if (!llvm::all_of(SCCNodes, CanRemoveConvergent)) + // Can't remove convergent from function declarations. + if (llvm::any_of(SCCNodes, + [](Function *F) { return F && F->isDeclaration(); })) return false; - // If we got here, all of the SCC's callees are non-convergent. Therefore all - // of the SCC's functions can be marked as non-convergent. + bool MadeChange = false; + + // Remove the convergent attribute from calls where the callee is not + // convergent. Keep track of whether, after doing this, there exists a + // convergent call to a function not in the SCC. Also keep a list of + // convergent calls to functions within the SCC. + bool HasConvergentNonSelfCall = false; + SmallVector ConvergentSelfCalls; + for (Function *F : SCCNodes) { + for (Instruction &I : instructions(*F)) + if (auto CS = CallSite(&I)) { + if (!CS.isConvergent()) + continue; + Function *Callee = CS.getCalledFunction(); + // Remove the convergent attr from calls to non-convergent functions. + if (Callee && !Callee->isConvergent()) { + DEBUG(dbgs() << "Removing convergent attr from instr " + << *CS.getInstruction() << "\n"); + CS.setNotConvergent(); + MadeChange = true; + continue; + } + if (Callee && SCCNodes.count(Callee) > 0) { + ConvergentSelfCalls.push_back(CS); + continue; + } + HasConvergentNonSelfCall = true; + } + } + + // If one of the SCC's functions contains a convergent call to a function not + // in the SCC, then any functions in the SCC which are convergent must remain + // convergent. + if (HasConvergentNonSelfCall) + return MadeChange; + + // If we got here, all of the calls the SCC makes to functions not in the SCC + // are non-convergent. Therefore all of the self-calls within the SCC can be + // marked as non-convergent, and all of the SCC's functions can also be made + // non-convergent. + for (CallSite &CS : ConvergentSelfCalls) { + DEBUG(dbgs() << "Removing convergent attr from instr " + << *CS.getInstruction() << "\n"); + CS.setNotConvergent(); + } for (Function *F : SCCNodes) { if (F->isConvergent()) - DEBUG(dbgs() << "Removing convergent attr from " << F->getName() << "\n"); + DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName() + << "\n"); F->setNotConvergent(); } return true; Index: test/Transforms/FunctionAttrs/convergent.ll =================================================================== --- test/Transforms/FunctionAttrs/convergent.ll +++ test/Transforms/FunctionAttrs/convergent.ll @@ -4,7 +4,9 @@ ; CHECK-NOT: convergent ; CHECK-NEXT: define i32 @nonleaf() define i32 @nonleaf() convergent { - %a = call i32 @leaf() + ; We should remove the convergent attr from the call. + ; CHECK: call i32 @leaf(){{$}} + %a = call i32 @leaf() convergent ret i32 %a } @@ -24,10 +26,43 @@ ; CHECK-SAME: convergent ; CHECK-NEXT: define i32 @extern() define i32 @extern() convergent { + ; CHECK: call i32 @k() [[CONVERGENT_ATTR:#[0-9]+]] + %a = call i32 @k() convergent + ret i32 %a +} + +; Convergent should not be removed on the function here. Although the call is +; not convergent, it picks up the convergent attr from the callee. +; +; CHECK: Function Attrs +; CHECK-SAME: convergent +; CHECK-NEXT: define i32 @extern_non_convergent_call() +define i32 @extern_non_convergent_call() convergent { %a = call i32 @k() ret i32 %a } +; For indirect calls, we can't look at the callee to figure out whether or not +; it's convergent. We have to rely on the call instr. +; +; CHECK: Function Attrs +; CHECK-SAME: convergent +; CHECK-NEXT: define i32 @indirect_convergent_call( +define i32 @indirect_convergent_call(i32 ()* %f) convergent { + %a = call i32 %f() convergent + ret i32 %a +} +; Give indirect_non_convergent_call the norecurse attribute so we get a +; "Function Attrs" comment in the output. +; +; CHECK: Function Attrs +; CHECK-NOT: convergent +; CHECK-NEXT: define i32 @indirect_non_convergent_call( +define i32 @indirect_non_convergent_call(i32 ()* %f) convergent norecurse { + %a = call i32 %f() + ret i32 %a +} + ; CHECK: Function Attrs ; CHECK-SAME: convergent ; CHECK-NEXT: define i32 @call_extern() @@ -49,21 +84,12 @@ ret i32 0 } -@xyz = global i32 ()* null -; CHECK: Function Attrs -; CHECK-SAME: convergent -; CHECK-NEXT: define i32 @functionptr() -define i32 @functionptr() convergent { - %1 = load i32 ()*, i32 ()** @xyz - %2 = call i32 %1() - ret i32 %2 -} - ; CHECK: Function Attrs ; CHECK-NOT: convergent ; CHECK-NEXT: define i32 @recursive1() define i32 @recursive1() convergent { - %a = call i32 @recursive2() + ; CHECK: call i32 @recursive2(){{$}} + %a = call i32 @recursive2() convergent ret i32 %a } @@ -71,7 +97,8 @@ ; CHECK-NOT: convergent ; CHECK-NEXT: define i32 @recursive2() define i32 @recursive2() convergent { - %a = call i32 @recursive1() + ; CHECK: call i32 @recursive1(){{$}} + %a = call i32 @recursive1() convergent ret i32 %a } @@ -79,7 +106,8 @@ ; CHECK-SAME: convergent ; CHECK-NEXT: define i32 @noopt() define i32 @noopt() convergent optnone noinline { - %a = call i32 @noopt_friend() + ; CHECK: call i32 @noopt_friend() [[CONVERGENT_ATTR]] + %a = call i32 @noopt_friend() convergent ret i32 0 } @@ -92,3 +120,5 @@ %a = call i32 @noopt() ret i32 0 } + +; CHECK: [[CONVERGENT_ATTR]] = { convergent }