diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1344,6 +1344,14 @@ }]; } +def FormatInferVariadicTypeFromNonVariadic + : TEST_Op<"format_infer_variadic_type_from_non_variadic", + [SameOperandsAndResultType]> { + let arguments = (ins Variadic:$operands); + let results = (outs AnyType:$result); + let assemblyFormat = "$operands attr-dict `:` type($result)"; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -101,3 +101,10 @@ // CHECK: test.format_optional_operand_result_b_op : i64 test.format_optional_operand_result_b_op : i64 + +//===----------------------------------------------------------------------===// +// Format trait type inference +//===----------------------------------------------------------------------===// + +// CHECK: test.format_infer_variadic_type_from_non_variadic %[[I64]], %[[I64]] : i64 +test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -840,10 +840,28 @@ for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { NamedTypeConstraint &operand = op.getOperand(i); body << " if (parser.resolveOperands(" << operand.name << "Operands, "; - emitTypeResolver(operandTypes[i], operand.name); - // If this isn't a buildable type, verify the sizes match by adding the loc. - if (!operandTypes[i].getBuilderIdx()) + // Resolve the type of this operand. + TypeResolution &operandType = operandTypes[i]; + emitTypeResolver(operandType, operand.name); + + // If the type is resolved by a non-variadic variable, index into the + // resolved type list. This allows for resolving the types of a variadic + // operand list from a non-variadic variable. + bool verifyOperandAndTypeSize = true; + if (auto *resolverVar = operandType.getVariable()) { + if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) { + body << "[0]"; + verifyOperandAndTypeSize = false; + } + } else { + verifyOperandAndTypeSize = !operandType.getBuilderIdx(); + } + + // Check to see if the sizes between the types and operands must match. If + // they do, provide the operand location to select the proper resolution + // overload. + if (verifyOperandAndTypeSize) body << ", " << operand.name << "OperandsLoc"; body << ", result.operands))\n return failure();\n"; }