diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3579,16 +3579,87 @@ return foldICmpBinOpEqualityWithConstant(Cmp, BO, C); } +static Instruction * +foldICmpUSubSatWithConstant(ICmpInst::Predicate Pred, IntrinsicInst *II, + const APInt &C, InstCombiner::BuilderTy &Builder) { + // This transform may end up producing more than one instruction for the + // intrinsic, so limit it to one user of the intrinsic. + if (!II->hasOneUse()) + return nullptr; + + // Let Y = usub_sat(X, C) pred C2 + // => Y = (X < C ? 0 : (X - C)) pred C2 + // => Y = (X < C) ? (0 pred C2) : ((X - C) pred C2) + // + // When (0 pred C2) is true, then + // Y = (X < C) ? true : ((X - C) pred C2) + // => Y = (X < C) || ((X - C) pred C2) + // else + // Y = (X < C) ? false : ((X - C) pred C2) + // => Y = !(X < C) && ((X - C) pred C2) + // => Y = (X >= C) && ((X - C) pred C2) + Value *Op0 = II->getOperand(0); + Value *Op1 = II->getOperand(1); + + // Check (0 pred C2) + auto [NewPred, LogicalOp] = + ICmpInst::compare(APInt::getZero(C.getBitWidth()), C, Pred) + ? std::make_pair(ICmpInst::ICMP_ULT, Instruction::BinaryOps::Or) + : std::make_pair(ICmpInst::ICMP_UGE, Instruction::BinaryOps::And); + + const APInt *COp1; + // This transform only works when the usub_sat has an integral constant or + // splat vector as the second operand. + if (!match(Op1, m_APInt(COp1))) + return nullptr; + + ConstantRange C1 = ConstantRange::makeExactICmpRegion(NewPred, *COp1); + // Convert '(X - C) pred C2' into 'X pred C2' shifted by C. + ConstantRange C2 = ConstantRange::makeExactICmpRegion(Pred, C); + C2 = C2.add(*COp1); + + std::optional Combination; + if (LogicalOp == Instruction::BinaryOps::Or) + Combination = C1.exactUnionWith(C2); + else /* LogicalOp == Instruction::BinaryOps::And */ + Combination = C1.exactIntersectWith(C2); + + if (!Combination) + return nullptr; + + CmpInst::Predicate EquivPred; + APInt EquivInt; + APInt EquivOffset; + + Combination->getEquivalentICmp(EquivPred, EquivInt, EquivOffset); + + return new ICmpInst( + EquivPred, + Builder.CreateAdd(Op0, ConstantInt::get(Op1->getType(), EquivOffset)), + ConstantInt::get(Op1->getType(), EquivInt)); +} + /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + + // Handle folds that apply for any kind of icmp. + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::usub_sat: + if (auto *Folded = foldICmpUSubSatWithConstant(Pred, II, C, Builder)) + return Folded; + break; + } + if (Cmp.isEquality()) return foldICmpEqIntrinsicWithConstant(Cmp, II, C); Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); - ICmpInst::Predicate Pred = Cmp.getPredicate(); switch (II->getIntrinsicID()) { case Intrinsic::ctpop: { // (ctpop X > BitWidth - 1) --> X == -1 diff --git a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll --- a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll +++ b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll @@ -10,8 +10,7 @@ define i1 @icmp_eq_basic_positive(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_eq_basic_positive ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 2) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[ARG]], 7 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 2) @@ -22,8 +21,7 @@ define i1 @icmp_ne_basic_positive(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_ne_basic_positive ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 8) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 [[SUB]], 9 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 [[ARG]], 17 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 8) @@ -34,8 +32,7 @@ define i1 @icmp_ule_basic_positive(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_ule_basic_positive ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 6) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SUB]], 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ARG]], 10 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 6) @@ -46,8 +43,7 @@ define i1 @icmp_ult_basic_positive(i64 %arg) { ; CHECK-LABEL: define i1 @icmp_ult_basic_positive ; CHECK-SAME: (i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 5) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[SUB]], 2 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[ARG]], 7 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i64 @llvm.usub.sat.i64(i64 %arg, i64 5) @@ -58,8 +54,7 @@ define i1 @icmp_uge_basic_positive(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_uge_basic_positive ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 4) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[SUB]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[ARG]], 7 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 4) @@ -70,8 +65,7 @@ define i1 @icmp_ugt_basic_positive(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_ugt_basic_positive ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 1) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[SUB]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[ARG]], 4 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 1) @@ -82,8 +76,8 @@ define i1 @icmp_sle_basic_positive(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_sle_basic_positive ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 10) -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SUB]], 9 +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[ARG]], 2147483638 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[TMP1]], -2147483639 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 10) @@ -94,8 +88,8 @@ define i1 @icmp_slt_basic_positive(i64 %arg) { ; CHECK-LABEL: define i1 @icmp_slt_basic_positive ; CHECK-SAME: (i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 24) -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[SUB]], 5 +; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[ARG]], 9223372036854775784 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[TMP1]], -9223372036854775803 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i64 @llvm.usub.sat.i64(i64 %arg, i64 24) @@ -106,8 +100,8 @@ define i1 @icmp_sge_basic_positive(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_sge_basic_positive ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 1) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[SUB]], 3 +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[ARG]], -5 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[TMP1]], 124 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 1) @@ -118,8 +112,8 @@ define i1 @icmp_sgt_basic_positive(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_sgt_basic_positive ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 2) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i16 [[SUB]], 5 +; CHECK-NEXT: [[TMP1:%.*]] = add i16 [[ARG]], -8 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[TMP1]], 32762 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 2) @@ -133,8 +127,7 @@ define i1 @icmp_eq_basic_negative(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_eq_basic_negative ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 -20) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[ARG]], -15 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 -20) @@ -145,8 +138,7 @@ define i1 @icmp_ne_basic_negative(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_ne_basic_negative ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 -80) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 [[SUB]], 9 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 [[ARG]], -71 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 -80) @@ -157,8 +149,7 @@ define i1 @icmp_ule_basic_negative(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_ule_basic_negative ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 -6) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SUB]], 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ARG]], -2 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 -6) @@ -169,8 +160,7 @@ define i1 @icmp_ult_basic_negative(i64 %arg) { ; CHECK-LABEL: define i1 @icmp_ult_basic_negative ; CHECK-SAME: (i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 -10) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[SUB]], 2 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[ARG]], -8 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i64 @llvm.usub.sat.i64(i64 %arg, i64 -10) @@ -181,8 +171,7 @@ define i1 @icmp_uge_basic_negative(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_uge_basic_negative ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 -4) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[SUB]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[ARG]], -3 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 -4) @@ -193,8 +182,7 @@ define i1 @icmp_ugt_basic_negative(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_ugt_basic_negative ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 -10) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[SUB]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[ARG]], -7 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 -10) @@ -205,8 +193,7 @@ define i1 @icmp_sle_basic_negative(i32 %arg) { ; CHECK-LABEL: define i1 @icmp_sle_basic_negative ; CHECK-SAME: (i32 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[ARG]], i32 -10) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SUB]], 9 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[ARG]], -1 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i32 @llvm.usub.sat.i32(i32 %arg, i32 -10) @@ -217,8 +204,7 @@ define i1 @icmp_slt_basic_negative(i64 %arg) { ; CHECK-LABEL: define i1 @icmp_slt_basic_negative ; CHECK-SAME: (i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[ARG]], i64 -24) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[ARG]], -19 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i64 @llvm.usub.sat.i64(i64 %arg, i64 -24) @@ -229,8 +215,7 @@ define i1 @icmp_sge_basic_negative(i8 %arg) { ; CHECK-LABEL: define i1 @icmp_sge_basic_negative ; CHECK-SAME: (i8 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[ARG]], i8 -10) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[SUB]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[ARG]], -7 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i8 @llvm.usub.sat.i8(i8 %arg, i8 -10) @@ -241,8 +226,7 @@ define i1 @icmp_sgt_basic_negative(i16 %arg) { ; CHECK-LABEL: define i1 @icmp_sgt_basic_negative ; CHECK-SAME: (i16 [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[ARG]], i16 -20) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[SUB]], 5 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[ARG]], -15 ; CHECK-NEXT: ret i1 [[CMP]] ; %sub = call i16 @llvm.usub.sat.i16(i16 %arg, i16 -20) @@ -290,8 +274,7 @@ define <2 x i1> @icmp_eq_vector_positive_equal(<2 x i8> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_eq_vector_positive_equal ; CHECK-SAME: (<2 x i8> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i8> @llvm.usub.sat.v2i8(<2 x i8> [[ARG]], <2 x i8> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i8> @llvm.usub.sat.v2i8(<2 x i8> %arg, <2 x i8> ) @@ -314,8 +297,7 @@ define <2 x i1> @icmp_ne_vector_positive_equal(<2 x i16> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_ne_vector_positive_equal ; CHECK-SAME: (<2 x i16> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i16> @llvm.usub.sat.v2i16(<2 x i16> [[ARG]], <2 x i16> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i16> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i16> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i16> @llvm.usub.sat.v2i16(<2 x i16> %arg, <2 x i16> ) @@ -338,8 +320,7 @@ define <2 x i1> @icmp_ule_vector_positive_equal(<2 x i32> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_ule_vector_positive_equal ; CHECK-SAME: (<2 x i32> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i32> @llvm.usub.sat.v2i32(<2 x i32> [[ARG]], <2 x i32> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i32> @llvm.usub.sat.v2i32(<2 x i32> %arg, <2 x i32> ) @@ -362,8 +343,8 @@ define <2 x i1> @icmp_sgt_vector_positive_equal(<2 x i64> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_sgt_vector_positive_equal ; CHECK-SAME: (<2 x i64> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i64> @llvm.usub.sat.v2i64(<2 x i64> [[ARG]], <2 x i64> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <2 x i64> [[SUB]], +; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i64> [[ARG]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i64> [[TMP1]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i64> @llvm.usub.sat.v2i64(<2 x i64> %arg, <2 x i64> ) @@ -389,8 +370,7 @@ define <2 x i1> @icmp_eq_vector_negative_equal(<2 x i8> %arg) { ; CHECK-LABEL: define <2 x i1> @icmp_eq_vector_negative_equal ; CHECK-SAME: (<2 x i8> [[ARG:%.*]]) { -; CHECK-NEXT: [[SUB:%.*]] = call <2 x i8> @llvm.usub.sat.v2i8(<2 x i8> [[ARG]], <2 x i8> ) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[SUB]], +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[ARG]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %sub = call <2 x i8> @llvm.usub.sat.v2i8(<2 x i8> %arg, <2 x i8> )