Index: include/polly/ScheduleOptimizer.h =================================================================== --- include/polly/ScheduleOptimizer.h +++ include/polly/ScheduleOptimizer.h @@ -244,12 +244,12 @@ /// and interchanging. /// /// @param Node The schedule node to be modified. - /// @param Mr The parameter of the BLIS micro-kernel. - /// @param Nr The parameter of the BLIS micro-kernel. + /// @param MacroKernelParams Parameters of the BLIS macro-kernel. /// /// @see ScheduleTreeOptimizer::optimizeMatMulPattern static __isl_give isl_schedule_node * - createMacroKernel(__isl_take isl_schedule_node *Node, int Mr, int Nr); + createMacroKernel(__isl_take isl_schedule_node *Node, + llvm::ArrayRef MacroKernelParams); }; #endif Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -530,10 +530,118 @@ return Node; } -__isl_give isl_schedule_node * -ScheduleTreeOptimizer::createMacroKernel(__isl_take isl_schedule_node *Node, - int Mr, int Nr) { +__isl_give isl_schedule_node *ScheduleTreeOptimizer::createMacroKernel( + __isl_take isl_schedule_node *Node, llvm::ArrayRef MacroKernelParams) { assert(isl_schedule_node_get_type(Node) == isl_schedule_node_band); + if (MacroKernelParams[0] == 1 && MacroKernelParams[1] == 1 && + MacroKernelParams[2] == 1) + return Node; + Node = tileNode(Node, "1nd level tiling", MacroKernelParams, 1); + Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); + Node = permuteBandNodeDimensions(Node, 1, 2); + return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0); +} + +/// @brief Replace all specified access relations +/// +/// Replace all access relations of memory accesses of the ScpStmt +/// that are equal to OldMemAccessRel. New memory accesses have a form +/// of NewMemAccessRel. +/// +/// It should be noted that in the current implementation +/// of replaceMemAccRelations identifiers of a ranges are not taken +/// into account. That is why an access relation could be replaced +/// even if an identifier of its range differs from the one that +/// is assigned to a range of OldMemAccessRel. However, new access +/// relations preserve identifiers of replaced ones. +/// +/// @param ScpStmt The SCoP statement that contains memory accesses +/// under consideration. +/// @param OldMemAccessRel The map, which describes access relations +/// that should be replaced. +/// @param NewMemAccessRel The map, which describes new access relations. +static void replaceMemAccRelations(ScopStmt *ScpStmt, + __isl_take isl_map *OldMemAccessRel, + __isl_take isl_map *NewMemAccessRel) { + isl_map *AccessRelation; + for (auto MemA = ScpStmt->begin(); MemA != ScpStmt->end(); MemA++) { + AccessRelation = (*MemA)->getAccessRelation(); + auto *MemId = isl_map_get_tuple_id(AccessRelation, isl_dim_out); + OldMemAccessRel = isl_map_set_tuple_id(OldMemAccessRel, isl_dim_out, MemId); + if (isl_map_is_equal(OldMemAccessRel, AccessRelation)) { + MemId = isl_map_get_tuple_id(AccessRelation, isl_dim_out); + auto *MemAccessRel = isl_map_set_tuple_id(isl_map_copy(NewMemAccessRel), + isl_dim_out, MemId); + (*MemA)->setNewAccessRelation(MemAccessRel); + } + isl_map_free(AccessRelation); + } + isl_map_free(OldMemAccessRel); + isl_map_free(NewMemAccessRel); +} + +/// @brief Replace all access relations passed as string variables +/// +/// An overloaded variant of replaceMemAccRelations, which replaces all access +/// relations of memory accesses of the ScpStmt that are equal +/// to OldMemAccessRel with NewMemAccessRel. MapOldIndVar is used to translate +/// NewMemAccessRel into the representation, which uses original induction +/// variables. +/// +/// @param ScpStmt The SCoP statement that contains memory accesses +/// under consideration. +/// @param OldMemAccessRelStr The map, which describes access relations +/// that should be replaced. +/// @param NewMemAccessRelStr The map, which describes new access relations. +/// @param MapOldIndVar The relation, which maps induction variables produced +/// by schedule transformations to the original ones. +static void replaceMemAccRelations(ScopStmt *ScpStmt, + const char *OldMemAccessRelStr, + const char *NewMemAccessRelStr, + __isl_take isl_map *MapToOldIndVar) { + auto *Ctx = ScpStmt->getIslCtx(); + auto *OldMemAccessRel = isl_map_read_from_str(Ctx, OldMemAccessRelStr); + auto *NewMemAccessRel = isl_map_read_from_str(Ctx, NewMemAccessRelStr); + NewMemAccessRel = isl_map_apply_range(MapToOldIndVar, NewMemAccessRel); + OldMemAccessRel = + isl_map_set_tuple_id(OldMemAccessRel, isl_dim_in, ScpStmt->getDomainId()); + replaceMemAccRelations(ScpStmt, OldMemAccessRel, NewMemAccessRel); +} + +/// @brief Get parameters of the BLIS micro kernel +/// +/// The description of the utilized algorithm can be found in +/// (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf). +/// +/// @param TTI Target Transform Info. +/// @param MicroKernelParams Parameters of the BLIS micro kernel, which +/// are to be computed. +static void getMicroKernelParams(const llvm::TargetTransformInfo *TTI, + std::array &MicroKernelParams) { + assert(TTI && "The target transform info should be provided."); + // Nvec - Number of double-precision floating-point numbers that can be hold + // by a vector register. Use 2 by default. + auto Nvec = TTI->getRegisterBitWidth(true) / 64; + if (Nvec == 0) + Nvec = 2; + int Nr = + ceil(sqrt(Nvec * LatencyVectorFma * ThrougputVectorFma) / Nvec) * Nvec; + int Mr = ceil(Nvec * LatencyVectorFma * ThrougputVectorFma / Nr); + MicroKernelParams = {Mr, Nr}; +} + +/// @brief Get parameters of BLIS macro kernel +/// +/// The description of the utilized algorithm can be found in +/// (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf). +/// +/// @param MicroKernelParams Parameters of the BLIS micro kernel. +/// @param MacroKernelParams Parameters of the BLIS macro kernel, which +/// are to be computed. +static void getMacroKernelParams(const std::array &MicroKernelParams, + std::array &MacroKernelParams) { + int Mr = MicroKernelParams[0]; + int Nr = MicroKernelParams[1]; // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf, // it requires information about the first two levels of a cache to determine // all the parameters of a macro-kernel. It also checks that an associativity @@ -543,7 +651,7 @@ CacheLevelAssociativityDegrees.size() >= 2 && CacheLevelSizes[0] > 0 && CacheLevelSizes[1] > 0 && CacheLevelAssociativityDegrees[0] > 2 && CacheLevelAssociativityDegrees[1] > 2)) - return Node; + return; int Cbr = floor((CacheLevelAssociativityDegrees[0] - 1) / (1 + static_cast(Mr) / Nr)); int Kc = @@ -556,28 +664,73 @@ CacheLevelSizes[1]; int Mc = floor(Mr / Cac); int Nc = floor((Nr * (CacheLevelAssociativityDegrees[1] - 2)) / Cbc); - int MacroKernelParams[] = {Mc, Nc, Kc}; - Node = tileNode(Node, "1nd level tiling", MacroKernelParams, 1); - Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); - Node = permuteBandNodeDimensions(Node, 1, 2); - return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0); + MacroKernelParams = {Mc, Nc, Kc}; +} + +/// @brief Perform the packing transformation +/// +/// The description of the utilized algorithm can be found in +/// (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf). +/// +/// @param ScpStmt The SCoP statement that contains memory accesses +/// under consideration. +/// @param MapOldIndVar The relation, which maps induction variables produced +/// by schedule transformations to the original ones. +/// @param MicroKernelParams Parameters of the BLIS micro kernel +/// @param MacroKernelParams Parameters of the BLIS macro kernel +static void +optimizeDataLayoutMatrMulPattern(ScopStmt *ScpStmt, + __isl_take isl_map *MapToOldIndVar, + const std::array &MicroKernelParams, + const std::array &MacroKernelParams) { + // After replacement access relations should have the following form: + // {[o0, o1, o2, o3, o4, o5, o6, o7, o8] -> [Mr * (o5 + Kc * o3) + o6]} + std::string NewAccRelationStr = + "{[o0, o1, o2, o3, o4, o5, o6, o7, o8] -> [" + + std::to_string(MicroKernelParams[0]) + "o5 + " + + std::to_string(MicroKernelParams[0] * MacroKernelParams[2]) + "o3 + o6]}"; + replaceMemAccRelations(ScpStmt, "{[i0, i1, i2] -> [i0, i2]}", + NewAccRelationStr.c_str(), + isl_map_copy(MapToOldIndVar)); + // After replacement access relations should have the following form: + // {[o0, o1, o2, o3, o4, o5, o6, o7, o8] -> [Nr * (o5 + Kc * o3) + o7]} + NewAccRelationStr = + "{[o0, o1, o2, o3, o4, o5, o6, o7, o8] -> [" + + std::to_string(MicroKernelParams[1]) + "o5 + " + + std::to_string(MicroKernelParams[1] * MacroKernelParams[2]) + "o3 + o7]}"; + replaceMemAccRelations(ScpStmt, "{[i0, i1, i2] -> [i2, i1]}", + NewAccRelationStr.c_str(), MapToOldIndVar); } __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern( __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) { assert(TTI && "The target transform info should be provided."); - // Nvec - Number of double-precision floating-point numbers that can be hold - // by a vector register. Use 2 by default. - auto Nvec = TTI->getRegisterBitWidth(true) / 64; - if (Nvec == 0) - Nvec = 2; - int Nr = - ceil(sqrt(Nvec * LatencyVectorFma * ThrougputVectorFma) / Nvec) * Nvec; - int Mr = ceil(Nvec * LatencyVectorFma * ThrougputVectorFma / Nr); - Node = createMacroKernel(Node, Mr, Nr); - // Get a micro-kernel. - int MicroKernelParams[] = {Mr, Nr}; - Node = applyRegisterTiling(Node, MicroKernelParams, 1); + std::array MicroKernelParams = {1, 1}; + std::array MacroKernelParams = {1, 1, 1}; + getMicroKernelParams(TTI, MicroKernelParams); + getMacroKernelParams(MicroKernelParams, MacroKernelParams); + Node = + createMacroKernel(Node, llvm::ArrayRef(MacroKernelParams.data(), 3)); + Node = applyRegisterTiling( + Node, llvm::ArrayRef(MicroKernelParams.data(), 2), 1); + // Case ... requires special handling + if (MacroKernelParams[0] == 1 && MacroKernelParams[1] == 1 && + MacroKernelParams[2] == 1) + return Node; + // Get RelMap... + auto *Child = isl_schedule_node_get_child(Node, 0); + auto *UnMapOldIndVar = isl_schedule_node_get_prefix_schedule_union_map(Child); + isl_schedule_node_free(Child); + auto *MapToOldIndVar = isl_map_from_union_map(UnMapOldIndVar); + if (isl_map_dim(MapToOldIndVar, isl_dim_out) > 9) + MapToOldIndVar = + isl_map_project_out(MapToOldIndVar, isl_dim_out, 0, + isl_map_dim(MapToOldIndVar, isl_dim_out) - 9); + auto InputDimsId = isl_map_get_tuple_id(MapToOldIndVar, isl_dim_in); + optimizeDataLayoutMatrMulPattern( + static_cast(isl_id_get_user(InputDimsId)), MapToOldIndVar, + MicroKernelParams, MacroKernelParams); + isl_id_free(InputDimsId); return Node; }