diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h @@ -23,6 +23,10 @@ /// Creates a pass that bufferizes the SCF dialect. std::unique_ptr createSCFBufferizePass(); +/// Creates a pass that hoists and elimiantes common sub-expressions from +/// scf.if branches. +std::unique_ptr createSCFCSEBranchesPass(); + /// Creates a pass that specializes for loop for unrolling and /// vectorization. std::unique_ptr createForLoopSpecializationPass(); diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td @@ -18,6 +18,19 @@ "memref::MemRefDialect"]; } +def SCFCSEBranches : Pass<"scf-cse-branches"> { + let summary = "Hoist and eliminate common sub-expressions in scf.if branches"; + let description = [{ + This pass eliminates common sub-expressions inside the "then" and the "else" + branches of scf.if ops. This pass relies on information provided by the + MemoryEffectOpInterface to identify when it is safe to eliminate operations. + + Note: This pass does not eliminate duplicate operations outside of scf.if + ops. The `-cse` pass can be used in such cases. + }]; + let constructor = "mlir::createSCFCSEBranchesPass()"; +} + // Note: Making these canonicalization patterns would require a dependency // of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa. def SCFForLoopCanonicalization diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRSCFTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + CSE.cpp ForToWhile.cpp LoopCanonicalization.cpp LoopPipelining.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/CSE.cpp b/mlir/lib/Dialect/SCF/Transforms/CSE.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/CSE.cpp @@ -0,0 +1,163 @@ +#include "mlir/Dialect/SCF/Transforms/Passes.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace mlir { +#define GEN_PASS_DEF_SCFCSEBRANCHES +#include "mlir/Dialect/SCF/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::scf; + +/// Return `true` if the given ops are equivalent. +// TODO: This function is copied from mlir/Transforms/CSE.cpp. The same +// implementation could be used. +static bool isEqual(Operation *lhs, Operation *rhs) { + // If op has no regions, operation equivalence w.r.t operands alone is + // enough. + if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) { + return OperationEquivalence::isEquivalentTo( + lhs, rhs, OperationEquivalence::exactValueMatch, + OperationEquivalence::ignoreValueEquivalence, + OperationEquivalence::IgnoreLocations); + } + + // If lhs or rhs does not have a single region with a single block, they + // aren't CSEed for now. + if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 || + !llvm::hasSingleElement(lhs->getRegion(0)) || + !llvm::hasSingleElement(rhs->getRegion(0))) + return false; + + // Compare the two blocks. + Block &lhsBlock = lhs->getRegion(0).front(); + Block &rhsBlock = rhs->getRegion(0).front(); + + // Don't CSE if number of arguments differ. + if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments()) + return false; + + // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in + // `rhsBlock`. `Value`s from `lhsBlock` are the key. + DenseMap areEquivalentValues; + for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(), + rhs->getRegion(0).getArguments())) { + areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs); + } + + // Helper function to get the parent operation. + auto getParent = [](Value v) -> Operation * { + if (auto blockArg = v.dyn_cast()) + return blockArg.getParentBlock()->getParentOp(); + return v.getDefiningOp()->getParentOp(); + }; + + // Callback to compare if operands of ops in the region of `lhs` and `rhs` + // are equivalent. + auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult { + if (lhsValue == rhsValue) + return success(); + if (areEquivalentValues.lookup(lhsValue) == rhsValue) + return success(); + return failure(); + }; + + // Callback to compare if results of ops in the region of `lhs` and `rhs` + // are equivalent. + auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult { + if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) { + auto insertion = areEquivalentValues.insert({lhsResult, rhsResult}); + return success(insertion.first->second == rhsResult); + } + return success(); + }; + + return OperationEquivalence::isEquivalentTo( + lhs, rhs, mapOperands, mapResults, OperationEquivalence::IgnoreLocations); +} + +/// Return `true` if `op` uses an OpResult defined inside `block`. +static bool usesValuesDefinedInBlock(Operation *op, Block *block) { + WalkResult status = op->walk([&](Operation *op) { + if (llvm::any_of(op->getOperands(), [&](Value v) { + return v.isa() && v.getDefiningOp()->getBlock() == block; + })) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return status.wasInterrupted(); +} + +namespace { +struct SCFCSEBranches : public impl::SCFCSEBranchesBase { + void runOnOperation() override; +}; +} // namespace + +void SCFCSEBranches::runOnOperation() { + getOperation()->walk([](IfOp ifOp) { + // Only scf.if ops with a single block are supported. + if (!ifOp.getThenRegion().hasOneBlock() || + !ifOp.getElseRegion().hasOneBlock()) + return; + Block *thenBlock = &ifOp.getThenRegion().front(); + Block *elseBlock = &ifOp.getElseRegion().front(); + + // Indicates if there is an impure op in the "then" branch before `thenOp`. + bool thenImpureOpFound = false; + + for (auto &thenOp : llvm::make_early_inc_range( + llvm::make_range(thenBlock->begin(), thenBlock->end()))) { + // Terminators and ops that use values defined in this block are skipped. + if (isa(&thenOp) || + usesValuesDefinedInBlock(&thenOp, thenBlock)) { + thenImpureOpFound |= !isPure(&thenOp); + continue; + } + + // Set to true if `thenOp` was hoisted. + bool hoisted = false; + // Indicates if there is an impure op in the "else" branch before + // `elseOp`. + bool elseImpureOpFound = false; + + for (auto &elseOp : llvm::make_early_inc_range( + llvm::make_range(elseBlock->begin(), elseBlock->end()))) { + // Do not CSE if: + // 1. `thenOp` and `elseOp` are not equivalent, or + bool equalOps = isEqual(&thenOp, &elseOp); + // 2. CSE'ing would change side effects. If the ops are pure, there then + // there can be no change in side effects. Otherwise, hoist and CSE + // them only if there is no side-effecting (impure) op in the branch + // before `thenOp` or `elseOp`. + bool sideEffectViolation = + !isPure(&thenOp) && (thenImpureOpFound || elseImpureOpFound); + if (!equalOps || sideEffectViolation) { + elseImpureOpFound |= !isPure(&elseOp); + continue; + } + + // There may be multiple matches for `thenOp` in the "else" branch. + // Hoist `thenOp` only once. + if (!hoisted) { + thenOp.moveBefore(ifOp); + hoisted = true; + } + + // Erase duplicate op in the "else" branch. + elseOp.replaceAllUsesWith(thenOp.getResults()); + elseOp.erase(); + } + + if (!hoisted) { + thenImpureOpFound |= !isPure(&thenOp); + } + } + }); +} + +std::unique_ptr mlir::createSCFCSEBranchesPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/cse.mlir b/mlir/test/Dialect/SCF/cse.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/cse.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -split-input-file -scf-cse-branches %s | FileCheck %s + +// CHECK-LABEL: func @cse_side_effecting_ops +// CHECK-NEXT: %[[a:.*]] = "test.dummy_a" +// CHECK-NEXT: %[[b:.*]] = "test.dummy_b"(%[[a]]) +// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index +// CHECK-NEXT: %[[if:.*]] = scf.if %{{.*}} { +// CHECK-NEXT: "test.other_side_effecting_op" +// CHECK-NEXT: %[[c0:.*]] = "test.dummy_c"(%[[b]], %[[c5]]) +// CHECK-NEXT: scf.yield %[[c0]] +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[c1:.*]] = "test.dummy_c"(%[[b]], %[[c5]]) +// CHECK-NEXT: scf.yield %[[c1]] +// CHECK-NEXT: } +// CHECK-NEXT: return %[[if]] +func.func @cse_side_effecting_ops(%arg0: i1) -> f32 { + %0 = scf.if %arg0 -> (f32) { + // %1 is CSE'd because there is no preceding impure op. + %1 = "test.dummy_a"() : () -> (f32) + // %2 is CSE'd because there is no preceding impure op (after %1 was CSE'd). + %2 = "test.dummy_b"(%1) : (f32) -> (f32) + // Not CSE'd, this op does not exist in the "else" branch. + "test.other_side_effecting_op"() : () -> () + // %c5 is CSE'd because it is pure and does not depend on values in this + // block. + %c5 = arith.constant 5 : index + // %3 is not CSE'd because it is impure and there is a preceding impure op + // in the "then" branch. + %3 = "test.dummy_c"(%2, %c5) : (f32, index) -> (f32) + // Terminator is not hoisted. + scf.yield %3 : f32 + } else { + %1 = "test.dummy_a"() : () -> (f32) + %2 = "test.dummy_b"(%1) : (f32) -> (f32) + %c5 = arith.constant 5 : index + %3 = "test.dummy_c"(%2, %c5) : (f32, index) -> (f32) + scf.yield %3 : f32 + } + return %0 : f32 +}