diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -171,6 +171,45 @@ /// loop bounds and loop steps are canonicalized. void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns); +/// Expands an scf.if op's regions by pulling ops before and after scf.if op +/// into both regions of the scf.if op. +// +/// For example, it converts the following IR: +/// ``` +/// %0 = opA .. +/// %1 = opB .. +/// %2 = scf.if .. { +/// %3 = opC %0 .. +/// scf.yield %3 +/// } else { +/// %4 = opD .. +/// scf.yield %4 +/// } +/// %5 = opE %2 .. +/// ``` +/// Into: +/// ``` +/// %2 = scf.if .. { +/// %0 = opA .. +/// %1 = opB .. +/// %3 = opC %0 .. +/// %5 = opE %3 .. +/// scf.yield %5 +/// } else { +/// %0 = opA .. +/// %1 = opB .. +/// %4 = opD .. +/// %5 = opE %4 .. +/// scf.yield %5 +/// } +/// ``` +/// +/// `canMoveToRegion` is called for each op before and after the scf.if op to +/// decide whether it can be moved into the region. If not, the op and its +/// backfoward/forward slices are kept as-is. +void populateIfRegionExpansionPatterns( + RewritePatternSet &patterns, + function_ref canMoveToRegion); } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ForToWhile.cpp + IfRegionExpansion.cpp LoopCanonicalization.cpp LoopPipelining.cpp LoopRangeFolding.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp b/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp @@ -0,0 +1,239 @@ +//===- IfRegionExpansion.cpp - Pull ops into scf.if Region ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns and passes for expanding scf.if's regions +// by pulling in ops before and after the scf.if op into both regions of the +// scf.if op. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-scf-expand-if-region" + +using namespace mlir; + +/// Pulls ops at the same nest level as the given `ifOp` into both regions of +/// the if `ifOp`. +static FailureOr +pullOpsIntoIfRegions(scf::IfOp ifOp, + function_ref canMoveToRegion, + RewriterBase &rewriter) { + // Need to pull ops into both regions. + if (!ifOp.elseBlock()) + return failure(); + + // Expect to only have one block in the enclosing region. This is the common + // case for the level where we have structured control flows and it avoids + // traditional control flow and simplifies the analysis. + if (!llvm::hasSingleElement(ifOp->getParentRegion()->getBlocks())) + return failure(); + + SmallVector allOps; + for (Operation &op : ifOp->getBlock()->without_terminator()) + allOps.push_back(&op); + + // If no ops before or after the if op, there is nothing to do. + if (allOps.size() == 1) + return failure(); + + // Return true if the given `op` is in the same region as the scf.if op. + auto isFromSameRegion = [ifOp](Operation *op) { + return op->getParentRegion() == ifOp->getParentRegion(); + }; + + // Collect ops before and after the scf.if op. + auto allPrevOps = llvm::makeArrayRef(allOps).take_while( + [&ifOp](Operation *op) { return op != ifOp.getOperation(); }); + auto allNextOps = + llvm::makeArrayRef(allOps).drop_front(allPrevOps.size() + 1); + + // Find previous ops that cannot be moved into the regions. + SetVector stickyPrevOps; + // We cannot move the op into the region if + // - The op is part of the backward slice for computing conditions. + // - The op is part of the backward slice for a op that cannot move. + if (Operation *condOp = ifOp.getCondition().getDefiningOp()) { + getBackwardSlice(condOp, &stickyPrevOps, isFromSameRegion); + stickyPrevOps.insert(condOp); + } + for (Operation *op : llvm::reverse(allPrevOps)) { + if (stickyPrevOps.contains(op)) + continue; + if (!canMoveToRegion(op)) { + getBackwardSlice(op, &stickyPrevOps, isFromSameRegion); + stickyPrevOps.insert(op); // Add the current op back. + } + } + + // Find out previous ops that cannot be moved into the regions. + SetVector stickyNextOps; + // We cannot move the op into the region if + // - The op is part of the forward slice for a op that cannot move. + for (Operation *op : allNextOps) { + if (!canMoveToRegion(op)) { + getForwardSlice(op, &stickyNextOps, isFromSameRegion); + stickyNextOps.insert(op); // Add the current op back. + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "sticky previous ops:\n"; + for (Operation *op : stickyPrevOps) + llvm::dbgs() << " " << *op << "\n"; + llvm::dbgs() << "sticky next ops:\n"; + for (Operation *op : stickyNextOps) + llvm::dbgs() << " " << *op << "\n"; + }); + + // NYI support for the case where we have sticky next ops. For such cases we + // need to analyze their operands and figure out which are later coming from + // the if region. It can be complicated; support this only truly needed. + if (!stickyNextOps.empty()) + return failure(); + + // Now get the ops that can be moved into the regions. + SmallVector prevOps, nextOps; + for (Operation *op : allPrevOps) + if (!stickyPrevOps.contains(op)) + prevOps.push_back(op); + for (Operation *op : allNextOps) + if (!stickyNextOps.contains(op)) + nextOps.push_back(op); + + LLVM_DEBUG({ + llvm::dbgs() << "previous ops to move:\n"; + for (Operation *op : prevOps) + llvm::dbgs() << " " << *op << "\n"; + llvm::dbgs() << "next ops to move:\n"; + for (Operation *op : nextOps) + llvm::dbgs() << " " << *op << "\n"; + }); + if (prevOps.empty() && nextOps.empty()) + return failure(); + + Operation *parentTerminator = ifOp->getBlock()->getTerminator(); + TypeRange resultTypes = ifOp.getResultTypes(); + if (!nextOps.empty()) { + // The if op should yield the values used by the terminator. + resultTypes = parentTerminator->getOperandTypes(); + } + + auto newIfOp = rewriter.create( + ifOp.getLoc(), resultTypes, ifOp.getCondition(), ifOp.elseBlock()); + + auto pullIntoBlock = [&](Block *newblock, Block *oldBlock) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(newblock); + BlockAndValueMapping bvm; + + // Clone all ops defined before the original if op. + for (Operation *prevOp : prevOps) + rewriter.clone(*prevOp, bvm); + + // Clone all ops defined inside the original if block. + for (Operation &blockOp : oldBlock->without_terminator()) + rewriter.clone(blockOp, bvm); + + if (nextOps.empty()) { + // If the if op needs to return value, its builder won't automatically + // insert terminators. Just clone the old one here. + if (newIfOp->getNumResults()) + rewriter.clone(*oldBlock->getTerminator(), bvm); + return; + } + + // There are ops after the old if op. Uses of the old if op should be + // replaced by the cloned yield value. + auto oldYieldOp = cast(oldBlock->back()); + for (int i = 0, e = ifOp->getNumResults(); i < e; ++i) { + bvm.map(ifOp->getResult(i), bvm.lookup(oldYieldOp.getOperand(i))); + } + + // Clone all ops defined after the original if op. While doing that, we need + // to check whether the op is used by the terminator. If so, we need to + // yield its result value at the proper index. + SmallVector yieldValues(newIfOp.getNumResults()); + for (Operation *nextOp : nextOps) { + rewriter.clone(*nextOp, bvm); + for (OpOperand &use : nextOp->getUses()) { + if (use.getOwner() == parentTerminator) { + unsigned index = use.getOperandNumber(); + yieldValues[index] = bvm.lookup(use.get()); + } + } + } + + if (!yieldValues.empty()) { + // Again the if builder won't insert terminators automatically. + rewriter.create(ifOp.getLoc(), yieldValues); + } + }; + + pullIntoBlock(newIfOp.thenBlock(), ifOp.thenBlock()); + pullIntoBlock(newIfOp.elseBlock(), ifOp.elseBlock()); + + if (nextOps.empty()) { + rewriter.replaceOp(ifOp, newIfOp->getResults()); + } else { + // Update the terminator to use the new if op's results. + rewriter.updateRootInPlace(parentTerminator, [&]() { + parentTerminator->setOperands(newIfOp->getResults()); + }); + // We have pulled in all ops following the if op into both regions. Now + // remove them all. Do this in the reverse order. + for (Operation *op : llvm::reverse(nextOps)) + rewriter.eraseOp(op); + rewriter.eraseOp(ifOp); + } + for (Operation *op : llvm::reverse(prevOps)) + rewriter.eraseOp(op); + + return newIfOp; +} + +namespace { + +class IfRegionExpansionPattern final : public OpRewritePattern { +public: + IfRegionExpansionPattern(MLIRContext *context, + function_ref canMoveToRegion, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), canMoveToRegion(canMoveToRegion) {} + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + return pullOpsIntoIfRegions(ifOp, canMoveToRegion, rewriter); + } + +private: + std::function canMoveToRegion; +}; + +} // namespace + +void scf::populateIfRegionExpansionPatterns( + RewritePatternSet &patterns, + function_ref canMoveToRegion) { + patterns.insert(patterns.getContext(), + canMoveToRegion); +} diff --git a/mlir/test/Dialect/SCF/if-region-expansion.mlir b/mlir/test/Dialect/SCF/if-region-expansion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/if-region-expansion.mlir @@ -0,0 +1,289 @@ +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -test-scf-transform-patterns=test-expand-if-region-patterns %s | FileCheck %s + +// CHECK-LABEL: func @parent_region_only_contains_if_op +func @parent_region_only_contains_if_op(%cond: i1, %val: i32) -> i32 { + %if = scf.if %cond -> i32 { + scf.yield %val: i32 + } else { + scf.yield %val: i32 + } + return %if: i32 +} + +// CHECK: scf.if +// CHECK: scf.yield +// CHECK: else +// CHECK: scf.yield + +// ----- + +// CHECK-LABEL: func @if_op_without_else_branch +func @if_op_without_else_branch(%cond: i1, %v0: i32, %v1: i32, %buffer: memref) { + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[] : memref + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[] : memref + return +} + +// CHECK: arith.addi +// CHECK: scf.if +// CHECK: memref.store +// CHECK: arith.muli +// CHECK: memref.store + +// ----- + +// CHECK-LABEL: func @non_side_effecting_ops_before_and_after_if_op +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @non_side_effecting_ops_before_and_after_if_op(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if = scf.if %cond -> i32 { + scf.yield %add: i32 + } else { + scf.yield %sub: i32 + } + %mul = arith.muli %if, %v0 : i32 + %div = arith.divsi %if, %v1 : i32 + return %mul, %div: i32, i32 +} + +// CHECK: %[[IF:.+]]:2 = scf.if %[[COND]] -> (i32, i32) +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[ADD]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[SUB]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: return %[[IF]]#0, %[[IF]]#1 + +// ----- + +// CHECK-LABEL: func @side_effect_op_before_after_if_op +func @side_effect_op_before_after_if_op(%cond: i1, %v0: i32, %v1: i32, %buffer: memref<3xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + memref.store %v0, %buffer[%c0] : memref<3xi32> + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[%c1] : memref<3xi32> + } else { + memref.store %add, %buffer[%c1] : memref<3xi32> + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[%c2] : memref<3xi32> + return +} + +// The control in the test pass allows moving side effecting ops. + +// CHECK: scf.if +// CHECK: memref.store +// CHECK: arith.addi +// CHECK: memref.store +// CHECK: arith.muli +// CHECK: memref.store +// CHECK: } else { +// CHECK: memref.store +// CHECK: arith.addi +// CHECK: memref.store +// CHECK: arith.muli +// CHECK: memref.store + +// ----- + +// CHECK-LABEL: func @zero_result_if_op +func @zero_result_if_op(%cond: i1, %v0: i32, %v1: i32, %buffer: memref) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[] : memref + } else { + memref.store %sub, %buffer[] : memref + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[] : memref + return +} + +// CHECK: scf.if +// CHECK: %[[ADD:.+]] = arith.addi +// CHECK: memref.store %[[ADD]] +// CHECK: %[[MUL:.+]] = arith.muli +// CHECK: memref.store %[[MUL]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi +// CHECK: memref.store %[[SUB]] +// CHECK: %[[MUL:.+]] = arith.muli +// CHECK: memref.store %[[MUL]] + +// ----- + +// CHECK-LABEL: func @multi_result_if_op +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @multi_result_if_op(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if:2 = scf.if %cond -> (i32, i32) { + scf.yield %add, %sub: i32, i32 + } else { + scf.yield %sub, %add: i32, i32 + } + %mul = arith.muli %if#0, %v0 : i32 + %div = arith.divsi %if#1, %v1 : i32 + return %mul, %div: i32, i32 +} + +// CHECK: %[[IF:.+]]:2 = scf.if +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[SUB]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: } else { +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[ADD]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: return %[[IF]]#0, %[[IF]]#1 + +// ----- + +// CHECK-LABEL: func @multi_use_in_terminator +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @multi_use_in_terminator(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if = scf.if %cond -> i32 { + scf.yield %add: i32 + } else { + scf.yield %sub: i32 + } + %mul = arith.muli %if, %if : i32 + return %mul, %mul, %mul: i32, i32, i32 +} + +// CHECK: %[[IF:.+]]:3 = scf.if +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[ADD]] +// CHECK: scf.yield %[[MUL]], %[[MUL]], %[[MUL]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[SUB]] +// CHECK: scf.yield %[[MUL]], %[[MUL]], %[[MUL]] +// CHECK: return %[[IF]]#0, %[[IF]]#1, %[[IF]]#2 + +// ----- + +// CHECK-LABEL: func @sticky_op_before_if_op +func @sticky_op_before_if_op(%cond: i1, %index: index, %v0: i32, %v1: i32, %buffer: memref<3xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %loc = arith.muli %index, %index : index + memref.store %v0, %buffer[%loc] {sticky} : memref<3xi32> + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[%c1] : memref<3xi32> + } else { + memref.store %add, %buffer[%c1] : memref<3xi32> + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[%c2] : memref<3xi32> + return +} + +// The control in the test pass disallows moving previous ops with "sticky" attribute and its backward slice. + +// CHECK: arith.muli +// CHECK: memref.store {{.*}} {sticky} +// CHECK: scf.if +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK: arith.addi +// CHECK: memref.store +// CHECK: arith.muli +// CHECK: memref.store +// CHECK: } else { +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK: arith.addi +// CHECK: memref.store +// CHECK: arith.muli +// CHECK: memref.store + +// ----- + +// CHECK-LABEL: func @sticky_op_after_if_op +func @sticky_op_after_if_op(%cond: i1, %index: index, %v0: i32, %v1: i32, %buffer: memref<3xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %loc = arith.muli %index, %index : index + memref.store %v0, %buffer[%loc] : memref<3xi32> + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[%c1] : memref<3xi32> + } else { + memref.store %add, %buffer[%c1] : memref<3xi32> + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[%c2] {sticky} : memref<3xi32> + return +} + +// NYI case for "sticky" next ops and its backward slice. + +// CHECK: arith.muli +// CHECK: memref.store +// CHECK: arith.addi +// CHECK: scf.if {{.*}} { +// CHECK: memref.store +// CHECK: } else { +// CHECK: memref.store +// CHECK: } +// CHECK: arith.muli +// CHECK: memref.store {{.*}} {sticky} + +// ----- + +// CHECK-LABEL: func @condition_back_slice +func @condition_back_slice(%cond0: i1, %cond1: i1, %cond2 : i1, %v0: i32, %v1: i32, %buffer: memref<3xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %add = arith.addi %v0, %v1 : i32 + %and = arith.andi %cond0, %cond1 : i1 + %or = arith.ori %and, %cond2 : i1 + scf.if %or { + memref.store %add, %buffer[%c1] : memref<3xi32> + } else { + memref.store %add, %buffer[%c1] : memref<3xi32> + } + return +} + +// CHECK: %[[AND:.+]] = arith.andi +// CHECK: %[[OR:.+]] = arith.ori %[[AND]] +// CHECK: scf.if %[[OR]] +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK: arith.addi +// CHECK: memref.store +// CHECK: } else { +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK: arith.addi +// CHECK: memref.store diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRSCFTestPasses TestLoopParametricTiling.cpp TestLoopUnrolling.cpp + TestSCFTransforms.cpp TestSCFUtils.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/SCF/TestSCFTransforms.cpp b/mlir/test/lib/Dialect/SCF/TestSCFTransforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/SCF/TestSCFTransforms.cpp @@ -0,0 +1,72 @@ +//===- TestSCFTransforms.cpp - Test SCF transformation patterns -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing SCF transformations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestSCFTransforms + : public PassWrapper> { + TestSCFTransforms() = default; + TestSCFTransforms(const TestSCFTransforms &pass) : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { return "test-scf-transform-patterns"; } + StringRef getDescription() const final { + return "Test SCF patterns by applying them selectively or greedily."; + } + + void runOnOperation() override; + + Option testExpandIfRegionPatterns{ + *this, "test-expand-if-region-patterns", + llvm::cl::desc("Test patterns to expand if regions"), + llvm::cl::init(false)}; +}; +} // namespace + +static void applyExpandIfRegionPatterns(FuncOp funcOp) { + SmallVector candidates; + funcOp.walk([&](scf::IfOp ifOp) { candidates.push_back(ifOp); }); + + RewritePatternSet patterns(funcOp.getContext()); + auto canMoveToRegion = [](Operation *op) { + return !op->hasAttrOfType("sticky"); + }; + scf::populateIfRegionExpansionPatterns(patterns, canMoveToRegion); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + for (scf::IfOp ifOp : candidates) { + // Apply transforms per op to avoid recursive behavior. + (void)applyOpPatternsAndFold(ifOp, frozenPatterns, /*erased=*/nullptr); + } +} + +void TestSCFTransforms::runOnOperation() { + FuncOp func = getOperation(); + if (testExpandIfRegionPatterns) + applyExpandIfRegionPatterns(func); +} + +namespace mlir { +namespace test { +void registerTestSCFTransforms() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -104,6 +104,7 @@ void registerTestPDLByteCodePass(); void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); +void registerTestSCFTransforms(); void registerTestSCFUtilsPass(); void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); @@ -193,6 +194,7 @@ mlir::test::registerTestPadFusion(); mlir::test::registerTestPDLByteCodePass(); mlir::test::registerTestRecursiveTypesPass(); + mlir::test::registerTestSCFTransforms(); mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms();