diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp --- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp +++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -30,6 +30,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -123,6 +124,7 @@ struct CachingVPExpander { Function &F; const TargetTransformInfo &TTI; + const DataLayout &DL; /// \returns A (fixed length) vector with ascending integer indices /// (<0, 1, ..., NumElems-1>). @@ -172,9 +174,18 @@ VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const; bool UsingTTIOverrides; + /// \brief Lower this llvm.vp.(load|store|gather|scatter) to a non-vp + /// instruction. + Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, + VPIntrinsic &VPI); + + Value *expandPredicationInUnfoldedLoadStore(IRBuilder<> &Builder, + VPIntrinsic &VPI); + public: - CachingVPExpander(Function &F, const TargetTransformInfo &TTI) - : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {} + CachingVPExpander(Function &F, const TargetTransformInfo &TTI, + const DataLayout &DL) + : F(F), TTI(TTI), DL(DL), UsingTTIOverrides(anyExpandVPOverridesSet()) {} bool expandVectorPredication(); }; @@ -383,6 +394,285 @@ return Reduction; } +/// \brief Lower this llvm.vp.(load|store|gather|scatter) to a non-vp +/// instruction. +Value * +CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, + VPIntrinsic &VPI) { + assert(VPI.canIgnoreVectorLengthParam()); + auto &I = cast(VPI); + + auto MaskParam = VPI.getMaskParam(); + auto PtrParam = VPI.getMemoryPointerParam(); + auto DataParam = VPI.getMemoryDataParam(); + bool IsUnmasked = isAllTrueMask(MaskParam); + + MaybeAlign AlignOpt = VPI.getPointerAlignment(); + + Value *NewMemoryInst = nullptr; + switch (VPI.getIntrinsicID()) { + default: + abort(); // not a VP memory intrinsic + + case Intrinsic::vp_store: { + if (IsUnmasked) { + StoreInst *NewStore = Builder.CreateStore(DataParam, PtrParam, false); + if (AlignOpt.hasValue()) + NewStore->setAlignment(AlignOpt.getValue()); + NewMemoryInst = NewStore; + } else { + NewMemoryInst = Builder.CreateMaskedStore( + DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam); + } + } break; + + case Intrinsic::vp_load: { + if (IsUnmasked) { + LoadInst *NewLoad = Builder.CreateLoad(VPI.getType(), PtrParam, false); + if (AlignOpt.hasValue()) + NewLoad->setAlignment(AlignOpt.getValue()); + NewMemoryInst = NewLoad; + } else { + NewMemoryInst = Builder.CreateMaskedLoad( + VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam); + } + } break; + + case Intrinsic::vp_scatter: { + NewMemoryInst = Builder.CreateMaskedScatter( + DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam); + } break; + + case Intrinsic::vp_gather: { + NewMemoryInst = Builder.CreateMaskedGather(VPI.getType(), PtrParam, + AlignOpt.valueOrOne(), MaskParam, + nullptr, I.getName()); + } break; + } + + assert(NewMemoryInst); + replaceOperation(*NewMemoryInst, VPI); + return NewMemoryInst; +} + +// The following are helper functions for loading and storing subvectors with +// variable offsets. There is currently no support for shuffles with +// non-constant masks, so these operations have to be done lane by lane. + +// Create a load into Dest from the subvector of src given by a variable Offset +// and constant Width. Src is a pointer; Dest is a fixed-width vector; Offset +// and Width are specified in lanes. +Value *LoadSubvector(Value *Dest, Value *Src, Value *Offset, unsigned Width, + MaybeAlign EltAlign, Type *OffsetTy, + Instruction *InsertPt) { + assert(OffsetTy->isIntegerTy() && "Offset must be an integer type!"); + assert(Src->getType()->isPointerTy() && "Source must be a pointer!"); + assert(Dest->getType()->isVectorTy() && "Destination must be a vector!"); + Type *EltTy = Dest->getType()->getScalarType(); + IRBuilder<> Builder(InsertPt); + Builder.SetCurrentDebugLocation(InsertPt->getDebugLoc()); + Value *SrcEltPtr = Builder.CreatePointerCast( + Src, EltTy->getPointerTo(Src->getType()->getPointerAddressSpace())); + auto *SubvecSrc = Builder.CreateInBoundsGEP(EltTy, SrcEltPtr, Offset); + Value *VResult = Dest; + for (unsigned i = 0; i < Width; ++i) { + Value *vi = ConstantInt::get(OffsetTy, i); + auto *EltOffset = Builder.CreateAdd(Offset, vi); + auto *EltPtr = Builder.CreateInBoundsGEP(EltTy, SubvecSrc, vi); + Value *EltLoad = Builder.CreateAlignedLoad(EltTy, EltPtr, EltAlign); + VResult = Builder.CreateInsertElement(VResult, EltLoad, EltOffset); + } + return VResult; +} + +// Create a store into Dest of the subvector of Val given by a variable Offset +// and constant Width. Dest is a pointer; Val is a fixed-width vector; Offset +// and Width are specified in lanes. +void StoreSubvector(Value *Val, Value *Dest, Value *Offset, unsigned Width, + MaybeAlign EltAlign, Type *OffsetTy, + Instruction *InsertPt) { + assert(OffsetTy->isIntegerTy() && "Offset must be an integer type!"); + assert(Dest->getType()->isPointerTy() && "Destination must be a pointer!"); + assert(Val->getType()->isVectorTy() && "Value must be a vector!"); + Type *EltTy = Val->getType()->getScalarType(); + IRBuilder<> Builder(InsertPt); + Builder.SetCurrentDebugLocation(InsertPt->getDebugLoc()); + Value *DestEltPtr = Builder.CreatePointerCast( + Dest, EltTy->getPointerTo(Dest->getType()->getPointerAddressSpace())); + auto *SubvecDest = Builder.CreateInBoundsGEP(EltTy, DestEltPtr, Offset); + for (unsigned i = 0; i < Width; ++i) { + Value *vi = ConstantInt::get(OffsetTy, i); + auto *EltOffset = Builder.CreateAdd(Offset, vi); + auto *EltPtr = Builder.CreateInBoundsGEP(EltTy, SubvecDest, vi); + Value *EltLoad = Builder.CreateExtractElement(Val, EltOffset); + Builder.CreateAlignedStore(EltLoad, EltPtr, EltAlign); + } + return; +} + +// We can split a vector store with variable length into contiguous conditional +// stores of powers of 2, one for each active bit in the length value. The +// offsets of the stores can be computed unconditionally using bitmasks of the +// length. The resulting logic looks like this: +// PreBB: +// // ... before intrinsic call +// goto HeadBB; +// HeadBB: +// if (Length == VectorWidth) +// goto ShortBB; +// else +// goto LongBB; +// ShortBB: +// // load/store full vector +// goto PostBB; +// LongBB: +// for (int i = 0; i < LengthBits; ++i) { +// if (hasBitSet(Length, i)) +// // load/store subvector of width 2^i +// } +// goto PostBB; +// PostBB: +// // after the intrinsic call ... + +Value * +CachingVPExpander::expandPredicationInUnfoldedLoadStore(IRBuilder<> &Builder, + VPIntrinsic &VPI) { + assert(!VPI.canIgnoreVectorLengthParam()); + unsigned OC = *VPI.getFunctionalOpcode(); + + auto &I = cast(VPI); + + Value *VLParam = VPI.getVectorLengthParam(); + Value *PtrParam = VPI.getMemoryPointerParam(); + Value *DataParam = VPI.getMemoryDataParam(); + Value *MaskParam = VPI.getMaskParam(); + assert(isAllTrueMask(MaskParam)); + + MaybeAlign AlignOpt = VPI.getPointerAlignment(); + + Value *NewMemoryInst = nullptr; + char const *Prefix; + + switch (OC) { + default: + abort(); // not a VP load or store + + case Instruction::Load: + Prefix = "vp.load."; + break; + case Instruction::Store: + Prefix = "vp.store."; + break; + } + + bool isLoad = (OC == Instruction::Load); + + auto *VecTy = isLoad ? cast(VPI.getType()) + : cast(DataParam->getType()); + unsigned VecNumElts = VecTy->getNumElements(); + Type *VecEltTy = VecTy->getElementType(); + Type *VLTy = VLParam->getType(); + + Builder.SetCurrentDebugLocation(I.getDebugLoc()); + + if (isa(VLParam)) { + switch (OC) { + case Instruction::Load: { + LoadInst *NewLoad = Builder.CreateLoad(VPI.getType(), PtrParam, false); + if (AlignOpt.hasValue()) + NewLoad->setAlignment(AlignOpt.getValue()); + NewMemoryInst = NewLoad; + } break; + case Instruction::Store: { + StoreInst *NewStore = Builder.CreateStore(DataParam, PtrParam, false); + if (AlignOpt.hasValue()) + NewStore->setAlignment(AlignOpt.getValue()); + NewMemoryInst = NewStore; + } break; + default: + break; + } + replaceOperation(*NewMemoryInst, VPI); + return NewMemoryInst; + } + + Instruction *ShortTerm, *LongTerm, *ThenTerm; + Value *Pred; + const Align BranchAlignment = commonAlignment( + AlignOpt.valueOrOne(), VecEltTy->getPrimitiveSizeInBits() / 8); + + Value *VResult = (isLoad ? UndefValue::get(VecTy) : nullptr); + + Pred = Builder.CreateICmpEQ(VLParam, ConstantInt::get(VLTy, VecNumElts)); + if (!isLoad) + VResult = I.getParent()->getTerminator(); + SplitBlockAndInsertIfThenElse(Pred, &I, &ShortTerm, &LongTerm); + ShortTerm->getParent()->setName(Twine(Prefix) + "short"); + LongTerm->getParent()->setName(Twine(Prefix) + "long"); + I.getParent()->setName(Twine(Prefix) + "exit"); + + unsigned LastBranchBit = Log2_64_Ceil(VecNumElts); + unsigned BranchMask = maskTrailingOnes(LastBranchBit); + unsigned BranchBit = LastBranchBit; + while (BranchBit--) { // postdecr to avoid compairing 0u-1 + unsigned BranchOffsetMask = + maskTrailingOnes(BranchBit + 1) ^ BranchMask; + unsigned BranchWidth = 1 << BranchBit; + Value *BranchWidthValue = ConstantInt::get(VLTy, BranchWidth); + Value *BranchOffsetMaskValue = ConstantInt::get(VLTy, BranchOffsetMask); + + BasicBlock *IfBB = LongTerm->getParent(); + Builder.SetInsertPoint(LongTerm); + Pred = Builder.CreateICmpUGT(Builder.CreateAnd(VLParam, BranchWidthValue), + ConstantInt::get(VLTy, 0)); + ThenTerm = SplitBlockAndInsertIfThen(Pred, LongTerm, /*Unreachable*/ false); + ThenTerm->getParent()->setName(Twine(Prefix) + "branch"); + LongTerm->getParent()->setName(Twine(Prefix) + "long"); + Builder.SetInsertPoint(ThenTerm); + Value *BranchOffsetValue = + Builder.CreateAnd(VLParam, BranchOffsetMaskValue); + + if (isLoad) { + Value *BranchVResult = + LoadSubvector(VResult, PtrParam, BranchOffsetValue, BranchWidth, + BranchAlignment, VLTy, ThenTerm); + Builder.SetInsertPoint(LongTerm); + PHINode *ThenPhi = Builder.CreatePHI(VecTy, 2); + ThenPhi->addIncoming(BranchVResult, ThenTerm->getParent()); + ThenPhi->addIncoming(VResult, IfBB); + VResult = ThenPhi; + } else { + StoreSubvector(DataParam, PtrParam, BranchOffsetValue, BranchWidth, + BranchAlignment, VLTy, ThenTerm); + } + } + + Builder.SetInsertPoint(ShortTerm); + Value *ShortVResult; + if (isLoad) { + LoadInst *ShortLoad = Builder.CreateLoad(VPI.getType(), PtrParam, false); + if (AlignOpt.hasValue()) + ShortLoad->setAlignment(AlignOpt.getValue()); + ShortVResult = ShortLoad; + } else { + StoreInst *ShortStore = Builder.CreateStore(DataParam, PtrParam, false); + if (AlignOpt.hasValue()) + ShortStore->setAlignment(AlignOpt.getValue()); + } + + if (isLoad) { + Builder.SetInsertPoint(&I); + PHINode *Phi = Builder.CreatePHI(VecTy, 2); + Phi->addIncoming(VResult, LongTerm->getParent()); + Phi->addIncoming(ShortVResult, ShortTerm->getParent()); + VResult = Phi; + } + + NewMemoryInst = VResult; + replaceOperation(*NewMemoryInst, VPI); + return NewMemoryInst; +} + void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) { LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n"); @@ -430,6 +720,25 @@ LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n'); LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); + // If the mask is trivial, and there is no lowering for the corresponding + // masked intrinsic, then we use the alternate evl-only scalarization. + switch (VPI.getIntrinsicID()) { + default: + break; + case Intrinsic::vp_load: + if (isAllTrueMask(OldMaskParam) && + !TTI.isLegalMaskedLoad(VPI.getType(), + VPI.getPointerAlignment().valueOrOne())) + return &VPI; + break; + case Intrinsic::vp_store: + if (isAllTrueMask(OldMaskParam) && + !TTI.isLegalMaskedLoad(VPI.getMemoryDataParam()->getType(), + VPI.getPointerAlignment().valueOrOne())) + return &VPI; + break; + } + // Convert the %evl predication into vector mask predication. ElementCount ElemCount = VPI.getStaticVectorLength(); Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); @@ -459,6 +768,20 @@ if (auto *VPRI = dyn_cast(&VPI)) return expandPredicationInReduction(Builder, *VPRI); + switch (VPI.getIntrinsicID()) { + default: + abort(); // unexpected intrinsic + case Intrinsic::vp_load: + case Intrinsic::vp_store: + if (!VPI.canIgnoreVectorLengthParam() && + isAllTrueMask(VPI.getMaskParam())) { + return expandPredicationInUnfoldedLoadStore(Builder, VPI); + } else { + return expandPredicationInMemoryIntrinsic(Builder, VPI); + } + break; + } + return &VPI; } @@ -572,7 +895,8 @@ bool runOnFunction(Function &F) override { const auto *TTI = &getAnalysis().getTTI(F); - CachingVPExpander VPExpander(F, *TTI); + const auto &DL = F.getParent()->getDataLayout(); + CachingVPExpander VPExpander(F, *TTI, DL); return VPExpander.expandVectorPredication(); } @@ -598,7 +922,8 @@ PreservedAnalyses ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { const auto &TTI = AM.getResult(F); - CachingVPExpander VPExpander(F, TTI); + const auto &DL = F.getParent()->getDataLayout(); + CachingVPExpander VPExpander(F, TTI, DL); if (!VPExpander.expandVectorPredication()) return PreservedAnalyses::all(); PreservedAnalyses PA;