Index: ../include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- ../include/llvm/Analysis/TargetTransformInfoImpl.h +++ ../include/llvm/Analysis/TargetTransformInfoImpl.h @@ -22,6 +22,7 @@ #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" +#include "llvm/Analysis/VectorUtils.h" namespace llvm { @@ -415,21 +416,28 @@ (Ptr == nullptr ? 0 : Ptr->getType()->getPointerAddressSpace()); auto GTI = gep_type_begin(PointerType::get(PointeeType, AS), Operands); for (auto I = Operands.begin(); I != Operands.end(); ++I, ++GTI) { + // We assume that the cost of Scalar GEP with constant index and the + // cost of Vector GEP with splat constant index are the same. + const ConstantInt *ConstIdx = dyn_cast(*I); + if (!ConstIdx) + if (auto Splat = getSplatValue(*I)) + ConstIdx = dyn_cast(Splat); if (isa(*GTI)) { int64_t ElementSize = DL.getTypeAllocSize(GTI.getIndexedType()); - if (const ConstantInt *ConstIdx = dyn_cast(*I)) { + if (ConstIdx) BaseOffset += ConstIdx->getSExtValue() * ElementSize; - } else { + else { // Needs scale register. - if (Scale != 0) { + if (Scale != 0) // No addressing mode takes two scale registers. return TTI::TCC_Basic; - } Scale = ElementSize; } } else { StructType *STy = cast(*GTI); - uint64_t Field = cast(*I)->getZExtValue(); + // For structures the index is always slat or scalar constant + assert(ConstIdx && "Unexpected GEP index"); + uint64_t Field = ConstIdx->getZExtValue(); BaseOffset += DL.getStructLayout(STy)->getElementOffset(Field); } } Index: ../include/llvm/Analysis/VectorUtils.h =================================================================== --- ../include/llvm/Analysis/VectorUtils.h +++ ../include/llvm/Analysis/VectorUtils.h @@ -85,7 +85,7 @@ /// \brief Get splat value if the input is a splat vector or return nullptr. /// The value may be extracted from a splat constants vector or from /// a sequence of instructions that broadcast a single value into a vector. -Value *getSplatValue(Value *V); +const Value *getSplatValue(const Value *V); /// \brief Compute a map of integer instructions to their minimum legal type /// size. Index: ../lib/Analysis/VectorUtils.cpp =================================================================== --- ../lib/Analysis/VectorUtils.cpp +++ ../lib/Analysis/VectorUtils.cpp @@ -417,9 +417,10 @@ /// the input value is (1) a splat constants vector or (2) a sequence /// of instructions that broadcast a single value into a vector. /// -llvm::Value *llvm::getSplatValue(Value *V) { - if (auto *CV = dyn_cast(V)) - return CV->getSplatValue(); +const llvm::Value *llvm::getSplatValue(const Value *V) { + + if (auto *C = dyn_cast(V)) + return C->getSplatValue(); auto *ShuffleInst = dyn_cast(V); if (!ShuffleInst) Index: ../lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- ../lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ ../lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3301,18 +3301,18 @@ // extract the spalt value and use it as a uniform base. // In all other cases the function returns 'false'. // -static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index, +static bool getUniformBase(const Value *& Ptr, SDValue& Base, SDValue& Index, SelectionDAGBuilder* SDB) { SelectionDAG& DAG = SDB->DAG; LLVMContext &Context = *DAG.getContext(); assert(Ptr->getType()->isVectorTy() && "Uexpected pointer type"); - GetElementPtrInst *GEP = dyn_cast(Ptr); + const GetElementPtrInst *GEP = dyn_cast(Ptr); if (!GEP || GEP->getNumOperands() > 2) return false; - Value *GEPPtr = GEP->getPointerOperand(); + const Value *GEPPtr = GEP->getPointerOperand(); if (!GEPPtr->getType()->isVectorTy()) Ptr = GEPPtr; else if (!(Ptr = getSplatValue(GEPPtr))) @@ -3348,7 +3348,7 @@ SDLoc sdl = getCurSDLoc(); // llvm.masked.scatter.*(Src0, Ptrs, alignemt, Mask) - Value *Ptr = I.getArgOperand(1); + const Value *Ptr = I.getArgOperand(1); SDValue Src0 = getValue(I.getArgOperand(0)); SDValue Mask = getValue(I.getArgOperand(3)); EVT VT = Src0.getValueType(); @@ -3362,10 +3362,10 @@ SDValue Base; SDValue Index; - Value *BasePtr = Ptr; + const Value *BasePtr = Ptr; bool UniformBase = getUniformBase(BasePtr, Base, Index, this); - Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; + const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), MachineMemOperand::MOStore, VT.getStoreSize(), @@ -3425,7 +3425,7 @@ SDLoc sdl = getCurSDLoc(); // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) - Value *Ptr = I.getArgOperand(0); + const Value *Ptr = I.getArgOperand(0); SDValue Src0 = getValue(I.getArgOperand(3)); SDValue Mask = getValue(I.getArgOperand(2)); @@ -3442,7 +3442,7 @@ SDValue Root = DAG.getRoot(); SDValue Base; SDValue Index; - Value *BasePtr = Ptr; + const Value *BasePtr = Ptr; bool UniformBase = getUniformBase(BasePtr, Base, Index, this); bool ConstantMemory = false; if (UniformBase && Index: ../test/Analysis/CostModel/X86/vector_gep.ll =================================================================== --- ../test/Analysis/CostModel/X86/vector_gep.ll +++ ../test/Analysis/CostModel/X86/vector_gep.ll @@ -0,0 +1,17 @@ +; RUN: opt < %s -cost-model -analyze -mtriple=x86_64-linux-unknown-unknown -mattr=+avx512f | FileCheck %s + +%struct.S = type { [1000 x i32] } + + +declare <4 x i32> @llvm.masked.gather.v4i32(<4 x i32*>, i32, <4 x i1>, <4 x i32>) + +define <4 x i32> @foov(<4 x %struct.S*> %s, i64 %base){ + %temp = insertelement <4 x i64> undef, i64 %base, i32 0 + %vector = shufflevector <4 x i64> %temp, <4 x i64> undef, <4 x i32> zeroinitializer +;CHECK: cost of 0 for instruction: {{.*}} getelementptr inbounds %struct.S + %B = getelementptr inbounds %struct.S, <4 x %struct.S*> %s, <4 x i32> zeroinitializer, <4 x i32> zeroinitializer +;CHECK: cost of 0 for instruction: {{.*}} getelementptr inbounds [1000 x i32] + %arrayidx = getelementptr inbounds [1000 x i32], <4 x [1000 x i32]*> %B, <4 x i64> zeroinitializer, <4 x i64> %vector + %res = call <4 x i32> @llvm.masked.gather.v4i32(<4 x i32*> %arrayidx, i32 4, <4 x i1> , <4 x i32> undef) + ret <4 x i32> %res +}