diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -722,9 +722,6 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results", [AttrSizedResultSegments]> { - let arguments = (ins - DenseI32ArrayAttr:$result_segment_sizes - ); let results = (outs Variadic:$a, Variadic:$b, @@ -733,6 +730,11 @@ ); } +def AttrSizedResultCompileTestOp : TEST_Op<"attr_sized_results_compile_test", + [AttrSizedResultSegments]> { + let results = (outs Variadic:$a, I32:$b, Optional:$c); +} + // This is used to test that the fallback for a custom op's parser and printer // is the dialect parser and printer hooks. def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">; diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -157,3 +157,10 @@ // CHECK-NOT: } // CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").cast<::mlir::TypedAttr>().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; + +def OpM : NS_Op<"mix_diff_size_variadic_and_normal_results_op", [AttrSizedResultSegments]> { + let results = (outs Variadic:$output1, AnyTensor:$output2, Optional:$output3); +} + +// CHECK-LABEL: OpM::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange output1, ::mlir::Type output2, /*optional*/::mlir::Type output3) +// CHECK: odsState.addAttribute(result_segment_sizesAttrName(odsState.name), odsBuilder.getDenseI32ArrayAttr({static_cast(output1.size()), 1, (output3 ? 1 : 0)})); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1514,6 +1514,34 @@ body << " " << builderOpState << ".addTypes(" << resultNames[i] << ");\n"; } + + // Automatically create the 'result_segment_sizes' attribute using + // the length of the type ranges. + if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { + std::string getterName = op.getGetterName(resultSegmentAttrName); + body << " " << builderOpState << ".addAttribute(" << getterName + << "AttrName(" << builderOpState << ".name), " + << "odsBuilder.getDenseI32ArrayAttr({"; + + interleaveComma( + llvm::seq(0, op.getNumResults()), body, [&](int i) { + const NamedTypeConstraint &result = op.getResult(i); + if (!result.isVariableLength()) { + body << "1"; + } else if (result.isOptional()) { + body << "(" << resultNames[i] << " ? 1 : 0)"; + } else { + // VariadicOfVariadic of results are currently unsupported in + // MLIR, hence it can only be a simple variadic. + // TODO: Add implementation for VariadicOfVariadic results here + // once supported. + assert(result.isVariadic()); + body << "static_cast(" << resultNames[i] << ".size())"; + } + }); + body << "}));\n"; + } + return; case TypeParamKind::Collective: { int numResults = op.getNumResults();