diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -795,9 +795,8 @@ Value v; }; - using EquivalenceClassRangeType = - llvm::iterator_range< - llvm::EquivalenceClasses::member_iterator>; + using EquivalenceClassRangeType = llvm::iterator_range< + llvm::EquivalenceClasses::member_iterator>; /// Check that aliasInfo for `v` exists and return a reference to it. EquivalenceClassRangeType getAliases(Value v) const; @@ -1106,15 +1105,13 @@ void BufferizationAliasInfo::printAliases(raw_ostream &os) const { os << "\n/===================== AliasInfo =====================\n"; - for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; - ++it) { + for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; ++it) { if (!it->isLeader()) continue; Value leader = it->getData(); os << "|\n| -- leader: " << printValueInfo(leader, /*prefix=*/false) << '\n'; - for (auto mit = aliasInfo.member_begin(it), - meit = aliasInfo.member_end(); + for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); mit != meit; ++mit) { Value v = static_cast(*mit); os << "| ---- equivalent member: " << printValueInfo(v, /*prefix=*/false) @@ -1148,13 +1145,12 @@ BufferizationAliasInfo::getAliases(Value v) const { DenseSet res; auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); - for (auto mit = aliasInfo.member_begin(it), - meit = aliasInfo.member_end(); - mit != meit; ++mit) { - res.insert(static_cast(*mit)); - } + for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); + mit != meit; ++mit) { + res.insert(static_cast(*mit)); + } return BufferizationAliasInfo::EquivalenceClassRangeType( - aliasInfo.member_begin(it), aliasInfo.member_end()); + aliasInfo.member_begin(it), aliasInfo.member_end()); } /// This is one particular type of relationship between ops on tensors that @@ -2631,6 +2627,11 @@ // 2. Rewrite the terminator without the inPlace bufferizable values. ValueRange retValues{returnValues}; + if (llvm::any_of(retValues.getTypes(), [](Type t) { + return t.isa(); + })) + return returnOp->emitError("memref return type is unsupported"); + FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( funcOp, funcOp.getType().getInputs(), retValues.getTypes(), bufferizedFunctionTypes); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -130,3 +130,13 @@ %r = "marklar"(%A) : (tensor<4xf32>) -> (tensor<4xf32>) return %r: tensor<4xf32> } + +// ----- + +func @mini_test_case1() -> tensor<10x20xf32> { + %f0 = constant 0.0 : f32 + %t = linalg.init_tensor [10, 20] : tensor<10x20xf32> + %r = linalg.fill(%f0, %t) : f32, tensor<10x20xf32> -> tensor<10x20xf32> + // expected-error @+1 {{memref return type is unsupported}} + return %r : tensor<10x20xf32> +}