diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -929,6 +929,11 @@ supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`, and `SameOperandsAndResultType`. +* InferTypeOpInterface + +Operations that implement `InferTypeOpInterface` can omit their result types in +their assembly format since the result types can be inferred from the operands. + ### `hasCanonicalizer` This boolean field indicate whether canonicalization patterns have been defined diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2021,6 +2021,24 @@ let assemblyFormat = "attr-dict $value `:` type($value)"; } +//===----------------------------------------------------------------------===// +// InferTypeOpInterface type inference in assembly format + +def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> { + let results = (outs AnyType); + let assemblyFormat = "attr-dict"; + + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); + return ::mlir::success(); + } + }]; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -3,6 +3,7 @@ // This file contains tests for the specification of the declarative op format. include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def TestDialect : Dialect { let name = "test"; @@ -566,4 +567,6 @@ operands type($result) attr-dict }], [AllTypesMatch<["operand", "result"]>]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; - +def ZCoverageValidI : TestFormat_Op<[{ + operands type(operands) attr-dict +}], [InferTypeOpInterface]>, Arguments<(ins Variadic:$inputs)>, Results<(outs I64:$result)>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -354,3 +354,10 @@ // CHECK: test.format_types_match_context %[[I64]] : i64 %ignored_res6 = test.format_types_match_context %i64 : i64 + +//===----------------------------------------------------------------------===// +// InferTypeOpInterface type inference +//===----------------------------------------------------------------------===// + +// CHECK: test.format_infer_type +%ignored_res7 = test.format_infer_type diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -441,7 +441,8 @@ }; OperationFormat(const Operator &op) - : allOperands(false), allOperandTypes(false), allResultTypes(false) { + : allOperands(false), allOperandTypes(false), allResultTypes(false), + infersResultTypes(false) { operandTypes.resize(op.getNumOperands(), TypeResolution()); resultTypes.resize(op.getNumResults(), TypeResolution()); @@ -482,6 +483,9 @@ /// contains these, it can not contain individual type resolvers. bool allOperands, allOperandTypes, allResultTypes; + /// A flag indicating if this operation infers its result types + bool infersResultTypes; + /// A flag indicating if this operation has the SingleBlockImplicitTerminator /// trait. bool hasImplicitTermTrait; @@ -682,6 +686,19 @@ {1}Types = {0}__{1}_functionType.getResults(); )"; +/// The code snippet used to generate a parser call to infer return types. +/// +/// {0}: The operation class name +const char *const inferReturnTypesParserCode = R"( + ::llvm::SmallVector<::mlir::Type> inferredReturnTypes; + if (::mlir::failed({0}::inferReturnTypes(parser.getContext(), + result.location, result.operands, + result.attributes.getDictionary(parser.getContext()), + result.regions, inferredReturnTypes))) + return ::mlir::failure(); + result.addTypes(inferredReturnTypes); +)"; + /// The code snippet used to generate a parser call for a region list. /// /// {0}: The name for the region list. @@ -1437,19 +1454,25 @@ }; // Resolve each of the result types. - if (allResultTypes) { - body << " result.addTypes(allResultTypes);\n"; - } else { - for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { - body << " result.addTypes("; - emitTypeResolver(resultTypes[i], op.getResultName(i)); - body << ");\n"; + if (!infersResultTypes) { + if (allResultTypes) { + body << " result.addTypes(allResultTypes);\n"; + } else { + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { + body << " result.addTypes("; + emitTypeResolver(resultTypes[i], op.getResultName(i)); + body << ");\n"; + } } } // Early exit if there are no operands. - if (op.getNumOperands() == 0) + if (op.getNumOperands() == 0) { + // Handle return type inference here if there are no operands + if (infersResultTypes) + body << formatv(inferReturnTypesParserCode, op.getCppClassName()); return; + } // Handle the case where all operand types are in one group. if (allOperandTypes) { @@ -1532,6 +1555,10 @@ body << ", " << operand.name << "OperandsLoc"; body << ", result.operands))\n return ::mlir::failure();\n"; } + + // Handle return type inference once all operands have been resolved + if (infersResultTypes) + body << formatv(inferReturnTypesParserCode, op.getCppClassName()); } void OperationFormat::genParserRegionResolution(Operator &op, @@ -2478,6 +2505,7 @@ // during parsing. bool hasAttrDict = false; bool hasAllRegions = false, hasAllSuccessors = false; + bool canInferResultTypes = false; llvm::SmallBitVector seenOperandTypes, seenResultTypes; llvm::SmallSetVector seenAttrs; llvm::DenseSet seenOperands; @@ -2515,6 +2543,9 @@ handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { handleTypesMatchConstraint(variableTyResolver, def); + } else if (def.getName() == "InferTypeOpInterface" && + !op.allResultTypesKnown()) { + canInferResultTypes = true; } } @@ -2684,6 +2715,14 @@ if (fmt.allResultTypes) return ::mlir::success(); + // If no result types are specified and we can infer them, infer all result + // types + if (op.getNumResults() > 0 && seenResultTypes.count() == 0 && + canInferResultTypes) { + fmt.infersResultTypes = true; + return ::mlir::success(); + } + // Check that all of the result types can be inferred. auto &buildableTypes = fmt.buildableTypes; for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {