diff --git a/mlir/docs/Dialects/Standard.md b/mlir/docs/Dialects/Standard.md --- a/mlir/docs/Dialects/Standard.md +++ b/mlir/docs/Dialects/Standard.md @@ -72,10 +72,13 @@ operation ::= `return` (ssa-use-list `:` type-list-no-parens)? ``` -The `return` terminator operation represents the completion of a function, and -produces the result values. The count and types of the operands must match the -result types of the enclosing function. It is legal for multiple blocks in a -single function to return. +The `return` terminator operation, when enclosed immediately by a function, +represents the completion of that function and produces its result values. The +count and types of the operands must match the result types of the enclosing +function. It is valid for multiple blocks in a single function to return. In the +case that the op holding the region in which the `return` appears is not a +function, the `return` represents a transfer of control to the enclosing op, +whose semantics define where control flows next. ## Core Operations diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1289,18 +1289,27 @@ // ReturnOp //===----------------------------------------------------------------------===// -def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, - Terminator]> { +def ReturnOp : Std_Op<"return", [NoSideEffect, Terminator]> { let summary = "return operation"; let description = [{ - The "return" operation represents a return operation within a function. - The operation takes variable number of operands and produces no results. - The operand number and types must match the signature of the function - that contains the operation. For example: + The `return` terminator operation, when enclosed immediately by a function, + represents the completion of that function and produces its result values. + The count and types of the operands must match the result types of the + enclosing function. It is valid for multiple blocks in a single function to + return. In the case that the `return` op's parent is not a function op, the + `return` represents a transfer of control to its parent, whose semantics + determine where control flows next, and whose result values match the values + returned. For such an imperative parent op, as a guideline, std.return + should be used as an terminator only when the values being returned are + actually the results of that op, which is typically the case when its region + is being executed once. - func @foo() : (i32, f8) { + ```mlir + func @foo() : () -> (i32, f32) { ... - return %0, %1 : i32, f8 + return %0, %1 : i32, f32 + } + ``` }]; let arguments = (ins Variadic:$operands); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1605,21 +1605,32 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(ReturnOp op) { - auto function = cast(op.getParentOp()); - - // The operand number and types must match the function signature. - const auto &results = function.getType().getResults(); - if (op.getNumOperands() != results.size()) + auto *parentOp = op.getParentOp(); + if (!parentOp) + return op.emitOpError("has no parent op"); + + // The operand count and types must be consistent with the parent op. When the + // parent is a FuncOp (which is declarative), check against its return types; + // for the rest, check again the number of actual SSA results. + // TODO: create an op interface for declarative func like ops instead of + // treating just the FuncOp specially. + ArrayRef retTypes; + if (auto funcOp = dyn_cast(parentOp)) + retTypes = funcOp.getType().getResults(); + else + retTypes = parentOp->getResultTypes(); + + if (op.getNumOperands() != retTypes.size()) return op.emitOpError("has ") << op.getNumOperands() - << " operands, but enclosing function returns " << results.size(); + << " operands, but enclosing function returns " << retTypes.size(); - for (unsigned i = 0, e = results.size(); i != e; ++i) - if (op.getOperand(i).getType() != results[i]) + for (unsigned i = 0, e = retTypes.size(); i != e; ++i) + if (op.getOperand(i).getType() != retTypes[i]) return op.emitError() << "type of return operand " << i << " (" << op.getOperand(i).getType() - << ") doesn't match function result type (" << results[i] << ")"; + << ") doesn't match function result type (" << retTypes[i] << ")"; return success(); } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -551,6 +551,15 @@ return } +// CHECK-LABEL: func @return_in_op_with_region +func @return_in_op_with_region() { + "foo.region"() ({ + %c9 = constant 9 : i32 + return %c9 : i32 + }): () -> (i32) + return +} + // Test with zero-dimensional operands using no index in load/store. // CHECK-LABEL: func @zero_dim_no_idx func @zero_dim_no_idx(%arg0 : memref, %arg1 : memref, %arg2 : memref) { diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -773,9 +773,9 @@ func @return_not_in_function() { "foo.region"() ({ - // expected-error@+1 {{'std.return' op expects parent op 'func'}} + // expected-error@+1 {{'std.return' op has 0 operands, but enclosing function returns 1}} return - }): () -> () + }): () -> (i32) return }