Index: include/polly/Support/ScopHelper.h =================================================================== --- include/polly/Support/ScopHelper.h +++ include/polly/Support/ScopHelper.h @@ -205,6 +205,10 @@ bool isNull() const { return !I; } bool isInstruction() const { return I; } + + llvm::Instruction *asInstruction() const { return I; } + +private: bool isLoad() const { return I && llvm::isa(I); } bool isStore() const { return I && llvm::isa(I); } bool isCallInst() const { return I && llvm::isa(I); } @@ -214,7 +218,6 @@ return I && llvm::isa(I); } - llvm::Instruction *asInstruction() const { return I; } llvm::LoadInst *asLoad() const { return llvm::cast(I); } llvm::StoreInst *asStore() const { return llvm::cast(I); } llvm::CallInst *asCallInst() const { return llvm::cast(I); } @@ -228,6 +231,18 @@ return llvm::cast(I); } }; +} + +namespace llvm { +/// @brief Specialize simplify_type for MemAccInst to enable dyn_cast and cast +/// from a MemAccInst object. +template <> struct simplify_type { + typedef Instruction *SimpleType; + static SimpleType getSimplifiedValue(polly::MemAccInst &I) { return I.get(); } +}; +} + +namespace polly { /// @brief Check if the PHINode has any incoming Invoke edge. /// Index: lib/Analysis/ScopDetection.cpp =================================================================== --- lib/Analysis/ScopDetection.cpp +++ lib/Analysis/ScopDetection.cpp @@ -980,8 +980,8 @@ // Check the access function. if (auto MemInst = MemAccInst::dyn_cast(Inst)) { - Context.hasStores |= MemInst.isLoad(); - Context.hasLoads |= MemInst.isStore(); + Context.hasStores |= isa(MemInst); + Context.hasLoads |= isa(MemInst); if (!MemInst.isSimple()) return invalid(Context, /*Assert=*/true, &Inst); Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -616,9 +616,7 @@ } void MemoryAccess::buildMemIntrinsicAccessRelation() { - auto MAI = MemAccInst(getAccessInstruction()); - (void)MAI; - assert(MAI.isMemIntrinsic()); + assert(isa(MemAccInst(getAccessInstruction()))); assert(Subscripts.size() == 2 && Sizes.size() == 0); auto *SubscriptPWA = Statement->getPwAff(Subscripts[0]); @@ -646,7 +644,7 @@ ScalarEvolution *SE = Statement->getParent()->getSE(); auto MAI = MemAccInst(getAccessInstruction()); - if (MAI.isMemIntrinsic()) + if (isa(MAI)) return; Value *Ptr = MAI.getPointerOperand(); @@ -2613,8 +2611,8 @@ if (!MA->isRead()) HasWriteAccess.insert(MA->getBaseAddr()); MemAccInst Acc(MA->getAccessInstruction()); - if (MA->isRead() && Acc.isMemTransferInst()) - PtrToAcc[Acc.asMemTransferInst()->getSource()] = MA; + if (MA->isRead() && isa(Acc)) + PtrToAcc[cast(Acc)->getSource()] = MA; else PtrToAcc[Acc.getPointerOperand()] = MA; AST.add(Acc); @@ -3849,7 +3847,7 @@ const SCEVUnknown *BasePointer = dyn_cast(SE->getPointerBase(AccessFunction)); enum MemoryAccess::AccessType Type = - Inst.isLoad() ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; + isa(Inst) ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; if (isa(Address) || isa(Address)) { auto *NewAddress = Address; @@ -3902,7 +3900,7 @@ Type *ElementType = Val->getType(); unsigned ElementSize = DL->getTypeAllocSize(ElementType); enum MemoryAccess::AccessType Type = - Inst.isLoad() ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; + isa(Inst) ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; const SCEV *AccessFunction = SE->getSCEVAtScope(Address, L); const SCEVUnknown *BasePointer = @@ -3939,10 +3937,12 @@ MemAccInst Inst, Loop *L, Region *R, const ScopDetection::BoxedLoopsSetTy *BoxedLoops, const InvariantLoadsSetTy &ScopRIL) { - if (!Inst.isMemIntrinsic()) + auto *MemIntr = dyn_cast_or_null(Inst); + + if (MemIntr == nullptr) return false; - auto *LengthVal = SE->getSCEVAtScope(Inst.asMemIntrinsic()->getLength(), L); + auto *LengthVal = SE->getSCEVAtScope(MemIntr->getLength(), L); assert(LengthVal); // Check if the length val is actually affine or if we overapproximate it @@ -3954,7 +3954,7 @@ if (!LengthIsAffine) LengthVal = nullptr; - auto *DestPtrVal = Inst.asMemIntrinsic()->getDest(); + auto *DestPtrVal = MemIntr->getDest(); assert(DestPtrVal); auto *DestAccFunc = SE->getSCEVAtScope(DestPtrVal, L); assert(DestAccFunc); @@ -3965,10 +3965,11 @@ IntegerType::getInt8Ty(DestPtrVal->getContext()), false, {DestAccFunc, LengthVal}, {}, Inst.getValueOperand()); - if (!Inst.isMemTransferInst()) + auto *MemTrans = dyn_cast(MemIntr); + if (!MemTrans) return true; - auto *SrcPtrVal = Inst.asMemTransferInst()->getSource(); + auto *SrcPtrVal = MemTrans->getSource(); assert(SrcPtrVal); auto *SrcAccFunc = SE->getSCEVAtScope(SrcPtrVal, L); assert(SrcAccFunc); @@ -3986,30 +3987,31 @@ MemAccInst Inst, Loop *L, Region *R, const ScopDetection::BoxedLoopsSetTy *BoxedLoops, const InvariantLoadsSetTy &ScopRIL) { - if (!Inst.isCallInst()) + auto *CI = dyn_cast_or_null(Inst); + + if (CI == nullptr) return false; - auto &CI = *Inst.asCallInst(); - if (CI.doesNotAccessMemory() || isIgnoredIntrinsic(&CI)) + if (CI->doesNotAccessMemory() || isIgnoredIntrinsic(CI)) return true; bool ReadOnly = false; - auto *AF = SE->getConstant(IntegerType::getInt64Ty(CI.getContext()), 0); - auto *CalledFunction = CI.getCalledFunction(); + auto *AF = SE->getConstant(IntegerType::getInt64Ty(CI->getContext()), 0); + auto *CalledFunction = CI->getCalledFunction(); switch (AA->getModRefBehavior(CalledFunction)) { case llvm::FMRB_UnknownModRefBehavior: llvm_unreachable("Unknown mod ref behaviour cannot be represented."); case llvm::FMRB_DoesNotAccessMemory: return true; case llvm::FMRB_OnlyReadsMemory: - GlobalReads.push_back(&CI); + GlobalReads.push_back(CI); return true; case llvm::FMRB_OnlyReadsArgumentPointees: ReadOnly = true; // Fall through case llvm::FMRB_OnlyAccessesArgumentPointees: auto AccType = ReadOnly ? MemoryAccess::READ : MemoryAccess::MAY_WRITE; - for (const auto &Arg : CI.arg_operands()) { + for (const auto &Arg : CI->arg_operands()) { if (!Arg->getType()->isPointerTy()) continue; @@ -4019,7 +4021,7 @@ auto *ArgBasePtr = cast(SE->getPointerBase(ArgSCEV)); addArrayAccess(Inst, AccType, ArgBasePtr->getValue(), - ArgBasePtr->getType(), false, {AF}, {}, &CI); + ArgBasePtr->getType(), false, {AF}, {}, CI); } return true; } @@ -4035,7 +4037,7 @@ Value *Val = Inst.getValueOperand(); Type *ElementType = Val->getType(); enum MemoryAccess::AccessType Type = - Inst.isLoad() ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; + isa(Inst) ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; const SCEV *AccessFunction = SE->getSCEVAtScope(Address, L); const SCEVUnknown *BasePointer =