Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -898,6 +898,131 @@ return nullptr; } +/// General pattern: +/// X & Y +/// +/// Where Y is checking that all the high bits (covered by a mask 4294967168) +/// are uniform, i.e. %arg & 4294967168 can be either 4294967168 or 0 +/// Pattern can be one of: +/// %t = add i32 %arg, 128 +/// %r = icmp ult i32 %t, 256 +/// Or +/// %t0 = shl i32 %arg, 24 +/// %t1 = ashr i32 %t0, 24 +/// %r = icmp eq i32 %t1, %arg +/// Or +/// %t0 = trunc i32 %arg to i8 +/// %t1 = sext i8 %t0 to i32 +/// %r = icmp eq i32 %t1, %arg +/// This pattern is a signed truncation check. +/// +/// And X is checking that some bit in that same mask is zero. +/// I.e. can be one of: +/// %r = icmp sgt i32 %arg, -1 +/// Or +/// %t = and i32 %arg, 2147483648 +/// %r = icmp eq i32 %t, 0 +/// +/// Since we are checking that all the bits in that mask are the same, +/// and a particular bit is zero, what we are really checking is that all the +/// masked bits are zero. +/// So this should be transformed to: +/// %r = icmp ult i32 %arg, 128 +static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, + Instruction &CxtI, + InstCombiner::BuilderTy &Builder) { + assert(ICmp0 == CxtI.getOperand(0) && ICmp1 == CxtI.getOperand(1) && + CxtI.getOpcode() == Instruction::And); + + // Match icmp ult (add %arg, C01), C1 (C1 == C01 << 1; powers of two) + auto tryToMatchSignedTruncationCheck = [](ICmpInst *ICmp, Value *&X, + APInt &SignBitMask) -> bool { + CmpInst::Predicate Pred; + const APInt *I01, *I1; // powers of two; I1 == I01 << 1 + if (!(match(ICmp, + m_ICmp(Pred, m_Add(m_Value(X), m_Power2(I01)), m_Power2(I1))) && + Pred == ICmpInst::ICMP_ULT && I1->ugt(*I01) && I01->shl(1) == *I1)) + return false; + // Which bit is the new sign bit as per the 'signed truncation' pattern? + SignBitMask = *I01; + return true; + }; + + // One icmp needs to be 'signed truncation check'. + // We need to match this first, else we will mismatch commutative cases. + Value *X1; + APInt HighestBit; + ICmpInst *OtherICmp; + if (tryToMatchSignedTruncationCheck(ICmp1, X1, HighestBit)) + OtherICmp = ICmp0; + else if (tryToMatchSignedTruncationCheck(ICmp0, X1, HighestBit)) + OtherICmp = ICmp1; + else + return nullptr; + + assert(HighestBit.isPowerOf2() && "expected to be power of two (non-zero)"); + + // Try to match/decompose into: icmp eq (X & Mask), 0 + auto tryToDecompose = [](ICmpInst *ICmp, Value *&X, + APInt &UnsetBitsMask) -> bool { + CmpInst::Predicate Pred = ICmp->getPredicate(); + // Can it be decomposed into icmp eq (X & Mask), 0 ? + if (llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), + Pred, X, UnsetBitsMask, + /*LookThruTrunc=*/false) && + Pred == ICmpInst::ICMP_EQ) + return true; + // Is it icmp eq (X & Mask), 0 already? + const APInt *Mask; + if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) && + Pred == ICmpInst::ICMP_EQ) { + UnsetBitsMask = *Mask; + return true; + } + return false; + }; + + // And the other icmp needs to be decomposable into a bit test. + Value *X0; + APInt UnsetBitsMask; + if (!tryToDecompose(OtherICmp, X0, UnsetBitsMask)) + return nullptr; + + assert(!UnsetBitsMask.isNullValue() && "empty mask makes no sense."); + + // Are they working on the same value? + Value *X; + if (X1 == X0) { + // Ok as is. + X = X1; + } else if (match(X0, m_Trunc(m_Specific(X1)))) { + UnsetBitsMask = UnsetBitsMask.zext(X1->getType()->getScalarSizeInBits()); + X = X1; + } else + return nullptr; + + // So which bits should be uniform as per the 'signed truncation check'? + // (all the bits starting with (i.e. including) HighestBit) + APInt SignBitsMask = ~(HighestBit - 1U); + + // UnsetBitsMask must have some common bits with SignBitsMask, + if (!UnsetBitsMask.intersects(SignBitsMask)) + return nullptr; + + // Does UnsetBitsMask contain any bits outside of SignBitsMask? + if (!UnsetBitsMask.isSubsetOf(SignBitsMask)) { + APInt OtherHighestBit = (~UnsetBitsMask) + 1U; + if (!OtherHighestBit.isPowerOf2()) + return nullptr; + HighestBit = APIntOps::umin(HighestBit, OtherHighestBit); + } + // Else, if it does not, then all is ok as-is. + + // %r = icmp ult %X, SignBit + return Builder.CreateICmpULT(X, ConstantInt::get(X->getType(), HighestBit), + CxtI.getName() + ".simplified"); +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { @@ -937,6 +1062,9 @@ if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) return V; + if (Value *V = foldSignedTruncationCheck(LHS, RHS, CxtI, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); ConstantInt *LHSC = dyn_cast(LHS->getOperand(1)); @@ -1304,6 +1432,7 @@ return nullptr; } + static Instruction *foldOrToXor(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { assert(I.getOpcode() == Instruction::Or); Index: test/Transforms/InstCombine/signed-truncation-check.ll =================================================================== --- test/Transforms/InstCombine/signed-truncation-check.ll +++ test/Transforms/InstCombine/signed-truncation-check.ll @@ -38,11 +38,8 @@ define i1 @positive_with_signbit(i32 %arg) { ; CHECK-LABEL: @positive_with_signbit( -; CHECK-NEXT: [[T1:%.*]] = icmp sgt i32 [[ARG:%.*]], -1 -; CHECK-NEXT: [[T2:%.*]] = add i32 [[ARG]], 128 -; CHECK-NEXT: [[T3:%.*]] = icmp ult i32 [[T2]], 256 -; CHECK-NEXT: [[T4:%.*]] = and i1 [[T1]], [[T3]] -; CHECK-NEXT: ret i1 [[T4]] +; CHECK-NEXT: [[T4_SIMPLIFIED:%.*]] = icmp ult i32 [[ARG:%.*]], 128 +; CHECK-NEXT: ret i1 [[T4_SIMPLIFIED]] ; %t1 = icmp sgt i32 %arg, -1 %t2 = add i32 %arg, 128 @@ -53,12 +50,8 @@ define i1 @positive_with_mask(i32 %arg) { ; CHECK-LABEL: @positive_with_mask( -; CHECK-NEXT: [[T1:%.*]] = and i32 [[ARG:%.*]], 1107296256 -; CHECK-NEXT: [[T2:%.*]] = icmp eq i32 [[T1]], 0 -; CHECK-NEXT: [[T3:%.*]] = add i32 [[ARG]], 128 -; CHECK-NEXT: [[T4:%.*]] = icmp ult i32 [[T3]], 256 -; CHECK-NEXT: [[T5:%.*]] = and i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[T5]] +; CHECK-NEXT: [[T5_SIMPLIFIED:%.*]] = icmp ult i32 [[ARG:%.*]], 128 +; CHECK-NEXT: ret i1 [[T5_SIMPLIFIED]] ; %t1 = and i32 %arg, 1107296256 %t2 = icmp eq i32 %t1, 0 @@ -70,11 +63,8 @@ define i1 @positive_with_icmp(i32 %arg) { ; CHECK-LABEL: @positive_with_icmp( -; CHECK-NEXT: [[T1:%.*]] = icmp ult i32 [[ARG:%.*]], 512 -; CHECK-NEXT: [[T2:%.*]] = add i32 [[ARG]], 128 -; CHECK-NEXT: [[T3:%.*]] = icmp ult i32 [[T2]], 256 -; CHECK-NEXT: [[T4:%.*]] = and i1 [[T1]], [[T3]] -; CHECK-NEXT: ret i1 [[T4]] +; CHECK-NEXT: [[T4_SIMPLIFIED:%.*]] = icmp ult i32 [[ARG:%.*]], 128 +; CHECK-NEXT: ret i1 [[T4_SIMPLIFIED]] ; %t1 = icmp ult i32 %arg, 512 %t2 = add i32 %arg, 128 @@ -86,11 +76,8 @@ ; Still the same define i1 @positive_with_aggressive_icmp(i32 %arg) { ; CHECK-LABEL: @positive_with_aggressive_icmp( -; CHECK-NEXT: [[T1:%.*]] = icmp ult i32 [[ARG:%.*]], 128 -; CHECK-NEXT: [[T2:%.*]] = add i32 [[ARG]], 256 -; CHECK-NEXT: [[T3:%.*]] = icmp ult i32 [[T2]], 512 -; CHECK-NEXT: [[T4:%.*]] = and i1 [[T1]], [[T3]] -; CHECK-NEXT: ret i1 [[T4]] +; CHECK-NEXT: [[T4_SIMPLIFIED:%.*]] = icmp ult i32 [[ARG:%.*]], 128 +; CHECK-NEXT: ret i1 [[T4_SIMPLIFIED]] ; %t1 = icmp ult i32 %arg, 128 %t2 = add i32 %arg, 256 @@ -107,11 +94,8 @@ define <2 x i1> @positive_vec_splat(<2 x i32> %arg) { ; CHECK-LABEL: @positive_vec_splat( -; CHECK-NEXT: [[T1:%.*]] = icmp sgt <2 x i32> [[ARG:%.*]], -; CHECK-NEXT: [[T2:%.*]] = add <2 x i32> [[ARG]], -; CHECK-NEXT: [[T3:%.*]] = icmp ult <2 x i32> [[T2]], -; CHECK-NEXT: [[T4:%.*]] = and <2 x i1> [[T1]], [[T3]] -; CHECK-NEXT: ret <2 x i1> [[T4]] +; CHECK-NEXT: [[T4_SIMPLIFIED:%.*]] = icmp ult <2 x i32> [[ARG:%.*]], +; CHECK-NEXT: ret <2 x i1> [[T4_SIMPLIFIED]] ; %t1 = icmp sgt <2 x i32> %arg, %t2 = add <2 x i32> %arg, @@ -249,11 +233,8 @@ define i1 @commutative() { ; CHECK-LABEL: @commutative( ; CHECK-NEXT: [[ARG:%.*]] = call i32 @gen32() -; CHECK-NEXT: [[T1:%.*]] = icmp sgt i32 [[ARG]], -1 -; CHECK-NEXT: [[T2:%.*]] = add i32 [[ARG]], 128 -; CHECK-NEXT: [[T3:%.*]] = icmp ult i32 [[T2]], 256 -; CHECK-NEXT: [[T4:%.*]] = and i1 [[T3]], [[T1]] -; CHECK-NEXT: ret i1 [[T4]] +; CHECK-NEXT: [[T4_SIMPLIFIED:%.*]] = icmp ult i32 [[ARG]], 128 +; CHECK-NEXT: ret i1 [[T4_SIMPLIFIED]] ; %arg = call i32 @gen32() %t1 = icmp sgt i32 %arg, -1 @@ -266,11 +247,8 @@ define i1 @commutative_with_icmp() { ; CHECK-LABEL: @commutative_with_icmp( ; CHECK-NEXT: [[ARG:%.*]] = call i32 @gen32() -; CHECK-NEXT: [[T1:%.*]] = icmp ult i32 [[ARG]], 512 -; CHECK-NEXT: [[T2:%.*]] = add i32 [[ARG]], 128 -; CHECK-NEXT: [[T3:%.*]] = icmp ult i32 [[T2]], 256 -; CHECK-NEXT: [[T4:%.*]] = and i1 [[T3]], [[T1]] -; CHECK-NEXT: ret i1 [[T4]] +; CHECK-NEXT: [[T4_SIMPLIFIED:%.*]] = icmp ult i32 [[ARG]], 128 +; CHECK-NEXT: ret i1 [[T4_SIMPLIFIED]] ; %arg = call i32 @gen32() %t1 = icmp ult i32 %arg, 512 @@ -286,12 +264,8 @@ define i1 @positive_trunc_signbit(i32 %arg) { ; CHECK-LABEL: @positive_trunc_signbit( -; CHECK-NEXT: [[T1:%.*]] = trunc i32 [[ARG:%.*]] to i8 -; CHECK-NEXT: [[T2:%.*]] = icmp sgt i8 [[T1]], -1 -; CHECK-NEXT: [[T3:%.*]] = add i32 [[ARG]], 128 -; CHECK-NEXT: [[T4:%.*]] = icmp ult i32 [[T3]], 256 -; CHECK-NEXT: [[T5:%.*]] = and i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[T5]] +; CHECK-NEXT: [[T5_SIMPLIFIED:%.*]] = icmp ult i32 [[ARG:%.*]], 128 +; CHECK-NEXT: ret i1 [[T5_SIMPLIFIED]] ; %t1 = trunc i32 %arg to i8 %t2 = icmp sgt i8 %t1, -1 @@ -304,11 +278,8 @@ define i1 @positive_trunc_base(i32 %arg) { ; CHECK-LABEL: @positive_trunc_base( ; CHECK-NEXT: [[T1:%.*]] = trunc i32 [[ARG:%.*]] to i16 -; CHECK-NEXT: [[T2:%.*]] = icmp sgt i16 [[T1]], -1 -; CHECK-NEXT: [[T3:%.*]] = add i16 [[T1]], 128 -; CHECK-NEXT: [[T4:%.*]] = icmp ult i16 [[T3]], 256 -; CHECK-NEXT: [[T5:%.*]] = and i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[T5]] +; CHECK-NEXT: [[T5_SIMPLIFIED:%.*]] = icmp ult i16 [[T1]], 128 +; CHECK-NEXT: ret i1 [[T5_SIMPLIFIED]] ; %t1 = trunc i32 %arg to i16 %t2 = icmp sgt i16 %t1, -1 @@ -357,8 +328,8 @@ ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: [[T3:%.*]] = icmp ult i32 [[T2]], 256 ; CHECK-NEXT: call void @use1(i1 [[T3]]) -; CHECK-NEXT: [[T4:%.*]] = and i1 [[T1]], [[T3]] -; CHECK-NEXT: ret i1 [[T4]] +; CHECK-NEXT: [[T4_SIMPLIFIED:%.*]] = icmp ult i32 [[ARG]], 128 +; CHECK-NEXT: ret i1 [[T4_SIMPLIFIED]] ; %t1 = icmp sgt i32 %arg, -1 call void @use1(i1 %t1) @@ -380,8 +351,8 @@ ; CHECK-NEXT: call void @use32(i32 [[T3]]) ; CHECK-NEXT: [[T4:%.*]] = icmp ult i32 [[T3]], 256 ; CHECK-NEXT: call void @use1(i1 [[T4]]) -; CHECK-NEXT: [[T5:%.*]] = and i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[T5]] +; CHECK-NEXT: [[T5_SIMPLIFIED:%.*]] = icmp ult i32 [[ARG]], 128 +; CHECK-NEXT: ret i1 [[T5_SIMPLIFIED]] ; %t1 = and i32 %arg, 603979776 ; some bit within the target 4294967168 mask. call void @use32(i32 %t1)