diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -195,7 +195,10 @@ }]; let constructor = "mlir::createBufferizationToMemRefPass()"; - let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect"]; + let dependentDialects = [ + "arith::ArithDialect", "memref::MemRefDialect", "scf::SCFDialect", + "func::FuncDialect" + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -15,7 +15,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" @@ -77,11 +79,210 @@ return success(); } }; + +/// The DeallocOpConversion transforms all bufferization dealloc operations into +/// memref dealloc operations potentially guarded by scf if operations. +/// Additionally, memref extract_aligned_pointer_as_index and arith operations +/// are inserted to compute the guard conditions. In the general case, a helper +/// func is created to avoid quadratic code size explosion (relative to the +/// number of operands of the dealloc operation). +class DeallocOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + /// Helper to avoid creation of multiple constant operations producing the + /// same value. + static Value getConstValue(OpBuilder &builder, Location loc, + SmallVectorImpl &cache, int64_t value) { + if (cache.size() > value && cache[value]) + return cache[value]; + + if (cache.size() <= value) + cache.resize(value + 1); + + return cache[value] = builder.create( + loc, builder.getIndexAttr(value)); + } + + /// Lower a simple case avoiding the helper function. Ideally, static analysis + /// can provide enough aliasing information to split the dealloc operations up + /// into this simple case as much as possible before running this pass. + LogicalResult + rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.create(op.getLoc(), adaptor.getConditions()[0], + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, adaptor.getMemrefs()[0]); + builder.create(loc); + }); + rewriter.replaceOpWithNewOp(op, + rewriter.getBoolAttr(false)); + return success(); + } + + /// Lowering that supports all features the dealloc operation has to offer. + LogicalResult rewriteGeneralCase(bufferization::DeallocOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Allocate two memrefs holding the aligned pointer indices for the list of + // memrefs to be deallocated and the ones to be retained. These can then be + // passed to the helper function and for-loops can iterate over them, + // without storing them to memrefs before we could not use for-loops but + // only a completely unrolled version of it, leading to high code-size + // complexity. + Value toDeallocMemref = rewriter.create( + op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, + rewriter.getIndexType())); + Value toRetainMemref = rewriter.create( + op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, + rewriter.getIndexType())); + + SmallVector constantCache; + + // Extract the base pointers of the memrefs as indices to check for aliasing + // at runtime. + for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) { + Value memrefAsIdx = + rewriter.create(op.getLoc(), + toDealloc); + rewriter.create( + op.getLoc(), memrefAsIdx, toDeallocMemref, + getConstValue(rewriter, op.getLoc(), constantCache, i)); + } + for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) { + Value memrefAsIdx = + rewriter.create(op.getLoc(), + toRetain); + rewriter.create( + op.getLoc(), memrefAsIdx, toDeallocMemref, + getConstValue(rewriter, op.getLoc(), constantCache, i)); + } + + // Cast the allocated memrefs to dynamic shape because we want only one + // helper function no matter how many operands the bufferization.dealloc + // has. + Value castedDeallocMemref = rewriter.create( + op->getLoc(), + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), + toDeallocMemref); + Value castedRetainMemref = rewriter.create( + op->getLoc(), + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), + toRetainMemref); + + SmallVector replacements; + for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) { + auto callOp = rewriter.create( + op.getLoc(), "deallocHelper", + SmallVector{rewriter.getI1Type(), rewriter.getI1Type()}, + SmallVector{ + castedDeallocMemref, castedRetainMemref, + getConstValue(rewriter, op.getLoc(), constantCache, + adaptor.getMemrefs().size()), + getConstValue(rewriter, op.getLoc(), constantCache, + adaptor.getRetained().size()), + getConstValue(rewriter, op.getLoc(), constantCache, i)}); + Value shouldDealloc = rewriter.create( + op.getLoc(), callOp.getResult(0), adaptor.getConditions()[i]); + Value ownership = rewriter.create( + op.getLoc(), callOp.getResult(1), adaptor.getConditions()[i]); + replacements.push_back(ownership); + rewriter.create( + op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { + builder.create(loc, adaptor.getMemrefs()[i]); + builder.create(loc); + }); + } + + // Deallocate above allocated memrefs again to avoid memory leaks. + // Deallocation will not be run on code after this stage. + rewriter.create(op.getLoc(), toDeallocMemref); + rewriter.create(op.getLoc(), toRetainMemref); + + rewriter.replaceOp(op, replacements); + return success(); + } + +public: + LogicalResult + matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty()) + return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter); + + return rewriteGeneralCase(op, adaptor, rewriter); + } + + static void buildDeallocationHelperFunction(OpBuilder &builder, + Location loc) { + Type idxType = builder.getIndexType(); + Type memrefArgType = MemRefType::get({ShapedType::kDynamic}, idxType); + SmallVector argTypes{memrefArgType, memrefArgType, idxType, idxType, + idxType}; + // TODO: don't hardcode the name, allow for uniquing when this name is + // already present + auto helperFunc = builder.create( + loc, "deallocHelper", + builder.getFunctionType(argTypes, + {builder.getI1Type(), builder.getI1Type()})); + auto &block = helperFunc.getFunctionBody().emplaceBlock(); + block.addArguments(argTypes, SmallVector(argTypes.size(), loc)); + + builder.setInsertionPointToStart(&block); + Value toDeallocMemref = helperFunc.getArguments()[0]; + Value toRetainMemref = helperFunc.getArguments()[1]; + Value toRetainSize = helperFunc.getArguments()[3]; + Value idxArg = helperFunc.getArguments()[4]; + + Value c0 = builder.create(loc, builder.getIndexAttr(0)); + Value c1 = builder.create(loc, builder.getIndexAttr(1)); + Value trueValue = + builder.create(loc, builder.getBoolAttr(true)); + + Value toDealloc = + builder.create(loc, toDeallocMemref, idxArg); + Value noRetainAlias = + builder + .create( + loc, c0, toRetainSize, c1, trueValue, + [&](OpBuilder &builder, Location loc, Value i, + ValueRange iterArgs) { + Value retainValue = + builder.create(loc, toRetainMemref, i); + Value doesntAlias = builder.create( + loc, arith::CmpIPredicate::ne, retainValue, toDealloc); + Value yieldValue = builder.create( + loc, iterArgs[0], doesntAlias); + builder.create(loc, yieldValue); + }) + .getResult(0); + Value noAlias = + builder + .create( + loc, c0, idxArg, c1, noRetainAlias, + [&](OpBuilder &builder, Location loc, Value i, + ValueRange iterArgs) { + Value prevDeallocValue = + builder.create(loc, toDeallocMemref, i); + Value doesntAlias = builder.create( + loc, arith::CmpIPredicate::ne, prevDeallocValue, + toDealloc); + Value yieldValue = builder.create( + loc, iterArgs[0], doesntAlias); + builder.create(loc, yieldValue); + }) + .getResult(0); + Value ownership = + builder.create(loc, noRetainAlias, trueValue); + builder.create(loc, SmallVector{noAlias, ownership}); + } +}; } // namespace void mlir::populateBufferizationToMemRefConversionPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } namespace { @@ -90,12 +291,25 @@ BufferizationToMemRefPass() = default; void runOnOperation() override { + OpBuilder builder = + OpBuilder::atBlockBegin(&getOperation()->getRegion(0).front()); + + getOperation()->walk([&](bufferization::DeallocOp deallocOp) { + if (deallocOp.getMemrefs().size() > 1 || + !deallocOp.getRetained().empty()) { + DeallocOpConversion::buildDeallocationHelperFunction( + builder, getOperation()->getLoc()); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + RewritePatternSet patterns(&getContext()); populateBufferizationToMemRefConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addLegalOp(); + target.addLegalDialect(); target.addIllegalDialect(); if (failed(applyPartialConversion(getOperation(), target, diff --git a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt --- a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt +++ b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt @@ -9,6 +9,10 @@ LINK_LIBS PUBLIC MLIRBufferizationDialect + MLIRSCFDialect + MLIRFuncDialect + MLIRArithDialect + MLIRMemRefDialect MLIRPass MLIRTransforms ) diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir --- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir +++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir @@ -66,3 +66,70 @@ memref.dealloc %arg0 : memref> return %1 : memref> } + +// ----- + +// CHECK-NOT: func.func @deallocHelper +// CHECK-LABEL: func.func @conversion_dealloc_simple +func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) -> i1 { + %0 = bufferization.dealloc %arg0 if %arg1 : memref<2xf32> + return %0 : i1 +} + +// CHECk-NEXT: scf.if %arg1 { +// CHECk-NEXT: memref.dealloc %arg0 : memref<2xf32> +// CHECk-NEXT: } +// CHECk-NEXT: %false = arith.constant false +// CHECk-NEXT: return %false : i1 + +// ----- + +func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1) -> (i1, i1) { + %0:2 = bufferization.dealloc %arg0, %arg1 if %arg3, %arg4 retain %arg2 : memref<2xf32>, memref<5xf32> retain memref<1xf32> + return %0#0, %0#1 : i1, i1 +} + +// CHECK-LABEL: func.func @deallocHelper(%arg0: memref, %arg1: memref, %arg2: index, %arg3: index, %arg4: index) -> (i1, i1) +// CHECK: [[TO_DEALLOC:%.+]] = memref.load %arg0[%arg4] : memref +// CHECK-NEXT: [[NO_RETAIN_ALIAS:%.+]] = scf.for %arg5 = %c0 to %arg3 step %c1 iter_args(%arg6 = %true) -> (i1) { +// CHECK-NEXT: [[RETAIN_VAL:%.+]] = memref.load %arg1[%arg5] : memref +// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index +// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi %arg6, [[DOES_ALIAS]] : i1 +// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: [[SHOULD_DEALLOC:%.+]] = scf.for %arg5 = %c0 to %arg4 step %c1 iter_args(%arg6 = [[NO_RETAIN_ALIAS]]) -> (i1) { +// CHECK-NEXT: [[OTHER_DEALLOC_VAL:%.+]] = memref.load %arg0[%arg5] : memref +// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index +// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi %arg6, [[DOES_ALIAS]] : i1 +// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: [[OWNERSHIP:%.+]] = arith.xori [[NO_RETAIN_ALIAS]], %true : i1 +// CHECK-NEXT: return [[SHOULD_DEALLOC]], [[OWNERSHIP]] : i1, i1 + + +// CHECK-LABEL: func.func @conversion_dealloc_multiple_memrefs_and_retained +// CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex> +// CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<1xindex> +// CHECK: [[V0:%.+]] = memref.extract_aligned_pointer_as_index %arg0 +// CHECK: memref.store [[V0]], [[TO_DEALLOC_MR]][%c0] +// CHECK: [[V1:%.+]] = memref.extract_aligned_pointer_as_index %arg1 +// CHECK: memref.store [[V1]], [[TO_DEALLOC_MR]][%c1] +// CHECK: [[V2:%.+]] = memref.extract_aligned_pointer_as_index %arg2 +// CHECK: memref.store [[V2]], [[TO_DEALLOC_MR]][%c0] +// CHECK: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref +// CHECK: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<1xindex> to memref +// CHECK: [[RES0:%.+]]:2 = call @deallocHelper([[CAST_DEALLOC]], [[CAST_RETAIN]], %c2, %c1, %c0) +// CHECK: [[SHOULD_DEALLOC_0:%.+]] = arith.andi [[RES0]]#0, %arg3 +// CHECK: [[OWNERSHIP0:%.+]] = arith.andi [[RES0]]#1, %arg3 +// CHECK: scf.if [[SHOULD_DEALLOC_0]] { +// CHECK: memref.dealloc %arg0 +// CHECK: } +// CHECK: [[RES1:%.+]]:2 = call @deallocHelper([[CAST_DEALLOC]], [[CAST_RETAIN]], %c2, %c1, %c1) +// CHECK: [[SHOULD_DEALLOC_1:%.+]] = arith.andi [[RES1:%.+]]#0, %arg4 +// CHECK: [[OWNERSHIP1:%.+]] = arith.andi [[RES1:%.+]]#1, %arg4 +// CHECK: scf.if [[SHOULD_DEALLOC_1]] +// CHECK: memref.dealloc %arg1 +// CHECK: } +// CHECK: memref.dealloc [[TO_DEALLOC_MR]] +// CHECK: memref.dealloc [[TO_RETAIN_MR]] +// CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]] diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -11550,6 +11550,7 @@ ":IR", ":MemRefDialect", ":Pass", + ":SCFDialect", ":Support", ":Transforms", ],