diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Vectorize/VectorCombine.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" @@ -60,8 +61,8 @@ class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, - const DominatorTree &DT, AAResults &AA) - : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA) {} + const DominatorTree &DT, AAResults &AA, AssumptionCache &AC) + : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {} bool run(); @@ -71,6 +72,7 @@ const TargetTransformInfo &TTI; const DominatorTree &DT; AAResults &AA; + AssumptionCache &AC; bool vectorizeLoadInsert(Instruction &I); ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, @@ -774,8 +776,16 @@ /// Check if it is legal to scalarize a memory access to \p VecTy at index \p /// Idx. \p Idx must access a valid vector element. -static bool canScalarizeAccess(FixedVectorType *VecTy, ConstantInt *Idx) { - return Idx->getValue().ult(VecTy->getNumElements()); +static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx, + Instruction *CtxI, AssumptionCache &AC) { + if (auto *C = dyn_cast(Idx)) + return C->getValue().ult(VecTy->getNumElements()); + + APInt Zero(Idx->getType()->getScalarSizeInBits(), 0); + APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements()); + ConstantRange ValidIndices(Zero, MaxElts); + ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0); + return ValidIndices.contains(IdxRange); } // Combine patterns like: @@ -796,10 +806,10 @@ // TargetTransformInfo. Instruction *Source; Value *NewElement; - ConstantInt *Idx; + Value *Idx; if (!match(SI->getValueOperand(), m_InsertElt(m_Instruction(Source), m_Value(NewElement), - m_ConstantInt(Idx)))) + m_Value(Idx)))) return false; if (auto *Load = dyn_cast(Source)) { @@ -810,7 +820,7 @@ // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || !DL.typeSizeEqualsStoreSize(Load->getType()) || - !canScalarizeAccess(VecTy, Idx) || + !canScalarizeAccess(VecTy, Idx, Load, AC) || SrcAddr != SI->getPointerOperand()->stripPointerCasts() || isMemModifiedBetween(Load->getIterator(), SI->getIterator(), MemoryLocation::get(SI), AA)) @@ -835,8 +845,8 @@ /// Try to scalarize vector loads feeding extractelement instructions. bool VectorCombine::scalarizeLoadExtract(Instruction &I) { Value *Ptr; - ConstantInt *Idx; - if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_ConstantInt(Idx)))) + Value *Idx; + if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_Value(Idx)))) return false; auto *LI = cast(I.getOperand(0)); @@ -848,7 +858,7 @@ if (!FixedVT) return false; - if (!canScalarizeAccess(FixedVT, Idx)) + if (!canScalarizeAccess(FixedVT, Idx, &I, AC)) return false; InstructionCost OriginalCost = TTI.getMemoryOpCost( @@ -962,6 +972,7 @@ } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); @@ -976,10 +987,11 @@ bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; + auto &AC = getAnalysis().getAssumptionCache(F); auto &TTI = getAnalysis().getTTI(F); auto &DT = getAnalysis().getDomTree(); auto &AA = getAnalysis().getAAResults(); - VectorCombine Combiner(F, TTI, DT, AA); + VectorCombine Combiner(F, TTI, DT, AA, AC); return Combiner.run(); } }; @@ -989,6 +1001,7 @@ INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", "Optimize scalar/vector ops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", "Optimize scalar/vector ops", false, false) @@ -998,10 +1011,11 @@ PreservedAnalyses VectorCombinePass::run(Function &F, FunctionAnalysisManager &FAM) { + auto &AC = FAM.getResult(F); TargetTransformInfo &TTI = FAM.getResult(F); DominatorTree &DT = FAM.getResult(F); AAResults &AA = FAM.getResult(F); - VectorCombine Combiner(F, TTI, DT, AA); + VectorCombine Combiner(F, TTI, DT, AA, AC); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll --- a/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll +++ b/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll @@ -13,8 +13,8 @@ ; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], 225 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) ; CHECK-NEXT: [[TMP3:%.*]] = bitcast [225 x double]* [[A:%.*]] to <225 x double>* -; CHECK-NEXT: [[TMP4:%.*]] = load <225 x double>, <225 x double>* [[TMP3]], align 8 -; CHECK-NEXT: [[MATRIXEXT:%.*]] = extractelement <225 x double> [[TMP4]], i64 [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[TMP3]], i64 0, i64 [[TMP1]] +; CHECK-NEXT: [[MATRIXEXT:%.*]] = load double, double* [[TMP4]], align 8 ; CHECK-NEXT: [[CONV2:%.*]] = zext i32 [[I:%.*]] to i64 ; CHECK-NEXT: [[TMP5:%.*]] = add nuw nsw i64 [[TMP0]], [[CONV2]] ; CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[TMP5]], 225 @@ -25,8 +25,8 @@ ; CHECK-NEXT: [[MUL:%.*]] = fmul double [[MATRIXEXT]], [[MATRIXEXT4]] ; CHECK-NEXT: [[MATRIXEXT7:%.*]] = extractelement <225 x double> [[TMP8]], i64 [[TMP1]] ; CHECK-NEXT: [[SUB:%.*]] = fsub double [[MATRIXEXT7]], [[MUL]] -; CHECK-NEXT: [[MATINS:%.*]] = insertelement <225 x double> [[TMP8]], double [[SUB]], i64 [[TMP1]] -; CHECK-NEXT: store <225 x double> [[MATINS]], <225 x double>* [[TMP7]], align 8 +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[TMP7]], i64 0, i64 [[TMP1]] +; CHECK-NEXT: store double [[SUB]], double* [[TMP9]], align 8 ; CHECK-NEXT: ret void ; entry: @@ -112,8 +112,8 @@ ; CHECK-NEXT: [[TMP6:%.*]] = add nuw nsw i64 [[TMP2]], [[CONV_US]] ; CHECK-NEXT: [[TMP7:%.*]] = icmp ult i64 [[TMP6]], 225 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP7]]) -; CHECK-NEXT: [[TMP8:%.*]] = load <225 x double>, <225 x double>* [[TMP0]], align 8 -; CHECK-NEXT: [[MATRIXEXT_US:%.*]] = extractelement <225 x double> [[TMP8]], i64 [[TMP6]] +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[TMP0]], i64 0, i64 [[TMP6]] +; CHECK-NEXT: [[MATRIXEXT_US:%.*]] = load double, double* [[TMP8]], align 8 ; CHECK-NEXT: [[MATRIXEXT8_US:%.*]] = extractelement <225 x double> [[TMP5]], i64 [[TMP3]] ; CHECK-NEXT: [[MUL_US:%.*]] = fmul double [[MATRIXEXT_US]], [[MATRIXEXT8_US]] ; CHECK-NEXT: [[MATRIXEXT11_US:%.*]] = extractelement <225 x double> [[TMP5]], i64 [[TMP6]] diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll --- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt -vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s +; RUN: opt -vector-combine -enable-new-pm=false -mtriple=arm64-apple-darwinos -S %s | FileCheck %s define i32 @load_extract_idx_0(<4 x i32>* %x) { ; CHECK-LABEL: @load_extract_idx_0( @@ -95,8 +96,8 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX]] +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP0]], align 1 ; CHECK-NEXT: ret i32 [[R]] ; entry: @@ -130,8 +131,8 @@ ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3 -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]] +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP0]], align 1 ; CHECK-NEXT: ret i32 [[R]] ; entry: @@ -160,8 +161,8 @@ ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4 -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]] +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP0]], align 1 ; CHECK-NEXT: ret i32 [[R]] ; entry: diff --git a/llvm/test/Transforms/VectorCombine/load-insert-store.ll b/llvm/test/Transforms/VectorCombine/load-insert-store.ll --- a/llvm/test/Transforms/VectorCombine/load-insert-store.ll +++ b/llvm/test/Transforms/VectorCombine/load-insert-store.ll @@ -130,9 +130,8 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IDX:%.*]], 4 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP0]], align 1 ; CHECK-NEXT: ret void ; entry: @@ -191,10 +190,9 @@ define void @insert_store_nonconst_index_known_valid_by_and(<16 x i8>* %q, i8 zeroext %s, i32 %idx) { ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_and( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 ; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i32 [[IDX:%.*]], 7 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP0]], align 1 ; CHECK-NEXT: ret void ; entry: @@ -225,10 +223,9 @@ define void @insert_store_nonconst_index_known_valid_by_urem(<16 x i8>* %q, i8 zeroext %s, i32 %idx) { ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_urem( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 ; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i32 [[IDX:%.*]], 16 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP0]], align 1 ; CHECK-NEXT: ret void ; entry: