diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -768,10 +768,11 @@ } static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, - ArrayRef memrefs, - ArrayRef conditions, + ValueRange memrefs, + ValueRange conditions, PatternRewriter &rewriter) { - if (deallocOp.getMemrefs() == memrefs) + if (deallocOp.getMemrefs() == memrefs && + deallocOp.getConditions() == conditions) return failure(); rewriter.updateRootInPlace(deallocOp, [&]() { @@ -983,6 +984,49 @@ } }; +/// The `memref.extract_strided_metadata` is often inserted to get the base +/// memref if the operand is not already guaranteed to be the result of a memref +/// allocation operation. This canonicalization pattern removes this extraction +/// operation if the operand is now produced by an allocation operation (e.g., +/// due to other canonicalizations simplifying the IR). +/// +/// Example: +/// ```mlir +/// %alloc = memref.alloc() : memref<2xi32> +/// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata +/// %alloc : memref<2xi32> -> memref, index, index, index +/// bufferization.dealloc (%base_memref : memref) if (%cond) +/// ``` +/// is canonicalized to +/// ```mlir +/// %alloc = memref.alloc() : memref<2xi32> +/// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond) +/// ``` +struct SkipExtractMetadataOfAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + SmallVector newMemrefs( + llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) { + auto extractStridedOp = + memref.getDefiningOp(); + if (!extractStridedOp) + return memref; + Value allocMemref = extractStridedOp.getOperand(); + auto allocOp = allocMemref.getDefiningOp(); + if (!allocOp) + return memref; + if (allocOp.getEffectOnValue(allocMemref)) + return allocMemref; + return memref; + })); + + return updateDeallocIfChanged(deallocOp, newMemrefs, + deallocOp.getConditions(), rewriter); + } +}; + } // anonymous namespace void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -990,7 +1034,7 @@ results.add(context); + EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -323,3 +323,20 @@ // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1) // CHECK-NEXT: bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]]) // CHECK-NEXT: return + +// ----- + +func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) { + %alloc = memref.alloc() : memref<2xi32> + %base0, %size0, %stride0, %offset0 = memref.extract_strided_metadata %alloc : memref<2xi32> -> memref, index, index, index + %base1, %size1, %stride1, %offset1 = memref.extract_strided_metadata %arg3 : memref<2xi32> -> memref, index, index, index + bufferization.dealloc (%base0, %arg0, %base1 : memref, memref<2xi32>, memref) if (%arg1, %arg2, %arg2) + return +} + +// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc() : memref<2xi32> +// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG3]] : +// CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[ARG0]], [[BASE]] : memref<2xi32>, memref<2xi32>, memref) if ([[ARG1]], [[ARG2]], [[ARG2]]) +// CHECK-NEXT: return