diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -18,6 +18,7 @@ #include "mlir/TableGen/Argument.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" @@ -198,6 +199,7 @@ private: friend class SymbolInfoMap; + friend llvm::DenseMapInfo; const void *getAsOpaquePointer() const { return node; } const llvm::DagInit *node; // nullptr means null DagNode @@ -242,10 +244,17 @@ // Class for information regarding a symbol. class SymbolInfo { public: + // Returns a type string of a variable. + std::string getVarTypeStr(StringRef name) const; + // Returns a string for defining a variable named as `name` to store the // value bound by this symbol. std::string getVarDecl(StringRef name) const; + // Returns a string for defining an argument which passes the reference of + // the variable. + std::string getArgDecl(StringRef name) const; + // Returns a variable name for the symbol named as `name`. std::string getVarName(StringRef name) const; @@ -383,6 +392,7 @@ // with index `argIndex` for operator `op`. const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex) const; + const_iterator findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const; // Returns the bounds of a range that includes all the elements which // bind to the `key`. @@ -474,15 +484,15 @@ // pair). std::vector getLocation() const; -private: - // Helper function to verify variabld binding. - void verifyBind(bool result, StringRef symbolName); - // Recursively collects all bound symbols inside the DAG tree rooted // at `tree` and updates the given `infoMap`. void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern); +private: + // Helper function to verify variable binding. + void verifyBind(bool result, StringRef symbolName); + // The TableGen definition of this pattern. const llvm::Record &def; @@ -495,4 +505,24 @@ } // end namespace tblgen } // end namespace mlir +namespace llvm { +template <> +struct DenseMapInfo { + static mlir::tblgen::DagNode getEmptyKey() { + return mlir::tblgen::DagNode( + llvm::DenseMapInfo::getEmptyKey()); + } + static mlir::tblgen::DagNode getTombstoneKey() { + return mlir::tblgen::DagNode( + llvm::DenseMapInfo::getTombstoneKey()); + } + static unsigned getHashValue(mlir::tblgen::DagNode node) { + return llvm::hash_value(node.getAsOpaquePointer()); + } + static bool isEqual(mlir::tblgen::DagNode lhs, mlir::tblgen::DagNode rhs) { + return lhs.node == rhs.node; + } +}; +} // end namespace llvm + #endif // MLIR_TABLEGEN_PATTERN_H_ diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -230,45 +230,50 @@ return alternativeName.hasValue() ? alternativeName.getValue() : name.str(); } -std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { - LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); +std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { + LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': "); switch (kind) { case Kind::Attr: { - if (op) { - auto type = op->getArg(getArgIndex()) - .get() - ->attr.getStorageType(); - return std::string(formatv("{0} {1};\n", type, name)); - } + if (op) + return op->getArg(getArgIndex()) + .get() + ->attr.getStorageType() + .str(); // TODO(suderman): Use a more exact type when available. - return std::string(formatv("Attribute {0};\n", name)); + return "Attribute"; } case Kind::Operand: { // Use operand range for captured operands (to support potential variadic // operands). - return std::string( - formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n", - getVarName(name))); + return "::mlir::Operation::operand_range"; } case Kind::Value: { - return std::string(formatv("::mlir::Value {0};\n", name)); + return "::mlir::Value"; } case Kind::MultipleValues: { - // This is for the variable used in the source pattern. Each named value in - // source pattern will only be bound to a Value. The others in the result - // pattern may be associated with multiple Values as we will use `auto` to - // do the type inference. - return std::string(formatv( - "::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name)); + return "::mlir::ValueRange"; } case Kind::Result: { // Use the op itself for captured results. - return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); + return op->getQualCppClassName(); } } llvm_unreachable("unknown kind"); } +std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { + LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); + std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : ""; + return std::string( + formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit)); +} + +std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const { + LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': "); + return std::string( + formatv("{0} &{1}", getVarTypeStr(name), getVarName(name))); +} + std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( StringRef name, int index, const char *fmt, const char *separator) const { LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); @@ -486,11 +491,14 @@ SymbolInfoMap::const_iterator SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex) const { + return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex)); +} + +SymbolInfoMap::const_iterator +SymbolInfoMap::findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const { std::string name = getValuePackName(key).str(); auto range = symbolInfoMap.equal_range(name); - const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex); - for (auto it = range.first; it != range.second; ++it) if (it->second.dagAndConstant == symbolInfo.dagAndConstant) return it; diff --git a/mlir/test/mlir-tblgen/rewriter-static-matcher.td b/mlir/test/mlir-tblgen/rewriter-static-matcher.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/rewriter-static-matcher.td @@ -0,0 +1,48 @@ +// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} +class NS_Op traits> : + Op; + +def AOp : NS_Op<"a_op", []> { + let arguments = (ins + AnyInteger:$any_integer + ); + + let results = (outs AnyInteger); +} + +def BOp : NS_Op<"b_op", []> { + let arguments = (ins + AnyAttr: $any_attr, + AnyInteger + ); + + let results = (outs AnyInteger); +} + +def COp : NS_Op<"c_op", []> { + let arguments = (ins + AnyAttr: $any_attr, + AnyInteger + ); + + let results = (outs AnyInteger); +} + +// Test static matcher for duplicate DagNode +// --- + +// CHECK: static ::mlir::LogicalResult static_dag_matcher_0 + +// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops +def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)), + (AOp $int)>; + +// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops +def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)), + (COp $attr, $int)>; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -18,6 +18,7 @@ #include "mlir/TableGen/Pattern.h" #include "mlir/TableGen/Predicate.h" #include "mlir/TableGen/Type.h" +#include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" @@ -54,13 +55,20 @@ //===----------------------------------------------------------------------===// namespace { + +class StaticMatcherHelper; + class PatternEmitter { public: - PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); + PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os, + StaticMatcherHelper &helper); // Emits the mlir::RewritePattern struct named `rewriteName`. void emit(StringRef rewriteName); + // Emits the static function of DAG matcher. + void emitStaticMatcher(DagNode tree, std::string funcName); + private: // Emits the code for matching ops. void emitMatchLogic(DagNode tree, StringRef opName); @@ -75,6 +83,9 @@ // Emits C++ statements for matching the DAG structure. void emitMatch(DagNode tree, StringRef name, int depth); + // Emit C++ function call to static DAG matcher. + void emitStaticMatchCall(DagNode tree, StringRef name); + // Emits C++ statements for matching using a native code call. void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); @@ -216,6 +227,8 @@ // Map for all bound symbols' info. SymbolInfoMap symbolInfoMap; + StaticMatcherHelper &staticMatcherHelper; + // The next unused ID for newly created values. unsigned nextValueId; @@ -223,16 +236,79 @@ // Format contexts containing placeholder substitutions. FmtContext fmtCtx; +}; + +// Tracks DagNode's reference multiple times across patterns. Enables generating +// static matcher functions for DagNode's referenced multiple times rather than +// inlining them. +class StaticMatcherHelper { +public: + StaticMatcherHelper(RecordOperatorMap &mapper); + + // Determine if we should inline the match logic or delegate to a static + // function. + bool useStaticMatcher(DagNode node) { + return refStats[node] > kStaticMatcherThreshold; + } + + // Get the name of the static DAG matcher function corresponding to the node. + std::string getMatcherName(DagNode node) { + assert(useStaticMatcher(node)); + return matcherNames[node]; + } + + // Collect the `Record`s, i.e., the DRR, so that we can get the information of + // the duplicated DAGs. + void addPattern(Record *record); + + // Emit all static functions of DAG Matcher. + void populateStaticMatchers(raw_ostream &os); - // Number of op processed. - int opCounter = 0; +private: + static constexpr unsigned kStaticMatcherThreshold = 1; + + // Consider two patterns as down below, + // DagNode_Root_A DagNode_Root_B + // \ \ + // DagNode_C DagNode_C + // \ \ + // DagNode_D DagNode_D + // + // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of + // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced + // multiple times so we'll have static matchers for both of them. When we're + // emitting the match logic for DagNode_C, we will check if DagNode_D has the + // static matcher generated. If so, then we'll generate a call to the + // function, inline otherwise. In this case, inlining is not what we want. As + // a result, generate the static matcher in topological order to ensure all + // the dependent static matchers are generated and we can avoid accidentally + // inlining. + // + // The topological order of all the DagNodes among all patterns. + SmallVector> topologicalOrder; + + RecordOperatorMap &opMap; + + // Records of the static function name of each DagNode + DenseMap matcherNames; + + // After collecting all the DagNode in each pattern, `refStats` records the + // number of users for each DagNode. We will generate the static matcher for a + // DagNode while the number of users exceeds a certain threshold. + DenseMap refStats; + + // Number of static matcher generated. This is used to generate a unique name + // for each DagNode. + int staticMatcherCounter = 0; }; + } // end anonymous namespace PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, - raw_ostream &os) + raw_ostream &os, StaticMatcherHelper &helper) : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), - symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { + symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), nextValueId(0), + os(os) { fmtCtx.withBuilder("rewriter"); } @@ -246,6 +322,33 @@ return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); } +void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) { + os << formatv( + "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " + "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation " + "*, 4> &tblgen_ops", + funcName); + + // We pass the reference of the variables that need to be captured. Hence we + // need to collect all the symbols in the tree first. + pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true); + symbolInfoMap.assignUniqueAlternativeNames(); + for (const auto &info : symbolInfoMap) + os << formatv(", {0}", info.second.getArgDecl(info.first)); + + os << ") {\n"; + os.indent(); + os << "(void)tblgen_ops;\n"; + + // Note that a static matcher is considered at least one step from the match + // entry. + emitMatch(tree, "op0", /*depth=*/1); + + os << "return ::mlir::success();\n"; + os.unindent(); + os << "}\n\n"; +} + // Helper function to match patterns. void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { if (tree.isNativeCodeCall()) { @@ -261,6 +364,36 @@ PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); } +void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) { + std::string funcName = staticMatcherHelper.getMatcherName(tree); + os << formatv("if(failed({0}(rewriter, {1}, tblgen_ops", funcName, opName); + + // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in + // one pass. + + // In general, bound symbol should have the unique name in the pattern but + // for the operand, binding same symbol to multiple operands imply a + // constraint at the same time. In this case, we will rename those operands + // with different names. As a result, we need to collect all the symbolInfos + // from the DagNode then get the updated name of the local variables from the + // global symbolInfoMap. + + // Collect all the bound symbols in the Dag + SymbolInfoMap localSymbolMap(loc); + pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true); + + for (const auto &info : localSymbolMap) { + auto name = info.first; + auto symboInfo = info.second; + auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo); + os << formatv(", {0}", ret->second.getVarName(name)); + } + + os << "))) {\n"; + os.scope().os << "return ::mlir::failure();\n"; + os << "}\n"; +} + // Helper function to match patterns. void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, int depth) { @@ -268,6 +401,21 @@ LLVM_DEBUG(tree.print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); + // The order of generating static matcher follows the topological order so + // that for every dependent DagNode already have their static matcher + // generated if needed. The reason we check if `getMatcherName(tree).empty()` + // is when we are generating the static matcher for a DagNode itself. In this + // case, we need to emit the function body rather than a function call. + if (staticMatcherHelper.useStaticMatcher(tree) && + !staticMatcherHelper.getMatcherName(tree).empty()) { + emitStaticMatchCall(tree, opName); + + // NativeCodeCall will never be at depth 0 so that we don't need to catch + // the root operation as emitOpMatch(); + + return; + } + // TODO(suderman): iterate through arguments, determine their types, output // names. SmallVector capture; @@ -356,7 +504,28 @@ << op.getOperationName() << "' at depth " << depth << '\n'); - std::string castedName = formatv("castedOp{0}", depth); + auto getCastedName = [depth]() -> std::string { + return formatv("castedOp{0}", depth); + }; + + // The order of generating static matcher follows the topological order so + // that for every dependent DagNode already have their static matcher + // generated if needed. The reason we check if `getMatcherName(tree).empty()` + // is when we are generating the static matcher for a DagNode itself. In this + // case, we need to emit the function body rather than a function call. + if (staticMatcherHelper.useStaticMatcher(tree) && + !staticMatcherHelper.getMatcherName(tree).empty()) { + emitStaticMatchCall(tree, opName); + // In the codegen of rewriter, we suppose that castedOp0 will capture the + // root operation. Manually add it if the root DagNode is a static matcher. + if (depth == 0) + os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); " + "(void){2};\n", + opName, op.getQualCppClassName(), getCastedName()); + return; + } + + std::string castedName = getCastedName(); os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); " "(void){0};\n", castedName, opName, op.getQualCppClassName()); @@ -405,7 +574,7 @@ formatv("\"Operand {0} of {1} has null definingOp\"", nextOperand++, castedName)); emitMatch(argTree, argName, depth + 1); - os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); + os << formatv("tblgen_ops.push_back({0});\n", argName); os.unindent() << "}\n"; continue; } @@ -704,13 +873,12 @@ } // TODO: capture ops with consistent numbering so that it can be // reused for fused loc. - os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", - pattern.getSourcePattern().getNumOps()); + os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n"; LLVM_DEBUG(llvm::dbgs() << "done creating local variables for capturing matches\n"); os << "// Match\n"; - os << "tblgen_ops[0] = op0;\n"; + os << "tblgen_ops.push_back(op0);\n"; emitMatchLogic(sourceTree, "op0"); os << "\n// Rewrite\n"; @@ -1399,17 +1567,67 @@ } } +StaticMatcherHelper::StaticMatcherHelper(RecordOperatorMap &mapper) + : opMap(mapper) {} + +void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { + // PatternEmitter will use the static matcher if there's one generated. To + // ensure that all the dependent static matchers are generated before emitting + // the matching logic of the DagNode, we use topological order to achieve it. + for (auto &dagInfo : topologicalOrder) { + DagNode node = dagInfo.first; + if (!useStaticMatcher(node)) + continue; + + std::string funcName = + formatv("static_dag_matcher_{0}", staticMatcherCounter++); + assert(matcherNames.find(node) == matcherNames.end()); + PatternEmitter(dagInfo.second, &opMap, os, *this) + .emitStaticMatcher(node, funcName); + matcherNames[node] = funcName; + } +} + +void StaticMatcherHelper::addPattern(Record *record) { + Pattern pat(record, &opMap); + + // While generating the function body of the DAG matcher, it may depends on + // other DAG matchers. To ensure the dependent matchers are ready, we compute + // the topological order for all the DAGs and emit the DAG matchers in this + // order. + llvm::unique_function dfs = [&](DagNode node) { + ++refStats[node]; + + if (refStats[node] != 1) + return; + + for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i) + if (DagNode sibling = node.getArgAsNestedDag(i)) + dfs(sibling); + + topologicalOrder.push_back(std::make_pair(node, record)); + }; + + dfs(pat.getSourcePattern()); +} + static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Rewriters", os); const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); - auto numPatterns = patterns.size(); // We put the map here because it can be shared among multiple patterns. RecordOperatorMap recordOpMap; + // Exam all the patterns and generate static matcher for the duplicated + // DagNode. + StaticMatcherHelper staticMatcher(recordOpMap); + for (Record *p : patterns) + staticMatcher.addPattern(p); + staticMatcher.populateStaticMatchers(os); + std::vector rewriterNames; - rewriterNames.reserve(numPatterns); + rewriterNames.reserve(patterns.size()); std::string baseRewriterName = "GeneratedConvert"; int rewriterIndex = 0; @@ -1425,7 +1643,7 @@ } LLVM_DEBUG(llvm::dbgs() << "=== start generating pattern '" << name << "' ===\n"); - PatternEmitter(p, &recordOpMap, os).emit(name); + PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name); LLVM_DEBUG(llvm::dbgs() << "=== done generating pattern '" << name << "' ===\n"); rewriterNames.push_back(std::move(name));