Index: mlir/docs/Dialects/Standard.md =================================================================== --- mlir/docs/Dialects/Standard.md +++ mlir/docs/Dialects/Standard.md @@ -72,10 +72,13 @@ operation ::= `return` (ssa-use-list `:` type-list-no-parens)? ``` -The `return` terminator operation represents the completion of a function, and -produces the result values. The count and types of the operands must match the -result types of the enclosing function. It is legal for multiple blocks in a -single function to return. +The `return` terminator operation, when enclosed immediately by a function, +represents the completion of that function and produces its result values. The +count and types of the operands must match the result types of the enclosing +function. It is valid for multiple blocks in a single function to return. In the +case that the op holding the region in which the `return` appears is not a +function, the `return` represents a transfer of control to the enclosing op, +whose semantics define where control flows next. ## Core Operations Index: mlir/include/mlir/Dialect/StandardOps/Ops.td =================================================================== --- mlir/include/mlir/Dialect/StandardOps/Ops.td +++ mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1003,17 +1003,22 @@ let hasFolder = 1; } -def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> { +def ReturnOp : Std_Op<"return", [Terminator]> { let summary = "return operation"; let description = [{ - The "return" operation represents a return operation within a function. - The operation takes variable number of operands and produces no results. - The operand number and types must match the signature of the function - that contains the operation. For example: - - func @foo() : (i32, f8) { - ... - return %0, %1 : i32, f8 + The `return` terminator operation, when enclosed immediately by a function, + represents the completion of that function and produces its result values. + The count and types of the operands must match the result types of the + enclosing function. It is valid for multiple blocks in a single function to + return. In the case that the `return` op's parent is not a function op, the + `return` represents a transfer of control to its parent, whose semantics + determine where control flows next, and whose result values match the values + returned. + + func @foo() : () -> (i32, f32) { + ... + return %0, %1 : i32, f32 + } }]; let arguments = (ins Variadic:$operands); Index: mlir/lib/Dialect/StandardOps/Ops.cpp =================================================================== --- mlir/lib/Dialect/StandardOps/Ops.cpp +++ mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1944,21 +1944,34 @@ } static LogicalResult verify(ReturnOp op) { - auto function = cast(op.getParentOp()); + auto *parentOp = op.getParentOp(); + if (!parentOp) + return op.emitOpError("has no parent op"); + + // The operand count and types must be consisntent with the parent op. When + // the parent is a FuncOp (which is declarative), check against its return + // types; for the rest, check again the number of actual SSA results. + SmallVector retTypes; + if (auto funcOp = dyn_cast(parentOp)) { + retTypes.assign(funcOp.getType().getResults().begin(), + funcOp.getType().getResults().end()); + } else { + retTypes.reserve(parentOp->getNumResults()); + for (auto result : parentOp->getResults()) + retTypes.push_back(result->getType()); + } - // The operand number and types must match the function signature. - const auto &results = function.getType().getResults(); - if (op.getNumOperands() != results.size()) + if (op.getNumOperands() != retTypes.size()) return op.emitOpError("has ") << op.getNumOperands() - << " operands, but enclosing function returns " << results.size(); + << " operands, but enclosing function returns " << retTypes.size(); - for (unsigned i = 0, e = results.size(); i != e; ++i) - if (op.getOperand(i)->getType() != results[i]) + for (unsigned i = 0, e = retTypes.size(); i != e; ++i) + if (op.getOperand(i)->getType() != retTypes[i]) return op.emitError() << "type of return operand " << i << " (" << op.getOperand(i)->getType() - << ") doesn't match function result type (" << results[i] << ")"; + << ") doesn't match function result type (" << retTypes[i] << ")"; return success(); } Index: mlir/test/IR/core-ops.mlir =================================================================== --- mlir/test/IR/core-ops.mlir +++ mlir/test/IR/core-ops.mlir @@ -728,3 +728,12 @@ tensor_store %1, %0 : memref<4x4xi32> return } + +// CHECK-LABEL: func @return_in_op_with_region +func @return_in_op_with_region() { + "foo.region"() ({ + %c9 = constant 9 : i32 + return %c9 : i32 + }): () -> (i32) + return +} Index: mlir/test/IR/invalid-ops.mlir =================================================================== --- mlir/test/IR/invalid-ops.mlir +++ mlir/test/IR/invalid-ops.mlir @@ -693,9 +693,9 @@ func @return_not_in_function() { "foo.region"() ({ - // expected-error@+1 {{'std.return' op expects parent op 'func'}} + // expected-error@+1 {{'std.return' op has 0 operands, but enclosing function returns 1}} return - }): () -> () + }): () -> (i32) return }