diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -616,6 +616,9 @@ an argument(attribute or operand), result, etc. In the `CallOp` example above, the variables would be `$callee` and `$args`. +Attribute variables are printed with their respective value type, unless that +value type is buildable. In those cases, the type of the attribute is elided. + #### Requirements The format specification has a certain set of requirements that must be adhered diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -629,6 +629,10 @@ // Requires a constBuilderCall defined. string defaultValue = ?; + // The value type of this attribute. This corresponds to the mlir::Type that + // this attribute returns via `getType()`. + Type valueType = ?; + // Whether the attribute is optional. Typically requires a custom // convertFromStorage method to handle the case where the attribute is // not present. @@ -660,6 +664,7 @@ let convertFromStorage = attr.convertFromStorage; let constBuilderCall = attr.constBuilderCall; let defaultValue = val; + let valueType = attr.valueType; let baseAttr = attr; } @@ -673,6 +678,7 @@ let returnType = "Optional<" # attr.returnType #">"; let convertFromStorage = "$_self ? " # returnType # "(" # attr.convertFromStorage # ") : (llvm::None)"; + let valueType = attr.valueType; let isOptional = 1; let baseAttr = attr; @@ -681,14 +687,15 @@ //===----------------------------------------------------------------------===// // Primitive attribute kinds -// A generic attribute that must be constructed around a specific type +// A generic attribute that must be constructed around a specific buildable type // `attrValType`. Backed by MLIR attribute kind `attrKind`. -class TypedAttrBase : +class TypedAttrBase : Attr { let constBuilderCall = "$_builder.get" # attrKind # "(" # attrValType.builderCall # ", $0)"; let storageType = attrKind; + let valueType = attrValType; } // Any attribute. @@ -1227,6 +1234,7 @@ let convertFromStorage = attr.convertFromStorage; let constBuilderCall = attr.constBuilderCall; let defaultValue = attr.defaultValue; + let valueType = attr.valueType; let isOptional = attr.isOptional; let baseAttr = attr; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -58,6 +58,10 @@ virtual void printType(Type type) = 0; virtual void printAttribute(Attribute attr) = 0; + /// Print the given attribute without its type. The corresponding parser must + /// provide a valid type for the attribute. + virtual void printAttributeWithoutType(Attribute attr) = 0; + /// Print a successor, and use list, of a terminator operation given the /// terminator and the successor index. virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0; diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -25,6 +25,7 @@ namespace mlir { namespace tblgen { +class Type; // Wrapper class with helper methods for accessing attribute constraints defined // in TableGen. @@ -54,6 +55,10 @@ // Returns the return type for this attribute. StringRef getReturnType() const; + // Return the type constraint corresponding to the type of this attribute, or + // None if this is not a TypedAttr. + llvm::Optional getValueType() const; + // Returns the template getter method call which reads this attribute's // storage and returns the value as of the desired return type. // The call will contain a `{0}` which will be expanded to this attribute. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -823,10 +823,21 @@ mlir::interleaveComma(c, os, each_fn); } - /// Print the given attribute. If 'mayElideType' is true, some attributes are - /// printed without the type when the type matches the default used in the - /// parser (for example i64 is the default for integer attributes). - void printAttribute(Attribute attr, bool mayElideType = false); + /// This enum descripes the different kinds of elision for the type of an + /// attribute when printing it. + enum class AttrTypeElision { + /// The type must not be elided, + Never, + /// The type may be elided when it matches the default used in the parser + /// (for example i64 is the default for integer attributes). + May, + /// The type must be elided. + Must + }; + + /// Print the given attribute. + void printAttribute(Attribute attr, + AttrTypeElision typeElision = AttrTypeElision::Never); void printType(Type type); void printLocation(LocationAttr loc); @@ -1150,7 +1161,8 @@ os << R"(opaque<"", "0xDEADBEEF">)"; } -void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { +void ModulePrinter::printAttribute(Attribute attr, + AttrTypeElision typeElision) { if (!attr) { os << "<>"; return; @@ -1165,6 +1177,7 @@ } } + auto attrType = attr.getType(); switch (attr.getKind()) { default: return printDialectAttribute(attr); @@ -1201,12 +1214,11 @@ case StandardAttributes::Integer: { auto intAttr = attr.cast(); // Print all integer attributes as signed unless i1. - bool isSigned = intAttr.getType().isIndex() || - intAttr.getType().getIntOrFloatBitWidth() != 1; + bool isSigned = attrType.isIndex() || attrType.getIntOrFloatBitWidth() != 1; intAttr.getValue().print(os, isSigned); // IntegerAttr elides the type if I64. - if (mayElideType && intAttr.getType().isInteger(64)) + if (typeElision == AttrTypeElision::May && attrType.isInteger(64)) return; break; } @@ -1215,7 +1227,7 @@ printFloatValue(floatAttr.getValue(), os); // FloatAttr elides the type if F64. - if (mayElideType && floatAttr.getType().isF64()) + if (typeElision == AttrTypeElision::May && attrType.isF64()) return; break; } @@ -1227,7 +1239,7 @@ case StandardAttributes::Array: os << '['; interleaveComma(attr.cast().getValue(), [&](Attribute attr) { - printAttribute(attr, /*mayElideType=*/true); + printAttribute(attr, AttrTypeElision::May); }); os << ']'; break; @@ -1304,9 +1316,8 @@ break; } - // Print the type if it isn't a 'none' type. - auto attrType = attr.getType(); - if (!attrType.isa()) { + // Don't print the type if we must elide it, or if it is a None type. + if (typeElision != AttrTypeElision::Must && !attrType.isa()) { os << " : "; printType(attrType); } @@ -1869,6 +1880,12 @@ ModulePrinter::printAttribute(attr); } + /// Print the given attribute without its type. The corresponding parser must + /// provide a valid type for the attribute. + void printAttributeWithoutType(Attribute attr) override { + ModulePrinter::printAttribute(attr, AttrTypeElision::Must); + } + /// Print the ID for the given value. void printOperand(Value value) override { printValueID(value); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -285,14 +285,14 @@ Attribute parseDecOrHexAttr(Type type, bool isNegative); /// Parse an opaque elements attribute. - Attribute parseOpaqueElementsAttr(); + Attribute parseOpaqueElementsAttr(Type attrType); /// Parse a dense elements attribute. - Attribute parseDenseElementsAttr(); - ShapedType parseElementsLiteralType(); + Attribute parseDenseElementsAttr(Type attrType); + ShapedType parseElementsLiteralType(Type type); /// Parse a sparse elements attribute. - Attribute parseSparseElementsAttr(); + Attribute parseSparseElementsAttr(Type attrType); //===--------------------------------------------------------------------===// // Location Parsing @@ -1505,7 +1505,7 @@ // Parse a dense elements attribute. case Token::kw_dense: - return parseDenseElementsAttr(); + return parseDenseElementsAttr(type); // Parse a dictionary attribute. case Token::l_brace: { @@ -1543,11 +1543,11 @@ // Parse an opaque elements attribute. case Token::kw_opaque: - return parseOpaqueElementsAttr(); + return parseOpaqueElementsAttr(type); // Parse a sparse elements attribute. case Token::kw_sparse: - return parseSparseElementsAttr(); + return parseSparseElementsAttr(type); // Parse a string attribute. case Token::string: { @@ -1783,7 +1783,7 @@ } /// Parse an opaque elements attribute. -Attribute Parser::parseOpaqueElementsAttr() { +Attribute Parser::parseOpaqueElementsAttr(Type attrType) { consumeToken(Token::kw_opaque); if (parseToken(Token::less, "expected '<' after 'opaque'")) return nullptr; @@ -1816,11 +1816,10 @@ return (emitError("opaque string only contains hex digits"), nullptr); consumeToken(Token::string); - if (parseToken(Token::greater, "expected '>'") || - parseToken(Token::colon, "expected ':'")) + if (parseToken(Token::greater, "expected '>'")) return nullptr; - auto type = parseElementsLiteralType(); + auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; @@ -2086,7 +2085,7 @@ } /// Parse a dense elements attribute. -Attribute Parser::parseDenseElementsAttr() { +Attribute Parser::parseDenseElementsAttr(Type attrType) { consumeToken(Token::kw_dense); if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; @@ -2096,12 +2095,11 @@ if (literalParser.parse()) return nullptr; - if (parseToken(Token::greater, "expected '>'") || - parseToken(Token::colon, "expected ':'")) + if (parseToken(Token::greater, "expected '>'")) return nullptr; auto typeLoc = getToken().getLoc(); - auto type = parseElementsLiteralType(); + auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; return literalParser.getAttr(typeLoc, type); @@ -2112,10 +2110,14 @@ /// elements-literal-type ::= vector-type | ranked-tensor-type /// /// This method also checks the type has static shape. -ShapedType Parser::parseElementsLiteralType() { - auto type = parseType(); - if (!type) - return nullptr; +ShapedType Parser::parseElementsLiteralType(Type type) { + // If the user didn't provide a type, parse the colon type for the literal. + if (!type) { + if (parseToken(Token::colon, "expected ':'")) + return nullptr; + if (!(type = parseType())) + return nullptr; + } if (!type.isa() && !type.isa()) { emitError("elements literal must be a ranked tensor or vector type"); @@ -2130,7 +2132,7 @@ } /// Parse a sparse elements attribute. -Attribute Parser::parseSparseElementsAttr() { +Attribute Parser::parseSparseElementsAttr(Type attrType) { consumeToken(Token::kw_sparse); if (parseToken(Token::less, "Expected '<' after 'sparse'")) return nullptr; @@ -2150,11 +2152,10 @@ if (valuesParser.parse()) return nullptr; - if (parseToken(Token::greater, "expected '>'") || - parseToken(Token::colon, "expected ':'")) + if (parseToken(Token::greater, "expected '>'")) return nullptr; - auto type = parseElementsLiteralType(); + auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -75,6 +75,14 @@ return getValueAsString(init); } +// Return the type constraint corresponding to the type of this attribute, or +// None if this is not a TypedAttr. +llvm::Optional tblgen::Attribute::getValueType() const { + if (auto *defInit = dyn_cast(def->getValueInit("valueType"))) + return tblgen::Type(defInit->getDef()); + return llvm::None; +} + StringRef tblgen::Attribute::getConvertFromStorageCall() const { const auto *init = def->getValueInit("convertFromStorage"); return getValueAsString(init); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -268,9 +268,10 @@ /// /// {0}: The storage type of the attribute. /// {1}: The name of the attribute. +/// {2}: The type for the attribute. const char *const attrParserCode = R"( {0} {1}Attr; - if (parser.parseAttribute({1}Attr, "{1}", result.attributes)) + if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes)) return failure(); )"; @@ -368,6 +369,10 @@ OpMethod::MP_Static); auto &body = method.body(); + // A format context used when parsing attributes with buildable types. + FmtContext attrTypeCtx; + attrTypeCtx.withBuilder("parser.getBuilder()"); + // Generate parsers for each of the elements. for (auto &element : elements) { /// Literals. @@ -377,7 +382,19 @@ /// Arguments. } else if (auto *attr = dyn_cast(element.get())) { const NamedAttribute *var = attr->getVar(); - body << formatv(attrParserCode, var->attr.getStorageType(), var->name); + + // If this attribute has a buildable type, use that when parsing the + // attribute. + std::string attrTypeStr; + if (Optional attrType = var->attr.getValueType()) { + if (Optional typeBuilder = attrType->getBuilderCall()) { + llvm::raw_string_ostream os(attrTypeStr); + os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); + } + } + + body << formatv(attrParserCode, var->attr.getStorageType(), var->name, + attrTypeStr); } else if (auto *operand = dyn_cast(element.get())) { bool isVariadic = operand->getVar()->isVariadic(); body << formatv(isVariadic ? variadicOperandParserCode @@ -615,7 +632,14 @@ shouldEmitSpace = true; if (auto *attr = dyn_cast(element.get())) { - body << " p << " << attr->getVar()->name << "Attr();\n"; + const NamedAttribute *var = attr->getVar(); + + // Elide the attribute type if it is buildable.. + Optional attrType = var->attr.getValueType(); + if (attrType && attrType->getBuilderCall()) + body << " p.printAttributeWithoutType(" << var->name << "Attr());\n"; + else + body << " p.printAttribute(" << var->name << "Attr());\n"; } else if (auto *operand = dyn_cast(element.get())) { body << " p << " << operand->getVar()->name << "();\n"; } else if (isa(element.get())) {