Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -481,61 +481,6 @@ return Node; } -/// Get the position of a dimension with a non-zero coefficient. -/// -/// Check that isl constraint @p Constraint has only one non-zero -/// coefficient for dimensions that have type @p DimType. If this is true, -/// return the position of the dimension corresponding to the non-zero -/// coefficient and negative value, otherwise. -/// -/// @param Constraint The isl constraint to be checked. -/// @param DimType The type of the dimensions. -/// @return The position of the dimension in case the isl -/// constraint satisfies the requirements, a negative -/// value, otherwise. -static int getMatMulConstraintDim(isl::constraint Constraint, - isl::dim DimType) { - int DimPos = -1; - auto LocalSpace = Constraint.get_local_space(); - int LocalSpaceDimNum = LocalSpace.dim(DimType); - for (int i = 0; i < LocalSpaceDimNum; i++) { - auto Val = Constraint.get_coefficient_val(DimType, i); - if (Val.is_zero()) - continue; - if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) || - (DimType == isl::dim::in && !Val.is_negone())) - return -1; - DimPos = i; - } - return DimPos; -} - -/// Check the form of the isl constraint. -/// -/// Check that the @p DimInPos input dimension of the isl constraint -/// @p Constraint has a coefficient that is equal to negative one, the @p -/// DimOutPos has a coefficient that is equal to one and others -/// have coefficients equal to zero. -/// -/// @param Constraint The isl constraint to be checked. -/// @param DimInPos The input dimension of the isl constraint. -/// @param DimOutPos The output dimension of the isl constraint. -/// @return isl_stat_ok in case the isl constraint satisfies -/// the requirements, isl_stat_error otherwise. -static isl_stat isMatMulOperandConstraint(isl::constraint Constraint, - int &DimInPos, int &DimOutPos) { - auto Val = Constraint.get_constant_val(); - if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero()) - return isl_stat_error; - DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in); - if (DimInPos < 0) - return isl_stat_error; - DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out); - if (DimOutPos < 0) - return isl_stat_error; - return isl_stat_ok; -} - /// Permute the two dimensions of the isl map. /// /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that @@ -583,30 +528,49 @@ /// second output dimension. /// @return True in case @p AccMap has the expected form and false, /// otherwise. -static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) { - int DimInPos[] = {FirstPos, SecondPos}; - auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat { - auto Constraints = BasicMap.get_constraint_list(); - if (isl_constraint_list_n_constraint(Constraints.get()) != 2) - return isl::stat::error; - for (int i = 0; i < 2; i++) { - auto Constraint = - isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i)); - int InPos, OutPos; - if (isMatMulOperandConstraint(Constraint, InPos, OutPos) == - isl_stat_error || - OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos)) - return isl::stat::error; - DimInPos[OutPos] = InPos; - } - return isl::stat::ok; - }; - if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 || - DimInPos[1] < 0) +static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, + int &SecondPos) { + + isl::space Space = AccMap.get_space(); + isl::map Universe = isl::map::universe(Space); + + if (Space.dim(isl::dim::out) != 2) return false; - FirstPos = DimInPos[0]; - SecondPos = DimInPos[1]; - return true; + + // MatMul has the form: + // for (i = 0; i < N; i++) + // for (j = 0; j < M; j++) + // for (k = 0; k < P; k++) + // C[i, j] += A[i, k] * B[k, j] + // + // Permutation of three outer loops: 3! = 6 possibilities. + int FirstDims[] = {0, 0, 1, 1, 2, 2}; + int SecondDims[] = {1, 2, 2, 0, 0, 1}; + for (int i = 0; i < 6; i += 1) { + auto PossibleMatMul = + Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0) + .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1); + + AccMap = AccMap.intersect_domain(Domain); + PossibleMatMul = PossibleMatMul.intersect_domain(Domain); + + // If AccMap spans entire domain (Non-partial write), + // compute FirstPos and SecondPos. + // If AccMap != PossibleMatMul here (the two maps have been gisted at + // this point), it means that the writes are not complete, or in other + // words, it is a Partial write and Partial writes must be rejected. + if (RedAccMap.is_equal(PossibleMatMul)) { + if (FirstPos != -1 && FirstPos != FirstDims[i]) + continue; + FirstPos = FirstDims[i]; + if (SecondPos != -1 && SecondPos != SecondDims[i]) + continue; + SecondPos = SecondDims[i]; + return true; + } + } + + return false; } /// Does the memory access represent a non-scalar operand of the matrix @@ -625,18 +589,16 @@ if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) return false; auto AccMap = MemAccess->getLatestAccessRelation(); - if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC && - isl_map_n_basic_map(AccMap.get()) == 1) { + isl::set StmtDomain = MemAccess->getStatement()->getDomain(); + if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) { MMI.ReadFromC = MemAccess; return true; } - if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A && - isl_map_n_basic_map(AccMap.get()) == 1) { + if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) { MMI.A = MemAccess; return true; } - if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B && - isl_map_n_basic_map(AccMap.get()) == 1) { + if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) { MMI.B = MemAccess; return true; } @@ -752,8 +714,7 @@ if (!MemAccessPtr->isWrite()) return false; auto AccMap = MemAccessPtr->getLatestAccessRelation(); - if (isl_map_n_basic_map(AccMap.get()) != 1 || - !isMatMulOperandAcc(AccMap, MMI.i, MMI.j)) + if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j)) return false; MMI.WriteToC = MemAccessPtr; break;