Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -596,8 +596,8 @@ /// second output dimension. /// @return True in case @p AccMap has the expected form and false, /// otherwise. -static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, - int &SecondPos) { +static bool isMatMulOperandAcc(isl::set Domain, isl::set Context, + isl::map AccMap, int &FirstPos, int &SecondPos) { isl::space Space = AccMap.get_space(); isl::map Universe = isl::map::universe(Space); @@ -618,8 +618,9 @@ Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0) .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1); - AccMap = AccMap.intersect_domain(Domain); - PossibleMatMul = PossibleMatMul.intersect_domain(Domain); + AccMap = AccMap.intersect_domain(Domain).intersect_params(Context); + PossibleMatMul = PossibleMatMul.intersect_domain(Domain).intersect_params(Context); + // If AccMap spans entire domain (Non-partial write), // compute FirstPos and SecondPos. @@ -657,15 +658,17 @@ return false; auto AccMap = MemAccess->getLatestAccessRelation(); isl::set StmtDomain = MemAccess->getStatement()->getDomain(); - if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) { + isl::set Context = MemAccess->getStatement()->getParent()->getContext(); + if (isMatMulOperandAcc(StmtDomain, Context, AccMap, MMI.i, MMI.j) && + !MMI.ReadFromC) { MMI.ReadFromC = MemAccess; return true; } - if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) { + if (isMatMulOperandAcc(StmtDomain, Context, AccMap, MMI.i, MMI.k) && !MMI.A) { MMI.A = MemAccess; return true; } - if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) { + if (isMatMulOperandAcc(StmtDomain, Context, AccMap, MMI.k, MMI.j) && !MMI.B) { MMI.B = MemAccess; return true; } @@ -785,7 +788,8 @@ if (!MemAccessPtr->isWrite()) return false; auto AccMap = MemAccessPtr->getLatestAccessRelation(); - if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j)) + if (!isMatMulOperandAcc(Stmt->getDomain(), Stmt->getParent()->getContext(), + AccMap, MMI.i, MMI.j)) return false; MMI.WriteToC = MemAccessPtr; break; @@ -1152,6 +1156,7 @@ ExtMap = ExtMap.intersect_range(Domain); ExtMap = ExtMap.set_tuple_id(isl::dim::out, NewStmt->getDomainId()); Node = createExtensionNode(Node, ExtMap); + return Node.child(0).child(0).child(0).child(0).child(0); }