Index: llvm/include/llvm/Analysis/InstructionSimplify.h =================================================================== --- llvm/include/llvm/Analysis/InstructionSimplify.h +++ llvm/include/llvm/Analysis/InstructionSimplify.h @@ -240,7 +240,8 @@ /// Given operands for an ICmpInst, fold the result or return null. Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const SimplifyQuery &Q); + const SimplifyQuery &Q, + const Instruction *CxtI = nullptr); /// Given operands for an FCmpInst, fold the result or return null. Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, Index: llvm/lib/Analysis/InstructionSimplify.cpp =================================================================== --- llvm/lib/Analysis/InstructionSimplify.cpp +++ llvm/lib/Analysis/InstructionSimplify.cpp @@ -64,7 +64,8 @@ static Value *simplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const SimplifyQuery &Q, unsigned MaxRecurse); + const SimplifyQuery &Q, unsigned MaxRecurse, + const Instruction *CxtI = nullptr); static Value *simplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); static Value *simplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned); @@ -3627,11 +3628,14 @@ return nullptr; } -static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate, - Value *LHS, Value *RHS, - const SimplifyQuery &Q) { +static Value * +simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + const Instruction *CxtI) { + auto *Inst = (CxtI != nullptr) ? CxtI : Q.CxtI; + // Gracefully handle instructions that have not been inserted yet. - if (!Q.AC || !Q.CxtI || !Q.CxtI->getParent()) + if (!Q.AC || !Inst || !Inst->getParent()) return nullptr; for (Value *AssumeBaseOp : {LHS, RHS}) { @@ -3642,7 +3646,7 @@ CallInst *Assume = cast(AssumeVH); if (std::optional Imp = isImpliedCondition( Assume->getArgOperand(0), Predicate, LHS, RHS, Q.DL)) - if (isValidAssumeForContext(Assume, Q.CxtI, Q.DT)) + if (isValidAssumeForContext(Assume, Inst, Q.DT)) return ConstantInt::get(getCompareTy(LHS), *Imp); } } @@ -3653,7 +3657,8 @@ /// Given operands for an ICmpInst, see if we can fold the result. /// If not, this returns null. static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const SimplifyQuery &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse, + const Instruction *CxtI) { CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); @@ -3910,7 +3915,7 @@ if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse)) return V; - if (Value *V = simplifyICmpWithDominatingAssume(Pred, LHS, RHS, Q)) + if (Value *V = simplifyICmpWithDominatingAssume(Pred, LHS, RHS, Q, CxtI)) return V; // Simplify comparisons of related pointers using a powerful, recursive @@ -3944,8 +3949,8 @@ } Value *llvm::simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const SimplifyQuery &Q) { - return ::simplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit); + const SimplifyQuery &Q, const Instruction *CxtI) { + return ::simplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit, CxtI); } /// Given operands for an FCmpInst, see if we can fold the result. Index: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1786,6 +1786,17 @@ return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0); } + // If X / Y == 0, then (X + 1) % Y => (X + 1 == Y) ? 0 : X + 1 . + if (match(Op0, m_Add(m_Value(X), m_One()))) { + Value *Val = simplifyICmpInst(ICmpInst::ICMP_ULT, X, Op1, SQ, &I); + if (Val && isa(Val)) + if (cast(Val)->isOneValue()) { + Value *FrozenX = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); + Value *Cmp = Builder.CreateICmpEQ(FrozenX, Op1); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenX); + } + } + return nullptr; } Index: llvm/test/Transforms/InstCombine/urem-via-cmp-select.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/InstCombine/urem-via-cmp-select.ll @@ -0,0 +1,77 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +; https://alive2.llvm.org/ce/z/UNmz9j +define noundef i64 @urem_assume(i64 noundef %x, i64 noundef %n) { +; CHECK-LABEL: @urem_assume( +; CHECK-NEXT: start: +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[X:%.*]], [[N:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) +; CHECK-NEXT: [[ADD:%.*]] = add i64 [[X]], 1 +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i64 [[ADD]], [[N]] +; CHECK-NEXT: [[OUT:%.*]] = select i1 [[TMP0]], i64 0, i64 [[ADD]] +; CHECK-NEXT: ret i64 [[OUT]] +; +start: + %cmp = icmp ult i64 %x, %n + tail call void @llvm.assume(i1 %cmp) + %add = add nuw i64 %x, 1 + %out = urem i64 %add, %n + ret i64 %out +} + +; https://alive2.llvm.org/ce/z/uo7HMz +define noundef i64 @urem_assume_without_nuw(i64 noundef %x, i64 noundef %n) { +; CHECK-LABEL: @urem_assume_without_nuw( +; CHECK-NEXT: start: +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[X:%.*]], [[N:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) +; CHECK-NEXT: [[ADD:%.*]] = add i64 [[X]], 1 +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i64 [[ADD]], [[N]] +; CHECK-NEXT: [[OUT:%.*]] = select i1 [[TMP0]], i64 0, i64 [[ADD]] +; CHECK-NEXT: ret i64 [[OUT]] +; +start: + %cmp = icmp ult i64 %x, %n + tail call void @llvm.assume(i1 %cmp) + %add = add i64 %x, 1 + %out = urem i64 %add, %n + ret i64 %out +} + +define noundef i64 @urem_assume_eq(i64 noundef %x, i64 noundef %n) { +; CHECK-LABEL: @urem_assume_eq( +; CHECK-NEXT: start: +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[X:%.*]], [[N:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) +; CHECK-NEXT: [[ADD:%.*]] = add i64 [[X]], 1 +; CHECK-NEXT: [[OUT:%.*]] = urem i64 [[ADD]], [[N]] +; CHECK-NEXT: ret i64 [[OUT]] +; +start: + %cmp = icmp eq i64 %x, %n + tail call void @llvm.assume(i1 %cmp) + %add = add i64 %x, 1 + %out = urem i64 %add, %n + ret i64 %out +} + +; Negative test: The assume is false +define noundef i64 @urem_assume_ne(i64 noundef %x, i64 noundef %n) { +; CHECK-LABEL: @urem_assume_ne( +; CHECK-NEXT: start: +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[X:%.*]], [[N:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) +; CHECK-NEXT: [[ADD:%.*]] = add i64 [[X]], 1 +; CHECK-NEXT: [[OUT:%.*]] = urem i64 [[ADD]], [[N]] +; CHECK-NEXT: ret i64 [[OUT]] +; +start: + %cmp = icmp ne i64 %x, %n + tail call void @llvm.assume(i1 %cmp) + %add = add i64 %x, 1 + %out = urem i64 %add, %n + ret i64 %out +} + +declare void @llvm.assume(i1 noundef)