diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -210,9 +210,13 @@ //===--------------------------------------------------------------------===// /// Get the terminator operation of this block. This function asserts that - /// the block has a valid terminator operation. + /// the block has a valid terminator operation: an op with the IsTerminator + /// trait or an unregistered op. Operation *getTerminator(); + /// Return true if this block has a terminator op with the IsTerminator trait. + bool hasRegisteredTerminator(); + //===--------------------------------------------------------------------===// // Predecessors and successors. //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -288,7 +288,7 @@ tiledLoopOp.iterator_types(), tiledLoopOp.distribution_types()); // Remove terminator. - if (!newTiledLoopOp.getBody()->empty()) + if (newTiledLoopOp.getBody()->hasRegisteredTerminator()) rewriter.eraseOp(tiledLoopOp.getBody()->getTerminator()); // Compute new loop body arguments. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -163,7 +163,7 @@ /*withElseRegion=*/true); // Remove terminators. - if (!newIfOp.thenBlock()->empty()) { + if (newIfOp.thenBlock()->hasRegisteredTerminator()) { rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); } @@ -326,7 +326,7 @@ iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); // Erase terminator if present. - if (iterArgs.size() == 1) + if (loopBody->hasRegisteredTerminator()) rewriter.eraseOp(loopBody->getTerminator()); // Move loop body to new loop. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -434,12 +434,12 @@ // Consider the operations within this block, ignoring the terminator if // requested. - bool hasTerminator = - !block->empty() && block->back().hasTrait(); auto range = llvm::make_range( block->begin(), std::prev(block->end(), - (!hasTerminator || printBlockTerminator) ? 0 : 1)); + (!block->hasRegisteredTerminator() || printBlockTerminator) + ? 0 + : 1)); for (Operation &op : range) print(&op); } @@ -2675,12 +2675,11 @@ } currentIndent += indentWidth; - bool hasTerminator = - !block->empty() && block->back().hasTrait(); auto range = llvm::make_range( block->begin(), - std::prev(block->end(), - (!hasTerminator || printBlockTerminator) ? 0 : 1)); + std::prev( + block->end(), + (!block->hasRegisteredTerminator() || printBlockTerminator) ? 0 : 1)); for (auto &op : range) { print(&op); os << newLine; diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -253,12 +253,18 @@ //===----------------------------------------------------------------------===// /// Get the terminator operation of this block. This function asserts that -/// the block has a valid terminator operation. +/// the block has a valid terminator operation: an op with the IsTerminator +/// trait or an unregistered op. Operation *Block::getTerminator() { assert(!empty() && back().mightHaveTrait()); return &back(); } +/// Return true if this block has a terminator op with the IsTerminator trait. +bool Block::hasRegisteredTerminator() { + return !empty() && back().hasTrait(); +} + // Indexed successor access. unsigned Block::getNumSuccessors() { return empty() ? 0 : back().getNumSuccessors(); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1292,7 +1292,7 @@ builder.createBlock(®ion); Block &block = region.back(); - if (!block.empty() && block.back().hasTrait()) + if (block.hasRegisteredTerminator()) return; builder.setInsertionPointToEnd(&block);