diff --git a/llvm/include/llvm/Support/Casting.h b/llvm/include/llvm/Support/Casting.h --- a/llvm/include/llvm/Support/Casting.h +++ b/llvm/include/llvm/Support/Casting.h @@ -132,24 +132,31 @@ } }; -// isa - Return true if the parameter to the template is an instance of the -// template type argument. Used like this: +// isa - Return true if the parameter to the template is an instance of one +// of the template type argument. Used like this: // // if (isa(myVal)) { ... } +// if (isa(myVal)) { ... } // template LLVM_NODISCARD inline bool isa(const Y &Val) { return isa_impl_wrap::SimpleType>::doit(Val); } +template +LLVM_NODISCARD inline typename std::enable_if::type +isa(const Y &Val) { + return isa(Val) || isa(Val); +} + // isa_and_nonnull - Functionally identical to isa, except that a null value // is accepted. // -template +template LLVM_NODISCARD inline bool isa_and_nonnull(const Y &Val) { if (!Val) return false; - return isa(Val); + return isa(Val); } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -3860,7 +3860,8 @@ if (I != HasRecMap.end()) return I->second; - bool FoundAddRec = SCEVExprContains(S, isa); + bool FoundAddRec = + SCEVExprContains(S, [](const SCEV *S) { return isa(S); }); HasRecMap.insert({S, FoundAddRec}); return FoundAddRec; } @@ -11203,9 +11204,12 @@ // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. static inline bool containsParameters(SmallVectorImpl &Terms) { - for (const SCEV *T : Terms) - if (SCEVExprContains(T, isa)) + for (const SCEV *T : Terms) { + bool FoundTerm = SCEVExprContains( + T, [](const SCEV *S) { return isa(S); }); + if (FoundTerm) return true; + } return false; } diff --git a/llvm/lib/Target/PowerPC/PPCBoolRetToInt.cpp b/llvm/lib/Target/PowerPC/PPCBoolRetToInt.cpp --- a/llvm/lib/Target/PowerPC/PPCBoolRetToInt.cpp +++ b/llvm/lib/Target/PowerPC/PPCBoolRetToInt.cpp @@ -220,7 +220,7 @@ auto Defs = findAllDefs(U); // If the values are all Constants or Arguments, don't bother - if (llvm::none_of(Defs, isa)) + if (llvm::none_of(Defs, [](Value *V) { return isa(V); })) return false; // Presently, we only know how to handle PHINode, Constant, Arguments and diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1686,6 +1686,9 @@ class HasParent : ParamNativeOpTrait<"HasParent", op>; +class ParentOneOf ops> + : ParamNativeOpTrait<"HasParent", StrJoin.result>; + // Op result type is derived from the first attribute. If the attribute is an // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the // attribute content is used. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1139,16 +1139,24 @@ }; }; -/// This class provides a verifier for ops that are expecting a specific parent. -template struct HasParent { +/// This class provides a verifier for ops that are expecting their parent +/// to be one of the given parent ops +template +struct HasParent { template class Impl : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { - if (isa(op->getParentOp())) + if (llvm::isa(op->getParentOp())) return success(); - return op->emitOpError() << "expects parent op '" - << ParentOpType::getOperationName() << "'"; + + InFlightDiagnostic diag = op->emitOpError(); + diag << "expects parent op " + << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'"); + llvm::interleaveComma( + llvm::makeArrayRef({ParentOpTypes::getOperationName()...}), diag); + diag << "'"; + return diag; } }; }; diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -173,6 +173,39 @@ }) : () -> () } +// ----- + +// CHECK: succeededParentOneOf +func @succeededParentOneOf() { + "test.parent"() ({ + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + +// ----- + +// CHECK: succeededParent1OneOf +func @succeededParent1OneOf() { + "test.parent1"() ({ + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + +// ----- + +func @failedParentOneOf_wrong_parent1() { + "some.otherop"() ({ + // expected-error@+1 {{'test.child_with_parent_one_of' op expects parent op to be one of 'test.parent, test.parent1'}} + "test.child_with_parent_one_of"() : () -> () + "test.finish"() : () -> () + }) : () -> () +} + + // ----- func @failedSingleBlockImplicitTerminator_empty_block() { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -439,10 +439,18 @@ let results = (outs AnyTensor); } -// There the "HasParent" trait. -def ParentOp : TEST_Op<"parent">; +// HasParent trait +def ParentOp : TEST_Op<"parent"> { + let regions = (region AnyRegion); +} def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>; +// ParentOneOf trait +def ParentOp1 : TEST_Op<"parent1"> { + let regions = (region AnyRegion); +} +def ChildWithParentOneOf : TEST_Op<"child_with_parent_one_of", + [ParentOneOf<["ParentOp", "ParentOp1"]>]>; def TerminatorOp : TEST_Op<"finish", [Terminator]>; def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",