Index: include/polly/ScopInfo.h =================================================================== --- include/polly/ScopInfo.h +++ include/polly/ScopInfo.h @@ -1317,6 +1317,9 @@ /// @brief A map from basic blocks to their domains. DenseMap DomainMap; + /// @brief A map from error blocks to their domains. + DenseMap ErrorDomainMap; + /// Constraints on parameters. isl_set *Context; @@ -1537,6 +1540,9 @@ /// @see isIgnored() void simplifySCoP(bool RemoveIgnoredStmts, DominatorTree &DT, LoopInfo &LI); + /// @brief TODO + __isl_give isl_set *getErrorCtxReachingStmt(ScopStmt &Stmt); + /// @brief Create equivalence classes for required invariant accesses. /// /// These classes will consolidate multiple required invariant loads from the Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -2127,7 +2127,10 @@ auto *CurrentDomain = DomainMap[ErrorChildBlock]; auto *Empty = isl_set_empty(isl_set_get_space(CurrentDomain)); DomainMap[ErrorChildBlock] = Empty; - isl_set_free(CurrentDomain); + if (!ErrorDomainMap.count(ErrorChildBlock)) + ErrorDomainMap[ErrorChildBlock] = CurrentDomain; + else + isl_set_free(CurrentDomain); } }; @@ -2932,6 +2935,8 @@ for (auto It : DomainMap) isl_set_free(It.second); + for (auto It : ErrorDomainMap) + isl_set_free(It.second); // Free the alias groups for (MinMaxVectorPairTy &MinMaxAccessPair : MinMaxAliasGroups) { @@ -3044,10 +3049,35 @@ return nullptr; } +isl_set *Scop::getErrorCtxReachingStmt(ScopStmt &Stmt) { + auto *ErrorCtx = isl_set_empty(getParamSpace()); + + SmallPtrSet FinishedBlocks; + SmallVector RemainingBlocks; + RemainingBlocks.push_back(Stmt.getEntryBlock()); + while (!RemainingBlocks.empty()) { + auto *BB = RemainingBlocks.pop_back_val(); + if (!R.contains(BB) || !FinishedBlocks.insert(BB).second) + continue; + RemainingBlocks.append(pred_begin(BB), pred_end(BB)); + + auto It = ErrorDomainMap.find(BB); + if (It == ErrorDomainMap.end()) + continue; + + auto *ErrorDomain = isl_set_copy(It->second); + ErrorCtx = isl_set_union(ErrorCtx, isl_set_params(ErrorDomain)); + } + + return ErrorCtx; +} + void Scop::addInvariantLoads(ScopStmt &Stmt, MemoryAccessList &InvMAs) { - // Get the context under which the statement is executed. + // Get the context under which the statement is executed but remove error + // block domains that reach this statement. isl_set *DomainCtx = isl_set_params(Stmt.getDomain()); + DomainCtx = isl_set_subtract(DomainCtx, getErrorCtxReachingStmt(Stmt)); DomainCtx = isl_set_remove_redundancies(DomainCtx); DomainCtx = isl_set_detect_equalities(DomainCtx); DomainCtx = isl_set_coalesce(DomainCtx);