diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Vectorize/VectorCombine.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" @@ -60,8 +61,9 @@ class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, - const DominatorTree &DT, AAResults &AA) - : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA) {} + const DominatorTree &DT, AAResults &AA, + AssumptionCache *AC = nullptr) + : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {} bool run(); @@ -71,6 +73,7 @@ const TargetTransformInfo &TTI; const DominatorTree &DT; AAResults &AA; + AssumptionCache *AC; bool vectorizeLoadInsert(Instruction &I); ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, @@ -774,8 +777,16 @@ /// Check if it is legal to scalarize a memory access to \p VecTy at index \p /// Idx. \p Idx must access a valid vector element. -static bool canScalarizeAccess(FixedVectorType *VecTy, ConstantInt *Idx) { - return Idx->getValue().ult(VecTy->getNumElements()); +static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx, + Instruction *CtxI, AssumptionCache *AC) { + if (auto *C = dyn_cast(Idx)) + return C->getValue().ult(VecTy->getNumElements()); + + APInt Zero(Idx->getType()->getScalarSizeInBits(), 0); + APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements()); + ConstantRange ValidIndices(Zero, MaxElts); + ConstantRange IdxRange = computeConstantRange(Idx, true, AC, CtxI, 0); + return ValidIndices.contains(IdxRange); } // Combine patterns like: @@ -796,10 +807,10 @@ // TargetTransformInfo. Instruction *Source; Value *NewElement; - ConstantInt *Idx; + Value *Idx; if (!match(SI->getValueOperand(), m_InsertElt(m_Instruction(Source), m_Value(NewElement), - m_ConstantInt(Idx)))) + m_Value(Idx)))) return false; if (auto *Load = dyn_cast(Source)) { @@ -810,7 +821,7 @@ // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || !DL.typeSizeEqualsStoreSize(Load->getType()) || - !canScalarizeAccess(VecTy, Idx) || + !canScalarizeAccess(VecTy, Idx, Load, AC) || SrcAddr != SI->getPointerOperand()->stripPointerCasts() || isMemModifiedBetween(Load->getIterator(), SI->getIterator(), MemoryLocation::get(SI), AA)) @@ -835,8 +846,8 @@ /// Try to scalarize vector loads feeding extractelement instructions. bool VectorCombine::scalarizeLoadExtract(Instruction &I) { Value *Ptr; - ConstantInt *Idx; - if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_ConstantInt(Idx)))) + Value *Idx; + if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_Value(Idx)))) return false; auto *LI = cast(I.getOperand(0)); @@ -848,7 +859,7 @@ if (!FixedVT) return false; - if (!canScalarizeAccess(FixedVT, Idx)) + if (!canScalarizeAccess(FixedVT, Idx, &I, AC)) return false; InstructionCost OriginalCost = TTI.getMemoryOpCost( @@ -998,10 +1009,11 @@ PreservedAnalyses VectorCombinePass::run(Function &F, FunctionAnalysisManager &FAM) { + auto &AC = FAM.getResult(F); TargetTransformInfo &TTI = FAM.getResult(F); DominatorTree &DT = FAM.getResult(F); AAResults &AA = FAM.getResult(F); - VectorCombine Combiner(F, TTI, DT, AA); + VectorCombine Combiner(F, TTI, DT, AA, &AC); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll --- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll @@ -95,8 +95,8 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX]] +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP0]], align 1 ; CHECK-NEXT: ret i32 [[R]] ; entry: diff --git a/llvm/test/Transforms/VectorCombine/load-insert-store.ll b/llvm/test/Transforms/VectorCombine/load-insert-store.ll --- a/llvm/test/Transforms/VectorCombine/load-insert-store.ll +++ b/llvm/test/Transforms/VectorCombine/load-insert-store.ll @@ -130,9 +130,8 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IDX:%.*]], 4 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP0]], align 1 ; CHECK-NEXT: ret void ; entry: