diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -26,7 +26,84 @@ namespace linalg { namespace comprehensive_bufferize { -class BufferizationAliasInfo; +// TODO: from some HW description. +static constexpr int64_t kBufferAlignments = 128; + +struct BufferizationState; + +/// Callback functions that are used to allocate/deallocate/copy memory buffers. +/// Comprehensive Bufferize provides default implementations of these functions. +// TODO: Could be replaced with a "bufferization strategy" object with virtual +// functions in the future. +struct AllocationCallbacks { + using AllocationFn = std::function( + OpBuilder &, Location, MemRefType, ArrayRef)>; + using DeallocationFn = std::function; + using MemCpyFn = std::function; + + AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, + MemCpyFn copyFn) + : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {} + + /// A function that allocates memory. + AllocationFn allocationFn; + + /// A function that deallocated memory. Must be allocated by `allocationFn`. + DeallocationFn deallocationFn; + + /// A function that copies memory between two allocations. + MemCpyFn memCpyFn; +}; + +/// Return default allocation callbacks. +std::unique_ptr defaultAllocationCallbacks(); + +/// PostAnalysisSteps can be registered with `BufferizationOptions` and are +/// executed after the analysis, but before bufferization. They can be used +/// implement custom dialect-specific optimizations. +struct PostAnalysisStep { + virtual ~PostAnalysisStep() {} + + /// Run the post analysis step. This function may modify the IR, but must keep + /// `aliasInfo` (inside `state`) consistent. Newly created operations and + /// operations that should be re-analyzed must be stored in `newOps`. + virtual LogicalResult run(FuncOp funcOp, BufferizationState &state, + SmallVector &newOps) = 0; +}; + +/// Options for ComprehensiveBufferize. +struct BufferizationOptions { + BufferizationOptions(); + + // BufferizationOptions cannot be copied. + BufferizationOptions(const BufferizationOptions &other) = delete; + + /// Register a "post analysis" step. Such steps are executed after the + /// analysis, but before bufferization. + template + void addPostAnalysisStep(Args... args) { + postAnalysisSteps.emplace_back( + std::make_unique(std::forward(args)...)); + } + + /// Helper functions for allocation, deallocation, memory copying. + std::unique_ptr allocationFns; + + /// Specifies whether returning newly allocated memrefs should be allowed. + /// Otherwise, a pass failure is triggered. + bool allowReturnMemref = false; + + /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated. + /// Should be used only with `testAnalysisOnly = true`. + unsigned analysisFuzzerSeed = 0; + + /// If set to `true`, does not modify the IR apart from adding attributes (for + /// checking the results of the analysis) and post analysis steps. + bool testAnalysisOnly = false; + + /// Registered post analysis steps. + std::vector> postAnalysisSteps; +}; /// Specify fine-grain relationship between buffers to enable more analysis. enum class BufferRelation { @@ -204,32 +281,6 @@ /// is returned regardless of whether it is a memory write or not. Value findLastPrecedingWrite(Value value); -struct BufferizationState; - -/// Callback functions that are used to allocate/deallocate/copy memory buffers. -/// Comprehensive Bufferize provides default implementations of these functions. -// TODO: Could be replaced with a "bufferization strategy" object with virtual -// functions in the future. -struct AllocationCallbacks { - using AllocationFn = std::function( - OpBuilder &, Location, MemRefType, ArrayRef)>; - using DeallocationFn = std::function; - using MemCpyFn = std::function; - - AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, - MemCpyFn copyFn) - : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {} - - /// A function that allocates memory. - AllocationFn allocationFn; - - /// A function that deallocated memory. Must be allocated by `allocationFn`. - DeallocationFn deallocationFn; - - /// A function that copies memory between two allocations. - MemCpyFn memCpyFn; -}; - /// Dialect-specific bufferization state. Analysis/bufferization information /// that is specific to ops from a certain dialect can be stored in derived /// variants of this struct. @@ -240,8 +291,8 @@ /// BufferizationState keeps track of bufferization state and provides access to /// the results of the analysis. struct BufferizationState { - BufferizationState(ModuleOp moduleOp, AllocationCallbacks &allocationFns) - : aliasInfo(moduleOp), allocationFns(allocationFns) {} + BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options) + : aliasInfo(moduleOp), options(options) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; @@ -289,10 +340,6 @@ /// `aliasInfo` keeps track of aliasing and equivalent values. BufferizationAliasInfo aliasInfo; - /// `allocationFns` contains helper functions for creating alloc ops, dealloc - /// ops and memcpy ops. - AllocationCallbacks &allocationFns; - /// The mapping of tensors to buffers. May also contain mappings of non-tensor /// values. BlockAndValueMapping mapping; @@ -302,6 +349,9 @@ /// Dialect-specific bufferization state. DenseMap> dialectState; + + /// A reference to current bufferization options. + const BufferizationOptions &options; }; /// Return the result buffer (memref) for a given OpResult (tensor). Allocate @@ -320,19 +370,6 @@ /// method of `BufferizableOpInterface`. LogicalResult bufferize(Operation *op, BufferizationState &state); -/// PostAnalysisSteps can be registered with `BufferizationOptions` and are -/// executed after the analysis, but before bufferization. They can be used -/// implement custom dialect-specific optimizations. -struct PostAnalysisStep { - virtual ~PostAnalysisStep() {} - - /// Run the post analysis step. This function may modify the IR, but must keep - /// `aliasInfo` (inside `state`) consistent. Newly created operations and - /// operations that should be re-analyzed must be stored in `newOps`. - virtual LogicalResult run(FuncOp funcOp, BufferizationState &state, - SmallVector &newOps) = 0; -}; - /// Return a contiguous MemRefType (i.e. with canonical/empty layout map) /// with the same shape as `shapedType` and specified `layout` and /// `addressSpace`. diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -9,54 +9,15 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/IR/BuiltinOps.h" namespace mlir { -class ModuleOp; - namespace linalg { namespace comprehensive_bufferize { -// TODO: from some HW description. -static constexpr int64_t kBufferAlignments = 128; - -/// Return default allocation callbacks. -std::unique_ptr defaultAllocationCallbacks(); - -/// Options for ComprehensiveBufferize. -struct BufferizationOptions { - BufferizationOptions(); - - // BufferizationOptions cannot be copied. - BufferizationOptions(const BufferizationOptions &other) = delete; - - /// Register a "post analysis" step. Such steps are executed after the - /// analysis, but before bufferization. - template - void addPostAnalysisStep(Args... args) { - postAnalysisSteps.emplace_back( - std::make_unique(std::forward(args)...)); - } - - /// Helper functions for allocation, deallocation, memory copying. - std::unique_ptr allocationFns; - - /// Specifies whether returning newly allocated memrefs should be allowed. - /// Otherwise, a pass failure is triggered. - bool allowReturnMemref = false; - - /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated. - /// Should be used only with `testAnalysisOnly = true`. - unsigned analysisFuzzerSeed = 0; - - /// If set to `true`, does not modify the IR apart from adding attributes (for - /// checking the results of the analysis) and post analysis steps. - bool testAnalysisOnly = false; - - /// Registered post analysis steps. - std::vector> postAnalysisSteps; -}; +struct BufferizationOptions; +struct BufferizationState; /// Bufferize the given function. Does not bufferize the function boundary. // TODO: This function is meant to be called from ModuleBufferize and not can diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -35,6 +35,45 @@ using namespace mlir; using namespace linalg::comprehensive_bufferize; +//===----------------------------------------------------------------------===// +// BufferizationOptions +//===----------------------------------------------------------------------===// + +/// Default allocation function that is used by the comprehensive bufferization +/// pass. The default currently creates a ranked memref using `memref.alloc`. +static Optional defaultAllocationFn(OpBuilder &b, Location loc, + MemRefType type, + ArrayRef dynShape) { + Value allocated = b.create( + loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + return allocated; +} + +/// Default deallocation function that is used by the comprehensive +/// bufferization pass. It expects to recieve back the value called from the +/// `defaultAllocationFn`. +static void defaultDeallocationFn(OpBuilder &b, Location loc, + Value allocatedBuffer) { + b.create(loc, allocatedBuffer); +} + +/// Default memory copy function that is used by the comprehensive bufferization +/// pass. Creates a `memref.copy` op. +static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) { + b.create(loc, from, to); +} + +std::unique_ptr +mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() { + return std::make_unique( + defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn); +} + +// Default constructor for BufferizationOptions that sets all allocation +// callbacks to their default functions. +BufferizationOptions::BufferizationOptions() + : allocationFns(defaultAllocationCallbacks()) {} + //===----------------------------------------------------------------------===// // BufferizationAliasInfo //===----------------------------------------------------------------------===// @@ -384,7 +423,8 @@ if (!skipCopy) { // The copy happens right before the op that is bufferized. b.setInsertionPoint(op); - state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer); + state.options.allocationFns->memCpyFn(b, loc, operandBuffer, + resultBuffer); } return resultBuffer; } @@ -537,7 +577,7 @@ MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); Optional allocated = - allocationFns.allocationFn(b, loc, allocMemRefType, dynShape); + options.allocationFns->allocationFn(b, loc, allocMemRefType, dynShape); // TODO: For now just assert the value is returned. Eventually need to // error-propagate. assert(allocated && "allocation failed"); @@ -549,7 +589,7 @@ // 2. Create memory deallocation. b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); - allocationFns.deallocationFn(b, loc, allocated.getValue()); + options.allocationFns->deallocationFn(b, loc, allocated.getValue()); return casted; } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -783,39 +783,3 @@ return success(); } - -/// Default allocation function that is used by the comprehensive bufferization -/// pass. The default currently creates a ranked memref using `memref.alloc`. -static Optional defaultAllocationFn(OpBuilder &b, Location loc, - MemRefType type, - ArrayRef dynShape) { - Value allocated = b.create( - loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); - return allocated; -} - -/// Default deallocation function that is used by the comprehensive -/// bufferization pass. It expects to recieve back the value called from the -/// `defaultAllocationFn`. -static void defaultDeallocationFn(OpBuilder &b, Location loc, - Value allocatedBuffer) { - b.create(loc, allocatedBuffer); -} - -/// Default memory copy function that is used by the comprehensive bufferization -/// pass. Creates a `memref.copy` op. -static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) { - b.create(loc, from, to); -} - -std::unique_ptr -mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() { - return std::make_unique( - defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn); -} - -// Default constructor for BufferizationOptions that sets all allocation -// callbacks to their default functions. -BufferizationOptions::BufferizationOptions() - : allocationFns(defaultAllocationCallbacks()) {} - diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -648,7 +648,7 @@ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return failure(); - BufferizationState state(moduleOp, *options.allocationFns); + BufferizationState state(moduleOp, options); BufferizationAliasInfo &aliasInfo = state.aliasInfo; // Interestingly, all function args that are not visible outside of a module diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -189,8 +189,8 @@ if (!inplace) { // Do not copy if the copied data is never read. if (isValueRead(extractSliceOp.result())) - state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView, - alloc); + state.options.allocationFns->memCpyFn(b, extractSliceOp.getLoc(), + subView, alloc); subView = alloc; } @@ -464,8 +464,8 @@ state.aliasInfo.insertNewBufferAlias(subView, dstMemref); // Copy tensor. Value srcMemref = state.lookupBuffer(insertSliceOp.source()); - state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, - subView); + state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(), + srcMemref, subView); } state.mapBuffer(insertSliceOp.result(), dstMemref);