diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -32,7 +32,7 @@ // in TableGen. class AttrConstraint : public Constraint { public: - explicit AttrConstraint(const llvm::Record *record); + using Constraint::Constraint; static bool classof(const Constraint *c) { return c->getKind() == CK_Attr; } diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -13,10 +13,10 @@ #ifndef MLIR_TABLEGEN_CODEGENHELPERS_H #define MLIR_TABLEGEN_CODEGENHELPERS_H -#include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/Format.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -26,8 +26,8 @@ namespace mlir { namespace tblgen { - class Constraint; +class DagLeaf; // Simple RAII helper for defining ifdef-undef-endif scopes. class IfDefScope { @@ -92,68 +92,128 @@ /// class StaticVerifierFunctionEmitter { public: - StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records); - - /// Emit the static verifier functions for `llvm::Record`s. The - /// `signatureFormat` describes the required arguments and it must have a - /// placeholder for function name. - /// Example, - /// const char *typeVerifierSignature = - /// "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type" - /// " type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)"; + StaticVerifierFunctionEmitter(raw_ostream &os, + const llvm::RecordKeeper &records); + + /// Collect and unique all compatible type, attribute, successor, and region + /// constraints from the operations in the file and emit them at the top of + /// the generated file. /// - /// `errorHandlerFormat` describes the error message to return. It may have a - /// placeholder for the summary of Constraint and bring more information for - /// the error message. - /// Example, - /// const char *typeVerifierErrorHandler = - /// " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << " - /// "\" must be {0}, but got \" << type"; + /// Constraints that do not meet the restriction that they can only reference + /// `$_self` and `$_op` are not uniqued. + void emitOpConstraints(ArrayRef opDefs, bool emitDecl); + + /// Unique all compatible type and attribute constraints from a pattern file + /// and emit them at the top of the generated file. /// - /// `typeArgName` is used to identify the argument that needs to check its - /// type. The constraint template will replace `$_self` with it. - - /// This is the helper to generate the constraint functions from op - /// definitions. - void emitConstraintMethodsInNamespace(StringRef signatureFormat, - StringRef errorHandlerFormat, - StringRef cppNamespace, - ArrayRef constraints, - raw_ostream &rawOs, bool emitDecl); - - /// Emit the static functions for the giving type constraints. - void emitConstraintMethods(StringRef signatureFormat, - StringRef errorHandlerFormat, - ArrayRef constraints, - raw_ostream &rawOs, bool emitDecl); - - /// Get the name of the local function used for the given type constraint. + /// Constraints that do not meet the restriction that they can only reference + /// `$_self`, `$_op`, and `$_builder` are not uniqued. + void emitPatternConstraints(const DenseSet &constraints); + + /// Get the name of the static function used for the given type constraint. /// These functions are used for operand and result constraints and have the /// form: + /// /// LogicalResult(Operation *op, Type type, StringRef valueKind, - /// unsigned valueGroupStartIndex); - StringRef getConstraintFn(const Constraint &constraint) const; + /// unsigned valueIndex); + /// + /// Pattern constraints have the form: + /// + /// LogicalResult(PatternRewriter &rewriter, Operation *op, Type type, + /// StringRef failureStr); + /// + StringRef getTypeConstraintFn(const Constraint &constraint) const; + + /// Get the name of the static function used for the given attribute + /// constraint. These functions are in the form: + /// + /// LogicalResult(Operation *op, Attribute attr, StringRef attrName); + /// + /// If a uniqued constraint was not found, this function returns None. The + /// uniqued constraints cannot be used in the context of an OpAdaptor. + /// + /// Pattern constraints have the form: + /// + /// LogicalResult(PatternRewriter &rewriter, Operation *op, Attribute attr, + /// StringRef failureStr); + /// + Optional getAttrConstraintFn(const Constraint &constraint) const; - /// The setter to set `self` in format context. - StaticVerifierFunctionEmitter &setSelf(StringRef str); + /// Get the name of the static function used for the given successor + /// constraint. These functions are in the form: + /// + /// LogicalResult(Operation *op, Block *successor, StringRef successorName, + /// unsigned successorIndex); + /// + StringRef getSuccessorConstraintFn(const Constraint &constraint) const; - /// The setter to set `builder` in format context. - StaticVerifierFunctionEmitter &setBuilder(StringRef str); + /// Get the name of the static function used for the given region constraint. + /// These functions are in the form: + /// + /// LogicalResult(Operation *op, Region ®ion, StringRef regionName, + /// unsigned regionIndex); + /// + /// The region name may be empty. + StringRef getRegionConstraintFn(const Constraint &constraint) const; private: - /// Returns a unique name to use when generating local methods. - static std::string getUniqueName(const llvm::RecordKeeper &records); - - /// The format context used for building the verifier function. - FmtContext fctx; + /// Emit static type constraint functions. + void emitTypeConstraints(); + /// Emit static attribute constraint functions. + void emitAttrConstraints(); + /// Emit static successor constraint functions. + void emitSuccessorConstraints(); + /// Emit static region constraint functions. + void emitRegionConstraints(); + + /// Emit pattern constraints. + void emitPatternConstraints(); + + /// Collect and unique all the constraints used by operations. + void collectOpConstraints(ArrayRef opDefs); + /// Collect and unique all pattern constraints. + void collectPatternConstraints(const DenseSet &constraints); + + /// The output stream. + raw_ostream &os; /// A unique label for the file currently being generated. This is used to - /// ensure that the local functions have a unique name. + /// ensure that the static functions have a unique name. std::string uniqueOutputLabel; - /// A set of functions implementing type constraints, used for operand and - /// result verification. - llvm::DenseMap localTypeConstraints; + /// Unique constraints by their predicate and summary. Constraints that share + /// the same predicate may have different descriptions; ensure that the + /// correct error message is reported when verification fails. + struct ConstraintUniquer { + static Constraint getEmptyKey(); + static Constraint getTombstoneKey(); + static unsigned getHashValue(Constraint constraint); + static bool isEqual(Constraint lhs, Constraint rhs); + }; + /// Use a MapVector to ensure that functions are generated deterministically. + using ConstraintMap = + llvm::MapVector>; + + /// A generic function to emit constraints + void emitConstraints(const ConstraintMap &constraints, StringRef selfName, + const char *const codeTemplate); + + /// Assign a unique name to a unique constraint. + std::string getUniqueName(StringRef kind, unsigned index); + /// Unique a constraint in the map. + void collectConstraint(ConstraintMap &map, StringRef kind, + Constraint constraint); + + /// The set of type constraints used for operand and result verification in + /// the current file. + ConstraintMap typeConstraints; + /// The set of attribute constraints used in the current file. + ConstraintMap attrConstraints; + /// The set of successor constraints used in the current file. + ConstraintMap successorConstraints; + /// The set of region constraints used in the current file. + ConstraintMap regionConstraints; }; // Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar. diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h --- a/mlir/include/mlir/TableGen/Constraint.h +++ b/mlir/include/mlir/TableGen/Constraint.h @@ -29,8 +29,15 @@ // TableGen. class Constraint { public: + // Constraint kind + enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized }; + + // Create a constraint with a TableGen definition and a kind. + Constraint(const llvm::Record *record, Kind kind) : def(record), kind(kind) {} + // Create a constraint with a TableGen definition, and infer the kind. Constraint(const llvm::Record *record); + /// Constraints are pointer-comparable. bool operator==(const Constraint &that) { return def == that.def; } bool operator!=(const Constraint &that) { return def != that.def; } @@ -47,24 +54,9 @@ // description is not provided, returns the TableGen def name. StringRef getSummary() const; - // Constraint kind - enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized }; - Kind getKind() const { return kind; } - /// Get an opaque pointer to the constraint. - const void *getAsOpaquePointer() const { return def; } - /// Construct a constraint from the opaque pointer representation. - static Constraint getFromOpaquePointer(const void *ptr) { - return Constraint(reinterpret_cast(ptr)); - } - - // Return the underlying def. - const llvm::Record *getDef() const { return def; } - protected: - Constraint(Kind kind, const llvm::Record *record); - // The TableGen definition of this constraint. const llvm::Record *def; diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h --- a/mlir/include/mlir/TableGen/Predicate.h +++ b/mlir/include/mlir/TableGen/Predicate.h @@ -53,15 +53,21 @@ // record of type CombinedPred. bool isCombined() const; + // Get the location of the predicate. + ArrayRef getLoc() const; + // Records are pointer-comparable. bool operator==(const Pred &other) const { return def == other.def; } - // Get the location of the predicate. - ArrayRef getLoc() const; + // Return true if the predicate is not null. + operator bool() const { return def; } -protected: - friend llvm::DenseMapInfo; + // Hash a predicate by its pointer value. + friend llvm::hash_code hash_value(Pred pred) { + return llvm::hash_value(pred.def); + } +protected: // The TableGen definition of this predicate. const llvm::Record *def; }; @@ -119,18 +125,4 @@ } // end namespace tblgen } // end namespace mlir -namespace llvm { -template <> -struct DenseMapInfo { - static mlir::tblgen::Pred getEmptyKey() { return mlir::tblgen::Pred(); } - static mlir::tblgen::Pred getTombstoneKey() { return mlir::tblgen::Pred(); } - static unsigned getHashValue(mlir::tblgen::Pred pred) { - return llvm::hash_value(pred.def); - } - static bool isEqual(mlir::tblgen::Pred lhs, mlir::tblgen::Pred rhs) { - return lhs == rhs; - } -}; -} // end namespace llvm - #endif // MLIR_TABLEGEN_PREDICATE_H_ diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -29,8 +29,9 @@ // TableGen. class TypeConstraint : public Constraint { public: - explicit TypeConstraint(const llvm::Record *record); - explicit TypeConstraint(const llvm::DefInit *init); + using Constraint::Constraint; + + TypeConstraint(const llvm::DefInit *record); static bool classof(const Constraint *c) { return c->getKind() == CK_Type; } diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -31,12 +31,6 @@ return {}; } -AttrConstraint::AttrConstraint(const Record *record) - : Constraint(Constraint::CK_Attr, record) { - assert(isSubClassOf("AttrConstraint") && - "must be subclass of TableGen 'AttrConstraint' class"); -} - bool AttrConstraint::isSubClassOf(StringRef className) const { return def->isSubClassOf(className); } diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -17,10 +17,11 @@ using namespace mlir::tblgen; Constraint::Constraint(const llvm::Record *record) - : def(record), kind(CK_Uncategorized) { + : Constraint(record, CK_Uncategorized) { // Look through OpVariable's to their constraint. if (def->isSubClassOf("OpVariable")) def = def->getValueAsDef("constraint"); + if (def->isSubClassOf("TypeConstraint")) { kind = CK_Type; } else if (def->isSubClassOf("AttrConstraint")) { @@ -34,13 +35,6 @@ } } -Constraint::Constraint(Kind kind, const llvm::Record *record) - : def(record), kind(kind) { - // Look through OpVariable's to their constraint. - if (def->isSubClassOf("OpVariable")) - def = def->getValueAsDef("constraint"); -} - Pred Constraint::getPredicate() const { auto *val = def->getValue("predicate"); diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -19,12 +19,6 @@ using namespace mlir; using namespace mlir::tblgen; -TypeConstraint::TypeConstraint(const llvm::Record *record) - : Constraint(Constraint::CK_Type, record) { - assert(def->isSubClassOf("TypeConstraint") && - "must be subclass of TableGen 'TypeConstraint' class"); -} - TypeConstraint::TypeConstraint(const llvm::DefInit *init) : TypeConstraint(init->getDef()) {} diff --git a/mlir/test/mlir-tblgen/constraint-unique.td b/mlir/test/mlir-tblgen/constraint-unique.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/constraint-unique.td @@ -0,0 +1,156 @@ +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +class NS_Op traits = []> : + Op; + +/// Test unique'ing of type, attribute, successor, and region constraints. + +def ATypePred : CPred<"typePred($_self, $_op)">; +def AType : Type; +def OtherType : Type; + +def AnAttrPred : CPred<"attrPred($_self, $_op)">; +def AnAttr : Attr; +def OtherAttr : Attr; + +def ASuccessorPred : CPred<"successorPred($_self, $_op)">; +def ASuccessor : Successor; +def OtherSuccessor : Successor; + +def ARegionPred : CPred<"regionPred($_self, $_op)">; +def ARegion : Region; +def OtherRegion : Region; + +// OpA and OpB have the same type, attribute, successor, and region constraints. + +def OpA : NS_Op<"op_a"> { + let arguments = (ins AType:$a, AnAttr:$b); + let results = (outs AType:$ret); + let successors = (successor ASuccessor:$c); + let regions = (region ARegion:$d); +} + +def OpB : NS_Op<"op_b"> { + let arguments = (ins AType:$a, AnAttr:$b); + let successors = (successor ASuccessor:$c); + let regions = (region ARegion:$d); +} + +// OpC has the same type, attribute, successor, and region predicates but has +// difference descriptions for them. + +def OpC : NS_Op<"op_c"> { + let arguments = (ins OtherType:$a, OtherAttr:$b); + let results = (outs OtherType:$ret); + let successors = (successor OtherSuccessor:$c); + let regions = (region OtherRegion:$d); +} + +/// Test that a type contraint was generated. +// CHECK: static ::mlir::LogicalResult [[$A_TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( +// CHECK: if (!((typePred(type, *op)))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be a type, but got " << type; + +/// Test that duplicate type constraint was not generated. +// CHECK-NOT: << " must be a type, but got " << type; + +/// Test that a type constraint with a different description was generated. +// CHECK: static ::mlir::LogicalResult [[$O_TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( +// CHECK: if (!((typePred(type, *op)))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be another type, but got " << type; + +/// Test that an attribute contraint was generated. +// CHECK: static ::mlir::LogicalResult [[$A_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]]( +// CHECK: if (attr && !((attrPred(attr, *op)))) { +// CHECK-NEXT: return op->emitOpError("attribute '") << attrName +// CHECK-NEXT: << "' failed to satisfy constraint: an attribute"; + +/// Test that duplicate attribute constraint was not generated. +// CHECK-NOT: << "' failed to satisfy constraint: an attribute"; + +/// Test that a attribute constraint with a different description was generated. +// CHECK: static ::mlir::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]]( +// CHECK: if (attr && !((attrPred(attr, *op)))) { +// CHECK-NEXT: return op->emitOpError("attribute '") << attrName +// CHECK-NEXT: << "' failed to satisfy constraint: another attribute"; + +/// Test that a successor contraint was generated. +// CHECK: static ::mlir::LogicalResult [[$A_SUCCESSOR_CONSTRAINT:__mlir_ods_local_successor_constraint.*]]( +// CHECK: if (!((successorPred(successor, *op)))) { +// CHECK-NEXT: return op->emitOpError("successor #") << successorIndex << " ('" +// CHECK-NEXT: << successorName << ")' failed to verify constraint: a successor"; + +/// Test that duplicate successor constraint was not generated. +// CHECK-NOT: << successorName << ")' failed to verify constraint: a successor"; + +/// Test that a successor constraint with a different description was generated. +// CHECK: static ::mlir::LogicalResult [[$O_SUCCESSOR_CONSTRAINT:__mlir_ods_local_successor_constraint.*]]( +// CHECK: if (!((successorPred(successor, *op)))) { +// CHECK-NEXT: return op->emitOpError("successor #") << successorIndex << " ('" +// CHECK-NEXT: << successorName << ")' failed to verify constraint: another successor"; + +/// Test that a region contraint was generated. +// CHECK: static ::mlir::LogicalResult [[$A_REGION_CONSTRAINT:__mlir_ods_local_region_constraint.*]]( +// CHECK: if (!((regionPred(region, *op)))) { +// CHECK-NEXT: return op->emitOpError("region #") << regionIndex +// CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ") +// CHECK-NEXT: << "failed to verify constraint: a region"; + +/// Test that duplicate region constraint was not generated. +// CHECK-NOT: << "failed to verify constraint: a region"; + +/// Test that a region constraint with a different description was generated. +// CHECK: static ::mlir::LogicalResult [[$O_REGION_CONSTRAINT:__mlir_ods_local_region_constraint.*]]( +// CHECK: if (!((regionPred(region, *op)))) { +// CHECK-NEXT: return op->emitOpError("region #") << regionIndex +// CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ") +// CHECK-NEXT: << "failed to verify constraint: another region"; + +/// Test that the uniqued constraints are being used. +// CHECK-LABEL: OpA::verify +// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName()); +// CHECK: if (::mlir::failed([[$A_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b"))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0); +// CHECK: for (auto [[$A_VALUE:.*]] : [[$A_VALUE_GROUP]]) +// CHECK-NEXT: if (::mlir::failed([[$A_TYPE_CONSTRAINT]](*this, [[$A_VALUE]].getType(), "operand", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: auto [[$RET_VALUE_GROUP:.*]] = getODSResults(0); +// CHECK: for (auto [[$RET_VALUE:.*]] : [[$RET_VALUE_GROUP]]) +// CHECK-NEXT: if (::mlir::failed([[$A_TYPE_CONSTRAINT]](*this, [[$RET_VALUE]].getType(), "result", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: for (auto ®ion : ::llvm::makeMutableArrayRef((*this)->getRegion(0))) +// CHECK-NEXT: if (::mlir::failed([[$A_REGION_CONSTRAINT]](*this, region, "d", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: for (auto *successor : ::llvm::makeMutableArrayRef(c())) +// CHECK-NEXT: if (::mlir::failed([[$A_SUCCESSOR_CONSTRAINT]](*this, successor, "c", index++))) +// CHECK-NEXT: return ::mlir::failure(); + +/// Test that the op with the same predicates but different with descriptions +/// uses the different constraints. +// CHECK-LABEL: OpC::verify +// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName()); +// CHECK: if (::mlir::failed([[$O_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b"))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0); +// CHECK: for (auto [[$A_VALUE:.*]] : [[$A_VALUE_GROUP]]) +// CHECK-NEXT: if (::mlir::failed([[$O_TYPE_CONSTRAINT]](*this, [[$A_VALUE]].getType(), "operand", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: auto [[$RET_VALUE_GROUP:.*]] = getODSResults(0); +// CHECK: for (auto [[$RET_VALUE:.*]] : [[$RET_VALUE_GROUP]]) +// CHECK-NEXT: if (::mlir::failed([[$O_TYPE_CONSTRAINT]](*this, [[$RET_VALUE]].getType(), "result", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: for (auto ®ion : ::llvm::makeMutableArrayRef((*this)->getRegion(0))) +// CHECK-NEXT: if (::mlir::failed([[$O_REGION_CONSTRAINT]](*this, region, "d", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: for (auto *successor : ::llvm::makeMutableArrayRef(c())) +// CHECK-NEXT: if (::mlir::failed([[$O_SUCCESSOR_CONSTRAINT]](*this, successor, "c", index++))) +// CHECK-NEXT: return ::mlir::failure(); diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -17,24 +17,28 @@ } // CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK-NEXT: if (!((type.isInteger(32) || type.isF32()))) { -// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; +// CHECK: if (!((type.isInteger(32) || type.isF32()))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be 32-bit integer or floating-point type, but got " << type; // Check there is no verifier with same predicate generated. // CHECK-NOT: if (!((type.isInteger(32) || type.isF32()))) { -// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; +// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NOT. << " must be 32-bit integer or floating-point type, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return (true); }(type.cast<::mlir::ShapedType>().getElementType())))) { -// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type; +// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return (true); }(type.cast<::mlir::ShapedType>().getElementType())))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be tensor of any type values, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(type.cast<::mlir::ShapedType>().getElementType())))) { -// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; +// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(type.cast<::mlir::ShapedType>().getElementType())))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; // CHECK-LABEL: OpA::verify // CHECK: auto valueGroup0 = getODSOperands(0); -// CHECK: for (::mlir::Value v : valueGroup0) { +// CHECK: for (auto v : valueGroup0) { // CHECK: if (::mlir::failed([[$INTEGER_FLOAT_CONSTRAINT]] def OpB : NS_Op<"op_for_And_PredOpTrait", [ @@ -109,7 +113,7 @@ // CHECK-LABEL: OpK::verify // CHECK: auto valueGroup0 = getODSOperands(0); -// CHECK: for (::mlir::Value v : valueGroup0) { +// CHECK: for (auto v : valueGroup0) { // CHECK: if (::mlir::failed([[$TENSOR_INTEGER_FLOAT_CONSTRAINT]] def OpL : NS_Op<"op_for_StringEscaping", []> { diff --git a/mlir/test/mlir-tblgen/rewriter-static-matcher.td b/mlir/test/mlir-tblgen/rewriter-static-matcher.td --- a/mlir/test/mlir-tblgen/rewriter-static-matcher.td +++ b/mlir/test/mlir-tblgen/rewriter-static-matcher.td @@ -37,11 +37,13 @@ // Test static matcher for duplicate DagNode // --- -// CHECK-DAG: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Type typeOrAttr}} -// CHECK-DAG: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Attribute}} -// CHECK-DAG: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]]( -// CHECK: if(failed([[$TYPE_CONSTRAINT]] +// CHECK: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( +// CHECK-NEXT: {{.*::mlir::Type type}} +// CHECK: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]]( +// CHECK-NEXT: {{.*::mlir::Attribute attr}} +// CHECK: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]]( // CHECK: if(failed([[$ATTR_CONSTRAINT]] +// CHECK: if(failed([[$TYPE_CONSTRAINT]] // CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)), diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp --- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp @@ -13,6 +13,7 @@ #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/Pattern.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Path.h" @@ -22,43 +23,9 @@ using namespace mlir; using namespace mlir::tblgen; -StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( - const llvm::RecordKeeper &records) - : uniqueOutputLabel(getUniqueName(records)) {} - -StaticVerifierFunctionEmitter & -StaticVerifierFunctionEmitter::setSelf(StringRef str) { - fctx.withSelf(str); - return *this; -} - -StaticVerifierFunctionEmitter & -StaticVerifierFunctionEmitter::setBuilder(StringRef str) { - fctx.withBuilder(str); - return *this; -} - -void StaticVerifierFunctionEmitter::emitConstraintMethodsInNamespace( - StringRef signatureFormat, StringRef errorHandlerFormat, - StringRef cppNamespace, ArrayRef constraints, raw_ostream &os, - bool emitDecl) { - llvm::Optional namespaceEmitter; - if (!emitDecl) - namespaceEmitter.emplace(os, cppNamespace); - - emitConstraintMethods(signatureFormat, errorHandlerFormat, constraints, os, - emitDecl); -} - -StringRef StaticVerifierFunctionEmitter::getConstraintFn( - const Constraint &constraint) const { - auto it = localTypeConstraints.find(constraint.getAsOpaquePointer()); - assert(it != localTypeConstraints.end() && "expected valid constraint fn"); - return it->second; -} - -std::string StaticVerifierFunctionEmitter::getUniqueName( - const llvm::RecordKeeper &records) { +/// Generate a unique label based on the current file name to prevent name +/// collisions if multiple generated files are included at once. +static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { // Use the input file name when generating a unique name. std::string inputFilename = records.getInputFilename(); @@ -77,66 +44,306 @@ return uniqueName; } -void StaticVerifierFunctionEmitter::emitConstraintMethods( - StringRef signatureFormat, StringRef errorHandlerFormat, - ArrayRef constraints, raw_ostream &rawOs, bool emitDecl) { - raw_indented_ostream os(rawOs); - - // Record the mapping from predicate to constraint. If two constraints has the - // same predicate and constraint summary, they can share the same verification - // function. - llvm::DenseMap predToConstraint; - for (auto it : llvm::enumerate(constraints)) { - std::string name; - Constraint constraint = Constraint::getFromOpaquePointer(it.value()); - Pred pred = constraint.getPredicate(); - auto iter = predToConstraint.find(pred); - if (iter != predToConstraint.end()) { - do { - Constraint built = Constraint::getFromOpaquePointer(iter->second); - // We may have the different constraints but have the same predicate, - // for example, ConstraintA and Variadic, note that - // Variadic<> doesn't introduce new predicate. In this case, we can - // share the same predicate function if they also have consistent - // summary, otherwise we may report the wrong message while verification - // fails. - if (constraint.getSummary() == built.getSummary()) { - name = getConstraintFn(built).str(); - break; - } - ++iter; - } while (iter != predToConstraint.end() && iter->first == pred); - } +StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( + raw_ostream &os, const llvm::RecordKeeper &records) + : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {} + +void StaticVerifierFunctionEmitter::emitOpConstraints( + ArrayRef opDefs, bool emitDecl) { + collectOpConstraints(opDefs); + if (emitDecl) + return; + + NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); + emitTypeConstraints(); + emitAttrConstraints(); + emitSuccessorConstraints(); + emitRegionConstraints(); +} + +void StaticVerifierFunctionEmitter::emitPatternConstraints( + const DenseSet &constraints) { + collectPatternConstraints(constraints); + emitPatternConstraints(); +} + +//===----------------------------------------------------------------------===// +// Constraint Getters + +StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( + const Constraint &constraint) const { + auto it = typeConstraints.find(constraint); + assert(it != typeConstraints.end() && "expected to find a type constraint"); + return it->second; +} + +// Find a uniqued attribute constraint. Since not all attribute constraints can +// be uniqued, return None if one was not found. +Optional StaticVerifierFunctionEmitter::getAttrConstraintFn( + const Constraint &constraint) const { + auto it = attrConstraints.find(constraint); + return it == attrConstraints.end() ? Optional() + : StringRef(it->second); +} + +StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn( + const Constraint &constraint) const { + auto it = successorConstraints.find(constraint); + assert(it != successorConstraints.end() && + "expected to find a sucessor constraint"); + return it->second; +} + +StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn( + const Constraint &constraint) const { + auto it = regionConstraints.find(constraint); + assert(it != regionConstraints.end() && + "expected to find a region constraint"); + return it->second; +} + +//===----------------------------------------------------------------------===// +// Constraint Emission + +/// Code templates for emitting type, attribute, successor, and region +/// constraints. Each of these templates require the following arguments: +/// +/// {0}: The unique constraint name. +/// {1}: The constraint code. +/// {2}: The constraint description. + +/// Code for a type constraint. These may be called on the type of either +/// operands or results. +static const char *const typeConstraintCode = R"( +static ::mlir::LogicalResult {0}( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!({1})) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be {2}, but got " << type; + } + return ::mlir::success(); +} +)"; + +/// Code for an attribute constraint. These may be called from ops only. +/// Attribute constraints cannot reference anything other than `$_self` and +/// `$_op`. +/// +/// TODO: Unique constraints for adaptors. However, most Adaptor::verify +/// functions are stripped anyways. +static const char *const attrConstraintCode = R"( +static ::mlir::LogicalResult {0}( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + if (attr && !({1})) { + return op->emitOpError("attribute '") << attrName + << "' failed to satisfy constraint: {2}"; + } + return ::mlir::success(); +} +)"; - if (!name.empty()) { - localTypeConstraints.try_emplace(it.value(), name); - continue; +/// Code for a successor constraint. +static const char *const successorConstraintCode = R"( +static ::mlir::LogicalResult {0}( + ::mlir::Operation *op, ::mlir::Block *successor, + ::llvm::StringRef successorName, unsigned successorIndex) { + if (!({1})) { + return op->emitOpError("successor #") << successorIndex << " ('" + << successorName << ")' failed to verify constraint: {2}"; + } + return ::mlir::success(); +} +)"; + +/// Code for a region constraint. Callers will need to pass in the region's name +/// for emitting an error message. +static const char *const regionConstraintCode = R"( +static ::mlir::LogicalResult {0}( + ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, + unsigned regionIndex) { + if (!({1})) { + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: {2}"; + } + return ::mlir::success(); +} +)"; + +/// Code for a pattern type or attribute constraint. +/// +/// {3}: "Type type" or "Attribute attr". +static const char *const patternAttrOrTypeConstraintCode = R"( +static ::mlir::LogicalResult {0}( + ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3}, + ::llvm::StringRef failureStr) { + if (!({1})) { + return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { + diag << failureStr << ": {2}"; + }); + } + return ::mlir::success(); +} +)"; + +void StaticVerifierFunctionEmitter::emitConstraints( + const ConstraintMap &constraints, StringRef selfName, + const char *const codeTemplate) { + FmtContext ctx; + ctx.withOp("*op").withSelf(selfName); + for (auto &it : constraints) { + os << formatv(codeTemplate, it.second, + tgfmt(it.first.getConditionTemplate(), &ctx), + it.first.getSummary()); + } +} + +void StaticVerifierFunctionEmitter::emitTypeConstraints() { + emitConstraints(typeConstraints, "type", typeConstraintCode); +} + +void StaticVerifierFunctionEmitter::emitAttrConstraints() { + emitConstraints(attrConstraints, "attr", attrConstraintCode); +} + +void StaticVerifierFunctionEmitter::emitSuccessorConstraints() { + emitConstraints(successorConstraints, "successor", successorConstraintCode); +} + +void StaticVerifierFunctionEmitter::emitRegionConstraints() { + emitConstraints(regionConstraints, "region", regionConstraintCode); +} + +void StaticVerifierFunctionEmitter::emitPatternConstraints() { + FmtContext ctx; + ctx.withOp("*op").withBuilder("rewriter").withSelf("type"); + for (auto &it : typeConstraints) { + os << formatv(patternAttrOrTypeConstraintCode, it.second, + tgfmt(it.first.getConditionTemplate(), &ctx), + it.first.getSummary(), "Type type"); + } + ctx.withSelf("attr"); + for (auto &it : attrConstraints) { + os << formatv(patternAttrOrTypeConstraintCode, it.second, + tgfmt(it.first.getConditionTemplate(), &ctx), + it.first.getSummary(), "Attribute attr"); + } +} + +//===----------------------------------------------------------------------===// +// Constraint Uniquing + +using RecordDenseMapInfo = llvm::DenseMapInfo; + +Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() { + return Constraint(RecordDenseMapInfo::getEmptyKey(), + Constraint::CK_Uncategorized); +} + +Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() { + return Constraint(RecordDenseMapInfo::getTombstoneKey(), + Constraint::CK_Uncategorized); +} + +unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue( + Constraint constraint) { + if (constraint == getEmptyKey()) + return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey()); + if (constraint == getTombstoneKey()) { + return RecordDenseMapInfo::getHashValue( + RecordDenseMapInfo::getTombstoneKey()); + } + return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary()); +} + +bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs, + Constraint rhs) { + if (lhs == rhs) + return true; + if (lhs == getEmptyKey() || lhs == getTombstoneKey()) + return false; + if (rhs == getEmptyKey() || rhs == getTombstoneKey()) + return false; + return lhs.getPredicate() == rhs.getPredicate() && + lhs.getSummary() == rhs.getSummary(); +} + +/// An attribute constraint that references anything other than itself and the +/// current op cannot be generically extracted into a function. Most +/// prohibitive are operands and results, which require calls to +/// `getODSOperands` or `getODSResults`. Attribute references are tricky too +/// because ops use cached identifiers. +static bool canUniqueAttrConstraint(Attribute attr) { + FmtContext ctx; + auto test = + tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op")) + .str(); + return !StringRef(test).contains(""); +} + +std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind, + unsigned index) { + return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel + + Twine(index)) + .str(); +} + +void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map, + StringRef kind, + Constraint constraint) { + auto it = map.find(constraint); + if (it == map.end()) + map.insert({constraint, getUniqueName(kind, map.size())}); +} + +void StaticVerifierFunctionEmitter::collectOpConstraints( + ArrayRef opDefs) { + const auto collectTypeConstraints = [&](Operator::value_range values) { + for (const NamedTypeConstraint &value : values) + if (value.hasPredicate()) + collectConstraint(typeConstraints, "type", value.constraint); + }; + + for (Record *def : opDefs) { + Operator op(*def); + /// Collect type constraints. + collectTypeConstraints(op.getOperands()); + collectTypeConstraints(op.getResults()); + /// Collect attribute constraints. + for (const NamedAttribute &namedAttr : op.getAttributes()) { + if (!namedAttr.attr.getPredicate().isNull() && + canUniqueAttrConstraint(namedAttr.attr)) + collectConstraint(attrConstraints, "attr", namedAttr.attr); + } + /// Collect successor constraints. + for (const NamedSuccessor &successor : op.getSuccessors()) { + if (!successor.constraint.getPredicate().isNull()) { + collectConstraint(successorConstraints, "successor", + successor.constraint); + } } + /// Collect region constraints. + for (const NamedRegion ®ion : op.getRegions()) + if (!region.constraint.getPredicate().isNull()) + collectConstraint(regionConstraints, "region", region.constraint); + } +} - // Generate an obscure and unique name for this type constraint. - name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel + - Twine(it.index())) - .str(); - predToConstraint.insert( - std::make_pair(constraint.getPredicate(), it.value())); - localTypeConstraints.try_emplace(it.value(), name); - - // Only generate the methods if we are generating definitions. - if (emitDecl) - continue; - - os << formatv(signatureFormat.data(), name) << " {\n"; - os.indent() << "if (!(" << tgfmt(constraint.getConditionTemplate(), &fctx) - << ")) {\n"; - os.indent() << "return " - << formatv(errorHandlerFormat.data(), - escapeString(constraint.getSummary())) - << ";\n"; - os.unindent() << "}\nreturn ::mlir::success();\n"; - os.unindent() << "}\n\n"; +void StaticVerifierFunctionEmitter::collectPatternConstraints( + const DenseSet &constraints) { + for (auto &leaf : constraints) { + assert(leaf.isOperandMatcher() || leaf.isAttrMatcher()); + collectConstraint( + leaf.isOperandMatcher() ? typeConstraints : attrConstraints, + leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint()); } } +//===----------------------------------------------------------------------===// +// Public Utility Functions +//===----------------------------------------------------------------------===// + std::string mlir::tblgen::escapeString(StringRef value) { std::string ret; llvm::raw_string_ostream os(ret); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -127,14 +127,6 @@ std::next({0}, valueRange.first + valueRange.second)}; )"; -static const char *const typeVerifierSignature = - "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type " - "type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)"; - -static const char *const typeVerifierErrorHandler = - " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must " - "be {0}, but got \" << type"; - static const char *const opCommentHeader = R"( //===----------------------------------------------------------------------===// // {0} {1} @@ -477,29 +469,42 @@ // Generate attribute verification. If an op instance is not available, then // attribute checks that require one will not be emitted. -static void genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, - FmtContext &ctx, OpMethodBody &body) { +static void genAttributeVerifier( + const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, OpMethodBody &body, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) { // Check that a required attribute exists. // // {0}: Attribute variable name. // {1}: Emit error prefix. // {2}: Attribute name. - const char *const checkRequiredAttr = R"( + const char *const verifyRequiredAttr = R"( if (!{0}) return {1}"requires attribute '{2}'"); - )"; - // Check the condition on an attribute if it is required. This assumes that - // default values are valid. +)"; + // Verify the attribute if it is present. This assumes that default values + // are valid. This code snippet pastes the condition inline. + // // TODO: verify the default value is valid (perhaps in debug mode only). // // {0}: Attribute variable name. // {1}: Attribute condition code. // {2}: Emit error prefix. - // {3}: Attribute/constraint description. - const char *const checkAttrCondition = R"( + // {3}: Attribute name. + // {4}: Attribute/constraint description. + const char *const verifyAttrInline = R"( if ({0} && !({1})) return {2}"attribute '{3}' failed to satisfy constraint: {4}"); - )"; +)"; + // Verify the attribute using a uniqued constraint. Can only be used within + // the context of an op. + // + // {0}: Unique constraint name. + // {1}: Attribute variable name. + // {2}: Attribute name. + const char *const verifyAttrUnique = R"( + if (::mlir::failed({0}(*this, {1}, "{2}"))) + return ::mlir::failure(); +)"; for (const auto &namedAttr : emitHelper.getOp().getAttributes()) { const auto &attr = namedAttr.attr; @@ -513,7 +518,8 @@ // If the attribute's condition needs an op but none is available, then the // condition cannot be emitted. bool canEmitCondition = - !StringRef(condition).contains("$_op") || emitHelper.isEmittingForOp(); + !condition.empty() && (!StringRef(condition).contains("$_op") || + emitHelper.isEmittingForOp()); // Prefix with `tblgen_` to avoid hiding the attribute accessor. Twine varName = tblgenNamePrefix + attrName; @@ -527,16 +533,22 @@ emitHelper.getAttr(attrName)); if (!allowMissingAttr) { - body << formatv(checkRequiredAttr, varName, emitHelper.emitErrorPrefix(), + body << formatv(verifyRequiredAttr, varName, emitHelper.emitErrorPrefix(), attrName); } if (canEmitCondition) { - body << formatv(checkAttrCondition, varName, - tgfmt(condition, &ctx.withSelf(varName)), - emitHelper.emitErrorPrefix(), attrName, - escapeString(attr.getSummary())); + Optional constraintFn; + if (emitHelper.isEmittingForOp() && + (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) { + body << formatv(verifyAttrUnique, *constraintFn, varName, attrName); + } else { + body << formatv(verifyAttrInline, varName, + tgfmt(condition, &ctx.withSelf(varName)), + emitHelper.emitErrorPrefix(), attrName, + escapeString(attr.getSummary())); + } } - body << "}\n"; + body << " }\n"; } } @@ -2209,7 +2221,7 @@ bool hasCustomVerify = stringInit && !stringInit->getValue().empty(); populateSubstitutions(emitHelper, verifyCtx); - genAttributeVerifier(emitHelper, verifyCtx, body); + genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); genOperandResultVerifier(body, op.getOperands(), "operand"); genOperandResultVerifier(body, op.getResults(), "result"); @@ -2238,10 +2250,38 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body, Operator::value_range values, StringRef valueKind) { + // Check that an optional value is at most 1 element. + // + // {0}: Value index. + // {1}: "operand" or "result" + const char *const verifyOptional = R"( + if (valueGroup{0}.size() > 1) { + return emitOpError("{1} group starting at #") << index + << " requires 0 or 1 element, but found " << valueGroup{0}.size(); + } +)"; + // Check the types of a range of values. + // + // {0}: Value index. + // {1}: Type constraint function. + // {2}: "operand" or "result" + const char *const verifyValues = R"( + for (auto v : valueGroup{0}) { + if (::mlir::failed({1}(*this, v.getType(), "{2}", index++))) + return ::mlir::failure(); + } +)"; + + const auto canSkip = [](const NamedTypeConstraint &value) { + return !value.hasPredicate() && !value.isOptional() && + !value.isVariadicOfVariadic(); + }; + if (values.empty() || llvm::all_of(values, canSkip)) + return; + FmtContext fctx; - body << " {\n"; - body << " unsigned index = 0; (void)index;\n"; + body << " {\n unsigned index = 0; (void)index;\n"; for (auto staticValue : llvm::enumerate(values)) { const NamedTypeConstraint &value = staticValue.value(); @@ -2259,11 +2299,7 @@ // If the constraint is optional check that the value group has at most 1 // value. if (isOptional) { - body << formatv(" if (valueGroup{0}.size() > 1)\n" - " return emitOpError(\"{1} group starting at #\") " - "<< index << \" requires 0 or 1 element, but found \" << " - "valueGroup{0}.size();\n", - staticValue.index(), valueKind); + body << formatv(verifyOptional, staticValue.index(), valueKind); } else if (isVariadicOfVariadic) { body << formatv( " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr(" @@ -2278,93 +2314,89 @@ continue; // Emit a loop to check all the dynamic values in the pack. StringRef constraintFn = - staticVerifierEmitter.getConstraintFn(value.constraint); - body << " for (::mlir::Value v : valueGroup" << staticValue.index() - << ") {\n" - << " if (::mlir::failed(" << constraintFn - << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n" - << " return ::mlir::failure();\n" - << " ++index;\n" - << " }\n"; + staticVerifierEmitter.getTypeConstraintFn(value.constraint); + body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind); } body << " }\n"; } void OpEmitter::genRegionVerifier(OpMethodBody &body) { + /// Code to verify a region. + /// + /// {0}: Getter for the regions. + /// {1}: The region constraint. + /// {2}: The region's name. + /// {3}: The region description. + const char *const verifyRegion = R"( + for (auto ®ion : {0}) + if (::mlir::failed({1}(*this, region, "{2}", index++))) + return ::mlir::failure(); +)"; + /// Get a single region. + /// + /// {0}: The region's index. + const char *const getSingleRegion = + "::llvm::makeMutableArrayRef((*this)->getRegion({0}))"; + // If we have no regions, there is nothing more to do. - unsigned numRegions = op.getNumRegions(); - if (numRegions == 0) + const auto canSkip = [](const NamedRegion ®ion) { + return region.constraint.getPredicate().isNull(); + }; + auto regions = op.getRegions(); + if (regions.empty() && llvm::all_of(regions, canSkip)) return; - body << "{\n"; - body << " unsigned index = 0; (void)index;\n"; - - for (unsigned i = 0; i < numRegions; ++i) { - const auto ®ion = op.getRegion(i); - if (region.constraint.getPredicate().isNull()) + body << " {\n unsigned index = 0; (void)index;\n"; + for (auto it : llvm::enumerate(regions)) { + const auto ®ion = it.value(); + if (canSkip(region)) continue; - body << " for (::mlir::Region ®ion : "; - body << formatv(region.isVariadic() - ? "{0}()" - : "::mlir::MutableArrayRef<::mlir::Region>((*this)" - "->getRegion({1}))", - op.getGetterName(region.name), i); - body << ") {\n"; - auto constraint = tgfmt(region.constraint.getConditionTemplate(), - &verifyCtx.withSelf("region")) - .str(); - - body << formatv(" (void)region;\n" - " if (!({0})) {\n " - "return emitOpError(\"region #\") << index << \" {1}" - "failed to " - "verify constraint: {2}\";\n }\n", - constraint, - region.name.empty() ? "" : "('" + region.name + "') ", - region.constraint.getSummary()) - << " ++index;\n" - << " }\n"; + auto getRegion = region.isVariadic() + ? formatv("{0}()", op.getGetterName(region.name)).str() + : formatv(getSingleRegion, it.index()).str(); + auto constraintFn = + staticVerifierEmitter.getRegionConstraintFn(region.constraint); + body << formatv(verifyRegion, getRegion, constraintFn, region.name); } body << " }\n"; } void OpEmitter::genSuccessorVerifier(OpMethodBody &body) { + const char *const verifySuccessor = R"( + for (auto *successor : {0}) + if (::mlir::failed({1}(*this, successor, "{2}", index++))) + return ::mlir::failure(); +)"; + /// Get a single successor. + /// + /// {0}: The successor's name. + const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())"; + // If we have no successors, there is nothing more to do. - unsigned numSuccessors = op.getNumSuccessors(); - if (numSuccessors == 0) + const auto canSkip = [](const NamedSuccessor &successor) { + return successor.constraint.getPredicate().isNull(); + }; + auto successors = op.getSuccessors(); + if (successors.empty() && llvm::all_of(successors, canSkip)) return; - body << "{\n"; - body << " unsigned index = 0; (void)index;\n"; + body << " {\n unsigned index = 0; (void)index;\n"; - for (unsigned i = 0; i < numSuccessors; ++i) { - const auto &successor = op.getSuccessor(i); - if (successor.constraint.getPredicate().isNull()) + for (auto it : llvm::enumerate(successors)) { + const auto &successor = it.value(); + if (canSkip(successor)) continue; - if (successor.isVariadic()) { - body << formatv(" for (::mlir::Block *successor : {0}()) {\n", - successor.name); - } else { - body << " {\n"; - body << formatv(" ::mlir::Block *successor = {0}();\n", - successor.name); - } - auto constraint = tgfmt(successor.constraint.getConditionTemplate(), - &verifyCtx.withSelf("successor")) - .str(); - - body << formatv(" (void)successor;\n" - " if (!({0})) {\n " - "return emitOpError(\"successor #\") << index << \"('{1}') " - "failed to " - "verify constraint: {2}\";\n }\n", - constraint, successor.name, - successor.constraint.getSummary()) - << " ++index;\n" - << " }\n"; + auto getSuccessor = + formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor, + successor.name, it.index()) + .str(); + auto constraintFn = + staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint); + body << formatv(verifySuccessor, getSuccessor, constraintFn, + successor.name); } body << " }\n"; } @@ -2504,11 +2536,16 @@ // getters identical to those defined in the Op. class OpOperandAdaptorEmitter { public: - static void emitDecl(const Operator &op, raw_ostream &os); - static void emitDef(const Operator &op, raw_ostream &os); + static void emitDecl(const Operator &op, + StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os); + static void emitDef(const Operator &op, + StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os); private: - explicit OpOperandAdaptorEmitter(const Operator &op); + explicit OpOperandAdaptorEmitter( + const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter); // Add verification function. This generates a verify method for the adaptor // which verifies all the op-independent attribute constraints. @@ -2516,11 +2553,14 @@ const Operator &op; Class adaptor; + StaticVerifierFunctionEmitter &staticVerifierEmitter; }; } // end anonymous namespace -OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) - : op(op), adaptor(op.getAdaptorName()) { +OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( + const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter) + : op(op), adaptor(op.getAdaptorName()), + staticVerifierEmitter(staticVerifierEmitter) { adaptor.newField("::mlir::ValueRange", "odsOperands"); adaptor.newField("::mlir::DictionaryAttr", "odsAttrs"); adaptor.newField("::mlir::RegionRange", "odsRegions"); @@ -2644,17 +2684,21 @@ FmtContext verifyCtx; populateSubstitutions(emitHelper, verifyCtx); - genAttributeVerifier(emitHelper, verifyCtx, body); + genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); body << " return ::mlir::success();"; } -void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { - OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os); +void OpOperandAdaptorEmitter::emitDecl( + const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os) { + OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os); } -void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { - OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os); +void OpOperandAdaptorEmitter::emitDef( + const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os) { + OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os); } // Emits the opcode enum and op classes. @@ -2679,27 +2723,9 @@ return; // Generate all of the locally instantiated methods first. - StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper); + StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); - staticVerifierEmitter.setSelf("type"); - - // Collect a set of all of the used type constraints within the operation - // definitions. - llvm::SetVector typeConstraints; - for (Record *def : defs) { - Operator op(*def); - for (NamedTypeConstraint &operand : op.getOperands()) - if (operand.hasPredicate()) - typeConstraints.insert(operand.constraint.getAsOpaquePointer()); - for (NamedTypeConstraint &result : op.getResults()) - if (result.hasPredicate()) - typeConstraints.insert(result.constraint.getAsOpaquePointer()); - } - - staticVerifierEmitter.emitConstraintMethodsInNamespace( - typeVerifierSignature, typeVerifierErrorHandler, - Operator(*defs[0]).getCppNamespace(), typeConstraints.getArrayRef(), os, - emitDecl); + staticVerifierEmitter.emitOpConstraints(defs, emitDecl); for (auto *def : defs) { Operator op(*def); @@ -2708,7 +2734,7 @@ NamespaceEmitter emitter(os, op.getCppNamespace()); os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); - OpOperandAdaptorEmitter::emitDecl(op, os); + OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os); OpEmitter::emitDecl(op, os, staticVerifierEmitter); } // Emit the TypeID explicit specialization to have a single definition. @@ -2719,7 +2745,7 @@ { NamespaceEmitter emitter(os, op.getCppNamespace()); os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); - OpOperandAdaptorEmitter::emitDef(op, os); + OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os); OpEmitter::emitDef(op, os, staticVerifierEmitter); } // Emit the TypeID explicit specialization to have a single definition. diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -42,23 +42,6 @@ #define DEBUG_TYPE "mlir-tblgen-rewritergen" -// The signature of static type verification function -static const char *typeVerifierSignature = - "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " - "::mlir::Operation *op, ::mlir::Type typeOrAttr, " - "::llvm::StringRef failureStr)"; - -// The signature of static attribute verification function -static const char *attrVerifierSignature = - "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " - "::mlir::Operation *op, ::mlir::Attribute typeOrAttr, " - "::llvm::StringRef failureStr)"; - -// The template of error handler in static type/attribute verification function -static const char *verifierErrorHandler = - "rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {\n diag " - "<< failureStr << \": {0}\";\n});"; - namespace llvm { template <> struct format_provider { @@ -273,7 +256,7 @@ // inlining them. class StaticMatcherHelper { public: - StaticMatcherHelper(const RecordKeeper &recordKeeper, + StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper, RecordOperatorMap &mapper); // Determine if we should inline the match logic or delegate to a static @@ -289,7 +272,7 @@ } // Get the name of static type/attribute verification function. - StringRef getVerifierName(Constraint constraint); + StringRef getVerifierName(DagLeaf leaf); // Collect the `Record`s, i.e., the DRR, so that we can get the information of // the duplicated DAGs. @@ -541,7 +524,7 @@ self = argName; else self = formatv("{0}.getType()", argName); - StringRef verifier = staticMatcherHelper.getVerifierName(constraint); + StringRef verifier = staticMatcherHelper.getVerifierName(leaf); emitStaticVerifierCall( verifier, opName, self, formatv("\"operand {0} of native code call '{1}' failed to satisfy " @@ -684,7 +667,7 @@ PrintFatalError(loc, error); } auto self = formatv("(*{0}.begin()).getType()", operandName); - StringRef verifier = staticMatcherHelper.getVerifierName(constraint); + StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher); emitStaticVerifierCall( verifier, opName, self.str(), formatv( @@ -809,8 +792,7 @@ // If a constraint is specified, we need to generate function call to its // static verifier. - StringRef verifier = - staticMatcherHelper.getVerifierName(matcher.getAsConstraint()); + StringRef verifier = staticMatcherHelper.getVerifierName(matcher); emitStaticVerifierCall( verifier, opName, "tblgen_attr", formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " @@ -1690,9 +1672,10 @@ } } -StaticMatcherHelper::StaticMatcherHelper(const RecordKeeper &recordKeeper, +StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os, + const RecordKeeper &recordKeeper, RecordOperatorMap &mapper) - : opMap(mapper), staticVerifierEmitter(recordKeeper) {} + : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {} void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { // PatternEmitter will use the static matcher if there's one generated. To @@ -1713,28 +1696,7 @@ } void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) { - llvm::SetVector typeConstraints; - llvm::SetVector attrConstraints; - for (DagLeaf leaf : constraints) { - if (leaf.isOperandMatcher()) { - typeConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer()); - } else { - assert(leaf.isAttrMatcher()); - attrConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer()); - } - } - - staticVerifierEmitter.setBuilder("rewriter").setSelf("typeOrAttr"); - - staticVerifierEmitter.emitConstraintMethods(typeVerifierSignature, - verifierErrorHandler, - typeConstraints.getArrayRef(), os, - /*emitDecl=*/false); - - staticVerifierEmitter.emitConstraintMethods(attrVerifierSignature, - verifierErrorHandler, - attrConstraints.getArrayRef(), os, - /*emitDecl=*/false); + staticVerifierEmitter.emitPatternConstraints(constraints); } void StaticMatcherHelper::addPattern(Record *record) { @@ -1765,8 +1727,15 @@ dfs(pat.getSourcePattern()); } -StringRef StaticMatcherHelper::getVerifierName(Constraint constraint) { - return staticVerifierEmitter.getConstraintFn(constraint); +StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { + if (leaf.isAttrMatcher()) { + Optional constraint = + staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint()); + assert(constraint.hasValue() && "attribute constraint was not uniqued"); + return *constraint; + } + assert(leaf.isOperandMatcher()); + return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint()); } static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { @@ -1779,7 +1748,7 @@ // Exam all the patterns and generate static matcher for the duplicated // DagNode. - StaticMatcherHelper staticMatcher(recordKeeper, recordOpMap); + StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap); for (Record *p : patterns) staticMatcher.addPattern(p); staticMatcher.populateStaticConstraintFunctions(os);