Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -334,9 +334,10 @@ /// Rewrite intrinsic call \p II such that \p OldV will be replaced with \p /// NewV, which has a different address space. This should happen for every /// operand index that collectFlatAddressOperands returned for the intrinsic. - /// \returns true if the intrinsic /// was handled. - bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, - Value *NewV) const; + /// \returns nullptr if the intrinsic was not handled. Otherwise, returns the + /// new value (which may be the original \p II with modified operands). + Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, + Value *NewV) const; /// Test whether calls to a function lower to actual program function /// calls. @@ -1215,8 +1216,9 @@ virtual unsigned getFlatAddressSpace() = 0; virtual bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Intrinsic::ID IID) const = 0; - virtual bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, - Value *NewV) const = 0; + virtual Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, + Value *OldV, + Value *NewV) const = 0; virtual bool isLoweredToCall(const Function *F) = 0; virtual void getUnrollingPreferences(Loop *L, ScalarEvolution &, UnrollingPreferences &UP) = 0; @@ -1507,8 +1509,8 @@ return Impl.collectFlatAddressOperands(OpIndexes, IID); } - bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, - Value *NewV) const override { + Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, + Value *NewV) const override { return Impl.rewriteIntrinsicWithAddressSpace(II, OldV, NewV); } Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -90,9 +90,9 @@ return false; } - bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, - Value *NewV) const { - return false; + Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, + Value *NewV) const { + return nullptr; } bool isLoweredToCall(const Function *F) { Index: llvm/include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -222,9 +222,9 @@ return false; } - bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, - Value *OldV, Value *NewV) const { - return false; + Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, + Value *NewV) const { + return nullptr; } bool isLegalAddImmediate(int64_t imm) { Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -212,9 +212,8 @@ return TTIImpl->collectFlatAddressOperands(OpIndexes, IID); } -bool TargetTransformInfo::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, - Value *OldV, - Value *NewV) const { +Value *TargetTransformInfo::rewriteIntrinsicWithAddressSpace( + IntrinsicInst *II, Value *OldV, Value *NewV) const { return TTIImpl->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); } Index: llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h +++ llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h @@ -211,8 +211,8 @@ bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Intrinsic::ID IID) const; - bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, - Value *OldV, Value *NewV) const; + Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, + Value *NewV) const; unsigned getVectorSplitCost() { return 0; } Index: llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -868,8 +868,9 @@ } } -bool GCNTTIImpl::rewriteIntrinsicWithAddressSpace( - IntrinsicInst *II, Value *OldV, Value *NewV) const { +Value *GCNTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, + Value *OldV, + Value *NewV) const { auto IntrID = II->getIntrinsicID(); switch (IntrID) { case Intrinsic::amdgcn_atomic_inc: @@ -879,7 +880,7 @@ case Intrinsic::amdgcn_ds_fmax: { const ConstantInt *IsVolatile = cast(II->getArgOperand(4)); if (!IsVolatile->isZero()) - return false; + return nullptr; Module *M = II->getParent()->getParent()->getParent(); Type *DestTy = II->getType(); Type *SrcTy = NewV->getType(); @@ -887,7 +888,7 @@ Intrinsic::getDeclaration(M, II->getIntrinsicID(), {DestTy, SrcTy}); II->setArgOperand(0, NewV); II->setCalledFunction(NewDecl); - return true; + return II; } case Intrinsic::amdgcn_is_shared: case Intrinsic::amdgcn_is_private: { @@ -897,12 +898,25 @@ LLVMContext &Ctx = NewV->getType()->getContext(); ConstantInt *NewVal = (TrueAS == NewAS) ? ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx); - II->replaceAllUsesWith(NewVal); - II->eraseFromParent(); - return true; + return NewVal; + } + case Intrinsic::ptrmask: { + unsigned OldAS = OldV->getType()->getPointerAddressSpace(); + unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + if (!getTLI()->isNoopAddrSpaceCast(OldAS, NewAS)) + return nullptr; + + Module *M = II->getParent()->getParent()->getParent(); + Value *MaskOp = II->getArgOperand(1); + Type *MaskTy = MaskOp->getType(); + Function *NewDecl = Intrinsic::getDeclaration(M, Intrinsic::ptrmask, + {NewV->getType(), MaskTy}); + CallInst *NewCall = CallInst::Create(NewDecl->getFunctionType(), NewDecl, + {NewV, MaskOp}, "", II); + return NewCall; } default: - return false; + return nullptr; } } Index: llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp =================================================================== --- llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -175,6 +175,11 @@ bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const; + Value *cloneInstructionWithNewAddressSpace( + Instruction *I, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl *UndefUsesToFix) const; + // Changes the flat address expressions in function F to point to specific // address spaces if InferredAddrSpace says so. Postorder is the postorder of // all flat expressions in the use-def graph of function F. @@ -218,20 +223,24 @@ // TODO: Currently, we consider only phi, bitcast, addrspacecast, and // getelementptr operators. static bool isAddressExpression(const Value &V) { - if (!isa(V)) + const Operator *Op = dyn_cast(&V); + if (!Op) return false; - const Operator &Op = cast(V); - switch (Op.getOpcode()) { + switch (Op->getOpcode()) { case Instruction::PHI: - assert(Op.getType()->isPointerTy()); + assert(Op->getType()->isPointerTy()); return true; case Instruction::BitCast: case Instruction::AddrSpaceCast: case Instruction::GetElementPtr: return true; case Instruction::Select: - return Op.getType()->isPointerTy(); + return Op->getType()->isPointerTy(); + case Instruction::Call: { + const IntrinsicInst *II = dyn_cast(&V); + return II && II->getIntrinsicID() == Intrinsic::ptrmask; + } default: return false; } @@ -254,12 +263,17 @@ return {Op.getOperand(0)}; case Instruction::Select: return {Op.getOperand(1), Op.getOperand(2)}; + case Instruction::Call: { + const IntrinsicInst &II = cast(Op); + assert(II.getIntrinsicID() == Intrinsic::ptrmask && + "unexpected intrinsic call"); + return {II.getArgOperand(0)}; + } default: llvm_unreachable("Unexpected instruction type."); } } -// TODO: Move logic to TTI? bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II, Value *OldV, Value *NewV) const { @@ -275,8 +289,17 @@ II->setCalledFunction(NewDecl); return true; } - default: - return TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); + case Intrinsic::ptrmask: + // This is handled as an address expression, not as a use memory operation. + return false; + default: { + Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); + if (!Rewrite) + return false; + if (Rewrite != II) + II->replaceAllUsesWith(Rewrite); + return true; + } } } @@ -285,6 +308,7 @@ DenseSet &Visited) const { auto IID = II->getIntrinsicID(); switch (IID) { + case Intrinsic::ptrmask: case Intrinsic::objectsize: appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), PostorderStack, Visited); @@ -438,10 +462,13 @@ // Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast // from a pointer whose type already matches. Therefore, this function returns a // Value* instead of an Instruction*. -static Value *cloneInstructionWithNewAddressSpace( +// +// This may also return nullptr in the case the instruction could not be +// rewritten. +Value *InferAddressSpaces::cloneInstructionWithNewAddressSpace( Instruction *I, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, - SmallVectorImpl *UndefUsesToFix) { + SmallVectorImpl *UndefUsesToFix) const { Type *NewPtrType = I->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); @@ -456,6 +483,23 @@ return Src; } + if (IntrinsicInst *II = dyn_cast(I)) { + // Technically the intrinsic ID is a pointer typed argument, so specially + // handle calls early. + assert(II->getIntrinsicID() == Intrinsic::ptrmask); + Value *NewPtr = operandWithNewAddressSpaceOrCreateUndef( + II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace, + UndefUsesToFix); + Value *Rewrite = + TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr); + if (Rewrite) { + assert(Rewrite != II && "cannot modify this pointer operation in place"); + return Rewrite; + } + + return nullptr; + } + // Computes the converted pointer operands. SmallVector NewPointerOperands; for (const Use &OperandUse : I->operands()) { @@ -591,7 +635,7 @@ if (Instruction *I = dyn_cast(V)) { Value *NewV = cloneInstructionWithNewAddressSpace( I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix); - if (Instruction *NewI = dyn_cast(NewV)) { + if (Instruction *NewI = dyn_cast_or_null(NewV)) { if (NewI->getParent() == nullptr) { NewI->insertBefore(I); NewI->takeName(I); @@ -879,8 +923,10 @@ for (Value* V : Postorder) { unsigned NewAddrSpace = InferredAddrSpace.lookup(V); if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { - ValueWithNewAddrSpace[V] = cloneValueWithNewAddressSpace( - V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); + Value *New = cloneValueWithNewAddressSpace( + V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); + if (New) + ValueWithNewAddrSpace[V] = New; } } @@ -890,7 +936,10 @@ // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace. for (const Use *UndefUse : UndefUsesToFix) { User *V = UndefUse->getUser(); - User *NewV = cast(ValueWithNewAddrSpace.lookup(V)); + User *NewV = cast_or_null(ValueWithNewAddrSpace.lookup(V)); + if (!NewV) + continue; + unsigned OperandNo = UndefUse->getOperandNo(); assert(isa(NewV->getOperand(OperandNo))); NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get())); Index: llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll =================================================================== --- llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll +++ llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll @@ -42,9 +42,8 @@ define i8 @ptrmask_cast_global_to_flat(i8 addrspace(1)* %src.ptr, i64 %mask) { ; CHECK-LABEL: @ptrmask_cast_global_to_flat( -; CHECK-NEXT: [[CAST:%.*]] = addrspacecast i8 addrspace(1)* [[SRC_PTR:%.*]] to i8* -; CHECK-NEXT: [[MASKED:%.*]] = call i8* @llvm.ptrmask.p0i8.i64(i8* [[CAST]], i64 [[MASK:%.*]]) -; CHECK-NEXT: [[LOAD:%.*]] = load i8, i8* [[MASKED]], align 1 +; CHECK-NEXT: [[TMP1:%.*]] = call i8 addrspace(1)* @llvm.ptrmask.p1i8.i64(i8 addrspace(1)* [[SRC_PTR:%.*]], i64 [[MASK:%.*]]) +; CHECK-NEXT: [[LOAD:%.*]] = load i8, i8 addrspace(1)* [[TMP1]], align 1 ; CHECK-NEXT: ret i8 [[LOAD]] ; %cast = addrspacecast i8 addrspace(1)* %src.ptr to i8* @@ -55,9 +54,8 @@ define i8 @ptrmask_cast_999_to_flat(i8 addrspace(999)* %src.ptr, i64 %mask) { ; CHECK-LABEL: @ptrmask_cast_999_to_flat( -; CHECK-NEXT: [[CAST:%.*]] = addrspacecast i8 addrspace(999)* [[SRC_PTR:%.*]] to i8* -; CHECK-NEXT: [[MASKED:%.*]] = call i8* @llvm.ptrmask.p0i8.i64(i8* [[CAST]], i64 [[MASK:%.*]]) -; CHECK-NEXT: [[LOAD:%.*]] = load i8, i8* [[MASKED]], align 1 +; CHECK-NEXT: [[TMP1:%.*]] = call i8 addrspace(999)* @llvm.ptrmask.p999i8.i64(i8 addrspace(999)* [[SRC_PTR:%.*]], i64 [[MASK:%.*]]) +; CHECK-NEXT: [[LOAD:%.*]] = load i8, i8 addrspace(999)* [[TMP1]], align 1 ; CHECK-NEXT: ret i8 [[LOAD]] ; %cast = addrspacecast i8 addrspace(999)* %src.ptr to i8* @@ -121,8 +119,8 @@ define i8 @ptrmask_cast_global_to_flat_global(i64 %mask) { ; CHECK-LABEL: @ptrmask_cast_global_to_flat_global( -; CHECK-NEXT: [[MASKED:%.*]] = call i8* @llvm.ptrmask.p0i8.i64(i8* addrspacecast (i8 addrspace(1)* @gv to i8*), i64 [[MASK:%.*]]) -; CHECK-NEXT: [[LOAD:%.*]] = load i8, i8* [[MASKED]], align 1 +; CHECK-NEXT: [[TMP1:%.*]] = call i8 addrspace(1)* @llvm.ptrmask.p1i8.i64(i8 addrspace(1)* @gv, i64 [[MASK:%.*]]) +; CHECK-NEXT: [[LOAD:%.*]] = load i8, i8 addrspace(1)* [[TMP1]], align 1 ; CHECK-NEXT: ret i8 [[LOAD]] ; %masked = call i8* @llvm.ptrmask.p0i8.i64(i8* addrspacecast (i8 addrspace(1)* @gv to i8*), i64 %mask) @@ -132,10 +130,9 @@ define i8 @multi_ptrmask_cast_global_to_flat(i8 addrspace(1)* %src.ptr, i64 %mask) { ; CHECK-LABEL: @multi_ptrmask_cast_global_to_flat( -; CHECK-NEXT: [[CAST:%.*]] = addrspacecast i8 addrspace(1)* [[SRC_PTR:%.*]] to i8* -; CHECK-NEXT: [[LOAD0:%.*]] = load i8, i8 addrspace(1)* [[SRC_PTR]], align 1 -; CHECK-NEXT: [[MASKED:%.*]] = call i8* @llvm.ptrmask.p0i8.i64(i8* [[CAST]], i64 [[MASK:%.*]]) -; CHECK-NEXT: [[LOAD1:%.*]] = load i8, i8* [[MASKED]], align 1 +; CHECK-NEXT: [[LOAD0:%.*]] = load i8, i8 addrspace(1)* [[SRC_PTR:%.*]], align 1 +; CHECK-NEXT: [[TMP1:%.*]] = call i8 addrspace(1)* @llvm.ptrmask.p1i8.i64(i8 addrspace(1)* [[SRC_PTR]], i64 [[MASK:%.*]]) +; CHECK-NEXT: [[LOAD1:%.*]] = load i8, i8 addrspace(1)* [[TMP1]], align 1 ; CHECK-NEXT: [[ADD:%.*]] = add i8 [[LOAD0]], [[LOAD1]] ; CHECK-NEXT: ret i8 [[ADD]] ;