diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1684,6 +1684,9 @@ class HasParent : ParamNativeOpTrait<"HasParent", op>; +class ParentOneOf ops> + : ParamNativeOpTrait<"HasParent", StrJoin.result>; + // Op result type is derived from the first attribute. If the attribute is an // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the // attribute content is used. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1145,16 +1145,24 @@ }; }; -/// This class provides a verifier for ops that are expecting a specific parent. -template struct HasParent { +/// This class provides a verifier for ops that are expecting their parent +/// to be one of the given parent ops +template +struct HasParent { template class Impl : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { - if (isa(op->getParentOp())) + if (llvm::isa(op->getParentOp())) return success(); - return op->emitOpError() << "expects parent op '" - << ParentOpType::getOperationName() << "'"; + + InFlightDiagnostic diag = op->emitOpError(); + diag << "expects parent op " + << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'"); + llvm::interleaveComma( + llvm::makeArrayRef({ParentOpTypes::getOperationName()...}), diag); + diag << "'"; + return diag; } }; }; diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -173,6 +173,39 @@ }) : () -> () } +// ----- + +// CHECK: succeededParentOneOf +func @succeededParentOneOf() { + "test.parent"() ({ + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + +// ----- + +// CHECK: succeededParent1OneOf +func @succeededParent1OneOf() { + "test.parent1"() ({ + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + +// ----- + +func @failedParentOneOf_wrong_parent1() { + "some.otherop"() ({ + // expected-error@+1 {{'test.child_with_parent_one_of' op expects parent op to be one of 'test.parent, test.parent1'}} + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () +} + + // ----- func @failedSingleBlockImplicitTerminator_empty_block() { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -439,10 +439,18 @@ let results = (outs AnyTensor); } -// There the "HasParent" trait. -def ParentOp : TEST_Op<"parent">; +// HasParent trait +def ParentOp : TEST_Op<"parent"> { + let regions = (region AnyRegion); +} def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>; +// ParentOneOf trait +def ParentOp1 : TEST_Op<"parent1"> { + let regions = (region AnyRegion); +} +def ChildWithParentOneOf : TEST_Op<"child_with_parent_one_of", + [ParentOneOf<["ParentOp", "ParentOp1"]>]>; def TerminatorOp : TEST_Op<"finish", [Terminator]>; def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",