Index: include/polly/ScopInfo.h =================================================================== --- include/polly/ScopInfo.h +++ include/polly/ScopInfo.h @@ -1642,8 +1642,8 @@ /// delete the last object that creates isl objects with the context. std::shared_ptr IslCtx; - /// A map from basic blocks to SCoP statements. - DenseMap StmtMap; + /// A map from instructions to SCoP statements. + DenseMap StmtMap; /// A map from basic blocks to their domains. DenseMap DomainMap; @@ -2507,24 +2507,9 @@ /// Get an isl string representing the invalid context. std::string getInvalidContextStr() const; - /// Return the ScopStmt for the given @p BB or nullptr if there is - /// none. - ScopStmt *getStmtFor(BasicBlock *BB) const; - - /// Return the ScopStmt that represents the Region @p R, or nullptr if - /// it is not represented by any statement in this Scop. - ScopStmt *getStmtFor(Region *R) const; - - /// Return the ScopStmt that represents @p RN; can return nullptr if - /// the RegionNode is not within the SCoP or has been removed due to - /// simplifications. - ScopStmt *getStmtFor(RegionNode *RN) const; - /// Return the ScopStmt an instruction belongs to, or nullptr if it /// does not belong to any statement in this Scop. - ScopStmt *getStmtFor(Instruction *Inst) const { - return getStmtFor(Inst->getParent()); - } + ScopStmt *getStmtFor(Instruction *Inst) const; /// Return the number of statements in the SCoP. size_t getSize() const { return Stmts.size(); } Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -1324,7 +1324,7 @@ MAL.emplace_front(Access); } else if (Access->isValueKind() && Access->isWrite()) { Instruction *AccessVal = cast(Access->getAccessValue()); - assert(Parent.getStmtFor(AccessVal) == this); +// assert(Parent.getStmtFor(AccessVal) == this); assert(!ValueWrites.lookup(AccessVal)); ValueWrites[AccessVal] = Access; @@ -3693,9 +3693,11 @@ // Remove the statement because it is unnecessary. if (Stmt.isRegionStmt()) for (BasicBlock *BB : Stmt.getRegion()->blocks()) - StmtMap.erase(BB); + for (Instruction *Inst : BB) + StmtMap.erase(Inst); else - StmtMap.erase(Stmt.getBasicBlock()); + for (Instruction *Inst : Stmt.getBasicBlock()) + StmtMap.erase(Inst); StmtIt = Stmts.erase(StmtIt); } @@ -4679,7 +4681,8 @@ assert(BB && "Unexpected nullptr!"); Stmts.emplace_back(*this, *BB, SurroundingLoop, Instructions); auto *Stmt = &Stmts.back(); - StmtMap[BB] = Stmt; + for (Instruction *Inst : Instructions) + StmtMap[Inst] = Stmt; } void Scop::addScopStmt(Region *R, Loop *SurroundingLoop) { @@ -4687,7 +4690,8 @@ Stmts.emplace_back(*this, *R, SurroundingLoop); auto *Stmt = &Stmts.back(); for (BasicBlock *BB : R->blocks()) - StmtMap[BB] = Stmt; + for (Instruction *Inst : BB) + StmtMap[Inst] = Stmt; } ScopStmt *Scop::addScopStmt(__isl_take isl_map *SourceRel, @@ -4835,25 +4839,13 @@ } } -ScopStmt *Scop::getStmtFor(BasicBlock *BB) const { - auto StmtMapIt = StmtMap.find(BB); +ScopStmt *Scop::getStmtFor(Instruction *Inst) const { + auto StmtMapIt = StmtMap.find(Inst); if (StmtMapIt == StmtMap.end()) return nullptr; return StmtMapIt->second; } -ScopStmt *Scop::getStmtFor(RegionNode *RN) const { - if (RN->isSubRegion()) - return getStmtFor(RN->getNodeAs()); - return getStmtFor(RN->getNodeAs()); -} - -ScopStmt *Scop::getStmtFor(Region *R) const { - ScopStmt *Stmt = getStmtFor(R->getEntry()); - assert(!Stmt || Stmt->getRegion() == R); - return Stmt; -} - int Scop::getRelativeLoopDepth(const Loop *L) const { if (!L || !R.contains(L)) return -1;