Index: llvm/lib/Transforms/Vectorize/VectorCombine.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" @@ -46,6 +47,79 @@ "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, cl::desc("Disable binop extract to shuffle transforms")); +static bool vectorizeLoad(Instruction &I, const TargetTransformInfo &TTI, + const DominatorTree &DT) { + // Match regular loads. + auto *Load = dyn_cast(&I); + if (!Load || !Load->isSimple()) + return false; + + // Match a scalar load of a bitcasted vector pointer. + // TODO: Extend this to match GEP with 0 or other offset. + Instruction *PtrOp; + Value *SrcPtr; + if (!match(Load->getPointerOperand(), + m_CombineAnd(m_Instruction(PtrOp), m_BitCast(m_Value(SrcPtr))))) + return false; + + // TODO: Extend this to allow widening of a sub-vector (not scalar) load. + auto *PtrOpTy = dyn_cast(PtrOp->getType()); + auto *SrcPtrTy = dyn_cast(SrcPtr->getType()); + if (!PtrOpTy || !SrcPtrTy) + return false; + + Type *ScalarTy = PtrOpTy->getElementType(); + auto *VectorTy = dyn_cast(SrcPtrTy->getElementType()); + if (ScalarTy->isVectorTy() || !VectorTy) + return false; + + // Check safety of replacing the scalar load with a larger vector load. + Align Alignment = Load->getAlign(); + const DataLayout &DL = I.getModule()->getDataLayout(); + if (!isSafeToLoadUnconditionally(SrcPtr, VectorTy, Alignment, DL, Load, &DT)) + return false; + + // Original pattern: load (bitcast VecPtr to ScalarPtr) + int OldCost = TTI.getMemoryOpCost(Instruction::Load, ScalarTy, Alignment, + Load->getPointerAddressSpace()); + OldCost += TTI.getCastInstrCost(Instruction::BitCast, PtrOpTy, SrcPtrTy); + + // If needed, bitcast the vector type to match the load (scalar element). + // Do not create a vector load of an unsupported type. + unsigned VecSize = VectorTy->getPrimitiveSizeInBits(); + Type *VecLoadTy = VectorTy; + if (VectorTy->getElementType() != Load->getType()) { + unsigned NumElts = VecSize / Load->getType()->getPrimitiveSizeInBits(); + VecLoadTy = VectorType::get(Load->getType(), NumElts); + } + + // New pattern: extractelt (load [bitcast] VecPtr), 0 + int NewCost = 0; + if (VecLoadTy != VectorTy) + NewCost += TTI.getCastInstrCost(Instruction::BitCast, + VecLoadTy->getPointerTo(), SrcPtrTy); + NewCost = TTI.getMemoryOpCost(Instruction::Load, VecLoadTy, Alignment, + Load->getPointerAddressSpace()); + NewCost += TTI.getVectorInstrCost(Instruction::ExtractElement, VectorTy, 0); + + // We can aggressively convert to the vector form because the backend will + // invert this transform if it does not result in a performance win. + if (OldCost < NewCost) + return false; + + // It is safe and profitable to load using the original vector pointer and + // extract the scalar value from that: + // load (bitcast VecPtr to ScalarPtr) --> extractelt (load VecPtr), 0 + IRBuilder<> Builder(Load); + if (VecLoadTy != VectorTy) + SrcPtr = Builder.CreateBitCast(SrcPtr, VecLoadTy->getPointerTo()); + + LoadInst *VecLd = Builder.CreateAlignedLoad(VecLoadTy, SrcPtr, Alignment); + Value *ExtElt = Builder.CreateExtractElement(VecLd, Builder.getInt32(0)); + Load->replaceAllUsesWith(ExtElt); + ExtElt->takeName(&I); + return true; +} /// Compare the relative costs of 2 extracts followed by scalar operation vs. /// vector operation(s) followed by extract. Return true if the existing @@ -423,6 +497,7 @@ for (Instruction &I : BB) { if (isa(I)) continue; + MadeChange |= vectorizeLoad(I, TTI, DT); MadeChange |= foldExtractExtract(I, TTI); MadeChange |= foldBitcastShuf(I, TTI); MadeChange |= scalarizeBinop(I, TTI); Index: llvm/test/Transforms/VectorCombine/X86/load-bitcast-vec.ll =================================================================== --- llvm/test/Transforms/VectorCombine/X86/load-bitcast-vec.ll +++ llvm/test/Transforms/VectorCombine/X86/load-bitcast-vec.ll @@ -6,8 +6,8 @@ define float @matching_scalar(<4 x float>* align 16 dereferenceable(16) %p) { ; CHECK-LABEL: @matching_scalar( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to float* -; CHECK-NEXT: [[R:%.*]] = load float, float* [[BC]], align 16 +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x float>, <4 x float>* [[P:%.*]], align 16 +; CHECK-NEXT: [[R:%.*]] = extractelement <4 x float> [[TMP1]], i32 0 ; CHECK-NEXT: ret float [[R]] ; %bc = bitcast <4 x float>* %p to float* @@ -17,8 +17,9 @@ define i32 @nonmatching_scalar(<4 x float>* align 16 dereferenceable(16) %p) { ; CHECK-LABEL: @nonmatching_scalar( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to i32* -; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[BC]], align 16 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x float>* [[P:%.*]] to <4 x i32>* +; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, <4 x i32>* [[TMP1]], align 16 +; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[TMP2]], i32 0 ; CHECK-NEXT: ret i32 [[R]] ; %bc = bitcast <4 x float>* %p to i32* @@ -28,8 +29,9 @@ define i64 @larger_scalar(<4 x float>* align 16 dereferenceable(16) %p) { ; CHECK-LABEL: @larger_scalar( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to i64* -; CHECK-NEXT: [[R:%.*]] = load i64, i64* [[BC]], align 16 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x float>* [[P:%.*]] to <2 x i64>* +; CHECK-NEXT: [[TMP2:%.*]] = load <2 x i64>, <2 x i64>* [[TMP1]], align 16 +; CHECK-NEXT: [[R:%.*]] = extractelement <2 x i64> [[TMP2]], i32 0 ; CHECK-NEXT: ret i64 [[R]] ; %bc = bitcast <4 x float>* %p to i64* @@ -39,8 +41,9 @@ define i8 @smaller_scalar(<4 x float>* align 16 dereferenceable(16) %p) { ; CHECK-LABEL: @smaller_scalar( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to i8* -; CHECK-NEXT: [[R:%.*]] = load i8, i8* [[BC]], align 16 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x float>* [[P:%.*]] to <16 x i8>* +; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i8>, <16 x i8>* [[TMP1]], align 16 +; CHECK-NEXT: [[R:%.*]] = extractelement <16 x i8> [[TMP2]], i32 0 ; CHECK-NEXT: ret i8 [[R]] ; %bc = bitcast <4 x float>* %p to i8* @@ -49,10 +52,16 @@ } define i8 @smaller_scalar_256bit_vec(<8 x float>* align 32 dereferenceable(32) %p) { -; CHECK-LABEL: @smaller_scalar_256bit_vec( -; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x float>* [[P:%.*]] to i8* -; CHECK-NEXT: [[R:%.*]] = load i8, i8* [[BC]], align 32 -; CHECK-NEXT: ret i8 [[R]] +; SSE-LABEL: @smaller_scalar_256bit_vec( +; SSE-NEXT: [[BC:%.*]] = bitcast <8 x float>* [[P:%.*]] to i8* +; SSE-NEXT: [[R:%.*]] = load i8, i8* [[BC]], align 32 +; SSE-NEXT: ret i8 [[R]] +; +; AVX-LABEL: @smaller_scalar_256bit_vec( +; AVX-NEXT: [[TMP1:%.*]] = bitcast <8 x float>* [[P:%.*]] to <32 x i8>* +; AVX-NEXT: [[TMP2:%.*]] = load <32 x i8>, <32 x i8>* [[TMP1]], align 32 +; AVX-NEXT: [[R:%.*]] = extractelement <32 x i8> [[TMP2]], i32 0 +; AVX-NEXT: ret i8 [[R]] ; %bc = bitcast <8 x float>* %p to i8* %r = load i8, i8* %bc, align 32 @@ -61,8 +70,9 @@ define i8 @smaller_scalar_less_aligned(<4 x float>* align 16 dereferenceable(16) %p) { ; CHECK-LABEL: @smaller_scalar_less_aligned( -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to i8* -; CHECK-NEXT: [[R:%.*]] = load i8, i8* [[BC]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x float>* [[P:%.*]] to <16 x i8>* +; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i8>, <16 x i8>* [[TMP1]], align 4 +; CHECK-NEXT: [[R:%.*]] = extractelement <16 x i8> [[TMP2]], i32 0 ; CHECK-NEXT: ret i8 [[R]] ; %bc = bitcast <4 x float>* %p to i8* @@ -70,6 +80,8 @@ ret i8 %r } +; negative test - not enough dereferenceable bytes + define float @matching_scalar_small_deref(<4 x float>* align 16 dereferenceable(15) %p) { ; CHECK-LABEL: @matching_scalar_small_deref( ; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to float* @@ -81,6 +93,8 @@ ret float %r } +; negative test - do not modify volatile + define float @matching_scalar_volatile(<4 x float>* align 16 dereferenceable(16) %p) { ; CHECK-LABEL: @matching_scalar_volatile( ; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x float>* [[P:%.*]] to float* @@ -92,6 +106,8 @@ ret float %r } +; negative test - not bitcast from vector + define float @nonvector(double* align 16 dereferenceable(16) %p) { ; CHECK-LABEL: @nonvector( ; CHECK-NEXT: [[BC:%.*]] = bitcast double* [[P:%.*]] to float*