Index: polly/trunk/lib/Transform/ScheduleOptimizer.cpp =================================================================== --- polly/trunk/lib/Transform/ScheduleOptimizer.cpp +++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp @@ -618,6 +618,212 @@ return {Mc, Nc, Kc}; } +/// @brief Identify a memory access through the shape of its memory access +/// relation. +/// +/// Identify the unique memory access in @p Stmt, that has an access relation +/// equal to @p ExpectedAccessRelation. +/// +/// @param Stmt The SCoP statement that contains the memory accesses under +/// consideration. +/// @param ExpectedAccessRelation The access relation that identifies +/// the memory access. +/// @return The memory access of @p Stmt whose memory access relation is equal +/// to @p ExpectedAccessRelation. nullptr in case there is no or more +/// than one such access. +MemoryAccess * +identifyAccessByAccessRelation(ScopStmt *Stmt, + __isl_take isl_map *ExpectedAccessRelation) { + if (isl_map_has_tuple_id(ExpectedAccessRelation, isl_dim_out)) + ExpectedAccessRelation = + isl_map_reset_tuple_id(ExpectedAccessRelation, isl_dim_out); + MemoryAccess *IdentifiedAccess = nullptr; + for (auto *Access : *Stmt) { + auto *AccessRelation = Access->getAccessRelation(); + AccessRelation = isl_map_reset_tuple_id(AccessRelation, isl_dim_out); + if (isl_map_is_equal(ExpectedAccessRelation, AccessRelation)) { + if (IdentifiedAccess) { + isl_map_free(AccessRelation); + isl_map_free(ExpectedAccessRelation); + return nullptr; + } + IdentifiedAccess = Access; + } + isl_map_free(AccessRelation); + } + isl_map_free(ExpectedAccessRelation); + return IdentifiedAccess; +} + +/// @brief Create an access relation that is specific to the matrix +/// multiplication pattern. +/// +/// Create an access relation of the following form: +/// Stmt[O0, O1, O2]->[OI, OJ], +/// where I is @p I, J is @J +/// +/// @param Stmt The SCoP statement for which to generate the access relation. +/// @param I The index of the input dimension that is mapped to the first output +/// dimension. +/// @param J The index of the input dimension that is mapped to the second +/// output dimension. +/// @return The specified access relation. +__isl_give isl_map * +getMatMulPatternOriginalAccessRelation(ScopStmt *Stmt, unsigned I, unsigned J) { + auto *AccessRelSpace = isl_space_alloc(Stmt->getIslCtx(), 0, 3, 2); + auto *AccessRel = isl_map_universe(AccessRelSpace); + AccessRel = isl_map_equate(AccessRel, isl_dim_in, I, isl_dim_out, 0); + AccessRel = isl_map_equate(AccessRel, isl_dim_in, J, isl_dim_out, 1); + AccessRel = isl_map_set_tuple_id(AccessRel, isl_dim_in, Stmt->getDomainId()); + return AccessRel; +} + +/// @brief Identify the memory access that corresponds to the access +/// to the second operand of the matrix multiplication. +/// +/// Identify the memory access that corresponds to the access +/// to the matrix B of the matrix multiplication C = A x B. +/// +/// @param Stmt The SCoP statement that contains the memory accesses +/// under consideration. +/// @return The memory access of @p Stmt that corresponds to the access +/// to the second operand of the matrix multiplication. +MemoryAccess *identifyAccessA(ScopStmt *Stmt) { + auto *OriginalRel = getMatMulPatternOriginalAccessRelation(Stmt, 0, 2); + return identifyAccessByAccessRelation(Stmt, OriginalRel); +} + +/// @brief Identify the memory access that corresponds to the access +/// to the first operand of the matrix multiplication. +/// +/// Identify the memory access that corresponds to the access +/// to the matrix A of the matrix multiplication C = A x B. +/// +/// @param Stmt The SCoP statement that contains the memory accesses +/// under consideration. +/// @return The memory access of @p Stmt that corresponds to the access +/// to the first operand of the matrix multiplication. +MemoryAccess *identifyAccessB(ScopStmt *Stmt) { + auto *OriginalRel = getMatMulPatternOriginalAccessRelation(Stmt, 2, 1); + return identifyAccessByAccessRelation(Stmt, OriginalRel); +} + +/// @brief Create an access relation that is specific to +/// the matrix multiplication pattern. +/// +/// Create an access relation of the following form: +/// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [0, O5 + K * OI, OJ], +/// where K is @p Coeff, I is @p FirstDim, J is @p SecondDim. +/// +/// It can be used, for example, to create relations that helps to consequently +/// access elements of operands of a matrix multiplication after creation of +/// the BLIS micro and macro kernels. +/// +/// @see ScheduleTreeOptimizer::createMicroKernel +/// @see ScheduleTreeOptimizer::createMacroKernel +/// +/// Subsequently, the described access relation is applied to the range of +/// @p MapOldIndVar, that is used to map original induction variables to +/// the ones, which are produced by schedule transformations. It helps to +/// define relations using a new space and, at the same time, keep them +/// in the original one. +/// +/// @param MapOldIndVar The relation, which maps original induction variables +/// to the ones, which are produced by schedule +/// transformations. +/// @param Coeff The coefficient that is used to define the specified access +/// relation. +/// @param FirstDim, SecondDim The input dimensions that are used to define +/// the specified access relation. +/// @return The specified access relation. +__isl_give isl_map *getMatMulAccRel(__isl_take isl_map *MapOldIndVar, + unsigned Coeff, unsigned FirstDim, + unsigned SecondDim) { + auto *Ctx = isl_map_get_ctx(MapOldIndVar); + auto *AccessRelSpace = isl_space_alloc(Ctx, 0, 9, 3); + auto *AccessRel = isl_map_universe(isl_space_copy(AccessRelSpace)); + auto *ConstrSpace = isl_local_space_from_space(AccessRelSpace); + auto *Constr = isl_constraint_alloc_equality(ConstrSpace); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, 1, -1); + Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, 5, 1); + Constr = + isl_constraint_set_coefficient_si(Constr, isl_dim_in, FirstDim, Coeff); + AccessRel = isl_map_add_constraint(AccessRel, Constr); + AccessRel = isl_map_fix_si(AccessRel, isl_dim_out, 0, 0); + AccessRel = isl_map_equate(AccessRel, isl_dim_in, SecondDim, isl_dim_out, 2); + return isl_map_apply_range(MapOldIndVar, AccessRel); +} + +/// @brief Apply the packing transformation. +/// +/// The packing transformation can be described as a data-layout +/// transformation that requires to introduce a new array, copy data +/// to the array, and change memory access locations of the compute kernel +/// to reference the array. +/// +/// @param Node The schedule node to be optimized. +/// @param MapOldIndVar The relation, which maps original induction variables +/// to the ones, which are produced by schedule +/// transformations. +/// @param MicroParams, MacroParams Parameters of the BLIS kernel +/// to be taken into account. +/// @return The optimized schedule node. +static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar, + MicroKernelParamsTy MicroParams, + MacroKernelParamsTy MacroParams) { + auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in); + auto *Stmt = static_cast(isl_id_get_user(InputDimsId)); + isl_id_free(InputDimsId); + MemoryAccess *MemAccessA = identifyAccessA(Stmt); + MemoryAccess *MemAccessB = identifyAccessB(Stmt); + if (!MemAccessA || !MemAccessB) { + isl_map_free(MapOldIndVar); + return; + } + auto *AccRel = + getMatMulAccRel(isl_map_copy(MapOldIndVar), MacroParams.Kc, 3, 6); + unsigned FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr; + unsigned SecondDimSize = MicroParams.Mr; + auto *SAI = Stmt->getParent()->createScopArrayInfo( + MemAccessA->getElementType(), "Packed_A", {FirstDimSize, SecondDimSize}); + AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId()); + MemAccessA->setNewAccessRelation(AccRel); + AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 7); + FirstDimSize = MacroParams.Nc * MacroParams.Kc / MicroParams.Nr; + SecondDimSize = MicroParams.Nr; + SAI = Stmt->getParent()->createScopArrayInfo( + MemAccessB->getElementType(), "Packed_B", {FirstDimSize, SecondDimSize}); + AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId()); + MemAccessB->setNewAccessRelation(AccRel); +} + +/// @brief Get a relation mapping induction variables produced by schedule +/// transformations to the original ones. +/// +/// @param Node The schedule node produced as the result of creation +/// of the BLIS kernels. +/// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel +/// to be taken into account. +/// @return The relation mapping original induction variables to the ones +/// produced by schedule transformation. +/// @see ScheduleTreeOptimizer::createMicroKernel +/// @see ScheduleTreeOptimizer::createMacroKernel +/// @see getMacroKernelParams +__isl_give isl_map * +getInductionVariablesSubstitution(__isl_take isl_schedule_node *Node, + MicroKernelParamsTy MicroKernelParams, + MacroKernelParamsTy MacroKernelParams) { + auto *Child = isl_schedule_node_get_child(Node, 0); + auto *UnMapOldIndVar = isl_schedule_node_get_prefix_schedule_union_map(Child); + isl_schedule_node_free(Child); + auto *MapOldIndVar = isl_map_from_union_map(UnMapOldIndVar); + if (isl_map_dim(MapOldIndVar, isl_dim_out) > 9) + MapOldIndVar = + isl_map_project_out(MapOldIndVar, isl_dim_out, 0, + isl_map_dim(MapOldIndVar, isl_dim_out) - 9); + return MapOldIndVar; +} + __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern( __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) { assert(TTI && "The target transform info should be provided."); @@ -625,6 +831,15 @@ auto MacroKernelParams = getMacroKernelParams(MicroKernelParams); Node = createMacroKernel(Node, MacroKernelParams); Node = createMicroKernel(Node, MicroKernelParams); + if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 || + MacroKernelParams.Kc == 1) + return Node; + auto *MapOldIndVar = getInductionVariablesSubstitution( + Node, MicroKernelParams, MacroKernelParams); + if (!MapOldIndVar) + return Node; + optimizeDataLayoutMatrMulPattern(MapOldIndVar, MicroKernelParams, + MacroKernelParams); return Node; } Index: polly/trunk/test/Isl/CodeGen/MemAccess/mat_mul_pattern_data_layout.ll =================================================================== --- polly/trunk/test/Isl/CodeGen/MemAccess/mat_mul_pattern_data_layout.ll +++ polly/trunk/test/Isl/CodeGen/MemAccess/mat_mul_pattern_data_layout.ll @@ -0,0 +1,86 @@ +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -polly-target-througput-vector-fma=1 -polly-target-latency-vector-fma=8 -polly-target-cache-level-associativity=8,8 -polly-target-cache-level-sizes=32768,262144 -polly-codegen -S < %s 2>&1 | FileCheck %s +; +; /* C := alpha*A*B + beta*C */ +; for (i = 0; i < _PB_NI; i++) +; for (j = 0; j < _PB_NJ; j++) +; { +; C[i][j] *= beta; +; for (k = 0; k < _PB_NK; ++k) +; C[i][j] += alpha * A[i][k] * B[k][j]; +; } +; +; CHECK:define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) #0 { +; CHECK:bb: +; CHECK: %arg3.s2a = alloca double +; CHECK: %arg4.s2a = alloca double +; CHECK: %Packed_A = alloca [1024 x [4 x double]] +; CHECK: %Packed_B = alloca [3072 x [8 x double]] +; CHECK: br label %polly.split_new_and_old +; +; CHECK:polly.stmt.bb14398: ; preds = %polly.stmt.bb14379 +; CHECK: %arg3.s2a.reload399 = load double, double* %arg3.s2a +; CHECK: %polly.access.cast.Packed_A400 = bitcast [1024 x [4 x double]]* %Packed_A to double* +; CHECK: %243 = mul nsw i64 256, %polly.indvar95 +; CHECK: %244 = add nsw i64 %243, %polly.indvar107 +; CHECK: %polly.access.add.Packed_A401 = add nsw i64 0, %244 +; CHECK: %polly.access.mul.Packed_A402 = mul nsw i64 %polly.access.add.Packed_A401, 4 +; CHECK: %polly.access.add.Packed_A403 = add nsw i64 %polly.access.mul.Packed_A402, 2 +; CHECK: %polly.access.Packed_A404 = getelementptr double, double* %polly.access.cast.Packed_A400, i64 %polly.access.add.Packed_A403 +; CHECK: %tmp17_p_scalar_405 = load double, double* %polly.access.Packed_A404, align 8 +; CHECK: %p_tmp18406 = fmul double %tmp17_p_scalar_405, %arg3.s2a.reload399 +; CHECK: %polly.access.cast.Packed_B407 = bitcast [3072 x [8 x double]]* %Packed_B to double* +; CHECK %245 = mul nsw i64 256, %polly.indvar101 +; CHECK %246 = add nsw i64 %245, %polly.indvar107 +; CHECK %polly.access.add.Packed_B408 = add nsw i64 0, %246 +; CHECK %polly.access.mul.Packed_B409 = mul nsw i64 %polly.access.add.Packed_B408, 8 +; CHECK %polly.access.add.Packed_B410 = add nsw i64 %polly.access.mul.Packed_B409, 0 +; +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-unknown" + +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) #0 { +bb: + br label %bb8 + +bb8: ; preds = %bb29, %bb + %tmp = phi i64 [ 0, %bb ], [ %tmp30, %bb29 ] + br label %bb9 + +bb9: ; preds = %bb26, %bb8 + %tmp10 = phi i64 [ 0, %bb8 ], [ %tmp27, %bb26 ] + %tmp11 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp, i64 %tmp10 + %tmp12 = load double, double* %tmp11, align 8 + %tmp13 = fmul double %tmp12, %arg4 + store double %tmp13, double* %tmp11, align 8 + br label %bb14 + +bb14: ; preds = %bb14, %bb9 + %tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %bb14 ] + %tmp16 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp, i64 %tmp15 + %tmp17 = load double, double* %tmp16, align 8 + %tmp18 = fmul double %tmp17, %arg3 + %tmp19 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp15, i64 %tmp10 + %tmp20 = load double, double* %tmp19, align 8 + %tmp21 = fmul double %tmp18, %tmp20 + %tmp22 = load double, double* %tmp11, align 8 + %tmp23 = fadd double %tmp22, %tmp21 + store double %tmp23, double* %tmp11, align 8 + %tmp24 = add nuw nsw i64 %tmp15, 1 + %tmp25 = icmp ne i64 %tmp24, 1024 + br i1 %tmp25, label %bb14, label %bb26 + +bb26: ; preds = %bb14 + %tmp27 = add nuw nsw i64 %tmp10, 1 + %tmp28 = icmp ne i64 %tmp27, 1056 + br i1 %tmp28, label %bb9, label %bb29 + +bb29: ; preds = %bb26 + %tmp30 = add nuw nsw i64 %tmp, 1 + %tmp31 = icmp ne i64 %tmp30, 1056 + br i1 %tmp31, label %bb8, label %bb32 + +bb32: ; preds = %bb29 + ret void +} + +attributes #0 = { nounwind uwtable "target-cpu"="x86-64" "target-features"="+aes,+avx,+cmov,+cx16,+fxsr,+mmx,+pclmul,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave,+xsaveopt" }