diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -28,6 +28,9 @@ namespace mlir { namespace bufferization { +class BufferizationState; +struct BufferizationOptions; + /// A helper type converter class that automatically populates the relevant /// materializations and type conversions for bufferization. class BufferizeTypeConverter : public TypeConverter { @@ -52,8 +55,6 @@ void populateEliminateBufferizeMaterializationsPatterns( BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns); -class BufferizationState; - /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. /// Whether buffer copies are needed or not is queried from `state`. /// @@ -61,13 +62,21 @@ /// unknown op (that does not implement `BufferizableOpInterface`) is found. No /// to_tensor/to_memref ops are inserted in that case. /// -/// Note: Tje layout map chosen to bufferize is the most dynamic canonical +/// Note: The layout map chosen to bufferize is the most dynamic canonical /// strided layout of the proper rank. This ensures compatibility with expected /// layouts after transformations. Combinations of memref.cast + /// canonicalization are responsible for clean ups. // TODO: Extract `options` from `state` and pass as separate argument. LogicalResult bufferizeOp(Operation *op, const BufferizationState &state); +/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Buffers are duplicated and copied before any tensor use that bufferizes to +/// a memory write. +/// +/// Note: This function bufferizes ops without utilizing analysis results. It +/// can be used to implement partial bufferization passes. +LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options); + } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSORINTERFACEIMPL_H -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSORINTERFACEIMPL_H - -namespace mlir { - -class DialectRegistry; - -namespace linalg { -namespace comprehensive_bufferize { -namespace tensor_ext { - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); - -} // namespace tensor_ext -} // namespace comprehensive_bufferize -} // namespace linalg -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSORINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_TENSOR_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace tensor { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -132,6 +132,10 @@ return std::make_unique(); } +//===----------------------------------------------------------------------===// +// BufferizableOpInterface-based Bufferization +//===----------------------------------------------------------------------===// + static bool isaTensor(Type t) { return t.isa(); } /// Return true if the given op has a tensor result or a tensor operand. @@ -208,3 +212,38 @@ return checkBufferizationResult(op, state.getOptions()); } + +namespace { +/// This a "no analysis, always copy" BufferizationState. In the absence of an +/// analysis, a buffer must be copied each time it is written to. Therefore, all +/// OpOperands that bufferize to a memory write must bufferize out-of-place. +class AlwaysCopyBufferizationState : public BufferizationState { +public: + AlwaysCopyBufferizationState(const BufferizationOptions &options) + : BufferizationState(options) {} + + AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete; + + virtual ~AlwaysCopyBufferizationState() = default; + + /// Return `true` if the given OpResult has been decided to bufferize inplace. + bool isInPlace(OpOperand &opOperand) const override { + // OpOperands that bufferize to a memory write are out-of-place, i.e., an + // alloc and copy is inserted. + return !bufferizesToMemoryWrite(opOperand); + } + + /// Return true if `v1` and `v2` bufferize to equivalent buffers. + bool areEquivalentBufferizedValues(Value v1, Value v2) const override { + // There is no analysis, so we do not know if the values are equivalent. The + // conservative answer is "false". + return false; + } +}; +} // namespace + +LogicalResult bufferization::bufferizeOp(Operation *op, + const BufferizationOptions &options) { + AlwaysCopyBufferizationState state(options); + return bufferizeOp(op, state); +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -5,7 +5,6 @@ ModuleBufferization.cpp SCFInterfaceImpl.cpp StdInterfaceImpl.cpp - TensorInterfaceImpl.cpp VectorInterfaceImpl.cpp ) @@ -57,17 +56,6 @@ MLIRStandard ) -add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl - TensorInterfaceImpl.cpp - - LINK_LIBS PUBLIC - MLIRBufferizableOpInterface - MLIRIR - MLIRMemRef - MLIRSCF - MLIRTensor -) - add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl VectorInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -55,7 +55,7 @@ MLIRStandardOpsTransforms MLIRStandardToLLVM MLIRTensor - MLIRTensorBufferizableOpInterfaceImpl + MLIRTensorTransforms MLIRTransforms MLIRTransformUtils MLIRVector diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -18,9 +18,9 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -59,7 +59,7 @@ scf_ext::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry); std_ext::registerBufferizableOpInterfaceExternalModels(registry); - tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp rename to mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1,4 +1,4 @@ -//===- TensorInterfaceImpl.cpp - Tensor Impl. of BufferizableOpInterface --===// +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -16,14 +16,11 @@ using namespace mlir; using namespace mlir::bufferization; +using namespace mlir::tensor; namespace mlir { -namespace linalg { -namespace comprehensive_bufferize { -namespace tensor_ext { - -using tensor::ExtractSliceOp; -using tensor::InsertSliceOp; +namespace tensor { +namespace { struct CastOpInterface : public BufferizableOpInterface::ExternalModel(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry - .addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); +void mlir::tensor::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTensorTransforms + BufferizableOpInterfaceImpl.cpp Bufferize.cpp ADDITIONAL_HEADER_DIRS @@ -9,6 +10,7 @@ LINK_LIBS PUBLIC MLIRArithmetic + MLIRBufferizableOpInterface MLIRBufferizationTransforms MLIRIR MLIRMemRef diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -30,7 +30,7 @@ MLIRStdBufferizableOpInterfaceImpl MLIRStandard MLIRTensor - MLIRTensorBufferizableOpInterfaceImpl + MLIRTensorTransforms MLIRTransformUtils MLIRVector MLIRVectorBufferizableOpInterfaceImpl diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -21,11 +21,11 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" @@ -65,7 +65,7 @@ linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerBufferizableOpInterfaceExternalModels(registry); - tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4415,11 +4415,15 @@ "lib/Dialect/Tensor/Transforms/*.h", ], ), - hdrs = ["include/mlir/Dialect/Tensor/Transforms/Passes.h"], + hdrs = [ + "include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h", + "include/mlir/Dialect/Tensor/Transforms/Passes.h", + ], includes = ["include"], deps = [ ":ArithmeticDialect", ":Async", + ":BufferizableOpInterface", ":BufferizationDialect", ":BufferizationTransforms", ":IR", @@ -6684,26 +6688,6 @@ ], ) -cc_library( - name = "TensorBufferizableOpInterfaceImpl", - srcs = [ - "lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp", - ], - hdrs = [ - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h", - ], - includes = ["include"], - deps = [ - ":BufferizableOpInterface", - ":IR", - ":MemRefDialect", - ":SCFDialect", - ":Support", - ":TensorDialect", - "//llvm:Support", - ], -) - cc_library( name = "VectorBufferizableOpInterfaceImpl", srcs = [ @@ -6947,8 +6931,8 @@ ":StandardOpsTransforms", ":StdBufferizableOpInterfaceImpl", ":Support", - ":TensorBufferizableOpInterfaceImpl", ":TensorDialect", + ":TensorTransforms", ":TensorUtils", ":TransformUtils", ":VectorBufferizableOpInterfaceImpl", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -403,8 +403,8 @@ "//mlir:SCFTransforms", "//mlir:StandardOps", "//mlir:StdBufferizableOpInterfaceImpl", - "//mlir:TensorBufferizableOpInterfaceImpl", "//mlir:TensorDialect", + "//mlir:TensorTransforms", "//mlir:TransformUtils", "//mlir:VectorBufferizableOpInterfaceImpl", "//mlir:VectorOps",