diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -271,7 +271,7 @@ case 2: return &MD->getOperand(1); default: - llvm_unreachable("loop metadata has 0 or 1 operand"); + llvm_unreachable("loop metadata has 0 or 1 operands"); } } @@ -287,7 +287,7 @@ case 2: if (ConstantInt *IntMD = mdconst::extract_or_null(MD->getOperand(1).get())) - return IntMD->getZExtValue(); + return IntMD->getZExtValue()!=0; return true; } llvm_unreachable("unexpected number of options"); diff --git a/polly/include/polly/CodeGen/IRBuilder.h b/polly/include/polly/CodeGen/IRBuilder.h --- a/polly/include/polly/CodeGen/IRBuilder.h +++ b/polly/include/polly/CodeGen/IRBuilder.h @@ -25,6 +25,7 @@ namespace polly { class Scop; +struct BandAttr; /// Helper class to annotate newly generated SCoPs with metadata. /// @@ -43,6 +44,7 @@ class ScopAnnotator { public: ScopAnnotator(); + ~ScopAnnotator(); /// Build all alias scopes for the given SCoP. void buildAliasScopes(Scop &S); @@ -83,6 +85,13 @@ /// Add inter iteration alias-free base pointer @p BasePtr. void addInterIterationAliasFreeBasePtr(llvm::Value *BasePtr); + /// Stack for surrounding BandAttr annotations. + llvm::SmallVector LoopAttrEnv; + BandAttr *&getStagingAttrEnv() { return LoopAttrEnv.back(); } + BandAttr *getActiveAttrEnv() const { + return LoopAttrEnv[LoopAttrEnv.size() - 2]; + } + private: /// Annotate with the second level alias metadata /// diff --git a/polly/include/polly/ManualOptimizer.h b/polly/include/polly/ManualOptimizer.h new file mode 100644 --- /dev/null +++ b/polly/include/polly/ManualOptimizer.h @@ -0,0 +1,36 @@ +//===------ ManualOptimizer.h ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Handle pragma/metadata-directed transformations. +// +//===----------------------------------------------------------------------===// + +#ifndef POLLY_MANUALOPTIMIZER_H +#define POLLY_MANUALOPTIMIZER_H + +#include "isl/isl-noexceptions.h" + +namespace polly { +class Scop; + +/// Apply loop-transformation metadata. +/// +/// The loop metadata are taken from mark-nodes in @sched. These nodes have been +/// added by ScopBuilder when creating a schedule for a loop with an attach +/// LoopID. +/// +/// @param S The SCoP for @p Sched. +/// @param Sched The input schedule to apply the directives on. +/// +/// @return The transformed schedule with all mark-nodes with loop +/// transformations applied. Returns NULL in case of an error or @p +/// Sched itself if no transformation has been applied. +isl::schedule applyManualTransformations(Scop *S, isl::schedule Sched); +} // namespace polly + +#endif /* POLLY_MANUALOPTIMIZER_H */ diff --git a/polly/include/polly/ScheduleTreeTransform.h b/polly/include/polly/ScheduleTreeTransform.h --- a/polly/include/polly/ScheduleTreeTransform.h +++ b/polly/include/polly/ScheduleTreeTransform.h @@ -13,14 +13,157 @@ #ifndef POLLY_SCHEDULETREETRANSFORM_H #define POLLY_SCHEDULETREETRANSFORM_H +#include "llvm/Support/ErrorHandling.h" #include "isl/isl-noexceptions.h" +#include namespace polly { +struct BandAttr; + +/// This class defines a simple visitor class that may be used for +/// various schedule tree analysis purposes. +template +struct ScheduleTreeVisitor { + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + + RetTy visit(const isl::schedule_node &Node, Args... args) { + assert(!Node.is_null()); + switch (isl_schedule_node_get_type(Node.get())) { + case isl_schedule_node_domain: + assert(isl_schedule_node_n_children(Node.get()) == 1); + return getDerived().visitDomain(Node, std::forward(args)...); + case isl_schedule_node_band: + assert(isl_schedule_node_n_children(Node.get()) == 1); + return getDerived().visitBand(Node, std::forward(args)...); + case isl_schedule_node_sequence: + assert(isl_schedule_node_n_children(Node.get()) >= 2); + return getDerived().visitSequence(Node, std::forward(args)...); + case isl_schedule_node_set: + return getDerived().visitSet(Node, std::forward(args)...); + assert(isl_schedule_node_n_children(Node.get()) >= 2); + case isl_schedule_node_leaf: + assert(isl_schedule_node_n_children(Node.get()) == 0); + return getDerived().visitLeaf(Node, std::forward(args)...); + case isl_schedule_node_mark: + assert(isl_schedule_node_n_children(Node.get()) == 1); + return getDerived().visitMark(Node, std::forward(args)...); + case isl_schedule_node_extension: + assert(isl_schedule_node_n_children(Node.get()) == 1); + return getDerived().visitExtension(Node, std::forward(args)...); + case isl_schedule_node_filter: + assert(isl_schedule_node_n_children(Node.get()) == 1); + return getDerived().visitFilter(Node, std::forward(args)...); + default: + llvm_unreachable("unimplemented schedule node type"); + } + } + + RetTy visitDomain(const isl::schedule_node &Domain, Args... args) { + return getDerived().visitSingleChild(Domain, std::forward(args)...); + } + + RetTy visitBand(const isl::schedule_node &Band, Args... args) { + return getDerived().visitSingleChild(Band, std::forward(args)...); + } + + RetTy visitSequence(const isl::schedule_node &Sequence, Args... args) { + return getDerived().visitMultiChild(Sequence, std::forward(args)...); + } + + RetTy visitSet(const isl::schedule_node &Set, Args... args) { + return getDerived().visitMultiChild(Set, std::forward(args)...); + } + + RetTy visitLeaf(const isl::schedule_node &Leaf, Args... args) { + return getDerived().visitNode(Leaf, std::forward(args)...); + } + + RetTy visitMark(const isl::schedule_node &Mark, Args... args) { + return getDerived().visitSingleChild(Mark, std::forward(args)...); + } + + RetTy visitExtension(const isl::schedule_node &Extension, Args... args) { + return getDerived().visitSingleChild(Extension, + std::forward(args)...); + } + + RetTy visitFilter(const isl::schedule_node &Extension, Args... args) { + return getDerived().visitSingleChild(Extension, + std::forward(args)...); + } + + RetTy visitSingleChild(const isl::schedule_node &Node, Args... args) { + return getDerived().visitNode(Node, std::forward(args)...); + } + + RetTy visitMultiChild(const isl::schedule_node &Node, Args... args) { + return getDerived().visitNode(Node, std::forward(args)...); + } + + RetTy visitNode(const isl::schedule_node &Node, Args... args) { + llvm_unreachable("Unimplemented other"); + } +}; + +/// Recursively visit all nodes of a schedule tree. +template +struct RecursiveScheduleTreeVisitor + : public ScheduleTreeVisitor { + using BaseTy = ScheduleTreeVisitor; + BaseTy &getBase() { return *this; } + const BaseTy &getBase() const { return *this; } + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + + /// When visiting an entire schedule tree, start at its root node. + RetTy visit(const isl::schedule &Schedule, Args... args) { + return getDerived().visit(Schedule.get_root(), std::forward(args)...); + } + + // Necessary to allow overload resolution with the added visit(isl::schedule) + // overload. + RetTy visit(const isl::schedule_node &Node, Args... args) { + return getBase().visit(Node, std::forward(args)...); + } + + /// By default, recursively visit the child nodes. + RetTy visitNode(const isl::schedule_node &Node, Args... args) { + isl_size NumChildren = Node.n_children(); + for (isl_size i = 0; i < NumChildren; i += 1) + getDerived().visit(Node.child(i), std::forward(args)...); + return RetTy(); + } +}; + +/// Is this node the marker for its parent band? +bool isBandMark(const isl::schedule_node &Node); + +/// Extract the BandAttr from a band's wrapping marker. Can also pass the band +/// itself and this methods will try to find its wrapping mark. Returns nullptr +/// if the band has not BandAttr. +BandAttr *getBandAttr(isl::schedule_node MarkOrBand); + /// Hoist all domains from extension into the root domain node, such that there /// are no more extension nodes (which isl does not support for some /// operations). This assumes that domains added by to extension nodes do not /// overlap. isl::schedule hoistExtensionNodes(isl::schedule Sched); + +/// Replace the AST band @p BandToUnroll by a sequence of all its iterations. +/// +/// The implementation enumerates all points in the partial schedule and creates +/// an ISL sequence node for each point. The number of iterations must be a +/// constant. +isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll); + +/// Replace the AST band @p BandToUnroll by a partially unrolled equivalent. +isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor); + } // namespace polly #endif // POLLY_SCHEDULETREETRANSFORM_H diff --git a/polly/include/polly/ScopInfo.h b/polly/include/polly/ScopInfo.h --- a/polly/include/polly/ScopInfo.h +++ b/polly/include/polly/ScopInfo.h @@ -1885,6 +1885,9 @@ /// in a schedule tree is given in the isl manual. isl::schedule Schedule = nullptr; + /// Is this Scop marked as not to be transformed by an optimization heuristic? + bool HasDisableHeuristicsHint = false; + /// Whether the schedule has been modified after derived from the CFG by /// ScopBuilder. bool ScheduleModified = false; @@ -2035,7 +2038,6 @@ /// /// A new statement will be created and added to the statement vector. /// - /// @param Stmt The parent statement. /// @param SourceRel The source location. /// @param TargetRel The target location. /// @param Domain The original domain under which the copy statement would @@ -2744,6 +2746,13 @@ /// various places. If statistics are disabled, only zeros are returned to /// avoid the overhead. ScopStatistics getStatistics() const; + + /// Is this Scop marked as not to be transformed by an optimization heuristic? + /// In this case, only user-directed transformations are allowed. + bool hasDisableHeuristicsHint() const { return HasDisableHeuristicsHint; } + + /// Mark this Scop to not apply an optimization heuristic. + void markDisableHeuristics() { HasDisableHeuristicsHint = true; } }; /// Print Scop scop to raw_ostream OS. diff --git a/polly/include/polly/Support/ScopHelper.h b/polly/include/polly/Support/ScopHelper.h --- a/polly/include/polly/Support/ScopHelper.h +++ b/polly/include/polly/Support/ScopHelper.h @@ -544,5 +544,71 @@ /// /// Such a statement must not be removed, even if has no side-effects. bool hasDebugCall(ScopStmt *Stmt); + +/// Find a property value in a LoopID. +/// +/// Generally, a property MDNode has the format +/// +/// !{ !"Name", value } +/// +/// In which case the value is returned. +/// +/// If the property is just +/// +/// !{ !"Name" } +/// +/// Then `nullptr` is set to mark the property is existing, but does not carry +/// any value. If the property does not exist, `None` is returned. +llvm::Optional findMetadataOperand(llvm::MDNode *LoopMD, + llvm::StringRef Name); + +/// Does the loop's LoopID contain a 'llvm.loop.disable_heuristics' property? +/// +/// This is equivalent to llvm::hasDisableAllTransformsHint(Loop*), but +/// including the LoopUtils.h header indirectly also declares llvm::MemoryAccess +/// which clashes with polly::MemoryAccess. Declaring this alias here avoid +/// having to include LoopUtils.h in other files. +bool hasDisableAllTransformsHint(llvm::Loop *L); + +/// Represent the attributes of a loop. +struct BandAttr { + /// LoopID which stores the properties of the loop, such as transformations to + /// apply and the metadata of followup-loops. + /// + /// Cannot be used to identify a loop. Two different loops can have the same + /// metadata. + llvm::MDNode *Metadata = nullptr; + + /// The LoopInfo reference for this loop. + /// + /// Only loops from the original IR are represented by LoopInfo. Loops that + /// were generated by Polly are not tracked by LoopInfo. + llvm::Loop *OriginalLoop = nullptr; +}; + +/// Get an isl::id representing a loop. +/// +/// This takes the ownership of the BandAttr and will be free'd when the +/// returned isl::Id is free'd. +isl::id getIslLoopAttr(isl::ctx Ctx, BandAttr *Attr); + +/// Create an isl::id that identifies an original loop. +/// +/// Return nullptr if the loop does not need a BandAttr (i.e. has no +/// properties); +/// +/// This creates a BandAttr which must be unique per loop and therefore this +/// must not be called multiple times on the same loop as their id would be +/// different. +isl::id createIslLoopAttr(isl::ctx Ctx, llvm::Loop *L); + +/// Is @p Id representing a loop? +/// +/// Such ids contain a polly::BandAttr as its user pointer. +bool isLoopAttr(const isl::id &Id); + +/// Return the BandAttr of a loop's isl::id. +BandAttr *getLoopAttr(const isl::id &Id); + } // namespace polly #endif diff --git a/polly/lib/Analysis/ScopBuilder.cpp b/polly/lib/Analysis/ScopBuilder.cpp --- a/polly/lib/Analysis/ScopBuilder.cpp +++ b/polly/lib/Analysis/ScopBuilder.cpp @@ -1284,6 +1284,7 @@ auto NumBlocksProcessed = LoopData->NumBlocksProcessed; assert(std::next(LoopData) != LoopStack.rend()); + Loop *L = LoopData->L; ++LoopData; --Dimension; @@ -1291,6 +1292,25 @@ isl::union_set Domain = Schedule.get_domain(); isl::multi_union_pw_aff MUPA = mapToDimension(Domain, Dimension); Schedule = Schedule.insert_partial_schedule(MUPA); + + if (hasDisableAllTransformsHint(L)) { + /// If any of the loops has a disable_nonforced heuristic, mark the + /// entire SCoP as such. The ISL rescheduler can only reschedule the + /// SCoP in its entirety. + /// TODO: ScopDetection could avoid including such loops or warp them as + /// boxed loop. It still needs to pass-through loop with user-defined + /// metadata. + scop->markDisableHeuristics(); + } + + // It is easier to insert the marks here that do it retroactively. + isl::id IslLoopId = createIslLoopAttr(scop->getIslCtx(), L); + if (IslLoopId) + Schedule = Schedule.get_root() + .get_child(0) + .insert_mark(IslLoopId) + .get_schedule(); + LoopData->Schedule = combineInSequence(LoopData->Schedule, Schedule); } diff --git a/polly/lib/CMakeLists.txt b/polly/lib/CMakeLists.txt --- a/polly/lib/CMakeLists.txt +++ b/polly/lib/CMakeLists.txt @@ -98,6 +98,7 @@ Transform/MaximalStaticExpansion.cpp Transform/RewriteByReferenceParameters.cpp Transform/ScopInliner.cpp + Transform/ManualOptimizer.cpp ${POLLY_HEADER_FILES} LINK_COMPONENTS diff --git a/polly/lib/CodeGen/IRBuilder.cpp b/polly/lib/CodeGen/IRBuilder.cpp --- a/polly/lib/CodeGen/IRBuilder.cpp +++ b/polly/lib/CodeGen/IRBuilder.cpp @@ -46,7 +46,15 @@ return ID; } -ScopAnnotator::ScopAnnotator() : SE(nullptr), AliasScopeDomain(nullptr) {} +ScopAnnotator::ScopAnnotator() : SE(nullptr), AliasScopeDomain(nullptr) { + // Push an empty staging BandAttr. + LoopAttrEnv.emplace_back(); +} + +ScopAnnotator::~ScopAnnotator() { + assert(LoopAttrEnv.size() == 1 && "Loop stack imbalance"); + assert(!getStagingAttrEnv() && "Forgot to clear staging attr env"); +} void ScopAnnotator::buildAliasScopes(Scop &S) { SE = S.getSE(); @@ -101,6 +109,9 @@ MDNode *AccessGroup = MDNode::getDistinct(Ctx, {}); ParallelLoops.push_back(AccessGroup); } + + // Open an empty BandAttr context for loops nested in this one. + LoopAttrEnv.emplace_back(); } void ScopAnnotator::popLoop(bool IsParallel) { @@ -110,6 +121,11 @@ assert(!ParallelLoops.empty() && "Expected a parallel loop to pop"); ParallelLoops.pop_back(); } + + // Exit the subloop context. + assert(!getStagingAttrEnv() && "Forgot to clear staging attr env"); + assert(LoopAttrEnv.size() >= 2 && "Popped too many"); + LoopAttrEnv.pop_back(); } void ScopAnnotator::annotateLoopLatch(BranchInst *B, Loop *L, bool IsParallel, @@ -120,6 +136,16 @@ // For the LoopID self-reference. Args.push_back(nullptr); + // Add the user-defined loop properties to the annotation, if any. Any + // additional properties are appended. + // FIXME: What to do if these conflict? + MDNode *MData = nullptr; + if (BandAttr *AttrEnv = getActiveAttrEnv()) { + MData = AttrEnv->Metadata; + if (MData) + llvm::append_range(Args, drop_begin(MData->operands(), 1)); + } + if (IsLoopVectorizerDisabled) { MDString *PropName = MDString::get(Ctx, "llvm.loop.vectorize.enable"); ConstantInt *FalseValue = ConstantInt::get(Type::getInt1Ty(Ctx), 0); @@ -134,11 +160,16 @@ } // No metadata to annotate. - if (Args.size() <= 1) + if (!MData && Args.size() <= 1) return; - MDNode *MData = MDNode::getDistinct(Ctx, Args); - MData->replaceOperandWith(0, MData); + // Reuse the MData node if possible, this will avoid having to create another + // one that cannot be merged because LoopIDs are 'distinct'. However, we have + // to create a new one if we add properties. + if (!MData || Args.size() > MData->getNumOperands()) { + MData = MDNode::getDistinct(Ctx, Args); + MData->replaceOperandWith(0, MData); + } B->setMetadata(LLVMContext::MD_loop, MData); } diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp --- a/polly/lib/CodeGen/IslNodeBuilder.cpp +++ b/polly/lib/CodeGen/IslNodeBuilder.cpp @@ -426,7 +426,22 @@ auto *BasePtr = static_cast(isl_id_get_user(Id)); Annotator.addInterIterationAliasFreeBasePtr(BasePtr); } + + BandAttr *ChildLoopAttr = getLoopAttr(isl::manage_copy(Id)); + if (ChildLoopAttr) { + assert(!Annotator.getStagingAttrEnv() && + "conflicting loop attr environments"); + Annotator.getStagingAttrEnv() = ChildLoopAttr; + } + create(Child); + + if (ChildLoopAttr) { + assert(Annotator.getStagingAttrEnv() == ChildLoopAttr && + "Nest must not overwrite loop attr environment"); + Annotator.getStagingAttrEnv() = nullptr; + } + isl_id_free(Id); } diff --git a/polly/lib/Support/ScopHelper.cpp b/polly/lib/Support/ScopHelper.cpp --- a/polly/lib/Support/ScopHelper.cpp +++ b/polly/lib/Support/ScopHelper.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" using namespace llvm; @@ -726,3 +727,84 @@ return false; } + +/// Find a property in a LoopID. +static MDNode *findNamedMetadataNode(MDNode *LoopMD, StringRef Name) { + if (!LoopMD) + return nullptr; + for (const MDOperand &X : drop_begin(LoopMD->operands(), 1)) { + auto *OpNode = dyn_cast(X.get()); + if (!OpNode) + continue; + + auto *OpName = dyn_cast(OpNode->getOperand(0)); + if (!OpName) + continue; + if (OpName->getString() == Name) + return OpNode; + } + return nullptr; +} + +Optional polly::findMetadataOperand(MDNode *LoopMD, + StringRef Name) { + MDNode *MD = findNamedMetadataNode(LoopMD, Name); + if (!MD) + return None; + switch (MD->getNumOperands()) { + case 1: + return nullptr; + case 2: + return MD->getOperand(1).get(); + default: + llvm_unreachable("loop metadata must have 0 or 1 operands"); + } +} + +bool polly::hasDisableAllTransformsHint(Loop *L) { + return llvm::hasDisableAllTransformsHint(L); +} + +isl::id polly::getIslLoopAttr(isl::ctx Ctx, BandAttr *Attr) { + assert(Attr && "Must be a valid BandAttr"); + + // The name "Loop" signals that this id contains a pointer to a BandAttr. + // The ScheduleOptimizer also uses the string "Inter iteration alias-free" in + // markers, but it's user pointer is an llvm::Value. + isl::id Result = isl::id::alloc(Ctx, "Loop with Metadata", Attr); + Result = isl::manage(isl_id_set_free_user(Result.release(), [](void *Ptr) { + BandAttr *Attr = reinterpret_cast(Ptr); + delete Attr; + })); + return Result; +} + +isl::id polly::createIslLoopAttr(isl::ctx Ctx, Loop *L) { + if (!L) + return {}; + + // A loop without metadata does not need to be annotated. + MDNode *LoopID = L->getLoopID(); + if (!LoopID) + return {}; + + BandAttr *Attr = new BandAttr(); + Attr->OriginalLoop = L; + Attr->Metadata = L->getLoopID(); + + return getIslLoopAttr(Ctx, Attr); +} + +bool polly::isLoopAttr(const isl::id &Id) { + if (Id.is_null()) + return false; + + return Id.get_name() == "Loop with Metadata"; +} + +BandAttr *polly::getLoopAttr(const isl::id &Id) { + if (!isLoopAttr(Id)) + return nullptr; + + return reinterpret_cast(Id.get_user()); +} diff --git a/polly/lib/Transform/ManualOptimizer.cpp b/polly/lib/Transform/ManualOptimizer.cpp new file mode 100644 --- /dev/null +++ b/polly/lib/Transform/ManualOptimizer.cpp @@ -0,0 +1,183 @@ +//===------ ManualOptimizer.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Handle pragma/metadata-directed transformations. +// +//===----------------------------------------------------------------------===// + +#include "polly/ManualOptimizer.h" +#include "polly/ScheduleTreeTransform.h" +#include "polly/Support/ScopHelper.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/Metadata.h" + +#define DEBUG_TYPE "polly-opt-manual" + +using namespace polly; +using namespace llvm; + +namespace { +/// Extract an integer property from an LoopID metadata node. +static llvm::Optional findOptionalIntOperand(MDNode *LoopMD, + StringRef Name) { + Metadata *AttrMD = findMetadataOperand(LoopMD, Name).getValueOr(nullptr); + if (!AttrMD) + return None; + + ConstantInt *IntMD = mdconst::extract_or_null(AttrMD); + if (!IntMD) + return None; + + return IntMD->getSExtValue(); +} + +/// Extract boolean property from an LoopID metadata node. +static llvm::Optional findOptionalBoolOperand(MDNode *LoopMD, + StringRef Name) { + auto MD = findOptionMDForLoopID(LoopMD, Name); + if (!MD) + return None; + + switch (MD->getNumOperands()) { + case 1: + // When the value is absent it is interpreted as 'attribute set'. + return true; + case 2: + ConstantInt *IntMD = + mdconst::extract_or_null(MD->getOperand(1).get()); + return IntMD->getZExtValue() != 0; + } + llvm_unreachable("unexpected number of options"); +} + +/// Apply full or partial unrolling. +static isl::schedule applyLoopUnroll(MDNode *LoopMD, + isl::schedule_node BandToUnroll) { + assert(BandToUnroll); + // TODO: Isl's codegen also supports unrolling by isl_ast_build via + // isl_schedule_node_band_set_ast_build_options({ unroll[x] }) which would be + // more efficient because the content duplication is delayed. However, the + // unrolled loop could be input of another loop transformation which expects + // the explicit schedule nodes. That is, we would need this explicit expansion + // anyway and using the ISL codegen option is a compile-time optimization. + int64_t Factor = + findOptionalIntOperand(LoopMD, "llvm.loop.unroll.count").getValueOr(0); + bool Full = findOptionalBoolOperand(LoopMD, "llvm.loop.unroll.full") + .getValueOr(false); + assert((!Full || !(Factor > 0)) && + "Cannot unroll fully and partially at the same time"); + + if (Full) + return applyFullUnroll(BandToUnroll); + + if (Factor > 0) + return applyPartialUnroll(BandToUnroll, Factor); + + llvm_unreachable("Negative unroll factor"); +} + +// Return the properties from a LoopID. Scalar properties are ignored. +static auto getLoopMDProps(MDNode *LoopMD) { + return map_range( + make_filter_range( + drop_begin(LoopMD->operands(), 1), + [](const MDOperand &MDOp) { return isa(MDOp.get()); }), + [](const MDOperand &MDOp) { return cast(MDOp.get()); }); +} + +/// Recursively visit all nodes in a schedule, loop for loop-transformations +/// metadata and apply the first encountered. +class SearchTransformVisitor + : public RecursiveScheduleTreeVisitor { +private: + using BaseTy = RecursiveScheduleTreeVisitor; + BaseTy &getBase() { return *this; } + const BaseTy &getBase() const { return *this; } + + // Set after a transformation is applied. Recursive search must be aborted + // once this happens to ensure that any new followup transformation is + // transformed in innermost-first order. + isl::schedule Result; + +public: + static isl::schedule applyOneTransformation(const isl::schedule &Sched) { + SearchTransformVisitor Transformer; + Transformer.visit(Sched); + return Transformer.Result; + } + + void visitBand(const isl::schedule_node &Band) { + // Transform inner loops first (depth-first search). + getBase().visitBand(Band); + if (Result) + return; + + // Since it is (currently) not possible to have a BandAttr marker that is + // specific to each loop in a band, we only support single-loop bands. + if (isl_schedule_node_band_n_member(Band.get()) != 1) + return; + + BandAttr *Attr = getBandAttr(Band); + if (!Attr) { + // Band has no attribute. + return; + } + + MDNode *LoopMD = Attr->Metadata; + if (!LoopMD) + return; + + // Iterate over loop properties to find the first transformation. + // FIXME: If there are more than one transformation in the LoopMD (making + // the order of transformations ambiguous), all others are silently ignored. + for (MDNode *MD : getLoopMDProps(LoopMD)) { + auto *NameMD = dyn_cast(MD->getOperand(0).get()); + if (!NameMD) + continue; + StringRef AttrName = NameMD->getString(); + + if (AttrName == "llvm.loop.unroll.enable") { + // TODO: Handle disabling like llvm::hasUnrollTransformation(). + Result = applyLoopUnroll(LoopMD, Band); + } else { + // not a loop transformation; look for next property + continue; + } + + assert(Result && "expecting applied transformation"); + return; + } + } + + void visitNode(const isl::schedule_node &Other) { + if (Result) + return; + getBase().visitNode(Other); + } +}; + +} // namespace + +isl::schedule polly::applyManualTransformations(Scop *S, isl::schedule Sched) { + // Search the loop nest for transformations until fixpoint. + while (true) { + isl::schedule Result = + SearchTransformVisitor::applyOneTransformation(Sched); + if (!Result) { + // No (more) transformation has been found. + break; + } + + // Use transformed schedule and look for more transformations. + Sched = Result; + } + + return Sched; +} diff --git a/polly/lib/Transform/ScheduleOptimizer.cpp b/polly/lib/Transform/ScheduleOptimizer.cpp --- a/polly/lib/Transform/ScheduleOptimizer.cpp +++ b/polly/lib/Transform/ScheduleOptimizer.cpp @@ -49,6 +49,7 @@ #include "polly/CodeGen/CodeGeneration.h" #include "polly/DependenceInfo.h" #include "polly/LinkAllPasses.h" +#include "polly/ManualOptimizer.h" #include "polly/Options.h" #include "polly/ScheduleTreeTransform.h" #include "polly/ScopInfo.h" @@ -257,6 +258,11 @@ cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated, cl::cat(PollyCategory)); +static cl::opt PragmaBasedOpts( + "polly-pragma-based-opts", + cl::desc("Apply user-directed transformation from metadata"), + cl::init(true), cl::ZeroOrMore, cl::cat(PollyCategory)); + static cl::opt PMBasedOpts("polly-pattern-matching-based-opts", cl::desc("Perform optimizations based on pattern matching"), @@ -1716,6 +1722,18 @@ char IslScheduleOptimizerWrapperPass::ID = 0; +static void printSchedule(llvm::raw_ostream &OS, const isl::schedule &Schedule, + StringRef Desc) { + isl::ctx Ctx = Schedule.get_ctx(); + isl_printer *P = isl_printer_to_str(Ctx.get()); + P = isl_printer_set_yaml_style(P, ISL_YAML_STYLE_BLOCK); + P = isl_printer_print_schedule(P, Schedule.get()); + char *Str = isl_printer_get_str(P); + OS << Desc << ": \n" << Str << "\n"; + free(Str); + isl_printer_free(P); +} + /// Collect statistics for the schedule tree. /// /// @param Schedule The schedule tree to analyze. If not a schedule tree it is @@ -1784,120 +1802,167 @@ return false; } - const Dependences &D = GetDeps(Dependences::AL_Statement); + ScopsProcessed++; - if (D.getSharedIslCtx() != S.getSharedIslCtx()) { - LLVM_DEBUG(dbgs() << "DependenceInfo for another SCoP/isl_ctx\n"); - return false; + // Schedule without optimizations. + isl::schedule Schedule = S.getScheduleTree(); + walkScheduleTreeForStatistics(S.getScheduleTree(), 0); + LLVM_DEBUG(printSchedule(dbgs(), Schedule, "Original schedule tree")); + + bool HasUserTransformation = false; + if (PragmaBasedOpts) { + isl::schedule ManuallyTransformed = + applyManualTransformations(&S, Schedule); + if (!ManuallyTransformed) { + LLVM_DEBUG(dbgs() << "Error during manual optimization\n"); + return false; + } + + if (ManuallyTransformed.get() != Schedule.get()) { + // User transformations have precedence over other transformations. + HasUserTransformation = true; + Schedule = std::move(ManuallyTransformed); + LLVM_DEBUG( + printSchedule(dbgs(), Schedule, "After manual transformations")); + } } - if (!D.hasValidDependences()) + // Only continue if either manual transformations have been applied or we are + // allowed to apply heuristics. + // TODO: Detect disabled heuristics and no user-directed transformation + // metadata earlier in ScopDetection. + if (!HasUserTransformation && S.hasDisableHeuristicsHint()) { + LLVM_DEBUG(dbgs() << "Heuristic optimizations disabled by metadata\n"); return false; - - // Build input data. - int ValidityKinds = - Dependences::TYPE_RAW | Dependences::TYPE_WAR | Dependences::TYPE_WAW; - int ProximityKinds; - - if (OptimizeDeps == "all") - ProximityKinds = - Dependences::TYPE_RAW | Dependences::TYPE_WAR | Dependences::TYPE_WAW; - else if (OptimizeDeps == "raw") - ProximityKinds = Dependences::TYPE_RAW; - else { - errs() << "Do not know how to optimize for '" << OptimizeDeps << "'" - << " Falling back to optimizing all dependences.\n"; - ProximityKinds = - Dependences::TYPE_RAW | Dependences::TYPE_WAR | Dependences::TYPE_WAW; } - isl::union_set Domain = S.getDomains(); - - if (!Domain) + // Get dependency analysis. + const Dependences &D = GetDeps(Dependences::AL_Statement); + if (D.getSharedIslCtx() != S.getSharedIslCtx()) { + LLVM_DEBUG(dbgs() << "DependenceInfo for another SCoP/isl_ctx\n"); + return false; + } + if (!D.hasValidDependences()) { + LLVM_DEBUG(dbgs() << "Dependency information not available\n"); return false; + } - ScopsProcessed++; - walkScheduleTreeForStatistics(S.getScheduleTree(), 0); + // Apply ISL's algorithm only if not overriden by the user. Note that + // post-rescheduling optimizations (tiling, pattern-based, prevectorization) + // rely on the coincidence/permutable annotations on schedule tree bands that + // are added by the rescheduling analyzer. Therefore, disabling the + // rescheduler implicitly also disables these optimizations. + if (HasUserTransformation) { + LLVM_DEBUG( + dbgs() << "Skipping rescheduling due to manual transformation\n"); + } else { + // Build input data. + int ValidityKinds = + Dependences::TYPE_RAW | Dependences::TYPE_WAR | Dependences::TYPE_WAW; + int ProximityKinds; + + if (OptimizeDeps == "all") + ProximityKinds = + Dependences::TYPE_RAW | Dependences::TYPE_WAR | Dependences::TYPE_WAW; + else if (OptimizeDeps == "raw") + ProximityKinds = Dependences::TYPE_RAW; + else { + errs() << "Do not know how to optimize for '" << OptimizeDeps << "'" + << " Falling back to optimizing all dependences.\n"; + ProximityKinds = + Dependences::TYPE_RAW | Dependences::TYPE_WAR | Dependences::TYPE_WAW; + } - isl::union_map Validity = D.getDependences(ValidityKinds); - isl::union_map Proximity = D.getDependences(ProximityKinds); - - // Simplify the dependences by removing the constraints introduced by the - // domains. This can speed up the scheduling time significantly, as large - // constant coefficients will be removed from the dependences. The - // introduction of some additional dependences reduces the possible - // transformations, but in most cases, such transformation do not seem to be - // interesting anyway. In some cases this option may stop the scheduler to - // find any schedule. - if (SimplifyDeps == "yes") { - Validity = Validity.gist_domain(Domain); - Validity = Validity.gist_range(Domain); - Proximity = Proximity.gist_domain(Domain); - Proximity = Proximity.gist_range(Domain); - } else if (SimplifyDeps != "no") { - errs() << "warning: Option -polly-opt-simplify-deps should either be 'yes' " - "or 'no'. Falling back to default: 'yes'\n"; - } + isl::union_set Domain = S.getDomains(); - LLVM_DEBUG(dbgs() << "\n\nCompute schedule from: "); - LLVM_DEBUG(dbgs() << "Domain := " << Domain << ";\n"); - LLVM_DEBUG(dbgs() << "Proximity := " << Proximity << ";\n"); - LLVM_DEBUG(dbgs() << "Validity := " << Validity << ";\n"); + if (!Domain) + return false; - unsigned IslSerializeSCCs; + isl::union_map Validity = D.getDependences(ValidityKinds); + isl::union_map Proximity = D.getDependences(ProximityKinds); + + // Simplify the dependences by removing the constraints introduced by the + // domains. This can speed up the scheduling time significantly, as large + // constant coefficients will be removed from the dependences. The + // introduction of some additional dependences reduces the possible + // transformations, but in most cases, such transformation do not seem to be + // interesting anyway. In some cases this option may stop the scheduler to + // find any schedule. + if (SimplifyDeps == "yes") { + Validity = Validity.gist_domain(Domain); + Validity = Validity.gist_range(Domain); + Proximity = Proximity.gist_domain(Domain); + Proximity = Proximity.gist_range(Domain); + } else if (SimplifyDeps != "no") { + errs() + << "warning: Option -polly-opt-simplify-deps should either be 'yes' " + "or 'no'. Falling back to default: 'yes'\n"; + } - if (FusionStrategy == "max") { - IslSerializeSCCs = 0; - } else if (FusionStrategy == "min") { - IslSerializeSCCs = 1; - } else { - errs() << "warning: Unknown fusion strategy. Falling back to maximal " - "fusion.\n"; - IslSerializeSCCs = 0; - } + LLVM_DEBUG(dbgs() << "\n\nCompute schedule from: "); + LLVM_DEBUG(dbgs() << "Domain := " << Domain << ";\n"); + LLVM_DEBUG(dbgs() << "Proximity := " << Proximity << ";\n"); + LLVM_DEBUG(dbgs() << "Validity := " << Validity << ";\n"); + + unsigned IslSerializeSCCs; + + if (FusionStrategy == "max") { + IslSerializeSCCs = 0; + } else if (FusionStrategy == "min") { + IslSerializeSCCs = 1; + } else { + errs() << "warning: Unknown fusion strategy. Falling back to maximal " + "fusion.\n"; + IslSerializeSCCs = 0; + } - int IslMaximizeBands; + int IslMaximizeBands; + + if (MaximizeBandDepth == "yes") { + IslMaximizeBands = 1; + } else if (MaximizeBandDepth == "no") { + IslMaximizeBands = 0; + } else { + errs() + << "warning: Option -polly-opt-maximize-bands should either be 'yes'" + " or 'no'. Falling back to default: 'yes'\n"; + IslMaximizeBands = 1; + } - if (MaximizeBandDepth == "yes") { - IslMaximizeBands = 1; - } else if (MaximizeBandDepth == "no") { - IslMaximizeBands = 0; - } else { - errs() << "warning: Option -polly-opt-maximize-bands should either be 'yes'" - " or 'no'. Falling back to default: 'yes'\n"; - IslMaximizeBands = 1; - } + int IslOuterCoincidence; - int IslOuterCoincidence; + if (OuterCoincidence == "yes") { + IslOuterCoincidence = 1; + } else if (OuterCoincidence == "no") { + IslOuterCoincidence = 0; + } else { + errs() << "warning: Option -polly-opt-outer-coincidence should either be " + "'yes' or 'no'. Falling back to default: 'no'\n"; + IslOuterCoincidence = 0; + } - if (OuterCoincidence == "yes") { - IslOuterCoincidence = 1; - } else if (OuterCoincidence == "no") { - IslOuterCoincidence = 0; - } else { - errs() << "warning: Option -polly-opt-outer-coincidence should either be " - "'yes' or 'no'. Falling back to default: 'no'\n"; - IslOuterCoincidence = 0; - } + isl_ctx *Ctx = S.getIslCtx().get(); - isl_ctx *Ctx = S.getIslCtx().get(); + isl_options_set_schedule_outer_coincidence(Ctx, IslOuterCoincidence); + isl_options_set_schedule_serialize_sccs(Ctx, IslSerializeSCCs); + isl_options_set_schedule_maximize_band_depth(Ctx, IslMaximizeBands); + isl_options_set_schedule_max_constant_term(Ctx, MaxConstantTerm); + isl_options_set_schedule_max_coefficient(Ctx, MaxCoefficient); + isl_options_set_tile_scale_tile_loops(Ctx, 0); - isl_options_set_schedule_outer_coincidence(Ctx, IslOuterCoincidence); - isl_options_set_schedule_serialize_sccs(Ctx, IslSerializeSCCs); - isl_options_set_schedule_maximize_band_depth(Ctx, IslMaximizeBands); - isl_options_set_schedule_max_constant_term(Ctx, MaxConstantTerm); - isl_options_set_schedule_max_coefficient(Ctx, MaxCoefficient); - isl_options_set_tile_scale_tile_loops(Ctx, 0); + auto OnErrorStatus = isl_options_get_on_error(Ctx); + isl_options_set_on_error(Ctx, ISL_ON_ERROR_CONTINUE); - auto OnErrorStatus = isl_options_get_on_error(Ctx); - isl_options_set_on_error(Ctx, ISL_ON_ERROR_CONTINUE); + auto SC = isl::schedule_constraints::on_domain(Domain); + SC = SC.set_proximity(Proximity); + SC = SC.set_validity(Validity); + SC = SC.set_coincidence(Validity); + Schedule = SC.compute_schedule(); + isl_options_set_on_error(Ctx, OnErrorStatus); - auto SC = isl::schedule_constraints::on_domain(Domain); - SC = SC.set_proximity(Proximity); - SC = SC.set_validity(Validity); - SC = SC.set_coincidence(Validity); - auto Schedule = SC.compute_schedule(); - isl_options_set_on_error(Ctx, OnErrorStatus); + ScopsRescheduled++; + LLVM_DEBUG(printSchedule(dbgs(), Schedule, "After rescheduling")); + } walkScheduleTreeForStatistics(Schedule, 1); @@ -1906,33 +1971,23 @@ if (!Schedule) return false; - ScopsRescheduled++; - - LLVM_DEBUG({ - auto *P = isl_printer_to_str(Ctx); - P = isl_printer_set_yaml_style(P, ISL_YAML_STYLE_BLOCK); - P = isl_printer_print_schedule(P, Schedule.get()); - auto *str = isl_printer_get_str(P); - dbgs() << "NewScheduleTree: \n" << str << "\n"; - free(str); - isl_printer_free(P); - }); - + // Apply post-rescheduling optimizations. const OptimizerAdditionalInfoTy OAI = {TTI, const_cast(&D)}; - auto NewSchedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule, &OAI); - NewSchedule = hoistExtensionNodes(NewSchedule); - walkScheduleTreeForStatistics(NewSchedule, 2); + Schedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule, &OAI); + Schedule = hoistExtensionNodes(Schedule); + LLVM_DEBUG(printSchedule(dbgs(), Schedule, "After post-optimizations")); + walkScheduleTreeForStatistics(Schedule, 2); - if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewSchedule)) + if (!ScheduleTreeOptimizer::isProfitableSchedule(S, Schedule)) return false; auto ScopStats = S.getStatistics(); ScopsOptimized++; NumAffineLoopsOptimized += ScopStats.NumAffineLoops; NumBoxedLoopsOptimized += ScopStats.NumBoxedLoops; - LastSchedule = NewSchedule; + LastSchedule = Schedule; - S.setScheduleTree(NewSchedule); + S.setScheduleTree(Schedule); S.markAsOptimized(); if (OptimizedScops) diff --git a/polly/lib/Transform/ScheduleTreeTransform.cpp b/polly/lib/Transform/ScheduleTreeTransform.cpp --- a/polly/lib/Transform/ScheduleTreeTransform.cpp +++ b/polly/lib/Transform/ScheduleTreeTransform.cpp @@ -12,132 +12,18 @@ #include "polly/ScheduleTreeTransform.h" #include "polly/Support/ISLTools.h" +#include "polly/Support/ScopHelper.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Transforms/Utils/UnrollLoop.h" using namespace polly; +using namespace llvm; namespace { - -/// This class defines a simple visitor class that may be used for -/// various schedule tree analysis purposes. -template -struct ScheduleTreeVisitor { - Derived &getDerived() { return *static_cast(this); } - const Derived &getDerived() const { - return *static_cast(this); - } - - RetTy visit(const isl::schedule_node &Node, Args... args) { - assert(!Node.is_null()); - switch (isl_schedule_node_get_type(Node.get())) { - case isl_schedule_node_domain: - assert(isl_schedule_node_n_children(Node.get()) == 1); - return getDerived().visitDomain(Node, std::forward(args)...); - case isl_schedule_node_band: - assert(isl_schedule_node_n_children(Node.get()) == 1); - return getDerived().visitBand(Node, std::forward(args)...); - case isl_schedule_node_sequence: - assert(isl_schedule_node_n_children(Node.get()) >= 2); - return getDerived().visitSequence(Node, std::forward(args)...); - case isl_schedule_node_set: - return getDerived().visitSet(Node, std::forward(args)...); - assert(isl_schedule_node_n_children(Node.get()) >= 2); - case isl_schedule_node_leaf: - assert(isl_schedule_node_n_children(Node.get()) == 0); - return getDerived().visitLeaf(Node, std::forward(args)...); - case isl_schedule_node_mark: - assert(isl_schedule_node_n_children(Node.get()) == 1); - return getDerived().visitMark(Node, std::forward(args)...); - case isl_schedule_node_extension: - assert(isl_schedule_node_n_children(Node.get()) == 1); - return getDerived().visitExtension(Node, std::forward(args)...); - case isl_schedule_node_filter: - assert(isl_schedule_node_n_children(Node.get()) == 1); - return getDerived().visitFilter(Node, std::forward(args)...); - default: - llvm_unreachable("unimplemented schedule node type"); - } - } - - RetTy visitDomain(const isl::schedule_node &Domain, Args... args) { - return getDerived().visitSingleChild(Domain, std::forward(args)...); - } - - RetTy visitBand(const isl::schedule_node &Band, Args... args) { - return getDerived().visitSingleChild(Band, std::forward(args)...); - } - - RetTy visitSequence(const isl::schedule_node &Sequence, Args... args) { - return getDerived().visitMultiChild(Sequence, std::forward(args)...); - } - - RetTy visitSet(const isl::schedule_node &Set, Args... args) { - return getDerived().visitMultiChild(Set, std::forward(args)...); - } - - RetTy visitLeaf(const isl::schedule_node &Leaf, Args... args) { - return getDerived().visitNode(Leaf, std::forward(args)...); - } - - RetTy visitMark(const isl::schedule_node &Mark, Args... args) { - return getDerived().visitSingleChild(Mark, std::forward(args)...); - } - - RetTy visitExtension(const isl::schedule_node &Extension, Args... args) { - return getDerived().visitSingleChild(Extension, - std::forward(args)...); - } - - RetTy visitFilter(const isl::schedule_node &Extension, Args... args) { - return getDerived().visitSingleChild(Extension, - std::forward(args)...); - } - - RetTy visitSingleChild(const isl::schedule_node &Node, Args... args) { - return getDerived().visitNode(Node, std::forward(args)...); - } - - RetTy visitMultiChild(const isl::schedule_node &Node, Args... args) { - return getDerived().visitNode(Node, std::forward(args)...); - } - - RetTy visitNode(const isl::schedule_node &Node, Args... args) { - llvm_unreachable("Unimplemented other"); - } -}; - -/// Recursively visit all nodes of a schedule tree. -template -struct RecursiveScheduleTreeVisitor - : public ScheduleTreeVisitor { - using BaseTy = ScheduleTreeVisitor; - BaseTy &getBase() { return *this; } - const BaseTy &getBase() const { return *this; } - Derived &getDerived() { return *static_cast(this); } - const Derived &getDerived() const { - return *static_cast(this); - } - - /// When visiting an entire schedule tree, start at its root node. - RetTy visit(const isl::schedule &Schedule, Args... args) { - return getDerived().visit(Schedule.get_root(), std::forward(args)...); - } - - // Necessary to allow overload resolution with the added visit(isl::schedule) - // overload. - RetTy visit(const isl::schedule_node &Node, Args... args) { - return getBase().visit(Node, std::forward(args)...); - } - - RetTy visitNode(const isl::schedule_node &Node, Args... args) { - int NumChildren = isl_schedule_node_n_children(Node.get()); - for (int i = 0; i < NumChildren; i += 1) - getDerived().visit(Node.child(i), std::forward(args)...); - return RetTy(); - } -}; - /// Recursively visit all nodes of a schedule tree while allowing changes. /// /// The visit methods return an isl::schedule_node that is used to continue @@ -461,8 +347,6 @@ } }; -} // namespace - /// Return whether the schedule contains an extension node. static bool containsExtensionNode(isl::schedule Schedule) { assert(!Schedule.is_null()); @@ -485,6 +369,129 @@ return RetVal == isl_stat_error; } +/// Find a named MDNode property in a LoopID. +static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) { + return dyn_cast_or_null( + findMetadataOperand(LoopMD, Name).getValueOr(nullptr)); +} + +/// Is this node of type band? +static bool isBand(const isl::schedule_node &Node) { + return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band; +} + +/// Is this node of type mark? +static bool isMark(const isl::schedule_node &Node) { + return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark; +} + +/// Is this node a band of a single dimension (i.e. could represent a loop)? +static bool isBandWithSingleLoop(const isl::schedule_node &Node) { + + return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1; +} + +/// Create an isl::id representing the output loop after a transformation. +static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) { + // Don't need to id the followup. + // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by + // user followup-MD + if (!FollowupLoopMD) + return {}; + + BandAttr *Attr = new BandAttr(); + Attr->Metadata = FollowupLoopMD; + return getIslLoopAttr(Ctx, Attr); +} + +/// A loop consists of a band and an optional marker that wraps it. Return the +/// outermost of the two. + +/// That is, either the mark or, if there is not mark, the loop itself. Can +/// start with either the mark or the band. +static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) { + if (isBandMark(BandOrMark)) { + assert(isBandWithSingleLoop(BandOrMark.get_child(0))); + return BandOrMark; + } + assert(isBandWithSingleLoop(BandOrMark)); + + isl::schedule_node Mark = BandOrMark.parent(); + if (isBandMark(Mark)) + return Mark; + + // Band has no loop marker. + return BandOrMark; +} + +static isl::schedule_node removeMark(isl::schedule_node MarkOrBand, + BandAttr *&Attr) { + MarkOrBand = moveToBandMark(MarkOrBand); + + isl::schedule_node Band; + if (isMark(MarkOrBand)) { + Attr = getLoopAttr(MarkOrBand.mark_get_id()); + Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release())); + } else { + Attr = nullptr; + Band = MarkOrBand; + } + + assert(isBandWithSingleLoop(Band)); + return Band; +} + +/// Remove the mark that wraps a loop. Return the band representing the loop. +static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) { + BandAttr *Attr; + return removeMark(MarkOrBand, Attr); +} + +static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) { + assert(isBand(Band)); + assert(moveToBandMark(Band).is_equal(Band) && + "Don't add a two marks for a band"); + + return Band.insert_mark(Mark).get_child(0); +} + +/// Return the (one-dimensional) set of numbers that are divisible by @p Factor +/// with remainder @p Offset. +/// +/// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 } +/// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 } +/// +static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor, + long Offset) { + isl::val ValFactor{Ctx, Factor}; + isl::val ValOffset{Ctx, Offset}; + + isl::space Unispace{Ctx, 0, 1}; + isl::local_space LUnispace{Unispace}; + isl::aff AffFactor{LUnispace, ValFactor}; + isl::aff AffOffset{LUnispace, ValOffset}; + + isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0); + isl::aff DivMul = Id.mod(ValFactor); + isl::basic_map Divisible = isl::basic_map::from_aff(DivMul); + isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset); + return Modulo.domain(); +} + +} // namespace + +bool polly::isBandMark(const isl::schedule_node &Node) { + return isMark(Node) && isLoopAttr(Node.mark_get_id()); +} + +BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) { + MarkOrBand = moveToBandMark(MarkOrBand); + if (!isMark(MarkOrBand)) + return nullptr; + + return getLoopAttr(MarkOrBand.mark_get_id()); +} + isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) { // If there is no extension node in the first place, return the original // schedule tree. @@ -508,3 +515,119 @@ return NewSched; } + +isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) { + isl::ctx Ctx = BandToUnroll.get_ctx(); + + // Remove the loop's mark, the loop will disappear anyway. + BandToUnroll = removeMark(BandToUnroll); + assert(isBandWithSingleLoop(BandToUnroll)); + + isl::multi_union_pw_aff PartialSched = isl::manage( + isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); + assert(PartialSched.dim(isl::dim::out) == 1 && + "Can only unroll a single dimension"); + isl::union_pw_aff PartialSchedUAff = PartialSched.get_union_pw_aff(0); + + isl::union_set Domain = BandToUnroll.get_domain(); + PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain); + isl::union_map PartialSchedUMap = isl::union_map(PartialSchedUAff); + + // Make consumable for the following code. + // Schedule at the beginning so it is at coordinate 0. + isl::union_set PartialSchedUSet = PartialSchedUMap.reverse().wrap(); + + SmallVector Elts; + // TODO: Diagnose if not enumerable or depends on a parameter. + PartialSchedUSet.foreach_point([&Elts](isl::point P) -> isl::stat { + Elts.push_back(P); + return isl::stat::ok(); + }); + + // Don't assume that foreach_point returns in execution order. + llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool { + isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0); + isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0); + return C1.lt(C2); + }); + + // Convert the points to a sequence of filters. + isl::union_set_list List = isl::union_set_list::alloc(Ctx, Elts.size()); + for (isl::point P : Elts) { + isl::basic_set AsSet{P}; + + // Throw away the scatter dimension. + AsSet = AsSet.unwrap().range(); + + List = List.add(AsSet); + } + + // Replace original band with unrolled sequence. + isl::schedule_node Body = + isl::manage(isl_schedule_node_delete(BandToUnroll.release())); + Body = Body.insert_sequence(List); + return Body.get_schedule(); +} + +isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll, + int Factor) { + assert(Factor > 0 && "Positive unroll factor required"); + isl::ctx Ctx = BandToUnroll.get_ctx(); + + // Remove the mark, save the attribute for later use. + BandAttr *Attr; + BandToUnroll = removeMark(BandToUnroll, Attr); + assert(isBandWithSingleLoop(BandToUnroll)); + + isl::multi_union_pw_aff PartialSched = isl::manage( + isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); + + // { Stmt[] -> [x] } + isl::union_pw_aff PartialSchedUAff = PartialSched.get_union_pw_aff(0); + + // Here we assume the schedule stride is one and starts with 0, which is not + // necessarily the case. + isl::union_pw_aff StridedPartialSchedUAff = + isl::union_pw_aff::empty(PartialSchedUAff.get_space()); + isl::val ValFactor{Ctx, Factor}; + PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff, + &ValFactor](isl::pw_aff PwAff) -> isl::stat { + isl::space Space = PwAff.get_space(); + isl::set Universe = isl::set::universe(Space.domain()); + isl::pw_aff AffFactor{Universe, ValFactor}; + isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor); + StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff); + return isl::stat::ok(); + }); + + isl::union_set_list List = isl::union_set_list::alloc(Ctx, Factor); + for (auto i : seq(0, Factor)) { + // { Stmt[] -> [x] } + isl::union_map UMap{PartialSchedUAff}; + + // { [x] } + isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i); + + // { Stmt[] } + isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain(); + + List = List.add(UnrolledDomain); + } + + isl::schedule_node Body = + isl::manage(isl_schedule_node_delete(BandToUnroll.copy())); + Body = Body.insert_sequence(List); + isl::schedule_node NewLoop = + Body.insert_partial_schedule(StridedPartialSchedUAff); + + MDNode *FollowupMD = nullptr; + if (Attr && Attr->Metadata) + FollowupMD = + findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled); + + isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD); + if (NewBandId) + NewLoop = insertMark(NewLoop, NewBandId); + + return NewLoop.get_schedule(); +} diff --git a/polly/test/ScheduleOptimizer/ManualOptimization/disable_nonforced.ll b/polly/test/ScheduleOptimizer/ManualOptimization/disable_nonforced.ll new file mode 100644 --- /dev/null +++ b/polly/test/ScheduleOptimizer/ManualOptimization/disable_nonforced.ll @@ -0,0 +1,58 @@ +; RUN: opt %loadPolly -polly-opt-isl -analyze < %s | FileCheck %s -match-full-lines +; +; Check that the disable_nonforced metadata is honored; optimization +; heuristics/rescheduling must not be applied. +; +define void @func(i32 %n, double* noalias nonnull %A) { +entry: + br label %for + +for: + %j = phi i32 [0, %entry], [%j.inc, %inc] + %j.cmp = icmp slt i32 %j, %n + br i1 %j.cmp, label %inner.for, label %exit + + + inner.for: + %i = phi i32 [0, %for], [%i.inc, %inner.inc] + br label %bodyA + + + bodyA: + %mul = mul nuw nsw i32 %j, 128 + %add = add nuw nsw i32 %mul, %i + %A_idx = getelementptr inbounds double, double* %A, i32 %add + store double 0.0, double* %A_idx + br label %inner.inc + + + inner.inc: + %i.inc = add nuw nsw i32 %i, 1 + %i.cmp = icmp slt i32 %i.inc, 128 + br i1 %i.cmp, label %inner.for, label %inner.exit + + inner.exit: + br label %inc, !llvm.loop !2 + + +inc: + %j.inc = add nuw nsw i32 %j, 1 + br label %for, !llvm.loop !2 + +exit: + br label %return + +return: + ret void +} + + +!2 = distinct !{!2, !3} +!3 = !{!"llvm.loop.disable_nonforced"} + + +; n/a indicates no new schedule was computed +; +; CHECK-LABEL: Printing analysis 'Polly - Optimize schedule of SCoP' for region: 'for => return' in function 'func': +; CHECK-NEXT: Calculated schedule: +; CHECK-NEXT: n/a diff --git a/polly/test/ScheduleOptimizer/ManualOptimization/unroll_double.ll b/polly/test/ScheduleOptimizer/ManualOptimization/unroll_double.ll new file mode 100644 --- /dev/null +++ b/polly/test/ScheduleOptimizer/ManualOptimization/unroll_double.ll @@ -0,0 +1,52 @@ +; RUN: opt %loadPolly -polly-opt-isl -analyze < %s | FileCheck %s --match-full-lines +; +; Apply two loop transformations. First partial, then full unrolling. +; +define void @func(double* noalias nonnull %A) { +entry: + br label %for + +for: + %j = phi i32 [0, %entry], [%j.inc, %inc] + %j.cmp = icmp slt i32 %j, 12 + br i1 %j.cmp, label %body, label %exit + + body: + store double 42.0, double* %A + br label %inc + +inc: + %j.inc = add nuw nsw i32 %j, 1 + br label %for, !llvm.loop !2 + +exit: + br label %return + +return: + ret void +} + + +!2 = distinct !{!2, !4, !5, !6} +!4 = !{!"llvm.loop.unroll.enable", i1 true} +!5 = !{!"llvm.loop.unroll.count", i4 4} +!6 = !{!"llvm.loop.unroll.followup_unrolled", !7} + +!7 = distinct !{!7, !8, !9} +!8 = !{!"llvm.loop.unroll.enable", i1 true} +!9 = !{!"llvm.loop.unroll.full"} + + +; CHECK-LABEL: Printing analysis 'Polly - Optimize schedule of SCoP' for region: 'for => return' in function 'func': +; CHECK: - filter: "{ Stmt_body[0] }" +; CHECK: - filter: "{ Stmt_body[1] }" +; CHECK: - filter: "{ Stmt_body[2] }" +; CHECK: - filter: "{ Stmt_body[3] }" +; CHECK: - filter: "{ Stmt_body[4] }" +; CHECK: - filter: "{ Stmt_body[5] }" +; CHECK: - filter: "{ Stmt_body[6] }" +; CHECK: - filter: "{ Stmt_body[7] }" +; CHECK: - filter: "{ Stmt_body[8] }" +; CHECK: - filter: "{ Stmt_body[9] }" +; CHECK: - filter: "{ Stmt_body[10] }" +; CHECK: - filter: "{ Stmt_body[11] }" diff --git a/polly/test/ScheduleOptimizer/ManualOptimization/unroll_full.ll b/polly/test/ScheduleOptimizer/ManualOptimization/unroll_full.ll new file mode 100644 --- /dev/null +++ b/polly/test/ScheduleOptimizer/ManualOptimization/unroll_full.ll @@ -0,0 +1,42 @@ +; RUN: opt %loadPolly -polly-opt-isl -analyze < %s | FileCheck %s --match-full-lines +; +; Full unroll of a loop with 5 iterations. +; +define void @func(double* noalias nonnull %A) { +entry: + br label %for + +for: + %j = phi i32 [0, %entry], [%j.inc, %inc] + %j.cmp = icmp slt i32 %j, 5 + br i1 %j.cmp, label %body, label %exit + + body: + store double 42.0, double* %A + br label %inc + +inc: + %j.inc = add nuw nsw i32 %j, 1 + br label %for, !llvm.loop !2 + +exit: + br label %return + +return: + ret void +} + + +!2 = distinct !{!2, !4, !5} +!4 = !{!"llvm.loop.unroll.enable", i1 true} +!5 = !{!"llvm.loop.unroll.full"} + + +; CHECK-LABEL: Printing analysis 'Polly - Optimize schedule of SCoP' for region: 'for => return' in function 'func': +; CHECK: domain: "{ Stmt_body[i0] : 0 <= i0 <= 4 }" +; CHECK: sequence: +; CHECK-NEXT: - filter: "{ Stmt_body[0] }" +; CHECK-NEXT: - filter: "{ Stmt_body[1] }" +; CHECK-NEXT: - filter: "{ Stmt_body[2] }" +; CHECK-NEXT: - filter: "{ Stmt_body[3] }" +; CHECK-NEXT: - filter: "{ Stmt_body[4] }" diff --git a/polly/test/ScheduleOptimizer/ManualOptimization/unroll_partial.ll b/polly/test/ScheduleOptimizer/ManualOptimization/unroll_partial.ll new file mode 100644 --- /dev/null +++ b/polly/test/ScheduleOptimizer/ManualOptimization/unroll_partial.ll @@ -0,0 +1,48 @@ +; RUN: opt %loadPolly -polly-opt-isl -polly-pragma-based-opts=1 -analyze < %s | FileCheck %s --match-full-lines +; RUN: opt %loadPolly -polly-opt-isl -polly-pragma-based-opts=0 -analyze < %s | FileCheck %s --check-prefix=OFF --match-full-lines +; +; Partial unroll by a factor of 4. +; +define void @func(i32 %n, double* noalias nonnull %A) { +entry: + br label %for + +for: + %j = phi i32 [0, %entry], [%j.inc, %inc] + %j.cmp = icmp slt i32 %j, %n + br i1 %j.cmp, label %body, label %exit + + body: + store double 42.0, double* %A + br label %inc + +inc: + %j.inc = add nuw nsw i32 %j, 1 + br label %for, !llvm.loop !2 + +exit: + br label %return + +return: + ret void +} + + +!2 = distinct !{!2, !4, !5} +!4 = !{!"llvm.loop.unroll.enable", i1 true} +!5 = !{!"llvm.loop.unroll.count", i4 4} + + +; CHECK-LABEL: Printing analysis 'Polly - Optimize schedule of SCoP' for region: 'for => return' in function 'func': +; CHECK: domain: "[n] -> { Stmt_body[i0] : 0 <= i0 < n }" +; CHECK: schedule: "[n] -> [{ Stmt_body[i0] -> [(i0 - (i0) mod 4)] }]" +; CHECK: sequence: +; CHECK-NEXT: - filter: "[n] -> { Stmt_body[i0] : (i0) mod 4 = 0 }" +; CHECK-NEXT: - filter: "[n] -> { Stmt_body[i0] : (-1 + i0) mod 4 = 0 }" +; CHECK-NEXT: - filter: "[n] -> { Stmt_body[i0] : (2 + i0) mod 4 = 0 }" +; CHECK-NEXT: - filter: "[n] -> { Stmt_body[i0] : (1 + i0) mod 4 = 0 }" + + +; OFF-LABEL: Printing analysis 'Polly - Optimize schedule of SCoP' for region: 'for => return' in function 'func': +; OFF-NEXT: Calculated schedule: +; OFF-NEXT: n/a diff --git a/polly/test/ScheduleOptimizer/ManualOptimization/unroll_partial_followup.ll b/polly/test/ScheduleOptimizer/ManualOptimization/unroll_partial_followup.ll new file mode 100644 --- /dev/null +++ b/polly/test/ScheduleOptimizer/ManualOptimization/unroll_partial_followup.ll @@ -0,0 +1,58 @@ +; RUN: opt %loadPolly -polly-opt-isl -polly-ast -analyze < %s | FileCheck %s --match-full-lines +; RUN: opt %loadPolly -polly-opt-isl -polly-codegen -simplifycfg -S < %s | FileCheck %s --check-prefix=CODEGEN +; +; Partial unroll by a factor of 4. +; +define void @func(i32 %n, double* noalias nonnull %A) { +entry: + br label %for + +for: + %j = phi i32 [0, %entry], [%j.inc, %inc] + %j.cmp = icmp slt i32 %j, %n + br i1 %j.cmp, label %body, label %exit + + body: + store double 42.0, double* %A + br label %inc + +inc: + %j.inc = add nuw nsw i32 %j, 1 + br label %for, !llvm.loop !2 + +exit: + br label %return + +return: + ret void +} + + +!2 = distinct !{!2, !4, !5, !6} +!4 = !{!"llvm.loop.unroll.enable", i1 true} +!5 = !{!"llvm.loop.unroll.count", i4 4} +!6 = !{!"llvm.loop.unroll.followup_unrolled", !7} + +!7 = distinct !{!7, !8} +!8 = !{!"llvm.loop.id", !"This-is-the-unrolled-loop"} + + +; CHECK-LABEL: Printing analysis 'Polly - Optimize schedule of SCoP' for region: 'for => return' in function 'func': +; CHECK: domain: "[n] -> { Stmt_body[i0] : 0 <= i0 < n }" +; CHECKL mark: "Loop with Metadata" +; CHECK: schedule: "[n] -> [{ Stmt_body[i0] -> [(i0 - (i0) mod 4)] }]" +; CHECK: sequence: +; CHECK-NEXT: - filter: "[n] -> { Stmt_body[i0] : (i0) mod 4 = 0 }" +; CHECK-NEXT: - filter: "[n] -> { Stmt_body[i0] : (-1 + i0) mod 4 = 0 }" +; CHECK-NEXT: - filter: "[n] -> { Stmt_body[i0] : (2 + i0) mod 4 = 0 }" +; CHECK-NEXT: - filter: "[n] -> { Stmt_body[i0] : (1 + i0) mod 4 = 0 }" + + +; CHECK-LABEL: Printing analysis 'Polly - Generate an AST of the SCoP (isl)'for => return' in function 'func': +; CHECK: // Loop with Metadata +; CHECK-NEXT: for (int c0 = 0; c0 < n; c0 += 4) { + + +; CODEGEN: br i1 %polly.loop_cond, label %polly.loop_header, label %polly.exiting, !llvm.loop ![[LOOPID:[0-9]+]] +; CODEGEN: ![[LOOPID]] = distinct !{![[LOOPID]], ![[LOOPNAME:[0-9]+]]} +; CODEGEN: ![[LOOPNAME]] = !{!"llvm.loop.id", !"This-is-the-unrolled-loop"}