diff --git a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp --- a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp @@ -301,15 +301,11 @@ // the results of the parallel loop when it is fully rewritten. loopResults.assign(forOp.result_begin(), forOp.result_end()); first = false; - } else { - // A loop is constructed with an empty "yield" terminator by default. - // Replace it with another "yield" that forwards the results of the nested - // loop to the parent loop. We need to explicitly make sure the new - // terminator is the last operation in the block because further - // transforms rely on this. + } else if (!forOp.getResults().empty()) { + // A loop is constructed with an empty "yield" terminator if there are + // no results. rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.replaceOpWithNewOp( - rewriter.getInsertionBlock()->getTerminator(), forOp.getResults()); + rewriter.create(loc, forOp.getResults()); } rewriter.setInsertionPointToStart(forOp.getBody()); @@ -342,9 +338,10 @@ mapping.lookup(reduceBlock.getTerminator()->getOperand(0))); } - rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.replaceOpWithNewOp( - rewriter.getInsertionBlock()->getTerminator(), yieldOperands); + if (!yieldOperands.empty()) { + rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); + rewriter.create(loc, yieldOperands); + } rewriter.replaceOp(parallelOp, loopResults); diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LoopOps/LoopOps.h" + #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -47,7 +48,10 @@ for (Value v : iterArgs) result.addTypes(v.getType()); Region *bodyRegion = result.addRegion(); - ForOp::ensureTerminator(*bodyRegion, *builder, result.location); + bodyRegion->push_back(new Block()); + if (iterArgs.empty()) { + ForOp::ensureTerminator(*bodyRegion, *builder, result.location); + } bodyRegion->front().addArgument(builder->getIndexType()); for (Value v : iterArgs) bodyRegion->front().addArgument(v.getType()); @@ -201,7 +205,7 @@ void IfOp::build(Builder *builder, OperationState &result, Value cond, bool withElseRegion) { - build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion); + build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion); } void IfOp::build(Builder *builder, OperationState &result,