diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -38,6 +38,10 @@ // CHECK-NEXT: odsState.addTypes(resultType1) // CHECK-NEXT: odsState.addTypes(z) +// CHECK: void OpC::build(Builder *odsBuilder, OperationState &odsState, ArrayRef resultTypes) { +// CHECK-NEXT: assert(resultTypes.size() == 3u && "mismatched number of results"); +// CHECK-NEXT: odsState.addTypes(resultTypes); + def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">; def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> { let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -702,6 +702,11 @@ } return; case TypeParamKind::Collective: + body << " " + << "assert(resultTypes.size() " + << (op.getNumVariadicResults() == 0 ? "==" : ">=") << " " + << (op.getNumResults() - op.getNumVariadicResults()) + << "u && \"mismatched number of results\");\n"; body << " " << builderOpState << ".addTypes(resultTypes);\n"; return; };