diff --git a/llvm/include/llvm/IR/ReplaceConstant.h b/llvm/include/llvm/IR/ReplaceConstant.h --- a/llvm/include/llvm/IR/ReplaceConstant.h +++ b/llvm/include/llvm/IR/ReplaceConstant.h @@ -16,6 +16,8 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Instruction.h" +#include +#include namespace llvm { @@ -23,6 +25,36 @@ /// it before \p Instr. Instruction *createReplacementInstr(ConstantExpr *CE, Instruction *Instr); +/// The given instruction \p I contains given constant expression \p CE as one +/// of its operands, possibly nested within constant expression trees. Convert +/// all reachable paths from contant expression operands of \p I to \p CE into +/// corresponding instructions, insert them before \p I, update operands of \p I +/// accordingly, and if required, return all such converted instructions at +/// \p Insts. +void convertConstantExprsToInstructions( + Instruction *I, ConstantExpr *CE, + SmallPtrSetImpl *Insts = nullptr); + +/// The given instruction \p I contains constant expression CE within the +/// constant expression trees of it`s constant expression operands, and +/// \p CEPaths holds all the reachable paths (to CE) from such constant +/// expression trees of \p I. Convert constant expressions within these paths +/// into corresponding instructions, insert them before \p I, update operands of +/// \p I accordingly, and if required, return all such converted instructions at +/// \p Insts. +void convertConstantExprsToInstructions( + Instruction *I, + std::map>> &CEPaths, + SmallPtrSetImpl *Insts = nullptr); + +/// Given an instruction \p I which uses given constant expression \p CE as +/// operand, either directly or nested within other constant expressions, return +/// all reachable paths from the constant expression operands of \p I to \p CE, +/// and return collected paths at \p CEPaths. +void collectConstantExprPaths( + Instruction *I, ConstantExpr *CE, + std::map>> &CEPaths); + } // end namespace llvm #endif // LLVM_IR_REPLACECONSTANT_H diff --git a/llvm/lib/IR/ReplaceConstant.cpp b/llvm/lib/IR/ReplaceConstant.cpp --- a/llvm/lib/IR/ReplaceConstant.cpp +++ b/llvm/lib/IR/ReplaceConstant.cpp @@ -68,4 +68,97 @@ llvm_unreachable("Unhandled constant expression!\n"); } } + +void convertConstantExprsToInstructions(Instruction *I, ConstantExpr *CE, + SmallPtrSetImpl *Insts) { + // Collect all reachable paths to CE from constant exprssion operands of I. + std::map>> CEPaths; + collectConstantExprPaths(I, CE, CEPaths); + + // Convert all constant expressions to instructions which are collected at + // CEPaths. + convertConstantExprsToInstructions(I, CEPaths, Insts); +} + +void convertConstantExprsToInstructions( + Instruction *I, + std::map>> &CEPaths, + SmallPtrSetImpl *Insts) { + for (Use &U : I->operands()) { + // The operand U is either not a constant expression operand or the + // constant expression paths do not belong to U, ignore U. + if (!CEPaths.count(&U)) { + continue; + } + + // If the instruction I is a PHI instruction, then fix the instruction + // insertion point to the entry of the incoming basic block for operand U. + auto *BI = I; + if (auto *Phi = dyn_cast(I)) { + BasicBlock *BB = Phi->getIncomingBlock(U); + BI = &(*(BB->getFirstInsertionPt())); + } + + // Go through the paths associated with operand U, and convert all the + // constant expressions along all paths to corresponding instructions. + auto *II = I; + auto &Paths = CEPaths[&U]; + SmallPtrSet Visited; + for (auto &Path : Paths) { + for (auto *CE : Path) { + if (!Visited.insert(CE).second) + continue; + auto *NI = CE->getAsInstruction(); + NI->insertBefore(BI); + II->replaceUsesOfWith(CE, NI); + CE->removeDeadConstantUsers(); + BI = II = NI; + if (Insts) + Insts->insert(NI); + } + } + } +} + +void collectConstantExprPaths( + Instruction *I, ConstantExpr *CE, + std::map>> &CEPaths) { + for (Use &U : I->operands()) { + // If the operand U is not a constant expression operand, then ignore it. + auto *CE2 = dyn_cast(U.get()); + if (!CE2) + continue; + + // Holds all reachable paths from CE2 to CE. + std::vector> Paths; + + // Collect all reachable paths from CE2 to CE. + std::vector Path{CE2}; + std::vector> Stack{Path}; + while (!Stack.empty()) { + std::vector TPath = Stack.back(); + Stack.pop_back(); + auto *CE3 = TPath.back(); + + if (CE3 == CE) { + Paths.push_back(TPath); + continue; + } + + for (auto &UU : CE3->operands()) { + if (auto *CE4 = dyn_cast(UU.get())) { + std::vector NPath(TPath.begin(), TPath.end()); + NPath.push_back(CE4); + Stack.push_back(NPath); + } + } + } + + // Associate all the collected paths with U, and save it. + if (!Paths.empty()) { + CEPaths[&U] = Paths; + } + } +} + } // namespace llvm