diff --git a/llvm/include/llvm/IR/ReplaceConstant.h b/llvm/include/llvm/IR/ReplaceConstant.h --- a/llvm/include/llvm/IR/ReplaceConstant.h +++ b/llvm/include/llvm/IR/ReplaceConstant.h @@ -23,6 +23,29 @@ /// it before \p Instr. Instruction *createReplacementInstr(ConstantExpr *CE, Instruction *Instr); +/// The given instruction \p I contains given constant expression \p CE as one +/// of its operands, possibly nested within other constant expressions. Convert +/// all such constant expression operands into corresponding instructions, +/// insert them before \p I, update operands of \p I accordingly, and if +/// required, return all such converted instructions at \p Insts. +void convertConstantExprsToInstructions( + Instruction *I, ConstantExpr *CE, + SmallPtrSetImpl *Insts = nullptr); + +/// Given an instruction \p I and its operands \p Operands, convert all the +/// constant expression operands from \p Operands into corresponding +/// instructions, insert them before \p I, update operands of \p I accordingly, +/// and if required, return all such converted instructions at \p Insts. +void convertConstantExprsToInstructions( + Instruction *I, SmallPtrSetImpl &Operands, + SmallPtrSetImpl *Insts = nullptr); + +/// Given an instruction \p I which uses given constant expression \p CE as +/// operand, either directly or nested within other constant expressions, return +/// all such constant expression operands of \p I which contain \p CE. +SmallPtrSet getConstantExprOperands(Instruction *I, + ConstantExpr *CE); + } // end namespace llvm #endif // LLVM_IR_REPLACECONSTANT_H diff --git a/llvm/lib/IR/ReplaceConstant.cpp b/llvm/lib/IR/ReplaceConstant.cpp --- a/llvm/lib/IR/ReplaceConstant.cpp +++ b/llvm/lib/IR/ReplaceConstant.cpp @@ -68,4 +68,82 @@ llvm_unreachable("Unhandled constant expression!\n"); } } + +void convertConstantExprsToInstructions(Instruction *I, ConstantExpr *CE, + SmallPtrSetImpl *Insts) { + // Get all operands of I which use CE as operand either directly or nested + // within other constant expressions. + SmallPtrSet CEOperands = getConstantExprOperands(I, CE); + + // Convert constant expressions operands of I (from the set CEOperands) to + // instructions. + convertConstantExprsToInstructions(I, CEOperands, Insts); +} + +void convertConstantExprsToInstructions(Instruction *I, + SmallPtrSetImpl &Operands, + SmallPtrSetImpl *Insts) { + for (Use &U : I->operands()) { + auto *V = U.get(); + + if (!Operands.contains(V)) + continue; + + auto *CE = dyn_cast(V); + if (!CE) + continue; + + auto *BI = I; + if (auto *Phi = dyn_cast(I)) { + BasicBlock *BB = Phi->getIncomingBlock(U); + BI = &(*(BB->getFirstInsertionPt())); + } + + auto *NI = CE->getAsInstruction(); + NI->insertBefore(BI); + I->replaceUsesOfWith(CE, NI); + if (Insts) { + Insts->insert(NI); + } + + if (CE->getNumOperands()) { + SmallPtrSet Operands2(CE->op_begin(), CE->op_end()); + convertConstantExprsToInstructions(NI, Operands2, Insts); + } + + CE->removeDeadConstantUsers(); + } +} + +// Return all the constant expression operands of I which contain CE. +SmallPtrSet getConstantExprOperands(Instruction *I, + ConstantExpr *CE) { + SmallPtrSet CEOperands; + + for (Use &U : I->operands()) { + auto *CE2 = dyn_cast(U.get()); + if (!CE2) + continue; + + if (CE2 != CE) { + SmallVector Stack; + Stack.push_back(CE2); + while (!Stack.empty()) { + Value *V = Stack.pop_back_val(); + if (auto *CE3 = dyn_cast(V)) { + if (CE3 == CE) { + CEOperands.insert(U.get()); + break; + } + append_range(Stack, CE3->operands()); + } + } + } else { + CEOperands.insert(U.get()); + } + } + + return CEOperands; +} + } // namespace llvm