diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -37,7 +37,10 @@ let options = [ Option<"testAnalysisOnly", "test-analysis-only", "bool", /*default=*/"false", - "Only runs inplaceability analysis (for testing purposes only)"> + "Only runs inplaceability analysis (for testing purposes only)">, + Option<"allowReturnMemref", "allow-return-memref", "bool", + /*default=*/"false", + "Allows the return of memrefs (for testing purposes only)"> ]; let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -795,9 +795,8 @@ Value v; }; - using EquivalenceClassRangeType = - llvm::iterator_range< - llvm::EquivalenceClasses::member_iterator>; + using EquivalenceClassRangeType = llvm::iterator_range< + llvm::EquivalenceClasses::member_iterator>; /// Check that aliasInfo for `v` exists and return a reference to it. EquivalenceClassRangeType getAliases(Value v) const; @@ -1106,15 +1105,13 @@ void BufferizationAliasInfo::printAliases(raw_ostream &os) const { os << "\n/===================== AliasInfo =====================\n"; - for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; - ++it) { + for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; ++it) { if (!it->isLeader()) continue; Value leader = it->getData(); os << "|\n| -- leader: " << printValueInfo(leader, /*prefix=*/false) << '\n'; - for (auto mit = aliasInfo.member_begin(it), - meit = aliasInfo.member_end(); + for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); mit != meit; ++mit) { Value v = static_cast(*mit); os << "| ---- equivalent member: " << printValueInfo(v, /*prefix=*/false) @@ -1148,13 +1145,12 @@ BufferizationAliasInfo::getAliases(Value v) const { DenseSet res; auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); - for (auto mit = aliasInfo.member_begin(it), - meit = aliasInfo.member_end(); - mit != meit; ++mit) { - res.insert(static_cast(*mit)); - } + for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); + mit != meit; ++mit) { + res.insert(static_cast(*mit)); + } return BufferizationAliasInfo::EquivalenceClassRangeType( - aliasInfo.member_begin(it), aliasInfo.member_end()); + aliasInfo.member_begin(it), aliasInfo.member_end()); } /// This is one particular type of relationship between ops on tensors that @@ -2917,6 +2913,14 @@ signalPassFailure(); return; } + if (!allowReturnMemref && + llvm::any_of(funcOp.getType().getResults(), [](Type t) { + return t.isa(); + })) { + funcOp->emitError("memref return type is unsupported"); + signalPassFailure(); + return; + } } // Perform a post-processing pass of layout modification at function boundary diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -130,3 +130,13 @@ %r = "marklar"(%A) : (tensor<4xf32>) -> (tensor<4xf32>) return %r: tensor<4xf32> } + +// ----- + +// expected-error @+1 {{memref return type is unsupported}} +func @mini_test_case1() -> tensor<10x20xf32> { + %f0 = constant 0.0 : f32 + %t = linalg.init_tensor [10, 20] : tensor<10x20xf32> + %r = linalg.fill(%f0, %t) : f32, tensor<10x20xf32> -> tensor<10x20xf32> + return %r : tensor<10x20xf32> +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=allow-return-memref -split-input-file | FileCheck %s // CHECK-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> { func @transfer_read(%A : tensor) -> (vector<4xf32>) {