diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -249,6 +249,7 @@ }]; let arguments = (ins AnyType:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; } def TerminatorOp : Loop_Op<"terminator", [Terminator]> { diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td @@ -23,7 +23,8 @@ [NoSideEffect, SameOperandsAndResultType])>; class SPV_BitFieldExtractOp traits = []> : - SPV_Op { + SPV_Op])> { let arguments = (ins SPV_ScalarOrVectorOf:$base, SPV_Integer:$offset, @@ -34,9 +35,11 @@ SPV_ScalarOrVectorOf:$result ); - let parser = [{ return ::parseBitFieldExtractOp(parser, result); }]; - let printer = [{ ::printBitFieldExtractOp(this->getOperation(), p); }]; - let verifier = [{ return ::verifyBitFieldExtractOp(this->getOperation()); }]; + let verifier = [{ return success(); }]; + + let assemblyFormat = [{ + operands attr-dict `:` type($base) `,` type($offset) `,` type($count) + }]; } class SPV_BitUnaryOp traits = []> : diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -333,8 +333,7 @@ let results = (outs); - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; + let assemblyFormat = "attr-dict"; let hasOpcode = 0; @@ -360,8 +359,7 @@ let results = (outs); - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; + let assemblyFormat = "attr-dict"; } // ----- @@ -383,8 +381,7 @@ let results = (outs); - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; + let assemblyFormat = "attr-dict"; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -365,8 +365,7 @@ let results = (outs); - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; + let assemblyFormat = "attr-dict"; let verifier = [{ return success(); }]; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -629,6 +629,9 @@ // Requires a constBuilderCall defined. string defaultValue = ?; + // The value type of this attribute. + Type valueType = ?; + // Whether the attribute is optional. Typically requires a custom // convertFromStorage method to handle the case where the attribute is // not present. @@ -681,14 +684,15 @@ //===----------------------------------------------------------------------===// // Primitive attribute kinds -// A generic attribute that must be constructed around a specific type +// A generic attribute that must be constructed around a specific buildable type // `attrValType`. Backed by MLIR attribute kind `attrKind`. -class TypedAttrBase : +class TypedAttrBase : Attr { let constBuilderCall = "$_builder.get" # attrKind # "(" # attrValType.builderCall # ", $0)"; let storageType = attrKind; + let valueType = attrValType; } // Any attribute. @@ -1227,6 +1231,7 @@ let convertFromStorage = attr.convertFromStorage; let constBuilderCall = attr.constBuilderCall; let defaultValue = attr.defaultValue; + let valueType = attr.valueType; let isOptional = attr.isOptional; let baseAttr = attr; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -58,6 +58,10 @@ virtual void printType(Type type) = 0; virtual void printAttribute(Attribute attr) = 0; + /// Print the given attribute without its type. The corresponding parser must + /// provide a valid type for the attribute. + virtual void printAttributeWithoutType(Attribute attr) = 0; + /// Print a successor, and use list, of a terminator operation given the /// terminator and the successor index. virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0; diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -25,6 +25,7 @@ namespace mlir { namespace tblgen { +class Type; // Wrapper class with helper methods for accessing attribute constraints defined // in TableGen. @@ -54,6 +55,10 @@ // Returns the return type for this attribute. StringRef getReturnType() const; + // Return the type constraint corresponding to the type of this attribute, or + // None if this is not a TypedAttr. + llvm::Optional getValueType() const; + // Returns the template getter method call which reads this attribute's // storage and returns the value as of the desired return type. // The call will contain a `{0}` which will be expanded to this attribute. diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -418,22 +418,6 @@ return success(); } -static ParseResult parseReduceReturnOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType operand; - Type resultType; - if (parser.parseOperand(operand) || parser.parseColonType(resultType) || - parser.resolveOperand(operand, resultType, result.operands)) - return failure(); - - return success(); -} - -static void print(OpAsmPrinter &p, ReduceReturnOp op) { - p << op.getOperationName() << " " << op.result() << " : " - << op.result().getType(); -} - //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -488,41 +488,6 @@ // Common parsers and printers //===----------------------------------------------------------------------===// -static ParseResult parseBitFieldExtractOp(OpAsmParser &parser, - OperationState &state) { - SmallVector operandInfo; - Type baseType; - Type offsetType; - Type countType; - auto loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(operandInfo, 3) || parser.parseColon() || - parser.parseType(baseType) || parser.parseComma() || - parser.parseType(offsetType) || parser.parseComma() || - parser.parseType(countType) || - parser.resolveOperands(operandInfo, {baseType, offsetType, countType}, - loc, state.operands)) { - return failure(); - } - state.addTypes(baseType); - return success(); -} - -static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName() << ' ' << op->getOperands() << " : " - << op->getOperandTypes(); -} - -static LogicalResult verifyBitFieldExtractOp(Operation *op) { - if (op->getOperand(0).getType() != op->getResult(0).getType()) { - return op->emitError("expected the same type for the first operand and " - "result, but provided ") - << op->getOperand(0).getType() << " and " - << op->getResult(0).getType(); - } - return success(); -} - // Parses an atomic update op. If the update op does not take a value (like // AtomicIIncrement) `hasValue` must be false. static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, @@ -668,19 +633,6 @@ return success(); } -// Parses an op that has no inputs and no outputs. -static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { - if (parser.parseOptionalAttrDict(state.attributes)) - return failure(); - return success(); -} - -// Prints an op that has no inputs and no outputs. -static void printNoIOOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName(); - printer.printOptionalAttrDict(op->getAttrs()); -} - static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) { OpAsmParser::OperandType operandInfo; Type type; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -823,10 +823,21 @@ mlir::interleaveComma(c, os, each_fn); } - /// Print the given attribute. If 'mayElideType' is true, some attributes are - /// printed without the type when the type matches the default used in the - /// parser (for example i64 is the default for integer attributes). - void printAttribute(Attribute attr, bool mayElideType = false); + /// This enum descripes the different kinds of elision for the type of an + /// attribute when printing it. + enum class AttrTypeElision { + /// The type must not be elided, + Never, + /// The type may be elided when it matches the default used in the parser + /// (for example i64 is the default for integer attributes). + May, + /// The type must be elided. + Must + }; + + /// Print the given attribute. + void printAttribute(Attribute attr, + AttrTypeElision typeElision = AttrTypeElision::Never); void printType(Type type); void printLocation(LocationAttr loc); @@ -1150,7 +1161,8 @@ os << R"(opaque<"", "0xDEADBEEF">)"; } -void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { +void ModulePrinter::printAttribute(Attribute attr, + AttrTypeElision typeElision) { if (!attr) { os << "<>"; return; @@ -1165,6 +1177,7 @@ } } + auto attrType = attr.getType(); switch (attr.getKind()) { default: return printDialectAttribute(attr); @@ -1201,12 +1214,11 @@ case StandardAttributes::Integer: { auto intAttr = attr.cast(); // Print all integer attributes as signed unless i1. - bool isSigned = intAttr.getType().isIndex() || - intAttr.getType().getIntOrFloatBitWidth() != 1; + bool isSigned = attrType.isIndex() || attrType.getIntOrFloatBitWidth() != 1; intAttr.getValue().print(os, isSigned); // IntegerAttr elides the type if I64. - if (mayElideType && intAttr.getType().isInteger(64)) + if (typeElision == AttrTypeElision::May && attrType.isInteger(64)) return; break; } @@ -1215,7 +1227,7 @@ printFloatValue(floatAttr.getValue(), os); // FloatAttr elides the type if F64. - if (mayElideType && floatAttr.getType().isF64()) + if (typeElision == AttrTypeElision::May && attrType.isF64()) return; break; } @@ -1227,7 +1239,7 @@ case StandardAttributes::Array: os << '['; interleaveComma(attr.cast().getValue(), [&](Attribute attr) { - printAttribute(attr, /*mayElideType=*/true); + printAttribute(attr, AttrTypeElision::May); }); os << ']'; break; @@ -1304,9 +1316,8 @@ break; } - // Print the type if it isn't a 'none' type. - auto attrType = attr.getType(); - if (!attrType.isa()) { + // Don't print the type if we must elide it, or if it is a None type. + if (typeElision != AttrTypeElision::Must && !attrType.isa()) { os << " : "; printType(attrType); } @@ -1869,6 +1880,12 @@ ModulePrinter::printAttribute(attr); } + /// Print the given attribute without its type. The corresponding parser must + /// provide a valid type for the attribute. + void printAttributeWithoutType(Attribute attr) override { + ModulePrinter::printAttribute(attr, AttrTypeElision::Must); + } + /// Print the ID for the given value. void printOperand(Value value) override { printValueID(value); } diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -75,6 +75,14 @@ return getValueAsString(init); } +// Return the type constraint corresponding to the type of this attribute, or +// None if this is not a TypedAttr. +llvm::Optional tblgen::Attribute::getValueType() const { + if (auto *defInit = dyn_cast(def->getValueInit("valueType"))) + return tblgen::Type(defInit->getDef()); + return llvm::None; +} + StringRef tblgen::Attribute::getConvertFromStorageCall() const { const auto *init = def->getValueInit("convertFromStorage"); return getValueAsString(init); diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -257,7 +257,7 @@ // ----- func @bit_field_u_extract_invalid_result_type(%base: vector<3xi32>, %offset: i32, %count: i16) -> vector<4xi32> { - // expected-error @+1 {{expected the same type for the first operand and result, but provided 'vector<3xi32>' and 'vector<4xi32>'}} + // expected-error @+1 {{failed to verify that all of {base, result} have same type}} %0 = "spv.BitFieldUExtract" (%base, %offset, %count) : (vector<3xi32>, i32, i16) -> vector<4xi32> spv.ReturnValue %0 : vector<4xi32> } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -268,9 +268,10 @@ /// /// {0}: The storage type of the attribute. /// {1}: The name of the attribute. +/// {2}: The type for the attribute. const char *const attrParserCode = R"( {0} {1}Attr; - if (parser.parseAttribute({1}Attr, "{1}", result.attributes)) + if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes)) return failure(); )"; @@ -368,6 +369,10 @@ OpMethod::MP_Static); auto &body = method.body(); + // A format context used when parsing attributes with buildable types. + FmtContext attrTypeCtx; + attrTypeCtx.withBuilder("parser.getBuilder()"); + // Generate parsers for each of the elements. for (auto &element : elements) { /// Literals. @@ -377,7 +382,19 @@ /// Arguments. } else if (auto *attr = dyn_cast(element.get())) { const NamedAttribute *var = attr->getVar(); - body << formatv(attrParserCode, var->attr.getStorageType(), var->name); + + // If this attribute has a buildable type, use that when parsing the + // attribute. + std::string attrTypeStr; + if (Optional attrType = var->attr.getValueType()) { + if (Optional typeBuilder = attrType->getBuilderCall()) { + llvm::raw_string_ostream os(attrTypeStr); + os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); + } + } + + body << formatv(attrParserCode, var->attr.getStorageType(), var->name, + attrTypeStr); } else if (auto *operand = dyn_cast(element.get())) { bool isVariadic = operand->getVar()->isVariadic(); body << formatv(isVariadic ? variadicOperandParserCode @@ -615,7 +632,14 @@ shouldEmitSpace = true; if (auto *attr = dyn_cast(element.get())) { - body << " p << " << attr->getVar()->name << "Attr();\n"; + const NamedAttribute *var = attr->getVar(); + + // Elide the attribute type if it is buildable.. + Optional attrType = var->attr.getValueType(); + if (attrType && attrType->getBuilderCall()) + body << " p.printAttributeWithoutType(" << var->name << "Attr());\n"; + else + body << " p.printAttribute(" << var->name << "Attr());\n"; } else if (auto *operand = dyn_cast(element.get())) { body << " p << " << operand->getVar()->name << "();\n"; } else if (isa(element.get())) {