diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2191,16 +2191,28 @@ AllMatchSameOperatorTrait; // A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`. +// An optional comparator function may be provided that changes the above form +// into: `comparator(transform(lhs.getType()), rhs.getType())`. class TypesMatchWith : - PredOpTrait> { + string transform, string comparator = "std::equal_to<>()"> + : PredOpTrait> { string lhs = lhsArg; string rhs = rhsArg; string transformer = transform; } +// Special variant of `TypesMatchWith` that provides a comparator suitable for +// ranged arguments. +class RangedTypesMatchWith + : TypesMatchWith; + // Type Constraint operand `idx`'s Element type is `type`. class TCopVTEtIs : And<[ CPred<"$_op.getNumOperands() > " # idx>, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1733,6 +1733,15 @@ let assemblyFormat = "attr-dict $value `:` type($value)"; } +def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [ + RangedTypesMatchWith<"result type matches operand", "value", "result", + "llvm::make_range($_self.begin(), $_self.end())"> + ]> { + let arguments = (ins Variadic:$value); + let results = (outs Variadic:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ TypesMatchWith<"result type matches constant", "value", "result", "$_self"> ]> { diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -308,5 +308,8 @@ // CHECK: test.format_types_match_var %[[I64]] : i64 %ignored_res3 = test.format_types_match_var %i64 : i64 +// CHECK: test.format_types_match_variadic %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64 +%ignored_res4:3 = test.format_types_match_variadic %i64, %i64, %i64 : i64, i64, i64 + // CHECK: test.format_types_match_attr 1 : i64 -%ignored_res4 = test.format_types_match_attr 1 : i64 +%ignored_res5 = test.format_types_match_attr 1 : i64 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1287,10 +1287,16 @@ if (Optional val = resolver.getBuilderIdx()) { body << "odsBuildableType" << *val; } else if (const NamedTypeConstraint *var = resolver.getVariable()) { - if (Optional tform = resolver.getVarTransformer()) - body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]")); - else + if (Optional tform = resolver.getVarTransformer()) { + FmtContext fmtContext; + if (var->isVariadic()) + fmtContext.withSelf(var->name + "Types"); + else + fmtContext.withSelf(var->name + "Types[0]"); + body << tgfmt(*tform, &fmtContext); + } else { body << var->name << "Types"; + } } else if (const NamedAttribute *attr = resolver.getAttribute()) { if (Optional tform = resolver.getVarTransformer()) body << tgfmt(*tform,