Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1041,6 +1041,9 @@ /// \return True if prefetching should also be done for writes. bool enableWritePrefetching() const; + /// \return if target want to issue a prefetch in address space \p AS. + bool shouldPrefetchAddressSpace(unsigned AS) const; + /// \return The maximum interleave factor that any transform should try to /// perform for this target. This number depends on the level of parallelism /// and the number of execution units in the CPU. @@ -1702,6 +1705,9 @@ /// \return True if prefetching should also be done for writes. virtual bool enableWritePrefetching() const = 0; + /// \return if target want to issue a prefetch in address space \p AS. + virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0; + virtual unsigned getMaxInterleaveFactor(unsigned VF) = 0; virtual InstructionCost getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, @@ -2228,6 +2234,11 @@ return Impl.enableWritePrefetching(); } + /// \return if target want to issue a prefetch in address space \p AS. + bool shouldPrefetchAddressSpace(unsigned AS) const override { + return Impl.shouldPrefetchAddressSpace(AS); + } + unsigned getMaxInterleaveFactor(unsigned VF) override { return Impl.getMaxInterleaveFactor(VF); } Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -475,6 +475,7 @@ } unsigned getMaxPrefetchIterationsAhead() const { return UINT_MAX; } bool enableWritePrefetching() const { return false; } + bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; } unsigned getMaxInterleaveFactor(unsigned VF) const { return 1; } Index: llvm/include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -683,6 +683,10 @@ return getST()->enableWritePrefetching(); } + virtual bool shouldPrefetchAddressSpace(unsigned AS) const { + return getST()->shouldPrefetchAddressSpace(AS); + } + /// @} /// \name Vector TTI Implementations Index: llvm/include/llvm/MC/MCSubtargetInfo.h =================================================================== --- llvm/include/llvm/MC/MCSubtargetInfo.h +++ llvm/include/llvm/MC/MCSubtargetInfo.h @@ -282,6 +282,9 @@ unsigned NumStridedMemAccesses, unsigned NumPrefetches, bool HasCall) const; + + /// \return if target want to issue a prefetch in address space \p AS. + virtual bool shouldPrefetchAddressSpace(unsigned AS) const; }; } // end namespace llvm Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -702,6 +702,10 @@ return TTIImpl->enableWritePrefetching(); } +bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const { + return TTIImpl->shouldPrefetchAddressSpace(AS); +} + unsigned TargetTransformInfo::getMaxInterleaveFactor(unsigned VF) const { return TTIImpl->getMaxInterleaveFactor(VF); } Index: llvm/lib/MC/MCSubtargetInfo.cpp =================================================================== --- llvm/lib/MC/MCSubtargetInfo.cpp +++ llvm/lib/MC/MCSubtargetInfo.cpp @@ -366,3 +366,7 @@ bool HasCall) const { return 1; } + +bool MCSubtargetInfo::shouldPrefetchAddressSpace(unsigned AS) const { + return !AS; +} Index: llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -336,7 +336,7 @@ } else continue; unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); - if (PtrAddrSpace) + if (!TTI->shouldPrefetchAddressSpace(PtrAddrSpace)) continue; NumMemAccesses++; if (L->isLoopInvariant(PtrValue)) @@ -396,7 +396,8 @@ if (!SCEVE.isSafeToExpand(NextLSCEV)) continue; - Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/); + unsigned PtrAddrSpace = NextLSCEV->getType()->getPointerAddressSpace(); + Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt); IRBuilder<> Builder(P.InsertPt);