diff --git a/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h b/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h --- a/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h +++ b/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h @@ -51,12 +51,19 @@ /// /// Users must supply a callback `shouldMoveIntoRegion` that determines whether /// the given operation that only has users in the given operation should be -/// moved into that region. +/// moved into that region. If this returns true, `moveIntoRegion` is called on +/// the same operation and region. +/// +/// `moveIntoRegion` must move the operation into the region such that dominance +/// of the operation is preserved; for example, by moving the operation to the +/// start of the entry block. This ensures the preservation of SSA dominance of +/// the operation's results. /// /// Returns the number of operations sunk. size_t controlFlowSink(ArrayRef regions, DominanceInfo &domInfo, - function_ref shouldMoveIntoRegion); + function_ref shouldMoveIntoRegion, + function_ref moveIntoRegion); /// Populates `regions` with regions of the provided region branch op that are /// executed at most once at that are reachable given the current operands of diff --git a/mlir/lib/Transforms/ControlFlowSink.cpp b/mlir/lib/Transforms/ControlFlowSink.cpp --- a/mlir/lib/Transforms/ControlFlowSink.cpp +++ b/mlir/lib/Transforms/ControlFlowSink.cpp @@ -60,9 +60,14 @@ // Get the regions are that known to be executed at most once. getSinglyExecutedRegionsToSink(branch, regionsToSink); // Sink side-effect free operations. - numSunk = - controlFlowSink(regionsToSink, domInfo, [](Operation *op, Region *) { - return isSideEffectFree(op); + numSunk = controlFlowSink( + regionsToSink, domInfo, + [](Operation *op, Region *) { return isSideEffectFree(op); }, + [](Operation *op, Region *region) { + // Move the operation to the beginning of the region's entry block. + // This guarantees the preservation of SSA dominance of all of the + // operation's uses are in the region. + op->moveBefore(®ion->front(), region->front().begin()); }); }); } diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp --- a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp +++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp @@ -34,8 +34,10 @@ public: /// Create an operation sinker with given dominance info. Sinker(function_ref shouldMoveIntoRegion, + function_ref moveIntoRegion, DominanceInfo &domInfo) - : shouldMoveIntoRegion(shouldMoveIntoRegion), domInfo(domInfo) {} + : shouldMoveIntoRegion(shouldMoveIntoRegion), + moveIntoRegion(moveIntoRegion), domInfo(domInfo), numSunk(0) {} /// Given a list of regions, find operations to sink and sink them. Return the /// number of operations sunk. @@ -61,6 +63,8 @@ /// The callback to determine whether an op should be moved in to a region. function_ref shouldMoveIntoRegion; + /// The calback to move an operation into the region. + function_ref moveIntoRegion; /// Dominance info to determine op user dominance with respect to regions. DominanceInfo &domInfo; /// The number of operations sunk. @@ -90,12 +94,7 @@ // If the op's users are all in the region and it can be moved, then do so. if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) { - // Move the op into the region's entry block. If the op is part of a - // subgraph, dependee ops would have been moved first, so inserting before - // the start of the block will ensure SSA dominance is preserved locally - // in the subgraph. Ops can only be safely moved into the entry block as - // the region's other blocks may for a loop. - op->moveBefore(®ion->front(), region->front().begin()); + moveIntoRegion(op, region); ++numSunk; // Add the op to the work queue. stack.push_back(op); @@ -127,8 +126,10 @@ size_t mlir::controlFlowSink( ArrayRef regions, DominanceInfo &domInfo, - function_ref shouldMoveIntoRegion) { - return Sinker(shouldMoveIntoRegion, domInfo).sinkRegions(regions); + function_ref shouldMoveIntoRegion, + function_ref moveIntoRegion) { + return Sinker(shouldMoveIntoRegion, moveIntoRegion, domInfo) + .sinkRegions(regions); } void mlir::getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch, diff --git a/mlir/test/Transforms/control-flow-sink-test.mlir b/mlir/test/Transforms/control-flow-sink-test.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/control-flow-sink-test.mlir @@ -0,0 +1,44 @@ +// Invoke the test control-flow sink pass to test the utilities. +// RUN: mlir-opt -test-control-flow-sink %s | FileCheck %s + +// CHECK-LABEL: func @test_sink +func @test_sink() { + %0 = "test.sink_me"() : () -> i32 + // CHECK-NEXT: test.sink_target + "test.sink_target"() ({ + // CHECK-NEXT: %[[V0:.*]] = "test.sink_me"() {was_sunk = 0 : i32} + // CHECK-NEXT: "test.use"(%[[V0]]) + "test.use"(%0) : (i32) -> () + }) : () -> () + return +} + +// CHECK-LABEL: func @test_sink_first_region_only +func @test_sink_first_region_only() { + %0 = "test.sink_me"() {first} : () -> i32 + // CHECK-NEXT: %[[V1:.*]] = "test.sink_me"() {second} + %1 = "test.sink_me"() {second} : () -> i32 + // CHECK-NEXT: test.sink_target + "test.sink_target"() ({ + // CHECK-NEXT: %[[V0:.*]] = "test.sink_me"() {first, was_sunk = 0 : i32} + // CHECK-NEXT: "test.use"(%[[V0]]) + "test.use"(%0) : (i32) -> () + }, { + "test.use"(%1) : (i32) -> () + }) : () -> () + return +} + +// CHECK-LABEL: func @test_sink_targeted_op_only +func @test_sink_targeted_op_only() { + %0 = "test.sink_me"() : () -> i32 + // CHECK-NEXT: %[[V1:.*]] = "test.dont_sink_me" + %1 = "test.dont_sink_me"() : () -> i32 + // CHECK-NEXT: test.sink_target + "test.sink_target"() ({ + // CHECK-NEXT: %[[V0:.*]] = "test.sink_me" + // CHECK-NEXT: "test.use"(%[[V0]], %[[V1]]) + "test.use"(%0, %1) : (i32, i32) -> () + }) : () -> () + return +} diff --git a/mlir/test/Transforms/control-flow-sink.mlir b/mlir/test/Transforms/control-flow-sink.mlir --- a/mlir/test/Transforms/control-flow-sink.mlir +++ b/mlir/test/Transforms/control-flow-sink.mlir @@ -1,3 +1,4 @@ +// Test the default control-flow sink pass. // RUN: mlir-opt -control-flow-sink %s | FileCheck %s // Test that operations can be sunk. diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms TestConstantFold.cpp + TestControlFlowSink.cpp TestInlining.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestControlFlowSink.cpp b/mlir/test/lib/Transforms/TestControlFlowSink.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestControlFlowSink.cpp @@ -0,0 +1,65 @@ +//===- TestControlFlowSink.cpp - Test control-flow sink pass --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass tests the control-flow sink utilities by implementing an example +// control-flow sink pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/ControlFlowSinkUtils.h" + +using namespace mlir; + +namespace { +/// An example control-flow sink pass to test the control-flow sink utilites. +/// This pass will sink ops named `test.sink_me` and tag them with an attribute +/// `was_sunk` into the first region of `test.sink_target` ops. +struct TestControlFlowSinkPass + : public PassWrapper> { + /// Get the command-line argument of the test pass. + StringRef getArgument() const final { return "test-control-flow-sink"; } + /// Get the description of the test pass. + StringRef getDescription() const final { + return "Test control-flow sink pass"; + } + + /// Runs the pass on the function. + void runOnOperation() override { + auto &domInfo = getAnalysis(); + auto shouldMoveIntoRegion = [](Operation *op, Region *region) { + return region->getRegionNumber() == 0 && + op->getName().getStringRef() == "test.sink_me"; + }; + auto moveIntoRegion = [](Operation *op, Region *region) { + Block &entry = region->front(); + op->moveBefore(&entry, entry.begin()); + op->setAttr("was_sunk", + Builder(op).getI32IntegerAttr(region->getRegionNumber())); + }; + + getOperation()->walk([&](Operation *op) { + if (op->getName().getStringRef() != "test.sink_target") + return; + SmallVector regions = + llvm::to_vector(RegionRange(op->getRegions())); + controlFlowSink(regions, domInfo, shouldMoveIntoRegion, moveIntoRegion); + }); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestControlFlowSink() { + PassRegistration(); +} +} // end namespace test +} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -65,6 +65,7 @@ void registerTestBuiltinAttributeInterfaces(); void registerTestCallGraphPass(); void registerTestConstantFold(); +void registerTestControlFlowSink(); void registerTestGpuSerializeToCubinPass(); void registerTestGpuSerializeToHsacoPass(); void registerTestDataLayoutQuery(); @@ -151,6 +152,7 @@ mlir::test::registerTestBuiltinAttributeInterfaces(); mlir::test::registerTestCallGraphPass(); mlir::test::registerTestConstantFold(); + mlir::test::registerTestControlFlowSink(); mlir::test::registerTestDiagnosticsPass(); #if MLIR_CUDA_CONVERSIONS_ENABLED mlir::test::registerTestGpuSerializeToCubinPass();