diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1686,6 +1686,9 @@ class HasParent : ParamNativeOpTrait<"HasParent", op>; +class ParentOneOf ops> + : ParamNativeOpTrait<"HasParent", StrJoin.result>; + // Op result type is derived from the first attribute. If the attribute is an // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the // attribute content is used. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1139,16 +1139,50 @@ }; }; -/// This class provides a verifier for ops that are expecting a specific parent. -template struct HasParent { +/// Helper functions for HasParent +namespace detail { +template +bool is_one_of(Operation *op) { + return isa(op); +} + +template +typename std::enable_if::type +is_one_of(Operation *op) { + return isa(op) || is_one_of(op); +} + +template +void enumerate(InFlightDiagnostic &d) { + d << ParentOpType::getOperationName(); +} + +template +typename std::enable_if::type +enumerate(InFlightDiagnostic &d) { + enumerate(d); + d << ", "; + enumerate(d); +} +} // end namespace detail + +/// This class provides a verifier for ops that are expecting their parent +/// to be one of the given parent ops +template +struct HasParent { template class Impl : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { - if (isa(op->getParentOp())) + if (detail::is_one_of(op->getParentOp())) return success(); - return op->emitOpError() << "expects parent op '" - << ParentOpType::getOperationName() << "'"; + + InFlightDiagnostic diag = op->emitOpError(); + diag << "expects parent op " + << (sizeof...(ParentOpTypes) != 1 ? "one of '" : "'"); + detail::enumerate(diag); + diag << "'"; + return diag; } }; }; diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -173,6 +173,39 @@ }) : () -> () } +// ----- + +// CHECK: succeededParentOneOf +func @succeededParentOneOf() { + "test.parent"() ({ + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + +// ----- + +// CHECK: succeededParent1OneOf +func @succeededParent1OneOf() { + "test.parent1"() ({ + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + +// ----- + +func @failedParentOneOf_wrong_parent1() { + "some.otherop"() ({ + // expected-error@+1 {{'test.child_with_parent_one_of' op expects parent op one of 'test.parent, test.parent1'}} + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () +} + + // ----- func @failedSingleBlockImplicitTerminator_empty_block() { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -439,10 +439,18 @@ let results = (outs AnyTensor); } -// There the "HasParent" trait. -def ParentOp : TEST_Op<"parent">; +// HasParent trait +def ParentOp : TEST_Op<"parent"> { + let regions = (region AnyRegion); +} def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>; +// ParentOneOf trait +def ParentOp1 : TEST_Op<"parent1"> { + let regions = (region AnyRegion); +} +def ChildWithParentOneOf : TEST_Op<"child_with_parent_one_of", + [ParentOneOf<["ParentOp", "ParentOp1"]>]>; def TerminatorOp : TEST_Op<"finish", [Terminator]>; def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",