diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Matchers.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -327,6 +328,88 @@ AliasAnalysis &aliasAnalysis; }; +/// Check for every retained memref if a must-aliasing memref exists in the +/// 'memref' operand list with constant 'true' condition. If so, we can replace +/// the operation result corresponding to that retained memref with 'true'. If +/// this this condition holds for all retained memrefs we can also remove the +/// aliasing memrefs and their conditions since they will never be deallocated +/// due to the must-alias and we don't need them to compute the result value +/// anymore since it got replaced with 'true'. +/// +/// Example: +/// ```mlir +/// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...) +/// if (%true, %true, %true) +/// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) +/// ``` +/// becomes +/// ```mlir +/// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true) +/// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) +/// // replace %0#0 with %true +/// // replace %0#1 with %true +/// ``` +/// Note that the dealloc operation will still have the result values, but they +/// don't have uses anymore. +struct RetainedMemrefAliasingAlwaysDeallocatedMemref + : public OpRewritePattern { + RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context, + AliasAnalysis &aliasAnalysis) + : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size()); + SmallVector newMemrefs, newConditions; + for (auto [memref, cond] : + llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + bool canDropMemref = false; + for (auto [i, retained, res] : llvm::enumerate( + deallocOp.getRetained(), deallocOp.getUpdatedConditions())) { + if (!matchPattern(cond, m_One())) + continue; + + AliasResult analysisResult = aliasAnalysis.alias(retained, memref); + if (analysisResult.isMust() || analysisResult.isPartial()) { + rewriter.replaceAllUsesWith(res, cond); + aliasesWithConstTrueMemref[i] = true; + canDropMemref = true; + continue; + } + + // TODO: once our alias analysis is powerful enough we can remove the + // rest of this loop body + auto extractOp = + memref.getDefiningOp(); + if (!extractOp) + continue; + + AliasResult extractAnalysisResult = + aliasAnalysis.alias(retained, extractOp.getOperand()); + if (extractAnalysisResult.isMust() || + extractAnalysisResult.isPartial()) { + rewriter.replaceAllUsesWith(res, cond); + aliasesWithConstTrueMemref[i] = true; + canDropMemref = true; + } + } + + if (!canDropMemref) { + newMemrefs.push_back(memref); + newConditions.push_back(cond); + } + } + if (!aliasesWithConstTrueMemref.all()) + return failure(); + + return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, + rewriter); + } + +private: + AliasAnalysis &aliasAnalysis; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -346,8 +429,9 @@ RewritePatternSet patterns(&getContext()); patterns.add(&getContext(), - aliasAnalysis); + SplitDeallocWhenNotAliasingAnyOther, + RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(), + aliasAnalysis); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir @@ -106,3 +106,35 @@ // CHECK-NEXT: [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1 // CHECK-NEXT: bufferization.dealloc // CHECK-NEXT: return [[V2]]#0, [[V3]] : + +// ----- + +func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition( + %arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: memref<2xi32>) -> (memref<2xi32>, memref<2xi32>, i1, i1) { + %true = arith.constant true + %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%true, %true, %true) retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) + return %arg0, %arg1, %0#0, %0#1 : memref<2xi32>, memref<2xi32>, i1, i1 +} + +// CHECK-LABEL: func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>) +// CHECK: bufferization.dealloc ([[ARG2]] :{{.*}}) if (%true{{[0-9_]*}}) +// CHECK-NEXT: return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} : + +// ----- + +func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition( + %arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: memref<2xi32>) -> (memref<2xi32>, memref<2xi32>, i1, i1) { + %true = arith.constant true + %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref, index, index, index + %base_buffer_1, %offset_1, %size_1, %stride_1 = memref.extract_strided_metadata %arg1 : memref<2xi32> -> memref, index, index, index + %base_buffer_2, %offset_2, %size_2, %stride_2 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref, index, index, index + %0:2 = bufferization.dealloc (%base_buffer, %base_buffer_1, %base_buffer_2 : memref, memref, memref) if (%true, %true, %true) retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) + return %arg0, %arg1, %0#0, %0#1 : memref<2xi32>, memref<2xi32>, i1, i1 +} + +// CHECK-LABEL: func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>) +// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]] +// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if (%true{{[0-9_]*}}) +// CHECK-NEXT: return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} :