diff --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils.h @@ -13,17 +13,20 @@ #ifndef MLIR_DIALECT_SCF_UTILS_H_ #define MLIR_DIALECT_SCF_UTILS_H_ +#include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" namespace mlir { class FuncOp; +class Location; class Operation; class OpBuilder; +class Region; +class RewriterBase; class ValueRange; class Value; -class AffineExpr; -class Operation; namespace scf { class IfOp; @@ -55,16 +58,34 @@ ValueRange newYieldedValues, bool replaceLoopResults = true); +/// Outline a region with a single block into a new FuncOp. +/// Assumes the FuncOp result types is the type of the yielded operands of the +/// single block. This constraint makes it easy to determine the result. +/// This method also clones the `arith::ConstantIndexOp` at the start of +/// `outlinedFuncBody` to alloc simple canonicalizations. +/// Creates a new FuncOp and thus cannot be used in a FunctionPass. +/// The client is responsible for providing a unique `funcName` that will not +/// collide with another FuncOp name. +// TODO: support more than single-block regions. +// TODO: more flexible constant handling. +FailureOr outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, + Region ®ion, StringRef funcName); + /// Outline the then and/or else regions of `ifOp` as follows: /// - if `thenFn` is not null, `thenFnName` must be specified and the `then` /// region is inlined into a new FuncOp that is captured by the pointer. /// - if `elseFn` is not null, `elseFnName` must be specified and the `else` /// region is inlined into a new FuncOp that is captured by the pointer. -void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn, - StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName); +/// Creates new FuncOps and thus cannot be used in a FunctionPass. +/// The client is responsible for providing a unique `thenFnName`/`elseFnName` +/// that will not collide with another FuncOp name. +LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, FuncOp *thenFn, + StringRef thenFnName, FuncOp *elseFn, + StringRef elseFnName); -/// Get a list of innermost parallel loops contained in `rootOp`. Innermost parallel -/// loops are those that do not contain further parallel loops themselves. +/// Get a list of innermost parallel loops contained in `rootOp`. Innermost +/// parallel loops are those that do not contain further parallel loops +/// themselves. bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl &result); diff --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp --- a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp @@ -12,12 +12,15 @@ #include "mlir/Dialect/SCF/Utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" using namespace mlir; @@ -77,51 +80,124 @@ return newLoop; } -void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn, - StringRef thenFnName, FuncOp *elseFn, - StringRef elseFnName) { - Location loc = ifOp.getLoc(); - MLIRContext *ctx = ifOp.getContext(); - auto outline = [&](Region &ifOrElseRegion, StringRef funcName) { - assert(!funcName.empty() && "Expected function name for outlining"); - assert(ifOrElseRegion.getBlocks().size() <= 1 && - "Expected at most one block"); - - // Outline before current function. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(ifOp->getParentOfType()); - - SetVector captures; - getUsedValuesDefinedAbove(ifOrElseRegion, captures); - - ValueRange values(captures.getArrayRef()); - FunctionType type = - FunctionType::get(ctx, values.getTypes(), ifOp.getResultTypes()); - auto outlinedFunc = b.create(loc, funcName, type); - b.setInsertionPointToStart(outlinedFunc.addEntryBlock()); +/// Outline a region with a single block into a new FuncOp. +/// Assumes the FuncOp result types is the type of the yielded operands of the +/// single block. This constraint makes it easy to determine the result. +/// This method also clones the `arith::ConstantIndexOp` at the start of +/// `outlinedFuncBody` to alloc simple canonicalizations. +// TODO: support more than single-block regions. +// TODO: more flexible constant handling. +FailureOr mlir::outlineSingleBlockRegion(RewriterBase &rewriter, + Location loc, Region ®ion, + StringRef funcName) { + assert(!funcName.empty() && "funcName cannot be empty"); + if (!region.hasOneBlock()) + return failure(); + + Block *originalBlock = ®ion.front(); + Operation *originalTerminator = originalBlock->getTerminator(); + + // Outline before current function. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(region.getParentOfType()); + + SetVector captures; + getUsedValuesDefinedAbove(region, captures); + + ValueRange outlinedValues(captures.getArrayRef()); + SmallVector outlinedFuncArgTypes; + // Region's arguments are exactly the first block's arguments as per + // Region::getArguments(). + // Func's arguments are cat(regions's arguments, captures arguments). + llvm::append_range(outlinedFuncArgTypes, region.getArgumentTypes()); + llvm::append_range(outlinedFuncArgTypes, outlinedValues.getTypes()); + FunctionType outlinedFuncType = + FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes, + originalTerminator->getOperandTypes()); + auto outlinedFunc = rewriter.create(loc, funcName, outlinedFuncType); + Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); + + // Merge blocks while replacing the original block operands. + // Warning: `mergeBlocks` erases the original block, reconstruct it later. + int64_t numOriginalBlockArguments = originalBlock->getNumArguments(); + auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments(); + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToEnd(outlinedFuncBody); + rewriter.mergeBlocks( + originalBlock, outlinedFuncBody, + outlinedFuncBlockArgs.take_front(numOriginalBlockArguments)); + // Explicitly set up a new ReturnOp terminator. + rewriter.setInsertionPointToEnd(outlinedFuncBody); + rewriter.create(loc, originalTerminator->getResultTypes(), + originalTerminator->getOperands()); + } + + // Reconstruct the block that was deleted and add a + // terminator(call_results). + Block *newBlock = rewriter.createBlock( + ®ion, region.begin(), + TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments)); + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToEnd(newBlock); + SmallVector callValues; + llvm::append_range(callValues, newBlock->getArguments()); + llvm::append_range(callValues, outlinedValues); + Operation *call = rewriter.create(loc, outlinedFunc, callValues); + + // `originalTerminator` was moved to `outlinedFuncBody` and is still valid. + // Clone `originalTerminator` to take the callOp results then erase it from + // `outlinedFuncBody`. BlockAndValueMapping bvm; - for (auto it : llvm::zip(values, outlinedFunc.getArguments())) - bvm.map(std::get<0>(it), std::get<1>(it)); - for (Operation &op : ifOrElseRegion.front().without_terminator()) - b.clone(op, bvm); - - Operation *term = ifOrElseRegion.front().getTerminator(); - SmallVector terminatorOperands; - for (auto op : term->getOperands()) - terminatorOperands.push_back(bvm.lookup(op)); - b.create(loc, term->getResultTypes(), terminatorOperands); - - ifOrElseRegion.front().clear(); - b.setInsertionPointToEnd(&ifOrElseRegion.front()); - Operation *call = b.create(loc, outlinedFunc, values); - b.create(loc, call->getResults()); - return outlinedFunc; - }; - - if (thenFn && !ifOp.getThenRegion().empty()) - *thenFn = outline(ifOp.getThenRegion(), thenFnName); - if (elseFn && !ifOp.getElseRegion().empty()) - *elseFn = outline(ifOp.getElseRegion(), elseFnName); + bvm.map(originalTerminator->getOperands(), call->getResults()); + rewriter.clone(*originalTerminator, bvm); + rewriter.eraseOp(originalTerminator); + } + + // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`. + // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`. + for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back( + outlinedValues.size()))) { + Value orig = std::get<0>(it); + Value repl = std::get<1>(it); + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(outlinedFuncBody); + if (Operation *cst = orig.getDefiningOp()) { + BlockAndValueMapping bvm; + repl = rewriter.clone(*cst, bvm)->getResult(0); + } + } + orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) { + return outlinedFunc->isProperAncestor(opOperand.getOwner()); + }); + } + + return outlinedFunc; +} + +LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp, FuncOp *thenFn, + StringRef thenFnName, FuncOp *elseFn, + StringRef elseFnName) { + IRRewriter rewriter(b); + Location loc = ifOp.getLoc(); + FailureOr outlinedFuncOpOrFailure; + if (thenFn && !ifOp.getThenRegion().empty()) { + outlinedFuncOpOrFailure = outlineSingleBlockRegion( + rewriter, loc, ifOp.getThenRegion(), thenFnName); + if (failed(outlinedFuncOpOrFailure)) + return failure(); + *thenFn = *outlinedFuncOpOrFailure; + } + if (elseFn && !ifOp.getElseRegion().empty()) { + outlinedFuncOpOrFailure = outlineSingleBlockRegion( + rewriter, loc, ifOp.getElseRegion(), elseFnName); + if (failed(outlinedFuncOpOrFailure)) + return failure(); + *elseFn = *outlinedFuncOpOrFailure; + } + return success(); } bool mlir::getInnermostParallelLoops(Operation *rootOp, diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -59,21 +59,26 @@ }; class TestSCFIfUtilsPass - : public PassWrapper { + : public PassWrapper> { public: StringRef getArgument() const final { return "test-scf-if-utils"; } StringRef getDescription() const final { return "test scf.if utils"; } explicit TestSCFIfUtilsPass() = default; - void runOnFunction() override { + void runOnOperation() override { int count = 0; - FuncOp func = getFunction(); - func.walk([&](scf::IfOp ifOp) { + getOperation().walk([&](scf::IfOp ifOp) { auto strCount = std::to_string(count++); FuncOp thenFn, elseFn; OpBuilder b(ifOp); - outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount, - &elseFn, std::string("outlined_else") + strCount); + IRRewriter rewriter(b); + if (failed(outlineIfOp(rewriter, ifOp, &thenFn, + std::string("outlined_then") + strCount, &elseFn, + std::string("outlined_else") + strCount))) { + this->signalPassFailure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); }); } };