diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -773,8 +773,8 @@ as traits on the operation; for example the true, false, and result values of a `select` operation often have the same type. The assembly format may inspect these equal constraints to discern the types of missing variables. The currently -supported traits are: `AllTypesMatch`, `SameTypeOperands`, and -`SameOperandsAndResultType`. +supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`, +and `SameOperandsAndResultType`. ### `hasCanonicalizer` diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1300,6 +1300,46 @@ }]; } +//===----------------------------------------------------------------------===// +// AllTypesMatch type inference +//===----------------------------------------------------------------------===// + +def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [ + AllTypesMatch<["value1", "value2", "result"]> + ]> { + let arguments = (ins AnyType:$value1, AnyType:$value2); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)"; +} + +def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [ + AllTypesMatch<["value1", "value2", "result"]> + ]> { + let arguments = (ins AnyAttr:$value1, AnyType:$value2); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value1 `,` $value2"; +} + +//===----------------------------------------------------------------------===// +// TypesMatchWith type inference +//===----------------------------------------------------------------------===// + +def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [ + TypesMatchWith<"result type matches operand", "value", "result", "$_self"> + ]> { + let arguments = (ins AnyType:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + +def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ + TypesMatchWith<"result type matches constant", "value", "result", "$_self"> + ]> { + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value"; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -94,3 +94,23 @@ // CHECK: test.format_optional_operand_result_b_op : i64 test.format_optional_operand_result_b_op : i64 + +//===----------------------------------------------------------------------===// +// AllTypesMatch type inference +//===----------------------------------------------------------------------===// + +// CHECK: test.format_all_types_match_var %[[I64]], %[[I64]] : i64 +%ignored_res1 = test.format_all_types_match_var %i64, %i64 : i64 + +// CHECK: test.format_all_types_match_attr 1 : i64, %[[I64]] +%ignored_res2 = test.format_all_types_match_attr 1 : i64, %i64 + +//===----------------------------------------------------------------------===// +// TypesMatchWith type inference +//===----------------------------------------------------------------------===// + +// CHECK: test.format_types_match_var %[[I64]] : i64 +%ignored_res3 = test.format_types_match_var %i64 : i64 + +// CHECK: test.format_types_match_attr 1 : i64 +%ignored_res4 = test.format_types_match_attr 1 : i64 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -270,6 +270,10 @@ //===----------------------------------------------------------------------===// namespace { + +using NamedArgOrAttr = + llvm::PointerUnion; + struct OperationFormat { /// This class represents a specific resolver for an operand or result type. class TypeResolution { @@ -280,15 +284,23 @@ Optional getBuilderIdx() const { return builderIdx; } void setBuilderIdx(int idx) { builderIdx = idx; } - /// Get the variable this type is resolved to, or None. - const NamedTypeConstraint *getVariable() const { return variable; } + /// Get the variable this type is resolved to, or nullptr. + const NamedTypeConstraint *getVariable() const { + return resolver.dyn_cast(); + } + /// Get the attribute this type is resolved to, or nullptr. + const NamedAttribute *getAttribute() const { + return resolver.dyn_cast(); + } + /// Get the transformer for the type of the variable, or None. Optional getVarTransformer() const { return variableTransformer; } - void setVariable(const NamedTypeConstraint *var, + void setResolver(NamedArgOrAttr varOrAttr, Optional transformer) { - variable = var; + resolver = varOrAttr; variableTransformer = transformer; + assert(getVariable() || getAttribute()); } private: @@ -296,8 +308,8 @@ /// 'buildableTypes' in the parent format. Optional builderIdx; /// If the type is resolved based upon another operand or result, this is - /// the variable that this type is resolved to. - const NamedTypeConstraint *variable; + /// the variable or the attribute that this type is resolved to. + NamedArgOrAttr resolver; /// If the type is resolved based upon another operand or result, this is /// a transformer to apply to the variable when resolving. Optional variableTransformer; @@ -722,7 +734,7 @@ continue; // Ensure that we don't verify the same variables twice. const NamedTypeConstraint *variable = resolver.getVariable(); - if (!verifiedVariables.insert(variable).second) + if (!variable || !verifiedVariables.insert(variable).second) continue; auto constraint = variable->constraint; @@ -757,6 +769,12 @@ body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]")); else body << var->name << "Types"; + } else if (const NamedAttribute *attr = resolver.getAttribute()) { + if (Optional tform = resolver.getVarTransformer()) + body << tgfmt(*tform, + &FmtContext().withSelf(attr->name + "Attr.getType()")); + else + body << attr->name << "Attr.getType()"; } else { body << curVar << "Types"; } @@ -1293,7 +1311,8 @@ /// Function to find an element within the given range that has the same name as /// 'name'. -template static auto findArg(RangeT &&range, StringRef name) { +template +static auto findArg(RangeT &&range, StringRef name) { auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); return it != range.end() ? &*it : nullptr; } @@ -1316,7 +1335,7 @@ /// type as well as an optional transformer to apply to that type in order to /// properly resolve the type of a variable. struct TypeResolutionInstance { - const NamedTypeConstraint *type; + NamedArgOrAttr resolver; Optional transformer; }; @@ -1355,10 +1374,21 @@ void handleSameTypesConstraint( llvm::StringMap &variableTyResolver, bool includeResults); + /// Check for inferable type resolution based on another operand, result, or + /// attribute. + void handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, + llvm::Record def); /// Returns an argument with the given name that has been seen within the /// format. const NamedTypeConstraint *findSeenArg(StringRef name); + /// Returns an attribute with the given name that has been seen within the + /// format. + const NamedAttribute *findSeenAttr(StringRef name); + /// Returns an argument or attribute with the given name that has been seen + /// within the format. + NamedArgOrAttr findSeenArgOrAttr(StringRef name); /// Parse a specific element. LogicalResult parseElement(std::unique_ptr &element, @@ -1467,9 +1497,7 @@ } else if (def.getName() == "SameOperandsAndResultType") { handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { - if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs"))) - variableTyResolver[def.getValueAsString("rhs")] = { - lhsArg, def.getValueAsString("transformer")}; + handleTypesMatchConstraint(variableTyResolver, def); } } @@ -1578,8 +1606,8 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getOperand(i).name); if (varResolverIt != variableTyResolver.end()) { - fmt.operandTypes[i].setVariable(varResolverIt->second.type, - varResolverIt->second.transformer); + auto resolver = varResolverIt->second; + fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer); continue; } @@ -1617,8 +1645,8 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getResultName(i)); if (varResolverIt != variableTyResolver.end()) { - fmt.resultTypes[i].setVariable(varResolverIt->second.type, - varResolverIt->second.transformer); + auto resolver = varResolverIt->second; + fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer); continue; } @@ -1665,7 +1693,7 @@ llvm::StringMap &variableTyResolver) { for (unsigned i = 0, e = values.size(); i != e; ++i) { // Check to see if this value matches a resolved operand or result type. - const NamedTypeConstraint *arg = findSeenArg(values[i]); + NamedArgOrAttr arg = findSeenArgOrAttr(values[i]); if (!arg) continue; @@ -1702,6 +1730,16 @@ } } +void FormatParser::handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, + llvm::Record def) { + auto lhsName = def.getValueAsString("lhs"); + auto rhsName = def.getValueAsString("rhs"); + auto transformer = def.getValueAsString("transformer"); + if (auto var = findSeenArgOrAttr(lhsName)) + variableTyResolver[rhsName] = {var, transformer}; +} + const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) { if (auto *arg = findArg(op.getOperands(), name)) return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; @@ -1710,6 +1748,20 @@ return nullptr; } +const NamedAttribute *FormatParser::findSeenAttr(StringRef name) { + if (const auto *attr = findArg(op.getAttributes(), name)) + return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr; + return nullptr; +} + +NamedArgOrAttr FormatParser::findSeenArgOrAttr(StringRef name) { + if (auto *arg = findSeenArg(name)) + return arg; + if (auto *attr = findSeenAttr(name)) + return attr; + return nullptr; +} + LogicalResult FormatParser::parseElement(std::unique_ptr &element, bool isTopLevel) { // Directives.