diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -138,6 +138,10 @@ bufferized in-place. This method will never be called on OpResults that do not have a tensor type. + By default, this method is the inverse of `getAliasingOpResult`. Ops + with a region that yield values may want to override this method to + return the OpOperands that are yielded by the terminator. + Note: This method can return multiple OpOperands, indicating that the given OpResult may at runtime alias with any of the OpOperands. This is useful for branches and for ops such as `std.select`. @@ -147,8 +151,18 @@ /*args=*/(ins "OpResult":$opResult), /*methodBody=*/"", /*defaultImplementation=*/[{ - // Does not have to be implemented for ops without tensor OpResults. - llvm_unreachable("getInplaceableOpResult not implemented"); + assert(opResult.getType().isa() && + "expected OpResult with tensor type"); + SmallVector result; + auto bufferizableOp = + cast($_op.getOperation()); + for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) { + if (!opOperand.get().getType().isa()) + continue; + if (bufferizableOp.getAliasingOpResult(opOperand) == opResult) + result.push_back(&opOperand); + } + return result; }] >, InterfaceMethod< diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -23,11 +23,6 @@ struct ConstantOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {}; - } - LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto constantOp = cast(op); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -41,11 +41,6 @@ return true; } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {}; - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -69,11 +64,6 @@ struct ToTensorOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {}; - } - LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto tensorLoadOp = cast(op); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -145,11 +145,6 @@ struct InitTensorOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {}; - } - bool isMemoryWrite(Operation *op, OpResult opResult) const { // InitTensorOps allocate but do not write. return false; @@ -191,15 +186,6 @@ return static_cast(bufferizableOp.getAliasingOpResult(opOperand)); } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - // TODO: TiledLoopOp helper method to avoid leaking impl details. - auto tiledLoopOp = cast(op); - return {&op->getOpOperand(tiledLoopOp.getNumControlOperands() + - tiledLoopOp.getNumInputs() + - opResult.getResultNumber())}; - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto tiledLoopOp = cast(op); return tiledLoopOp.getTiedOpResult(opOperand); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -447,12 +447,6 @@ return true; } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - // TODO: Can we do better? - return {}; - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { // CallOpInterface is special, it needs to wait for the callee to be // bufferized and needs to inspect the BufferAliasInfo object. It can't diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -167,12 +167,6 @@ return true; } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - auto forOp = cast(op); - return {&forOp.getIterOpOperands()[opResult.getResultNumber()]}; - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto forOp = cast(op); if (!opOperand.get().getType().isa()) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -48,11 +48,6 @@ return false; } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {&op->getOpOperand(0)}; - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return op->getResult(0); } @@ -137,11 +132,6 @@ return false; } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {&op->getOpOperand(0) /*source*/}; - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return &opOperand == &op->getOpOperand(0) /*source*/ ? op->getResult(0) @@ -335,11 +325,6 @@ return &opOperand == &op->getOpOperand(1) /*dest*/; } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {&op->getOpOperand(1) /*dest*/}; - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return &opOperand == &op->getOpOperand(1) /*dest*/ ? op->getResult(0) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -68,11 +68,6 @@ return true; } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {&op->getOpOperand(1)}; - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { assert(opOperand.get().getType().isa() && "only tensor types expected");