Index: polly/trunk/include/polly/Support/ScopHelper.h =================================================================== --- polly/trunk/include/polly/Support/ScopHelper.h +++ polly/trunk/include/polly/Support/ScopHelper.h @@ -242,6 +242,10 @@ bool isNull() const { return !I; } bool isInstruction() const { return I; } + + llvm::Instruction *asInstruction() const { return I; } + +private: bool isLoad() const { return I && llvm::isa(I); } bool isStore() const { return I && llvm::isa(I); } bool isCallInst() const { return I && llvm::isa(I); } @@ -251,7 +255,6 @@ return I && llvm::isa(I); } - llvm::Instruction *asInstruction() const { return I; } llvm::LoadInst *asLoad() const { return llvm::cast(I); } llvm::StoreInst *asStore() const { return llvm::cast(I); } llvm::CallInst *asCallInst() const { return llvm::cast(I); } @@ -265,6 +268,20 @@ return llvm::cast(I); } }; +} + +namespace llvm { +/// @brief Specialize simplify_type for MemAccInst to enable dyn_cast and cast +/// from a MemAccInst object. +template <> struct simplify_type { + typedef Instruction *SimpleType; + static SimpleType getSimplifiedValue(polly::MemAccInst &I) { + return I.asInstruction(); + } +}; +} + +namespace polly { /// @brief Check if the PHINode has any incoming Invoke edge. /// Index: polly/trunk/lib/Analysis/ScopDetection.cpp =================================================================== --- polly/trunk/lib/Analysis/ScopDetection.cpp +++ polly/trunk/lib/Analysis/ScopDetection.cpp @@ -989,8 +989,8 @@ // Check the access function. if (auto MemInst = MemAccInst::dyn_cast(Inst)) { - Context.hasStores |= MemInst.isStore(); - Context.hasLoads |= MemInst.isLoad(); + Context.hasStores |= isa(MemInst); + Context.hasLoads |= isa(MemInst); if (!MemInst.isSimple()) return invalid(Context, /*Assert=*/true, &Inst); Index: polly/trunk/lib/Analysis/ScopInfo.cpp =================================================================== --- polly/trunk/lib/Analysis/ScopInfo.cpp +++ polly/trunk/lib/Analysis/ScopInfo.cpp @@ -616,9 +616,7 @@ } void MemoryAccess::buildMemIntrinsicAccessRelation() { - auto MAI = MemAccInst(getAccessInstruction()); - (void)MAI; - assert(MAI.isMemIntrinsic()); + assert(isa(getAccessInstruction())); assert(Subscripts.size() == 2 && Sizes.size() == 0); auto *SubscriptPWA = Statement->getPwAff(Subscripts[0]); @@ -646,7 +644,7 @@ ScalarEvolution *SE = Statement->getParent()->getSE(); auto MAI = MemAccInst(getAccessInstruction()); - if (MAI.isMemIntrinsic()) + if (isa(MAI)) return; Value *Ptr = MAI.getPointerOperand(); @@ -2613,8 +2611,8 @@ if (!MA->isRead()) HasWriteAccess.insert(MA->getBaseAddr()); MemAccInst Acc(MA->getAccessInstruction()); - if (MA->isRead() && Acc.isMemTransferInst()) - PtrToAcc[Acc.asMemTransferInst()->getSource()] = MA; + if (MA->isRead() && isa(Acc)) + PtrToAcc[cast(Acc)->getSource()] = MA; else PtrToAcc[Acc.getPointerOperand()] = MA; AST.add(Acc); @@ -3850,7 +3848,7 @@ const SCEVUnknown *BasePointer = dyn_cast(SE->getPointerBase(AccessFunction)); enum MemoryAccess::AccessType Type = - Inst.isLoad() ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; + isa(Inst) ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; if (auto *BitCast = dyn_cast(Address)) { auto *Src = BitCast->getOperand(0); @@ -3905,7 +3903,7 @@ Type *ElementType = Val->getType(); unsigned ElementSize = DL->getTypeAllocSize(ElementType); enum MemoryAccess::AccessType Type = - Inst.isLoad() ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; + isa(Inst) ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; const SCEV *AccessFunction = SE->getSCEVAtScope(Address, L); const SCEVUnknown *BasePointer = @@ -3942,10 +3940,12 @@ MemAccInst Inst, Loop *L, Region *R, const ScopDetection::BoxedLoopsSetTy *BoxedLoops, const InvariantLoadsSetTy &ScopRIL) { - if (!Inst.isMemIntrinsic()) + auto *MemIntr = dyn_cast_or_null(Inst); + + if (MemIntr == nullptr) return false; - auto *LengthVal = SE->getSCEVAtScope(Inst.asMemIntrinsic()->getLength(), L); + auto *LengthVal = SE->getSCEVAtScope(MemIntr->getLength(), L); assert(LengthVal); // Check if the length val is actually affine or if we overapproximate it @@ -3957,7 +3957,7 @@ if (!LengthIsAffine) LengthVal = nullptr; - auto *DestPtrVal = Inst.asMemIntrinsic()->getDest(); + auto *DestPtrVal = MemIntr->getDest(); assert(DestPtrVal); auto *DestAccFunc = SE->getSCEVAtScope(DestPtrVal, L); assert(DestAccFunc); @@ -3968,10 +3968,11 @@ IntegerType::getInt8Ty(DestPtrVal->getContext()), false, {DestAccFunc, LengthVal}, {}, Inst.getValueOperand()); - if (!Inst.isMemTransferInst()) + auto *MemTrans = dyn_cast(MemIntr); + if (!MemTrans) return true; - auto *SrcPtrVal = Inst.asMemTransferInst()->getSource(); + auto *SrcPtrVal = MemTrans->getSource(); assert(SrcPtrVal); auto *SrcAccFunc = SE->getSCEVAtScope(SrcPtrVal, L); assert(SrcAccFunc); @@ -3989,30 +3990,31 @@ MemAccInst Inst, Loop *L, Region *R, const ScopDetection::BoxedLoopsSetTy *BoxedLoops, const InvariantLoadsSetTy &ScopRIL) { - if (!Inst.isCallInst()) + auto *CI = dyn_cast_or_null(Inst); + + if (CI == nullptr) return false; - auto &CI = *Inst.asCallInst(); - if (CI.doesNotAccessMemory() || isIgnoredIntrinsic(&CI)) + if (CI->doesNotAccessMemory() || isIgnoredIntrinsic(CI)) return true; bool ReadOnly = false; - auto *AF = SE->getConstant(IntegerType::getInt64Ty(CI.getContext()), 0); - auto *CalledFunction = CI.getCalledFunction(); + auto *AF = SE->getConstant(IntegerType::getInt64Ty(CI->getContext()), 0); + auto *CalledFunction = CI->getCalledFunction(); switch (AA->getModRefBehavior(CalledFunction)) { case llvm::FMRB_UnknownModRefBehavior: llvm_unreachable("Unknown mod ref behaviour cannot be represented."); case llvm::FMRB_DoesNotAccessMemory: return true; case llvm::FMRB_OnlyReadsMemory: - GlobalReads.push_back(&CI); + GlobalReads.push_back(CI); return true; case llvm::FMRB_OnlyReadsArgumentPointees: ReadOnly = true; // Fall through case llvm::FMRB_OnlyAccessesArgumentPointees: auto AccType = ReadOnly ? MemoryAccess::READ : MemoryAccess::MAY_WRITE; - for (const auto &Arg : CI.arg_operands()) { + for (const auto &Arg : CI->arg_operands()) { if (!Arg->getType()->isPointerTy()) continue; @@ -4022,7 +4024,7 @@ auto *ArgBasePtr = cast(SE->getPointerBase(ArgSCEV)); addArrayAccess(Inst, AccType, ArgBasePtr->getValue(), - ArgBasePtr->getType(), false, {AF}, {}, &CI); + ArgBasePtr->getType(), false, {AF}, {}, CI); } return true; } @@ -4038,7 +4040,7 @@ Value *Val = Inst.getValueOperand(); Type *ElementType = Val->getType(); enum MemoryAccess::AccessType Type = - Inst.isLoad() ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; + isa(Inst) ? MemoryAccess::READ : MemoryAccess::MUST_WRITE; const SCEV *AccessFunction = SE->getSCEVAtScope(Address, L); const SCEVUnknown *BasePointer =