Index: llvm/include/llvm/Analysis/TargetFolder.h =================================================================== --- llvm/include/llvm/Analysis/TargetFolder.h +++ llvm/include/llvm/Analysis/TargetFolder.h @@ -74,36 +74,29 @@ return nullptr; } - Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override { + Value *FoldBinOp(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS) const { auto *LC = dyn_cast(LHS); auto *RC = dyn_cast(RHS); if (LC && RC) - return Fold(ConstantExpr::getUDiv(LC, RC, IsExact)); + return ConstantFoldBinaryOpOperands(Opcode, LC, RC, DL); return nullptr; } + Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override { + return FoldBinOp(Instruction::UDiv, LHS, RHS); + } + Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override { - auto *LC = dyn_cast(LHS); - auto *RC = dyn_cast(RHS); - if (LC && RC) - return Fold(ConstantExpr::getSDiv(LC, RC, IsExact)); - return nullptr; + return FoldBinOp(Instruction::SDiv, LHS, RHS); } Value *FoldURem(Value *LHS, Value *RHS) const override { - auto *LC = dyn_cast(LHS); - auto *RC = dyn_cast(RHS); - if (LC && RC) - return Fold(ConstantExpr::getURem(LC, RC)); - return nullptr; + return FoldBinOp(Instruction::URem, LHS, RHS); } Value *FoldSRem(Value *LHS, Value *RHS) const override { - auto *LC = dyn_cast(LHS); - auto *RC = dyn_cast(RHS); - if (LC && RC) - return Fold(ConstantExpr::getSRem(LC, RC)); - return nullptr; + return FoldBinOp(Instruction::SRem, LHS, RHS); } Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override { Index: llvm/include/llvm/IR/ConstantFolder.h =================================================================== --- llvm/include/llvm/IR/ConstantFolder.h +++ llvm/include/llvm/IR/ConstantFolder.h @@ -63,36 +63,29 @@ return nullptr; } - Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override { + Value *FoldBinOp(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS) const { auto *LC = dyn_cast(LHS); auto *RC = dyn_cast(RHS); if (LC && RC) - return ConstantExpr::getUDiv(LC, RC, IsExact); + return ConstantFoldBinaryInstruction(Opcode, LC, RC); return nullptr; } + Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override { + return FoldBinOp(Instruction::UDiv, LHS, RHS); + } + Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override { - auto *LC = dyn_cast(LHS); - auto *RC = dyn_cast(RHS); - if (LC && RC) - return ConstantExpr::getSDiv(LC, RC, IsExact); - return nullptr; + return FoldBinOp(Instruction::SDiv, LHS, RHS); } Value *FoldURem(Value *LHS, Value *RHS) const override { - auto *LC = dyn_cast(LHS); - auto *RC = dyn_cast(RHS); - if (LC && RC) - return ConstantExpr::getURem(LC, RC); - return nullptr; + return FoldBinOp(Instruction::URem, LHS, RHS); } Value *FoldSRem(Value *LHS, Value *RHS) const override { - auto *LC = dyn_cast(LHS); - auto *RC = dyn_cast(RHS); - if (LC && RC) - return ConstantExpr::getSRem(LC, RC); - return nullptr; + return FoldBinOp(Instruction::SRem, LHS, RHS); } Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override { Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -9540,14 +9540,7 @@ } return nullptr; } - case scUDivExpr: { - const SCEVUDivExpr *SU = cast(V); - if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) - if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) - if (LHS->getType() == RHS->getType()) - return ConstantExpr::getUDiv(LHS, RHS); - return nullptr; - } + case scUDivExpr: case scSMaxExpr: case scUMaxExpr: case scSMinExpr: Index: llvm/lib/IR/ConstantFold.cpp =================================================================== --- llvm/lib/IR/ConstantFold.cpp +++ llvm/lib/IR/ConstantFold.cpp @@ -2218,9 +2218,15 @@ : cast(CurrIdx->getType())->getNumElements(), Factor); - NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor); + NewIdxs[i] = + ConstantFoldBinaryInstruction(Instruction::SRem, CurrIdx, Factor); - Constant *Div = ConstantExpr::getSDiv(CurrIdx, Factor); + Constant *Div = + ConstantFoldBinaryInstruction(Instruction::SDiv, CurrIdx, Factor); + + // We're working on either ConstantInt or vectors of ConstantInt, + // so these should always fold. + assert(NewIdxs[i] != nullptr && Div != nullptr && "Should have folded"); unsigned CommonExtendedWidth = std::max(PrevIdx->getType()->getScalarSizeInBits(), Index: llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1543,7 +1543,10 @@ !ShAmtC->containsConstantExpression()) { // Canonicalize a shift amount constant operand to modulo the bit-width. Constant *WidthC = ConstantInt::get(Ty, BitWidth); - Constant *ModuloC = ConstantExpr::getURem(ShAmtC, WidthC); + Constant *ModuloC = + ConstantFoldBinaryOpOperands(Instruction::URem, ShAmtC, WidthC, DL); + if (!ModuloC) + return nullptr; if (ModuloC != ShAmtC) return replaceOperand(*II, 2, ModuloC);