Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -481,61 +481,6 @@ return Node; } -/// Get the position of a dimension with a non-zero coefficient. -/// -/// Check that isl constraint @p Constraint has only one non-zero -/// coefficient for dimensions that have type @p DimType. If this is true, -/// return the position of the dimension corresponding to the non-zero -/// coefficient and negative value, otherwise. -/// -/// @param Constraint The isl constraint to be checked. -/// @param DimType The type of the dimensions. -/// @return The position of the dimension in case the isl -/// constraint satisfies the requirements, a negative -/// value, otherwise. -static int getMatMulConstraintDim(isl::constraint Constraint, - isl::dim DimType) { - int DimPos = -1; - auto LocalSpace = Constraint.get_local_space(); - int LocalSpaceDimNum = LocalSpace.dim(DimType); - for (int i = 0; i < LocalSpaceDimNum; i++) { - auto Val = Constraint.get_coefficient_val(DimType, i); - if (Val.is_zero()) - continue; - if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) || - (DimType == isl::dim::in && !Val.is_negone())) - return -1; - DimPos = i; - } - return DimPos; -} - -/// Check the form of the isl constraint. -/// -/// Check that the @p DimInPos input dimension of the isl constraint -/// @p Constraint has a coefficient that is equal to negative one, the @p -/// DimOutPos has a coefficient that is equal to one and others -/// have coefficients equal to zero. -/// -/// @param Constraint The isl constraint to be checked. -/// @param DimInPos The input dimension of the isl constraint. -/// @param DimOutPos The output dimension of the isl constraint. -/// @return isl_stat_ok in case the isl constraint satisfies -/// the requirements, isl_stat_error otherwise. -static isl_stat isMatMulOperandConstraint(isl::constraint Constraint, - int &DimInPos, int &DimOutPos) { - auto Val = Constraint.get_constant_val(); - if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero()) - return isl_stat_error; - DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in); - if (DimInPos < 0) - return isl_stat_error; - DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out); - if (DimOutPos < 0) - return isl_stat_error; - return isl_stat_ok; -} - /// Permute the two dimensions of the isl map. /// /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that @@ -571,6 +516,42 @@ return Map; } +/// Compute values of FirstPos and SecondPos. +/// +/// @param AccMap The access relation to be checked. +/// @param PossibleMatMul One of the many(6) possible access maps +/// corresponding to MatMul pattern. +/// @param ExpectedFirstPos The expected value of FirstPos (if already +/// assigned). +/// @param ExpectedSecondPos The expected value of SecondPos (if already +/// assigned). +/// @param FirstPos The index of the input dimension that is +/// mapped to the first output dimension. +/// @param SecondPos The index of the input dimension that is +/// mapped to the second output dimension. +/// @return True in case @p FirstPos, @p SecondPos have +/// the expected values, and false otherwise. +static bool computeMatchedDimensions(isl::map AccMap, isl::map PossibleMatMul, + int ExpectedFirstPos, + int ExpectedSecondPos, int &FirstPos, + int &SecondPos) { + if (AccMap.is_equal(PossibleMatMul)) { + if (FirstPos != -1 && FirstPos != ExpectedFirstPos) + return false; + FirstPos = ExpectedFirstPos; + if (SecondPos != -1 && SecondPos != ExpectedSecondPos) + return false; + SecondPos = ExpectedSecondPos; + return true; + } + + // If AccMap is not the same as PossibleMatMul here (the two maps have + // been gisted at this point), it means that their domains are not equal. + // This means that the writes are not complete, or in other words, + // it is a Partial write and Partial writes must be rejected. + return false; +} + /// Check the form of the access relation. /// /// Check that the access relation @p AccMap has the form M[i][j], where i @@ -583,30 +564,70 @@ /// second output dimension. /// @return True in case @p AccMap has the expected form and false, /// otherwise. -static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) { - int DimInPos[] = {FirstPos, SecondPos}; - auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat { +static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, + int &SecondPos) { + + auto Lambda = [](isl::basic_map BasicMap) -> isl::stat { auto Constraints = BasicMap.get_constraint_list(); if (isl_constraint_list_n_constraint(Constraints.get()) != 2) return isl::stat::error; - for (int i = 0; i < 2; i++) { - auto Constraint = - isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i)); - int InPos, OutPos; - if (isMatMulOperandConstraint(Constraint, InPos, OutPos) == - isl_stat_error || - OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos)) - return isl::stat::error; - DimInPos[OutPos] = InPos; - } return isl::stat::ok; }; - if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 || - DimInPos[1] < 0) + if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok) return false; - FirstPos = DimInPos[0]; - SecondPos = DimInPos[1]; - return true; + + isl::map PossibleMatMul[6]; + + isl::space Space = AccMap.get_space(); + isl::map Universe = isl::map::universe(Space); + + // { Stmt[i0, i1, i2] -> MemRef_A[i0, i1] } + // FirstPos = 0, SecondPos = 1 + PossibleMatMul[0] = Universe.equate(isl::dim::in, 0, isl::dim::out, 0) + .equate(isl::dim::in, 1, isl::dim::out, 1); + // { Stmt[i0, i1, i2] -> MemRef_A[i0, i2] } + // FirstPos = 0, SecondPos = 2 + PossibleMatMul[1] = Universe.equate(isl::dim::in, 0, isl::dim::out, 0) + .equate(isl::dim::in, 2, isl::dim::out, 1); + // { Stmt[i0, i1, i2] -> MemRef_A[i1, i2] } + // FirstPos = 1, SecondPos = 2 + PossibleMatMul[2] = Universe.equate(isl::dim::in, 1, isl::dim::out, 0) + .equate(isl::dim::in, 2, isl::dim::out, 1); + // { Stmt[i0, i1, i2] -> MemRef_A[i1, i0] } + // FirstPos = 1, SecondPos = 0 + PossibleMatMul[3] = Universe.equate(isl::dim::in, 1, isl::dim::out, 0) + .equate(isl::dim::in, 0, isl::dim::out, 1); + // { Stmt[i0, i1, i2] -> MemRef_A[i2, i0] } + // FirstPos = 2, SecondPos = 0 + PossibleMatMul[4] = Universe.equate(isl::dim::in, 2, isl::dim::out, 0) + .equate(isl::dim::in, 0, isl::dim::out, 1); + // { Stmt[i0, i1, i2] -> MemRef_A[i2, i1] } + // FirstPos = 2, SecondPos = 1 + PossibleMatMul[5] = Universe.equate(isl::dim::in, 2, isl::dim::out, 0) + .equate(isl::dim::in, 1, isl::dim::out, 1); + + // AccMap = isl_map_intersect_domain(AccMap, isl_set_copy(Domain)); + AccMap = AccMap.intersect_domain(Domain); + PossibleMatMul[0] = PossibleMatMul[0].intersect_domain(Domain); + PossibleMatMul[1] = PossibleMatMul[1].intersect_domain(Domain); + PossibleMatMul[2] = PossibleMatMul[2].intersect_domain(Domain); + PossibleMatMul[3] = PossibleMatMul[3].intersect_domain(Domain); + PossibleMatMul[4] = PossibleMatMul[4].intersect_domain(Domain); + PossibleMatMul[5] = PossibleMatMul[5].intersect_domain(Domain); + + // Relying on short-circuiting of expression. + return (computeMatchedDimensions(AccMap, PossibleMatMul[0], 0, 1, FirstPos, + SecondPos) || + computeMatchedDimensions(AccMap, PossibleMatMul[1], 0, 2, FirstPos, + SecondPos) || + computeMatchedDimensions(AccMap, PossibleMatMul[2], 1, 2, FirstPos, + SecondPos) || + computeMatchedDimensions(AccMap, PossibleMatMul[3], 1, 0, FirstPos, + SecondPos) || + computeMatchedDimensions(AccMap, PossibleMatMul[4], 2, 0, FirstPos, + SecondPos) || + computeMatchedDimensions(AccMap, PossibleMatMul[5], 2, 1, FirstPos, + SecondPos)); } /// Does the memory access represent a non-scalar operand of the matrix @@ -625,18 +646,21 @@ if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) return false; auto AccMap = MemAccess->getLatestAccessRelation(); - if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC && - isl_map_n_basic_map(AccMap.get()) == 1) { + if (isMatMulOperandAcc(MemAccess->getStatement()->getDomain(), AccMap, MMI.i, + MMI.j) && + !MMI.ReadFromC) { MMI.ReadFromC = MemAccess; return true; } - if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A && - isl_map_n_basic_map(AccMap.get()) == 1) { + if (isMatMulOperandAcc(MemAccess->getStatement()->getDomain(), AccMap, MMI.i, + MMI.k) && + !MMI.A) { MMI.A = MemAccess; return true; } - if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B && - isl_map_n_basic_map(AccMap.get()) == 1) { + if (isMatMulOperandAcc(MemAccess->getStatement()->getDomain(), AccMap, MMI.k, + MMI.j) && + !MMI.B) { MMI.B = MemAccess; return true; } @@ -752,8 +776,7 @@ if (!MemAccessPtr->isWrite()) return false; auto AccMap = MemAccessPtr->getLatestAccessRelation(); - if (isl_map_n_basic_map(AccMap.get()) != 1 || - !isMatMulOperandAcc(AccMap, MMI.i, MMI.j)) + if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j)) return false; MMI.WriteToC = MemAccessPtr; break;