diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -55,6 +55,10 @@ // Creates a pass which lowers for loops into while loops. std::unique_ptr createForToWhileLoopPass(); +/// Creates a pass to pull ops before and after an scf.if op into both scf.if op +/// regions. +std::unique_ptr createIfRegionExpansionPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -114,4 +114,18 @@ }]; } +def SCFIfRegionExpansion + : FunctionPass<"if-region-expansion"> { + let summary = "Pulls ops before and after scf.if into both scf.if regions"; + let constructor = "mlir::createIfRegionExpansionPass()"; + let description = [{ + This pass expands an scf.if op's regions by pulling in ops before and after + scf.if op into both regions of the scf.if op. This can be helpful as a + prelimiary step to enable further optimizations in both regions, which now + does not need to cross region boundaries. + + The scf.if op's parent region should only contain one block. + }]; +} + #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -158,6 +158,39 @@ /// loop bounds and loop steps are canonicalized. void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns); +/// Expands an scf.if op's regions by pulling ops before and after scf.if op +/// into both regions of the scf.if op. +// +/// For example, it converts the following IR: +/// ``` +/// %0 = opA .. +/// %1 = opB .. +/// %2 = scf.if .. { +/// %3 = opC %0 .. +/// scf.yield %3 +/// } else { +/// %4 = opD .. +/// scf.yield %4 +/// } +/// %5 = opE %2 .. +/// ``` +/// Into: +/// ``` +/// %2 = scf.if .. { +/// %0 = opA .. +/// %1 = opB .. +/// %3 = opC %0 .. +/// %5 = opE %3 .. +/// scf.yield %5 +/// } else { +/// %0 = opA .. +/// %1 = opB .. +/// %4 = opD .. +/// %5 = opE %4 .. +/// scf.yield %5 +/// } +/// ``` +void populateIfRegionExpansionPatterns(RewritePatternSet &patterns); } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ AffineCanonicalizationUtils.cpp Bufferize.cpp ForToWhile.cpp + IfRegionExpansion.cpp LoopCanonicalization.cpp LoopPipelining.cpp LoopRangeFolding.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp b/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp @@ -0,0 +1,184 @@ +//===- IfRegionExpansion.cpp - Pull ops into scf.if Region ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns and passes for expanding scf.if's regions +// by pulling in ops before and after the scf.if op into both regions of the +// scf.if op. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/AffineCanonicalizationUtils.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-scf-expand-if-region" + +using namespace mlir; + +static constexpr char kExpandedIfMarker[] = "__expanded_if_regions__"; + +/// Pulls ops at the same nest level as the given `ifOp` into both regions of +/// the if `ifOp`. +static FailureOr pullOpsIntoIfRegions(scf::IfOp ifOp, + RewriterBase &rewriter) { + // Need to pull ops into both regions. + if (!ifOp.elseBlock()) + return failure(); + + // Expect to only have one block in the enclosing region. This is the common + // case for the level where we have structured control flows and it avoids + // traditional control flow and simplifies the analysis. + if (!llvm::hasSingleElement(ifOp->getParentRegion()->getBlocks())) + return failure(); + + SmallVector allOps; + for (Operation &op : ifOp->getBlock()->without_terminator()) + allOps.push_back(&op); + + // If no ops before or after the if op, there is nothing to do. + if (allOps.size() == 1) + return failure(); + + auto prevOps = llvm::makeArrayRef(allOps).take_while( + [&ifOp](Operation *op) { return op != ifOp.getOperation(); }); + auto nextOps = llvm::makeArrayRef(allOps).drop_front(prevOps.size() + 1); + + // Require all previous ops to have on side effects, so that after cloning + // them into both regions, we can rely on DCE to remove them. + if (llvm::any_of(prevOps, [](Operation *op) { + return !MemoryEffectOpInterface::hasNoEffect(op); + })) + return failure(); + + Operation *parentTerminator = ifOp->getBlock()->getTerminator(); + TypeRange resultTypes = ifOp.getResultTypes(); + if (!nextOps.empty()) { + // The if op should yield the values used by the terminator. + resultTypes = parentTerminator->getOperandTypes(); + } + + auto newIfOp = rewriter.create( + ifOp.getLoc(), resultTypes, ifOp.getCondition(), ifOp.elseBlock()); + + auto pullIntoBlock = [&](Block *newblock, Block *oldBlock) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(newblock); + BlockAndValueMapping bvm; + + // Clone all ops defined before the original if op. + for (Operation *prevOp : prevOps) + rewriter.clone(*prevOp, bvm); + + // Clone all ops defined inside the original if block. + for (Operation &blockOp : oldBlock->without_terminator()) + rewriter.clone(blockOp, bvm); + + if (nextOps.empty()) { + // If the if op needs to return value, its builder won't automatically + // insert terminators. Just clone the old one here. + if (newIfOp->getNumResults()) + rewriter.clone(*oldBlock->getTerminator(), bvm); + return; + } + + // There are ops after the old if op. Uses of the old if op should be + // replaced by the cloned yield value. + auto oldYieldOp = cast(oldBlock->back()); + for (int i = 0, e = ifOp->getNumResults(); i < e; ++i) { + bvm.map(ifOp->getResult(i), bvm.lookup(oldYieldOp.getOperand(i))); + } + + // Clone all ops defined after the original if op. While doing that, we need + // to check whether the op is used by the terminator. If so, we need to + // yield its result value at the proper index. + SmallVector yieldValues(newIfOp.getNumResults()); + for (Operation *nextOp : nextOps) { + rewriter.clone(*nextOp, bvm); + for (OpOperand &use : nextOp->getUses()) { + if (use.getOwner() == parentTerminator) { + unsigned index = use.getOperandNumber(); + yieldValues[index] = bvm.lookup(use.get()); + } + } + } + + if (!yieldValues.empty()) { + // Again the if builder won't insert terminators automatically. + rewriter.create(ifOp.getLoc(), yieldValues); + } + }; + + pullIntoBlock(newIfOp.thenBlock(), ifOp.thenBlock()); + pullIntoBlock(newIfOp.elseBlock(), ifOp.elseBlock()); + + if (nextOps.empty()) { + rewriter.replaceOp(ifOp, newIfOp->getResults()); + } else { + // Update the terminator to use the new if op's results. + rewriter.updateRootInPlace(parentTerminator, [&]() { + parentTerminator->setOperands(newIfOp->getResults()); + }); + // We have pulled in all ops following the if op into both regions. Now + // remove them all. Do this in the reverse order. + for (Operation *op : llvm::reverse(nextOps)) + rewriter.eraseOp(op); + rewriter.eraseOp(ifOp); + } + + return newIfOp; +} + +namespace { + +struct IfRegionExpansionPattern final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + if (ifOp->hasAttr(kExpandedIfMarker)) + return failure(); + + auto newOp = pullOpsIntoIfRegions(ifOp, rewriter); + if (failed(newOp)) + return failure(); + + newOp.getValue()->setAttr(kExpandedIfMarker, rewriter.getUnitAttr()); + return success(); + } +}; + +struct IfRegionExpansion : public SCFIfRegionExpansionBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace + +void scf::populateIfRegionExpansionPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + +std::unique_ptr mlir::createIfRegionExpansionPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/if-region-expansion.mlir b/mlir/test/Dialect/SCF/if-region-expansion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/if-region-expansion.mlir @@ -0,0 +1,160 @@ +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -if-region-expansion %s | FileCheck %s + +// CHECK-LABEL: func @parent_region_only_contains_if_op +func @parent_region_only_contains_if_op(%cond: i1, %val: i32) -> i32 { + %if = scf.if %cond -> i32 { + scf.yield %val: i32 + } else { + scf.yield %val: i32 + } + return %if: i32 +} + +// CHECK-NOT: __expanded_if_regions__ + +// ----- + +// CHECK-LABEL: func @side_effect_op_before_if_op +func @side_effect_op_before_if_op(%cond: i1, %v0: i32, %v1: i32, %buffer: memref<3xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + memref.store %v0, %buffer[%c0] : memref<3xi32> + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[%c1] : memref<3xi32> + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[%c2] : memref<3xi32> + return +} + +// CHECK-NOT: __expanded_if_regions__ + +// ----- + +// CHECK-LABEL: func @if_op_without_else_branch +func @if_op_without_else_branch(%cond: i1, %v0: i32, %v1: i32, %buffer: memref) { + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[] : memref + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[] : memref + return +} + +// CHECK-NOT: __expanded_if_regions__ + +// ----- + +// CHECK-LABEL: func @ops_before_and_after_if_op +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @ops_before_and_after_if_op(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if = scf.if %cond -> i32 { + scf.yield %add: i32 + } else { + scf.yield %sub: i32 + } + %mul = arith.muli %if, %v0 : i32 + %div = arith.divsi %if, %v1 : i32 + return %mul, %div: i32, i32 +} + +// CHECK: %[[IF:.+]]:2 = scf.if %[[COND]] -> (i32, i32) +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[ADD]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[SUB]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: return %[[IF]]#0, %[[IF]]#1 + +// ----- + +// CHECK-LABEL: func @zero_result_if_op +func @zero_result_if_op(%cond: i1, %v0: i32, %v1: i32, %buffer: memref) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[] : memref + } else { + memref.store %sub, %buffer[] : memref + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[] : memref + return +} + +// CHECK: scf.if +// CHECK: %[[ADD:.+]] = arith.addi +// CHECK: memref.store %[[ADD]] +// CHECK: %[[MUL:.+]] = arith.muli +// CHECK: memref.store %[[MUL]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi +// CHECK: memref.store %[[SUB]] +// CHECK: %[[MUL:.+]] = arith.muli +// CHECK: memref.store %[[MUL]] + +// ----- + +// CHECK-LABEL: func @multi_result_if_op +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @multi_result_if_op(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if:2 = scf.if %cond -> (i32, i32) { + scf.yield %add, %sub: i32, i32 + } else { + scf.yield %sub, %add: i32, i32 + } + %mul = arith.muli %if#0, %v0 : i32 + %div = arith.divsi %if#1, %v1 : i32 + return %mul, %div: i32, i32 +} + +// CHECK: %[[IF:.+]]:2 = scf.if +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[SUB]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: } else { +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[ADD]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: return %[[IF]]#0, %[[IF]]#1 + +// ----- + +// CHECK-LABEL: func @multi_use_in_terminator +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @multi_use_in_terminator(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if = scf.if %cond -> i32 { + scf.yield %add: i32 + } else { + scf.yield %sub: i32 + } + %mul = arith.muli %if, %if : i32 + return %mul, %mul, %mul: i32, i32, i32 +} + +// CHECK: %[[IF:.+]]:3 = scf.if +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[ADD]] +// CHECK: scf.yield %[[MUL]], %[[MUL]], %[[MUL]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[SUB]] +// CHECK: scf.yield %[[MUL]], %[[MUL]], %[[MUL]] +// CHECK: return %[[IF]]#0, %[[IF]]#1, %[[IF]]#2