diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h @@ -9,6 +9,8 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" + namespace mlir { class DialectRegistry; @@ -17,6 +19,11 @@ namespace comprehensive_bufferize { namespace tensor_ext { +struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep { + LogicalResult run(FuncOp funcOp, BufferizationState &state, + SmallVector &newOps) override; +}; + void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace tensor_ext diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -13,6 +13,8 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +using namespace mlir; + namespace mlir { namespace linalg { namespace comprehensive_bufferize { @@ -21,6 +23,20 @@ using tensor::ExtractSliceOp; using tensor::InsertSliceOp; +namespace { +/// Extra bufferization state that is required for bufferization of tensor ops. +struct TensorBufferizationState : public DialectBufferizationState { + /// InsertSliceOps that bufferize inplace and do not require a copy. + DenseSet insertSliceOpsWithoutCopy; +}; +} // namespace + +static TensorBufferizationState & +getTensorBufferizationState(BufferizationState &state) { + return state.getDialectState( + tensor::TensorDialect::getDialectNamespace()); +} + struct CastOpInterface : public BufferizableOpInterface::ExternalModel { @@ -374,6 +390,7 @@ // catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); + TensorBufferizationState &tensorState = getTensorBufferizationState(state); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -385,15 +402,8 @@ if (!dstMemref) return failure(); - // A copy of the source buffer is needed if either: - // - The producer of `source` is not inplace. This is the case where a - // slice is computed out of place into the inplace full tensor. - // - The result is not inplace. This is the case where the whole tensor is - // cloned and the clone needs to be updated. - // TODO: Is this necessary? - bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp( - state.aliasInfo, insertSliceOp) || - !state.aliasInfo.isInPlace(insertSliceOp->getResult(0)); + bool needCopy = + !tensorState.insertSliceOpsWithoutCopy.contains(insertSliceOp); if (needCopy) { // Take a subview of the dst. auto dstMemrefType = dstMemref.getType().cast(); @@ -424,6 +434,24 @@ } // namespace linalg } // namespace mlir +LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext:: + InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state, + SmallVector &newOps) { + auto &tensorState = getTensorBufferizationState(state); + funcOp.walk([&](InsertSliceOp insertSliceOp) { + // A copy of the source buffer is needed if either: + // - The producer of `source` is not inplace. This is the case where a + // slice is computed out of place into the inplace full tensor. + // - The result is not inplace. This is the case where the whole tensor is + // cloned and the clone needs to be updated. + if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo, + insertSliceOp) && + state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) + tensorState.insertSliceOpsWithoutCopy.insert(insertSliceOp); + }); + return success(); +} + void mlir::linalg::comprehensive_bufferize::tensor_ext:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addOpInterface(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -90,6 +90,9 @@ // Enable InitTensorOp elimination. options.addPostAnalysisStep< linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); + // TODO: Find a way to enable this step automatically when bufferizing tensor + // dialect ops. + options.addPostAnalysisStep(); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp);