diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -549,6 +549,7 @@ ConstantRange computeConstantRange(const Value *V, bool UseInstrInfo = true, AssumptionCache *AC = nullptr, const Instruction *CtxI = nullptr, + const DominatorTree *DT = nullptr, unsigned Depth = 0); /// Return true if this function can prove that the instruction I will diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -7031,6 +7031,7 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo, AssumptionCache *AC, const Instruction *CtxI, + const DominatorTree *DT, unsigned Depth) { assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction"); @@ -7069,7 +7070,7 @@ assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && "must be an assume intrinsic"); - if (!isValidAssumeForContext(I, CtxI, nullptr)) + if (!isValidAssumeForContext(I, CtxI, DT)) continue; Value *Arg = I->getArgOperand(0); ICmpInst *Cmp = dyn_cast(Arg); @@ -7077,7 +7078,7 @@ if (!Cmp || Cmp->getOperand(0) != V) continue; ConstantRange RHS = computeConstantRange(Cmp->getOperand(1), UseInstrInfo, - AC, I, Depth + 1); + AC, I, DT, Depth + 1); CR = CR.intersectWith( ConstantRange::makeAllowedICmpRegion(Cmp->getPredicate(), RHS)); } diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -827,7 +827,8 @@ /// Idx. \p Idx must access a valid vector element. static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy, Value *Idx, Instruction *CtxI, - AssumptionCache &AC) { + AssumptionCache &AC, + const DominatorTree &DT) { if (auto *C = dyn_cast(Idx)) { if (C->getValue().ult(VecTy->getNumElements())) return ScalarizationResult::safe(); @@ -841,7 +842,7 @@ ConstantRange IdxRange(IntWidth, true); if (isGuaranteedNotToBePoison(Idx, &AC)) { - if (ValidIndices.contains(computeConstantRange(Idx, true, &AC, CtxI, 0))) + if (ValidIndices.contains(computeConstantRange(Idx, true, &AC, CtxI, &DT))) return ScalarizationResult::safe(); return ScalarizationResult::unsafe(); } @@ -909,7 +910,7 @@ SrcAddr != SI->getPointerOperand()->stripPointerCasts()) return false; - auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC); + auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT); if (ScalarizableIdx.isUnsafe() || isMemModifiedBetween(Load->getIterator(), SI->getIterator(), MemoryLocation::get(SI), AA)) @@ -987,7 +988,7 @@ else if (LastCheckedInst->comesBefore(UI)) LastCheckedInst = UI; - auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC); + auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT); if (!ScalarIdx.isSafe()) { // TODO: Freeze index if it is safe to do so. return false; diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extract-insert-store-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extract-insert-store-scalarization.ll --- a/llvm/test/Transforms/VectorCombine/AArch64/load-extract-insert-store-scalarization.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extract-insert-store-scalarization.ll @@ -74,8 +74,8 @@ ; CHECK-NEXT: [[MUL:%.*]] = fmul double 2.000000e+01, [[EXT_0]] ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <225 x double> [[LV]], i64 [[IDX_2]] ; CHECK-NEXT: [[SUB:%.*]] = fsub double [[EXT_1]], [[MUL]] -; CHECK-NEXT: [[INS:%.*]] = insertelement <225 x double> [[LV]], double [[SUB]], i64 [[IDX_1]] -; CHECK-NEXT: store <225 x double> [[INS]], <225 x double>* [[A]], align 8 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[A]], i64 0, i64 [[IDX_1]] +; CHECK-NEXT: store double [[SUB]], double* [[TMP0]], align 8 ; CHECK-NEXT: [[C_2:%.*]] = call i1 @cond() ; CHECK-NEXT: br i1 [[C_2]], label [[LOOP]], label [[EXIT]] ; CHECK: exit: diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll --- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll @@ -114,9 +114,9 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) ; CHECK-NEXT: br i1 [[C_1:%.*]], label [[LOOP:%.*]], label [[EXIT:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 ; CHECK-NEXT: call void @maythrow() -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX]] +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP0]], align 4 ; CHECK-NEXT: [[C_2:%.*]] = call i1 @cond() ; CHECK-NEXT: br i1 [[C_2]], label [[LOOP]], label [[EXIT]] ; CHECK: exit: diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -2103,7 +2103,7 @@ // Check the depth cutoff results in a conservative result (full set) by // passing Depth == MaxDepth == 6. - ConstantRange CR3 = computeConstantRange(X2, true, &AC, I, 6); + ConstantRange CR3 = computeConstantRange(X2, true, &AC, I, nullptr, 6); EXPECT_TRUE(CR3.isFullSet()); } {