Changeset View
Changeset View
Standalone View
Standalone View
lib/Transform/ScheduleOptimizer.cpp
Show First 20 Lines • Show All 477 Lines • ▼ Show 20 Lines | for (int i = Dims - 1; i >= 0; i--) | ||||
if (Node.band_member_get_coincident(i)) { | if (Node.band_member_get_coincident(i)) { | ||||
Node = prevectSchedBand(Node, i, PrevectorWidth); | Node = prevectSchedBand(Node, i, PrevectorWidth); | ||||
break; | break; | ||||
} | } | ||||
return Node; | return Node; | ||||
} | } | ||||
/// Get the position of a dimension with a non-zero coefficient. | |||||
/// | |||||
/// Check that isl constraint @p Constraint has only one non-zero | |||||
/// coefficient for dimensions that have type @p DimType. If this is true, | |||||
/// return the position of the dimension corresponding to the non-zero | |||||
/// coefficient and negative value, otherwise. | |||||
/// | |||||
/// @param Constraint The isl constraint to be checked. | |||||
/// @param DimType The type of the dimensions. | |||||
/// @return The position of the dimension in case the isl | |||||
/// constraint satisfies the requirements, a negative | |||||
/// value, otherwise. | |||||
static int getMatMulConstraintDim(isl::constraint Constraint, | |||||
isl::dim DimType) { | |||||
int DimPos = -1; | |||||
auto LocalSpace = Constraint.get_local_space(); | |||||
int LocalSpaceDimNum = LocalSpace.dim(DimType); | |||||
for (int i = 0; i < LocalSpaceDimNum; i++) { | |||||
auto Val = Constraint.get_coefficient_val(DimType, i); | |||||
if (Val.is_zero()) | |||||
continue; | |||||
if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) || | |||||
(DimType == isl::dim::in && !Val.is_negone())) | |||||
return -1; | |||||
DimPos = i; | |||||
} | |||||
return DimPos; | |||||
} | |||||
/// Check the form of the isl constraint. | |||||
/// | |||||
/// Check that the @p DimInPos input dimension of the isl constraint | |||||
/// @p Constraint has a coefficient that is equal to negative one, the @p | |||||
/// DimOutPos has a coefficient that is equal to one and others | |||||
/// have coefficients equal to zero. | |||||
/// | |||||
/// @param Constraint The isl constraint to be checked. | |||||
/// @param DimInPos The input dimension of the isl constraint. | |||||
/// @param DimOutPos The output dimension of the isl constraint. | |||||
/// @return isl_stat_ok in case the isl constraint satisfies | |||||
/// the requirements, isl_stat_error otherwise. | |||||
static isl_stat isMatMulOperandConstraint(isl::constraint Constraint, | |||||
int &DimInPos, int &DimOutPos) { | |||||
auto Val = Constraint.get_constant_val(); | |||||
if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero()) | |||||
return isl_stat_error; | |||||
DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in); | |||||
if (DimInPos < 0) | |||||
return isl_stat_error; | |||||
DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out); | |||||
if (DimOutPos < 0) | |||||
return isl_stat_error; | |||||
return isl_stat_ok; | |||||
} | |||||
/// Permute the two dimensions of the isl map. | /// Permute the two dimensions of the isl map. | ||||
/// | /// | ||||
/// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that | /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that | ||||
/// have type @p DimType. | /// have type @p DimType. | ||||
/// | /// | ||||
/// @param Map The isl map to be modified. | /// @param Map The isl map to be modified. | ||||
/// @param DimType The type of the dimensions. | /// @param DimType The type of the dimensions. | ||||
/// @param DstPos The first dimension. | /// @param DstPos The first dimension. | ||||
Show All 31 Lines | |||||
/// | /// | ||||
/// @param AccMap The access relation to be checked. | /// @param AccMap The access relation to be checked. | ||||
/// @param FirstPos The index of the input dimension that is mapped to | /// @param FirstPos The index of the input dimension that is mapped to | ||||
/// the first output dimension. | /// the first output dimension. | ||||
/// @param SecondPos The index of the input dimension that is mapped to the | /// @param SecondPos The index of the input dimension that is mapped to the | ||||
/// second output dimension. | /// second output dimension. | ||||
/// @return True in case @p AccMap has the expected form and false, | /// @return True in case @p AccMap has the expected form and false, | ||||
/// otherwise. | /// otherwise. | ||||
static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) { | static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, | ||||
int DimInPos[] = {FirstPos, SecondPos}; | int &SecondPos) { | ||||
auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat { | |||||
auto Constraints = BasicMap.get_constraint_list(); | isl::space Space = AccMap.get_space(); | ||||
if (isl_constraint_list_n_constraint(Constraints.get()) != 2) | isl::map Universe = isl::map::universe(Space); | ||||
return isl::stat::error; | |||||
for (int i = 0; i < 2; i++) { | if (Space.dim(isl::dim::out) != 2) | ||||
auto Constraint = | |||||
isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i)); | |||||
int InPos, OutPos; | |||||
if (isMatMulOperandConstraint(Constraint, InPos, OutPos) == | |||||
isl_stat_error || | |||||
OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos)) | |||||
return isl::stat::error; | |||||
DimInPos[OutPos] = InPos; | |||||
} | |||||
return isl::stat::ok; | |||||
}; | |||||
if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 || | |||||
DimInPos[1] < 0) | |||||
return false; | return false; | ||||
FirstPos = DimInPos[0]; | |||||
SecondPos = DimInPos[1]; | // MatMul has the form: | ||||
// for (i = 0; i < N; i++) | |||||
// for (j = 0; j < M; j++) | |||||
// for (k = 0; k < P; k++) | |||||
// C[i, j] += A[i, k] * B[k, j] | |||||
// | |||||
// Permutation of three outer loops: 3! = 6 possibilities. | |||||
int FirstDims[] = {0, 0, 1, 1, 2, 2}; | |||||
int SecondDims[] = {1, 2, 2, 0, 0, 1}; | |||||
for (int i = 0; i < 6; i += 1) { | |||||
auto PossibleMatMul = | |||||
Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0) | |||||
.equate(isl::dim::in, SecondDims[i], isl::dim::out, 1); | |||||
AccMap = AccMap.intersect_domain(Domain); | |||||
PossibleMatMul = PossibleMatMul.intersect_domain(Domain); | |||||
// If AccMap spans entire domain (Non-partial write), | |||||
// compute FirstPos and SecondPos. | |||||
// If AccMap != PossibleMatMul here (the two maps have been gisted at | |||||
// this point), it means that the writes are not complete, or in other | |||||
// words, it is a Partial write and Partial writes must be rejected. | |||||
if (AccMap.is_equal(PossibleMatMul)) { | |||||
if (FirstPos != -1 && FirstPos != FirstDims[i]) | |||||
continue; | |||||
FirstPos = FirstDims[i]; | |||||
if (SecondPos != -1 && SecondPos != SecondDims[i]) | |||||
continue; | |||||
SecondPos = SecondDims[i]; | |||||
return true; | return true; | ||||
} | } | ||||
} | |||||
return false; | |||||
} | |||||
/// Does the memory access represent a non-scalar operand of the matrix | /// Does the memory access represent a non-scalar operand of the matrix | ||||
/// multiplication. | /// multiplication. | ||||
/// | /// | ||||
/// Check that the memory access @p MemAccess is the read access to a non-scalar | /// Check that the memory access @p MemAccess is the read access to a non-scalar | ||||
/// operand of the matrix multiplication or its result. | /// operand of the matrix multiplication or its result. | ||||
/// | /// | ||||
/// @param MemAccess The memory access to be checked. | /// @param MemAccess The memory access to be checked. | ||||
/// @param MMI Parameters of the matrix multiplication operands. | /// @param MMI Parameters of the matrix multiplication operands. | ||||
/// @return True in case the memory access represents the read access | /// @return True in case the memory access represents the read access | ||||
/// to a non-scalar operand of the matrix multiplication and | /// to a non-scalar operand of the matrix multiplication and | ||||
/// false, otherwise. | /// false, otherwise. | ||||
static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess, | static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess, | ||||
MatMulInfoTy &MMI) { | MatMulInfoTy &MMI) { | ||||
if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) | if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) | ||||
return false; | return false; | ||||
auto AccMap = MemAccess->getLatestAccessRelation(); | auto AccMap = MemAccess->getLatestAccessRelation(); | ||||
if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC && | isl::set StmtDomain = MemAccess->getStatement()->getDomain(); | ||||
isl_map_n_basic_map(AccMap.get()) == 1) { | if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) { | ||||
MMI.ReadFromC = MemAccess; | MMI.ReadFromC = MemAccess; | ||||
return true; | return true; | ||||
} | } | ||||
if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A && | if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) { | ||||
isl_map_n_basic_map(AccMap.get()) == 1) { | |||||
MMI.A = MemAccess; | MMI.A = MemAccess; | ||||
return true; | return true; | ||||
} | } | ||||
if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B && | if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) { | ||||
isl_map_n_basic_map(AccMap.get()) == 1) { | |||||
MMI.B = MemAccess; | MMI.B = MemAccess; | ||||
return true; | return true; | ||||
} | } | ||||
return false; | return false; | ||||
} | } | ||||
/// Check accesses to operands of the matrix multiplication. | /// Check accesses to operands of the matrix multiplication. | ||||
/// | /// | ||||
▲ Show 20 Lines • Show All 103 Lines • ▼ Show 20 Lines | static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D, | ||||
auto Accesses = getAccessesInOrder(*Stmt); | auto Accesses = getAccessesInOrder(*Stmt); | ||||
for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) { | for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) { | ||||
auto *MemAccessPtr = *MemA; | auto *MemAccessPtr = *MemA; | ||||
if (!MemAccessPtr->isLatestArrayKind()) | if (!MemAccessPtr->isLatestArrayKind()) | ||||
continue; | continue; | ||||
if (!MemAccessPtr->isWrite()) | if (!MemAccessPtr->isWrite()) | ||||
return false; | return false; | ||||
auto AccMap = MemAccessPtr->getLatestAccessRelation(); | auto AccMap = MemAccessPtr->getLatestAccessRelation(); | ||||
if (isl_map_n_basic_map(AccMap.get()) != 1 || | if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j)) | ||||
!isMatMulOperandAcc(AccMap, MMI.i, MMI.j)) | |||||
return false; | return false; | ||||
MMI.WriteToC = MemAccessPtr; | MMI.WriteToC = MemAccessPtr; | ||||
break; | break; | ||||
} | } | ||||
if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k)) | if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k)) | ||||
return false; | return false; | ||||
▲ Show 20 Lines • Show All 761 Lines • Show Last 20 Lines |