diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -213,8 +213,8 @@ std::enable_if_t::value, bool> -matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { - return matcher.match(op->getOperand(idx)); +matchOperationOrValue(Value val, MatcherClass &matcher) { + return matcher.match(val); } /// Statically switch to an Operation matcher. @@ -222,8 +222,8 @@ std::enable_if_t::value, bool> -matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { - if (auto *defOp = op->getOperand(idx).getDefiningOp()) +matchOperationOrValue(Value val, MatcherClass &matcher) { + if (auto *defOp = val.getDefiningOp()) return matcher.match(defOp); return false; } @@ -243,6 +243,21 @@ } }; +/// Terminal matcher, returns true when operation type matches and binds +/// operation, otherwise returns false. +template +struct AnyCapturedOperationMatcher { + OpTy &what; + AnyCapturedOperationMatcher(OpTy &what) : what(what) {} + bool match(Operation *op) { + if (isa(op)) { + what = cast(op); + return true; + } + return false; + } +}; + /// Binds to a specific value and matches it. struct PatternMatcherValue { PatternMatcherValue(Value val) : value(val) {} @@ -274,13 +289,29 @@ return false; bool res = true; enumerate(operandMatchers, [&](size_t index, auto &matcher) { - res &= matchOperandOrValueAtIndex(op, index, matcher); + res &= matchOperationOrValue(op->getOperand(index), matcher); }); return res; } std::tuple operandMatchers; }; +/// All patterns have to match. Child patterns may either match on values of +/// operations, but this patten always has to be applied to values. +template +struct AllOfPatternMatcher { + AllOfPatternMatcher(Patterns... patterns) : patterns(patterns...) {} + + bool match(Value op) { + bool res = true; + enumerate(patterns, [&](size_t index, auto &pattern) { + res &= matchOperationOrValue(op, pattern); + }); + return res; + } + std::tuple patterns; +}; + } // namespace detail /// Matches a constant foldable operation. @@ -373,6 +404,13 @@ return detail::op_matcher(); } +/// Matches the given OpClass like `m_Op()`, but binds the matched operation to +/// the argument. +template +inline auto m_AnyOpOfType(OpClass &op) { + return detail::AnyCapturedOperationMatcher(op); +} + /// Entry point for matching a pattern over a Value. template inline bool matchPattern(Value value, const Pattern &pattern) { @@ -407,6 +445,13 @@ return detail::RecursivePatternMatcher(matchers...); } +/// All of the matchers have to be successful. May only be applied to Values, +/// not Operations. +template +auto m_AllOf(Matchers... matchers) { + return detail::AllOfPatternMatcher(matchers...); +} + namespace matchers { inline auto m_Any() { return detail::AnyValueMatcher(); } inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -768,10 +768,11 @@ } static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, - ArrayRef memrefs, - ArrayRef conditions, + ValueRange memrefs, + ValueRange conditions, PatternRewriter &rewriter) { - if (deallocOp.getMemrefs() == memrefs) + if (deallocOp.getMemrefs() == memrefs && + deallocOp.getConditions() == conditions) return failure(); rewriter.updateRootInPlace(deallocOp, [&]() { @@ -983,6 +984,49 @@ } }; +/// The `memref.extract_strided_metadata` is often inserted to get the base +/// memref if the operand is not already guaranteed to be the result of a memref +/// allocation operation. This canonicalization pattern removes this extraction +/// operation if the operand is now produced by an allocation operation (e.g., +/// due to other canonicalizations simplifying the IR). +/// +/// Example: +/// ```mlir +/// %alloc = memref.alloc() : memref<2xi32> +/// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata +/// %alloc : memref<2xi32> -> memref, index, index, index +/// bufferization.dealloc (%base_memref : memref) if (%cond) +/// ``` +/// is canonicalized to +/// ```mlir +/// %alloc = memref.alloc() : memref<2xi32> +/// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond) +/// ``` +struct SkipExtractMetadataOfAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + SmallVector newMemrefs; + for (Value memref : deallocOp.getMemrefs()) { + Value allocMemref; + MemoryEffectOpInterface allocOp; + if (matchPattern(memref, + m_Op(m_AllOf( + matchers::m_Any(&allocMemref), + m_AnyOpOfType(allocOp)))) && + allocOp.getEffectOnValue(allocMemref)) { + newMemrefs.push_back(allocMemref); + continue; + } + newMemrefs.push_back(memref); + } + + return updateDeallocIfChanged(deallocOp, newMemrefs, + deallocOp.getConditions(), rewriter); + } +}; + } // anonymous namespace void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -990,7 +1034,7 @@ results.add(context); + EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -323,3 +323,20 @@ // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1) // CHECK-NEXT: bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]]) // CHECK-NEXT: return + +// ----- + +func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) { + %alloc = memref.alloc() : memref<2xi32> + %base0, %size0, %stride0, %offset0 = memref.extract_strided_metadata %alloc : memref<2xi32> -> memref, index, index, index + %base1, %size1, %stride1, %offset1 = memref.extract_strided_metadata %arg3 : memref<2xi32> -> memref, index, index, index + bufferization.dealloc (%base0, %arg0, %base1 : memref, memref<2xi32>, memref) if (%arg1, %arg2, %arg2) + return +} + +// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc() : memref<2xi32> +// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG3]] : +// CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[ARG0]], [[BASE]] : memref<2xi32>, memref<2xi32>, memref) if ([[ARG1]], [[ARG2]], [[ARG2]]) +// CHECK-NEXT: return diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir --- a/mlir/test/IR/test-matchers.mlir +++ b/mlir/test/IR/test-matchers.mlir @@ -52,3 +52,5 @@ // CHECK-LABEL: test3 // CHECK: Pattern mul(*, add(*, m_Op("test.name"))) matched // CHECK: Pattern m_Attr("fastmath") matched and bound value to: fast +// CHECK: Pattern allOf{mul(*, add(*, m_Op("test.name"))), m_Attr("fastmath")} matched +// CHECK: Pattern mul(*, *) matched and captured operation: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} fastmath : f32 diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -150,9 +150,12 @@ void test3(FunctionOpInterface f) { arith::FastMathFlagsAttr fastMathAttr; + arith::MulFOp mulFOp; auto p = m_Op(m_Any(), m_Op(m_Any(), m_Op("test.name"))); auto p1 = m_Attr("fastmath", &fastMathAttr); + auto p2 = m_AllOf(p, p1); + auto p3 = m_AnyOpOfType(mulFOp); // Last operation that is not the terminator. Operation *lastOp = f.getFunctionBody().front().back().getPrevNode(); @@ -161,6 +164,12 @@ if (p1.match(lastOp)) llvm::outs() << "Pattern m_Attr(\"fastmath\") matched and bound value to: " << fastMathAttr.getValue() << "\n"; + if (p2.match(lastOp->getResult(0))) + llvm::outs() << "Pattern allOf{mul(*, add(*, m_Op(\"test.name\"))), " + "m_Attr(\"fastmath\")} matched\n"; + if (p3.match(lastOp)) + llvm::outs() << "Pattern mul(*, *) matched and captured operation: " + << mulFOp << "\n"; } void TestMatchers::runOnOperation() {