diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2021,6 +2021,26 @@ let assemblyFormat = "attr-dict $value `:` type($value)"; } +//===----------------------------------------------------------------------===// +// InferTypeOpInterface type inference in assembly format + +def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> { + let arguments = (ins Variadic:$inputs); + let results = (outs AnyType:$result); + + let assemblyFormat = "attr-dict $inputs `:` type($inputs)"; + + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.assign({operands[0].getType()}); + return ::mlir::success(); + } + }]; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -3,6 +3,7 @@ // This file contains tests for the specification of the declarative op format. include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def TestDialect : Dialect { let name = "test"; @@ -566,4 +567,6 @@ operands type($result) attr-dict }], [AllTypesMatch<["operand", "result"]>]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; - +def ZCoverageValidI : TestFormat_Op<[{ + operands type(operands) attr-dict +}], [InferTypeOpInterface]>, Arguments<(ins Variadic:$inputs)>, Results<(outs I64:$result)>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -354,3 +354,10 @@ // CHECK: test.format_types_match_context %[[I64]] : i64 %ignored_res6 = test.format_types_match_context %i64 : i64 + +//===----------------------------------------------------------------------===// +// InferTypeOpInterface type inference +//===----------------------------------------------------------------------===// + +// CHECK: test.format_infer_type %[[I32]], %[[I64]] : i32, i64 +%ignored_res7 = test.format_infer_type %i32, %i64 : i32, i64 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -441,7 +441,8 @@ }; OperationFormat(const Operator &op) - : allOperands(false), allOperandTypes(false), allResultTypes(false) { + : allOperands(false), allOperandTypes(false), allResultTypes(false), + infersResultTypes(false) { operandTypes.resize(op.getNumOperands(), TypeResolution()); resultTypes.resize(op.getNumResults(), TypeResolution()); @@ -482,6 +483,9 @@ /// contains these, it can not contain individual type resolvers. bool allOperands, allOperandTypes, allResultTypes; + /// A flag indicating if this operation infers its result types + bool infersResultTypes; + /// A flag indicating if this operation has the SingleBlockImplicitTerminator /// trait. bool hasImplicitTermTrait; @@ -1437,13 +1441,15 @@ }; // Resolve each of the result types. - if (allResultTypes) { - body << " result.addTypes(allResultTypes);\n"; - } else { - for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { - body << " result.addTypes("; - emitTypeResolver(resultTypes[i], op.getResultName(i)); - body << ");\n"; + if (!infersResultTypes) { + if (allResultTypes) { + body << " result.addTypes(allResultTypes);\n"; + } else { + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { + body << " result.addTypes("; + emitTypeResolver(resultTypes[i], op.getResultName(i)); + body << ");\n"; + } } } @@ -1532,6 +1538,20 @@ body << ", " << operand.name << "OperandsLoc"; body << ", result.operands))\n return ::mlir::failure();\n"; } + + // Handle return type inference once all operands have been resolved + if (infersResultTypes) { + body << formatv(R"( + ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; + if (::mlir::failed({0}::inferReturnTypes(parser.getContext(), + result.location, result.operands, + result.attributes.getDictionary(parser.getContext()), + result.regions, inferredReturnTypes))) + return ::mlir::failure(); + result.addTypes(inferredReturnTypes); +)", + op.getCppClassName()); + } } void OperationFormat::genParserRegionResolution(Operator &op, @@ -2478,6 +2498,7 @@ // during parsing. bool hasAttrDict = false; bool hasAllRegions = false, hasAllSuccessors = false; + bool canInferResultTypes = false; llvm::SmallBitVector seenOperandTypes, seenResultTypes; llvm::SmallSetVector seenAttrs; llvm::DenseSet seenOperands; @@ -2515,6 +2536,9 @@ handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { handleTypesMatchConstraint(variableTyResolver, def); + } else if (def.getName() == "InferTypeOpInterface" && + !op.allResultTypesKnown()) { + canInferResultTypes = true; } } @@ -2686,15 +2710,20 @@ // Check that all of the result types can be inferred. auto &buildableTypes = fmt.buildableTypes; + llvm::Optional buildableTypeErrorIndex; + bool hasVerifiedAnyResult = false; for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { - if (seenResultTypes.test(i)) + if (seenResultTypes.test(i)) { + hasVerifiedAnyResult = true; continue; + } // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getResultName(i)); if (varResolverIt != variableTyResolver.end()) { TypeResolutionInstance resolver = varResolverIt->second; fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer); + hasVerifiedAnyResult = true; continue; } @@ -2703,18 +2732,36 @@ NamedTypeConstraint &result = op.getResult(i); Optional builder = result.constraint.getBuilderCall(); if (!builder || result.isVariableLength()) { - return emitErrorAndNote( - loc, - "type of result #" + Twine(i) + ", named '" + result.name + - "', is not buildable and a buildable type cannot be inferred", - "suggest adding a type constraint to the operation or adding a " - "'type($" + - result.name + ")' directive to the " + "custom assembly format"); + if (!buildableTypeErrorIndex) { + // Defer reporting a nonbuildable type error in case the result can be + // inferred + buildableTypeErrorIndex = i; + } + continue; } // Note in the format that this result uses the custom builder. auto it = buildableTypes.insert({*builder, buildableTypes.size()}); fmt.resultTypes[i].setBuilderIdx(it.first->second); + hasVerifiedAnyResult = true; + } + + if (op.getNumResults() > 0 && !hasVerifiedAnyResult && canInferResultTypes) { + fmt.infersResultTypes = true; + return ::mlir::success(); } + + if (buildableTypeErrorIndex) { + unsigned i = *buildableTypeErrorIndex; + auto &result = op.getResult(i); + return emitErrorAndNote( + loc, + "type of result #" + Twine(i) + ", named '" + result.name + + "', is not buildable and a buildable type cannot be inferred", + "suggest adding a type constraint to the operation or adding a " + "'type($" + + result.name + ")' directive to the " + "custom assembly format"); + } + return ::mlir::success(); }