Index: lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSelect.cpp +++ lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -317,19 +317,44 @@ Value *FalseVal, InstCombiner::BuilderTy *Builder) { const ICmpInst *IC = dyn_cast(SI.getCondition()); - if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy()) + if (!IC || !SI.getType()->isIntegerTy()) return nullptr; Value *CmpLHS = IC->getOperand(0); Value *CmpRHS = IC->getOperand(1); - if (!match(CmpRHS, m_Zero())) - return nullptr; + Value *V; + unsigned C1Log; + bool IsEqualZero; + bool NeedAnd = false; + if (IC->isEquality()) { + if (!match(CmpRHS, m_Zero())) + return nullptr; + + const APInt *C1; + if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) + return nullptr; + + V = CmpLHS; + C1Log = C1->logBase2(); + IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; + } else if (IC->getPredicate() == ICmpInst::ICMP_SLT || + IC->getPredicate() == ICmpInst::ICMP_SGT) { + // We also need to recognize (icmp slt (trunc (X)), 0) and + // (icmp sgt (trunc (X)), -1). + IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT; + if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) || + (!IsEqualZero && !match(CmpRHS, m_Zero()))) + return nullptr; - Value *X; - const APInt *C1; - if (!match(CmpLHS, m_And(m_Value(X), m_Power2(C1)))) + if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V))))) + return nullptr; + + C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; + NeedAnd = true; + } else { return nullptr; + } const APInt *C2; bool OrOnTrueVal = false; @@ -340,15 +365,11 @@ if (!OrOnFalseVal && !OrOnTrueVal) return nullptr; - Value *V = CmpLHS; Value *Y = OrOnFalseVal ? TrueVal : FalseVal; - unsigned C1Log = C1->logBase2(); unsigned C2Log = C2->logBase2(); - ICmpInst::Predicate Pred = IC->getPredicate(); - bool NeedXor = (Pred == ICmpInst::ICMP_NE && OrOnFalseVal) || - (Pred == ICmpInst::ICMP_EQ && OrOnTrueVal); + bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); bool NeedShift = C1Log != C2Log; bool NeedZExtTrunc = Y->getType()->getIntegerBitWidth() != V->getType()->getIntegerBitWidth(); @@ -359,6 +380,14 @@ (IC->hasOneUse() + Or->hasOneUse())) return nullptr; + if (NeedAnd) { + // Insert the AND instruction on the input to the truncate. + V = Builder->CreateAnd( + V, ConstantInt::get(V->getType(), + APInt::getOneBitSet( + V->getType()->getScalarSizeInBits(), C1Log))); + } + if (C2Log > C1Log) { V = Builder->CreateZExtOrTrunc(V, Y->getType()); V = Builder->CreateShl(V, C2Log - C1Log); Index: test/Transforms/InstCombine/select-with-bitwise-ops.ll =================================================================== --- test/Transforms/InstCombine/select-with-bitwise-ops.ll +++ test/Transforms/InstCombine/select-with-bitwise-ops.ll @@ -297,11 +297,10 @@ define i32 @test68(i32 %x, i32 %y) { ; CHECK-LABEL: @test68( -; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8 -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[TMP1]], -1 -; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], 2 -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[CMP]], i32 [[Y]], i32 [[OR]] -; CHECK-NEXT: ret i32 [[SELECT]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 6 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = or i32 [[TMP2]], [[Y:%.*]] +; CHECK-NEXT: ret i32 [[TMP3]] ; %and = and i32 %x, 128 %cmp = icmp eq i32 %and, 0 @@ -312,11 +311,11 @@ define i32 @test69(i32 %x, i32 %y) { ; CHECK-LABEL: @test69( -; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[TMP1]], 0 -; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], 2 -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[CMP]], i32 [[Y]], i32 [[OR]] -; CHECK-NEXT: ret i32 [[SELECT]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 6 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP2]], 2 +; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP3]], [[Y:%.*]] +; CHECK-NEXT: ret i32 [[TMP4]] ; %and = and i32 %x, 128 %cmp = icmp ne i32 %and, 0