diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -354,89 +354,23 @@ return !is_contained(I->operands(), CI); } -/// Return true if the specified value is the same when the return would exit -/// as it was when the initial iteration of the recursive function was executed. -/// -/// We currently handle static constants and arguments that are not modified as -/// part of the recursion. -static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) { - if (isa(V)) return true; // Static constants are always dyn consts - - // Check to see if this is an immutable argument, if so, the value - // will be available to initialize the accumulator. - if (Argument *Arg = dyn_cast(V)) { - // Figure out which argument number this is... - unsigned ArgNo = 0; - Function *F = CI->getParent()->getParent(); - for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI) - ++ArgNo; - - // If we are passing this argument into call as the corresponding - // argument operand, then the argument is dynamically constant. - // Otherwise, we cannot transform this function safely. - if (CI->getArgOperand(ArgNo) == Arg) - return true; - } - - // Switch cases are always constant integers. If the value is being switched - // on and the return is only reachable from one of its cases, it's - // effectively constant. - if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor()) - if (SwitchInst *SI = dyn_cast(UniquePred->getTerminator())) - if (SI->getCondition() == V) - return SI->getDefaultDest() != RI->getParent(); - - // Not a constant or immutable argument, we can't safely transform. - return false; -} - -/// Check to see if the function containing the specified tail call consistently -/// returns the same runtime-constant value at all exit points except for -/// IgnoreRI. If so, return the returned value. -static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) { - Function *F = CI->getParent()->getParent(); - Value *ReturnedValue = nullptr; - - for (BasicBlock &BBI : *F) { - ReturnInst *RI = dyn_cast(BBI.getTerminator()); - if (RI == nullptr || RI == IgnoreRI) continue; - - // We can only perform this transformation if the value returned is - // evaluatable at the start of the initial invocation of the function, - // instead of at the end of the evaluation. - // - Value *RetOp = RI->getOperand(0); - if (!isDynamicConstant(RetOp, CI, RI)) - return nullptr; - - if (ReturnedValue && RetOp != ReturnedValue) - return nullptr; // Cannot transform if differing values are returned. - ReturnedValue = RetOp; - } - return ReturnedValue; -} +static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { + if (!I->isAssociative() || !I->isCommutative()) + return false; -/// If the specified instruction can be transformed using accumulator recursion -/// elimination, return the constant which is the start of the accumulator -/// value. Otherwise return null. -static Value *canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { - if (!I->isAssociative() || !I->isCommutative()) return nullptr; assert(I->getNumOperands() == 2 && "Associative/commutative operations should have 2 args!"); // Exactly one operand should be the result of the call instruction. if ((I->getOperand(0) == CI && I->getOperand(1) == CI) || (I->getOperand(0) != CI && I->getOperand(1) != CI)) - return nullptr; + return false; // The only user of this instruction we allow is a single return instruction. if (!I->hasOneUse() || !isa(I->user_back())) - return nullptr; + return false; - // Ok, now we have to check all of the other return instructions in this - // function. If they return non-constants or differing values, then we cannot - // transform the function safely. - return getCommonReturnValue(cast(I->user_back()), CI); + return true; } static Instruction *firstNonDbg(BasicBlock::iterator I) { @@ -470,6 +404,16 @@ // to either propagate RetPN or select a new return value. SmallVector RetSelects; + // The below are shared state needed when performing accumulator recursion. + // There values should be populated by insertAccumulator the first time we + // find an elimination that requires an accumulator. + + // PHI node to store our current accumulated value. + PHINode *AccPN = nullptr; + + // The instruction doing the accumulating. + Instruction *AccumulatorRecursionInstr = nullptr; + TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI, AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) @@ -480,7 +424,7 @@ void createTailRecurseLoopHeader(CallInst *CI); - PHINode *insertAccumulator(Value *AccumulatorRecursionEliminationInitVal); + void insertAccumulator(Instruction *AccRecInstr); bool eliminateCall(CallInst *CI); @@ -608,47 +552,44 @@ DTU.recalculate(*NewEntry->getParent()); } -PHINode *TailRecursionEliminator::insertAccumulator( - Value *AccumulatorRecursionEliminationInitVal) { +void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) { + assert(!AccPN && "Trying to insert multiple accumulators"); + + AccumulatorRecursionInstr = AccRecInstr; + // Start by inserting a new PHI node for the accumulator. pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB); - PHINode *AccPN = PHINode::Create( - AccumulatorRecursionEliminationInitVal->getType(), - std::distance(PB, PE) + 1, "accumulator.tr", &HeaderBB->front()); + AccPN = PHINode::Create(F.getReturnType(), std::distance(PB, PE) + 1, + "accumulator.tr", &HeaderBB->front()); // Loop over all of the predecessors of the tail recursion block. For the - // real entry into the function we seed the PHI with the initial value, - // computed earlier. For any other existing branches to this block (due to - // other tail recursions eliminated) the accumulator is not modified. + // real entry into the function we seed the PHI with the identity constant for + // the accumulation operation. For any other existing branches to this block + // (due to other tail recursions eliminated) the accumulator is not modified. // Because we haven't added the branch in the current block to HeaderBB yet, // it will not show up as a predecessor. for (pred_iterator PI = PB; PI != PE; ++PI) { BasicBlock *P = *PI; - if (P == &F.getEntryBlock()) - AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P); - else + if (P == &F.getEntryBlock()) { + Constant *Identity = ConstantExpr::getBinOpIdentity( + AccRecInstr->getOpcode(), AccRecInstr->getType()); + AccPN->addIncoming(Identity, P); + } else { AccPN->addIncoming(AccPN, P); + } } - return AccPN; + ++NumAccumAdded; } bool TailRecursionEliminator::eliminateCall(CallInst *CI) { ReturnInst *Ret = cast(CI->getParent()->getTerminator()); - // If we are introducing accumulator recursion to eliminate operations after - // the call instruction that are both associative and commutative, the initial - // value for the accumulator is placed in this variable. If this value is set - // then we actually perform accumulator recursion elimination instead of - // simple tail recursion elimination. If the operation is an LLVM instruction - // (eg: "add") then it is recorded in AccumulatorRecursionInstr. - Value *AccumulatorRecursionEliminationInitVal = nullptr; - Instruction *AccumulatorRecursionInstr = nullptr; - // Ok, we found a potential tail call. We can currently only transform the // tail call if all of the instructions between the call and the return are // movable to above the call itself, leaving the call next to the return. // Check that this is the case now. + Instruction *AccRecInstr = nullptr; BasicBlock::iterator BBI(CI); for (++BBI; &*BBI != Ret; ++BBI) { if (canMoveAboveCall(&*BBI, CI, AA)) @@ -657,15 +598,13 @@ // If we can't move the instruction above the call, it might be because it // is an associative and commutative operation that could be transformed // using accumulator recursion elimination. Check to see if this is the - // case, and if so, remember the initial accumulator value for later. - if ((AccumulatorRecursionEliminationInitVal = - canTransformAccumulatorRecursion(&*BBI, CI))) { - // Yes, this is accumulator recursion. Remember which instruction - // accumulates. - AccumulatorRecursionInstr = &*BBI; - } else { - return false; // Otherwise, we cannot eliminate the tail recursion! - } + // case, and if so, remember which instruction accumulates for later. + if (AccPN || !canTransformAccumulatorRecursion(&*BBI, CI)) + return false; // We cannot eliminate the tail recursion! + + // Yes, this is accumulator recursion. Remember which instruction + // accumulates. + AccRecInstr = &*BBI; } BasicBlock *BB = Ret->getParent(); @@ -690,37 +629,18 @@ for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) ArgumentPHIs[i]->addIncoming(CI->getArgOperand(i), BB); - // If we are introducing an accumulator variable to eliminate the recursion, - // do so now. Note that we _know_ that no subsequent tail recursion - // eliminations will happen on this function because of the way the - // accumulator recursion predicate is set up. - // - if (AccumulatorRecursionEliminationInitVal) { - PHINode *AccPN = insertAccumulator(AccumulatorRecursionEliminationInitVal); - - Instruction *AccRecInstr = AccumulatorRecursionInstr; - - // Add an incoming argument for the current block, which is computed by - // our associative and commutative accumulator instruction. - AccPN->addIncoming(AccRecInstr, BB); + if (AccRecInstr) { + insertAccumulator(AccRecInstr); - // Next, rewrite the accumulator recursion instruction so that it does not - // use the result of the call anymore, instead, use the PHI node we just + // Rewrite the accumulator recursion instruction so that it does not use + // the result of the call anymore, instead, use the PHI node we just // inserted. AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); - - // Finally, rewrite any return instructions in the program to return the PHI - // node instead of the "initval" that they do currently. This loop will - // actually rewrite the return value we are destroying, but that's ok. - for (BasicBlock &BBI : F) - if (ReturnInst *RI = dyn_cast(BBI.getTerminator())) - RI->setOperand(0, AccPN); - ++NumAccumAdded; } // Update our return value tracking if (RetPN) { - if (Ret->getReturnValue() == CI || AccumulatorRecursionEliminationInitVal) { + if (Ret->getReturnValue() == CI || AccRecInstr) { // Defer selecting a return value RetPN->addIncoming(RetPN, BB); RetKnownPN->addIncoming(RetKnownPN, BB); @@ -735,6 +655,9 @@ RetPN->addIncoming(SI, BB); RetKnownPN->addIncoming(ConstantInt::getTrue(RetKnownPN->getType()), BB); } + + if (AccPN) + AccPN->addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB); } // Now that all of the PHI nodes are in place, remove the call and @@ -829,6 +752,24 @@ RetKnownPN->dropAllReferences(); RetKnownPN->eraseFromParent(); + + if (AccPN) { + // We need to insert a copy of our accumulator instruction before any + // return in the function, and return its result instead. + Instruction *AccRecInstr = AccumulatorRecursionInstr; + for (BasicBlock &BB : F) { + ReturnInst *RI = dyn_cast(BB.getTerminator()); + if (!RI) + continue; + + Instruction *AccRecInstrNew = AccRecInstr->clone(); + AccRecInstrNew->setName("accumulator.ret.tr"); + AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, + RI->getOperand(0)); + AccRecInstrNew->insertBefore(RI); + RI->setOperand(0, AccRecInstrNew); + } + } } else { // We need to insert a select instruction before any return left in the // function to select our stored return value if we have one. @@ -839,8 +780,23 @@ SelectInst *SI = SelectInst::Create( RetKnownPN, RetPN, RI->getOperand(0), "current.ret.tr", RI); + RetSelects.push_back(SI); RI->setOperand(0, SI); } + + if (AccPN) { + // We need to insert a copy of our accumulator instruction before any + // of the selects we inserted, and select its result instead. + Instruction *AccRecInstr = AccumulatorRecursionInstr; + for (SelectInst *SI : RetSelects) { + Instruction *AccRecInstrNew = AccRecInstr->clone(); + AccRecInstrNew->setName("accumulator.ret.tr"); + AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, + SI->getFalseValue()); + AccRecInstrNew->insertBefore(SI); + SI->setFalseValue(AccRecInstrNew); + } + } } } } diff --git a/llvm/test/Transforms/TailCallElim/accum_recursion.ll b/llvm/test/Transforms/TailCallElim/accum_recursion.ll --- a/llvm/test/Transforms/TailCallElim/accum_recursion.ll +++ b/llvm/test/Transforms/TailCallElim/accum_recursion.ll @@ -3,73 +3,222 @@ define i32 @test1_factorial(i32 %x) { entry: - %tmp.1 = icmp sgt i32 %x, 0 ; [#uses=1] + %tmp.1 = icmp sgt i32 %x, 0 br i1 %tmp.1, label %then, label %else -then: ; preds = %entry - %tmp.6 = add i32 %x, -1 ; [#uses=1] - %tmp.4 = call i32 @test1_factorial( i32 %tmp.6 ) ; [#uses=1] - %tmp.7 = mul i32 %tmp.4, %x ; [#uses=1] - ret i32 %tmp.7 -else: ; preds = %entry +then: + %tmp.6 = add i32 %x, -1 + %recurse = call i32 @test1_factorial( i32 %tmp.6 ) + %accumulate = mul i32 %recurse, %x + ret i32 %accumulate +else: ret i32 1 } ; CHECK-LABEL: define i32 @test1_factorial( -; CHECK: phi i32 -; CHECK-NOT: call i32 +; CHECK: tailrecurse: +; CHECK: %accumulator.tr = phi i32 [ 1, %entry ], [ %accumulate, %then ] +; CHECK: then: +; CHECK-NOT: %recurse +; CHECK: %accumulate = mul i32 %accumulator.tr, %x.tr ; CHECK: else: +; CHECK: %accumulator.ret.tr = mul i32 %accumulator.tr, 1 +; CHECK: ret i32 %accumulator.ret.tr ; This is a more aggressive form of accumulator recursion insertion, which ; requires noticing that X doesn't change as we perform the tailcall. define i32 @test2_mul(i32 %x, i32 %y) { entry: - %tmp.1 = icmp eq i32 %y, 0 ; [#uses=1] + %tmp.1 = icmp eq i32 %y, 0 br i1 %tmp.1, label %return, label %endif -endif: ; preds = %entry - %tmp.8 = add i32 %y, -1 ; [#uses=1] - %tmp.5 = call i32 @test2_mul( i32 %x, i32 %tmp.8 ) ; [#uses=1] - %tmp.9 = add i32 %tmp.5, %x ; [#uses=1] - ret i32 %tmp.9 -return: ; preds = %entry +endif: + %tmp.8 = add i32 %y, -1 + %recurse = call i32 @test2_mul( i32 %x, i32 %tmp.8 ) + %accumulate = add i32 %recurse, %x + ret i32 %accumulate +return: ret i32 %x } ; CHECK-LABEL: define i32 @test2_mul( -; CHECK: phi i32 -; CHECK-NOT: call i32 +; CHECK: tailrecurse: +; CHECK: %accumulator.tr = phi i32 [ 0, %entry ], [ %accumulate, %endif ] +; CHECK: endif: +; CHECK-NOT: %recurse +; CHECK: %accumulate = add i32 %accumulator.tr, %x ; CHECK: return: - +; CHECK: %accumulator.ret.tr = add i32 %accumulator.tr, %x +; CHECK: ret i32 %accumulator.ret.tr define i64 @test3_fib(i64 %n) nounwind readnone { -; CHECK-LABEL: @test3_fib( entry: -; CHECK: tailrecurse: -; CHECK: %accumulator.tr = phi i64 [ %n, %entry ], [ %3, %bb1 ] -; CHECK: %n.tr = phi i64 [ %n, %entry ], [ %2, %bb1 ] switch i64 %n, label %bb1 [ -; CHECK: switch i64 %n.tr, label %bb1 [ i64 0, label %bb2 i64 1, label %bb2 ] bb1: -; CHECK: bb1: %0 = add i64 %n, -1 -; CHECK: %0 = add i64 %n.tr, -1 - %1 = tail call i64 @test3_fib(i64 %0) nounwind -; CHECK: %1 = tail call i64 @test3_fib(i64 %0) - %2 = add i64 %n, -2 -; CHECK: %2 = add i64 %n.tr, -2 - %3 = tail call i64 @test3_fib(i64 %2) nounwind -; CHECK-NOT: tail call i64 @test3_fib - %4 = add nsw i64 %3, %1 -; CHECK: add nsw i64 %accumulator.tr, %1 - ret i64 %4 -; CHECK: br label %tailrecurse + %recurse1 = tail call i64 @test3_fib(i64 %0) nounwind + %1 = add i64 %n, -2 + %recurse2 = tail call i64 @test3_fib(i64 %1) nounwind + %accumulate = add nsw i64 %recurse2, %recurse1 + ret i64 %accumulate bb2: -; CHECK: bb2: ret i64 %n -; CHECK: ret i64 %accumulator.tr } + +; CHECK-LABEL: define i64 @test3_fib( +; CHECK: tailrecurse: +; CHECK: %accumulator.tr = phi i64 [ 0, %entry ], [ %accumulate, %bb1 ] +; CHECK: bb1: +; CHECK-NOT: %recurse2 +; CHECK: %accumulate = add nsw i64 %accumulator.tr, %recurse1 +; CHECK: bb2: +; CHECK: %accumulator.ret.tr = add nsw i64 %accumulator.tr, %n.tr +; CHECK: ret i64 %accumulator.ret.tr + +define i32 @test4_base_case_call() local_unnamed_addr { +entry: + %base = call i32 @test4_helper() + switch i32 %base, label %sw.default [ + i32 1, label %cleanup + i32 5, label %cleanup + i32 7, label %cleanup + ] + +sw.default: + %recurse = call i32 @test4_base_case_call() + %accumulate = add nsw i32 %recurse, 1 + br label %cleanup + +cleanup: + %retval.0 = phi i32 [ %accumulate, %sw.default ], [ %base, %entry ], [ %base, %entry ], [ %base, %entry ] + ret i32 %retval.0 +} + +declare i32 @test4_helper() + +; CHECK-LABEL: define i32 @test4_base_case_call( +; CHECK: tailrecurse: +; CHECK: %accumulator.tr = phi i32 [ 0, %entry ], [ %accumulate, %sw.default ] +; CHECK: sw.default: +; CHECK-NOT: %recurse +; CHECK: %accumulate = add nsw i32 %accumulator.tr, 1 +; CHECK: cleanup: +; CHECK: %accumulator.ret.tr = add nsw i32 %accumulator.tr, %base +; CHECK: ret i32 %accumulator.ret.tr + +define i32 @test5_base_case_load(i32* nocapture %A, i32 %n) local_unnamed_addr { +entry: + %cmp = icmp eq i32 %n, 0 + br i1 %cmp, label %if.then, label %if.end + +if.then: + %base = load i32, i32* %A, align 4 + ret i32 %base + +if.end: + %idxprom = zext i32 %n to i64 + %arrayidx1 = getelementptr inbounds i32, i32* %A, i64 %idxprom + %load = load i32, i32* %arrayidx1, align 4 + %sub = add i32 %n, -1 + %recurse = tail call i32 @test5_base_case_load(i32* %A, i32 %sub) + %accumulate = add i32 %recurse, %load + ret i32 %accumulate +} + +; CHECK-LABEL: define i32 @test5_base_case_load( +; CHECK: tailrecurse: +; CHECK: %accumulator.tr = phi i32 [ 0, %entry ], [ %accumulate, %if.end ] +; CHECK: if.then: +; CHECK: %accumulator.ret.tr = add i32 %accumulator.tr, %base +; CHECK: ret i32 %accumulator.ret.tr +; CHECK: if.end: +; CHECK-NOT: %recurse +; CHECK: %accumulate = add i32 %accumulator.tr, %load + +define i32 @test6_multiple_returns(i32 %x, i32 %y) local_unnamed_addr { +entry: + switch i32 %x, label %default [ + i32 0, label %case0 + i32 99, label %case99 + ] + +case0: + %helper = call i32 @test6_helper() + ret i32 %helper + +case99: + %sub1 = add i32 %x, -1 + %recurse1 = call i32 @test6_multiple_returns(i32 %sub1, i32 %y) + ret i32 18 + +default: + %sub2 = add i32 %x, -1 + %recurse2 = call i32 @test6_multiple_returns(i32 %sub2, i32 %y) + %accumulate = add i32 %recurse2, %y + ret i32 %accumulate +} + +declare i32 @test6_helper() + +; CHECK-LABEL: define i32 @test6_multiple_returns( +; CHECK: tailrecurse: +; CHECK: %accumulator.tr = phi i32 [ %accumulator.tr, %case99 ], [ 0, %entry ], [ %accumulate, %default ] +; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %case99 ], [ %ret.tr, %default ] +; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %case99 ], [ %ret.known.tr, %default ] +; CHECK: case0: +; CHECK: %accumulator.ret.tr2 = add i32 %accumulator.tr, %helper +; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 %accumulator.ret.tr2 +; CHECK: case99: +; CHECK-NOT: %recurse +; CHECK: %accumulator.ret.tr = add i32 %accumulator.tr, 18 +; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 %accumulator.ret.tr +; CHECK: default: +; CHECK-NOT: %recurse +; CHECK: %accumulate = add i32 %accumulator.tr, %y + +; It is only safe to transform one accumulator per function, make sure we don't +; try to remove more. + +define i32 @test7_multiple_accumulators(i32 %a) local_unnamed_addr { +entry: + %tobool = icmp eq i32 %a, 0 + br i1 %tobool, label %return, label %if.end + +if.end: + %and = and i32 %a, 1 + %tobool1 = icmp eq i32 %and, 0 + %sub = add nsw i32 %a, -1 + br i1 %tobool1, label %if.end3, label %if.then2 + +if.then2: + %recurse1 = tail call i32 @test7_multiple_accumulators(i32 %sub) + %accumulate1 = add nsw i32 %recurse1, 1 + br label %return + +if.end3: + %recurse2 = tail call i32 @test7_multiple_accumulators(i32 %sub) + %accumulate2 = mul nsw i32 %recurse2, 2 + br label %return + +return: + %retval.0 = phi i32 [ %accumulate1, %if.then2 ], [ %accumulate2, %if.end3 ], [ 0, %entry ] + ret i32 %retval.0 +} + +; CHECK-LABEL: define i32 @test7_multiple_accumulators( +; CHECK: tailrecurse: +; CHECK: %accumulator.tr = phi i32 [ 0, %entry ], [ %accumulate1, %if.then2 ] +; CHECK: if.then2: +; CHECK-NOT: %recurse1 +; CHECK: %accumulate1 = add nsw i32 %accumulator.tr, 1 +; CHECK: if.end3: +; CHECK: %recurse2 +; CHECK: %accumulator.ret.tr = add nsw i32 %accumulator.tr, %accumulate2 +; CHECK: ret i32 %accumulator.ret.tr +; CHECK: return: +; CHECK: %accumulator.ret.tr1 = add nsw i32 %accumulator.tr, 0 +; CHECK: ret i32 %accumulator.ret.tr1