diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -96,11 +96,16 @@ ValueRange operands) = 0; /// Print an optional arrow followed by a type list. - void printOptionalArrowTypeList(ArrayRef types) { - if (types.empty()) - return; + template + void printOptionalArrowTypeList(TypeRange &&types) { + if (types.begin() != types.end()) + printArrowTypeList(types); + } + template void printArrowTypeList(TypeRange &&types) { auto &os = getStream() << " -> "; - bool wrapped = types.size() != 1 || types[0].isa(); + + bool wrapped = !has_single_element(types) || + (*types.begin()).template isa(); if (wrapped) os << '('; interleaveComma(types, *this); @@ -110,23 +115,17 @@ /// Print the complete type of an operation in functional form. void printFunctionalType(Operation *op) { + printFunctionalType(op->getNonSuccessorOperands().getTypes(), + op->getResultTypes()); + } + /// Print the two given type ranges in a functional form. + template + void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { auto &os = getStream(); os << "("; - interleaveComma(op->getNonSuccessorOperands(), os, [&](Value operand) { - if (operand) - printType(operand.getType()); - else - os << "<"; - }); - os << ") -> "; - if (op->getNumResults() == 1 && - !op->getResult(0).getType().isa()) { - printType(op->getResult(0).getType()); - } else { - os << '('; - interleaveComma(op->getResultTypes(), os); - os << ')'; - } + interleaveComma(inputs, os); + os << ")"; + printArrowTypeList(results); } /// Print the given string as a symbol reference, i.e. a form representable by @@ -191,6 +190,10 @@ interleaveComma(types, p); return p; } +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef types) { + interleaveComma(types, p); + return p; +} //===----------------------------------------------------------------------===// // OpAsmParser @@ -489,6 +492,20 @@ return failure(); return success(); } + template + ParseResult resolveOperands(Operands &&operands, Types &&types, + llvm::SMLoc loc, SmallVectorImpl &result) { + size_t operandSize = std::distance(operands.begin(), operands.end()); + size_t typeSize = std::distance(types.begin(), types.end()); + if (operandSize != typeSize) + return emitError(loc) + << operandSize << " operands present, but expected " << typeSize; + + for (auto it : llvm::zip(operands, types)) + if (resolveOperand(std::get<0>(it), std::get<1>(it), result)) + return failure(); + return success(); + } /// Parses an affine map attribute where dims and symbols are SSA operands. /// Operand values must come from single-result sources, and be valid @@ -557,6 +574,34 @@ /// Parse a type. virtual ParseResult parseType(Type &result) = 0; + /// Parse a type of a specific type. + template ParseResult parseType(TypeT &result) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of type. + Type type; + if (parseType(type)) + return failure(); + + // Check for the right kind of attribute. + result = type.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of type specified"); + + return success(); + } + + /// Parse a type list. + ParseResult parseTypeList(SmallVectorImpl &result) { + do { + Type type; + if (parseType(type)) + return failure(); + result.push_back(type); + } while (succeeded(parseOptionalComma())); + return success(); + } + /// Parse an optional arrow followed by a type list. virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl &result) = 0; diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -1062,4 +1062,64 @@ let results = (outs AnyType); } +//===----------------------------------------------------------------------===// +// Test Op Asm Format +//===----------------------------------------------------------------------===// + +def FormatLiteralOp : TEST_Op<"format_literal_op"> { + let assemblyFormat = [{ + `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` attr-dict + }]; +} + +// Test that we elide attributes that are within the syntax. +def FormatAttrOp : TEST_Op<"format_attr_op"> { + let arguments = (ins I64Attr:$attr); + let assemblyFormat = "$attr attr-dict"; +} + +// Test that we don't need to provide types in the format if they are buildable. +def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> { + let arguments = (ins I64:$buildable); + let results = (outs I64:$buildable_res); + let assemblyFormat = "$buildable attr-dict"; +} + +// Test various mixings of result type formatting. +class FormatResultBase : TEST_Op { + let results = (outs I64:$buildable_res, AnyMemRef:$result); + let assemblyFormat = fmt; +} +def FormatResultAOp : FormatResultBase<"format_result_a_op", [{ + type($result) attr-dict +}]>; +def FormatResultBOp : FormatResultBase<"format_result_b_op", [{ + type(results) attr-dict +}]>; +def FormatResultCOp : FormatResultBase<"format_result_c_op", [{ + functional-type($buildable_res, $result) attr-dict +}]>; + +// Test various mixings of operand type formatting. +class FormatOperandBase : TEST_Op { + let arguments = (ins I64:$buildable, AnyMemRef:$operand); + let assemblyFormat = fmt; +} + +def FormatOperandAOp : FormatOperandBase<"format_operand_a_op", [{ + operands `:` type(operands) attr-dict +}]>; +def FormatOperandBOp : FormatOperandBase<"format_operand_b_op", [{ + operands `:` type($operand) attr-dict +}]>; +def FormatOperandCOp : FormatOperandBase<"format_operand_c_op", [{ + $buildable `,` $operand `:` type(operands) attr-dict +}]>; +def FormatOperandDOp : FormatOperandBase<"format_operand_d_op", [{ + $buildable `,` $operand `:` type($operand) attr-dict +}]>; +def FormatOperandEOp : FormatOperandBase<"format_operand_e_op", [{ + $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict +}]>; + #endif // TEST_OPS diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s + +// CHECK: %[[I64:.*]] = +%i64 = "foo.op"() : () -> (i64) +// CHECK: %[[MEMREF:.*]] = +%memref = "foo.op"() : () -> (memref<1xf64>) + +// CHECK: test.format_literal_op keyword_$. -> :, = <> () [] {foo.some_attr} +test.format_literal_op keyword_$. -> :, = <> () [] {foo.some_attr} + +// CHECK: test.format_attr_op 10 +// CHECK-NOT: {attr +test.format_attr_op 10 + +// CHECK: test.format_buildable_type_op %[[I64]] +%ignored = test.format_buildable_type_op %i64 + +// CHECK: test.format_result_a_op memref<1xf64> +%ignored_a:2 = test.format_result_a_op memref<1xf64> + +// CHECK: test.format_result_b_op i64, memref<1xf64> +%ignored_b:2 = test.format_result_b_op i64, memref<1xf64> + +// CHECK: test.format_result_c_op (i64) -> memref<1xf64> +%ignored_c:2 = test.format_result_c_op (i64) -> memref<1xf64> + +// CHECK: test.format_operand_a_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64> +test.format_operand_a_op %i64, %memref : i64, memref<1xf64> + +// CHECK: test.format_operand_b_op %[[I64]], %[[MEMREF]] : memref<1xf64> +test.format_operand_b_op %i64, %memref : memref<1xf64> + +// CHECK: test.format_operand_c_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64> +test.format_operand_c_op %i64, %memref : i64, memref<1xf64> + +// CHECK: test.format_operand_d_op %[[I64]], %[[MEMREF]] : memref<1xf64> +test.format_operand_d_op %i64, %memref : memref<1xf64> + +// CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64> +test.format_operand_e_op %i64, %memref : i64, memref<1xf64> diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -207,6 +207,15 @@ buildableResultTypes.resize(op.getNumResults(), llvm::None); } + /// Generate the operation parser from this format. + void genParser(Operator &op, OpClass &opClass); + /// Generate the c++ to resolve the types of operands and results during + /// parsing. + void genParserTypeResolution(Operator &op, OpMethodBody &body); + + /// Generate the operation printer from this format. + void genPrinter(Operator &op, OpClass &opClass); + /// The various elements in this format. std::vector> elements; @@ -222,6 +231,366 @@ }; } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Parser Gen + +/// The code snippet used to generate a parser call for an attribute. +/// +/// {0}: The storage type of the attribute. +/// {1}: The name of the attribute. +const char *attrParserCode = R"( + {0} {1}Attr; + if (parser.parseAttribute({1}Attr, "{1}", result.attributes)) + return failure(); +)"; + +/// The code snippet used to generate a parser call for an operand. +/// +/// {0}: The name of the operand. +const char *variadicOperandParserCode = R"( + llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation(); + (void){0}OperandsLoc; + SmallVector {0}Operands; + if (parser.parseOperandList({0}Operands)) + return failure(); +)"; +const char *operandParserCode = R"( + llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation(); + (void){0}OperandsLoc; + OpAsmParser::OperandType {0}RawOperands[1]; + if (parser.parseOperand({0}RawOperands[0])) + return failure(); + ArrayRef {0}Operands({0}RawOperands); +)"; + +/// The code snippet used to generate a parser call for a type list. +/// +/// {0}: The name for the type list. +const char *variadicTypeParserCode = R"( + SmallVector {0}Types; + if (parser.parseTypeList({0}Types)) + return failure(); +)"; +const char *typeParserCode = R"( + Type {0}RawTypes[1] = {{nullptr}; + if (parser.parseType({0}RawTypes[0])) + return failure(); + ArrayRef {0}Types({0}RawTypes); +)"; + +/// The code snippet used to generate a parser call for a functional type. +/// +/// {0}: The name for the input type list. +/// {1}: The name for the result type list. +const char *functionalTypeParserCode = R"( + FunctionType {0}__{1}_functionType; + if (parser.parseType({0}__{1}_functionType)) + return failure(); + ArrayRef {0}Types = {0}__{1}_functionType.getInputs(); + ArrayRef {1}Types = {0}__{1}_functionType.getResults(); +)"; + +/// Get the name used for the type list for the given type directive operand. +/// 'isVariadic' is set to true if the operand has variadic types. +static StringRef getTypeListName(Element *arg, bool &isVariadic) { + if (auto *operand = dyn_cast(arg)) { + isVariadic = operand->getVar()->isVariadic(); + return operand->getVar()->name; + } + if (auto *result = dyn_cast(arg)) { + isVariadic = result->getVar()->isVariadic(); + return result->getVar()->name; + } + isVariadic = true; + if (isa(arg)) + return "fullOperand"; + if (isa(arg)) + return "fullResult"; + llvm_unreachable("unknown 'type' directive argument"); +} + +/// Generate the parser for a literal value. +static void genLiteralParser(StringRef value, OpMethodBody &body) { + body << " if (parser.parse"; + + // Handle the case of a keyword/identifier. + if (value.front() == '_' || isalpha(value.front())) { + body << "Keyword(\"" << value << "\"))\n return failure();\n"; + return; + } + body << (StringRef)llvm::StringSwitch(value) + .Case("->", "Arrow") + .Case(":", "Colon") + .Case(",", "Comma") + .Case("=", "Equal") + .Case("<", "Less") + .Case(">", "Greater") + .Case("(", "LParen") + .Case(")", "RParen") + .Case("[", "LSquare") + .Case("]", "RSquare") + << "())\n return failure();\n"; +} + +void OperationFormat::genParser(Operator &op, OpClass &opClass) { + auto &method = opClass.newMethod( + "ParseResult", "parse", "OpAsmParser &parser, OperationState &result", + OpMethod::MP_Static); + auto &body = method.body(); + + // Generate parsers for each of the elements. + for (auto &element : elements) { + /// Literals. + if (LiteralElement *literal = dyn_cast(element.get())) { + genLiteralParser(literal->getLiteral(), body); + + /// Arguments. + } else if (auto *attr = dyn_cast(element.get())) { + const NamedAttribute *var = attr->getVar(); + body << formatv(attrParserCode, var->attr.getStorageType(), var->name); + } else if (auto *operand = dyn_cast(element.get())) { + bool isVariadic = operand->getVar()->isVariadic(); + body << formatv(isVariadic ? variadicOperandParserCode + : operandParserCode, + operand->getVar()->name); + + /// Directives. + } else if (isa(element.get())) { + body << " if (parser.parseOptionalAttrDict(result.attributes))\n" + << " return failure();\n"; + } else if (isa(element.get())) { + body << " llvm::SMLoc fullOperandLoc = parser.getCurrentLocation();\n" + << " SmallVector fullOperands;\n" + << " if (parser.parseOperandList(fullOperands))\n" + << " return failure();\n"; + } else if (auto *dir = dyn_cast(element.get())) { + bool isVariadic = false; + StringRef listName = getTypeListName(dir->getOperand(), isVariadic); + body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode, + listName); + } else if (auto *dir = dyn_cast(element.get())) { + bool ignored = false; + body << formatv(functionalTypeParserCode, + getTypeListName(dir->getInputs(), ignored), + getTypeListName(dir->getResults(), ignored)); + } else { + llvm_unreachable("unknown format element"); + } + } + + // Generate the code to resolve the operand and result types now that they + // have been parsed. + genParserTypeResolution(op, body); + body << " return success();\n"; +} + +void OperationFormat::genParserTypeResolution(Operator &op, + OpMethodBody &body) { + // Initialize the set of buildable types. + for (auto &it : buildableTypes) + body << " Type odsBuildableType" << it.second << " = parser.getBuilder()." + << it.first << ";\n"; + + // Resolve each of the result types. + if (allResultTypes) { + body << " result.addTypes(fullResultTypes);\n"; + } else { + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { + body << " result.addTypes("; + if (Optional val = buildableResultTypes[i]) + body << "odsBuildableType" << *val; + else + body << op.getResultName(i) << "Types"; + body << ");\n"; + } + } + + // Early exit if there are no operands. + if (op.getNumOperands() == 0) + return; + + // Flag indicating if operands were dumped all together in a group. + bool hasAllOperands = llvm::any_of( + elements, [](auto &elt) { return isa(elt.get()); }); + + // Handle the case where all operand types are in one group. + if (allOperandTypes) { + // If we have all operands together, use the full operand list directly. + if (hasAllOperands) { + body << " if (parser.resolveOperands(fullOperands, fullOperandTypes, " + "fullOperandLoc, result.operands))\n" + " return failure();\n"; + return; + } + + // Otherwise, use llvm::concat to merge the disjoint operand lists together. + // llvm::concat does not allow the case of a single range, so guard it here. + body << " if (parser.resolveOperands("; + if (op.getNumOperands() > 1) { + body << "llvm::concat("; + interleaveComma(op.getOperands(), body, [&](auto &operand) { + body << operand.name << "Operands"; + }); + body << ")"; + } else { + body << op.operand_begin()->name << "Operands"; + } + body << ", fullOperandTypes, parser.getNameLoc(), result.operands))\n" + << " return failure();\n"; + return; + } + // Handle the case where all of the operands were grouped together. + if (hasAllOperands) { + body << " if (parser.resolveOperands(fullOperands, "; + + // Group all of the operand types together to perform the resolution all at + // once. Use llvm::concat to perform the merge. llvm::concat does not allow + // the case of a single range, so guard it here. + if (op.getNumOperands() > 1) { + body << "llvm::concat("; + interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { + if (Optional val = buildableOperandTypes[i]) + body << "ArrayRef(odsBuildableType" << *val << ")"; + else + body << op.getOperand(i).name << "Types"; + }); + body << ")"; + } else { + body << op.operand_begin()->name << "Types"; + } + + body << ", fullOperandLoc, result.operands))\n" + << " return failure();\n"; + return; + } + + // The final case is the one where each of the operands types are resolved + // separately. + for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { + NamedTypeConstraint &operand = op.getOperand(i); + body << " if (parser.resolveOperands(" << operand.name << "Operands, "; + if (Optional val = buildableOperandTypes[i]) + body << "odsBuildableType" << *val << ", "; + else + body << operand.name << "Types, " << operand.name << "OperandsLoc, "; + body << "result.operands))\n return failure();\n"; + } +} + +//===----------------------------------------------------------------------===// +// PrinterGen + +/// Generate the printer for the 'attr-dict' directive. +static void genAttrDictPrinter(OperationFormat &fmt, OpMethodBody &body) { + // Collect all of the attributes used in the format, these will be elided. + SmallVector usedAttributes; + for (auto &it : fmt.elements) + if (auto *attr = dyn_cast(it.get())) + usedAttributes.push_back(attr->getVar()); + + body << " p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"; + interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) { + body << "\"" << attr->name << "\""; + }); + body << "});\n"; +} + +/// Generate the printer for a literal value. `shouldEmitSpace` is true if a +/// space should be emitted before this element. `lastWasPunctuation` is true if +/// the previous element was a punctuation literal. +static void genLiteralPrinter(StringRef value, OpMethodBody &body, + bool &shouldEmitSpace, bool &lastWasPunctuation) { + body << " p"; + + // Don't insert a space for certain punctuation. + auto shouldPrintSpaceBeforeLiteral = [&] { + if (value.size() != 1 && value != "->") + return true; + if (lastWasPunctuation) + return !StringRef(">)}],").contains(value.front()); + return !StringRef("<>(){}[],").contains(value.front()); + }; + if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral()) + body << " << \" \""; + body << " << \"" << value << "\";\n"; + + // Insert a space after certain literals. + shouldEmitSpace = + value.size() != 1 || !StringRef("<({[").contains(value.front()); + lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); +} + +/// Generate the c++ for an operand to a (*-)type directive. +static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) { + if (isa(arg)) + return body << "getOperation()->getOperandTypes()"; + if (isa(arg)) + return body << "getOperation()->getResultTypes()"; + auto *operand = dyn_cast(arg); + auto *var = operand ? operand->getVar() : cast(arg)->getVar(); + if (var->isVariadic()) + return body << var->name << "().getTypes()"; + return body << "ArrayRef(" << var->name << "().getType())"; +} + +void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { + auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p"); + auto &body = method.body(); + + // Emit the operation name, trimming the prefix if this is the standard + // dialect. + body << " p << \""; + std::string opName = op.getOperationName(); + if (op.getDialectName() == "std") + body << StringRef(opName).drop_front(4); + else + body << opName; + body << "\";\n"; + + // Flags for if we should emit a space, and if the last element was + // punctuation. + bool shouldEmitSpace = true, lastWasPunctuation = false; + for (auto &element : elements) { + // Emit a literal element. + if (LiteralElement *literal = dyn_cast(element.get())) { + genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, + lastWasPunctuation); + continue; + } + + // Emit the attribute dictionary. + if (isa(element.get())) { + genAttrDictPrinter(*this, body); + lastWasPunctuation = false; + continue; + } + + // Optionally insert a space before the next element. The AttrDict printer + // already adds a space as necessary. + if (shouldEmitSpace || !lastWasPunctuation) + body << " p << \" \";\n"; + lastWasPunctuation = false; + shouldEmitSpace = true; + + if (auto *attr = dyn_cast(element.get())) { + body << " p << " << attr->getVar()->name << "Attr();\n"; + } else if (auto *operand = dyn_cast(element.get())) { + body << " p << " << operand->getVar()->name << "();\n"; + } else if (isa(element.get())) { + body << " p << getOperation()->getOperands();\n"; + } else if (auto *dir = dyn_cast(element.get())) { + body << " p << "; + genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; + } else if (auto *dir = dyn_cast(element.get())) { + body << " p.printFunctionalType("; + genTypeOperandPrinter(dir->getInputs(), body) << ", "; + genTypeOperandPrinter(dir->getResults(), body) << ");\n"; + } else { + llvm_unreachable("unknown format element"); + } + } +} + //===----------------------------------------------------------------------===// // FormatLexer //===----------------------------------------------------------------------===// @@ -796,4 +1165,8 @@ OperationFormat format(op); if (failed(FormatParser(mgr, format, op).parse())) return; + + // Generate the printer and parser based on the parsed format. + format.genParser(op, opClass); + format.genPrinter(op, opClass); }