Index: include/polly/ScopInfo.h =================================================================== --- include/polly/ScopInfo.h +++ include/polly/ScopInfo.h @@ -1644,8 +1644,8 @@ /// delete the last object that creates isl objects with the context. std::shared_ptr IslCtx; - /// A map from basic blocks to SCoP statements. - DenseMap StmtMap; + /// A map from instructions to SCoP statements. + DenseMap StmtMap; /// A map from basic blocks to their domains. DenseMap DomainMap; @@ -1877,12 +1877,16 @@ /// of error statements and those only reachable via error statements will be /// replaced by an empty set. Later those will be removed completely. /// - /// @param R The currently traversed region. - /// @param DT The DominatorTree for the current function. - /// @param LI The LoopInfo for the current function. - /// + /// @param R The currently traversed region. + /// @param DT The DominatorTree for the current function. + /// @param LI The LoopInfo for the current function. + /// @param InvalidDomainMap BB to InvalidDomain map for the BB of current + /// region. + // /// @returns True if there was no problem and false otherwise. - bool propagateInvalidStmtDomains(Region *R, DominatorTree &DT, LoopInfo &LI); + bool propagateInvalidStmtDomains( + Region *R, DominatorTree &DT, LoopInfo &LI, + DenseMap &InvalidDomainMap); /// Compute the domain for each basic block in @p R. /// @@ -2507,26 +2511,10 @@ /// Get an isl string representing the invalid context. std::string getInvalidContextStr() const; - /// Return the ScopStmt for the given @p BB or nullptr if there is - /// none. - ScopStmt *getStmtFor(BasicBlock *BB) const; - - /// Return the ScopStmt that represents the Region @p R, or nullptr if - /// it is not represented by any statement in this Scop. - ScopStmt *getStmtFor(Region *R) const; - - /// Return the ScopStmt that represents @p RN; can return nullptr if - /// the RegionNode is not within the SCoP or has been removed due to - /// simplifications. - ScopStmt *getStmtFor(RegionNode *RN) const; - /// Return the ScopStmt an instruction belongs to, or nullptr if it - /// does not belong to any statement in this Scop. - ScopStmt *getStmtFor(Instruction *Inst) const { - return getStmtFor(Inst->getParent()); - } + /// does not belong to any statement in this Scop. + ScopStmt *getStmtFor(Instruction *Inst) const; - /// Return the number of statements in the SCoP. size_t getSize() const { return Stmts.size(); } /// @name Statements Iterators Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -1324,7 +1324,6 @@ MAL.emplace_front(Access); } else if (Access->isValueKind() && Access->isWrite()) { Instruction *AccessVal = cast(Access->getAccessValue()); - assert(Parent.getStmtFor(AccessVal) == this); assert(!ValueWrites.lookup(AccessVal)); ValueWrites[AccessVal] = Access; @@ -2620,9 +2619,13 @@ L = L->getParentLoop(); } + /// A map from basic blocks to their invalid domains. + DenseMap InvalidDomainMap; + // Initialize the invalid domain. - auto *EntryStmt = getStmtFor(EntryBB); - EntryStmt->setInvalidDomain(isl_set_empty(isl_set_get_space(S))); + InvalidDomainMap[EntryBB] = isl_set_empty(isl_set_get_space(S)); + for (ScopStmt &Stmt : Stmts) + Stmt.setInvalidDomain(InvalidDomainMap[Stmt.getBasicBlock()]); DomainMap[EntryBB] = S; @@ -2647,7 +2650,7 @@ // with an empty set. Additionally, we will record for each block under which // parameter combination it would be reached via an error block in its // InvalidDomain. This information is needed during load hoisting. - if (!propagateInvalidStmtDomains(R, DT, LI)) + if (!propagateInvalidStmtDomains(R, DT, LI, InvalidDomainMap)) return false; return true; @@ -2703,8 +2706,10 @@ return Dom; } -bool Scop::propagateInvalidStmtDomains(Region *R, DominatorTree &DT, - LoopInfo &LI) { +bool Scop::propagateInvalidStmtDomains( + Region *R, DominatorTree &DT, LoopInfo &LI, + DenseMap &InvalidDomainMap) { + ReversePostOrderTraversal RTraversal(R); for (auto *RN : RTraversal) { @@ -2713,18 +2718,17 @@ if (RN->isSubRegion()) { Region *SubRegion = RN->getNodeAs(); if (!isNonAffineSubRegion(SubRegion)) { - propagateInvalidStmtDomains(SubRegion, DT, LI); + propagateInvalidStmtDomains(SubRegion, DT, LI, InvalidDomainMap); continue; } } bool ContainsErrorBlock = containsErrorBlock(RN, getRegion(), LI, DT); BasicBlock *BB = getRegionNodeBasicBlock(RN); - ScopStmt *Stmt = getStmtFor(BB); isl_set *&Domain = DomainMap[BB]; assert(Domain && "Cannot propagate a nullptr"); - auto *InvalidDomain = Stmt->getInvalidDomain(); + auto *InvalidDomain = InvalidDomainMap[BB]; bool IsInvalidBlock = ContainsErrorBlock || isl_set_is_subset(Domain, InvalidDomain); @@ -2740,7 +2744,8 @@ } if (isl_set_is_empty(InvalidDomain)) { - Stmt->setInvalidDomain(InvalidDomain); + for (ScopStmt &Stmt : Stmts) + Stmt.setInvalidDomain(InvalidDomainMap[Stmt.getBasicBlock()]); continue; } @@ -2749,7 +2754,7 @@ unsigned NumSuccs = RN->isSubRegion() ? 1 : TI->getNumSuccessors(); for (unsigned u = 0; u < NumSuccs; u++) { auto *SuccBB = getRegionNodeSuccessor(RN, TI, u); - auto *SuccStmt = getStmtFor(SuccBB); + auto *SuccStmt = getStmtFor(&SuccBB->front()); // Skip successors outside the SCoP. if (!SuccStmt) @@ -2779,7 +2784,8 @@ return false; } - Stmt->setInvalidDomain(InvalidDomain); + for (ScopStmt &Stmt : Stmts) + Stmt.setInvalidDomain(InvalidDomainMap[Stmt.getBasicBlock()]); } return true; @@ -2812,7 +2818,7 @@ auto *Domain = DomainMap[BB]; assert(Domain && "Cannot propagate a nullptr"); - auto *ExitStmt = getStmtFor(ExitBB); + auto *ExitStmt = getStmtFor(&ExitBB->back()); auto *ExitBBLoop = ExitStmt->getSurroundingLoop(); // Since the dimensions of @p BB and @p ExitBB might be different we have to @@ -2909,8 +2915,8 @@ for (unsigned u = 0, e = ConditionSets.size(); u < e; u++) { isl_set *CondSet = ConditionSets[u]; BasicBlock *SuccBB = getRegionNodeSuccessor(RN, TI, u); + auto *SuccStmt = getStmtFor(&SuccBB->front()); - auto *SuccStmt = getStmtFor(SuccBB); // Skip blocks outside the region. if (!SuccStmt) { isl_set_free(CondSet); @@ -2974,7 +2980,7 @@ // The region info of this function. auto &RI = *R.getRegionInfo(); - auto *BBLoop = getStmtFor(BB)->getSurroundingLoop(); + auto *BBLoop = getStmtFor(&BB->front())->getSurroundingLoop(); // A domain to collect all predecessor domains, thus all conditions under // which the block is executed. To this end we start with the empty domain. @@ -3010,7 +3016,8 @@ } auto *PredBBDom = getDomainConditions(PredBB); - auto *PredBBLoop = getStmtFor(PredBB)->getSurroundingLoop(); + auto *PredBBLoop = getStmtFor(&PredBB->back())->getSurroundingLoop(); + PredBBDom = adjustDomainDimensions(*this, PredBBDom, PredBBLoop, BBLoop); PredDom = isl_set_union(PredDom, PredBBDom); @@ -3690,9 +3697,11 @@ // Remove the statement because it is unnecessary. if (Stmt.isRegionStmt()) for (BasicBlock *BB : Stmt.getRegion()->blocks()) - StmtMap.erase(BB); + for (Instruction &Inst : *BB) + StmtMap.erase(&Inst); else - StmtMap.erase(Stmt.getBasicBlock()); + for (Instruction &Inst : *Stmt.getBasicBlock()) + StmtMap.erase(&Inst); StmtIt = Stmts.erase(StmtIt); } @@ -4666,7 +4675,8 @@ assert(BB && "Unexpected nullptr!"); Stmts.emplace_back(*this, *BB, SurroundingLoop, Instructions); auto *Stmt = &Stmts.back(); - StmtMap[BB] = Stmt; + for (Instruction *Inst : Instructions) + StmtMap[Inst] = Stmt; } void Scop::addScopStmt(Region *R, Loop *SurroundingLoop) { @@ -4674,7 +4684,8 @@ Stmts.emplace_back(*this, *R, SurroundingLoop); auto *Stmt = &Stmts.back(); for (BasicBlock *BB : R->blocks()) - StmtMap[BB] = Stmt; + for (Instruction &Inst : *BB) + StmtMap[&Inst] = Stmt; } ScopStmt *Scop::addScopStmt(__isl_take isl_map *SourceRel, @@ -4822,25 +4833,13 @@ } } -ScopStmt *Scop::getStmtFor(BasicBlock *BB) const { - auto StmtMapIt = StmtMap.find(BB); +ScopStmt *Scop::getStmtFor(Instruction *Inst) const { + auto StmtMapIt = StmtMap.find(Inst); if (StmtMapIt == StmtMap.end()) return nullptr; return StmtMapIt->second; } -ScopStmt *Scop::getStmtFor(RegionNode *RN) const { - if (RN->isSubRegion()) - return getStmtFor(RN->getNodeAs()); - return getStmtFor(RN->getNodeAs()); -} - -ScopStmt *Scop::getStmtFor(Region *R) const { - ScopStmt *Stmt = getStmtFor(R->getEntry()); - assert(!Stmt || Stmt->getRegion() == R); - return Stmt; -} - int Scop::getRelativeLoopDepth(const Loop *L) const { if (!L || !R.contains(L)) return -1;