diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -412,6 +412,10 @@ // even taking non-uniform arguments bool isAlwaysUniform(const Value *V) const; + /// Query the target whether the specified address space cast from FromAS to + /// ToAS is valid. + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const; + /// Returns the address space ID for a target's 'flat' address space. Note /// this is not necessarily the same as addrspace(0), which LLVM sometimes /// refers to as the generic address space. The flat address space is a @@ -1680,6 +1684,7 @@ virtual bool useGPUDivergenceAnalysis() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; virtual bool isAlwaysUniform(const Value *V) = 0; + virtual bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const = 0; virtual unsigned getFlatAddressSpace() = 0; virtual bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Intrinsic::ID IID) const = 0; @@ -2061,6 +2066,10 @@ return Impl.isAlwaysUniform(V); } + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const override { + return Impl.isValidAddrSpaceCast(FromAS, ToAS); + } + unsigned getFlatAddressSpace() override { return Impl.getFlatAddressSpace(); } bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -95,6 +95,10 @@ bool isAlwaysUniform(const Value *V) const { return false; } + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { + return FromAS == ToAS; + } + unsigned getFlatAddressSpace() const { return -1; } bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -284,6 +284,10 @@ bool isAlwaysUniform(const Value *V) { return false; } + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { + return FromAS == ToAS; + } + unsigned getFlatAddressSpace() { // Return an invalid address space. return -1; diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h --- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -529,6 +529,8 @@ SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth = 0, bool AllowMultipleUsers = false) = 0; + + bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const; }; } // namespace llvm diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -266,6 +266,11 @@ return TTIImpl->isAlwaysUniform(V); } +bool llvm::TargetTransformInfo::isValidAddrSpaceCast(unsigned FromAS, + unsigned ToAS) const { + return TTIImpl->isValidAddrSpaceCast(FromAS, ToAS); +} + unsigned TargetTransformInfo::getFlatAddressSpace() const { return TTIImpl->getFlatAddressSpace(); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -262,17 +262,25 @@ PointerReplacer(InstCombinerImpl &IC, Instruction &Root) : IC(IC), Root(Root) {} - bool collectUsers(); + bool collectUsers(unsigned SrcAS); void replacePointer(Value *V); private: - bool collectUsersRecursive(Instruction &I); + bool collectUsersRecursive(unsigned SrcAS, Instruction &I); void replace(Instruction *I); Value *getReplacement(Value *I); bool isAvailable(Instruction *I) const { return I == &Root || Worklist.contains(I); } + bool isValidAddrSpaceCast(const Instruction *I, unsigned FromAS) const { + const auto *ASC = dyn_cast(I); + if (!ASC) + return false; + unsigned ToAS = ASC->getDestAddressSpace(); + return IC.isValidAddrSpaceCast(FromAS, ToAS); + } + SmallPtrSet ValuesToRevisit; SmallSetVector Worklist; MapVector WorkMap; @@ -281,8 +289,8 @@ }; } // end anonymous namespace -bool PointerReplacer::collectUsers() { - if (!collectUsersRecursive(Root)) +bool PointerReplacer::collectUsers(unsigned SrcAS) { + if (!collectUsersRecursive(SrcAS, Root)) return false; // Ensure that all outstanding (indirect) users of I @@ -294,7 +302,7 @@ return true; } -bool PointerReplacer::collectUsersRecursive(Instruction &I) { +bool PointerReplacer::collectUsersRecursive(unsigned SrcAS, Instruction &I) { for (auto *U : I.users()) { auto *Inst = cast(&*U); if (auto *Load = dyn_cast(Inst)) { @@ -318,7 +326,7 @@ } Worklist.insert(PHI); - if (!collectUsersRecursive(*PHI)) + if (!collectUsersRecursive(SrcAS, *PHI)) return false; } else if (auto *SI = dyn_cast(Inst)) { if (!isa(SI->getTrueValue()) || @@ -331,16 +339,18 @@ continue; } Worklist.insert(SI); - if (!collectUsersRecursive(*SI)) + if (!collectUsersRecursive(SrcAS, *SI)) return false; } else if (isa(Inst)) { Worklist.insert(Inst); - if (!collectUsersRecursive(*Inst)) + if (!collectUsersRecursive(SrcAS, *Inst)) return false; } else if (auto *MI = dyn_cast(Inst)) { if (MI->isVolatile()) return false; Worklist.insert(Inst); + } else if (isValidAddrSpaceCast(Inst, SrcAS)) { + Worklist.insert(Inst); } else if (Inst->isLifetimeStartOrEnd()) { continue; } else { @@ -426,6 +436,21 @@ IC.eraseInstFromFunction(*MemCpy); WorkMap[MemCpy] = NewI; + } else if (auto *ASC = dyn_cast(I)) { + auto *V = getReplacement(ASC->getPointerOperand()); + assert(V && "Operand not replaced"); + assert(isValidAddrSpaceCast(ASC, V->getType()->getPointerAddressSpace()) && + "Invalid address space cast!"); + auto *NewV = V; + if (V->getType()->getPointerAddressSpace() != + ASC->getType()->getPointerAddressSpace()) { + auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), ""); + NewI->takeName(ASC); + IC.InsertNewInstWith(NewI, *ASC); + NewV = NewI; + } + IC.replaceInstUsesWith(*ASC, NewV); + IC.eraseInstFromFunction(*ASC); } else { llvm_unreachable("should never reach here"); } @@ -519,7 +544,8 @@ } PointerReplacer PtrReplacer(*this, AI); - if (PtrReplacer.collectUsers()) { + if (PtrReplacer.collectUsers( + TheSrc->getType()->getPointerAddressSpace())) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -196,6 +196,10 @@ return std::nullopt; } +bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { + return TTI.isValidAddrSpaceCast(FromAS, ToAS); +} + Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { return llvm::emitGEPOffset(&Builder, DL, GEP); } diff --git a/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll --- a/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll +++ b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll @@ -428,6 +428,40 @@ ret i8 %load } +declare i8 @readonly_callee(ptr readonly nocapture) + +define i8 @call_readonly_remove_alloca() { +; CHECK-LABEL: @call_readonly_remove_alloca( +; CHECK-NEXT: [[V:%.*]] = call i8 @readonly_callee(ptr nonnull @g1) +; CHECK-NEXT: ret i8 [[V]] +; + %alloca = alloca [32 x i8], addrspace(1) + call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %alloca, ptr @g1, i64 32, i1 false) + %p = addrspacecast ptr addrspace(1) %alloca to ptr + %v = call i8 @readonly_callee(ptr %p) + ret i8 %v +} + +define i8 @call_readonly_keep_alloca2() { +; CHECK-LABEL: @call_readonly_keep_alloca2( +; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 1, addrspace(1) +; CHECK-NEXT: call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noundef align 1 dereferenceable(16) [[ALLOCA]], ptr noundef nonnull align 16 dereferenceable(16) @g1, i64 16, i1 false) +; CHECK-NEXT: [[A1:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 16 +; CHECK-NEXT: call void @llvm.memcpy.p1.p1.i64(ptr addrspace(1) noundef align 1 dereferenceable(16) [[A1]], ptr addrspace(1) noundef align 16 dereferenceable(16) @g2, i64 16, i1 false) +; CHECK-NEXT: [[P:%.*]] = addrspacecast ptr addrspace(1) [[ALLOCA]] to ptr +; CHECK-NEXT: [[V:%.*]] = call i8 @readonly_callee(ptr [[P]]) +; CHECK-NEXT: ret i8 [[V]] +; + %alloca = alloca [32 x i8], addrspace(1) + call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %alloca, ptr @g1, i64 16, i1 false) + %a1 = getelementptr inbounds [32 x i8], ptr addrspace(1) %alloca, i32 0, i32 16 + call void @llvm.memcpy.p1.p1.i64(ptr addrspace(1) %a1, ptr addrspace(1) @g2, i64 16, i1 false) + %p = addrspacecast ptr addrspace(1) %alloca to ptr + %v = call i8 @readonly_callee(ptr %p) + ret i8 %v +} + declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1), ptr, i64, i1) declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1) declare void @llvm.memcpy.p0.p1.i64(ptr, ptr addrspace(1), i64, i1) +declare void @llvm.memcpy.p1.p1.i64(ptr addrspace(1), ptr addrspace(1), i64, i1)