diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -303,6 +303,17 @@ } }; +/// This class provides verification for ops using their OpAdaptor. +template +class OpAdaptorVerifier : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return typename ConcreteType::Adaptor( + op->getOperands(), op->getAttrDictionary(), op->getRegions()) + .verify(op->getLoc()); + } +}; + //===----------------------------------------------------------------------===// // Operand Traits diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir --- a/mlir/test/Dialect/traits.mlir +++ b/mlir/test/Dialect/traits.mlir @@ -151,3 +151,11 @@ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> return %0 : tensor<4xi1> } + +// ----- + +func @adaptor_verifier(%arg0: tensor<2x2xi32>) -> tensor<4x7xi1> { + // expected-error@+1 {{op requires attribute 'names'}} + %0 = "test.op_adaptor_verifier"(%arg0) : (tensor<2x2xi32>) -> tensor<4x7xi1> + return %0 : tensor<4x7xi1> +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -571,6 +571,16 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +def OpAdaptorVerifier : NativeOpTrait<"OpAdaptorVerifier">; +def OpAdaptorVerifierOp : TEST_Op<"op_adaptor_verifier", + [ + OpAdaptorVerifier, + ResultsBroadcastableShape + ]> { + let arguments = (ins TensorOf<[F32]>:$arg1, StrArrayAttr:$names); + let results = (outs AnyTensor); +} + //===----------------------------------------------------------------------===// // Test Locations //===----------------------------------------------------------------------===//