Index: polly/trunk/lib/Transform/ScheduleOptimizer.cpp =================================================================== --- polly/trunk/lib/Transform/ScheduleOptimizer.cpp +++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp @@ -785,8 +785,8 @@ /// the matrix multiplication pattern. /// /// Create an access relation of the following form: -/// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [O5 + K * OI, OJ], -/// where K is @p Coeff, I is @p FirstDim, J is @p SecondDim. +/// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ] +/// where I is @p FirstDim, J is @p SecondDim. /// /// It can be used, for example, to create relations that helps to consequently /// access elements of operands of a matrix multiplication after creation of @@ -804,25 +804,17 @@ /// @param MapOldIndVar The relation, which maps original induction variables /// to the ones, which are produced by schedule /// transformations. -/// @param Coeff The coefficient that is used to define the specified access -/// relation. /// @param FirstDim, SecondDim The input dimensions that are used to define /// the specified access relation. /// @return The specified access relation. __isl_give isl_map *getMatMulAccRel(__isl_take isl_map *MapOldIndVar, - unsigned Coeff, unsigned FirstDim, - unsigned SecondDim) { + unsigned FirstDim, unsigned SecondDim) { auto *Ctx = isl_map_get_ctx(MapOldIndVar); - auto *AccessRelSpace = isl_space_alloc(Ctx, 0, 9, 2); - auto *AccessRel = isl_map_universe(isl_space_copy(AccessRelSpace)); - auto *ConstrSpace = isl_local_space_from_space(AccessRelSpace); - auto *Constr = isl_constraint_alloc_equality(ConstrSpace); - Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, 0, -1); - Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, 5, 1); - Constr = - isl_constraint_set_coefficient_si(Constr, isl_dim_in, FirstDim, Coeff); - AccessRel = isl_map_add_constraint(AccessRel, Constr); - AccessRel = isl_map_equate(AccessRel, isl_dim_in, SecondDim, isl_dim_out, 1); + auto *AccessRelSpace = isl_space_alloc(Ctx, 0, 9, 3); + auto *AccessRel = isl_map_universe(AccessRelSpace); + AccessRel = isl_map_equate(AccessRel, isl_dim_in, FirstDim, isl_dim_out, 0); + AccessRel = isl_map_equate(AccessRel, isl_dim_in, 5, isl_dim_out, 1); + AccessRel = isl_map_equate(AccessRel, isl_dim_in, SecondDim, isl_dim_out, 2); return isl_map_apply_range(MapOldIndVar, AccessRel); } @@ -869,12 +861,13 @@ Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); Node = isl_schedule_node_parent(Node); Node = isl_schedule_node_child(isl_schedule_node_band_split(Node, 2), 0); - auto *AccRel = - getMatMulAccRel(isl_map_copy(MapOldIndVar), MacroParams.Kc, 3, 7); - unsigned FirstDimSize = MacroParams.Nc * MacroParams.Kc / MicroParams.Nr; - unsigned SecondDimSize = MicroParams.Nr; + auto *AccRel = getMatMulAccRel(isl_map_copy(MapOldIndVar), 3, 7); + unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr; + unsigned SecondDimSize = MacroParams.Kc; + unsigned ThirdDimSize = MicroParams.Nr; auto *SAI = Stmt->getParent()->createScopArrayInfo( - MemAccessB->getElementType(), "Packed_B", {FirstDimSize, SecondDimSize}); + MemAccessB->getElementType(), "Packed_B", + {FirstDimSize, SecondDimSize, ThirdDimSize}); AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId()); auto *OldAcc = MemAccessB->getAccessRelation(); MemAccessB->setNewAccessRelation(AccRel); @@ -898,11 +891,12 @@ // Create a copy statement that corresponds to the memory access // to the matrix A, the first operand of the matrix multiplication. Node = isl_schedule_node_child(Node, 0); - AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 6); - FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr; - SecondDimSize = MicroParams.Mr; + AccRel = getMatMulAccRel(MapOldIndVar, 4, 6); + FirstDimSize = MacroParams.Mc / MicroParams.Mr; + ThirdDimSize = MicroParams.Mr; SAI = Stmt->getParent()->createScopArrayInfo( - MemAccessA->getElementType(), "Packed_A", {FirstDimSize, SecondDimSize}); + MemAccessA->getElementType(), "Packed_A", + {FirstDimSize, SecondDimSize, ThirdDimSize}); AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId()); OldAcc = MemAccessA->getAccessRelation(); MemAccessA->setNewAccessRelation(AccRel); Index: polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll =================================================================== --- polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll +++ polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll @@ -9,14 +9,14 @@ ; C[i][j] += alpha * A[i][k] * B[k][j]; ; } ; -; CHECK: double Packed_B[ { [] -> [(512)] } ][ { [] -> [(8)] } ]; // Element size 8 -; CHECK-NEXT: double Packed_A[ { [] -> [(6144)] } ][ { [] -> [(4)] } ]; // Element size 8 +; CHECK: double Packed_B[ { [] -> [(2)] } ][ { [] -> [(256)] } ][ { [] -> [(8)] } ]; // Element size 8 +; CHECK-NEXT: double Packed_A[ { [] -> [(24)] } ][ { [] -> [(256)] } ][ { [] -> [(4)] } ]; // Element size 8 ; ; CHECK: { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg6[i0, i2] }; -; CHECK-NEXT: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 96*floor((i0)/96) <= 4*floor((o0)/256) <= i0 - 96*floor((i0)/96) }; +; CHECK-NEXT: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_A[o0, o1, o2] : 256*floor((-i2 + o1)/256) = -i2 + o1 and 4*floor((-i0 + o2)/4) = -i0 + o2 and 0 <= o1 <= 255 and 0 <= o2 <= 3 and -3 + i0 - 4o0 <= 96*floor((i0)/96) <= i0 - 4o0 }; ; ; CHECK: { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg7[i2, i1] }; -; CHECK-NEXT: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 16*floor((i1)/16) <= 8*floor((o0)/256) <= i1 - 16*floor((i1)/16) }; +; CHECK-NEXT: new: { Stmt_Copy_0[i0, i1, i2] -> Packed_B[o0, o1, o2] : 256*floor((-i2 + o1)/256) = -i2 + o1 and 8*floor((-i1 + o2)/8) = -i1 + o2 and 0 <= o1 <= 255 and 0 <= o2 <= 7 and -7 + i1 - 8o0 <= 16*floor((i1)/16) <= i1 - 8o0 }; ; ; CHECK: CopyStmt_0 ; CHECK-NEXT: Domain := @@ -25,7 +25,7 @@ ; CHECK-NEXT: ; ; CHECK-NEXT: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0] ; CHECK-NEXT: null; -; CHECK-NEXT: new: { CopyStmt_0[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 16*floor((i1)/16) <= 8*floor((o0)/256) <= i1 - 16*floor((i1)/16) }; +; CHECK-NEXT: new: { CopyStmt_0[i0, i1, i2] -> Packed_B[o0, o1, o2] : 256*floor((-i2 + o1)/256) = -i2 + o1 and 8*floor((-i1 + o2)/8) = -i1 + o2 and 0 <= o1 <= 255 and 0 <= o2 <= 7 and -7 + i1 - 8o0 <= 16*floor((i1)/16) <= i1 - 8o0 }; ; CHECK-NEXT: ReadAccess := [Reduction Type: NONE] [Scalar: 0] ; CHECK-NEXT: null; ; CHECK-NEXT: new: { CopyStmt_0[i0, i1, i2] -> MemRef_arg7[i2, i1] }; @@ -36,7 +36,7 @@ ; CHECK-NEXT: ; ; CHECK-NEXT: MustWriteAccess := [Reduction Type: NONE] [Scalar: 0] ; CHECK-NEXT: null; -; CHECK-NEXT: new: { CopyStmt_1[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 96*floor((i0)/96) <= 4*floor((o0)/256) <= i0 - 96*floor((i0)/96) }; +; CHECK-NEXT: new: { CopyStmt_1[i0, i1, i2] -> Packed_A[o0, o1, o2] : 256*floor((-i2 + o1)/256) = -i2 + o1 and 4*floor((-i0 + o2)/4) = -i0 + o2 and 0 <= o1 <= 255 and 0 <= o2 <= 3 and -3 + i0 - 4o0 <= 96*floor((i0)/96) <= i0 - 4o0 }; ; CHECK-NEXT: ReadAccess := [Reduction Type: NONE] [Scalar: 0] ; CHECK-NEXT: null; ; CHECK-NEXT: new: { CopyStmt_1[i0, i1, i2] -> MemRef_arg6[i0, i2] };