diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -48,6 +48,18 @@ return success(); } +static bool potentiallyAliasesMemref(AliasAnalysis &analysis, + ValueRange memrefList, Value memref, + bool allowSelfAlias) { + for (auto mr : memrefList) { + if (allowSelfAlias && mr == memref) + continue; + if (!analysis.alias(mr, memref).isNo()) + return true; + } + return false; +} + //===----------------------------------------------------------------------===// // Patterns //===----------------------------------------------------------------------===// @@ -159,21 +171,13 @@ AliasAnalysis &aliasAnalysis) : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} - bool potentiallyAliasesMemref(DeallocOp deallocOp, - Value retainedMemref) const { - for (auto memref : deallocOp.getMemrefs()) { - if (!aliasAnalysis.alias(memref, retainedMemref).isNo()) - return true; - } - return false; - } - LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { SmallVector newRetainedMemrefs, replacements; Value falseValue; for (auto retainedMemref : deallocOp.getRetained()) { - if (potentiallyAliasesMemref(deallocOp, retainedMemref)) { + if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(), + retainedMemref, false)) { newRetainedMemrefs.push_back(retainedMemref); replacements.push_back({}); continue; @@ -206,6 +210,85 @@ AliasAnalysis &aliasAnalysis; }; +/// Split off memrefs to separate (newly created) dealloc operations to reduce +/// the number of runtime checks required and enable further canonicalization of +/// the new and simpler dealloc operations. A memref can be split off if it is +/// guaranteed to not alias with any other memref in the `memref` operand list. +/// The results of the old and the new dealloc operation have to be combined by +/// computing the element-wise disjunction of them. +/// +/// Example: +/// ```mlir +/// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>) +/// if (%cond0, %cond1) +/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) +/// return %0#0, %0#1 +/// ``` +/// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is +/// canonicalized to the following, thus reducing the number of runtime alias +/// checks by 1 and potentially enabling further canonicalization of the new +/// split-up dealloc operations. +/// ```mlir +/// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0) +/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) +/// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1) +/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) +/// %2 = arith.ori %0#0, %1#0 +/// %3 = arith.ori %0#1, %1#1 +/// return %2, %3 +/// ``` +struct SplitDeallocWhenNotAliasingAnyOther + : public OpRewritePattern { + SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context, + AliasAnalysis &aliasAnalysis) + : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + if (deallocOp.getMemrefs().size() <= 1) + return failure(); + + SmallVector newMemrefs, newConditions, replacements; + DenseSet exceptedUsers; + replacements = deallocOp.getUpdatedConditions(); + for (auto [memref, cond] : + llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(), + memref, true)) { + newMemrefs.push_back(memref); + newConditions.push_back(cond); + continue; + } + + auto newDeallocOp = rewriter.create( + deallocOp.getLoc(), memref, cond, deallocOp.getRetained()); + replacements = SmallVector(llvm::map_range( + llvm::zip(replacements, newDeallocOp.getUpdatedConditions()), + [&](auto replAndNew) -> Value { + auto orOp = rewriter.create(deallocOp.getLoc(), + std::get<0>(replAndNew), + std::get<1>(replAndNew)); + exceptedUsers.insert(orOp); + return orOp.getResult(); + })); + } + + if (newMemrefs.size() == deallocOp.getMemrefs().size()) + return failure(); + + rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements, + [&](OpOperand &operand) { + return !exceptedUsers.contains( + operand.getOwner()); + }); + return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, + rewriter); + } + +private: + AliasAnalysis &aliasAnalysis; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -224,8 +307,9 @@ AliasAnalysis &aliasAnalysis = getAnalysis(); RewritePatternSet patterns(&getContext()); patterns.add(&getContext(), - aliasAnalysis); + RemoveRetainedMemrefsGuaranteedToNotAlias, + SplitDeallocWhenNotAliasingAnyOther>(&getContext(), + aliasAnalysis); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir @@ -52,3 +52,24 @@ // CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]]) // CHECK-NOT: retain // CHECK-NEXT: return [[FALSE]], [[FALSE]] : + +// ----- + +func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1) { + %alloc = memref.alloc() : memref<2xi32> + %alloc0 = memref.alloc() : memref<2xi32> + %0 = arith.select %arg0, %alloc, %alloc0 : memref<2xi32> + %1:2 = bufferization.dealloc (%alloc, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg0, %arg3) retain (%arg1, %0 : memref<2xi32>, memref<2xi32>) + return %1#0, %1#1 : i1, i1 +} + +// CHECK-LABEL: func @dealloc_split_when_no_other_aliasing +// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1) +// CHECK-NEXT: [[ALLOC0:%.+]] = memref.alloc( +// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc( +// CHECK-NEXT: [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] : +// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>) +// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>) +// CHECK-NEXT: [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1 +// CHECK-NEXT: bufferization.dealloc +// CHECK-NEXT: return [[V2]]#0, [[V3]] :