diff --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils.h @@ -13,11 +13,15 @@ #ifndef MLIR_DIALECT_SCF_UTILS_H_ #define MLIR_DIALECT_SCF_UTILS_H_ +#include "mlir/Support/LLVM.h" + namespace mlir { +class FuncOp; class OpBuilder; class ValueRange; namespace scf { +class IfOp; class ForOp; class ParallelOp; } // end namespace scf @@ -46,5 +50,12 @@ ValueRange newYieldedValues, bool replaceLoopResults = true); +/// Outline the then and/or else regions of `ifOp` as follows: +/// - if `thenFn` is not null, `thenFnName` must be specified and the `then` +/// region is inlined into a new FuncOp that is captured by the pointer. +/// - if `elseFn` is not null, `elseFnName` must be specified and the `else` +/// region is inlined into a new FuncOp that is captured by thepointer. +void outlineIfOp(scf::IfOp ifOp, FuncOp *thenFn, StringRef thenFnName, + FuncOp *elseFn, StringRef elseFnName); } // end namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_H_ diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -17,4 +17,5 @@ MLIRSCF MLIRStandardOps MLIRSupport - ) + MLIRTransformUtils +) diff --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp --- a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp @@ -13,7 +13,12 @@ #include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Function.h" +#include "mlir/Transforms/RegionUtils.h" + +#include "llvm/ADT/SetVector.h" using namespace mlir; @@ -71,3 +76,52 @@ return newLoop; } + +void mlir::outlineIfOp(scf::IfOp ifOp, FuncOp *thenFn, StringRef thenFnName, + FuncOp *elseFn, StringRef elseFnName) { + Location loc = ifOp.getLoc(); + MLIRContext *ctx = ifOp.getContext(); + auto outline = [&](Region &ifOrElseRegion, StringRef funcName) { + assert(!funcName.empty() && "Expected function name for outlining"); + assert(ifOrElseRegion.getBlocks().size() <= 1 && + "Expected at most one block"); + + // Outline before current function. + OpBuilder b(ifOp.getParentOfType()); + + llvm::SetVector captures; + getUsedValuesDefinedAbove(ifOrElseRegion, captures); + + ValueRange values(captures.getArrayRef()); + // TODO: use TypeRange when it works. + SmallVector types; + types.assign(values.getTypes().begin(), values.getTypes().end()); + FunctionType type = FunctionType::get(types, ifOp.getResultTypes(), ctx); + auto outlinedFunc = b.create(loc, funcName, type); + b.setInsertionPointToStart(outlinedFunc.addEntryBlock()); + BlockAndValueMapping bvm; + for (auto it : llvm::zip(values, outlinedFunc.getArguments())) + bvm.map(std::get<0>(it), std::get<1>(it)); + Operation *terminator = ifOrElseRegion.begin()->getTerminator(); + for (Operation &op : ifOrElseRegion.getOps()) { + if (&op == terminator) { + Operation *ret = + b.create(loc, op.getResultTypes(), op.getOperands()); + for (auto en : llvm::enumerate(ret->getOperands())) + ret->setOperand(en.index(), bvm.lookup(en.value())); + } else { + b.clone(op, bvm); + } + } + ifOrElseRegion.front().clear(); + b.setInsertionPointToEnd(&ifOrElseRegion.front()); + Operation *call = b.create(loc, outlinedFunc, values); + b.create(loc, call->getResults()); + return outlinedFunc; + }; + + if (thenFn && !ifOp.thenRegion().empty()) + *thenFn = outline(ifOp.thenRegion(), thenFnName); + if (elseFn && !ifOp.elseRegion().empty()) + *elseFn = outline(ifOp.elseRegion(), elseFnName); +} diff --git a/mlir/test/Transforms/scf-if-utils.mlir b/mlir/test/Transforms/scf-if-utils.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/scf-if-utils.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt -allow-unregistered-dialect -test-scf-if-utils -split-input-file %s | FileCheck %s + +// ----- + +// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref) -> i8 { +// CHECK-NEXT: %{{.*}} = "some_op"(%{{.*}}, %{{.*}}) : (i1, memref) -> i8 +// CHECK-NEXT: return %{{.*}} : i8 +// CHECK-NEXT: } +// CHECK: func @outlined_else0(%{{.*}}: i8) -> i8 { +// CHECK-NEXT: return %{{.*}}0 : i8 +// CHECK-NEXT: } +// CHECK: func @outline_if_else( +// CHECK-NEXT: %{{.*}} = scf.if %{{.*}} -> (i8) { +// CHECK-NEXT: %{{.*}} = call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref) -> i8 +// CHECK-NEXT: scf.yield %{{.*}} : i8 +// CHECK-NEXT: } else { +// CHECK-NEXT: %{{.*}} = call @outlined_else0(%{{.*}}) : (i8) -> i8 +// CHECK-NEXT: scf.yield %{{.*}} : i8 +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +func @outline_if_else(%cond: i1, %a: index, %b: memref, %c: i8) { + %r = scf.if %cond -> (i8) { + %r = "some_op"(%cond, %b) : (i1, memref) -> (i8) + scf.yield %r : i8 + } else { + scf.yield %c : i8 + } + return +} + +// ----- + +// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref) { +// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func @outline_if( +// CHECK-NEXT: scf.if %{{.*}} { +// CHECK-NEXT: call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +func @outline_if(%cond: i1, %a: index, %b: memref, %c: i8) { + scf.if %cond { + "some_op"(%cond, %b) : (i1, memref) -> () + scf.yield + } + return +} + +// ----- + +// CHECK: func @outlined_then0() { +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: memref) { +// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func @outline_empty_if_else( +// CHECK-NEXT: scf.if %{{.*}} { +// CHECK-NEXT: call @outlined_then0() : () -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, memref) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +func @outline_empty_if_else(%cond: i1, %a: index, %b: memref, %c: i8) { + scf.if %cond { + } else { + "some_op"(%cond, %b) : (i1, memref) -> () + } + return +} diff --git a/mlir/test/Transforms/loop-utils.mlir b/mlir/test/Transforms/scf-loop-utils.mlir rename from mlir/test/Transforms/loop-utils.mlir rename to mlir/test/Transforms/scf-loop-utils.mlir --- a/mlir/test/Transforms/loop-utils.mlir +++ b/mlir/test/Transforms/scf-loop-utils.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s // CHECK-LABEL: @hoist // CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index, diff --git a/mlir/test/lib/Transforms/TestSCFUtils.cpp b/mlir/test/lib/Transforms/TestSCFUtils.cpp --- a/mlir/test/lib/Transforms/TestSCFUtils.cpp +++ b/mlir/test/lib/Transforms/TestSCFUtils.cpp @@ -21,9 +21,10 @@ using namespace mlir; namespace { -class TestSCFUtilsPass : public PassWrapper { +class TestSCFForUtilsPass + : public PassWrapper { public: - explicit TestSCFUtilsPass() {} + explicit TestSCFForUtilsPass() {} void runOnFunction() override { FuncOp func = getFunction(); @@ -49,10 +50,30 @@ loop.erase(); } }; + +class TestSCFIfUtilsPass + : public PassWrapper { +public: + explicit TestSCFIfUtilsPass() {} + + void runOnFunction() override { + int count = 0; + FuncOp func = getFunction(); + func.walk([&](scf::IfOp ifOp) { + auto strCount = std::to_string(count++); + FuncOp thenFn, elseFn; + outlineIfOp(ifOp, &thenFn, std::string("outlined_then") + strCount, + &elseFn, std::string("outlined_else") + strCount); + }); + } +}; } // end namespace namespace mlir { void registerTestSCFUtilsPass() { - PassRegistration("test-scf-utils", "test scf utils"); + PassRegistration("test-scf-for-utils", + "test scf.for utils"); + PassRegistration("test-scf-if-utils", + "test scf.if utils"); } } // namespace mlir