diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -1003,6 +1003,12 @@ } }; +struct is_irem_op { + bool isOpType(unsigned Opcode) { + return Opcode == Instruction::SRem || Opcode == Instruction::URem; + } +}; + /// Matches shift operations. template inline BinOpPred_match m_Shift(const LHS &L, @@ -1038,6 +1044,13 @@ return BinOpPred_match(L, R); } +/// Matches integer remainder operations. +template +inline BinOpPred_match m_IRem(const LHS &L, + const RHS &R) { + return BinOpPred_match(L, R); +} + //===----------------------------------------------------------------------===// // Class that matches exact binary ops. // diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1317,6 +1317,27 @@ return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } +/// If we have: +/// icmp eq/ne (urem/srem %x, %y), 0 +/// iff %y is a power-of-two, we can replace this with a bit test: +/// icmp eq/ne (and %x, (add %y, -1)), 0 +Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { + // This fold is only valid for equality predicates. + if (!I.isEquality()) + return nullptr; + ICmpInst::Predicate Pred; + Value *X, *Y, *Zero; + if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))), + m_CombineAnd(m_Zero(), m_Value(Zero))))) + return nullptr; + if (!isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, 0, &I)) + return nullptr; + // This may increase instruction count, we don't enforce that Y is a constant. + Value *Mask = Builder.CreateAdd(Y, Constant::getAllOnesValue(Y->getType())); + Value *Masked = Builder.CreateAnd(X, Mask); + return ICmpInst::Create(Instruction::ICmp, Pred, Masked, Zero); +} + // Handle icmp pred X, 0 Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); @@ -1335,6 +1356,9 @@ } } + if (Instruction *New = foldIRemByPowerOfTwoToBitTest(Cmp)) + return New; + // Given: // icmp eq/ne (urem %x, %y), 0 // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -839,6 +839,7 @@ Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); Instruction *foldICmpBinOp(ICmpInst &Cmp); Instruction *foldICmpEquality(ICmpInst &Cmp); + Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); Instruction *foldICmpWithZero(ICmpInst &Cmp); Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1319,6 +1319,8 @@ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { + // This may increase instruction count, we don't enforce that Y is a + // constant. Constant *N1 = Constant::getAllOnesValue(Ty); Value *Add = Builder.CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); diff --git a/llvm/test/Transforms/InstCombine/rem.ll b/llvm/test/Transforms/InstCombine/rem.ll --- a/llvm/test/Transforms/InstCombine/rem.ll +++ b/llvm/test/Transforms/InstCombine/rem.ll @@ -201,8 +201,8 @@ define i1 @test3a(i32 %A) { ; CHECK-LABEL: @test3a( -; CHECK-NEXT: [[B1:%.*]] = and i32 [[A:%.*]], 7 -; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[B1]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 7 +; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[TMP1]], 0 ; CHECK-NEXT: ret i1 [[C]] ; %B = srem i32 %A, -8 @@ -212,8 +212,8 @@ define <2 x i1> @test3a_vec(<2 x i32> %A) { ; CHECK-LABEL: @test3a_vec( -; CHECK-NEXT: [[B1:%.*]] = and <2 x i32> [[A:%.*]], -; CHECK-NEXT: [[C:%.*]] = icmp ne <2 x i32> [[B1]], zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], +; CHECK-NEXT: [[C:%.*]] = icmp ne <2 x i32> [[TMP1]], zeroinitializer ; CHECK-NEXT: ret <2 x i1> [[C]] ; %B = srem <2 x i32> %A, @@ -681,8 +681,8 @@ define i1 @test25(i32 %A) { ; CHECK-LABEL: @test25( -; CHECK-NEXT: [[B:%.*]] = srem i32 [[A:%.*]], -2147483648 -; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[B]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 2147483647 +; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[TMP1]], 0 ; CHECK-NEXT: ret i1 [[C]] ; %B = srem i32 %A, 2147483648 ; signbit @@ -692,8 +692,8 @@ define <2 x i1> @test25_vec(<2 x i32> %A) { ; CHECK-LABEL: @test25_vec( -; CHECK-NEXT: [[B:%.*]] = srem <2 x i32> [[A:%.*]], -; CHECK-NEXT: [[C:%.*]] = icmp ne <2 x i32> [[B]], zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], +; CHECK-NEXT: [[C:%.*]] = icmp ne <2 x i32> [[TMP1]], zeroinitializer ; CHECK-NEXT: ret <2 x i1> [[C]] ; %B = srem <2 x i32> %A, @@ -703,9 +703,10 @@ define i1 @test26(i32 %A, i32 %B) { ; CHECK-LABEL: @test26( -; CHECK-NEXT: [[C:%.*]] = shl nuw i32 1, [[B:%.*]] -; CHECK-NEXT: [[D:%.*]] = srem i32 [[A:%.*]], [[C]] -; CHECK-NEXT: [[E:%.*]] = icmp ne i32 [[D]], 0 +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i32 -1, [[B:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[NOTMASK]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], [[A:%.*]] +; CHECK-NEXT: [[E:%.*]] = icmp ne i32 [[TMP2]], 0 ; CHECK-NEXT: ret i1 [[E]] ; %C = shl i32 1, %B ; not a constant @@ -729,8 +730,8 @@ define i1 @test28(i32 %A) { ; CHECK-LABEL: @test28( -; CHECK-NEXT: [[B:%.*]] = srem i32 [[A:%.*]], -2147483648 -; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[B]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 2147483647 +; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[TMP1]], 0 ; CHECK-NEXT: ret i1 [[C]] ; %B = srem i32 %A, 2147483648 ; signbit