diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1876,6 +1876,15 @@ let assemblyFormat = "attr-dict $value"; } +def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [ + TypesMatchWith<"tuple result type matches operand type", "value", "result", + "::mlir::TupleType::get($_ctxt, $_self)"> + ]> { + let arguments = (ins AnyType:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -348,3 +348,6 @@ // CHECK: test.format_types_match_attr 1 : i64 %ignored_res5 = test.format_types_match_attr 1 : i64 + +// CHECK: test.format_types_match_context %[[I64]] : i64 +%ignored_res6 = test.format_types_match_context %i64 : i64 diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -579,6 +579,7 @@ opClass(op.getCppClassName(), op.getExtraClassDeclaration()), staticVerifierEmitter(staticVerifierEmitter) { verifyCtx.withOp("(*this->getOperation())"); + verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); genTraits(); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1343,6 +1343,7 @@ } else if (const NamedTypeConstraint *var = resolver.getVariable()) { if (Optional tform = resolver.getVarTransformer()) { FmtContext fmtContext; + fmtContext.addSubst("_ctxt", "parser.getBuilder().getContext()"); if (var->isVariadic()) fmtContext.withSelf(var->name + "Types"); else