diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -564,14 +564,36 @@ _additional_ verification, you can use ```tablegen -let verifier = [{ +let earlyVerifier = [{ + ... +}]; + +let lateVerifier = [{ ... }]; ``` -Code placed in `verifier` will be called after the auto-generated verification -code. The order of trait verification excluding those of `verifier` should not -be relied upon. +#### Verification Ordering + +The verification of an operation involves several steps, + +1. Operation structure verification, e.g., the number of operands. +1. Trait/Interface verification +1. Invoking the custom verifier + +The operation structure verification is done by several core traits such as +`OneOperand` and the generated method `verifyInvariants` which will verify the +types, attrs, .etc. After structure verification, it gives the guarantee that +any basic operating on the operation will be safe, i.e., don't need to exam the +nullity of the fields. The custom verifier can do the further verification based +on the things which have been verified by the traits/interfaces. + +Because we may have nested operations for an operation, some of the verifiers +may depend on the verification result of the nested operations but some don't. +As a result, the custom verifiers fall into two categories, early and late kind. +`earlyVerifier` will be invoked before the verification of nested operations, +`lateVerifier` will be invoked after the verification of nested operations. Note +that it's valid to define both of them. ### Declarative Assembly Format diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -36,24 +36,29 @@ }; ``` -Operation traits may also provide a `verifyTrait` hook, that is called when -verifying the concrete operation. The trait verifiers will currently always be -invoked before the main `Op::verify`. +Operation traits may also provide a verification hook by defining +`earlyVerifyTrait` or `lateVerifyTrait`, these are called when verifying the +concrete operation. `earlyVerifyTrait` will be called before the verification of +nested operation. `lateVerifyTrait` will be called after the verification of +nested operation. + +Note that the legacy `verifyTrait` method is going to be deprecated, use the two +methods mentioned above instead. ```c++ template class MyTrait : public OpTrait::TraitBase { public: - /// Override the 'verifyTrait' hook to add additional verification on the + /// Override the 'earlyVerifyTrait' hook to add additional verification on the /// concrete operation. - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult earlyVerifyTrait(Operation *op) { // ... } }; ``` Note: It is generally good practice to define the implementation of the -`verifyTrait` hook out-of-line as a free function when possible to avoid +verification hook out-of-line as a free function when possible to avoid instantiating the implementation for every concrete operation type. Operation traits may also provide a `foldTrait` hook that is called when folding diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2169,10 +2169,17 @@ // interfaces specified for ops in TableGen less alien and more integrated. class OpInterfaceTrait traits = []> : InterfaceTrait, OpTrait { - // Specify the body of the verification function. `$_op` will be replaced with - // the operation being verified. + // TODO: This is going to be deprecated. Use `earlyVerifier`/`lateVerifier` + // instead. code verify = verifyBody; + // Specify the body of the verification function. `$_op` will be replaced with + // the operation being verified. `early-` means it needs to be called before + // nested operations. `late-` means it need to be called after nested + // operations. Note that it's valid to define both of them. + code earlyVerifier = ?; + code lateVerifier = ?; + // Specify the list of trait verifiers that need to be run before the verifier // of this OpInterfaceTrait. list dependentTraits = traits; @@ -2430,8 +2437,18 @@ string assemblyFormat = ?; // Custom verifier. + // TODO: This is going to be deprecated. Use `earlyVerifier`/`lateVerifier` + // instead. code verifier = ?; + // Custom verifier with `early-` and `late-` prefix to indicate when this + // verifier should be invoked. `early-` means it needs to be called before + // the verification of nested operations. `late-` means it needs to be called + // after the verification of nested operations are done. It's valid to define + // both verifiers. + code earlyVerifier = ?; + code lateVerifier = ?; + // Whether this op has associated canonicalization patterns. bit hasCanonicalizer = 0; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -356,7 +356,7 @@ template class ZeroOperands : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyZeroOperands(op); } @@ -375,7 +375,7 @@ void setOperand(Value value) { this->getOperation()->setOperand(0, value); } - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyOneOperand(op); } }; @@ -394,7 +394,7 @@ class Impl : public detail::MultiOperandTraitBase::Impl> { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyNOperands(op, N); } }; @@ -412,7 +412,7 @@ class Impl : public detail::MultiOperandTraitBase::Impl> { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyAtLeastNOperands(op, N); } }; @@ -432,7 +432,7 @@ template class ZeroRegion : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyZeroRegion(op); } }; @@ -473,7 +473,7 @@ return getRegion().template getOps(); } - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyOneRegion(op); } }; @@ -489,7 +489,7 @@ class Impl : public detail::MultiRegionTraitBase::Impl> { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyNRegions(op, N); } }; @@ -504,7 +504,7 @@ class Impl : public detail::MultiRegionTraitBase::Impl> { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyAtLeastNRegions(op, N); } }; @@ -524,7 +524,7 @@ template class ZeroResult : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyZeroResult(op); } }; @@ -598,7 +598,7 @@ this->getOperation()->replaceAllUsesWith(op); } - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyOneResult(op); } }; @@ -637,7 +637,7 @@ class Impl : public detail::MultiResultTraitBase::Impl> { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyNResults(op, N); } }; @@ -655,7 +655,7 @@ class Impl : public detail::MultiResultTraitBase::Impl> { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyAtLeastNResults(op, N); } }; @@ -679,7 +679,7 @@ template class IsTerminator : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyIsTerminator(op); } }; @@ -689,7 +689,7 @@ template class ZeroSuccessor : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyZeroSuccessor(op); } }; @@ -733,7 +733,7 @@ this->getOperation()->setSuccessor(succ, 0); } - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyOneSuccessor(op); } }; @@ -749,7 +749,7 @@ class Impl : public detail::MultiSuccessorTraitBase::Impl> { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyNSuccessors(op, N); } }; @@ -765,7 +765,7 @@ : public detail::MultiSuccessorTraitBase::Impl> { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyAtLeastNSuccessors(op, N); } }; @@ -786,7 +786,7 @@ template struct SingleBlock : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { Region ®ion = op->getRegion(i); @@ -882,8 +882,8 @@ /// The type of the operation used as the implicit terminator type. using ImplicitTerminatorOpT = TerminatorOpType; - static LogicalResult verifyTrait(Operation *op) { - if (failed(Base::verifyTrait(op))) + static LogicalResult verifyCoreTrait(Operation *op) { + if (failed(Base::verifyCoreTrait(op))) return failure(); for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { Region ®ion = op->getRegion(i); @@ -986,7 +986,7 @@ template class SameOperandsShape : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifySameOperandsShape(op); } }; @@ -998,7 +998,7 @@ class SameOperandsAndResultShape : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifySameOperandsAndResultShape(op); } }; @@ -1010,7 +1010,7 @@ class SameOperandsElementType : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifySameOperandsElementType(op); } }; @@ -1022,7 +1022,7 @@ class SameOperandsAndResultElementType : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifySameOperandsAndResultElementType(op); } }; @@ -1036,7 +1036,7 @@ class SameOperandsAndResultType : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifySameOperandsAndResultType(op); } }; @@ -1046,7 +1046,7 @@ template class ResultsAreBoolLike : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyResultsAreBoolLike(op); } }; @@ -1057,7 +1057,7 @@ class ResultsAreFloatLike : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyResultsAreFloatLike(op); } }; @@ -1068,7 +1068,7 @@ class ResultsAreSignlessIntegerLike : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyResultsAreSignlessIntegerLike(op); } }; @@ -1082,7 +1082,7 @@ template class IsInvolution : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { static_assert(ConcreteType::template hasTrait(), "expected operation to produce one result"); static_assert(ConcreteType::template hasTrait(), @@ -1105,7 +1105,7 @@ template class IsIdempotent : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { static_assert(ConcreteType::template hasTrait(), "expected operation to produce one result"); static_assert(ConcreteType::template hasTrait() || @@ -1129,7 +1129,7 @@ class OperandsAreFloatLike : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyOperandsAreFloatLike(op); } }; @@ -1140,7 +1140,7 @@ class OperandsAreSignlessIntegerLike : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyOperandsAreSignlessIntegerLike(op); } }; @@ -1150,7 +1150,7 @@ template class SameTypeOperands : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifySameTypeOperands(op); } }; @@ -1161,7 +1161,7 @@ template class ConstantLike : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { static_assert(ConcreteType::template hasTrait(), "expected operation to produce one result"); static_assert(ConcreteType::template hasTrait(), @@ -1180,7 +1180,7 @@ class IsIsolatedFromAbove : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return impl::verifyIsIsolatedFromAbove(op); } }; @@ -1193,7 +1193,7 @@ template class AffineScope : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { static_assert(!ConcreteType::template hasTrait(), "expected operation to have one or more regions"); return success(); @@ -1209,7 +1209,7 @@ class AutomaticAllocationScope : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { static_assert(!ConcreteType::template hasTrait(), "expected operation to have one or more regions"); return success(); @@ -1223,7 +1223,7 @@ template class Impl : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { if (llvm::isa(op->getParentOp())) return success(); @@ -1254,7 +1254,7 @@ return "operand_segment_sizes"; } - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return ::mlir::OpTrait::impl::verifyOperandSizeAttr( op, getOperandSegmentSizeAttr()); } @@ -1267,7 +1267,7 @@ public: static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; } - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return ::mlir::OpTrait::impl::verifyResultSizeAttr( op, getResultSegmentSizeAttr()); } @@ -1277,7 +1277,7 @@ /// not have any arguments template struct NoRegionArguments : public TraitBase { - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return ::mlir::OpTrait::impl::verifyNoRegionArguments(op); } }; @@ -1325,7 +1325,7 @@ /// an `Elementwise` op. template struct Elementwise : public TraitBase { - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { return ::mlir::OpTrait::impl::verifyElementwise(op); } }; @@ -1354,7 +1354,7 @@ /// ``` template struct Scalarizable : public TraitBase { - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { static_assert( ConcreteType::template hasTrait(), "`Scalarizable` trait is only applicable to `Elementwise` ops."); @@ -1374,7 +1374,7 @@ /// broadcasting in cases like `%select_scalar_pred` below. template struct Vectorizable : public TraitBase { - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { static_assert( ConcreteType::template hasTrait(), "`Vectorizable` trait is only applicable to `Elementwise` ops."); @@ -1415,7 +1415,7 @@ /// ``` template struct Tensorizable : public TraitBase { - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyCoreTrait(Operation *op) { static_assert( ConcreteType::template hasTrait(), "`Tensorizable` trait is only applicable to `Elementwise` ops."); @@ -1547,6 +1547,30 @@ template using detect_has_verify_trait = llvm::is_detected; +/// Trait to check if T provides a `verifyCoreTrait` method. +template +using has_verify_core_trait = + decltype(T::verifyCoreTrait(std::declval())); +template +using detect_has_verify_core_trait = + llvm::is_detected; + +/// Trait to check if T provides a `earlyVerifyTrait` method. +template +using has_early_verify_trait = + decltype(T::earlyVerifyTrait(std::declval())); +template +using detect_has_early_verify_trait = + llvm::is_detected; + +/// Trait to check if T provides a `lateVerifyTrait` method. +template +using has_late_verify_trait = + decltype(T::lateVerifyTrait(std::declval())); +template +using detect_has_late_verify_trait = + llvm::is_detected; + /// The internal implementation of `verifyTraits` below that returns the result /// of verifying the current operation with all of the provided trait types /// `Ts`. @@ -1558,6 +1582,39 @@ return result; } +/// The internal implementation of `verifyCoreTraits` below that returns the +/// result of verifying the current operation with all of the provided trait +/// types `Ts`. +template +static LogicalResult verifyCoreTraitsImpl(Operation *op, std::tuple *) { + LogicalResult result = success(); + (void)std::initializer_list{ + (result = succeeded(result) ? Ts::verifyCoreTrait(op) : failure(), 0)...}; + return result; +} + +/// The internal implementation of `earlyVerifyTraits` below that returns the +/// result of verifying the current operation with all of the provided trait +/// types `Ts`. +template +static LogicalResult earlyVerifyTraitsImpl(Operation *op, std::tuple *) { + LogicalResult result = success(); + (void)std::initializer_list{( + result = succeeded(result) ? Ts::earlyVerifyTrait(op) : failure(), 0)...}; + return result; +} + +/// The internal implementation of `lateVerifyTraits` below that returns the +/// result of verifying the current operation with all of the provided trait +/// types `Ts`. +template +static LogicalResult lateVerifyTraitsImpl(Operation *op, std::tuple *) { + LogicalResult result = success(); + (void)std::initializer_list{ + (result = succeeded(result) ? Ts::lateVerifyTrait(op) : failure(), 0)...}; + return result; +} + /// Given a tuple type containing a set of traits that contain a /// `verifyTrait` method, return the result of verifying the given operation. template @@ -1565,6 +1622,33 @@ return verifyTraitsImpl(op, (TraitTupleT *)nullptr); } +/// Given a tuple type containing a set of traits that contain a +/// `verifyCoreTrait` method, return the result of verifying the given +/// operation. +/// `verifyCoreTrait` can only be used in core MLIR traits to define the +/// verifier. It's supposed to be able to run independently, i.e., it doesn't +/// depend on any traits to be held. +template +static LogicalResult verifyCoreTraits(Operation *op) { + return verifyCoreTraitsImpl(op, (TraitTupleT *)nullptr); +} + +/// Given a tuple type containing a set of traits that contain a +/// `earlyVerifyTrait` method, return the result of verifying the given +/// operation. +template +static LogicalResult earlyVerifyTraits(Operation *op) { + return earlyVerifyTraitsImpl(op, (TraitTupleT *)nullptr); +} + +/// Given a tuple type containing a set of traits that contain a +/// `lateVerifyTrait` method, return the result of verifying the given +/// operation. +template +static LogicalResult lateVerifyTraits(Operation *op) { + return lateVerifyTraitsImpl(op, (TraitTupleT *)nullptr); +} + /// A trait verifier may specify a set of trait verifiers that need to /// be run before itself. The dependent traits are labeled in the trailing /// template arguments of TraitBase. The declaration order of traits in an Op @@ -1770,6 +1854,13 @@ info->attachInterface(); } + LogicalResult verify() const { return success(); } + /// This is used to verify the things specified in ODS such as the types, + /// attrs, .etc. + LogicalResult verifyInvariants() const { return success(); } + LogicalResult earlyVerify() const { return success(); } + LogicalResult lateVerify() const { return success(); } + private: /// Trait to check if T provides a 'fold' method for a single result op. template @@ -1800,6 +1891,16 @@ typename detail::FilterTypes...>::type; + using VerifiableCoreTraitsTupleT = typename detail::FilterTypes< + op_definition_impl::detect_has_verify_core_trait, + Traits...>::type; + using EarlyVerifiableTraitsTupleT = typename detail::FilterTypes< + op_definition_impl::detect_has_early_verify_trait, + Traits...>::type; + using LateVerifiableTraitsTupleT = typename detail::FilterTypes< + op_definition_impl::detect_has_late_verify_trait, + Traits...>::type; + /// Returns an interface map containing the interfaces registered to this /// operation. static detail::InterfaceMap getInterfaceMap() { @@ -1928,9 +2029,13 @@ OpState::printOpName(op, p, defaultDialect); return cast(op).print(p); } - /// Implementation of `VerifyInvariantsFn` OperationName hook. - static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() { - return &verifyInvariants; + + static OperationName::VerificationFn getEarlyVerifyFn() { + return &earlyVerifyOps; + } + + static OperationName::VerificationFn getLateVerifyFn() { + return &lateVerifyOps; } static constexpr bool hasNoDataMembers() { @@ -1940,12 +2045,42 @@ return sizeof(ConcreteType) == sizeof(EmptyOp); } - static LogicalResult verifyInvariants(Operation *op) { + static LogicalResult earlyVerifyOps(Operation *op) { static_assert(hasNoDataMembers(), "Op class shouldn't define new data members"); return failure( + // This will verify the operation structure such as number of operands. + // After this, the access to the field of operation is supposed to be + // safe. + failed(op_definition_impl::verifyCoreTraits( + op)) || + // This will verify the properties specified in ODS, such as types, + // attrs, .etc. + failed(cast(op).verifyInvariants()) || + // TODO: This is going to be deprecated. failed(op_definition_impl::verifyTraits(op)) || - failed(cast(op).verify())); + // The traits/interfaces which specify a verifier that needs to be run + // before nested operations. + failed( + op_definition_impl::earlyVerifyTraits( + op)) || + // TODO: This calls the user defined verifier. It's going to be + // deprecated. + failed(cast(op).verify()) || + // The custom verifier which needs to be run before the nested + // operation. + failed(cast(op).earlyVerify())); + } + + static LogicalResult lateVerifyOps(Operation *op) { + return failure( + // The traits/interfaces which specify a verifier that needs to be run + // after nested operations. + failed(op_definition_impl::lateVerifyTraits( + op)) || + // The custom verifier which needs to be run after the nested + // operations. + failed(cast(op).lateVerify())); } /// Allow access to internal implementation methods. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -73,7 +73,7 @@ llvm::unique_function; using PrintAssemblyFn = llvm::unique_function; - using VerifyInvariantsFn = + using VerificationFn = llvm::unique_function; protected: @@ -113,7 +113,8 @@ HasTraitFn hasTraitFn; ParseAssemblyFn parseAssemblyFn; PrintAssemblyFn printAssemblyFn; - VerifyInvariantsFn verifyInvariantsFn; + VerificationFn earlyVerifyFn; + VerificationFn lateVerifyFn; /// A list of attribute names registered to this operation in StringAttr /// form. This allows for operation classes to use StringAttr for attribute @@ -242,7 +243,7 @@ static void insert(Dialect &dialect) { insert(T::getOperationName(), dialect, TypeID::get(), T::getParseAssemblyFn(), T::getPrintAssemblyFn(), - T::getVerifyInvariantsFn(), T::getFoldHookFn(), + T::getEarlyVerifyFn(), T::getLateVerifyFn(), T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames()); } @@ -251,7 +252,8 @@ static void insert(StringRef name, Dialect &dialect, TypeID typeID, ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, - VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, + VerificationFn &&earyVerifyFn, VerificationFn &&lateVerifyFn, + FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, ArrayRef attrNames); @@ -276,11 +278,14 @@ return impl->printAssemblyFn(op, p, defaultDialect); } - /// This hook implements the verifier for this operation. It should emits an + /// These hooks implement the verifier for this operation. It should emits an /// error message and returns failure if a problem is detected, or returns /// success if everything is ok. - LogicalResult verifyInvariants(Operation *op) const { - return impl->verifyInvariantsFn(op); + LogicalResult earlyVerification(Operation *op) const { + return impl->earlyVerifyFn(op); + } + LogicalResult lateVerification(Operation *op) const { + return impl->lateVerifyFn(op); } /// This hook implements a generalized folder for this operation. Operations diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -337,7 +337,7 @@ template class SymbolTable : public TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult lateVerifyTrait(Operation *op) { return ::mlir::detail::verifySymbolTable(op); } diff --git a/mlir/include/mlir/Parser.h b/mlir/include/mlir/Parser.h --- a/mlir/include/mlir/Parser.h +++ b/mlir/include/mlir/Parser.h @@ -15,6 +15,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Verifier.h" #include namespace llvm { @@ -67,7 +68,7 @@ // After splicing, verify just this operation to ensure it can properly // contain the operations inside of it. - if (failed(op.verify())) + if (failed(verify(op))) return OwningOpRef(); return opRef; } diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -95,6 +95,10 @@ // Return the verify method body if it has one. llvm::Optional getVerify() const; + // Return the {early|late} verify method body if it has one. + llvm::Optional getEarlyVerifier() const; + llvm::Optional getLateVerifier() const; + llvm::ArrayRef getDependentTraits() const; // Returns the Tablegen definition this interface was constructed from. diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -677,7 +677,8 @@ void RegisteredOperationName::insert( StringRef name, Dialect &dialect, TypeID typeID, ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, - VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, + VerificationFn &&earlyVerifyFn, VerificationFn &&lateVerifyFn, + FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, ArrayRef attrNames) { @@ -720,7 +721,8 @@ impl.hasTraitFn = std::move(hasTrait); impl.parseAssemblyFn = std::move(parseAssembly); impl.printAssemblyFn = std::move(printAssembly); - impl.verifyInvariantsFn = std::move(verifyInvariants); + impl.earlyVerifyFn = std::move(earlyVerifyFn); + impl.lateVerifyFn = std::move(lateVerifyFn); impl.attributeNames = cachedAttrNames; } diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -178,7 +178,7 @@ // If we can get operation info for this, check the custom hook. OperationName opName = op.getName(); Optional registeredInfo = opName.getRegisteredInfo(); - if (registeredInfo && failed(registeredInfo->verifyInvariants(&op))) + if (registeredInfo && failed(registeredInfo->earlyVerification(&op))) return failure(); if (unsigned numRegions = op.getNumRegions()) { @@ -218,8 +218,11 @@ } // If this is a registered operation, there is nothing left to do. - if (registeredInfo) + if (registeredInfo) { + if (failed(registeredInfo->lateVerification(&op))) + return failure(); return success(); + } // Otherwise, verify that the parent dialect allows un-registered operations. Dialect *dialect = opName.getDialect(); diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -145,6 +145,22 @@ return value.empty() ? llvm::Optional() : value; } +// Return the body for this method if it has one. +llvm::Optional Interface::getEarlyVerifier() const { + // Only OpInterface supports the verify method. + if (!isa(this)) + return llvm::None; + return def->getValueAsOptionalString("earlyVerifier"); +} + +// Return the body for this method if it has one. +llvm::Optional Interface::getLateVerifier() const { + // Only OpInterface supports the verify method. + if (!isa(this)) + return llvm::None; + return def->getValueAsOptionalString("lateVerifier"); +} + llvm::ArrayRef Interface::getDependentTraits() const { return dependentTraits; } diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir --- a/mlir/test/Dialect/Arithmetic/invalid.mlir +++ b/mlir/test/Dialect/Arithmetic/invalid.mlir @@ -285,7 +285,7 @@ // ----- func @index_cast_float(%arg0: index, %arg1: f32) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op result #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}} %0 = arith.index_cast %arg0 : index to f32 return } @@ -293,7 +293,7 @@ // ----- func @index_cast_float_to_index(%arg0: f32) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}} %0 = arith.index_cast %arg0 : f32 to index return } @@ -301,7 +301,7 @@ // ----- func @sitofp_i32_to_i64(%arg0 : i32) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op result #0 must be floating-point-like, but got 'i64'}} %0 = arith.sitofp %arg0 : i32 to i64 return } @@ -309,7 +309,7 @@ // ----- func @sitofp_f32_to_i32(%arg0 : f32) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'f32'}} %0 = arith.sitofp %arg0 : f32 to i32 return } @@ -333,7 +333,7 @@ // ----- func @fpext_i32_to_f32(%arg0 : i32) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op operand #0 must be floating-point-like, but got 'i32'}} %0 = arith.extf %arg0 : i32 to f32 return } @@ -341,7 +341,7 @@ // ----- func @fpext_f32_to_i32(%arg0 : f32) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op result #0 must be floating-point-like, but got 'i32'}} %0 = arith.extf %arg0 : f32 to i32 return } @@ -373,7 +373,7 @@ // ----- func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op operand #0 must be floating-point-like, but got 'vector<2xi32>'}} %0 = arith.extf %arg0 : vector<2xi32> to vector<2xf32> return } @@ -381,7 +381,7 @@ // ----- func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op result #0 must be floating-point-like, but got 'vector<2xi32>'}} %0 = arith.extf %arg0 : vector<2xf32> to vector<2xi32> return } @@ -405,7 +405,7 @@ // ----- func @fptrunc_i32_to_f32(%arg0 : i32) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op operand #0 must be floating-point-like, but got 'i32'}} %0 = arith.truncf %arg0 : i32 to f32 return } @@ -413,7 +413,7 @@ // ----- func @fptrunc_f32_to_i32(%arg0 : f32) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op result #0 must be floating-point-like, but got 'i32'}} %0 = arith.truncf %arg0 : f32 to i32 return } @@ -445,7 +445,7 @@ // ----- func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op operand #0 must be floating-point-like, but got 'vector<2xi32>'}} %0 = arith.truncf %arg0 : vector<2xi32> to vector<2xf32> return } @@ -453,7 +453,7 @@ // ----- func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op result #0 must be floating-point-like, but got 'vector<2xi32>'}} %0 = arith.truncf %arg0 : vector<2xf32> to vector<2xi32> return } @@ -461,7 +461,7 @@ // ----- func @sexti_index_as_operand(%arg0 : index) { - // expected-error@+1 {{are cast incompatible}} + // expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}} %0 = arith.extsi %arg0 : index to i128 return } @@ -469,7 +469,7 @@ // ----- func @zexti_index_as_operand(%arg0 : index) { - // expected-error@+1 {{operand type 'index' and result type}} + // expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}} %0 = arith.extui %arg0 : index to i128 return } @@ -477,7 +477,7 @@ // ----- func @trunci_index_as_operand(%arg0 : index) { - // expected-error@+1 {{operand type 'index' and result type}} + // expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}} %2 = arith.trunci %arg0 : index to i128 return } @@ -485,7 +485,7 @@ // ----- func @sexti_index_as_result(%arg0 : i1) { - // expected-error@+1 {{result type 'index' are cast incompatible}} + // expected-error@+1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}} %0 = arith.extsi %arg0 : i1 to index return } @@ -493,7 +493,7 @@ // ----- func @zexti_index_as_operand(%arg0 : i1) { - // expected-error@+1 {{result type 'index' are cast incompatible}} + // expected-error@+1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}} %0 = arith.extui %arg0 : i1 to index return } @@ -501,7 +501,7 @@ // ----- func @trunci_index_as_result(%arg0 : i128) { - // expected-error@+1 {{result type 'index' are cast incompatible}} + // expected-error@+1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}} %2 = arith.trunci %arg0 : i128 to index return } diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -80,8 +80,8 @@ // ----- -// expected-error @+1 {{requires string attribute 'sym_name'}} -"llvm.mlir.global"() ({}) {type = i64, constant, value = 42 : i64} : () -> () +// expected-error @+1 {{op requires attribute 'sym_name'}} +"llvm.mlir.global"() ({}) {type = i64, constant, global_type = i64, value = 42 : i64} : () -> () // ----- diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -214,15 +214,15 @@ // ----- -func @generic_scalar_operand_block_arg_type(%arg0: f32) { +func @generic_scalar_operand_block_arg_type(%arg0: tensor) { // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}} linalg.generic { indexing_maps = [ affine_map<() -> ()> ], iterator_types = []} - outs(%arg0 : f32) { + outs(%arg0 : tensor) { ^bb(%i: i1): linalg.yield %i : i1 - } + } -> tensor } // ----- @@ -243,7 +243,7 @@ func @generic_result_tensor_type(%arg0: memref(off + i)>>, %arg1: tensor) { - // expected-error @+1 {{expected type of operand #1 ('tensor') to match type of corresponding result ('f32')}} + // expected-error @+1 {{expected type of operand #1 ('tensor') to match type of corresponding result ('tensor')}} %0 = linalg.generic { indexing_maps = [ affine_map<(i) -> (i)> , affine_map<(i) -> (i)> ], iterator_types = ["parallel"]} @@ -251,7 +251,7 @@ outs(%arg1 : tensor) { ^bb(%i: f32, %j: f32): linalg.yield %i: f32 - } -> f32 + } -> tensor } // ----- @@ -427,11 +427,11 @@ // ----- -func @illegal_fill_memref_with_return(%arg0 : memref, %arg1 : f32) -> memref +func @illegal_fill_memref_with_return(%arg0 : memref, %arg1 : f32) -> tensor { - // expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}} - %0 = linalg.fill(%arg1, %arg0) : f32, memref -> memref - return %0 : memref + // expected-error @+1 {{op expected the number of results (1) to be equal to the number of output tensors (0)}} + %0 = linalg.fill(%arg1, %arg0) : f32, memref -> tensor + return %0 : tensor } // ----- @@ -449,7 +449,7 @@ func @illegal_fill_tensor_with_memref_return (%arg0 : tensor, %arg1 : f32) -> memref { - // expected-error @+1 {{expected type of operand #1 ('tensor') to match type of corresponding result ('memref')}} + // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'memref'}} %0 = linalg.fill(%arg1, %arg0) : f32, tensor -> memref return %0 : memref } diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -98,7 +98,7 @@ // ----- func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{missing indexing map required attribute 'strides'}} + // expected-error @+1 {{op requires attribute 'strides'}} linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>} ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%output: memref<1x56x56x96xf32>) @@ -108,7 +108,7 @@ // ----- func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{missing indexing map required attribute 'dilations'}} + // expected-error @+1 {{op requires attribute 'dilations'}} linalg.depthwise_conv_2d_nhwc_hwc {strides = dense<1> : vector<2xi64>} ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%output: memref<1x56x56x96xf32>) @@ -118,7 +118,7 @@ // ----- func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}} + // expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}} linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>} ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%output: memref<1x56x56x96xf32>) @@ -128,7 +128,7 @@ // ----- func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}} + // expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}} linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> } ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%output: memref<1x56x56x96xf32>) diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -59,7 +59,7 @@ // ----- func @bit_field_u_extract_invalid_result_type(%base: vector<3xi32>, %offset: i32, %count: i16) -> vector<4xi32> { - // expected-error @+1 {{inferred type(s) 'vector<3xi32>' are incompatible with return type(s) of operation 'vector<4xi32>'}} + // expected-error @+1 {{failed to verify that all of {base, result} have same type}} %0 = "spv.BitFieldUExtract" (%base, %offset, %count) : (vector<3xi32>, i32, i16) -> vector<4xi32> spv.ReturnValue %0 : vector<4xi32> } diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -98,8 +98,8 @@ // ----- func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) { - // expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}} - %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32> + // expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xindex>'}} + %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xindex> return } diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir --- a/mlir/test/Dialect/traits.mlir +++ b/mlir/test/Dialect/traits.mlir @@ -58,7 +58,7 @@ // Check incompatible vector and tensor result type func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> { ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): - // expected-error @+1 {{cannot broadcast vector with tensor}} + // expected-error @+1 {{op result #0 must be tensor of any type values, but got 'vector<4xf32>'}} %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> return %0 : vector<4xf32> } diff --git a/mlir/test/IR/invalid-module-op.mlir b/mlir/test/IR/invalid-module-op.mlir --- a/mlir/test/IR/invalid-module-op.mlir +++ b/mlir/test/IR/invalid-module-op.mlir @@ -3,7 +3,7 @@ // ----- func @module_op() { - // expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}} + // expected-error@+1 {{'builtin.module' op expects region #0 to have 0 or 1 blocks}} builtin.module { ^bb1: "test.dummy"() : () -> () diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -332,12 +332,12 @@ // Test the invariants of operations with the Symbol Trait. -// expected-error@+1 {{requires string attribute 'sym_name'}} +// expected-error@+1 {{op requires attribute 'sym_name'}} "test.symbol"() {} : () -> () // ----- -// expected-error@+1 {{requires visibility attribute 'sym_visibility' to be a string attribute}} +// expected-error@+1 {{op attribute 'sym_visibility' failed to satisfy constraint: string attribute}} "test.symbol"() {sym_name = "foo_2", sym_visibility} : () -> () // ----- @@ -363,8 +363,8 @@ // ----- -// Test that operation with the SymbolTable Trait fails with too many blocks. -// expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}} +// Test that operation with the SymbolTable Trait fails with too many blocks. +// expected-error@+1 {{op expects region #0 to have 0 or 1 blocks}} "test.symbol_scope"() ({ ^entry: "test.finish" () : () -> () @@ -599,4 +599,4 @@ // expected-error@+1 {{'attr' attribute should have trait 'TestAttrTrait'}} "test.attr_with_trait"() {attr = 42 : i32} : () -> () return -} \ No newline at end of file +} diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -63,7 +63,7 @@ // Test verify method // --- -// DEF: ::mlir::LogicalResult AOpAdaptor::verify +// DEF: ::mlir::LogicalResult AOpAdaptor::verifyInvariants // DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); // DEF-NEXT: if (!tblgen_aAttr) // DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); @@ -177,7 +177,7 @@ // Test verify method // --- -// DEF: ::mlir::LogicalResult AgetOpAdaptor::verify +// DEF: ::mlir::LogicalResult AgetOpAdaptor::verifyInvariants // DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); // DEF-NEXT: if (!tblgen_aAttr) // DEF-NEXT. return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'"); @@ -270,7 +270,7 @@ // Test common attribute kinds' constraints // --- -// DEF-LABEL: BOpAdaptor::verify +// DEF-LABEL: BOpAdaptor::verifyInvariants // DEF: if (tblgen_any_attr && !((true))) // DEF: if (tblgen_bool_attr && !((tblgen_bool_attr.isa<::mlir::BoolAttr>()))) // DEF: if (tblgen_i32_attr && !(((tblgen_i32_attr.isa<::mlir::IntegerAttr>())) && ((tblgen_i32_attr.cast<::mlir::IntegerAttr>().getType().isSignlessInteger(32))))) diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -41,6 +41,8 @@ let parser = [{ foo }]; let printer = [{ bar }]; let verifier = [{ baz }]; + let earlyVerifier = [{ earlyBaz }]; + let lateVerifier = [{ lateBaz }]; let hasCanonicalizer = 1; let hasFolder = 1; @@ -96,7 +98,10 @@ // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) // CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); // CHECK: void print(::mlir::OpAsmPrinter &p); +// CHECK: ::mlir::LogicalResult verifyInvariants(); // CHECK: ::mlir::LogicalResult verify(); +// CHECK: ::mlir::LogicalResult earlyVerify(); +// CHECK: ::mlir::LogicalResult lateVerify(); // CHECK: static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); // CHECK: ::mlir::LogicalResult fold(::llvm::ArrayRef<::mlir::Attribute> operands, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results); // CHECK: // Display a graph for debugging purposes. diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir --- a/mlir/test/mlir-tblgen/types.mlir +++ b/mlir/test/mlir-tblgen/types.mlir @@ -58,7 +58,7 @@ // ----- func @complex_f64_failure() { - // expected-error@+1 {{op inferred type(s) 'complex' are incompatible with return type(s) of operation 'f64'}} + // expected-error@+1 {{op result #0 must be complex type with 64-bit float elements, but got 'f64'}} "test.complex_f64"() : () -> (f64) return } @@ -438,7 +438,7 @@ // ----- func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) { - // expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<*xf32>'}} + // expected-error@+1 {{op failed to verify that all of {x, res} have same type}} "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32> return } @@ -446,7 +446,7 @@ // ----- func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) { - // expected-error@+1 {{op inferred type(s) 'tensor<1x2xi32>' are incompatible with return type(s) of operation 'tensor<2x1xi32>'}} + // expected-error@+1 {{op failed to verify that all of {x, res} have same type}} "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x1xi32> return } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -396,6 +396,9 @@ // Generates verify method for the operation. void genVerifier(); + // Generates custom verify method for the operation. + void genCustomVerifier(); + // Generates verify statements for operands and results in the operation. // The generated code will be attached to `body`. void genOperandResultVerifier(MethodBody &body, Operator::value_range values, @@ -586,6 +589,7 @@ genParser(); genPrinter(); genVerifier(); + genCustomVerifier(); genCanonicalizerDecls(); genFolderDecls(); genTypeInterfaceMethods(); @@ -2199,17 +2203,15 @@ } } +// TODO: Change the function name from 'verify' to 'verifyInvariant'. void OpEmitter::genVerifier() { - auto *method = opClass.addMethod("::mlir::LogicalResult", "verify"); - ERROR_IF_PRUNED(method, "verify", op); + auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants"); + ERROR_IF_PRUNED(method, "verifyInvariants", op); auto &body = method->body(); OpOrAdaptorHelper emitHelper(op, /*isOp=*/true); genNativeTraitAttrVerifier(body, emitHelper); - auto *valueInit = def.getValueInit("verifier"); - StringInit *stringInit = dyn_cast(valueInit); - bool hasCustomVerify = stringInit && !stringInit->getValue().empty(); populateSubstitutions(emitHelper, verifyCtx); genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); @@ -2228,14 +2230,29 @@ genRegionVerifier(body); genSuccessorVerifier(body); - if (hasCustomVerify) { + body << " return ::mlir::success();\n"; +} + +void OpEmitter::genCustomVerifier() { + auto emitCustomVerifier = [&](StringRef verifier, StringRef methodName) { + StringInit *stringInit = dyn_cast(def.getValueInit(verifier)); + if (!stringInit || stringInit->getValue().empty()) { + return; + } + + auto *method = opClass.addMethod("::mlir::LogicalResult", methodName); + ERROR_IF_PRUNED(method, methodName, op); + FmtContext fctx; fctx.addSubst("cppClass", opClass.getClassName()); auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r"); - body << " " << tgfmt(printer, &fctx); - } else { - body << " return ::mlir::success();\n"; - } + method->body() << " " << tgfmt(printer, &fctx); + }; + + // TODO: `verifier` is going to be deprecated. + emitCustomVerifier("verifier", "verify"); + emitCustomVerifier("earlyVerifier", "earlyVerify"); + emitCustomVerifier("lateVerifier", "lateVerify"); } void OpEmitter::genOperandResultVerifier(MethodBody &body, @@ -2665,9 +2682,9 @@ } void OpOperandAdaptorEmitter::addVerification() { - auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify", + auto *method = adaptor.addMethod("::mlir::LogicalResult", "verifyInvariants", MethodParameter("::mlir::Location", "loc")); - ERROR_IF_PRUNED(method, "verify", op); + ERROR_IF_PRUNED(method, "verifyInvariants", op); auto &body = method->body(); OpOrAdaptorHelper emitHelper(op, /*isOp=*/false); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -388,10 +388,13 @@ bool isOpInterface = isa(interface); for (auto &method : interface.getMethods()) { // Flag interface methods named verifyTrait. - if (method.getName() == "verifyTrait") + if (method.getName() == "verifyTrait" || + method.getName() == "earlyVerifyTrait" || + method.getName() == "lateVerifyTrait") PrintFatalError( formatv("'verifyTrait' method cannot be specified as interface " - "method for '{0}'; use the 'verify' field instead", + "method for '{0}'; use the 'earlyVerifier' or 'lateVerifier' " + "field instead", interfaceName)); auto defaultImpl = method.getDefaultImplementation(); if (!defaultImpl) @@ -414,6 +417,26 @@ "{\n " << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n }\n"; } + if (auto verify = interface.getEarlyVerifier()) { + assert(isa(interface) && + "only OpInterface supports 'earlyVerifier'"); + + tblgen::FmtContext verifyCtx; + verifyCtx.withOp("op"); + os << " static ::mlir::LogicalResult earlyVerifyTrait(::mlir::Operation " + "*op) {\n " + << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n }\n"; + } + if (auto verify = interface.getLateVerifier()) { + assert(isa(interface) && + "only OpInterface supports 'lateVerifier'"); + + tblgen::FmtContext verifyCtx; + verifyCtx.withOp("op"); + os << " static ::mlir::LogicalResult lateVerifyTrait(::mlir::Operation " + "*op) {\n " + << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n }\n"; + } if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";