diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h @@ -9,8 +9,6 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" - namespace mlir { class DialectRegistry; @@ -19,12 +17,6 @@ namespace comprehensive_bufferize { namespace tensor_ext { -struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep { - LogicalResult run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) override; -}; - void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace tensor_ext diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -23,20 +23,6 @@ using tensor::ExtractSliceOp; using tensor::InsertSliceOp; -namespace { -/// Extra bufferization state that is required for bufferization of tensor ops. -struct TensorBufferizationState : public DialectBufferizationState { - /// InsertSliceOps that bufferize inplace and do not require a copy. - DenseSet insertSliceOpsWithoutCopy; -}; -} // namespace - -static TensorBufferizationState & -getTensorBufferizationState(BufferizationState &state) { - return state.getDialectState( - tensor::TensorDialect::getDialectNamespace()); -} - struct CastOpInterface : public BufferizableOpInterface::ExternalModel { @@ -274,23 +260,6 @@ return true; } -/// Return true if the source of a `insertSliceOp` bufferizes to an -/// equivalent ExtractSliceOp that bufferizes inplace. -static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( - const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) { - bool foundOp = false; - aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) { - auto extractSliceOp = value.getDefiningOp(); - if (extractSliceOp && - areEquivalentExtractSliceOps(aliasInfo, extractSliceOp, - insertSliceOp) && - aliasInfo.isInPlace(extractSliceOp->getResult(0))) { - foundOp = true; - } - }); - return foundOp; -} - /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, @@ -419,7 +388,6 @@ // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); Location loc = insertSliceOp.getLoc(); - TensorBufferizationState &tensorState = getTensorBufferizationState(state); // When bufferizing out-of-place, `getResultBuffer` allocates. Value dstMemref = @@ -427,24 +395,22 @@ if (!dstMemref) return failure(); - bool needCopy = - !tensorState.insertSliceOpsWithoutCopy.contains(insertSliceOp); - if (needCopy) { - // Take a subview of the dst. - auto dstMemrefType = dstMemref.getType().cast(); - auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( - insertSliceOp.getSourceType().getRank(), dstMemrefType, - insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), - insertSliceOp.getMixedStrides()) - .cast(); - Value subView = rewriter.create( - loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), - insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); - // Copy tensor. - Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source()); - state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView); - } + // Take a subview of the dst. + auto dstMemrefType = dstMemref.getType().cast(); + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + insertSliceOp.getSourceType().getRank(), dstMemrefType, + insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), + insertSliceOp.getMixedStrides()) + .cast(); + Value subView = rewriter.create( + loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), + insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + + // Copy tensor. If this tensor.insert_slice has a matching + // tensor.extract_slice, the copy operation will eventually fold away. + Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source()); + state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView); state.replaceOp(rewriter, op, dstMemref); return success(); @@ -456,25 +422,6 @@ } // namespace linalg } // namespace mlir -LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext:: - InplaceInsertSliceOpAnalysis::run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) { - auto &tensorState = getTensorBufferizationState(state); - op->walk([&](InsertSliceOp insertSliceOp) { - // A copy of the source buffer is needed if either: - // - The producer of `source` is not inplace. This is the case where a - // slice is computed out of place into the inplace full tensor. - // - The result is not inplace. This is the case where the whole tensor is - // cloned and the clone needs to be updated. - if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo, - insertSliceOp) && - state.isInPlace(insertSliceOp->getResult(0))) - tensorState.insertSliceOpsWithoutCopy.insert(insertSliceOp); - }); - return success(); -} - void mlir::linalg::comprehensive_bufferize::tensor_ext:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addOpInterface(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -94,9 +94,6 @@ // Enable InitTensorOp elimination. options->addPostAnalysisStep< linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); - // TODO: Find a way to enable this step automatically when bufferizing tensor - // dialect ops. - options->addPostAnalysisStep(); if (!allowReturnMemref) options->addPostAnalysisStep(); diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -99,9 +99,6 @@ // Enable InitTensorOp elimination. options->addPostAnalysisStep< linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); - // TODO: Find a way to enable this step automatically when bufferizing - // tensor dialect ops. - options->addPostAnalysisStep(); if (!allowReturnMemref) options->addPostAnalysisStep();