diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -194,6 +194,13 @@ PatternRewriter &rewriter) const override; }; +struct ExecuteRegionLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override; +}; + struct ParallelLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -400,6 +407,38 @@ return success(); } +LogicalResult +ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + auto *condBlock = rewriter.getInsertionBlock(); + auto opPosition = rewriter.getInsertionPoint(); + auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); + + auto ®ion = op.region(); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, ®ion.front()); + + for (Block &block : region) { + if (auto terminator = dyn_cast(block.getTerminator())) { + ValueRange terminatorOperands = terminator->getOperands(); + rewriter.setInsertionPointToEnd(&block); + rewriter.create(loc, remainingOpsBlock, terminatorOperands); + rewriter.eraseOp(terminator); + } + } + + rewriter.inlineRegionBefore(region, remainingOpsBlock); + + SmallVector vals; + for (auto arg : remainingOpsBlock->addArguments(op->getResultTypes())) { + vals.push_back(arg); + } + rewriter.replaceOp(op, vals); + return success(); +} + LogicalResult ParallelLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { @@ -569,8 +608,8 @@ } void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext()); patterns.add(patterns.getContext(), /*benefit=*/2); } @@ -580,7 +619,8 @@ // Configure conversion to lower out scf.for, scf.if, scf.parallel and // scf.while. Anything else is fine. ConversionTarget target(getContext()); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns))))