Index: include/polly/ScopInfo.h =================================================================== --- include/polly/ScopInfo.h +++ include/polly/ScopInfo.h @@ -1644,8 +1644,8 @@ /// delete the last object that creates isl objects with the context. std::shared_ptr IslCtx; - /// A map from basic blocks to SCoP statements. - DenseMap StmtMap; + /// A map from instructions to SCoP statements. + DenseMap StmtMap; /// A map from basic blocks to their domains. DenseMap DomainMap; @@ -1877,12 +1877,16 @@ /// of error statements and those only reachable via error statements will be /// replaced by an empty set. Later those will be removed completely. /// - /// @param R The currently traversed region. - /// @param DT The DominatorTree for the current function. - /// @param LI The LoopInfo for the current function. - /// + /// @param R The currently traversed region. + /// @param DT The DominatorTree for the current function. + /// @param LI The LoopInfo for the current function. + /// @param InvalidDomainMap BB to InvalidDomain map for the BB of current + /// region. + // /// @returns True if there was no problem and false otherwise. - bool propagateInvalidStmtDomains(Region *R, DominatorTree &DT, LoopInfo &LI); + bool propagateInvalidStmtDomains( + Region *R, DominatorTree &DT, LoopInfo &LI, + DenseMap &InvalidDomainMap); /// Compute the domain for each basic block in @p R. /// @@ -2507,6 +2511,10 @@ /// Get an isl string representing the invalid context. std::string getInvalidContextStr() const; + /// Return the ScopStmt an instruction belongs to, or nullptr if it + /// does not belong to any statement in this Scop. + ScopStmt *getStmtFor(Instruction *Inst) const; + /// Return the ScopStmt for the given @p BB or nullptr if there is /// none. ScopStmt *getStmtFor(BasicBlock *BB) const; @@ -2520,13 +2528,6 @@ /// simplifications. ScopStmt *getStmtFor(RegionNode *RN) const; - /// Return the ScopStmt an instruction belongs to, or nullptr if it - /// does not belong to any statement in this Scop. - ScopStmt *getStmtFor(Instruction *Inst) const { - return getStmtFor(Inst->getParent()); - } - - /// Return the number of statements in the SCoP. size_t getSize() const { return Stmts.size(); } /// @name Statements Iterators Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -1324,7 +1324,6 @@ MAL.emplace_front(Access); } else if (Access->isValueKind() && Access->isWrite()) { Instruction *AccessVal = cast(Access->getAccessValue()); - assert(Parent.getStmtFor(AccessVal) == this); assert(!ValueWrites.lookup(AccessVal)); ValueWrites[AccessVal] = Access; @@ -2620,10 +2619,10 @@ L = L->getParentLoop(); } - // Initialize the invalid domain. - auto *EntryStmt = getStmtFor(EntryBB); - EntryStmt->setInvalidDomain(isl_set_empty(isl_set_get_space(S))); + /// A map from basic blocks to their invalid domains. + DenseMap InvalidDomainMap; + InvalidDomainMap[EntryBB] = isl_set_empty(isl_set_get_space(S)); DomainMap[EntryBB] = S; if (IsOnlyNonAffineRegion) @@ -2647,9 +2646,16 @@ // with an empty set. Additionally, we will record for each block under which // parameter combination it would be reached via an error block in its // InvalidDomain. This information is needed during load hoisting. - if (!propagateInvalidStmtDomains(R, DT, LI)) + if (!propagateInvalidStmtDomains(R, DT, LI, InvalidDomainMap)) return false; + // Initialize the invalid domain. + for (ScopStmt &Stmt : Stmts) + if (Stmt.isRegionStmt()) + Stmt.setInvalidDomain(InvalidDomainMap[Stmt.getRegion()->getEntry()]); + else + Stmt.setInvalidDomain(InvalidDomainMap[Stmt.getBasicBlock()]); + return true; } @@ -2703,8 +2709,10 @@ return Dom; } -bool Scop::propagateInvalidStmtDomains(Region *R, DominatorTree &DT, - LoopInfo &LI) { +bool Scop::propagateInvalidStmtDomains( + Region *R, DominatorTree &DT, LoopInfo &LI, + DenseMap &InvalidDomainMap) { + ReversePostOrderTraversal RTraversal(R); for (auto *RN : RTraversal) { @@ -2713,18 +2721,17 @@ if (RN->isSubRegion()) { Region *SubRegion = RN->getNodeAs(); if (!isNonAffineSubRegion(SubRegion)) { - propagateInvalidStmtDomains(SubRegion, DT, LI); + propagateInvalidStmtDomains(SubRegion, DT, LI, InvalidDomainMap); continue; } } bool ContainsErrorBlock = containsErrorBlock(RN, getRegion(), LI, DT); BasicBlock *BB = getRegionNodeBasicBlock(RN); - ScopStmt *Stmt = getStmtFor(BB); isl_set *&Domain = DomainMap[BB]; assert(Domain && "Cannot propagate a nullptr"); - auto *InvalidDomain = Stmt->getInvalidDomain(); + auto *InvalidDomain = InvalidDomainMap[BB]; bool IsInvalidBlock = ContainsErrorBlock || isl_set_is_subset(Domain, InvalidDomain); @@ -2740,7 +2747,11 @@ } if (isl_set_is_empty(InvalidDomain)) { - Stmt->setInvalidDomain(InvalidDomain); + for (ScopStmt &Stmt : Stmts) + if (Stmt.isRegionStmt()) + InvalidDomainMap[Stmt.getRegion()->getEntry()] = InvalidDomain; + else + InvalidDomainMap[Stmt.getBasicBlock()] = InvalidDomain; continue; } @@ -2749,7 +2760,7 @@ unsigned NumSuccs = RN->isSubRegion() ? 1 : TI->getNumSuccessors(); for (unsigned u = 0; u < NumSuccs; u++) { auto *SuccBB = getRegionNodeSuccessor(RN, TI, u); - auto *SuccStmt = getStmtFor(SuccBB); + auto *SuccStmt = getStmtFor(&SuccBB->front()); // Skip successors outside the SCoP. if (!SuccStmt) @@ -2778,10 +2789,7 @@ invalidate(COMPLEXITY, TI->getDebugLoc()); return false; } - - Stmt->setInvalidDomain(InvalidDomain); } - return true; } @@ -2812,7 +2820,7 @@ auto *Domain = DomainMap[BB]; assert(Domain && "Cannot propagate a nullptr"); - auto *ExitStmt = getStmtFor(ExitBB); + auto *ExitStmt = getStmtFor(&ExitBB->back()); auto *ExitBBLoop = ExitStmt->getSurroundingLoop(); // Since the dimensions of @p BB and @p ExitBB might be different we have to @@ -2909,8 +2917,8 @@ for (unsigned u = 0, e = ConditionSets.size(); u < e; u++) { isl_set *CondSet = ConditionSets[u]; BasicBlock *SuccBB = getRegionNodeSuccessor(RN, TI, u); + auto *SuccStmt = getStmtFor(&SuccBB->front()); - auto *SuccStmt = getStmtFor(SuccBB); // Skip blocks outside the region. if (!SuccStmt) { isl_set_free(CondSet); @@ -2974,7 +2982,7 @@ // The region info of this function. auto &RI = *R.getRegionInfo(); - auto *BBLoop = getStmtFor(BB)->getSurroundingLoop(); + auto *BBLoop = getStmtFor(&BB->front())->getSurroundingLoop(); // A domain to collect all predecessor domains, thus all conditions under // which the block is executed. To this end we start with the empty domain. @@ -3010,7 +3018,8 @@ } auto *PredBBDom = getDomainConditions(PredBB); - auto *PredBBLoop = getStmtFor(PredBB)->getSurroundingLoop(); + auto *PredBBLoop = getStmtFor(&PredBB->back())->getSurroundingLoop(); + PredBBDom = adjustDomainDimensions(*this, PredBBDom, PredBBLoop, BBLoop); PredDom = isl_set_union(PredDom, PredBBDom); @@ -3690,9 +3699,11 @@ // Remove the statement because it is unnecessary. if (Stmt.isRegionStmt()) for (BasicBlock *BB : Stmt.getRegion()->blocks()) - StmtMap.erase(BB); + for (Instruction &Inst : *BB) + StmtMap.erase(&Inst); else - StmtMap.erase(Stmt.getBasicBlock()); + for (Instruction &Inst : *Stmt.getBasicBlock()) + StmtMap.erase(&Inst); StmtIt = Stmts.erase(StmtIt); } @@ -4666,7 +4677,8 @@ assert(BB && "Unexpected nullptr!"); Stmts.emplace_back(*this, *BB, SurroundingLoop, Instructions); auto *Stmt = &Stmts.back(); - StmtMap[BB] = Stmt; + for (Instruction *Inst : Instructions) + StmtMap[Inst] = Stmt; } void Scop::addScopStmt(Region *R, Loop *SurroundingLoop) { @@ -4674,7 +4686,8 @@ Stmts.emplace_back(*this, *R, SurroundingLoop); auto *Stmt = &Stmts.back(); for (BasicBlock *BB : R->blocks()) - StmtMap[BB] = Stmt; + for (Instruction &Inst : *BB) + StmtMap[&Inst] = Stmt; } ScopStmt *Scop::addScopStmt(__isl_take isl_map *SourceRel, @@ -4822,8 +4835,15 @@ } } +ScopStmt *Scop::getStmtFor(Instruction *Inst) const { + auto StmtMapIt = StmtMap.find(Inst); + if (StmtMapIt == StmtMap.end()) + return nullptr; + return StmtMapIt->second; +} + ScopStmt *Scop::getStmtFor(BasicBlock *BB) const { - auto StmtMapIt = StmtMap.find(BB); + auto StmtMapIt = StmtMap.find(&BB->front()); if (StmtMapIt == StmtMap.end()) return nullptr; return StmtMapIt->second;