Index: llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h =================================================================== --- llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h +++ llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h @@ -83,6 +83,8 @@ OptimizationRemarkEmitter *ORE_); private: + void vectorizeLoads(BasicBlock *BB); + /// Collect store and getelementptr instructions and organize them /// according to the underlying object of their pointer operands. We sort the /// instructions by their underlying objects to reduce the cost of Index: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -36,6 +36,7 @@ #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryLocation.h" @@ -83,6 +84,7 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Vectorize.h" #include @@ -5070,6 +5072,7 @@ // Scan the blocks in the function in post order. for (auto BB : post_order(&F.getEntryBlock())) { + vectorizeLoads(BB); collectSeedInstructions(BB); // Vectorize trees that end at stores. @@ -5239,6 +5242,83 @@ return Changed; } +void SLPVectorizerPass::vectorizeLoads(BasicBlock *BB) { + SmallVector DeadLoads; + for (Instruction &I : *BB) { + // Match regular loads. + auto *Load = dyn_cast(&I); + if (!Load || Load->isVolatile() || Load->isAtomic()) + continue; + + // Match a scalar load of a bitcasted vector pointer. + // TODO: Extend this to match GEP with 0 or other offset. + Instruction *PtrOp; + Value *SrcPtr; + if (!match(Load->getPointerOperand(), + m_CombineAnd(m_Instruction(PtrOp), m_BitCast(m_Value(SrcPtr))))) + continue; + + // TODO: Extend this to allow widening of a sub-vector (not scalar) load. + auto *PtrOpTy = dyn_cast(PtrOp->getType()); + auto *SrcPtrTy = dyn_cast(SrcPtr->getType()); + if (!PtrOpTy || !SrcPtrTy || PtrOpTy->getElementType()->isVectorTy() || + !SrcPtrTy->getElementType()->isVectorTy()) + continue; + + // Do not create a vector load of an unsupported type. + auto *VecTy = cast(SrcPtrTy->getElementType()); + unsigned VecSize = VecTy->getPrimitiveSizeInBits(); + if (VecSize > TTI->getRegisterBitWidth(true)) + continue; + + // Check safety of replacing the scalar load with a larger vector load. + unsigned Alignment = Load->getAlignment(); + if (!isSafeToLoadUnconditionally(SrcPtr, SrcPtrTy->getElementType(), + Alignment, *DL, Load, DT)) + continue; + + // Original pattern: load (bitcast VecPtr to ScalarPtr) + int OldCost = TTI->getMemoryOpCost(Instruction::Load, + SrcPtrTy->getElementType(), Alignment, + Load->getPointerAddressSpace()); + OldCost += TTI->getCastInstrCost(Instruction::BitCast, PtrOpTy, SrcPtrTy); + + // If needed, bitcast the vector type to match the load (scalar element). + Type *VecLoadTy = VecTy; + if (VecTy->getVectorElementType() != Load->getType()) { + unsigned NumElts = VecSize / Load->getType()->getPrimitiveSizeInBits(); + VecLoadTy = VectorType::get(Load->getType(), NumElts); + } + + // New pattern: extractelt (load [bitcast] VecPtr), 0 + int NewCost = TTI->getMemoryOpCost(Instruction::Load, VecLoadTy, Alignment, + Load->getPointerAddressSpace()); + NewCost += TTI->getVectorInstrCost(Instruction::ExtractElement, + SrcPtrTy->getElementType(), 0); + if (VecLoadTy != VecTy) + NewCost += TTI->getCastInstrCost(Instruction::BitCast, + VecLoadTy->getPointerTo(), SrcPtrTy); + + // We can aggressively convert to the vector form because the backend will + // invert this transform if it does not result in a larger performance win. + if (OldCost < NewCost) + continue; + + // It is safe and profitable to load using the original vector pointer and + // extract the scalar value from that: + // load (bitcast VecPtr to ScalarPtr) --> extractelt (load VecPtr), 0 + IRBuilder<> Builder(Load); + if (VecLoadTy != VecTy) + SrcPtr = Builder.CreateBitCast(SrcPtr, VecLoadTy->getPointerTo()); + + LoadInst *VecLd = Builder.CreateAlignedLoad(VecLoadTy, SrcPtr, Alignment); + Value *ExtElt = Builder.CreateExtractElement(VecLd, Builder.getInt32(0)); + Load->replaceAllUsesWith(ExtElt); + DeadLoads.push_back(Load); + } + RecursivelyDeleteTriviallyDeadInstructions(DeadLoads, TLI); +} + void SLPVectorizerPass::collectSeedInstructions(BasicBlock *BB) { // Initialize the collections. We will make a single pass over the block. Stores.clear(); Index: llvm/test/Transforms/SLPVectorizer/X86/load-bitcast-vec.ll =================================================================== --- llvm/test/Transforms/SLPVectorizer/X86/load-bitcast-vec.ll +++ llvm/test/Transforms/SLPVectorizer/X86/load-bitcast-vec.ll @@ -1,12 +1,12 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -slp-vectorizer -S -mtriple=x86_64-- -mattr=+avx | FileCheck %s --check-prefixes=CHECK,AVX1 +; RUN: opt < %s -slp-vectorizer -S -mtriple=x86_64-- -mattr=+sse2 | FileCheck %s --check-prefixes=CHECK,SSE2 ; RUN: opt < %s -slp-vectorizer -S -mtriple=x86_64-- -mattr=+avx2 | FileCheck %s --check-prefixes=CHECK,AVX2 define float @matching_scalar(<4 x float>* dereferenceable(16) %p) { ; CHECK-LABEL: @matching_scalar( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to float* -; CHECK-NEXT: [[R:%.*]] = load float, float* [[BC]], align 16 -; CHECK-NEXT: ret float [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x float>, <4 x float>* [[P:%.*]], align 16 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[TMP1]], i32 0 +; CHECK-NEXT: ret float [[TMP2]] ; %bc = bitcast <4 x float>* %p to float* %r = load float, float* %bc, align 16 @@ -15,9 +15,10 @@ define i32 @nonmatching_scalar(<4 x float>* dereferenceable(16) %p) { ; CHECK-LABEL: @nonmatching_scalar( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to i32* -; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[BC]], align 16 -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x float>* [[P:%.*]] to <4 x i32>* +; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, <4 x i32>* [[TMP1]], align 16 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[TMP2]], i32 0 +; CHECK-NEXT: ret i32 [[TMP3]] ; %bc = bitcast <4 x float>* %p to i32* %r = load i32, i32* %bc, align 16 @@ -26,9 +27,10 @@ define i64 @larger_scalar(<4 x float>* dereferenceable(16) %p) { ; CHECK-LABEL: @larger_scalar( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to i64* -; CHECK-NEXT: [[R:%.*]] = load i64, i64* [[BC]], align 16 -; CHECK-NEXT: ret i64 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x float>* [[P:%.*]] to <2 x i64>* +; CHECK-NEXT: [[TMP2:%.*]] = load <2 x i64>, <2 x i64>* [[TMP1]], align 16 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i64> [[TMP2]], i32 0 +; CHECK-NEXT: ret i64 [[TMP3]] ; %bc = bitcast <4 x float>* %p to i64* %r = load i64, i64* %bc, align 16 @@ -37,20 +39,29 @@ define i8 @smaller_scalar(<4 x float>* dereferenceable(16) %p) { ; CHECK-LABEL: @smaller_scalar( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to i8* -; CHECK-NEXT: [[R:%.*]] = load i8, i8* [[BC]], align 16 -; CHECK-NEXT: ret i8 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x float>* [[P:%.*]] to <16 x i8>* +; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i8>, <16 x i8>* [[TMP1]], align 16 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <16 x i8> [[TMP2]], i32 0 +; CHECK-NEXT: ret i8 [[TMP3]] ; %bc = bitcast <4 x float>* %p to i8* %r = load i8, i8* %bc, align 16 ret i8 %r } +; Partial negative test - don't create an illegal load for an SSE target. + define i8 @smaller_scalar_256bit_vec(<8 x float>* dereferenceable(32) %p) { -; CHECK-LABEL: @smaller_scalar_256bit_vec( -; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x float>* [[P:%.*]] to i8* -; CHECK-NEXT: [[R:%.*]] = load i8, i8* [[BC]], align 32 -; CHECK-NEXT: ret i8 [[R]] +; SSE2-LABEL: @smaller_scalar_256bit_vec( +; SSE2-NEXT: [[BC:%.*]] = bitcast <8 x float>* [[P:%.*]] to i8* +; SSE2-NEXT: [[R:%.*]] = load i8, i8* [[BC]], align 32 +; SSE2-NEXT: ret i8 [[R]] +; +; AVX2-LABEL: @smaller_scalar_256bit_vec( +; AVX2-NEXT: [[TMP1:%.*]] = bitcast <8 x float>* [[P:%.*]] to <32 x i8>* +; AVX2-NEXT: [[TMP2:%.*]] = load <32 x i8>, <32 x i8>* [[TMP1]], align 32 +; AVX2-NEXT: [[TMP3:%.*]] = extractelement <32 x i8> [[TMP2]], i32 0 +; AVX2-NEXT: ret i8 [[TMP3]] ; %bc = bitcast <8 x float>* %p to i8* %r = load i8, i8* %bc, align 32 @@ -59,15 +70,18 @@ define i8 @smaller_scalar_less_aligned(<4 x float>* dereferenceable(16) %p) { ; CHECK-LABEL: @smaller_scalar_less_aligned( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to i8* -; CHECK-NEXT: [[R:%.*]] = load i8, i8* [[BC]], align 4 -; CHECK-NEXT: ret i8 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x float>* [[P:%.*]] to <16 x i8>* +; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i8>, <16 x i8>* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <16 x i8> [[TMP2]], i32 0 +; CHECK-NEXT: ret i8 [[TMP3]] ; %bc = bitcast <4 x float>* %p to i8* %r = load i8, i8* %bc, align 4 ret i8 %r } +; Negative test - not enough dereferenceable bytes. + define float @matching_scalar_small_deref(<4 x float>* dereferenceable(15) %p) { ; CHECK-LABEL: @matching_scalar_small_deref( ; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to float* @@ -79,6 +93,8 @@ ret float %r } +; Negative test - don't transform volatile ops. + define float @matching_scalar_volatile(<4 x float>* dereferenceable(16) %p) { ; CHECK-LABEL: @matching_scalar_volatile( ; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to float* @@ -90,6 +106,8 @@ ret float %r } +; Negative test - not bitcast from vector type. + define float @nonvector(double* dereferenceable(16) %p) { ; CHECK-LABEL: @nonvector( ; CHECK-NEXT: [[BC:%.*]] = bitcast double* [[P:%.*]] to float*