Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -464,10 +464,6 @@ /// \return The width of the largest scalar or vector register type. unsigned getRegisterBitWidth(bool Vector) const; - /// \return The bitwidth of the largest vector type that should be used to - /// load/store in the given address space. - unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const; - /// \return The size of a cache line in bytes. unsigned getCacheLineSize() const; @@ -618,6 +614,38 @@ bool areInlineCompatible(const Function *Caller, const Function *Callee) const; + /// \returns The bitwidth of the largest vector type that should be used to + /// load/store in the given address space. + unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const; + + /// \returns True if the load instruction is legal to vectorize. + bool isLegalToVectorizeLoad(LoadInst *LI) const; + + /// \returns True if the store instruction is legal to vectorize. + bool isLegalToVectorizeStore(StoreInst *SI) const; + + /// \returns True if it is legal to vectorize the given load chain. + bool isLegalToVectorizeLoadChain(unsigned ChainSizeInBytes, + unsigned Alignment, + unsigned AddrSpace) const; + + /// \returns True if it is legal to vectorize the given store chain. + bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, + unsigned Alignment, + unsigned AddrSpace) const; + + /// \returns The new vector factor value if the target doesn't support \p + /// SizeInBytes loads or has a better vector factor. + unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const; + + /// \returns The new vector factor value if the target doesn't support \p + /// SizeInBytes stores or has a better vector factor. + unsigned getStoreVectorFactor(unsigned VF, unsigned StoreSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const; + /// @} private: @@ -693,7 +721,6 @@ Type *Ty) = 0; virtual unsigned getNumberOfRegisters(bool Vector) = 0; virtual unsigned getRegisterBitWidth(bool Vector) = 0; - virtual unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) = 0; virtual unsigned getCacheLineSize() = 0; virtual unsigned getPrefetchDistance() = 0; virtual unsigned getMinPrefetchStride() = 0; @@ -746,6 +773,21 @@ Type *ExpectedType) = 0; virtual bool areInlineCompatible(const Function *Caller, const Function *Callee) const = 0; + virtual unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const = 0; + virtual bool isLegalToVectorizeLoad(LoadInst *LI) const = 0; + virtual bool isLegalToVectorizeStore(StoreInst *SI) const = 0; + virtual bool isLegalToVectorizeLoadChain(unsigned ChainSizeInBytes, + unsigned Alignment, + unsigned AddrSpace) const = 0; + virtual bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, + unsigned Alignment, + unsigned AddrSpace) const = 0; + virtual unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const = 0; + virtual unsigned getStoreVectorFactor(unsigned VF, unsigned StoreSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const = 0; }; template @@ -888,10 +930,6 @@ return Impl.getRegisterBitWidth(Vector); } - unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) override { - return Impl.getLoadStoreVecRegBitWidth(AddrSpace); - } - unsigned getCacheLineSize() override { return Impl.getCacheLineSize(); } @@ -991,6 +1029,37 @@ const Function *Callee) const override { return Impl.areInlineCompatible(Caller, Callee); } + unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override { + return Impl.getLoadStoreVecRegBitWidth(AddrSpace); + } + bool isLegalToVectorizeLoad(LoadInst *LI) const override { + return Impl.isLegalToVectorizeLoad(LI); + } + bool isLegalToVectorizeStore(StoreInst *SI) const override { + return Impl.isLegalToVectorizeStore(SI); + } + bool isLegalToVectorizeLoadChain(unsigned ChainSizeInBytes, + unsigned Alignment, + unsigned AddrSpace) const override { + return Impl.isLegalToVectorizeLoadChain(ChainSizeInBytes, Alignment, + AddrSpace); + } + bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, + unsigned Alignment, + unsigned AddrSpace) const override { + return Impl.isLegalToVectorizeStoreChain(ChainSizeInBytes, Alignment, + AddrSpace); + } + unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const override { + return Impl.getLoadVectorFactor(VF, LoadSize, ChainSizeInBytes, VecTy); + } + unsigned getStoreVectorFactor(unsigned VF, unsigned StoreSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const override { + return Impl.getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy); + } }; template Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -290,8 +290,6 @@ unsigned getRegisterBitWidth(bool Vector) { return 32; } - unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) { return 128; } - unsigned getCacheLineSize() { return 0; } unsigned getPrefetchDistance() { return 0; } @@ -393,6 +391,36 @@ (Caller->getFnAttribute("target-features") == Callee->getFnAttribute("target-features")); } + + unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const { return 128; } + + bool isLegalToVectorizeLoad(LoadInst *LI) const { return true; } + + bool isLegalToVectorizeStore(StoreInst *SI) const { return true; } + + bool isLegalToVectorizeLoadChain(unsigned ChainSizeInBytes, + unsigned Alignment, + unsigned AddrSpace) const { + return true; + } + + bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, + unsigned Alignment, + unsigned AddrSpace) const { + return true; + } + + unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const { + return VF; + } + + unsigned getStoreVectorFactor(unsigned VF, unsigned StoreSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const { + return VF; + } }; /// \brief CRTP base class for use as a mix-in that aids implementing Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -251,10 +251,6 @@ return TTIImpl->getRegisterBitWidth(Vector); } -unsigned TargetTransformInfo::getLoadStoreVecRegBitWidth(unsigned AS) const { - return TTIImpl->getLoadStoreVecRegBitWidth(AS); -} - unsigned TargetTransformInfo::getCacheLineSize() const { return TTIImpl->getCacheLineSize(); } @@ -423,6 +419,44 @@ return TTIImpl->areInlineCompatible(Caller, Callee); } +unsigned TargetTransformInfo::getLoadStoreVecRegBitWidth(unsigned AS) const { + return TTIImpl->getLoadStoreVecRegBitWidth(AS); +} + +bool TargetTransformInfo::isLegalToVectorizeLoad(LoadInst *LI) const { + return TTIImpl->isLegalToVectorizeLoad(LI); +} + +bool TargetTransformInfo::isLegalToVectorizeStore(StoreInst *SI) const { + return TTIImpl->isLegalToVectorizeStore(SI); +} + +bool TargetTransformInfo::isLegalToVectorizeLoadChain( + unsigned ChainSizeInBytes, unsigned Alignment, unsigned AddrSpace) const { + return TTIImpl->isLegalToVectorizeLoadChain(ChainSizeInBytes, Alignment, + AddrSpace); +} + +bool TargetTransformInfo::isLegalToVectorizeStoreChain( + unsigned ChainSizeInBytes, unsigned Alignment, unsigned AddrSpace) const { + return TTIImpl->isLegalToVectorizeStoreChain(ChainSizeInBytes, Alignment, + AddrSpace); +} + +unsigned TargetTransformInfo::getLoadVectorFactor(unsigned VF, + unsigned LoadSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const { + return TTIImpl->getLoadVectorFactor(VF, LoadSize, ChainSizeInBytes, VecTy); +} + +unsigned TargetTransformInfo::getStoreVectorFactor(unsigned VF, + unsigned StoreSize, + unsigned ChainSizeInBytes, + VectorType *VecTy) const { + return TTIImpl->getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy); +} + TargetTransformInfo::Concept::~Concept() {} TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {} Index: lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h =================================================================== --- lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h +++ lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h @@ -82,7 +82,7 @@ unsigned getNumberOfRegisters(bool Vector); unsigned getRegisterBitWidth(bool Vector); - unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace); + unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const; unsigned getMaxInterleaveFactor(unsigned VF); int getArithmeticInstrCost( Index: lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp =================================================================== --- lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -80,7 +80,7 @@ return Vector ? 0 : 32; } -unsigned AMDGPUTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) { +unsigned AMDGPUTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const { switch (AddrSpace) { case AMDGPUAS::GLOBAL_ADDRESS: case AMDGPUAS::CONSTANT_ADDRESS: Index: lib/Transforms/Vectorize/LoadStoreVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -429,10 +429,13 @@ std::pair, ArrayRef> Vectorizer::splitOddVectorElts(ArrayRef Chain, unsigned ElementSizeBits) { - unsigned ElemSizeInBytes = ElementSizeBits / 8; - unsigned SizeInBytes = ElemSizeInBytes * Chain.size(); - unsigned NumRight = (SizeInBytes % 4) / ElemSizeInBytes; - unsigned NumLeft = Chain.size() - NumRight; + unsigned ElementSizeBytes = ElementSizeBits / 8; + unsigned SizeBytes = ElementSizeBytes * Chain.size(); + unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes; + if (NumLeft == Chain.size()) + --NumLeft; + else if (NumLeft == 0) + NumLeft = 1; return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft)); } @@ -540,6 +543,10 @@ if (!LI->isSimple()) continue; + // Skip if it's not legal. + if (!TTI.isLegalToVectorizeLoad(LI)) + continue; + Type *Ty = LI->getType(); if (!VectorType::isValidElementType(Ty->getScalarType())) continue; @@ -565,8 +572,6 @@ })) continue; - // TODO: Target hook to filter types. - // Save the load locations. Value *ObjPtr = GetUnderlyingObject(Ptr, DL); LoadRefs[ObjPtr].push_back(LI); @@ -575,6 +580,10 @@ if (!SI->isSimple()) continue; + // Skip if it's not legal. + if (!TTI.isLegalToVectorizeStore(SI)) + continue; + Type *Ty = SI->getValueOperand()->getType(); if (!VectorType::isValidElementType(Ty->getScalarType())) continue; @@ -719,6 +728,7 @@ unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); unsigned VF = VecRegSize / Sz; unsigned ChainSize = Chain.size(); + unsigned Alignment = getAlignment(S0); if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { InstructionsProcessed->insert(Chain.begin(), Chain.end()); @@ -741,17 +751,11 @@ Chain = NewChain; ChainSize = Chain.size(); - // Store size should be 1B, 2B or multiple of 4B. - // TODO: Target hook for size constraint? + // Check if it's legal to vectorize this chain. If not, split the chain and + // try again. unsigned EltSzInBytes = Sz / 8; unsigned SzInBytes = EltSzInBytes * ChainSize; - if (SzInBytes > 2 && SzInBytes % 4 != 0) { - DEBUG(dbgs() << "LSV: Size should be 1B, 2B " - "or multiple of 4B. Splitting.\n"); - if (SzInBytes == 3) - return vectorizeStoreChain(Chain.slice(0, ChainSize - 1), - InstructionsProcessed); - + if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) { auto Chains = splitOddVectorElts(Chain, Sz); return vectorizeStoreChain(Chains.first, InstructionsProcessed) | vectorizeStoreChain(Chains.second, InstructionsProcessed); @@ -765,13 +769,15 @@ else VecTy = VectorType::get(StoreTy, Chain.size()); - // If it's more than the max vector size, break it into two pieces. - // TODO: Target hook to control types to split to. - if (ChainSize > VF) { - DEBUG(dbgs() << "LSV: Vector factor is too big." + // If it's more than the max vector size or the target has a better + // vector factor, break it into two pieces. + unsigned TargetVF = TTI.getStoreVectorFactor(VF, Sz, SzInBytes, VecTy); + if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { + DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." " Creating two separate arrays.\n"); - return vectorizeStoreChain(Chain.slice(0, VF), InstructionsProcessed) | - vectorizeStoreChain(Chain.slice(VF), InstructionsProcessed); + return vectorizeStoreChain(Chain.slice(0, TargetVF), + InstructionsProcessed) | + vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed); } DEBUG({ @@ -784,9 +790,6 @@ // whether we succeed below. InstructionsProcessed->insert(Chain.begin(), Chain.end()); - // Check alignment restrictions. - unsigned Alignment = getAlignment(S0); - // If the store is going to be misaligned, don't vectorize it. if (accessIsMisaligned(SzInBytes, AS, Alignment)) { if (S0->getPointerAddressSpace() != 0) @@ -873,6 +876,7 @@ unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); unsigned VF = VecRegSize / Sz; unsigned ChainSize = Chain.size(); + unsigned Alignment = getAlignment(L0); if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { InstructionsProcessed->insert(Chain.begin(), Chain.end()); @@ -895,16 +899,11 @@ Chain = NewChain; ChainSize = Chain.size(); - // Load size should be 1B, 2B or multiple of 4B. - // TODO: Should size constraint be a target hook? + // Check if it's legal to vectorize this chain. If not, split the chain and + // try again. unsigned EltSzInBytes = Sz / 8; unsigned SzInBytes = EltSzInBytes * ChainSize; - if (SzInBytes > 2 && SzInBytes % 4 != 0) { - DEBUG(dbgs() << "LSV: Size should be 1B, 2B " - "or multiple of 4B. Splitting.\n"); - if (SzInBytes == 3) - return vectorizeLoadChain(Chain.slice(0, ChainSize - 1), - InstructionsProcessed); + if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) { auto Chains = splitOddVectorElts(Chain, Sz); return vectorizeLoadChain(Chains.first, InstructionsProcessed) | vectorizeLoadChain(Chains.second, InstructionsProcessed); @@ -918,22 +917,20 @@ else VecTy = VectorType::get(LoadTy, Chain.size()); - // If it's more than the max vector size, break it into two pieces. - // TODO: Target hook to control types to split to. - if (ChainSize > VF) { - DEBUG(dbgs() << "LSV: Vector factor is too big. " - "Creating two separate arrays.\n"); - return vectorizeLoadChain(Chain.slice(0, VF), InstructionsProcessed) | - vectorizeLoadChain(Chain.slice(VF), InstructionsProcessed); + // If it's more than the max vector size or the target has a better + // vector factor, break it into two pieces. + unsigned TargetVF = TTI.getLoadVectorFactor(VF, Sz, SzInBytes, VecTy); + if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { + DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." + " Creating two separate arrays.\n"); + return vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed) | + vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed); } // We won't try again to vectorize the elements of the chain, regardless of // whether we succeed below. InstructionsProcessed->insert(Chain.begin(), Chain.end()); - // Check alignment restrictions. - unsigned Alignment = getAlignment(L0); - // If the load is going to be misaligned, don't vectorize it. if (accessIsMisaligned(SzInBytes, AS, Alignment)) { if (L0->getPointerAddressSpace() != 0)