diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -16,6 +16,7 @@ #include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/Dialect.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -94,46 +95,63 @@ StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records, raw_ostream &os); - /// Emit the static verifier functions for `llvm::Record`s. The - /// `signatureFormat` describes the required arguments and it must have a - /// placeholder for function name. - /// Example, - /// const char *typeVerifierSignature = - /// "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type" - /// " type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)"; + /// Collect and unique all compatible type, attribute, successor, and region + /// constraints in the file and emit them at the top of the generated file. /// - /// `errorHandlerFormat` describes the error message to return. It may have a - /// placeholder for the summary of Constraint and bring more information for - /// the error message. - /// Example, - /// const char *typeVerifierErrorHandler = - /// " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << " - /// "\" must be {0}, but got \" << type"; - /// - /// `typeArgName` is used to identify the argument that needs to check its - /// type. The constraint template will replace `$_self` with it. - void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat, - StringRef typeArgName, ArrayRef opDefs, - bool emitDecl); + /// Constraints that do not meet the restriction that they can only reference + /// `$_self` and `$_op` are not uniqued. + void emitLocalFunctions(ArrayRef opDefs, bool emitDecl); /// Get the name of the local function used for the given type constraint. /// These functions are used for operand and result constraints and have the /// form: + /// /// LogicalResult(Operation *op, Type type, StringRef valueKind, - /// unsigned valueGroupStartIndex); + /// unsigned valueIndex); + /// StringRef getTypeConstraintFn(const Constraint &constraint) const; + /// Get the name of the local function used for the given attribute + /// constraint. These functions are in the form: + /// + /// LogicalResult(Operation *op, Attribute attr, StringRef attrName); + /// + /// If a uniqued constraint was not found, this function returns None. The + /// uniqued constraints cannot be used in the context of an OpAdaptor. + Optional getAttrConstraintFn(const Constraint &constraint) const; + + /// Get the name of the local function used for the given successor + /// constraint. These functions are in the form: + /// + /// LogicalResult(Operation *op, Block *successor, StringRef successorName, + /// unsigned successorIndex); + /// + StringRef getSuccessorConstraintFn(const Constraint &constraint) const; + + /// Get the name of the local function used for the given region constraint. + /// These functions are in the form: + /// + /// LogicalResult(Operation *op, Region ®ion, StringRef regionName, + /// unsigned regionIndex); + /// + /// The region name may be empty. + StringRef getRegionConstraintFn(const Constraint &constraint) const; + private: /// Returns a unique name to use when generating local methods. static std::string getUniqueName(const llvm::RecordKeeper &records); - /// Emit local methods for the type constraints used within the provided op - /// definitions. - void emitTypeConstraintMethods(StringRef signatureFormat, - StringRef errorHandlerFormat, - StringRef typeArgName, - ArrayRef opDefs, - bool emitDecl); + /// Emit local type constraint functions. + void emitTypeConstraints(); + /// Emit local attribute constraint functions. + void emitAttrConstraints(); + /// Emit local successor constraint functions. + void emitSuccessorConstraints(); + /// Emit local region constraint functions. + void emitRegionConstraints(); + + /// Collect and unique all the constraints used by operations. + void collectAllConstraints(ArrayRef opDefs); raw_indented_ostream os; @@ -141,9 +159,33 @@ /// ensure that the local functions have a unique name. std::string uniqueOutputLabel; - /// A set of functions implementing type constraints, used for operand and - /// result verification. - llvm::DenseMap localTypeConstraints; + /// Unique constraints by their predicate and summary. Constraints that share + /// the same predicate may have different descriptions; ensure that the + /// correct error message is reported when verification fails. + struct ConstraintUniquer { + static Constraint getEmptyKey(); + static Constraint getTombstoneKey(); + static unsigned getHashValue(Constraint constraint); + static bool isEqual(Constraint lhs, Constraint rhs); + }; + /// Use a MapVector to ensure that functions are generated deterministically. + using ConstraintMap = + llvm::MapVector>; + + /// A generic function to emit constraints. + void emitConstraints(const ConstraintMap &constraints, StringRef selfName, + const char *const codeTemplate); + + /// The set of type constraints used for operand and result verification in + /// the current file. + ConstraintMap localTypeConstraints; + /// The set of attribute constraints used in the current file. + ConstraintMap localAttrConstraints; + /// The set of successor constraints used in the current file. + ConstraintMap localSuccessorConstraints; + /// The set of region constraints used in the current file. + ConstraintMap localRegionConstraints; }; // Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar. diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h --- a/mlir/include/mlir/TableGen/Constraint.h +++ b/mlir/include/mlir/TableGen/Constraint.h @@ -29,6 +29,7 @@ // TableGen. class Constraint { public: + Constraint() : def(nullptr) {} Constraint(const llvm::Record *record); bool operator==(const Constraint &that) { return def == that.def; } @@ -52,13 +53,6 @@ Kind getKind() const { return kind; } - /// Get an opaque pointer to the constraint. - const void *getAsOpaquePointer() const { return def; } - /// Construct a constraint from the opaque pointer representation. - static Constraint getFromOpaquePointer(const void *ptr) { - return Constraint(reinterpret_cast(ptr)); - } - // Return the underlying def. const llvm::Record *getDef() const { return def; } diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h --- a/mlir/include/mlir/TableGen/Predicate.h +++ b/mlir/include/mlir/TableGen/Predicate.h @@ -59,9 +59,10 @@ // Get the location of the predicate. ArrayRef getLoc() const; -protected: - friend llvm::DenseMapInfo; + // Get the TableGen definition. + const llvm::Record *getDef() const { return def; } +protected: // The TableGen definition of this predicate. const llvm::Record *def; }; @@ -119,18 +120,4 @@ } // end namespace tblgen } // end namespace mlir -namespace llvm { -template <> -struct DenseMapInfo { - static mlir::tblgen::Pred getEmptyKey() { return mlir::tblgen::Pred(); } - static mlir::tblgen::Pred getTombstoneKey() { return mlir::tblgen::Pred(); } - static unsigned getHashValue(mlir::tblgen::Pred pred) { - return llvm::hash_value(pred.def); - } - static bool isEqual(mlir::tblgen::Pred lhs, mlir::tblgen::Pred rhs) { - return lhs == rhs; - } -}; -} // end namespace llvm - #endif // MLIR_TABLEGEN_PREDICATE_H_ diff --git a/mlir/test/mlir-tblgen/constraint-unique.td b/mlir/test/mlir-tblgen/constraint-unique.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/constraint-unique.td @@ -0,0 +1,156 @@ +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +class NS_Op traits = []> : + Op; + +/// Test unique'ing of type, attribute, successor, and region constraints. + +def ATypePred : CPred<"typePred($_self, $_op)">; +def AType : Type; +def OtherType : Type; + +def AnAttrPred : CPred<"attrPred($_self, $_op)">; +def AnAttr : Attr; +def OtherAttr : Attr; + +def ASuccessorPred : CPred<"successorPred($_self, $_op)">; +def ASuccessor : Successor; +def OtherSuccessor : Successor; + +def ARegionPred : CPred<"regionPred($_self, $_op)">; +def ARegion : Region; +def OtherRegion : Region; + +// OpA and OpB have the same type, attribute, successor, and region constraints. + +def OpA : NS_Op<"op_a"> { + let arguments = (ins AType:$a, AnAttr:$b); + let results = (outs AType:$ret); + let successors = (successor ASuccessor:$c); + let regions = (region ARegion:$d); +} + +def OpB : NS_Op<"op_b"> { + let arguments = (ins AType:$a, AnAttr:$b); + let successors = (successor ASuccessor:$c); + let regions = (region ARegion:$d); +} + +// OpC has the same type, attribute, successor, and region predicates but has +// difference descriptions for them. + +def OpC : NS_Op<"op_c"> { + let arguments = (ins OtherType:$a, OtherAttr:$b); + let results = (outs OtherType:$ret); + let successors = (successor OtherSuccessor:$c); + let regions = (region OtherRegion:$d); +} + +/// Test that a type contraint was generated. +// CHECK: static ::mlir::LogicalResult [[$A_TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( +// CHECK: if (!((typePred(type, *op)))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be a type, but got " << type; + +/// Test that duplicate type constraint was not generated. +// CHECK-NOT: << " must be a type, but got " << type; + +/// Test that a type constraint with a different description was generated. +// CHECK: static ::mlir::LogicalResult [[$O_TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( +// CHECK: if (!((typePred(type, *op)))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be another type, but got " << type; + +/// Test that an attribute contraint was generated. +// CHECK: static ::mlir::LogicalResult [[$A_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]]( +// CHECK: if (attr && !((attrPred(attr, *op)))) { +// CHECK-NEXT: return op->emitOpError("attribute '") << attrName +// CHECK-NEXT: << "' failed to satisfy constraint: an attribute"; + +/// Test that duplicate attribute constraint was not generated. +// CHECK-NOT: << "' failed to satisfy constraint: an attribute"; + +/// Test that a attribute constraint with a different description was generated. +// CHECK: static ::mlir::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]]( +// CHECK: if (attr && !((attrPred(attr, *op)))) { +// CHECK-NEXT: return op->emitOpError("attribute '") << attrName +// CHECK-NEXT: << "' failed to satisfy constraint: another attribute"; + +/// Test that a successor contraint was generated. +// CHECK: static ::mlir::LogicalResult [[$A_SUCCESSOR_CONSTRAINT:__mlir_ods_local_successor_constraint.*]]( +// CHECK: if (!((successorPred(successor, *op)))) { +// CHECK-NEXT: return op->emitOpError("successor #") << successorIndex << " ('" +// CHECK-NEXT: << successorName << ")' failed to verify constraint: a successor"; + +/// Test that duplicate successor constraint was not generated. +// CHECK-NOT: << successorName << ")' failed to verify constraint: a successor"; + +/// Test that a successor constraint with a different description was generated. +// CHECK: static ::mlir::LogicalResult [[$O_SUCCESSOR_CONSTRAINT:__mlir_ods_local_successor_constraint.*]]( +// CHECK: if (!((successorPred(successor, *op)))) { +// CHECK-NEXT: return op->emitOpError("successor #") << successorIndex << " ('" +// CHECK-NEXT: << successorName << ")' failed to verify constraint: another successor"; + +/// Test that a region contraint was generated. +// CHECK: static ::mlir::LogicalResult [[$A_REGION_CONSTRAINT:__mlir_ods_local_region_constraint.*]]( +// CHECK: if (!((regionPred(region, *op)))) { +// CHECK-NEXT: return op->emitOpError("region #") << regionIndex +// CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ") +// CHECK-NEXT: << "failed to verify constraint: a region"; + +/// Test that duplicate region constraint was not generated. +// CHECK-NOT: << "failed to verify constraint: a region"; + +/// Test that a region constraint with a different description was generated. +// CHECK: static ::mlir::LogicalResult [[$O_REGION_CONSTRAINT:__mlir_ods_local_region_constraint.*]]( +// CHECK: if (!((regionPred(region, *op)))) { +// CHECK-NEXT: return op->emitOpError("region #") << regionIndex +// CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ") +// CHECK-NEXT: << "failed to verify constraint: another region"; + +/// Test that the uniqued constraints are being used. +// CHECK-LABEL: OpA::verify +// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName()); +// CHECK: if (::mlir::failed([[$A_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b"))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0); +// CHECK: for (auto [[$A_VALUE:.*]] : [[$A_VALUE_GROUP]]) +// CHECK-NEXT: if (::mlir::failed([[$A_TYPE_CONSTRAINT]](*this, [[$A_VALUE]].getType(), "operand", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: auto [[$RET_VALUE_GROUP:.*]] = getODSResults(0); +// CHECK: for (auto [[$RET_VALUE:.*]] : [[$RET_VALUE_GROUP]]) +// CHECK-NEXT: if (::mlir::failed([[$A_TYPE_CONSTRAINT]](*this, [[$RET_VALUE]].getType(), "result", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: for (auto ®ion : ::llvm::makeMutableArrayRef((*this)->getRegion(0))) +// CHECK-NEXT: if (::mlir::failed([[$A_REGION_CONSTRAINT]](*this, region, "d", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: for (auto *successor : ::llvm::makeMutableArrayRef(c())) +// CHECK-NEXT: if (::mlir::failed([[$A_SUCCESSOR_CONSTRAINT]](*this, successor, "c", index++))) +// CHECK-NEXT: return ::mlir::failure(); + +/// Test that the op with the same predicates but different with descriptions +/// uses the different constraints. +// CHECK-LABEL: OpC::verify +// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName()); +// CHECK: if (::mlir::failed([[$O_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b"))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0); +// CHECK: for (auto [[$A_VALUE:.*]] : [[$A_VALUE_GROUP]]) +// CHECK-NEXT: if (::mlir::failed([[$O_TYPE_CONSTRAINT]](*this, [[$A_VALUE]].getType(), "operand", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: auto [[$RET_VALUE_GROUP:.*]] = getODSResults(0); +// CHECK: for (auto [[$RET_VALUE:.*]] : [[$RET_VALUE_GROUP]]) +// CHECK-NEXT: if (::mlir::failed([[$O_TYPE_CONSTRAINT]](*this, [[$RET_VALUE]].getType(), "result", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: for (auto ®ion : ::llvm::makeMutableArrayRef((*this)->getRegion(0))) +// CHECK-NEXT: if (::mlir::failed([[$O_REGION_CONSTRAINT]](*this, region, "d", index++))) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK: for (auto *successor : ::llvm::makeMutableArrayRef(c())) +// CHECK-NEXT: if (::mlir::failed([[$O_SUCCESSOR_CONSTRAINT]](*this, successor, "c", index++))) +// CHECK-NEXT: return ::mlir::failure(); diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -17,24 +17,28 @@ } // CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK-NEXT: if (!((type.isInteger(32) || type.isF32()))) { -// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; +// CHECK: if (!((type.isInteger(32) || type.isF32()))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be 32-bit integer or floating-point type, but got " << type; // Check there is no verifier with same predicate generated. // CHECK-NOT: if (!((type.isInteger(32) || type.isF32()))) { -// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; +// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NOT. << " must be 32-bit integer or floating-point type, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) { -// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type; +// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be tensor of any type values, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) { -// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; +// CHECK: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; // CHECK-LABEL: OpA::verify // CHECK: auto valueGroup0 = getODSOperands(0); -// CHECK: for (::mlir::Value v : valueGroup0) { +// CHECK: for (auto v : valueGroup0) { // CHECK: if (::mlir::failed([[$INTEGER_FLOAT_CONSTRAINT]] def OpB : NS_Op<"op_for_And_PredOpTrait", [ @@ -109,7 +113,7 @@ // CHECK-LABEL: OpK::verify // CHECK: auto valueGroup0 = getODSOperands(0); -// CHECK: for (::mlir::Value v : valueGroup0) { +// CHECK: for (auto v : valueGroup0) { // CHECK: if (::mlir::failed([[$TENSOR_INTEGER_FLOAT_CONSTRAINT]] def OpL : NS_Op<"op_for_StringEscaping", []> { diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp --- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp @@ -27,22 +27,17 @@ const llvm::RecordKeeper &records, raw_ostream &os) : os(os), uniqueOutputLabel(getUniqueName(records)) {} -void StaticVerifierFunctionEmitter::emitFunctionsFor( - StringRef signatureFormat, StringRef errorHandlerFormat, - StringRef typeArgName, ArrayRef opDefs, bool emitDecl) { - llvm::Optional namespaceEmitter; - if (!emitDecl) - namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace()); +void StaticVerifierFunctionEmitter::emitLocalFunctions( + ArrayRef opDefs, bool emitDecl) { + collectAllConstraints(opDefs); + if (emitDecl) + return; - emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName, - opDefs, emitDecl); -} - -StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( - const Constraint &constraint) const { - auto it = localTypeConstraints.find(constraint.getAsOpaquePointer()); - assert(it != localTypeConstraints.end() && "expected valid constraint fn"); - return it->second; + NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); + emitTypeConstraints(); + emitAttrConstraints(); + emitSuccessorConstraints(); + emitRegionConstraints(); } std::string StaticVerifierFunctionEmitter::getUniqueName( @@ -65,78 +60,225 @@ return uniqueName; } -void StaticVerifierFunctionEmitter::emitTypeConstraintMethods( - StringRef signatureFormat, StringRef errorHandlerFormat, - StringRef typeArgName, ArrayRef opDefs, bool emitDecl) { - // Collect a set of all of the used type constraints within the operation - // definitions. - llvm::SetVector typeConstraints; - for (Record *def : opDefs) { - Operator op(*def); - for (NamedTypeConstraint &operand : op.getOperands()) - if (operand.hasPredicate()) - typeConstraints.insert(operand.constraint.getAsOpaquePointer()); - for (NamedTypeConstraint &result : op.getResults()) - if (result.hasPredicate()) - typeConstraints.insert(result.constraint.getAsOpaquePointer()); +//===----------------------------------------------------------------------===// +// Constraint Getters + +StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( + const Constraint &constraint) const { + auto it = localTypeConstraints.find(constraint); + assert(it != localTypeConstraints.end() && + "expected to find a type constraint"); + return it->second; +} + +Optional StaticVerifierFunctionEmitter::getAttrConstraintFn( + const Constraint &constraint) const { + auto it = localAttrConstraints.find(constraint); + return it == localAttrConstraints.end() ? Optional() + : StringRef(it->second); +} + +StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn( + const Constraint &constraint) const { + auto it = localSuccessorConstraints.find(constraint); + assert(it != localSuccessorConstraints.end() && + "expected to find a sucessor constraint"); + return it->second; +} + +StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn( + const Constraint &constraint) const { + auto it = localRegionConstraints.find(constraint); + assert(it != localRegionConstraints.end() && + "expected to find a region constraint"); + return it->second; +} + +//===----------------------------------------------------------------------===// +// Constraint Emission + +/// Code templates for emitting type, attribute, successor, and region +/// constraints. Each of these templates require the following arguments: +/// +/// {0}: The unique constraint name. +/// {1}: The constraint code. +/// {2}: The constraint description. + +/// Code for a type constraint. These may be called on the type of either +/// operands or results. +static const char *const typeConstraintCode = R"( +static ::mlir::LogicalResult {0}( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!({1})) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be {2}, but got " << type; } + return ::mlir::success(); +} +)"; - // Record the mapping from predicate to constraint. If two constraints has the - // same predicate and constraint summary, they can share the same verification - // function. - llvm::DenseMap predToConstraint; - FmtContext fctx; - for (auto it : llvm::enumerate(typeConstraints)) { - std::string name; - Constraint constraint = Constraint::getFromOpaquePointer(it.value()); - Pred pred = constraint.getPredicate(); - auto iter = predToConstraint.find(pred); - if (iter != predToConstraint.end()) { - do { - Constraint built = Constraint::getFromOpaquePointer(iter->second); - // We may have the different constraints but have the same predicate, - // for example, ConstraintA and Variadic, note that - // Variadic<> doesn't introduce new predicate. In this case, we can - // share the same predicate function if they also have consistent - // summary, otherwise we may report the wrong message while verification - // fails. - if (constraint.getSummary() == built.getSummary()) { - name = getTypeConstraintFn(built).str(); - break; - } - ++iter; - } while (iter != predToConstraint.end() && iter->first == pred); - } +/// Code for an attribute constraint. These may be called from ops only. +/// Attribute constraints cannot reference anything other than `$_self` and +/// `$_op`. +/// +/// TODO: Unique constraints for adaptors. However, most Adaptor::verify +/// functions are stripped anyways. +static const char *const attrConstraintCode = R"( +static ::mlir::LogicalResult {0}( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + if (attr && !({1})) { + return op->emitOpError("attribute '") << attrName + << "' failed to satisfy constraint: {2}"; + } + return ::mlir::success(); +} +)"; - if (!name.empty()) { - localTypeConstraints.try_emplace(it.value(), name); - continue; - } +/// Code for a successor constraint. +static const char *const successorConstraintCode = R"( +static ::mlir::LogicalResult {0}( + ::mlir::Operation *op, ::mlir::Block *successor, + ::llvm::StringRef successorName, unsigned successorIndex) { + if (!({1})) { + return op->emitOpError("successor #") << successorIndex << " ('" + << successorName << ")' failed to verify constraint: {2}"; + } + return ::mlir::success(); +} +)"; - // Generate an obscure and unique name for this type constraint. - name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel + - Twine(it.index())) - .str(); - predToConstraint.insert( - std::make_pair(constraint.getPredicate(), it.value())); - localTypeConstraints.try_emplace(it.value(), name); - - // Only generate the methods if we are generating definitions. - if (emitDecl) - continue; - - os << formatv(signatureFormat.data(), name) << " {\n"; - os.indent() << "if (!(" - << tgfmt(constraint.getConditionTemplate(), - &fctx.withSelf(typeArgName)) - << ")) {\n"; - os.indent() << "return " - << formatv(errorHandlerFormat.data(), constraint.getSummary()) - << ";\n"; - os.unindent() << "}\nreturn ::mlir::success();\n"; - os.unindent() << "}\n\n"; +/// Code for a region constraint. Callers will need to pass in the region's name +/// for emitting an error message. +static const char *const regionConstraintCode = R"( +static ::mlir::LogicalResult {0}( + ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, + unsigned regionIndex) { + if (!({1})) { + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: {2}"; } + return ::mlir::success(); +} +)"; + +void StaticVerifierFunctionEmitter::emitConstraints( + const ConstraintMap &constraints, StringRef selfName, + const char *const codeTemplate) { + FmtContext ctx; + ctx.withOp("*op").withSelf(selfName); + for (auto &it : constraints) { + os << formatv(codeTemplate, it.second, + tgfmt(it.first.getConditionTemplate(), &ctx), + it.first.getSummary()); + } +} + +void StaticVerifierFunctionEmitter::emitTypeConstraints() { + emitConstraints(localTypeConstraints, "type", typeConstraintCode); +} + +void StaticVerifierFunctionEmitter::emitAttrConstraints() { + emitConstraints(localAttrConstraints, "attr", attrConstraintCode); +} + +void StaticVerifierFunctionEmitter::emitSuccessorConstraints() { + emitConstraints(localSuccessorConstraints, "successor", + successorConstraintCode); +} + +void StaticVerifierFunctionEmitter::emitRegionConstraints() { + emitConstraints(localRegionConstraints, "region", regionConstraintCode); +} + +//===----------------------------------------------------------------------===// +// Constraint Uniquing + +Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() { + return Constraint(); +} +Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() { + return Constraint(); } +unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue( + Constraint constraint) { + if (!constraint.getDef()) + return 0; + return llvm::hash_combine(constraint.getPredicate().getDef(), + constraint.getSummary()); +} +bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs, + Constraint rhs) { + if (!lhs.getDef() ^ !rhs.getDef()) + return false; + return !lhs.getDef() || (lhs.getPredicate() == rhs.getPredicate() && + lhs.getSummary() == rhs.getSummary()); +} + +/// An attribute constraint that references anything other than itself and the +/// current op cannot be generically extracted into a function. Most +/// prohibitive are operands and results, which require calls to +/// `getODSOperands` or `getODSResults`. Attribute references are tricky too +/// because ops use cached identifiers. +static bool canUniqueAttrConstraint(Attribute attr) { + FmtContext ctx; + auto test = + tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op")) + .str(); + return !StringRef(test).contains(""); +} + +void StaticVerifierFunctionEmitter::collectAllConstraints( + ArrayRef opDefs) { + /// Function to assign a unique name to a unique constraint. + const auto getUniqueName = [this](StringRef kind, unsigned index) { + return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel + + Twine(index)) + .str(); + }; + + /// Function to unique a constraint. + const auto collectConstraint = [&](ConstraintMap &map, StringRef kind, + Constraint constraint) { + auto it = map.find(constraint); + if (it == map.end()) + map.insert({constraint, getUniqueName(kind, map.size())}); + }; + + const auto collectTypeConstraints = [&](Operator::value_range values) { + for (const NamedTypeConstraint &value : values) + if (value.hasPredicate()) + collectConstraint(localTypeConstraints, "type", value.constraint); + }; + + for (Record *def : opDefs) { + Operator op(*def); + /// Collect type constraints. + collectTypeConstraints(op.getOperands()); + collectTypeConstraints(op.getResults()); + /// Collect attribute constraints. + for (const NamedAttribute &namedAttr : op.getAttributes()) + if (!namedAttr.attr.getPredicate().isNull() && + canUniqueAttrConstraint(namedAttr.attr)) + collectConstraint(localAttrConstraints, "attr", namedAttr.attr); + /// Collect successor constraints. + for (const NamedSuccessor &successor : op.getSuccessors()) { + if (!successor.constraint.getPredicate().isNull()) { + collectConstraint(localSuccessorConstraints, "successor", + successor.constraint); + } + } + /// Collect region constraints. + for (const NamedRegion ®ion : op.getRegions()) + if (!region.constraint.getPredicate().isNull()) + collectConstraint(localRegionConstraints, "region", region.constraint); + } +} + +//===----------------------------------------------------------------------===// +// Public Utility Functions +//===----------------------------------------------------------------------===// std::string mlir::tblgen::escapeString(StringRef value) { std::string ret; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -127,14 +127,6 @@ std::next({0}, valueRange.first + valueRange.second)}; )"; -static const char *const typeVerifierSignature = - "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type " - "type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)"; - -static const char *const typeVerifierErrorHandler = - " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must " - "be {0}, but got \" << type"; - static const char *const opCommentHeader = R"( //===----------------------------------------------------------------------===// // {0} {1} @@ -475,29 +467,42 @@ // Generate attribute verification. If an op instance is not available, then // attribute checks that require one will not be emitted. -static void genAttributeVerifier(const OpOrAdaptorHelper &emit, FmtContext &ctx, - OpMethodBody &body) { +static void genAttributeVerifier( + const OpOrAdaptorHelper &emit, FmtContext &ctx, OpMethodBody &body, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) { // Check that a required attribute exists. // // {0}: Attribute variable name. // {1}: Emit error prefix. // {2}: Attribute name. - const char *const checkRequiredAttr = R"( + const char *const verifyRequiredAttr = R"( if (!{0}) return {1}"requires attribute '{2}'"); - )"; - // Check the condition on an attribute if it is required. This assumes that - // default values are valid. +)"; + // Verify the attribute if it is present. This assumes that default values + // are valid. This code snippet pastes the condition inline. + // // TODO: verify the default value is valid (perhaps in debug mode only). // // {0}: Attribute variable name. // {1}: Attribute condition code. // {2}: Emit error prefix. - // {3}: Attribute/constraint description. - const char *const checkAttrCondition = R"( + // {3}: Attribute name. + // {4}: Attribute/constraint description. + const char *const verifyAttrInline = R"( if ({0} && !({1})) return {2}"attribute '{3}' failed to satisfy constraint: {4}"); - )"; +)"; + // Verify the attribute using a uniqued constraint. Can only be used within + // the context of an op. + // + // {0}: Unique constraint name. + // {1}: Attribute variable name. + // {2}: Attribute name. + const char *const verifyAttrUnique = R"( + if (::mlir::failed({0}(*this, {1}, "{2}"))) + return ::mlir::failure(); +)"; for (const auto &namedAttr : emit.getOp().getAttributes()) { const auto &attr = namedAttr.attr; @@ -511,7 +516,8 @@ // If the attribute's condition needs an op but none is available, then the // condition cannot be emitted. bool canEmitCondition = - !StringRef(condition).contains("$_op") || emit.isUsingOp(); + !condition.empty() && + (!StringRef(condition).contains("$_op") || emit.isUsingOp()); // Prefix with `tblgen_` to avoid hiding the attribute accessor. Twine varName = tblgenNamePrefix + attrName; @@ -525,15 +531,21 @@ emit.getAttr(attrName)); if (!allowMissingAttr) { - body << formatv(checkRequiredAttr, varName, emit.emitErrorPrefix(), + body << formatv(verifyRequiredAttr, varName, emit.emitErrorPrefix(), attrName); } if (canEmitCondition) { - body << formatv( - checkAttrCondition, varName, tgfmt(condition, &ctx.withSelf(varName)), - emit.emitErrorPrefix(), attrName, escapeString(attr.getSummary())); + Optional constraintFn; + if (emit.isUsingOp() && + (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) { + body << formatv(verifyAttrUnique, *constraintFn, varName, attrName); + } else { + body << formatv( + verifyAttrInline, varName, tgfmt(condition, &ctx.withSelf(varName)), + emit.emitErrorPrefix(), attrName, escapeString(attr.getSummary())); + } } - body << "}\n"; + body << " }\n"; } } @@ -2206,7 +2218,7 @@ bool hasCustomVerify = stringInit && !stringInit->getValue().empty(); populateSubstitutions(emit, verifyCtx); - genAttributeVerifier(emit, verifyCtx, body); + genAttributeVerifier(emit, verifyCtx, body, staticVerifierEmitter); genOperandResultVerifier(body, op.getOperands(), "operand"); genOperandResultVerifier(body, op.getResults(), "result"); @@ -2235,10 +2247,38 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body, Operator::value_range values, StringRef valueKind) { + // Check that an optional value is at most 1 element. + // + // {0}: Value index. + // {1}: "operand" or "result" + const char *const verifyOptional = R"( + if (valueGroup{0}.size() > 1) { + return emitOpError("{1} group starting at #") << index + << " requires 0 or 1 element, but found " << valueGroup{0}.size(); + } +)"; + // Check the types of a range of values. + // + // {0}: Value index. + // {1}: Type constraint function. + // {2}: "operand" or "result" + const char *const verifyValues = R"( + for (auto v : valueGroup{0}) { + if (::mlir::failed({1}(*this, v.getType(), "{2}", index++))) + return ::mlir::failure(); + } +)"; + + const auto canSkip = [](const NamedTypeConstraint &value) { + return !value.hasPredicate() && !value.isOptional() && + !value.isVariadicOfVariadic(); + }; + if (values.empty() || llvm::all_of(values, canSkip)) + return; + FmtContext fctx; - body << " {\n"; - body << " unsigned index = 0; (void)index;\n"; + body << " {\n unsigned index = 0; (void)index;\n"; for (auto staticValue : llvm::enumerate(values)) { const NamedTypeConstraint &value = staticValue.value(); @@ -2256,11 +2296,7 @@ // If the constraint is optional check that the value group has at most 1 // value. if (isOptional) { - body << formatv(" if (valueGroup{0}.size() > 1)\n" - " return emitOpError(\"{1} group starting at #\") " - "<< index << \" requires 0 or 1 element, but found \" << " - "valueGroup{0}.size();\n", - staticValue.index(), valueKind); + body << formatv(verifyOptional, staticValue.index(), valueKind); } else if (isVariadicOfVariadic) { body << formatv( " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr(" @@ -2276,92 +2312,88 @@ // Emit a loop to check all the dynamic values in the pack. StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(value.constraint); - body << " for (::mlir::Value v : valueGroup" << staticValue.index() - << ") {\n" - << " if (::mlir::failed(" << constraintFn - << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n" - << " return ::mlir::failure();\n" - << " ++index;\n" - << " }\n"; + body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind); } body << " }\n"; } void OpEmitter::genRegionVerifier(OpMethodBody &body) { + /// Code to verify a region. + /// + /// {0}: Getter for the regions. + /// {1}: The region constraint. + /// {2}: The region's name. + /// {3}: The region description. + const char *const verifyRegion = R"( + for (auto ®ion : {0}) + if (::mlir::failed({1}(*this, region, "{2}", index++))) + return ::mlir::failure(); +)"; + /// Get a single region. + /// + /// {0}: Reserved. + /// {1}: The region's index. + const char *const getSingleRegion = + "::llvm::makeMutableArrayRef((*this)->getRegion({1}))"; + // If we have no regions, there is nothing more to do. - unsigned numRegions = op.getNumRegions(); - if (numRegions == 0) + const auto canSkip = [](const NamedRegion ®ion) { + return region.constraint.getPredicate().isNull(); + }; + auto regions = op.getRegions(); + if (regions.empty() && llvm::all_of(regions, canSkip)) return; - body << "{\n"; - body << " unsigned index = 0; (void)index;\n"; - - for (unsigned i = 0; i < numRegions; ++i) { - const auto ®ion = op.getRegion(i); - if (region.constraint.getPredicate().isNull()) + body << " {\n unsigned index = 0; (void)index;\n"; + for (auto it : llvm::enumerate(regions)) { + const auto ®ion = it.value(); + if (canSkip(region)) continue; - body << " for (::mlir::Region ®ion : "; - body << formatv(region.isVariadic() - ? "{0}()" - : "::mlir::MutableArrayRef<::mlir::Region>((*this)" - "->getRegion({1}))", - op.getGetterName(region.name), i); - body << ") {\n"; - auto constraint = tgfmt(region.constraint.getConditionTemplate(), - &verifyCtx.withSelf("region")) - .str(); - - body << formatv(" (void)region;\n" - " if (!({0})) {\n " - "return emitOpError(\"region #\") << index << \" {1}" - "failed to " - "verify constraint: {2}\";\n }\n", - constraint, - region.name.empty() ? "" : "('" + region.name + "') ", - region.constraint.getSummary()) - << " ++index;\n" - << " }\n"; + auto getRegion = formatv(region.isVariadic() ? "{0}()" : getSingleRegion, + op.getGetterName(region.name), it.index()); + auto constraintFn = + staticVerifierEmitter.getRegionConstraintFn(region.constraint); + body << formatv(verifyRegion, getRegion, constraintFn, region.name); } body << " }\n"; } void OpEmitter::genSuccessorVerifier(OpMethodBody &body) { + const char *const verifySuccessor = R"( + for (auto *successor : {0}) + if (::mlir::failed({1}(*this, successor, "{2}", index++))) + return ::mlir::failure(); +)"; + /// Get a single successor. + /// + /// {0}: The successor's name. + const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())"; + // If we have no successors, there is nothing more to do. - unsigned numSuccessors = op.getNumSuccessors(); - if (numSuccessors == 0) + const auto canSkip = [](const NamedSuccessor &successor) { + return successor.constraint.getPredicate().isNull(); + }; + auto successors = op.getSuccessors(); + if (successors.empty() && llvm::all_of(successors, canSkip)) return; - body << "{\n"; - body << " unsigned index = 0; (void)index;\n"; + body << " {\n unsigned index = 0; (void)index;\n"; - for (unsigned i = 0; i < numSuccessors; ++i) { - const auto &successor = op.getSuccessor(i); - if (successor.constraint.getPredicate().isNull()) + for (auto it : llvm::enumerate(successors)) { + const auto &successor = it.value(); + if (canSkip(successor)) continue; - if (successor.isVariadic()) { - body << formatv(" for (::mlir::Block *successor : {0}()) {\n", - successor.name); - } else { - body << " {\n"; - body << formatv(" ::mlir::Block *successor = {0}();\n", - successor.name); - } - auto constraint = tgfmt(successor.constraint.getConditionTemplate(), - &verifyCtx.withSelf("successor")) - .str(); - - body << formatv(" (void)successor;\n" - " if (!({0})) {\n " - "return emitOpError(\"successor #\") << index << \"('{1}') " - "failed to " - "verify constraint: {2}\";\n }\n", - constraint, successor.name, - successor.constraint.getSummary()) - << " ++index;\n" - << " }\n"; + auto getSuccessor = + formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor, + successor.name, it.index()) + .str(); + auto constraintFn = + staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint); + body << formatv(verifySuccessor, getSuccessor, constraintFn, + successor.name); } body << " }\n"; } @@ -2501,11 +2533,16 @@ // getters identical to those defined in the Op. class OpOperandAdaptorEmitter { public: - static void emitDecl(const Operator &op, raw_ostream &os); - static void emitDef(const Operator &op, raw_ostream &os); + static void emitDecl(const Operator &op, + StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os); + static void emitDef(const Operator &op, + StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os); private: - explicit OpOperandAdaptorEmitter(const Operator &op); + explicit OpOperandAdaptorEmitter( + const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter); // Add verification function. This generates a verify method for the adaptor // which verifies all the op-independent attribute constraints. @@ -2513,11 +2550,14 @@ const Operator &op; Class adaptor; + StaticVerifierFunctionEmitter &staticVerifierEmitter; }; } // end anonymous namespace -OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) - : op(op), adaptor(op.getAdaptorName()) { +OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( + const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter) + : op(op), adaptor(op.getAdaptorName()), + staticVerifierEmitter(staticVerifierEmitter) { adaptor.newField("::mlir::ValueRange", "odsOperands"); adaptor.newField("::mlir::DictionaryAttr", "odsAttrs"); adaptor.newField("::mlir::RegionRange", "odsRegions"); @@ -2641,17 +2681,21 @@ FmtContext verifyCtx; populateSubstitutions(emit, verifyCtx); - genAttributeVerifier(emit, verifyCtx, body); + genAttributeVerifier(emit, verifyCtx, body, staticVerifierEmitter); body << " return ::mlir::success();"; } -void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { - OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os); +void OpOperandAdaptorEmitter::emitDecl( + const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os) { + OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os); } -void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { - OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os); +void OpOperandAdaptorEmitter::emitDef( + const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os) { + OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os); } // Emits the opcode enum and op classes. @@ -2678,9 +2722,7 @@ // Generate all of the locally instantiated methods first. StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os); os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); - staticVerifierEmitter.emitFunctionsFor( - typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type", - defs, emitDecl); + staticVerifierEmitter.emitLocalFunctions(defs, emitDecl); for (auto *def : defs) { Operator op(*def); @@ -2689,7 +2731,7 @@ NamespaceEmitter emitter(os, op.getCppNamespace()); os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); - OpOperandAdaptorEmitter::emitDecl(op, os); + OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os); OpEmitter::emitDecl(op, os, staticVerifierEmitter); } // Emit the TypeID explicit specialization to have a single definition. @@ -2700,7 +2742,7 @@ { NamespaceEmitter emitter(os, op.getCppNamespace()); os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); - OpOperandAdaptorEmitter::emitDef(op, os); + OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os); OpEmitter::emitDef(op, os, staticVerifierEmitter); } // Emit the TypeID explicit specialization to have a single definition.