diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -11,7 +11,6 @@ // are propagated to the callee by specializing the function. // // Current limitations: -// - It does not handle specialization of recursive functions, // - It does not yet handle integer ranges. // - Only 1 argument per function is specialised, // - The cost-model could be further looked into, @@ -68,9 +67,142 @@ "function-specialization-for-literal-constant", cl::init(false), cl::Hidden, cl::desc("Make function specialization available for literal constant.")); +// Helper to check if \p LV is either a constant or a constant +// range with a single element. This should cover exactly the same cases as the +// old ValueLatticeElement::isConstant() and is intended to be used in the +// transition to ValueLatticeElement. +static bool isConstant(const ValueLatticeElement &LV) { + return LV.isConstant() || + (LV.isConstantRange() && LV.getConstantRange().isSingleElement()); +} + // Helper to check if \p LV is either overdefined or a constant int. static bool isOverdefined(const ValueLatticeElement &LV) { - return !LV.isUnknownOrUndef() && !LV.isConstant(); + return !LV.isUnknownOrUndef() && !isConstant(LV); +} + +static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) { + Value *StoreValue = nullptr; + for (auto *User : Alloca->users()) { + // We can't use llvm::isAllocaPromotable() as that would fail because of + // the usage in the CallInst, which is what we check here. + if (User == Call) + continue; + if (auto *Bitcast = dyn_cast(User)) { + if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call) + return nullptr; + continue; + } + + if (auto *Store = dyn_cast(User)) { + // This is a duplicate store, bail out. + if (StoreValue || Store->isVolatile()) + return nullptr; + StoreValue = Store->getValueOperand(); + continue; + } + // Bail if there is any other unknown usage. + return nullptr; + } + return dyn_cast_or_null(StoreValue); +} + +// A constant stack value is an AllocaInst that has a single constant +// value stored to it. Return this constant if such an alloca stack value +// is a function argument. +static Constant *getConstantStackValue(CallInst *Call, Value *Val, + SCCPSolver &Solver) { + if (!Val) + return nullptr; + Val = Val->stripPointerCasts(); + if (auto *ConstVal = dyn_cast(Val)) + return ConstVal; + auto *Alloca = dyn_cast(Val); + if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy()) + return nullptr; + return getPromotableAlloca(Alloca, Call); +} + +// To support specializing recursive functions, it is important to propagate +// constant arguments because after a first iteration of specialisation, a +// reduced example may look like this: +// +// define internal void @RecursiveFn(i32* arg1) { +// %temp = alloca i32, align 4 +// store i32 2 i32* %temp, align 4 +// call void @RecursiveFn.1(i32* nonnull %temp) +// ret void +// } +// +// Before a next iteration, we need to propagate the constant like so +// which allows further specialization in next iterations. +// +// @funcspec.arg = internal constant i32 2 +// +// define internal void @someFunc(i32* arg1) { +// call void @otherFunc(i32* nonnull @funcspec.arg) +// ret void +// } +// +static void constantArgPropagation(SmallVectorImpl &WorkList, + Module &M, SCCPSolver &Solver) { + // Iterate over the argument tracked functions see if there + // are any new constant values for the call instruction via + // stack variables. + for (auto *F : WorkList) { + // TODO: Generalize for any read only arguments. + if (F->arg_size() != 1) + continue; + + auto &Arg = *F->arg_begin(); + if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy()) + continue; + + for (auto *User : F->users()) { + auto *Call = dyn_cast(User); + if (!Call) + break; + auto *ArgOp = Call->getArgOperand(0); + auto *ArgOpType = ArgOp->getType(); + auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver); + if (!ConstVal) + break; + + Value *GV = new GlobalVariable(M, ConstVal->getType(), true, + GlobalValue::InternalLinkage, ConstVal, + "funcspec.arg"); + + if (ArgOpType != ConstVal->getType()) + GV = ConstantExpr::getBitCast(cast(GV), ArgOp->getType()); + + Call->setArgOperand(0, GV); + + // Add the changed CallInst to Solver Worklist + Solver.visitCall(*Call); + } + } +} + +// ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics +// interfere with the constantArgPropagation optimization. +static void removeSSACopy(Function &F) { + for (BasicBlock &BB : F) { + for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { + Instruction *Inst = &*BI++; + auto *II = dyn_cast(Inst); + if (!II) + continue; + if (II->getIntrinsicID() != Intrinsic::ssa_copy) + continue; + Inst->replaceAllUsesWith(II->getOperand(0)); + Inst->eraseFromParent(); + } + } +} + +static void removeSSACopy(Module &M) { + for (Function &F : M) + removeSSACopy(F); } class FunctionSpecializer { @@ -115,9 +247,14 @@ for (auto *SpecializedFunc : CurrentSpecializations) { SpecializedFuncs.insert(SpecializedFunc); - // TODO: If we want to support specializing specialized functions, - // initialize here the state of the newly created functions, marking - // them argument-tracked and executable. + // Initialize the state of the newly created functions, marking them + // argument-tracked and executable. + if (SpecializedFunc->hasExactDefinition() && + !SpecializedFunc->hasFnAttribute(Attribute::Naked)) + Solver.addTrackedFunction(SpecializedFunc); + Solver.addArgumentTrackedFunction(SpecializedFunc); + FuncDecls.push_back(SpecializedFunc); + Solver.markBlockExecutable(&SpecializedFunc->front()); // Replace the function arguments for the specialized functions. for (Argument &Arg : SpecializedFunc->args()) @@ -138,12 +275,22 @@ const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); if (isOverdefined(IV)) return false; - auto *Const = IV.isConstant() ? Solver.getConstant(IV) - : UndefValue::get(V->getType()); + auto *Const = + isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType()); V->replaceAllUsesWith(Const); - // TODO: Update the solver here if we want to specialize specialized - // functions. + for (auto *U : Const->users()) + if (auto *I = dyn_cast(U)) + if (Solver.isBlockExecutable(I->getParent())) + Solver.visit(I); + + // Remove the instruction from Block and Solver. + if (auto *I = dyn_cast(V)) { + if (I->isSafeToRemove()) { + I->eraseFromParent(); + Solver.removeLatticeValueFor(I); + } + } return true; } @@ -152,6 +299,15 @@ // also in the cost model. unsigned NbFunctionsSpecialized = 0; + /// Clone the function \p F and remove the ssa_copy intrinsics added by + /// the SCCPSolver in the cloned version. + Function *cloneCandidateFunction(Function *F) { + ValueToValueMapTy EmptyMap; + Function *Clone = CloneFunction(F, EmptyMap); + removeSSACopy(*Clone); + return Clone; + } + /// This function decides whether to specialize function \p F based on the /// known constant values its arguments can take on. Specialization is /// performed on the first interesting argument. Specializations based on @@ -214,8 +370,7 @@ for (auto *C : Constants) { // Clone the function. We leave the ValueToValueMap empty to allow // IPSCCP to propagate the constant arguments. - ValueToValueMapTy EmptyMap; - Function *Clone = CloneFunction(F, EmptyMap); + Function *Clone = cloneCandidateFunction(F); Argument *ClonedArg = Clone->arg_begin() + A.getArgNo(); // Rewrite calls to the function so that they call the clone instead. @@ -231,9 +386,10 @@ NbFunctionsSpecialized++; } - // TODO: if we want to support specialize specialized functions, and if - // the function has been completely specialized, the original function is - // no longer needed, so we would need to mark it unreachable here. + // If the function has been completely specialized, the original function + // is no longer needed. Mark it unreachable. + if (!IsPartial) + Solver.markFunctionUnreachable(F); // FIXME: Only one argument per function. return true; @@ -528,24 +684,6 @@ } }; -/// Function to clean up the left over intrinsics from SCCP util. -static void cleanup(Module &M) { - for (Function &F : M) { - for (BasicBlock &BB : F) { - for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { - Instruction *Inst = &*BI++; - if (auto *II = dyn_cast(Inst)) { - if (II->getIntrinsicID() == Intrinsic::ssa_copy) { - Value *Op = II->getOperand(0); - Inst->replaceAllUsesWith(Op); - Inst->eraseFromParent(); - } - } - } - } - } -} - bool llvm::runFunctionSpecialization( Module &M, const DataLayout &DL, std::function GetTLI, @@ -637,14 +775,18 @@ unsigned I = 0; while (FuncSpecializationMaxIters != I++ && FS.specializeFunctions(FuncDecls, CurrentSpecializations)) { - // TODO: run the solver here for the specialized functions only if we want - // to specialize recursively. + + // Run the solver for the specialized functions. + RunSCCPSolver(CurrentSpecializations); + + // Replace some unresolved constant arguments + constantArgPropagation(FuncDecls, M, Solver); CurrentSpecializations.clear(); Changed = true; } // Clean up the IR by removing ssa_copy intrinsics. - cleanup(M); + removeSSACopy(M); return Changed; } diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll @@ -1,26 +1,10 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt -function-specialization -inline -instcombine -S < %s | FileCheck %s - -; TODO: this is a case that would be interesting to support, but we don't yet -; at the moment. +; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS2 +; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=3 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS3 +; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=4 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS4 @Global = internal constant i32 1, align 4 define internal void @recursiveFunc(i32* nocapture readonly %arg) { -; CHECK-LABEL: @recursiveFunc( -; CHECK-NEXT: [[TEMP:%.*]] = alloca i32, align 4 -; CHECK-NEXT: [[ARG_LOAD:%.*]] = load i32, i32* [[ARG:%.*]], align 4 -; CHECK-NEXT: [[ARG_CMP:%.*]] = icmp slt i32 [[ARG_LOAD]], 4 -; CHECK-NEXT: br i1 [[ARG_CMP]], label [[BLOCK6:%.*]], label [[RET_BLOCK:%.*]] -; CHECK: block6: -; CHECK-NEXT: call void @print_val(i32 [[ARG_LOAD]]) -; CHECK-NEXT: [[ARG_ADD:%.*]] = add nsw i32 [[ARG_LOAD]], 1 -; CHECK-NEXT: store i32 [[ARG_ADD]], i32* [[TEMP]], align 4 -; CHECK-NEXT: call void @recursiveFunc(i32* nonnull [[TEMP]]) -; CHECK-NEXT: br label [[RET_BLOCK]] -; CHECK: ret.block: -; CHECK-NEXT: ret void -; %temp = alloca i32, align 4 %arg.load = load i32, i32* %arg, align 4 %arg.cmp = icmp slt i32 %arg.load, 4 @@ -37,10 +21,28 @@ ret void } +; ITERS2: @funcspec.arg.3 = internal constant i32 3 +; ITERS3: @funcspec.arg.5 = internal constant i32 4 + define i32 @main() { -; CHECK-LABEL: @main( -; CHECK-NEXT: call void @recursiveFunc(i32* nonnull @Global) -; CHECK-NEXT: ret i32 0 +; ITERS2-LABEL: @main( +; ITERS2-NEXT: call void @print_val(i32 1) +; ITERS2-NEXT: call void @print_val(i32 2) +; ITERS2-NEXT: call void @recursiveFunc(i32* nonnull @funcspec.arg.3) +; ITERS2-NEXT: ret i32 0 +; +; ITERS3-LABEL: @main( +; ITERS3-NEXT: call void @print_val(i32 1) +; ITERS3-NEXT: call void @print_val(i32 2) +; ITERS3-NEXT: call void @print_val(i32 3) +; ITERS3-NEXT: call void @recursiveFunc(i32* nonnull @funcspec.arg.5) +; ITERS3-NEXT: ret i32 0 +; +; ITERS4-LABEL: @main( +; ITERS4-NEXT: call void @print_val(i32 1) +; ITERS4-NEXT: call void @print_val(i32 2) +; ITERS4-NEXT: call void @print_val(i32 3) +; ITERS4-NEXT: ret i32 0 ; call void @recursiveFunc(i32* nonnull @Global) ret i32 0 diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive2.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive2.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive2.ll @@ -0,0 +1,32 @@ +; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -S < %s | FileCheck %s + +; Volatile store preventing recursive specialisation: +; +; CHECK: @recursiveFunc.1 +; CHECK-NOT: @recursiveFunc.2 + +@Global = internal constant i32 1, align 4 + +define internal void @recursiveFunc(i32* nocapture readonly %arg) { + %temp = alloca i32, align 4 + %arg.load = load i32, i32* %arg, align 4 + %arg.cmp = icmp slt i32 %arg.load, 4 + br i1 %arg.cmp, label %block6, label %ret.block + +block6: + call void @print_val(i32 %arg.load) + %arg.add = add nsw i32 %arg.load, 1 + store volatile i32 %arg.add, i32* %temp, align 4 + call void @recursiveFunc(i32* nonnull %temp) + br label %ret.block + +ret.block: + ret void +} + +define i32 @main() { + call void @recursiveFunc(i32* nonnull @Global) + ret i32 0 +} + +declare dso_local void @print_val(i32) diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive3.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive3.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive3.ll @@ -0,0 +1,34 @@ +; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -S < %s | FileCheck %s + +; Duplicate store preventing recursive specialisation: +; +; CHECK: @recursiveFunc.1 +; CHECK-NOT: @recursiveFunc.2 + +@Global = internal constant i32 1, align 4 + +define internal void @recursiveFunc(i32* nocapture readonly %arg) { + %temp = alloca i32, align 4 + %arg.load = load i32, i32* %arg, align 4 + %arg.cmp = icmp slt i32 %arg.load, 4 + br i1 %arg.cmp, label %block6, label %ret.block + +block6: + call void @print_val(i32 %arg.load) + %arg.add = add nsw i32 %arg.load, 1 + store i32 %arg.add, i32* %temp, align 4 + store i32 %arg.add, i32* %temp, align 4 + call void @recursiveFunc(i32* nonnull %temp) + br label %ret.block + +ret.block: + ret void +} + + +define i32 @main() { + call void @recursiveFunc(i32* nonnull @Global) + ret i32 0 +} + +declare dso_local void @print_val(i32) diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive4.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive4.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive4.ll @@ -0,0 +1,32 @@ +; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -S < %s | FileCheck %s + +; Alloca is not an integer type: +; +; CHECK: @recursiveFunc.1 +; CHECK-NOT: @recursiveFunc.2 + +@Global = internal constant i32 1, align 4 + +define internal void @recursiveFunc(i32* nocapture readonly %arg) { + %temp = alloca float, align 4 + %arg.load = load i32, i32* %arg, align 4 + %arg.cmp = icmp slt i32 %arg.load, 4 + br i1 %arg.cmp, label %block6, label %ret.block + +block6: + call void @print_val(i32 %arg.load) + %arg.add = add nsw i32 %arg.load, 1 + %bc = bitcast float* %temp to i32* + call void @recursiveFunc(i32* nonnull %bc) + br label %ret.block + +ret.block: + ret void +} + +define i32 @main() { + call void @recursiveFunc(i32* nonnull @Global) + ret i32 0 +} + +declare dso_local void @print_val(i32)