Index: llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp =================================================================== --- llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -256,6 +256,48 @@ INITIALIZE_PASS_END(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", false, false) +static unsigned getPtrOrVecOfPtrsAddressSpace(Type *Ty) { + if (Ty->isVectorTy()) { + Ty = cast(Ty)->getElementType(); + } + assert(Ty->isPointerTy()); + return Ty->getPointerAddressSpace(); +} + +static bool isPtrOrVecOfPtrsType(Type *Ty) { + if (Ty->isVectorTy()) { + Ty = cast(Ty)->getElementType(); + } + return Ty->isPointerTy(); +} + +static Type *getPtrOrVecOfPtrsWithNewAS(Type *Ty, unsigned NewAddrSpace) { + if (!Ty->isVectorTy()) { + assert(Ty->isPointerTy()); + return PointerType::getWithSamePointeeType(cast(Ty), + NewAddrSpace); + } + + Type *PT = cast(Ty)->getElementType(); + assert(PT->isPointerTy()); + + Type *NPT = + PointerType::getWithSamePointeeType(cast(PT), NewAddrSpace); + return VectorType::get(NPT, cast(Ty)->getElementCount()); +} + +static bool hasSameElementOfPtrOrVecPtrs(Type *Ty1, Type *Ty2) { + assert(isPtrOrVecOfPtrsType(Ty1) && isPtrOrVecOfPtrsType(Ty2)); + assert(Ty1->isVectorTy() == Ty2->isVectorTy()); + if (Ty1->isVectorTy()) { + Ty1 = cast(Ty1)->getElementType(); + Ty2 = cast(Ty2)->getElementType(); + } + + assert(Ty1->isPointerTy() && Ty2->isPointerTy()); + return cast(Ty1)->hasSameElementTypeAs(cast(Ty2)); +} + // Check whether that's no-op pointer bicast using a pair of // `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over // different address spaces. @@ -279,8 +321,9 @@ // arithmetic may also be undefined after invalid pointer reinterpret cast. // However, as we confirm through the target hooks that it's a no-op // addrspacecast, it doesn't matter since the bits should be the same. - unsigned P2IOp0AS = P2I->getOperand(0)->getType()->getPointerAddressSpace(); - unsigned I2PAS = I2P->getType()->getPointerAddressSpace(); + unsigned P2IOp0AS = + getPtrOrVecOfPtrsAddressSpace(P2I->getOperand(0)->getType()); + unsigned I2PAS = getPtrOrVecOfPtrsAddressSpace(I2P->getType()); return CastInst::isNoopCast(Instruction::CastOps(I2P->getOpcode()), I2P->getOperand(0)->getType(), I2P->getType(), DL) && @@ -301,14 +344,14 @@ switch (Op->getOpcode()) { case Instruction::PHI: - assert(Op->getType()->isPointerTy()); + assert(isPtrOrVecOfPtrsType(Op->getType())); return true; case Instruction::BitCast: case Instruction::AddrSpaceCast: case Instruction::GetElementPtr: return true; case Instruction::Select: - return Op->getType()->isPointerTy(); + return isPtrOrVecOfPtrsType(Op->getType()); case Instruction::Call: { const IntrinsicInst *II = dyn_cast(&V); return II && II->getIntrinsicID() == Intrinsic::ptrmask; @@ -373,6 +416,24 @@ case Intrinsic::ptrmask: // This is handled as an address expression, not as a use memory operation. return false; + case Intrinsic::masked_gather: { + Type *RetTy = II->getType(); + Type *NewPtrTy = NewV->getType(); + Function *NewDecl = + Intrinsic::getDeclaration(M, II->getIntrinsicID(), {RetTy, NewPtrTy}); + II->setArgOperand(0, NewV); + II->setCalledFunction(NewDecl); + return true; + } + case Intrinsic::masked_scatter: { + Type *ValueTy = II->getOperand(0)->getType(); + Type *NewPtrTy = NewV->getType(); + Function *NewDecl = + Intrinsic::getDeclaration(M, II->getIntrinsicID(), {ValueTy, NewPtrTy}); + II->setArgOperand(1, NewV); + II->setCalledFunction(NewDecl); + return true; + } default: { Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); if (!Rewrite) @@ -394,6 +455,14 @@ appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), PostorderStack, Visited); break; + case Intrinsic::masked_gather: + appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), + PostorderStack, Visited); + break; + case Intrinsic::masked_scatter: + appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(1), + PostorderStack, Visited); + break; default: SmallVector OpIndexes; if (TTI->collectFlatAddressOperands(OpIndexes, IID)) { @@ -412,7 +481,7 @@ void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack( Value *V, PostorderStackTy &PostorderStack, DenseSet &Visited) const { - assert(V->getType()->isPointerTy()); + assert(isPtrOrVecOfPtrsType(V->getType())); // Generic addressing expressions may be hidden in nested constant // expressions. @@ -424,7 +493,7 @@ return; } - if (V->getType()->getPointerAddressSpace() == FlatAddrSpace && + if (getPtrOrVecOfPtrsAddressSpace(V->getType()) == FlatAddrSpace && isAddressExpression(*V, *DL, TTI)) { if (Visited.insert(V).second) { PostorderStack.emplace_back(V, false); @@ -460,8 +529,7 @@ // addressing calculations may also be faster. for (Instruction &I : instructions(F)) { if (auto *GEP = dyn_cast(&I)) { - if (!GEP->getType()->isVectorTy()) - PushPtrOperand(GEP->getPointerOperand()); + PushPtrOperand(GEP->getPointerOperand()); } else if (auto *LI = dyn_cast(&I)) PushPtrOperand(LI->getPointerOperand()); else if (auto *SI = dyn_cast(&I)) @@ -481,13 +549,12 @@ collectRewritableIntrinsicOperands(II, PostorderStack, Visited); else if (ICmpInst *Cmp = dyn_cast(&I)) { // FIXME: Handle vectors of pointers - if (Cmp->getOperand(0)->getType()->isPointerTy()) { + if (isPtrOrVecOfPtrsType(Cmp->getOperand(0)->getType())) { PushPtrOperand(Cmp->getOperand(0)); PushPtrOperand(Cmp->getOperand(1)); } } else if (auto *ASC = dyn_cast(&I)) { - if (!ASC->getType()->isVectorTy()) - PushPtrOperand(ASC->getPointerOperand()); + PushPtrOperand(ASC->getPointerOperand()); } else if (auto *I2P = dyn_cast(&I)) { if (isNoopPtrIntCastPair(cast(I2P), *DL, TTI)) PushPtrOperand( @@ -501,7 +568,7 @@ // If the operands of the expression on the top are already explored, // adds that expression to the resultant postorder. if (PostorderStack.back().getInt()) { - if (TopVal->getType()->getPointerAddressSpace() == FlatAddrSpace) + if (getPtrOrVecOfPtrsAddressSpace(TopVal->getType()) == FlatAddrSpace) Postorder.push_back(TopVal); PostorderStack.pop_back(); continue; @@ -529,8 +596,7 @@ SmallVectorImpl *UndefUsesToFix) { Value *Operand = OperandUse.get(); - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast(Operand->getType()), NewAddrSpace); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAddrSpace); if (Constant *C = dyn_cast(Operand)) return ConstantExpr::getAddrSpaceCast(C, NewPtrTy); @@ -543,8 +609,7 @@ if (I != PredicatedAS.end()) { // Insert an addrspacecast on that operand before the user. unsigned NewAS = I->second; - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast(Operand->getType()), NewAS); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAS); auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy); NewI->insertBefore(Inst); NewI->setDebugLoc(Inst->getDebugLoc()); @@ -572,15 +637,14 @@ const ValueToValueMapTy &ValueWithNewAddrSpace, const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl *UndefUsesToFix) const { - Type *NewPtrType = PointerType::getWithSamePointeeType( - cast(I->getType()), NewAddrSpace); + Type *NewPtrType = getPtrOrVecOfPtrsWithNewAS(I->getType(), NewAddrSpace); if (I->getOpcode() == Instruction::AddrSpaceCast) { Value *Src = I->getOperand(0); // Because `I` is flat, the source address space must be specific. // Therefore, the inferred address space must be the source space, according // to our algorithm. - assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); + assert(getPtrOrVecOfPtrsAddressSpace(Src->getType()) == NewAddrSpace); if (Src->getType() != NewPtrType) return new BitCastInst(Src, NewPtrType); return Src; @@ -607,8 +671,7 @@ if (AS != UninitializedAddressSpace) { // For the assumed address space, insert an `addrspacecast` to make that // explicit. - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast(I->getType()), AS); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(I->getType(), AS); auto *NewI = new AddrSpaceCastInst(I, NewPtrTy); NewI->insertAfter(I); return NewI; @@ -617,7 +680,7 @@ // Computes the converted pointer operands. SmallVector NewPointerOperands; for (const Use &OperandUse : I->operands()) { - if (!OperandUse.get()->getType()->isPointerTy()) + if (!isPtrOrVecOfPtrsType(OperandUse.get()->getType())) NewPointerOperands.push_back(nullptr); else NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef( @@ -629,7 +692,7 @@ case Instruction::BitCast: return new BitCastInst(NewPointerOperands[0], NewPtrType); case Instruction::PHI: { - assert(I->getType()->isPointerTy()); + assert(isPtrOrVecOfPtrsType(I->getType())); PHINode *PHI = cast(I); PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues()); for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) { @@ -648,7 +711,7 @@ return NewGEP; } case Instruction::Select: - assert(I->getType()->isPointerTy()); + assert(isPtrOrVecOfPtrsType(I->getType())); return SelectInst::Create(I->getOperand(0), NewPointerOperands[1], NewPointerOperands[2], "", nullptr, I); case Instruction::IntToPtr: { @@ -674,16 +737,16 @@ ConstantExpr *CE, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL, const TargetTransformInfo *TTI) { - Type *TargetType = CE->getType()->isPointerTy() - ? PointerType::getWithSamePointeeType( - cast(CE->getType()), NewAddrSpace) - : CE->getType(); + Type *TargetType = + isPtrOrVecOfPtrsType(CE->getType()) + ? getPtrOrVecOfPtrsWithNewAS(CE->getType(), NewAddrSpace) + : CE->getType(); if (CE->getOpcode() == Instruction::AddrSpaceCast) { // Because CE is flat, the source address space must be specific. // Therefore, the inferred address space must be the source space according // to our algorithm. - assert(CE->getOperand(0)->getType()->getPointerAddressSpace() == + assert(getPtrOrVecOfPtrsAddressSpace(CE->getOperand(0)->getType()) == NewAddrSpace); return ConstantExpr::getBitCast(CE->getOperand(0), TargetType); } @@ -697,7 +760,7 @@ if (CE->getOpcode() == Instruction::IntToPtr) { assert(isNoopPtrIntCastPair(cast(CE), *DL, TTI)); Constant *Src = cast(CE->getOperand(0))->getOperand(0); - assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); + assert(getPtrOrVecOfPtrsAddressSpace(Src->getType()) == NewAddrSpace); return ConstantExpr::getBitCast(Src, TargetType); } @@ -753,7 +816,7 @@ const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl *UndefUsesToFix) const { // All values in Postorder are flat address expressions. - assert(V->getType()->getPointerAddressSpace() == FlatAddrSpace && + assert(getPtrOrVecOfPtrsAddressSpace(V->getType()) == FlatAddrSpace && isAddressExpression(*V, *DL, TTI)); if (Instruction *I = dyn_cast(V)) { @@ -898,12 +961,14 @@ Value *Src1 = Op.getOperand(2); auto I = InferredAddrSpace.find(Src0); - unsigned Src0AS = (I != InferredAddrSpace.end()) ? - I->second : Src0->getType()->getPointerAddressSpace(); + unsigned Src0AS = (I != InferredAddrSpace.end()) + ? I->second + : getPtrOrVecOfPtrsAddressSpace(Src0->getType()); auto J = InferredAddrSpace.find(Src1); - unsigned Src1AS = (J != InferredAddrSpace.end()) ? - J->second : Src1->getType()->getPointerAddressSpace(); + unsigned Src1AS = (J != InferredAddrSpace.end()) + ? J->second + : getPtrOrVecOfPtrsAddressSpace(Src1->getType()); auto *C0 = dyn_cast(Src0); auto *C1 = dyn_cast(Src1); @@ -932,7 +997,7 @@ auto I = InferredAddrSpace.find(PtrOperand); unsigned OperandAS; if (I == InferredAddrSpace.end()) { - OperandAS = PtrOperand->getType()->getPointerAddressSpace(); + OperandAS = getPtrOrVecOfPtrsAddressSpace(PtrOperand->getType()); if (OperandAS == FlatAddrSpace) { // Check AC for assumption dominating V. unsigned AS = getPredicatedAddrSpace(V, PtrOperand); @@ -1057,7 +1122,7 @@ unsigned NewAS) const { assert(NewAS != UninitializedAddressSpace); - unsigned SrcAS = C->getType()->getPointerAddressSpace(); + unsigned SrcAS = getPtrOrVecOfPtrsAddressSpace(C->getType()); if (SrcAS == NewAS || isa(C)) return true; @@ -1075,7 +1140,7 @@ return isSafeToCastConstAddrSpace(cast(Op->getOperand(0)), NewAS); if (Op->getOpcode() == Instruction::IntToPtr && - Op->getType()->getPointerAddressSpace() == FlatAddrSpace) + getPtrOrVecOfPtrsAddressSpace(Op->getType()) == FlatAddrSpace) return true; } @@ -1111,7 +1176,7 @@ if (NewAddrSpace == UninitializedAddressSpace) continue; - if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { + if (getPtrOrVecOfPtrsAddressSpace(V->getType()) != NewAddrSpace) { Value *New = cloneValueWithNewAddressSpace(V, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, &UndefUsesToFix); @@ -1168,7 +1233,7 @@ I = skipToNextUser(I, E); if (isSimplePointerUseValidToReplace( - *TTI, U, V->getType()->getPointerAddressSpace())) { + *TTI, U, getPtrOrVecOfPtrsAddressSpace(V->getType()))) { // If V is used as the pointer operand of a compatible memory operation, // sets the pointer operand to NewV. This replacement does not change // the element type, so the resultant load/store is still valid. @@ -1199,13 +1264,13 @@ // into // %cmp = icmp eq float addrspace(3)* %new_p, %new_q - unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + unsigned NewAS = getPtrOrVecOfPtrsAddressSpace(NewV->getType()); int SrcIdx = U.getOperandNo(); int OtherIdx = (SrcIdx == 0) ? 1 : 0; Value *OtherSrc = Cmp->getOperand(OtherIdx); if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) { - if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) { + if (getPtrOrVecOfPtrsAddressSpace(OtherNewV->getType()) == NewAS) { Cmp->setOperand(OtherIdx, OtherNewV); Cmp->setOperand(SrcIdx, NewV); continue; @@ -1224,11 +1289,10 @@ } if (AddrSpaceCastInst *ASC = dyn_cast(CurUser)) { - unsigned NewAS = NewV->getType()->getPointerAddressSpace(); - if (ASC->getDestAddressSpace() == NewAS) { - if (!cast(ASC->getType()) - ->hasSameElementTypeAs( - cast(NewV->getType()))) { + unsigned NewAS = getPtrOrVecOfPtrsAddressSpace(NewV->getType()); + if (getPtrOrVecOfPtrsAddressSpace(ASC->getType()) == NewAS) { + if (!hasSameElementOfPtrOrVecPtrs(ASC->getType(), + NewV->getType())) { BasicBlock::iterator InsertPos; if (Instruction *NewVInst = dyn_cast(NewV)) InsertPos = std::next(NewVInst->getIterator()); Index: llvm/test/Transforms/InferAddressSpaces/AMDGPU/icmp.ll =================================================================== --- llvm/test/Transforms/InferAddressSpaces/AMDGPU/icmp.ll +++ llvm/test/Transforms/InferAddressSpaces/AMDGPU/icmp.ll @@ -147,9 +147,8 @@ ret i1 %cmp } -; TODO: Should be handled ; CHECK-LABEL: @icmp_flat_flat_from_group_vector( -; CHECK: %cmp = icmp eq <2 x ptr> %cast0, %cast1 +; CHECK: %cmp = icmp eq <2 x ptr addrspace(3)> %group.ptr.0, %group.ptr.1 define <2 x i1> @icmp_flat_flat_from_group_vector(<2 x ptr addrspace(3)> %group.ptr.0, <2 x ptr addrspace(3)> %group.ptr.1) #0 { %cast0 = addrspacecast <2 x ptr addrspace(3)> %group.ptr.0 to <2 x ptr> %cast1 = addrspacecast <2 x ptr addrspace(3)> %group.ptr.1 to <2 x ptr> Index: llvm/test/Transforms/InferAddressSpaces/masked-gather-scatter.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/InferAddressSpaces/masked-gather-scatter.ll @@ -0,0 +1,25 @@ +; RUN: opt -S -passes=infer-address-spaces -assume-default-is-flat-addrspace %s | FileCheck %s + +; CHECK-LABEL: @masked_gather_inferas( +; CHECK: tail call <4 x i32> @llvm.masked.gather.v4i32.v4p1 +define <4 x i32> @masked_gather_inferas(ptr addrspace(1) %out, <4 x i64> %index) { +entry: + %out.1 = addrspacecast ptr addrspace(1) %out to ptr + %ptrs = getelementptr inbounds i32, ptr %out.1, <4 x i64> %index + %value = tail call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> %ptrs, i32 4, <4 x i1> , <4 x i32> poison) + ret <4 x i32> %value +} + +; CHECK-LABEL: @masked_scatter_inferas( +; CHECK: tail call void @llvm.masked.scatter.v4i32.v4p1 +define void @masked_scatter_inferas(ptr addrspace(1) %out, <4 x i64> %index, <4 x i32> %value) { +entry: + %out.1 = addrspacecast ptr addrspace(1) %out to ptr + %ptrs = getelementptr inbounds i32, ptr %out.1, <4 x i64> %index + tail call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> %value, <4 x ptr> %ptrs, i32 4, <4 x i1> ) + ret void +} + +declare <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr>, i32 immarg, <4 x i1>, <4 x i32>) + +declare void @llvm.masked.scatter.v4i32.v4p0(<4 x i32>, <4 x ptr>, i32 immarg, <4 x i1>) \ No newline at end of file