diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -52,7 +52,7 @@ the lower bound but does not include the upper bound. The body region must contain exactly one block that terminates with - "loop.yield". Calling ForOp::build will create such a region and insert + "loop.yield". Calling ForOp::build will create such a region and insert the terminator implicitly if none is defined, so will the parsing even in cases when it is absent from the custom format. For example: @@ -71,9 +71,11 @@ The region must terminate with a "loop.yield" that passes all the current iteration variables to the next iteration, or to the "loop.for" result, if - at the last iteration. "loop.for" results hold the final values after the - last iteration. + at the last iteration. Note, that when the loop-carried variables are + present, calling ForOp::build will not insert the terminator implicitly. + The caller must insert "loop.yield" in that case. + "loop.for" results hold the final values after the last iteration. For example, to sum-reduce a memref: ```mlir diff --git a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp --- a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp @@ -301,15 +301,11 @@ // the results of the parallel loop when it is fully rewritten. loopResults.assign(forOp.result_begin(), forOp.result_end()); first = false; - } else { - // A loop is constructed with an empty "yield" terminator by default. - // Replace it with another "yield" that forwards the results of the nested - // loop to the parent loop. We need to explicitly make sure the new - // terminator is the last operation in the block because further - // transforms rely on this. + } else if (!forOp.getResults().empty()) { + // A loop is constructed with an empty "yield" terminator if there are + // no results. rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.replaceOpWithNewOp( - rewriter.getInsertionBlock()->getTerminator(), forOp.getResults()); + rewriter.create(loc, forOp.getResults()); } rewriter.setInsertionPointToStart(forOp.getBody()); @@ -342,9 +338,10 @@ mapping.lookup(reduceBlock.getTerminator()->getOperand(0))); } - rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.replaceOpWithNewOp( - rewriter.getInsertionBlock()->getTerminator(), yieldOperands); + if (!yieldOperands.empty()) { + rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); + rewriter.create(loc, yieldOperands); + } rewriter.replaceOp(parallelOp, loopResults); diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LoopOps/LoopOps.h" + #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -47,7 +48,9 @@ for (Value v : iterArgs) result.addTypes(v.getType()); Region *bodyRegion = result.addRegion(); - ForOp::ensureTerminator(*bodyRegion, *builder, result.location); + bodyRegion->push_back(new Block()); + if (iterArgs.empty()) + ForOp::ensureTerminator(*bodyRegion, *builder, result.location); bodyRegion->front().addArgument(builder->getIndexType()); for (Value v : iterArgs) bodyRegion->front().addArgument(v.getType()); @@ -201,7 +204,7 @@ void IfOp::build(Builder *builder, OperationState &result, Value cond, bool withElseRegion) { - build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion); + build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion); } void IfOp::build(Builder *builder, OperationState &result,