diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h --- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h +++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h @@ -143,6 +143,8 @@ Cost estimateBranchInst(BranchInst &I); Constant *visitInstruction(Instruction &I) { return nullptr; } + Constant *visitFreezeInst(FreezeInst &I); + Constant *visitCallBase(CallBase &I); Constant *visitLoadInst(LoadInst &I); Constant *visitGetElementPtrInst(GetElementPtrInst &I); Constant *visitSelectInst(SelectInst &I); diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -222,6 +222,34 @@ return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI); } +Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) { + if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second)) + return LastVisited->second; + return nullptr; +} + +Constant *InstCostVisitor::visitCallBase(CallBase &I) { + Function *F = I.getCalledFunction(); + if (!F || !canConstantFoldCallTo(&I, F)) + return nullptr; + + SmallVector Operands; + Operands.reserve(I.getNumOperands()); + + for (unsigned Idx = 0, E = I.getNumOperands() - 1; Idx != E; ++Idx) { + Value *V = I.getOperand(Idx); + auto *C = dyn_cast(V); + if (!C) + C = findConstantFor(V, KnownConstants); + if (!C) + return nullptr; + Operands.push_back(C); + } + + auto Ops = ArrayRef(Operands.begin(), Operands.end()); + return ConstantFoldCall(&I, F, Ops); +} + Constant *InstCostVisitor::visitLoadInst(LoadInst &I) { if (isa(LastVisited->second)) return nullptr; diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp --- a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp +++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp @@ -227,13 +227,21 @@ const char *ModuleString = R"( @g = constant [2 x i32] zeroinitializer, align 4 - define i32 @foo(i8 %a, i1 %cond, ptr %b) { + declare i32 @llvm.smax.i32(i32, i32) + declare i32 @bar(i32) + + define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { %cmp = icmp eq i8 %a, 10 %ext = zext i1 %cmp to i32 %sel = select i1 %cond, i32 %ext, i32 1 %gep = getelementptr i32, ptr %b, i32 %sel %ld = load i32, ptr %gep - ret i32 %ld + %fr = freeze i32 %ld + %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) + %call = call i32 @bar(i32 %smax) + %fr2 = freeze i32 %c + %add = add i32 %call, %fr2 + ret i32 %add } )"; @@ -245,6 +253,7 @@ GlobalVariable *GV = M.getGlobalVariable("g"); Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); Constant *True = ConstantInt::getTrue(M.getContext()); + Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); auto BlockIter = F->front().begin(); Instruction &Icmp = *BlockIter++; @@ -252,6 +261,8 @@ Instruction &Select = *BlockIter++; Instruction &Gep = *BlockIter++; Instruction &Load = *BlockIter++; + Instruction &Freeze = *BlockIter++; + Instruction &Smax = *BlockIter++; // icmp + zext Cost Ref = getInstCost(Icmp) + getInstCost(Zext); @@ -265,9 +276,13 @@ EXPECT_EQ(Bonus, Ref); EXPECT_TRUE(Bonus > 0); - // gep + load - Ref = getInstCost(Gep) + getInstCost(Load); + // gep + load + freeze + smax + Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) + + getInstCost(Smax); Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor); EXPECT_EQ(Bonus, Ref); EXPECT_TRUE(Bonus > 0); + + Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor); + EXPECT_TRUE(Bonus == 0); }