diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -963,13 +963,65 @@ } }; +/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no +/// other user of the allocated value and the allocating operation can be safely +/// removed. If the same value is present multiple times, this pattern relies on +/// other canonicalization patterns to remove the duplicate first. +/// +/// Example: +/// ```mlir +/// %alloc = memref.alloc() : memref<2xi32> +/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true) +/// ``` +/// is canonicalized to +/// ```mlir +/// bufferization.dealloc (%arg0 : ...) if (%true) +/// ``` +struct RemoveAllocDeallocPairWhenNoOtherUsers + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + SmallVector newMemrefs, newConditions; + SmallVector toDelete; + for (auto [memref, cond] : + llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + if (auto allocOp = memref.getDefiningOp()) { + // Check that it is indeed an allocate effect, that the op has no other + // side effects (which would not allow us to remove the op), and that + // there are no other users. + if (allocOp.getEffectOnValue(memref) && + hasSingleEffect(allocOp, memref) && + memref.hasOneUse()) { + toDelete.push_back(allocOp); + continue; + } + } + + newMemrefs.push_back(memref); + newConditions.push_back(cond); + } + + if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, + rewriter))) + return failure(); + + for (Operation *op : toDelete) + rewriter.eraseOp(op); + + return success(); + } +}; + } // anonymous namespace void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc, + RemoveAllocDeallocPairWhenNoOtherUsers>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -323,12 +323,12 @@ // ----- -func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) { +func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) -> memref<2xi32> { %alloc = memref.alloc() : memref<2xi32> %base0, %size0, %stride0, %offset0 = memref.extract_strided_metadata %alloc : memref<2xi32> -> memref, index, index, index %base1, %size1, %stride1, %offset1 = memref.extract_strided_metadata %arg3 : memref<2xi32> -> memref, index, index, index bufferization.dealloc (%base0, %arg0, %base1 : memref, memref<2xi32>, memref) if (%arg1, %arg2, %arg2) - return + return %alloc : memref<2xi32> } // CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc @@ -337,3 +337,17 @@ // CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG3]] : // CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[ARG0]], [[BASE]] : memref<2xi32>, memref<2xi32>, memref) if ([[ARG1]], [[ARG2]], [[ARG2]]) // CHECK-NEXT: return + +// ----- + +func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>) { + %true = arith.constant true + %alloc = memref.alloc() : memref<2xi32> + bufferization.dealloc (%alloc, %arg0 : memref<2xi32>, memref<2xi32>) if (%true, %true) + return +} + +// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc +// CHECK-SAME:([[ARG0:%.+]]: memref<2xi32>) +// CHECK-NOT: memref.alloc( +// CHECK: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if (%true