Index: include/polly/ScopInfo.h =================================================================== --- include/polly/ScopInfo.h +++ include/polly/ScopInfo.h @@ -1644,8 +1644,8 @@ /// delete the last object that creates isl objects with the context. std::shared_ptr IslCtx; - /// A map from basic blocks to SCoP statements. - DenseMap StmtMap; + /// A map from instructions to SCoP statements. + DenseMap StmtMap; /// A map from basic blocks to their domains. DenseMap DomainMap; @@ -1818,14 +1818,16 @@ /// block in the @p FinishedExitBlocks set so we can later skip edges from /// within the region to that block. /// - /// @param BB The block for which the domain is currently propagated. - /// @param BBLoop The innermost affine loop surrounding @p BB. + /// @param BB The block for which the domain is currently propagated. + /// @param BBLoop The innermost affine loop surrounding @p BB. /// @param FinishedExitBlocks Set of region exits the domain was set for. - /// @param LI The LoopInfo for the current function. - /// + /// @param LI The LoopInfo for the current function. + /// @param InvalidDomainMap BB to InvalidDomain map for the BB of current + /// region. void propagateDomainConstraintsToRegionExit( BasicBlock *BB, Loop *BBLoop, - SmallPtrSetImpl &FinishedExitBlocks, LoopInfo &LI); + SmallPtrSetImpl &FinishedExitBlocks, LoopInfo &LI, + DenseMap &InvalidDomainMap); /// Compute the union of predecessor domains for @p BB. /// @@ -1853,13 +1855,16 @@ /// Compute the branching constraints for each basic block in @p R. /// - /// @param R The region we currently build branching conditions for. - /// @param DT The DominatorTree for the current function. - /// @param LI The LoopInfo for the current function. + /// @param R The region we currently build branching conditions for. + /// @param DT The DominatorTree for the current function. + /// @param LI The LoopInfo for the current function. + /// @param InvalidDomainMap BB to InvalidDomain map for the BB of current + /// region. /// /// @returns True if there was no problem and false otherwise. - bool buildDomainsWithBranchConstraints(Region *R, DominatorTree &DT, - LoopInfo &LI); + bool buildDomainsWithBranchConstraints( + Region *R, DominatorTree &DT, LoopInfo &LI, + DenseMap &InvalidDomainMap); /// Propagate the domain constraints through the region @p R. /// @@ -1877,12 +1882,16 @@ /// of error statements and those only reachable via error statements will be /// replaced by an empty set. Later those will be removed completely. /// - /// @param R The currently traversed region. - /// @param DT The DominatorTree for the current function. - /// @param LI The LoopInfo for the current function. - /// + /// @param R The currently traversed region. + /// @param DT The DominatorTree for the current function. + /// @param LI The LoopInfo for the current function. + /// @param InvalidDomainMap BB to InvalidDomain map for the BB of current + /// region. + // /// @returns True if there was no problem and false otherwise. - bool propagateInvalidStmtDomains(Region *R, DominatorTree &DT, LoopInfo &LI); + bool propagateInvalidStmtDomains( + Region *R, DominatorTree &DT, LoopInfo &LI, + DenseMap &InvalidDomainMap); /// Compute the domain for each basic block in @p R. /// @@ -2507,6 +2516,12 @@ /// Get an isl string representing the invalid context. std::string getInvalidContextStr() const; + /// Return the First ScopStmt for the given @p BB or nullptr if there is none. + ScopStmt *getFirstStmtFor(BasicBlock *BB) const; + + /// Return the Last ScopStmt for the given @p BB or nullptr id there is none. + ScopStmt *getLastStmtFor(BasicBlock *BB) const; + /// Return the ScopStmt for the given @p BB or nullptr if there is /// none. ScopStmt *getStmtFor(BasicBlock *BB) const; @@ -2521,12 +2536,9 @@ ScopStmt *getStmtFor(RegionNode *RN) const; /// Return the ScopStmt an instruction belongs to, or nullptr if it - /// does not belong to any statement in this Scop. - ScopStmt *getStmtFor(Instruction *Inst) const { - return getStmtFor(Inst->getParent()); - } + /// does not belong to any statement in this Scop. + ScopStmt *getStmtFor(Instruction *Inst) const; - /// Return the number of statements in the SCoP. size_t getSize() const { return Stmts.size(); } /// @name Statements Iterators Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -2620,16 +2620,16 @@ L = L->getParentLoop(); } - // Initialize the invalid domain. - auto *EntryStmt = getStmtFor(EntryBB); - EntryStmt->setInvalidDomain(isl_set_empty(isl_set_get_space(S))); + /// A map from basic blocks to their invalid domains. + DenseMap InvalidDomainMap; + InvalidDomainMap[EntryBB] = isl_set_empty(isl_set_get_space(S)); DomainMap[EntryBB] = S; if (IsOnlyNonAffineRegion) return !containsErrorBlock(R->getNode(), *R, LI, DT); - if (!buildDomainsWithBranchConstraints(R, DT, LI)) + if (!buildDomainsWithBranchConstraints(R, DT, LI, InvalidDomainMap)) return false; if (!propagateDomainConstraints(R, DT, LI)) @@ -2647,9 +2647,13 @@ // with an empty set. Additionally, we will record for each block under which // parameter combination it would be reached via an error block in its // InvalidDomain. This information is needed during load hoisting. - if (!propagateInvalidStmtDomains(R, DT, LI)) + if (!propagateInvalidStmtDomains(R, DT, LI, InvalidDomainMap)) return false; + // Initialize the invalid domain. + for (ScopStmt &Stmt : Stmts) + Stmt.setInvalidDomain(InvalidDomainMap[Stmt.getEntryBlock()]); + return true; } @@ -2703,8 +2707,10 @@ return Dom; } -bool Scop::propagateInvalidStmtDomains(Region *R, DominatorTree &DT, - LoopInfo &LI) { +bool Scop::propagateInvalidStmtDomains( + Region *R, DominatorTree &DT, LoopInfo &LI, + DenseMap &InvalidDomainMap) { + ReversePostOrderTraversal RTraversal(R); for (auto *RN : RTraversal) { @@ -2713,18 +2719,17 @@ if (RN->isSubRegion()) { Region *SubRegion = RN->getNodeAs(); if (!isNonAffineSubRegion(SubRegion)) { - propagateInvalidStmtDomains(SubRegion, DT, LI); + propagateInvalidStmtDomains(SubRegion, DT, LI, InvalidDomainMap); continue; } } bool ContainsErrorBlock = containsErrorBlock(RN, getRegion(), LI, DT); BasicBlock *BB = getRegionNodeBasicBlock(RN); - ScopStmt *Stmt = getStmtFor(BB); isl_set *&Domain = DomainMap[BB]; assert(Domain && "Cannot propagate a nullptr"); - auto *InvalidDomain = Stmt->getInvalidDomain(); + auto *InvalidDomain = InvalidDomainMap[BB]; bool IsInvalidBlock = ContainsErrorBlock || isl_set_is_subset(Domain, InvalidDomain); @@ -2740,7 +2745,7 @@ } if (isl_set_is_empty(InvalidDomain)) { - Stmt->setInvalidDomain(InvalidDomain); + InvalidDomainMap[BB] = InvalidDomain; continue; } @@ -2749,7 +2754,7 @@ unsigned NumSuccs = RN->isSubRegion() ? 1 : TI->getNumSuccessors(); for (unsigned u = 0; u < NumSuccs; u++) { auto *SuccBB = getRegionNodeSuccessor(RN, TI, u); - auto *SuccStmt = getStmtFor(SuccBB); + auto *SuccStmt = getFirstStmtFor(SuccBB); // Skip successors outside the SCoP. if (!SuccStmt) @@ -2762,12 +2767,12 @@ auto *SuccBBLoop = SuccStmt->getSurroundingLoop(); auto *AdjustedInvalidDomain = adjustDomainDimensions( *this, isl_set_copy(InvalidDomain), BBLoop, SuccBBLoop); - auto *SuccInvalidDomain = SuccStmt->getInvalidDomain(); + auto *SuccInvalidDomain = InvalidDomainMap[SuccBB]; SuccInvalidDomain = isl_set_union(SuccInvalidDomain, AdjustedInvalidDomain); SuccInvalidDomain = isl_set_coalesce(SuccInvalidDomain); unsigned NumConjucts = isl_set_n_basic_set(SuccInvalidDomain); - SuccStmt->setInvalidDomain(SuccInvalidDomain); + InvalidDomainMap[SuccBB] = SuccInvalidDomain; // Check if the maximal number of domain disjunctions was reached. // In case this happens we will bail. @@ -2778,16 +2783,15 @@ invalidate(COMPLEXITY, TI->getDebugLoc()); return false; } - - Stmt->setInvalidDomain(InvalidDomain); + InvalidDomainMap[BB] = InvalidDomain; } - return true; } void Scop::propagateDomainConstraintsToRegionExit( BasicBlock *BB, Loop *BBLoop, - SmallPtrSetImpl &FinishedExitBlocks, LoopInfo &LI) { + SmallPtrSetImpl &FinishedExitBlocks, LoopInfo &LI, + DenseMap &InvalidDomainMap) { // Check if the block @p BB is the entry of a region. If so we propagate it's // domain to the exit block of the region. Otherwise we are done. @@ -2812,7 +2816,7 @@ auto *Domain = DomainMap[BB]; assert(Domain && "Cannot propagate a nullptr"); - auto *ExitStmt = getStmtFor(ExitBB); + auto *ExitStmt = getLastStmtFor(ExitBB); auto *ExitBBLoop = ExitStmt->getSurroundingLoop(); // Since the dimensions of @p BB and @p ExitBB might be different we have to @@ -2827,13 +2831,15 @@ ExitDomain ? isl_set_union(AdjustedDomain, ExitDomain) : AdjustedDomain; // Initialize the invalid domain. - ExitStmt->setInvalidDomain(isl_set_empty(isl_set_get_space(ExitDomain))); + InvalidDomainMap[ExitBB] = isl_set_empty(isl_set_get_space(ExitDomain)); FinishedExitBlocks.insert(ExitBB); } -bool Scop::buildDomainsWithBranchConstraints(Region *R, DominatorTree &DT, - LoopInfo &LI) { +bool Scop::buildDomainsWithBranchConstraints( + Region *R, DominatorTree &DT, LoopInfo &LI, + DenseMap &InvalidDomainMap) { + // To create the domain for each block in R we iterate over all blocks and // subregions in R and propagate the conditions under which the current region // element is executed. To this end we iterate in reverse post order over R as @@ -2854,7 +2860,8 @@ if (RN->isSubRegion()) { Region *SubRegion = RN->getNodeAs(); if (!isNonAffineSubRegion(SubRegion)) { - if (!buildDomainsWithBranchConstraints(SubRegion, DT, LI)) + if (!buildDomainsWithBranchConstraints(SubRegion, DT, LI, + InvalidDomainMap)) return false; continue; } @@ -2877,7 +2884,8 @@ auto *BBLoop = getRegionNodeLoop(RN, LI); // Propagate the domain from BB directly to blocks that have a superset // domain, at the moment only region exit nodes of regions that start in BB. - propagateDomainConstraintsToRegionExit(BB, BBLoop, FinishedExitBlocks, LI); + propagateDomainConstraintsToRegionExit(BB, BBLoop, FinishedExitBlocks, LI, + InvalidDomainMap); // If all successors of BB have been set a domain through the propagation // above we do not need to build condition sets but can just skip this @@ -2897,7 +2905,7 @@ SmallVector ConditionSets; if (RN->isSubRegion()) ConditionSets.push_back(isl_set_copy(Domain)); - else if (!buildConditionSets(*getStmtFor(BB), TI, BBLoop, Domain, + else if (!buildConditionSets(*getFirstStmtFor(BB), TI, BBLoop, Domain, ConditionSets)) return false; @@ -2909,8 +2917,8 @@ for (unsigned u = 0, e = ConditionSets.size(); u < e; u++) { isl_set *CondSet = ConditionSets[u]; BasicBlock *SuccBB = getRegionNodeSuccessor(RN, TI, u); + auto *SuccStmt = getFirstStmtFor(SuccBB); - auto *SuccStmt = getStmtFor(SuccBB); // Skip blocks outside the region. if (!SuccStmt) { isl_set_free(CondSet); @@ -2942,7 +2950,7 @@ SuccDomain = isl_set_coalesce(isl_set_union(SuccDomain, CondSet)); } else { // Initialize the invalid domain. - SuccStmt->setInvalidDomain(isl_set_empty(isl_set_get_space(CondSet))); + InvalidDomainMap[SuccBB] = isl_set_empty(isl_set_get_space(CondSet)); SuccDomain = CondSet; } @@ -2974,7 +2982,7 @@ // The region info of this function. auto &RI = *R.getRegionInfo(); - auto *BBLoop = getStmtFor(BB)->getSurroundingLoop(); + auto *BBLoop = getFirstStmtFor(BB)->getSurroundingLoop(); // A domain to collect all predecessor domains, thus all conditions under // which the block is executed. To this end we start with the empty domain. @@ -3010,7 +3018,8 @@ } auto *PredBBDom = getDomainConditions(PredBB); - auto *PredBBLoop = getStmtFor(PredBB)->getSurroundingLoop(); + auto *PredBBLoop = getLastStmtFor(PredBB)->getSurroundingLoop(); + PredBBDom = adjustDomainDimensions(*this, PredBBDom, PredBBLoop, BBLoop); PredDom = isl_set_union(PredDom, PredBBDom); @@ -3690,9 +3699,11 @@ // Remove the statement because it is unnecessary. if (Stmt.isRegionStmt()) for (BasicBlock *BB : Stmt.getRegion()->blocks()) - StmtMap.erase(BB); + for (Instruction &Inst : *BB) + StmtMap.erase(&Inst); else - StmtMap.erase(Stmt.getBasicBlock()); + for (Instruction &Inst : *Stmt.getBasicBlock()) + StmtMap.erase(&Inst); StmtIt = Stmts.erase(StmtIt); } @@ -4666,7 +4677,8 @@ assert(BB && "Unexpected nullptr!"); Stmts.emplace_back(*this, *BB, SurroundingLoop, Instructions); auto *Stmt = &Stmts.back(); - StmtMap[BB] = Stmt; + for (Instruction *Inst : Instructions) + StmtMap[Inst] = Stmt; } void Scop::addScopStmt(Region *R, Loop *SurroundingLoop) { @@ -4674,7 +4686,8 @@ Stmts.emplace_back(*this, *R, SurroundingLoop); auto *Stmt = &Stmts.back(); for (BasicBlock *BB : R->blocks()) - StmtMap[BB] = Stmt; + for (Instruction &Inst : *BB) + StmtMap[&Inst] = Stmt; } ScopStmt *Scop::addScopStmt(__isl_take isl_map *SourceRel, @@ -4822,13 +4835,41 @@ } } -ScopStmt *Scop::getStmtFor(BasicBlock *BB) const { - auto StmtMapIt = StmtMap.find(BB); +ScopStmt *Scop::getStmtFor(Instruction *Inst) const { + auto StmtMapIt = StmtMap.find(Inst); if (StmtMapIt == StmtMap.end()) return nullptr; return StmtMapIt->second; } +ScopStmt *Scop::getFirstStmtFor(BasicBlock *BB) const { + for (Instruction &Inst : *BB) { + auto StmtMapIt = StmtMap.find(&Inst); + if (StmtMapIt != StmtMap.end()) + return StmtMapIt->second; + } + return nullptr; +} + +ScopStmt *Scop::getLastStmtFor(BasicBlock *BB) const { + for (BasicBlock::reverse_iterator It = BB->rbegin(), ItE = BB->rend(); + It != ItE; ++It) { + auto StmtMapIt = StmtMap.find(&*It); + if (StmtMapIt != StmtMap.end()) + return StmtMapIt->second; + } + return nullptr; +} + +ScopStmt *Scop::getStmtFor(BasicBlock *BB) const { + for (Instruction &Inst : *BB) { + auto StmtMapIt = StmtMap.find(&Inst); + if (StmtMapIt != StmtMap.end()) + return StmtMapIt->second; + } + return nullptr; +} + ScopStmt *Scop::getStmtFor(RegionNode *RN) const { if (RN->isSubRegion()) return getStmtFor(RN->getNodeAs());