diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -122,10 +122,13 @@ // Op's regions have a single block. def SingleBlock : NativeOpTrait<"SingleBlock">, StructuralOpTrait; +class SingleBlockImplicitTerminatorImpl + : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op, [SingleBlock]>, + StructuralOpTrait; + // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator - : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>, - StructuralOpTrait; + : TraitList<[SingleBlock, SingleBlockImplicitTerminatorImpl]>; // Op's regions don't have terminator. def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -942,10 +942,8 @@ /// block that must terminate with `TerminatorOpType`. template struct SingleBlockImplicitTerminator { - template - class Impl : public SingleBlock { + template class Impl { private: - using Base = SingleBlock; /// Builds a terminator operation without relying on OpBuilder APIs to avoid /// cyclic header inclusion. static Operation *buildTerminator(OpBuilder &builder, Location loc) { @@ -959,8 +957,6 @@ using ImplicitTerminatorOpT = TerminatorOpType; static LogicalResult verifyRegionTrait(Operation *op) { - if (failed(Base::verifyTrait(op))) - return failure(); for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { Region ®ion = op->getRegion(i); // Empty regions are fine. @@ -1002,7 +998,6 @@ //===------------------------------------------------------------------===// // Single Region Utilities //===------------------------------------------------------------------===// - using Base::getBody; template using enable_if_single_region = @@ -1011,7 +1006,8 @@ /// Insert the operation into the back of the body, before the terminator. template enable_if_single_region push_back(Operation *op) { - insert(Block::iterator(getBody()->getTerminator()), op); + Block *body = static_cast *>(this)->getBody(); + insert(Block::iterator(body->getTerminator()), op); } /// Insert the operation at the given insertion point. Note: The operation @@ -1024,7 +1020,7 @@ template enable_if_single_region insert(Block::iterator insertPt, Operation *op) { - auto *body = getBody(); + Block *body = static_cast *>(this)->getBody(); if (insertPt == body->end()) insertPt = Block::iterator(body->getTerminator()); body->getOperations().insert(insertPt, op); diff --git a/mlir/test/mlir-tblgen/gen-dialect-doc.td b/mlir/test/mlir-tblgen/gen-dialect-doc.td --- a/mlir/test/mlir-tblgen/gen-dialect-doc.td +++ b/mlir/test/mlir-tblgen/gen-dialect-doc.td @@ -81,7 +81,7 @@ // CHECK: Other group // CHECK: test.b // CHECK: test.c -// CHECK: Traits: SingleBlockImplicitTerminator +// CHECK: Traits: SingleBlock, SingleBlockImplicitTerminator // CHECK: Interfaces: NoMemoryEffect (MemoryEffectOpInterface) // CHECK: Effects: MemoryEffects::Effect{} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -313,11 +313,10 @@ resultTypes.resize(op.getNumResults(), TypeResolution()); hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) { - return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator"); + return trait.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl"); }); - hasSingleBlockTrait = - hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock"); + hasSingleBlockTrait = op.getTrait("::mlir::OpTrait::SingleBlock"); } /// Generate the operation parser from this format.