diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h --- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h +++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h @@ -206,10 +206,9 @@ /// is a function argument. Constant *getConstantStackValue(CallInst *Call, Value *Val); - /// Iterate over the argument tracked functions see if there - /// are any new constant values for the call instruction via - /// stack variables. - void promoteConstantStackValues(); + /// See if there are any new constant values for the callers of \p F via + /// stack variables and promote them to global variables. + void promoteConstantStackValues(Function *F); /// Clean up fully specialized functions. void removeDeadFunctions(); @@ -217,9 +216,6 @@ /// Remove any ssa_copy intrinsics that may have been introduced. void cleanUpSSA(); - // Compute the code metrics for function \p F. - CodeMetrics &analyzeFunction(Function *F); - /// @brief Find potential specialization opportunities. /// @param F Function to specialize /// @param SpecCost Cost of specializing a function. Final score is benefit @@ -238,9 +234,6 @@ /// @return The new, cloned function Function *createSpecialization(Function *F, const SpecSig &S); - /// Compute and return the cost of specializing function \p F. - Cost getSpecializationCost(Function *F); - /// Determine if it is possible to specialise the function for constant values /// of the formal parameter \p A. bool isArgumentInteresting(Argument *A); diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -397,49 +397,37 @@ // ret void // } // -void FunctionSpecializer::promoteConstantStackValues() { - // Iterate over the argument tracked functions see if there - // are any new constant values for the call instruction via - // stack variables. - for (Function &F : M) { - if (!Solver.isArgumentTrackedFunction(&F)) +// See if there are any new constant values for the callers of \p F via +// stack variables and promote them to global variables. +void FunctionSpecializer::promoteConstantStackValues(Function *F) { + for (User *U : F->users()) { + + auto *Call = dyn_cast(U); + if (!Call) continue; - for (auto *User : F.users()) { + if (!Solver.isBlockExecutable(Call->getParent())) + continue; - auto *Call = dyn_cast(User); - if (!Call) - continue; + for (const Use &U : Call->args()) { + unsigned Idx = Call->getArgOperandNo(&U); + Value *ArgOp = Call->getArgOperand(Idx); + Type *ArgOpType = ArgOp->getType(); - if (!Solver.isBlockExecutable(Call->getParent())) + if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy()) continue; - bool Changed = false; - for (const Use &U : Call->args()) { - unsigned Idx = Call->getArgOperandNo(&U); - Value *ArgOp = Call->getArgOperand(Idx); - Type *ArgOpType = ArgOp->getType(); - - if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy()) - continue; - - auto *ConstVal = getConstantStackValue(Call, ArgOp); - if (!ConstVal) - continue; - - Value *GV = new GlobalVariable(M, ConstVal->getType(), true, - GlobalValue::InternalLinkage, ConstVal, - "funcspec.arg"); - if (ArgOpType != ConstVal->getType()) - GV = ConstantExpr::getBitCast(cast(GV), ArgOpType); + auto *ConstVal = getConstantStackValue(Call, ArgOp); + if (!ConstVal) + continue; - Call->setArgOperand(Idx, GV); - Changed = true; - } + Value *GV = new GlobalVariable(M, ConstVal->getType(), true, + GlobalValue::InternalLinkage, ConstVal, + "funcspec.arg"); + if (ArgOpType != ConstVal->getType()) + GV = ConstantExpr::getBitCast(cast(GV), ArgOpType); - // Add the changed CallInst to Solver Worklist - if (Changed) - Solver.visitCall(*Call); + Call->setArgOperand(Idx, GV); } } } @@ -504,17 +492,37 @@ if (!isCandidateFunction(&F)) continue; - Cost SpecCost = getSpecializationCost(&F); - if (!SpecCost.isValid()) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialization cost for " - << F.getName() << "\n"); - continue; + auto [It, Inserted] = FunctionMetrics.try_emplace(&F); + CodeMetrics &Metrics = It->second; + //Analyze the function. + if (Inserted) { + SmallPtrSet EphValues; + CodeMetrics::collectEphemeralValues(&F, &GetAC(F), EphValues); + for (BasicBlock &BB : F) + Metrics.analyzeBasicBlock(&BB, GetTTI(F), EphValues); } + // If the code metrics reveal that we shouldn't duplicate the function, + // or if the code size implies that this function is easy to get inlined, + // then we shouldn't specialize it. + if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() || + (!ForceSpecialization && !F.hasFnAttribute(Attribute::NoInline) && + Metrics.NumInsts < MinFunctionSize)) + continue; + + // TODO: For now only consider recursive functions when running multiple + // times. This should change if specialization on literal constants gets + // enabled. + if (!Inserted && !Metrics.isRecursive && !SpecializeLiteralConstant) + continue; + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for " - << F.getName() << " is " << SpecCost << "\n"); + << F.getName() << " is " << Metrics.NumInsts << "\n"); + + if (Inserted && Metrics.isRecursive) + promoteConstantStackValues(&F); - if (!findSpecializations(&F, SpecCost, AllSpecs, SM)) { + if (!findSpecializations(&F, Metrics.NumInsts, AllSpecs, SM)) { LLVM_DEBUG( dbgs() << "FnSpecialization: No possible specializations found for " << F.getName() << "\n"); @@ -622,7 +630,10 @@ // Rerun the solver to notify the users of the modified callsites. Solver.solveWhileResolvedUndefs(); - promoteConstantStackValues(); + for (Function *F : OriginalFuncs) + if (FunctionMetrics[F].isRecursive) + promoteConstantStackValues(F); + return true; } @@ -637,20 +648,6 @@ FullySpecialized.clear(); } -// Compute the code metrics for function \p F. -CodeMetrics &FunctionSpecializer::analyzeFunction(Function *F) { - auto I = FunctionMetrics.insert({F, CodeMetrics()}); - CodeMetrics &Metrics = I.first->second; - if (I.second) { - // The code metrics were not cached. - SmallPtrSet EphValues; - CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); - for (BasicBlock &BB : *F) - Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); - } - return Metrics; -} - /// Clone the function \p F and remove the ssa_copy intrinsics added by /// the SCCPSolver in the cloned version. static Function *cloneCandidateFunction(Function *F) { @@ -802,23 +799,6 @@ return Clone; } -/// Compute and return the cost of specializing function \p F. -Cost FunctionSpecializer::getSpecializationCost(Function *F) { - CodeMetrics &Metrics = analyzeFunction(F); - // If the code metrics reveal that we shouldn't duplicate the function, we - // shouldn't specialize it. Set the specialization cost to Invalid. - // Or if the lines of codes implies that this function is easy to get - // inlined so that we shouldn't specialize it. - if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() || - (!ForceSpecialization && !F->hasFnAttribute(Attribute::NoInline) && - Metrics.NumInsts < MinFunctionSize)) - return InstructionCost::getInvalid(); - - // Otherwise, set the specialization cost to be the cost of all the - // instructions in the function. - return Metrics.NumInsts; -} - /// Compute a bonus for replacing argument \p A with constant \p C. Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, InstCostVisitor &Visitor) { diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll deleted file mode 100644 --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll +++ /dev/null @@ -1,59 +0,0 @@ -; RUN: opt -passes="ipsccp,inline,instcombine" -force-specialization -funcspec-max-iters=2 -S < %s | FileCheck %s --check-prefix=ITERS2 -; RUN: opt -passes="ipsccp,inline,instcombine" -force-specialization -funcspec-max-iters=3 -S < %s | FileCheck %s --check-prefix=ITERS3 -; RUN: opt -passes="ipsccp,inline,instcombine" -force-specialization -funcspec-max-iters=4 -S < %s | FileCheck %s --check-prefix=ITERS4 - -@low = internal constant i32 0, align 4 -@high = internal constant i32 6, align 4 - -define internal void @recursiveFunc(ptr nocapture readonly %lo, i32 %step, ptr nocapture readonly %hi) { - %lo.temp = alloca i32, align 4 - %hi.temp = alloca i32, align 4 - %lo.load = load i32, ptr %lo, align 4 - %hi.load = load i32, ptr %hi, align 4 - %cmp = icmp ne i32 %lo.load, %hi.load - br i1 %cmp, label %block6, label %ret.block - -block6: - call void @print_val(i32 %lo.load, i32 %hi.load) - %add = add nsw i32 %lo.load, %step - %sub = sub nsw i32 %hi.load, %step - store i32 %add, ptr %lo.temp, align 4 - store i32 %sub, ptr %hi.temp, align 4 - call void @recursiveFunc(ptr nonnull %lo.temp, i32 %step, ptr nonnull %hi.temp) - br label %ret.block - -ret.block: - ret void -} - -; ITERS2: @funcspec.arg.4 = internal constant i32 2 -; ITERS2: @funcspec.arg.5 = internal constant i32 4 - -; ITERS3: @funcspec.arg.7 = internal constant i32 3 -; ITERS3: @funcspec.arg.8 = internal constant i32 3 - -define i32 @main() { -; ITERS2-LABEL: @main( -; ITERS2-NEXT: call void @print_val(i32 0, i32 6) -; ITERS2-NEXT: call void @print_val(i32 1, i32 5) -; ITERS2-NEXT: call void @recursiveFunc(ptr nonnull @funcspec.arg.4, i32 1, ptr nonnull @funcspec.arg.5) -; ITERS2-NEXT: ret i32 0 -; -; ITERS3-LABEL: @main( -; ITERS3-NEXT: call void @print_val(i32 0, i32 6) -; ITERS3-NEXT: call void @print_val(i32 1, i32 5) -; ITERS3-NEXT: call void @print_val(i32 2, i32 4) -; ITERS3-NEXT: call void @recursiveFunc(ptr nonnull @funcspec.arg.7, i32 1, ptr nonnull @funcspec.arg.8) -; ITERS3-NEXT: ret i32 0 -; -; ITERS4-LABEL: @main( -; ITERS4-NEXT: call void @print_val(i32 0, i32 6) -; ITERS4-NEXT: call void @print_val(i32 1, i32 5) -; ITERS4-NEXT: call void @print_val(i32 2, i32 4) -; ITERS4-NEXT: ret i32 0 -; - call void @recursiveFunc(ptr nonnull @low, i32 1, ptr nonnull @high) - ret i32 0 -} - -declare dso_local void @print_val(i32, i32) diff --git a/llvm/test/Transforms/FunctionSpecialization/promoteContantStackValues.ll b/llvm/test/Transforms/FunctionSpecialization/promoteContantStackValues.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/promoteContantStackValues.ll @@ -0,0 +1,95 @@ +; RUN: opt -passes="ipsccp,inline,instcombine" -force-specialization -funcspec-max-iters=1 -S < %s | FileCheck %s --check-prefix=ITERS1 +; RUN: opt -passes="ipsccp,inline,instcombine" -force-specialization -funcspec-max-iters=2 -S < %s | FileCheck %s --check-prefix=ITERS2 +; RUN: opt -passes="ipsccp,inline,instcombine" -force-specialization -funcspec-max-iters=3 -S < %s | FileCheck %s --check-prefix=ITERS3 +; RUN: opt -passes="ipsccp,inline,instcombine" -force-specialization -funcspec-max-iters=4 -S < %s | FileCheck %s --check-prefix=ITERS4 + +@low = internal constant i32 0, align 4 +@high = internal constant i32 6, align 4 + +define internal void @recursiveFunc(ptr nocapture readonly %lo, i32 %step, ptr nocapture readonly %hi) { + %lo.temp = alloca i32, align 4 + %hi.temp = alloca i32, align 4 + %lo.load = load i32, ptr %lo, align 4 + %hi.load = load i32, ptr %hi, align 4 + %cmp = icmp ne i32 %lo.load, %hi.load + br i1 %cmp, label %block6, label %ret.block + +block6: + call void @print_val(i32 %lo.load, i32 %hi.load) + %add = add nsw i32 %lo.load, %step + %sub = sub nsw i32 %hi.load, %step + store i32 %add, ptr %lo.temp, align 4 + store i32 %sub, ptr %hi.temp, align 4 + call void @recursiveFunc(ptr nonnull %lo.temp, i32 %step, ptr nonnull %hi.temp) + br label %ret.block + +ret.block: + ret void +} + +; ITERS1: @funcspec.arg = internal constant i32 0 +; ITERS1: @funcspec.arg.1 = internal constant i32 6 +; ITERS1: @funcspec.arg.3 = internal constant i32 1 +; ITERS1: @funcspec.arg.4 = internal constant i32 5 + +; ITERS2: @funcspec.arg = internal constant i32 0 +; ITERS2: @funcspec.arg.1 = internal constant i32 6 +; ITERS2: @funcspec.arg.3 = internal constant i32 1 +; ITERS2: @funcspec.arg.4 = internal constant i32 5 +; ITERS2: @funcspec.arg.6 = internal constant i32 2 +; ITERS2: @funcspec.arg.7 = internal constant i32 4 + +; ITERS3: @funcspec.arg = internal constant i32 0 +; ITERS3: @funcspec.arg.1 = internal constant i32 6 +; ITERS3: @funcspec.arg.3 = internal constant i32 1 +; ITERS3: @funcspec.arg.4 = internal constant i32 5 +; ITERS3: @funcspec.arg.6 = internal constant i32 2 +; ITERS3: @funcspec.arg.7 = internal constant i32 4 +; ITERS3: @funcspec.arg.9 = internal constant i32 3 +; ITERS3: @funcspec.arg.10 = internal constant i32 3 + +; ITERS4: @funcspec.arg = internal constant i32 0 +; ITERS4: @funcspec.arg.1 = internal constant i32 6 +; ITERS4: @funcspec.arg.3 = internal constant i32 1 +; ITERS4: @funcspec.arg.4 = internal constant i32 5 +; ITERS4: @funcspec.arg.6 = internal constant i32 2 +; ITERS4: @funcspec.arg.7 = internal constant i32 4 +; ITERS4: @funcspec.arg.9 = internal constant i32 3 +; ITERS4: @funcspec.arg.10 = internal constant i32 3 + +define i32 @main() { +; ITERS1-LABEL: @main( +; ITERS1-NEXT: call void @print_val(i32 0, i32 6) +; ITERS1-NEXT: call void @recursiveFunc(ptr nonnull @funcspec.arg.3, i32 1, ptr nonnull @funcspec.arg.4) +; ITERS1-NEXT: ret i32 0 +; +; ITERS2-LABEL: @main( +; ITERS2-NEXT: call void @print_val(i32 0, i32 6) +; ITERS2-NEXT: call void @print_val(i32 1, i32 5) +; ITERS2-NEXT: call void @recursiveFunc(ptr nonnull @funcspec.arg.6, i32 1, ptr nonnull @funcspec.arg.7) +; ITERS2-NEXT: ret i32 0 +; +; ITERS3-LABEL: @main( +; ITERS3-NEXT: call void @print_val(i32 0, i32 6) +; ITERS3-NEXT: call void @print_val(i32 1, i32 5) +; ITERS3-NEXT: call void @print_val(i32 2, i32 4) +; ITERS3-NEXT: call void @recursiveFunc(ptr nonnull @funcspec.arg.9, i32 1, ptr nonnull @funcspec.arg.10) +; ITERS3-NEXT: ret i32 0 +; +; ITERS4-LABEL: @main( +; ITERS4-NEXT: call void @print_val(i32 0, i32 6) +; ITERS4-NEXT: call void @print_val(i32 1, i32 5) +; ITERS4-NEXT: call void @print_val(i32 2, i32 4) +; ITERS4-NEXT: ret i32 0 +; + %lo.temp = alloca i32, align 4 + %hi.temp = alloca i32, align 4 + %lo.load = load i32, ptr @low, align 4 + %hi.load = load i32, ptr @high, align 4 + store i32 %lo.load, ptr %lo.temp, align 4 + store i32 %hi.load, ptr %hi.temp, align 4 + call void @recursiveFunc(ptr nonnull %lo.temp, i32 1, ptr nonnull %hi.temp) + ret i32 0 +} + +declare dso_local void @print_val(i32, i32)