diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -412,6 +412,10 @@ // even taking non-uniform arguments bool isAlwaysUniform(const Value *V) const; + /// Query the target whether the specified address space cast from FromAS to + /// ToAS is valid. + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const; + /// Returns the address space ID for a target's 'flat' address space. Note /// this is not necessarily the same as addrspace(0), which LLVM sometimes /// refers to as the generic address space. The flat address space is a @@ -1680,6 +1684,7 @@ virtual bool useGPUDivergenceAnalysis() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; virtual bool isAlwaysUniform(const Value *V) = 0; + virtual bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const = 0; virtual unsigned getFlatAddressSpace() = 0; virtual bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Intrinsic::ID IID) const = 0; @@ -2061,6 +2066,10 @@ return Impl.isAlwaysUniform(V); } + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const override { + return Impl.isValidAddrSpaceCast(FromAS, ToAS); + } + unsigned getFlatAddressSpace() override { return Impl.getFlatAddressSpace(); } bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -95,6 +95,10 @@ bool isAlwaysUniform(const Value *V) const { return false; } + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { + return FromAS == ToAS; + } + unsigned getFlatAddressSpace() const { return -1; } bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -284,6 +284,10 @@ bool isAlwaysUniform(const Value *V) { return false; } + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { + return FromAS == ToAS; + } + unsigned getFlatAddressSpace() { // Return an invalid address space. return -1; diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h --- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -529,6 +529,8 @@ SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth = 0, bool AllowMultipleUsers = false) = 0; + + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const; }; } // namespace llvm diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -274,6 +274,11 @@ return TTIImpl->isAlwaysUniform(V); } +bool llvm::TargetTransformInfo::isValidAddrSpaceCast(unsigned FromAS, + unsigned ToAS) const { + return TTIImpl->isValidAddrSpaceCast(FromAS, ToAS); +} + unsigned TargetTransformInfo::getFlatAddressSpace() const { return TTIImpl->getFlatAddressSpace(); } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h @@ -168,6 +168,21 @@ bool isSourceOfDivergence(const Value *V) const; bool isAlwaysUniform(const Value *V) const; + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { + if (ToAS == AMDGPUAS::FLAT_ADDRESS) { + switch (FromAS) { + case AMDGPUAS::GLOBAL_ADDRESS: + case AMDGPUAS::CONSTANT_ADDRESS: + case AMDGPUAS::LOCAL_ADDRESS: + case AMDGPUAS::PRIVATE_ADDRESS: + return true; + default: + break; + } + } + return FromAS == ToAS; + } + unsigned getFlatAddressSpace() const { // Don't bother running InferAddressSpaces pass on graphics shaders which // don't use flat addressing. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -260,8 +260,8 @@ // instruction. class PointerReplacer { public: - PointerReplacer(InstCombinerImpl &IC, Instruction &Root) - : IC(IC), Root(Root) {} + PointerReplacer(InstCombinerImpl &IC, Instruction &Root, unsigned SrcAS) + : IC(IC), Root(Root), FromAS(SrcAS) {} bool collectUsers(); void replacePointer(Value *V); @@ -274,11 +274,17 @@ return I == &Root || Worklist.contains(I); } + bool isValidAddrSpaceCast(const Instruction *I, unsigned FromAS) const { + const auto *ASC = dyn_cast(I); + return ASC && IC.isValidAddrSpaceCast(FromAS, ASC->getDestAddressSpace()); + } + SmallPtrSet ValuesToRevisit; SmallSetVector Worklist; MapVector WorkMap; InstCombinerImpl ⁣ Instruction &Root; + unsigned FromAS; }; } // end anonymous namespace @@ -342,6 +348,8 @@ if (MI->isVolatile()) return false; Worklist.insert(Inst); + } else if (isValidAddrSpaceCast(Inst, FromAS)) { + Worklist.insert(Inst); } else if (Inst->isLifetimeStartOrEnd()) { continue; } else { @@ -427,6 +435,21 @@ IC.eraseInstFromFunction(*MemCpy); WorkMap[MemCpy] = NewI; + } else if (auto *ASC = dyn_cast(I)) { + auto *V = getReplacement(ASC->getPointerOperand()); + assert(V && "Operand not replaced"); + assert(isValidAddrSpaceCast(ASC, V->getType()->getPointerAddressSpace()) && + "Invalid address space cast!"); + auto *NewV = V; + if (V->getType()->getPointerAddressSpace() != + ASC->getType()->getPointerAddressSpace()) { + auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), ""); + NewI->takeName(ASC); + IC.InsertNewInstWith(NewI, *ASC); + NewV = NewI; + } + IC.replaceInstUsesWith(*ASC, NewV); + IC.eraseInstFromFunction(*ASC); } else { llvm_unreachable("should never reach here"); } @@ -519,7 +542,8 @@ return NewI; } - PointerReplacer PtrReplacer(*this, AI); + PointerReplacer PtrReplacer(*this, AI, + TheSrc->getType()->getPointerAddressSpace()); if (PtrReplacer.collectUsers()) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -196,6 +196,10 @@ return std::nullopt; } +bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { + return TTI.isValidAddrSpaceCast(FromAS, ToAS); +} + Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { return llvm::emitGEPOffset(&Builder, DL, GEP); } diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll b/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll --- a/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll +++ b/llvm/test/Transforms/InstCombine/AMDGPU/memcpy-from-constant.ll @@ -137,11 +137,9 @@ ; Alloca is written through a flat pointer define i8 @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat(ptr addrspace(4) noalias readonly align 4 dereferenceable(32) %arg, i32 %idx) { ; CHECK-LABEL: @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat( -; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 4, addrspace(5) -; CHECK-NEXT: [[ALLOCA_CAST_ASC:%.*]] = addrspacecast ptr addrspace(5) [[ALLOCA]] to ptr -; CHECK-NEXT: call void @llvm.memcpy.p0.p4.i64(ptr noundef nonnull align 1 dereferenceable(31) [[ALLOCA_CAST_ASC]], ptr addrspace(4) noundef align 4 dereferenceable(31) [[ARG:%.*]], i64 31, i1 false) -; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(5) [[ALLOCA]], i32 0, i32 [[IDX:%.*]] -; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(5) [[GEP]], align 1 +; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[IDX:%.*]] to i64 +; CHECK-NEXT: [[GEP:%.*]] = getelementptr [32 x i8], ptr addrspace(4) [[ARG:%.*]], i64 0, i64 [[TMP1]] +; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(4) [[GEP]], align 1 ; CHECK-NEXT: ret i8 [[LOAD]] ; %alloca = alloca [32 x i8], align 4, addrspace(5) @@ -155,9 +153,7 @@ ; Alloca is only addressed through flat pointer. define i8 @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat2(ptr addrspace(4) noalias readonly align 4 dereferenceable(32) %arg, i32 %idx) { ; CHECK-LABEL: @memcpy_constant_arg_ptr_to_alloca_addrspacecast_to_flat2( -; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 4, addrspace(5) -; CHECK-NEXT: [[ALLOCA_CAST_ASC:%.*]] = addrspacecast ptr addrspace(5) [[ALLOCA]] to ptr -; CHECK-NEXT: call void @llvm.memcpy.p0.p4.i64(ptr noundef nonnull align 1 dereferenceable(32) [[ALLOCA_CAST_ASC]], ptr addrspace(4) noundef align 4 dereferenceable(32) [[ARG:%.*]], i64 32, i1 false) +; CHECK-NEXT: [[ALLOCA_CAST_ASC:%.*]] = addrspacecast ptr addrspace(4) [[ARG:%.*]] to ptr ; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[IDX:%.*]] to i64 ; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds [32 x i8], ptr [[ALLOCA_CAST_ASC]], i64 0, i64 [[TMP1]] ; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[GEP]], align 1 @@ -202,9 +198,7 @@ define amdgpu_kernel void @byref_infloop_addrspacecast(ptr %scratch, ptr addrspace(4) byref(%struct.ty) align 4 %arg) local_unnamed_addr #1 { ; CHECK-LABEL: @byref_infloop_addrspacecast( ; CHECK-NEXT: bb: -; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [4 x i32], align 4, addrspace(5) -; CHECK-NEXT: [[ADDRSPACECAST_ALLOCA:%.*]] = addrspacecast ptr addrspace(5) [[ALLOCA]] to ptr -; CHECK-NEXT: call void @llvm.memcpy.p0.p4.i64(ptr noundef nonnull align 4 dereferenceable(16) [[ADDRSPACECAST_ALLOCA]], ptr addrspace(4) noundef align 4 dereferenceable(16) [[ARG:%.*]], i64 16, i1 false) +; CHECK-NEXT: [[ADDRSPACECAST_ALLOCA:%.*]] = addrspacecast ptr addrspace(4) [[ARG:%.*]] to ptr ; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 4 dereferenceable(16) [[SCRATCH:%.*]], ptr noundef nonnull align 4 dereferenceable(16) [[ADDRSPACECAST_ALLOCA]], i64 16, i1 false) ; CHECK-NEXT: ret void ; diff --git a/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll --- a/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll +++ b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll @@ -428,6 +428,40 @@ ret i8 %load } +declare i8 @readonly_callee(ptr readonly nocapture) + +define i8 @call_readonly_remove_alloca() { +; CHECK-LABEL: @call_readonly_remove_alloca( +; CHECK-NEXT: [[V:%.*]] = call i8 @readonly_callee(ptr nonnull @g1) +; CHECK-NEXT: ret i8 [[V]] +; + %alloca = alloca [32 x i8], addrspace(1) + call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %alloca, ptr @g1, i64 32, i1 false) + %p = addrspacecast ptr addrspace(1) %alloca to ptr + %v = call i8 @readonly_callee(ptr %p) + ret i8 %v +} + +define i8 @call_readonly_keep_alloca2() { +; CHECK-LABEL: @call_readonly_keep_alloca2( +; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 1, addrspace(1) +; CHECK-NEXT: call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noundef align 1 dereferenceable(16) [[ALLOCA]], ptr noundef nonnull align 16 dereferenceable(16) @g1, i64 16, i1 false) +; CHECK-NEXT: [[A1:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 16 +; CHECK-NEXT: call void @llvm.memcpy.p1.p1.i64(ptr addrspace(1) noundef align 1 dereferenceable(16) [[A1]], ptr addrspace(1) noundef align 16 dereferenceable(16) @g2, i64 16, i1 false) +; CHECK-NEXT: [[P:%.*]] = addrspacecast ptr addrspace(1) [[ALLOCA]] to ptr +; CHECK-NEXT: [[V:%.*]] = call i8 @readonly_callee(ptr [[P]]) +; CHECK-NEXT: ret i8 [[V]] +; + %alloca = alloca [32 x i8], addrspace(1) + call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %alloca, ptr @g1, i64 16, i1 false) + %a1 = getelementptr inbounds [32 x i8], ptr addrspace(1) %alloca, i32 0, i32 16 + call void @llvm.memcpy.p1.p1.i64(ptr addrspace(1) %a1, ptr addrspace(1) @g2, i64 16, i1 false) + %p = addrspacecast ptr addrspace(1) %alloca to ptr + %v = call i8 @readonly_callee(ptr %p) + ret i8 %v +} + declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1), ptr, i64, i1) declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1) declare void @llvm.memcpy.p0.p1.i64(ptr, ptr addrspace(1), i64, i1) +declare void @llvm.memcpy.p1.p1.i64(ptr addrspace(1), ptr addrspace(1), i64, i1)