diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h --- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h +++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h @@ -171,6 +171,7 @@ SmallPtrSet Specializations; SmallPtrSet FullySpecialized; DenseMap FunctionMetrics; + DenseMap NumSpecs; public: FunctionSpecializer( diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -74,6 +74,11 @@ "Force function specialization for every call site with a constant " "argument")); +static cl::opt MaxTotalClones( + "funcspec-max-total-clones", cl::init(8), cl::Hidden, cl::desc( + "The maximum number of clones allowed for a single function " + "specialization across multiple iterations")); + static cl::opt MaxClones( "funcspec-max-clones", cl::init(3), cl::Hidden, cl::desc( "The maximum number of clones allowed for a single function " @@ -695,13 +700,13 @@ AllSpecs[Index].CallSites.push_back(&CS); } else { // Calculate the specialisation gain. - Cost Score = 0 - SpecCost; + Cost Score = 0; InstCostVisitor Visitor = getInstCostVisitorFor(F); for (ArgInfo &A : S.Args) Score += getSpecializationBonus(A.Formal, A.Actual, Visitor); // Discard unprofitable specialisations. - if (!ForceSpecialization && Score <= 0) + if (!ForceSpecialization && Score <= SpecCost) continue; // Create a new specialisation entry. @@ -766,6 +771,11 @@ // Mark all the specialized functions Specializations.insert(Clone); + + // Update the cost model. + for (const ArgInfo &A : S.Args) + ++NumSpecs[A.Formal]; + ++NumSpecsCreated; return Clone; @@ -774,6 +784,10 @@ /// Compute a bonus for replacing argument \p A with constant \p C. Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, InstCostVisitor &Visitor) { + // Maximum number of clones reached. + if (NumSpecs[A] == MaxTotalClones) + return 0; + LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " << C->getNameOrAsOperand() << "\n"); diff --git a/llvm/lib/Transforms/IPO/SCCP.cpp b/llvm/lib/Transforms/IPO/SCCP.cpp --- a/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/llvm/lib/Transforms/IPO/SCCP.cpp @@ -42,7 +42,7 @@ "Number of instructions replaced with (simpler) instruction"); static cl::opt FuncSpecMaxIters( - "funcspec-max-iters", cl::init(1), cl::Hidden, cl::desc( + "funcspec-max-iters", cl::init(10), cl::Hidden, cl::desc( "The maximum number of iterations function specialization is run")); static void findReturnsToZap(Function &F, diff --git a/llvm/test/Transforms/FunctionSpecialization/recursive-penalty.ll b/llvm/test/Transforms/FunctionSpecialization/recursive-penalty.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/recursive-penalty.ll @@ -0,0 +1,64 @@ +; REQUIRES: asserts +; RUN: opt -passes="ipsccp,inline,instcombine,simplifycfg" -S \ +; RUN: -funcspec-min-function-size=23 -funcspec-max-iters=100 \ +; RUN: -debug-only=function-specialization < %s 2>&1 | FileCheck %s + +; Make sure the number of specializations created are not +; linear to the number of iterations (funcspec-max-iters). + +; CHECK: FnSpecialization: Created 8 specializations in module + +@Global = internal constant i32 1, align 4 + +define internal void @recursiveFunc(ptr readonly %arg) { + %temp = alloca i32, align 4 + %arg.load = load i32, ptr %arg, align 4 + %arg.cmp = icmp slt i32 %arg.load, 10000 + br i1 %arg.cmp, label %loop1, label %ret.block + +loop1: + br label %loop2 + +loop2: + br label %loop3 + +loop3: + br label %loop4 + +loop4: + br label %block6 + +block6: + call void @print_val(i32 %arg.load) + %arg.add = add nsw i32 %arg.load, 1 + store i32 %arg.add, ptr %temp, align 4 + call void @recursiveFunc(ptr %temp) + br label %loop4.end + +loop4.end: + %exit_cond1 = call i1 @exit_cond() + br i1 %exit_cond1, label %loop4, label %loop3.end + +loop3.end: + %exit_cond2 = call i1 @exit_cond() + br i1 %exit_cond2, label %loop3, label %loop2.end + +loop2.end: + %exit_cond3 = call i1 @exit_cond() + br i1 %exit_cond3, label %loop2, label %loop1.end + +loop1.end: + %exit_cond4 = call i1 @exit_cond() + br i1 %exit_cond4, label %loop1, label %ret.block + +ret.block: + ret void +} + +define i32 @main() { + call void @recursiveFunc(ptr @Global) + ret i32 0 +} + +declare dso_local void @print_val(i32) +declare dso_local i1 @exit_cond()