Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -660,6 +660,10 @@ /// Return true if the target supports masked expand load. bool isLegalMaskedExpandLoad(Type *DataType) const; + /// Return true if the target supports vectorization of the intrinsic IID for + /// a given ElementCount VF. + bool isLegalScalableVectorIntrinsic(Intrinsic::ID IID) const; + /// Return true if the target has a unified operation to calculate division /// and remainder. If so, the additional implicit multiplication and /// subtraction required to calculate a remainder from division are free. This @@ -1516,6 +1520,7 @@ virtual bool isLegalMaskedGather(Type *DataType, Align Alignment) = 0; virtual bool isLegalMaskedCompressStore(Type *DataType) = 0; virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0; + virtual bool isLegalScalableVectorIntrinsic(Intrinsic::ID IID) = 0; virtual bool hasDivRemOp(Type *DataType, bool IsSigned) = 0; virtual bool hasVolatileVariant(Instruction *I, unsigned AddrSpace) = 0; virtual bool prefersVectorizedAddressing() = 0; @@ -1896,6 +1901,10 @@ bool isLegalMaskedExpandLoad(Type *DataType) override { return Impl.isLegalMaskedExpandLoad(DataType); } + bool isLegalScalableVectorIntrinsic(Intrinsic::ID IID) override { + return Impl.isLegalScalableVectorIntrinsic(IID); + } + bool hasDivRemOp(Type *DataType, bool IsSigned) override { return Impl.hasDivRemOp(DataType, IsSigned); } Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -258,6 +258,8 @@ return false; } + bool isLegalScalableVectorIntrinsic(Intrinsic::ID IID) const { return false; } + bool isLegalMaskedCompressStore(Type *DataType) const { return false; } bool isLegalMaskedExpandLoad(Type *DataType) const { return false; } Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -409,6 +409,11 @@ return TTIImpl->isLegalMaskedExpandLoad(DataType); } +bool TargetTransformInfo::isLegalScalableVectorIntrinsic( + Intrinsic::ID IID) const { + return TTIImpl->isLegalScalableVectorIntrinsic(IID); +} + bool TargetTransformInfo::hasDivRemOp(Type *DataType, bool IsSigned) const { return TTIImpl->hasDivRemOp(DataType, IsSigned); } Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -256,6 +256,8 @@ return isLegalMaskedGatherScatter(DataType); } + bool isLegalScalableVectorIntrinsic(Intrinsic::ID IID) const; + bool isLegalNTStore(Type *DataType, Align Alignment) { // NOTE: The logic below is mostly geared towards LV, which calls it with // vectors with 2 elements. We might want to improve that, if other Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1717,3 +1717,48 @@ return BaseT::getShuffleCost(Kind, Tp, Mask, Index, SubTp); } + +bool AArch64TTIImpl::isLegalScalableVectorIntrinsic(Intrinsic::ID IID) const { + switch (IID) { + case Intrinsic::abs: // Begin integer bit-manipulation. + case Intrinsic::bswap: + case Intrinsic::bitreverse: + case Intrinsic::ctpop: + case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::fshl: + case Intrinsic::fshr: + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::umax: + case Intrinsic::umin: + case Intrinsic::sadd_sat: + case Intrinsic::ssub_sat: + case Intrinsic::uadd_sat: + case Intrinsic::usub_sat: + case Intrinsic::smul_fix: + case Intrinsic::smul_fix_sat: + case Intrinsic::umul_fix: + case Intrinsic::umul_fix_sat: + case Intrinsic::sqrt: // Begin floating-point. + case Intrinsic::fabs: + case Intrinsic::minnum: + case Intrinsic::maxnum: + case Intrinsic::minimum: + case Intrinsic::maximum: + case Intrinsic::floor: + case Intrinsic::ceil: + case Intrinsic::trunc: + case Intrinsic::rint: + case Intrinsic::nearbyint: + case Intrinsic::round: + case Intrinsic::roundeven: + case Intrinsic::fma: + case Intrinsic::fmuladd: + return true; + default: + // We can fall back on scalarization for fixed width vectors, but not for + // scalable vectors. + return BaseT::isLegalScalableVectorIntrinsic(IID); + } +} Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1508,13 +1508,18 @@ /// Returns true if the target machine supports all of the reduction /// variables found for the given VF. - bool canVectorizeReductions(ElementCount VF) { + bool canVectorizeReductions(ElementCount VF) const { return (all_of(Legal->getReductionVars(), [&](auto &Reduction) -> bool { RecurrenceDescriptor RdxDesc = Reduction.second; return TTI.isLegalToVectorizeReduction(RdxDesc, VF); })); } + /// Returns true if we can widen all instructions in the loop using a maximum + /// scalable vectorization factor MaxVF. If the loop is illegal the function + /// returns an appropriate error remark in Msg. + bool canWidenLoopWithScalableVectors(ElementCount MaxVF, StringRef &Msg) const; + /// Returns true if \p I is an instruction that will be scalarized with /// predication. Such instructions include conditional stores and /// instructions that may divide by zero. @@ -5638,6 +5643,45 @@ return false; } +bool LoopVectorizationCostModel::canWidenLoopWithScalableVectors( + ElementCount MaxVF, StringRef &Msg) const { + // Test that the loop-vectorizer can legalize all operations for eligible + // vectorization factors up to MaxVF. + + // Disable scalable vectorization if the loop contains unsupported reductions. + if (!canVectorizeReductions(MaxVF)) { + Msg = "Scalable vectorization not supported for the reduction " + "operations found in this loop."; + return false; + } + + // Iterate through all instructions in the loop ensuring that is legal to + // vectorize with a scalable VF. + for (BasicBlock *BB : TheLoop->blocks()) { + for (Instruction &I : *BB) { + if (auto *CI = dyn_cast(&I)) { + Intrinsic::ID VecID = getVectorIntrinsicIDForCall(CI, TLI); + + // First check if it's always legal to widen this intrinsic regardless + // of the scalable VF, i.e. we don't have to worry about scalarizing + // the intrinsic. + if (VecID && TTI.isLegalScalableVectorIntrinsic(VecID)) + continue; + + // At this point we have no guarantee that we can widen this call + // unless we have mappings in the vector function database. + if (VFDatabase::getMappings(*CI).empty()) { + Msg = "Scalable vectorization not supported for the call " + "instructions found in this loop"; + return false; + } + } + } + } + + return true; +} + ElementCount LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { if (!TTI.supportsScalableVectors() && !ForceTargetSupportsScalableVectors) { @@ -5657,17 +5701,9 @@ auto MaxScalableVF = ElementCount::getScalable( std::numeric_limits::max()); - // Disable scalable vectorization if the loop contains unsupported reductions. - // Test that the loop-vectorizer can legalize all operations for this MaxVF. - // FIXME: While for scalable vectors this is currently sufficient, this should - // be replaced by a more detailed mechanism that filters out specific VFs, - // instead of invalidating vectorization for a whole set of VFs based on the - // MaxVF. - if (!canVectorizeReductions(MaxScalableVF)) { - reportVectorizationInfo( - "Scalable vectorization not supported for the reduction " - "operations found in this loop.", - "ScalableVFUnfeasible", ORE, TheLoop); + StringRef Msg; + if (!canWidenLoopWithScalableVectors(MaxScalableVF, Msg)) { + reportVectorizationInfo(Msg, "ScalableVFUnfeasible", ORE, TheLoop); return ElementCount::getScalable(0); } Index: llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll =================================================================== --- llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll +++ llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll @@ -1,4 +1,6 @@ -; RUN: opt -S -loop-vectorize -force-vector-interleave=1 -instcombine -mattr=+sve -mtriple aarch64-unknown-linux-gnu -scalable-vectorization=on < %s | FileCheck %s +; RUN: opt -S -loop-vectorize -force-vector-interleave=1 -instcombine -mattr=+sve -mtriple aarch64-unknown-linux-gnu -scalable-vectorization=on \ +; RUN: -pass-remarks-missed=loop-vectorize < %s 2>%t | FileCheck %s +; RUN: cat %t | FileCheck %s --check-prefix=CHECK-REMARKS define void @vec_load(i64 %N, double* nocapture %a, double* nocapture readonly %b) { ; CHECK-LABEL: @vec_load @@ -95,9 +97,60 @@ ret void } +; CHECK-REMARKS: Scalable vectorization not supported for the call instructions found in this loop +define void @vec_sin_no_mapping(float* noalias nocapture %dst, float* noalias nocapture readonly %src, i64 %n) { +; CHECK: @vec_sin_no_mapping +; CHECK: call fast <2 x float> @llvm.sin.v2f32 +; CHECK-NOT: @llvm.sqrt.nxv2f32 +entry: + br label %for.body + +for.body: ; preds = %entry, %for.body + %i.07 = phi i64 [ %inc, %for.body ], [ 0, %entry ] + %arrayidx = getelementptr inbounds float, float* %src, i64 %i.07 + %0 = load float, float* %arrayidx, align 4 + %1 = tail call fast float @llvm.sqrt.f32(float %0) + %arrayidx1 = getelementptr inbounds float, float* %dst, i64 %i.07 + store float %1, float* %arrayidx1, align 4 + %inc = add nuw nsw i64 %i.07, 1 + %exitcond.not = icmp eq i64 %inc, %n + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !llvm.loop !1 + +for.cond.cleanup: ; preds = %for.body + ret void +} + + declare double @foo(double) declare i64 @bar(i64*) declare double @llvm.sin.f64(double) +declare float @llvm.sin.f32(float) +declare float @llvm.sqrt.f32(float) declare @foo_vec() declare @bar_vec()