diff --git a/llvm/include/llvm/Analysis/LoopInfo.h b/llvm/include/llvm/Analysis/LoopInfo.h --- a/llvm/include/llvm/Analysis/LoopInfo.h +++ b/llvm/include/llvm/Analysis/LoopInfo.h @@ -872,6 +872,8 @@ /// unrolling pass is run more than once (which it generally is). void setLoopAlreadyUnrolled(); + void setLoopCanBeUnrolled(); + /// Add llvm.loop.mustprogress to this loop's loop id metadata. void setLoopMustProgress(); diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -545,6 +545,17 @@ setLoopID(NewLoopID); } +void Loop::setLoopCanBeUnrolled() { + LLVMContext &Context = getHeader()->getContext(); + + MDNode *EnableUnrollMD = + MDNode::get(Context, MDString::get(Context, "llvm.loop.unroll.enable")); + MDNode *LoopID = getLoopID(); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, LoopID, {"llvm.loop.unroll."}, {EnableUnrollMD}); + setLoopID(NewLoopID); +} + void Loop::setLoopMustProgress() { LLVMContext &Context = getHeader()->getContext(); diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -1330,6 +1330,7 @@ LLVMLoopUnrollFollowupRemainder}); if (RemainderLoopID) RemainderLoop->setLoopID(*RemainderLoopID); + RemainderLoop->setLoopCanBeUnrolled(); } if (UnrollResult != LoopUnrollResult::FullyUnrolled) {