diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -574,10 +574,10 @@ /// modes that operate across loop iterations. bool shouldFavorBackedgeIndex(const Loop *L) const; - /// Return true if the target supports masked load. - bool isLegalMaskedStore(Type *DataType) const; /// Return true if the target supports masked store. - bool isLegalMaskedLoad(Type *DataType) const; + bool isLegalMaskedStore(Type *DataType, MaybeAlign Alignment) const; + /// Return true if the target supports masked load. + bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment) const; /// Return true if the target supports nontemporal store. bool isLegalNTStore(Type *DataType, Align Alignment) const; @@ -1209,8 +1209,8 @@ TargetLibraryInfo *LibInfo) = 0; virtual bool shouldFavorPostInc() const = 0; virtual bool shouldFavorBackedgeIndex(const Loop *L) const = 0; - virtual bool isLegalMaskedStore(Type *DataType) = 0; - virtual bool isLegalMaskedLoad(Type *DataType) = 0; + virtual bool isLegalMaskedStore(Type *DataType, MaybeAlign Alignment) = 0; + virtual bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment) = 0; virtual bool isLegalNTStore(Type *DataType, Align Alignment) = 0; virtual bool isLegalNTLoad(Type *DataType, Align Alignment) = 0; virtual bool isLegalMaskedScatter(Type *DataType) = 0; @@ -1496,11 +1496,11 @@ bool shouldFavorBackedgeIndex(const Loop *L) const override { return Impl.shouldFavorBackedgeIndex(L); } - bool isLegalMaskedStore(Type *DataType) override { - return Impl.isLegalMaskedStore(DataType); + bool isLegalMaskedStore(Type *DataType, MaybeAlign Alignment) override { + return Impl.isLegalMaskedStore(DataType, Alignment); } - bool isLegalMaskedLoad(Type *DataType) override { - return Impl.isLegalMaskedLoad(DataType); + bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment) override { + return Impl.isLegalMaskedLoad(DataType, Alignment); } bool isLegalNTStore(Type *DataType, Align Alignment) override { return Impl.isLegalNTStore(DataType, Alignment); diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -243,9 +243,9 @@ bool shouldFavorBackedgeIndex(const Loop *L) const { return false; } - bool isLegalMaskedStore(Type *DataType) { return false; } + bool isLegalMaskedStore(Type *DataType, MaybeAlign Alignment) { return false; } - bool isLegalMaskedLoad(Type *DataType) { return false; } + bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment) { return false; } bool isLegalNTStore(Type *DataType, Align Alignment) { // By default, assume nontemporal memory stores are available for stores diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -288,12 +288,14 @@ return TTIImpl->shouldFavorBackedgeIndex(L); } -bool TargetTransformInfo::isLegalMaskedStore(Type *DataType) const { - return TTIImpl->isLegalMaskedStore(DataType); +bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, + MaybeAlign Alignment) const { + return TTIImpl->isLegalMaskedStore(DataType, Alignment); } -bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType) const { - return TTIImpl->isLegalMaskedLoad(DataType); +bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, + MaybeAlign Alignment) const { + return TTIImpl->isLegalMaskedLoad(DataType, Alignment); } bool TargetTransformInfo::isLegalNTStore(Type *DataType, diff --git a/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp --- a/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp +++ b/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp @@ -851,17 +851,24 @@ switch (II->getIntrinsicID()) { default: break; - case Intrinsic::masked_load: + case Intrinsic::masked_load: { // Scalarize unsupported vector masked load - if (TTI->isLegalMaskedLoad(CI->getType())) + unsigned Alignment = + cast(CI->getArgOperand(1))->getZExtValue(); + if (TTI->isLegalMaskedLoad(CI->getType(), MaybeAlign(Alignment))) return false; scalarizeMaskedLoad(CI, ModifiedDT); return true; - case Intrinsic::masked_store: - if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) + } + case Intrinsic::masked_store: { + unsigned Alignment = + cast(CI->getArgOperand(2))->getZExtValue(); + if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType(), + MaybeAlign(Alignment))) return false; scalarizeMaskedStore(CI, ModifiedDT); return true; + } case Intrinsic::masked_gather: if (TTI->isLegalMaskedGather(CI->getType())) return false; diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h @@ -153,8 +153,10 @@ return ST->getMaxInterleaveFactor(); } - bool isLegalMaskedLoad(Type *DataTy); - bool isLegalMaskedStore(Type *DataTy) { return isLegalMaskedLoad(DataTy); } + bool isLegalMaskedLoad(Type *DataTy, MaybeAlign Alignment); + bool isLegalMaskedStore(Type *DataTy, MaybeAlign Alignment) { + return isLegalMaskedLoad(DataTy, Alignment); + } int getMemcpyCost(const Instruction *I); diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -491,7 +491,7 @@ return BaseT::getAddressComputationCost(Ty, SE, Ptr); } -bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy) { +bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, MaybeAlign Alignment) { if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps()) return false; diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -185,8 +185,8 @@ bool isLSRCostLess(TargetTransformInfo::LSRCost &C1, TargetTransformInfo::LSRCost &C2); bool canMacroFuseCmp(); - bool isLegalMaskedLoad(Type *DataType); - bool isLegalMaskedStore(Type *DataType); + bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment); + bool isLegalMaskedStore(Type *DataType, MaybeAlign Alignment); bool isLegalNTLoad(Type *DataType, Align Alignment); bool isLegalNTStore(Type *DataType, Align Alignment); bool isLegalMaskedGather(Type *DataType); diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -2417,8 +2417,9 @@ unsigned NumElem = SrcVTy->getVectorNumElements(); VectorType *MaskTy = VectorType::get(Type::getInt8Ty(SrcVTy->getContext()), NumElem); - if ((IsLoad && !isLegalMaskedLoad(SrcVTy)) || - (IsStore && !isLegalMaskedStore(SrcVTy)) || !isPowerOf2_32(NumElem)) { + if ((IsLoad && !isLegalMaskedLoad(SrcVTy, MaybeAlign(Alignment))) || + (IsStore && !isLegalMaskedStore(SrcVTy, MaybeAlign(Alignment))) || + !isPowerOf2_32(NumElem)) { // Scalarization int MaskSplitCost = getScalarizationOverhead(MaskTy, false, true); int ScalarCompareCost = getCmpSelInstrCost( @@ -3213,7 +3214,7 @@ return ST->hasMacroFusion() || ST->hasBranchFusion(); } -bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy) { +bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, MaybeAlign Alignment) { if (!ST->hasAVX()) return false; @@ -3236,8 +3237,8 @@ ((IntWidth == 8 || IntWidth == 16) && ST->hasBWI()); } -bool X86TTIImpl::isLegalMaskedStore(Type *DataType) { - return isLegalMaskedLoad(DataType); +bool X86TTIImpl::isLegalMaskedStore(Type *DataType, MaybeAlign Alignment) { + return isLegalMaskedLoad(DataType, Alignment); } bool X86TTIImpl::isLegalNTLoad(Type *DataType, Align Alignment) { diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1190,14 +1190,16 @@ /// Returns true if the target machine supports masked store operation /// for the given \p DataType and kind of access to \p Ptr. - bool isLegalMaskedStore(Type *DataType, Value *Ptr) { - return Legal->isConsecutivePtr(Ptr) && TTI.isLegalMaskedStore(DataType); + bool isLegalMaskedStore(Type *DataType, Value *Ptr, unsigned Alignment) { + return Legal->isConsecutivePtr(Ptr) && + TTI.isLegalMaskedStore(DataType, MaybeAlign(Alignment)); } /// Returns true if the target machine supports masked load operation /// for the given \p DataType and kind of access to \p Ptr. - bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { - return Legal->isConsecutivePtr(Ptr) && TTI.isLegalMaskedLoad(DataType); + bool isLegalMaskedLoad(Type *DataType, Value *Ptr, unsigned Alignment) { + return Legal->isConsecutivePtr(Ptr) && + TTI.isLegalMaskedLoad(DataType, MaybeAlign(Alignment)); } /// Returns true if the target machine supports masked scatter operation @@ -4551,6 +4553,7 @@ return false; auto *Ptr = getLoadStorePointerOperand(I); auto *Ty = getMemInstValueType(I); + unsigned Alignment = getLoadStoreAlignment(I); // We have already decided how to vectorize this instruction, get that // result. if (VF > 1) { @@ -4560,8 +4563,8 @@ return WideningDecision == CM_Scalarize; } return isa(I) ? - !(isLegalMaskedLoad(Ty, Ptr) || isLegalMaskedGather(Ty)) - : !(isLegalMaskedStore(Ty, Ptr) || isLegalMaskedScatter(Ty)); + !(isLegalMaskedLoad(Ty, Ptr, Alignment) || isLegalMaskedGather(Ty)) + : !(isLegalMaskedStore(Ty, Ptr, Alignment) || isLegalMaskedScatter(Ty)); } case Instruction::UDiv: case Instruction::SDiv: @@ -4604,8 +4607,9 @@ "Masked interleave-groups for predicated accesses are not enabled."); auto *Ty = getMemInstValueType(I); - return isa(I) ? TTI.isLegalMaskedLoad(Ty) - : TTI.isLegalMaskedStore(Ty); + unsigned Alignment = getLoadStoreAlignment(I); + return isa(I) ? TTI.isLegalMaskedLoad(Ty, MaybeAlign(Alignment)) + : TTI.isLegalMaskedStore(Ty, MaybeAlign(Alignment)); } bool LoopVectorizationCostModel::memoryInstructionCanBeWidened(Instruction *I,