diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -79,17 +79,12 @@ /// is empty, insert a new block first. `buildTerminatorOp` should return the /// terminator operation to insert. void ensureRegionTerminator( - Region ®ion, Location loc, - function_ref buildTerminatorOp); -/// Templated version that fills the generates the provided operation type. -template -void ensureRegionTerminator(Region ®ion, Builder &builder, Location loc) { - ensureRegionTerminator(region, loc, [&](OpBuilder &b) { - OperationState state(loc, OpTy::getOperationName()); - OpTy::build(b, state); - return Operation::create(state); - }); -} + Region ®ion, OpBuilder &builder, Location loc, + function_ref buildTerminatorOp); +void ensureRegionTerminator( + Region ®ion, Builder &builder, Location loc, + function_ref buildTerminatorOp); + } // namespace impl /// This is the concrete base class that holds the operation pointer and has @@ -1077,6 +1072,15 @@ template struct SingleBlockImplicitTerminator { template class Impl : public TraitBase { + private: + /// Builds a terminator operation without relying on OpBuilder APIs to avoid + /// cyclic header inclusion. + static Operation *buildTerminator(OpBuilder &builder, Location loc) { + OperationState state(loc, TerminatorOpType::getOperationName()); + TerminatorOpType::build(builder, state); + return Operation::create(state); + } + public: static LogicalResult verifyTrait(Operation *op) { for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { @@ -1112,10 +1116,19 @@ } /// Ensure that the given region has the terminator required by this trait. + /// If OpBuilder is provided, use it to build the terminator and notify the + /// OpBuilder litsteners accoridngly. If only a Builder is provided, locally + /// construct an OpBuilder with no listeners; this should only be used if no + /// OpBuilder is available at the call site, e.g., in the parser. static void ensureTerminator(Region ®ion, Builder &builder, Location loc) { - ::mlir::impl::template ensureRegionTerminator( - region, builder, loc); + ::mlir::impl::ensureRegionTerminator(region, builder, loc, + buildTerminator); + } + static void ensureTerminator(Region ®ion, OpBuilder &builder, + Location loc) { + ::mlir::impl::ensureRegionTerminator(region, builder, loc, + buildTerminator); } Block *getBody(unsigned idx = 0) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1099,17 +1099,27 @@ /// is empty, insert a new block first. `buildTerminatorOp` should return the /// terminator operation to insert. void impl::ensureRegionTerminator( - Region ®ion, Location loc, - function_ref buildTerminatorOp) { + Region ®ion, OpBuilder &builder, Location loc, + function_ref buildTerminatorOp) { + OpBuilder::InsertionGuard guard(builder); if (region.empty()) - region.push_back(new Block); + builder.createBlock(®ion); Block &block = region.back(); if (!block.empty() && block.back().isKnownTerminator()) return; - OpBuilder builder(loc.getContext()); - block.push_back(buildTerminatorOp(builder)); + builder.setInsertionPointToEnd(&block); + builder.insert(buildTerminatorOp(builder, loc)); +} + +/// Create a simple OpBuilder and forward to the OpBuilder version of this +/// function. +void impl::ensureRegionTerminator( + Region ®ion, Builder &builder, Location loc, + function_ref buildTerminatorOp) { + OpBuilder opBuilder(builder.getContext()); + ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); } //===----------------------------------------------------------------------===//