Index: g3doc/Dialects/Standard.md =================================================================== --- g3doc/Dialects/Standard.md +++ g3doc/Dialects/Standard.md @@ -72,10 +72,14 @@ operation ::= `return` (ssa-use-list `:` type-list-no-parens)? ``` -The `return` terminator operation represents the completion of a function, and -produces the result values. The count and types of the operands must match the -result types of the enclosing function. It is legal for multiple blocks in a -single function to return. +The `return` terminator operation represents the transfer of control to its +parent op, which is the op immediately enclosing the region in which the +`return' appears. The semantics of the parent op define where control flows +next. When the parent op of a `return` is a function op, the `return' represents +the completion of that function and produces its result values. The count and +types of the operands must match the result types of that function. In the case +that the parent op is not a function, the count and types of the return's +operands must match those of the result values of the parent op. ## Core Operations Index: include/mlir/Dialect/StandardOps/Ops.td =================================================================== --- include/mlir/Dialect/StandardOps/Ops.td +++ include/mlir/Dialect/StandardOps/Ops.td @@ -1012,7 +1012,7 @@ let hasFolder = 1; } -def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> { +def ReturnOp : Std_Op<"return", [Terminator]> { let summary = "return operation"; let description = [{ The "return" operation represents a return operation within a function. Index: lib/Dialect/StandardOps/Ops.cpp =================================================================== --- lib/Dialect/StandardOps/Ops.cpp +++ lib/Dialect/StandardOps/Ops.cpp @@ -1956,21 +1956,35 @@ } static LogicalResult verify(ReturnOp op) { - auto function = cast(op.getParentOp()); + auto *parentOp = op.getParentOp(); + if (!parentOp) + // TODO: this path can only be tested via the builder. + return op.emitOpError("has no parent op"); + + // The operand count and types must be consisntent with the parent op. When + // the parent is a FuncOp (which is declarative), check against its return + // types; for the rest, check again the number of actual SSA results. + SmallVector retTypes; + if (auto funcOp = dyn_cast(parentOp)) { + retTypes.assign(funcOp.getType().getResults().begin(), + funcOp.getType().getResults().end()); + } else { + retTypes.reserve(parentOp->getNumResults()); + for (auto *result : parentOp->getResults()) + retTypes.push_back(result->getType()); + } - // The operand number and types must match the function signature. - const auto &results = function.getType().getResults(); - if (op.getNumOperands() != results.size()) + if (op.getNumOperands() != retTypes.size()) return op.emitOpError("has ") << op.getNumOperands() - << " operands, but enclosing function returns " << results.size(); + << " operands, but enclosing function returns " << retTypes.size(); - for (unsigned i = 0, e = results.size(); i != e; ++i) - if (op.getOperand(i)->getType() != results[i]) + for (unsigned i = 0, e = retTypes.size(); i != e; ++i) + if (op.getOperand(i)->getType() != retTypes[i]) return op.emitError() << "type of return operand " << i << " (" << op.getOperand(i)->getType() - << ") doesn't match function result type (" << results[i] << ")"; + << ") doesn't match function result type (" << retTypes[i] << ")"; return success(); } Index: test/IR/core-ops.mlir =================================================================== --- test/IR/core-ops.mlir +++ test/IR/core-ops.mlir @@ -682,3 +682,12 @@ tensor_store %1, %0 : memref<4x4xi32> return } + +// CHECK-LABEL: func @return_in_op_with_region +func @return_in_op_with_region() { + "foo.region"() ({ + %c9 = constant 9 : i32 + return %c9 : i32 + }): () -> (i32) + return +} Index: test/IR/invalid-ops.mlir =================================================================== --- test/IR/invalid-ops.mlir +++ test/IR/invalid-ops.mlir @@ -693,9 +693,9 @@ func @return_not_in_function() { "foo.region"() ({ - // expected-error@+1 {{'std.return' op expects parent op 'func'}} + // expected-error@+1 {{'std.return' op has 0 operands, but enclosing function returns 1}} return - }): () -> () + }): () -> (i32) return }