Index: include/polly/ScopInfo.h =================================================================== --- include/polly/ScopInfo.h +++ include/polly/ScopInfo.h @@ -2561,6 +2561,9 @@ /// none. ScopStmt *getStmtFor(BasicBlock *BB) const; + /// Return the list of ScopStmts that represent the given @p BB. + std::vector getStmtListFor(BasicBlock *BB) const; + /// Return the last statement representing @p BB. /// /// Of the sequence of statements that represent a @p BB, this is the last one Index: lib/Analysis/PolyhedralInfo.cpp =================================================================== --- lib/Analysis/PolyhedralInfo.cpp +++ lib/Analysis/PolyhedralInfo.cpp @@ -127,22 +127,22 @@ assert(CurrDim >= 0 && "Loop in region should have at least depth one"); for (auto *BB : L->blocks()) { - auto *SS = S->getStmtFor(BB); - if (!SS) - continue; - - unsigned int MaxDim = SS->getNumIterators(); - DEBUG(dbgs() << "Maximum depth of Stmt:\t" << MaxDim << "\n"); - auto *ScheduleMap = SS->getSchedule(); - assert(ScheduleMap && - "Schedules that contain extension nodes require special handling."); - - ScheduleMap = isl_map_project_out(ScheduleMap, isl_dim_out, CurrDim + 1, - MaxDim - CurrDim - 1); - ScheduleMap = - isl_map_set_tuple_id(ScheduleMap, isl_dim_in, SS->getDomainId()); - Schedule = - isl_union_map_union(Schedule, isl_union_map_from_map(ScheduleMap)); + for (auto *SS : S->getStmtListFor(BB)) { + + unsigned int MaxDim = SS->getNumIterators(); + DEBUG(dbgs() << "Maximum depth of Stmt:\t" << MaxDim << "\n"); + auto *ScheduleMap = SS->getSchedule(); + assert( + ScheduleMap && + "Schedules that contain extension nodes require special handling."); + + ScheduleMap = isl_map_project_out(ScheduleMap, isl_dim_out, CurrDim + 1, + MaxDim - CurrDim - 1); + ScheduleMap = + isl_map_set_tuple_id(ScheduleMap, isl_dim_in, SS->getDomainId()); + Schedule = + isl_union_map_union(Schedule, isl_union_map_from_map(ScheduleMap)); + } } Schedule = isl_union_map_coalesce(Schedule); Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -4909,6 +4909,14 @@ return StmtMapIt->second.front(); } +std::vector Scop::getStmtListFor(BasicBlock *BB) const { + auto StmtMapIt = StmtMap.find(BB); + if (StmtMapIt == StmtMap.end()) + return {}; + assert(StmtMapIt->second.size() == 1); + return StmtMapIt->second; +} + ScopStmt *Scop::getStmtFor(RegionNode *RN) const { if (RN->isSubRegion()) return getStmtFor(RN->getNodeAs());