diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -10,15 +10,9 @@ #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Value.h" -#include "llvm/ADT/SetOperations.h" namespace mlir { -class DominanceInfo; -class FuncOp; -class GlobalCreator; class ModuleOp; namespace linalg { @@ -27,29 +21,6 @@ // TODO: from some HW description. static constexpr int64_t kBufferAlignments = 128; -struct BufferizationState; - -/// Analyze the `ops` to determine which OpResults are inplaceable. -LogicalResult inPlaceAnalysis(SmallVector &ops, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo, - unsigned analysisFuzzerSeed = 0); - -// TODO: Do not expose those functions in the header file. -/// Default allocation function that is used by the comprehensive bufferization -/// pass. The default currently creates a ranked memref using `memref.alloc`. -Optional defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type, - const SmallVector &dynShape); - -/// Default deallocation function that is used by the comprehensive -/// bufferization pass. It expects to recieve back the value called from the -/// `defaultAllocationFn`. -void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer); - -/// Default memory copy function that is used by the comprehensive bufferization -/// pass. Creates a `linalg.copy` op. -void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to); - /// Return default allocation callbacks. std::unique_ptr defaultAllocationCallbacks(); @@ -63,9 +34,13 @@ /// Register external models implemented for the `BufferizableOpInterface`. void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +/// Options for ComprehensiveBufferize. struct BufferizationOptions { BufferizationOptions(); + // BufferizationOptions cannot be copied. + BufferizationOptions(const BufferizationOptions &other) = delete; + /// Register a "post analysis" step. Such steps are executed after the /// analysis, but before bufferization. template @@ -74,10 +49,22 @@ std::make_unique(std::forward(args)...)); } + /// Helper functions for allocation, deallocation, memory copying. std::unique_ptr allocationFns; + + /// Specifies whether returning newly allocated memrefs should be allowed. + /// Otherwise, a pass failure is triggered. bool allowReturnMemref = false; + + /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated. + /// Should be used only with `testAnalysisOnly = true`. unsigned analysisFuzzerSeed = 0; + + /// If set to `true`, does not modify the IR apart from adding attributes (for + /// checking the results of the analysis) and post analysis steps. bool testAnalysisOnly = false; + + /// Registered post analysis steps. std::vector> postAnalysisSteps; }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -982,8 +982,8 @@ return success(); } -/// Determine if `operand` can be bufferized in-place with one of the op's -/// results. +/// Analyze the `ops` to determine which OpResults are inplaceable. Walk ops in +/// reverse and bufferize ops greedily. This is a good starter heuristic. /// /// Even if an op does not read or write, it may still create an alias when /// bufferized in-place. An example of such ops is tensor.extract_slice. @@ -1000,24 +1000,10 @@ /// /// An analysis is required to ensure inplace bufferization would not result in /// RaW dependence violations. -static LogicalResult -bufferizableInPlaceAnalysis(OpOperand &operand, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { - auto bufferizableOp = dyn_cast(operand.getOwner()); - if (!bufferizableOp) - return success(); - if (OpResult result = bufferizableOp.getAliasingOpResult(operand)) - return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); - return success(); -} - -/// Analyze the `ops` to determine which OpResults are inplaceable. Walk ops in -/// reverse and bufferize ops greedily. This is a good starter heuristic. -/// ExtractSliceOps are interleaved with other ops in traversal order. -LogicalResult mlir::linalg::comprehensive_bufferize::inPlaceAnalysis( - SmallVector &ops, BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo, unsigned analysisFuzzerSeed) { +static LogicalResult inPlaceAnalysis(SmallVector &ops, + BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo, + unsigned analysisFuzzerSeed = 0) { if (analysisFuzzerSeed) { // This is a fuzzer. For testing purposes only. Randomize the order in which // operations are analyzed. The bufferization quality is likely worse, but @@ -1030,8 +1016,11 @@ for (Operation *op : reverse(ops)) for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) - if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) - return failure(); + if (auto bufferizableOp = dyn_cast(op)) + if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand)) + if (failed(bufferizableInPlaceAnalysisImpl(opOperand, opResult, + aliasInfo, domInfo))) + return failure(); return success(); } @@ -1076,26 +1065,6 @@ // Bufferization entry-point for functions. //===----------------------------------------------------------------------===// -Optional mlir::linalg::comprehensive_bufferize::defaultAllocationFn( - OpBuilder &b, Location loc, MemRefType type, - const SmallVector &dynShape) { - Value allocated = b.create( - loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); - return allocated; -} - -void mlir::linalg::comprehensive_bufferize::defaultDeallocationFn( - OpBuilder &b, Location loc, Value allocatedBuffer) { - b.create(loc, allocatedBuffer); -} - -void mlir::linalg::comprehensive_bufferize::defaultMemCpyFn(OpBuilder &b, - Location loc, - Value from, - Value to) { - b.create(loc, from, to); -} - LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp( Operation *op, BufferizationState &state, DenseMap *bufferizedFunctionTypes) { @@ -1642,6 +1611,30 @@ return success(); } +/// Default allocation function that is used by the comprehensive bufferization +/// pass. The default currently creates a ranked memref using `memref.alloc`. +static Optional defaultAllocationFn(OpBuilder &b, Location loc, + MemRefType type, + const SmallVector &dynShape) { + Value allocated = b.create( + loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + return allocated; +} + +/// Default deallocation function that is used by the comprehensive +/// bufferization pass. It expects to recieve back the value called from the +/// `defaultAllocationFn`. +static void defaultDeallocationFn(OpBuilder &b, Location loc, + Value allocatedBuffer) { + b.create(loc, allocatedBuffer); +} + +/// Default memory copy function that is used by the comprehensive bufferization +/// pass. Creates a `memref.copy` op. +static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) { + b.create(loc, from, to); +} + std::unique_ptr mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() { return std::make_unique(