diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -264,6 +264,15 @@ return builder.create(loc, type, value); } +::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( + ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); + return ::mlir::success(); +} + void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, OperationName opName) { if (opName.getIdentifier() == "test.unregistered_side_effect_op" && diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2139,6 +2139,12 @@ }]; } +// Check that formatget supports DeclareOpInterfaceMethods. +def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods]> { + let results = (outs AnyType); + let assemblyFormat = "attr-dict"; +} + // Base class for testing mixing allOperandTypes, allOperands, and // inferResultTypes. class FormatInferAllTypesBaseOp traits = []> diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -409,7 +409,10 @@ //===----------------------------------------------------------------------===// // CHECK: test.format_infer_type -%ignored_res7 = test.format_infer_type +%ignored_res7a = test.format_infer_type + +// CHECK: test.format_infer_type2 +%ignored_res7b = test.format_infer_type2 // CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32 %ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2345,9 +2345,16 @@ handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { handleTypesMatchConstraint(variableTyResolver, def); - } else if (def.getName() == "InferTypeOpInterface" && - !op.allResultTypesKnown()) { - canInferResultTypes = true; + } else if (!op.allResultTypesKnown()) { + // This doesn't check the name directly to handle + // DeclareOpInterfaceMethods + // and the like. + // TODO: Add hasCppInterface check. + if (auto name = def.getValueAsOptionalString("cppClassName")) { + if (*name == "InferTypeOpInterface" && + def.getValueAsString("cppNamespace") == "::mlir") + canInferResultTypes = true; + } } }