diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -37,7 +37,10 @@ let options = [ Option<"testAnalysisOnly", "test-analysis-only", "bool", /*default=*/"false", - "Only runs inplaceability analysis (for testing purposes only)"> + "Only runs inplaceability analysis (for testing purposes only)">, + Option<"allowReturnMemref", "allow-return-memref", "bool", + /*default=*/"false", + "Allows the return of memrefs (for testing purposes only)"> ]; let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -2914,6 +2914,14 @@ signalPassFailure(); return; } + if (!allowReturnMemref && + llvm::any_of(funcOp.getType().getResults(), [](Type t) { + return t.isa(); + })) { + funcOp->emitError("memref return type is unsupported"); + signalPassFailure(); + return; + } } // Perform a post-processing pass of layout modification at function boundary diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -130,3 +130,13 @@ %r = "marklar"(%A) : (tensor<4xf32>) -> (tensor<4xf32>) return %r: tensor<4xf32> } + +// ----- + +// expected-error @+1 {{memref return type is unsupported}} +func @mini_test_case1() -> tensor<10x20xf32> { + %f0 = constant 0.0 : f32 + %t = linalg.init_tensor [10, 20] : tensor<10x20xf32> + %r = linalg.fill(%f0, %t) : f32, tensor<10x20xf32> -> tensor<10x20xf32> + return %r : tensor<10x20xf32> +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=allow-return-memref -split-input-file | FileCheck %s // CHECK-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> { func @transfer_read(%A : tensor) -> (vector<4xf32>) {