diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -216,6 +216,13 @@ This method will never be called on ops that do not have at least one tensor operand/result. + + The return value of this method indicates whether there was an error + while bufferizing this op (such as failing to create a new buffer + allocation op). The bufferization driver immediately stops bufferizing + the input IR and returns `failure` in that case. If this op is + expected to survive bufferization, `success` should be returned + (together with `allow-unknown-ops` enabled). }], /*retType=*/"LogicalResult", /*methodName=*/"bufferize", diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -133,7 +133,7 @@ // DCE away. In case of partial bufferization, to_memref(to_tensor(x)) // constructs may be left over. These are folded by the canonicalizer or // FinalizingBufferize. - return failure(); + return success(); } bool isWritable(Value value, const AnalysisState &state) const { diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -354,7 +354,10 @@ LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, BufferizationState &state) { // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. - return foldToMemrefToTensorPair(rewriter, *this); + (void)foldToMemrefToTensorPair(rewriter, *this); + // Note: The return value of `bufferize` indicates whether there was an error + // or not. (And not whether the pattern matched or not.) + return success(); } Optional CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -416,7 +416,8 @@ continue; // Bufferize the op. rewriter.setInsertionPoint(op); - (void)bufferizableOp.bufferize(rewriter, bufferizationState); + if (failed(bufferizableOp.bufferize(rewriter, bufferizationState))) + return op->emitError("failed to bufferize op"); } // Fold all to_memref(to_tensor(x)) pairs. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -389,7 +389,7 @@ #endif // NDEBUG // ReturnOps are bufferized as part of FuncOps. - return failure(); + return success(); } }; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -11,8 +11,8 @@ // ----- -// expected-error @+2 {{op was not bufferized}} -// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} +// expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}} +// expected-error @+1 {{failed to bufferize op}} func.func private @foo() -> tensor // ----- @@ -262,7 +262,7 @@ // ----- -// expected-error @+2 {{op was not bufferized}} +// expected-error @+2 {{failed to bufferize op}} // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} func.func private @foo(%t : tensor) -> (f32, tensor, f32)