diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3580,8 +3580,9 @@ } static Instruction * -foldICmpUSubSatWithConstant(ICmpInst::Predicate Pred, IntrinsicInst *II, - const APInt &C, InstCombiner::BuilderTy &Builder) { +foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred, + IntrinsicInst *II, const APInt &C, + InstCombiner::BuilderTy &Builder) { // This transform may end up producing more than one instruction for the // intrinsic, so limit it to one user of the intrinsic. if (!II->hasOneUse()) @@ -3598,25 +3599,62 @@ // Y = (X < C) ? false : ((X - C) pred C2) // => Y = !(X < C) && ((X - C) pred C2) // => Y = (X >= C) && ((X - C) pred C2) + //====================================== + // Let Y = uadd_sat(X, C) pred C2 + // => Y = ((X + C < UINT_MAX) ? (X + C) : UINT_MAX) pred C2 + // Now, + // Let Z = (X + C < UINT_MAX) + // Then Z = (X < UINT_MAX - C) + // => Z = (X < (UINT_MAX + ~C + 1)) [as -C == (~C + 1)] + // => Z = (X < (~C + 0)) [as (UINT_MAX + 1) == 0] + // => Z = (X < ~C) + // Therefore + // Y = ((X < ~C) ? (X + C) : UINT_MAX) pred C2 + // => Y = (X < ~C) ? ((X + C) pred C2) : (UINT_MAX pred C2) + // + // When (UINT_MAX pred C2) is true, then + // Y = (X < ~C) ? ((X + C) pred C2) : true + // => Y = !(X < ~C) ? true : ((X + C) pred C2) + // => Y = !(X < ~C) || ((X + C) pred C2) + // => Y = (X >= ~C) || ((X + C) pred C2) + // else + // Y = (X < ~C) ? ((X + C) pred C2) : false + // => Y = (X < ~C) && ((X + C) pred C2) Value *Op0 = II->getOperand(0); Value *Op1 = II->getOperand(1); - // Check (0 pred C2) - auto [NewPred, LogicalOp] = - ICmpInst::compare(APInt::getZero(C.getBitWidth()), C, Pred) - ? std::make_pair(ICmpInst::ICMP_ULT, Instruction::BinaryOps::Or) - : std::make_pair(ICmpInst::ICMP_UGE, Instruction::BinaryOps::And); - const APInt *COp1; - // This transform only works when the usub_sat has an integral constant or + // This transform only works when the intrinsic has an integral constant or // splat vector as the second operand. if (!match(Op1, m_APInt(COp1))) return nullptr; - ConstantRange C1 = ConstantRange::makeExactICmpRegion(NewPred, *COp1); - // Convert '(X - C) pred C2' into 'X pred C2' shifted by C. + bool IntrinsicIsUAddSat = II->getIntrinsicID() == Intrinsic::uadd_sat; + APInt SaturatingValue = IntrinsicIsUAddSat + ? APInt::getAllOnes(C.getBitWidth()) + : APInt::getZero(C.getBitWidth()); + + // Check either (0 pred C2) or (UINT_MAX pred C2) + bool SaturatingValueCheck = ICmpInst::compare(SaturatingValue, C, Pred); + + ICmpInst::Predicate LHSPred; + if (IntrinsicIsUAddSat) + LHSPred = SaturatingValueCheck ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + else + LHSPred = SaturatingValueCheck ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; + + Instruction::BinaryOps LogicalOp = SaturatingValueCheck + ? Instruction::BinaryOps::Or + : Instruction::BinaryOps::And; + + ConstantRange C1 = ConstantRange::makeExactICmpRegion( + LHSPred, IntrinsicIsUAddSat ? ~*COp1 : *COp1); ConstantRange C2 = ConstantRange::makeExactICmpRegion(Pred, C); - C2 = C2.add(*COp1); + + if (IntrinsicIsUAddSat) + C2 = C2.sub(*COp1); + else + C2 = C2.add(*COp1); std::optional Combination; if (LogicalOp == Instruction::BinaryOps::Or) @@ -3649,8 +3687,10 @@ switch (II->getIntrinsicID()) { default: break; + case Intrinsic::uadd_sat: case Intrinsic::usub_sat: - if (auto *Folded = foldICmpUSubSatWithConstant(Pred, II, C, Builder)) + if (auto *Folded = + foldICmpUSubSatOrUAddSatWithConstant(Pred, II, C, Builder)) return Folded; break; } diff --git a/llvm/test/Transforms/InstCombine/icmp-uadd-sat.ll b/llvm/test/Transforms/InstCombine/icmp-uadd-sat.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/icmp-uadd-sat.ll @@ -0,0 +1,262 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +; Tests for InstCombineCompares.cpp::foldICmpUSubSatOrUAddSatWithConstant +; - uadd_sat case + +; ============================================================================== +; Basic tests with one user +; ============================================================================== +define i1 @icmp_eq_basic(i8 %arg) { +; CHECK-LABEL: define i1 @icmp_eq_basic +; CHECK-SAME: (i8 [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[ARG]], 3 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i8 @llvm.uadd.sat.i8(i8 %arg, i8 2) + %cmp = icmp eq i8 %add, 5 + ret i1 %cmp +} + +define i1 @icmp_ne_basic(i16 %arg) { +; CHECK-LABEL: define i1 @icmp_ne_basic +; CHECK-SAME: (i16 [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 [[ARG]], 1 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i16 @llvm.uadd.sat.i16(i16 %arg, i16 8) + %cmp = icmp ne i16 %add, 9 + ret i1 %cmp +} + +define i1 @icmp_ule_basic(i32 %arg) { +; CHECK-LABEL: define i1 @icmp_ule_basic +; CHECK-SAME: (i32 [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ARG]], 2 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i32 @llvm.uadd.sat.i32(i32 %arg, i32 2) + %cmp = icmp ule i32 %add, 3 + ret i1 %cmp +} + +define i1 @icmp_ult_basic(i64 %arg) { +; CHECK-LABEL: define i1 @icmp_ult_basic +; CHECK-SAME: (i64 [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[ARG]], 15 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i64 @llvm.uadd.sat.i64(i64 %arg, i64 5) + %cmp = icmp ult i64 %add, 20 + ret i1 %cmp +} + +define i1 @icmp_uge_basic(i8 %arg) { +; CHECK-LABEL: define i1 @icmp_uge_basic +; CHECK-SAME: (i8 [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[ARG]], 3 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i8 @llvm.uadd.sat.i8(i8 %arg, i8 4) + %cmp = icmp uge i8 %add, 8 + ret i1 %cmp +} + +define i1 @icmp_ugt_basic(i16 %arg) { +; CHECK-LABEL: define i1 @icmp_ugt_basic +; CHECK-SAME: (i16 [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[ARG]], 2 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i16 @llvm.uadd.sat.i16(i16 %arg, i16 1) + %cmp = icmp ugt i16 %add, 3 + ret i1 %cmp +} + +define i1 @icmp_sle_basic(i32 %arg) { +; CHECK-LABEL: define i1 @icmp_sle_basic +; CHECK-SAME: (i32 [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[ARG]], 2147483637 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i32 @llvm.uadd.sat.i32(i32 %arg, i32 10) + %cmp = icmp sle i32 %add, 8 + ret i1 %cmp +} + +define i1 @icmp_slt_basic(i64 %arg) { +; CHECK-LABEL: define i1 @icmp_slt_basic +; CHECK-SAME: (i64 [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i64 [[ARG]], 9223372036854775783 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i64 @llvm.uadd.sat.i64(i64 %arg, i64 24) + %cmp = icmp slt i64 %add, 5 + ret i1 %cmp +} + +define i1 @icmp_sge_basic(i8 %arg) { +; CHECK-LABEL: define i1 @icmp_sge_basic +; CHECK-SAME: (i8 [[ARG:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[ARG]], -3 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[TMP1]], 124 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i8 @llvm.uadd.sat.i8(i8 %arg, i8 1) + %cmp = icmp sge i8 %add, 4 + ret i1 %cmp +} + +define i1 @icmp_sgt_basic(i16 %arg) { +; CHECK-LABEL: define i1 @icmp_sgt_basic +; CHECK-SAME: (i16 [[ARG:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i16 [[ARG]], -4 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[TMP1]], 32762 +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i16 @llvm.uadd.sat.i16(i16 %arg, i16 2) + %cmp = icmp sgt i16 %add, 5 + ret i1 %cmp +} + +; ============================================================================== +; Tests with more than user +; ============================================================================== +define i1 @icmp_eq_multiuse(i8 %arg) { +; CHECK-LABEL: define i1 @icmp_eq_multiuse +; CHECK-SAME: (i8 [[ARG:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = call i8 @llvm.uadd.sat.i8(i8 [[ARG]], i8 2) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[ADD]], 5 +; CHECK-NEXT: call void @use.i8(i8 [[ADD]]) +; CHECK-NEXT: ret i1 [[CMP]] +; + %add = call i8 @llvm.uadd.sat.i8(i8 %arg, i8 2) + %cmp = icmp eq i8 %add, 5 + call void @use.i8(i8 %add) + ret i1 %cmp +} + +; ============================================================================== +; Tests with vector types +; ============================================================================== +define <2 x i1> @icmp_eq_vector_equal(<2 x i8> %arg) { +; CHECK-LABEL: define <2 x i1> @icmp_eq_vector_equal +; CHECK-SAME: (<2 x i8> [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[ARG]], +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %add = call <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8> %arg, <2 x i8> ) + %cmp = icmp eq <2 x i8> %add, + ret <2 x i1> %cmp +} + +define <2 x i1> @icmp_eq_vector_unequal(<2 x i8> %arg) { +; CHECK-LABEL: define <2 x i1> @icmp_eq_vector_unequal +; CHECK-SAME: (<2 x i8> [[ARG:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = call <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8> [[ARG]], <2 x i8> ) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[ADD]], +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %add = call <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8> %arg, <2 x i8> ) + %cmp = icmp eq <2 x i8> %add, + ret <2 x i1> %cmp +} + +define <2 x i1> @icmp_ne_vector_equal(<2 x i16> %arg) { +; CHECK-LABEL: define <2 x i1> @icmp_ne_vector_equal +; CHECK-SAME: (<2 x i16> [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i16> [[ARG]], +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %add = call <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16> %arg, <2 x i16> ) + %cmp = icmp ne <2 x i16> %add, + ret <2 x i1> %cmp +} + +define <2 x i1> @icmp_ne_vector_unequal(<2 x i16> %arg) { +; CHECK-LABEL: define <2 x i1> @icmp_ne_vector_unequal +; CHECK-SAME: (<2 x i16> [[ARG:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = call <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16> [[ARG]], <2 x i16> ) +; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i16> [[ADD]], +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %add = call <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16> %arg, <2 x i16> ) + %cmp = icmp ne <2 x i16> %add, + ret <2 x i1> %cmp +} + +define <2 x i1> @icmp_ule_vector_equal(<2 x i32> %arg) { +; CHECK-LABEL: define <2 x i1> @icmp_ule_vector_equal +; CHECK-SAME: (<2 x i32> [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[ARG]], +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %add = call <2 x i32> @llvm.uadd.sat.v2i32(<2 x i32> %arg, <2 x i32> ) + %cmp = icmp ult <2 x i32> %add, + ret <2 x i1> %cmp +} + +define <2 x i1> @icmp_ule_vector_unequal(<2 x i32> %arg) { +; CHECK-LABEL: define <2 x i1> @icmp_ule_vector_unequal +; CHECK-SAME: (<2 x i32> [[ARG:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = call <2 x i32> @llvm.uadd.sat.v2i32(<2 x i32> [[ARG]], <2 x i32> ) +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[ADD]], +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %add = call <2 x i32> @llvm.uadd.sat.v2i32(<2 x i32> %arg, <2 x i32> ) + %cmp = icmp ult <2 x i32> %add, + ret <2 x i1> %cmp +} + +define <2 x i1> @icmp_sgt_vector_equal(<2 x i64> %arg) { +; CHECK-LABEL: define <2 x i1> @icmp_sgt_vector_equal +; CHECK-SAME: (<2 x i64> [[ARG:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i64> [[ARG]], +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %add = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> %arg, <2 x i64> ) + %cmp = icmp sgt <2 x i64> %add, + ret <2 x i1> %cmp +} + +define <2 x i1> @icmp_sgt_vector_unequal(<2 x i64> %arg) { +; CHECK-LABEL: define <2 x i1> @icmp_sgt_vector_unequal +; CHECK-SAME: (<2 x i64> [[ARG:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> [[ARG]], <2 x i64> ) +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <2 x i64> [[ADD]], +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %add = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> %arg, <2 x i64> ) + %cmp = icmp sgt <2 x i64> %add, + ret <2 x i1> %cmp +} + +; ============================================================================== +; Tests with vector types and multiple uses +; ============================================================================== +define <2 x i1> @icmp_eq_vector_multiuse_equal(<2 x i8> %arg) { +; CHECK-LABEL: define <2 x i1> @icmp_eq_vector_multiuse_equal +; CHECK-SAME: (<2 x i8> [[ARG:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = call <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8> [[ARG]], <2 x i8> ) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[ADD]], +; CHECK-NEXT: call void @use.v2i8(<2 x i8> [[ADD]]) +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %add = call <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8> %arg, <2 x i8> ) + %cmp = icmp eq <2 x i8> %add, + call void @use.v2i8(<2 x i8> %add) + ret <2 x i1> %cmp +} + +declare i8 @llvm.uadd.sat.i8(i8, i8) +declare i16 @llvm.uadd.sat.i16(i16, i16) +declare i32 @llvm.uadd.sat.i32(i32, i32) +declare i64 @llvm.uadd.sat.i64(i64, i64) + +declare <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64>, <2 x i64>) +declare <2 x i32> @llvm.uadd.sat.v2i32(<2 x i32>, <2 x i32>) +declare <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16>, <2 x i16>) +declare <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8>, <2 x i8>) + +declare void @use.i8(i8) +declare void @use.v2i8(<2 x i8>) diff --git a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll --- a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll +++ b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll @@ -1,7 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s -; Tests for InstCombineCompares.cpp::foldICmpUSubSatWithConstant +; Tests for InstCombineCompares.cpp::foldICmpUSubSatOrUAddSatWithConstant +; - usub_sat case ; https://github.com/llvm/llvm-project/issues/58342 ; ==============================================================================