diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -307,6 +307,24 @@ } }; +// The Folder will fold expressions that is guarded by the loop entry. +class SCEVFolder : public SCEVRewriteVisitor { +public: + ScalarEvolution &SE; + const Loop *CurLoop; + SCEVFolder(ScalarEvolution &SE, const Loop *CurLoop) + : SCEVRewriteVisitor(SE), SE(SE), CurLoop(CurLoop) {} + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + // If expression is guarded by CurLoop to be greater of equal to zero + // then convert sext to zext. Otherwise return the original expression. + if (SE.isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Expr, + SE.getZero(Expr->getType())) == false) + return Expr; + return SE.getZeroExtendExpr(visit(Expr->getOperand()), Expr->getType()); + } +}; + } // end anonymous namespace char LoopIdiomRecognizeLegacyPass::ID = 0; @@ -968,12 +986,23 @@ << "\n"); if (PositiveStrideSCEV != MemsetSizeSCEV) { - // TODO: folding can be done to the SCEVs // The folding is to fold expressions that is covered by the loop guard // at loop entry. After the folding, compare again and proceed // optimization if equal. - LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); - return false; + SCEVFolder Folder(*SE, CurLoop); + const SCEV *FoldedPositiveStride = Folder.visit(PositiveStrideSCEV); + const SCEV *FoldedMemsetSize = Folder.visit(MemsetSizeSCEV); + + LLVM_DEBUG(dbgs() << " Try to fold SCEV with respect to loop guard\n" + << " FoldedMemsetSize: " << *FoldedMemsetSize + << "\n" + << " FoldedPositiveStride: " << *FoldedPositiveStride + << "\n"); + + if (FoldedPositiveStride != FoldedMemsetSize) { + LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); + return false; + } } } diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll @@ -17,9 +17,12 @@ ; CHECK-NEXT: memset size is not a loop-invariant, abort ; CHECK: loop-idiom Scanning: F[MemsetSize_Stride_Mismatch] Countable Loop %for.body ; CHECK-NEXT: memset size is non-constant -; CHECK-NEXT: MemsetSizeSCEV: (4 * (sext i32 %m to i64)) -; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))) -; CHECK-NEXT: SCEV don't match, abort +; CHECK-NEXT: MemsetSizeSCEV: (4 * (sext i32 %m to i64)) +; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))) +; CHECK-NEXT: Try to fold SCEV with respect to loop guard +; CHECK-NEXT: FoldedMemsetSize: (4 * (sext i32 %m to i64)) +; CHECK-NEXT: FoldedPositiveStride: (4 + (4 * (sext i32 %m to i64))) +; CHECK-NEXT: SCEV don't match, abort ; CHECK: loop-idiom Scanning: F[NonZeroAddressSpace] Countable Loop %for.cond1.preheader ; CHECK-NEXT: memset size is non-constant ; CHECK-NEXT: pointer is not in address space zero, abort diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll --- a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll @@ -107,4 +107,110 @@ ret void } +; The C code to generate this testcase: +; void test(int n, int m, int o, int *ar) +; { +; for (int i=0; i