diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -51,6 +51,24 @@ void getUsedValuesDefinedAbove(MutableArrayRef regions, SetVector &values); +/// Make a region isolated from above +/// - Capture the values that are defined above the region and used within it. +/// - Append to the entry block arguments that represent the captured values +/// (one per captured value). +/// - Replace all uses within the region of the captured values with the +/// newly added arguments. +/// - `cloneOperationIntoRegion` is a callback that allows caller to specify +/// if the operation defining an `OpOperand` needs to be cloned into the +/// region. Then the operands of this operation become part of the captured +/// values set (unless the operations that define the operands themeselves +/// are to be cloned). The cloned operations are added to the entry block +/// of the region. +/// Return the set of captured values for the operation. +SmallVector makeRegionIsolatedFromAbove( + RewriterBase &rewriter, Region ®ion, + llvm::function_ref cloneOperationIntoRegion = + [](Operation *) { return false; }); + /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -8,17 +8,21 @@ #include "mlir/Transforms/RegionUtils.h" #include "mlir/IR/Block.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/TopologicalSortUtils.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallSet.h" +#include + using namespace mlir; void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, @@ -69,6 +73,99 @@ getUsedValuesDefinedAbove(region, region, values); } +//===----------------------------------------------------------------------===// +// Make block isolated from above. +//===----------------------------------------------------------------------===// + +SmallVector mlir::makeRegionIsolatedFromAbove( + RewriterBase &rewriter, Region ®ion, + llvm::function_ref cloneOperationIntoRegion) { + + // Get initial list of values used within region but defined above. + llvm::SetVector initialCapturedValues; + mlir::getUsedValuesDefinedAbove(region, initialCapturedValues); + + std::deque worklist(initialCapturedValues.begin(), + initialCapturedValues.end()); + llvm::DenseSet visited; + llvm::DenseSet visitedOps; + + llvm::SetVector finalCapturedValues; + SmallVector clonedOperations; + while (!worklist.empty()) { + Value currValue = worklist.front(); + worklist.pop_front(); + if (visited.count(currValue)) + continue; + visited.insert(currValue); + + Operation *definingOp = currValue.getDefiningOp(); + if (!definingOp || visitedOps.count(definingOp)) { + finalCapturedValues.insert(currValue); + continue; + } + visitedOps.insert(definingOp); + + if (!cloneOperationIntoRegion(definingOp)) { + // Defining operation isnt cloned, so add the current value to final + // captured values list. + finalCapturedValues.insert(currValue); + continue; + } + + // Add all operands of the operation to the worklist and mark the op as to + // be cloned. + for (Value operand : definingOp->getOperands()) { + if (visited.count(operand)) + continue; + worklist.push_back(operand); + } + clonedOperations.push_back(definingOp); + } + + mlir::computeTopologicalSorting(clonedOperations); + + OpBuilder::InsertionGuard g(rewriter); + // Collect types of existing block + Block *entryBlock = ®ion.front(); + SmallVector newArgTypes = + llvm::to_vector(entryBlock->getArgumentTypes()); + SmallVector newArgLocs = llvm::to_vector(llvm::map_range( + entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); })); + + // Append the types of the captured values. + for (auto value : finalCapturedValues) { + newArgTypes.push_back(value.getType()); + newArgLocs.push_back(value.getLoc()); + } + + // Create a new entry block. + Block *newEntryBlock = + rewriter.createBlock(®ion, region.begin(), newArgTypes, newArgLocs); + auto newEntryBlockArgs = newEntryBlock->getArguments(); + + // Create a mapping between the captured values and the new arguments added. + IRMapping map; + auto replaceIfFn = [&](OpOperand &use) { + return use.getOwner()->getBlock()->getParent() == ®ion; + }; + for (auto [arg, capturedVal] : + llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()), + finalCapturedValues)) { + map.map(capturedVal, arg); + rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn); + } + rewriter.setInsertionPointToStart(newEntryBlock); + for (auto clonedOp : clonedOperations) { + Operation *newOp = rewriter.clone(*clonedOp, map); + rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn); + } + rewriter.mergeBlocks( + entryBlock, newEntryBlock, + newEntryBlock->getArguments().take_front(entryBlock->getNumArguments())); + return llvm::to_vector(finalCapturedValues); +} + //===----------------------------------------------------------------------===// // Unreachable Block Elimination //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/make-isolated-from-above.mlir b/mlir/test/Transforms/make-isolated-from-above.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/make-isolated-from-above.mlir @@ -0,0 +1,115 @@ +// RUN: mlir-opt -test-make-isolated-from-above=simple -allow-unregistered-dialect --split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-make-isolated-from-above=clone-ops-with-no-operands -allow-unregistered-dialect --split-input-file %s | FileCheck %s --check-prefix=CLONE1 +// RUN: mlir-opt -test-make-isolated-from-above=clone-ops-with-operands -allow-unregistered-dialect --split-input-file %s | FileCheck %s --check-prefix=CLONE2 + +func.func @make_isolated_from_above_single_block(%arg0 : index, %arg1 : index) { + %c0 = arith.constant 0: index + %c1 = arith.constant 1 : index + %empty = tensor.empty(%arg0, %arg1) : tensor + %d0 = tensor.dim %empty, %c0 : tensor + %d1 = tensor.dim %empty, %c1 : tensor + "test.one_region_with_operands_op"() ({ + "foo.yield"(%c0, %c1, %d0, %d1) : (index, index, index, index) -> () + }) : () -> () + return +} +// CHECK-LABEL: func @make_isolated_from_above_single_block( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]]) +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]] +// CHECK: test.isolated_one_region_op %[[C0]], %[[C1]], %[[D0]], %[[D1]] +// CHECK-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index, %[[B3:[a-zA-Z0-9]+]]: index) +// CHECK: "foo.yield"(%[[B0]], %[[B1]], %[[B2]], %[[B3]]) + +// CLONE1-LABEL: func @make_isolated_from_above_single_block( +// CLONE1-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CLONE1-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CLONE1-DAG: %[[C0:.+]] = arith.constant 0 : index +// CLONE1-DAG: %[[C1:.+]] = arith.constant 1 : index +// CLONE1-DAG: %[[EMPTY:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]]) +// CLONE1-DAG: %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]] +// CLONE1-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]] +// CLONE1: test.isolated_one_region_op %[[D0]], %[[D1]] +// CLONE1-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index) +// CLONE1-DAG: %[[C0_0:.+]] = arith.constant 0 : index +// CLONE1-DAG: %[[C1_0:.+]] = arith.constant 1 : index +// CLONE1: "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B0]], %[[B1]]) + +// CLONE2-LABEL: func @make_isolated_from_above_single_block( +// CLONE2-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CLONE2-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CLONE2: test.isolated_one_region_op %[[ARG0]], %[[ARG1]] +// CLONE2-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index) +// CLONE2-DAG: %[[C0:.+]] = arith.constant 0 : index +// CLONE2-DAG: %[[C1:.+]] = arith.constant 1 : index +// CLONE2-DAG: %[[EMPTY:.+]] = tensor.empty(%[[B0]], %[[B1]]) +// CLONE2-DAG: %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]] +// CLONE2-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]] +// CLONE2: "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]]) + +// ----- + +func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index, %arg2 : index) { + %c0 = arith.constant 0: index + %c1 = arith.constant 1 : index + %empty = tensor.empty(%arg0, %arg1) : tensor + %d0 = tensor.dim %empty, %c0 : tensor + %d1 = tensor.dim %empty, %c1 : tensor + "test.one_region_with_operands_op"(%arg2) ({ + ^bb0(%b0 : index): + cf.br ^bb1(%b0: index) + ^bb1(%b1 : index): + "foo.yield"(%c0, %c1, %d0, %d1, %b1) : (index, index, index, index, index) -> () + }) : (index) -> () + return +} +// CHECK-LABEL: func @make_isolated_from_above_multiple_blocks( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]]) +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]] +// CHECK: test.isolated_one_region_op %[[ARG2]], %[[C0]], %[[C1]], %[[D0]], %[[D1]] +// CHECK-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index, %[[B3:[a-zA-Z0-9]+]]: index, %[[B4:[a-zA-Z0-9]+]]: index) +// CHECK-NEXT: cf.br ^bb1(%[[B0]] : index) +// CHECK: ^bb1(%[[B5:.+]]: index) +// CHECK: "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]]) + +// CLONE1-LABEL: func @make_isolated_from_above_multiple_blocks( +// CLONE1-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CLONE1-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CLONE1-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CLONE1-DAG: %[[C0:.+]] = arith.constant 0 : index +// CLONE1-DAG: %[[C1:.+]] = arith.constant 1 : index +// CLONE1-DAG: %[[EMPTY:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]]) +// CLONE1-DAG: %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]] +// CLONE1-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]] +// CLONE1: test.isolated_one_region_op %[[ARG2]], %[[D0]], %[[D1]] +// CLONE1-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index) +// CLONE1-DAG: %[[C0_0:.+]] = arith.constant 0 : index +// CLONE1-DAG: %[[C1_0:.+]] = arith.constant 1 : index +// CLONE1-NEXT: cf.br ^bb1(%[[B0]] : index) +// CLONE1: ^bb1(%[[B3:.+]]: index) +// CLONE1: "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B3]]) + +// CLONE2-LABEL: func @make_isolated_from_above_multiple_blocks( +// CLONE2-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CLONE2-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CLONE2-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CLONE2: test.isolated_one_region_op %[[ARG2]], %[[ARG0]], %[[ARG1]] +// CLONE2-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index) +// CLONE2-DAG: %[[C0:.+]] = arith.constant 0 : index +// CLONE2-DAG: %[[C1:.+]] = arith.constant 1 : index +// CLONE2-DAG: %[[EMPTY:.+]] = tensor.empty(%[[B1]], %[[B2]]) +// CLONE2-DAG: %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]] +// CLONE2-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]] +// CLONE2-NEXT: cf.br ^bb1(%[[B0]] : index) +// CLONE2: ^bb1(%[[B3:.+]]: index) +// CLONE2: "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B3]]) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -438,6 +438,19 @@ }]; } +def OneRegionWithOperandsOp : TEST_Op<"one_region_with_operands_op", []> { + let arguments = (ins Variadic:$operands); + let regions = (region AnyRegion); +} + +def IsolatedOneRegionOp : TEST_Op<"isolated_one_region_op", [IsolatedFromAbove]> { + let arguments = (ins Variadic:$operands); + let regions = (region AnyRegion:$my_region); + let assemblyFormat = [{ + attr-dict-with-keyword $operands $my_region `:` type($operands) + }]; +} + //===----------------------------------------------------------------------===// // NoTerminator Operation //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ TestDialectConversion.cpp TestInlining.cpp TestIntRangeInference.cpp + TestMakeIsolatedFromAbove.cpp TestTopologicalSort.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp @@ -0,0 +1,156 @@ +//===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; + +/// Helper function to call the `makeRegionIsolatedFromAbove` to convert +/// `test.one_region_op` to `test.isolated_one_region_op`. +static LogicalResult +makeIsolatedFromAboveImpl(RewriterBase &rewriter, + test::OneRegionWithOperandsOp regionOp, + llvm::function_ref callBack) { + Region ®ion = regionOp.getRegion(); + SmallVector capturedValues = + makeRegionIsolatedFromAbove(rewriter, region, callBack); + SmallVector operands = regionOp.getOperands(); + operands.append(capturedValues); + auto isolatedRegionOp = + rewriter.create(regionOp.getLoc(), operands); + rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(), + isolatedRegionOp.getRegion().begin()); + rewriter.eraseOp(regionOp); + return success(); +} + +namespace { + +/// Simple test for making region isolated from above without cloning any +/// operations. +struct SimpleMakeIsolatedFromAbove + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, + PatternRewriter &rewriter) const override { + return makeIsolatedFromAboveImpl(rewriter, regionOp, + [](Operation *) { return false; }); + } +}; + +/// Test for making region isolated from above while clong operations +/// with no operands. +struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, + PatternRewriter &rewriter) const override { + return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) { + return op->getNumOperands() == 0; + }); + } +}; + +/// Test for making region isolated from above while clong operations +/// with no operands. +struct MakeIsolatedFromAboveAndCloneOpsWithOperands + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, + PatternRewriter &rewriter) const override { + return makeIsolatedFromAboveImpl(rewriter, regionOp, + [](Operation *op) { return true; }); + } +}; + +/// Test pass for testing the `makeIsolatedFromAbove` function. +struct TestMakeIsolatedFromAbovePass + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass) + + TestMakeIsolatedFromAbovePass() = default; + TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const final { + return "test-make-isolated-from-above"; + } + + StringRef getDescription() const final { + return "Test making a region isolated from above"; + } + + Option simple{ + *this, "simple", + llvm::cl::desc("Test simple case with no cloning of operations"), + llvm::cl::init(false)}; + + Option cloneOpsWithNoOperands{ + *this, "clone-ops-with-no-operands", + llvm::cl::desc("Test case with cloning of operations with no operands"), + llvm::cl::init(false)}; + + Option cloneOpsWithOperands{ + *this, "clone-ops-with-operands", + llvm::cl::desc("Test case with cloning of operations with no operands"), + llvm::cl::init(false)}; + + void runOnOperation() override; +}; + +} // namespace + +void TestMakeIsolatedFromAbovePass::runOnOperation() { + MLIRContext *context = &getContext(); + func::FuncOp funcOp = getOperation(); + + if (simple) { + RewritePatternSet patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + return; + } + + if (cloneOpsWithNoOperands) { + RewritePatternSet patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + return; + } + + if (cloneOpsWithOperands) { + RewritePatternSet patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + return; + } +} + +namespace mlir { +namespace test { +void registerTestMakeIsolatedFromAbovePass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -104,6 +104,7 @@ void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); void registerTestLowerToLLVM(); +void registerTestMakeIsolatedFromAbovePass(); void registerTestMatchReductionPass(); void registerTestMathAlgebraicSimplificationPass(); void registerTestMathPolynomialApproximationPass(); @@ -218,6 +219,7 @@ mlir::test::registerTestLoopMappingPass(); mlir::test::registerTestLoopUnrollingPass(); mlir::test::registerTestLowerToLLVM(); + mlir::test::registerTestMakeIsolatedFromAbovePass(); mlir::test::registerTestMatchReductionPass(); mlir::test::registerTestMathAlgebraicSimplificationPass(); mlir::test::registerTestMathPolynomialApproximationPass();