Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -672,12 +672,19 @@ static bool isMatMulOperandAcc(__isl_keep isl_map *AccMap, int &FirstPos, int &SecondPos) { int DimInPos[] = {FirstPos, SecondPos}; - if (isl_map_foreach_basic_map(AccMap, isMatMulOperandBasicMap, + if (isl_map_dim(AccMap, isl_dim_out) != 2) + return false; + auto *NonPartialAccMap = isl_map_drop_constraints_not_involving_dims( + isl_map_copy(AccMap), isl_dim_out, 0, 2); + if (isl_map_foreach_basic_map(NonPartialAccMap, isMatMulOperandBasicMap, static_cast(DimInPos)) != isl_stat_ok || - DimInPos[0] < 0 || DimInPos[1] < 0) + DimInPos[0] < 0 || DimInPos[1] < 0) { + isl_map_free(NonPartialAccMap); return false; + } FirstPos = DimInPos[0]; SecondPos = DimInPos[1]; + isl_map_free(NonPartialAccMap); return true; } @@ -694,7 +701,7 @@ /// false, otherwise. static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess, MatMulInfoTy &MMI) { - if (!MemAccess->isArrayKind() || !MemAccess->isRead()) + if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) return false; isl_map *AccMap = MemAccess->getAccessRelation(); if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC && @@ -749,15 +756,22 @@ MMI.k, OutDimNum - 1); for (auto *MemA = Stmt->begin(); MemA != Stmt->end() - 1; MemA++) { auto *MemAccessPtr = *MemA; - if (MemAccessPtr->isArrayKind() && MemAccessPtr != MMI.WriteToC && + if (MemAccessPtr->isLatestArrayKind() && !isMatMulNonScalarReadAccess(MemAccessPtr, MMI) && !(MemAccessPtr->isStrideZero(isl_map_copy(MapI)) && MemAccessPtr->isStrideZero(isl_map_copy(MapJ)) && MemAccessPtr->isStrideZero(isl_map_copy(MapK)))) { - isl_map_free(MapI); - isl_map_free(MapJ); - isl_map_free(MapK); - return false; + auto *AccessIntersection = + isl_map_intersect(MemAccessPtr->getLatestAccessRelation(), + MMI.WriteToC->getLatestAccessRelation()); + if (isl_map_is_empty(AccessIntersection)) { + isl_map_free(AccessIntersection); + isl_map_free(MapI); + isl_map_free(MapJ); + isl_map_free(MapK); + return false; + } + isl_map_free(AccessIntersection); } } isl_map_free(MapI); @@ -841,7 +855,7 @@ return false; for (auto *MemA = Stmt->end() - 1; MemA != Stmt->begin(); MemA--) { auto *MemAccessPtr = *MemA; - if (!MemAccessPtr->isArrayKind()) + if (!MemAccessPtr->isLatestArrayKind()) continue; if (!MemAccessPtr->isWrite()) return false; Index: test/ScheduleOptimizer/pattern-matching-based-opts_11.ll =================================================================== --- /dev/null +++ test/ScheduleOptimizer/pattern-matching-based-opts_11.ll @@ -0,0 +1,54 @@ +; RUN: opt %loadPolly -polly -polly-delicm \ +; RUN: -polly-delicm-overapproximate-writes -polly-pattern-matching-based-opts \ +; RUN: -polly-opt-isl -debug < %s 2>&1 | FileCheck %s +; +; Check that the pattern matching detects the matrix multiplication pattern +; in case memory accesses were modified by DeLICM. +; +; CHECK: The matrix multiplication pattern was detected +; + +define internal fastcc void @kernel_gemm([1024 x double]* nocapture %C, [1024 x double]* nocapture readonly %A, [1024 x double]* nocapture readonly %B) unnamed_addr { +entry: + br label %entry.split + +entry.split: ; preds = %entry + br label %for.cond1.preheader + +for.cond1.preheader: ; preds = %for.inc20, %entry.split + %indvars.iv7 = phi i64 [ 0, %entry.split ], [ %indvars.iv.next8, %for.inc20 ] + br label %for.cond4.preheader + +for.cond4.preheader: ; preds = %for.inc17, %for.cond1.preheader + %indvars.iv4 = phi i64 [ 0, %for.cond1.preheader ], [ %indvars.iv.next5, %for.inc17 ] + %arrayidx16 = getelementptr inbounds [1024 x double], [1024 x double]* %C, i64 %indvars.iv7, i64 %indvars.iv4 + %.pre = load double, double* %arrayidx16, align 8 + br label %for.body6 + +for.body6: ; preds = %for.body6, %for.cond4.preheader + %0 = phi double [ %.pre, %for.cond4.preheader ], [ %add, %for.body6 ] + %indvars.iv = phi i64 [ 0, %for.cond4.preheader ], [ %indvars.iv.next, %for.body6 ] + %arrayidx8 = getelementptr inbounds [1024 x double], [1024 x double]* %A, i64 %indvars.iv7, i64 %indvars.iv + %1 = load double, double* %arrayidx8, align 8 + %arrayidx12 = getelementptr inbounds [1024 x double], [1024 x double]* %B, i64 %indvars.iv, i64 %indvars.iv4 + %2 = load double, double* %arrayidx12, align 8 + %mul = fmul double %1, %2 + %add = fadd double %0, %mul + store double %add, double* %arrayidx16, align 8 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, 1024 + br i1 %exitcond, label %for.inc17, label %for.body6 + +for.inc17: ; preds = %for.body6 + %indvars.iv.next5 = add nuw nsw i64 %indvars.iv4, 1 + %exitcond6 = icmp eq i64 %indvars.iv.next5, 1024 + br i1 %exitcond6, label %for.inc20, label %for.cond4.preheader + +for.inc20: ; preds = %for.inc17 + %indvars.iv.next8 = add nuw nsw i64 %indvars.iv7, 1 + %exitcond9 = icmp eq i64 %indvars.iv.next8, 1024 + br i1 %exitcond9, label %for.end22, label %for.cond1.preheader + +for.end22: ; preds = %for.inc20 + ret void +}