diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -511,6 +511,12 @@ return newOp; } +/// Return `true` if the buffer of given OpResult should be deallocated. This +/// function should be called during `BufferizableOpInterface::bufferize` +/// implementations that allocate a new buffer for the given OpResult. +bool shouldDeallocateOpResult(OpResult opResult, + const BufferizationOptions &options); + /// Return a MemRefType to which the type of the given value can be bufferized. /// /// If possible, op bufferization implementations should not use this function diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -206,6 +206,29 @@ return success(); } +bool bufferization::shouldDeallocateOpResult( + OpResult opResult, const BufferizationOptions &options) { + Operation *op = opResult.getOwner(); + assert(options.dynCastBufferizableOp(op).bufferizesToAllocation(opResult) && + "expected that op allocates"); + + AnalysisState analysisState(options); + if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) { + // AllocTensorOp has one result. + ArrayAttr escapeAttr = + op->getAttr(BufferizationDialect::kEscapeAttrName).cast(); + return !escapeAttr[0].cast().getValue(); + } + + // No "escape" annotation found. + if (options.createDeallocs) { + // Perform an ad-hoc analysis. + return !analysisState.isTensorYielded(opResult); + } + + return false; +} + //===----------------------------------------------------------------------===// // OpFilter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -204,22 +204,8 @@ } // Should the buffer be deallocated? - AnalysisState analysisState(options); - bool dealloc; - if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) { - // AllocTensorOp has one result. - ArrayAttr escapeAttr = - op->getAttr(BufferizationDialect::kEscapeAttrName).cast(); - dealloc = !escapeAttr[0].cast().getValue(); - } else { - // No "escape" annotation found. - if (options.createDeallocs) { - // Perform an ad-hoc analysis. - dealloc = !analysisState.isTensorYielded(getResult()); - } else { - dealloc = false; - } - } + bool dealloc = + shouldDeallocateOpResult(getResult().cast(), options); // Replace op. replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -363,9 +363,17 @@ struct FromElementsOpInterface : public BufferizableOpInterface::ExternalModel { + + bool bufferizesToAllocation(Operation *op, OpResult opResult) const { + return true; + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto fromElementsOp = cast(op); + // Should the buffer be deallocated? + bool dealloc = shouldDeallocateOpResult( + fromElementsOp.getResult().cast(), options); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != static_cast(0)) @@ -376,11 +384,10 @@ auto tensorType = fromElementsOp.getType().cast(); auto shape = tensorType.getShape(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. - AnalysisState analysisState(options); - FailureOr tensorAlloc = allocateTensorForShapedValue( - rewriter, loc, fromElementsOp.getResult(), - analysisState.isTensorYielded(fromElementsOp.getResult()), options, - /*copy=*/false); + FailureOr tensorAlloc = + allocateTensorForShapedValue(rewriter, loc, fromElementsOp.getResult(), + /*escape=*/!dealloc, options, + /*copy=*/false); if (failed(tensorAlloc)) return failure(); auto memrefType = @@ -416,6 +423,7 @@ indices); replaceOpWithBufferizedValues(rewriter, op, buffer); + return success(); } }; @@ -424,9 +432,17 @@ struct GenerateOpInterface : public BufferizableOpInterface::ExternalModel { + + bool bufferizesToAllocation(Operation *op, OpResult opResult) const { + return true; + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto generateOp = cast(op); + // Should the buffer be deallocated? + bool dealloc = shouldDeallocateOpResult( + generateOp.getResult().cast(), options); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != static_cast(0)) @@ -436,11 +452,10 @@ // Allocate memory. Location loc = op->getLoc(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. - AnalysisState analysisState(options); - FailureOr tensorAlloc = allocateTensorForShapedValue( - rewriter, loc, generateOp.getResult(), - analysisState.isTensorYielded(generateOp.getResult()), options, - /*copy=*/false); + FailureOr tensorAlloc = + allocateTensorForShapedValue(rewriter, loc, generateOp.getResult(), + /*escape=*/!dealloc, options, + /*copy=*/false); if (failed(tensorAlloc)) return failure(); auto memrefType = @@ -484,6 +499,7 @@ parallelBody->getArguments()); replaceOpWithBufferizedValues(rewriter, op, buffer); + return success(); } }; diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -217,3 +217,22 @@ // CHECK: } return } + +// ----- + +// CHECK-LABEL: func @dealloc_generate_buffer +func.func @dealloc_generate_buffer(%arg: tensor<*xf32>, %sz: index, %idx: index) + -> index +{ + // CHECK: memref.alloc + // CHECK: scf.parallel + // CHECK: memref.load + // CHECK: memref.dealloc + %0 = tensor.generate %sz { + ^bb0(%i : index): + %elem = tensor.dim %arg, %i : tensor<*xf32> + tensor.yield %elem : index + } : tensor + %r = tensor.extract %0[%idx] : tensor + return %r : index +}