diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -67,8 +67,7 @@ /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpResult result, OpOperand &operand, - BufferRelation bufferRelation = BufferRelation::None); + void bufferizeInPlace(OpResult result, OpOperand &operand); /// Set the inPlace bufferization spec to false. void bufferizeOutOfPlace(OpResult result); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -136,6 +136,8 @@ using namespace linalg; using namespace tensor; +using BufferRelation = BufferizationAliasInfo::BufferRelation; + #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) @@ -421,8 +423,10 @@ // buffers in memory. // 3. Whether an op operand, when bufferized inplace, aliases a return value. // 4. Whether an op return value, when bufferized inplace, aliases an operand. -// 5. Wheher an op bufferizes to a memory read. -// 6. Wheher an op bufferizes to a memory write. +// 5. Whether an op bufferizes to a memory read. +// 6. Whether an op bufferizes to a memory write. +// 7. The buffer relationship between an operand and it corresponding result +// (in case of in-place bufferization). // These interfaces are necessary to distinguish between various cases and allow // special inplace behavior for (ExtractSliceOp, InsertSliceOp) pairs. //===----------------------------------------------------------------------===// @@ -682,6 +686,16 @@ getInPlace(opResult) == inPlaceSpec; } +/// Returns the relationship between the operand and the its corresponding +/// OpResult that it may alias with. +static BufferRelation bufferRelation(OpOperand &operand) { + return TypeSwitch(operand.getOwner()) + // ExtractSliceOp returns a subview of the original tensor. + .Case([&](ExtractSliceOp op) { return BufferRelation::None; }) + // All other ops: Buffers are equivalent. + .Default([&](Operation *op) { return BufferRelation::Equivalent; }); +} + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -787,13 +801,12 @@ /// Set the inPlace bufferization spec to true. void BufferizationAliasInfo::bufferizeInPlace(OpResult result, - OpOperand &operand, - BufferRelation bufferRelation) { + OpOperand &operand) { setInPlaceOpResult(result, InPlaceSpec::True); aliasInfo.unionSets(result, operand.get()); // Dump the updated alias analysis. LLVM_DEBUG(dumpAliases()); - if (bufferRelation == BufferRelation::Equivalent) + if (bufferRelation(operand) == BufferRelation::Equivalent) equivalentInfo.unionSets(result, operand.get()); // Dump the updated equivalence analysis. LLVM_DEBUG(dumpEquivalences()); @@ -2293,8 +2306,7 @@ else // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support // more cases on a per-need basis. - aliasInfo.bufferizeInPlace( - result, operand, BufferizationAliasInfo::BufferRelation::Equivalent); + aliasInfo.bufferizeInPlace(result, operand); LDBG("Done inplace analysis for result #" << resultNumber << '\n');