Index: include/polly/CodeGen/IslExprBuilder.h =================================================================== --- include/polly/CodeGen/IslExprBuilder.h +++ include/polly/CodeGen/IslExprBuilder.h @@ -166,6 +166,8 @@ /// was enabled. llvm::Value *getOverflowState() const; + llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr); + private: Scop &S; @@ -203,7 +205,7 @@ llvm::Value *createId(__isl_take isl_ast_expr *Expr); llvm::Value *createInt(__isl_take isl_ast_expr *Expr); llvm::Value *createOpAddressOf(__isl_take isl_ast_expr *Expr); - llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr); + // llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr); /// @brief Create a binary operation @p Opc and track overflows if requested. /// Index: include/polly/CodeGen/IslNodeBuilder.h =================================================================== --- include/polly/CodeGen/IslNodeBuilder.h +++ include/polly/CodeGen/IslNodeBuilder.h @@ -348,6 +348,16 @@ /// virtual __isl_give isl_union_map * getScheduleForAstNode(__isl_take isl_ast_node *Node); + +private: + /// @brief Read a value specified by the read access and store into a location + /// specified by write access. + /// + /// @param Stmt The copy statement that contains the accesses. + /// @param NewAccesses The hash table that contains remappings from memory + /// ids to new access expressions. + void generateCopyStmt(ScopStmt *Stmt, + __isl_keep isl_id_to_ast_expr *NewAccesses); }; #endif Index: include/polly/ScopInfo.h =================================================================== --- include/polly/ScopInfo.h +++ include/polly/ScopInfo.h @@ -969,6 +969,9 @@ /// Create an overapproximating ScopStmt for the region @p R. ScopStmt(Scop &parent, Region &R); + /// Create a copy statement. + ScopStmt(Scop &parent, MemoryAccess *MemA); + /// Initialize members after all MemoryAccesses have been added. void init(LoopInfo &LI); @@ -1162,6 +1165,9 @@ /// @brief Return true if this statement represents a single basic block. bool isBlockStmt() const { return BB != nullptr; } + /// @brief Return true if this is a copy statement. + bool isCopyStmt() const { return BB == nullptr && R == nullptr; } + /// @brief Get the region represented by this ScopStmt (if any). /// /// @return The region represented by this ScopStmt, or null if the statement Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -1440,6 +1440,42 @@ BaseName = getIslCompatibleName("Stmt_", &bb, ""); } +ScopStmt::ScopStmt(Scop &parent, MemoryAccess *MemA) + : Parent(parent), InvalidDomain(nullptr), BB(nullptr), R(nullptr), + Build(nullptr) { + BaseName = getIslCompatibleName("Stmt_", MemA->getBaseName(), ""); + auto *Stmt = MemA->getStatement(); + auto *Id = isl_id_alloc(getIslCtx(), getBaseName(), this); + Domain = isl_set_set_tuple_id(Stmt->getDomain(), isl_id_copy(Id)); + std::vector Subscripts; + for (unsigned i = 0; i < MemA->getNumSubscripts(); i++) + Subscripts.push_back(MemA->getSubscript(i)); + std::vector Sizes; + auto *SAI = MemA->getScopArrayInfo(); + Sizes.push_back(nullptr); + for (unsigned i = 1; i < SAI->getNumberOfDimensions(); i++) + Sizes.push_back(SAI->getDimensionSize(i)); + MemoryAccess *MemRead = new MemoryAccess( + this, MemA->getAccessInstruction(), MemoryAccess::AccessType::READ, + MemA->getBaseAddr(), MemA->getElementType(), true, Subscripts, Sizes, + MemA->getAccessValue(), ScopArrayInfo::MemoryKind::MK_Array, + MemA->getBaseName()); + auto *NewAccessRelation = MemA->getOriginalAccessRelation(); + NewAccessRelation = + isl_map_set_tuple_id(NewAccessRelation, isl_dim_in, isl_id_copy(Id)); + MemRead->setNewAccessRelation(NewAccessRelation); + addAccess(MemRead); + MemoryAccess *MemWrite = new MemoryAccess( + this, MemA->getAccessInstruction(), MemoryAccess::AccessType::MUST_WRITE, + MemA->getBaseAddr(), MemA->getElementType(), true, Subscripts, Sizes, + MemA->getAccessValue(), ScopArrayInfo::MemoryKind::MK_Array, + MemA->getBaseName()); + NewAccessRelation = MemA->getAccessRelation(); + NewAccessRelation = isl_map_set_tuple_id(NewAccessRelation, isl_dim_in, Id); + MemWrite->setNewAccessRelation(NewAccessRelation); + addAccess(MemWrite); +} + void ScopStmt::init(LoopInfo &LI) { assert(!Domain && "init must be called only once"); Index: lib/CodeGen/IslNodeBuilder.cpp =================================================================== --- lib/CodeGen/IslNodeBuilder.cpp +++ lib/CodeGen/IslNodeBuilder.cpp @@ -774,6 +774,23 @@ isl_ast_expr_free(Expr); } +void IslNodeBuilder::generateCopyStmt( + ScopStmt *Stmt, __isl_keep isl_id_to_ast_expr *NewAccesses) { + Value *LoadValue; + Value *StoreAddr; + for (auto *MemA : *Stmt) { + isl_ast_expr *AccessExpr = + isl_id_to_ast_expr_get(NewAccesses, MemA->getId()); + if (MemA->isRead()) + LoadValue = ExprBuilder.create(AccessExpr); + else + StoreAddr = ExprBuilder.createAccessAddress(AccessExpr); + delete MemA; + } + Builder.CreateStore(LoadValue, StoreAddr); + delete Stmt; +} + void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) { LoopToScevMapT LTS; isl_id *Id; @@ -788,12 +805,18 @@ Stmt = (ScopStmt *)isl_id_get_user(Id); auto *NewAccesses = createNewAccesses(Stmt, User); - createSubstitutions(Expr, Stmt, LTS); - if (Stmt->isBlockStmt()) - BlockGen.copyStmt(*Stmt, LTS, NewAccesses); - else - RegionGen.copyStmt(*Stmt, LTS, NewAccesses); + if (Stmt->isCopyStmt()) { + generateCopyStmt(Stmt, NewAccesses); + isl_ast_expr_free(Expr); + } else { + createSubstitutions(Expr, Stmt, LTS); + + if (Stmt->isBlockStmt()) + BlockGen.copyStmt(*Stmt, LTS, NewAccesses); + else + RegionGen.copyStmt(*Stmt, LTS, NewAccesses); + } isl_id_to_ast_expr_free(NewAccesses); isl_ast_node_free(User); Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -628,22 +628,94 @@ /// @param OldMemAccessRelStr The map describing access relation /// that should be replaced. /// @param NewMemAccessRelStr The map describing new access relation. -static void replaceMemAccRelation(ScopStmt *Stmt, - __isl_take isl_map *OldMemAccessRel, - __isl_take isl_map *NewMemAccessRel) { +static MemoryAccess * +replaceMemAccRelation(ScopStmt *Stmt, __isl_take isl_map *OldMemAccessRel, + __isl_take isl_map *NewMemAccessRel) { for (auto *MemA : *Stmt) { auto *AccessRelation = MemA->getAccessRelation(); auto *MemId = isl_map_get_tuple_id(AccessRelation, isl_dim_out); OldMemAccessRel = isl_map_set_tuple_id(OldMemAccessRel, isl_dim_out, MemId); if (isl_map_is_equal(OldMemAccessRel, AccessRelation)) { - MemA->setNewAccessRelation(isl_map_copy(NewMemAccessRel)); isl_map_free(AccessRelation); - break; + isl_map_free(OldMemAccessRel); + MemA->setNewAccessRelation(NewMemAccessRel); + return MemA; } isl_map_free(AccessRelation); } isl_map_free(OldMemAccessRel); isl_map_free(NewMemAccessRel); + return nullptr; +} + +/// @brief Add constrains to @DimNum dimension of @p ExtMap. +/// +/// If @ExtMap has the following form [O0, O1, O2]->[I1, I2, I3], +/// the following constraint will be added +/// Bound * OM <= IM <= Bound * (OM + 1) - 1, +/// wehere M is @p DimNum and Bound is @p Bound. +/// +/// @param ExtMap The isl map to be modified. +/// @param DimNum The output dimension to be modfied. +/// @param Bound The value that is used to specify the constraint. +/// @return The modified isl map +__isl_give isl_map * +addExtensionMapMatMulDimConstraint(__isl_take isl_map *ExtMap, unsigned DimNum, + unsigned Bound) { + assert(Bound != 0); + auto *ExtMapSpace = isl_map_get_space(ExtMap); + auto *ConstrSpace = isl_local_space_from_space(ExtMapSpace); + auto *Constr = + isl_constraint_alloc_inequality(isl_local_space_copy(ConstrSpace)); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, DimNum, 1); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, DimNum, + Bound * (-1)); + ExtMap = isl_map_add_constraint(ExtMap, Constr); + Constr = isl_constraint_alloc_inequality(ConstrSpace); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, DimNum, -1); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, DimNum, Bound); + Constr = isl_constraint_set_constant_si(Constr, Bound - 1); + return isl_map_add_constraint(ExtMap, Constr); +} + +/// @breif Create an access relation that is specific for matrix multiplication +/// pattern. +/// +/// Create an access relation of the following form: +/// { [O0, O1, O2]->[I1, I2, I3] : +/// FirstOutputDimBound * O0 <= I1 <= FirstOutputDimBound * (O0 + 1) - 1 +/// and SecondOutputDimBound * O1 <= I2 <= SecondOutputDimBound * (O1 + 1) - 1 +/// and ThirdOutputDimBound * O2 <= I3 <= ThirdOutputDimBound * (O2 + 1) - 1} +/// where FirstOutputDimBound is @p FirstOutputDimBound, +/// SecondOutputDimBound is @SecondOutputDimBound, +/// ThirdOutputDimBound is @p ThirdOutputDimBound +/// +/// @param Ctx The isl context. +/// @param FirstOutputDimBound, +/// SecondOutputDimBound, +/// ThirdOutputDimBound The parameters of the access relation. +/// @return The specified access relation. +__isl_give isl_map *getMatMulExt(isl_ctx *Ctx, unsigned FirstOutputDimBound, + unsigned SecondOutputDimBound, + unsigned ThirdOutputDimBound) { + auto *NewRelSpace = isl_space_alloc(Ctx, 0, 3, 3); + auto *extensionMap = isl_map_universe(NewRelSpace); + if (!FirstOutputDimBound) + extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 0, 0); + else + extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 0, + FirstOutputDimBound); + if (!SecondOutputDimBound) + extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 1, 0); + else + extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 1, + SecondOutputDimBound); + if (!ThirdOutputDimBound) + extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 2, 0); + else + extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 2, + ThirdOutputDimBound); + return extensionMap; } /// @brief Create an access relation that is specific for matrix multiplication @@ -692,6 +764,29 @@ return isl_map_apply_range(MapOldIndVar, AccessRel); } +/// @brief Insert an extension node that describes copy statement for @p MemA. +/// +/// @param Node The schedule node to be modified. +/// @param ExtensionMap The isl map that describes the schedule node +/// to be added. +/// @param MemA The memory access that should be packed. +/// @return Modified schedule node, if @p MemA +/// is not a nullptr and @p otherwise. +__isl_give isl_schedule_node *createCopyStmt(__isl_take isl_schedule_node *Node, + __isl_take isl_map *ExtensionMap, + MemoryAccess *MemA) { + if (!MemA) { + isl_map_free(ExtensionMap); + return Node; + } + auto *NewStmt = new ScopStmt(*(MemA->getStatement()->getParent()), MemA); + ExtensionMap = + isl_map_set_tuple_id(ExtensionMap, isl_dim_out, NewStmt->getDomainId()); + auto *Extension = isl_union_map_from_map(ExtensionMap); + auto *NewNode = isl_schedule_node_from_extension(Extension); + return isl_schedule_node_graft_before(Node, NewNode); +} + /// @brief Set a ScopArrayInfo memory object of @p AccRel to a new one with /// @p Name and @p DimSizes. /// @param AccRel The access relation to be modified. @@ -724,9 +819,9 @@ /// @param MicKerParams, MacKerParams Parameters of the BLIS kernel /// to be taken into account. /// @return The optimized schedule node. -static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar, - MicroKernelParamsTy MicKerParams, - MacroKernelParamsTy MacKerParams) { +static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern( + __isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar, + MicroKernelParamsTy MicKerParams, MacroKernelParamsTy MacKerParams) { auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in); auto *Stmt = static_cast(isl_id_get_user(InputDimsId)); isl_id_free(InputDimsId); @@ -736,14 +831,27 @@ unsigned FirstDimSize = MacKerParams.Mc * MacKerParams.Kc / MicKerParams.Mr; unsigned SecondDimSize = MicKerParams.Mr; AccRel = setScopInfo(AccRel, {FirstDimSize, SecondDimSize}, "Packed_A"); - replaceMemAccRelation(Stmt, OriginalRel, AccRel); + auto *MemA = replaceMemAccRelation(Stmt, OriginalRel, AccRel); + Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); + Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); + Node = isl_schedule_node_parent(Node); + Node = isl_schedule_node_child(isl_schedule_node_band_split(Node, 2), 0); + auto *ExtMap = + getMatMulExt(Stmt->getIslCtx(), MacKerParams.Mc, 0, MacKerParams.Kc); + ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 1, 1); + Node = isl_schedule_node_child(createCopyStmt(Node, ExtMap, MemA), 0); AccRel = getMatMulAccRel(MapOldIndVar, MacKerParams.Kc, 4, 7); OriginalRel = getMatMulPatternOriginalAccessRelation(Stmt, 2, 1); FirstDimSize = MacKerParams.Nc * MacKerParams.Kc / MicKerParams.Nr; SecondDimSize = MicKerParams.Nr; AccRel = setScopInfo(AccRel, {FirstDimSize, SecondDimSize}, "Packed_B"); - replaceMemAccRelation(Stmt, OriginalRel, AccRel); - return; + MemA = replaceMemAccRelation(Stmt, OriginalRel, AccRel); + ExtMap = getMatMulExt(Stmt->getIslCtx(), 0, MacKerParams.Nc, MacKerParams.Kc); + isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 1, 1); + isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1); + Node = createCopyStmt(Node, ExtMap, MemA); + Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0); + return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0); } /// @brief Get a relation mapping induction variables produced by schedule @@ -788,9 +896,8 @@ Node, MicroKernelParams, MacroKernelParams); if (!MapOldIndVar) return Node; - optimizeDataLayoutMatrMulPattern(MapOldIndVar, MicroKernelParams, - MacroKernelParams); - return Node; + return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams, + MacroKernelParams); } bool ScheduleTreeOptimizer::isMatrMultPattern( @@ -1036,18 +1143,18 @@ auto *TTI = &getAnalysis().getTTI(F); isl_schedule *NewSchedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI); - isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule); + /*isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule); if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) { isl_union_map_free(NewScheduleMap); isl_schedule_free(NewSchedule); return false; - } + }*/ S.setScheduleTree(NewSchedule); S.markAsOptimized(); - isl_union_map_free(NewScheduleMap); + // isl_union_map_free(NewScheduleMap); return false; } Index: test/Isl/CodeGen/MemAccess/mat_mul_pattern_data_layout.ll =================================================================== --- test/Isl/CodeGen/MemAccess/mat_mul_pattern_data_layout.ll +++ test/Isl/CodeGen/MemAccess/mat_mul_pattern_data_layout.ll @@ -17,23 +17,48 @@ ; CHECK: %Packed_B = alloca [3072 x [8 x double]] ; CHECK: br label %polly.split_new_and_old ; -; CHECK:polly.stmt.bb14398: ; preds = %polly.stmt.bb14379 -; CHECK: %arg3.s2a.reload399 = load double, double* %arg3.s2a -; CHECK: %polly.access.cast.Packed_A400 = bitcast [1024 x [4 x double]]* %Packed_A to double* -; CHECK: %243 = mul nsw i64 256, %polly.indvar95 -; CHECK: %244 = add nsw i64 %243, %polly.indvar107 -; CHECK: %polly.access.add.Packed_A401 = add nsw i64 0, %244 -; CHECK: %polly.access.mul.Packed_A402 = mul nsw i64 %polly.access.add.Packed_A401, 4 -; CHECK: %polly.access.add.Packed_A403 = add nsw i64 %polly.access.mul.Packed_A402, 2 -; CHECK: %polly.access.Packed_A404 = getelementptr double, double* %polly.access.cast.Packed_A400, i64 %polly.access.add.Packed_A403 -; CHECK: %tmp17_p_scalar_405 = load double, double* %polly.access.Packed_A404, align 8 -; CHECK: %p_tmp18406 = fmul double %tmp17_p_scalar_405, %arg3.s2a.reload399 -; CHECK: %polly.access.cast.Packed_B407 = bitcast [3072 x [8 x double]]* %Packed_B to double* -; CHECK %245 = mul nsw i64 256, %polly.indvar101 -; CHECK %246 = add nsw i64 %245, %polly.indvar107 -; CHECK %polly.access.add.Packed_B408 = add nsw i64 0, %246 -; CHECK %polly.access.mul.Packed_B409 = mul nsw i64 %polly.access.add.Packed_B408, 8 -; CHECK %polly.access.add.Packed_B410 = add nsw i64 %polly.access.mul.Packed_B409, 0 +; CHECK: %polly.access.cast.arg6101 = bitcast [1024 x double]* %arg6 to double* +; CHECK: %polly.access.mul.arg6102 = mul nsw i64 %polly.indvar89, 1024 +; CHECK: %polly.access.add.arg6103 = add nsw i64 %polly.access.mul.arg6102, %polly.indvar97 +; CHECK: %polly.access.arg6104 = getelementptr double, double* %polly.access.cast.arg6101, i64 %polly.access.add.arg6103 +; CHECK: %polly.access.arg6104.load = load double, double* %polly.access.arg6104 +; CHECK: %polly.access.cast.Packed_A = bitcast [1024 x [4 x double]]* %Packed_A to double* +; CHECK: %26 = add nsw i64 %polly.indvar89, 4 +; CHECK: %pexp.pdiv_r = urem i64 %26, 4 +; CHECK: %27 = mul nsw i64 64, %pexp.pdiv_r +; CHECK: %28 = sub nsw i64 0, %27 +; CHECK: %29 = mul nsw i64 1024, %polly.indvar77 +; CHECK: %30 = sub nsw i64 %28, %29 +; CHECK: %31 = mul nsw i64 256, %polly.indvar83 +; CHECK: %32 = sub nsw i64 %30, %31 +; CHECK: %33 = mul nsw i64 64, %polly.indvar89 +; CHECK: %34 = add nsw i64 %32, %33 +; CHECK: %35 = add nsw i64 %34, %polly.indvar97 +; CHECK: %polly.access.add.Packed_A = add nsw i64 0, %35 +; CHECK: %polly.access.mul.Packed_A = mul nsw i64 %polly.access.add.Packed_A, 4 +; CHECK: %36 = add nsw i64 %polly.indvar89, 4 +; CHECK: %pexp.pdiv_r105 = urem i64 %36, 4 +; CHECK: %polly.access.add.Packed_A106 = add nsw i64 %polly.access.mul.Packed_A, %pexp.pdiv_r105 +; CHECK: %polly.access.Packed_A = getelementptr double, double* %polly.access.cast.Packed_A, i64 %polly.access.add.Packed_A106 +; CHECK: store double %polly.access.arg6104.load, double* %polly.access.Packed_A +; +; CHECK:polly.stmt.bb14: ; preds = %polly.loop_header150 +; CHECK %arg3.s2a.reload = load double, double* %arg3.s2a +; CHECK %polly.access.cast.Packed_A156 = bitcast [1024 x [4 x double]]* %Packed_A to double* +; CHECK %62 = mul nsw i64 256, %polly.indvar141 +; CHECK %63 = add nsw i64 %62, %polly.indvar153 +; CHECK %polly.access.add.Packed_A157 = add nsw i64 0, %63 +; CHECK %polly.access.mul.Packed_A158 = mul nsw i64 %polly.access.add.Packed_A157, 4 +; CHECK %polly.access.add.Packed_A159 = add nsw i64 %polly.access.mul.Packed_A158, 0 +; CHECK %polly.access.Packed_A160 = getelementptr double, double* %polly.access.cast.Packed_A156, i64 %polly.access.add.Packed_A159 +; CHECK %tmp17_p_scalar_ = load double, double* %polly.access.Packed_A160, align 8 +; CHECK %p_tmp18 = fmul double %tmp17_p_scalar_, %arg3.s2a.reload +; CHECK %polly.access.cast.Packed_B161 = bitcast [3072 x [8 x double]]* %Packed_B to double* +; CHECK %64 = mul nsw i64 256, %polly.indvar147 +; CHECK %65 = add nsw i64 %64, %polly.indvar153 +; CHECK %polly.access.add.Packed_B162 = add nsw i64 0, %65 +; CHECK %polly.access.mul.Packed_B163 = mul nsw i64 %polly.access.add.Packed_B162, 8 +; CHECK %polly.access.add.Packed_B164 = add nsw i64 %polly.access.mul.Packed_B163, 0 ; target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-unknown"