Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -203,6 +203,13 @@ /// comments for a detailed explanation of the cost values. int getUserCost(const User *U) const; + /// Determines if the given address computation will be hidden into + /// the addressing mode of the target, or else will be complex. + /// + /// \p Stride holds a pointer to the Stride value in bytes if it is a compile + /// time constant, or nullptr otherwise. + bool isComplexStridedAddressComputation(const APInt* Stride) const; + /// \brief Return true if branch divergence exists. /// /// Branch divergence has a significantly negative impact on GPU performance @@ -778,6 +785,7 @@ FastMathFlags FMF) = 0; virtual int getCallInstrCost(Function *F, Type *RetTy, ArrayRef Tys) = 0; + virtual bool isComplexStridedAddressComputation(const APInt* Stride) = 0; virtual unsigned getNumberOfParts(Type *Tp) = 0; virtual int getAddressComputationCost(Type *Ty, bool IsComplex) = 0; virtual unsigned getCostOfKeepingLiveOverCall(ArrayRef Tys) = 0; @@ -1025,6 +1033,9 @@ ArrayRef Tys) override { return Impl.getCallInstrCost(F, RetTy, Tys); } + bool isComplexStridedAddressComputation(const APInt* Stride) override { + return Impl.isComplexStridedAddressComputation(Stride); + } unsigned getNumberOfParts(Type *Tp) override { return Impl.getNumberOfParts(Tp); } Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -368,6 +368,10 @@ return 1; } + bool isComplexStridedAddressComputation(const APInt* Stride) { + return true; + } + unsigned getNumberOfParts(Type *Tp) { return 0; } unsigned getAddressComputationCost(Type *Tp, bool) { return 0; } Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -918,6 +918,20 @@ return 10; } + bool isComplexStridedAddressComputation(const APInt* Stride) { + int MaxMergeDistance = 64; + + if (!Stride) + return true; + + // Huge step value - give up. + if (Stride->getBitWidth() > 64) + return true; + + int64_t StepVal = Stride->getSExtValue(); + return (StepVal > MaxMergeDistance); + } + unsigned getNumberOfParts(Type *Tp) { std::pair LT = getTLI()->getTypeLegalizationCost(DL, Tp); return LT.first; Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -384,6 +384,11 @@ return Cost; } +bool TargetTransformInfo::isComplexStridedAddressComputation( + const APInt* Stride) const { + return TTIImpl->isComplexStridedAddressComputation(Stride); +} + unsigned TargetTransformInfo::getNumberOfParts(Type *Tp) const { return TTIImpl->getNumberOfParts(Tp); } Index: lib/Target/X86/X86TargetTransformInfo.h =================================================================== --- lib/Target/X86/X86TargetTransformInfo.h +++ lib/Target/X86/X86TargetTransformInfo.h @@ -71,6 +71,7 @@ unsigned AddressSpace); int getGatherScatterOpCost(unsigned Opcode, Type *DataTy, Value *Ptr, bool VariableMask, unsigned Alignment); + bool isComplexStridedAddressComputation(const APInt* Stride); int getAddressComputationCost(Type *PtrTy, bool IsComplex); int getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, Index: lib/Target/X86/X86TargetTransformInfo.cpp =================================================================== --- lib/Target/X86/X86TargetTransformInfo.cpp +++ lib/Target/X86/X86TargetTransformInfo.cpp @@ -1473,6 +1473,14 @@ return Cost+LT.first; } +bool X86TTIImpl::isComplexStridedAddressComputation(const APInt* Stride) { + // Cost modeling of Strided Address Computation is hidden by the indexing + // modes of X86 regardless of the stride value, also in the case of non + // compile time constant stride we dont see any significant address + // computation cost. + return false; +} + int X86TTIImpl::getAddressComputationCost(Type *Ty, bool IsComplex) { // Address computations in vectorized code with non-consecutive addresses will // likely result in more instructions compared to scalar code where the Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6588,7 +6588,8 @@ static bool isLikelyComplexAddressComputation(Value *Ptr, LoopVectorizationLegality *Legal, ScalarEvolution *SE, - const Loop *TheLoop) { + const Loop *TheLoop, + const TargetTransformInfo &TTI) { auto *Gep = dyn_cast(Ptr); if (!Gep) return true; @@ -6605,8 +6606,6 @@ // Now we know we have a GEP ptr, %inv, %ind, %inv. Make sure that the step // can likely be merged into the address computation. - unsigned MaxMergeDistance = 64; - const SCEVAddRecExpr *AddRec = dyn_cast(SE->getSCEV(Ptr)); if (!AddRec) return true; @@ -6616,17 +6615,11 @@ // Calculate the pointer stride and check if it is consecutive. const auto *C = dyn_cast(Step); if (!C) - return true; - - const APInt &APStepVal = C->getAPInt(); - - // Huge step value - give up. - if (APStepVal.getBitWidth() > 64) - return true; + return TTI.isComplexStridedAddressComputation(NULL); - int64_t StepVal = APStepVal.getSExtValue(); + const APInt &APStrideVal = C->getAPInt(); - return StepVal > MaxMergeDistance; + return TTI.isComplexStridedAddressComputation(&APStrideVal); } static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) { @@ -6854,7 +6847,7 @@ // True if the memory instruction's address computation is complex. bool IsComplexComputation = - isLikelyComplexAddressComputation(Ptr, Legal, SE, TheLoop); + isLikelyComplexAddressComputation(Ptr, Legal, SE, TheLoop, TTI); // Get the cost of the scalar memory instruction and address computation. Cost += VF * TTI.getAddressComputationCost(PtrTy, IsComplexComputation); Index: test/Transforms/LoopVectorize/X86/strided_load_cost.ll =================================================================== --- test/Transforms/LoopVectorize/X86/strided_load_cost.ll +++ test/Transforms/LoopVectorize/X86/strided_load_cost.ll @@ -0,0 +1,54 @@ +; This test checks that the given loop still beneficial for vecotization +; even if it contains scalarized load (gather on AVX2) +;RUN: opt < %s -loop-vectorize -S -o - | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +; Function Attrs: norecurse nounwind readonly uwtable +define i32 @matrix_row_col([100 x i32]* nocapture readonly %data, i32 %i, i32 %j) local_unnamed_addr #0 { +entry: + %idxprom = sext i32 %i to i64 + %idxprom5 = sext i32 %j to i64 + br label %for.body + +for.cond.cleanup: ; preds = %for.body + ret i32 %add7 + +for.body: ; preds = %for.body, %entry +; the loop gets vectorized +; first consecutive load as vector load +; CHECK: %wide.load = load <8 x i32> +; second strided load scalarized +; CHECK: load i32 +; CHECK: load i32 +; CHECK: load i32 +; CHECK: load i32 +; CHECK: load i32 +; CHECK: load i32 +; CHECK: load i32 +; CHECK: load i32 + + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %sum.015 = phi i32 [ 0, %entry ], [ %add7, %for.body ] + %arrayidx2 = getelementptr inbounds [100 x i32], [100 x i32]* %data, i64 %idxprom, i64 %indvars.iv + %0 = load i32, i32* %arrayidx2, align 4, !tbaa !1 + %arrayidx6 = getelementptr inbounds [100 x i32], [100 x i32]* %data, i64 %indvars.iv, i64 %idxprom5 + %1 = load i32, i32* %arrayidx6, align 4, !tbaa !1 + %mul = mul nsw i32 %1, %0 + %add = add i32 %sum.015, 4 + %add7 = add i32 %add, %mul + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, 100 + br i1 %exitcond, label %for.cond.cleanup, label %for.body +} + +attributes #0 = { "target-cpu"="core-avx2" "target-features"="+avx,+avx2,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3" } + +!llvm.ident = !{!0} + +!0 = !{!"clang version 4.0.0 (cfe/trunk 284570)"} +!1 = !{!2, !2, i64 0} +!2 = !{!"int", !3, i64 0} +!3 = !{!"omnipotent char", !4, i64 0} +!4 = !{!"Simple C/C++ TBAA"}