diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -367,7 +367,11 @@ let hasFolder = 1; } -def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> { +def Shape_YieldOp : Shape_Op<"yield", + [HasParent<"ReduceOp">, + NoSideEffect, + ReturnLike, + Terminator]> { let summary = "Returns the value to parent op"; let arguments = (ins Variadic:$operands); @@ -376,6 +380,7 @@ "OpBuilder &b, OperationState &result", [{ build(b, result, llvm::None); }] >]; + let verifier = [{ return ::verify(*this); }]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -391,6 +391,26 @@ } //===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(YieldOp op) { + auto *parentOp = op.getParentOp(); + auto results = parentOp->getResults(); + auto operands = op.getOperands(); + + if (parentOp->getNumResults() != op.getNumOperands()) + return op.emitOpError() << "number of operands does not match number of " + "results of its parent"; + for (auto e : llvm::zip(results, operands)) + if (std::get<0>(e).getType() != std::get<1>(e).getType()) + return op.emitOpError() + << "types mismatch between yield op and its parent"; + + return success(); +} + +//===----------------------------------------------------------------------===// // SplitAtOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -4,7 +4,7 @@ // expected-error@+1 {{ReduceOp body is expected to have 3 arguments}} %num_elements = shape.reduce(%shape, %init) -> !shape.size { ^bb0(%index: index, %dim: !shape.size): - "shape.yield"(%dim) : (!shape.size) -> () + shape.yield %dim : !shape.size } } @@ -13,9 +13,10 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) { // expected-error@+1 {{argument 0 of ReduceOp body is expected to be of IndexType}} %num_elements = shape.reduce(%shape, %init) -> !shape.size { - ^bb0(%index: f32, %dim: !shape.size, %lci: !shape.size): - %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size - "shape.yield"(%acc) : (!shape.size) -> () + ^bb0(%index: f32, %dim: !shape.size, %acc: !shape.size): + %new_acc = "shape.add"(%acc, %dim) + : (!shape.size, !shape.size) -> !shape.size + shape.yield %new_acc : !shape.size } } @@ -25,7 +26,7 @@ // expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType}} %num_elements = shape.reduce(%shape, %init) -> !shape.size { ^bb0(%index: index, %dim: f32, %lci: !shape.size): - "shape.yield"() : () -> () + shape.yield } } @@ -35,6 +36,27 @@ // expected-error@+1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}} %num_elements = shape.reduce(%shape, %init) -> f32 { ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): - "shape.yield"() : () -> () + shape.yield + } +} + +// ----- + +func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) { + // expected-error@+3 {{number of operands does not match number of results of its parent}} + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): + shape.yield %dim, %dim : !shape.size, !shape.size + } +} + +// ----- + +func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) { + // expected-error@+4 {{types mismatch between yield op and its parent}} + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): + %c0 = constant 1 : index + shape.yield %c0 : index } } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -10,7 +10,7 @@ %num_elements = shape.reduce(%shape, %init) -> !shape.size { ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size - "shape.yield"(%acc) : (!shape.size) -> () + shape.yield %acc : !shape.size } return %num_elements : !shape.size }