diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -73,19 +73,19 @@ namespace { /// Remove values from the `memref` operand list that are also present in the -/// `retained` list since they will always alias and thus never actually be -/// deallocated. However, we also need to be certain that no other value in the -/// `retained` list can alias, for which we use a static alias analysis. This is -/// necessary because the `dealloc` operation is defined to return one `i1` -/// value per memref in the `retained` list which represents the disjunction of -/// the condition values corresponding to all aliasing values in the `memref` -/// list. In particular, this means that if there is some value R in the -/// `retained` list which aliases with a value M in the `memref` list (but can -/// only be staticaly determined to may-alias) and M is also present in the -/// `retained` list, then it would be illegal to remove M because the result -/// corresponding to R would be computed incorrectly afterwards. -/// Because we require an alias analysis, this pattern cannot be applied as a -/// regular canonicalization pattern. +/// `retained` list (or a guaranteed alias of it) because they will never +/// actually be deallocated. However, we also need to be certain about which +/// other memrefs in the `retained` list can alias, i.e., there must not by any +/// may-aliasing memref. This is necessary because the `dealloc` operation is +/// defined to return one `i1` value per memref in the `retained` list which +/// represents the disjunction of the condition values corresponding to all +/// aliasing values in the `memref` list. In particular, this means that if +/// there is some value R in the `retained` list which aliases with a value M in +/// the `memref` list (but can only be staticaly determined to may-alias) and M +/// is also present in the `retained` list, then it would be illegal to remove M +/// because the result corresponding to R would be computed incorrectly +/// afterwards. Because we require an alias analysis, this pattern cannot be +/// applied as a regular canonicalization pattern. /// /// Example: /// ```mlir @@ -101,63 +101,75 @@ /// // replace %0#0 with %1 /// ``` /// given that `%r0` and `%r1` may not alias with `%m0`. -struct DeallocRemoveDeallocMemrefsContainedInRetained +struct RemoveDeallocMemrefsContainedInRetained : public OpRewritePattern { - DeallocRemoveDeallocMemrefsContainedInRetained(MLIRContext *context, - AliasAnalysis &aliasAnalysis) + RemoveDeallocMemrefsContainedInRetained(MLIRContext *context, + AliasAnalysis &aliasAnalysis) : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} + /// The passed 'memref' must not have a may-alias relation to any retained + /// memref, and at least one must-alias relation. If there is no must-aliasing + /// memref in the retain list, we cannot simply remove the memref as there + /// could be situations in which it actually has to be deallocated. If it's + /// no-alias, then just proceed, if it's must-alias we need to update the + /// updated condition returned by the dealloc operation for that alias. + LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond, + PatternRewriter &rewriter) const { + rewriter.setInsertionPointAfter(deallocOp); + + // Check that there is no may-aliasing memref and that at least one memref + // in the retain list aliases (because otherwise it might have to be + // deallocated in some situations and can thus not be dropped). + bool atLeastOneMustAlias = false; + for (Value retained : deallocOp.getRetained()) { + AliasResult analysisResult = aliasAnalysis.alias(retained, memref); + if (analysisResult.isMay()) + return failure(); + if (analysisResult.isMust() || analysisResult.isPartial()) + atLeastOneMustAlias = true; + } + if (!atLeastOneMustAlias) + return failure(); + + // Insert arith.ori operations to update the corresponding dealloc result + // values to incorporate the condition of the must-aliasing memref such that + // we can remove that operand later on. + for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) { + Value updatedCondition = deallocOp.getUpdatedConditions()[i]; + AliasResult analysisResult = aliasAnalysis.alias(retained, memref); + if (analysisResult.isMust() || analysisResult.isPartial()) { + auto disjunction = rewriter.create( + deallocOp.getLoc(), updatedCondition, cond); + rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(), + disjunction); + } + } + + return success(); + } + LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { - // Unique memrefs to be deallocated. - DenseMap retained; - for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained())) - retained[ret] = i; - // There must not be any duplicates in the retain list anymore because we // would miss updating one of the result values otherwise. + DenseSet retained(deallocOp.getRetained().begin(), + deallocOp.getRetained().end()); if (retained.size() != deallocOp.getRetained().size()) return failure(); SmallVector newMemrefs, newConditions; - for (auto memrefAndCond : + for (auto [memref, cond] : llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { - Value memref = std::get<0>(memrefAndCond); - Value cond = std::get<1>(memrefAndCond); - - auto replaceResultsIfNoInvalidAliasing = [&](Value memref) -> bool { - Value retainedMemref = deallocOp.getRetained()[retained[memref]]; - // The current memref must not have a may-alias relation to any retained - // memref, and exactly one must-alias relation. - // TODO: it is possible to extend this pattern to allow an arbitrary - // number of must-alias relations as long as there is no may-alias. If - // it's no-alias, then just proceed (only supported case as of now), if - // it's must-alias, we also need to update the condition for that alias. - if (llvm::all_of(deallocOp.getRetained(), [&](Value mr) { - return aliasAnalysis.alias(mr, memref).isNo() || - mr == retainedMemref; - })) { - rewriter.setInsertionPointAfter(deallocOp); - auto orOp = rewriter.create( - deallocOp.getLoc(), - deallocOp.getUpdatedConditions()[retained[memref]], cond); - rewriter.replaceAllUsesExcept( - deallocOp.getUpdatedConditions()[retained[memref]], - orOp.getResult(), orOp); - return true; - } - return false; - }; - - if (retained.contains(memref) && - replaceResultsIfNoInvalidAliasing(memref)) - continue; - auto extractOp = memref.getDefiningOp(); - if (extractOp && retained.contains(extractOp.getOperand()) && - replaceResultsIfNoInvalidAliasing(extractOp.getOperand())) + if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter))) continue; + if (auto extractOp = + memref.getDefiningOp()) + if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond, + rewriter))) + continue; + newMemrefs.push_back(memref); newConditions.push_back(cond); } @@ -332,7 +344,7 @@ void runOnOperation() override { AliasAnalysis &aliasAnalysis = getAnalysis(); RewritePatternSet patterns(&getContext()); - patterns.add(&getContext(), aliasAnalysis); diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir @@ -1,51 +1,79 @@ // RUN: mlir-opt %s --buffer-deallocation-simplification --split-input-file | FileCheck %s -func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) { +func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) { %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>) %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) %2:2 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) - return %0, %1, %2#0, %2#1 : i1, i1, i1, i1 + // multiple must-alias + %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref + %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>> + %alloc = memref.alloc() : memref<2xi32> + %5:3 = bufferization.dealloc (%arg0, %4 : memref<2xi32>, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref) + return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1 } // CHECK-LABEL: func @dealloc_deallocated_in_retained -// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1) +// CHECK-NEXT: arith.constant false // CHECK-NEXT: bufferization.dealloc // CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) // CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] // CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) -// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 : +// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the +// COM: retained memrefs since the list of memrefs to be deallocated becomes empty +// COM: due to the pattern under test (and thus there is no memref the retain values +// COM: could alias to) +// CHECK-NEXT: bufferization.dealloc +// CHECK-NOT: if +// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]] +// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]] +// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] : // ----- -func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) { +func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) { %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref, index, index, index %base_buffer0, %offset0, %size0, %stride0 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref, index, index, index %0 = bufferization.dealloc (%base_buffer : memref) if (%arg1) retain (%arg0 : memref<2xi32>) %1 = bufferization.dealloc (%base_buffer, %base_buffer0 : memref, memref) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) %2:2 = bufferization.dealloc (%base_buffer : memref) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) - return %0, %1, %2#0, %2#1 : i1, i1, i1, i1 + // multiple must-alias + %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref + %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>> + %alloc = memref.alloc() : memref<2xi32> + %5:3 = bufferization.dealloc (%base_buffer, %4 : memref, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref) + return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1 } // CHECK-LABEL: func @dealloc_deallocated_in_retained_extract_base_memref -// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1) +// CHECK-NEXT: arith.constant false // CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] : // CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] : // CHECK-NEXT: bufferization.dealloc // CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) // CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] // CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) -// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 : +// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the +// COM: retained memrefs since the list of memrefs to be deallocated becomes empty +// COM: due to the pattern under test (and thus there is no memref the retain values +// COM: could alias to) +// CHECK-NEXT: bufferization.dealloc +// CHECK-NOT: if +// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]] +// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]] +// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] : // ----- -func.func @dealloc_deallocated_in_retained(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) { +func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) { %alloc = memref.alloc() : memref<2xi32> %alloc0 = memref.alloc() : memref<2xi32> %0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>) return %0#0, %0#1 : i1, i1 } -// CHECK-LABEL: func @dealloc_deallocated_in_retained +// CHECK-LABEL: func @remove_retained_memrefs_guarateed_to_not_alias // CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>) // CHECK-NEXT: [[FALSE:%.+]] = arith.constant false // CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc(