diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -0,0 +1,239 @@ +//===- ComprehensiveBufferize.h - Linalg bufferization pass -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_COMPREHENSIVE_BUFFERIZE_H +#define MLIR_DIALECT_LINALG_TRANSFORMS_COMPREHENSIVE_BUFFERIZE_H + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/SetOperations.h" + +namespace mlir { + +class DominanceInfo; +class FuncOp; +class GlobalCreator; + +namespace linalg { + +/// The BufferizationAliasInfo class maintains a list of buffer aliases and +/// equivalence classes to support bufferization. +/// ExtractSliceOps have special behavior, they act as a level of indirection +/// for bufferization. They don't create reads or writes themselves and analysis +/// needs to look through their uses. +/// ExtractSliceOp + InsertSliceOp have special joint behavior: they may +/// bufferize to the same buffer (i.e. subview), which is what introduces the +/// need for bufferization classes. +/// Some of these functionalities could be refactored in a Bufferizer class that +/// uses BufferizationAliasInfo. +class BufferizationAliasInfo { +public: + /// Specify fine-grain relationship between buffers to enable more analysis. + enum class BufferRelation { + None, + // TODO: ResultContainsOperand, + // TODO: OperandContainsResult, + Equivalent + }; + + explicit BufferizationAliasInfo(Operation *rootOp); + + /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the + /// beginning the alias and equivalence sets only contain `v` itself. + void createAliasInfoEntry(Value v); + + /// Insert an info entry for `newValue` and merge its alias set with that of + /// `alias`. + void insertNewBufferAlias(Value newValue, Value alias); + + /// Insert an info entry for `newValue` and merge its alias set with that of + /// `alias`. Additionally, merge their equivalence classes. + void insertNewBufferEquivalence(Value newValue, Value alias); + + /// Return true if the buffer to which `operand` would bufferize aliases a + /// buffer that is known to not be writeable. This implies that the matching + /// OpResult cannot be bufferized inplace. + bool aliasesNonWriteableBuffer(OpOperand &operand) const; + + /// Return true if the buffer to which `operand` would bufferize is equivalent + /// to some buffer write. + bool aliasesInPlaceWrite(Value v) const; + + /// Set the inPlace bufferization spec to true. + /// Merge result's and operand's aliasing sets and iterate to a fixed point. + void bufferizeInPlace(OpResult result, OpOperand &operand, + BufferRelation bufferRelation = BufferRelation::None); + + /// Set the inPlace bufferization spec to false. + void bufferizeOutOfPlace(OpResult result); + + /// Return true if it is possible to find an inplace write W among `usesWrite` + /// and a read R among `usesRead`, such that W and R interfere. + /// Such a (W, R) pair is an interference to the inplace bufferization of + /// opResult when: + /// 1. R is not known properly dominate W (i.e. the effects of the write may + /// be visible from R). + /// 2. one cannot find an intermediate clobbering write `C` to W, such that + /// C interleaved between W and R (i.e. W -> C -> R where -> denotes + /// dominance). + bool wouldCreateReadAfterWriteInterference( + Operation *opToBufferize, DenseSet &usesRead, + DenseSet &usesWrite, const DominanceInfo &domInfo) const; + + /// Assume that result bufferizes in-place with one of the operation's + /// operands. Return true if it is possible to find an inplace write W (resp. + /// a read R) among the uses of `aliasInfo[result]`, and a read R (resp. an + /// inplace write W) among the uses of + /// `aliasInfo[getAliasingOpOperand(result)]`, such that W and R interfere. + /// Interference detection is needed to determine which cases may bufferize + /// inplace without interferences. Such cases comprise: + /// + /// ``` + /// %0 = op_to_bufferize(%1) + /// read(%1) + /// + /// %0 = op_to_bufferize(%1) + /// write(%0) + /// read(%1) + /// + /// %0 = op_to_bufferize(%1) + /// write(%1) + /// read(%0) + /// ``` + bool + wouldCreateReadAfterWriteInterference(OpResult result, + const DominanceInfo &domInfo) const; + + /// Return true if `v1` and `v2` bufferize to equivalent buffers. + bool areEquivalentBufferizedValues(Value v1, Value v2) const { + return equivalentInfo.getLeaderValue(v1) == + equivalentInfo.getLeaderValue(v2); + } + + /// Return true if the source of an `insertSliceOp` bufferizes to an + /// equivalent ExtractSliceOp. + bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( + tensor::InsertSliceOp insertSliceOp) const; + + /// Apply `fun` to all the members of the equivalence class of `v`. + void applyOnEquivalenceClass(Value v, function_ref fun) const; + + /// Print to `os`. + void printAliases(raw_ostream &os) const; + void printEquivalences(raw_ostream &os) const; + + /// Print to `errs()`. + void dumpAliases() const; + void dumpEquivalences() const; + +private: + /// llvm::EquivalenceClasses wants comparable elements because it uses + /// std::set as the underlying impl. + /// ValueWrapper wraps Value and uses pointer comparison on the defining op. + /// This is a poor man's comparison but it's not like UnionFind needs ordering + /// anyway .. + struct ValueWrapper { + ValueWrapper(Value val) : v(val) {} + operator Value() const { return v; } + bool operator<(const ValueWrapper &wrap) const { + return v.getImpl() < wrap.v.getImpl(); + } + bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; } + Value v; + }; + + using EquivalenceClassRangeType = llvm::iterator_range< + llvm::EquivalenceClasses::member_iterator>; + /// Check that aliasInfo for `v` exists and return a reference to it. + EquivalenceClassRangeType getAliases(Value v) const; + + /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. + /// equivalent operand / result and same offset/sizes/strides specification). + /// + /// This is one particular type of relationship between ops on tensors that + /// reduce to an equivalence on buffers. This should be generalized and + /// exposed as interfaces on the proper types. + bool areEquivalentExtractSliceOps(tensor::ExtractSliceOp st, + tensor::InsertSliceOp sti) const; + + /// Return true if there is a `candidateOp` that would write to memory after + /// bufferization and such that: + /// 1. The written buffer is equivalent to either `aliasingRead` or + /// `aliasingWrite` under the inPlace bufferization decisions taken + /// so far. + /// 2. `aliasingWrite` properly dominates `candidateOp`. + /// 3. `candidateOp` properly dominates `aliasingReadOp`. + // TODO: richer clobbering analysis with container-containee relationship + // instead of equivalence. + bool existsInterleavedValueClobber(OpOperand &aliasingRead, + OpOperand &aliasingWrite, + const DominanceInfo &domInfo) const; + + /// Return true if there is a write that: + /// 1. Properly dominates aliasingReadOp. + /// 2. Is properly dominated by aliasingWriteOp. + /// 3. Clobbers the write that would be interfering with the read. + /// + /// Case discussion: + /// ================ + /// Case 1: opOperand is produced by opToBufferize, + /// Case 2: opResult is produced by opToBufferize, + /// Common case: + /// - aliasingReadOp is a read to an alias of opOperand. + /// - aliasingWriteOp is an inplace write to an alias of opResult. + /// - aliasingWriteOp dominates aliasingReadOp. + /// + /// ``` + /// // Either case 1: + /// %opOperand = opToBufferize(%opResult) + /// aliasingWriteOp(%aliasingWrite = alias(%opResult)) // inplace + /// aliasingReadOp( %aliasingRead = alias(%opOperand)) + /// ``` + /// + /// ``` + /// // Or case 2: + /// %opResult = opToBufferize(%opOperand) + /// aliasingWriteOp(%aliasingWrite = alias(%opResult)) // inplace + /// aliasingReadOp( %aliasingRead = alias(%opOperand)) + /// ``` + /// + /// Capture possible cases where `aliasingWriteOp(alias(%opResult))` has no + /// visible effect on `aliasingReadOp(alias(%opOperand))`. + bool isClobberedWriteBeforeRead(Operation *opToBufferize, + OpOperand &aliasingRead, + OpOperand &aliasingWrite, + const DominanceInfo &domInfo) const; + + /// Auxiliary structure to store all the values a given value aliases with. + /// These are the conservative cases that can further decompose into + /// "equivalent" buffer relationships. + llvm::EquivalenceClasses aliasInfo; + + /// Auxiliary structure to store all the equivalent buffer classes. + llvm::EquivalenceClasses equivalentInfo; +}; + +/// Analyze the `ops` to determine which OpResults are inplaceable: +LogicalResult inPlaceAnalysis(SmallVector &ops, + BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo); + +/// Bufferize one particular op. +/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be +/// non-null if `op` is a CallOpInterface (resp. GlobalCreator). +LogicalResult +bufferizeOp(Operation *op, BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + DenseMap *bufferizedFunctionTypes = nullptr, + GlobalCreator *globalCreator = nullptr); + +} // namespace linalg +} // namespace mlir + +#endif // define MLIR_DIALECT_LINALG_TRANSFORMS_COMPREHENSIVE_BUFFERIZE_H diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -105,6 +105,8 @@ // expected layouts after transformations. Combinations of memref.cast + // canonicalization are responsible for clean ups. +#include "mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h" + #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -123,9 +125,7 @@ #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/ScopeExit.h" -#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" @@ -219,7 +219,7 @@ } else { parentOp = key.getDefiningOp()->getParentOfType(); } - LDBG("In func:\n" << *parentOp << "NO VALUE FOR KEY: " << key << '\n'); + LDBG("In func:\n" << *parentOp << "\nNO VALUE FOR KEY: " << key << '\n'); (void)parentOp; return Value(); } @@ -690,205 +690,6 @@ // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// -namespace { - -/// The BufferizationAliasInfo class maintains a list of buffer aliases and -/// equivalence classes to support bufferization. -/// ExtractSliceOps have special behavior, they act as a level of indirection -/// for bufferization. They don't create reads or writes themselves and analysis -/// needs to look through their uses. -/// ExtractSliceOp + InsertSliceOp have special joint behavior: they may -/// bufferize to the same buffer (i.e. subview), which is what introduces the -/// need for bufferization classes. -/// Some of these functionalities could be refactored in a Bufferizer class that -/// uses BufferizationAliasInfo. -class BufferizationAliasInfo { -public: - /// Specify fine-grain relationship between buffers to enable more analysis. - enum class BufferRelation { - None, - // TODO: ResultContainsOperand, - // TODO: OperandContainsResult, - Equivalent - }; - - explicit BufferizationAliasInfo(Operation *rootOp); - - /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the - /// beginning the alias and equivalence sets only contain `v` itself. - void createAliasInfoEntry(Value v); - - /// Insert an info entry for `newValue` and merge its alias set with that of - /// `alias`. - void insertNewBufferAlias(Value newValue, Value alias); - - /// Insert an info entry for `newValue` and merge its alias set with that of - /// `alias`. Additionally, merge their equivalence classes. - void insertNewBufferEquivalence(Value newValue, Value alias); - - /// Return true if the buffer to which `operand` would bufferize aliases a - /// buffer that is known to not be writeable. This implies that the matching - /// OpResult cannot be bufferized inplace. - bool aliasesNonWriteableBuffer(OpOperand &operand) const; - - /// Return true if the buffer to which `operand` would bufferize is equivalent - /// to some buffer write. - bool aliasesInPlaceWrite(Value v) const; - - /// Set the inPlace bufferization spec to true. - /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpResult result, OpOperand &operand, - BufferRelation bufferRelation = BufferRelation::None); - - /// Set the inPlace bufferization spec to false. - void bufferizeOutOfPlace(OpResult result); - - /// Return true if it is possible to find an inplace write W among `usesWrite` - /// and a read R among `usesRead`, such that W and R interfere. - /// Such a (W, R) pair is an interference to the inplace bufferization of - /// opResult when: - /// 1. R is not known properly dominate W (i.e. the effects of the write may - /// be visible from R). - /// 2. one cannot find an intermediate clobbering write `C` to W, such that - /// C interleaved between W and R (i.e. W -> C -> R where -> denotes - /// dominance). - bool wouldCreateReadAfterWriteInterference( - Operation *opToBufferize, DenseSet &usesRead, - DenseSet &usesWrite, const DominanceInfo &domInfo) const; - - /// Assume that result bufferizes in-place with one of the operation's - /// operands. Return true if it is possible to find an inplace write W (resp. - /// a read R) among the uses of `aliasInfo[result]`, and a read R (resp. an - /// inplace write W) among the uses of - /// `aliasInfo[getAliasingOpOperand(result)]`, such that W and R interfere. - /// Interference detection is needed to determine which cases may bufferize - /// inplace without interferences. Such cases comprise: - /// - /// ``` - /// %0 = op_to_bufferize(%1) - /// read(%1) - /// - /// %0 = op_to_bufferize(%1) - /// write(%0) - /// read(%1) - /// - /// %0 = op_to_bufferize(%1) - /// write(%1) - /// read(%0) - /// ``` - bool - wouldCreateReadAfterWriteInterference(OpResult result, - const DominanceInfo &domInfo) const; - - /// Return true if `v1` and `v2` bufferize to equivalent buffers. - bool areEquivalentBufferizedValues(Value v1, Value v2) const { - return equivalentInfo.getLeaderValue(v1) == - equivalentInfo.getLeaderValue(v2); - } - - /// Return true if the source of an `insertSliceOp` bufferizes to an - /// equivalent ExtractSliceOp. - bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( - InsertSliceOp insertSliceOp) const; - - /// Apply `fun` to all the members of the equivalence class of `v`. - void applyOnEquivalenceClass(Value v, function_ref fun) const; - - /// Print to `os`. - void printAliases(raw_ostream &os) const; - void printEquivalences(raw_ostream &os) const; - - /// Print to `errs()`. - void dumpAliases() const { printAliases(llvm::errs()); } - void dumpEquivalences() const { printEquivalences(llvm::errs()); } - -private: - /// llvm::EquivalenceClasses wants comparable elements because it uses - /// std::set as the underlying impl. - /// ValueWrapper wraps Value and uses pointer comparison on the defining op. - /// This is a poor man's comparison but it's not like UnionFind needs ordering - /// anyway .. - struct ValueWrapper { - ValueWrapper(Value val) : v(val) {} - operator Value() const { return v; } - bool operator<(const ValueWrapper &wrap) const { - return v.getImpl() < wrap.v.getImpl(); - } - bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; } - Value v; - }; - - using EquivalenceClassRangeType = llvm::iterator_range< - llvm::EquivalenceClasses::member_iterator>; - /// Check that aliasInfo for `v` exists and return a reference to it. - EquivalenceClassRangeType getAliases(Value v) const; - - /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. - /// equivalent operand / result and same offset/sizes/strides specification). - /// - /// This is one particular type of relationship between ops on tensors that - /// reduce to an equivalence on buffers. This should be generalized and - /// exposed as interfaces on the proper types. - bool areEquivalentExtractSliceOps(ExtractSliceOp st, InsertSliceOp sti) const; - - /// Return true if there is a `candidateOp` that would write to memory after - /// bufferization and such that: - /// 1. The written buffer is equivalent to either `aliasingRead` or - /// `aliasingWrite` under the inPlace bufferization decisions taken - /// so far. - /// 2. `aliasingWrite` properly dominates `candidateOp`. - /// 3. `candidateOp` properly dominates `aliasingReadOp`. - // TODO: richer clobbering analysis with container-containee relationship - // instead of equivalence. - bool existsInterleavedValueClobber(OpOperand &aliasingRead, - OpOperand &aliasingWrite, - const DominanceInfo &domInfo) const; - - /// Return true if there is a write that: - /// 1. Properly dominates aliasingReadOp. - /// 2. Is properly dominated by aliasingWriteOp. - /// 3. Clobbers the write that would be interfering with the read. - /// - /// Case discussion: - /// ================ - /// Case 1: opOperand is produced by opToBufferize, - /// Case 2: opResult is produced by opToBufferize, - /// Common case: - /// - aliasingReadOp is a read to an alias of opOperand. - /// - aliasingWriteOp is an inplace write to an alias of opResult. - /// - aliasingWriteOp dominates aliasingReadOp. - /// - /// ``` - /// // Either case 1: - /// %opOperand = opToBufferize(%opResult) - /// aliasingWriteOp(%aliasingWrite = alias(%opResult)) // inplace - /// aliasingReadOp( %aliasingRead = alias(%opOperand)) - /// ``` - /// - /// ``` - /// // Or case 2: - /// %opResult = opToBufferize(%opOperand) - /// aliasingWriteOp(%aliasingWrite = alias(%opResult)) // inplace - /// aliasingReadOp( %aliasingRead = alias(%opOperand)) - /// ``` - /// - /// Capture possible cases where `aliasingWriteOp(alias(%opResult))` has no - /// visible effect on `aliasingReadOp(alias(%opOperand))`. - bool isClobberedWriteBeforeRead(Operation *opToBufferize, - OpOperand &aliasingRead, - OpOperand &aliasingWrite, - const DominanceInfo &domInfo) const; - - /// Auxiliary structure to store all the values a given value aliases with. - /// These are the conservative cases that can further decompose into - /// "equivalent" buffer relationships. - llvm::EquivalenceClasses aliasInfo; - - /// Auxiliary structure to store all the equivalent buffer classes. - llvm::EquivalenceClasses equivalentInfo; -}; -} // namespace - BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { rootOp->walk([&](Operation *op) { for (Value v : op->getResults()) @@ -1232,6 +1033,12 @@ aliasInfo.member_begin(it), aliasInfo.member_end()); } +void BufferizationAliasInfo::dumpAliases() const { printAliases(llvm::errs()); } + +void BufferizationAliasInfo::dumpEquivalences() const { + printEquivalences(llvm::errs()); +} + /// This is one particular type of relationship between ops on tensors that /// reduce to an equivalence on buffers. This should be generalized and exposed /// as interfaces on the proper types. @@ -2453,30 +2260,15 @@ return success(); } -/// Analyze the `funcOp` body to determine which OpResults are inplaceable: +/// Analyze the `ops` to determine which OpResults are inplaceable: /// 1. First, analyze InsertSliceOp greedily: we almost never want to /// bufferize the tensor "inserted into" to become out-of-place. /// 2. Walk the other ops in reverse. This is a good starter heuristic. /// ExtractSliceOps are interleaved with other ops in traversal order. /// -static LogicalResult -inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { - LLVM_DEBUG(llvm::dbgs() << "\n\n"); - LDBG("Begin InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); - assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() && - "expected a funcOp definition with a body"); - - // Collect ops so we can build our own reverse traversal. - SmallVector ops; - funcOp.walk([&](Operation *op) { - // No tensors => no buffers. - if (none_of(op->getOperandTypes(), isaTensor) && - none_of(op->getResultTypes(), isaTensor)) - return; - ops.push_back(op); - }); - +LogicalResult mlir::linalg::inPlaceAnalysis(SmallVector &ops, + BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { // Walk ops in reverse for better interference analysis. for (Operation *op : reverse(ops)) { for (OpOperand &opOperand : op->getOpOperands()) { @@ -2498,63 +2290,68 @@ continue; } } + return success(); +} + +/// Analyze the `funcOp` body to determine which OpResults are inplaceable. +static LogicalResult +inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { + LLVM_DEBUG(llvm::dbgs() << "\n\n"); + LDBG("Begin InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); + assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() && + "expected a funcOp definition with a body"); + + // Collect ops so we can build our own reverse traversal. + SmallVector ops; + funcOp.walk([&](Operation *op) { + // No tensors => no buffers. + if (none_of(op->getOperandTypes(), isaTensor) && + none_of(op->getResultTypes(), isaTensor)) + return; + ops.push_back(op); + }); + LogicalResult res = inPlaceAnalysis(ops, aliasInfo, domInfo); LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); - return success(); + return res; } //===----------------------------------------------------------------------===// // Bufferization entry-point for functions. //===----------------------------------------------------------------------===// -static LogicalResult bufferizeFuncOpInternals( - FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - DenseMap &bufferizedFunctionTypes, - GlobalCreator &globalCreator) { - LLVM_DEBUG(llvm::dbgs() << "\n\n"); - LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); - OpBuilder b(funcOp->getContext()); - /// Start by bufferizing `funcOp` arguments. - if (failed(bufferize(b, funcOp, bvm, aliasInfo))) - return failure(); - // Walk in PreOrder to ensure ops with regions are handled before their body. - // Since walk has to be PreOrder, we need to erase ops that require it - // separately: this is the case for CallOp - // clang-format off - SmallVector toErase; - WalkResult result = funcOp.walk([&](Operation *op) - -> WalkResult { - WalkResult result = - TypeSwitch(op) +LogicalResult mlir::linalg::bufferizeOp( + Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, + DenseMap *bufferizedFunctionTypes, + GlobalCreator *globalCreator) { + OpBuilder b(op->getContext()); + return TypeSwitch(op) // Skip BufferCast and TensorLoad ops. - .Case([&](auto) { return success(); }) - .Case( + [&](auto) { return success(); }) + .Case([&](auto op) { LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo); }) .Case([&](CallOpInterface op) { LDBG("Begin bufferize:\n" << op << '\n'); - return bufferize(b, op, bvm, aliasInfo, bufferizedFunctionTypes); + if (!bufferizedFunctionTypes) + llvm_unreachable( + "null bufferizedFunctionTypes when bufferizing CallOpInterface"); + return bufferize(b, op, bvm, aliasInfo, *bufferizedFunctionTypes); }) .Case([&](ConstantOp op) { if (!isaTensor(op.getResult().getType())) return success(); LDBG("Begin bufferize:\n" << op << '\n'); - return bufferize(b, op, bvm, aliasInfo, globalCreator); + if (!globalCreator) + llvm_unreachable("null globalCreator when bufferizing ConstantOp"); + return bufferize(b, op, bvm, aliasInfo, *globalCreator); }) .Default([&](Operation *op) -> LogicalResult { auto isaTensor = [](Type t) { return t.isa(); }; @@ -2563,22 +2360,45 @@ return op->emitError() << "unsupported op with tensors"; return success(); }); +} - // Register post-walk erasure, if necessary. - if (isa(op)) - if (llvm::any_of(op->getOperandTypes(), isaTensor) || - llvm::any_of(op->getResultTypes(), isaTensor)) - toErase.push_back(op); +static LogicalResult bufferizeFuncOpInternals( + FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, + DenseMap &bufferizedFunctionTypes, + GlobalCreator &globalCreator) { + + LLVM_DEBUG(llvm::dbgs() << "\n\n"); + LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); + OpBuilder b(funcOp->getContext()); + /// Start by bufferizing `funcOp` arguments. + if (failed(bufferize(b, funcOp, bvm, aliasInfo))) + return failure(); + + // Walk in PreOrder to ensure ops with regions are handled before their body. + // Since walk has to be PreOrder, we need to erase ops that require it + // separately: this is the case for CallOp + SmallVector toErase; + if (funcOp + .walk([&](Operation *op) -> WalkResult { + if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, + &globalCreator))) + return failure(); + // Register post-walk erasure, if necessary. + if (isa(op)) + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getResultTypes(), isaTensor)) + toErase.push_back(op); + return success(); + }) + .wasInterrupted()) + return failure(); - return result; - }); - // clang-format on LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); for (Operation *op : toErase) op->erase(); - return failure(result.wasInterrupted()); + return success(); } //===----------------------------------------------------------------------===// diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6153,6 +6153,7 @@ "include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h", "include/mlir/Dialect/Linalg/Passes.h", "include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h", + "include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h", "include/mlir/Dialect/Linalg/Transforms/HoistPadding.h", "include/mlir/Dialect/Linalg/Transforms/Hoisting.h", "include/mlir/Dialect/Linalg/Transforms/Transforms.h",