Index: include/llvm/IR/PatternMatch.h =================================================================== --- include/llvm/IR/PatternMatch.h +++ include/llvm/IR/PatternMatch.h @@ -1272,6 +1272,46 @@ return m_Intrinsic(Op0, Op1); } +template struct Signum_match { + Opnd_t Val; + Signum_match(const Opnd_t &V) : Val(V) {} + + template bool match(OpTy *V) { + unsigned TypeSize = V->getType()->getScalarSizeInBits(); + if (TypeSize == 0) + return false; + + unsigned ShiftWidth = TypeSize - 1; + Value *OpL = nullptr, *OpR = nullptr; + + // This is the representation of signum we match: + // + // signum(x) == (x >> 63) | (-x >>u 63) + // + // An i1 value is its own signum, so it's correct to match + // + // signum(x) == (x >> 0) | (-x >>u 0) + // + // for i1 values. + + auto LHS = m_AShr(m_Value(OpL), m_SpecificInt(ShiftWidth)); + auto RHS = m_LShr(m_Neg(m_Value(OpR)), m_SpecificInt(ShiftWidth)); + auto Signum = m_Or(LHS, RHS); + + return Signum.match(V) && OpL == OpR && Val.match(OpL); + } +}; + +/// \brief Matches a signum pattern. +/// +/// signum(x) = +/// x > 0 -> 1 +/// x == 0 -> 0 +/// x < 0 -> -1 +template inline Signum_match m_Signum(const Val_t &V) { + return Signum_match(V); +} + } // end namespace PatternMatch } // end namespace llvm Index: lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCompares.cpp +++ lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1145,6 +1145,14 @@ switch (LHSI->getOpcode()) { case Instruction::Trunc: + if (RHS->isOne() && RHSV.getBitWidth() > 1) { + // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (ICI.getPredicate() == ICmpInst::ICMP_SLT && + match(LHSI->getOperand(0), m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } if (ICI.isEquality() && LHSI->hasOneUse()) { // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all // of the high bits truncated out of x are known. @@ -1467,6 +1475,15 @@ break; case Instruction::Or: { + if (RHS->isOne()) { + // icmp slt signum(V) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (ICI.getPredicate() == ICmpInst::ICMP_SLT && + match(LHSI, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + if (!ICI.isEquality() || !RHS->isNullValue() || !LHSI->hasOneUse()) break; Value *P, *Q; Index: test/Transforms/InstCombine/compare-signs.ll =================================================================== --- test/Transforms/InstCombine/compare-signs.ll +++ test/Transforms/InstCombine/compare-signs.ll @@ -56,3 +56,43 @@ ; CHECK-NOT: zext ; CHECK: ret i32 %2 } + +define i1 @test4a(i32 %a) { +; CHECK-LABEL: @test4a( + entry: +; CHECK: %c = icmp slt i32 %a, 1 +; CHECK-NEXT: ret i1 %c + %l = ashr i32 %a, 31 + %na = sub i32 0, %a + %r = lshr i32 %na, 31 + %signum = or i32 %l, %r + %c = icmp slt i32 %signum, 1 + ret i1 %c +} + +define i1 @test4b(i64 %a) { +; CHECK-LABEL: @test4b( + entry: +; CHECK: %c = icmp slt i64 %a, 1 +; CHECK-NEXT: ret i1 %c + %l = ashr i64 %a, 63 + %na = sub i64 0, %a + %r = lshr i64 %na, 63 + %signum = or i64 %l, %r + %c = icmp slt i64 %signum, 1 + ret i1 %c +} + +define i1 @test4c(i64 %a) { +; CHECK-LABEL: @test4c( + entry: +; CHECK: %c = icmp slt i64 %a, 1 +; CHECK-NEXT: ret i1 %c + %l = ashr i64 %a, 63 + %na = sub i64 0, %a + %r = lshr i64 %na, 63 + %signum = or i64 %l, %r + %signum.trunc = trunc i64 %signum to i32 + %c = icmp slt i32 %signum.trunc, 1 + ret i1 %c +}