Index: llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp =================================================================== --- llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -360,8 +360,10 @@ /// /// We currently handle static constants and arguments that are not modified as /// part of the recursion. -static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) { - if (isa(V)) return true; // Static constants are always dyn consts +static Value *getDynamicConstantOrReplacement(Value *V, CallInst *CI, + ReturnInst *RI) { + if (isa(V)) + return V; // Static constants are always dyn consts // Check to see if this is an immutable argument, if so, the value // will be available to initialize the accumulator. @@ -376,7 +378,7 @@ // argument operand, then the argument is dynamically constant. // Otherwise, we cannot transform this function safely. if (CI->getArgOperand(ArgNo) == Arg) - return true; + return Arg; } // Switch cases are always constant integers. If the value is being switched @@ -385,10 +387,10 @@ if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor()) if (SwitchInst *SI = dyn_cast(UniquePred->getTerminator())) if (SI->getCondition() == V) - return SI->getDefaultDest() != RI->getParent(); + return SI->findCaseDest(RI->getParent()); // Not a constant or immutable argument, we can't safely transform. - return false; + return nullptr; } /// Check to see if the function containing the specified tail call consistently @@ -406,13 +408,13 @@ // evaluatable at the start of the initial invocation of the function, // instead of at the end of the evaluation. // - Value *RetOp = RI->getOperand(0); - if (!isDynamicConstant(RetOp, CI, RI)) + Value *V = getDynamicConstantOrReplacement(RI->getReturnValue(), CI, RI); + if (!V) return nullptr; - if (ReturnedValue && RetOp != ReturnedValue) + if (ReturnedValue && V != ReturnedValue) return nullptr; // Cannot transform if differing values are returned. - ReturnedValue = RetOp; + ReturnedValue = V; } return ReturnedValue; } @@ -505,15 +507,21 @@ // the call instruction that are both associative and commutative, the initial // value for the accumulator is placed in this variable. If this value is set // then we actually perform accumulator recursion elimination instead of - // simple tail recursion elimination. If the operation is an LLVM instruction - // (eg: "add") then it is recorded in AccumulatorRecursionInstr. If not, then - // we are handling the case when the return instruction returns a constant C - // which is different to the constant returned by other return instructions - // (which is recorded in AccumulatorRecursionEliminationInitVal). This is a - // special case of accumulator recursion, the operation being "return C". + // simple tail recursion elimination. Value *AccumulatorRecursionEliminationInitVal = nullptr; + + // If the operation is an LLVM instruction (eg: "add") then it is recorded in + // AccumulatorRecursionInstr. Instruction *AccumulatorRecursionInstr = nullptr; + // If not, then we are handling the case when the return instruction returns + // a constant C which is different to the constant returned by other return + // instructions (which is recorded in AccumulatorRecursionEliminationInitVal). + // This is a special case of accumulator recursion, the operation being + // "return C". In this case store the returned constant C in + // AccumulatorRecursionEliminationReturnedConstant + Value *AccumulatorRecursionEliminationReturnedConstant = nullptr; + // Ok, we found a potential tail call. We can currently only transform the // tail call if all of the instructions between the call and the return are // movable to above the call itself, leaving the call next to the return. @@ -548,8 +556,12 @@ // One case remains that we are able to handle: the current return // instruction returns a constant, and all other return instructions // return a different constant. - if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret)) - return false; // Current return instruction does not return a constant. + AccumulatorRecursionEliminationReturnedConstant = + getDynamicConstantOrReplacement(Ret->getReturnValue(), CI, Ret); + + if (!AccumulatorRecursionEliminationReturnedConstant) + return false; + // Check that all other return instructions return a common constant. If // so, record it in AccumulatorRecursionEliminationInitVal. AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI); @@ -660,7 +672,7 @@ } else { // Add an incoming argument for the current block, which is just the // constant returned by the current return instruction. - AccPN->addIncoming(Ret->getReturnValue(), BB); + AccPN->addIncoming(AccumulatorRecursionEliminationReturnedConstant, BB); } // Finally, rewrite any return instructions in the program to return the PHI Index: llvm/test/Transforms/TailCallElim/accum_recursion.ll =================================================================== --- llvm/test/Transforms/TailCallElim/accum_recursion.ll +++ llvm/test/Transforms/TailCallElim/accum_recursion.ll @@ -40,36 +40,28 @@ ; CHECK-NOT: call i32 ; CHECK: return: - -define i64 @test3_fib(i64 %n) nounwind readnone { -; CHECK-LABEL: @test3_fib( +define i32 @test3_switch() local_unnamed_addr { entry: -; CHECK: tailrecurse: -; CHECK: %accumulator.tr = phi i64 [ %n, %entry ], [ %3, %bb1 ] -; CHECK: %n.tr = phi i64 [ %n, %entry ], [ %2, %bb1 ] - switch i64 %n, label %bb1 [ -; CHECK: switch i64 %n.tr, label %bb1 [ - i64 0, label %bb2 - i64 1, label %bb2 + %call = call i32 @test3_call() + switch i32 %call, label %sw.default [ + i32 1, label %cleanup ] -bb1: -; CHECK: bb1: - %0 = add i64 %n, -1 -; CHECK: %0 = add i64 %n.tr, -1 - %1 = tail call i64 @test3_fib(i64 %0) nounwind -; CHECK: %1 = tail call i64 @test3_fib(i64 %0) - %2 = add i64 %n, -2 -; CHECK: %2 = add i64 %n.tr, -2 - %3 = tail call i64 @test3_fib(i64 %2) nounwind -; CHECK-NOT: tail call i64 @test3_fib - %4 = add nsw i64 %3, %1 -; CHECK: add nsw i64 %accumulator.tr, %1 - ret i64 %4 -; CHECK: br label %tailrecurse +sw.default: + %call1 = call i32 @test3_switch() + %add = add nsw i32 %call1, 1 + br label %cleanup -bb2: -; CHECK: bb2: - ret i64 %n -; CHECK: ret i64 %accumulator.tr +cleanup: + %retval.0 = phi i32 [ %add, %sw.default ], [ %call, %entry ] + ret i32 %retval.0 } + +declare i32 @test3_call() + +; CHECK-LABEL: define i32 @test3_switch( +; CHECK: tailrecurse: +; CHECK: %accumulator.tr = phi i32 [ 1, %entry ], +; CHECK: sw.default: +; CHECK-NOT: call i32 +; CHECK: cleanup: