diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h rename to mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -1,4 +1,4 @@ -//===- BufferizableOpInterface.h - Comprehensive Bufferize ------*- C++ -*-===// +//===- BufferizableOpInterface.h - Bufferizable Ops -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_ -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_ +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ +#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ #include @@ -25,13 +25,11 @@ class DominanceInfo; class FuncOp; -namespace linalg { -namespace comprehensive_bufferize { +namespace bufferization { // TODO: from some HW description. static constexpr int64_t kBufferAlignments = 128; -class BufferizationAliasInfo; class BufferizableOpInterface; struct BufferizationOptions; class BufferizationState; @@ -241,7 +239,8 @@ } /// Return dialect-specific bufferization state or create one if none exists. - template StateT &getOrCreateDialectState(StringRef name) { + template + StateT &getOrCreateDialectState(StringRef name) { // Create state if it does not exist yet. if (!dialectState.count(name)) dialectState[name] = std::make_unique(); @@ -321,15 +320,13 @@ LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to, const BufferizationOptions &options); -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace bufferization } // namespace mlir -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h.inc" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc" namespace mlir { -namespace linalg { -namespace comprehensive_bufferize { +namespace bufferization { /// AllocationHoistingBarrierOnly is an external implementation of /// BufferizableOpInterface for ops that are (not yet) bufferizable, but are @@ -378,8 +375,7 @@ bool isAllocationHoistingBarrier(Operation *op) const { return true; } }; -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace bufferization } // namespace mlir -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_ +#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td rename to mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -1,4 +1,4 @@ -//===-- BufferizableOpInterface.td - Compreh. Bufferize ----*- tablegen -*-===// +//===-- BufferizableOpInterface.td - Bufferizable Ops ------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -16,7 +16,7 @@ An op interface for Comprehensive Bufferization. Ops that implement this interface can be bufferized using Comprehensive Bufferization. }]; - let cppNamespace = "::mlir::linalg::comprehensive_bufferize"; + let cppNamespace = "::mlir::bufferization"; let methods = [ InterfaceMethod< /*desc=*/[{ @@ -311,12 +311,12 @@ // TODO: The following two attributes should belong to the tensor dialect. // The corresponding verifier should also be in the tensor dialect. /// Attribute name used to mark region arguments that can be bufferized - /// in-place during linalg comprehensive bufferization. + /// in-place during one-shot bufferization. constexpr const static ::llvm::StringLiteral kInplaceableAttrName = "linalg.inplaceable"; /// Attribute name used to mark the bufferization layout for region - /// arguments during linalg comprehensive bufferization. + /// arguments during one-shot bufferization. constexpr const static ::llvm::StringLiteral kBufferLayoutAttrName = "linalg.buffer_layout"; }]; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h rename to mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h @@ -6,22 +6,20 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATIONINTERFACEIMPL_H -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATIONINTERFACEIMPL_H +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_ +#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_ namespace mlir { class DialectRegistry; -namespace linalg { -namespace comprehensive_bufferize { +namespace bufferization { namespace bufferization_ext { void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace bufferization_ext -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace bufferization } // namespace mlir -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATIONINTERFACEIMPL_H +#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt @@ -1,3 +1,4 @@ add_mlir_dialect(BufferizationOps bufferization) add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc) add_mlir_interface(AllocationOpInterface) +add_mlir_interface(BufferizableOpInterface) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -52,6 +52,22 @@ void populateEliminateBufferizeMaterializationsPatterns( BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns); +class BufferizationState; + +/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Whether buffer copies are needed or not is queried from `state`. +/// +/// Note: If `allowUnknownOps` is set to false, bufferization fails when an +/// unknown op (that does not implement `BufferizableOpInterface`) is found. No +/// to_tensor/to_memref ops are inserted in that case. +/// +/// Note: Tje layout map chosen to bufferize is the most dynamic canonical +/// strided layout of the proper rank. This ensures compatibility with expected +/// layouts after transformations. Combinations of memref.cast + +/// canonicalization are responsible for clean ups. +// TODO: Extract `options` from `state` and pass as separate argument. +LogicalResult bufferizeOp(Operation *op, const BufferizationState &state); + } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h rename to mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -1,4 +1,4 @@ -//===- ComprehensiveBufferize.h - Linalg bufferization pass -----*- C++ -*-===// +//===- OneShotAnalysis.h - One-Shot (Single Pass) Analysis ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,22 +6,19 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H +#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H +#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/EquivalenceClasses.h" namespace mlir { - -namespace linalg { -namespace comprehensive_bufferize { +namespace bufferization { class AnalysisBufferizationState; class BufferizationAliasInfo; struct AnalysisBufferizationOptions; -class BufferizationState; /// PostAnalysisSteps can be registered with `BufferizationOptions` and are /// executed after the analysis, but before bufferization. They can be used to @@ -168,7 +165,7 @@ private: /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal - /// functions and `runComprehensiveBufferize` may access this object. + /// functions and `runOneShotBufferize` may access this object. BufferizationAliasInfo aliasInfo; }; @@ -176,16 +173,12 @@ /// `state`. LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state); -/// Bufferize `op` and its nested ops. Bufferization decisions are stored in -/// `state`. -LogicalResult bufferizeOp(Operation *op, const BufferizationState &state); - -/// Run Comprehensive Bufferize on the given op: Analysis + Bufferization -LogicalResult runComprehensiveBufferize( - Operation *op, std::unique_ptr options); +/// Run One-Shot Bufferize on the given op: Analysis + Bufferization +LogicalResult +runOneShotBufferize(Operation *op, + std::unique_ptr options); -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace bufferization } // namespace mlir -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H +#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -1,5 +1,2 @@ -set(LLVM_TARGET_DEFINITIONS BufferizableOpInterface.td) -mlir_tablegen(BufferizableOpInterface.h.inc -gen-op-interface-decls) -mlir_tablegen(BufferizableOpInterface.cpp.inc -gen-op-interface-defs) -add_public_tablegen_target(MLIRBufferizableOpInterfaceIncGen) -add_dependencies(mlir-headers MLIRBufferizableOpInterfaceIncGen) +# no targets defined here + diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -9,20 +9,16 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" namespace mlir { - class DialectRegistry; namespace linalg { namespace comprehensive_bufferize { - -class BufferizationAliasInfo; - namespace linalg_ext { -struct InitTensorEliminationStep : public PostAnalysisStep { +struct InitTensorEliminationStep : public bufferization::PostAnalysisStep { /// A function that matches anchor OpOperands for InitTensorOp elimination. using AnchorMatchFn = std::function; @@ -39,11 +35,11 @@ /// InitTensorOp. /// * The result of `rewriteFunc` must usually be analyzed for inplacability. /// This analysis can be skipped with `skipAnalysis`. - LogicalResult eliminateInitTensors(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - AnchorMatchFn anchorMatchFunc, - RewriteFn rewriteFunc, - SmallVector &newOps); + LogicalResult + eliminateInitTensors(Operation *op, bufferization::BufferizationState &state, + bufferization::BufferizationAliasInfo &aliasInfo, + AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, + SmallVector &newOps); }; /// Try to eliminate InitTensorOps inside `op` that are anchored on an @@ -51,8 +47,8 @@ /// (and some other conditions are met). struct InsertSliceAnchoredInitTensorEliminationStep : public InitTensorEliminationStep { - LogicalResult run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, + LogicalResult run(Operation *op, bufferization::BufferizationState &state, + bufferization::BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h @@ -17,16 +17,19 @@ struct LogicalResult; class ModuleOp; +namespace bufferization { +struct AnalysisBufferizationOptions; +} // namespace bufferization + namespace linalg { namespace comprehensive_bufferize { -struct AnalysisBufferizationOptions; - /// Run Module Bufferization on the given module. Performs a simple function /// call analysis to determine which function arguments are inplaceable. Then -/// analyzes and bufferizes FuncOps one-by-one with Comprehensive Bufferization. +/// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize. LogicalResult runComprehensiveBufferize( - ModuleOp moduleOp, std::unique_ptr options); + ModuleOp moduleOp, + std::unique_ptr options); namespace std_ext { diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h @@ -9,7 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" namespace mlir { @@ -23,9 +23,9 @@ /// bbArgs. This is required because the i-th OpResult of an scf.for op is /// currently assumed to alias with the i-th iter_arg (in the absence of /// conflicts). -struct AssertScfForAliasingProperties : public PostAnalysisStep { - LogicalResult run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, +struct AssertScfForAliasingProperties : public bufferization::PostAnalysisStep { + LogicalResult run(Operation *op, bufferization::BufferizationState &state, + bufferization::BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override; }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp rename to mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -1,4 +1,4 @@ -//===- BufferizableOpInterface.cpp - Comprehensive Bufferize --------------===// +//===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" - +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AsmState.h" @@ -19,13 +18,11 @@ #include "llvm/Support/Debug.h" namespace mlir { -namespace linalg { -namespace comprehensive_bufferize { +namespace bufferization { -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp.inc" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace bufferization } // namespace mlir #define DEBUG_TYPE "bufferizable-op-interface" @@ -33,7 +30,7 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) using namespace mlir; -using namespace linalg::comprehensive_bufferize; +using namespace bufferization; //===----------------------------------------------------------------------===// // BufferizationOptions @@ -42,15 +39,15 @@ // Default constructor for BufferizationOptions. BufferizationOptions::BufferizationOptions() {} -BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: - BufferizationOptions::dynCastBufferizableOp(Operation *op) const { +BufferizableOpInterface +BufferizationOptions::dynCastBufferizableOp(Operation *op) const { if (isOpAllowed(op)) return dyn_cast(op); return nullptr; } -BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: - BufferizationOptions::dynCastBufferizableOp(Value value) const { +BufferizableOpInterface +BufferizationOptions::dynCastBufferizableOp(Value value) const { if (auto bufferizableOp = value.getDefiningOp()) if (isOpAllowed(bufferizableOp.getOperation())) return bufferizableOp; @@ -72,8 +69,7 @@ /// Determine which OpOperand* will alias with `result` if the op is bufferized /// in place. Return an empty vector if the op is not bufferizable. SmallVector -mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand( - OpResult result) const { +BufferizationState::getAliasingOpOperand(OpResult result) const { if (Operation *op = result.getDefiningOp()) if (auto bufferizableOp = dyn_cast(op)) return bufferizableOp.getAliasingOpOperand(result, *this); @@ -82,9 +78,7 @@ /// Determine which OpResult will alias with `opOperand` if the op is bufferized /// in place. Return an empty OpResult if the op is not bufferizable. -OpResult -mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult( - OpOperand &opOperand) const { +OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.getAliasingOpResult(opOperand, *this); @@ -93,8 +87,7 @@ /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the /// op is not bufferizable. -bool mlir::linalg::comprehensive_bufferize::BufferizationState:: - bufferizesToMemoryRead(OpOperand &opOperand) const { +bool BufferizationState::bufferizesToMemoryRead(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); @@ -106,8 +99,7 @@ /// Return true if `opOperand` bufferizes to a memory write. Return /// `true` if the op is not bufferizable. -bool mlir::linalg::comprehensive_bufferize::BufferizationState:: - bufferizesToMemoryWrite(OpOperand &opOperand) const { +bool BufferizationState::bufferizesToMemoryWrite(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); @@ -119,8 +111,7 @@ /// Return true if `opOperand` does neither read nor write but bufferizes to an /// alias. Return false if the op is not bufferizable. -bool mlir::linalg::comprehensive_bufferize::BufferizationState:: - bufferizesToAliasOnly(OpOperand &opOperand) const { +bool BufferizationState::bufferizesToAliasOnly(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); @@ -133,8 +124,7 @@ /// Return true if the given value is read by an op that bufferizes to a memory /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). -bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead( - Value value) const { +bool BufferizationState::isValueRead(Value value) const { assert(value.getType().isa() && "expected TensorType"); SmallVector workingSet; for (OpOperand &use : value.getUses()) @@ -157,9 +147,8 @@ // the aliasing OpOperands. Find and return Values for which `condition` // evaluates to true. OpOperands of such matching Values are not traversed any // further. -llvm::SetVector mlir::linalg::comprehensive_bufferize:: - BufferizationState::findValueInReverseUseDefChain( - Value value, llvm::function_ref condition) const { +llvm::SetVector BufferizationState::findValueInReverseUseDefChain( + Value value, llvm::function_ref condition) const { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -185,8 +174,8 @@ } // Find the Values of the last preceding write of a given Value. -llvm::SetVector mlir::linalg::comprehensive_bufferize:: - BufferizationState::findLastPrecedingWrite(Value value) const { +llvm::SetVector +BufferizationState::findLastPrecedingWrite(Value value) const { return findValueInReverseUseDefChain(value, [&](Value value) { Operation *op = value.getDefiningOp(); if (!op) @@ -198,8 +187,7 @@ }); } -mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState( - const BufferizationOptions &options) +BufferizationState::BufferizationState(const BufferizationOptions &options) : options(options) {} // bufferization.to_memref is not allowed to change the rank. @@ -237,8 +225,7 @@ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. -FailureOr -mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer( +FailureOr BufferizationState::getBuffer( RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace, Optional customCopyInsertionPoint) const { OpBuilder::InsertionGuard guard(rewriter); @@ -294,8 +281,9 @@ return resultBuffer; } -void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues( - RewriterBase &rewriter, Operation *op, ValueRange values) { +void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, + Operation *op, + ValueRange values) { OpBuilder::InsertionGuard g(rewriter); // Replace all OpResults with the given values. @@ -409,9 +397,10 @@ /// Create an AllocOp/DeallocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. -FailureOr mlir::linalg::comprehensive_bufferize::createAlloc( - OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref, - const BufferizationOptions &options) { +FailureOr +bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue, + bool deallocMemref, + const BufferizationOptions &options) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -445,9 +434,10 @@ } /// Create a memref allocation. -FailureOr mlir::linalg::comprehensive_bufferize::createAlloc( - OpBuilder &b, Location loc, MemRefType type, ArrayRef dynShape, - const BufferizationOptions &options) { +FailureOr +bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, + ArrayRef dynShape, + const BufferizationOptions &options) { if (options.allocationFn) return (*options.allocationFn)(b, loc, type, dynShape); @@ -458,9 +448,9 @@ } /// Create a memref deallocation. -LogicalResult mlir::linalg::comprehensive_bufferize::createDealloc( - OpBuilder &b, Location loc, Value allocatedBuffer, - const BufferizationOptions &options) { +LogicalResult +bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, + const BufferizationOptions &options) { if (options.deallocationFn) return (*options.deallocationFn)(b, loc, allocatedBuffer); @@ -470,9 +460,9 @@ } /// Create a memory copy between two memref buffers. -LogicalResult mlir::linalg::comprehensive_bufferize::createMemCpy( - OpBuilder &b, Location loc, Value from, Value to, - const BufferizationOptions &options) { +LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, + Value from, Value to, + const BufferizationOptions &options) { if (options.memCpyFn) return (*options.memCpyFn)(b, loc, from, to); @@ -484,27 +474,28 @@ // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// -bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) { +bool bufferization::isFunctionArgument(Value value) { auto bbArg = value.dyn_cast(); if (!bbArg) return false; return isa(bbArg.getOwner()->getParentOp()); } -MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType( - ShapedType shapedType, MemRefLayoutAttrInterface layout, - Attribute memorySpace) { +MemRefType +bufferization::getContiguousMemRefType(ShapedType shapedType, + MemRefLayoutAttrInterface layout, + Attribute memorySpace) { return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), layout, memorySpace); } -UnrankedMemRefType mlir::linalg::comprehensive_bufferize::getUnrankedMemRefType( - Type elementType, Attribute memorySpace) { +UnrankedMemRefType bufferization::getUnrankedMemRefType(Type elementType, + Attribute memorySpace) { return UnrankedMemRefType::get(elementType, memorySpace); } -MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType( - RankedTensorType tensorType, unsigned addressSpace) { +MemRefType bufferization::getDynamicMemRefType(RankedTensorType tensorType, + unsigned addressSpace) { // TODO: address space decisions to connect with the actual alloc. int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; SmallVector dynamicStrides(tensorType.getRank(), diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp rename to mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp @@ -6,24 +6,21 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" using namespace mlir; -using namespace linalg; -using namespace comprehensive_bufferize; +using namespace mlir::bufferization; namespace mlir { -namespace linalg { -namespace comprehensive_bufferize { +namespace bufferization { namespace bufferization_ext { -// TODO: These ops should implement BufferizableOpInterface directly when moved -// to the Bufferization dialect. +// TODO: These ops should implement BufferizableOpInterface. /// Bufferization of bufferization.to_memref. to_memref(to_tensor(x)) is folded /// to x. Other to_memref ops are ignored during bufferization. @@ -57,7 +54,6 @@ bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, const BufferizationState &state) const { // ToMemrefOps always bufferize inplace. - // TODO: Remove ToMemrefOps from the analysis. return true; } @@ -121,14 +117,11 @@ }; } // namespace bufferization_ext -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace bufferization } // namespace mlir -void mlir::linalg::comprehensive_bufferize::bufferization_ext:: - registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); +void bufferization_ext::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); } diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRBufferization + PARTIAL_SOURCES_INTENDED AllocationOpInterface.cpp BufferizationOps.cpp BufferizationDialect.cpp @@ -16,3 +17,17 @@ MLIRTensor MLIRMemRef ) + +add_mlir_dialect_library(MLIRBufferizableOpInterface + PARTIAL_SOURCES_INTENDED + BufferizableOpInterface.cpp + BufferizationInterfaceImpl.cpp + + DEPENDS + MLIRBufferizableOpInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRBufferization + MLIRMemRef +) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -8,10 +8,12 @@ #include "PassDetail.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/IR/Operation.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::bufferization; @@ -129,3 +131,80 @@ mlir::bufferization::createFinalizingBufferizePass() { return std::make_unique(); } + +static bool isaTensor(Type t) { return t.isa(); } + +/// Return true if the given op has a tensor result or a tensor operand. +static bool hasTensorSemantics(Operation *op) { + bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); + bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); + return hasTensorResult || hasTensorOperand; +} + +/// Rewrite pattern that bufferizes bufferizable ops. +struct BufferizationPattern + : public OpInterfaceRewritePattern { + BufferizationPattern(MLIRContext *context, const BufferizationState &state, + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + state(state) {} + + LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, + PatternRewriter &rewriter) const override { + // No tensors => no buffers. + if (!hasTensorSemantics(bufferizableOp.getOperation())) + return failure(); + if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation())) + return failure(); + return bufferizableOp.bufferize(rewriter, state); + } + +private: + const BufferizationState &state; +}; + +/// Check the result of bufferization. Return an error if an op was not +/// bufferized, unless partial bufferization is allowed. +static LogicalResult +checkBufferizationResult(Operation *op, const BufferizationOptions &options) { + if (!options.allowUnknownOps) { + // Check if all ops were bufferized. + LogicalResult status = success(); + op->walk([&](Operation *op) { + if (!hasTensorSemantics(op)) + return WalkResult::advance(); + + // Bufferization dialect ops will canonicalize away if all other ops are + // bufferized. + if (isa(op)) + return WalkResult::advance(); + + // Ops that are not in the allow list can be ignored. + if (!options.isOpAllowed(op)) + return WalkResult::advance(); + + // Ops without any uses and no side effects will fold away. + if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) + return WalkResult::advance(); + + status = op->emitError("op was not bufferized"); + return WalkResult::interrupt(); + }); + + if (failed(status)) + return status; + } + + return success(); +} + +LogicalResult bufferization::bufferizeOp(Operation *op, + const BufferizationState &state) { + // Bufferize the op and its nested ops. + OwningRewritePatternList patterns(op->getContext()); + patterns.add(op->getContext(), state); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return failure(); + + return checkBufferizationResult(op, state.getOptions()); +} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms Bufferize.cpp BufferDeallocation.cpp + OneShotAnalysis.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization @@ -9,7 +10,13 @@ MLIRBufferizationPassIncGen LINK_LIBS PUBLIC + MLIRBufferizableOpInterface MLIRBufferization + MLIRControlFlowInterfaces + MLIRInferTypeOpInterface + MLIRIR + MLIRMemRef MLIRPass + MLIRStandard MLIRTransforms ) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp rename to mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -1,4 +1,4 @@ -//===- ComprehensiveBufferize.cpp - Single pass bufferization -------------===// +//===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// // -// Comprehensive Bufferize bufferizes function bodies. Function boundaries -// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. -// ModuleBufferization.cpp is an extension of Comprehensive Bufferize for simple +// One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp +// bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. +// ModuleBufferization.cpp is an extension of One-Shot Analysis for simple // call graphs. // -// Comprehensive Bufferize consists of two phases. +// One-Shot Bufferize consists of two phases. // // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without // inserting buffer copies. The analysis queries op bufferization semantics @@ -20,49 +20,43 @@ // function does not generate buffer copies for OpResults that were decided // to bufferize inplace during the analysis phase. // +// This file contains only the analysis. The actual bufferization is implemented +// via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a +// helper function `runOneShotBufferize` that analyzes an op (and its nested +// ops) and then bufferizes it. +// // Inplace bufferization decisions are passed from the analysis to the // bufferization phase via `BufferizationState` and `BufferizationAliasInfo`. // They can be printed for debugging purposes with `testAnalysisOnly`. // // Ops that do not implement `BufferizableOpInterface` can be analyzed but are -// treated conservatively. E.g., the analysis has to assume that their +// treated conservatively. E.g., the analysis has to assume that their tensor // OpOperands bufferize to memory writes. While such ops can be analyzed, they // are not bufferized and remain in the IR. to_tensor and to_memref ops are // inserted at the bufferization boundary. // -// Note: If `allowUnknownOps` is set to false, bufferization fails when an -// unknown op (that does not implement `BufferizableOpInterface`) is found. No -// to_tensor/to_memref ops are inserted. -// -// This pass caters to high-performance codegen where buffer reuse is deemed -// critical: the pass should fail if the bufferized form of the function needs -// to return any buffer, unless `allowReturnMemref` is enabled. -// -// Lastly, note that layout map chosen to bufferize is the most dynamic -// canonical strided layout of the proper rank. This ensures compatibility with -// expected layouts after transformations. Combinations of memref.cast + -// canonicalization are responsible for clean ups. +// This analysis caters to high-performance codegen where buffer reuse is deemed +// critical: the analysis should fail if the bufferized form of the function +// needs to return a buffer, unless `allowReturnMemref` is enabled. -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" using namespace mlir; -using namespace linalg; -using namespace tensor; -using namespace comprehensive_bufferize; +using namespace mlir::bufferization; static bool isaTensor(Type t) { return t.isa(); } @@ -751,65 +745,8 @@ } }; -/// Rewrite pattern that bufferizes bufferizable ops. -struct BufferizationPattern - : public OpInterfaceRewritePattern { - BufferizationPattern(MLIRContext *context, const BufferizationState &state, - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - state(state) {} - - LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, - PatternRewriter &rewriter) const override { - // No tensors => no buffers. - if (!hasTensorSemantics(bufferizableOp.getOperation())) - return failure(); - if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation())) - return failure(); - return bufferizableOp.bufferize(rewriter, state); - } - -private: - const BufferizationState &state; -}; - -/// Check the result of bufferization. Return an error if an op was not -/// bufferized, unless partial bufferization is allowed. -static LogicalResult -checkBufferizationResult(Operation *op, const BufferizationOptions &options) { - if (!options.allowUnknownOps) { - // Check if all ops were bufferized. - LogicalResult status = success(); - op->walk([&](Operation *op) { - if (!hasTensorSemantics(op)) - return WalkResult::advance(); - - // Bufferization dialect ops will canonicalize away if all other ops are - // bufferized. - if (isa(op)) - return WalkResult::advance(); - - // Ops that are not in the allow list can be ignored. - if (!options.isOpAllowed(op)) - return WalkResult::advance(); - - // Ops without any uses and no side effects will fold away. - if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) - return WalkResult::advance(); - - status = op->emitError("op was not bufferized"); - return WalkResult::interrupt(); - }); - - if (failed(status)) - return status; - } - - return success(); -} - -LogicalResult mlir::linalg::comprehensive_bufferize::analyzeOp( - Operation *op, AnalysisBufferizationState &state) { +LogicalResult bufferization::analyzeOp(Operation *op, + AnalysisBufferizationState &state) { DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); const auto &options = @@ -849,18 +786,7 @@ return success(); } -LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp( - Operation *op, const BufferizationState &state) { - // Bufferize the op and its nested ops. - OwningRewritePatternList patterns(op->getContext()); - patterns.add(op->getContext(), state); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return failure(); - - return checkBufferizationResult(op, state.getOptions()); -} - -LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( +LogicalResult bufferization::runOneShotBufferize( Operation *op, std::unique_ptr options) { AnalysisBufferizationState state(op, *options); if (failed(analyzeOp(op, state))) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp @@ -9,7 +9,9 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" + +using namespace mlir::bufferization; void mlir::linalg::comprehensive_bufferize::affine_ext:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -9,12 +9,14 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/Transforms/BufferUtils.h" +using namespace mlir::bufferization; + namespace mlir { namespace linalg { namespace comprehensive_bufferize { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -1,9 +1,6 @@ set(LLVM_OPTIONAL_SOURCES AffineInterfaceImpl.cpp ArithInterfaceImpl.cpp - BufferizableOpInterface.cpp - BufferizationInterfaceImpl.cpp - ComprehensiveBufferize.cpp LinalgInterfaceImpl.cpp ModuleBufferization.cpp SCFInterfaceImpl.cpp @@ -12,18 +9,6 @@ VectorInterfaceImpl.cpp ) -add_mlir_dialect_library(MLIRBufferizableOpInterface - BufferizableOpInterface.cpp - - DEPENDS - MLIRBufferizableOpInterfaceIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRBufferization - MLIRMemRef -) - add_mlir_dialect_library(MLIRAffineBufferizableOpInterfaceImpl AffineInterfaceImpl.cpp @@ -48,7 +33,7 @@ LINK_LIBS PUBLIC MLIRBufferizableOpInterface - MLIRComprehensiveBufferize + MLIRBufferizationTransforms MLIRIR MLIRLinalg MLIRTensor @@ -59,7 +44,7 @@ LINK_LIBS PUBLIC MLIRBufferizableOpInterface - MLIRComprehensiveBufferize + MLIRBufferizationTransforms MLIRIR MLIRSCF ) @@ -91,18 +76,14 @@ MLIRVector ) -add_mlir_dialect_library(MLIRComprehensiveBufferize - BufferizationInterfaceImpl.cpp - ComprehensiveBufferize.cpp +add_mlir_dialect_library(MLIRModuleBufferization ModuleBufferization.cpp LINK_LIBS PUBLIC MLIRBufferizableOpInterface - MLIRControlFlowInterfaces - MLIRInferTypeOpInterface + MLIRBufferizationTransforms MLIRIR MLIRMemRef MLIRStandard MLIRStandardOpsTransforms - MLIRTransforms ) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" @@ -17,6 +17,7 @@ using namespace mlir; using namespace linalg; using namespace comprehensive_bufferize; +using namespace mlir::bufferization; namespace { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -71,9 +71,10 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Operation.h" @@ -82,6 +83,7 @@ using namespace linalg; using namespace tensor; using namespace comprehensive_bufferize; +using namespace mlir::bufferization; namespace { /// The state of analysis of a FuncOp. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -7,13 +7,15 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +using namespace mlir::bufferization; + namespace mlir { namespace linalg { namespace comprehensive_bufferize { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp @@ -8,11 +8,13 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +using namespace mlir::bufferization; + namespace mlir { namespace linalg { namespace comprehensive_bufferize { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -7,13 +7,14 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" using namespace mlir; +using namespace mlir::bufferization; namespace mlir { namespace linalg { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -7,11 +7,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +using namespace mlir::bufferization; + namespace mlir { namespace linalg { namespace comprehensive_bufferize { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -65,12 +65,12 @@ /// Attribute name used to mark the bufferization layout for region /// arguments during linalg comprehensive bufferization. constexpr const ::llvm::StringLiteral - comprehensive_bufferize::BufferizableOpInterface::kBufferLayoutAttrName; + bufferization::BufferizableOpInterface::kBufferLayoutAttrName; /// Attribute name used to mark region arguments that can be bufferized /// in-place during linalg comprehensive bufferization. constexpr const ::llvm::StringLiteral - comprehensive_bufferize::BufferizableOpInterface::kInplaceableAttrName; + bufferization::BufferizableOpInterface::kInplaceableAttrName; /// Trait to check if T provides a `regionBuilder` method. template @@ -125,7 +125,7 @@ LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { - using comprehensive_bufferize::BufferizableOpInterface; + using bufferization::BufferizableOpInterface; if (attr.getName() == BufferizableOpInterface::kInplaceableAttrName) { if (!attr.getValue().isa()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,7 +38,6 @@ MLIRArithmetic MLIRBufferizableOpInterface MLIRComplex - MLIRComprehensiveBufferize MLIRInferTypeOpInterface MLIRIR MLIRMemRef @@ -46,6 +45,7 @@ MLIRLinalgAnalysis MLIRLinalgBufferizableOpInterfaceImpl MLIRLinalgUtils + MLIRModuleBufferization MLIRSCF MLIRSCFBufferizableOpInterfaceImpl MLIRSCFTransforms diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -8,12 +8,12 @@ #include "PassDetail.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" @@ -27,6 +27,7 @@ #include "mlir/Transforms/Passes.h" using namespace mlir; +using namespace mlir::bufferization; using namespace mlir::linalg; using namespace mlir::linalg::comprehensive_bufferize; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -13,8 +13,8 @@ #include "CodegenUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -284,8 +284,7 @@ if (auto funcOp = dyn_cast(arg.getOwner()->getParentOp())) if (auto attr = funcOp.getArgAttrOfType( arg.getArgNumber(), - linalg::comprehensive_bufferize::BufferizableOpInterface:: - kInplaceableAttrName)) + bufferization::BufferizableOpInterface::kInplaceableAttrName)) return attr.getValue(); return false; } diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -17,7 +17,7 @@ MLIRArithBufferizableOpInterfaceImpl MLIRArithmetic MLIRBufferizableOpInterface - MLIRComprehensiveBufferize + MLIRBufferizationTransforms MLIRGPUTransforms MLIRLinalg MLIRLinalgBufferizableOpInterfaceImpl diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -12,14 +12,13 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" @@ -34,6 +33,7 @@ using namespace mlir; using namespace mlir::linalg; using namespace mlir::linalg::comprehensive_bufferize; +using namespace mlir::bufferization; namespace { /// A helper struct for FunctionBufferize and ModuleBufferize. Both passes are @@ -118,8 +118,8 @@ options->dialectFilter->insert(dialectNamespace); } - Operation *op = getOperation().getOperation(); - if (failed(runComprehensiveBufferize(op, std::move(options)))) + Operation *op = getOperation(); + if (failed(runOneShotBufferize(op, std::move(options)))) return; if (testAnalysisOnly) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6506,7 +6506,7 @@ td_library( name = "BufferizableOpInterfaceTdFiles", srcs = [ - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td", + "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td", ], includes = ["include"], deps = [ @@ -6520,15 +6520,15 @@ tbl_outs = [ ( ["-gen-op-interface-decls"], - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h.inc", + "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc", ), ( ["-gen-op-interface-defs"], - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp.inc", + "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc", ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td", + td_file = "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td", deps = [ ":BufferizableOpInterfaceTdFiles", ], @@ -6537,10 +6537,12 @@ cc_library( name = "BufferizableOpInterface", srcs = [ - "lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp", + "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp", + "lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp", ], hdrs = [ - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h", + "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h", + "include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h", ], includes = ["include"], deps = [ @@ -6601,7 +6603,7 @@ deps = [ ":BufferizableOpInterface", ":BufferizationDialect", - ":ComprehensiveBufferize", + ":BufferizationTransforms", ":IR", ":LinalgOps", ":LinalgStructuredOpsIncGen", @@ -6621,7 +6623,7 @@ deps = [ ":BufferizableOpInterface", ":BufferizationDialect", - ":ComprehensiveBufferize", + ":BufferizationTransforms", ":IR", ":SCFDialect", ":Support", @@ -6891,7 +6893,6 @@ ":BufferizationDialect", ":BufferizationTransforms", ":ComplexDialect", - ":ComprehensiveBufferize", ":DialectUtils", ":IR", ":InferTypeOpInterface", @@ -6901,6 +6902,7 @@ ":LinalgStructuredOpsIncGen", ":MathDialect", ":MemRefDialect", + ":ModuleBufferization", ":Pass", ":SCFBufferizableOpInterfaceImpl", ":SCFDialect", @@ -6921,30 +6923,23 @@ ) cc_library( - name = "ComprehensiveBufferize", + name = "ModuleBufferization", srcs = [ - "lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp", - "lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp", "lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp", ], hdrs = [ - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h", - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h", "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h", ], includes = ["include"], deps = [ ":BufferizableOpInterface", ":BufferizationDialect", - ":ControlFlowInterfaces", + ":BufferizationTransforms", ":DialectUtils", ":IR", - ":InferTypeOpInterface", ":MemRefDialect", - ":Pass", ":StandardOps", ":Support", - ":Transforms", "//llvm:Support", ], ) @@ -7957,12 +7952,10 @@ cc_library( name = "BufferizationDialect", - srcs = glob( - [ - "lib/Dialect/Bufferization/IR/Bufferization*.h", - "lib/Dialect/Bufferization/IR/Bufferization*.cpp", - ], - ), + srcs = [ + "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp", + "lib/Dialect/Bufferization/IR/BufferizationOps.cpp", + ], hdrs = ["include/mlir/Dialect/Bufferization/IR/Bufferization.h"], includes = ["include"], deps = [ @@ -8011,11 +8004,17 @@ deps = [ ":AllocationOpInterface", ":Analysis", + ":BufferizableOpInterface", ":BufferizationDialect", ":BufferizationPassIncGen", + ":ControlFlowInterfaces", + ":DialectUtils", ":IR", + ":InferTypeOpInterface", ":MemRefDialect", ":Pass", + ":StandardOps", + ":Support", ":Transforms", "//llvm:Support", ], diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -390,7 +390,7 @@ "//mlir:ArithmeticDialect", "//mlir:BufferizableOpInterface", "//mlir:BufferizationDialect", - "//mlir:ComprehensiveBufferize", + "//mlir:BufferizationTransforms", "//mlir:GPUDialect", "//mlir:IR", "//mlir:LinalgBufferizableOpInterfaceImpl",