diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -515,14 +515,6 @@ return op->getResult(0); } -/// Return the OpResult that may bufferize into the same buffer as `opOperand` -/// when the op is bufferized inplace. -/// Return null if no such result exists. -static OpResult getInplaceableOpResult(tensor::CastOp op, - OpOperand &opOperand) { - return op->getResult(0); -} - /// Return the OpResult that may bufferize into the same buffer as `opOperand` /// when the op is bufferized inplace. /// The inplace analysis uses this information along with interfering read @@ -533,8 +525,7 @@ // clang-format off // Ops that perform destructive updates on operand(s) to produce // result(s). - .Case( + .Case( [] (auto op) { return OpResult(); }) // CallOpInterface is special, it needs to wait for the callee to be // bufferized and needs to inspect the BufferAliasInfo object. It can't @@ -562,9 +554,8 @@ if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp())) return SmallVector(); TypeSwitch(result.getDefiningOp()) - .Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); }) - .Case( - [&](auto op) { r.push_back(&op->getOpOperand(0)); }) + .Case([&](auto op) { r.push_back(&op->getOpOperand(0)); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. // TODO: update this. @@ -619,6 +610,14 @@ return OpResult(); } +/// If the a tensor::CastOp is bufferized in-place, the source operand will +/// alias with the result. +static OpResult getAliasingOpResult(tensor::CastOp op, OpOperand &opOperand) { + if (&op->getOpOperand(0) == &opOperand) + return op->getResult(0); + return OpResult(); +} + /// Determine which OpResult will alias with `opOperand` if the op is bufferized /// in place. This is a superset of `getInplaceableOpResult`. /// TODO: in the future this may need to evolve towards a list of OpResult. @@ -627,7 +626,8 @@ // Some ops are different: Their result is not inplaceable on an OpOperand // but when bufferized inplace, their result is aliasing (a subregion of) // an OpOperand. - .Case( + .Case( [&](auto op) { return getAliasingOpResult(op, opOperand); }) // All other ops, return the result of `getInplaceableOpResult`. .Default( @@ -648,8 +648,8 @@ while (!workingSet.empty()) { OpOperand *uMaybeReading = workingSet.pop_back_val(); // Skip over all ops that create an alias but do not read. - if (isa( - uMaybeReading->getOwner())) + if (isa(uMaybeReading->getOwner())) for (OpOperand &use : uMaybeReading->getOwner()->getResult(0).getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) @@ -667,8 +667,8 @@ return true; // Some ops alone do not bufferize to a memory read, but one of their uses // may. - if (isa( - opOperand.getOwner())) + if (isa(opOperand.getOwner())) return false; // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. @@ -700,8 +700,8 @@ return false; // Some ops alone do not bufferize to a memory write, but one of their uses // may. - if (isa( - opOperand.getOwner())) + if (isa(opOperand.getOwner())) return false; // CallOpInterface alone doesn't bufferize to a memory write, one of the uses // of the matching bbArg may. It is the responsibility of the caller to @@ -2357,7 +2357,8 @@ // Special logic to analyze ops who's OpResults are not inplaceable on an // OpOperand but may create an alias. - if (isa(op)) + if (isa(op)) if (failed( bufferizableInPlaceAnalysisAliasOnlyOp(op, aliasInfo, domInfo))) return failure();