diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -249,6 +249,7 @@ }]; let arguments = (ins AnyType:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; } def TerminatorOp : Loop_Op<"terminator", [Terminator]> { diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td @@ -23,7 +23,8 @@ [NoSideEffect, SameOperandsAndResultType])>; class SPV_BitFieldExtractOp<string mnemonic, list<OpTrait> traits = []> : - SPV_Op<mnemonic, !listconcat(traits, [NoSideEffect])> { + SPV_Op<mnemonic, !listconcat(traits, + [NoSideEffect, AllTypesMatch<["base", "result"]>])> { let arguments = (ins SPV_ScalarOrVectorOf<SPV_Integer>:$base, SPV_Integer:$offset, @@ -34,9 +35,11 @@ SPV_ScalarOrVectorOf<SPV_Integer>:$result ); - let parser = [{ return ::parseBitFieldExtractOp(parser, result); }]; - let printer = [{ ::printBitFieldExtractOp(this->getOperation(), p); }]; - let verifier = [{ return ::verifyBitFieldExtractOp(this->getOperation()); }]; + let verifier = [{ return success(); }]; + + let assemblyFormat = [{ + operands attr-dict `:` type($base) `,` type($offset) `,` type($count) + }]; } class SPV_BitUnaryOp<string mnemonic, list<OpTrait> traits = []> : diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -333,8 +333,7 @@ let results = (outs); - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; + let assemblyFormat = "attr-dict"; let hasOpcode = 0; @@ -360,8 +359,7 @@ let results = (outs); - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; + let assemblyFormat = "attr-dict"; } // ----- @@ -383,8 +381,7 @@ let results = (outs); - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; + let assemblyFormat = "attr-dict"; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -365,8 +365,7 @@ let results = (outs); - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; + let assemblyFormat = "attr-dict"; let verifier = [{ return success(); }]; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -629,6 +629,9 @@ // Requires a constBuilderCall defined. string defaultValue = ?; + // The value type of this attribute. + Type valueType = ?; + // Whether the attribute is optional. Typically requires a custom // convertFromStorage method to handle the case where the attribute is // not present. @@ -681,14 +684,15 @@ //===----------------------------------------------------------------------===// // Primitive attribute kinds -// A generic attribute that must be constructed around a specific type +// A generic attribute that must be constructed around a specific buildable type // `attrValType`. Backed by MLIR attribute kind `attrKind`. -class TypedAttrBase<BuildableType attrValType, string attrKind, - Pred condition, string descr> : +class TypedAttrBase<Type attrValType, string attrKind, Pred condition, + string descr> : Attr<condition, descr> { let constBuilderCall = "$_builder.get" # attrKind # "(" # attrValType.builderCall # ", $0)"; let storageType = attrKind; + let valueType = attrValType; } // Any attribute. @@ -1227,6 +1231,7 @@ let convertFromStorage = attr.convertFromStorage; let constBuilderCall = attr.constBuilderCall; let defaultValue = attr.defaultValue; + let valueType = attr.valueType; let isOptional = attr.isOptional; let baseAttr = attr; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -58,6 +58,10 @@ virtual void printType(Type type) = 0; virtual void printAttribute(Attribute attr) = 0; + /// Print the given attribute without its type. The corresponding parser must + /// provide a valid type for the attribute. + virtual void printAttributeWithoutType(Attribute attr) = 0; + /// Print a successor, and use list, of a terminator operation given the /// terminator and the successor index. virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0; diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -25,6 +25,7 @@ namespace mlir { namespace tblgen { +class Type; // Wrapper class with helper methods for accessing attribute constraints defined // in TableGen. @@ -54,6 +55,10 @@ // Returns the return type for this attribute. StringRef getReturnType() const; + // Return the type constraint corresponding to the type of this attribute, or + // None if this is not a TypedAttr. + llvm::Optional<Type> getValueType() const; + // Returns the template getter method call which reads this attribute's // storage and returns the value as of the desired return type. // The call will contain a `{0}` which will be expanded to this attribute. diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -418,22 +418,6 @@ return success(); } -static ParseResult parseReduceReturnOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType operand; - Type resultType; - if (parser.parseOperand(operand) || parser.parseColonType(resultType) || - parser.resolveOperand(operand, resultType, result.operands)) - return failure(); - - return success(); -} - -static void print(OpAsmPrinter &p, ReduceReturnOp op) { - p << op.getOperationName() << " " << op.result() << " : " - << op.result().getType(); -} - //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -488,41 +488,6 @@ // Common parsers and printers //===----------------------------------------------------------------------===// -static ParseResult parseBitFieldExtractOp(OpAsmParser &parser, - OperationState &state) { - SmallVector<OpAsmParser::OperandType, 3> operandInfo; - Type baseType; - Type offsetType; - Type countType; - auto loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(operandInfo, 3) || parser.parseColon() || - parser.parseType(baseType) || parser.parseComma() || - parser.parseType(offsetType) || parser.parseComma() || - parser.parseType(countType) || - parser.resolveOperands(operandInfo, {baseType, offsetType, countType}, - loc, state.operands)) { - return failure(); - } - state.addTypes(baseType); - return success(); -} - -static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName() << ' ' << op->getOperands() << " : " - << op->getOperandTypes(); -} - -static LogicalResult verifyBitFieldExtractOp(Operation *op) { - if (op->getOperand(0).getType() != op->getResult(0).getType()) { - return op->emitError("expected the same type for the first operand and " - "result, but provided ") - << op->getOperand(0).getType() << " and " - << op->getResult(0).getType(); - } - return success(); -} - // Parses an atomic update op. If the update op does not take a value (like // AtomicIIncrement) `hasValue` must be false. static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, @@ -668,19 +633,6 @@ return success(); } -// Parses an op that has no inputs and no outputs. -static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { - if (parser.parseOptionalAttrDict(state.attributes)) - return failure(); - return success(); -} - -// Prints an op that has no inputs and no outputs. -static void printNoIOOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName(); - printer.printOptionalAttrDict(op->getAttrs()); -} - static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) { OpAsmParser::OperandType operandInfo; Type type; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -823,10 +823,21 @@ mlir::interleaveComma(c, os, each_fn); } - /// Print the given attribute. If 'mayElideType' is true, some attributes are - /// printed without the type when the type matches the default used in the - /// parser (for example i64 is the default for integer attributes). - void printAttribute(Attribute attr, bool mayElideType = false); + /// This enum descripes the different kinds of elision for the type of an + /// attribute when printing it. + enum class AttrTypeElision { + /// The type must not be elided, + Never, + /// The type may be elided when it matches the default used in the parser + /// (for example i64 is the default for integer attributes). + May, + /// The type must be elided. + Must + }; + + /// Print the given attribute. + void printAttribute(Attribute attr, + AttrTypeElision typeElision = AttrTypeElision::Never); void printType(Type type); void printLocation(LocationAttr loc); @@ -1150,7 +1161,8 @@ os << R"(opaque<"", "0xDEADBEEF">)"; } -void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { +void ModulePrinter::printAttribute(Attribute attr, + AttrTypeElision typeElision) { if (!attr) { os << "<<NULL ATTRIBUTE>>"; return; @@ -1165,6 +1177,7 @@ } } + auto attrType = attr.getType(); switch (attr.getKind()) { default: return printDialectAttribute(attr); @@ -1201,12 +1214,11 @@ case StandardAttributes::Integer: { auto intAttr = attr.cast<IntegerAttr>(); // Print all integer attributes as signed unless i1. - bool isSigned = intAttr.getType().isIndex() || - intAttr.getType().getIntOrFloatBitWidth() != 1; + bool isSigned = attrType.isIndex() || attrType.getIntOrFloatBitWidth() != 1; intAttr.getValue().print(os, isSigned); // IntegerAttr elides the type if I64. - if (mayElideType && intAttr.getType().isInteger(64)) + if (typeElision == AttrTypeElision::May && attrType.isInteger(64)) return; break; } @@ -1215,7 +1227,7 @@ printFloatValue(floatAttr.getValue(), os); // FloatAttr elides the type if F64. - if (mayElideType && floatAttr.getType().isF64()) + if (typeElision == AttrTypeElision::May && attrType.isF64()) return; break; } @@ -1227,7 +1239,7 @@ case StandardAttributes::Array: os << '['; interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) { - printAttribute(attr, /*mayElideType=*/true); + printAttribute(attr, AttrTypeElision::May); }); os << ']'; break; @@ -1304,9 +1316,8 @@ break; } - // Print the type if it isn't a 'none' type. - auto attrType = attr.getType(); - if (!attrType.isa<NoneType>()) { + // Don't print the type if we must elide it, or if it is a None type. + if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) { os << " : "; printType(attrType); } @@ -1869,6 +1880,12 @@ ModulePrinter::printAttribute(attr); } + /// Print the given attribute without its type. The corresponding parser must + /// provide a valid type for the attribute. + void printAttributeWithoutType(Attribute attr) override { + ModulePrinter::printAttribute(attr, AttrTypeElision::Must); + } + /// Print the ID for the given value. void printOperand(Value value) override { printValueID(value); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -285,14 +285,14 @@ Attribute parseDecOrHexAttr(Type type, bool isNegative); /// Parse an opaque elements attribute. - Attribute parseOpaqueElementsAttr(); + Attribute parseOpaqueElementsAttr(Type attrType); /// Parse a dense elements attribute. - Attribute parseDenseElementsAttr(); - ShapedType parseElementsLiteralType(); + Attribute parseDenseElementsAttr(Type attrType); + ShapedType parseElementsLiteralType(Type type); /// Parse a sparse elements attribute. - Attribute parseSparseElementsAttr(); + Attribute parseSparseElementsAttr(Type attrType); //===--------------------------------------------------------------------===// // Location Parsing @@ -1505,7 +1505,7 @@ // Parse a dense elements attribute. case Token::kw_dense: - return parseDenseElementsAttr(); + return parseDenseElementsAttr(type); // Parse a dictionary attribute. case Token::l_brace: { @@ -1543,11 +1543,11 @@ // Parse an opaque elements attribute. case Token::kw_opaque: - return parseOpaqueElementsAttr(); + return parseOpaqueElementsAttr(type); // Parse a sparse elements attribute. case Token::kw_sparse: - return parseSparseElementsAttr(); + return parseSparseElementsAttr(type); // Parse a string attribute. case Token::string: { @@ -1783,7 +1783,7 @@ } /// Parse an opaque elements attribute. -Attribute Parser::parseOpaqueElementsAttr() { +Attribute Parser::parseOpaqueElementsAttr(Type attrType) { consumeToken(Token::kw_opaque); if (parseToken(Token::less, "expected '<' after 'opaque'")) return nullptr; @@ -1816,11 +1816,10 @@ return (emitError("opaque string only contains hex digits"), nullptr); consumeToken(Token::string); - if (parseToken(Token::greater, "expected '>'") || - parseToken(Token::colon, "expected ':'")) + if (parseToken(Token::greater, "expected '>'")) return nullptr; - auto type = parseElementsLiteralType(); + auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; @@ -2086,7 +2085,7 @@ } /// Parse a dense elements attribute. -Attribute Parser::parseDenseElementsAttr() { +Attribute Parser::parseDenseElementsAttr(Type attrType) { consumeToken(Token::kw_dense); if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; @@ -2096,12 +2095,11 @@ if (literalParser.parse()) return nullptr; - if (parseToken(Token::greater, "expected '>'") || - parseToken(Token::colon, "expected ':'")) + if (parseToken(Token::greater, "expected '>'")) return nullptr; auto typeLoc = getToken().getLoc(); - auto type = parseElementsLiteralType(); + auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; return literalParser.getAttr(typeLoc, type); @@ -2112,10 +2110,14 @@ /// elements-literal-type ::= vector-type | ranked-tensor-type /// /// This method also checks the type has static shape. -ShapedType Parser::parseElementsLiteralType() { - auto type = parseType(); - if (!type) - return nullptr; +ShapedType Parser::parseElementsLiteralType(Type type) { + // If the user didn't provide a type, parse the colon type for the literal. + if (!type) { + if (parseToken(Token::colon, "expected ':'")) + return nullptr; + if (!(type = parseType())) + return nullptr; + } if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { emitError("elements literal must be a ranked tensor or vector type"); @@ -2130,7 +2132,7 @@ } /// Parse a sparse elements attribute. -Attribute Parser::parseSparseElementsAttr() { +Attribute Parser::parseSparseElementsAttr(Type attrType) { consumeToken(Token::kw_sparse); if (parseToken(Token::less, "Expected '<' after 'sparse'")) return nullptr; @@ -2150,11 +2152,10 @@ if (valuesParser.parse()) return nullptr; - if (parseToken(Token::greater, "expected '>'") || - parseToken(Token::colon, "expected ':'")) + if (parseToken(Token::greater, "expected '>'")) return nullptr; - auto type = parseElementsLiteralType(); + auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -75,6 +75,14 @@ return getValueAsString(init); } +// Return the type constraint corresponding to the type of this attribute, or +// None if this is not a TypedAttr. +llvm::Optional<tblgen::Type> tblgen::Attribute::getValueType() const { + if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType"))) + return tblgen::Type(defInit->getDef()); + return llvm::None; +} + StringRef tblgen::Attribute::getConvertFromStorageCall() const { const auto *init = def->getValueInit("convertFromStorage"); return getValueAsString(init); diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -257,7 +257,7 @@ // ----- func @bit_field_u_extract_invalid_result_type(%base: vector<3xi32>, %offset: i32, %count: i16) -> vector<4xi32> { - // expected-error @+1 {{expected the same type for the first operand and result, but provided 'vector<3xi32>' and 'vector<4xi32>'}} + // expected-error @+1 {{failed to verify that all of {base, result} have same type}} %0 = "spv.BitFieldUExtract" (%base, %offset, %count) : (vector<3xi32>, i32, i16) -> vector<4xi32> spv.ReturnValue %0 : vector<4xi32> } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -268,9 +268,10 @@ /// /// {0}: The storage type of the attribute. /// {1}: The name of the attribute. +/// {2}: The type for the attribute. const char *const attrParserCode = R"( {0} {1}Attr; - if (parser.parseAttribute({1}Attr, "{1}", result.attributes)) + if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes)) return failure(); )"; @@ -368,6 +369,10 @@ OpMethod::MP_Static); auto &body = method.body(); + // A format context used when parsing attributes with buildable types. + FmtContext attrTypeCtx; + attrTypeCtx.withBuilder("parser.getBuilder()"); + // Generate parsers for each of the elements. for (auto &element : elements) { /// Literals. @@ -377,7 +382,19 @@ /// Arguments. } else if (auto *attr = dyn_cast<AttributeVariable>(element.get())) { const NamedAttribute *var = attr->getVar(); - body << formatv(attrParserCode, var->attr.getStorageType(), var->name); + + // If this attribute has a buildable type, use that when parsing the + // attribute. + std::string attrTypeStr; + if (Optional<Type> attrType = var->attr.getValueType()) { + if (Optional<StringRef> typeBuilder = attrType->getBuilderCall()) { + llvm::raw_string_ostream os(attrTypeStr); + os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); + } + } + + body << formatv(attrParserCode, var->attr.getStorageType(), var->name, + attrTypeStr); } else if (auto *operand = dyn_cast<OperandVariable>(element.get())) { bool isVariadic = operand->getVar()->isVariadic(); body << formatv(isVariadic ? variadicOperandParserCode @@ -615,7 +632,14 @@ shouldEmitSpace = true; if (auto *attr = dyn_cast<AttributeVariable>(element.get())) { - body << " p << " << attr->getVar()->name << "Attr();\n"; + const NamedAttribute *var = attr->getVar(); + + // Elide the attribute type if it is buildable.. + Optional<Type> attrType = var->attr.getValueType(); + if (attrType && attrType->getBuilderCall()) + body << " p.printAttributeWithoutType(" << var->name << "Attr());\n"; + else + body << " p.printAttribute(" << var->name << "Attr());\n"; } else if (auto *operand = dyn_cast<OperandVariable>(element.get())) { body << " p << " << operand->getVar()->name << "();\n"; } else if (isa<OperandsDirective>(element.get())) {