diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -58,6 +58,12 @@ LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref); +/// Add the canonicalization patterns for bufferization.dealloc to the given +/// pattern set to make them available to other passes (such as +/// BufferDeallocationSimplification). +void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context); + } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -1018,10 +1018,15 @@ void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + populateDeallocOpCanonicalizationPatterns(results, context); +} + +void bufferization::populateDeallocOpCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -432,6 +432,7 @@ SplitDeallocWhenNotAliasingAnyOther, RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(), aliasAnalysis); + populateDeallocOpCanonicalizationPatterns(patterns, &getContext()); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir @@ -15,7 +15,6 @@ // CHECK-LABEL: func @dealloc_deallocated_in_retained // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1) // CHECK-NEXT: arith.constant false -// CHECK-NEXT: bufferization.dealloc // CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) // CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] // CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) @@ -23,7 +22,6 @@ // COM: retained memrefs since the list of memrefs to be deallocated becomes empty // COM: due to the pattern under test (and thus there is no memref the retain values // COM: could alias to) -// CHECK-NEXT: bufferization.dealloc // CHECK-NOT: if // CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]] // CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]] @@ -50,7 +48,6 @@ // CHECK-NEXT: arith.constant false // CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] : // CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] : -// CHECK-NEXT: bufferization.dealloc // CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) // CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] // CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) @@ -58,7 +55,6 @@ // COM: retained memrefs since the list of memrefs to be deallocated becomes empty // COM: due to the pattern under test (and thus there is no memref the retain values // COM: could alias to) -// CHECK-NEXT: bufferization.dealloc // CHECK-NOT: if // CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]] // CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]] @@ -66,11 +62,11 @@ // ----- -func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) { +func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1, memref<2xi32>) { %alloc = memref.alloc() : memref<2xi32> %alloc0 = memref.alloc() : memref<2xi32> %0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>) - return %0#0, %0#1 : i1, i1 + return %0#0, %0#1, %alloc : i1, i1, memref<2xi32> } // CHECK-LABEL: func @remove_retained_memrefs_guarateed_to_not_alias @@ -79,7 +75,7 @@ // CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc( // CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]]) // CHECK-NOT: retain -// CHECK-NEXT: return [[FALSE]], [[FALSE]] : +// CHECK-NEXT: return [[FALSE]], [[FALSE]], [[ALLOC]] : // ----- @@ -104,7 +100,6 @@ // CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>) // CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>) // CHECK-NEXT: [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1 -// CHECK-NEXT: bufferization.dealloc // CHECK-NEXT: return [[V2]]#0, [[V3]] : // -----