diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -488,14 +488,14 @@ The memrefs to be deallocated must be the originally allocated memrefs, however, the memrefs to be retained may be arbitrary memrefs. - Returns a list of conditions corresponding to the list of memrefs which - indicates the new ownerships, i.e., if the memref was deallocated the - ownership was dropped (set to 'false') and otherwise will be the same as the - input condition. + Returns a list of conditions corresponding to the list of retained memrefs + which indicates their ownerships, i.e., the disjunction of ownership values + of all aliases in the list of memrefs to be deallocated. If there is no + alias, the result will be 'false' Example: ```mlir - %0:2 = bufferization.dealloc %a0, %a1 if %cond0, %cond1 retain %r0, %r1 : + bufferization.dealloc %a0, %a1 if %cond0, %cond1 retain %r0, %r1 : memref<2xf32>, memref<4xi32> retain memref, memref ``` Deallocation will be called on `%a0` if `%cond0` is 'true' and neither `%r0` diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -97,71 +97,82 @@ /// /// Example: /// ``` - /// %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) + /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) /// ``` /// is lowered to /// ``` /// scf.if %arg1 { /// memref.dealloc %arg0 : memref<2xf32> /// } - /// %0 = arith.constant false /// ``` LogicalResult rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - rewriter.create(op.getLoc(), adaptor.getConditions()[0], - [&](OpBuilder &builder, Location loc) { - builder.create( - loc, adaptor.getMemrefs()[0]); - builder.create(loc); - }); - rewriter.replaceOpWithNewOp(op, - rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) { + builder.create(loc, adaptor.getMemrefs()[0]); + builder.create(loc); + }); return success(); } /// Lowering that supports all features the dealloc operation has to offer. It /// computes the base pointer of each memref (as an index), stores them in a /// new memref and passes it to the helper function generated in - /// 'buildDeallocationHelperFunction'. The two return values are used as - /// condition for the scf if operation containing the memref deallocate and as - /// replacement for the original bufferization dealloc respectively. + /// 'buildDeallocationHelperFunction'. The results are stored in two memrefs + /// of booleans passed as arguments. The first stores the condition under + /// which the memref should be deallocated, the second one stores the + /// ownership of the retained values which can be used to replace the result + /// values of the `bufferization.dealloc` operation. /// /// Example: /// ``` /// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) - /// if (%arg3, %arg4) retain (%arg2 : memref<1xf32>) + /// if (%arg3, %arg4) + /// retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) /// ``` /// lowers to (simplified): /// ``` /// %c0 = arith.constant 0 : index /// %c1 = arith.constant 1 : index /// %alloc = memref.alloc() : memref<2xindex> - /// %alloc_0 = memref.alloc() : memref<1xindex> + /// %alloc_0 = memref.alloc() : memref<2xi1> + /// %alloc_1 = memref.alloc() : memref<2xindex> /// %intptr = memref.extract_aligned_pointer_as_index %arg0 /// memref.store %intptr, %alloc[%c0] : memref<2xindex> - /// %intptr_1 = memref.extract_aligned_pointer_as_index %arg1 - /// memref.store %intptr_1, %alloc[%c1] : memref<2xindex> - /// %intptr_2 = memref.extract_aligned_pointer_as_index %arg2 - /// memref.store %intptr_2, %alloc_0[%c0] : memref<1xindex> + /// %intptr_2 = memref.extract_aligned_pointer_as_index %arg1 + /// memref.store %intptr_2, %alloc[%c1] : memref<2xindex> + /// memref.store %arg3, %alloc_0[%c0] : memref<2xi1> + /// memref.store %arg4, %alloc_0[%c1] : memref<2xi1> + /// %intptr_5 = memref.extract_aligned_pointer_as_index %arg2 + /// memref.store %intptr_5, %alloc_1[%c0] : memref<2xindex> + /// %intptr_7 = memref.extract_aligned_pointer_as_index %arg5 + /// memref.store %intptr_7, %alloc_1[%c1] : memref<2xindex> /// %cast = memref.cast %alloc : memref<2xindex> to memref - /// %cast_4 = memref.cast %alloc_0 : memref<1xindex> to memref - /// %0:2 = call @dealloc_helper(%cast, %cast_4, %c0) - /// %1 = arith.andi %0#0, %arg3 : i1 - /// %2 = arith.andi %0#1, %arg3 : i1 - /// scf.if %1 { + /// %cast_9 = memref.cast %alloc_0 : memref<2xi1> to memref + /// %cast_10 = memref.cast %alloc_1 : memref<2xindex> to memref + /// %alloc_11 = memref.alloc() : memref<2xi1> + /// %alloc_12 = memref.alloc() : memref<2xi1> + /// %cast_13 = memref.cast %alloc_11 : memref<2xi1> to memref + /// %cast_14 = memref.cast %alloc_12 : memref<2xi1> to memref + /// call @dealloc_helper(%cast, %cast_10, %cast_9, %cast_13, %cast_14) : (...) + /// %0 = memref.load %alloc_11[%c0] : memref<2xi1> + /// %1 = memref.load %alloc_12[%c0] : memref<2xi1> + /// scf.if %0 { /// memref.dealloc %arg0 : memref<2xf32> /// } - /// %3:2 = call @dealloc_helper(%cast, %cast_4, %c1) - /// %4 = arith.andi %3#0, %arg4 : i1 - /// %5 = arith.andi %3#1, %arg4 : i1 - /// scf.if %4 { + /// %2 = memref.load %alloc_11[%c1] : memref<2xi1> + /// %3 = memref.load %alloc_12[%c1] : memref<2xi1> + /// scf.if %2 { /// memref.dealloc %arg1 : memref<5xf32> /// } /// memref.dealloc %alloc : memref<2xindex> - /// memref.dealloc %alloc_0 : memref<1xindex> - /// // replace %0#0 with %2 - /// // replace %0#1 with %5 + /// memref.dealloc %alloc_1 : memref<2xindex> + /// memref.dealloc %alloc_0 : memref<2xi1> + /// memref.dealloc %alloc_11 : memref<2xi1> + /// memref.dealloc %alloc_12 : memref<2xi1> + /// // replace %0#0 with %1 + /// // replace %0#1 with %3 /// ``` LogicalResult rewriteGeneralCase(bufferization::DeallocOp op, OpAdaptor adaptor, @@ -175,6 +186,9 @@ Value toDeallocMemref = rewriter.create( op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, rewriter.getIndexType())); + Value conditionMemref = rewriter.create( + op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()}, + rewriter.getI1Type())); Value toRetainMemref = rewriter.create( op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, rewriter.getIndexType())); @@ -193,6 +207,11 @@ rewriter.create(op.getLoc(), memrefAsIdx, toDeallocMemref, getConstValue(i)); } + + for (auto [i, cond] : llvm::enumerate(adaptor.getConditions())) + rewriter.create(op.getLoc(), cond, conditionMemref, + getConstValue(i)); + for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) { Value memrefAsIdx = rewriter.create(op.getLoc(), @@ -208,21 +227,44 @@ op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), toDeallocMemref); + Value castedCondsMemref = rewriter.create( + op->getLoc(), + MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), + conditionMemref); Value castedRetainMemref = rewriter.create( op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), toRetainMemref); + Value deallocCondsMemref = rewriter.create( + op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, + rewriter.getI1Type())); + Value retainCondsMemref = rewriter.create( + op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, + rewriter.getI1Type())); + + Value castedDeallocCondsMemref = rewriter.create( + op->getLoc(), + MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), + deallocCondsMemref); + Value castedRetainCondsMemref = rewriter.create( + op->getLoc(), + MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), + retainCondsMemref); + + rewriter.create( + op.getLoc(), deallocHelperFunc, + SmallVector{castedDeallocMemref, castedRetainMemref, + castedCondsMemref, castedDeallocCondsMemref, + castedRetainCondsMemref}); + SmallVector replacements; - for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) { - auto callOp = rewriter.create( - op.getLoc(), deallocHelperFunc, - SmallVector{castedDeallocMemref, castedRetainMemref, - getConstValue(i)}); - Value shouldDealloc = rewriter.create( - op.getLoc(), callOp.getResult(0), adaptor.getConditions()[i]); - Value ownership = rewriter.create( - op.getLoc(), callOp.getResult(1), adaptor.getConditions()[i]); + for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) { + Value idxValue = getConstValue(i); + Value shouldDealloc = rewriter.create( + op.getLoc(), deallocCondsMemref, idxValue); + Value ownership = rewriter.create( + op.getLoc(), retainCondsMemref, idxValue); replacements.push_back(ownership); rewriter.create( op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { @@ -235,6 +277,9 @@ // Deallocation will not be run on code after this stage. rewriter.create(op.getLoc(), toDeallocMemref); rewriter.create(op.getLoc(), toRetainMemref); + rewriter.create(op.getLoc(), conditionMemref); + rewriter.create(op.getLoc(), deallocCondsMemref); + rewriter.create(op.getLoc(), retainCondsMemref); rewriter.replaceOp(op, replacements); return success(); @@ -261,70 +306,95 @@ /// Build a helper function per compilation unit that can be called at /// bufferization dealloc sites to determine aliasing and ownership. /// - /// The generated function takes two memrefs of indices and one index value as - /// arguments and returns two boolean values: - /// * The first memref argument A should contain the result of the + /// The generated function takes two memrefs of indices and three memrefs of + /// booleans as arguments: + /// * The first argument A should contain the result of the /// extract_aligned_pointer_as_index operation applied to the memrefs to be /// deallocated - /// * The second memref argument B should contain the result of the + /// * The second argument B should contain the result of the /// extract_aligned_pointer_as_index operation applied to the memrefs to be /// retained - /// * The index argument I represents the currently processed index of - /// memref A and is needed because aliasing with all previously deallocated - /// memrefs has to be checked to avoid double deallocation - /// * The first result indicates whether the memref at position I should be - /// deallocated - /// * The second result provides the updated ownership value corresponding - /// the the memref at position I + /// * The third argument C should contain the conditions as passed directly + /// to the deallocation operation. + /// * The fourth argument D is used to pass results to the caller. Those + /// represent the condition under which the memref at the corresponding + /// position in A should be deallocated. + /// * The fifth argument E is used to pass results to the caller. It + /// provides the ownership value corresponding the the memref at the same + /// position in B /// - /// This helper function is supposed to be called for each element in the list - /// of memrefs to be deallocated to determine the deallocation need and new - /// ownership indicator, but does not perform the deallocation itself. + /// This helper function is supposed to be called once for each + /// `bufferization.dealloc` operation to determine the deallocation need and + /// new ownership indicator for the retained values, but does not perform the + /// deallocation itself. /// - /// The first scf for loop in the body computes whether the memref at index I - /// aliases with any memref in the list of retained memrefs. - /// The second loop additionally checks whether one of the previously - /// deallocated memrefs aliases with the currently processed one. + /// The first scf for loop zero-initializes the output memref for aggregation. + /// The second scf for loop contains two more loops, the first of which + /// computes whether the memref at the index given by the outer loop aliases + /// with any memref in the list of retained memrefs. The second nested loop + /// additionally checks whether one of the previously deallocated memrefs + /// aliases with the currently processed one. /// /// Generated code: /// ``` - /// func.func @dealloc_helper(%arg0: memref, - /// %arg1: memref, - /// %arg2: index) -> (i1, i1) { + /// func.func @dealloc_helper( + /// %arg0: memref, + /// %arg1: memref, + /// %arg2: memref, + /// %arg3: memref, + /// %arg4: memref) { /// %c0 = arith.constant 0 : index /// %c1 = arith.constant 1 : index /// %true = arith.constant true - /// %dim = memref.dim %arg1, %c0 : memref - /// %0 = memref.load %arg0[%arg2] : memref - /// %1 = scf.for %i = %c0 to %dim step %c1 iter_args(%arg4 = %true) -> (i1){ - /// %4 = memref.load %arg1[%i] : memref - /// %5 = arith.cmpi ne, %4, %0 : index - /// %6 = arith.andi %arg4, %5 : i1 - /// scf.yield %6 : i1 + /// %false = arith.constant false + /// %dim = memref.dim %arg0, %c0 : memref + /// %dim_0 = memref.dim %arg1, %c0 : memref + /// scf.for %arg5 = %c0 to %dim_0 step %c1 { + /// memref.store %false, %arg4[%arg5] : memref /// } - /// %2 = scf.for %i = %c0 to %arg2 step %c1 iter_args(%arg4 = %1) -> (i1) { - /// %4 = memref.load %arg0[%i] : memref - /// %5 = arith.cmpi ne, %4, %0 : index - /// %6 = arith.andi %arg4, %5 : i1 - /// scf.yield %6 : i1 + /// scf.for %arg5 = %c0 to %dim step %c1 { + /// %0 = memref.load %arg0[%arg5] : memref + /// %1 = memref.load %arg2[%arg5] : memref + /// %2 = scf.for %arg6 = %c0 to %dim_0 step %c1 + /// iter_args(%arg7 = %true) -> (i1) { + /// %5 = memref.load %arg1[%arg6] : memref + /// %6 = arith.cmpi eq, %5, %0 : index + /// scf.if %6 { + /// %9 = memref.load %arg4[%arg6] : memref + /// %10 = arith.ori %9, %1 : i1 + /// memref.store %10, %arg4[%arg6] : memref + /// } + /// %7 = arith.cmpi ne, %5, %0 : index + /// %8 = arith.andi %arg7, %7 : i1 + /// scf.yield %8 : i1 + /// } + /// %3 = scf.for %arg6 = %c0 to %arg5 step %c1 + /// iter_args(%arg7 = %2) -> (i1) { + /// %5 = memref.load %arg0[%arg6] : memref + /// %6 = arith.cmpi ne, %5, %0 : index + /// %7 = arith.andi %arg7, %6 : i1 + /// scf.yield %7 : i1 + /// } + /// %4 = arith.andi %3, %1 : i1 + /// memref.store %4, %arg3[%arg5] : memref /// } - /// %3 = arith.xori %1, %true : i1 - /// return %2, %3 : i1, i1 + /// return /// } /// ``` static func::FuncOp buildDeallocationHelperFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable) { - Type idxType = builder.getIndexType(); - Type memrefArgType = MemRefType::get({ShapedType::kDynamic}, idxType); - SmallVector argTypes{memrefArgType, memrefArgType, idxType}; + Type indexMemrefType = + MemRefType::get({ShapedType::kDynamic}, builder.getIndexType()); + Type boolMemrefType = + MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); + SmallVector argTypes{indexMemrefType, indexMemrefType, boolMemrefType, + boolMemrefType, boolMemrefType}; builder.clearInsertionPoint(); // Generate the func operation itself. auto helperFuncOp = func::FuncOp::create( - loc, "dealloc_helper", - builder.getFunctionType(argTypes, - {builder.getI1Type(), builder.getI1Type()})); + loc, "dealloc_helper", builder.getFunctionType(argTypes, {})); symbolTable.insert(helperFuncOp); auto &block = helperFuncOp.getFunctionBody().emplaceBlock(); block.addArguments(argTypes, SmallVector(argTypes.size(), loc)); @@ -332,57 +402,101 @@ builder.setInsertionPointToStart(&block); Value toDeallocMemref = helperFuncOp.getArguments()[0]; Value toRetainMemref = helperFuncOp.getArguments()[1]; - Value idxArg = helperFuncOp.getArguments()[2]; + Value conditionMemref = helperFuncOp.getArguments()[2]; + Value deallocCondsMemref = helperFuncOp.getArguments()[3]; + Value retainCondsMemref = helperFuncOp.getArguments()[4]; // Insert some prerequisites. Value c0 = builder.create(loc, builder.getIndexAttr(0)); Value c1 = builder.create(loc, builder.getIndexAttr(1)); Value trueValue = builder.create(loc, builder.getBoolAttr(true)); + Value falseValue = + builder.create(loc, builder.getBoolAttr(false)); + Value toDeallocSize = + builder.create(loc, toDeallocMemref, c0); Value toRetainSize = builder.create(loc, toRetainMemref, c0); - Value toDealloc = - builder.create(loc, toDeallocMemref, idxArg); - - // Build the first for loop that computes aliasing with retained memrefs. - Value noRetainAlias = - builder - .create( - loc, c0, toRetainSize, c1, trueValue, - [&](OpBuilder &builder, Location loc, Value i, - ValueRange iterArgs) { - Value retainValue = - builder.create(loc, toRetainMemref, i); - Value doesntAlias = builder.create( - loc, arith::CmpIPredicate::ne, retainValue, toDealloc); - Value yieldValue = builder.create( - loc, iterArgs[0], doesntAlias); - builder.create(loc, yieldValue); - }) - .getResult(0); - - // Build the second for loop that adds aliasing with previously deallocated - // memrefs. - Value noAlias = - builder - .create( - loc, c0, idxArg, c1, noRetainAlias, - [&](OpBuilder &builder, Location loc, Value i, - ValueRange iterArgs) { - Value prevDeallocValue = - builder.create(loc, toDeallocMemref, i); - Value doesntAlias = builder.create( - loc, arith::CmpIPredicate::ne, prevDeallocValue, - toDealloc); - Value yieldValue = builder.create( - loc, iterArgs[0], doesntAlias); - builder.create(loc, yieldValue); - }) - .getResult(0); - - Value ownership = - builder.create(loc, noRetainAlias, trueValue); - builder.create(loc, SmallVector{noAlias, ownership}); + builder.create( + loc, c0, toRetainSize, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + builder.create(loc, falseValue, retainCondsMemref, + i); + builder.create(loc); + }); + + builder.create( + loc, c0, toDeallocSize, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value outerIter, + ValueRange iterArgs) { + Value toDealloc = + builder.create(loc, toDeallocMemref, outerIter); + Value cond = + builder.create(loc, conditionMemref, outerIter); + + // Build the first for loop that computes aliasing with retained + // memrefs. + Value noRetainAlias = + builder + .create( + loc, c0, toRetainSize, c1, trueValue, + [&](OpBuilder &builder, Location loc, Value i, + ValueRange iterArgs) { + Value retainValue = builder.create( + loc, toRetainMemref, i); + Value doesAlias = builder.create( + loc, arith::CmpIPredicate::eq, retainValue, + toDealloc); + builder.create( + loc, doesAlias, + [&](OpBuilder &builder, Location loc) { + Value retainCondValue = + builder.create( + loc, retainCondsMemref, i); + Value aggregatedRetainCond = + builder.create( + loc, retainCondValue, cond); + builder.create( + loc, aggregatedRetainCond, retainCondsMemref, + i); + builder.create(loc); + }); + Value doesntAlias = builder.create( + loc, arith::CmpIPredicate::ne, retainValue, + toDealloc); + Value yieldValue = builder.create( + loc, iterArgs[0], doesntAlias); + builder.create(loc, yieldValue); + }) + .getResult(0); + + // Build the second for loop that adds aliasing with previously + // deallocated memrefs. + Value noAlias = + builder + .create( + loc, c0, outerIter, c1, noRetainAlias, + [&](OpBuilder &builder, Location loc, Value i, + ValueRange iterArgs) { + Value prevDeallocValue = builder.create( + loc, toDeallocMemref, i); + Value doesntAlias = builder.create( + loc, arith::CmpIPredicate::ne, prevDeallocValue, + toDealloc); + Value yieldValue = builder.create( + loc, iterArgs[0], doesntAlias); + builder.create(loc, yieldValue); + }) + .getResult(0); + + Value shouldDealoc = + builder.create(loc, noAlias, cond); + builder.create(loc, shouldDealoc, deallocCondsMemref, + outerIter); + builder.create(loc); + }); + + builder.create(loc); return helperFuncOp; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -755,7 +755,8 @@ ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { DeallocOpAdaptor adaptor(operands, attributes, properties, regions); - inferredReturnTypes = SmallVector(adaptor.getConditions().getTypes()); + inferredReturnTypes = SmallVector(adaptor.getRetained().size(), + IntegerType::get(context, 1)); return success(); } @@ -766,44 +767,43 @@ return success(); } +static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, + ArrayRef memrefs, + ArrayRef conditions) { + if (deallocOp.getMemrefs() == memrefs) + return failure(); + + deallocOp.getMemrefsMutable().assign(memrefs); + deallocOp.getConditionsMutable().assign(conditions); + return success(); +} + namespace { -/// Remove duplicate values in the list of retained memrefs as well as the list -/// of memrefs to be deallocated. For the latter, we need to make sure the -/// corresponding condition values match as well, or otherwise have to combine -/// them (by computing the disjunction of them). +/// Remove duplicate values in the list of memrefs to be deallocated. We need to +/// make sure the corresponding condition value is updated accordingly since +/// their two conditions might not cover the same set of cases. In that case, we +/// have to combine them (by computing the disjunction of them). /// Example: /// ```mlir -/// %0:2 = bufferization.dealloc (%arg0, %arg0 : ...) -/// if (%arg1, %arg2) -/// retain (%arg3, %arg3 : ...) +/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2) /// ``` /// is canonicalized to /// ```mlir /// %0 = arith.ori %arg1, %arg2 : i1 -/// %1 = bufferization.dealloc (%arg0 : memref<2xi32>) -/// if (%0) -/// retain (%arg3 : memref<2xi32>) +/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0) /// ``` -struct DeallocRemoveDuplicates : public OpRewritePattern { +struct DeallocRemoveDuplicateDeallocMemrefs + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { // Unique memrefs to be deallocated. - DenseSet retained(deallocOp.getRetained().begin(), - deallocOp.getRetained().end()); DenseMap memrefToCondition; - SmallVector newMemrefs, newConditions, newRetained; - SmallVector resultIndices(deallocOp.getMemrefs().size(), -1); + SmallVector newMemrefs, newConditions; for (auto [i, memref, cond] : llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) { - if (retained.contains(memref)) { - rewriter.replaceAllUsesWith(deallocOp.getResult(i), - deallocOp.getConditions()[i]); - continue; - } - if (memrefToCondition.count(memref)) { // If the dealloc conditions don't match, we need to make sure that the // dealloc happens on the union of cases. @@ -816,42 +816,102 @@ newMemrefs.push_back(memref); newConditions.push_back(cond); } - resultIndices[i] = memrefToCondition[memref]; } + // Return failure if we don't change anything such that we don't run into an + // infinite loop of pattern applications. + return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions); + } +}; + +/// Remove duplicate values in the list of retained memrefs. We need to make +/// sure the corresponding result condition value is replaced properly. +/// Example: +/// ```mlir +/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...) +/// ``` +/// is canonicalized to +/// ```mlir +/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>) +/// ``` +struct DeallocRemoveDuplicateRetainedMemrefs + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { // Unique retained values - DenseSet seen; - for (auto retained : deallocOp.getRetained()) { - if (!seen.contains(retained)) { - seen.insert(retained); - newRetained.push_back(retained); + DenseMap seen; + SmallVector newRetained; + SmallVector resultReplacementIdx; + for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) { + if (seen.count(retained)) { + resultReplacementIdx.push_back(seen[retained]); + continue; } + + seen[retained] = i; + newRetained.push_back(retained); + resultReplacementIdx.push_back(i); } // Return failure if we don't change anything such that we don't run into an // infinite loop of pattern applications. - if (newConditions.size() == deallocOp.getConditions().size() && - newRetained.size() == deallocOp.getRetained().size()) + if (newRetained.size() == deallocOp.getRetained().size()) return failure(); // We need to create a new op because the number of results is always the // same as the number of condition operands. - auto newDealloc = rewriter.create(deallocOp.getLoc(), newMemrefs, - newConditions, newRetained); - for (auto [i, newIdx] : llvm::enumerate(resultIndices)) - if (newIdx != -1) - rewriter.replaceAllUsesWith(deallocOp.getResult(i), - newDealloc.getResult(newIdx)); - - rewriter.eraseOp(deallocOp); + auto newDeallocOp = + rewriter.create(deallocOp.getLoc(), deallocOp.getMemrefs(), + deallocOp.getConditions(), newRetained); + SmallVector replacements( + llvm::map_range(resultReplacementIdx, [&](unsigned idx) { + return newDeallocOp.getUpdatedConditions()[idx]; + })); + rewriter.replaceOp(deallocOp, replacements); return success(); } }; +/// Remove memrefs to be deallocated that are also present in the retained list +/// since they will always alias and thus never actually be deallocated. +/// Example: +/// ```mlir +/// %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) retain (%arg0 : ...) +/// ``` +/// is canonicalized to +/// ```mlir +/// %0 = bufferization.dealloc retain (%arg0 : ...) +/// ``` +struct DeallocRemoveDeallocMemrefsContainedInRetained + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + // Unique memrefs to be deallocated. + DenseSet retained(deallocOp.getRetained().begin(), + deallocOp.getRetained().end()); + SmallVector newMemrefs, newConditions; + for (auto [memref, cond] : + llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + if (!retained.contains(memref)) { + newMemrefs.push_back(memref); + newConditions.push_back(cond); + } + } + + // Return failure if we don't change anything such that we don't run into an + // infinite loop of pattern applications. + return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions); + } +}; + /// Erase deallocation operations where the variadic list of memrefs to -/// deallocate is emtpy. Example: +/// deallocate is empty. Example: /// ```mlir -/// bufferization.dealloc retain (%arg0: memref<2xi32>) +/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>) /// ``` struct EraseEmptyDealloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -859,7 +919,11 @@ LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { if (deallocOp.getMemrefs().empty()) { - rewriter.eraseOp(deallocOp); + Value constFalse = rewriter.create( + deallocOp.getLoc(), rewriter.getBoolAttr(false)); + rewriter.replaceOp( + deallocOp, SmallVector(deallocOp.getUpdatedConditions().size(), + constFalse)); return success(); } return failure(); @@ -871,12 +935,12 @@ /// /// Example: /// ``` -/// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) +/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) /// if (%arg2, %false) /// ``` /// becomes /// ``` -/// %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2) +/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2) /// ``` struct EraseAlwaysFalseDealloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -884,32 +948,15 @@ LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { SmallVector newMemrefs, newConditions; - SmallVector replacements; - - for (auto [res, memref, cond] : - llvm::zip(deallocOp.getUpdatedConditions(), deallocOp.getMemrefs(), - deallocOp.getConditions())) { - if (matchPattern(cond, m_Zero())) { - replacements.push_back(cond); - continue; + for (auto [memref, cond] : + llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + if (!matchPattern(cond, m_Zero())) { + newMemrefs.push_back(memref); + newConditions.push_back(cond); } - newMemrefs.push_back(memref); - newConditions.push_back(cond); - replacements.push_back({}); } - if (newMemrefs.size() == deallocOp.getMemrefs().size()) - return failure(); - - auto newDeallocOp = rewriter.create( - deallocOp.getLoc(), newMemrefs, newConditions, deallocOp.getRetained()); - unsigned i = 0; - for (auto &repl : replacements) - if (!repl) - repl = newDeallocOp.getResult(i++); - - rewriter.replaceOp(deallocOp, replacements); - return success(); + return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions); } }; @@ -917,9 +964,10 @@ void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir --- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir +++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir @@ -81,78 +81,104 @@ // CHECK-LABEL: func @conversion_dealloc_simple // CHECK-SAME: [[ARG0:%.+]]: memref<2xf32> // CHECK-SAME: [[ARG1:%.+]]: i1 -func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) -> i1 { - %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) - return %0 : i1 +func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) { + bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) + return } // CHECk: scf.if [[ARG1]] { // CHECk-NEXT: memref.dealloc [[ARG0]] : memref<2xf32> // CHECk-NEXT: } -// CHECk-NEXT: [[FALSE:%.+]] = arith.constant false -// CHECk-NEXT: return [[FALSE]] : i1 +// CHECk-NEXT: return // ----- -func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1) -> (i1, i1) { - %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2 : memref<1xf32>) +func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) return %0#0, %0#1 : i1, i1 } // CHECK-LABEL: func @conversion_dealloc_multiple_memrefs_and_retained -// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>, -// CHECK-SAME: [[ARG1:%.+]]: memref<5xf32>, -// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>, -// CHECK-SAME: [[ARG3:%.+]]: i1, -// CHECK-SAME: [[ARG4:%.+]]: i1 +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<5xf32>, +// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1, +// CHECK-SAME: [[ARG5:%.+]]: memref<2xf32>) // CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex> -// CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<1xindex> +// CHECK: [[CONDS:%.+]] = memref.alloc() : memref<2xi1> +// CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<2xindex> // CHECK-DAG: [[V0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]] // CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index // CHECK-DAG: memref.store [[V0]], [[TO_DEALLOC_MR]][[[C0]]] // CHECK-DAG: [[V1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]] // CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index // CHECK-DAG: memref.store [[V1]], [[TO_DEALLOC_MR]][[[C1]]] +// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index +// CHECK-DAG: memref.store [[ARG3]], [[CONDS]][[[C0]]] +// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index +// CHECK-DAG: memref.store [[ARG4]], [[CONDS]][[[C1]]] // CHECK-DAG: [[V2:%.+]] = memref.extract_aligned_pointer_as_index [[ARG2]] // CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index // CHECK-DAG: memref.store [[V2]], [[TO_RETAIN_MR]][[[C0]]] +// CHECK-DAG: [[V3:%.+]] = memref.extract_aligned_pointer_as_index [[ARG5]] +// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index +// CHECK-DAG: memref.store [[V3]], [[TO_RETAIN_MR]][[[C1]]] // CHECK-DAG: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref -// CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<1xindex> to memref -// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index -// CHECK: [[RES0:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C0]]) -// CHECK: [[SHOULD_DEALLOC_0:%.+]] = arith.andi [[RES0]]#0, [[ARG3]] -// CHECK: [[OWNERSHIP0:%.+]] = arith.andi [[RES0]]#1, [[ARG3]] +// CHECK-DAG: [[CAST_CONDS:%.+]] = memref.cast [[CONDS]] : memref<2xi1> to memref +// CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<2xindex> to memref +// CHECK: [[DEALLOC_CONDS:%.+]] = memref.alloc() : memref<2xi1> +// CHECK: [[RETAIN_CONDS:%.+]] = memref.alloc() : memref<2xi1> +// CHECK: [[CAST_DEALLOC_CONDS:%.+]] = memref.cast [[DEALLOC_CONDS]] : memref<2xi1> to memref +// CHECK: [[CAST_RETAIN_CONDS:%.+]] = memref.cast [[RETAIN_CONDS]] : memref<2xi1> to memref +// CHECK: call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[CAST_CONDS]], [[CAST_DEALLOC_CONDS]], [[CAST_RETAIN_CONDS]]) +// CHECK: [[C0:%.+]] = arith.constant 0 : index +// CHECK: [[SHOULD_DEALLOC_0:%.+]] = memref.load [[DEALLOC_CONDS]][[[C0]]] +// CHECK: [[OWNERSHIP0:%.+]] = memref.load [[RETAIN_CONDS]][[[C0]]] // CHECK: scf.if [[SHOULD_DEALLOC_0]] { // CHECK: memref.dealloc %arg0 // CHECK: } // CHECK: [[C1:%.+]] = arith.constant 1 : index -// CHECK: [[RES1:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C1]]) -// CHECK: [[SHOULD_DEALLOC_1:%.+]] = arith.andi [[RES1:%.+]]#0, [[ARG4]] -// CHECK: [[OWNERSHIP1:%.+]] = arith.andi [[RES1:%.+]]#1, [[ARG4]] +// CHECK: [[SHOULD_DEALLOC_1:%.+]] = memref.load [[DEALLOC_CONDS]][[[C1]]] +// CHECK: [[OWNERSHIP1:%.+]] = memref.load [[RETAIN_CONDS]][[[C1]]] // CHECK: scf.if [[SHOULD_DEALLOC_1]] // CHECK: memref.dealloc [[ARG1]] // CHECK: } // CHECK: memref.dealloc [[TO_DEALLOC_MR]] // CHECK: memref.dealloc [[TO_RETAIN_MR]] +// CHECK: memref.dealloc [[CONDS]] +// CHECK: memref.dealloc [[DEALLOC_CONDS]] +// CHECK: memref.dealloc [[RETAIN_CONDS]] // CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]] // CHECK: func @dealloc_helper -// CHECK-SAME: [[ARG0:%.+]]: memref, [[ARG1:%.+]]: memref -// CHECK-SAME: [[ARG2:%.+]]: index -// CHECK-SAME: -> (i1, i1) -// CHECK: [[TO_RETAIN_SIZE:%.+]] = memref.dim [[ARG1]], %c0 -// CHECK: [[TO_DEALLOC:%.+]] = memref.load [[ARG0]][[[ARG2]]] : memref -// CHECK-NEXT: [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) { -// CHECK-NEXT: [[RETAIN_VAL:%.+]] = memref.load [[ARG1]][[[ITER]]] : memref -// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index -// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1 -// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1 +// CHECK-SAME: ([[TO_DEALLOC_MR:%.+]]: memref, [[TO_RETAIN_MR:%.+]]: memref, +// CHECK-SAME: [[CONDS:%.+]]: memref, [[DEALLOC_CONDS_OUT:%.+]]: memref, +// CHECK-SAME: [[RETAIN_CONDS_OUT:%.+]]: memref) +// CHECK: [[TO_DEALLOC_SIZE:%.+]] = memref.dim [[TO_DEALLOC_MR]], %c0 +// CHECK: [[TO_RETAIN_SIZE:%.+]] = memref.dim [[TO_RETAIN_MR]], %c0 +// CHECK: scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 { +// CHECK-NEXT: memref.store %false, [[RETAIN_CONDS_OUT]][[[ITER]]] // CHECK-NEXT: } -// CHECK-NEXT: [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[ARG2]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) { -// CHECK-NEXT: [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref -// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index -// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1 -// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1 +// CHECK: scf.for [[OUTER_ITER:%.+]] = %c0 to [[TO_DEALLOC_SIZE]] step %c1 { +// CHECK: [[TO_DEALLOC:%.+]] = memref.load [[TO_DEALLOC_MR]][[[OUTER_ITER]]] +// CHECK-NEXT: [[COND:%.+]] = memref.load [[CONDS]][[[OUTER_ITER]]] +// CHECK-NEXT: [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) { +// CHECK-NEXT: [[RETAIN_VAL:%.+]] = memref.load [[TO_RETAIN_MR]][[[ITER]]] : memref +// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi eq, [[RETAIN_VAL]], [[TO_DEALLOC]] : index +// CHECK-NEXT: scf.if [[DOES_ALIAS]] +// CHECK-NEXT: [[RETAIN_COND:%.+]] = memref.load [[RETAIN_CONDS_OUT]][[[ITER]]] +// CHECK-NEXT: [[AGG_RETAIN_COND:%.+]] = arith.ori [[RETAIN_COND]], [[COND]] : i1 +// CHECK-NEXT: memref.store [[AGG_RETAIN_COND]], [[RETAIN_CONDS_OUT]][[[ITER]]] +// CHECK-NEXT: } +// CHECK-NEXT: [[DOES_NOT_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index +// CHECK-NEXT: [[AGG_DOES_NOT_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_NOT_ALIAS]] : i1 +// CHECK-NEXT: scf.yield [[AGG_DOES_NOT_ALIAS]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[OUTER_ITER]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) { +// CHECK-NEXT: [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref +// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index +// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1 +// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: [[DEALLOC_COND:%.+]] = arith.andi [[SHOULD_DEALLOC]], [[COND]] : i1 +// CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]] // CHECK-NEXT: } -// CHECK-NEXT: [[OWNERSHIP:%.+]] = arith.xori [[NO_RETAIN_ALIAS]], %true : i1 -// CHECK-NEXT: return [[SHOULD_DEALLOC]], [[OWNERSHIP]] : i1, i1 +// CHECK-NEXT: return diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -282,44 +282,44 @@ // ----- -func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1, i1, i1) { +func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1) { %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg5, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>) - %1:2 = bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2) - return %0#0, %0#1, %0#2, %1#0, %1#1 : i1, i1, i1, i1, i1 + bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2) + return %0#0, %0#1, %0#2 : i1, i1, i1 } // CHECK-LABEL: func @dealloc_canonicalize_duplicates // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>, [[ARG5:%.+]]: memref<2xi32>) // CHECK-NEXT: [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG5]] : memref<2xi32>, memref<2xi32>) // CHECK-NEXT: [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1 -// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]]) -// CHECK-NEXT: return [[V0]]#0, [[V0]]#1, [[V0]]#1, [[V1]], [[V1]] : +// CHECK-NEXT: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]]) +// CHECK-NEXT: return [[V0]]#0, [[V0]]#1, [[V0]]#0 : // ----- -func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1) { +func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1) { %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>) - %1:2 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) + %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) bufferization.dealloc bufferization.dealloc retain (%arg0 : memref<2xi32>) - return %0, %1#0, %1#1 : i1, i1, i1 + return %0, %1 : i1, i1 } // CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false // CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) -// CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] : +// CHECK-NEXT: return [[FALSE]], [[V0]] : // ----- -func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) -> (i1, i1) { +func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) { %false = arith.constant false - %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2) - return %0#0, %0#1 : i1, i1 + bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2) + return } // CHECK-LABEL: func @dealloc_always_false_condition // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1) -// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false -// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]]) -// CHECK-NEXT: return [[FALSE]], [[V0]] : +// CHECK-NEXT: bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]]) +// CHECK-NEXT: return diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -106,8 +106,16 @@ // ----- -func.func @invalid_dealloc_memref_condition_mismatch(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) -> i1 { +func.func @invalid_dealloc_memref_condition_mismatch(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) { // expected-error @below{{must have the same number of conditions as memrefs to deallocate}} - %0 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2) - return %0 : i1 + bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2) + return +} + +// ----- + +func.func @invalid_dealloc_wrong_number_of_results(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) -> i1 { + // expected-error @below{{operation defines 1 results but was provided 2 to bind}} + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg2) retain (%arg1 : memref<4xi32>) + return %0#0 : i1 } diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -73,8 +73,8 @@ // CHECK: bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg3) retain (%arg4, %arg5 : memref, memref<*xf64>) %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg3) retain (%arg4, %arg5 : memref, memref<*xf64>) // CHECK: bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) - %1 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) + bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) // CHECK: bufferization.dealloc bufferization.dealloc - return %0, %1 : i1, i1 + return %0#0, %0#1 : i1, i1 }