diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4706,8 +4706,7 @@ return nullptr; } -static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { assert(isa(ICmp.getOperand(0)) && "Expected cast for operand 0"); auto *CastOp0 = cast(ICmp.getOperand(0)); Value *X; @@ -4716,25 +4715,37 @@ bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt; bool IsSignedCmp = ICmp.isSigned(); - if (auto *CastOp1 = dyn_cast(ICmp.getOperand(1))) { - // If the signedness of the two casts doesn't agree (i.e. one is a sext - // and the other is a zext), then we can't handle this. - // TODO: This is too strict. We can handle some predicates (equality?). - if (CastOp0->getOpcode() != CastOp1->getOpcode()) - return nullptr; + + // icmp Pred (ext X), (ext Y) + Value *Y; + if (match(ICmp.getOperand(1), m_ZExtOrSExt(m_Value(Y)))) { + bool IsZext0 = isa(ICmp.getOperand(0)); + bool IsZext1 = isa(ICmp.getOperand(1)); + + // If we have mismatched casts, treat the zext of a non-negative source as + // a sext to simulate matching casts. Otherwise, we are done. + // TODO: Can we handle some predicates (equality) without non-negative? + if (IsZext0 != IsZext1) { + if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) || + (IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT))) + IsSignedExt = true; + else + return nullptr; + } // Not an extension from the same type? - Value *Y = CastOp1->getOperand(0); Type *XTy = X->getType(), *YTy = Y->getType(); if (XTy != YTy) { // One of the casts must have one use because we are creating a new cast. - if (!CastOp0->hasOneUse() && !CastOp1->hasOneUse()) + if (!ICmp.getOperand(0)->hasOneUse() && !ICmp.getOperand(1)->hasOneUse()) return nullptr; // Extend the narrower operand to the type of the wider operand. + CastInst::CastOps CastOpcode = + IsSignedExt ? Instruction::SExt : Instruction::ZExt; if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits()) - X = Builder.CreateCast(CastOp0->getOpcode(), X, YTy); + X = Builder.CreateCast(CastOpcode, X, YTy); else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits()) - Y = Builder.CreateCast(CastOp0->getOpcode(), Y, XTy); + Y = Builder.CreateCast(CastOpcode, Y, XTy); else return nullptr; } @@ -4852,7 +4863,7 @@ if (Instruction *R = foldICmpWithTrunc(ICmp, Builder)) return R; - return foldICmpWithZextOrSext(ICmp, Builder); + return foldICmpWithZextOrSext(ICmp); } static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -661,7 +661,8 @@ Constant *RHSC); Instruction *foldICmpAddOpConst(Value *X, const APInt &C, ICmpInst::Predicate Pred); - Instruction *foldICmpWithCastOp(ICmpInst &ICI); + Instruction *foldICmpWithCastOp(ICmpInst &ICmp); + Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp); Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); diff --git a/llvm/test/Transforms/InstCombine/icmp-ext-ext.ll b/llvm/test/Transforms/InstCombine/icmp-ext-ext.ll --- a/llvm/test/Transforms/InstCombine/icmp-ext-ext.ll +++ b/llvm/test/Transforms/InstCombine/icmp-ext-ext.ll @@ -250,9 +250,7 @@ define i1 @zext_sext_sgt_known_nonneg(i8 %x, i8 %y) { ; CHECK-LABEL: @zext_sext_sgt_known_nonneg( ; CHECK-NEXT: [[N:%.*]] = udiv i8 127, [[X:%.*]] -; CHECK-NEXT: [[A:%.*]] = zext i8 [[N]] to i32 -; CHECK-NEXT: [[B:%.*]] = sext i8 [[Y:%.*]] to i32 -; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[A]], [[B]] +; CHECK-NEXT: [[C:%.*]] = icmp sgt i8 [[N]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[C]] ; %n = udiv i8 127, %x @@ -265,9 +263,7 @@ define i1 @zext_sext_ugt_known_nonneg(i8 %x, i8 %y) { ; CHECK-LABEL: @zext_sext_ugt_known_nonneg( ; CHECK-NEXT: [[N:%.*]] = and i8 [[X:%.*]], 127 -; CHECK-NEXT: [[A:%.*]] = zext i8 [[N]] to i32 -; CHECK-NEXT: [[B:%.*]] = sext i8 [[Y:%.*]] to i32 -; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A]], [[B]] +; CHECK-NEXT: [[C:%.*]] = icmp ugt i8 [[N]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[C]] ; %n = and i8 %x, 127 @@ -280,9 +276,7 @@ define i1 @zext_sext_eq_known_nonneg(i8 %x, i8 %y) { ; CHECK-LABEL: @zext_sext_eq_known_nonneg( ; CHECK-NEXT: [[N:%.*]] = lshr i8 [[X:%.*]], 1 -; CHECK-NEXT: [[A:%.*]] = zext i8 [[N]] to i32 -; CHECK-NEXT: [[B:%.*]] = sext i8 [[Y:%.*]] to i32 -; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[A]], [[B]] +; CHECK-NEXT: [[C:%.*]] = icmp eq i8 [[N]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[C]] ; %n = lshr i8 %x, 1 @@ -295,9 +289,8 @@ define i1 @zext_sext_sle_known_nonneg_op0_narrow(i8 %x, i16 %y) { ; CHECK-LABEL: @zext_sext_sle_known_nonneg_op0_narrow( ; CHECK-NEXT: [[N:%.*]] = and i8 [[X:%.*]], 12 -; CHECK-NEXT: [[A:%.*]] = zext i8 [[N]] to i32 -; CHECK-NEXT: [[B:%.*]] = sext i16 [[Y:%.*]] to i32 -; CHECK-NEXT: [[C:%.*]] = icmp sle i32 [[A]], [[B]] +; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[N]] to i16 +; CHECK-NEXT: [[C:%.*]] = icmp sle i16 [[TMP1]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[C]] ; %n = and i8 %x, 12 @@ -310,9 +303,8 @@ define i1 @zext_sext_ule_known_nonneg_op0_wide(i9 %x, i8 %y) { ; CHECK-LABEL: @zext_sext_ule_known_nonneg_op0_wide( ; CHECK-NEXT: [[N:%.*]] = urem i9 [[X:%.*]], 254 -; CHECK-NEXT: [[A:%.*]] = zext i9 [[N]] to i32 -; CHECK-NEXT: [[B:%.*]] = sext i8 [[Y:%.*]] to i32 -; CHECK-NEXT: [[C:%.*]] = icmp ule i32 [[A]], [[B]] +; CHECK-NEXT: [[TMP1:%.*]] = sext i8 [[Y:%.*]] to i9 +; CHECK-NEXT: [[C:%.*]] = icmp ule i9 [[N]], [[TMP1]] ; CHECK-NEXT: ret i1 [[C]] ; %n = urem i9 %x, 254 @@ -324,10 +316,8 @@ define i1 @sext_zext_slt_known_nonneg(i8 %x, i8 %y) { ; CHECK-LABEL: @sext_zext_slt_known_nonneg( -; CHECK-NEXT: [[A:%.*]] = sext i8 [[X:%.*]] to i32 ; CHECK-NEXT: [[N:%.*]] = and i8 [[Y:%.*]], 126 -; CHECK-NEXT: [[B:%.*]] = zext i8 [[N]] to i32 -; CHECK-NEXT: [[C:%.*]] = icmp slt i32 [[A]], [[B]] +; CHECK-NEXT: [[C:%.*]] = icmp sgt i8 [[N]], [[X:%.*]] ; CHECK-NEXT: ret i1 [[C]] ; %a = sext i8 %x to i32 @@ -339,10 +329,8 @@ define i1 @sext_zext_ult_known_nonneg(i8 %x, i8 %y) { ; CHECK-LABEL: @sext_zext_ult_known_nonneg( -; CHECK-NEXT: [[A:%.*]] = sext i8 [[X:%.*]] to i32 ; CHECK-NEXT: [[N:%.*]] = lshr i8 [[Y:%.*]], 6 -; CHECK-NEXT: [[B:%.*]] = zext i8 [[N]] to i32 -; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A]], [[B]] +; CHECK-NEXT: [[C:%.*]] = icmp ugt i8 [[N]], [[X:%.*]] ; CHECK-NEXT: ret i1 [[C]] ; %a = sext i8 %x to i32 @@ -354,10 +342,8 @@ define i1 @sext_zext_ne_known_nonneg(i8 %x, i8 %y) { ; CHECK-LABEL: @sext_zext_ne_known_nonneg( -; CHECK-NEXT: [[A:%.*]] = sext i8 [[X:%.*]] to i32 ; CHECK-NEXT: [[N:%.*]] = udiv i8 [[Y:%.*]], 6 -; CHECK-NEXT: [[B:%.*]] = zext i8 [[N]] to i32 -; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[A]], [[B]] +; CHECK-NEXT: [[C:%.*]] = icmp ne i8 [[N]], [[X:%.*]] ; CHECK-NEXT: ret i1 [[C]] ; %a = sext i8 %x to i32 @@ -369,10 +355,9 @@ define <2 x i1> @sext_zext_sge_known_nonneg_op0_narrow(<2 x i5> %x, <2 x i8> %y) { ; CHECK-LABEL: @sext_zext_sge_known_nonneg_op0_narrow( -; CHECK-NEXT: [[A:%.*]] = sext <2 x i5> [[X:%.*]] to <2 x i32> ; CHECK-NEXT: [[N:%.*]] = mul nsw <2 x i8> [[Y:%.*]], [[Y]] -; CHECK-NEXT: [[B:%.*]] = zext <2 x i8> [[N]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = icmp sge <2 x i32> [[A]], [[B]] +; CHECK-NEXT: [[TMP1:%.*]] = sext <2 x i5> [[X:%.*]] to <2 x i8> +; CHECK-NEXT: [[C:%.*]] = icmp sle <2 x i8> [[N]], [[TMP1]] ; CHECK-NEXT: ret <2 x i1> [[C]] ; %a = sext <2 x i5> %x to <2 x i32> @@ -384,10 +369,9 @@ define i1 @sext_zext_uge_known_nonneg_op0_wide(i16 %x, i8 %y) { ; CHECK-LABEL: @sext_zext_uge_known_nonneg_op0_wide( -; CHECK-NEXT: [[A:%.*]] = sext i16 [[X:%.*]] to i32 ; CHECK-NEXT: [[N:%.*]] = and i8 [[Y:%.*]], 12 -; CHECK-NEXT: [[B:%.*]] = zext i8 [[N]] to i32 -; CHECK-NEXT: [[C:%.*]] = icmp uge i32 [[A]], [[B]] +; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[N]] to i16 +; CHECK-NEXT: [[C:%.*]] = icmp ule i16 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i1 [[C]] ; %a = sext i16 %x to i32