diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -18,6 +18,7 @@ #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" @@ -3335,6 +3336,63 @@ return nullptr; } +static Instruction * +foldICmpUSubSatWithConstant(bool ConstantComparisonWithZero, + ICmpInst::Predicate Pred, IntrinsicInst *II, + const APInt &C, InstCombiner::BuilderTy &Builder) { + // These transforms may end up producing more than one instruction for the + // intrinsic, so limit them to one user. + if (!II->hasOneUse()) + return nullptr; + + // Let Y = usub.sat(X, C) pred C2 + // => Y = (X < C ? 0 : (X - C)) pred C2 + // => Y = (X < C) ? (0 pred C2) : ((X - C) pred C2) + // + // When (0 pred C2) [ConstantComparisonWithZero] is true, then + // Y = (X < C) ? true : ((X - C) pred C2) + // => Y = (X < C) || ((X - C)) pred C2 + // else + // Y = (X < C) ? false : ((X - C) pred C2) + // => Y = !(X < C) && ((X - C) pred C2) + // => Y = (X >= C) && ((X - C) pred C2) + Value *Op0 = II->getOperand(0); + Value *Op1 = II->getOperand(1); + + auto [NewPred, LogicalOp] = + ConstantComparisonWithZero + ? std::make_pair(ICmpInst::ICMP_ULT, Instruction::BinaryOps::Or) + : std::make_pair(ICmpInst::ICMP_UGE, Instruction::BinaryOps::And); + + const APInt *COp1 = nullptr; + if (match(Op1, m_APInt(COp1))) { + ConstantRange C1 = ConstantRange::makeExactICmpRegion(NewPred, *COp1); + // Convert '(X - C) pred C2' into 'X pred (C + C2)'. + ConstantRange C2 = ConstantRange::makeExactICmpRegion(Pred, C); + C2 = C2.add(*COp1); + + std::optional Combination = std::nullopt; + if (LogicalOp == Instruction::BinaryOps::Or) + Combination = C1.exactUnionWith(C2); + else /* LogicalOp == Instruction::BinaryOp::And */ + Combination = C1.exactIntersectWith(C2); + + if (Combination) { + CmpInst::Predicate EquivPred; + APInt EquivInt; + APInt EquivOffset; + Combination->getEquivalentICmp(EquivPred, EquivInt, EquivOffset); + + return new ICmpInst( + EquivPred, + Builder.CreateAdd(Op0, ConstantInt::get(Op1->getType(), EquivOffset)), + ConstantInt::get(Op1->getType(), EquivInt)); + } + } + + return nullptr; +} + /// Fold an equality icmp with LLVM intrinsic and constant operand. Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { @@ -3427,12 +3485,30 @@ return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1)); break; case Intrinsic::usub_sat: { - // usub.sat(a, b) == 0 -> a <= b + // This is a special case of foldICmpIntrinsicWithConstant. + // + // Here we have two cases: + // - usub.sat(a, b) == c (or) + // - usub.sat(a, b) != c + // + // When c == 0, this simplifies to + // - usub.sat(a, b) == 0 -> a <= b + // - usub.sat(a, b) != 0 -> a > b + // else we have: + // - usub.sat(a, b) == c -> (a >= b) && (a - b) == c (as 0 == c is false) + // - usub.sat(a, b) != c -> (a < b) || (a - b) != c (as 0 != c is true) + // which is handled by foldICmpUSubSatWithConstant. + if (C.isZero()) { ICmpInst::Predicate NewPred = Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1)); } + + if (auto *Folded = foldICmpUSubSatWithConstant(Pred != ICmpInst::ICMP_EQ, + Pred, II, C, Builder)) + return Folded; + break; } default: @@ -3657,6 +3733,14 @@ II->getArgOperand(1)); } break; + case Intrinsic::usub_sat: { + if (auto *Folded = foldICmpUSubSatWithConstant( + ICmpInst::compare(APInt::getZero(C.getBitWidth()), C, Pred), Pred, + II, C, Builder)) + return Folded; + + break; + } default: break; } diff --git a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll --- a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll +++ b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll @@ -10,8 +10,7 @@ define i1 @icmp_eq_basic_positive(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_eq_basic_positive ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 2) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[ARG]], 7 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 2) @@ -22,8 +21,7 @@ define i1 @icmp_ne_basic_positive(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_ne_basic_positive ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 8) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 [[SUB]], 9 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 [[ARG]], 17 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 8) @@ -34,8 +32,7 @@ define i1 @icmp_ule_basic_positive(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_ule_basic_positive ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 6) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SUB]], 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ARG]], 10 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 6) @@ -57,8 +54,7 @@ define i1 @icmp_uge_basic_positive(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_uge_basic_positive ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 4) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[SUB]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[ARG]], 7 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 4) @@ -69,8 +65,7 @@ define i1 @icmp_ugt_basic_positive(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_ugt_basic_positive ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 1) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[SUB]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[ARG]], 4 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 1) @@ -81,8 +76,8 @@ define i1 @icmp_sle_basic_positive(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_sle_basic_positive ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 10) -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SUB]], 9 +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[ARG]], 2147483638 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[TMP1]], -2147483639 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 10) @@ -93,8 +88,8 @@ define i1 @icmp_slt_basic_positive(i64 %arg) { ; CHECK-LABEL: define i1 @icmp_slt_basic_positive ; CHECK-SAME: (i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 24) -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[SUB]], 5 +; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[ARG]], 9223372036854775784 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[TMP1]], -9223372036854775803 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i64 @llvm.usub.sat.i64(i64 %arg, i64 24) @@ -105,8 +100,8 @@ define i1 @icmp_sge_basic_positive(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_sge_basic_positive ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 1) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[SUB]], 3 +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[ARG]], -5 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[TMP1]], 124 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 1) @@ -117,8 +112,8 @@ define i1 @icmp_sgt_basic_positive(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_sgt_basic_positive ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 2) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i16 [[SUB]], 5 +; CHECK-NEXT: [[TMP1:%.*]] = add i16 [[ARG]], -8 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[TMP1]], 32762 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 2) @@ -152,8 +147,7 @@ define i1 @icmp_ule_basic_negative(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_ule_basic_negative ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 -6) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SUB]], 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ARG]], -2 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 -6) @@ -195,8 +189,7 @@ define i1 @icmp_sle_basic_negative(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_sle_basic_negative ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 -10) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SUB]], 9 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[ARG]], -1 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 -10) @@ -207,8 +200,7 @@ define i1 @icmp_slt_basic_negative(i64 %arg) { ; CHECK-LABEL: define i1 @icmp_slt_basic_negative ; CHECK-SAME: (i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 -24) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[ARG]], -19 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i64 @llvm.usub.sat.i64(i64 %arg, i64 -24) @@ -586,8 +578,7 @@ define <2 x i1> @icmp_eq_vector_positive_equal(<2 x i8> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_eq_vector_positive_equal ; CHECK-SAME: (<2 x i8> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i8> @llvm.usub.sat.v2i8(<2 x i8> [[ARG]], <2 x i8> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i8> @llvm.usub.sat.v2i8(<2 x i8> %arg, <2 x i8> ) @@ -662,8 +653,7 @@ define <2 x i1> @icmp_ne_vector_positive_equal(<2 x i16> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_ne_vector_positive_equal ; CHECK-SAME: (<2 x i16> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i16> @llvm.usub.sat.v2i16(<2 x i16> [[ARG]], <2 x i16> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i16> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i16> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i16> @llvm.usub.sat.v2i16(<2 x i16> %arg, <2 x i16> ) @@ -686,8 +676,7 @@ define <2 x i1> @icmp_ule_vector_positive_equal(<2 x i32> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_ule_vector_positive_equal ; CHECK-SAME: (<2 x i32> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i32> @llvm.usub.sat.v2i32(<2 x i32> [[ARG]], <2 x i32> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i32> @llvm.usub.sat.v2i32(<2 x i32> %arg, <2 x i32> ) @@ -710,8 +699,8 @@ define <2 x i1> @icmp_sgt_vector_positive_equal(<2 x i64> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_sgt_vector_positive_equal ; CHECK-SAME: (<2 x i64> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i64> @llvm.usub.sat.v2i64(<2 x i64> [[ARG]], <2 x i64> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <2 x i64> [[SUB]], +; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i64> [[ARG]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i64> [[TMP1]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i64> @llvm.usub.sat.v2i64(<2 x i64> %arg, <2 x i64> )