diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -289,6 +289,14 @@ Optional transformer) { variable = var; variableTransformer = transformer; + attribute = nullptr; + } + /// Get the attribute this type is resolved to, or None. + const NamedAttribute *getAttribute() const { return attribute; } + void setAttribute(const NamedAttribute *attr) { + variable = nullptr; + variableTransformer = llvm::None; + attribute = attr; } private: @@ -301,6 +309,9 @@ /// If the type is resolved based upon another operand or result, this is /// a transformer to apply to the variable when resolving. Optional variableTransformer; + /// If the type is resolved based upon another attribute, this is + /// the attribute that this type is resolved to. + const NamedAttribute *attribute; }; OperationFormat(const Operator &op) @@ -763,6 +774,8 @@ body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]")); else body << var->name << "Types"; + } else if (const NamedAttribute *attr = resolver.getAttribute()) { + body << attr->name << "Attr.getType()"; } else { body << curVar << "Types"; } @@ -1323,6 +1336,7 @@ /// properly resolve the type of a variable. struct TypeResolutionInstance { const NamedTypeConstraint *type; + const NamedAttribute *attr; Optional transformer; }; @@ -1361,10 +1375,17 @@ void handleSameTypesConstraint( llvm::StringMap &variableTyResolver, bool includeResults); + /// Check for inferable type resolution based on another operand, result, or attribute. + void handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, + llvm::Record def); /// Returns an argument with the given name that has been seen within the /// format. const NamedTypeConstraint *findSeenArg(StringRef name); + /// Returns an attribute with the given name that has been seen within the + /// format. + const NamedAttribute *findSeenAttr(StringRef name); /// Parse a specific element. LogicalResult parseElement(std::unique_ptr &element, @@ -1473,9 +1494,7 @@ } else if (def.getName() == "SameOperandsAndResultType") { handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { - if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs"))) - variableTyResolver[def.getValueAsString("rhs")] = { - lhsArg, def.getValueAsString("transformer")}; + handleTypesMatchConstraint(variableTyResolver, def); } } @@ -1584,8 +1603,11 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getOperand(i).name); if (varResolverIt != variableTyResolver.end()) { - fmt.operandTypes[i].setVariable(varResolverIt->second.type, - varResolverIt->second.transformer); + auto resolver = varResolverIt->second; + if (resolver.type) + fmt.operandTypes[i].setVariable(resolver.type, resolver.transformer); + else if (resolver.attr) + fmt.operandTypes[i].setAttribute(resolver.attr); continue; } @@ -1623,8 +1645,11 @@ // Check to see if we can infer this type from another variable. auto varResolverIt = variableTyResolver.find(op.getResultName(i)); if (varResolverIt != variableTyResolver.end()) { - fmt.resultTypes[i].setVariable(varResolverIt->second.type, - varResolverIt->second.transformer); + auto resolver = varResolverIt->second; + if (resolver.type) + fmt.resultTypes[i].setVariable(resolver.type, resolver.transformer); + else if (resolver.attr) + fmt.resultTypes[i].setAttribute(resolver.attr); continue; } @@ -1677,9 +1702,9 @@ // Mark this value as the type resolver for the other variables. for (unsigned j = 0; j != i; ++j) - variableTyResolver[values[j]] = {arg, llvm::None}; + variableTyResolver[values[j]] = {arg, nullptr, llvm::None}; for (unsigned j = i + 1; j != e; ++j) - variableTyResolver[values[j]] = {arg, llvm::None}; + variableTyResolver[values[j]] = {arg, nullptr, llvm::None}; } } @@ -1700,14 +1725,25 @@ // Set the resolvers for each operand and result. for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty()) - variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None}; + variableTyResolver[op.getOperand(i).name] = {resolver, nullptr, llvm::None}; if (includeResults) { for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) if (!seenResultTypes.test(i) && !op.getResultName(i).empty()) - variableTyResolver[op.getResultName(i)] = {resolver, llvm::None}; + variableTyResolver[op.getResultName(i)] = {resolver, nullptr, llvm::None}; } } + void FormatParser::handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, llvm::Record def) { + auto lhsName = def.getValueAsString("lhs"); + auto rhsName = def.getValueAsString("rhs"); + if (const auto *lhsArg = findSeenArg(lhsName)) + variableTyResolver[rhsName] = { + lhsArg, nullptr, def.getValueAsString("transformer")}; + else if (const auto *attr = findSeenAttr(lhsName)) + variableTyResolver[rhsName] = { nullptr, attr, llvm::None }; + +} const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) { if (auto *arg = findArg(op.getOperands(), name)) return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; @@ -1716,6 +1752,12 @@ return nullptr; } +const NamedAttribute *FormatParser::findSeenAttr(StringRef name) { + if (const auto *attr = findArg(op.getAttributes(), name)) + return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr; + return nullptr; +} + LogicalResult FormatParser::parseElement(std::unique_ptr &element, bool isTopLevel) { // Directives.