Index: include/polly/ScheduleOptimizer.h =================================================================== --- include/polly/ScheduleOptimizer.h +++ include/polly/ScheduleOptimizer.h @@ -12,6 +12,167 @@ #include "llvm/ADT/ArrayRef.h" #include "isl/isl-noexceptions.h" +#include "polly/ScopInfo.h" + +/* +* Taken from: https://github.com/PollyLabs/islutils +* Author: Alex (https://ozinenko.com/) +* \brief Structural matchers on schedule trees. +* +* A matcher is an object that captures the structure of schedule trees. +* Conceptually, a matcher is a tree itself where every node is assigned a node +* type. The matcher class provides functionality to detect if a subtree in +* the schedule tree has the same structure, that is the same types of nodes +* and parent/child relationships. Contrary to regular trees, matchers can be +* constructed using nested call syntax omitting the details about the content +* of nodes. For example, +* +* ``` +* auto m = domain( +* context( +* sequence( +* filter(), +* filter()))); +* ``` +* +* matches a subtree that starts at a domain node, having context as only +* child, which in turn has a sequence as only child node, and the latter has +* two filter children. The structure is not anchored at any position in the +* tree: the first node is not necessarily the tree root, and the innermost +* node may have children of their own. +*/ + +namespace matchers { +class ScheduleNodeMatcher; + +/* +* These functions construct a structural matcher on the schedule tree by +* specifying the type of the node (indicated by the function name). They take +* other matchers as arguments to describe the children of the node. Depending +* on the node type, functions take a single child matcher or an arbitrary +* number thereof. Sequence and set matcher builders take multiple children as +* these types of node are the only ones that can have more than one child. +* Additionally, all constructors are overloaded with an extra leading argument +* to store a callback function for finer-grain matching. This function is +* called on the node before attempting to match its children. It is passed +* the node itself and returns true if the matching may continue and false if +* it should fail immediately without processing the children. When no child +* matchers are provided, the node is allowed to have zero or more children. +*/ + +ScheduleNodeMatcher sequence(); + +template ::type, + ScheduleNodeMatcher>::value>::type> +ScheduleNodeMatcher sequence(Arg, Args... args); + +template +ScheduleNodeMatcher sequence(std::function callback, + Args... args); + +ScheduleNodeMatcher set(); + +template ::type, + ScheduleNodeMatcher>::value>::type> +ScheduleNodeMatcher set(Arg, Args... args); + +template +ScheduleNodeMatcher set(std::function callback, + Args... args); + +ScheduleNodeMatcher band(); +ScheduleNodeMatcher band(ScheduleNodeMatcher &&child); +ScheduleNodeMatcher band(std::function callback); +ScheduleNodeMatcher band(std::function callback, + ScheduleNodeMatcher &&child); + +ScheduleNodeMatcher context(); +ScheduleNodeMatcher context(ScheduleNodeMatcher &&child); +ScheduleNodeMatcher context(std::function callback); +ScheduleNodeMatcher context(std::function callback, + ScheduleNodeMatcher &&child); + +ScheduleNodeMatcher domain(); +ScheduleNodeMatcher domain(ScheduleNodeMatcher &&child); +ScheduleNodeMatcher domain(std::function callback); +ScheduleNodeMatcher domain(std::function callback, + ScheduleNodeMatcher &&child); + +ScheduleNodeMatcher extension(); +ScheduleNodeMatcher extension(ScheduleNodeMatcher &&child); +ScheduleNodeMatcher extension(std::function callback); +ScheduleNodeMatcher extension(std::function callback, + ScheduleNodeMatcher &&child); + +ScheduleNodeMatcher filter(); +ScheduleNodeMatcher filter(ScheduleNodeMatcher &&child); +ScheduleNodeMatcher filter(std::function callback); +ScheduleNodeMatcher filter(std::function callback, + ScheduleNodeMatcher &&child); + +ScheduleNodeMatcher guard(); +ScheduleNodeMatcher guard(ScheduleNodeMatcher &&child); +ScheduleNodeMatcher guard(std::function callback); +ScheduleNodeMatcher guard(std::function callback, + ScheduleNodeMatcher &&child); + +ScheduleNodeMatcher mark(); +ScheduleNodeMatcher mark(ScheduleNodeMatcher &&child); +ScheduleNodeMatcher mark(std::function callback); +ScheduleNodeMatcher mark(std::function callback, + ScheduleNodeMatcher &&child); + +ScheduleNodeMatcher leaf(); + +class ScheduleNodeMatcher { +#define DECL_FRIEND_TYPE_MATCH(name) \ + friend ScheduleNodeMatcher name(); \ + template \ + friend ScheduleNodeMatcher name(std::function, \ + Args...); \ + template \ + friend ScheduleNodeMatcher name(Arg, Args...); + DECL_FRIEND_TYPE_MATCH(sequence) + DECL_FRIEND_TYPE_MATCH(set) + +#undef DECL_FRIEND_TYPE_MATCH + +#define DECL_FRIEND_TYPE_MATCH(name) \ + friend ScheduleNodeMatcher name(); \ + friend ScheduleNodeMatcher name(ScheduleNodeMatcher &&); \ + friend ScheduleNodeMatcher name(std::function); \ + friend ScheduleNodeMatcher name(std::function, \ + ScheduleNodeMatcher &&); + + DECL_FRIEND_TYPE_MATCH(band) + DECL_FRIEND_TYPE_MATCH(context) + DECL_FRIEND_TYPE_MATCH(domain) + DECL_FRIEND_TYPE_MATCH(extension) + DECL_FRIEND_TYPE_MATCH(filter) + DECL_FRIEND_TYPE_MATCH(guard) + DECL_FRIEND_TYPE_MATCH(mark) + DECL_FRIEND_TYPE_MATCH(leaf) + +#undef DECL_FRIEND_TYPE_MATCH + +public: + bool isMatching(const ScheduleNodeMatcher &matcher, isl::schedule_node node); + void printMatcher(raw_ostream &OS, const ScheduleNodeMatcher &matcher, int indent) const; + +private: + isl_schedule_node_type current_; + //TODO: SmallVector ?? + std::vector children_; + std::function nodeCallback_; +}; + +#include "matchers-inl.h" + +} // namespace matchers namespace llvm { Index: include/polly/matchers-inl.h =================================================================== --- /dev/null +++ include/polly/matchers-inl.h @@ -0,0 +1,97 @@ +//#define DEBUG_TYPE "matchers-inl" + +#include +#include +#include "llvm/Support/Debug.h" + +namespace { +template +inline typename std::enable_if::type +appendVarargToVector(std::vector &vec, Args... args) { + for(auto& a : {typename std::common_type::type(args)...}) + { + vec.push_back(a); + } +} + +template +inline typename std::enable_if::type +appendVarargToVector(std::vector &vec, Args...) { + (void)vec; +} + +template +std::vector varargToVector(Args... args) { + std::vector result; + appendVarargToVector(result, args...); + return result; +} +} // namespace + +/* Definitions for schedule tree matcher factory functions ********************/ +#define DEF_TYPE_MATCHER(name, type) \ + template \ + inline ScheduleNodeMatcher name(Arg arg, Args... args) { \ + ScheduleNodeMatcher matcher; \ + matcher.current_ = type; \ + matcher.children_ = varargToVector(arg, args...); \ + return matcher; \ + } \ + inline ScheduleNodeMatcher name() { \ + ScheduleNodeMatcher matcher; \ + matcher.current_ = type; \ + return matcher; \ + } \ + \ + template \ + inline ScheduleNodeMatcher name( \ + std::function callback, Args... args) { \ + ScheduleNodeMatcher matcher = name(std::forward(args)...); \ + matcher.nodeCallback_ = callback; \ + return matcher; \ + } + +DEF_TYPE_MATCHER(sequence, isl_schedule_node_sequence) +DEF_TYPE_MATCHER(set, isl_schedule_node_set) + +#undef DEF_TYPE_MATCHER + +#define DEF_TYPE_MATCHER(name, type) \ + inline ScheduleNodeMatcher name() { \ + ScheduleNodeMatcher matcher; \ + matcher.current_ = type; \ + return matcher; \ + } \ + \ + inline ScheduleNodeMatcher name(ScheduleNodeMatcher &&child) { \ + ScheduleNodeMatcher matcher; \ + matcher.current_ = type; \ + matcher.children_.emplace_back(child); \ + return matcher; \ + } \ + \ + inline ScheduleNodeMatcher name( \ + std::function callback) { \ + ScheduleNodeMatcher matcher = name(); \ + matcher.nodeCallback_ = callback; \ + return matcher; \ + } \ + \ + inline ScheduleNodeMatcher name( \ + std::function callback, \ + ScheduleNodeMatcher &&child) { \ + ScheduleNodeMatcher matcher = name(std::move(child)); \ + matcher.nodeCallback_ = callback; \ + return matcher; \ + } + +DEF_TYPE_MATCHER(band, isl_schedule_node_band) +DEF_TYPE_MATCHER(context, isl_schedule_node_context) +DEF_TYPE_MATCHER(domain, isl_schedule_node_domain) +DEF_TYPE_MATCHER(extension, isl_schedule_node_extension) +DEF_TYPE_MATCHER(filter, isl_schedule_node_filter) +DEF_TYPE_MATCHER(guard, isl_schedule_node_guard) +DEF_TYPE_MATCHER(mark, isl_schedule_node_mark) +DEF_TYPE_MATCHER(leaf, isl_schedule_node_leaf) + +#undef DEF_TYPE_MATCHER Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -299,6 +299,93 @@ STATISTIC(MatMulOpts, "Number of matrix multiplication patterns detected and optimized"); +namespace matchers { + +void ScheduleNodeMatcher::printMatcher(raw_ostream &OS, + const ScheduleNodeMatcher &matcher, int indent) const { + + switch(matcher.current_) { + case isl_schedule_node_sequence: + OS.indent(indent) << "Sequence Node\n"; + break; + case isl_schedule_node_set: + OS.indent(indent) << "Set Node\n"; + break; + case isl_schedule_node_band: + OS.indent(indent) << "Band Node\n"; + break; + case isl_schedule_node_context: + OS.indent(indent) << "Context Node\n"; + break; + case isl_schedule_node_domain: + OS.indent(indent) << "Domain Node\n"; + break; + case isl_schedule_node_extension: + OS.indent(indent) << "Extension Node\n"; + break; + case isl_schedule_node_filter: + OS.indent(indent) << "Filter Node\n"; + break; + case isl_schedule_node_guard: + OS.indent(indent) << "Guard Node\n"; + break; + case isl_schedule_node_mark: + OS.indent(indent) << "Mark Node\n"; + break; + case isl_schedule_node_leaf: + OS.indent(indent) << "Leaf Node\n"; + break; + default: + OS.indent(indent) << "ND\n"; + } + + if(matcher.children_.empty()) { + return; + } + + int n_children_ = matcher.children_.size(); + for(int i = 0; i < n_children_; ++i) { + printMatcher(OS, matcher.children_[i], indent+2); + } + + OS << "\n"; +} + + +bool ScheduleNodeMatcher::isMatching(const ScheduleNodeMatcher &matcher, + isl::schedule_node node) { + if (!node.get()) { + return false; + } + + if (matcher.current_ != isl_schedule_node_get_type(node.get())) { + return false; + } + + if (matcher.nodeCallback_ && !matcher.nodeCallback_(node)) { + return false; + } + + if (matcher.children_.size() == 0) { + return true; + } + + size_t nChildren = + static_cast(isl_schedule_node_n_children(node.get())); + if (matcher.children_.size() != nChildren) { + return false; + } + + for (size_t i = 0; i < nChildren; ++i) { + if (!isMatching(matcher.children_.at(i), node.child(i))) { + return false; + } + } + return true; +} + +} // namespace matchers + /// Create an isl::union_set, which describes the isolate option based on /// IsolateDomain. /// @@ -1467,6 +1554,11 @@ &Version); } +static bool handlerMatcher(isl::schedule_node Node) { + LLVM_DEBUG(dbgs() << "hello matchers\n"); + return true; +} + bool IslScheduleOptimizer::runOnScop(Scop &S) { // Skip SCoPs in case they're already optimised by PPCGCodeGeneration if (S.isToBeSkipped()) @@ -1617,6 +1709,35 @@ isl_printer_free(P); }); +{ + using namespace matchers; + // try different constructors for matchers. (to be removed) + auto MatcherObj = domain(sequence(filter(band()),filter(band()))); + MatcherObj.printMatcher(llvm::dbgs(), MatcherObj, 2); + MatcherObj = band(sequence(filter())); + MatcherObj.printMatcher(llvm::dbgs(), MatcherObj, 2); + MatcherObj = band(set(filter())); + MatcherObj.printMatcher(llvm::dbgs(), MatcherObj, 2); + MatcherObj = band(sequence()); + MatcherObj.printMatcher(llvm::dbgs(), MatcherObj, 2); + MatcherObj = domain(context(filter())); + MatcherObj.printMatcher(llvm::dbgs(), MatcherObj, 2); + // callback (to be removed) + bool isAMatch = false; + isl::schedule_node Root = Schedule.get_root(); + MatcherObj = domain(sequence(handlerMatcher)); + isAMatch = MatcherObj.isMatching(MatcherObj,Root); + LLVM_DEBUG(dbgs() << "Match :=" << isAMatch << "\n"); + MatcherObj = domain(handlerMatcher); + isAMatch = MatcherObj.isMatching(MatcherObj,Root); + LLVM_DEBUG(dbgs() << "Match :=" << isAMatch << "\n"); + // callback plus child (to be removed) + MatcherObj = domain(handlerMatcher,band()); + isAMatch = MatcherObj.isMatching(MatcherObj,Root); + LLVM_DEBUG(dbgs() << "Match :=" << isAMatch << "\n"); +} + + Function &F = S.getFunction(); auto *TTI = &getAnalysis().getTTI(F); const OptimizerAdditionalInfoTy OAI = {TTI, const_cast(&D)};