diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -154,6 +154,78 @@ AliasAnalysis &aliasAnalysis; }; +/// Remove memrefs from the `retained` list which are guaranteed to not alias +/// any memref in the `memrefs` list. The corresponding result value can be +/// replaced with `false` in that case according to the operation description. +/// +/// Example: +/// ```mlir +/// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond) +/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) +/// return %0#0, %0#1 +/// ``` +/// can be canonicalized to the following given that `%r0` and `%r1` do not +/// alias `%m`: +/// ```mlir +/// bufferization.dealloc (%m : memref<2xi32>) if (%cond) +/// return %false, %false +/// ``` +struct RemoveRetainedMemrefsGuaranteedToNotAlias + : public OpRewritePattern { + RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context, + AliasAnalysis &aliasAnalysis) + : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} + + bool potentiallyAliasesMemref(DeallocOp deallocOp, + Value retainedMemref) const { + for (auto memref : deallocOp.getMemrefs()) { + if (!aliasAnalysis.alias(memref, retainedMemref).isNo()) + return true; + } + return false; + } + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + SmallVector newRetainedMemrefs, replacements; + Value falseValue; + auto getOrCreateFalse = [&]() -> Value { + if (!falseValue) + falseValue = rewriter.create( + deallocOp.getLoc(), rewriter.getBoolAttr(false)); + return falseValue; + }; + + for (auto retainedMemref : deallocOp.getRetained()) { + if (potentiallyAliasesMemref(deallocOp, retainedMemref)) { + newRetainedMemrefs.push_back(retainedMemref); + replacements.push_back({}); + continue; + } + + replacements.push_back(getOrCreateFalse()); + } + + if (newRetainedMemrefs.size() == deallocOp.getRetained().size()) + return failure(); + + auto newDeallocOp = rewriter.create( + deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(), + newRetainedMemrefs); + int i = 0; + for (auto &repl : replacements) { + if (!repl) + repl = newDeallocOp.getUpdatedConditions()[i++]; + } + + rewriter.replaceOp(deallocOp, replacements); + return success(); + } + +private: + AliasAnalysis &aliasAnalysis; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -171,8 +243,9 @@ void runOnOperation() override { AliasAnalysis &aliasAnalysis = getAnalysis(); RewritePatternSet patterns(&getContext()); - patterns.add(&getContext(), - aliasAnalysis); + patterns.add(&getContext(), + aliasAnalysis); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir @@ -9,12 +9,11 @@ // CHECK-LABEL: func @dealloc_deallocated_in_retained // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) -// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>) -// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]] +// CHECK-NEXT: bufferization.dealloc // CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) // CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] // CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) -// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 : +// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 : // ----- @@ -31,9 +30,25 @@ // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) // CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] : // CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] : -// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>) -// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]] +// CHECK-NEXT: bufferization.dealloc // CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) // CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] // CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) -// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 : +// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 : + +// ----- + +func.func @dealloc_deallocated_in_retained(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) { + %alloc = memref.alloc() : memref<2xi32> + %alloc0 = memref.alloc() : memref<2xi32> + %0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>) + return %0#0, %0#1 : i1, i1 +} + +// CHECK-LABEL: func @dealloc_deallocated_in_retained +// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false +// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc( +// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]]) +// CHECK-NOT: retain +// CHECK-NEXT: return [[FALSE]], [[FALSE]] :