diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -120,6 +120,97 @@ return success(); } + /// A special case lowering for the deallocation operation with exactly one + /// memref, but arbitrary number of retained values. This avoids the helper + /// function that the general case needs and thus also avoids storing indices + /// to specifically allocated memrefs. The size of the code produced by this + /// lowering is linear to the number of retained values. + /// + /// Example: + /// ```mlir + /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond) + // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) + /// return %0#0, %0#1 : i1, i1 + /// ``` + /// ```mlir + /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m + /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 + /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer + /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 + /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer + /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1 + /// %should_dealloc = arith.andi %not_retained, %cond : i1 + /// scf.if %should_dealloc { + /// memref.dealloc %m : memref<2xf32> + /// } + /// %true = arith.constant true + /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1 + /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1 + /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1 + /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1 + /// return %r0_ownership, %r1_ownership : i1, i1 + /// ``` + LogicalResult rewriteOneMemrefMultipleRetainCase( + bufferization::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getMemrefs().size() == 1 && "expected only one memref"); + + // Compute the base pointer indices, compare all retained indices to the + // memref index to check if they alias. + SmallVector doesNotAliasList; + Value memrefAsIdx = rewriter.create( + op->getLoc(), adaptor.getMemrefs()[0]); + for (Value retained : adaptor.getRetained()) { + Value retainedAsIdx = + rewriter.create(op->getLoc(), + retained); + Value doesNotAlias = rewriter.create( + op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx); + doesNotAliasList.push_back(doesNotAlias); + } + + // AND-reduce the list of booleans from above. + Value prev = doesNotAliasList.front(); + bool first = true; + for (Value doesNotAlias : doesNotAliasList) { + if (first) { + first = false; + continue; + } + prev = rewriter.create(op->getLoc(), prev, doesNotAlias); + } + + // Also consider the condition given by the dealloc operation and perform a + // conditional deallocation guarded by that value. + Value shouldDealloc = rewriter.create( + op->getLoc(), prev, adaptor.getConditions()[0]); + + rewriter.create( + op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { + builder.create(loc, adaptor.getMemrefs()[0]); + builder.create(loc); + }); + + // Compute the replacement values for the dealloc operation results. This + // inserts an already canonicalized form of + // `select(does_alias_with_memref(r), memref_cond, false)` for each retained + // value r. + SmallVector replacements; + Value trueVal = rewriter.create( + op->getLoc(), rewriter.getBoolAttr(true)); + for (Value doesNotAlias : doesNotAliasList) { + Value aliases = + rewriter.create(op->getLoc(), doesNotAlias, trueVal); + Value result = rewriter.create(op->getLoc(), aliases, + adaptor.getConditions()[0]); + replacements.push_back(result); + } + + rewriter.replaceOp(op, replacements); + + return success(); + } + /// Lowering that supports all features the dealloc operation has to offer. It /// computes the base pointer of each memref (as an index), stores it in a /// new memref helper structure and passes it to the helper function generated @@ -310,12 +401,20 @@ matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Lower the trivial case. - if (adaptor.getMemrefs().empty()) - return rewriter.eraseOp(op), success(); + if (adaptor.getMemrefs().empty()) { + Value falseVal = rewriter.create( + op.getLoc(), rewriter.getBoolAttr(false)); + rewriter.replaceOp( + op, SmallVector(adaptor.getRetained().size(), falseVal)); + return success(); + } if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty()) return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter); + if (adaptor.getMemrefs().size() == 1) + return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter); + return rewriteGeneralCase(op, adaptor, rewriter); } @@ -535,8 +634,7 @@ // Build dealloc helper function if there are deallocs. func::FuncOp helperFuncOp; getOperation()->walk([&](bufferization::DeallocOp deallocOp) { - if (deallocOp.getMemrefs().size() > 1 || - !deallocOp.getRetained().empty()) { + if (deallocOp.getMemrefs().size() > 1) { helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction( builder, getOperation()->getLoc(), symbolTable); return WalkResult::interrupt(); diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir --- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir +++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir @@ -66,17 +66,29 @@ memref.dealloc %arg0 : memref> return %1 : memref> } + // ----- // CHECK-LABEL: func @conversion_dealloc_empty func.func @conversion_dealloc_empty() { - // CHECK-NEXT: return + // CHECK-NOT: bufferization.dealloc bufferization.dealloc return } // ----- +func.func @conversion_dealloc_empty_but_retains(%arg0: memref<2xi32>, %arg1: memref<2xi32>) -> (i1, i1) { + %0:2 = bufferization.dealloc retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) + return %0#0, %0#1 : i1, i1 +} + +// CHECK-LABEL: func @conversion_dealloc_empty +// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false +// CHECK-NEXT: return [[FALSE]], [[FALSE]] : + +// ----- + // CHECK-NOT: func @deallocHelper // CHECK-LABEL: func @conversion_dealloc_simple // CHECK-SAME: [[ARG0:%.+]]: memref<2xf32> @@ -93,6 +105,33 @@ // ----- +func.func @conversion_dealloc_one_memref_and_multiple_retained(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>) + return %0#0, %0#1 : i1, i1 +} + +// CHECK-LABEL: func @conversion_dealloc_one_memref_and_multiple_retained +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<1xf32>, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xf32>) +// CHECK-DAG: [[M0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]] +// CHECK-DAG: [[R0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]] +// CHECK-DAG: [[R1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG3]] +// CHECK-DAG: [[DOES_NOT_ALIAS_R0:%.+]] = arith.cmpi ne, [[M0]], [[R0]] : index +// CHECK-DAG: [[DOES_NOT_ALIAS_R1:%.+]] = arith.cmpi ne, [[M0]], [[R1]] : index +// CHECK: [[NOT_RETAINED:%.+]] = arith.andi [[DOES_NOT_ALIAS_R0]], [[DOES_NOT_ALIAS_R1]] +// CHECK: [[SHOULD_DEALLOC:%.+]] = arith.andi [[NOT_RETAINED]], [[ARG2]] +// CHECK: scf.if [[SHOULD_DEALLOC]] +// CHECK: memref.dealloc [[ARG0]] +// CHECK: } +// CHECK-DAG: [[ALIASES_R0:%.+]] = arith.xori [[DOES_NOT_ALIAS_R0]], %true +// CHECK-DAG: [[ALIASES_R1:%.+]] = arith.xori [[DOES_NOT_ALIAS_R1]], %true +// CHECK-DAG: [[RES0:%.+]] = arith.andi [[ALIASES_R0]], [[ARG2]] +// CHECK-DAG: [[RES1:%.+]] = arith.andi [[ALIASES_R1]], [[ARG2]] +// CHECK: return [[RES0]], [[RES1]] + +// CHECK-NOT: func @dealloc_helper + +// ----- + func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) return %0#0, %0#1 : i1, i1