diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td @@ -158,6 +158,21 @@ return failure(); }] >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given OpOperand can be written to in-place. This + is the case for most ops, but some ops such as ConstantOp may + bufferize to non-writable (read-only) memory locations. This method + will never be called on OpResults that do not have a tensor type. + }], + /*retType=*/"bool", + /*methodName=*/"isWritable", + /*args=*/(ins "OpResult":$opResult), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return true; + }] + > ]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -587,12 +587,12 @@ return true; } - if (Operation *op = v.getDefiningOp()) { - if (isa(op) || - !dyn_cast(op)) { - LDBG("-----------notWritable op\n"); - return true; - } + auto bufferizableOp = dyn_cast(v.getDefiningOp()); + if (!bufferizableOp || !bufferizableOp.isWritable(v.cast())) { + // Unknown ops are treated conservatively: Assume that it is illegal to + // write to their OpResults in-place. + LDBG("-----------notWritable op\n"); + return true; } } LDBG("---->value is writable\n"); @@ -2382,6 +2382,11 @@ return success(); } + + bool isWritable(Operation *op, OpResult opResult) const { + // Memory locations returned by memref::GetGlobalOp may not be written to. + return false; + } }; } // namespace arith_ext