diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -256,6 +256,48 @@ INITIALIZE_PASS_END(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", false, false) +static unsigned getPtrOrVecOfPtrsAddressSpace(Type *Ty) { + if (Ty->isVectorTy()) { + Ty = cast(Ty)->getElementType(); + } + assert(Ty->isPointerTy()); + return Ty->getPointerAddressSpace(); +} + +static bool isPtrOrVecOfPtrsType(Type *Ty) { + if (Ty->isVectorTy()) { + Ty = cast(Ty)->getElementType(); + } + return Ty->isPointerTy(); +} + +static Type *getPtrOrVecOfPtrsWithNewAS(Type *Ty, unsigned NewAddrSpace) { + if (!Ty->isVectorTy()) { + assert(Ty->isPointerTy()); + return PointerType::getWithSamePointeeType(cast(Ty), + NewAddrSpace); + } + + Type *PT = cast(Ty)->getElementType(); + assert(PT->isPointerTy()); + + Type *NPT = + PointerType::getWithSamePointeeType(cast(PT), NewAddrSpace); + return VectorType::get(NPT, cast(Ty)->getElementCount()); +} + +static bool hasSameElementOfPtrOrVecPtrs(Type *Ty1, Type *Ty2) { + assert(isPtrOrVecOfPtrsType(Ty1) && isPtrOrVecOfPtrsType(Ty2)); + assert(Ty1->isVectorTy() == Ty2->isVectorTy()); + if (Ty1->isVectorTy()) { + Ty1 = cast(Ty1)->getElementType(); + Ty2 = cast(Ty2)->getElementType(); + } + + assert(Ty1->isPointerTy() && Ty2->isPointerTy()); + return cast(Ty1)->hasSameElementTypeAs(cast(Ty2)); +} + // Check whether that's no-op pointer bicast using a pair of // `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over // different address spaces. @@ -279,8 +321,9 @@ // arithmetic may also be undefined after invalid pointer reinterpret cast. // However, as we confirm through the target hooks that it's a no-op // addrspacecast, it doesn't matter since the bits should be the same. - unsigned P2IOp0AS = P2I->getOperand(0)->getType()->getPointerAddressSpace(); - unsigned I2PAS = I2P->getType()->getPointerAddressSpace(); + unsigned P2IOp0AS = + getPtrOrVecOfPtrsAddressSpace(P2I->getOperand(0)->getType()); + unsigned I2PAS = getPtrOrVecOfPtrsAddressSpace(I2P->getType()); return CastInst::isNoopCast(Instruction::CastOps(I2P->getOpcode()), I2P->getOperand(0)->getType(), I2P->getType(), DL) && @@ -301,14 +344,14 @@ switch (Op->getOpcode()) { case Instruction::PHI: - assert(Op->getType()->isPointerTy()); + assert(isPtrOrVecOfPtrsType(Op->getType())); return true; case Instruction::BitCast: case Instruction::AddrSpaceCast: case Instruction::GetElementPtr: return true; case Instruction::Select: - return Op->getType()->isPointerTy(); + return isPtrOrVecOfPtrsType(Op->getType()); case Instruction::Call: { const IntrinsicInst *II = dyn_cast(&V); return II && II->getIntrinsicID() == Intrinsic::ptrmask; @@ -412,7 +455,7 @@ void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack( Value *V, PostorderStackTy &PostorderStack, DenseSet &Visited) const { - assert(V->getType()->isPointerTy()); + assert(isPtrOrVecOfPtrsType(V->getType())); // Generic addressing expressions may be hidden in nested constant // expressions. @@ -424,7 +467,7 @@ return; } - if (V->getType()->getPointerAddressSpace() == FlatAddrSpace && + if (getPtrOrVecOfPtrsAddressSpace(V->getType()) == FlatAddrSpace && isAddressExpression(*V, *DL, TTI)) { if (Visited.insert(V).second) { PostorderStack.emplace_back(V, false); @@ -460,8 +503,7 @@ // addressing calculations may also be faster. for (Instruction &I : instructions(F)) { if (auto *GEP = dyn_cast(&I)) { - if (!GEP->getType()->isVectorTy()) - PushPtrOperand(GEP->getPointerOperand()); + PushPtrOperand(GEP->getPointerOperand()); } else if (auto *LI = dyn_cast(&I)) PushPtrOperand(LI->getPointerOperand()); else if (auto *SI = dyn_cast(&I)) @@ -481,13 +523,12 @@ collectRewritableIntrinsicOperands(II, PostorderStack, Visited); else if (ICmpInst *Cmp = dyn_cast(&I)) { // FIXME: Handle vectors of pointers - if (Cmp->getOperand(0)->getType()->isPointerTy()) { + if (getPtrOrVecOfPtrsAddressSpace(Cmp->getOperand(0)->getType())) { PushPtrOperand(Cmp->getOperand(0)); PushPtrOperand(Cmp->getOperand(1)); } } else if (auto *ASC = dyn_cast(&I)) { - if (!ASC->getType()->isVectorTy()) - PushPtrOperand(ASC->getPointerOperand()); + PushPtrOperand(ASC->getPointerOperand()); } else if (auto *I2P = dyn_cast(&I)) { if (isNoopPtrIntCastPair(cast(I2P), *DL, TTI)) PushPtrOperand( @@ -501,7 +542,7 @@ // If the operands of the expression on the top are already explored, // adds that expression to the resultant postorder. if (PostorderStack.back().getInt()) { - if (TopVal->getType()->getPointerAddressSpace() == FlatAddrSpace) + if (getPtrOrVecOfPtrsAddressSpace(TopVal->getType()) == FlatAddrSpace) Postorder.push_back(TopVal); PostorderStack.pop_back(); continue; @@ -529,8 +570,7 @@ SmallVectorImpl *UndefUsesToFix) { Value *Operand = OperandUse.get(); - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast(Operand->getType()), NewAddrSpace); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAddrSpace); if (Constant *C = dyn_cast(Operand)) return ConstantExpr::getAddrSpaceCast(C, NewPtrTy); @@ -543,8 +583,7 @@ if (I != PredicatedAS.end()) { // Insert an addrspacecast on that operand before the user. unsigned NewAS = I->second; - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast(Operand->getType()), NewAS); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAS); auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy); NewI->insertBefore(Inst); NewI->setDebugLoc(Inst->getDebugLoc()); @@ -572,15 +611,14 @@ const ValueToValueMapTy &ValueWithNewAddrSpace, const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl *UndefUsesToFix) const { - Type *NewPtrType = PointerType::getWithSamePointeeType( - cast(I->getType()), NewAddrSpace); + Type *NewPtrType = getPtrOrVecOfPtrsWithNewAS(I->getType(), NewAddrSpace); if (I->getOpcode() == Instruction::AddrSpaceCast) { Value *Src = I->getOperand(0); // Because `I` is flat, the source address space must be specific. // Therefore, the inferred address space must be the source space, according // to our algorithm. - assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); + assert(getPtrOrVecOfPtrsAddressSpace(Src->getType()) == NewAddrSpace); if (Src->getType() != NewPtrType) return new BitCastInst(Src, NewPtrType); return Src; @@ -607,8 +645,7 @@ if (AS != UninitializedAddressSpace) { // For the assumed address space, insert an `addrspacecast` to make that // explicit. - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast(I->getType()), AS); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(I->getType(), AS); auto *NewI = new AddrSpaceCastInst(I, NewPtrTy); NewI->insertAfter(I); return NewI; @@ -617,7 +654,7 @@ // Computes the converted pointer operands. SmallVector NewPointerOperands; for (const Use &OperandUse : I->operands()) { - if (!OperandUse.get()->getType()->isPointerTy()) + if (!isPtrOrVecOfPtrsType(OperandUse.get()->getType())) NewPointerOperands.push_back(nullptr); else NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef( @@ -629,7 +666,7 @@ case Instruction::BitCast: return new BitCastInst(NewPointerOperands[0], NewPtrType); case Instruction::PHI: { - assert(I->getType()->isPointerTy()); + assert(isPtrOrVecOfPtrsType(I->getType())); PHINode *PHI = cast(I); PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues()); for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) { @@ -648,7 +685,7 @@ return NewGEP; } case Instruction::Select: - assert(I->getType()->isPointerTy()); + assert(isPtrOrVecOfPtrsType(I->getType())); return SelectInst::Create(I->getOperand(0), NewPointerOperands[1], NewPointerOperands[2], "", nullptr, I); case Instruction::IntToPtr: { @@ -674,16 +711,16 @@ ConstantExpr *CE, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL, const TargetTransformInfo *TTI) { - Type *TargetType = CE->getType()->isPointerTy() - ? PointerType::getWithSamePointeeType( - cast(CE->getType()), NewAddrSpace) - : CE->getType(); + Type *TargetType = + isPtrOrVecOfPtrsType(CE->getType()) + ? getPtrOrVecOfPtrsWithNewAS(CE->getType(), NewAddrSpace) + : CE->getType(); if (CE->getOpcode() == Instruction::AddrSpaceCast) { // Because CE is flat, the source address space must be specific. // Therefore, the inferred address space must be the source space according // to our algorithm. - assert(CE->getOperand(0)->getType()->getPointerAddressSpace() == + assert(getPtrOrVecOfPtrsAddressSpace(CE->getOperand(0)->getType()) == NewAddrSpace); return ConstantExpr::getBitCast(CE->getOperand(0), TargetType); } @@ -697,8 +734,8 @@ if (CE->getOpcode() == Instruction::Select) { Constant *Src0 = CE->getOperand(1); Constant *Src1 = CE->getOperand(2); - if (Src0->getType()->getPointerAddressSpace() == - Src1->getType()->getPointerAddressSpace()) { + if (getPtrOrVecOfPtrsAddressSpace(Src0->getType()) == + getPtrOrVecOfPtrsAddressSpace(Src1->getType())) { return ConstantExpr::getSelect( CE->getOperand(0), ConstantExpr::getAddrSpaceCast(Src0, TargetType), @@ -709,7 +746,7 @@ if (CE->getOpcode() == Instruction::IntToPtr) { assert(isNoopPtrIntCastPair(cast(CE), *DL, TTI)); Constant *Src = cast(CE->getOperand(0))->getOperand(0); - assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); + assert(getPtrOrVecOfPtrsAddressSpace(Src->getType()) == NewAddrSpace); return ConstantExpr::getBitCast(Src, TargetType); } @@ -765,7 +802,7 @@ const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl *UndefUsesToFix) const { // All values in Postorder are flat address expressions. - assert(V->getType()->getPointerAddressSpace() == FlatAddrSpace && + assert(isPtrOrVecOfPtrsType(V->getType()) == FlatAddrSpace && isAddressExpression(*V, *DL, TTI)); if (Instruction *I = dyn_cast(V)) { @@ -910,12 +947,14 @@ Value *Src1 = Op.getOperand(2); auto I = InferredAddrSpace.find(Src0); - unsigned Src0AS = (I != InferredAddrSpace.end()) ? - I->second : Src0->getType()->getPointerAddressSpace(); + unsigned Src0AS = (I != InferredAddrSpace.end()) + ? I->second + : getPtrOrVecOfPtrsAddressSpace(Src0->getType()); auto J = InferredAddrSpace.find(Src1); - unsigned Src1AS = (J != InferredAddrSpace.end()) ? - J->second : Src1->getType()->getPointerAddressSpace(); + unsigned Src1AS = (J != InferredAddrSpace.end()) + ? J->second + : getPtrOrVecOfPtrsAddressSpace(Src1->getType()); auto *C0 = dyn_cast(Src0); auto *C1 = dyn_cast(Src1); @@ -944,7 +983,7 @@ auto I = InferredAddrSpace.find(PtrOperand); unsigned OperandAS; if (I == InferredAddrSpace.end()) { - OperandAS = PtrOperand->getType()->getPointerAddressSpace(); + OperandAS = getPtrOrVecOfPtrsAddressSpace(PtrOperand->getType()); if (OperandAS == FlatAddrSpace) { // Check AC for assumption dominating V. unsigned AS = getPredicatedAddrSpace(V, PtrOperand); @@ -1069,7 +1108,7 @@ unsigned NewAS) const { assert(NewAS != UninitializedAddressSpace); - unsigned SrcAS = C->getType()->getPointerAddressSpace(); + unsigned SrcAS = getPtrOrVecOfPtrsAddressSpace(C->getType()); if (SrcAS == NewAS || isa(C)) return true; @@ -1087,7 +1126,7 @@ return isSafeToCastConstAddrSpace(cast(Op->getOperand(0)), NewAS); if (Op->getOpcode() == Instruction::IntToPtr && - Op->getType()->getPointerAddressSpace() == FlatAddrSpace) + getPtrOrVecOfPtrsAddressSpace(Op->getType()) == FlatAddrSpace) return true; } @@ -1123,7 +1162,7 @@ if (NewAddrSpace == UninitializedAddressSpace) continue; - if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { + if (getPtrOrVecOfPtrsAddressSpace(V->getType()) != NewAddrSpace) { Value *New = cloneValueWithNewAddressSpace(V, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, &UndefUsesToFix); @@ -1180,7 +1219,7 @@ I = skipToNextUser(I, E); if (isSimplePointerUseValidToReplace( - *TTI, U, V->getType()->getPointerAddressSpace())) { + *TTI, U, getPtrOrVecOfPtrsAddressSpace(V->getType()))) { // If V is used as the pointer operand of a compatible memory operation, // sets the pointer operand to NewV. This replacement does not change // the element type, so the resultant load/store is still valid. @@ -1211,13 +1250,13 @@ // into // %cmp = icmp eq float addrspace(3)* %new_p, %new_q - unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + unsigned NewAS = getPtrOrVecOfPtrsAddressSpace(NewV->getType()); int SrcIdx = U.getOperandNo(); int OtherIdx = (SrcIdx == 0) ? 1 : 0; Value *OtherSrc = Cmp->getOperand(OtherIdx); if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) { - if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) { + if (getPtrOrVecOfPtrsAddressSpace(OtherNewV->getType()) == NewAS) { Cmp->setOperand(OtherIdx, OtherNewV); Cmp->setOperand(SrcIdx, NewV); continue; @@ -1236,11 +1275,10 @@ } if (AddrSpaceCastInst *ASC = dyn_cast(CurUser)) { - unsigned NewAS = NewV->getType()->getPointerAddressSpace(); - if (ASC->getDestAddressSpace() == NewAS) { - if (!cast(ASC->getType()) - ->hasSameElementTypeAs( - cast(NewV->getType()))) { + unsigned NewAS = getPtrOrVecOfPtrsAddressSpace(NewV->getType()); + if (getPtrOrVecOfPtrsAddressSpace(ASC->getType()) == NewAS) { + if (!hasSameElementOfPtrOrVecPtrs(ASC->getType(), + NewV->getType())) { BasicBlock::iterator InsertPos; if (Instruction *NewVInst = dyn_cast(NewV)) InsertPos = std::next(NewVInst->getIterator());