diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -18,6 +18,7 @@ #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" @@ -3335,6 +3336,70 @@ return nullptr; } +static Instruction * +foldICmpUSubSatWithConstant(bool ConstantComparisonWithZero, + ICmpInst::Predicate Pred, IntrinsicInst *II, + const APInt &C, InstCombiner::BuilderTy &Builder) { + // Let Y = usub.sat(X, C) pred C2 + // => Y = (X < C ? 0 : (X - C)) pred C2 + // => Y = (X < C) ? (0 pred C2) : ((X - C) pred C2) + // + // When (0 pred C2) [ConstantComparisonWithZero] is true, then + // Y = (X < C) ? true : ((X - C) pred C2) + // => Y = (X < C) || ((X - C)) pred C2 + // else + // Y = (X < C) ? false : ((X - C) pred C2) + // => Y = !(X < C) && ((X - C) pred C2) + // => Y = (X >= C) && ((X - C) pred C2) + Value *Op0 = II->getOperand(0); + Value *Op1 = II->getOperand(1); + + auto [NewPred, LogicalOp] = + ConstantComparisonWithZero + ? std::make_pair(ICmpInst::ICMP_ULT, Instruction::BinaryOps::Or) + : std::make_pair(ICmpInst::ICMP_UGE, Instruction::BinaryOps::And); + + Constant *COp1 = nullptr; + if (match(Op1, m_ImmConstant(COp1))) { + // Try to see if we can generate a single icmp directly. + if (isa(COp1) || COp1->getSplatValue()) { + const APInt &COp1Int = COp1->getUniqueInteger(); + ConstantRange C1 = ConstantRange::makeExactICmpRegion(NewPred, COp1Int); + // Convert '(X - C) pred C2' into 'X pred (C + C2)'. + ConstantRange C2 = ConstantRange::makeExactICmpRegion(Pred, C); + C2 = C2.add(COp1Int); + + std::optional Combination = std::nullopt; + if (LogicalOp == Instruction::BinaryOps::Or) + Combination = C1.exactUnionWith(C2); + else /* LogicalOp == Instruction::BinaryOp::And */ + Combination = C1.exactIntersectWith(C2); + + if (Combination.has_value()) { + CmpInst::Predicate EquivalentPred; + APInt EquivalentInt; + Combination->getEquivalentICmp(EquivalentPred, EquivalentInt); + + return new ICmpInst(EquivalentPred, Op0, + ConstantInt::get(COp1->getType(), EquivalentInt)); + } + } + + // Two uses of X can cause issues with undef, so freeze the value if it is + // not known to be noundef. + if (!isGuaranteedNotToBeUndefOrPoison(Op0)) + Op0 = Builder.CreateFreeze(Op0); + + // Fall back to generating the two icmps and the logical op. + Constant *CondConstant = ConstantInt::get(COp1->getType(), C); + return BinaryOperator::Create( + LogicalOp, Builder.CreateICmp(NewPred, Op0, Op1), + Builder.CreateICmp(Pred, Builder.CreateSub(Op0, Op1), CondConstant)); + } + + return nullptr; +} + /// Fold an equality icmp with LLVM intrinsic and constant operand. Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { @@ -3427,12 +3492,35 @@ return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1)); break; case Intrinsic::usub_sat: { - // usub.sat(a, b) == 0 -> a <= b + // These transforms may end up producing more than one instruction for the + // intrinsic, so limit them to one user. + if (!II->hasOneUse()) + break; + + // This is a special case of foldICmpIntrinsicWithConstant. + // + // Here we have two cases: + // - usub.sat(a, b) == c (or) + // - usub.sat(a, b) != c + // + // When c == 0, this simplifies to + // - usub.sat(a, b) == 0 -> a <= b + // - usub.sat(a, b) != 0 -> a > b + // else we have: + // - usub.sat(a, b) == c -> (a >= b) && (a - b) == c (as 0 == c is false) + // - usub.sat(a, b) != c -> (a < b) || (a - b) != c (as 0 != c is true) + // which is handled by foldICmpUSubSatWithConstant. + if (C.isZero()) { ICmpInst::Predicate NewPred = Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1)); } + + if (auto *Folded = foldICmpUSubSatWithConstant(Pred != ICmpInst::ICMP_EQ, + Pred, II, C, Builder)) + return Folded; + break; } default: @@ -3657,6 +3745,19 @@ II->getArgOperand(1)); } break; + case Intrinsic::usub_sat: { + // These transforms may end up producing more than one instruction for the + // intrinsic, so limit them to one user. + if (!II->hasOneUse()) + break; + + if (auto *Folded = foldICmpUSubSatWithConstant( + ICmpInst::compare(APInt::getZero(C.getBitWidth()), C, Pred), Pred, + II, C, Builder)) + return Folded; + + break; + } default: break; } diff --git a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll --- a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll +++ b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll @@ -10,8 +10,7 @@ define i1 @icmp_eq_basic_positive(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_eq_basic_positive ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 2) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[ARG]], 7 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 2) @@ -22,8 +21,7 @@ define i1 @icmp_ne_basic_positive(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_ne_basic_positive ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 8) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 [[SUB]], 9 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 [[ARG]], 17 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 8) @@ -34,8 +32,7 @@ define i1 @icmp_ule_basic_positive(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_ule_basic_positive ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 6) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SUB]], 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ARG]], 10 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 6) @@ -57,8 +54,7 @@ define i1 @icmp_uge_basic_positive(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_uge_basic_positive ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 4) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[SUB]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[ARG]], 7 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 4) @@ -69,8 +65,7 @@ define i1 @icmp_ugt_basic_positive(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_ugt_basic_positive ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 1) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[SUB]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[ARG]], 4 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 1) @@ -81,8 +76,7 @@ define i1 @icmp_sle_basic_positive(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_sle_basic_positive ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 10) -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SUB]], 9 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ARG]], -2147483639 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 10) @@ -93,8 +87,7 @@ define i1 @icmp_slt_basic_positive(i64 %arg) { ; CHECK-LABEL: define i1 @icmp_slt_basic_positive ; CHECK-SAME: (i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 24) -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[ARG]], -9223372036854775803 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i64 @llvm.usub.sat.i64(i64 %arg, i64 24) @@ -105,8 +98,7 @@ define i1 @icmp_sge_basic_positive(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_sge_basic_positive ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 1) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[SUB]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[ARG]], 124 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 1) @@ -117,8 +109,7 @@ define i1 @icmp_sgt_basic_positive(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_sgt_basic_positive ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 2) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i16 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[ARG]], 32762 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 2) @@ -152,8 +143,7 @@ define i1 @icmp_ule_basic_negative(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_ule_basic_negative ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 -6) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SUB]], 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ARG]], -2 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 -6) @@ -195,8 +185,7 @@ define i1 @icmp_sle_basic_negative(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_sle_basic_negative ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 -10) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SUB]], 9 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[ARG]], -1 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 -10) @@ -207,8 +196,7 @@ define i1 @icmp_slt_basic_negative(i64 %arg) { ; CHECK-LABEL: define i1 @icmp_slt_basic_negative ; CHECK-SAME: (i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 -24) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[ARG]], -19 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i64 @llvm.usub.sat.i64(i64 %arg, i64 -24) @@ -297,7 +285,7 @@ ; CHECK-LABEL: define i1 @icmp_ult_multiuse_positive ; CHECK-SAME: (i64 [[ARG:%.*]]) { ; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 5) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[ARG]], 6 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[SUB]], 0 ; CHECK-NEXT: [[TMP1:%.*]] = and i64 [[SUB]], 1 ; CHECK-NEXT: [[SUB_TR:%.*]] = icmp ne i64 [[TMP1]], 0 ; CHECK-NEXT: [[ADD_NARROW:%.*]] = xor i1 [[SUB_TR]], [[CMP]] @@ -474,7 +462,7 @@ ; CHECK-LABEL: define i1 @icmp_ult_multiuse_negative ; CHECK-SAME: (i64 [[ARG:%.*]]) { ; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 -5) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[ARG]], -4 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[SUB]], 0 ; CHECK-NEXT: [[TMP1:%.*]] = and i64 [[SUB]], 1 ; CHECK-NEXT: [[SUB_TR:%.*]] = icmp ne i64 [[TMP1]], 0 ; CHECK-NEXT: [[ADD_NARROW:%.*]] = xor i1 [[SUB_TR]], [[CMP]] @@ -586,8 +574,7 @@ define <2 x i1> @icmp_eq_vector_positive_equal(<2 x i8> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_eq_vector_positive_equal ; CHECK-SAME: (<2 x i8> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i8> @llvm.usub.sat.v2i8(<2 x i8> [[ARG]], <2 x i8> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i8> @llvm.usub.sat.v2i8(<2 x i8> %arg, <2 x i8> ) @@ -662,8 +649,7 @@ define <2 x i1> @icmp_ne_vector_positive_equal(<2 x i16> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_ne_vector_positive_equal ; CHECK-SAME: (<2 x i16> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i16> @llvm.usub.sat.v2i16(<2 x i16> [[ARG]], <2 x i16> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i16> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i16> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i16> @llvm.usub.sat.v2i16(<2 x i16> %arg, <2 x i16> ) @@ -686,8 +672,7 @@ define <2 x i1> @icmp_ule_vector_positive_equal(<2 x i32> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_ule_vector_positive_equal ; CHECK-SAME: (<2 x i32> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i32> @llvm.usub.sat.v2i32(<2 x i32> [[ARG]], <2 x i32> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i32> @llvm.usub.sat.v2i32(<2 x i32> %arg, <2 x i32> ) @@ -710,8 +695,7 @@ define <2 x i1> @icmp_sgt_vector_positive_equal(<2 x i64> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_sgt_vector_positive_equal ; CHECK-SAME: (<2 x i64> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i64> @llvm.usub.sat.v2i64(<2 x i64> [[ARG]], <2 x i64> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <2 x i64> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i64> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i64> @llvm.usub.sat.v2i64(<2 x i64> %arg, <2 x i64> )