Index: polly/trunk/lib/Transform/ScheduleOptimizer.cpp =================================================================== --- polly/trunk/lib/Transform/ScheduleOptimizer.cpp +++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp @@ -483,61 +483,6 @@ return Node; } -/// Get the position of a dimension with a non-zero coefficient. -/// -/// Check that isl constraint @p Constraint has only one non-zero -/// coefficient for dimensions that have type @p DimType. If this is true, -/// return the position of the dimension corresponding to the non-zero -/// coefficient and negative value, otherwise. -/// -/// @param Constraint The isl constraint to be checked. -/// @param DimType The type of the dimensions. -/// @return The position of the dimension in case the isl -/// constraint satisfies the requirements, a negative -/// value, otherwise. -static int getMatMulConstraintDim(isl::constraint Constraint, - isl::dim DimType) { - int DimPos = -1; - auto LocalSpace = Constraint.get_local_space(); - int LocalSpaceDimNum = LocalSpace.dim(DimType); - for (int i = 0; i < LocalSpaceDimNum; i++) { - auto Val = Constraint.get_coefficient_val(DimType, i); - if (Val.is_zero()) - continue; - if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) || - (DimType == isl::dim::in && !Val.is_negone())) - return -1; - DimPos = i; - } - return DimPos; -} - -/// Check the form of the isl constraint. -/// -/// Check that the @p DimInPos input dimension of the isl constraint -/// @p Constraint has a coefficient that is equal to negative one, the @p -/// DimOutPos has a coefficient that is equal to one and others -/// have coefficients equal to zero. -/// -/// @param Constraint The isl constraint to be checked. -/// @param DimInPos The input dimension of the isl constraint. -/// @param DimOutPos The output dimension of the isl constraint. -/// @return isl_stat_ok in case the isl constraint satisfies -/// the requirements, isl_stat_error otherwise. -static isl_stat isMatMulOperandConstraint(isl::constraint Constraint, - int &DimInPos, int &DimOutPos) { - auto Val = Constraint.get_constant_val(); - if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero()) - return isl_stat_error; - DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in); - if (DimInPos < 0) - return isl_stat_error; - DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out); - if (DimOutPos < 0) - return isl_stat_error; - return isl_stat_ok; -} - /// Permute the two dimensions of the isl map. /// /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that @@ -585,30 +530,49 @@ /// second output dimension. /// @return True in case @p AccMap has the expected form and false, /// otherwise. -static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) { - int DimInPos[] = {FirstPos, SecondPos}; - auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat { - auto Constraints = BasicMap.get_constraint_list(); - if (isl_constraint_list_n_constraint(Constraints.get()) != 2) - return isl::stat::error; - for (int i = 0; i < 2; i++) { - auto Constraint = - isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i)); - int InPos, OutPos; - if (isMatMulOperandConstraint(Constraint, InPos, OutPos) == - isl_stat_error || - OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos)) - return isl::stat::error; - DimInPos[OutPos] = InPos; - } - return isl::stat::ok; - }; - if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 || - DimInPos[1] < 0) +static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, + int &SecondPos) { + + isl::space Space = AccMap.get_space(); + isl::map Universe = isl::map::universe(Space); + + if (Space.dim(isl::dim::out) != 2) return false; - FirstPos = DimInPos[0]; - SecondPos = DimInPos[1]; - return true; + + // MatMul has the form: + // for (i = 0; i < N; i++) + // for (j = 0; j < M; j++) + // for (k = 0; k < P; k++) + // C[i, j] += A[i, k] * B[k, j] + // + // Permutation of three outer loops: 3! = 6 possibilities. + int FirstDims[] = {0, 0, 1, 1, 2, 2}; + int SecondDims[] = {1, 2, 2, 0, 0, 1}; + for (int i = 0; i < 6; i += 1) { + auto PossibleMatMul = + Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0) + .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1); + + AccMap = AccMap.intersect_domain(Domain); + PossibleMatMul = PossibleMatMul.intersect_domain(Domain); + + // If AccMap spans entire domain (Non-partial write), + // compute FirstPos and SecondPos. + // If AccMap != PossibleMatMul here (the two maps have been gisted at + // this point), it means that the writes are not complete, or in other + // words, it is a Partial write and Partial writes must be rejected. + if (AccMap.is_equal(PossibleMatMul)) { + if (FirstPos != -1 && FirstPos != FirstDims[i]) + continue; + FirstPos = FirstDims[i]; + if (SecondPos != -1 && SecondPos != SecondDims[i]) + continue; + SecondPos = SecondDims[i]; + return true; + } + } + + return false; } /// Does the memory access represent a non-scalar operand of the matrix @@ -627,18 +591,16 @@ if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) return false; auto AccMap = MemAccess->getLatestAccessRelation(); - if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC && - isl_map_n_basic_map(AccMap.get()) == 1) { + isl::set StmtDomain = MemAccess->getStatement()->getDomain(); + if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) { MMI.ReadFromC = MemAccess; return true; } - if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A && - isl_map_n_basic_map(AccMap.get()) == 1) { + if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) { MMI.A = MemAccess; return true; } - if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B && - isl_map_n_basic_map(AccMap.get()) == 1) { + if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) { MMI.B = MemAccess; return true; } @@ -758,8 +720,7 @@ if (!MemAccessPtr->isWrite()) return false; auto AccMap = MemAccessPtr->getLatestAccessRelation(); - if (isl_map_n_basic_map(AccMap.get()) != 1 || - !isMatMulOperandAcc(AccMap, MMI.i, MMI.j)) + if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j)) return false; MMI.WriteToC = MemAccessPtr; break; Index: polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll =================================================================== --- polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll +++ polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll @@ -0,0 +1,59 @@ +; RUN: opt %loadPolly -polly-import-jscop -polly-import-jscop-postfix=transformed -polly-opt-isl -debug-only=polly-opt-isl -disable-output < %s 2>&1 | FileCheck %s +; REQUIRES: asserts +; +; void pattern_matching_based_opts_splitmap(double C[static const restrict 2][2], double A[static const restrict 2][784], double B[static const restrict 784][2]) { +; for (int i = 0; i < 2; i+=1) +; for (int j = 0; j < 2; j+=1) +; for (int k = 0; k < 784; k+=1) +; C[i][j] += A[i][k] * B[k][j]; +;} +; +; Check that the pattern matching detects the matrix multiplication pattern +; when the AccMap cannot be reduced to a single disjunct. +; +; CHECK: The matrix multiplication pattern was detected +; +; ModuleID = 'pattern_matching_based_opts_splitmap.ll' +; +; Function Attrs: noinline nounwind uwtable +define void @pattern_matching_based_opts_splitmap([2 x double]* noalias dereferenceable(32) %C, [784 x double]* noalias dereferenceable(12544) %A, [2 x double]* noalias dereferenceable(12544) %B) { +entry: + br label %for.body + +for.body: ; preds = %entry, %for.inc21 + %i = phi i64 [ 0, %entry ], [ %add22, %for.inc21 ] + br label %for.body3 + +for.body3: ; preds = %for.body, %for.inc18 + %j = phi i64 [ 0, %for.body ], [ %add19, %for.inc18 ] + br label %for.body6 + +for.body6: ; preds = %for.body3, %for.body6 + %k = phi i64 [ 0, %for.body3 ], [ %add17, %for.body6 ] + %arrayidx8 = getelementptr inbounds [784 x double], [784 x double]* %A, i64 %i, i64 %k + %tmp6 = load double, double* %arrayidx8, align 8 + %arrayidx12 = getelementptr inbounds [2 x double], [2 x double]* %B, i64 %k, i64 %j + %tmp10 = load double, double* %arrayidx12, align 8 + %mul = fmul double %tmp6, %tmp10 + %arrayidx16 = getelementptr inbounds [2 x double], [2 x double]* %C, i64 %i, i64 %j + %tmp14 = load double, double* %arrayidx16, align 8 + %add = fadd double %tmp14, %mul + store double %add, double* %arrayidx16, align 8 + %add17 = add nsw i64 %k, 1 + %cmp5 = icmp slt i64 %add17, 784 + br i1 %cmp5, label %for.body6, label %for.inc18 + +for.inc18: ; preds = %for.body6 + %add19 = add nsw i64 %j, 1 + %cmp2 = icmp slt i64 %add19, 2 + br i1 %cmp2, label %for.body3, label %for.inc21 + +for.inc21: ; preds = %for.inc18 + %add22 = add nsw i64 %i, 1 + %cmp = icmp slt i64 %add22, 2 + br i1 %cmp, label %for.body, label %for.end23 + +for.end23: ; preds = %for.inc21 + ret void +} + Index: polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop =================================================================== --- polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop +++ polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop @@ -0,0 +1,46 @@ +{ + "arrays" : [ + { + "name" : "MemRef_A", + "sizes" : [ "*", "784" ], + "type" : "double" + }, + { + "name" : "MemRef_B", + "sizes" : [ "*", "2" ], + "type" : "double" + }, + { + "name" : "MemRef_C", + "sizes" : [ "*", "2" ], + "type" : "double" + } + ], + "context" : "{ : }", + "name" : "%for.body---%for.end23", + "statements" : [ + { + "accesses" : [ + { + "kind" : "read", + "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }" + }, + { + "kind" : "read", + "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }" + }, + { + "kind" : "read", + "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }" + }, + { + "kind" : "write", + "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }" + } + ], + "domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }", + "name" : "Stmt_for_body6", + "schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }" + } + ] +} Index: polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed =================================================================== --- polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed +++ polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed @@ -0,0 +1,46 @@ +{ + "arrays" : [ + { + "name" : "MemRef_A", + "sizes" : [ "*", "784" ], + "type" : "double" + }, + { + "name" : "MemRef_B", + "sizes" : [ "*", "2" ], + "type" : "double" + }, + { + "name" : "MemRef_C", + "sizes" : [ "*", "2" ], + "type" : "double" + } + ], + "context" : "{ : }", + "name" : "%for.body---%for.end23", + "statements" : [ + { + "accesses" : [ + { + "kind" : "read", + "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }" + }, + { + "kind" : "read", + "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }" + }, + { + "kind" : "read", + "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }" + }, + { + "kind" : "write", + "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] : i2 <= 784 - i0 - i1; Stmt_for_body6[1, 1, 783] -> MemRef_C[1, 1] }" + } + ], + "domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }", + "name" : "Stmt_for_body6", + "schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }" + } + ] +}