Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -483,61 +483,6 @@ return Node; } -/// Get the position of a dimension with a non-zero coefficient. -/// -/// Check that isl constraint @p Constraint has only one non-zero -/// coefficient for dimensions that have type @p DimType. If this is true, -/// return the position of the dimension corresponding to the non-zero -/// coefficient and negative value, otherwise. -/// -/// @param Constraint The isl constraint to be checked. -/// @param DimType The type of the dimensions. -/// @return The position of the dimension in case the isl -/// constraint satisfies the requirements, a negative -/// value, otherwise. -static int getMatMulConstraintDim(isl::constraint Constraint, - isl::dim DimType) { - int DimPos = -1; - auto LocalSpace = Constraint.get_local_space(); - int LocalSpaceDimNum = LocalSpace.dim(DimType); - for (int i = 0; i < LocalSpaceDimNum; i++) { - auto Val = Constraint.get_coefficient_val(DimType, i); - if (Val.is_zero()) - continue; - if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) || - (DimType == isl::dim::in && !Val.is_negone())) - return -1; - DimPos = i; - } - return DimPos; -} - -/// Check the form of the isl constraint. -/// -/// Check that the @p DimInPos input dimension of the isl constraint -/// @p Constraint has a coefficient that is equal to negative one, the @p -/// DimOutPos has a coefficient that is equal to one and others -/// have coefficients equal to zero. -/// -/// @param Constraint The isl constraint to be checked. -/// @param DimInPos The input dimension of the isl constraint. -/// @param DimOutPos The output dimension of the isl constraint. -/// @return isl_stat_ok in case the isl constraint satisfies -/// the requirements, isl_stat_error otherwise. -static isl_stat isMatMulOperandConstraint(isl::constraint Constraint, - int &DimInPos, int &DimOutPos) { - auto Val = Constraint.get_constant_val(); - if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero()) - return isl_stat_error; - DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in); - if (DimInPos < 0) - return isl_stat_error; - DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out); - if (DimOutPos < 0) - return isl_stat_error; - return isl_stat_ok; -} - /// Permute the two dimensions of the isl map. /// /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that @@ -585,30 +530,49 @@ /// second output dimension. /// @return True in case @p AccMap has the expected form and false, /// otherwise. -static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) { - int DimInPos[] = {FirstPos, SecondPos}; - auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat { - auto Constraints = BasicMap.get_constraint_list(); - if (isl_constraint_list_n_constraint(Constraints.get()) != 2) - return isl::stat::error; - for (int i = 0; i < 2; i++) { - auto Constraint = - isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i)); - int InPos, OutPos; - if (isMatMulOperandConstraint(Constraint, InPos, OutPos) == - isl_stat_error || - OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos)) - return isl::stat::error; - DimInPos[OutPos] = InPos; - } - return isl::stat::ok; - }; - if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 || - DimInPos[1] < 0) +static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, + int &SecondPos) { + + isl::space Space = AccMap.get_space(); + isl::map Universe = isl::map::universe(Space); + + if (Space.dim(isl::dim::out) != 2) return false; - FirstPos = DimInPos[0]; - SecondPos = DimInPos[1]; - return true; + + // MatMul has the form: + // for (i = 0; i < N; i++) + // for (j = 0; j < M; j++) + // for (k = 0; k < P; k++) + // C[i, j] += A[i, k] * B[k, j] + // + // Permutation of three outer loops: 3! = 6 possibilities. + int FirstDims[] = {0, 0, 1, 1, 2, 2}; + int SecondDims[] = {1, 2, 2, 0, 0, 1}; + for (int i = 0; i < 6; i += 1) { + auto PossibleMatMul = + Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0) + .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1); + + AccMap = AccMap.intersect_domain(Domain); + PossibleMatMul = PossibleMatMul.intersect_domain(Domain); + + // If AccMap spans entire domain (Non-partial write), + // compute FirstPos and SecondPos. + // If AccMap != PossibleMatMul here (the two maps have been gisted at + // this point), it means that the writes are not complete, or in other + // words, it is a Partial write and Partial writes must be rejected. + if (AccMap.is_equal(PossibleMatMul)) { + if (FirstPos != -1 && FirstPos != FirstDims[i]) + continue; + FirstPos = FirstDims[i]; + if (SecondPos != -1 && SecondPos != SecondDims[i]) + continue; + SecondPos = SecondDims[i]; + return true; + } + } + + return false; } /// Does the memory access represent a non-scalar operand of the matrix @@ -627,18 +591,16 @@ if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) return false; auto AccMap = MemAccess->getLatestAccessRelation(); - if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC && - isl_map_n_basic_map(AccMap.get()) == 1) { + isl::set StmtDomain = MemAccess->getStatement()->getDomain(); + if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) { MMI.ReadFromC = MemAccess; return true; } - if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A && - isl_map_n_basic_map(AccMap.get()) == 1) { + if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) { MMI.A = MemAccess; return true; } - if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B && - isl_map_n_basic_map(AccMap.get()) == 1) { + if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) { MMI.B = MemAccess; return true; } @@ -758,8 +720,7 @@ if (!MemAccessPtr->isWrite()) return false; auto AccMap = MemAccessPtr->getLatestAccessRelation(); - if (isl_map_n_basic_map(AccMap.get()) != 1 || - !isMatMulOperandAcc(AccMap, MMI.i, MMI.j)) + if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j)) return false; MMI.WriteToC = MemAccessPtr; break; Index: test/ScheduleOptimizer/pattern-matching-based-opts_15.ll =================================================================== --- /dev/null +++ test/ScheduleOptimizer/pattern-matching-based-opts_15.ll @@ -0,0 +1,105 @@ +; RUN: opt %loadPolly -polly-import-jscop \ +; RUN: -sroa -simplifycfg -loop-rotate -loop-simplify \ +; RUN: -polly-delicm -polly-simplify -polly-opt-isl \ +; RUN: -polly-pattern-matching-based-opts=true \ +; RUN: -pass-remarks-analysis=polly-detect \ +; RUN: -polly-delicm-partial-writes -debug < %s 2>&1 \ +; RUN: | FileCheck %s +; +; Check that the pattern matching detects the matrix multiplication pattern +; when the AccMap cannot be reduced to a single disjunct. +; +; CHECK: The matrix multiplication pattern was detected +; +source_filename = "__compute_module" +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux_gnu" + +define void @"cluster_2[_XlaCompiledKernel=true,_XlaNumConstantArgs=0].v5"(i8* align 16 dereferenceable(16) %retval, i8* noalias %run_options, i8** noalias %params, i8** noalias %temps, i64* noalias %prof_counters) { +entry: + %accum_address = alloca float + %invar_address.reduction = alloca i64 + %invar_address.rhs1 = alloca i64 + %invar_address.lhs0 = alloca i64 + %0 = getelementptr inbounds i8*, i8** %params, i64 1 + %1 = load i8*, i8** %0, !tbaa !0, !dereferenceable !3, !align !4 + %2 = bitcast i8* %1 to [2 x [784 x float]]* + %3 = getelementptr inbounds i8*, i8** %params, i64 0 + %4 = load i8*, i8** %3, !tbaa !5, !dereferenceable !3, !align !4 + %5 = bitcast i8* %4 to [784 x [2 x float]]* + %6 = bitcast i8* %retval to [2 x [2 x float]]* + store i64 0, i64* %invar_address.lhs0 + br label %loop_header.lhs0 + + loop_header.lhs0: ; preds = %loop_exit.rhs1, %entry + %indvar.lhs0 = load i64, i64* %invar_address.lhs0 + %7 = icmp uge i64 %indvar.lhs0, 2 + br i1 %7, label %loop_exit.lhs0, label %loop_body.lhs0 + + loop_body.lhs0: ; preds = %loop_header.lhs0 + store i64 0, i64* %invar_address.rhs1 + br label %loop_header.rhs1 + + loop_header.rhs1: ; preds = %loop_exit.reduction, %loop_body.lhs0 + %indvar.rhs1 = load i64, i64* %invar_address.rhs1 + %8 = icmp uge i64 %indvar.rhs1, 2 + br i1 %8, label %loop_exit.rhs1, label %loop_body.rhs1 + + loop_body.rhs1: ; preds = %loop_header.rhs1 + store i64 0, i64* %invar_address.reduction + store float 0.000000e+00, float* %accum_address + br label %loop_header.reduction + + loop_header.reduction: ; preds = %loop_body.reduction, %loop_body.rhs1 + %indvar.reduction = load i64, i64* %invar_address.reduction + %9 = icmp uge i64 %indvar.reduction, 784 + br i1 %9, label %loop_exit.reduction, label %loop_body.reduction + + loop_body.reduction: ; preds = %loop_header.reduction + %10 = getelementptr inbounds [2 x [784 x float]], [2 x [784 x float]]* %2, i64 0, i64 %indvar.lhs0, i64 %indvar.reduction + %11 = load float, float* %10, !tbaa !7, !invariant.load !9, !noalias !10 + %12 = getelementptr inbounds [784 x [2 x float]], [784 x [2 x float]]* %5, i64 0, i64 %indvar.reduction, i64 %indvar.rhs1 + %13 = load float, float* %12, !tbaa !13, !invariant.load !9, !noalias !10 + %14 = fmul fast float %11, %13 + %15 = load float, float* %accum_address + %16 = fadd fast float %15, %14 + store float %16, float* %accum_address + %invar.inc2 = add nuw nsw i64 %indvar.reduction, 1 + store i64 %invar.inc2, i64* %invar_address.reduction + br label %loop_header.reduction + + loop_exit.reduction: ; preds = %loop_header.reduction + %17 = load float, float* %accum_address + %18 = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %6, i64 0, i64 %indvar.lhs0, i64 %indvar.rhs1 + store float %17, float* %18, !tbaa !15, !alias.scope !10 + %invar.inc1 = add nuw nsw i64 %indvar.rhs1, 1 + store i64 %invar.inc1, i64* %invar_address.rhs1 + br label %loop_header.rhs1 + + loop_exit.rhs1: ; preds = %loop_header.rhs1 + %invar.inc = add nuw nsw i64 %indvar.lhs0, 1 + store i64 %invar.inc, i64* %invar_address.lhs0 + br label %loop_header.lhs0 + + loop_exit.lhs0: ; preds = %loop_header.lhs0 + %prof_counter_computation = getelementptr i64, i64* %prof_counters, i64 0 + ret void +} + +!0 = !{!1, !1, i64 0} +!1 = !{!"pointer-to element_type: F32 dimensions: 2 dimensions: 784 layout { minor_to_major: 1 minor_to_major: 0 }", !2} +!2 = !{!"XLA TBAA"} +!3 = !{i64 6272} +!4 = !{i64 16} +!5 = !{!6, !6, i64 0} +!6 = !{!"pointer-to element_type: F32 dimensions: 784 dimensions: 2 layout { minor_to_major: 1 minor_to_major: 0 }", !2} +!7 = !{!8, !8, i64 0} +!8 = !{!"element_type: F32 dimensions: 2 dimensions: 784 layout { minor_to_major: 1 minor_to_major: 0 }", !2} +!9 = !{} +!10 = !{!11} +!11 = !{!"buffer: 2", !12} +!12 = distinct !{!12} +!13 = !{!14, !14, i64 0} +!14 = !{!"element_type: F32 dimensions: 784 dimensions: 2 layout { minor_to_major: 1 minor_to_major: 0 }", !2} +!15 = !{!16, !16, i64 0} +!16 = !{!"element_type: F32 dimensions: 2 dimensions: 2 layout { minor_to_major: 1 minor_to_major: 0 }", !2}