diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -15,6 +15,7 @@ #include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/Dialect.h" +#include "mlir/TableGen/Format.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -91,8 +92,7 @@ /// class StaticVerifierFunctionEmitter { public: - StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records, - raw_ostream &os); + StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records); /// Emit the static verifier functions for `llvm::Record`s. The /// `signatureFormat` describes the required arguments and it must have a @@ -112,30 +112,40 @@ /// /// `typeArgName` is used to identify the argument that needs to check its /// type. The constraint template will replace `$_self` with it. - void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat, - StringRef typeArgName, ArrayRef opDefs, - bool emitDecl); + + /// This is the helper to generate the constraint functions from op + /// definitions. + void emitConstraintMethodsInNamespace(StringRef signatureFormat, + StringRef errorHandlerFormat, + StringRef cppNamespace, + ArrayRef constraints, + raw_ostream &rawOs, bool emitDecl); + + /// Emit the static functions for the giving type constraints. + void emitConstraintMethods(StringRef signatureFormat, + StringRef errorHandlerFormat, + ArrayRef constraints, + raw_ostream &rawOs, bool emitDecl); /// Get the name of the local function used for the given type constraint. /// These functions are used for operand and result constraints and have the /// form: /// LogicalResult(Operation *op, Type type, StringRef valueKind, /// unsigned valueGroupStartIndex); - StringRef getTypeConstraintFn(const Constraint &constraint) const; + StringRef getConstraintFn(const Constraint &constraint) const; + + /// The setter to set `self` in format context. + StaticVerifierFunctionEmitter &setSelf(StringRef str); + + /// The setter to set `builder` in format context. + StaticVerifierFunctionEmitter &setBuilder(StringRef str); private: /// Returns a unique name to use when generating local methods. static std::string getUniqueName(const llvm::RecordKeeper &records); - /// Emit local methods for the type constraints used within the provided op - /// definitions. - void emitTypeConstraintMethods(StringRef signatureFormat, - StringRef errorHandlerFormat, - StringRef typeArgName, - ArrayRef opDefs, - bool emitDecl); - - raw_indented_ostream os; + /// The format context used for building the verifier function. + FmtContext fctx; /// A unique label for the file currently being generated. This is used to /// ensure that the local functions have a unique name. diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -113,6 +113,9 @@ void print(raw_ostream &os) const; private: + friend llvm::DenseMapInfo; + const void *getAsOpaquePointer() const { return def; } + // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and // also a subclass of the given `superclass`. bool isSubClassOf(StringRef superclass) const; @@ -523,6 +526,24 @@ return lhs.node == rhs.node; } }; + +template <> +struct DenseMapInfo { + static mlir::tblgen::DagLeaf getEmptyKey() { + return mlir::tblgen::DagLeaf( + llvm::DenseMapInfo::getEmptyKey()); + } + static mlir::tblgen::DagLeaf getTombstoneKey() { + return mlir::tblgen::DagLeaf( + llvm::DenseMapInfo::getTombstoneKey()); + } + static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) { + return llvm::hash_value(leaf.getAsOpaquePointer()); + } + static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs) { + return lhs.def == rhs.def; + } +}; } // end namespace llvm #endif // MLIR_TABLEGEN_PATTERN_H_ diff --git a/mlir/test/mlir-tblgen/rewriter-static-matcher.td b/mlir/test/mlir-tblgen/rewriter-static-matcher.td --- a/mlir/test/mlir-tblgen/rewriter-static-matcher.td +++ b/mlir/test/mlir-tblgen/rewriter-static-matcher.td @@ -37,12 +37,16 @@ // Test static matcher for duplicate DagNode // --- -// CHECK: static ::mlir::LogicalResult static_dag_matcher_0 +// CHECK-DAG: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Type typeOrAttr}} +// CHECK-DAG: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Attribute}} +// CHECK-DAG: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]]( +// CHECK: if(failed([[$TYPE_CONSTRAINT]] +// CHECK: if(failed([[$ATTR_CONSTRAINT]] -// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops +// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)), (AOp $int)>; -// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops +// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)), (COp $attr, $int)>; diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp --- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/CodeGenHelpers.h" -#include "mlir/TableGen/Format.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/FormatVariadic.h" @@ -24,21 +23,34 @@ using namespace mlir::tblgen; StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( - const llvm::RecordKeeper &records, raw_ostream &os) - : os(os), uniqueOutputLabel(getUniqueName(records)) {} + const llvm::RecordKeeper &records) + : uniqueOutputLabel(getUniqueName(records)) {} -void StaticVerifierFunctionEmitter::emitFunctionsFor( +StaticVerifierFunctionEmitter & +StaticVerifierFunctionEmitter::setSelf(StringRef str) { + fctx.withSelf(str); + return *this; +} + +StaticVerifierFunctionEmitter & +StaticVerifierFunctionEmitter::setBuilder(StringRef str) { + fctx.withBuilder(str); + return *this; +} + +void StaticVerifierFunctionEmitter::emitConstraintMethodsInNamespace( StringRef signatureFormat, StringRef errorHandlerFormat, - StringRef typeArgName, ArrayRef opDefs, bool emitDecl) { + StringRef cppNamespace, ArrayRef constraints, raw_ostream &os, + bool emitDecl) { llvm::Optional namespaceEmitter; if (!emitDecl) - namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace()); + namespaceEmitter.emplace(os, cppNamespace); - emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName, - opDefs, emitDecl); + emitConstraintMethods(signatureFormat, errorHandlerFormat, constraints, os, + emitDecl); } -StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( +StringRef StaticVerifierFunctionEmitter::getConstraintFn( const Constraint &constraint) const { auto it = localTypeConstraints.find(constraint.getAsOpaquePointer()); assert(it != localTypeConstraints.end() && "expected valid constraint fn"); @@ -65,28 +77,16 @@ return uniqueName; } -void StaticVerifierFunctionEmitter::emitTypeConstraintMethods( +void StaticVerifierFunctionEmitter::emitConstraintMethods( StringRef signatureFormat, StringRef errorHandlerFormat, - StringRef typeArgName, ArrayRef opDefs, bool emitDecl) { - // Collect a set of all of the used type constraints within the operation - // definitions. - llvm::SetVector typeConstraints; - for (Record *def : opDefs) { - Operator op(*def); - for (NamedTypeConstraint &operand : op.getOperands()) - if (operand.hasPredicate()) - typeConstraints.insert(operand.constraint.getAsOpaquePointer()); - for (NamedTypeConstraint &result : op.getResults()) - if (result.hasPredicate()) - typeConstraints.insert(result.constraint.getAsOpaquePointer()); - } + ArrayRef constraints, raw_ostream &rawOs, bool emitDecl) { + raw_indented_ostream os(rawOs); // Record the mapping from predicate to constraint. If two constraints has the // same predicate and constraint summary, they can share the same verification // function. llvm::DenseMap predToConstraint; - FmtContext fctx; - for (auto it : llvm::enumerate(typeConstraints)) { + for (auto it : llvm::enumerate(constraints)) { std::string name; Constraint constraint = Constraint::getFromOpaquePointer(it.value()); Pred pred = constraint.getPredicate(); @@ -101,7 +101,7 @@ // summary, otherwise we may report the wrong message while verification // fails. if (constraint.getSummary() == built.getSummary()) { - name = getTypeConstraintFn(built).str(); + name = getConstraintFn(built).str(); break; } ++iter; @@ -126,12 +126,11 @@ continue; os << formatv(signatureFormat.data(), name) << " {\n"; - os.indent() << "if (!(" - << tgfmt(constraint.getConditionTemplate(), - &fctx.withSelf(typeArgName)) + os.indent() << "if (!(" << tgfmt(constraint.getConditionTemplate(), &fctx) << ")) {\n"; os.indent() << "return " - << formatv(errorHandlerFormat.data(), constraint.getSummary()) + << formatv(errorHandlerFormat.data(), + escapeString(constraint.getSummary())) << ";\n"; os.unindent() << "}\nreturn ::mlir::success();\n"; os.unindent() << "}\n\n"; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2233,7 +2233,7 @@ continue; // Emit a loop to check all the dynamic values in the pack. StringRef constraintFn = - staticVerifierEmitter.getTypeConstraintFn(value.constraint); + staticVerifierEmitter.getConstraintFn(value.constraint); body << " for (::mlir::Value v : valueGroup" << staticValue.index() << ") {\n" << " if (::mlir::failed(" << constraintFn @@ -2639,11 +2639,27 @@ return; // Generate all of the locally instantiated methods first. - StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os); + StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper); os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); - staticVerifierEmitter.emitFunctionsFor( - typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type", - defs, emitDecl); + staticVerifierEmitter.setSelf("type"); + + // Collect a set of all of the used type constraints within the operation + // definitions. + llvm::SetVector typeConstraints; + for (Record *def : defs) { + Operator op(*def); + for (NamedTypeConstraint &operand : op.getOperands()) + if (operand.hasPredicate()) + typeConstraints.insert(operand.constraint.getAsOpaquePointer()); + for (NamedTypeConstraint &result : op.getResults()) + if (result.hasPredicate()) + typeConstraints.insert(result.constraint.getAsOpaquePointer()); + } + + staticVerifierEmitter.emitConstraintMethodsInNamespace( + typeVerifierSignature, typeVerifierErrorHandler, + Operator(*defs[0]).getCppNamespace(), typeConstraints.getArrayRef(), os, + emitDecl); for (auto *def : defs) { Operator op(*def); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -20,6 +20,7 @@ #include "mlir/TableGen/Predicate.h" #include "mlir/TableGen/Type.h" #include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" @@ -41,6 +42,23 @@ #define DEBUG_TYPE "mlir-tblgen-rewritergen" +// The signature of static type verification function +static const char *typeVerifierSignature = + "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " + "::mlir::Operation *op, ::mlir::Type typeOrAttr, " + "::llvm::StringRef failureStr)"; + +// The signature of static attribute verification function +static const char *attrVerifierSignature = + "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " + "::mlir::Operation *op, ::mlir::Attribute typeOrAttr, " + "::llvm::StringRef failureStr)"; + +// The template of error handler in static type/attribute verification function +static const char *verifierErrorHandler = + "rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {\n diag " + "<< failureStr << \": {0}\";\n});"; + namespace llvm { template <> struct format_provider { @@ -87,6 +105,10 @@ // Emit C++ function call to static DAG matcher. void emitStaticMatchCall(DagNode tree, StringRef name); + // Emit C++ function call to static type/attribute constraint function. + void emitStaticVerifierCall(StringRef funcName, StringRef opName, + StringRef arg, StringRef failureStr); + // Emits C++ statements for matching using a native code call. void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); @@ -244,7 +266,8 @@ // inlining them. class StaticMatcherHelper { public: - StaticMatcherHelper(RecordOperatorMap &mapper); + StaticMatcherHelper(const RecordKeeper &recordKeeper, + RecordOperatorMap &mapper); // Determine if we should inline the match logic or delegate to a static // function. @@ -258,6 +281,9 @@ return matcherNames[node]; } + // Get the name of static type/attribute verification function. + StringRef getVerifierName(Constraint constraint); + // Collect the `Record`s, i.e., the DRR, so that we can get the information of // the duplicated DAGs. void addPattern(Record *record); @@ -265,6 +291,9 @@ // Emit all static functions of DAG Matcher. void populateStaticMatchers(raw_ostream &os); + // Emit all static functions for Constraints. + void populateStaticConstraintFunctions(raw_ostream &os); + private: static constexpr unsigned kStaticMatcherThreshold = 1; @@ -301,6 +330,12 @@ // Number of static matcher generated. This is used to generate a unique name // for each DagNode. int staticMatcherCounter = 0; + + // The DagLeaf which contains type or attr constraint. + DenseSet constraints; + + // Static type/attribute verification function emitter. + StaticVerifierFunctionEmitter staticVerifierEmitter; }; } // end anonymous namespace @@ -395,6 +430,15 @@ os << "}\n"; } +void PatternEmitter::emitStaticVerifierCall(StringRef funcName, + StringRef opName, StringRef arg, + StringRef failureStr) { + os << formatv("if(failed({0}(rewriter, {1}, {2}, {3}))) {{\n", funcName, + opName, arg, failureStr); + os.scope().os << "return ::mlir::failure();\n"; + os << "}\n"; +} + // Helper function to match patterns. void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, int depth) { @@ -487,14 +531,15 @@ self = argName; else self = formatv("{0}.getType()", argName); - emitMatchCheck( - opName, - tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), + StringRef verifier = staticMatcherHelper.getVerifierName(constraint); + emitStaticVerifierCall( + verifier, opName, self, formatv("\"operand {0} of native code call '{1}' failed to satisfy " "constraint: " "'{2}'\"", i, tree.getNativeCodeTemplate(), - escapeString(constraint.getSummary()))); + escapeString(constraint.getSummary())) + .str()); } LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); @@ -626,13 +671,14 @@ } auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", opName, operandIndex); - emitMatchCheck( - opName, - tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), - formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " - "'{2}'\"", - operand - op.operand_begin(), op.getOperationName(), - escapeString(constraint.getSummary()))); + StringRef verifier = staticMatcherHelper.getVerifierName(constraint); + emitStaticVerifierCall( + verifier, opName, self.str(), + formatv( + "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"", + operandIndex, op.getOperationName(), + escapeString(constraint.getSummary())) + .str()); } } @@ -690,15 +736,17 @@ op.getOperationName(), argIndex + 1)); } - // If a constraint is specified, we need to generate C++ statements to - // check the constraint. - emitMatchCheck( - opName, - tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), + // If a constraint is specified, we need to generate function call to its + // static verifier. + StringRef verifier = + staticMatcherHelper.getVerifierName(matcher.getAsConstraint()); + emitStaticVerifierCall( + verifier, opName, "tblgen_attr", formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " "'{2}'\"", op.getOperationName(), namedAttr->name, - escapeString(matcher.getAsConstraint().getSummary()))); + escapeString(matcher.getAsConstraint().getSummary())) + .str()); } // Capture the value @@ -1571,8 +1619,9 @@ } } -StaticMatcherHelper::StaticMatcherHelper(RecordOperatorMap &mapper) - : opMap(mapper) {} +StaticMatcherHelper::StaticMatcherHelper(const RecordKeeper &recordKeeper, + RecordOperatorMap &mapper) + : opMap(mapper), staticVerifierEmitter(recordKeeper) {} void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { // PatternEmitter will use the static matcher if there's one generated. To @@ -1592,6 +1641,31 @@ } } +void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) { + llvm::SetVector typeConstraints; + llvm::SetVector attrConstraints; + for (DagLeaf leaf : constraints) { + if (leaf.isOperandMatcher()) { + typeConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer()); + } else { + assert(leaf.isAttrMatcher()); + attrConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer()); + } + } + + staticVerifierEmitter.setBuilder("rewriter").setSelf("typeOrAttr"); + + staticVerifierEmitter.emitConstraintMethods(typeVerifierSignature, + verifierErrorHandler, + typeConstraints.getArrayRef(), os, + /*emitDecl=*/false); + + staticVerifierEmitter.emitConstraintMethods(attrVerifierSignature, + verifierErrorHandler, + attrConstraints.getArrayRef(), os, + /*emitDecl=*/false); +} + void StaticMatcherHelper::addPattern(Record *record) { Pattern pat(record, &opMap); @@ -1608,6 +1682,11 @@ for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i) if (DagNode sibling = node.getArgAsNestedDag(i)) dfs(sibling); + else { + DagLeaf leaf = node.getArgAsLeaf(i); + if (!leaf.isUnspecified()) + constraints.insert(leaf); + } topologicalOrder.push_back(std::make_pair(node, record)); }; @@ -1615,6 +1694,10 @@ dfs(pat.getSourcePattern()); } +StringRef StaticMatcherHelper::getVerifierName(Constraint constraint) { + return staticVerifierEmitter.getConstraintFn(constraint); +} + static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Rewriters", os); @@ -1625,9 +1708,10 @@ // Exam all the patterns and generate static matcher for the duplicated // DagNode. - StaticMatcherHelper staticMatcher(recordOpMap); + StaticMatcherHelper staticMatcher(recordKeeper, recordOpMap); for (Record *p : patterns) staticMatcher.addPattern(p); + staticMatcher.populateStaticConstraintFunctions(os); staticMatcher.populateStaticMatchers(os); std::vector rewriterNames;