diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -83,6 +83,35 @@ ValueRange inputs; }; +/// This class represents upper and lower bounds on the number of times a region +/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least +/// zero, but the upper bound may not be known. +class InvocationBounds { +public: + /// Create invocation bounds. The lower bound must be at least 0 and only the + /// upper bound can be unknown. + InvocationBounds(unsigned lb, Optional ub) : lower(lb), upper(ub) { + assert(!ub || ub >= lb && "upper bound cannot be less than lower bound"); + } + + /// Return the lower bound. + unsigned getLowerBound() const { return lower; } + + /// Return the upper bound. + Optional getUpperBound() const { return upper; } + + /// Returns the unknown invocation bounds, i.e., there is no information on + /// how many times a region may be invoked. + static InvocationBounds getUnknown() { return {0, llvm::None}; } + +private: + /// The minimum number of times the successor region will be invoked. + unsigned lower; + /// The maximum number of times the successor region will be invoked or `None` + /// if an upper bound is not known. + Optional upper; +}; + /// Return `true` if `a` and `b` are in mutually exclusive regions as per /// RegionBranchOpInterface. bool insideMutuallyExclusiveRegions(Operation *a, Operation *b); diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -102,9 +102,10 @@ let methods = [ InterfaceMethod<[{ Returns the operands of this operation used as the entry arguments when - entering the region at `index`, which was specified as a successor of this - operation by `getSuccessorRegions`. These operands should correspond 1-1 - with the successor inputs specified in `getSuccessorRegions`. + entering the region at `index`, which was specified as a successor of + this operation by `getSuccessorRegions`. These operands should + correspond 1-1 with the successor inputs specified in + `getSuccessorRegions`. }], "::mlir::OperandRange", "getSuccessorEntryOperands", (ins "unsigned":$index), [{}], /*defaultImplementation=*/[{ @@ -127,9 +128,28 @@ successor region must be non-empty. }], "void", "getSuccessorRegions", - (ins "::mlir::Optional":$index, "::mlir::ArrayRef<::mlir::Attribute>":$operands, + (ins "::mlir::Optional":$index, + "::mlir::ArrayRef<::mlir::Attribute>":$operands, "::mlir::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions) - > + >, + InterfaceMethod<[{ + Populates `invocationBounds` with the minimum and maximum number of + times this operation will invoke the attached regions (assuming the + regions yield normally, i.e. do not abort or invoke an infinite loop). + The minimum number of invocations is at least 0. If the maximum number + of invocations cannot be statically determined, then it will not have a + value (i.e., it is set to `llvm::None`). + + `operands` is a set of optional attributes that either correspond to a + constant values for each operand of this operation, or null if that + operand is not a constant. + }], + "void", "getRegionInvocationBounds", + (ins "::mlir::ArrayRef<::mlir::Attribute>":$operands, + "::llvm::SmallVectorImpl<::mlir::InvocationBounds> &" + :$invocationBounds), [{}], + [{ invocationBounds.append($_op->getNumRegions(), {0, ::llvm::None}); }] + >, ]; let verify = [{ diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -74,6 +74,9 @@ ArrayRef disabledPatterns = llvm::None, ArrayRef enabledPatterns = llvm::None); +/// Creates a pass to perform control-flow sinking. +std::unique_ptr createControlFlowSinkPass(); + /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -307,6 +307,28 @@ ] # RewritePassUtils.options; } +def ControlFlowSink : Pass<"control-flow-sink"> { + let summary = "Sink operations into conditional blocks"; + let description = [{ + This pass implements a simple control-flow sink on operations that implement + `RegionBranchOpInterface` by moving dominating operations whose only uses + are in a single conditionally-executed region into that region so that + executions paths where their results are not needed do not perform + unnecessary computations. + + This is similar (but opposite) to loop-invariant code motion, which hoists + operations out of regions executed more than once. + + It is recommended to run canonicalization first to remove unreachable + blocks: ops in unreachable blocks may prevent other operations from being + sunk as they may contain uses of their results + }]; + let constructor = "::mlir::createControlFlowSinkPass()"; + let statistics = [ + Statistic<"numSunk", "num-sunk", "Number of operations sunk">, + ]; +} + def CSE : Pass<"cse"> { let summary = "Eliminate common sub-expressions"; let description = [{ diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -25,6 +25,7 @@ class AffineApplyOp; class AffineForOp; +class DominanceInfo; class Location; class OpBuilder; @@ -147,6 +148,53 @@ void createAffineComputationSlice(Operation *opInst, SmallVectorImpl *sliceOps); +/// Given a list of regions, perform control flow sinking on them. For each +/// region, control-flow sinking moves operations that dominate the region but +/// whose only users are in the region into the regions so that they aren't +/// executed on paths where their results are not needed. +/// +/// TODO: For the moment, this is a *simple* control-flow sink, i.e., no +/// duplicating of ops. It should be made to accept a cost model to determine +/// whether duplicating a particular op is profitable. +/// +/// Example: +/// +/// ```mlir +/// %0 = arith.addi %arg0, %arg1 +/// scf.if %cond { +/// scf.yield %0 +/// } else { +/// scf.yield %arg2 +/// } +/// ``` +/// +/// After control-flow sink: +/// +/// ```mlir +/// scf.if %cond { +/// %0 = arith.addi %arg0, %arg1 +/// scf.yield %0 +/// } else { +/// scf.yield %arg2 +/// } +/// ``` +/// +/// Users must supply a callback `shouldMoveIntoRegion` that determines whether +/// the given operation that only has users in the given operation should be +/// moved into that region. +/// +/// Returns the number of operations sunk. +size_t +controlFlowSink(ArrayRef regions, DominanceInfo &domInfo, + function_ref shouldMoveIntoRegion); + +/// Populates `regions` with regions of the provided region branch op that are +/// executed at most once at that are reachable given the current operands of +/// the op. These regions can be passed to `controlFlowSink` to perform sinking +/// on the regions of the operation. +void getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch, + SmallVectorImpl ®ions); + } // namespace mlir #endif // MLIR_TRANSFORMS_UTILS_H diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ BufferResultsToOutParams.cpp BufferUtils.cpp Canonicalizer.cpp + ControlFlowSink.cpp CSE.cpp Inliner.cpp LocationSnapshot.cpp diff --git a/mlir/lib/Transforms/ControlFlowSink.cpp b/mlir/lib/Transforms/ControlFlowSink.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/ControlFlowSink.cpp @@ -0,0 +1,71 @@ +//===- ControlFlowSink.cpp - Code to perform control-flow sinking ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a basic control-flow sink pass. Control-flow sinking +// moves operations whose only uses are in conditionally-executed blocks in to +// those blocks so that they aren't executed on paths where their results are +// not needed. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" + +using namespace mlir; + +namespace { +/// A basic control-flow sink pass. This pass analyzes the regions of operations +/// that implement `RegionBranchOpInterface` that are reachable and executed at +/// most once and sinks candidate operations that are side-effect free. +struct ControlFlowSink : public ControlFlowSinkBase { + void runOnOperation() override; +}; +} // end anonymous namespace + +/// Returns true if the given operation is side-effect free as are all of its +/// nested operations. +static bool isSideEffectFree(Operation *op) { + if (auto memInterface = dyn_cast(op)) { + // If the op has side-effects, it cannot be moved. + if (!memInterface.hasNoEffect()) + return false; + // If the op does not have recursive side effects, then it can be moved. + if (!op->hasTrait()) + return true; + } else if (!op->hasTrait()) { + // Otherwise, if the op does not implement the memory effect interface and + // it does not have recursive side effects, then it cannot be known that the + // op is moveable. + return false; + } + + // Recurse into the regions and ensure that all nested ops can also be moved. + for (Region ®ion : op->getRegions()) + for (Operation &op : region.getOps()) + if (!isSideEffectFree(&op)) + return false; + return true; +} + +void ControlFlowSink::runOnOperation() { + auto &domInfo = getAnalysis(); + getOperation()->walk([&](RegionBranchOpInterface branch) { + SmallVector regionsToSink; + getSinglyExecutedRegionsToSink(branch, regionsToSink); + numSunk = mlir::controlFlowSink( + regionsToSink, domInfo, + [](Operation *op, Region *) { return isSideEffectFree(op); }); + }); +} + +std::unique_ptr mlir::createControlFlowSinkPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRTransformUtils + ControlFlowSinkUtils.cpp DialectConversion.cpp FoldUtils.cpp GreedyPatternRewriteDriver.cpp diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp @@ -0,0 +1,152 @@ +//===- ControlFlowSinkUtils.cpp - Code to perform control-flow sinking ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilityies for control-flow sinking. Control-flow +// sinking moves operations whose only uses are in conditionally-executed blocks +// into those blocks so that they aren't executed on paths where their results +// are not needed. +// +// Control-flow sinking is not implemented on BranchOpInterface because +// sinking ops into the successors of branch operations may move ops into loops. +// It is idiomatic MLIR to perform optimizations at IR levels that readily +// provide the necessary information. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Transforms/Utils.h" +#include + +#define DEBUG_TYPE "cf-sink" + +using namespace mlir; + +namespace { +/// A helper struct for control-flow sinking. +class Sinker { +public: + /// Create an operation sinker with given dominance info. + Sinker(function_ref shouldMoveIntoRegion, + DominanceInfo &domInfo) + : shouldMoveIntoRegion(shouldMoveIntoRegion), domInfo(domInfo), + numSunk(0) {} + + /// Given a list of regions, find operations to sink and sink them. Return the + /// number of operations sunk. + size_t sinkRegions(ArrayRef regions) &&; + +private: + /// Given a region and an op which dominates the region, returns true if all + /// users of the given op are dominated by the entry block of the region, and + /// thus the operation can be sunk into the region. + bool allUsersDominatedBy(Operation *op, Region *region); + + /// Given a region and a top-level op (an op whose parent region is the given + /// region), determine whether the defining ops of the op's operands can be + /// sunk into the region. + /// + /// Add moved ops to the work queue. + void tryToSinkPredecessors(Operation *user, Region *region, + std::vector &stack); + + /// Iterate over all the ops in a region and try to sink their predecessors. + /// Recurse on subgraphs using a work queue. + void sinkRegion(Region *region); + + /// The callback to determine whether an op should be moved in to a region. + function_ref shouldMoveIntoRegion; + /// Dominance info to determine op user dominance with respect to regions. + DominanceInfo &domInfo; + /// The number of operations sunk. + size_t numSunk; +}; +} // end anonymous namespace + +bool Sinker::allUsersDominatedBy(Operation *op, Region *region) { + assert(region->findAncestorOpInRegion(*op) == nullptr && + "expected op to be defined outside the region"); + return llvm::all_of(op->getUsers(), [&](Operation *user) { + // The user is dominated by the region if its containing block is dominated + // by the region's entry block. + return domInfo.dominates(®ion->front(), user->getBlock()); + }); +} + +void Sinker::tryToSinkPredecessors(Operation *user, Region *region, + std::vector &stack) { + LLVM_DEBUG(user->print(llvm::dbgs() << "\nContained op:\n")); + for (Value value : user->getOperands()) { + Operation *op = value.getDefiningOp(); + // Ignore block arguments and ops that are already inside the region. + if (!op || op->getParentRegion() == region) + continue; + LLVM_DEBUG(op->print(llvm::dbgs() << "\nTry to sink:\n")); + + // If the op's users are all in the region and it can be moved, then do so. + if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) { + // Move the op into the region's entry block. If the op is part of a + // subgraph, dependee ops would have been moved first, so inserting before + // the start of the block will ensure dominance is preserved. Ops can only + // be safely moved into the entry block as the region's other blocks may + // for a loop. + op->moveBefore(®ion->front(), region->front().begin()); + ++numSunk; + // Add the op to the work queue. + stack.push_back(op); + } + } +} + +void Sinker::sinkRegion(Region *region) { + // Initialize the work queue with all the ops in the region. + std::vector stack; + for (Operation &op : region->getOps()) + stack.push_back(&op); + + // Process all the ops depth-first. This ensures that nodes of subgraphs are + // sunk in the correct order. + while (!stack.empty()) { + Operation *op = stack.back(); + stack.pop_back(); + tryToSinkPredecessors(op, region, stack); + } +} + +size_t Sinker::sinkRegions(ArrayRef regions) && { + for (Region *region : regions) + if (!region->empty()) + sinkRegion(region); + return numSunk; +} + +size_t mlir::controlFlowSink( + ArrayRef regions, DominanceInfo &domInfo, + function_ref shouldMoveIntoRegion) { + return Sinker(shouldMoveIntoRegion, domInfo).sinkRegions(regions); +} + +void mlir::getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch, + SmallVectorImpl ®ions) { + // Collect constant operands. + SmallVector operands(branch->getNumOperands(), Attribute()); + for (auto &it : llvm::enumerate(branch->getOperands())) + matchPattern(it.value(), m_Constant(&operands[it.index()])); + // Get the invocation bounds. + SmallVector bounds; + branch.getRegionInvocationBounds(operands, bounds); + + // For a simple control-flow sink, only consider regions that are executed at + // most once. + for (auto it : llvm::zip(branch->getRegions(), bounds)) { + const InvocationBounds &bound = std::get<1>(it); + if (bound.getUpperBound() && *bound.getUpperBound() <= 1) + regions.push_back(&std::get<0>(it)); + } +} diff --git a/mlir/test/Transforms/control-flow-sink.mlir b/mlir/test/Transforms/control-flow-sink.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/control-flow-sink.mlir @@ -0,0 +1,210 @@ +// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s + +// Test that operations can be sunk. + +// CHECK-LABEL: @test_simple_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +// CHECK-NEXT: %[[V0:.*]] = arith.subi %[[ARG2]], %[[ARG1]] +// CHECK-NEXT: %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then { +// CHECK-NEXT: %[[V2:.*]] = arith.subi %[[ARG1]], %[[ARG2]] +// CHECK-NEXT: test.region_if_yield %[[V2]] +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG1]] +// CHECK-NEXT: %[[V3:.*]] = arith.addi %[[V0]], %[[V2]] +// CHECK-NEXT: test.region_if_yield %[[V3]] +// CHECK-NEXT: } join { +// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[ARG2]], %[[ARG2]] +// CHECK-NEXT: %[[V3:.*]] = arith.addi %[[V2]], %[[V0]] +// CHECK-NEXT: test.region_if_yield %[[V3]] +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V1]] +func @test_simple_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 { + %0 = arith.subi %arg1, %arg2 : i32 + %1 = arith.subi %arg2, %arg1 : i32 + %2 = arith.addi %arg1, %arg1 : i32 + %3 = arith.addi %arg2, %arg2 : i32 + %4 = test.region_if %arg0: i1 -> i32 then { + test.region_if_yield %0 : i32 + } else { + %5 = arith.addi %1, %2 : i32 + test.region_if_yield %5 : i32 + } join { + %5 = arith.addi %3, %1 : i32 + test.region_if_yield %5 : i32 + } + return %4 : i32 +} + +// ----- + +// Test that a region op can be sunk. + +// CHECK-LABEL: @test_region_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then { +// CHECK-NEXT: %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then { +// CHECK-NEXT: test.region_if_yield %[[ARG1]] +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[V2:.*]] = arith.subi %[[ARG1]], %[[ARG2]] +// CHECK-NEXT: test.region_if_yield %[[V2]] +// CHECK-NEXT: } join { +// CHECK-NEXT: test.region_if_yield %[[ARG2]] +// CHECK-NEXT: } +// CHECK-NEXT: test.region_if_yield %[[V1]] +// CHECK-NEXT: } else { +// CHECK-NEXT: test.region_if_yield %[[ARG1]] +// CHECK-NEXT: } join { +// CHECK-NEXT: test.region_if_yield %[[ARG2]] +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]] +func @test_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 { + %0 = arith.subi %arg1, %arg2 : i32 + %1 = test.region_if %arg0: i1 -> i32 then { + test.region_if_yield %arg1 : i32 + } else { + test.region_if_yield %0 : i32 + } join { + test.region_if_yield %arg2 : i32 + } + %2 = test.region_if %arg0: i1 -> i32 then { + test.region_if_yield %1 : i32 + } else { + test.region_if_yield %arg1 : i32 + } join { + test.region_if_yield %arg2 : i32 + } + return %2 : i32 +} + +// ----- + +// Test that an entire subgraph can be sunk. + +// CHECK-LABEL: @test_subgraph_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then { +// CHECK-NEXT: %[[V1:.*]] = arith.subi %[[ARG1]], %[[ARG2]] +// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG2]] +// CHECK-NEXT: %[[V3:.*]] = arith.subi %[[ARG2]], %[[ARG1]] +// CHECK-NEXT: %[[V4:.*]] = arith.muli %[[V3]], %[[V3]] +// CHECK-NEXT: %[[V5:.*]] = arith.muli %[[V2]], %[[V1]] +// CHECK-NEXT: %[[V6:.*]] = arith.addi %[[V5]], %[[V4]] +// CHECK-NEXT: test.region_if_yield %[[V6]] +// CHECK-NEXT: } else { +// CHECK-NEXT: test.region_if_yield %[[ARG1]] +// CHECK-NEXT: } join { +// CHECK-NEXT: test.region_if_yield %[[ARG2]] +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]] +func @test_subgraph_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 { + %0 = arith.addi %arg1, %arg2 : i32 + %1 = arith.subi %arg1, %arg2 : i32 + %2 = arith.subi %arg2, %arg1 : i32 + %3 = arith.muli %0, %1 : i32 + %4 = arith.muli %2, %2 : i32 + %5 = arith.addi %3, %4 : i32 + %6 = test.region_if %arg0: i1 -> i32 then { + test.region_if_yield %5 : i32 + } else { + test.region_if_yield %arg1 : i32 + } join { + test.region_if_yield %arg2 : i32 + } + return %6 : i32 +} + +// ----- + +// Test that ops can be sunk into regions with multiple blocks. + +// CHECK-LABEL: @test_multiblock_region_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +// CHECK-NEXT: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG2]] +// CHECK-NEXT: %[[V1:.*]] = "test.any_cond"() ({ +// CHECK-NEXT: %[[V3:.*]] = arith.addi %[[V0]], %[[ARG2]] +// CHECK-NEXT: %[[V4:.*]] = arith.addi %[[V3]], %[[ARG1]] +// CHECK-NEXT: br ^bb1(%[[V4]] : i32) +// CHECK-NEXT: ^bb1(%[[V5:.*]]: i32): +// CHECK-NEXT: %[[V6:.*]] = arith.addi %[[V5]], %[[V4]] +// CHECK-NEXT: "test.yield"(%[[V6]]) +// CHECK-NEXT: }) +// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] +// CHECK-NEXT: return %[[V2]] +func @test_multiblock_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 { + %0 = arith.addi %arg1, %arg2 : i32 + %1 = arith.addi %0, %arg2 : i32 + %2 = arith.addi %1, %arg1 : i32 + %3 = "test.any_cond"() ({ + br ^bb1(%2 : i32) + ^bb1(%5: i32): + %6 = arith.addi %5, %2 : i32 + "test.yield"(%6) : (i32) -> () + }) : () -> i32 + %4 = arith.addi %0, %3 : i32 + return %4 : i32 +} + +// ----- + +// Test that ops can be sunk recursively into nested regions. + +// CHECK-LABEL: @test_nested_region_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) -> i32 { +// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then { +// CHECK-NEXT: %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then { +// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG1]] +// CHECK-NEXT: test.region_if_yield %[[V2]] +// CHECK-NEXT: } else { +// CHECK-NEXT: test.region_if_yield %[[ARG1]] +// CHECK-NEXT: } join { +// CHECK-NEXT: test.region_if_yield %[[ARG1]] +// CHECK-NEXT: } +// CHECK-NEXT: test.region_if_yield %[[V1]] +// CHECK-NEXT: } else { +// CHECK-NEXT: test.region_if_yield %[[ARG1]] +// CHECK-NEXT: } join { +// CHECK-NEXT: test.region_if_yield %[[ARG1]] +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]] +func @test_nested_region_sink(%arg0: i1, %arg1: i32) -> i32 { + %0 = arith.addi %arg1, %arg1 : i32 + %1 = test.region_if %arg0: i1 -> i32 then { + %2 = test.region_if %arg0: i1 -> i32 then { + test.region_if_yield %0 : i32 + } else { + test.region_if_yield %arg1 : i32 + } join { + test.region_if_yield %arg1 : i32 + } + test.region_if_yield %2 : i32 + } else { + test.region_if_yield %arg1 : i32 + } join { + test.region_if_yield %arg1 : i32 + } + return %1 : i32 +} + +// ----- + +// Test that ops are only moved into the entry block, even when their only uses +// are further along. + +// CHECK-LABEL: @test_not_sunk_deeply +// CHECK-SAME: (%[[ARG0:.*]]: i32) -> i32 { +// CHECK-NEXT: %[[V0:.*]] = "test.any_cond"() ({ +// CHECK-NEXT: %[[V1:.*]] = arith.addi %[[ARG0]], %[[ARG0]] +// CHECK-NEXT: br ^bb1 +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: "test.yield"(%[[V1]]) : (i32) -> () +// CHECK-NEXT: }) +// CHECK-NEXT: return %[[V0]] +func @test_not_sunk_deeply(%arg0: i32) -> i32 { + %0 = arith.addi %arg0, %arg0 : i32 + %1 = "test.any_cond"() ({ + br ^bb1 + ^bb1: + "test.yield"(%0) : (i32) -> () + }) : () -> i32 + return %1 : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1127,15 +1127,15 @@ p.printOperands(op.getOperands()); p << ": " << op.getOperandTypes(); p.printArrowTypeList(op.getResultTypes()); - p << " then"; + p << " then "; p.printRegion(op.getThenRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); - p << " else"; + p << " else "; p.printRegion(op.getElseRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); - p << " join"; + p << " join "; p.printRegion(op.getJoinRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); @@ -1189,6 +1189,34 @@ regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); } +void RegionIfOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl &invocationBounds) { + // Each region is invoked at most once. + invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); +} + +//===----------------------------------------------------------------------===// +// AnyCondOp +//===----------------------------------------------------------------------===// + +void AnyCondOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + // The parent op branches into the only region, and the region branches back + // to the parent op. + if (index) + regions.emplace_back(&getRegion()); + else + regions.emplace_back(getResults()); +} + +void AnyCondOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl &invocationBounds) { + invocationBounds.emplace_back(1, 1); +} + //===----------------------------------------------------------------------===// // SingleNoTerminatorCustomAsmOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2342,14 +2342,15 @@ } def RegionIfOp : TEST_Op<"region_if", - [DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"RegionIfYieldOp">, RecursiveSideEffects]> { let description =[{ Represents an abstract if-then-else-join pattern. In this context, the then and else regions jump to the join region, which finally returns to its parent op. - }]; + }]; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseRegionIfOp(parser, result); }]; @@ -2372,6 +2373,14 @@ }]; } +def AnyCondOp : TEST_Op<"any_cond", + [DeclareOpInterfaceMethods, + RecursiveSideEffects]> { + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); +} + //===----------------------------------------------------------------------===// // Test TableGen generated build() methods //===----------------------------------------------------------------------===//