diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -19,7 +19,6 @@ // Current limitations: // - It does not yet handle integer ranges. We do support "literal constants", // but that's off by default under an option. -// - Only 1 argument per function is specialised, // - The cost-model could be further looked into (it mainly focuses on inlining // benefits), // - We are not yet caching analysis results, but profiling and checking where @@ -210,35 +209,39 @@ // are any new constant values for the call instruction via // stack variables. for (auto *F : WorkList) { - // TODO: Generalize for any read only arguments. - if (F->arg_size() != 1) - continue; - - auto &Arg = *F->arg_begin(); - if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy()) - continue; for (auto *User : F->users()) { + auto *Call = dyn_cast(User); if (!Call) - break; - auto *ArgOp = Call->getArgOperand(0); - auto *ArgOpType = ArgOp->getType(); - auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver); - if (!ConstVal) - break; + continue; - Value *GV = new GlobalVariable(M, ConstVal->getType(), true, - GlobalValue::InternalLinkage, ConstVal, - "funcspec.arg"); + bool Changed = false; + for (const Use &U : Call->args()) { + unsigned Idx = Call->getArgOperandNo(&U); + Value *ArgOp = Call->getArgOperand(Idx); + Type *ArgOpType = ArgOp->getType(); - if (ArgOpType != ConstVal->getType()) - GV = ConstantExpr::getBitCast(cast(GV), ArgOp->getType()); + if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy()) + continue; + + auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver); + if (!ConstVal) + continue; - Call->setArgOperand(0, GV); + Value *GV = new GlobalVariable(M, ConstVal->getType(), true, + GlobalValue::InternalLinkage, ConstVal, + "funcspec.arg"); + if (ArgOpType != ConstVal->getType()) + GV = ConstantExpr::getBitCast(cast(GV), ArgOpType); + + Call->setArgOperand(Idx, GV); + Changed = true; + } // Add the changed CallInst to Solver Worklist - Solver.visitCall(*Call); + if (Changed) + Solver.visitCall(*Call); } } } diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll @@ -2,50 +2,58 @@ ; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=3 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS3 ; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=4 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS4 -@Global = internal constant i32 1, align 4 +@low = internal constant i32 0, align 4 +@high = internal constant i32 6, align 4 -define internal void @recursiveFunc(i32* nocapture readonly %arg) { - %temp = alloca i32, align 4 - %arg.load = load i32, i32* %arg, align 4 - %arg.cmp = icmp slt i32 %arg.load, 4 - br i1 %arg.cmp, label %block6, label %ret.block +define internal void @recursiveFunc(i32* nocapture readonly %lo, i32 %step, i32* nocapture readonly %hi) { + %lo.temp = alloca i32, align 4 + %hi.temp = alloca i32, align 4 + %lo.load = load i32, i32* %lo, align 4 + %hi.load = load i32, i32* %hi, align 4 + %cmp = icmp ne i32 %lo.load, %hi.load + br i1 %cmp, label %block6, label %ret.block block6: - call void @print_val(i32 %arg.load) - %arg.add = add nsw i32 %arg.load, 1 - store i32 %arg.add, i32* %temp, align 4 - call void @recursiveFunc(i32* nonnull %temp) + call void @print_val(i32 %lo.load, i32 %hi.load) + %add = add nsw i32 %lo.load, %step + %sub = sub nsw i32 %hi.load, %step + store i32 %add, i32* %lo.temp, align 4 + store i32 %sub, i32* %hi.temp, align 4 + call void @recursiveFunc(i32* nonnull %lo.temp, i32 %step, i32* nonnull %hi.temp) br label %ret.block ret.block: ret void } -; ITERS2: @funcspec.arg.3 = internal constant i32 3 -; ITERS3: @funcspec.arg.5 = internal constant i32 4 +; ITERS2: @funcspec.arg.4 = internal constant i32 2 +; ITERS2: @funcspec.arg.5 = internal constant i32 4 + +; ITERS3: @funcspec.arg.7 = internal constant i32 3 +; ITERS3: @funcspec.arg.8 = internal constant i32 3 define i32 @main() { ; ITERS2-LABEL: @main( -; ITERS2-NEXT: call void @print_val(i32 1) -; ITERS2-NEXT: call void @print_val(i32 2) -; ITERS2-NEXT: call void @recursiveFunc(i32* nonnull @funcspec.arg.3) +; ITERS2-NEXT: call void @print_val(i32 0, i32 6) +; ITERS2-NEXT: call void @print_val(i32 1, i32 5) +; ITERS2-NEXT: call void @recursiveFunc(i32* nonnull @funcspec.arg.4, i32 1, i32* nonnull @funcspec.arg.5) ; ITERS2-NEXT: ret i32 0 ; ; ITERS3-LABEL: @main( -; ITERS3-NEXT: call void @print_val(i32 1) -; ITERS3-NEXT: call void @print_val(i32 2) -; ITERS3-NEXT: call void @print_val(i32 3) -; ITERS3-NEXT: call void @recursiveFunc(i32* nonnull @funcspec.arg.5) +; ITERS3-NEXT: call void @print_val(i32 0, i32 6) +; ITERS3-NEXT: call void @print_val(i32 1, i32 5) +; ITERS3-NEXT: call void @print_val(i32 2, i32 4) +; ITERS3-NEXT: call void @recursiveFunc(i32* nonnull @funcspec.arg.7, i32 1, i32* nonnull @funcspec.arg.8) ; ITERS3-NEXT: ret i32 0 ; ; ITERS4-LABEL: @main( -; ITERS4-NEXT: call void @print_val(i32 1) -; ITERS4-NEXT: call void @print_val(i32 2) -; ITERS4-NEXT: call void @print_val(i32 3) +; ITERS4-NEXT: call void @print_val(i32 0, i32 6) +; ITERS4-NEXT: call void @print_val(i32 1, i32 5) +; ITERS4-NEXT: call void @print_val(i32 2, i32 4) ; ITERS4-NEXT: ret i32 0 ; - call void @recursiveFunc(i32* nonnull @Global) + call void @recursiveFunc(i32* nonnull @low, i32 1, i32* nonnull @high) ret i32 0 } -declare dso_local void @print_val(i32) +declare dso_local void @print_val(i32, i32)