diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -56,8 +56,8 @@ class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, - const DominatorTree &DT) - : F(F), Builder(F.getContext()), TTI(TTI), DT(DT) {} + const DominatorTree &DT, AAResults &AA) + : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA) {} bool run(); @@ -66,6 +66,7 @@ IRBuilder<> Builder; const TargetTransformInfo &TTI; const DominatorTree &DT; + AAResults &AA; bool vectorizeLoadInsert(Instruction &I); ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, @@ -83,6 +84,7 @@ bool foldBitcastShuf(Instruction &I); bool scalarizeBinopOrCmp(Instruction &I); bool foldExtractedCmps(Instruction &I); + bool foldSingleElementStore(Instruction &I); }; } // namespace @@ -752,6 +754,63 @@ return true; } +// Check if memory loc modified between two instrs in the same BB +static bool isMemModifiedBetween(BasicBlock::iterator Begin, + BasicBlock::iterator End, + const MemoryLocation &Loc, AAResults &AA) { + for (BasicBlock::iterator BBI = Begin; BBI != End; ++BBI) + if (isModSet(AA.getModRefInfo(&*BBI, Loc))) + return true; + return false; +} + +// Combine patterns like: +// %0 = load <4 x i32>, <4 x i32>* %a +// %1 = insertelement <4 x i32> %0, i32 %b, i32 1 +// store <4 x i32> %1, <4 x i32>* %a +// to: +// %0 = bitcast <4 x i32>* %a to i32* +// %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1 +// store i32 %b, i32* %1 +bool VectorCombine::foldSingleElementStore(Instruction &I) { + StoreInst *SI = dyn_cast(&I); + if (SI == nullptr || !SI->isSimple() || + !SI->getValueOperand()->getType()->isVectorTy()) + return false; + + // TODO: Combine more complicated patterns (multiple insert) by referencing + // TargetTransformInfo. + Instruction *Source; + Value *NewElement; + Constant *Idx; + if (!match(SI->getValueOperand(), + m_InsertElt(m_Instruction(Source), m_Value(NewElement), + m_Constant(Idx)))) + return false; + + if (auto *Load = dyn_cast(Source)) { + Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts(); + // Don't optimize for atomic/volatile load or stores. + if (!Load->isSimple() || Load->getParent() != SI->getParent() || + SrcAddr != SI->getPointerOperand()->stripPointerCasts() || + isMemModifiedBetween(Load->getIterator(), SI->getIterator(), + MemoryLocation::get(SI), AA)) + return false; + + Type *ElePtrType = NewElement->getType()->getPointerTo(); + Value *ElePtr = + Builder.CreatePointerCast(SI->getPointerOperand(), ElePtrType); + Value *GEP = Builder.CreateInBoundsGEP(NewElement->getType(), ElePtr, Idx); + StoreInst *NSI = Builder.CreateStore(NewElement, GEP); + NSI->copyMetadata(*SI, {LLVMContext::MD_nontemporal}); + replaceValue(I, *NSI); + I.eraseFromParent(); + return true; + } + + return false; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -771,7 +830,7 @@ // Walk the block forwards to enable simple iterative chains of transforms. // TODO: It could be more efficient to remove dead instructions // iteratively in this loop rather than waiting until the end. - for (Instruction &I : BB) { + for (Instruction &I : make_early_inc_range(BB)) { if (isa(I)) continue; Builder.SetInsertPoint(&I); @@ -780,6 +839,7 @@ MadeChange |= foldBitcastShuf(I); MadeChange |= scalarizeBinopOrCmp(I); MadeChange |= foldExtractedCmps(I); + MadeChange |= foldSingleElementStore(I); } } @@ -817,7 +877,8 @@ return false; auto &TTI = getAnalysis().getTTI(F); auto &DT = getAnalysis().getDomTree(); - VectorCombine Combiner(F, TTI, DT); + auto &AA = getAnalysis().getAAResults(); + VectorCombine Combiner(F, TTI, DT, AA); return Combiner.run(); } }; @@ -838,7 +899,8 @@ FunctionAnalysisManager &FAM) { TargetTransformInfo &TTI = FAM.getResult(F); DominatorTree &DT = FAM.getResult(F); - VectorCombine Combiner(F, TTI, DT); + AAResults &AA = FAM.getResult(F); + VectorCombine Combiner(F, TTI, DT, AA); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/test/Transforms/InstCombine/load-insert-store.ll b/llvm/test/Transforms/VectorCombine/X86/load-insert-store.ll rename from llvm/test/Transforms/InstCombine/load-insert-store.ll rename to llvm/test/Transforms/VectorCombine/X86/load-insert-store.ll --- a/llvm/test/Transforms/InstCombine/load-insert-store.ll +++ b/llvm/test/Transforms/VectorCombine/X86/load-insert-store.ll @@ -1,12 +1,13 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt -S -instcombine < %s | FileCheck %s +; RUN: opt -S -vector-combine -data-layout=e < %s | FileCheck %s +; RUN: opt -S -vector-combine -data-layout=E < %s | FileCheck %s define void @insert_store(<16 x i8>* %q, i8 zeroext %s) { ; CHECK-LABEL: @insert_store( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 3 -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <16 x i8>* [[Q:%.*]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, i8* [[TMP0]], i32 3 +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP1]], align 1 ; CHECK-NEXT: ret void ; entry: @@ -16,19 +17,18 @@ ret void } -define void @single_shuffle_store(<4 x i32>* %a, i32 %b) { -; CHECK-LABEL: @single_shuffle_store( +define void @insert_store_i16(<8 x i16>* %q, i16 zeroext %s) { +; CHECK-LABEL: @insert_store_i16( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, <4 x i32>* [[A:%.*]], align 16 -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32> [[TMP0]], i32 [[B:%.*]], i32 1 -; CHECK-NEXT: store <4 x i32> [[TMP1]], <4 x i32>* [[A]], align 16, !nontemporal !0 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i16>* [[Q:%.*]] to i16* +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i16, i16* [[TMP0]], i32 3 +; CHECK-NEXT: store i16 [[S:%.*]], i16* [[TMP1]], align 2 ; CHECK-NEXT: ret void ; entry: - %0 = load <4 x i32>, <4 x i32>* %a - %1 = insertelement <4 x i32> %0, i32 %b, i32 1 - %2 = shufflevector <4 x i32> %0, <4 x i32> %1, <4 x i32> - store <4 x i32> %2, <4 x i32>* %a, !nontemporal !0 + %0 = load <8 x i16>, <8 x i16>* %q + %vecins = insertelement <8 x i16> %0, i16 %s, i32 3 + store <8 x i16> %vecins, <8 x i16>* %q ret void } @@ -69,6 +69,9 @@ ret void } +; We can't transform if any instr could modify memory in between. +; Here p and q may alias, so we can't remove the load. +; r is impossible to alias with others, so it's safe to transform. define void @insert_store_mem_modify(<16 x i8>* %p, <16 x i8>* %q, <16 x i8>* noalias %r, i8 %s) { ; CHECK-LABEL: @insert_store_mem_modify( ; CHECK-NEXT: entry: @@ -76,10 +79,10 @@ ; CHECK-NEXT: store <16 x i8> zeroinitializer, <16 x i8>* [[Q:%.*]], align 16 ; CHECK-NEXT: [[INS:%.*]] = insertelement <16 x i8> [[LD]], i8 [[S:%.*]], i32 3 ; CHECK-NEXT: store <16 x i8> [[INS]], <16 x i8>* [[P]], align 16 -; CHECK-NEXT: [[LD2:%.*]] = load <16 x i8>, <16 x i8>* [[Q]], align 16 ; CHECK-NEXT: store <16 x i8> zeroinitializer, <16 x i8>* [[R:%.*]], align 16 -; CHECK-NEXT: [[INS2:%.*]] = insertelement <16 x i8> [[LD2]], i8 [[S]], i32 7 -; CHECK-NEXT: store <16 x i8> [[INS2]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <16 x i8>* [[Q]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, i8* [[TMP0]], i32 7 +; CHECK-NEXT: store i8 [[S]], i8* [[TMP1]], align 1 ; CHECK-NEXT: ret void ; entry: @@ -95,4 +98,36 @@ ret void } +; Check cases when calls may modify memory +define void @insert_store_with_call(<16 x i8>* %p, <16 x i8>* %q, i8 %s) { +; CHECK-LABEL: @insert_store_with_call( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[LD:%.*]] = load <16 x i8>, <16 x i8>* [[P:%.*]], align 16 +; CHECK-NEXT: call void @maywrite(<16 x i8>* [[P]]) +; CHECK-NEXT: [[INS:%.*]] = insertelement <16 x i8> [[LD]], i8 [[S:%.*]], i32 3 +; CHECK-NEXT: store <16 x i8> [[INS]], <16 x i8>* [[P]], align 16 +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: call void @nowrite(<16 x i8>* [[P]]) +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <16 x i8>* [[P]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, i8* [[TMP0]], i32 7 +; CHECK-NEXT: store i8 [[S]], i8* [[TMP1]], align 1 +; CHECK-NEXT: ret void +; +entry: + %ld = load <16 x i8>, <16 x i8>* %p + call void @maywrite(<16 x i8>* %p) + %ins = insertelement <16 x i8> %ld, i8 %s, i32 3 + store <16 x i8> %ins, <16 x i8>* %p + call void @foo() ; Barrier + %ld2 = load <16 x i8>, <16 x i8>* %p + call void @nowrite(<16 x i8>* %p) + %ins2 = insertelement <16 x i8> %ld2, i8 %s, i32 7 + store <16 x i8> %ins2, <16 x i8>* %p + ret void +} + +declare void @foo() +declare void @maywrite(<16 x i8>*) +declare void @nowrite(<16 x i8>*) readonly + !0 = !{}