diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -626,6 +626,18 @@ [&](Operation *op) { return getInplaceableOpResult(opOperand); }); } +/// Return `true` if the given OpOperand does not bufferize to a memory read or +/// write, but creates an alias when bufferized inplace. +static bool bufferizesToAliasOnly(OpOperand &opOperand) { + Operation *owner = opOperand.getOwner(); + // TODO: In the future this may need to evolve into a TypeSwitch. For all + // currently supported ops, the aliasing-only OpOperand is always the first + // one. + return isa(owner) && + &opOperand == &owner->getOpOperand(0); +} + // Predeclaration of function. static bool bufferizesToMemoryRead(OpOperand &opOperand); @@ -640,8 +652,8 @@ while (!workingSet.empty()) { OpOperand *uMaybeReading = workingSet.pop_back_val(); // Skip over all ops that create an alias but do not read. - if (isa(uMaybeReading->getOwner())) - for (OpOperand &use : uMaybeReading->getOwner()->getResult(0).getUses()) + if (bufferizesToAliasOnly(*uMaybeReading)) + for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) return true; @@ -658,7 +670,7 @@ return true; // Some ops alone do not bufferize to a memory read, but one of their uses // may. - if (isa(opOperand.getOwner())) + if (bufferizesToAliasOnly(opOperand)) return false; // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. @@ -690,7 +702,7 @@ return false; // Some ops alone do not bufferize to a memory write, but one of their uses // may. - if (isa(opOperand.getOwner())) + if (bufferizesToAliasOnly(opOperand)) return false; // CallOpInterface alone doesn't bufferize to a memory write, one of the uses // of the matching bbArg may. It is the responsibility of the caller to @@ -2318,9 +2330,8 @@ return success(); } -/// This analysis function is used for ops where the first OpOperand aliases -/// with the first OpResult, without creating a read or write. There are a few -/// ops besides ExtractSliceOp that have such semantics. +/// This analysis function is used for OpOperands that alias with an OpResult +/// but are not inplaceable on it. E.g., ExtractSliceOp. /// /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: /// @@ -2335,11 +2346,12 @@ /// An analysis is required to ensure inplace bufferization would not result in /// RaW dependence violations. static LogicalResult -bufferizableInPlaceAnalysisAliasOnlyOp(Operation *op, +bufferizableInPlaceAnalysisAliasOnlyOp(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { - return bufferizableInPlaceAnalysisImpl( - op->getOpOperand(0), op->getOpResult(0), aliasInfo, domInfo); + OpResult result = getAliasingOpResult(operand); + assert(result && "expected that the OpOperand has an aliasing OpResult"); + return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); } /// Determine if `operand` can be bufferized in-place with one of the op's @@ -2372,16 +2384,17 @@ // Walk ops in reverse for better interference analysis. for (Operation *op : reverse(ops)) { - for (OpOperand &opOperand : op->getOpOperands()) + for (OpOperand &opOperand : op->getOpOperands()) { if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) return failure(); - // Special logic to analyze ops who's OpResults are not inplaceable on an - // OpOperand but may create an alias. - if (isa(op)) - if (failed( - bufferizableInPlaceAnalysisAliasOnlyOp(op, aliasInfo, domInfo))) - return failure(); + // Special logic to analyze OpOperands that are not inplaceable on an + // OpResult but may create an alias. + if (bufferizesToAliasOnly(opOperand)) + if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(opOperand, aliasInfo, + domInfo))) + return failure(); + } } return success(); @@ -3049,8 +3062,8 @@ aliasInfo.createAliasInfoEntry(extractOp.result()); // Run analysis on the ExtractSliceOp. - if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(extractOp, aliasInfo, - domInfo))) + if (failed(bufferizableInPlaceAnalysisAliasOnlyOp( + extractOp->getOpOperand(0), aliasInfo, domInfo))) return WalkResult::interrupt(); // Advance to the next operation.