Index: polly/trunk/include/polly/CodeGen/IslExprBuilder.h =================================================================== --- polly/trunk/include/polly/CodeGen/IslExprBuilder.h +++ polly/trunk/include/polly/CodeGen/IslExprBuilder.h @@ -166,6 +166,17 @@ /// was enabled. llvm::Value *getOverflowState() const; + /// Create LLVM-IR that computes the memory location of an access expression. + /// + /// For a given isl_ast_expr[ession] of type isl_ast_op_access this function + /// creates IR that computes the address the access expression refers to. + /// + /// @param Expr The ast expression of type isl_ast_op_access + /// for which we generate LLVM-IR. + /// + /// @return The llvm::Value* containing the result of the computation. + llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr); + private: Scop &S; @@ -203,7 +214,6 @@ llvm::Value *createId(__isl_take isl_ast_expr *Expr); llvm::Value *createInt(__isl_take isl_ast_expr *Expr); llvm::Value *createOpAddressOf(__isl_take isl_ast_expr *Expr); - llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr); /// Create a binary operation @p Opc and track overflows if requested. /// Index: polly/trunk/include/polly/CodeGen/IslNodeBuilder.h =================================================================== --- polly/trunk/include/polly/CodeGen/IslNodeBuilder.h +++ polly/trunk/include/polly/CodeGen/IslNodeBuilder.h @@ -375,6 +375,21 @@ /// virtual __isl_give isl_union_map * getScheduleForAstNode(__isl_take isl_ast_node *Node); + +private: + /// Create code for a copy statement. + /// + /// A copy statement is expected to have one read memory access and one write + /// memory access (in this very order). Data is loaded from the location + /// described by the read memory access and written to the location described + /// by the write memory access. @p NewAccesses contains for each access + /// the isl ast expression that describes the location accessed. + /// + /// @param Stmt The copy statement that contains the accesses. + /// @param NewAccesses The hash table that contains remappings from memory + /// ids to new access expressions. + void generateCopyStmt(ScopStmt *Stmt, + __isl_keep isl_id_to_ast_expr *NewAccesses); }; #endif Index: polly/trunk/include/polly/ScheduleOptimizer.h =================================================================== --- polly/trunk/include/polly/ScheduleOptimizer.h +++ polly/trunk/include/polly/ScheduleOptimizer.h @@ -88,7 +88,7 @@ /// /// @return True, if we believe @p NewSchedule is an improvement for @p S. static bool isProfitableSchedule(polly::Scop &S, - __isl_keep isl_union_map *NewSchedule); + __isl_keep isl_schedule *NewSchedule); /// Isolate a set of partial tile prefixes. /// Index: polly/trunk/include/polly/ScopInfo.h =================================================================== --- polly/trunk/include/polly/ScopInfo.h +++ polly/trunk/include/polly/ScopInfo.h @@ -689,6 +689,19 @@ ArrayRef Subscripts, ArrayRef Sizes, Value *AccessValue, ScopArrayInfo::MemoryKind Kind, StringRef BaseName); + + /// Create a new MemoryAccess that corresponds to @p AccRel. + /// + /// Along with @p Stmt and @p AccType it uses information about dimension + /// lengths of the accessed array, the type of the accessed array elements, + /// the name of the accessed array that is derived from the object accessible + /// via @p AccRel. + /// + /// @param Stmt The parent statement. + /// @param AccType Whether read or write access. + /// @param AccRel The access relation that describes the memory access. + MemoryAccess(ScopStmt *Stmt, AccessType AccType, __isl_take isl_map *AccRel); + ~MemoryAccess(); /// Add a new incoming block/value pairs for this PHI/ExitPHI access. @@ -1083,6 +1096,16 @@ /// Create an overapproximating ScopStmt for the region @p R. ScopStmt(Scop &parent, Region &R); + /// Create a copy statement. + /// + /// @param Stmt The parent statement. + /// @param SourceRel The source location. + /// @param TargetRel The target location. + /// @param Domain The original domain under which copy statement whould + /// be executed. + ScopStmt(Scop &parent, __isl_take isl_map *SourceRel, + __isl_take isl_map *TargetRel, __isl_take isl_set *Domain); + /// Initialize members after all MemoryAccesses have been added. void init(LoopInfo &LI); @@ -1217,10 +1240,14 @@ /// Get the schedule function of this ScopStmt. /// - /// @return The schedule function of this ScopStmt. + /// @return The schedule function of this ScopStmt, if it does not contain + /// extension nodes, and nullptr, otherwise. __isl_give isl_map *getSchedule() const; /// Get an isl string representing this schedule. + /// + /// @return An isl string representing this schedule, if it does not contain + /// extension nodes, and an empty string, otherwise. std::string getScheduleStr() const; /// Get the invalid domain for this statement. @@ -1245,6 +1272,9 @@ /// Return true if this statement represents a single basic block. bool isBlockStmt() const { return BB != nullptr; } + /// Return true if this is a copy statement. + bool isCopyStmt() const { return BB == nullptr && R == nullptr; } + /// Get the region represented by this ScopStmt (if any). /// /// @return The region represented by this ScopStmt, or null if the statement @@ -1448,6 +1478,9 @@ /// Max loop depth. unsigned MaxLoopDepth; + /// Number of copy statements. + unsigned CopyStmtsNum; + typedef std::list StmtSet; /// The statements in this Scop. StmtSet Stmts; @@ -1615,11 +1648,6 @@ Scop(Region &R, ScalarEvolution &SE, LoopInfo &LI, ScopDetection::DetectionContext &DC); - /// Add the access function to all MemoryAccess objects of the Scop - /// created in this pass. - void addAccessFunction(MemoryAccess *Access) { - AccessFunctions.emplace_back(Access); - } //@} /// Initialize this ScopBuilder. @@ -1927,6 +1955,30 @@ public: ~Scop(); + /// Get the count of copy statements added to this Scop. + /// + /// @return The count of copy statements added to this Scop. + unsigned getCopyStmtsNum() { return CopyStmtsNum; } + + /// Create a new copy statement. + /// + /// A new statement will be created and added to the statement vector. + /// + /// @param Stmt The parent statement. + /// @param SourceRel The source location. + /// @param TargetRel The target location. + /// @param Domain The original domain under which copy statement whould + /// be executed. + ScopStmt *addScopStmt(__isl_take isl_map *SourceRel, + __isl_take isl_map *TargetRel, + __isl_take isl_set *Domain); + + /// Add the access function to all MemoryAccess objects of the Scop + /// created in this pass. + void addAccessFunction(MemoryAccess *Access) { + AccessFunctions.emplace_back(Access); + } + ScalarEvolution *getSE() const; /// Get the count of parameters used in this Scop. @@ -2349,6 +2401,9 @@ __isl_give isl_union_map *getAccesses(); /// Get the schedule of all the statements in the SCoP. + /// + /// @return The schedule of all the statements in the SCoP, if the schedule of + /// the Scop does not contain extension nodes, and nullptr, otherwise. __isl_give isl_union_map *getSchedule() const; /// Get a schedule tree describing the schedule of all statements. @@ -2380,6 +2435,11 @@ /// Find the ScopArrayInfo associated with an isl Id /// that has name @p Name. ScopArrayInfo *getArrayInfoByName(const std::string BaseName); + + /// Check whether @p Schedule contains extension nodes. + /// + /// @return true if @p Schedule contains extension nodes. + static bool containsExtensionNode(__isl_keep isl_schedule *Schedule); }; /// Print Scop scop to raw_ostream O. Index: polly/trunk/lib/Analysis/DependenceInfo.cpp =================================================================== --- polly/trunk/lib/Analysis/DependenceInfo.cpp +++ polly/trunk/lib/Analysis/DependenceInfo.cpp @@ -153,6 +153,8 @@ // to match the new access domains, thus we need // [Stmt[i0, i1] -> MemAcc_A[i0 + i1]] -> [0, i0, 2, i1, 0] isl_map *Schedule = Stmt.getSchedule(); + assert(Schedule && "Schedules that contain extension nodes require " + "special handling."); Schedule = isl_map_apply_domain( Schedule, isl_map_reverse(isl_map_domain_map(isl_map_copy(accdom)))); @@ -162,7 +164,10 @@ } else { accdom = tag(accdom, MA, Level); if (Level > Dependences::AL_Statement) { - isl_map *Schedule = tag(Stmt.getSchedule(), MA, Level); + auto *StmtScheduleMap = Stmt.getSchedule(); + assert(StmtScheduleMap && "Schedules that contain extension nodes " + "require special handling."); + isl_map *Schedule = tag(StmtScheduleMap, MA, Level); *StmtSchedule = isl_union_map_add_map(*StmtSchedule, Schedule); } } @@ -610,6 +615,8 @@ StmtScat = Stmt.getSchedule(); else StmtScat = isl_map_copy((*NewSchedule)[&Stmt]); + assert(StmtScat && + "Schedules that contain extension nodes require special handling."); if (!ScheduleSpace) ScheduleSpace = isl_space_range(isl_map_get_space(StmtScat)); Index: polly/trunk/lib/Analysis/PolyhedralInfo.cpp =================================================================== --- polly/trunk/lib/Analysis/PolyhedralInfo.cpp +++ polly/trunk/lib/Analysis/PolyhedralInfo.cpp @@ -134,6 +134,8 @@ unsigned int MaxDim = SS->getNumIterators(); DEBUG(dbgs() << "Maximum depth of Stmt:\t" << MaxDim << "\n"); auto *ScheduleMap = SS->getSchedule(); + assert(ScheduleMap && + "Schedules that contain extension nodes require special handling."); ScheduleMap = isl_map_project_out(ScheduleMap, isl_dim_out, CurrDim + 1, MaxDim - CurrDim - 1); Index: polly/trunk/lib/Analysis/ScopInfo.cpp =================================================================== --- polly/trunk/lib/Analysis/ScopInfo.cpp +++ polly/trunk/lib/Analysis/ScopInfo.cpp @@ -857,6 +857,28 @@ Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this); } +MemoryAccess::MemoryAccess(ScopStmt *Stmt, AccessType AccType, + __isl_take isl_map *AccRel) + : Kind(ScopArrayInfo::MemoryKind::MK_Array), AccType(AccType), + RedType(RT_NONE), Statement(Stmt), InvalidDomain(nullptr), + AccessInstruction(nullptr), IsAffine(true), AccessRelation(nullptr), + NewAccessRelation(AccRel) { + auto *ArrayInfoId = isl_map_get_tuple_id(NewAccessRelation, isl_dim_out); + auto *SAI = ScopArrayInfo::getFromId(ArrayInfoId); + Sizes.push_back(nullptr); + for (unsigned i = 1; i < SAI->getNumberOfDimensions(); i++) + Sizes.push_back(SAI->getDimensionSize(i)); + ElementType = SAI->getElementType(); + BaseAddr = SAI->getBasePtr(); + BaseName = SAI->getName(); + static const std::string TypeStrings[] = {"", "_Read", "_Write", "_MayWrite"}; + const std::string Access = TypeStrings[AccType] + utostr(Stmt->size()) + "_"; + + std::string IdName = + getIslCompatibleName(Stmt->getBaseName(), Access, BaseName); + Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this); +} + void MemoryAccess::realignParams() { auto *Ctx = Statement->getParent()->getContext(); InvalidDomain = isl_set_gist_params(InvalidDomain, isl_set_copy(Ctx)); @@ -1040,6 +1062,10 @@ isl_aff_zero_on_domain(isl_local_space_from_space(getDomainSpace()))); } auto *Schedule = getParent()->getSchedule(); + if (!Schedule) { + isl_set_free(Domain); + return nullptr; + } Schedule = isl_union_map_intersect_domain( Schedule, isl_union_set_from_set(isl_set_copy(Domain))); if (isl_union_map_is_empty(Schedule)) { @@ -1430,6 +1456,25 @@ BaseName = getIslCompatibleName("Stmt_", &bb, ""); } +ScopStmt::ScopStmt(Scop &parent, __isl_take isl_map *SourceRel, + __isl_take isl_map *TargetRel, __isl_take isl_set *NewDomain) + : Parent(parent), InvalidDomain(nullptr), Domain(NewDomain), BB(nullptr), + R(nullptr), Build(nullptr) { + BaseName = getIslCompatibleName("CopyStmt_", "", + std::to_string(parent.getCopyStmtsNum())); + auto *Id = isl_id_alloc(getIslCtx(), getBaseName(), this); + Domain = isl_set_set_tuple_id(Domain, isl_id_copy(Id)); + TargetRel = isl_map_set_tuple_id(TargetRel, isl_dim_in, Id); + auto *Access = + new MemoryAccess(this, MemoryAccess::AccessType::MUST_WRITE, TargetRel); + parent.addAccessFunction(Access); + addAccess(Access); + SourceRel = isl_map_set_tuple_id(SourceRel, isl_dim_in, isl_id_copy(Id)); + Access = new MemoryAccess(this, MemoryAccess::AccessType::READ, SourceRel); + parent.addAccessFunction(Access); + addAccess(Access); +} + void ScopStmt::init(LoopInfo &LI) { assert(!Domain && "init must be called only once"); @@ -1576,6 +1621,8 @@ std::string ScopStmt::getScheduleStr() const { auto *S = getSchedule(); + if (!S) + return ""; auto Str = stringFromIslObj(S); isl_map_free(S); return Str; @@ -3041,9 +3088,10 @@ ScopDetection::DetectionContext &DC) : SE(&ScalarEvolution), R(R), IsOptimized(false), HasSingleExitEdge(R.getExitingBlock()), HasErrorBlock(false), - MaxLoopDepth(0), DC(DC), IslCtx(isl_ctx_alloc(), isl_ctx_free), - Context(nullptr), Affinator(this, LI), AssumedContext(nullptr), - InvalidContext(nullptr), Schedule(nullptr) { + MaxLoopDepth(0), CopyStmtsNum(0), DC(DC), + IslCtx(isl_ctx_alloc(), isl_ctx_free), Context(nullptr), + Affinator(this, LI), AssumedContext(nullptr), InvalidContext(nullptr), + Schedule(nullptr) { if (IslOnErrorAbort) isl_options_set_on_error(getIslCtx(), ISL_ON_ERROR_ABORT); buildContext(); @@ -3922,8 +3970,27 @@ return getAccessesOfType([](MemoryAccess &MA) { return true; }); } +// Check whether @p Node is an extension node. +// +// @return true if @p Node is an extension node. +isl_bool isNotExtNode(__isl_keep isl_schedule_node *Node, void *User) { + if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) + return isl_bool_error; + else + return isl_bool_true; +} + +bool Scop::containsExtensionNode(__isl_keep isl_schedule *Schedule) { + return isl_schedule_foreach_schedule_node_top_down(Schedule, isNotExtNode, + nullptr) == isl_stat_error; +} + __isl_give isl_union_map *Scop::getSchedule() const { auto *Tree = getScheduleTree(); + if (containsExtensionNode(Tree)) { + isl_schedule_free(Tree); + return nullptr; + } auto *S = isl_schedule_get_map(Tree); isl_schedule_free(Tree); return S; @@ -4059,6 +4126,14 @@ } } +ScopStmt *Scop::addScopStmt(__isl_take isl_map *SourceRel, + __isl_take isl_map *TargetRel, + __isl_take isl_set *Domain) { + Stmts.emplace_back(*this, SourceRel, TargetRel, Domain); + CopyStmtsNum++; + return &(Stmts.back()); +} + void Scop::buildSchedule(LoopInfo &LI) { Loop *L = getLoopSurroundingScop(*this, LI); LoopStackTy LoopStack({LoopStackElementTy(L, nullptr, 0)}); Index: polly/trunk/lib/CodeGen/BlockGenerators.cpp =================================================================== --- polly/trunk/lib/CodeGen/BlockGenerators.cpp +++ polly/trunk/lib/CodeGen/BlockGenerators.cpp @@ -681,7 +681,9 @@ void BlockGenerator::invalidateScalarEvolution(Scop &S) { for (auto &Stmt : S) - if (Stmt.isBlockStmt()) + if (Stmt.isCopyStmt()) + continue; + else if (Stmt.isBlockStmt()) for (auto &Inst : *Stmt.getBasicBlock()) SE.forgetValue(&Inst); else if (Stmt.isRegionStmt()) Index: polly/trunk/lib/CodeGen/IRBuilder.cpp =================================================================== --- polly/trunk/lib/CodeGen/IRBuilder.cpp +++ polly/trunk/lib/CodeGen/IRBuilder.cpp @@ -61,7 +61,8 @@ SetVector BasePtrs; for (ScopStmt &Stmt : S) for (MemoryAccess *MA : Stmt) - BasePtrs.insert(MA->getBaseAddr()); + if (!Stmt.isCopyStmt()) + BasePtrs.insert(MA->getBaseAddr()); std::string AliasScopeStr = "polly.alias.scope."; for (Value *BasePtr : BasePtrs) Index: polly/trunk/lib/CodeGen/IslAst.cpp =================================================================== --- polly/trunk/lib/CodeGen/IslAst.cpp +++ polly/trunk/lib/CodeGen/IslAst.cpp @@ -593,8 +593,7 @@ P = isl_ast_node_print(RootNode, P, Options); AstStr = isl_printer_get_str(P); - isl_union_map *Schedule = - isl_union_map_intersect_domain(S.getSchedule(), S.getDomains()); + auto *Schedule = S.getScheduleTree(); DEBUG({ dbgs() << S.getContextStr() << "\n"; @@ -609,7 +608,7 @@ free(AstStr); isl_ast_expr_free(RunCondition); - isl_union_map_free(Schedule); + isl_schedule_free(Schedule); isl_ast_node_free(RootNode); isl_printer_free(P); } Index: polly/trunk/lib/CodeGen/IslNodeBuilder.cpp =================================================================== --- polly/trunk/lib/CodeGen/IslNodeBuilder.cpp +++ polly/trunk/lib/CodeGen/IslNodeBuilder.cpp @@ -767,6 +767,23 @@ isl_ast_expr_free(Expr); } +void IslNodeBuilder::generateCopyStmt( + ScopStmt *Stmt, __isl_keep isl_id_to_ast_expr *NewAccesses) { + assert(Stmt->size() == 2); + auto ReadAccess = Stmt->begin(); + auto WriteAccess = ReadAccess++; + assert((*ReadAccess)->isRead() && (*WriteAccess)->isMustWrite()); + assert((*ReadAccess)->getElementType() == (*WriteAccess)->getElementType() && + "Accesses use the same data type"); + assert((*ReadAccess)->isArrayKind() && (*WriteAccess)->isArrayKind()); + auto *AccessExpr = + isl_id_to_ast_expr_get(NewAccesses, (*ReadAccess)->getId()); + auto *LoadValue = ExprBuilder.create(AccessExpr); + AccessExpr = isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId()); + auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr); + Builder.CreateStore(LoadValue, StoreAddr); +} + void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) { LoopToScevMapT LTS; isl_id *Id; @@ -781,12 +798,17 @@ Stmt = (ScopStmt *)isl_id_get_user(Id); auto *NewAccesses = createNewAccesses(Stmt, User); - createSubstitutions(Expr, Stmt, LTS); + if (Stmt->isCopyStmt()) { + generateCopyStmt(Stmt, NewAccesses); + isl_ast_expr_free(Expr); + } else { + createSubstitutions(Expr, Stmt, LTS); - if (Stmt->isBlockStmt()) - BlockGen.copyStmt(*Stmt, LTS, NewAccesses); - else - RegionGen.copyStmt(*Stmt, LTS, NewAccesses); + if (Stmt->isBlockStmt()) + BlockGen.copyStmt(*Stmt, LTS, NewAccesses); + else + RegionGen.copyStmt(*Stmt, LTS, NewAccesses); + } isl_id_to_ast_expr_free(NewAccesses); isl_ast_node_free(User); Index: polly/trunk/lib/Exchange/JSONExporter.cpp =================================================================== --- polly/trunk/lib/Exchange/JSONExporter.cpp +++ polly/trunk/lib/Exchange/JSONExporter.cpp @@ -294,6 +294,8 @@ int Index = 0; for (ScopStmt &Stmt : S) { Json::Value Schedule = JScop["statements"][Index]["schedule"]; + assert(!Schedule.asString().empty() && + "Schedules that contain extension nodes require special handling."); isl_map *Map = isl_map_read_from_str(S.getIslCtx(), Schedule.asCString()); isl_space *Space = Stmt.getDomainSpace(); Index: polly/trunk/lib/Transform/DeadCodeElimination.cpp =================================================================== --- polly/trunk/lib/Transform/DeadCodeElimination.cpp +++ polly/trunk/lib/Transform/DeadCodeElimination.cpp @@ -92,6 +92,8 @@ // no point in trying to remove them from the live-out set. __isl_give isl_union_set *DeadCodeElim::getLiveOut(Scop &S) { isl_union_map *Schedule = S.getSchedule(); + assert(Schedule && + "Schedules that contain extension nodes require special handling."); isl_union_map *WriteIterations = isl_union_map_reverse(S.getMustWrites()); isl_union_map *WriteTimes = isl_union_map_apply_range(WriteIterations, isl_union_map_copy(Schedule)); Index: polly/trunk/lib/Transform/ScheduleOptimizer.cpp =================================================================== --- polly/trunk/lib/Transform/ScheduleOptimizer.cpp +++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp @@ -660,6 +660,76 @@ return IdentifiedAccess; } +/// Add constrains to @Dim dimension of @p ExtMap. +/// +/// If @ExtMap has the following form [O0, O1, O2]->[I1, I2, I3], +/// the following constraint will be added +/// Bound * OM <= IM <= Bound * (OM + 1) - 1, +/// where M is @p Dim and Bound is @p Bound. +/// +/// @param ExtMap The isl map to be modified. +/// @param Dim The output dimension to be modfied. +/// @param Bound The value that is used to specify the constraint. +/// @return The modified isl map +__isl_give isl_map * +addExtensionMapMatMulDimConstraint(__isl_take isl_map *ExtMap, unsigned Dim, + unsigned Bound) { + assert(Bound != 0); + auto *ExtMapSpace = isl_map_get_space(ExtMap); + auto *ConstrSpace = isl_local_space_from_space(ExtMapSpace); + auto *Constr = + isl_constraint_alloc_inequality(isl_local_space_copy(ConstrSpace)); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, 1); + Constr = + isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound * (-1)); + ExtMap = isl_map_add_constraint(ExtMap, Constr); + Constr = isl_constraint_alloc_inequality(ConstrSpace); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, -1); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound); + Constr = isl_constraint_set_constant_si(Constr, Bound - 1); + return isl_map_add_constraint(ExtMap, Constr); +} + +/// Create an access relation that is specific for matrix multiplication +/// pattern. +/// +/// Create an access relation of the following form: +/// { [O0, O1, O2]->[I1, I2, I3] : +/// FirstOutputDimBound * O0 <= I1 <= FirstOutputDimBound * (O0 + 1) - 1 +/// and SecondOutputDimBound * O1 <= I2 <= SecondOutputDimBound * (O1 + 1) - 1 +/// and ThirdOutputDimBound * O2 <= I3 <= ThirdOutputDimBound * (O2 + 1) - 1} +/// where FirstOutputDimBound is @p FirstOutputDimBound, +/// SecondOutputDimBound is @p SecondOutputDimBound, +/// ThirdOutputDimBound is @p ThirdOutputDimBound +/// +/// @param Ctx The isl context. +/// @param FirstOutputDimBound, +/// SecondOutputDimBound, +/// ThirdOutputDimBound The parameters of the access relation. +/// @return The specified access relation. +__isl_give isl_map *getMatMulExt(isl_ctx *Ctx, unsigned FirstOutputDimBound, + unsigned SecondOutputDimBound, + unsigned ThirdOutputDimBound) { + auto *NewRelSpace = isl_space_alloc(Ctx, 0, 3, 3); + auto *extensionMap = isl_map_universe(NewRelSpace); + if (!FirstOutputDimBound) + extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 0, 0); + else + extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 0, + FirstOutputDimBound); + if (!SecondOutputDimBound) + extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 1, 0); + else + extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 1, + SecondOutputDimBound); + if (!ThirdOutputDimBound) + extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 2, 0); + else + extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 2, + ThirdOutputDimBound); + return extensionMap; +} + /// Create an access relation that is specific to the matrix /// multiplication pattern. /// @@ -758,6 +828,14 @@ return isl_map_apply_range(MapOldIndVar, AccessRel); } +__isl_give isl_schedule_node * +createExtensionNode(__isl_take isl_schedule_node *Node, + __isl_take isl_map *ExtensionMap) { + auto *Extension = isl_union_map_from_map(ExtensionMap); + auto *NewNode = isl_schedule_node_from_extension(Extension); + return isl_schedule_node_graft_before(Node, NewNode); +} + /// Apply the packing transformation. /// /// The packing transformation can be described as a data-layout @@ -772,9 +850,9 @@ /// @param MicroParams, MacroParams Parameters of the BLIS kernel /// to be taken into account. /// @return The optimized schedule node. -static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar, - MicroKernelParamsTy MicroParams, - MacroKernelParamsTy MacroParams) { +static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern( + __isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar, + MicroKernelParamsTy MicroParams, MacroKernelParamsTy MacroParams) { auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in); auto *Stmt = static_cast(isl_id_get_user(InputDimsId)); isl_id_free(InputDimsId); @@ -782,8 +860,12 @@ MemoryAccess *MemAccessB = identifyAccessB(Stmt); if (!MemAccessA || !MemAccessB) { isl_map_free(MapOldIndVar); - return; + return Node; } + Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); + Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); + Node = isl_schedule_node_parent(Node); + Node = isl_schedule_node_child(isl_schedule_node_band_split(Node, 2), 0); auto *AccRel = getMatMulAccRel(isl_map_copy(MapOldIndVar), MacroParams.Kc, 3, 6); unsigned FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr; @@ -791,14 +873,34 @@ auto *SAI = Stmt->getParent()->createScopArrayInfo( MemAccessA->getElementType(), "Packed_A", {FirstDimSize, SecondDimSize}); AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId()); + auto *OldAcc = MemAccessA->getAccessRelation(); MemAccessA->setNewAccessRelation(AccRel); + auto *ExtMap = + getMatMulExt(Stmt->getIslCtx(), MacroParams.Mc, 0, MacroParams.Kc); + ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 1, 1); + auto *Domain = Stmt->getDomain(); + auto *NewStmt = Stmt->getParent()->addScopStmt( + OldAcc, MemAccessA->getAccessRelation(), isl_set_copy(Domain)); + ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId()); + Node = createExtensionNode(Node, ExtMap); + Node = isl_schedule_node_child(Node, 0); AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 7); FirstDimSize = MacroParams.Nc * MacroParams.Kc / MicroParams.Nr; SecondDimSize = MicroParams.Nr; SAI = Stmt->getParent()->createScopArrayInfo( MemAccessB->getElementType(), "Packed_B", {FirstDimSize, SecondDimSize}); AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId()); + OldAcc = MemAccessB->getAccessRelation(); MemAccessB->setNewAccessRelation(AccRel); + ExtMap = getMatMulExt(Stmt->getIslCtx(), 0, MacroParams.Nc, MacroParams.Kc); + isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 1, 1); + isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1); + NewStmt = Stmt->getParent()->addScopStmt( + OldAcc, MemAccessB->getAccessRelation(), Domain); + ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId()); + Node = createExtensionNode(Node, ExtMap); + Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0); + return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0); } /// Get a relation mapping induction variables produced by schedule @@ -842,9 +944,8 @@ Node, MicroKernelParams, MacroKernelParams); if (!MapOldIndVar) return Node; - optimizeDataLayoutMatrMulPattern(MapOldIndVar, MicroKernelParams, - MacroKernelParams); - return Node; + return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams, + MacroKernelParams); } bool ScheduleTreeOptimizer::isMatrMultPattern( @@ -901,7 +1002,7 @@ } bool ScheduleTreeOptimizer::isProfitableSchedule( - Scop &S, __isl_keep isl_union_map *NewSchedule) { + Scop &S, __isl_keep isl_schedule *NewSchedule) { // To understand if the schedule has been optimized we check if the schedule // has changed at all. // TODO: We can improve this by tracking if any necessarily beneficial @@ -911,9 +1012,15 @@ // optimizations, by comparing (yet to be defined) performance metrics // before/after the scheduling optimizer // (e.g., #stride-one accesses) + if (S.containsExtensionNode(NewSchedule)) + return true; + auto *NewScheduleMap = isl_schedule_get_map(NewSchedule); isl_union_map *OldSchedule = S.getSchedule(); - bool changed = !isl_union_map_is_equal(OldSchedule, NewSchedule); + assert(OldSchedule && "Only IslScheduleOptimizer can insert extension nodes " + "that make Scop::getSchedule() return nullptr."); + bool changed = !isl_union_map_is_equal(OldSchedule, NewScheduleMap); isl_union_map_free(OldSchedule); + isl_union_map_free(NewScheduleMap); return changed; } @@ -1090,10 +1197,8 @@ auto *TTI = &getAnalysis().getTTI(F); isl_schedule *NewSchedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI); - isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule); - if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) { - isl_union_map_free(NewScheduleMap); + if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewSchedule)) { isl_schedule_free(NewSchedule); return false; } @@ -1104,7 +1209,6 @@ if (OptimizedScops) S.dump(); - isl_union_map_free(NewScheduleMap); return false; } Index: polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll =================================================================== --- polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll +++ polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll @@ -12,11 +12,34 @@ ; CHECK: double Packed_A[ { [] -> [(1024)] } ][ { [] -> [(4)] } ]; // Element size 8 ; CHECK: double Packed_B[ { [] -> [(3072)] } ][ { [] -> [(8)] } ]; // Element size 8 ; -; CHECK: { Stmt_bb14[i0, i1, i2] -> MemRef_arg6[i0, i2] }; -; CHECK: new: { Stmt_bb14[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) }; +; CHECK: { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg6[i0, i2] }; +; CHECK: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) }; ; -; CHECK: { Stmt_bb14[i0, i1, i2] -> MemRef_arg7[i2, i1] }; -; CHECK: new: { Stmt_bb14[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) }; +; CHECK: { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg7[i2, i1] }; +; CHECK: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) }; +; +; CHECK: CopyStmt_0 +; CHECK: Domain := +; CHECK: { CopyStmt_0[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 }; +; CHECK: Schedule := +; CHECK: ; +; CHECK: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0] +; CHECK: null; +; CHECK: new: { CopyStmt_0[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) }; +; CHECK: ReadAccess := [Reduction Type: NONE] [Scalar: 0] +; CHECK: null; +; CHECK: new: { CopyStmt_0[i0, i1, i2] -> MemRef_arg6[i0, i2] }; +; CHECK: CopyStmt_1 +; CHECK: Domain := +; CHECK: { CopyStmt_1[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 }; +; CHECK: Schedule := +; CHECK: ; +; CHECK: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0] +; CHECK: null; +; CHECK: new: { CopyStmt_1[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) }; +; CHECK: ReadAccess := [Reduction Type: NONE] [Scalar: 0] +; CHECK: null; +; CHECK: new: { CopyStmt_1[i0, i1, i2] -> MemRef_arg7[i2, i1] }; ; target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-unknown" @@ -35,10 +58,10 @@ %tmp12 = load double, double* %tmp11, align 8 %tmp13 = fmul double %tmp12, %arg4 store double %tmp13, double* %tmp11, align 8 - br label %bb14 + br label %Copy_0 -bb14: ; preds = %bb14, %bb9 - %tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %bb14 ] +Copy_0: ; preds = %Copy_0, %bb9 + %tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %Copy_0 ] %tmp16 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp, i64 %tmp15 %tmp17 = load double, double* %tmp16, align 8 %tmp18 = fmul double %tmp17, %arg3 @@ -50,9 +73,9 @@ store double %tmp23, double* %tmp11, align 8 %tmp24 = add nuw nsw i64 %tmp15, 1 %tmp25 = icmp ne i64 %tmp24, 1024 - br i1 %tmp25, label %bb14, label %bb26 + br i1 %tmp25, label %Copy_0, label %bb26 -bb26: ; preds = %bb14 +bb26: ; preds = %Copy_0 %tmp27 = add nuw nsw i64 %tmp10, 1 %tmp28 = icmp ne i64 %tmp27, 1056 br i1 %tmp28, label %bb9, label %bb29