diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -1805,17 +1805,9 @@ /// Does loop set (and return) the final value of the control variable? bool hasLastValue() { return lastVal().size(); } - /// Get the body of the loop - mlir::Block *getBody() { return ®ion().front(); } - /// Get the block argument corresponding to the loop control value (PHI) mlir::Value getInductionVar() { return getBody()->getArgument(0); } - /// Get a builder to insert operations into the LoopOp - mlir::OpBuilder getBodyBuilder() { - return mlir::OpBuilder(getBody(), std::prev(getBody()->end())); - } - void setLowerBound(mlir::Value bound) { getOperation()->setOperand(0, bound); } diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -180,11 +180,7 @@ static StringRef getLowerBoundAttrName() { return "lower_bound"; } static StringRef getUpperBoundAttrName() { return "upper_bound"; } - Block *getBody() { return ®ion().front(); } Value getInductionVar() { return getBody()->getArgument(0); } - OpBuilder getBodyBuilder() { - return OpBuilder(getBody(), std::prev(getBody()->end())); - } // TODO: provide iterators for the lower and upper bound operands // if the current access via getLowerBound(), getUpperBound() is too slow. diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -141,11 +141,7 @@ ]; let extraClassDeclaration = [{ - Block *getBody() { return ®ion().front(); } Value getInductionVar() { return getBody()->getArgument(0); } - OpBuilder getBodyBuilder() { - return OpBuilder(getBody(), std::prev(getBody()->end())); - } Block::BlockArgListType getRegionIterArgs() { return getBody()->getArguments().drop_front(); } @@ -242,16 +238,14 @@ let extraClassDeclaration = [{ OpBuilder getThenBodyBuilder() { - assert(!thenRegion().empty() && "Unexpected empty 'then' region."); - Block &body = thenRegion().front(); - return OpBuilder(&body, - results().empty() ? std::prev(body.end()) : body.end()); + Block* body = getBody(0); + return results().empty() ? OpBuilder::atBlockTerminator(body) + : OpBuilder::atBlockEnd(body); } OpBuilder getElseBodyBuilder() { - assert(!elseRegion().empty() && "Unexpected empty 'else' region."); - Block &body = elseRegion().front(); - return OpBuilder(&body, - results().empty() ? std::prev(body.end()) : body.end()); + Block* body = getBody(1); + return results().empty() ? OpBuilder::atBlockTerminator(body) + : OpBuilder::atBlockEnd(body); } }]; } @@ -322,7 +316,6 @@ ]; let extraClassDeclaration = [{ - Block *getBody() { return ®ion().front(); } ValueRange getInductionVars() { return getBody()->getArguments(); } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -533,11 +533,6 @@ ]; let extraClassDeclaration = [{ - OpBuilder getBodyBuilder() { - assert(!body().empty() && "Unexpected empty 'body' region."); - Block &block = body().front(); - return OpBuilder(&block, block.end()); - } // The value stored in memref[ivs]. Value getCurrentValue() { return body().front().getArgument(0); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -195,7 +195,7 @@ } /// Create a builder and set the insertion point to before the first operation - /// in the block but still inside th block. + /// in the block but still inside the block. static OpBuilder atBlockBegin(Block *block) { return OpBuilder(block, block->begin()); } @@ -206,6 +206,14 @@ return OpBuilder(block, block->end()); } + /// Create a builder and set the insertion point to before the block + /// terminator. + static OpBuilder atBlockTerminator(Block *block) { + auto *terminator = block->getTerminator(); + assert(terminator != nullptr && "the block has no terminator"); + return OpBuilder(block, terminator->getIterator()); + } + /// This class represents a saved insertion point. class InsertPoint { public: @@ -322,7 +330,7 @@ OpTy::build(this, state, std::forward(args)...); auto *op = createOperation(state); auto result = dyn_cast(op); - assert(result && "Builder didn't return the right type"); + assert(result && "builder didn't return the right type"); return result; } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1093,6 +1093,12 @@ ::mlir::impl::template ensureRegionTerminator( region, builder, loc); } + + Block *getBody(unsigned idx = 0) { + Region ®ion = this->getOperation()->getRegion(idx); + assert(!region.empty() && "unexpected empty region"); + return ®ion.front(); + } }; }; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -192,7 +192,7 @@ BlockAndValueMapping operandMap; - OpBuilder bodyBuilder = loopChunk.getBodyBuilder(); + auto bodyBuilder = OpBuilder::atBlockTerminator(loopChunk.getBody()); for (auto it = opGroupQueue.begin() + offset, e = opGroupQueue.end(); it != e; ++it) { uint64_t shift = it->first; @@ -470,7 +470,7 @@ // Builder to insert unrolled bodies just before the terminator of the body of // 'forOp'. - OpBuilder builder = forOp.getBodyBuilder(); + auto builder = OpBuilder::atBlockTerminator(forOp.getBody()); // Keep a pointer to the last non-terminator operation in the original block // so that we know what to clone (since we are doing this in-place). @@ -906,7 +906,7 @@ SmallVector innerLoops; for (auto t : targets) { // Insert newForOp before the terminator of `t`. - OpBuilder b = t.getBodyBuilder(); + auto b = OpBuilder::atBlockTerminator(t.getBody()); auto newForOp = b.create(t.getLoc(), lbOperands, lbMap, ubOperands, ubMap, originalStep); auto begin = t.getBody()->begin(); @@ -938,7 +938,7 @@ auto nOps = t.getBody()->getOperations().size(); // Insert newForOp before the terminator of `t`. - OpBuilder b(t.getBodyBuilder()); + auto b = OpBuilder::atBlockTerminator((t.getBody())); Value stepped = b.create(t.getLoc(), iv, forOp.step()); Value less = b.create(t.getLoc(), CmpIPredicate::slt, forOp.upperBound(), stepped); @@ -1510,7 +1510,7 @@ if (d == 0) copyNestRoot = forOp; - b = forOp.getBodyBuilder(); + b = OpBuilder::atBlockTerminator(forOp.getBody()); auto fastBufOffsetMap = AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]); @@ -2309,7 +2309,7 @@ AffineForOp fullTileLoop = createCanonicalizedAffineForOp( b, loop.getLoc(), lbVmap.getOperands(), lbVmap.getAffineMap(), ubVmap.getOperands(), ubVmap.getAffineMap()); - b = fullTileLoop.getBodyBuilder(); + b = OpBuilder::atBlockTerminator(fullTileLoop.getBody()); fullTileLoops.push_back(fullTileLoop); } @@ -2318,7 +2318,7 @@ for (auto loopEn : llvm::enumerate(inputNest)) operandMap.map(loopEn.value().getInductionVar(), fullTileLoops[loopEn.index()].getInductionVar()); - b = fullTileLoops.back().getBodyBuilder(); + b = OpBuilder::atBlockTerminator(fullTileLoops.back().getBody()); for (auto &op : inputNest.back().getBody()->without_terminator()) b.clone(op, operandMap); return success();