Index: include/polly/CodeGen/IslExprBuilder.h =================================================================== --- include/polly/CodeGen/IslExprBuilder.h +++ include/polly/CodeGen/IslExprBuilder.h @@ -166,6 +166,18 @@ /// was enabled. llvm::Value *getOverflowState() const; + /// @brief Create LLVM-IR that computes the memory location of an access + /// expression. + /// + /// For a given isl_ast_expr[ession] of type isl_ast_op_access this function + /// creates IR that computes the address of the access expression refers to. + /// + /// @param Expr The ast expression that has type isl_ast_op_access + /// for which we generate LLVM-IR. + /// + /// @return The llvm::Value* containing the result of the computation. + llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr); + private: Scop &S; @@ -203,7 +215,6 @@ llvm::Value *createId(__isl_take isl_ast_expr *Expr); llvm::Value *createInt(__isl_take isl_ast_expr *Expr); llvm::Value *createOpAddressOf(__isl_take isl_ast_expr *Expr); - llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr); /// @brief Create a binary operation @p Opc and track overflows if requested. /// Index: include/polly/CodeGen/IslNodeBuilder.h =================================================================== --- include/polly/CodeGen/IslNodeBuilder.h +++ include/polly/CodeGen/IslNodeBuilder.h @@ -375,6 +375,21 @@ /// virtual __isl_give isl_union_map * getScheduleForAstNode(__isl_take isl_ast_node *Node); + +private: + /// @brief Create code for a copy statement. + /// + /// A copy statement is expected to have one read memory access and one write + /// memory access (in this very order). Data is loaded from the location + /// described by the read memory access and written to the location described + /// by the write memory access. @p NewAccesses contains for each access + /// the isl ast expression that describes the location accessed. + /// + /// @param Stmt The copy statement that contains the accesses. + /// @param NewAccesses The hash table that contains remappings from memory + /// ids to new access expressions. + void generateCopyStmt(ScopStmt *Stmt, + __isl_keep isl_id_to_ast_expr *NewAccesses); }; #endif Index: include/polly/ScheduleOptimizer.h =================================================================== --- include/polly/ScheduleOptimizer.h +++ include/polly/ScheduleOptimizer.h @@ -88,7 +88,7 @@ /// /// @return True, if we believe @p NewSchedule is an improvement for @p S. static bool isProfitableSchedule(polly::Scop &S, - __isl_keep isl_union_map *NewSchedule); + __isl_keep isl_schedule *NewSchedule); /// @brief Isolate a set of partial tile prefixes. /// Index: include/polly/ScopInfo.h =================================================================== --- include/polly/ScopInfo.h +++ include/polly/ScopInfo.h @@ -688,6 +688,19 @@ ArrayRef Subscripts, ArrayRef Sizes, Value *AccessValue, ScopArrayInfo::MemoryKind Kind, StringRef BaseName); + + /// @brief Create a new MemoryAccess that corresponds to @p AccRel. + /// + /// Along with @p Stmt and @p AccType it uses information about dimension + /// lengths of the accessed array, the type of the accessed array elements, + /// the name of the accessed array that is derived from the object accessible + /// via @p AccRel. + /// + /// @param Stmt The parent statement. + /// @param AccType Whether read or write access. + /// @param AccRel The access relation that describes the memory access. + MemoryAccess(ScopStmt *Stmt, AccessType AccType, __isl_take isl_map *AccRel); + ~MemoryAccess(); /// @brief Add a new incoming block/value pairs for this PHI/ExitPHI access. @@ -968,6 +981,16 @@ /// Create an overapproximating ScopStmt for the region @p R. ScopStmt(Scop &parent, Region &R); + /// @brief Create a copy statement. + /// + /// @param Stmt The parent statement. + /// @param SourceRel The source location. + /// @param TargetRel The target location. + /// @param Domain The original domain under which copy statement whould + /// be executed. + ScopStmt(Scop &parent, __isl_take isl_map *SourceRel, + __isl_take isl_map *TargetRel, __isl_take isl_set *Domain); + /// Initialize members after all MemoryAccesses have been added. void init(LoopInfo &LI); @@ -1133,10 +1156,14 @@ /// @brief Get the schedule function of this ScopStmt. /// - /// @return The schedule function of this ScopStmt. + /// @return The schedule function of this ScopStmt, if it does not contain + /// extension nodes, and nullptr, otherwise. __isl_give isl_map *getSchedule() const; /// @brief Get an isl string representing this schedule. + /// + /// @return An isl string representing this schedule, if it does not contain + /// extension nodes, and an empty string, otherwise. std::string getScheduleStr() const; /// @brief Get the invalid domain for this statement. @@ -1161,6 +1188,9 @@ /// @brief Return true if this statement represents a single basic block. bool isBlockStmt() const { return BB != nullptr; } + /// @brief Return true if this is a copy statement. + bool isCopyStmt() const { return BB == nullptr && R == nullptr; } + /// @brief Get the region represented by this ScopStmt (if any). /// /// @return The region represented by this ScopStmt, or null if the statement @@ -1364,6 +1394,9 @@ /// Max loop depth. unsigned MaxLoopDepth; + /// Number of copy statements. + unsigned CopyStmtsNum; + typedef std::list StmtSet; /// The statements in this Scop. StmtSet Stmts; @@ -1531,11 +1564,6 @@ Scop(Region &R, ScalarEvolution &SE, LoopInfo &LI, ScopDetection::DetectionContext &DC); - /// @brief Add the access function to all MemoryAccess objects of the Scop - /// created in this pass. - void addAccessFunction(MemoryAccess *Access) { - AccessFunctions.emplace_back(Access); - } //@} /// @brief Initialize this ScopBuilder. @@ -1844,6 +1872,29 @@ public: ~Scop(); + /// @brief Get the count of copy statements added to this Scop. + /// + /// @return The count of copy statements added to this Scop. + unsigned getCopyStmtsNum() { return CopyStmtsNum; } + + /// @brief Create a new copy statement. + /// + /// A new statement will be created and added to the statement vector. + /// + /// @param SourceRel The source location. + /// @param TargetRel The target location. + /// @param Domain The original domain under which copy statement whould + /// be executed. + ScopStmt *addScopStmt(__isl_take isl_map *SourceRel, + __isl_take isl_map *TargetRel, + __isl_take isl_set *Domain); + + /// @brief Add the access function to all MemoryAccess objects of the Scop + /// created in this pass. + void addAccessFunction(MemoryAccess *Access) { + AccessFunctions.emplace_back(Access); + } + ScalarEvolution *getSE() const; /// @brief Get the count of parameters used in this Scop. @@ -2266,6 +2317,10 @@ __isl_give isl_union_map *getAccesses(); /// @brief Get the schedule of all the statements in the SCoP. + /// + /// @return The schedule of all the statements in the SCoP, if the schedule of + /// the Scop does not contain + /// extension nodes, and nullptr, otherwise. __isl_give isl_union_map *getSchedule() const; /// @brief Get a schedule tree describing the schedule of all statements. @@ -2297,6 +2352,11 @@ /// @brief Find the ScopArrayInfo associated with an isl Id /// that has name @p Name. ScopArrayInfo *getArrayInfoByName(const std::string BaseName); + + /// @brief Check whether @p Schedule contains extension nodes. + /// + /// @return true if @p Schedule contains extension nodes. + static bool containsExtNode(__isl_keep isl_schedule *Schedule); }; /// @brief Print Scop scop to raw_ostream O. Index: lib/Analysis/DependenceInfo.cpp =================================================================== --- lib/Analysis/DependenceInfo.cpp +++ lib/Analysis/DependenceInfo.cpp @@ -153,6 +153,8 @@ // to match the new access domains, thus we need // [Stmt[i0, i1] -> MemAcc_A[i0 + i1]] -> [0, i0, 2, i1, 0] isl_map *Schedule = Stmt.getSchedule(); + assert(Schedule && "Schedules that contain extension nodes reuquire " + "special handling."); Schedule = isl_map_apply_domain( Schedule, isl_map_reverse(isl_map_domain_map(isl_map_copy(accdom)))); @@ -162,7 +164,10 @@ } else { accdom = tag(accdom, MA, Level); if (Level > Dependences::AL_Statement) { - isl_map *Schedule = tag(Stmt.getSchedule(), MA, Level); + auto *StmtScheduleMap = Stmt.getSchedule(); + assert(StmtScheduleMap && "Schedules that contain extension nodes " + "reuquire special handling."); + isl_map *Schedule = tag(StmtScheduleMap, MA, Level); *StmtSchedule = isl_union_map_add_map(*StmtSchedule, Schedule); } } @@ -597,6 +602,8 @@ StmtScat = Stmt.getSchedule(); else StmtScat = isl_map_copy((*NewSchedule)[&Stmt]); + assert(StmtScat && + "Schedules that contain extension nodes reuquire special handling."); if (!ScheduleSpace) ScheduleSpace = isl_space_range(isl_map_get_space(StmtScat)); Index: lib/Analysis/PolyhedralInfo.cpp =================================================================== --- lib/Analysis/PolyhedralInfo.cpp +++ lib/Analysis/PolyhedralInfo.cpp @@ -134,6 +134,8 @@ unsigned int MaxDim = SS->getNumIterators(); DEBUG(dbgs() << "Maximum depth of Stmt:\t" << MaxDim << "\n"); auto *ScheduleMap = SS->getSchedule(); + assert(ScheduleMap && + "Schedules that contain extension nodes reuquire special handling."); ScheduleMap = isl_map_project_out(ScheduleMap, isl_dim_out, CurrDim + 1, MaxDim - CurrDim - 1); Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -832,6 +832,28 @@ Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this); } +MemoryAccess::MemoryAccess(ScopStmt *Stmt, AccessType AccType, + __isl_take isl_map *AccRel) + : Kind(ScopArrayInfo::MemoryKind::MK_Array), AccType(AccType), + RedType(RT_NONE), Statement(Stmt), InvalidDomain(nullptr), + AccessInstruction(nullptr), IsAffine(true), AccessRelation(nullptr), + NewAccessRelation(AccRel) { + auto *ArrayInfoId = isl_map_get_tuple_id(NewAccessRelation, isl_dim_out); + auto *SAI = ScopArrayInfo::getFromId(ArrayInfoId); + Sizes.push_back(nullptr); + for (unsigned i = 1; i < SAI->getNumberOfDimensions(); i++) + Sizes.push_back(SAI->getDimensionSize(i)); + ElementType = SAI->getElementType(); + BaseAddr = SAI->getBasePtr(); + BaseName = SAI->getName(); + static const std::string TypeStrings[] = {"", "_Read", "_Write", "_MayWrite"}; + const std::string Access = TypeStrings[AccType] + utostr(Stmt->size()) + "_"; + + std::string IdName = + getIslCompatibleName(Stmt->getBaseName(), Access, BaseName); + Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this); +} + void MemoryAccess::realignParams() { auto *Ctx = Statement->getParent()->getContext(); InvalidDomain = isl_set_gist_params(InvalidDomain, isl_set_copy(Ctx)); @@ -977,6 +999,10 @@ isl_aff_zero_on_domain(isl_local_space_from_space(getDomainSpace()))); } auto *Schedule = getParent()->getSchedule(); + if (!Schedule) { + isl_set_free(Domain); + return nullptr; + } Schedule = isl_union_map_intersect_domain( Schedule, isl_union_set_from_set(isl_set_copy(Domain))); if (isl_union_map_is_empty(Schedule)) { @@ -1440,6 +1466,25 @@ BaseName = getIslCompatibleName("Stmt_", &bb, ""); } +ScopStmt::ScopStmt(Scop &parent, __isl_take isl_map *SourceRel, + __isl_take isl_map *TargetRel, __isl_take isl_set *NewDomain) + : Parent(parent), InvalidDomain(nullptr), Domain(NewDomain), BB(nullptr), + R(nullptr), Build(nullptr) { + BaseName = getIslCompatibleName("Stmt_Copy_", "", + std::to_string(parent.getCopyStmtsNum())); + auto *Id = isl_id_alloc(getIslCtx(), getBaseName(), this); + Domain = isl_set_set_tuple_id(Domain, isl_id_copy(Id)); + TargetRel = isl_map_set_tuple_id(TargetRel, isl_dim_in, Id); + auto *Access = + new MemoryAccess(this, MemoryAccess::AccessType::MUST_WRITE, TargetRel); + parent.addAccessFunction(Access); + addAccess(Access); + SourceRel = isl_map_set_tuple_id(SourceRel, isl_dim_in, isl_id_copy(Id)); + Access = new MemoryAccess(this, MemoryAccess::AccessType::READ, SourceRel); + parent.addAccessFunction(Access); + addAccess(Access); +} + void ScopStmt::init(LoopInfo &LI) { assert(!Domain && "init must be called only once"); @@ -1586,6 +1631,8 @@ std::string ScopStmt::getScheduleStr() const { auto *S = getSchedule(); + if (!S) + return ""; auto Str = stringFromIslObj(S); isl_map_free(S); return Str; @@ -3047,9 +3094,10 @@ ScopDetection::DetectionContext &DC) : SE(&ScalarEvolution), R(R), IsOptimized(false), HasSingleExitEdge(R.getExitingBlock()), HasErrorBlock(false), - MaxLoopDepth(0), DC(DC), IslCtx(isl_ctx_alloc(), isl_ctx_free), - Context(nullptr), Affinator(this, LI), AssumedContext(nullptr), - InvalidContext(nullptr), Schedule(nullptr) { + MaxLoopDepth(0), CopyStmtsNum(0), DC(DC), + IslCtx(isl_ctx_alloc(), isl_ctx_free), Context(nullptr), + Affinator(this, LI), AssumedContext(nullptr), InvalidContext(nullptr), + Schedule(nullptr) { if (IslOnErrorAbort) isl_options_set_on_error(getIslCtx(), ISL_ON_ERROR_ABORT); buildContext(); @@ -3925,8 +3973,27 @@ return getAccessesOfType([](MemoryAccess &MA) { return true; }); } +// @brief Check whether @p Node is an extension node. +// +// @return true if @p Node is an extension node. +isl_bool isNotExtNode(__isl_keep isl_schedule_node *Node, void *User) { + if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) + return isl_bool_error; + else + return isl_bool_true; +} + +bool Scop::containsExtNode(__isl_keep isl_schedule *Schedule) { + return isl_schedule_foreach_schedule_node_top_down(Schedule, isNotExtNode, + nullptr) == isl_stat_error; +} + __isl_give isl_union_map *Scop::getSchedule() const { auto *Tree = getScheduleTree(); + if (containsExtNode(Tree)) { + isl_schedule_free(Tree); + return nullptr; + } auto *S = isl_schedule_get_map(Tree); isl_schedule_free(Tree); return S; @@ -4062,6 +4129,14 @@ } } +ScopStmt *Scop::addScopStmt(__isl_take isl_map *SourceRel, + __isl_take isl_map *TargetRel, + __isl_take isl_set *Domain) { + Stmts.emplace_back(*this, SourceRel, TargetRel, Domain); + CopyStmtsNum++; + return &(Stmts.back()); +} + void Scop::buildSchedule(LoopInfo &LI) { Loop *L = getLoopSurroundingScop(*this, LI); LoopStackTy LoopStack({LoopStackElementTy(L, nullptr, 0)}); Index: lib/CodeGen/BlockGenerators.cpp =================================================================== --- lib/CodeGen/BlockGenerators.cpp +++ lib/CodeGen/BlockGenerators.cpp @@ -648,7 +648,9 @@ void BlockGenerator::invalidateScalarEvolution(Scop &S) { for (auto &Stmt : S) - if (Stmt.isBlockStmt()) + if (Stmt.isCopyStmt()) + continue; + else if (Stmt.isBlockStmt()) for (auto &Inst : *Stmt.getBasicBlock()) SE.forgetValue(&Inst); else if (Stmt.isRegionStmt()) Index: lib/CodeGen/IRBuilder.cpp =================================================================== --- lib/CodeGen/IRBuilder.cpp +++ lib/CodeGen/IRBuilder.cpp @@ -61,7 +61,8 @@ SetVector BasePtrs; for (ScopStmt &Stmt : S) for (MemoryAccess *MA : Stmt) - BasePtrs.insert(MA->getBaseAddr()); + if (!Stmt.isCopyStmt()) + BasePtrs.insert(MA->getBaseAddr()); std::string AliasScopeStr = "polly.alias.scope."; for (Value *BasePtr : BasePtrs) Index: lib/CodeGen/IslAst.cpp =================================================================== --- lib/CodeGen/IslAst.cpp +++ lib/CodeGen/IslAst.cpp @@ -593,8 +593,7 @@ P = isl_ast_node_print(RootNode, P, Options); AstStr = isl_printer_get_str(P); - isl_union_map *Schedule = - isl_union_map_intersect_domain(S.getSchedule(), S.getDomains()); + auto *Schedule = S.getScheduleTree(); DEBUG({ dbgs() << S.getContextStr() << "\n"; @@ -609,7 +608,7 @@ free(AstStr); isl_ast_expr_free(RunCondition); - isl_union_map_free(Schedule); + isl_schedule_free(Schedule); isl_ast_node_free(RootNode); isl_printer_free(P); } Index: lib/CodeGen/IslNodeBuilder.cpp =================================================================== --- lib/CodeGen/IslNodeBuilder.cpp +++ lib/CodeGen/IslNodeBuilder.cpp @@ -766,6 +766,23 @@ isl_ast_expr_free(Expr); } +void IslNodeBuilder::generateCopyStmt( + ScopStmt *Stmt, __isl_keep isl_id_to_ast_expr *NewAccesses) { + assert(Stmt->size() == 2); + auto ReadAccess = Stmt->begin(); + auto WriteAccess = ReadAccess++; + assert((*ReadAccess)->isRead() && (*WriteAccess)->isMustWrite()); + assert((*ReadAccess)->getElementType() == (*WriteAccess)->getElementType() && + "Accesses use the same data type"); + assert((*ReadAccess)->isArrayKind() && (*WriteAccess)->isArrayKind()); + auto *AccessExpr = + isl_id_to_ast_expr_get(NewAccesses, (*ReadAccess)->getId()); + auto *LoadValue = ExprBuilder.create(AccessExpr); + AccessExpr = isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId()); + auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr); + Builder.CreateStore(LoadValue, StoreAddr); +} + void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) { LoopToScevMapT LTS; isl_id *Id; @@ -780,12 +797,17 @@ Stmt = (ScopStmt *)isl_id_get_user(Id); auto *NewAccesses = createNewAccesses(Stmt, User); - createSubstitutions(Expr, Stmt, LTS); + if (Stmt->isCopyStmt()) { + generateCopyStmt(Stmt, NewAccesses); + isl_ast_expr_free(Expr); + } else { + createSubstitutions(Expr, Stmt, LTS); - if (Stmt->isBlockStmt()) - BlockGen.copyStmt(*Stmt, LTS, NewAccesses); - else - RegionGen.copyStmt(*Stmt, LTS, NewAccesses); + if (Stmt->isBlockStmt()) + BlockGen.copyStmt(*Stmt, LTS, NewAccesses); + else + RegionGen.copyStmt(*Stmt, LTS, NewAccesses); + } isl_id_to_ast_expr_free(NewAccesses); isl_ast_node_free(User); Index: lib/Exchange/JSONExporter.cpp =================================================================== --- lib/Exchange/JSONExporter.cpp +++ lib/Exchange/JSONExporter.cpp @@ -289,6 +289,8 @@ int Index = 0; for (ScopStmt &Stmt : S) { Json::Value Schedule = JScop["statements"][Index]["schedule"]; + assert(!Schedule.asString().empty() && + "Schedules that contain extension nodes reuquire special handling."); isl_map *Map = isl_map_read_from_str(S.getIslCtx(), Schedule.asCString()); isl_space *Space = Stmt.getDomainSpace(); Index: lib/Transform/DeadCodeElimination.cpp =================================================================== --- lib/Transform/DeadCodeElimination.cpp +++ lib/Transform/DeadCodeElimination.cpp @@ -92,6 +92,8 @@ // no point in trying to remove them from the live-out set. __isl_give isl_union_set *DeadCodeElim::getLiveOut(Scop &S) { isl_union_map *Schedule = S.getSchedule(); + assert(Schedule && + "Schedules that contain extension nodes reuquire special handling."); isl_union_map *WriteIterations = isl_union_map_reverse(S.getMustWrites()); isl_union_map *WriteTimes = isl_union_map_apply_range(WriteIterations, isl_union_map_copy(Schedule)); Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -662,6 +662,76 @@ return IdentifiedAccess; } +/// @brief Add constrains to @Dim dimension of @p ExtMap. +/// +/// If @ExtMap has the following form [O0, O1, O2]->[I1, I2, I3], +/// the following constraint will be added +/// Bound * OM <= IM <= Bound * (OM + 1) - 1, +/// where M is @p Dim and Bound is @p Bound. +/// +/// @param ExtMap The isl map to be modified. +/// @param Dim The output dimension to be modfied. +/// @param Bound The value that is used to specify the constraint. +/// @return The modified isl map +__isl_give isl_map * +addExtensionMapMatMulDimConstraint(__isl_take isl_map *ExtMap, unsigned Dim, + unsigned Bound) { + assert(Bound != 0); + auto *ExtMapSpace = isl_map_get_space(ExtMap); + auto *ConstrSpace = isl_local_space_from_space(ExtMapSpace); + auto *Constr = + isl_constraint_alloc_inequality(isl_local_space_copy(ConstrSpace)); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, 1); + Constr = + isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound * (-1)); + ExtMap = isl_map_add_constraint(ExtMap, Constr); + Constr = isl_constraint_alloc_inequality(ConstrSpace); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, -1); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound); + Constr = isl_constraint_set_constant_si(Constr, Bound - 1); + return isl_map_add_constraint(ExtMap, Constr); +} + +/// @brief Create an access relation that is specific for matrix multiplication +/// pattern. +/// +/// Create an access relation of the following form: +/// { [O0, O1, O2]->[I1, I2, I3] : +/// FirstOutputDimBound * O0 <= I1 <= FirstOutputDimBound * (O0 + 1) - 1 +/// and SecondOutputDimBound * O1 <= I2 <= SecondOutputDimBound * (O1 + 1) - 1 +/// and ThirdOutputDimBound * O2 <= I3 <= ThirdOutputDimBound * (O2 + 1) - 1} +/// where FirstOutputDimBound is @p FirstOutputDimBound, +/// SecondOutputDimBound is @p SecondOutputDimBound, +/// ThirdOutputDimBound is @p ThirdOutputDimBound +/// +/// @param Ctx The isl context. +/// @param FirstOutputDimBound, +/// SecondOutputDimBound, +/// ThirdOutputDimBound The parameters of the access relation. +/// @return The specified access relation. +__isl_give isl_map *getMatMulExt(isl_ctx *Ctx, unsigned FirstOutputDimBound, + unsigned SecondOutputDimBound, + unsigned ThirdOutputDimBound) { + auto *NewRelSpace = isl_space_alloc(Ctx, 0, 3, 3); + auto *extensionMap = isl_map_universe(NewRelSpace); + if (!FirstOutputDimBound) + extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 0, 0); + else + extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 0, + FirstOutputDimBound); + if (!SecondOutputDimBound) + extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 1, 0); + else + extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 1, + SecondOutputDimBound); + if (!ThirdOutputDimBound) + extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 2, 0); + else + extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 2, + ThirdOutputDimBound); + return extensionMap; +} + /// @brief Create an access relation that is specific to the matrix /// multiplication pattern. /// @@ -761,6 +831,14 @@ return isl_map_apply_range(MapOldIndVar, AccessRel); } +__isl_give isl_schedule_node * +createExtensionNode(__isl_take isl_schedule_node *Node, + __isl_take isl_map *ExtensionMap) { + auto *Extension = isl_union_map_from_map(ExtensionMap); + auto *NewNode = isl_schedule_node_from_extension(Extension); + return isl_schedule_node_graft_before(Node, NewNode); +} + /// @brief Apply the packing transformation. /// /// The packing transformation can be described as a data-layout @@ -775,9 +853,9 @@ /// @param MicroParams, MacroParams Parameters of the BLIS kernel /// to be taken into account. /// @return The optimized schedule node. -static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar, - MicroKernelParamsTy MicroParams, - MacroKernelParamsTy MacroParams) { +static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern( + __isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar, + MicroKernelParamsTy MicroParams, MacroKernelParamsTy MacroParams) { auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in); auto *Stmt = static_cast(isl_id_get_user(InputDimsId)); isl_id_free(InputDimsId); @@ -785,8 +863,12 @@ MemoryAccess *MemAccessB = identifyAccessB(Stmt); if (!MemAccessA || !MemAccessB) { isl_map_free(MapOldIndVar); - return; + return Node; } + Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); + Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); + Node = isl_schedule_node_parent(Node); + Node = isl_schedule_node_child(isl_schedule_node_band_split(Node, 2), 0); auto *AccRel = getMatMulAccRel(isl_map_copy(MapOldIndVar), MacroParams.Kc, 3, 6); unsigned FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr; @@ -794,14 +876,34 @@ auto *SAI = Stmt->getParent()->createScopArrayInfo( MemAccessA->getElementType(), "Packed_A", {FirstDimSize, SecondDimSize}); AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId()); + auto *OldAcc = MemAccessA->getAccessRelation(); MemAccessA->setNewAccessRelation(AccRel); + auto *ExtMap = + getMatMulExt(Stmt->getIslCtx(), MacroParams.Mc, 0, MacroParams.Kc); + ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 1, 1); + auto *Domain = Stmt->getDomain(); + auto *NewStmt = Stmt->getParent()->addScopStmt( + OldAcc, MemAccessA->getAccessRelation(), isl_set_copy(Domain)); + ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId()); + Node = createExtensionNode(Node, ExtMap); + Node = isl_schedule_node_child(Node, 0); AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 7); FirstDimSize = MacroParams.Nc * MacroParams.Kc / MicroParams.Nr; SecondDimSize = MicroParams.Nr; SAI = Stmt->getParent()->createScopArrayInfo( MemAccessB->getElementType(), "Packed_B", {FirstDimSize, SecondDimSize}); AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId()); + OldAcc = MemAccessB->getAccessRelation(); MemAccessB->setNewAccessRelation(AccRel); + ExtMap = getMatMulExt(Stmt->getIslCtx(), 0, MacroParams.Nc, MacroParams.Kc); + isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 1, 1); + isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1); + NewStmt = Stmt->getParent()->addScopStmt( + OldAcc, MemAccessB->getAccessRelation(), Domain); + ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId()); + Node = createExtensionNode(Node, ExtMap); + Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0); + return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0); } /// @brief Get a relation mapping induction variables produced by schedule @@ -845,9 +947,8 @@ Node, MicroKernelParams, MacroKernelParams); if (!MapOldIndVar) return Node; - optimizeDataLayoutMatrMulPattern(MapOldIndVar, MicroKernelParams, - MacroKernelParams); - return Node; + return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams, + MacroKernelParams); } bool ScheduleTreeOptimizer::isMatrMultPattern( @@ -904,7 +1005,7 @@ } bool ScheduleTreeOptimizer::isProfitableSchedule( - Scop &S, __isl_keep isl_union_map *NewSchedule) { + Scop &S, __isl_keep isl_schedule *NewSchedule) { // To understand if the schedule has been optimized we check if the schedule // has changed at all. // TODO: We can improve this by tracking if any necessarily beneficial @@ -914,9 +1015,15 @@ // optimizations, by comparing (yet to be defined) performance metrics // before/after the scheduling optimizer // (e.g., #stride-one accesses) + if (S.containsExtNode(NewSchedule)) + return true; + auto *NewScheduleMap = isl_schedule_get_map(NewSchedule); isl_union_map *OldSchedule = S.getSchedule(); - bool changed = !isl_union_map_is_equal(OldSchedule, NewSchedule); + assert(OldSchedule && "Only IslScheduleOptimizer can insert extension nodes " + "that makes Scop::getSchedule() return nullptr."); + bool changed = !isl_union_map_is_equal(OldSchedule, NewScheduleMap); isl_union_map_free(OldSchedule); + isl_union_map_free(NewScheduleMap); return changed; } @@ -1093,10 +1200,8 @@ auto *TTI = &getAnalysis().getTTI(F); isl_schedule *NewSchedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI); - isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule); - if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) { - isl_union_map_free(NewScheduleMap); + if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewSchedule)) { isl_schedule_free(NewSchedule); return false; } @@ -1107,7 +1212,6 @@ if (OptimizedScops) S.dump(); - isl_union_map_free(NewScheduleMap); return false; } Index: test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll =================================================================== --- test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll +++ test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll @@ -18,6 +18,29 @@ ; CHECK: { Stmt_bb14[i0, i1, i2] -> MemRef_arg7[i2, i1] }; ; CHECK: new: { Stmt_bb14[i0, i1, i2] -> Packed_B[0, o1, o2] : 256*floor((-i2 + o1)/256) = -i2 + o1 and 8*floor((-i1 + o2)/8) = -i1 + o2 and 0 <= o2 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o1)/256) <= i1 - 96*floor((i1)/96) }; ; +; CHECK: Stmt_Copy_0 +; CHECK: Domain := +; CHECK: { Stmt_Copy_0[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 }; +; CHECK: Schedule := +; CHECK: ; +; CHECK: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0] +; CHECK: null; +; CHECK: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_A[0, o1, o2] : 256*floor((-i2 + o1)/256) = -i2 + o1 and 4*floor((-i0 + o2)/4) = -i0 + o2 and 0 <= o2 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o1)/256) <= i0 - 16*floor((i0)/16) }; +; CHECK: ReadAccess := [Reduction Type: NONE] [Scalar: 0] +; CHECK: null; +; CHECK: new: { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg6[i0, i2] }; +; CHECK: Stmt_Copy_1 +; CHECK: Domain := +; CHECK: { Stmt_Copy_1[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 }; +; CHECK: Schedule := +; CHECK: ; +; CHECK: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0] +; CHECK: null; +; CHECK: new: { Stmt_Copy_1[i0, i1, i2] -> Packed_B[0, o1, o2] : 256*floor((-i2 + o1)/256) = -i2 + o1 and 8*floor((-i1 + o2)/8) = -i1 + o2 and 0 <= o2 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o1)/256) <= i1 - 96*floor((i1)/96) }; +; CHECK: ReadAccess := [Reduction Type: NONE] [Scalar: 0] +; CHECK: null; +; CHECK: new: { Stmt_Copy_1[i0, i1, i2] -> MemRef_arg7[i2, i1] }; +; target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-unknown"