diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -18,6 +18,8 @@ class BlockAndValueMapping; namespace linalg { +namespace comprehensive_bufferize { + struct AllocationCallbacks; class BufferizationAliasInfo; @@ -28,6 +30,37 @@ // TODO: OperandContainsResult, Equivalent }; + +/// Determine which OpOperand* will alias with `result` if the op is bufferized +/// in place. Return an empty vector if the op is not bufferizable. +SmallVector getAliasingOpOperand(OpResult result); + +/// Determine which OpResult will alias with `opOperand` if the op is bufferized +/// in place. Return an empty OpResult if the op is not bufferizable. +OpResult getAliasingOpResult(OpOperand &opOperand); + +/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the +/// op is not bufferizable. +bool bufferizesToMemoryRead(OpOperand &opOperand); + +/// Return true if `opOperand` bufferizes to a memory write. Return +/// `true` if the op is not bufferizable. +bool bufferizesToMemoryWrite(OpOperand &opOperand); + +/// Return true if `opOperand` does neither read nor write but bufferizes to an +/// alias. Return false if the op is not bufferizable. +bool bufferizesToAliasOnly(OpOperand &opOperand); + +/// Return true if the given value is read by an op that bufferizes to a memory +/// read. Also takes into account ops that create an alias but do not read by +/// themselves (e.g., ExtractSliceOp). +bool isValueRead(Value value); + +/// Return the relationship between the operand and the its corresponding +/// OpResult that it may alias with. Return None if the op is not bufferizable. +BufferRelation bufferRelation(OpOperand &opOperand); + +} // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -16,7 +16,7 @@ An op interface for Comprehensive Bufferization. Ops that implement this interface can be bufferized using Comprehensive Bufferization. }]; - let cppNamespace = "::mlir::linalg"; + let cppNamespace = "::mlir::linalg::comprehensive_bufferize"; let methods = [ InterfaceMethod< /*desc=*/[{ diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -22,6 +22,7 @@ class ModuleOp; namespace linalg { +namespace comprehensive_bufferize { // TODO: from some HW description. static constexpr int64_t kBufferAlignments = 128; @@ -217,6 +218,7 @@ LogicalResult runComprehensiveBufferize(ModuleOp moduleOp, const BufferizationOptions &options); +} // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -7,11 +7,114 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/IR/Operation.h" namespace mlir { namespace linalg { +namespace comprehensive_bufferize { #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp.inc" +} // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir + +using namespace mlir; +using namespace linalg::comprehensive_bufferize; + +//===----------------------------------------------------------------------===// +// Helper functions for BufferizableOpInterface +//===----------------------------------------------------------------------===// + +/// Determine which OpOperand* will alias with `result` if the op is bufferized +/// in place. Return an empty vector if the op is not bufferizable. +SmallVector +mlir::linalg::comprehensive_bufferize::getAliasingOpOperand(OpResult result) { + if (Operation *op = result.getDefiningOp()) + if (auto bufferizableOp = dyn_cast(op)) + return bufferizableOp.getAliasingOpOperand(result); + return {}; +} + +/// Determine which OpResult will alias with `opOperand` if the op is bufferized +/// in place. Return an empty OpResult if the op is not bufferizable. +OpResult mlir::linalg::comprehensive_bufferize::getAliasingOpResult( + OpOperand &opOperand) { + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.getAliasingOpResult(opOperand); + return OpResult(); +} + +/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the +/// op is not bufferizable. +bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryRead( + OpOperand &opOperand) { + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferizesToMemoryRead(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return true. + return true; +} + +/// Return true if `opOperand` bufferizes to a memory write. Return +/// `true` if the op is not bufferizable. +bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryWrite( + OpOperand &opOperand) { + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferizesToMemoryWrite(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return true. + return true; +} + +/// Return true if `opOperand` does neither read nor write but bufferizes to an +/// alias. Return false if the op is not bufferizable. +bool mlir::linalg::comprehensive_bufferize::bufferizesToAliasOnly( + OpOperand &opOperand) { + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferizesToAliasOnly(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return false. + return false; +} + +/// Return true if the given value is read by an op that bufferizes to a memory +/// read. Also takes into account ops that create an alias but do not read by +/// themselves (e.g., ExtractSliceOp). +bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) { + SmallVector workingSet; + for (OpOperand &use : value.getUses()) + workingSet.push_back(&use); + + while (!workingSet.empty()) { + OpOperand *uMaybeReading = workingSet.pop_back_val(); + // Skip over all ops that neither read nor write (but create an alias). + if (bufferizesToAliasOnly(*uMaybeReading)) + for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) + workingSet.push_back(&use); + if (bufferizesToMemoryRead(*uMaybeReading)) + return true; + } + + return false; +} + +/// Return the relationship between the operand and the its corresponding +/// OpResult that it may alias with. Return None if the op is not bufferizable. +BufferRelation +mlir::linalg::comprehensive_bufferize::bufferRelation(OpOperand &opOperand) { + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferRelation(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return None. + return BufferRelation::None; +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -132,6 +132,7 @@ using namespace mlir; using namespace linalg; using namespace tensor; +using namespace comprehensive_bufferize; #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) @@ -404,97 +405,6 @@ return result; } -//===----------------------------------------------------------------------===// -// Helper functions for BufferizableOpInterface -//===----------------------------------------------------------------------===// - -/// Determine which OpOperand* will alias with `result` if the op is bufferized -/// in place. Return an empty vector if the op is not bufferizable. -static SmallVector getAliasingOpOperand(OpResult result) { - if (Operation *op = result.getDefiningOp()) - if (auto bufferizableOp = dyn_cast(op)) - return bufferizableOp.getAliasingOpOperand(result); - return {}; -} - -/// Determine which OpResult will alias with `opOperand` if the op is bufferized -/// in place. Return an empty OpResult if the op is not bufferizable. -static OpResult getAliasingOpResult(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.getAliasingOpResult(opOperand); - return OpResult(); -} - -/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the -/// op is not bufferizable. -static bool bufferizesToMemoryRead(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToMemoryRead(opOperand); - - // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return true. - return true; -} - -/// Return true if `opOperand` bufferizes to a memory write. Return -/// `true` if the op is not bufferizable. -static bool bufferizesToMemoryWrite(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToMemoryWrite(opOperand); - - // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return true. - return true; -} - -/// Return true if `opOperand` does neither read nor write but bufferizes to an -/// alias. Return false if the op is not bufferizable. -static bool bufferizesToAliasOnly(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToAliasOnly(opOperand); - - // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return false. - return false; -} - -/// Return true if the given value is read by an op that bufferizes to a memory -/// read. Also takes into account ops that create an alias but do not read by -/// themselves (e.g., ExtractSliceOp). -static bool isValueRead(Value value) { - SmallVector workingSet; - for (OpOperand &use : value.getUses()) - workingSet.push_back(&use); - - while (!workingSet.empty()) { - OpOperand *uMaybeReading = workingSet.pop_back_val(); - // Skip over all ops that neither read nor write (but create an alias). - if (bufferizesToAliasOnly(*uMaybeReading)) - for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) - workingSet.push_back(&use); - if (bufferizesToMemoryRead(*uMaybeReading)) - return true; - } - - return false; -} - -/// Return the relationship between the operand and the its corresponding -/// OpResult that it may alias with. Return None if the op is not bufferizable. -static BufferRelation bufferRelation(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferRelation(opOperand); - - // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return None. - return BufferRelation::None; -} - //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -1623,10 +1533,9 @@ /// Analyze the `ops` to determine which OpResults are inplaceable. Walk ops in /// reverse and bufferize ops greedily. This is a good starter heuristic. /// ExtractSliceOps are interleaved with other ops in traversal order. -LogicalResult mlir::linalg::inPlaceAnalysis(SmallVector &ops, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo, - unsigned analysisFuzzerSeed) { +LogicalResult mlir::linalg::comprehensive_bufferize::inPlaceAnalysis( + SmallVector &ops, BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo, unsigned analysisFuzzerSeed) { if (analysisFuzzerSeed) { // This is a fuzzer. For testing purposes only. Randomize the order in which // operations are analyzed. The bufferization quality is likely worse, but @@ -1685,25 +1594,27 @@ // Bufferization entry-point for functions. //===----------------------------------------------------------------------===// -Optional -mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type, - const SmallVector &dynShape) { +Optional mlir::linalg::comprehensive_bufferize::defaultAllocationFn( + OpBuilder &b, Location loc, MemRefType type, + const SmallVector &dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated; } -void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc, - Value allocatedBuffer) { +void mlir::linalg::comprehensive_bufferize::defaultDeallocationFn( + OpBuilder &b, Location loc, Value allocatedBuffer) { b.create(loc, allocatedBuffer); } -void mlir::linalg::defaultMemCpyFn(OpBuilder &b, Location loc, Value from, - Value to) { +void mlir::linalg::comprehensive_bufferize::defaultMemCpyFn(OpBuilder &b, + Location loc, + Value from, + Value to) { b.create(loc, from, to); } -LogicalResult mlir::linalg::bufferizeOp( +LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp( Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns, DenseMap *bufferizedFunctionTypes) { @@ -2119,7 +2030,7 @@ /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single InitTensorOp. -LogicalResult mlir::linalg::initTensorElimination( +LogicalResult mlir::linalg::comprehensive_bufferize::initTensorElimination( FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo, std::function anchorMatchFunc, std::function rewriteFunc, @@ -2214,8 +2125,10 @@ /// /// Note that the newly inserted ExtractSliceOp may have to bufferize /// out-of-place due to RaW conflicts. -LogicalResult mlir::linalg::eliminateInsertSliceAnchoredInitTensorOps( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo) { +LogicalResult mlir::linalg::comprehensive_bufferize:: + eliminateInsertSliceAnchoredInitTensorOps(FuncOp funcOp, + BufferizationAliasInfo &aliasInfo, + DominanceInfo &domInfo) { return initTensorElimination( funcOp, aliasInfo, domInfo, [](OpOperand &operand) { @@ -2256,9 +2169,8 @@ } #endif -LogicalResult -mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp, - const BufferizationOptions &options) { +LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( + ModuleOp moduleOp, const BufferizationOptions &options) { SmallVector orderedFuncOps; DenseMap> callerMap; DenseMap bufferizedFunctionTypes; @@ -2356,6 +2268,7 @@ namespace mlir { namespace linalg { +namespace comprehensive_bufferize { namespace arith_ext { struct ConstantOpInterface @@ -3585,5 +3498,6 @@ >::registerOpInterface(registry); } +} // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -16,6 +16,7 @@ using namespace mlir; using namespace mlir::linalg; +using namespace mlir::linalg::comprehensive_bufferize; namespace { struct LinalgComprehensiveModuleBufferize