Index: include/llvm/Analysis/ConstantFolding.h =================================================================== --- include/llvm/Analysis/ConstantFolding.h +++ include/llvm/Analysis/ConstantFolding.h @@ -19,6 +19,8 @@ #ifndef LLVM_ANALYSIS_CONSTANTFOLDING_H #define LLVM_ANALYSIS_CONSTANTFOLDING_H +#include "llvm/ADT/DenseMap.h" + namespace llvm { class APInt; template class ArrayRef; @@ -46,12 +48,26 @@ Constant *ConstantFoldInstruction(Instruction *I, const DataLayout &DL, const TargetLibraryInfo *TLI = nullptr); +/// ConstantFoldInstruction overload that allows persisting the folded +/// expression cache across calls. +Constant * +ConstantFoldInstruction(Instruction *I, const DataLayout &DL, + const TargetLibraryInfo *TLI, + SmallDenseMap &FoldedOps); + /// ConstantFoldConstant - Fold the constant using the specified DataLayout. /// This function always returns a non-null constant: Either the folding result, /// or the original constant if further folding is not possible. Constant *ConstantFoldConstant(const Constant *C, const DataLayout &DL, const TargetLibraryInfo *TLI = nullptr); +/// ConstantFoldConstant overload that allows persisting the folded expression +/// cache across calls. +Constant * +ConstantFoldConstant(const Constant *C, const DataLayout &DL, + const TargetLibraryInfo *TLI, + SmallDenseMap &FoldedOps); + /// ConstantFoldInstOperands - Attempt to constant fold an instruction with the /// specified operands. If successful, the constant result is returned, if not, /// null is returned. Note that this function can fail when attempting to Index: lib/Analysis/ConstantFolding.cpp =================================================================== --- lib/Analysis/ConstantFolding.cpp +++ lib/Analysis/ConstantFolding.cpp @@ -1078,28 +1078,15 @@ namespace { Constant * -ConstantFoldConstantImpl(const Constant *C, const DataLayout &DL, - const TargetLibraryInfo *TLI, - SmallDenseMap &FoldedOps) { - if (!isa(C) && !isa(C)) - return const_cast(C); +ConstantFoldConstantUncached(const Constant *C, const DataLayout &DL, + const TargetLibraryInfo *TLI, + SmallDenseMap &FoldedOps) { + assert(isa(C) || isa(C)); SmallVector Ops; for (const Use &OldU : C->operands()) { - Constant *OldC = cast(&OldU); - Constant *NewC = OldC; - // Recursively fold the ConstantExpr's operands. If we have already folded - // a ConstantExpr, we don't have to process it again. - if (isa(OldC) || isa(OldC)) { - auto It = FoldedOps.find(OldC); - if (It == FoldedOps.end()) { - NewC = ConstantFoldConstantImpl(OldC, DL, TLI, FoldedOps); - FoldedOps.insert({OldC, NewC}); - } else { - NewC = It->second; - } - } - Ops.push_back(NewC); + Constant *C = cast(&OldU); + Ops.push_back(ConstantFoldConstant(C, DL, TLI, FoldedOps)); } if (auto *CE = dyn_cast(C)) { @@ -1118,11 +1105,17 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, const DataLayout &DL, const TargetLibraryInfo *TLI) { + SmallDenseMap FoldedOps; + return ConstantFoldInstruction(I, DL, TLI, FoldedOps); +} + +Constant *llvm::ConstantFoldInstruction( + Instruction *I, const DataLayout &DL, const TargetLibraryInfo *TLI, + SmallDenseMap &FoldedOps) { // Handle PHI nodes quickly here... if (auto *PN = dyn_cast(I)) { Constant *CommonValue = nullptr; - SmallDenseMap FoldedOps; for (Value *Incoming : PN->incoming_values()) { // If the incoming value is undef then skip it. Note that while we could // skip the value if it is equal to the phi node itself we choose not to @@ -1135,7 +1128,7 @@ if (!C) return nullptr; // Fold the PHI's operands. - C = ConstantFoldConstantImpl(C, DL, TLI, FoldedOps); + C = ConstantFoldConstant(C, DL, TLI, FoldedOps); // If the incoming value is a different constant to // the one we saw previously, then give up. if (CommonValue && C != CommonValue) @@ -1152,12 +1145,11 @@ if (!all_of(I->operands(), [](Use &U) { return isa(U); })) return nullptr; - SmallDenseMap FoldedOps; SmallVector Ops; for (const Use &OpU : I->operands()) { auto *Op = cast(&OpU); // Fold the Instruction's operands. - Op = ConstantFoldConstantImpl(Op, DL, TLI, FoldedOps); + Op = ConstantFoldConstant(Op, DL, TLI, FoldedOps); Ops.push_back(Op); } @@ -1186,8 +1178,27 @@ Constant *llvm::ConstantFoldConstant(const Constant *C, const DataLayout &DL, const TargetLibraryInfo *TLI) { + if (!isa(C) && !isa(C)) + return const_cast(C); + SmallDenseMap FoldedOps; - return ConstantFoldConstantImpl(C, DL, TLI, FoldedOps); + return ConstantFoldConstantUncached(C, DL, TLI, FoldedOps); +} + +Constant * +llvm::ConstantFoldConstant(const Constant *C, const DataLayout &DL, + const TargetLibraryInfo *TLI, + SmallDenseMap &FoldedOps) { + if (!isa(C) && !isa(C)) + return const_cast(C); + + auto It = FoldedOps.find(C); + if (It != FoldedOps.end()) + return It->second; + + Constant *NewC = ConstantFoldConstantUncached(C, DL, TLI, FoldedOps); + FoldedOps.insert({ const_cast(C), NewC }); + return NewC; } Constant *llvm::ConstantFoldInstOperands(Instruction *I, Index: lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- lib/Transforms/InstCombine/InstructionCombining.cpp +++ lib/Transforms/InstCombine/InstructionCombining.cpp @@ -3591,7 +3591,7 @@ Worklist.push_back(BB); SmallVector InstrsForInstCombineWorklist; - DenseMap FoldedConstants; + SmallDenseMap FoldedConstants; do { BB = Worklist.pop_back_val(); @@ -3606,7 +3606,8 @@ // ConstantProp instruction if trivially constant. if (!Inst->use_empty() && (Inst->getNumOperands() == 0 || isa(Inst->getOperand(0)))) - if (Constant *C = ConstantFoldInstruction(Inst, DL, TLI)) { + if (Constant *C = ConstantFoldInstruction(Inst, DL, TLI, + FoldedConstants)) { LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *Inst << '\n'); Inst->replaceAllUsesWith(C); @@ -3622,11 +3623,8 @@ if (!isa(U) && !isa(U)) continue; - auto *C = cast(U); - Constant *&FoldRes = FoldedConstants[C]; - if (!FoldRes) - FoldRes = ConstantFoldConstant(C, DL, TLI); - + Constant *C = cast(U); + Constant *FoldRes = ConstantFoldConstant(C, DL, TLI, FoldedConstants); if (FoldRes != C) { LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst << "\n Old = " << *C