Index: ../include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- ../include/llvm/Analysis/TargetTransformInfo.h +++ ../include/llvm/Analysis/TargetTransformInfo.h @@ -310,12 +310,11 @@ bool HasBaseReg, int64_t Scale, unsigned AddrSpace = 0) const; - /// \brief Return true if the target works with masked instruction - /// AVX2 allows masks for consecutive load and store for i32 and i64 elements. - /// AVX-512 architecture will also allow masks for non-consecutive memory - /// accesses. - bool isLegalMaskedStore(Type *DataType, int Consecutive) const; - bool isLegalMaskedLoad(Type *DataType, int Consecutive) const; + /// \brief Return true if the target supports masked load/store + /// AVX2 and AVX-512 targets allow masks for consecutive load and store for + /// 32 and 64 bit elements. + bool isLegalMaskedStore(Type *DataType) const; + bool isLegalMaskedLoad(Type *DataType) const; /// \brief Return the cost of the scaling factor used in the addressing /// mode represented by AM for this target, for a load/store @@ -568,8 +567,8 @@ int64_t BaseOffset, bool HasBaseReg, int64_t Scale, unsigned AddrSpace) = 0; - virtual bool isLegalMaskedStore(Type *DataType, int Consecutive) = 0; - virtual bool isLegalMaskedLoad(Type *DataType, int Consecutive) = 0; + virtual bool isLegalMaskedStore(Type *DataType) = 0; + virtual bool isLegalMaskedLoad(Type *DataType) = 0; virtual int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, unsigned AddrSpace) = 0; @@ -693,11 +692,11 @@ return Impl.isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale, AddrSpace); } - bool isLegalMaskedStore(Type *DataType, int Consecutive) override { - return Impl.isLegalMaskedStore(DataType, Consecutive); + bool isLegalMaskedStore(Type *DataType) override { + return Impl.isLegalMaskedStore(DataType); } - bool isLegalMaskedLoad(Type *DataType, int Consecutive) override { - return Impl.isLegalMaskedLoad(DataType, Consecutive); + bool isLegalMaskedLoad(Type *DataType) override { + return Impl.isLegalMaskedLoad(DataType); } int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, Index: ../include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- ../include/llvm/Analysis/TargetTransformInfoImpl.h +++ ../include/llvm/Analysis/TargetTransformInfoImpl.h @@ -209,9 +209,9 @@ return !BaseGV && BaseOffset == 0 && (Scale == 0 || Scale == 1); } - bool isLegalMaskedStore(Type *DataType, int Consecutive) { return false; } + bool isLegalMaskedStore(Type *DataType) { return false; } - bool isLegalMaskedLoad(Type *DataType, int Consecutive) { return false; } + bool isLegalMaskedLoad(Type *DataType) { return false; } int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, unsigned AddrSpace) { Index: ../lib/Analysis/TargetTransformInfo.cpp =================================================================== --- ../lib/Analysis/TargetTransformInfo.cpp +++ ../lib/Analysis/TargetTransformInfo.cpp @@ -113,14 +113,12 @@ Scale, AddrSpace); } -bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, - int Consecutive) const { - return TTIImpl->isLegalMaskedStore(DataType, Consecutive); +bool TargetTransformInfo::isLegalMaskedStore(Type *DataType) const { + return TTIImpl->isLegalMaskedStore(DataType); } -bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, - int Consecutive) const { - return TTIImpl->isLegalMaskedLoad(DataType, Consecutive); +bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType) const { + return TTIImpl->isLegalMaskedLoad(DataType); } int TargetTransformInfo::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, Index: ../lib/CodeGen/CodeGenPrepare.cpp =================================================================== --- ../lib/CodeGen/CodeGenPrepare.cpp +++ ../lib/CodeGen/CodeGenPrepare.cpp @@ -1384,7 +1384,7 @@ } case Intrinsic::masked_load: { // Scalarize unsupported vector masked load - if (!TTI->isLegalMaskedLoad(CI->getType(), 1)) { + if (!TTI->isLegalMaskedLoad(CI->getType())) { ScalarizeMaskedLoad(CI); ModifiedDT = true; return true; @@ -1392,7 +1392,7 @@ return false; } case Intrinsic::masked_store: { - if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType(), 1)) { + if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) { ScalarizeMaskedStore(CI); ModifiedDT = true; return true; Index: ../lib/Target/X86/X86TargetTransformInfo.h =================================================================== --- ../lib/Target/X86/X86TargetTransformInfo.h +++ ../lib/Target/X86/X86TargetTransformInfo.h @@ -88,8 +88,8 @@ int getIntImmCost(unsigned Opcode, unsigned Idx, const APInt &Imm, Type *Ty); int getIntImmCost(Intrinsic::ID IID, unsigned Idx, const APInt &Imm, Type *Ty); - bool isLegalMaskedLoad(Type *DataType, int Consecutive); - bool isLegalMaskedStore(Type *DataType, int Consecutive); + bool isLegalMaskedLoad(Type *DataType); + bool isLegalMaskedStore(Type *DataType); bool areInlineCompatible(const Function *Caller, const Function *Callee) const; Index: ../lib/Target/X86/X86TargetTransformInfo.cpp =================================================================== --- ../lib/Target/X86/X86TargetTransformInfo.cpp +++ ../lib/Target/X86/X86TargetTransformInfo.cpp @@ -899,8 +899,8 @@ unsigned NumElem = SrcVTy->getVectorNumElements(); VectorType *MaskTy = VectorType::get(Type::getInt8Ty(getGlobalContext()), NumElem); - if ((Opcode == Instruction::Load && !isLegalMaskedLoad(SrcVTy, 1)) || - (Opcode == Instruction::Store && !isLegalMaskedStore(SrcVTy, 1)) || + if ((Opcode == Instruction::Load && !isLegalMaskedLoad(SrcVTy)) || + (Opcode == Instruction::Store && !isLegalMaskedStore(SrcVTy)) || !isPowerOf2_32(NumElem)) { // Scalarization int MaskSplitCost = getScalarizationOverhead(MaskTy, false, true); @@ -1189,19 +1189,16 @@ return X86TTIImpl::getIntImmCost(Imm, Ty); } -bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, int Consecutive) { - int DataWidth = DataTy->getPrimitiveSizeInBits(); +bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy) { + Type *ScalarTy = DataTy->getScalarType(); + int DataWidth = ScalarTy->isPointerTy() ? DL.getPointerSizeInBits() : + ScalarTy->getPrimitiveSizeInBits(); - // Todo: AVX512 allows gather/scatter, works with strided and random as well - if ((DataWidth < 32) || (Consecutive == 0)) - return false; - if (ST->hasAVX512() || ST->hasAVX2()) - return true; - return false; + return (DataWidth >= 32 && ST->hasAVX2()); } -bool X86TTIImpl::isLegalMaskedStore(Type *DataType, int Consecutive) { - return isLegalMaskedLoad(DataType, Consecutive); +bool X86TTIImpl::isLegalMaskedStore(Type *DataType) { + return isLegalMaskedLoad(DataType); } bool X86TTIImpl::areInlineCompatible(const Function *Caller, Index: ../lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- ../lib/Transforms/Vectorize/LoopVectorize.cpp +++ ../lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1226,12 +1226,12 @@ /// Returns true if the target machine supports masked store operation /// for the given \p DataType and kind of access to \p Ptr. bool isLegalMaskedStore(Type *DataType, Value *Ptr) { - return TTI->isLegalMaskedStore(DataType, isConsecutivePtr(Ptr)); + return isConsecutivePtr(Ptr) && TTI->isLegalMaskedStore(DataType); } /// Returns true if the target machine supports masked load operation /// for the given \p DataType and kind of access to \p Ptr. bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { - return TTI->isLegalMaskedLoad(DataType, isConsecutivePtr(Ptr)); + return isConsecutivePtr(Ptr) && TTI->isLegalMaskedLoad(DataType); } /// Returns true if vector representation of the instruction \p I /// requires mask.