diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -14,6 +14,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/SetVector.h" namespace mlir { class BlockAndValueMapping; @@ -21,7 +22,6 @@ namespace linalg { namespace comprehensive_bufferize { -struct AllocationCallbacks; class BufferizationAliasInfo; /// Specify fine-grain relationship between buffers to enable more analysis. @@ -160,6 +160,61 @@ /// OpResult that it may alias with. Return None if the op is not bufferizable. BufferRelation bufferRelation(OpOperand &opOperand); +/// Starting from `value`, follow the use-def chain in reverse, always selecting +/// the aliasing OpOperands. Find and return Values for which `condition` +/// evaluates to true. OpOperands of such matching Values are not traversed any +/// further. +/// +/// When reaching the end of a chain (BlockArgument or Value without aliasing +/// OpOperands), also return the last Value of that chain. +/// +/// Example: +/// +/// 8 +/// | +/// 6* 7* +-----+----+ +/// | | | | +/// 2* 3 4* 5 +/// | | | | +/// +----------+----------+----------+ +/// | +/// 1 +/// +/// In the above example, Values with a star satisfy the condition. When +/// starting the traversal from Value 1, the resulting SetVector is: +/// { 2, 7, 8, 5 } +llvm::SetVector +findValueInReverseUseDefChain(Value value, + std::function condition); + +/// Find the Value of the last preceding write of a given Value. +/// +/// Note: Unknown ops are handled conservatively and assumed to be writes. +/// Furthermore, BlockArguments are also assumed to be writes. There is no +/// analysis across block boundaries. +/// +/// Note: When reaching an end of the reverse SSA use-def chain, that value +/// is returned regardless of whether it is a memory write or not. +Value findLastPrecedingWrite(Value value); + +/// Callback functions that are used by the comprehensive bufferization pass to +/// allocate/deallocate memory. The `deallocationFn` is gauranteed to recieve +/// the `Value` returned by the `allocationFn`. +struct AllocationCallbacks { + using AllocationFn = std::function( + OpBuilder &, Location, MemRefType, const SmallVector &)>; + using DeallocationFn = std::function; + using MemCpyFn = std::function; + + AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, + MemCpyFn copyFn) + : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {} + + AllocationFn allocationFn; + DeallocationFn deallocationFn; + MemCpyFn memCpyFn; +}; + } // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -23,6 +23,7 @@ namespace linalg { namespace comprehensive_bufferize { +struct AllocationCallbacks; class BufferizationAliasInfo; // TODO: from some HW description. @@ -49,30 +50,6 @@ /// pass. Creates a `linalg.copy` op. void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to); -/// Callback functions that are used by the comprehensive bufferization pass to -/// allocate/deallocate memory. These default to use the -/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the -/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned -/// by the `allocationFn`. -struct AllocationCallbacks { - using AllocationFn = std::function( - OpBuilder &, Location, MemRefType, const SmallVector &)>; - using DeallocationFn = std::function; - using MemCpyFn = std::function; - - AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, - MemCpyFn copyFn) - : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {} - - AllocationCallbacks() - : allocationFn(defaultAllocationFn), - deallocationFn(defaultDeallocationFn), memCpyFn(defaultMemCpyFn) {} - - AllocationFn allocationFn; - DeallocationFn deallocationFn; - MemCpyFn memCpyFn; -}; - /// Bufferize one particular op. /// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be /// non-null if `op` is a CallOpInterface (resp. GlobalCreator). @@ -108,8 +85,7 @@ FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo); struct BufferizationOptions { - BufferizationOptions() - : allocationFns(std::make_unique()) {} + BufferizationOptions(); std::unique_ptr allocationFns; bool allowReturnMemref = false; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -260,3 +260,56 @@ // Conservatively return None. return BufferRelation::None; } + +// Starting from `value`, follow the use-def chain in reverse, always selecting +// the aliasing OpOperands. Find and return Values for which `condition` +// evaluates to true. OpOperands of such matching Values are not traversed any +// further. +llvm::SetVector +mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain( + Value value, std::function condition) { + llvm::SetVector result, workingSet; + workingSet.insert(value); + + while (!workingSet.empty()) { + Value value = workingSet.pop_back_val(); + if (condition(value) || value.isa()) { + result.insert(value); + continue; + } + + OpResult opResult = value.cast(); + SmallVector opOperands = getAliasingOpOperand(opResult); + if (opOperands.empty()) { + result.insert(value); + continue; + } + + for (OpOperand *o : opOperands) + workingSet.insert(o->get()); + } + + return result; +} + +// Find the Value of the last preceding write of a given Value. +Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite( + Value value) { + SetVector result = + findValueInReverseUseDefChain(value, [](Value value) { + Operation *op = value.getDefiningOp(); + if (!op) + return true; + auto bufferizableOp = dyn_cast(op); + if (!bufferizableOp) + return true; + return bufferizableOp.isMemoryWrite(value.cast()); + }); + + // To simplify the analysis, `scf.if` ops are considered memory writes. There + // are currently no other ops where one OpResult may alias with multiple + // OpOperands. Therefore, this function should return exactly one result at + // the moment. + assert(result.size() == 1 && "expected exactly one result"); + return result.front(); +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -402,84 +402,6 @@ return foundInplaceWrite; } -/// Starting from `value`, follow the use-def chain in reverse, always selecting -/// the aliasing OpOperands. Find and return Values for which `condition` -/// evaluates to true. OpOperands of such matching Values are not traversed any -/// further. -/// -/// When reaching the end of a chain (BlockArgument or Value without aliasing -/// OpOperands), also return the last Value of that chain. -/// -/// Example: -/// -/// 8 -/// | -/// 6* 7* +-----+----+ -/// | | | | -/// 2* 3 4* 5 -/// | | | | -/// +----------+----------+----------+ -/// | -/// 1 -/// -/// In the above example, Values with a star satisfy the condition. When -/// starting the traversal from Value 1, the resulting SetVector is: -/// { 2, 7, 8, 5 } -static llvm::SetVector -findValueInReverseUseDefChain(Value value, - std::function condition) { - llvm::SetVector result, workingSet; - workingSet.insert(value); - - while (!workingSet.empty()) { - Value value = workingSet.pop_back_val(); - if (condition(value) || value.isa()) { - result.insert(value); - continue; - } - - OpResult opResult = value.cast(); - SmallVector opOperands = getAliasingOpOperand(opResult); - if (opOperands.empty()) { - result.insert(value); - continue; - } - - for (OpOperand *o : opOperands) - workingSet.insert(o->get()); - } - - return result; -} - -/// Find the Value of the last preceding write of a given Value. -/// -/// Note: Unknown ops are handled conservatively and assumed to be writes. -/// Furthermore, BlockArguments are also assumed to be writes. There is no -/// analysis across block boundaries. -/// -/// Note: When reaching an end of the reverse SSA use-def chain, that value -/// is returned regardless of whether it is a memory write or not. -static Value findLastPrecedingWrite(Value value) { - SetVector result = - findValueInReverseUseDefChain(value, [](Value value) { - Operation *op = value.getDefiningOp(); - if (!op) - return true; - auto bufferizableOp = dyn_cast(op); - if (!bufferizableOp) - return true; - return bufferizableOp.isMemoryWrite(value.cast()); - }); - - // To simplify the analysis, `scf.if` ops are considered memory writes. There - // are currently no other ops where one OpResult may alias with multiple - // OpOperands. Therefore, this function should return exactly one result at - // the moment. - assert(result.size() == 1 && "expected exactly one result"); - return result.front(); -} - /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, @@ -1980,6 +1902,12 @@ return success(); } +// Default constructor for BufferizationOptions that sets all allocation +// callbacks to their default functions. +BufferizationOptions::BufferizationOptions() + : allocationFns(std::make_unique( + defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn)) {} + //===----------------------------------------------------------------------===// // BufferizableOpInterface Implementations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Pass/Pass.h" diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6374,6 +6374,7 @@ ":AffineUtils", ":Analysis", ":ArithmeticDialect", + ":BufferizableOpInterface", ":ComplexDialect", ":ComprehensiveBufferize", ":DialectUtils",