diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -460,6 +460,16 @@ SmallVector ArgumentPHIs; bool RemovableCallsMustBeMarkedTail = false; + // PHI node to store our return value. + PHINode *RetPN = nullptr; + + // i1 PHI node to track if we have a valid return value stored in RetPN. + PHINode *RetKnownPN = nullptr; + + // Vector of select instructions we insereted. These selects use RetKnownPN + // to either propagate RetPN or select a new return value. + SmallVector RetSelects; + TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI, AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) @@ -577,6 +587,21 @@ PN->addIncoming(&*I, NewEntry); ArgumentPHIs.push_back(PN); } + + // If the function doen't return void, create the RetPN and RetKnownPN PHI + // nodes to track our return value. We initialize RetPN with undef and + // RetKnownPN with false since we can't know our return value at function + // entry. + Type *RetType = F.getReturnType(); + if (!RetType->isVoidTy()) { + Type *BoolType = Type::getInt1Ty(F.getContext()); + RetPN = PHINode::Create(RetType, 2, "ret.tr", InsertPos); + RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr", InsertPos); + + RetPN->addIncoming(UndefValue::get(RetType), NewEntry); + RetKnownPN->addIncoming(ConstantInt::getFalse(BoolType), NewEntry); + } + // The entry block was changed from HeaderBB to NewEntry. // The forward DominatorTree needs to be recalculated when the EntryBB is // changed. In this corner-case we recalculate the entire tree. @@ -616,11 +641,7 @@ // value for the accumulator is placed in this variable. If this value is set // then we actually perform accumulator recursion elimination instead of // simple tail recursion elimination. If the operation is an LLVM instruction - // (eg: "add") then it is recorded in AccumulatorRecursionInstr. If not, then - // we are handling the case when the return instruction returns a constant C - // which is different to the constant returned by other return instructions - // (which is recorded in AccumulatorRecursionEliminationInitVal). This is a - // special case of accumulator recursion, the operation being "return C". + // (eg: "add") then it is recorded in AccumulatorRecursionInstr. Value *AccumulatorRecursionEliminationInitVal = nullptr; Instruction *AccumulatorRecursionInstr = nullptr; @@ -647,26 +668,6 @@ } } - // We can only transform call/return pairs that either ignore the return value - // of the call and return void, ignore the value of the call and return a - // constant, return the value returned by the tail call, or that are being - // accumulator recursion variable eliminated. - if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI && - !isa(Ret->getReturnValue()) && - AccumulatorRecursionEliminationInitVal == nullptr && - !getCommonReturnValue(nullptr, CI)) { - // One case remains that we are able to handle: the current return - // instruction returns a constant, and all other return instructions - // return a different constant. - if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret)) - return false; // Current return instruction does not return a constant. - // Check that all other return instructions return a common constant. If - // so, record it in AccumulatorRecursionEliminationInitVal. - AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI); - if (!AccumulatorRecursionEliminationInitVal) - return false; - } - BasicBlock *BB = Ret->getParent(); using namespace ore; @@ -698,20 +699,15 @@ PHINode *AccPN = insertAccumulator(AccumulatorRecursionEliminationInitVal); Instruction *AccRecInstr = AccumulatorRecursionInstr; - if (AccRecInstr) { - // Add an incoming argument for the current block, which is computed by - // our associative and commutative accumulator instruction. - AccPN->addIncoming(AccRecInstr, BB); - - // Next, rewrite the accumulator recursion instruction so that it does not - // use the result of the call anymore, instead, use the PHI node we just - // inserted. - AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); - } else { - // Add an incoming argument for the current block, which is just the - // constant returned by the current return instruction. - AccPN->addIncoming(Ret->getReturnValue(), BB); - } + + // Add an incoming argument for the current block, which is computed by + // our associative and commutative accumulator instruction. + AccPN->addIncoming(AccRecInstr, BB); + + // Next, rewrite the accumulator recursion instruction so that it does not + // use the result of the call anymore, instead, use the PHI node we just + // inserted. + AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); // Finally, rewrite any return instructions in the program to return the PHI // node instead of the "initval" that they do currently. This loop will @@ -722,6 +718,25 @@ ++NumAccumAdded; } + // Update our return value tracking + if (RetPN) { + if (Ret->getReturnValue() == CI || AccumulatorRecursionEliminationInitVal) { + // Defer selecting a return value + RetPN->addIncoming(RetPN, BB); + RetKnownPN->addIncoming(RetKnownPN, BB); + } else { + // We found a return value we want to use, insert a select instruction to + // select it if we don't already know what our return value will be and + // store the result in our return value PHI node. + SelectInst *SI = SelectInst::Create( + RetKnownPN, RetPN, Ret->getReturnValue(), "current.ret.tr", Ret); + RetSelects.push_back(SI); + + RetPN->addIncoming(SI, BB); + RetKnownPN->addIncoming(ConstantInt::getTrue(RetKnownPN->getType()), BB); + } + } + // Now that all of the PHI nodes are in place, remove the call and // ret instructions, replacing them with an unconditional branch. BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret); @@ -804,6 +819,30 @@ PN->eraseFromParent(); } } + + if (RetPN) { + if (RetSelects.empty()) { + // If we didn't insert any select instructions, then we know we didn't + // store a return value and we can remove the PHI nodes we inserted. + RetPN->dropAllReferences(); + RetPN->eraseFromParent(); + + RetKnownPN->dropAllReferences(); + RetKnownPN->eraseFromParent(); + } else { + // We need to insert a select instruction before any return left in the + // function to select our stored return value if we have one. + for (BasicBlock &BB : F) { + ReturnInst *RI = dyn_cast(BB.getTerminator()); + if (!RI) + continue; + + SelectInst *SI = SelectInst::Create( + RetKnownPN, RetPN, RI->getOperand(0), "current.ret.tr", RI); + RI->setOperand(0, SI); + } + } + } } bool TailRecursionEliminator::eliminate(Function &F, diff --git a/llvm/test/Transforms/TailCallElim/2010-06-26-MultipleReturnValues.ll b/llvm/test/Transforms/TailCallElim/2010-06-26-MultipleReturnValues.ll --- a/llvm/test/Transforms/TailCallElim/2010-06-26-MultipleReturnValues.ll +++ b/llvm/test/Transforms/TailCallElim/2010-06-26-MultipleReturnValues.ll @@ -1,20 +1,112 @@ ; RUN: opt < %s -tailcallelim -verify-dom-info -S | FileCheck %s ; PR7328 ; PR7506 -define i32 @foo(i32 %x) { -; CHECK-LABEL: define i32 @foo( -; CHECK: %accumulator.tr = phi i32 [ 1, %entry ], [ 0, %body ] +define i32 @test1_constants(i32 %x) { entry: %cond = icmp ugt i32 %x, 0 ; [#uses=1] br i1 %cond, label %return, label %body body: ; preds = %entry %y = add i32 %x, 1 ; [#uses=1] - %tmp = call i32 @foo(i32 %y) ; [#uses=0] -; CHECK-NOT: call + %recurse = call i32 @test1_constants(i32 %y) ; [#uses=0] ret i32 0 -; CHECK: ret i32 %accumulator.tr return: ; preds = %entry ret i32 1 } + +; CHECK-LABEL: define i32 @test1_constants( +; CHECK: tailrecurse: +; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %body ] +; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %body ] +; CHECK: body: +; CHECK-NOT: %recurse +; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 0 +; CHECK-NOT: ret +; CHECK: return: +; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 1 +; CHECK: ret i32 %current.ret.tr1 + +define i32 @test2_non_constants(i32 %x) { +entry: + %cond = icmp ugt i32 %x, 0 + br i1 %cond, label %return, label %body + +body: + %y = add i32 %x, 1 + %helper1 = call i32 @test2_helper() + %recurse = call i32 @test2_non_constants(i32 %y) + ret i32 %helper1 + +return: + %helper2 = call i32 @test2_helper() + ret i32 %helper2 +} + +declare i32 @test2_helper() + +; CHECK-LABEL: define i32 @test2_non_constants( +; CHECK: tailrecurse: +; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %body ] +; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %body ] +; CHECK: body: +; CHECK-NOT: %recurse +; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 %helper1 +; CHECK-NOT: ret +; CHECK: return: +; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 %helper2 +; CHECK: ret i32 %current.ret.tr1 + +define i32 @test3_mixed(i32 %x) { +entry: + switch i32 %x, label %default [ + i32 0, label %case0 + i32 1, label %case1 + i32 2, label %case2 + ] + +case0: + %helper1 = call i32 @test3_helper() + br label %return + +case1: + %y1 = add i32 %x, -1 + %recurse1 = call i32 @test3_mixed(i32 %y1) + br label %return + +case2: + %y2 = add i32 %x, -1 + %helper2 = call i32 @test3_helper() + %recurse2 = call i32 @test3_mixed(i32 %y2) + br label %return + +default: + %y3 = urem i32 %x, 3 + %recurse3 = call i32 @test3_mixed(i32 %y3) + br label %return + +return: + %retval = phi i32 [ %recurse3, %default ], [ %helper2, %case2 ], [ 9, %case1 ], [ %helper1, %case0 ] + ret i32 %retval +} + +declare i32 @test3_helper() + +; CHECK-LABEL: define i32 @test3_mixed( +; CHECK: tailrecurse: +; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %case1 ], [ %current.ret.tr1, %case2 ], [ %ret.tr, %default ] +; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %case1 ], [ true, %case2 ], [ %ret.known.tr, %default ] +; CHECK: case1: +; CHECK-NOT: %recurse +; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 9 +; CHECK: br label %tailrecurse +; CHECK: case2: +; CHECK-NOT: %recurse +; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 %helper2 +; CHECK: br label %tailrecurse +; CHECK: default: +; CHECK-NOT: %recurse +; CHECK: br label %tailrecurse +; CHECK: return: +; CHECK: %current.ret.tr2 = select i1 %ret.known.tr, i32 %ret.tr, i32 %helper1 +; CHECK: ret i32 %current.ret.tr2 diff --git a/llvm/test/Transforms/TailCallElim/basic.ll b/llvm/test/Transforms/TailCallElim/basic.ll --- a/llvm/test/Transforms/TailCallElim/basic.ll +++ b/llvm/test/Transforms/TailCallElim/basic.ll @@ -46,8 +46,16 @@ ; plunked it into the demo script, so maybe they care about it. define i32 @test3(i32 %c) { ; CHECK: i32 @test3 +; CHECK: tailrecurse: +; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %else ] +; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %else ] +; CHECK: else: ; CHECK-NOT: call -; CHECK: ret i32 0 +; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 0 +; CHECK-NOT: ret +; CHECK: return: +; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 0 +; CHECK: ret i32 %current.ret.tr1 entry: %tmp.1 = icmp eq i32 %c, 0 ; [#uses=1] br i1 %tmp.1, label %return, label %else