diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -509,10 +509,6 @@ if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine()) return LegalStoreKind::None; - // Check to see if we have a constant stride. - if (!isa(StoreEv->getOperand(1))) - return LegalStoreKind::None; - // See if the store can be turned into a memset. // If the stored value is a byte-wise value (like i32 -1), then it may be @@ -897,7 +893,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, const SCEV *BECount) { // We can only handle non-volatile memsets with a constant size. - if (MSI->isVolatile() || !isa(MSI->getLength())) + if (MSI->isVolatile()) return false; // If we're not allowed to hack on memset, we fail. @@ -913,20 +909,68 @@ if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine()) return false; - // Reject memsets that are so large that they overflow an unsigned. - uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); - if ((SizeInBytes >> 32) != 0) + const SCEV *StrideSCEV = Ev->getOperand(1); + const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); + if (!StrideSCEV || !MemsetSizeSCEV) return false; - // Check to see if the stride matches the size of the memset. If so, then we - // know that every byte is touched in the loop. - const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); - if (!ConstStride) - return false; + bool IsNegStride; + const bool IsConstantSize = isa(MSI->getLength()); - APInt Stride = ConstStride->getAPInt(); - if (SizeInBytes != Stride && SizeInBytes != -Stride) - return false; + if (IsConstantSize) { + // Memset size is constant + // Reject memsets that are so large that they overflow an unsigned. + LLVM_DEBUG(dbgs() << " memset size is constant\n"); + uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); + if ((SizeInBytes >> 32) != 0) + return false; + + // Check to see if the stride matches the size of the memset. If so, then + // we know that every byte is touched in the loop. + const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); + if (!ConstStride) + return false; + + APInt Stride = ConstStride->getAPInt(); + if (SizeInBytes != Stride && SizeInBytes != -Stride) + return false; + + IsNegStride = SizeInBytes == -Stride; + } else { + // Memset size is non-constant + // Check if the pointer stride matches the memset size. + // Also, the pass only handle memset length and stride that are invariant + // for the top level loop. To be conservative, the pass would not promote + // pointers that isn't in address space zero. + // If the original StrideSCEV and MemsetSizeSCEV does not match, the pass + // will fold expressions that is covered by the loop guard at loop entry. + // The pass will compare again after the folding and proceed if equal. + LLVM_DEBUG(dbgs() << " memset size is non-constant\n"); + if (Pointer->getType()->getPointerAddressSpace() != 0) { + LLVM_DEBUG(dbgs() << " pointer is not in address space zero\n"); + return false; + } + if (!SE->isLoopInvariant(MemsetSizeSCEV, CurLoop) || + !SE->isLoopInvariant(StrideSCEV, CurLoop)) { + LLVM_DEBUG(dbgs() << " memset size or stride is not a loop-invariant, " + << "abort\n"); + return false; + } + + // compare positive direction strideSCEV with MemsizeSizeSCEV + IsNegStride = StrideSCEV->isNonConstantNegative(); + const SCEV *PositiveStrideSCEV = + IsNegStride ? SE->getNegativeSCEV(StrideSCEV) : StrideSCEV; + LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n" + << " PositiveStrideSCEV: " << *PositiveStrideSCEV + << "\n"); + + if (PositiveStrideSCEV != MemsetSizeSCEV) { + // TODO: folding can be done to the SCEVs + LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); + return false; + } + } // Verify that the memset value is loop invariant. If not, we can't promote // the memset. @@ -936,7 +980,6 @@ SmallPtrSet MSIs; MSIs.insert(MSI); - bool IsNegStride = SizeInBytes == -Stride; return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()), MaybeAlign(MSI->getDestAlignment()), SplatValue, MSI, MSIs, Ev, BECount, diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll @@ -0,0 +1,50 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes="function(loop(loop-idiom,loop-deletion),simplifycfg)" < %s -S | FileCheck %s +; The C code to generate this testcase: +; void test(int n, int m, int *ar) +; { +; long i; +; for (i=0; i