Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1058,6 +1058,7 @@ /// \returns The type to use in a loop expansion of a memcpy call. Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length, + unsigned SrcAddrSpace, unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const; /// \param[out] OpsOut The operand types to copy RemainingBytes of memory. @@ -1069,6 +1070,8 @@ void getMemcpyLoopResidualLoweringType(SmallVectorImpl &OpsOut, LLVMContext &Context, unsigned RemainingBytes, + unsigned SrcAddrSpace, + unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const; @@ -1369,11 +1372,15 @@ virtual Value *getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst, Type *ExpectedType) = 0; virtual Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length, + unsigned SrcAddrSpace, + unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const = 0; virtual void getMemcpyLoopResidualLoweringType( SmallVectorImpl &OpsOut, LLVMContext &Context, - unsigned RemainingBytes, unsigned SrcAlign, unsigned DestAlign) const = 0; + unsigned RemainingBytes, + unsigned SrcAddrSpace, unsigned DestAddrSpace, + unsigned SrcAlign, unsigned DestAlign) const = 0; virtual bool areInlineCompatible(const Function *Caller, const Function *Callee) const = 0; virtual bool @@ -1814,16 +1821,22 @@ return Impl.getOrCreateResultFromMemIntrinsic(Inst, ExpectedType); } Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length, + unsigned SrcAddrSpace, unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const override { - return Impl.getMemcpyLoopLoweringType(Context, Length, SrcAlign, DestAlign); + return Impl.getMemcpyLoopLoweringType(Context, Length, + SrcAddrSpace, DestAddrSpace, + SrcAlign, DestAlign); } void getMemcpyLoopResidualLoweringType(SmallVectorImpl &OpsOut, LLVMContext &Context, unsigned RemainingBytes, + unsigned SrcAddrSpace, + unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const override { Impl.getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes, + SrcAddrSpace, DestAddrSpace, SrcAlign, DestAlign); } bool areInlineCompatible(const Function *Caller, Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -536,6 +536,7 @@ } Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length, + unsigned SrcAddrSpace, unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const { return Type::getInt8Ty(Context); } @@ -543,6 +544,8 @@ void getMemcpyLoopResidualLoweringType(SmallVectorImpl &OpsOut, LLVMContext &Context, unsigned RemainingBytes, + unsigned SrcAddrSpace, + unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const { for (unsigned i = 0; i != RemainingBytes; ++i) Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -771,16 +771,23 @@ Type *TargetTransformInfo::getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length, + unsigned SrcAddrSpace, + unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const { - return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAlign, + return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAddrSpace, + DestAddrSpace, SrcAlign, DestAlign); } void TargetTransformInfo::getMemcpyLoopResidualLoweringType( SmallVectorImpl &OpsOut, LLVMContext &Context, - unsigned RemainingBytes, unsigned SrcAlign, unsigned DestAlign) const { + unsigned RemainingBytes, + unsigned SrcAddrSpace, + unsigned DestAddrSpace, + unsigned SrcAlign, unsigned DestAlign) const { TTIImpl->getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes, + SrcAddrSpace, DestAddrSpace, SrcAlign, DestAlign); } Index: llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp =================================================================== --- llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -36,16 +36,16 @@ Function *ParentFunc = PreLoopBB->getParent(); LLVMContext &Ctx = PreLoopBB->getContext(); + unsigned SrcAS = cast(SrcAddr->getType())->getAddressSpace(); + unsigned DstAS = cast(DstAddr->getType())->getAddressSpace(); + Type *TypeOfCopyLen = CopyLen->getType(); Type *LoopOpType = - TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAlign, DestAlign); + TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAS, DstAS, SrcAlign, DestAlign); unsigned LoopOpSize = getLoopOperandSizeInBytes(LoopOpType); uint64_t LoopEndCount = CopyLen->getZExtValue() / LoopOpSize; - unsigned SrcAS = cast(SrcAddr->getType())->getAddressSpace(); - unsigned DstAS = cast(DstAddr->getType())->getAddressSpace(); - if (LoopEndCount != 0) { // Split PostLoopBB = PreLoopBB->splitBasicBlock(InsertBefore, "memcpy-split"); @@ -99,6 +99,7 @@ SmallVector RemainingOps; TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes, + SrcAS, DstAS, SrcAlign, DestAlign); for (auto OpTy : RemainingOps) { @@ -144,15 +145,15 @@ Function *ParentFunc = PreLoopBB->getParent(); LLVMContext &Ctx = PreLoopBB->getContext(); + unsigned SrcAS = cast(SrcAddr->getType())->getAddressSpace(); + unsigned DstAS = cast(DstAddr->getType())->getAddressSpace(); Type *LoopOpType = - TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAlign, DestAlign); + TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAS, DstAS, SrcAlign, DestAlign); unsigned LoopOpSize = getLoopOperandSizeInBytes(LoopOpType); IRBuilder<> PLBuilder(PreLoopBB->getTerminator()); - unsigned SrcAS = cast(SrcAddr->getType())->getAddressSpace(); - unsigned DstAS = cast(DstAddr->getType())->getAddressSpace(); PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS); PointerType *DstOpType = PointerType::get(LoopOpType, DstAS); if (SrcAddr->getType() != SrcOpType) {