diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -780,8 +780,8 @@ as traits on the operation; for example the true, false, and result values of a `select` operation often have the same type. The assembly format may inspect these equal constraints to discern the types of missing variables. The currently -supported traits are: `AllTypesMatch`, `SameTypeOperands`, and -`SameOperandsAndResultType`. +supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`, +and `SameOperandsAndResultType`. ### `hasCanonicalizer` diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1352,6 +1352,46 @@ let assemblyFormat = "$operands attr-dict `:` type($result)"; } +//===----------------------------------------------------------------------===// +// AllTypesMatch type inference +//===----------------------------------------------------------------------===// + +def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [ + AllTypesMatch<["value1", "value2", "result"]> + ]> { + let arguments = (ins AnyType:$value1, AnyType:$value2); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)"; +} + +def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [ + AllTypesMatch<["value1", "value2", "result"]> + ]> { + let arguments = (ins AnyAttr:$value1, AnyType:$value2); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value1 `,` $value2"; +} + +//===----------------------------------------------------------------------===// +// TypesMatchWith type inference +//===----------------------------------------------------------------------===// + +def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [ + TypesMatchWith<"result type matches operand", "value", "result", "$_self"> + ]> { + let arguments = (ins AnyType:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + +def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ + TypesMatchWith<"result type matches constant", "value", "result", "$_self"> + ]> { + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value"; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -108,3 +108,23 @@ // CHECK: test.format_infer_variadic_type_from_non_variadic %[[I64]], %[[I64]] : i64 test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64 + +//===----------------------------------------------------------------------===// +// AllTypesMatch type inference +//===----------------------------------------------------------------------===// + +// CHECK: test.format_all_types_match_var %[[I64]], %[[I64]] : i64 +%ignored_res1 = test.format_all_types_match_var %i64, %i64 : i64 + +// CHECK: test.format_all_types_match_attr 1 : i64, %[[I64]] +%ignored_res2 = test.format_all_types_match_attr 1 : i64, %i64 + +//===----------------------------------------------------------------------===// +// TypesMatchWith type inference +//===----------------------------------------------------------------------===// + +// CHECK: test.format_types_match_var %[[I64]] : i64 +%ignored_res3 = test.format_types_match_var %i64 : i64 + +// CHECK: test.format_types_match_attr 1 : i64 +%ignored_res4 = test.format_types_match_attr 1 : i64 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -270,6 +270,10 @@ //===----------------------------------------------------------------------===// namespace { + +using ConstArgument = + llvm::PointerUnion; + struct OperationFormat { /// This class represents a specific resolver for an operand or result type. class TypeResolution { @@ -280,15 +284,22 @@ Optional getBuilderIdx() const { return builderIdx; } void setBuilderIdx(int idx) { builderIdx = idx; } - /// Get the variable this type is resolved to, or None. - const NamedTypeConstraint *getVariable() const { return variable; } + /// Get the variable this type is resolved to, or nullptr. + const NamedTypeConstraint *getVariable() const { + return resolver.dyn_cast(); + } + /// Get the attribute this type is resolved to, or nullptr. + const NamedAttribute *getAttribute() const { + return resolver.dyn_cast(); + } + /// Get the transformer for the type of the variable, or None. Optional getVarTransformer() const { return variableTransformer; } - void setVariable(const NamedTypeConstraint *var, - Optional transformer) { - variable = var; + void setResolver(ConstArgument arg, Optional transformer) { + resolver = arg; variableTransformer = transformer; + assert(getVariable() || getAttribute()); } private: @@ -296,8 +307,8 @@ /// 'buildableTypes' in the parent format. Optional builderIdx; /// If the type is resolved based upon another operand or result, this is - /// the variable that this type is resolved to. - const NamedTypeConstraint *variable; + /// the variable or the attribute that this type is resolved to. + ConstArgument resolver; /// If the type is resolved based upon another operand or result, this is /// a transformer to apply to the variable when resolving. Optional variableTransformer; @@ -729,7 +740,7 @@ continue; // Ensure that we don't verify the same variables twice. const NamedTypeConstraint *variable = resolver.getVariable(); - if (!verifiedVariables.insert(variable).second) + if (!variable || !verifiedVariables.insert(variable).second) continue; auto constraint = variable->constraint; @@ -764,6 +775,12 @@ body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]")); else body << var->name << "Types"; + } else if (const NamedAttribute *attr = resolver.getAttribute()) { + if (Optional tform = resolver.getVarTransformer()) + body << tgfmt(*tform, + &FmtContext().withSelf(attr->name + "Attr.getType()")); + else + body << attr->name << "Attr.getType()"; } else { body << curVar << "Types"; } @@ -1353,7 +1370,7 @@ /// type as well as an optional transformer to apply to that type in order to /// properly resolve the type of a variable. struct TypeResolutionInstance { - const NamedTypeConstraint *type; + ConstArgument resolver; Optional transformer; }; @@ -1392,10 +1409,15 @@ void handleSameTypesConstraint( llvm::StringMap &variableTyResolver, bool includeResults); + /// Check for inferable type resolution based on another operand, result, or + /// attribute. + void handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, + llvm::Record def); - /// Returns an argument with the given name that has been seen within the - /// format. - const NamedTypeConstraint *findSeenArg(StringRef name); + /// Returns an argument or attribute with the given name that has been seen + /// within the format. + ConstArgument findSeenArg(StringRef name); /// Parse a specific element. LogicalResult parseElement(std::unique_ptr &element, @@ -1504,9 +1526,7 @@ } else if (def.getName() == "SameOperandsAndResultType") { handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { - if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs"))) - variableTyResolver[def.getValueAsString("rhs")] = { - lhsArg, def.getValueAsString("transformer")}; + handleTypesMatchConstraint(variableTyResolver, def); } } @@ -1615,8 +1635,8 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getOperand(i).name); if (varResolverIt != variableTyResolver.end()) { - fmt.operandTypes[i].setVariable(varResolverIt->second.type, - varResolverIt->second.transformer); + TypeResolutionInstance &resolver = varResolverIt->second; + fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer); continue; } @@ -1654,8 +1674,8 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getResultName(i)); if (varResolverIt != variableTyResolver.end()) { - fmt.resultTypes[i].setVariable(varResolverIt->second.type, - varResolverIt->second.transformer); + TypeResolutionInstance resolver = varResolverIt->second; + fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer); continue; } @@ -1702,7 +1722,7 @@ llvm::StringMap &variableTyResolver) { for (unsigned i = 0, e = values.size(); i != e; ++i) { // Check to see if this value matches a resolved operand or result type. - const NamedTypeConstraint *arg = findSeenArg(values[i]); + ConstArgument arg = findSeenArg(values[i]); if (!arg) continue; @@ -1739,11 +1759,23 @@ } } -const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) { - if (auto *arg = findArg(op.getOperands(), name)) +void FormatParser::handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, + llvm::Record def) { + StringRef lhsName = def.getValueAsString("lhs"); + StringRef rhsName = def.getValueAsString("rhs"); + StringRef transformer = def.getValueAsString("transformer"); + if (ConstArgument arg = findSeenArg(lhsName)) + variableTyResolver[rhsName] = {arg, transformer}; +} + +ConstArgument FormatParser::findSeenArg(StringRef name) { + if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name)) return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; - if (auto *arg = findArg(op.getResults(), name)) + if (const NamedTypeConstraint *arg = findArg(op.getResults(), name)) return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr; + if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) + return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr; return nullptr; }