diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -195,12 +195,14 @@ This pass converts bufferization operations into memref operations. In the current state, this pass only transforms a `bufferization.clone` - operation into `memref.alloc` and `memref.copy` operations. This conversion - is needed, since some clone operations could remain after applying several - transformation processes. Currently, only `canonicalize` transforms clone - operations or even eliminates them. This can lead to errors if any clone op - survived after all conversion passes (starting from the bufferization - dialect) are performed. + operation into `memref.alloc` and `memref.copy` operations and + `bufferization.dealloc` operations (the same way as the + `-bufferization-lower-deallocations` pass). The conversion of `clone` + operations is needed, since some clone operations could remain after + applying several transformation processes. Currently, only `canonicalize` + transforms clone operations or even eliminates them. This can lead to errors + if any clone op survived after all conversion passes (starting from the + bufferization dialect) are performed. See: https://llvm.discourse.group/t/bufferization-error-related-to-memref-clone/4665 diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -5,6 +5,9 @@ namespace mlir { class ModuleOp; +class RewritePatternSet; +class OpBuilder; +class SymbolTable; namespace func { class FuncOp; @@ -29,6 +32,98 @@ /// static alias analysis. std::unique_ptr createBufferDeallocationSimplificationPass(); +/// Creates an instance of the LowerDeallocations pass to lower +/// `bufferization.dealloc` operations to the `memref` dialect. +std::unique_ptr createLowerDeallocationsPass(); + +/// Adds the conversion pattern of the `bufferization.dealloc` operation to the +/// given pattern set for use in other transformation passes. +void populateBufferizationDeallocLoweringPattern( + RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc); + +/// Construct the library function needed for the fully generic +/// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass. +/// The function can then be called at bufferization dealloc sites to determine +/// aliasing and ownership. +/// +/// The generated function takes two memrefs of indices and three memrefs of +/// booleans as arguments: +/// * The first argument A should contain the result of the +/// extract_aligned_pointer_as_index operation applied to the memrefs to be +/// deallocated +/// * The second argument B should contain the result of the +/// extract_aligned_pointer_as_index operation applied to the memrefs to be +/// retained +/// * The third argument C should contain the conditions as passed directly +/// to the deallocation operation. +/// * The fourth argument D is used to pass results to the caller. Those +/// represent the condition under which the memref at the corresponding +/// position in A should be deallocated. +/// * The fifth argument E is used to pass results to the caller. It +/// provides the ownership value corresponding the the memref at the same +/// position in B +/// +/// This helper function is supposed to be called once for each +/// `bufferization.dealloc` operation to determine the deallocation need and new +/// ownership indicator for the retained values, but does not perform the +/// deallocation itself. +/// +/// Generated code: +/// ``` +/// func.func @dealloc_helper( +/// %dyn_dealloc_base_pointer_list: memref, +/// %dyn_retain_base_pointer_list: memref, +/// %dyn_cond_list: memref, +/// %dyn_dealloc_cond_out: memref, +/// %dyn_ownership_out: memref) { +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %true = arith.constant true +/// %false = arith.constant false +/// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0 +/// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0 +/// // Zero initialize result buffer. +/// scf.for %i = %c0 to %num_retain_memrefs step %c1 { +/// memref.store %false, %dyn_ownership_out[%i] : memref +/// } +/// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 { +/// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i] +/// %cond = memref.load %dyn_cond_list[%i] +/// // Check for aliasing with retained memrefs. +/// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs +/// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) { +/// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j] +/// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index +/// scf.if %does_alias { +/// %curr_ownership = memref.load %dyn_ownership_out[%j] +/// %updated_ownership = arith.ori %curr_ownership, %cond : i1 +/// memref.store %updated_ownership, %dyn_ownership_out[%j] +/// } +/// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index +/// %updated_aggregate = arith.andi %does_not_alias_aggregated, +/// %does_not_alias : i1 +/// scf.yield %updated_aggregate : i1 +/// } +/// // Check for aliasing with dealloc memrefs in the list before the +/// // current one, i.e., +/// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j], +/// // %dyn_dealloc_base_pointer[i])` +/// %does_not_alias_any = scf.for %j = %c0 to %i step %c1 +/// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) { +/// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j] +/// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp +/// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias +/// scf.yield %updated_alias_agg : i1 +/// } +/// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1 +/// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref +/// } +/// return +/// } +/// ``` +func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, + SymbolTable &symbolTable); + /// Run buffer deallocation. LogicalResult deallocateBuffers(Operation *op); diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -108,6 +108,29 @@ ]; } +def LowerDeallocations : Pass<"bufferization-lower-deallocations"> { + let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`" + "operations"; + let description = [{ + This pass lowers `bufferization.dealloc` operations to the `memref` dialect. + It can be applied to a `builtin.module` or operations implementing the + `FunctionOpInterface`. For the latter, only simple `dealloc` operations can + be lowered because the library function necessary for the fully generic + lowering cannot be inserted. In this case, an error will be emitted. + Next to `memref.dealloc` operations, it may also emit operations from the + `arith`, `scf`, and `func` dialects to build conditional deallocations and + library functions to avoid code-size blow-up. + }]; + + let constructor = + "mlir::bufferization::createLowerDeallocationsPass()"; + + let dependentDialects = [ + "arith::ArithDialect", "memref::MemRefDialect", "scf::SCFDialect", + "func::FuncDialect" + ]; +} + def BufferHoisting : Pass<"buffer-hoisting", "func::FuncOp"> { let summary = "Optimizes placement of allocation operations by moving them " "into common dominators and out of nested regions"; diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -80,543 +81,6 @@ } }; -/// The DeallocOpConversion transforms all bufferization dealloc operations into -/// memref dealloc operations potentially guarded by scf if operations. -/// Additionally, memref extract_aligned_pointer_as_index and arith operations -/// are inserted to compute the guard conditions. We distinguish multiple cases -/// to provide an overall more efficient lowering. In the general case, a helper -/// func is created to avoid quadratic code size explosion (relative to the -/// number of operands of the dealloc operation). For examples of each case, -/// refer to the documentation of the member functions of this class. -class DeallocOpConversion - : public OpConversionPattern { - - /// Lower a simple case without any retained values and a single memref to - /// avoiding the helper function. Ideally, static analysis can provide enough - /// aliasing information to split the dealloc operations up into this simple - /// case as much as possible before running this pass. - /// - /// Example: - /// ``` - /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) - /// ``` - /// is lowered to - /// ``` - /// scf.if %arg1 { - /// memref.dealloc %arg0 : memref<2xf32> - /// } - /// ``` - LogicalResult - rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getMemrefs().size() == 1 && "expected only one memref"); - assert(adaptor.getRetained().empty() && "expected no retained memrefs"); - - rewriter.replaceOpWithNewOp( - op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) { - builder.create(loc, adaptor.getMemrefs()[0]); - builder.create(loc); - }); - return success(); - } - - /// A special case lowering for the deallocation operation with exactly one - /// memref, but arbitrary number of retained values. This avoids the helper - /// function that the general case needs and thus also avoids storing indices - /// to specifically allocated memrefs. The size of the code produced by this - /// lowering is linear to the number of retained values. - /// - /// Example: - /// ```mlir - /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond) - // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) - /// return %0#0, %0#1 : i1, i1 - /// ``` - /// ```mlir - /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m - /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 - /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer - /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 - /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer - /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1 - /// %should_dealloc = arith.andi %not_retained, %cond : i1 - /// scf.if %should_dealloc { - /// memref.dealloc %m : memref<2xf32> - /// } - /// %true = arith.constant true - /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1 - /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1 - /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1 - /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1 - /// return %r0_ownership, %r1_ownership : i1, i1 - /// ``` - LogicalResult rewriteOneMemrefMultipleRetainCase( - bufferization::DeallocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getMemrefs().size() == 1 && "expected only one memref"); - - // Compute the base pointer indices, compare all retained indices to the - // memref index to check if they alias. - SmallVector doesNotAliasList; - Value memrefAsIdx = rewriter.create( - op->getLoc(), adaptor.getMemrefs()[0]); - for (Value retained : adaptor.getRetained()) { - Value retainedAsIdx = - rewriter.create(op->getLoc(), - retained); - Value doesNotAlias = rewriter.create( - op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx); - doesNotAliasList.push_back(doesNotAlias); - } - - // AND-reduce the list of booleans from above. - Value prev = doesNotAliasList.front(); - for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front()) - prev = rewriter.create(op->getLoc(), prev, doesNotAlias); - - // Also consider the condition given by the dealloc operation and perform a - // conditional deallocation guarded by that value. - Value shouldDealloc = rewriter.create( - op->getLoc(), prev, adaptor.getConditions()[0]); - - rewriter.create( - op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { - builder.create(loc, adaptor.getMemrefs()[0]); - builder.create(loc); - }); - - // Compute the replacement values for the dealloc operation results. This - // inserts an already canonicalized form of - // `select(does_alias_with_memref(r), memref_cond, false)` for each retained - // value r. - SmallVector replacements; - Value trueVal = rewriter.create( - op->getLoc(), rewriter.getBoolAttr(true)); - for (Value doesNotAlias : doesNotAliasList) { - Value aliases = - rewriter.create(op->getLoc(), doesNotAlias, trueVal); - Value result = rewriter.create(op->getLoc(), aliases, - adaptor.getConditions()[0]); - replacements.push_back(result); - } - - rewriter.replaceOp(op, replacements); - - return success(); - } - - /// Lowering that supports all features the dealloc operation has to offer. It - /// computes the base pointer of each memref (as an index), stores it in a - /// new memref helper structure and passes it to the helper function generated - /// in 'buildDeallocationHelperFunction'. The results are stored in two lists - /// (represented as memrefs) of booleans passed as arguments. The first list - /// stores whether the corresponding condition should be deallocated, the - /// second list stores the ownership of the retained values which can be used - /// to replace the result values of the `bufferization.dealloc` operation. - /// - /// Example: - /// ``` - /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>) - /// if (%cond0, %cond1) - /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) - /// ``` - /// lowers to (simplified): - /// ``` - /// %c0 = arith.constant 0 : index - /// %c1 = arith.constant 1 : index - /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex> - /// %cond_list = memref.alloc() : memref<2xi1> - /// %retain_base_pointer_list = memref.alloc() : memref<2xindex> - /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0 - /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0] - /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1 - /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1] - /// memref.store %cond0, %cond_list[%c0] - /// memref.store %cond1, %cond_list[%c1] - /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 - /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0] - /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 - /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1] - /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list : - /// memref<2xindex> to memref - /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref - /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list : - /// memref<2xindex> to memref - /// %dealloc_cond_out = memref.alloc() : memref<2xi1> - /// %ownership_out = memref.alloc() : memref<2xi1> - /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out : - /// memref<2xi1> to memref - /// %dyn_ownership_out = memref.cast %ownership_out : - /// memref<2xi1> to memref - /// call @dealloc_helper(%dyn_dealloc_base_pointer_list, - /// %dyn_retain_base_pointer_list, - /// %dyn_cond_list, - /// %dyn_dealloc_cond_out, - /// %dyn_ownership_out) : (...) - /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1> - /// scf.if %m0_dealloc_cond { - /// memref.dealloc %m0 : memref<2xf32> - /// } - /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1> - /// scf.if %m1_dealloc_cond { - /// memref.dealloc %m1 : memref<5xf32> - /// } - /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1> - /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1> - /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex> - /// memref.dealloc %retain_base_pointer_list : memref<2xindex> - /// memref.dealloc %cond_list : memref<2xi1> - /// memref.dealloc %dealloc_cond_out : memref<2xi1> - /// memref.dealloc %ownership_out : memref<2xi1> - /// // replace %0#0 with %r0_ownership - /// // replace %0#1 with %r1_ownership - /// ``` - LogicalResult rewriteGeneralCase(bufferization::DeallocOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // Allocate two memrefs holding the base pointer indices of the list of - // memrefs to be deallocated and the ones to be retained. These can then be - // passed to the helper function and the for-loops can iterate over them. - // Without storing them to memrefs, we could not use for-loops but only a - // completely unrolled version of it, potentially leading to code-size - // blow-up. - Value toDeallocMemref = rewriter.create( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, - rewriter.getIndexType())); - Value conditionMemref = rewriter.create( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()}, - rewriter.getI1Type())); - Value toRetainMemref = rewriter.create( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, - rewriter.getIndexType())); - - auto getConstValue = [&](uint64_t value) -> Value { - return rewriter.create(op.getLoc(), - rewriter.getIndexAttr(value)); - }; - - // Extract the base pointers of the memrefs as indices to check for aliasing - // at runtime. - for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) { - Value memrefAsIdx = - rewriter.create(op.getLoc(), - toDealloc); - rewriter.create(op.getLoc(), memrefAsIdx, - toDeallocMemref, getConstValue(i)); - } - - for (auto [i, cond] : llvm::enumerate(adaptor.getConditions())) - rewriter.create(op.getLoc(), cond, conditionMemref, - getConstValue(i)); - - for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) { - Value memrefAsIdx = - rewriter.create(op.getLoc(), - toRetain); - rewriter.create(op.getLoc(), memrefAsIdx, toRetainMemref, - getConstValue(i)); - } - - // Cast the allocated memrefs to dynamic shape because we want only one - // helper function no matter how many operands the bufferization.dealloc - // has. - Value castedDeallocMemref = rewriter.create( - op->getLoc(), - MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), - toDeallocMemref); - Value castedCondsMemref = rewriter.create( - op->getLoc(), - MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), - conditionMemref); - Value castedRetainMemref = rewriter.create( - op->getLoc(), - MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), - toRetainMemref); - - Value deallocCondsMemref = rewriter.create( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, - rewriter.getI1Type())); - Value retainCondsMemref = rewriter.create( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, - rewriter.getI1Type())); - - Value castedDeallocCondsMemref = rewriter.create( - op->getLoc(), - MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), - deallocCondsMemref); - Value castedRetainCondsMemref = rewriter.create( - op->getLoc(), - MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), - retainCondsMemref); - - rewriter.create( - op.getLoc(), deallocHelperFunc, - SmallVector{castedDeallocMemref, castedRetainMemref, - castedCondsMemref, castedDeallocCondsMemref, - castedRetainCondsMemref}); - - for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) { - Value idxValue = getConstValue(i); - Value shouldDealloc = rewriter.create( - op.getLoc(), deallocCondsMemref, idxValue); - rewriter.create( - op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { - builder.create(loc, adaptor.getMemrefs()[i]); - builder.create(loc); - }); - } - - SmallVector replacements; - for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) { - Value idxValue = getConstValue(i); - Value ownership = rewriter.create( - op.getLoc(), retainCondsMemref, idxValue); - replacements.push_back(ownership); - } - - // Deallocate above allocated memrefs again to avoid memory leaks. - // Deallocation will not be run on code after this stage. - rewriter.create(op.getLoc(), toDeallocMemref); - rewriter.create(op.getLoc(), toRetainMemref); - rewriter.create(op.getLoc(), conditionMemref); - rewriter.create(op.getLoc(), deallocCondsMemref); - rewriter.create(op.getLoc(), retainCondsMemref); - - rewriter.replaceOp(op, replacements); - return success(); - } - -public: - DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc) - : OpConversionPattern(context), - deallocHelperFunc(deallocHelperFunc) {} - - LogicalResult - matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Lower the trivial case. - if (adaptor.getMemrefs().empty()) { - Value falseVal = rewriter.create( - op.getLoc(), rewriter.getBoolAttr(false)); - rewriter.replaceOp( - op, SmallVector(adaptor.getRetained().size(), falseVal)); - return success(); - } - - if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty()) - return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter); - - if (adaptor.getMemrefs().size() == 1) - return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter); - - if (!deallocHelperFunc) - return op->emitError( - "library function required for generic lowering, but cannot be " - "automatically inserted when operating on functions"); - - return rewriteGeneralCase(op, adaptor, rewriter); - } - - /// Build a helper function per compilation unit that can be called at - /// bufferization dealloc sites to determine aliasing and ownership. - /// - /// The generated function takes two memrefs of indices and three memrefs of - /// booleans as arguments: - /// * The first argument A should contain the result of the - /// extract_aligned_pointer_as_index operation applied to the memrefs to be - /// deallocated - /// * The second argument B should contain the result of the - /// extract_aligned_pointer_as_index operation applied to the memrefs to be - /// retained - /// * The third argument C should contain the conditions as passed directly - /// to the deallocation operation. - /// * The fourth argument D is used to pass results to the caller. Those - /// represent the condition under which the memref at the corresponding - /// position in A should be deallocated. - /// * The fifth argument E is used to pass results to the caller. It - /// provides the ownership value corresponding the the memref at the same - /// position in B - /// - /// This helper function is supposed to be called once for each - /// `bufferization.dealloc` operation to determine the deallocation need and - /// new ownership indicator for the retained values, but does not perform the - /// deallocation itself. - /// - /// Generated code: - /// ``` - /// func.func @dealloc_helper( - /// %dyn_dealloc_base_pointer_list: memref, - /// %dyn_retain_base_pointer_list: memref, - /// %dyn_cond_list: memref, - /// %dyn_dealloc_cond_out: memref, - /// %dyn_ownership_out: memref) { - /// %c0 = arith.constant 0 : index - /// %c1 = arith.constant 1 : index - /// %true = arith.constant true - /// %false = arith.constant false - /// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0 - /// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0 - /// // Zero initialize result buffer. - /// scf.for %i = %c0 to %num_retain_memrefs step %c1 { - /// memref.store %false, %dyn_ownership_out[%i] : memref - /// } - /// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 { - /// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i] - /// %cond = memref.load %dyn_cond_list[%i] - /// // Check for aliasing with retained memrefs. - /// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs - /// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) { - /// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j] - /// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index - /// scf.if %does_alias { - /// %curr_ownership = memref.load %dyn_ownership_out[%j] - /// %updated_ownership = arith.ori %curr_ownership, %cond : i1 - /// memref.store %updated_ownership, %dyn_ownership_out[%j] - /// } - /// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index - /// %updated_aggregate = arith.andi %does_not_alias_aggregated, - /// %does_not_alias : i1 - /// scf.yield %updated_aggregate : i1 - /// } - /// // Check for aliasing with dealloc memrefs in the list before the - /// // current one, i.e., - /// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j], - /// // %dyn_dealloc_base_pointer[i])` - /// %does_not_alias_any = scf.for %j = %c0 to %i step %c1 - /// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) { - /// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j] - /// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp - /// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias - /// scf.yield %updated_alias_agg : i1 - /// } - /// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1 - /// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref - /// } - /// return - /// } - /// ``` - static func::FuncOp - buildDeallocationHelperFunction(OpBuilder &builder, Location loc, - SymbolTable &symbolTable) { - Type indexMemrefType = - MemRefType::get({ShapedType::kDynamic}, builder.getIndexType()); - Type boolMemrefType = - MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); - SmallVector argTypes{indexMemrefType, indexMemrefType, boolMemrefType, - boolMemrefType, boolMemrefType}; - builder.clearInsertionPoint(); - - // Generate the func operation itself. - auto helperFuncOp = func::FuncOp::create( - loc, "dealloc_helper", builder.getFunctionType(argTypes, {})); - symbolTable.insert(helperFuncOp); - auto &block = helperFuncOp.getFunctionBody().emplaceBlock(); - block.addArguments(argTypes, SmallVector(argTypes.size(), loc)); - - builder.setInsertionPointToStart(&block); - Value toDeallocMemref = helperFuncOp.getArguments()[0]; - Value toRetainMemref = helperFuncOp.getArguments()[1]; - Value conditionMemref = helperFuncOp.getArguments()[2]; - Value deallocCondsMemref = helperFuncOp.getArguments()[3]; - Value retainCondsMemref = helperFuncOp.getArguments()[4]; - - // Insert some prerequisites. - Value c0 = builder.create(loc, builder.getIndexAttr(0)); - Value c1 = builder.create(loc, builder.getIndexAttr(1)); - Value trueValue = - builder.create(loc, builder.getBoolAttr(true)); - Value falseValue = - builder.create(loc, builder.getBoolAttr(false)); - Value toDeallocSize = - builder.create(loc, toDeallocMemref, c0); - Value toRetainSize = builder.create(loc, toRetainMemref, c0); - - builder.create( - loc, c0, toRetainSize, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - builder.create(loc, falseValue, retainCondsMemref, - i); - builder.create(loc); - }); - - builder.create( - loc, c0, toDeallocSize, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value outerIter, - ValueRange iterArgs) { - Value toDealloc = - builder.create(loc, toDeallocMemref, outerIter); - Value cond = - builder.create(loc, conditionMemref, outerIter); - - // Build the first for loop that computes aliasing with retained - // memrefs. - Value noRetainAlias = - builder - .create( - loc, c0, toRetainSize, c1, trueValue, - [&](OpBuilder &builder, Location loc, Value i, - ValueRange iterArgs) { - Value retainValue = builder.create( - loc, toRetainMemref, i); - Value doesAlias = builder.create( - loc, arith::CmpIPredicate::eq, retainValue, - toDealloc); - builder.create( - loc, doesAlias, - [&](OpBuilder &builder, Location loc) { - Value retainCondValue = - builder.create( - loc, retainCondsMemref, i); - Value aggregatedRetainCond = - builder.create( - loc, retainCondValue, cond); - builder.create( - loc, aggregatedRetainCond, retainCondsMemref, - i); - builder.create(loc); - }); - Value doesntAlias = builder.create( - loc, arith::CmpIPredicate::ne, retainValue, - toDealloc); - Value yieldValue = builder.create( - loc, iterArgs[0], doesntAlias); - builder.create(loc, yieldValue); - }) - .getResult(0); - - // Build the second for loop that adds aliasing with previously - // deallocated memrefs. - Value noAlias = - builder - .create( - loc, c0, outerIter, c1, noRetainAlias, - [&](OpBuilder &builder, Location loc, Value i, - ValueRange iterArgs) { - Value prevDeallocValue = builder.create( - loc, toDeallocMemref, i); - Value doesntAlias = builder.create( - loc, arith::CmpIPredicate::ne, prevDeallocValue, - toDealloc); - Value yieldValue = builder.create( - loc, iterArgs[0], doesntAlias); - builder.create(loc, yieldValue); - }) - .getResult(0); - - Value shouldDealoc = - builder.create(loc, noAlias, cond); - builder.create(loc, shouldDealoc, deallocCondsMemref, - outerIter); - builder.create(loc); - }); - - builder.create(loc); - return helperFuncOp; - } - -private: - func::FuncOp deallocHelperFunc; -}; } // namespace namespace { @@ -641,7 +105,7 @@ // Build dealloc helper function if there are deallocs. getOperation()->walk([&](bufferization::DeallocOp deallocOp) { if (deallocOp.getMemrefs().size() > 1) { - helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction( + helperFuncOp = bufferization::buildDeallocationLibraryFunction( builder, getOperation()->getLoc(), symbolTable); return WalkResult::interrupt(); } @@ -651,7 +115,8 @@ RewritePatternSet patterns(&getContext()); patterns.add(patterns.getContext()); - patterns.add(patterns.getContext(), helperFuncOp); + bufferization::populateBufferizationDeallocLoweringPattern(patterns, + helperFuncOp); ConversionTarget target(getContext()); target.addLegalDialect { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Check for unranked memref types which are currently not supported. - Type type = op.getType(); - if (isa(type)) { - return rewriter.notifyMatchFailure( - op, "UnrankedMemRefType is not supported."); - } - MemRefType memrefType = cast(type); - MemRefLayoutAttrInterface layout; - auto allocType = - MemRefType::get(memrefType.getShape(), memrefType.getElementType(), - layout, memrefType.getMemorySpace()); - // Since this implementation always allocates, certain result types of the - // clone op cannot be lowered. - if (!memref::CastOp::areCastCompatible({allocType}, {memrefType})) - return failure(); - - // Transform a clone operation into alloc + copy operation and pay - // attention to the shape dimensions. - Location loc = op->getLoc(); - SmallVector dynamicOperands; - for (int i = 0; i < memrefType.getRank(); ++i) { - if (!memrefType.isDynamicDim(i)) - continue; - Value dim = rewriter.createOrFold(loc, op.getInput(), i); - dynamicOperands.push_back(dim); - } - - // Allocate a memref with identity layout. - Value alloc = rewriter.create(op->getLoc(), allocType, - dynamicOperands); - // Cast the allocation to the specified type if needed. - if (memrefType != allocType) - alloc = rewriter.create(op->getLoc(), memrefType, alloc); - rewriter.replaceOp(op, alloc); - rewriter.create(loc, op.getInput(), alloc); - return success(); - } -}; - /// The DeallocOpConversion transforms all bufferization dealloc operations into /// memref dealloc operations potentially guarded by scf if operations. /// Additionally, memref extract_aligned_pointer_as_index and arith operations @@ -417,213 +369,15 @@ return rewriteGeneralCase(op, adaptor, rewriter); } - /// Build a helper function per compilation unit that can be called at - /// bufferization dealloc sites to determine aliasing and ownership. - /// - /// The generated function takes two memrefs of indices and three memrefs of - /// booleans as arguments: - /// * The first argument A should contain the result of the - /// extract_aligned_pointer_as_index operation applied to the memrefs to be - /// deallocated - /// * The second argument B should contain the result of the - /// extract_aligned_pointer_as_index operation applied to the memrefs to be - /// retained - /// * The third argument C should contain the conditions as passed directly - /// to the deallocation operation. - /// * The fourth argument D is used to pass results to the caller. Those - /// represent the condition under which the memref at the corresponding - /// position in A should be deallocated. - /// * The fifth argument E is used to pass results to the caller. It - /// provides the ownership value corresponding the the memref at the same - /// position in B - /// - /// This helper function is supposed to be called once for each - /// `bufferization.dealloc` operation to determine the deallocation need and - /// new ownership indicator for the retained values, but does not perform the - /// deallocation itself. - /// - /// Generated code: - /// ``` - /// func.func @dealloc_helper( - /// %dyn_dealloc_base_pointer_list: memref, - /// %dyn_retain_base_pointer_list: memref, - /// %dyn_cond_list: memref, - /// %dyn_dealloc_cond_out: memref, - /// %dyn_ownership_out: memref) { - /// %c0 = arith.constant 0 : index - /// %c1 = arith.constant 1 : index - /// %true = arith.constant true - /// %false = arith.constant false - /// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0 - /// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0 - /// // Zero initialize result buffer. - /// scf.for %i = %c0 to %num_retain_memrefs step %c1 { - /// memref.store %false, %dyn_ownership_out[%i] : memref - /// } - /// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 { - /// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i] - /// %cond = memref.load %dyn_cond_list[%i] - /// // Check for aliasing with retained memrefs. - /// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs - /// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) { - /// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j] - /// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index - /// scf.if %does_alias { - /// %curr_ownership = memref.load %dyn_ownership_out[%j] - /// %updated_ownership = arith.ori %curr_ownership, %cond : i1 - /// memref.store %updated_ownership, %dyn_ownership_out[%j] - /// } - /// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index - /// %updated_aggregate = arith.andi %does_not_alias_aggregated, - /// %does_not_alias : i1 - /// scf.yield %updated_aggregate : i1 - /// } - /// // Check for aliasing with dealloc memrefs in the list before the - /// // current one, i.e., - /// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j], - /// // %dyn_dealloc_base_pointer[i])` - /// %does_not_alias_any = scf.for %j = %c0 to %i step %c1 - /// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) { - /// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j] - /// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp - /// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias - /// scf.yield %updated_alias_agg : i1 - /// } - /// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1 - /// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref - /// } - /// return - /// } - /// ``` - static func::FuncOp - buildDeallocationHelperFunction(OpBuilder &builder, Location loc, - SymbolTable &symbolTable) { - Type indexMemrefType = - MemRefType::get({ShapedType::kDynamic}, builder.getIndexType()); - Type boolMemrefType = - MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); - SmallVector argTypes{indexMemrefType, indexMemrefType, boolMemrefType, - boolMemrefType, boolMemrefType}; - builder.clearInsertionPoint(); - - // Generate the func operation itself. - auto helperFuncOp = func::FuncOp::create( - loc, "dealloc_helper", builder.getFunctionType(argTypes, {})); - symbolTable.insert(helperFuncOp); - auto &block = helperFuncOp.getFunctionBody().emplaceBlock(); - block.addArguments(argTypes, SmallVector(argTypes.size(), loc)); - - builder.setInsertionPointToStart(&block); - Value toDeallocMemref = helperFuncOp.getArguments()[0]; - Value toRetainMemref = helperFuncOp.getArguments()[1]; - Value conditionMemref = helperFuncOp.getArguments()[2]; - Value deallocCondsMemref = helperFuncOp.getArguments()[3]; - Value retainCondsMemref = helperFuncOp.getArguments()[4]; - - // Insert some prerequisites. - Value c0 = builder.create(loc, builder.getIndexAttr(0)); - Value c1 = builder.create(loc, builder.getIndexAttr(1)); - Value trueValue = - builder.create(loc, builder.getBoolAttr(true)); - Value falseValue = - builder.create(loc, builder.getBoolAttr(false)); - Value toDeallocSize = - builder.create(loc, toDeallocMemref, c0); - Value toRetainSize = builder.create(loc, toRetainMemref, c0); - - builder.create( - loc, c0, toRetainSize, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - builder.create(loc, falseValue, retainCondsMemref, - i); - builder.create(loc); - }); - - builder.create( - loc, c0, toDeallocSize, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value outerIter, - ValueRange iterArgs) { - Value toDealloc = - builder.create(loc, toDeallocMemref, outerIter); - Value cond = - builder.create(loc, conditionMemref, outerIter); - - // Build the first for loop that computes aliasing with retained - // memrefs. - Value noRetainAlias = - builder - .create( - loc, c0, toRetainSize, c1, trueValue, - [&](OpBuilder &builder, Location loc, Value i, - ValueRange iterArgs) { - Value retainValue = builder.create( - loc, toRetainMemref, i); - Value doesAlias = builder.create( - loc, arith::CmpIPredicate::eq, retainValue, - toDealloc); - builder.create( - loc, doesAlias, - [&](OpBuilder &builder, Location loc) { - Value retainCondValue = - builder.create( - loc, retainCondsMemref, i); - Value aggregatedRetainCond = - builder.create( - loc, retainCondValue, cond); - builder.create( - loc, aggregatedRetainCond, retainCondsMemref, - i); - builder.create(loc); - }); - Value doesntAlias = builder.create( - loc, arith::CmpIPredicate::ne, retainValue, - toDealloc); - Value yieldValue = builder.create( - loc, iterArgs[0], doesntAlias); - builder.create(loc, yieldValue); - }) - .getResult(0); - - // Build the second for loop that adds aliasing with previously - // deallocated memrefs. - Value noAlias = - builder - .create( - loc, c0, outerIter, c1, noRetainAlias, - [&](OpBuilder &builder, Location loc, Value i, - ValueRange iterArgs) { - Value prevDeallocValue = builder.create( - loc, toDeallocMemref, i); - Value doesntAlias = builder.create( - loc, arith::CmpIPredicate::ne, prevDeallocValue, - toDealloc); - Value yieldValue = builder.create( - loc, iterArgs[0], doesntAlias); - builder.create(loc, yieldValue); - }) - .getResult(0); - - Value shouldDealoc = - builder.create(loc, noAlias, cond); - builder.create(loc, shouldDealoc, deallocCondsMemref, - outerIter); - builder.create(loc); - }); - - builder.create(loc); - return helperFuncOp; - } - private: func::FuncOp deallocHelperFunc; }; } // namespace namespace { -struct BufferizationToMemRefPass - : public impl::ConvertBufferizationToMemRefBase { - BufferizationToMemRefPass() = default; - +struct LowerDeallocationsPass + : public bufferization::impl::LowerDeallocationsBase< + LowerDeallocationsPass> { void runOnOperation() override { if (!isa(getOperation())) { emitError(getOperation()->getLoc(), @@ -641,7 +395,7 @@ // Build dealloc helper function if there are deallocs. getOperation()->walk([&](bufferization::DeallocOp deallocOp) { if (deallocOp.getMemrefs().size() > 1) { - helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction( + helperFuncOp = bufferization::buildDeallocationLibraryFunction( builder, getOperation()->getLoc(), symbolTable); return WalkResult::interrupt(); } @@ -650,13 +404,13 @@ } RewritePatternSet patterns(&getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext(), helperFuncOp); + bufferization::populateBufferizationDeallocLoweringPattern(patterns, + helperFuncOp); ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalDialect(); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -665,6 +419,126 @@ }; } // namespace -std::unique_ptr mlir::createBufferizationToMemRefPass() { - return std::make_unique(); +func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( + OpBuilder &builder, Location loc, SymbolTable &symbolTable) { + Type indexMemrefType = + MemRefType::get({ShapedType::kDynamic}, builder.getIndexType()); + Type boolMemrefType = + MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); + SmallVector argTypes{indexMemrefType, indexMemrefType, boolMemrefType, + boolMemrefType, boolMemrefType}; + builder.clearInsertionPoint(); + + // Generate the func operation itself. + auto helperFuncOp = func::FuncOp::create( + loc, "dealloc_helper", builder.getFunctionType(argTypes, {})); + symbolTable.insert(helperFuncOp); + auto &block = helperFuncOp.getFunctionBody().emplaceBlock(); + block.addArguments(argTypes, SmallVector(argTypes.size(), loc)); + + builder.setInsertionPointToStart(&block); + Value toDeallocMemref = helperFuncOp.getArguments()[0]; + Value toRetainMemref = helperFuncOp.getArguments()[1]; + Value conditionMemref = helperFuncOp.getArguments()[2]; + Value deallocCondsMemref = helperFuncOp.getArguments()[3]; + Value retainCondsMemref = helperFuncOp.getArguments()[4]; + + // Insert some prerequisites. + Value c0 = builder.create(loc, builder.getIndexAttr(0)); + Value c1 = builder.create(loc, builder.getIndexAttr(1)); + Value trueValue = + builder.create(loc, builder.getBoolAttr(true)); + Value falseValue = + builder.create(loc, builder.getBoolAttr(false)); + Value toDeallocSize = builder.create(loc, toDeallocMemref, c0); + Value toRetainSize = builder.create(loc, toRetainMemref, c0); + + builder.create( + loc, c0, toRetainSize, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + builder.create(loc, falseValue, retainCondsMemref, i); + builder.create(loc); + }); + + builder.create( + loc, c0, toDeallocSize, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value outerIter, + ValueRange iterArgs) { + Value toDealloc = + builder.create(loc, toDeallocMemref, outerIter); + Value cond = + builder.create(loc, conditionMemref, outerIter); + + // Build the first for loop that computes aliasing with retained + // memrefs. + Value noRetainAlias = + builder + .create( + loc, c0, toRetainSize, c1, trueValue, + [&](OpBuilder &builder, Location loc, Value i, + ValueRange iterArgs) { + Value retainValue = builder.create( + loc, toRetainMemref, i); + Value doesAlias = builder.create( + loc, arith::CmpIPredicate::eq, retainValue, + toDealloc); + builder.create( + loc, doesAlias, + [&](OpBuilder &builder, Location loc) { + Value retainCondValue = + builder.create( + loc, retainCondsMemref, i); + Value aggregatedRetainCond = + builder.create( + loc, retainCondValue, cond); + builder.create( + loc, aggregatedRetainCond, retainCondsMemref, + i); + builder.create(loc); + }); + Value doesntAlias = builder.create( + loc, arith::CmpIPredicate::ne, retainValue, + toDealloc); + Value yieldValue = builder.create( + loc, iterArgs[0], doesntAlias); + builder.create(loc, yieldValue); + }) + .getResult(0); + + // Build the second for loop that adds aliasing with previously + // deallocated memrefs. + Value noAlias = + builder + .create( + loc, c0, outerIter, c1, noRetainAlias, + [&](OpBuilder &builder, Location loc, Value i, + ValueRange iterArgs) { + Value prevDeallocValue = builder.create( + loc, toDeallocMemref, i); + Value doesntAlias = builder.create( + loc, arith::CmpIPredicate::ne, prevDeallocValue, + toDealloc); + Value yieldValue = builder.create( + loc, iterArgs[0], doesntAlias); + builder.create(loc, yieldValue); + }) + .getResult(0); + + Value shouldDealoc = builder.create(loc, noAlias, cond); + builder.create(loc, shouldDealoc, deallocCondsMemref, + outerIter); + builder.create(loc); + }); + + builder.create(loc); + return helperFuncOp; +} + +void mlir::bufferization::populateBufferizationDeallocLoweringPattern( + RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) { + patterns.add(patterns.getContext(), deallocLibraryFunc); +} + +std::unique_ptr mlir::bufferization::createLowerDeallocationsPass() { + return std::make_unique(); } diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir --- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir +++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir @@ -68,26 +68,7 @@ } // ----- - -// CHECK-LABEL: func @conversion_dealloc_empty -func.func @conversion_dealloc_empty() { - // CHECK-NOT: bufferization.dealloc - bufferization.dealloc - return -} - -// ----- - -func.func @conversion_dealloc_empty_but_retains(%arg0: memref<2xi32>, %arg1: memref<2xi32>) -> (i1, i1) { - %0:2 = bufferization.dealloc retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) - return %0#0, %0#1 : i1, i1 -} - -// CHECK-LABEL: func @conversion_dealloc_empty -// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false -// CHECK-NEXT: return [[FALSE]], [[FALSE]] : - -// ----- +// Test: check that the dealloc lowering pattern is registered. // CHECK-NOT: func @deallocHelper // CHECK-LABEL: func @conversion_dealloc_simple @@ -102,124 +83,3 @@ // CHECk-NEXT: memref.dealloc [[ARG0]] : memref<2xf32> // CHECk-NEXT: } // CHECk-NEXT: return - -// ----- - -func.func @conversion_dealloc_one_memref_and_multiple_retained(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) { - %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>) - return %0#0, %0#1 : i1, i1 -} - -// CHECK-LABEL: func @conversion_dealloc_one_memref_and_multiple_retained -// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<1xf32>, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xf32>) -// CHECK-DAG: [[M0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]] -// CHECK-DAG: [[R0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]] -// CHECK-DAG: [[R1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG3]] -// CHECK-DAG: [[DOES_NOT_ALIAS_R0:%.+]] = arith.cmpi ne, [[M0]], [[R0]] : index -// CHECK-DAG: [[DOES_NOT_ALIAS_R1:%.+]] = arith.cmpi ne, [[M0]], [[R1]] : index -// CHECK: [[NOT_RETAINED:%.+]] = arith.andi [[DOES_NOT_ALIAS_R0]], [[DOES_NOT_ALIAS_R1]] -// CHECK: [[SHOULD_DEALLOC:%.+]] = arith.andi [[NOT_RETAINED]], [[ARG2]] -// CHECK: scf.if [[SHOULD_DEALLOC]] -// CHECK: memref.dealloc [[ARG0]] -// CHECK: } -// CHECK-DAG: [[ALIASES_R0:%.+]] = arith.xori [[DOES_NOT_ALIAS_R0]], %true -// CHECK-DAG: [[ALIASES_R1:%.+]] = arith.xori [[DOES_NOT_ALIAS_R1]], %true -// CHECK-DAG: [[RES0:%.+]] = arith.andi [[ALIASES_R0]], [[ARG2]] -// CHECK-DAG: [[RES1:%.+]] = arith.andi [[ALIASES_R1]], [[ARG2]] -// CHECK: return [[RES0]], [[RES1]] - -// CHECK-NOT: func @dealloc_helper - -// ----- - -func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { - %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) - return %0#0, %0#1 : i1, i1 -} - -// CHECK-LABEL: func @conversion_dealloc_multiple_memrefs_and_retained -// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<5xf32>, -// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1, -// CHECK-SAME: [[ARG5:%.+]]: memref<2xf32>) -// CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex> -// CHECK: [[CONDS:%.+]] = memref.alloc() : memref<2xi1> -// CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<2xindex> -// CHECK-DAG: [[V0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]] -// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index -// CHECK-DAG: memref.store [[V0]], [[TO_DEALLOC_MR]][[[C0]]] -// CHECK-DAG: [[V1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]] -// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index -// CHECK-DAG: memref.store [[V1]], [[TO_DEALLOC_MR]][[[C1]]] -// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index -// CHECK-DAG: memref.store [[ARG3]], [[CONDS]][[[C0]]] -// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index -// CHECK-DAG: memref.store [[ARG4]], [[CONDS]][[[C1]]] -// CHECK-DAG: [[V2:%.+]] = memref.extract_aligned_pointer_as_index [[ARG2]] -// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index -// CHECK-DAG: memref.store [[V2]], [[TO_RETAIN_MR]][[[C0]]] -// CHECK-DAG: [[V3:%.+]] = memref.extract_aligned_pointer_as_index [[ARG5]] -// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index -// CHECK-DAG: memref.store [[V3]], [[TO_RETAIN_MR]][[[C1]]] -// CHECK-DAG: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref -// CHECK-DAG: [[CAST_CONDS:%.+]] = memref.cast [[CONDS]] : memref<2xi1> to memref -// CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<2xindex> to memref -// CHECK: [[DEALLOC_CONDS:%.+]] = memref.alloc() : memref<2xi1> -// CHECK: [[RETAIN_CONDS:%.+]] = memref.alloc() : memref<2xi1> -// CHECK: [[CAST_DEALLOC_CONDS:%.+]] = memref.cast [[DEALLOC_CONDS]] : memref<2xi1> to memref -// CHECK: [[CAST_RETAIN_CONDS:%.+]] = memref.cast [[RETAIN_CONDS]] : memref<2xi1> to memref -// CHECK: call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[CAST_CONDS]], [[CAST_DEALLOC_CONDS]], [[CAST_RETAIN_CONDS]]) -// CHECK: [[C0:%.+]] = arith.constant 0 : index -// CHECK: [[SHOULD_DEALLOC_0:%.+]] = memref.load [[DEALLOC_CONDS]][[[C0]]] -// CHECK: scf.if [[SHOULD_DEALLOC_0]] { -// CHECK: memref.dealloc %arg0 -// CHECK: } -// CHECK: [[C1:%.+]] = arith.constant 1 : index -// CHECK: [[SHOULD_DEALLOC_1:%.+]] = memref.load [[DEALLOC_CONDS]][[[C1]]] -// CHECK: scf.if [[SHOULD_DEALLOC_1]] -// CHECK: memref.dealloc [[ARG1]] -// CHECK: } -// CHECK: [[C0:%.+]] = arith.constant 0 : index -// CHECK: [[OWNERSHIP0:%.+]] = memref.load [[RETAIN_CONDS]][[[C0]]] -// CHECK: [[C1:%.+]] = arith.constant 1 : index -// CHECK: [[OWNERSHIP1:%.+]] = memref.load [[RETAIN_CONDS]][[[C1]]] -// CHECK: memref.dealloc [[TO_DEALLOC_MR]] -// CHECK: memref.dealloc [[TO_RETAIN_MR]] -// CHECK: memref.dealloc [[CONDS]] -// CHECK: memref.dealloc [[DEALLOC_CONDS]] -// CHECK: memref.dealloc [[RETAIN_CONDS]] -// CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]] - -// CHECK: func @dealloc_helper -// CHECK-SAME: ([[TO_DEALLOC_MR:%.+]]: memref, [[TO_RETAIN_MR:%.+]]: memref, -// CHECK-SAME: [[CONDS:%.+]]: memref, [[DEALLOC_CONDS_OUT:%.+]]: memref, -// CHECK-SAME: [[RETAIN_CONDS_OUT:%.+]]: memref) -// CHECK: [[TO_DEALLOC_SIZE:%.+]] = memref.dim [[TO_DEALLOC_MR]], %c0 -// CHECK: [[TO_RETAIN_SIZE:%.+]] = memref.dim [[TO_RETAIN_MR]], %c0 -// CHECK: scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 { -// CHECK-NEXT: memref.store %false, [[RETAIN_CONDS_OUT]][[[ITER]]] -// CHECK-NEXT: } -// CHECK: scf.for [[OUTER_ITER:%.+]] = %c0 to [[TO_DEALLOC_SIZE]] step %c1 { -// CHECK: [[TO_DEALLOC:%.+]] = memref.load [[TO_DEALLOC_MR]][[[OUTER_ITER]]] -// CHECK-NEXT: [[COND:%.+]] = memref.load [[CONDS]][[[OUTER_ITER]]] -// CHECK-NEXT: [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) { -// CHECK-NEXT: [[RETAIN_VAL:%.+]] = memref.load [[TO_RETAIN_MR]][[[ITER]]] : memref -// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi eq, [[RETAIN_VAL]], [[TO_DEALLOC]] : index -// CHECK-NEXT: scf.if [[DOES_ALIAS]] -// CHECK-NEXT: [[RETAIN_COND:%.+]] = memref.load [[RETAIN_CONDS_OUT]][[[ITER]]] -// CHECK-NEXT: [[AGG_RETAIN_COND:%.+]] = arith.ori [[RETAIN_COND]], [[COND]] : i1 -// CHECK-NEXT: memref.store [[AGG_RETAIN_COND]], [[RETAIN_CONDS_OUT]][[[ITER]]] -// CHECK-NEXT: } -// CHECK-NEXT: [[DOES_NOT_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index -// CHECK-NEXT: [[AGG_DOES_NOT_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_NOT_ALIAS]] : i1 -// CHECK-NEXT: scf.yield [[AGG_DOES_NOT_ALIAS]] : i1 -// CHECK-NEXT: } -// CHECK-NEXT: [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[OUTER_ITER]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) { -// CHECK-NEXT: [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref -// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index -// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1 -// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1 -// CHECK-NEXT: } -// CHECK-NEXT: [[DEALLOC_COND:%.+]] = arith.andi [[SHOULD_DEALLOC]], [[COND]] : i1 -// CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]] -// CHECK-NEXT: } -// CHECK-NEXT: return diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations-func.mlir rename from mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir rename to mlir/test/Dialect/Bufferization/Transforms/lower-deallocations-func.mlir --- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations-func.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -verify-diagnostics --pass-pipeline="builtin.module(func.func(convert-bufferization-to-memref))" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -verify-diagnostics --pass-pipeline="builtin.module(func.func(bufferization-lower-deallocations))" -split-input-file %s | FileCheck %s // CHECK-NOT: func @deallocHelper // CHECK-LABEL: func @conversion_dealloc_simple diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir copy from mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir copy to mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir --- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir @@ -1,73 +1,4 @@ -// RUN: mlir-opt -verify-diagnostics -convert-bufferization-to-memref -split-input-file %s | FileCheck %s - -// CHECK-LABEL: @conversion_static -func.func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> { - %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32> - memref.dealloc %arg0 : memref<2xf32> - return %0 : memref<2xf32> -} - -// CHECK: %[[ALLOC:.*]] = memref.alloc -// CHECK-NEXT: memref.copy %[[ARG:.*]], %[[ALLOC]] -// CHECK-NEXT: memref.dealloc %[[ARG]] -// CHECK-NEXT: return %[[ALLOC]] - -// ----- - -// CHECK-LABEL: @conversion_dynamic -func.func @conversion_dynamic(%arg0 : memref) -> memref { - %1 = bufferization.clone %arg0 : memref to memref - memref.dealloc %arg0 : memref - return %1 : memref -} - -// CHECK: %[[CONST:.*]] = arith.constant -// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]] -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) -// CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]] -// CHECK-NEXT: memref.dealloc %[[ARG]] -// CHECK-NEXT: return %[[ALLOC]] - -// ----- - -func.func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> { -// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}} - %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32> - memref.dealloc %arg0 : memref<*xf32> - return %1 : memref<*xf32> -} - -// ----- - -// CHECK-LABEL: func @conversion_with_layout_map( -// CHECK-SAME: %[[ARG:.*]]: memref> -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]] -// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref -// CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC]] : memref to memref> -// CHECK: memref.copy -// CHECK: memref.dealloc -// CHECK: return %[[CASTED]] -func.func @conversion_with_layout_map(%arg0 : memref>) -> memref> { - %1 = bufferization.clone %arg0 : memref> to memref> - memref.dealloc %arg0 : memref> - return %1 : memref> -} - -// ----- - -// This bufferization.clone cannot be lowered because a buffer with this layout -// map cannot be allocated (or casted to). - -func.func @conversion_with_invalid_layout_map(%arg0 : memref>) - -> memref> { -// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}} - %1 = bufferization.clone %arg0 : memref> to memref> - memref.dealloc %arg0 : memref> - return %1 : memref> -} - -// ----- +// RUN: mlir-opt -verify-diagnostics -bufferization-lower-deallocations -split-input-file %s | FileCheck %s // CHECK-LABEL: func @conversion_dealloc_empty func.func @conversion_dealloc_empty() { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -11967,7 +11967,9 @@ ":MemRefDialect", ":MemRefUtils", ":Pass", + ":SCFDialect", ":SideEffectInterfaces", + ":Support", ":TensorDialect", ":Transforms", ":ViewLikeInterface", @@ -11987,6 +11989,7 @@ deps = [ ":ArithDialect", ":BufferizationDialect", + ":BufferizationTransforms", ":ConversionPassIncGen", ":FuncDialect", ":IR",