diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -171,6 +171,39 @@ /// loop bounds and loop steps are canonicalized. void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns); +/// Expands an scf.if op's regions by pulling ops before and after scf.if op +/// into both regions of the scf.if op. +// +/// For example, it converts the following IR: +/// ``` +/// %0 = opA .. +/// %1 = opB .. +/// %2 = scf.if .. { +/// %3 = opC %0 .. +/// scf.yield %3 +/// } else { +/// %4 = opD .. +/// scf.yield %4 +/// } +/// %5 = opE %2 .. +/// ``` +/// Into: +/// ``` +/// %2 = scf.if .. { +/// %0 = opA .. +/// %1 = opB .. +/// %3 = opC %0 .. +/// %5 = opE %3 .. +/// scf.yield %5 +/// } else { +/// %0 = opA .. +/// %1 = opB .. +/// %4 = opD .. +/// %5 = opE %4 .. +/// scf.yield %5 +/// } +/// ``` +void populateIfRegionExpansionPatterns(RewritePatternSet &patterns); } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ForToWhile.cpp + IfRegionExpansion.cpp LoopCanonicalization.cpp LoopPipelining.cpp LoopRangeFolding.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp b/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp @@ -0,0 +1,160 @@ +//===- IfRegionExpansion.cpp - Pull ops into scf.if Region ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns and passes for expanding scf.if's regions +// by pulling in ops before and after the scf.if op into both regions of the +// scf.if op. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-scf-expand-if-region" + +using namespace mlir; + +/// Pulls ops at the same nest level as the given `ifOp` into both regions of +/// the if `ifOp`. +static FailureOr pullOpsIntoIfRegions(scf::IfOp ifOp, + RewriterBase &rewriter) { + // Need to pull ops into both regions. + if (!ifOp.elseBlock()) + return failure(); + + // Expect to only have one block in the enclosing region. This is the common + // case for the level where we have structured control flows and it avoids + // traditional control flow and simplifies the analysis. + if (!llvm::hasSingleElement(ifOp->getParentRegion()->getBlocks())) + return failure(); + + SmallVector allOps; + for (Operation &op : ifOp->getBlock()->without_terminator()) + allOps.push_back(&op); + + // If no ops before or after the if op, there is nothing to do. + if (allOps.size() == 1) + return failure(); + + auto prevOps = llvm::makeArrayRef(allOps).take_while( + [&ifOp](Operation *op) { return op != ifOp.getOperation(); }); + auto nextOps = llvm::makeArrayRef(allOps).drop_front(prevOps.size() + 1); + + // Require all previous ops to have no side effects, so that after cloning + // them into both regions, we can rely on DCE to remove them. + if (llvm::any_of(prevOps, [](Operation *op) { + return !MemoryEffectOpInterface::hasNoEffect(op); + })) + return failure(); + + Operation *parentTerminator = ifOp->getBlock()->getTerminator(); + TypeRange resultTypes = ifOp.getResultTypes(); + if (!nextOps.empty()) { + // The if op should yield the values used by the terminator. + resultTypes = parentTerminator->getOperandTypes(); + } + + auto newIfOp = rewriter.create( + ifOp.getLoc(), resultTypes, ifOp.getCondition(), ifOp.elseBlock()); + + auto pullIntoBlock = [&](Block *newblock, Block *oldBlock) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(newblock); + BlockAndValueMapping bvm; + + // Clone all ops defined before the original if op. + for (Operation *prevOp : prevOps) + rewriter.clone(*prevOp, bvm); + + // Clone all ops defined inside the original if block. + for (Operation &blockOp : oldBlock->without_terminator()) + rewriter.clone(blockOp, bvm); + + if (nextOps.empty()) { + // If the if op needs to return value, its builder won't automatically + // insert terminators. Just clone the old one here. + if (newIfOp->getNumResults()) + rewriter.clone(*oldBlock->getTerminator(), bvm); + return; + } + + // There are ops after the old if op. Uses of the old if op should be + // replaced by the cloned yield value. + auto oldYieldOp = cast(oldBlock->back()); + for (int i = 0, e = ifOp->getNumResults(); i < e; ++i) { + bvm.map(ifOp->getResult(i), bvm.lookup(oldYieldOp.getOperand(i))); + } + + // Clone all ops defined after the original if op. While doing that, we need + // to check whether the op is used by the terminator. If so, we need to + // yield its result value at the proper index. + SmallVector yieldValues(newIfOp.getNumResults()); + for (Operation *nextOp : nextOps) { + rewriter.clone(*nextOp, bvm); + for (OpOperand &use : nextOp->getUses()) { + if (use.getOwner() == parentTerminator) { + unsigned index = use.getOperandNumber(); + yieldValues[index] = bvm.lookup(use.get()); + } + } + } + + if (!yieldValues.empty()) { + // Again the if builder won't insert terminators automatically. + rewriter.create(ifOp.getLoc(), yieldValues); + } + }; + + pullIntoBlock(newIfOp.thenBlock(), ifOp.thenBlock()); + pullIntoBlock(newIfOp.elseBlock(), ifOp.elseBlock()); + + if (nextOps.empty()) { + rewriter.replaceOp(ifOp, newIfOp->getResults()); + } else { + // Update the terminator to use the new if op's results. + rewriter.updateRootInPlace(parentTerminator, [&]() { + parentTerminator->setOperands(newIfOp->getResults()); + }); + // We have pulled in all ops following the if op into both regions. Now + // remove them all. Do this in the reverse order. + for (Operation *op : llvm::reverse(nextOps)) + rewriter.eraseOp(op); + rewriter.eraseOp(ifOp); + } + + return newIfOp; +} + +namespace { + +struct IfRegionExpansionPattern final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + return pullOpsIntoIfRegions(ifOp, rewriter); + } +}; + +} // namespace + +void scf::populateIfRegionExpansionPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/test/Dialect/SCF/if-region-expansion.mlir b/mlir/test/Dialect/SCF/if-region-expansion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/if-region-expansion.mlir @@ -0,0 +1,172 @@ +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -test-scf-transform-patterns=test-expand-if-region-patterns %s | FileCheck %s + +// CHECK-LABEL: func @parent_region_only_contains_if_op +func @parent_region_only_contains_if_op(%cond: i1, %val: i32) -> i32 { + %if = scf.if %cond -> i32 { + scf.yield %val: i32 + } else { + scf.yield %val: i32 + } + return %if: i32 +} + +// CHECK: scf.if +// CHECK: scf.yield +// CHECK: else +// CHECK: scf.yield + +// ----- + +// CHECK-LABEL: func @side_effect_op_before_if_op +func @side_effect_op_before_if_op(%cond: i1, %v0: i32, %v1: i32, %buffer: memref<3xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + memref.store %v0, %buffer[%c0] : memref<3xi32> + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[%c1] : memref<3xi32> + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[%c2] : memref<3xi32> + return +} + +// CHECK: memref.store +// CHECK: arith.addi +// CHECK: scf.if +// CHECK: memref.store +// CHECK: arith.muli +// CHECK: memref.store + +// ----- + +// CHECK-LABEL: func @if_op_without_else_branch +func @if_op_without_else_branch(%cond: i1, %v0: i32, %v1: i32, %buffer: memref) { + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[] : memref + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[] : memref + return +} + +// CHECK: arith.addi +// CHECK: scf.if +// CHECK: memref.store +// CHECK: arith.muli +// CHECK: memref.store + +// ----- + +// CHECK-LABEL: func @ops_before_and_after_if_op +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @ops_before_and_after_if_op(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if = scf.if %cond -> i32 { + scf.yield %add: i32 + } else { + scf.yield %sub: i32 + } + %mul = arith.muli %if, %v0 : i32 + %div = arith.divsi %if, %v1 : i32 + return %mul, %div: i32, i32 +} + +// CHECK: %[[IF:.+]]:2 = scf.if %[[COND]] -> (i32, i32) +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[ADD]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[SUB]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: return %[[IF]]#0, %[[IF]]#1 + +// ----- + +// CHECK-LABEL: func @zero_result_if_op +func @zero_result_if_op(%cond: i1, %v0: i32, %v1: i32, %buffer: memref) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[] : memref + } else { + memref.store %sub, %buffer[] : memref + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[] : memref + return +} + +// CHECK: scf.if +// CHECK: %[[ADD:.+]] = arith.addi +// CHECK: memref.store %[[ADD]] +// CHECK: %[[MUL:.+]] = arith.muli +// CHECK: memref.store %[[MUL]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi +// CHECK: memref.store %[[SUB]] +// CHECK: %[[MUL:.+]] = arith.muli +// CHECK: memref.store %[[MUL]] + +// ----- + +// CHECK-LABEL: func @multi_result_if_op +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @multi_result_if_op(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if:2 = scf.if %cond -> (i32, i32) { + scf.yield %add, %sub: i32, i32 + } else { + scf.yield %sub, %add: i32, i32 + } + %mul = arith.muli %if#0, %v0 : i32 + %div = arith.divsi %if#1, %v1 : i32 + return %mul, %div: i32, i32 +} + +// CHECK: %[[IF:.+]]:2 = scf.if +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[SUB]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: } else { +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[ADD]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: return %[[IF]]#0, %[[IF]]#1 + +// ----- + +// CHECK-LABEL: func @multi_use_in_terminator +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @multi_use_in_terminator(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if = scf.if %cond -> i32 { + scf.yield %add: i32 + } else { + scf.yield %sub: i32 + } + %mul = arith.muli %if, %if : i32 + return %mul, %mul, %mul: i32, i32, i32 +} + +// CHECK: %[[IF:.+]]:3 = scf.if +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[ADD]] +// CHECK: scf.yield %[[MUL]], %[[MUL]], %[[MUL]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[SUB]] +// CHECK: scf.yield %[[MUL]], %[[MUL]], %[[MUL]] +// CHECK: return %[[IF]]#0, %[[IF]]#1, %[[IF]]#2 diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRSCFTestPasses TestLoopParametricTiling.cpp TestLoopUnrolling.cpp + TestSCFTransforms.cpp TestSCFUtils.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/SCF/TestSCFTransforms.cpp b/mlir/test/lib/Dialect/SCF/TestSCFTransforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/SCF/TestSCFTransforms.cpp @@ -0,0 +1,68 @@ +//===- TestSCFTransforms.cpp - Test SCF transformation patterns -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing SCF transformations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestSCFTransforms + : public PassWrapper> { + TestSCFTransforms() = default; + TestSCFTransforms(const TestSCFTransforms &pass) : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { return "test-scf-transform-patterns"; } + StringRef getDescription() const final { + return "Test SCF patterns by applying them selectively or greedily."; + } + + void runOnOperation() override; + + Option testExpandIfRegionPatterns{ + *this, "test-expand-if-region-patterns", + llvm::cl::desc("Test patterns to expand if regions"), + llvm::cl::init(false)}; +}; +} // namespace + +static void applyExpandIfRegionPatterns(FuncOp funcOp) { + SmallVector candidates; + funcOp.walk([&](scf::IfOp ifOp) { candidates.push_back(ifOp); }); + + RewritePatternSet patterns(funcOp.getContext()); + scf::populateIfRegionExpansionPatterns(patterns); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + for (scf::IfOp ifOp : candidates) { + // Apply transforms per op to avoid recursive behavior. + (void)applyOpPatternsAndFold(ifOp, frozenPatterns, /*erased=*/nullptr); + } +} + +void TestSCFTransforms::runOnOperation() { + FuncOp func = getOperation(); + if (testExpandIfRegionPatterns) + applyExpandIfRegionPatterns(func); +} + +namespace mlir { +namespace test { +void registerTestSCFTransforms() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -104,6 +104,7 @@ void registerTestPDLByteCodePass(); void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); +void registerTestSCFTransforms(); void registerTestSCFUtilsPass(); void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); @@ -193,6 +194,7 @@ mlir::test::registerTestPadFusion(); mlir::test::registerTestPDLByteCodePass(); mlir::test::registerTestRecursiveTypesPass(); + mlir::test::registerTestSCFTransforms(); mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms();