diff --git a/llvm/include/llvm/Support/Casting.h b/llvm/include/llvm/Support/Casting.h --- a/llvm/include/llvm/Support/Casting.h +++ b/llvm/include/llvm/Support/Casting.h @@ -132,24 +132,31 @@ } }; -// isa<X> - Return true if the parameter to the template is an instance of the -// template type argument. Used like this: +// isa<X> - Return true if the parameter to the template is an instance of one +// of the template type argument. Used like this: // // if (isa<Type>(myVal)) { ... } +// if (isa<Type0, Type1, Type2>(myVal)) { ... } // template <class X, class Y> LLVM_NODISCARD inline bool isa(const Y &Val) { return isa_impl_wrap<X, const Y, typename simplify_type<const Y>::SimpleType>::doit(Val); } +template <typename First, typename... Rest, typename Y> +LLVM_NODISCARD inline typename std::enable_if<sizeof...(Rest) != 0, bool>::type +isa(const Y &Val) { + return isa<First>(Val) || isa<Rest...>(Val); +} + // isa_and_nonnull<X> - Functionally identical to isa, except that a null value // is accepted. // -template <class X, class Y> +template <typename... X, class Y> LLVM_NODISCARD inline bool isa_and_nonnull(const Y &Val) { if (!Val) return false; - return isa<X>(Val); + return isa<X...>(Val); } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1686,6 +1686,9 @@ class HasParent<string op> : ParamNativeOpTrait<"HasParent", op>; +class ParentOneOf<list<string> ops> + : ParamNativeOpTrait<"HasParent", StrJoin<ops>.result>; + // Op result type is derived from the first attribute. If the attribute is an // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the // attribute content is used. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1139,16 +1139,24 @@ }; }; -/// This class provides a verifier for ops that are expecting a specific parent. -template <typename ParentOpType> struct HasParent { +/// This class provides a verifier for ops that are expecting their parent +/// to be one of the given parent ops +template <typename... ParentOpTypes> +struct HasParent { template <typename ConcreteType> class Impl : public TraitBase<ConcreteType, Impl> { public: static LogicalResult verifyTrait(Operation *op) { - if (isa<ParentOpType>(op->getParentOp())) + if (llvm::isa<ParentOpTypes...>(op->getParentOp())) return success(); - return op->emitOpError() << "expects parent op '" - << ParentOpType::getOperationName() << "'"; + + InFlightDiagnostic diag = op->emitOpError(); + diag << "expects parent op " + << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'"); + llvm::interleaveComma( + llvm::makeArrayRef({ParentOpTypes::getOperationName()...}), diag); + diag << "'"; + return diag; } }; }; diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -173,6 +173,39 @@ }) : () -> () } +// ----- + +// CHECK: succeededParentOneOf +func @succeededParentOneOf() { + "test.parent"() ({ + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + +// ----- + +// CHECK: succeededParent1OneOf +func @succeededParent1OneOf() { + "test.parent1"() ({ + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + +// ----- + +func @failedParentOneOf_wrong_parent1() { + "some.otherop"() ({ + // expected-error@+1 {{'test.child_with_parent_one_of' op expects parent op to be one of 'test.parent, test.parent1'}} + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () +} + + // ----- func @failedSingleBlockImplicitTerminator_empty_block() { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -439,10 +439,18 @@ let results = (outs AnyTensor); } -// There the "HasParent" trait. -def ParentOp : TEST_Op<"parent">; +// HasParent trait +def ParentOp : TEST_Op<"parent"> { + let regions = (region AnyRegion); +} def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>; +// ParentOneOf trait +def ParentOp1 : TEST_Op<"parent1"> { + let regions = (region AnyRegion); +} +def ChildWithParentOneOf : TEST_Op<"child_with_parent_one_of", + [ParentOneOf<["ParentOp", "ParentOp1"]>]>; def TerminatorOp : TEST_Op<"finish", [Terminator]>; def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",