diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1358,19 +1358,28 @@ /// Fold equality-comparison between zero and any (maybe truncated) right-shift /// by one-less-than-bitwidth into a sign test on the original value. -Instruction *foldSignBitTest(ICmpInst &I) { +Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) { + Instruction *Val; ICmpInst::Predicate Pred; - Value *X; - Constant *C; - if (!I.isEquality() || - !match(&I, m_ICmp(Pred, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))), - m_Zero()))) + if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero()))) return nullptr; - Type *XTy = X->getType(); - unsigned XBitWidth = XTy->getScalarSizeInBits(); - if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, - APInt(XBitWidth, XBitWidth - 1)))) + Value *X; + Type *XTy; + + Constant *C; + if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) { + XTy = X->getType(); + unsigned XBitWidth = XTy->getScalarSizeInBits(); + if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(XBitWidth, XBitWidth - 1)))) + return nullptr; + } else if (isa(Val) && + (X = reassociateShiftAmtsOfTwoSameDirectionShifts( + cast(Val), SQ.getWithInstruction(Val), + /*AnalyzeForSignBitExtraction=*/true))) { + XTy = X->getType(); + } else return nullptr; return ICmpInst::Create(Instruction::ICmp, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -390,6 +390,9 @@ Instruction *visitOr(BinaryOperator &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); + Value *reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction = false); Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract( BinaryOperator &OldAShr); Instruction *visitAShr(BinaryOperator &I); @@ -912,6 +915,7 @@ Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Instruction *foldICmpEquality(ICmpInst &Cmp); Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); + Instruction *foldSignBitTest(ICmpInst &I); Instruction *foldICmpWithZero(ICmpInst &Cmp); Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -25,10 +25,12 @@ // we should rewrite it as // x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) // This is valid for any shift, but they must be identical. -static Instruction * -reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0, - const SimplifyQuery &SQ, - InstCombiner::BuilderTy &Builder) { +// +// AnalyzeForSignBitExtraction indicates that we will only analyze whether this +// pattern has any 2 right-shifts that sum to 1 less than original bit width. +Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction) { // Look for a shift of some instruction, ignore zext of shift amount if any. Instruction *Sh0Op0; Value *ShAmt0; @@ -56,14 +58,25 @@ if (ShAmt0->getType() != ShAmt1->getType()) return nullptr; - // The shift opcodes must be identical. + // We are only looking for signbit extraction if we have two right shifts. + bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && + match(Sh1, m_Shr(m_Value(), m_Value())); + // ... and if it's not two right-shifts, we know the answer already. + if (AnalyzeForSignBitExtraction && !HadTwoRightShifts) + return nullptr; + + // The shift opcodes must be identical, unless we are just checking whether + // this pattern can be interpreted as a sign-bit-extraction. Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode(); - if (ShiftOpcode != Sh1->getOpcode()) + bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode(); + if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction) return nullptr; // If we saw truncation, we'll need to produce extra instruction, - // and for that one of the operands of the shift must be one-use. - if (Trunc && !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + // and for that one of the operands of the shift must be one-use, + // unless of course we don't actually plan to produce any instructions here. + if (Trunc && !AnalyzeForSignBitExtraction && + !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) return nullptr; // Can we fold (ShAmt0+ShAmt1) ? @@ -80,14 +93,22 @@ return nullptr; // FIXME: could perform constant-folding. // If there was a truncation, and we have a right-shift, we can only fold if - // we are left with the original sign bit. + // we are left with the original sign bit. Likewise, if we were just checking + // that this is a sighbit extraction, this is the place to check it. // FIXME: zero shift amount is also legal here, but we can't *easily* check // more than one predicate so it's not really worth it. - if (Trunc && ShiftOpcode != Instruction::BinaryOps::Shl && - !match(NewShAmt, - m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, - APInt(NewShAmtBitWidth, XBitWidth - 1)))) - return nullptr; + if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) { + // If it's not a sign bit extraction, then we're done. + if (!match(NewShAmt, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(NewShAmtBitWidth, XBitWidth - 1)))) + return nullptr; + // If it is, and that was the question, return the base value. + if (AnalyzeForSignBitExtraction) + return X; + } + + assert(IdenticalShOpcodes && "Should not get here with different shifts."); // All good, we can do this fold. NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType()); @@ -287,8 +308,8 @@ if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; - if (Instruction *NewShift = - reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ, Builder)) + if (auto *NewShift = cast_or_null( + reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))) return NewShift; // (C1 shift (A add C2)) -> (C1 shift C2) shift A) diff --git a/llvm/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll b/llvm/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll --- a/llvm/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll +++ b/llvm/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll @@ -45,7 +45,7 @@ ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 64, %nbits @@ -107,7 +107,7 @@ ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 64, %nbits @@ -138,7 +138,7 @@ ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i32 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 32, %nbits @@ -169,7 +169,7 @@ ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 64, %nbits @@ -200,7 +200,7 @@ ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i32 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 32, %nbits @@ -231,7 +231,7 @@ ; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]]) ; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]]) ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]]) -; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0 +; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0 ; CHECK-NEXT: ret i1 [[ISNEG]] ; %num_low_bits_to_skip = sub i32 64, %nbits