Index: lib/Transform/MaximalStaticExpansion.cpp =================================================================== --- lib/Transform/MaximalStaticExpansion.cpp +++ lib/Transform/MaximalStaticExpansion.cpp @@ -68,20 +68,32 @@ SmallPtrSetImpl &Reads, Scop &S, isl::union_map &Dependences); - /// Expand a write memory access. + /// Expand the MemoryAccess according to its domain. /// /// @param S The SCop in which the memory access appears in. /// @param MA The memory access that need to be expanded. - ScopArrayInfo *expandWrite(Scop &S, MemoryAccess *MA); + ScopArrayInfo *expandAccordingToDomain(Scop &S, MemoryAccess *MA); - /// Expand the read memory access. + /// Expand the MemoryAccess according to Dependences and already expanded + /// MemoryAccesses. /// /// @param The SCop in which the memory access appears in. /// @param The memory access that need to be expanded. /// @param Dependences The RAW dependences of the SCop. /// @param ExpandedSAI The expanded SAI created during write expansion. - void expandRead(Scop &S, MemoryAccess *MA, isl::union_map &Dependences, - ScopArrayInfo *ExpandedSAI); + /// @param If Reverse is true, the Dependences union_map is reversed before + /// intersection. + void expandAccordingToDependences(Scop &S, MemoryAccess *MA, + isl::union_map &Dependences, + ScopArrayInfo *ExpandedSAI, bool Reverse); + + /// Expand PHI memory accesses. + /// + /// @param The SCop in which the memory access appears in. + /// @param The ScopArrayInfo representing the PHI accesses to expand. + /// @param Dependences The RAW dependences of the SCop. + void expandPhi(Scop &S, const ScopArrayInfo *SAI, + isl::union_map &Dependences); }; } // namespace @@ -151,6 +163,20 @@ SmallPtrSetImpl &Reads, Scop &S, isl::union_map &Dependences) { + if (SAI->isValueKind()) { + Writes.insert(S.getValueDef(SAI)); + for (auto MA : S.getValueUses(SAI)) + Reads.insert(MA); + return true; + } else if (SAI->isPHIKind()) { + return true; + } else if (SAI->isExitPHIKind()) { + // For now, we are not able to expand ExitPhi. + emitRemark(SAI->getName() + " is a ExitPhi node.", + S.getEnteringBlock()->getFirstNonPHI()); + return false; + } + int NumberWrites = 0; for (ScopStmt &Stmt : S) { for (MemoryAccess *MA : Stmt) { @@ -159,13 +185,6 @@ if (SAI != MA->getLatestScopArrayInfo()) continue; - // For now, we are not able to expand Scalar. - if (MA->isLatestScalarKind()) { - emitRemark(SAI->getName() + " is a Scalar access.", - MA->getAccessInstruction()); - return false; - } - // For now, we are not able to expand MayWrite. if (MA->isMayWrite()) { emitRemark(SAI->getName() + " has a maywrite access.", @@ -235,19 +254,23 @@ return true; } -void MaximalStaticExpander::expandRead(Scop &S, MemoryAccess *MA, - isl::union_map &Dependences, - ScopArrayInfo *ExpandedSAI) { +void MaximalStaticExpander::expandAccordingToDependences( + Scop &S, MemoryAccess *MA, isl::union_map &Dependences, + ScopArrayInfo *ExpandedSAI, bool Reverse) { // Get the current AM. auto CurrentAccessMap = MA->getAccessRelation(); // Get RAW dependences for the current WA. - auto WriteDomainSet = MA->getAccessRelation().domain(); - auto WriteDomain = isl::union_set(WriteDomainSet); + auto DomainSet = MA->getAccessRelation().domain(); + auto Domain = isl::union_set(DomainSet); - auto CurrentReadWriteDependences = - Dependences.reverse().intersect_domain(WriteDomain); + isl::union_map CurrentReadWriteDependences; + if (Reverse) + CurrentReadWriteDependences = + Dependences.reverse().intersect_domain(Domain); + else + CurrentReadWriteDependences = Dependences.intersect_domain(Domain); // If no dependences, no need to modify anything. if (CurrentReadWriteDependences.is_empty()) { @@ -267,7 +290,8 @@ MA->setNewAccessRelation(NewAccessMap); } -ScopArrayInfo *MaximalStaticExpander::expandWrite(Scop &S, MemoryAccess *MA) { +ScopArrayInfo * +MaximalStaticExpander::expandAccordingToDomain(Scop &S, MemoryAccess *MA) { // Get the current AM. auto CurrentAccessMap = MA->getAccessRelation(); @@ -336,6 +360,17 @@ return ExpandedSAI; } +void MaximalStaticExpander::expandPhi(Scop &S, const ScopArrayInfo *SAI, + isl::union_map &Dependences) { + auto Writes = S.getPHIIncomings(SAI); + auto Read = S.getPHIRead(SAI); + auto ExpandedSAI = expandAccordingToDomain(S, Read); + + for (auto MA : Writes) { + expandAccordingToDependences(S, MA, Dependences, ExpandedSAI, false); + } +} + void MaximalStaticExpander::emitRemark(StringRef Msg, Instruction *Inst) { ORE->emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "ExpansionRejection", Inst) << Msg); @@ -360,13 +395,20 @@ if (!isExpandable(SAI, AllWrites, AllReads, S, Dependences)) continue; - assert(AllWrites.size() == 1); + // If MemoryKind::Value of MemoryKind::Array + if (SAI->isValueKind() || SAI->isArrayKind()) { + assert(AllWrites.size() == 1 || SAI->isValueKind()); - auto TheWrite = *(AllWrites.begin()); - ScopArrayInfo *ExpandedArray = expandWrite(S, TheWrite); + auto TheWrite = *(AllWrites.begin()); + ScopArrayInfo *ExpandedArray = expandAccordingToDomain(S, TheWrite); - for (MemoryAccess *MA : AllReads) - expandRead(S, MA, Dependences, ExpandedArray); + for (MemoryAccess *MA : AllReads) + expandAccordingToDependences(S, MA, Dependences, ExpandedArray, true); + } + // Else If MemoryKind::Phi + else if (SAI->isPHIKind()) { + expandPhi(S, SAI, Dependences); + } } return false; Index: test/MaximalStaticExpansion/non_working_phi.ll =================================================================== --- /dev/null +++ test/MaximalStaticExpansion/non_working_phi.ll @@ -0,0 +1,83 @@ +; RUN: opt %loadPolly -polly-canonicalize -polly-mse -analyze < %s 2>1| FileCheck %s +; +; Verify that the accesses are correctly expanded for MemoryKind::PHI +; +; Original source code : +; +; #define Ni 10000 +; #define Nj 10000 +; +; void mse(double A[Ni], double B[Nj]) { +; int i,j; +; double tmp = 6; +; for (i = 0; i < Ni; i++) { +; for (int j = 0; j MemRef_conv_Stmt_for_body_expanded[i0] }; +; CHECK: new: { Stmt_for_body5[i0, i1] -> MemRef_conv_Stmt_for_body_expanded[i0] }; +; CHECK: new: { Stmt_for_end[i0] -> MemRef_conv_Stmt_for_body_expanded[i0] }; + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +; Function Attrs: noinline nounwind uwtable +define void @mse(double* %A, double* %B) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %i = alloca i32, align 4 + %j = alloca i32, align 4 + %tmp = alloca double, align 8 + %j1 = alloca i32, align 4 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store double 6.000000e+00, double* %tmp, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond +for.cond: ; preds = %for.inc8, %entry + %0 = load i32, i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end10 +for.body: ; preds = %for.cond + %1 = load i32, i32* %i, align 4 + %conv = sitofp i32 %1 to double + store double %conv, double* %tmp, align 8 + store i32 0, i32* %j1, align 4 + br label %for.cond2 +for.cond2: ; preds = %for.inc, %for.body + %2 = load i32, i32* %j1, align 4 + %cmp3 = icmp slt i32 %2, 10000 + br i1 %cmp3, label %for.body5, label %for.end +for.body5: ; preds = %for.cond2 + %3 = load double, double* %tmp, align 8 + %add = fadd double %3, 3.000000e+00 + %4 = load double*, double** %A.addr, align 8 + %5 = load i32, i32* %j1, align 4 + %idxprom = sext i32 %5 to i64 + %arrayidx = getelementptr inbounds double, double* %4, i64 %idxprom + store double %add, double* %arrayidx, align 8 + br label %for.inc +for.inc: ; preds = %for.body5 + %6 = load i32, i32* %j1, align 4 + %inc = add nsw i32 %6, 1 + store i32 %inc, i32* %j1, align 4 + br label %for.cond2 +for.end: ; preds = %for.cond2 + %7 = load double, double* %tmp, align 8 + %8 = load double*, double** %B.addr, align 8 + %9 = load i32, i32* %i, align 4 + %idxprom6 = sext i32 %9 to i64 + %arrayidx7 = getelementptr inbounds double, double* %8, i64 %idxprom6 + store double %7, double* %arrayidx7, align 8 + br label %for.inc8 +for.inc8: ; preds = %for.end + %10 = load i32, i32* %i, align 4 + %inc9 = add nsw i32 %10, 1 + store i32 %inc9, i32* %i, align 4 + br label %for.cond +for.end10: ; preds = %for.cond + ret void +}