diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -11,6 +11,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/FunctionExtras.h" namespace mlir { @@ -447,6 +448,30 @@ Region::iterator before); void cloneRegionBefore(Region ®ion, Block *before); + /// This method replaces the uses of the results of `op` with the values in + /// `newValues` when the provided `functor` returns true for a specific use. + /// The number of values in `newValues` is required to match the number of + /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of + /// the uses of `op` were replaced. Note that in some pattern rewriters, the + /// given 'functor' may be stored beyond the lifetime of the pattern being + /// applied. As such, the function should not capture by reference and instead + /// use value capture as necessary. + virtual void + replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, + llvm::unique_function functor); + void replaceOpWithIf(Operation *op, ValueRange newValues, + llvm::unique_function functor) { + replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr, + std::move(functor)); + } + + /// This method replaces the uses of the results of `op` with the values in + /// `newValues` when a use is nested within the given `block`. The number of + /// values in `newValues` is required to match the number of results of `op`. + /// If all uses of this operation are replaced, the operation is erased. + void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, + bool *allUsesReplaced = nullptr); + /// This method performs the final replacement for a pattern, where the /// results of the operation are updated to use the specified list of SSA /// values. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -470,6 +470,12 @@ // PatternRewriter Hooks //===--------------------------------------------------------------------===// + /// PatternRewriter hook for replacing the results of an operation when the + /// given functor returns true. + void replaceOpWithIf( + Operation *op, ValueRange newValues, bool *allUsesReplaced, + llvm::unique_function functor) override; + /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ValueRange newValues) override; using PatternRewriter::replaceOp; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -155,6 +155,41 @@ // Out of line to provide a vtable anchor for the class. } +/// This method replaces the uses of the results of `op` with the values in +/// `newValues` when the provided `functor` returns true for a specific use. +/// The number of values in `newValues` is required to match the number of +/// results of `op`. +void PatternRewriter::replaceOpWithIf( + Operation *op, ValueRange newValues, bool *allUsesReplaced, + llvm::unique_function functor) { + assert(op->getNumResults() == newValues.size() && + "incorrect number of values to replace operation"); + + // Notify the rewriter subclass that we're about to replace this root. + notifyRootReplaced(op); + + // Replace each use of the results when the functor is true. + bool replacedAllUses = true; + for (auto it : llvm::zip(op->getResults(), newValues)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor); + replacedAllUses &= std::get<0>(it).use_empty(); + } + if (allUsesReplaced) + *allUsesReplaced = replacedAllUses; +} + +/// This method replaces the uses of the results of `op` with the values in +/// `newValues` when a use is nested within the given `block`. The number of +/// values in `newValues` is required to match the number of results of `op`. +/// If all uses of this operation are replaced, the operation is erased. +void PatternRewriter::replaceOpWithinBlock(Operation *op, ValueRange newValues, + Block *block, + bool *allUsesReplaced) { + replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) { + return block->getParentOp()->isProperAncestor(use.getOwner()); + }); +} + /// This method performs the final replacement for a pattern, where the /// results of the operation are updated to use the specified list of SSA /// values. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1250,6 +1250,21 @@ impl(new detail::ConversionPatternRewriterImpl(*this)) {} ConversionPatternRewriter::~ConversionPatternRewriter() {} +/// PatternRewriter hook for replacing the results of an operation when the +/// given functor returns true. +void ConversionPatternRewriter::replaceOpWithIf( + Operation *op, ValueRange newValues, bool *allUsesReplaced, + llvm::unique_function functor) { + // TODO: To support this we will need to rework a bit of how replacements are + // tracked, given that this isn't guranteed to replace all of the uses of an + // operation. The main change is that now an operation can be replaced + // multiple times, in parts. The current "set" based tracking is mainly useful + // for tracking if a replaced operation should be ignored, i.e. if all of the + // uses will be replaced. + llvm_unreachable( + "replaceOpWithIf is currently not supported by DialectConversion"); +} + /// PatternRewriter hook for replacing the results of an operation. void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { LLVM_DEBUG({ diff --git a/mlir/test/Transforms/test-pattern-selective-replacement.mlir b/mlir/test/Transforms/test-pattern-selective-replacement.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-pattern-selective-replacement.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-pattern-selective-replacement -verify-diagnostics %s | FileCheck %s + +// Test that operations can be selectively replaced. + +// CHECK-LABEL: @test1 +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func @test1(%arg0: i32, %arg1 : i32) -> () { + // CHECK: addi %[[ARG1]], %[[ARG1]] + // CHECK-NEXT: "test.return"(%[[ARG0]] + %cast = "test.cast"(%arg0, %arg1) : (i32, i32) -> (i32) + %non_terminator = addi %cast, %cast : i32 + "test.return"(%cast, %non_terminator) : (i32, i32) -> () +} + +// ----- diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -847,6 +847,10 @@ }; } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Test Block Merging +//===----------------------------------------------------------------------===// + namespace { /// A rewriter pattern that tests that blocks can be merged. struct TestMergeBlock : public OpConversionPattern { @@ -955,6 +959,46 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Test Selective Replacement +//===----------------------------------------------------------------------===// + +namespace { +/// A rewrite mechanism to inline the body of the op into its parent, when both +/// ops can have a single block. +struct TestSelectiveOpReplacementPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TestCastOp op, + PatternRewriter &rewriter) const final { + if (op.getNumOperands() != 2) + return failure(); + OperandRange operands = op.getOperands(); + + // Replace non-terminator uses with the first operand. + rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { + return operand.getOwner()->isKnownTerminator(); + }); + // Replace everything else with the second operand if the operation isn't + // dead. + rewriter.replaceOp(op, op.getOperand(1)); + return success(); + } +}; + +struct TestSelectiveReplacementPatternDriver + : public PassWrapper> { + void runOnOperation() override { + mlir::OwningRewritePatternList patterns; + MLIRContext *context = &getContext(); + patterns.insert(context); + applyPatternsAndFoldGreedily(getOperation()->getRegions(), + std::move(patterns)); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // PassRegistration //===----------------------------------------------------------------------===// @@ -992,6 +1036,9 @@ PassRegistration{ "test-merge-blocks", "Test Merging operation in ConversionPatternRewriter"}; + PassRegistration{ + "test-pattern-selective-replacement", + "Test selective replacement in the PatternRewriter"}; } } // namespace test } // namespace mlir