diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -497,6 +497,10 @@ /// LCSSA PHIs have been created, return the LCSSA PHI available at \p User. /// If no PHIs have been created, return the unchanged operand \p OpIdx. Value *fixupLCSSAFormFor(Instruction *User, unsigned OpIdx); + + /// Try to simplify to either \p A or \p B if the other argument is a + /// constant. Or create and return 'or A, B'. + Value *foldOrCreateOr(Value *A, Value *B); }; /// Helper to remove instructions inserted during SCEV expansion, unless they diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2516,38 +2516,46 @@ // And select either 1. or 2. depending on whether step is positive or // negative. If Step is known to be positive or negative, only create // either 1. or 2. - Value *Add = nullptr, *Sub = nullptr; - bool NeedPosCheck = !SE.isKnownNegative(Step); - bool NeedNegCheck = !SE.isKnownPositive(Step); - - if (PointerType *ARPtrTy = dyn_cast(ARTy)) { - StartValue = InsertNoopCastOfTo( - StartValue, Builder.getInt8PtrTy(ARPtrTy->getAddressSpace())); - Value *NegMulV = Builder.CreateNeg(MulV); - if (NeedPosCheck) - Add = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, MulV); - if (NeedNegCheck) - Sub = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, NegMulV); - } else { + auto ComputeEndCheck = [&]() -> Value * { + // Checking isZero() && SE.isKnownPositive(Step)) + return ConstantInt::getFalse(Loc->getContext()); + + Value *Add = nullptr, *Sub = nullptr; + bool NeedPosCheck = !SE.isKnownNegative(Step); + bool NeedNegCheck = !SE.isKnownPositive(Step); + + if (PointerType *ARPtrTy = dyn_cast(ARTy)) { + StartValue = InsertNoopCastOfTo( + StartValue, Builder.getInt8PtrTy(ARPtrTy->getAddressSpace())); + Value *NegMulV = Builder.CreateNeg(MulV); + if (NeedPosCheck) + Add = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, MulV); + if (NeedNegCheck) + Sub = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, NegMulV); + } else { + if (NeedPosCheck) + Add = Builder.CreateAdd(StartValue, MulV); + if (NeedNegCheck) + Sub = Builder.CreateSub(StartValue, MulV); + } + + Value *EndCompareLT = nullptr; + Value *EndCompareGT = nullptr; + Value *EndCheck = nullptr; if (NeedPosCheck) - Add = Builder.CreateAdd(StartValue, MulV); + EndCheck = EndCompareLT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue); if (NeedNegCheck) - Sub = Builder.CreateSub(StartValue, MulV); - } - - Value *EndCompareLT = nullptr; - Value *EndCompareGT = nullptr; - Value *EndCheck = nullptr; - if (NeedPosCheck) - EndCheck = EndCompareLT = Builder.CreateICmp( - Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue); - if (NeedNegCheck) - EndCheck = EndCompareGT = Builder.CreateICmp( - Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue); - if (NeedPosCheck && NeedNegCheck) { - // Select the answer based on the sign of Step. - EndCheck = Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT); - } + EndCheck = EndCompareGT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue); + if (NeedPosCheck && NeedNegCheck) { + // Select the answer based on the sign of Step. + EndCheck = Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT); + } + return EndCheck; + }; + Value *EndCheck = ComputeEndCheck(); // If the backedge taken count type is larger than the AR type, // check that we don't drop any bits by truncating it. If we are @@ -2560,10 +2568,10 @@ BackedgeCheck = Builder.CreateAnd( BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero)); - EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck); + EndCheck = foldOrCreateOr(EndCheck, BackedgeCheck); } - return Builder.CreateOr(EndCheck, OfMul); + return foldOrCreateOr(EndCheck, OfMul); } Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, @@ -2633,6 +2641,18 @@ return User->getOperand(OpIdx); } +Value *SCEVExpander::foldOrCreateOr(Value *A, Value *B) { + if (match(A, m_SpecificInt(0))) + return B; + if (match(B, m_SpecificInt(0))) + return A; + if (match(A, m_SpecificInt(1))) + return A; + if (match(B, m_SpecificInt(1))) + return B; + return Builder.CreateOr(A, B); +} + namespace { // Search for a SCEV subexpression that is not safe to expand. Any expression // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely diff --git a/llvm/test/Transforms/LoopDistribute/scev-inserted-runtime-check.ll b/llvm/test/Transforms/LoopDistribute/scev-inserted-runtime-check.ll --- a/llvm/test/Transforms/LoopDistribute/scev-inserted-runtime-check.ll +++ b/llvm/test/Transforms/LoopDistribute/scev-inserted-runtime-check.ll @@ -17,11 +17,8 @@ ; CHECK-NEXT: [[MUL1:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 2, i32 [[TMP1]]) ; CHECK-NEXT: [[MUL_RESULT:%.*]] = extractvalue { i32, i1 } [[MUL1]], 0 ; CHECK-NEXT: [[MUL_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[MUL1]], 1 -; CHECK-NEXT: [[TMP2:%.*]] = add i32 0, [[MUL_RESULT]] -; CHECK-NEXT: [[TMP5:%.*]] = icmp ult i32 [[TMP2]], 0 ; CHECK-NEXT: [[TMP7:%.*]] = icmp ugt i64 [[TMP0]], 4294967295 -; CHECK-NEXT: [[TMP8:%.*]] = or i1 [[TMP5]], [[TMP7]] -; CHECK-NEXT: [[TMP9:%.*]] = or i1 [[TMP8]], [[MUL_OVERFLOW]] +; CHECK-NEXT: [[TMP9:%.*]] = or i1 [[TMP7]], [[MUL_OVERFLOW]] ; CHECK-NEXT: [[MUL2:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 8, i64 [[TMP0]]) ; CHECK-NEXT: [[MUL_RESULT3:%.*]] = extractvalue { i64, i1 } [[MUL2]], 0 ; CHECK-NEXT: [[MUL_OVERFLOW4:%.*]] = extractvalue { i64, i1 } [[MUL2]], 1 @@ -159,11 +156,8 @@ ; CHECK-NEXT: [[MUL1:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 2, i32 [[TMP1]]) ; CHECK-NEXT: [[MUL_RESULT:%.*]] = extractvalue { i32, i1 } [[MUL1]], 0 ; CHECK-NEXT: [[MUL_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[MUL1]], 1 -; CHECK-NEXT: [[TMP2:%.*]] = add i32 0, [[MUL_RESULT]] -; CHECK-NEXT: [[TMP5:%.*]] = icmp ult i32 [[TMP2]], 0 ; CHECK-NEXT: [[TMP7:%.*]] = icmp ugt i64 [[TMP0]], 4294967295 -; CHECK-NEXT: [[TMP8:%.*]] = or i1 [[TMP5]], [[TMP7]] -; CHECK-NEXT: [[TMP9:%.*]] = or i1 [[TMP8]], [[MUL_OVERFLOW]] +; CHECK-NEXT: [[TMP9:%.*]] = or i1 [[TMP7]], [[MUL_OVERFLOW]] ; CHECK-NEXT: [[MUL2:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 8, i64 [[TMP0]]) ; CHECK-NEXT: [[MUL_RESULT3:%.*]] = extractvalue { i64, i1 } [[MUL2]], 0 ; CHECK-NEXT: [[MUL_OVERFLOW4:%.*]] = extractvalue { i64, i1 } [[MUL2]], 1 diff --git a/llvm/test/Transforms/LoopVectorize/runtime-check-small-clamped-bounds.ll b/llvm/test/Transforms/LoopVectorize/runtime-check-small-clamped-bounds.ll --- a/llvm/test/Transforms/LoopVectorize/runtime-check-small-clamped-bounds.ll +++ b/llvm/test/Transforms/LoopVectorize/runtime-check-small-clamped-bounds.ll @@ -20,11 +20,8 @@ ; CHECK: vector.scevcheck: ; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[N]], -1 ; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[TMP0]] to i2 -; CHECK-NEXT: [[TMP2:%.*]] = add i2 0, [[TMP1]] -; CHECK-NEXT: [[TMP5:%.*]] = icmp ult i2 [[TMP2]], 0 ; CHECK-NEXT: [[TMP7:%.*]] = icmp ugt i32 [[TMP0]], 3 -; CHECK-NEXT: [[TMP8:%.*]] = or i1 [[TMP5]], [[TMP7]] -; CHECK-NEXT: br i1 [[TMP8]], label [[SCALAR_PH]], label [[VECTOR_MEMCHECK:%.*]] +; CHECK-NEXT: br i1 [[TMP7]], label [[SCALAR_PH]], label [[VECTOR_MEMCHECK:%.*]] ; CHECK: vector.memcheck: ; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[N]], -1 ; CHECK-NEXT: [[TMP11:%.*]] = zext i32 [[TMP10]] to i64 @@ -107,11 +104,8 @@ ; CHECK: vector.scevcheck: ; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[N]], -1 ; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[TMP0]] to i2 -; CHECK-NEXT: [[TMP2:%.*]] = add i2 0, [[TMP1]] -; CHECK-NEXT: [[TMP5:%.*]] = icmp ult i2 [[TMP2]], 0 ; CHECK-NEXT: [[TMP7:%.*]] = icmp ugt i32 [[TMP0]], 3 -; CHECK-NEXT: [[TMP8:%.*]] = or i1 [[TMP5]], [[TMP7]] -; CHECK-NEXT: br i1 [[TMP8]], label [[SCALAR_PH]], label [[VECTOR_MEMCHECK:%.*]] +; CHECK-NEXT: br i1 [[TMP7]], label [[SCALAR_PH]], label [[VECTOR_MEMCHECK:%.*]] ; CHECK: vector.memcheck: ; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[N]], -1 ; CHECK-NEXT: [[TMP11:%.*]] = zext i32 [[TMP10]] to i64 @@ -273,11 +267,8 @@ ; CHECK: vector.scevcheck: ; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[N]], -1 ; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[TMP0]] to i2 -; CHECK-NEXT: [[TMP2:%.*]] = add i2 0, [[TMP1]] -; CHECK-NEXT: [[TMP5:%.*]] = icmp ult i2 [[TMP2]], 0 ; CHECK-NEXT: [[TMP7:%.*]] = icmp ugt i32 [[TMP0]], 3 -; CHECK-NEXT: [[TMP8:%.*]] = or i1 [[TMP5]], [[TMP7]] -; CHECK-NEXT: br i1 [[TMP8]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]] +; CHECK-NEXT: br i1 [[TMP7]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]] ; CHECK: vector.ph: ; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i32 [[N]], 2 ; CHECK-NEXT: [[N_VEC:%.*]] = sub i32 [[N]], [[N_MOD_VF]] diff --git a/llvm/test/Transforms/LoopVersioning/wrapping-pointer-versioning.ll b/llvm/test/Transforms/LoopVersioning/wrapping-pointer-versioning.ll --- a/llvm/test/Transforms/LoopVersioning/wrapping-pointer-versioning.ll +++ b/llvm/test/Transforms/LoopVersioning/wrapping-pointer-versioning.ll @@ -34,11 +34,8 @@ ; LV-NEXT: [[MUL1:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 2, i32 [[TMP1]]) ; LV-NEXT: [[MUL_RESULT:%.*]] = extractvalue { i32, i1 } [[MUL1]], 0 ; LV-NEXT: [[MUL_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[MUL1]], 1 -; LV-NEXT: [[TMP2:%.*]] = add i32 0, [[MUL_RESULT]] -; LV-NEXT: [[TMP5:%.*]] = icmp ult i32 [[TMP2]], 0 ; LV-NEXT: [[TMP7:%.*]] = icmp ugt i64 [[TMP0]], 4294967295 -; LV-NEXT: [[TMP8:%.*]] = or i1 [[TMP5]], [[TMP7]] -; LV-NEXT: [[TMP9:%.*]] = or i1 [[TMP8]], [[MUL_OVERFLOW]] +; LV-NEXT: [[TMP9:%.*]] = or i1 [[TMP7]], [[MUL_OVERFLOW]] ; LV-NEXT: [[MUL2:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 4, i64 [[TMP0]]) ; LV-NEXT: [[MUL_RESULT3:%.*]] = extractvalue { i64, i1 } [[MUL2]], 0 ; LV-NEXT: [[MUL_OVERFLOW4:%.*]] = extractvalue { i64, i1 } [[MUL2]], 1