diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -583,39 +583,34 @@ }); } +/// If the an ExtractSliceOp is bufferized in-place, the source operand will +/// alias with the result. +static OpResult getAliasingOpResult(ExtractSliceOp op, OpOperand &opOperand) { + if (op.source() == opOperand.get()) + return op->getResult(0); + return OpResult(); +} + /// Determine which OpResult will alias with `opOperand` if the op is bufferized /// in place. This is a superset of `getInplaceableOpResult`. -/// Return None if the owner of `opOperand` does not have known -/// bufferization aliasing behavior, which indicates that the op must allocate -/// all of its tensor results. /// TODO: in the future this may need to evolve towards a list of OpResult. -static Optional getAliasingOpResult(OpOperand &opOperand) { - if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) - return None; +static OpResult getAliasingOpResult(OpOperand &opOperand) { return TypeSwitch(opOperand.getOwner()) - // These terminators legitimately have no result. - .Case( - [&](auto op) { return OpResult(); }) - // DimOp has no tensor result. - .Case([&](auto op) { return None; }) - // ConstantOp is never inplaceable. - .Case([&](ConstantOp op) { return op->getResult(0); }) // ExtractSliceOp is different: its result is not inplaceable on op.source // but when bufferized inplace, the result is an aliasing subregion of // op.source. - .Case([&](ExtractSliceOp op) { return op->getResult(0); }) - // All other ops, including scf::ForOp, return the result of - // `getInplaceableOpResult`. + .Case( + [&](ExtractSliceOp op) { return getAliasingOpResult(op, opOperand); }) + // All other ops, return the result of `getInplaceableOpResult`. .Default( [&](Operation *op) { return getInplaceableOpResult(opOperand); }); } /// Return true if `opOperand` bufferizes to a memory read. static bool bufferizesToMemoryRead(OpOperand &opOperand) { - Optional maybeOpResult = getAliasingOpResult(opOperand); // Unknown op that returns a tensor. The inplace analysis does not support // it. Conservatively return true. - if (!maybeOpResult) + if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) return true; // ExtractSliceOp alone doesn't bufferize to a memory read, one of its uses // may. @@ -672,19 +667,19 @@ // conservative. if (auto callOp = dyn_cast(opOperand.getOwner())) return true; - Optional maybeOpResult = getAliasingOpResult(opOperand); // Unknown op that returns a tensor. The inplace analysis does not support // it. Conservatively return true. - if (!maybeOpResult) + if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) return true; + OpResult opResult = getAliasingOpResult(opOperand); // Supported op without a matching result for opOperand (e.g. ReturnOp). // This does not bufferize to a write. - if (!*maybeOpResult) + if (!opResult) return false; // If we have a matching OpResult, this is a write. // Additionally allow to restrict to only inPlace write, if so specified. return inPlaceSpec == InPlaceSpec::None || - getInPlace(*maybeOpResult) == inPlaceSpec; + getInPlace(opResult) == inPlaceSpec; } //===----------------------------------------------------------------------===//