diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Vectorize/VectorCombine.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" @@ -60,8 +61,8 @@ class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, - const DominatorTree &DT, AAResults &AA) - : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA) {} + const DominatorTree &DT, AAResults &AA, AssumptionCache &AC) + : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {} bool run(); @@ -71,6 +72,7 @@ const TargetTransformInfo &TTI; const DominatorTree &DT; AAResults &AA; + AssumptionCache &AC; bool vectorizeLoadInsert(Instruction &I); ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, @@ -774,8 +776,16 @@ /// Check if it is legal to scalarize a memory access to \p VecTy at index \p /// Idx. \p Idx must access a valid vector element. -static bool canScalarizeAccess(FixedVectorType *VecTy, ConstantInt *Idx) { - return Idx->getValue().ult(VecTy->getNumElements()); +static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx, + Instruction *CtxI, AssumptionCache &AC) { + if (auto *C = dyn_cast(Idx)) + return C->getValue().ult(VecTy->getNumElements()); + + APInt Zero(Idx->getType()->getScalarSizeInBits(), 0); + APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements()); + ConstantRange ValidIndices(Zero, MaxElts); + ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0); + return ValidIndices.contains(IdxRange); } // Combine patterns like: @@ -796,10 +806,10 @@ // TargetTransformInfo. Instruction *Source; Value *NewElement; - ConstantInt *Idx; + Value *Idx; if (!match(SI->getValueOperand(), m_InsertElt(m_Instruction(Source), m_Value(NewElement), - m_ConstantInt(Idx)))) + m_Value(Idx)))) return false; if (auto *Load = dyn_cast(Source)) { @@ -810,7 +820,7 @@ // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || !DL.typeSizeEqualsStoreSize(Load->getType()) || - !canScalarizeAccess(VecTy, Idx) || + !canScalarizeAccess(VecTy, Idx, Load, AC) || SrcAddr != SI->getPointerOperand()->stripPointerCasts() || isMemModifiedBetween(Load->getIterator(), SI->getIterator(), MemoryLocation::get(SI), AA)) @@ -835,8 +845,8 @@ /// Try to scalarize vector loads feeding extractelement instructions. bool VectorCombine::scalarizeLoadExtract(Instruction &I) { Value *Ptr; - ConstantInt *Idx; - if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_ConstantInt(Idx)))) + Value *Idx; + if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_Value(Idx)))) return false; auto *LI = cast(I.getOperand(0)); @@ -848,7 +858,7 @@ if (!FixedVT) return false; - if (!canScalarizeAccess(FixedVT, Idx)) + if (!canScalarizeAccess(FixedVT, Idx, &I, AC)) return false; InstructionCost OriginalCost = TTI.getMemoryOpCost( @@ -962,6 +972,7 @@ } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); @@ -976,10 +987,11 @@ bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; + auto &AC = getAnalysis().getAssumptionCache(F); auto &TTI = getAnalysis().getTTI(F); auto &DT = getAnalysis().getDomTree(); auto &AA = getAnalysis().getAAResults(); - VectorCombine Combiner(F, TTI, DT, AA); + VectorCombine Combiner(F, TTI, DT, AA, AC); return Combiner.run(); } }; @@ -989,6 +1001,7 @@ INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", "Optimize scalar/vector ops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", "Optimize scalar/vector ops", false, false) @@ -998,10 +1011,11 @@ PreservedAnalyses VectorCombinePass::run(Function &F, FunctionAnalysisManager &FAM) { + auto &AC = FAM.getResult(F); TargetTransformInfo &TTI = FAM.getResult(F); DominatorTree &DT = FAM.getResult(F); AAResults &AA = FAM.getResult(F); - VectorCombine Combiner(F, TTI, DT, AA); + VectorCombine Combiner(F, TTI, DT, AA, AC); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll --- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt -vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s +; RUN: opt -vector-combine -enable-new-pm=false -mtriple=arm64-apple-darwinos -S %s | FileCheck %s define i32 @load_extract_idx_0(<4 x i32>* %x) { ; CHECK-LABEL: @load_extract_idx_0( @@ -95,8 +96,8 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX]] +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP0]], align 1 ; CHECK-NEXT: ret i32 [[R]] ; entry: diff --git a/llvm/test/Transforms/VectorCombine/load-insert-store.ll b/llvm/test/Transforms/VectorCombine/load-insert-store.ll --- a/llvm/test/Transforms/VectorCombine/load-insert-store.ll +++ b/llvm/test/Transforms/VectorCombine/load-insert-store.ll @@ -130,9 +130,8 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IDX:%.*]], 4 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP0]], align 1 ; CHECK-NEXT: ret void ; entry: