diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1897,6 +1897,10 @@ // Op attached regions have no arguments def NoRegionArguments : NativeOpTrait<"NoRegionArguments">; +// Verification using the generated OpAdaptor's verify method. This verification +// is done in the generated verify method unless this trait is specified. +def OpAdaptorVerifier : NativeOpTrait<"OpAdaptorVerifier">; + //===----------------------------------------------------------------------===// // OpInterface definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -291,6 +291,26 @@ } }; +//===----------------------------------------------------------------------===// +// Traits spanning Attributes, Operands & Regions + +/// This class provides verification for ops using their OpAdaptor. OpAdaptor +/// is created automatically for ops defined via ODS. The OpAdaptor verify +/// method verifies aspects of the op such as required named attributes and +/// number of operands. +/// +/// This verification is normally done post all other verifications unless this +/// trait is used. +template +class OpAdaptorVerifier : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return typename ConcreteType::Adaptor( + op->getOperands(), op->getAttrDictionary(), op->getRegions()) + .verify(op->getLoc()); + } +}; + //===----------------------------------------------------------------------===// // Operand Traits diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir --- a/mlir/test/Dialect/traits.mlir +++ b/mlir/test/Dialect/traits.mlir @@ -151,3 +151,11 @@ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> return %0 : tensor<4xi1> } + +// ----- + +func @adaptor_verifier(%arg0: tensor<2x2xi32>) -> tensor<4x7xi1> { + // expected-error@+1 {{op requires attribute 'names'}} + %0 = "test.op_adaptor_verifier"(%arg0) : (tensor<2x2xi32>) -> tensor<4x7xi1> + return %0 : tensor<4x7xi1> +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -616,6 +616,15 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +def OpAdaptorVerifierOp : TEST_Op<"op_adaptor_verifier", + [ + OpAdaptorVerifier, + ResultsBroadcastableShape + ]> { + let arguments = (ins TensorOf<[F32]>:$arg1, StrArrayAttr:$names); + let results = (outs AnyTensor); +} + //===----------------------------------------------------------------------===// // Test Locations //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -274,6 +274,23 @@ // CHECK: class KWithTraitOp : public ::mlir::Op { + let arguments = (ins TensorOf<[F32]>:$arg1, StrArrayAttr:$names); + let results = (outs AnyTensor); +} + +// CHECK: class OpAdaptorVerifierOp +// CHECK: ::mlir::OpTrait::OpAdaptorVerifier, ::mlir::OpTrait::ResultsBroadcastableShape +// DEFS: ::mlir::LogicalResult OpAdaptorVerifierOp::verify() { +// DEFS-NOT: if (failed(OpAdaptorVerifierOpAdaptor + // Test that type defs have the proper namespaces when used as a constraint. // --- @@ -322,3 +339,4 @@ // REDUCE_EXC-NOT: NS::AOp declarations // REDUCE_EXC-LABEL: NS::BOp declarations + diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1950,9 +1950,13 @@ void OpEmitter::genVerifier() { auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify"); auto &body = method->body(); - body << " if (failed(" << op.getAdaptorName() - << "(*this).verify((*this)->getLoc()))) " - << "return ::mlir::failure();\n"; + // Verify using the OpAdaptor's verify method unless already verified earlier + // by OpAdaptorVerifier. + if (op.getTrait("::mlir::OpTrait::OpAdaptorVerifier") == nullptr) { + body << " if (failed(" << op.getAdaptorName() + << "(*this).verify((*this)->getLoc()))) " + << "return ::mlir::failure();\n"; + } auto *valueInit = def.getValueInit("verifier"); StringInit *stringInit = dyn_cast(valueInit);