diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1242,6 +1242,26 @@ (`[` $variadic^ `]`)? attr-dict }]>; +//===----------------------------------------------------------------------===// +// TypesMatchWith type inference +//===----------------------------------------------------------------------===// + +def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [ + TypesMatchWith<"result type matches operand", "value", "result", "$_self"> + ]> { + let arguments = (ins AnyType:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + +def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ + TypesMatchWith<"result type matches constant", "value", "result", "$_self"> + ]> { + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value"; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -94,3 +94,13 @@ // CHECK: test.format_optional_operand_result_b_op : i64 test.format_optional_operand_result_b_op : i64 + +//===----------------------------------------------------------------------===// +// TypesMatchWith type inference +//===----------------------------------------------------------------------===// + +// CHECK: test.format_types_match_var %[[I64]] : i64 +%ignored_res1 = test.format_types_match_var %i64 : i64 + +// CHECK: test.format_types_match_attr 1 : i64 +%ignored_res2 = test.format_types_match_attr 1 : i64 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -280,15 +280,23 @@ Optional getBuilderIdx() const { return builderIdx; } void setBuilderIdx(int idx) { builderIdx = idx; } - /// Get the variable this type is resolved to, or None. - const NamedTypeConstraint *getVariable() const { return variable; } + /// Get the variable this type is resolved to, or nullptr. + const NamedTypeConstraint *getVariable() const { + return resolver.dyn_cast(); + } + /// Get the attribute this type is resolved to, or nullptr. + const NamedAttribute *getAttribute() const { + return resolver.dyn_cast(); + } + /// Get the transformer for the type of the variable, or None. Optional getVarTransformer() const { return variableTransformer; } - void setVariable(const NamedTypeConstraint *var, + void setResolver(llvm::PointerUnion varOrAttr, Optional transformer) { - variable = var; + resolver = varOrAttr; variableTransformer = transformer; + assert(getVariable() || getAttribute()); } private: @@ -296,8 +304,8 @@ /// 'buildableTypes' in the parent format. Optional builderIdx; /// If the type is resolved based upon another operand or result, this is - /// the variable that this type is resolved to. - const NamedTypeConstraint *variable; + /// the variable or the attribute that this type is resolved to. + llvm::PointerUnion resolver; /// If the type is resolved based upon another operand or result, this is /// a transformer to apply to the variable when resolving. Optional variableTransformer; @@ -726,7 +734,7 @@ continue; // Ensure that we don't verify the same variables twice. const NamedTypeConstraint *variable = resolver.getVariable(); - if (!verifiedVariables.insert(variable).second) + if (!variable || !verifiedVariables.insert(variable).second) continue; auto constraint = variable->constraint; @@ -763,6 +771,11 @@ body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]")); else body << var->name << "Types"; + } else if (const NamedAttribute *attr = resolver.getAttribute()) { + if (Optional tform = resolver.getVarTransformer()) + body << tgfmt(*tform, &FmtContext().withSelf(attr->name + "Attr.getType()")); + else + body << attr->name << "Attr.getType()"; } else { body << curVar << "Types"; } @@ -1322,7 +1335,7 @@ /// type as well as an optional transformer to apply to that type in order to /// properly resolve the type of a variable. struct TypeResolutionInstance { - const NamedTypeConstraint *type; + llvm::PointerUnion resolver; Optional transformer; }; @@ -1361,10 +1374,17 @@ void handleSameTypesConstraint( llvm::StringMap &variableTyResolver, bool includeResults); + /// Check for inferable type resolution based on another operand, result, or attribute. + void handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, + llvm::Record def); /// Returns an argument with the given name that has been seen within the /// format. const NamedTypeConstraint *findSeenArg(StringRef name); + /// Returns an attribute with the given name that has been seen within the + /// format. + const NamedAttribute *findSeenAttr(StringRef name); /// Parse a specific element. LogicalResult parseElement(std::unique_ptr &element, @@ -1473,9 +1493,7 @@ } else if (def.getName() == "SameOperandsAndResultType") { handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { - if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs"))) - variableTyResolver[def.getValueAsString("rhs")] = { - lhsArg, def.getValueAsString("transformer")}; + handleTypesMatchConstraint(variableTyResolver, def); } } @@ -1584,8 +1602,8 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getOperand(i).name); if (varResolverIt != variableTyResolver.end()) { - fmt.operandTypes[i].setVariable(varResolverIt->second.type, - varResolverIt->second.transformer); + auto resolver = varResolverIt->second; + fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer); continue; } @@ -1623,8 +1641,8 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getResultName(i)); if (varResolverIt != variableTyResolver.end()) { - fmt.resultTypes[i].setVariable(varResolverIt->second.type, - varResolverIt->second.transformer); + auto resolver = varResolverIt->second; + fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer); continue; } @@ -1708,6 +1726,18 @@ } } +void FormatParser::handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, + llvm::Record def) { + auto lhsName = def.getValueAsString("lhs"); + auto rhsName = def.getValueAsString("rhs"); + auto transformer = def.getValueAsString("transformer"); + if (const auto *var = findSeenArg(lhsName)) + variableTyResolver[rhsName] = { var, transformer }; + else if (const auto *attr = findSeenAttr(lhsName)) + variableTyResolver[rhsName] = { attr, transformer }; +} + const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) { if (auto *arg = findArg(op.getOperands(), name)) return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; @@ -1716,6 +1746,12 @@ return nullptr; } +const NamedAttribute *FormatParser::findSeenAttr(StringRef name) { + if (const auto *attr = findArg(op.getAttributes(), name)) + return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr; + return nullptr; +} + LogicalResult FormatParser::parseElement(std::unique_ptr &element, bool isTopLevel) { // Directives.