Index: lib/Transforms/IPO/FunctionAttrs.cpp =================================================================== --- lib/Transforms/IPO/FunctionAttrs.cpp +++ lib/Transforms/IPO/FunctionAttrs.cpp @@ -935,49 +935,37 @@ return MadeChange; } -/// Removes convergent attributes where we can prove that none of the SCC's -/// callees are themselves convergent. Returns true if successful at removing -/// the attribute. +/// Remove the convergent attribute from all functions in the SCC if every +/// callsite within the SCC is not convergent (except for calls to functions +/// within the SCC). Returns true if changes were made. static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) { - // Determines whether a function can be made non-convergent, ignoring all - // other functions in SCC. (A function can *actually* be made non-convergent - // only if all functions in its SCC can be made convergent.) - auto CanRemoveConvergent = [&](Function *F) { - if (!F->isConvergent()) - return true; - - // Can't remove convergent from declarations. - if (F->isDeclaration()) - return false; - - for (Instruction &I : instructions(*F)) - if (auto CS = CallSite(&I)) { - // Can't remove convergent if any of F's callees -- ignoring functions - // in the SCC itself -- are convergent. This needs to consider both - // function calls and intrinsic calls. We also assume indirect calls - // might call a convergent function. - // FIXME: We should revisit this when we put convergent onto calls - // instead of functions so that indirect calls which should be - // convergent are required to be marked as such. - Function *Callee = CS.getCalledFunction(); - if (!Callee || (SCCNodes.count(Callee) == 0 && Callee->isConvergent())) - return false; - } - - return true; - }; - - // We can remove the convergent attr from functions in the SCC if they all - // can be made non-convergent (because they call only non-convergent - // functions, other than each other). - if (!llvm::all_of(SCCNodes, CanRemoveConvergent)) + // No point checking if none of SCCNodes is convergent. + if (!llvm::any_of(SCCNodes, [](Function *F) { return F->isConvergent(); })) return false; - // If we got here, all of the SCC's callees are non-convergent. Therefore all - // of the SCC's functions can be marked as non-convergent. + // Can't remove convergent from function declarations. + if (llvm::any_of(SCCNodes, [](Function *F) { return F->isDeclaration(); })) + return false; + + // Can't remove convergent if any of our functions has a convergent call to a + // function not in the SCC. + for (Function *F : SCCNodes) + for (Instruction &I : instructions(*F)) { + CallSite CS(&I); + // Bail if is CS a convergent call to a function not in the SCC. + if (CS && CS.isConvergent() && + SCCNodes.count(CS.getCalledFunction()) == 0) + return false; + } + + // If we got here, all of the calls the SCC makes to functions not in the SCC + // are non-convergent. Therefore all of the SCC's functions can also be made + // non-convergent. We'll remove the attr from the callsites in + // InstCombineCalls. for (Function *F : SCCNodes) { if (F->isConvergent()) - DEBUG(dbgs() << "Removing convergent attr from " << F->getName() << "\n"); + DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName() + << "\n"); F->setNotConvergent(); } return true; Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2033,7 +2033,15 @@ if (!isa(Callee) && transformConstExprCastCall(CS)) return nullptr; - if (Function *CalleeF = dyn_cast(Callee)) + if (Function *CalleeF = dyn_cast(Callee)) { + // Remove the convergent attr on calls when the callee is not convergent. + if (CS.isConvergent() && !CalleeF->isConvergent()) { + DEBUG(dbgs() << "Removing convergent attr from instr " + << CS.getInstruction() << "\n"); + CS.setNotConvergent(); + return CS.getInstruction(); + } + // If the call and callee calling conventions don't match, this call must // be unreachable, as the call is undefined. if (CalleeF->getCallingConv() != CS.getCallingConv() && @@ -2058,6 +2066,7 @@ Constant::getNullValue(CalleeF->getType())); return nullptr; } + } if (isa(Callee) || isa(Callee)) { // If CS does not return void then replaceAllUsesWith undef. Index: test/Transforms/FunctionAttrs/convergent.ll =================================================================== --- test/Transforms/FunctionAttrs/convergent.ll +++ test/Transforms/FunctionAttrs/convergent.ll @@ -1,4 +1,4 @@ -; RUN: opt < %s -basicaa -functionattrs -rpo-functionattrs -S | FileCheck %s +; RUN: opt -functionattrs -S < %s | FileCheck %s ; CHECK: Function Attrs ; CHECK-NOT: convergent @@ -24,16 +24,37 @@ ; CHECK-SAME: convergent ; CHECK-NEXT: define i32 @extern() define i32 @extern() convergent { + %a = call i32 @k() convergent + ret i32 %a +} + +; Convergent should not be removed on the function here. Although the call is +; not explicitly convergent, it picks up the convergent attr from the callee. +; +; CHECK: Function Attrs +; CHECK-SAME: convergent +; CHECK-NEXT: define i32 @extern_non_convergent_call() +define i32 @extern_non_convergent_call() convergent { %a = call i32 @k() ret i32 %a } ; CHECK: Function Attrs ; CHECK-SAME: convergent -; CHECK-NEXT: define i32 @call_extern() -define i32 @call_extern() convergent { - %a = call i32 @extern() - ret i32 %a +; CHECK-NEXT: define i32 @indirect_convergent_call( +define i32 @indirect_convergent_call(i32 ()* %f) convergent { + %a = call i32 %f() convergent + ret i32 %a +} +; Give indirect_non_convergent_call the norecurse attribute so we get a +; "Function Attrs" comment in the output. +; +; CHECK: Function Attrs +; CHECK-NOT: convergent +; CHECK-NEXT: define i32 @indirect_non_convergent_call( +define i32 @indirect_non_convergent_call(i32 ()* %f) convergent norecurse { + %a = call i32 %f() + ret i32 %a } ; CHECK: Function Attrs @@ -45,25 +66,16 @@ ; CHECK-SAME: convergent ; CHECK-NEXT: define i32 @intrinsic() define i32 @intrinsic() convergent { + ; Implicitly convergent, because the intrinsic is convergent. call void @llvm.cuda.syncthreads() ret i32 0 } -@xyz = global i32 ()* null -; CHECK: Function Attrs -; CHECK-SAME: convergent -; CHECK-NEXT: define i32 @functionptr() -define i32 @functionptr() convergent { - %1 = load i32 ()*, i32 ()** @xyz - %2 = call i32 %1() - ret i32 %2 -} - ; CHECK: Function Attrs ; CHECK-NOT: convergent ; CHECK-NEXT: define i32 @recursive1() define i32 @recursive1() convergent { - %a = call i32 @recursive2() + %a = call i32 @recursive2() convergent ret i32 %a } @@ -71,7 +83,7 @@ ; CHECK-NOT: convergent ; CHECK-NEXT: define i32 @recursive2() define i32 @recursive2() convergent { - %a = call i32 @recursive1() + %a = call i32 @recursive1() convergent ret i32 %a } @@ -79,7 +91,7 @@ ; CHECK-SAME: convergent ; CHECK-NEXT: define i32 @noopt() define i32 @noopt() convergent optnone noinline { - %a = call i32 @noopt_friend() + %a = call i32 @noopt_friend() convergent ret i32 0 } Index: test/Transforms/InstCombine/convergent.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/convergent.ll @@ -0,0 +1,33 @@ +; RUN: opt -instcombine -S < %s | FileCheck %s + +declare i32 @k() convergent +declare i32 @f() + +define i32 @extern() { + ; Convergent attr shouldn't be removed here; k is convergent. + ; CHECK: call i32 @k() [[CONVERGENT_ATTR:#[0-9]+]] + %a = call i32 @k() convergent + ret i32 %a +} + +define i32 @extern_no_attr() { + ; Convergent attr shouldn't be added here, even though k is convergent. + ; CHECK: call i32 @k(){{$}} + %a = call i32 @k() + ret i32 %a +} + +define i32 @no_extern() { + ; Convergent should be removed here, as the target is convergent. + ; CHECK: call i32 @f(){{$}} + %a = call i32 @f() convergent + ret i32 %a +} + +define i32 @indirect_call(i32 ()* %f) { + ; CHECK call i32 %f() [[CONVERGENT_ATTR]] + %a = call i32 %f() convergent + ret i32 %a +} + +; CHECK: [[CONVERGENT_ATTR]] = { convergent }