diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1579,6 +1579,9 @@ // Custom printer. code printer = ?; + // Custom assembly format. + string assemblyFormat = ?; + // Custom verifier. code verifier = ?; diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -36,6 +36,10 @@ // Returns true if this is a variadic type constraint. bool isVariadic() const; + + // Returns the builder call for this constraint if this is a buildable type, + // returns None otherwise. + Optional getBuilderCall() const; }; // Wrapper class with helper methods for accessing Types defined in TableGen. diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -29,6 +29,18 @@ return def->isSubClassOf("Variadic"); } +// Returns the builder call for this constraint if this is a buildable type, +// returns None otherwise. +Optional TypeConstraint::getBuilderCall() const { + const llvm::Record *baseType = def; + if (isVariadic()) + baseType = baseType->getValueAsDef("baseType"); + + if (!baseType->isSubClassOf("BuildableType")) + return None; + return baseType->getValueAsString("builderCall"); +} + Type::Type(const llvm::Record *record) : TypeConstraint(record) {} StringRef Type::getTypeDescription() const { diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -0,0 +1,236 @@ +// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s 2>&1 | FileCheck %s --dump-input-on-failure + +// This file contains tests for the specification of the declarative op format. + +include "mlir/IR/OpBase.td" + +def TestDialect : Dialect { + let name = "test"; +} +class TestFormat_Op : Op { + let assemblyFormat = fmt; +} + +//===----------------------------------------------------------------------===// +// Directives +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// attr-dict + +// CHECK: error: format missing 'attr-dict' directive +def DirectiveAttrDictInvalidA : TestFormat_Op<"attrdict_invalid_a", [{ +}]>; +// CHECK: error: 'attr-dict' directive has already been seen +def DirectiveAttrDictInvalidB : TestFormat_Op<"attrdict_invalid_b", [{ + attr-dict attr-dict +}]>; +// CHECK: error: 'attr-dict' directive can only be used as a top-level directive +def DirectiveAttrDictInvalidC : TestFormat_Op<"attrdict_invalid_c", [{ + type(attr-dict) +}]>; +// CHECK-NOT: error +def DirectiveAttrDictValid : TestFormat_Op<"attrdict_valid", [{ + attr-dict +}]>; + +//===----------------------------------------------------------------------===// +// functional-type + +// CHECK: error: 'functional-type' is only valid as a top-level directive +def DirectiveFunctionalTypeInvalidA : TestFormat_Op<"functype_invalid_a", [{ + functional-type(functional-type) +}]>; +// CHECK: error: expected '(' before argument list +def DirectiveFunctionalTypeInvalidB : TestFormat_Op<"functype_invalid_b", [{ + functional-type +}]>; +// CHECK: error: expected directive, literal, or variable +def DirectiveFunctionalTypeInvalidC : TestFormat_Op<"functype_invalid_c", [{ + functional-type( +}]>; +// CHECK: error: expected ',' after inputs argument +def DirectiveFunctionalTypeInvalidD : TestFormat_Op<"functype_invalid_d", [{ + functional-type(operands +}]>; +// CHECK: error: expected directive, literal, or variable +def DirectiveFunctionalTypeInvalidE : TestFormat_Op<"functype_invalid_e", [{ + functional-type(operands, +}]>; +// CHECK: error: expected ')' after argument list +def DirectiveFunctionalTypeInvalidF : TestFormat_Op<"functype_invalid_f", [{ + functional-type(operands, results +}]>; +// CHECK-NOT: error +def DirectiveFunctionalTypeValid : TestFormat_Op<"functype_invalid_a", [{ + functional-type(operands, results) attr-dict +}]>; + +//===----------------------------------------------------------------------===// +// operands + +// CHECK: error: 'operands' directive creates overlap in format +def DirectiveOperandsInvalidA : TestFormat_Op<"operands_invalid_a", [{ + operands operands +}]>; +// CHECK: error: 'operands' directive creates overlap in format +def DirectiveOperandsInvalidB : TestFormat_Op<"operands_invalid_b", [{ + $operand operands +}]>, Arguments<(ins I64:$operand)>; +// CHECK-NOT: error: +def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{ + operands attr-dict +}]>; + +//===----------------------------------------------------------------------===// +// results + +// CHECK: error: 'results' directive can not be used as a top-level directive +def DirectiveResultsInvalidA : TestFormat_Op<"operands_invalid_a", [{ + results +}]>; + +//===----------------------------------------------------------------------===// +// type + +// CHECK: error: expected '(' before argument list +def DirectiveTypeInvalidA : TestFormat_Op<"type_invalid_a", [{ + type +}]>; +// CHECK: error: expected directive, literal, or variable +def DirectiveTypeInvalidB : TestFormat_Op<"type_invalid_b", [{ + type( +}]>; +// CHECK: error: expected ')' after argument list +def DirectiveTypeInvalidC : TestFormat_Op<"type_invalid_c", [{ + type(operands +}]>; +// CHECK-NOT: error: +def DirectiveTypeValid : TestFormat_Op<"type_valid", [{ + type(operands) attr-dict +}]>; + +//===----------------------------------------------------------------------===// +// functional-type/type operands + +// CHECK: error: 'type' directive operand expects variable or directive operand +def DirectiveTypeZOperandInvalidA : TestFormat_Op<"type_operand_invalid_a", [{ + type(`literal`) +}]>; +// CHECK: error: 'operands' 'type' is already bound +def DirectiveTypeZOperandInvalidB : TestFormat_Op<"type_operand_invalid_b", [{ + type(operands) type(operands) +}]>; +// CHECK: error: 'operands' 'type' is already bound +def DirectiveTypeZOperandInvalidC : TestFormat_Op<"type_operand_invalid_c", [{ + type($operand) type(operands) +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: 'type' of 'operand' is already bound +def DirectiveTypeZOperandInvalidD : TestFormat_Op<"type_operand_invalid_d", [{ + type(operands) type($operand) +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: 'type' of 'operand' is already bound +def DirectiveTypeZOperandInvalidE : TestFormat_Op<"type_operand_invalid_e", [{ + type($operand) type($operand) +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: 'results' 'type' is already bound +def DirectiveTypeZOperandInvalidF : TestFormat_Op<"type_operand_invalid_f", [{ + type(results) type(results) +}]>; +// CHECK: error: 'results' 'type' is already bound +def DirectiveTypeZOperandInvalidG : TestFormat_Op<"type_operand_invalid_g", [{ + type($result) type(results) +}]>, Results<(outs I64:$result)>; +// CHECK: error: 'type' of 'result' is already bound +def DirectiveTypeZOperandInvalidH : TestFormat_Op<"type_operand_invalid_h", [{ + type(results) type($result) +}]>, Results<(outs I64:$result)>; +// CHECK: error: 'type' of 'result' is already bound +def DirectiveTypeZOperandInvalidI : TestFormat_Op<"type_operand_invalid_i", [{ + type($result) type($result) +}]>, Results<(outs I64:$result)>; +// CHECK-NOT: error: +def DirectiveTypeZOperandValid : TestFormat_Op<"type_operand_valid", [{ + type(operands) type(results) attr-dict +}]>; + +//===----------------------------------------------------------------------===// +// Literals +//===----------------------------------------------------------------------===// + +// Test all of the valid literals. +// CHECK: error: expected valid literal +def LiteralInvalidA : TestFormat_Op<"literal_invalid_a", [{ + `1` +}]>; +// CHECK: error: unexpected end of file in literal +// CHECK: error: expected directive, literal, or variable +def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{ + ` +}]>; +// CHECK-NOT: error +def LiteralValid : TestFormat_Op<"literal_valid", [{ + `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `->` `abc$._` + attr-dict +}]>; + +//===----------------------------------------------------------------------===// +// Variables +//===----------------------------------------------------------------------===// + +// CHECK: error: expected variable to refer to a argument or result +def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{ + $unknown_arg attr-dict +}]>; +// CHECK: error: attribute 'attr' is already bound +def VariableInvalidB : TestFormat_Op<"variable_invalid_b", [{ + $attr $attr attr-dict +}]>, Arguments<(ins I64Attr:$attr)>; +// CHECK: error: operand 'operand' is already bound +def VariableInvalidC : TestFormat_Op<"variable_invalid_c", [{ + $operand $operand attr-dict +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: operand 'operand' is already bound +def VariableInvalidD : TestFormat_Op<"variable_invalid_d", [{ + operands $operand attr-dict +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: results can not be used at the top level +def VariableInvalidE : TestFormat_Op<"variable_invalid_e", [{ + $result attr-dict +}]>, Results<(outs I64:$result)>; + +//===----------------------------------------------------------------------===// +// Coverage Checks +//===----------------------------------------------------------------------===// + +// CHECK: error: format missing instance of result #0('result') type +def ZCoverageInvalidA : TestFormat_Op<"variable_invalid_a", [{ + attr-dict +}]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; +// CHECK: error: format missing instance of operand #0('operand') +def ZCoverageInvalidB : TestFormat_Op<"variable_invalid_b", [{ + type($result) attr-dict +}]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; +// CHECK: error: format missing instance of operand #0('operand') type +def ZCoverageInvalidC : TestFormat_Op<"variable_invalid_c", [{ + $operand type($result) attr-dict +}]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; +// CHECK: error: format missing instance of operand #0('operand') type +def ZCoverageInvalidD : TestFormat_Op<"variable_invalid_d", [{ + operands attr-dict +}]>, Arguments<(ins Variadic:$operand)>; +// CHECK: error: format missing instance of result #0('result') type +def ZCoverageInvalidE : TestFormat_Op<"variable_invalid_e", [{ + attr-dict +}]>, Results<(outs Variadic:$result)>; +// CHECK-NOT: error +def ZCoverageValidA : TestFormat_Op<"variable_valid_a", [{ + $operand type($operand) type($result) attr-dict +}]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; +def ZCoverageValidB : TestFormat_Op<"variable_valid_b", [{ + $operand type(operands) type(results) attr-dict +}]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; +def ZCoverageValidC : TestFormat_Op<"variable_valid_c", [{ + operands functional-type(operands, results) attr-dict +}]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; + diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -10,6 +10,7 @@ mlir-tblgen.cpp OpDefinitionsGen.cpp OpDocGen.cpp + OpFormatGen.cpp OpInterfacesGen.cpp ReferenceImplGen.cpp RewriterGen.cpp diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "OpFormatGen.h" #include "mlir/Support/STLExtras.h" #include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Format.h" @@ -306,6 +307,7 @@ genCanonicalizerDecls(); genFolderDecls(); genOpInterfaceMethods(); + generateOpFormat(op, opClass); } void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) { @@ -1065,7 +1067,8 @@ } void OpEmitter::genParser() { - if (!hasStringAttribute(def, "parser")) + if (!hasStringAttribute(def, "parser") || + hasStringAttribute(def, "assemblyFormat")) return; auto &method = opClass.newMethod( @@ -1078,6 +1081,9 @@ } void OpEmitter::genPrinter() { + if (hasStringAttribute(def, "assemblyFormat")) + return; + auto valueInit = def.getValueInit("printer"); CodeInit *codeInit = dyn_cast(valueInit); if (!codeInit) diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.h b/mlir/tools/mlir-tblgen/OpFormatGen.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/OpFormatGen.h @@ -0,0 +1,28 @@ +//===- OpFormatGen.h - MLIR operation format generator ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the interface for generating parsers and printers from the +// declarative format. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_ +#define MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_ + +namespace mlir { +namespace tblgen { +class OpClass; +class Operator; + +// Generate the assembly format for the given operator. +void generateOpFormat(const Operator &constOp, OpClass &opClass); + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_ diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -0,0 +1,800 @@ +//===- OpFormatGen.cpp - MLIR operation asm format generator --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OpFormatGen.h" +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/STLExtras.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/OpClass.h" +#include "mlir/TableGen/OpInterfaces.h" +#include "mlir/TableGen/OpTrait.h" +#include "mlir/TableGen/Operator.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +#define DEBUG_TYPE "mlir-tblgen-opformatgen" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// Element +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents a single format element. +class Element { +public: + enum class Kind { + /// This element is a directive. + AttrDictDirective, + FunctionalTypeDirective, + OperandsDirective, + ResultsDirective, + TypeDirective, + + /// This element is a literal. + Literal, + + /// This element is an variable value. + AttributeVariable, + OperandVariable, + ResultVariable, + }; + Element(Kind kind) : kind(kind) {} + virtual ~Element() = default; + + /// Return the kind of this element. + Kind getKind() const { return kind; } + +private: + /// The kind of this element. + Kind kind; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// VariableElement + +namespace { +/// This class represents an instance of an variable element. A variable refers +/// to something registered on the operation itself, e.g. an argument, result, +/// etc. +template +class VariableElement : public Element { +public: + VariableElement(const VarT *var) : Element(kindVal), var(var) {} + static bool classof(const Element *element) { + return element->getKind() == kindVal; + } + const VarT *getVar() { return var; } + +private: + const VarT *var; +}; + +/// This class represents a variable that refers to an attribute argument. +using AttributeVariable = + VariableElement; + +/// This class represents a variable that refers to an operand argument. +using OperandVariable = + VariableElement; + +/// This class represents a variable that refers to a result. +using ResultVariable = + VariableElement; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// DirectiveElement + +namespace { +/// This class implements single kind directives. +template +class DirectiveElement : public Element { +public: + DirectiveElement() : Element(type){}; + static bool classof(const Element *ele) { return ele->getKind() == type; } +}; +/// This class represents the `attr-dict` directive. This directive represents +/// the attribute dictionary of the operation. +using AttrDictDirective = DirectiveElement; + +/// This class represents the `operands` directive. This directive represents +/// all of the operands of an operation. +using OperandsDirective = DirectiveElement; + +/// This class represents the `results` directive. This directive represents +/// all of the results of an operation. +using ResultsDirective = DirectiveElement; + +/// This class represents the `functional-type` directive. This directive takes +/// two arguments and formats them, respectively, as the inputs and results of a +/// FunctionType. +struct FunctionalTypeDirective + : public DirectiveElement { +public: + FunctionalTypeDirective(std::unique_ptr inputs, + std::unique_ptr results) + : inputs(std::move(inputs)), results(std::move(results)) {} + Element *getInputs() const { return inputs.get(); } + Element *getResults() const { return results.get(); } + +private: + /// The input and result arguments. + std::unique_ptr inputs, results; +}; + +/// This class represents the `type` directive. +struct TypeDirective : public DirectiveElement { +public: + TypeDirective(std::unique_ptr arg) : operand(std::move(arg)) {} + Element *getOperand() const { return operand.get(); } + +private: + /// The operand that is used to format the directive. + std::unique_ptr operand; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// LiteralElement + +namespace { +/// This class represents an instance of a literal element. +class LiteralElement : public Element { +public: + LiteralElement(StringRef literal) + : Element{Kind::Literal}, literal(literal){}; + static bool classof(const Element *element) { + return element->getKind() == Kind::Literal; + } + + /// Return the literal for this element. + StringRef getLiteral() const { return literal; } + + /// Returns true if the given string is a valid literal. + static bool isValidLiteral(StringRef value); + +private: + /// The spelling of the literal for this element. + StringRef literal; +}; +} // end anonymous namespace + +bool LiteralElement::isValidLiteral(StringRef value) { + if (value.empty()) + return false; + char front = value.front(); + + // If there is only one character, this must either be punctuation or a + // single character bare identifier. + if (value.size() == 1) + return isalpha(front) || StringRef("_:,=<>()[]").contains(front); + + // Check the punctuation that are larger than a single character. + if (value == "->") + return true; + + // Otherwise, this must be an identifier. + if (!isalpha(front) && front != '_') + return false; + return llvm::all_of(value.drop_front(), [](char c) { + return isalnum(c) || c == '_' || c == '$' || c == '.'; + }); +} + +//===----------------------------------------------------------------------===// +// OperationFormat +//===----------------------------------------------------------------------===// + +namespace { +struct OperationFormat { + OperationFormat(const Operator &op) + : allOperandTypes(false), allResultTypes(false) { + buildableOperandTypes.resize(op.getNumOperands(), llvm::None); + buildableResultTypes.resize(op.getNumResults(), llvm::None); + } + + /// The various elements in this format. + std::vector> elements; + + /// A flag indicating if all operand/result types were seen. If the format + /// contains these, it can not contain individual type resolvers. + bool allOperandTypes, allResultTypes; + + /// A map of buildable types to indices. + llvm::MapVector> buildableTypes; + + /// The index of the buildable type, if valid, for every operand and result. + std::vector> buildableOperandTypes, buildableResultTypes; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// FormatLexer +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents a specific token in the input format. +class Token { +public: + enum Kind { + // Markers. + eof, + error, + + // Tokens with no info. + l_paren, + r_paren, + comma, + equal, + + // Keywords. + keyword_start, + kw_attr_dict, + kw_functional_type, + kw_operands, + kw_results, + kw_type, + keyword_end, + + // String valued tokens. + identifier, + literal, + variable, + }; + Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} + + /// Return the bytes that make up this token. + StringRef getSpelling() const { return spelling; } + + /// Return the kind of this token. + Kind getKind() const { return kind; } + + /// Return a location for this token. + llvm::SMLoc getLoc() const { + return llvm::SMLoc::getFromPointer(spelling.data()); + } + + /// Return if this token is a keyword. + bool isKeyword() const { return kind > keyword_start && kind < keyword_end; } + +private: + /// Discriminator that indicates the kind of token this is. + Kind kind; + + /// A reference to the entire token contents; this is always a pointer into + /// a memory buffer owned by the source manager. + StringRef spelling; +}; + +/// This class implements a simple lexer for operation assembly format strings. +class FormatLexer { +public: + FormatLexer(llvm::SourceMgr &mgr); + + /// Lex the next token and return it. + Token lexToken(); + + /// Emit an error to the lexer with the given location and message. + Token emitError(llvm::SMLoc loc, const Twine &msg); + Token emitError(const char *loc, const Twine &msg); + +private: + Token formToken(Token::Kind kind, const char *tokStart) { + return Token(kind, StringRef(tokStart, curPtr - tokStart)); + } + + /// Return the next character in the stream. + int getNextChar(); + + /// Lex an identifier, literal, or variable. + Token lexIdentifier(const char *tokStart); + Token lexLiteral(const char *tokStart); + Token lexVariable(const char *tokStart); + + llvm::SourceMgr &srcMgr; + StringRef curBuffer; + const char *curPtr; +}; +} // end anonymous namespace + +FormatLexer::FormatLexer(llvm::SourceMgr &mgr) : srcMgr(mgr) { + curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer(); + curPtr = curBuffer.begin(); +} + +Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) { + srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); + return formToken(Token::error, loc.getPointer()); +} +Token FormatLexer::emitError(const char *loc, const Twine &msg) { + return emitError(llvm::SMLoc::getFromPointer(loc), msg); +} + +int FormatLexer::getNextChar() { + char curChar = *curPtr++; + switch (curChar) { + default: + return (unsigned char)curChar; + case 0: { + // A nul character in the stream is either the end of the current buffer or + // a random nul in the file. Disambiguate that here. + if (curPtr - 1 != curBuffer.end()) + return 0; + + // Otherwise, return end of file. + --curPtr; + return EOF; + } + case '\n': + case '\r': + // Handle the newline character by ignoring it and incrementing the line + // count. However, be careful about 'dos style' files with \n\r in them. + // Only treat a \n\r or \r\n as a single line. + if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) + ++curPtr; + return '\n'; + } +} + +Token FormatLexer::lexToken() { + const char *tokStart = curPtr; + + // This always consumes at least one character. + int curChar = getNextChar(); + switch (curChar) { + default: + // Handle identifiers: [a-zA-Z_] + if (isalpha(curChar) || curChar == '_') + return lexIdentifier(tokStart); + + // Unknown character, emit an error. + return emitError(tokStart, "unexpected character"); + case EOF: + // Return EOF denoting the end of lexing. + return formToken(Token::eof, tokStart); + + // Lex punctuation. + case ',': + return formToken(Token::comma, tokStart); + case '=': + return formToken(Token::equal, tokStart); + case '(': + return formToken(Token::l_paren, tokStart); + case ')': + return formToken(Token::r_paren, tokStart); + + // Ignore whitespace characters. + case 0: + case ' ': + case '\t': + case '\n': + return lexToken(); + + case '`': + return lexLiteral(tokStart); + case '$': + return lexVariable(tokStart); + } +} + +Token FormatLexer::lexLiteral(const char *tokStart) { + assert(curPtr[-1] == '`'); + + // Lex a literal surrounded by ``. + while (const char curChar = *curPtr++) { + if (curChar == '`') + return formToken(Token::literal, tokStart); + } + return emitError(curPtr - 1, "unexpected end of file in literal"); +} + +Token FormatLexer::lexVariable(const char *tokStart) { + if (!isalpha(curPtr[0]) && curPtr[0] != '_') + return emitError(curPtr - 1, "expected variable name"); + + // Otherwise, consume the rest of the characters. + while (isalnum(*curPtr) || *curPtr == '_') + ++curPtr; + return formToken(Token::variable, tokStart); +} + +Token FormatLexer::lexIdentifier(const char *tokStart) { + // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* + while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') + ++curPtr; + + // Check to see if this identifier is a keyword. + StringRef str(tokStart, curPtr - tokStart); + Token::Kind kind = llvm::StringSwitch(str) + .Case("attr-dict", Token::kw_attr_dict) + .Case("functional-type", Token::kw_functional_type) + .Case("operands", Token::kw_operands) + .Case("results", Token::kw_results) + .Case("type", Token::kw_type) + .Default(Token::identifier); + return Token(kind, str); +} + +//===----------------------------------------------------------------------===// +// FormatParser +//===----------------------------------------------------------------------===// + +namespace { +/// This class implements a parser for an instance of an operation assembly +/// format. +class FormatParser { +public: + FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) + : lexer(mgr), curToken(lexer.lexToken()), fmt(format), op(op), + seenOperandTypes(op.getNumOperands()), + seenResultTypes(op.getNumResults()) {} + + /// Parse the operation assembly format. + LogicalResult parse(); + +private: + /// Parse a specific element. + LogicalResult parseElement(std::unique_ptr &element, + bool isTopLevel); + LogicalResult parseVariable(std::unique_ptr &element, + bool isTopLevel); + LogicalResult parseDirective(std::unique_ptr &element, + bool isTopLevel); + LogicalResult parseLiteral(std::unique_ptr &element); + + /// Parse the various different directives. + LogicalResult parseAttrDictDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel); + LogicalResult parseFunctionalTypeDirective(std::unique_ptr &element, + Token tok, bool isTopLevel); + LogicalResult parseOperandsDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel); + LogicalResult parseResultsDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel); + LogicalResult parseTypeDirective(std::unique_ptr &element, Token tok, + bool isTopLevel); + LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element); + + //===--------------------------------------------------------------------===// + // Lexer Utilities + //===--------------------------------------------------------------------===// + + /// Advance the current lexer onto the next token. + void consumeToken() { + assert(curToken.getKind() != Token::eof && + curToken.getKind() != Token::error && + "shouldn't advance past EOF or errors"); + curToken = lexer.lexToken(); + } + LogicalResult parseToken(Token::Kind kind, const Twine &msg) { + if (curToken.getKind() != kind) + return emitError(curToken.getLoc(), msg); + consumeToken(); + return success(); + } + LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) { + lexer.emitError(loc, msg); + return failure(); + } + + //===--------------------------------------------------------------------===// + // Fields + //===--------------------------------------------------------------------===// + + FormatLexer lexer; + Token curToken; + OperationFormat &fmt; + Operator &op; + + // The following are various bits of format state used for verification during + // parsing. + bool hasAllOperands = false, hasAttrDict = false; + llvm::SmallBitVector seenOperandTypes, seenResultTypes; + llvm::DenseSet seenOperands; + llvm::DenseSet seenAttrs; +}; +} // end anonymous namespace + +LogicalResult FormatParser::parse() { + llvm::SMLoc loc = curToken.getLoc(); + + // Parse each of the format elements into the main format. + while (curToken.getKind() != Token::eof) { + std::unique_ptr element; + if (failed(parseElement(element, /*isTopLevel=*/true))) + return failure(); + fmt.elements.push_back(std::move(element)); + } + + // Check that the attribute dictionary is in the format. + if (!hasAttrDict) + return emitError(loc, "format missing 'attr-dict' directive"); + + // Check that all of the result types can be inferred. + auto &buildableTypes = fmt.buildableTypes; + if (!fmt.allResultTypes) { + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { + if (seenResultTypes.test(i)) + continue; + + // If the result is not variadic, allow for the case where the type has a + // builder that we can use. + NamedTypeConstraint &result = op.getResult(i); + Optional builder = result.constraint.getBuilderCall(); + if (!builder || result.constraint.isVariadic()) { + return emitError(loc, "format missing instance of result #" + Twine(i) + + "('" + result.name + "') type"); + } + // Note in the format that this result uses the custom builder. + auto it = buildableTypes.insert({*builder, buildableTypes.size()}); + fmt.buildableResultTypes[i] = it.first->second; + } + } + + // Check that all of the operands are within the format, and their types can + // be inferred. + for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { + NamedTypeConstraint &operand = op.getOperand(i); + + // Check that the operand itself is in the format. + if (!hasAllOperands && !seenOperands.count(&operand)) { + return emitError(loc, "format missing instance of operand #" + Twine(i) + + "('" + operand.name + "')"); + } + + // Check that the operand type is in the format, or that it can be inferred. + if (!fmt.allOperandTypes && !seenOperandTypes.test(i)) { + // Similarly to results, allow a custom builder for resolving the type if + // we aren't using the 'operands' directive. + Optional builder = operand.constraint.getBuilderCall(); + if (!builder || (hasAllOperands && operand.isVariadic())) { + return emitError(loc, "format missing instance of operand #" + + Twine(i) + "('" + operand.name + "') type"); + } + auto it = buildableTypes.insert({*builder, buildableTypes.size()}); + fmt.buildableOperandTypes[i] = it.first->second; + } + } + return success(); +} + +LogicalResult FormatParser::parseElement(std::unique_ptr &element, + bool isTopLevel) { + // Directives. + if (curToken.isKeyword()) + return parseDirective(element, isTopLevel); + // Literals. + if (curToken.getKind() == Token::literal) + return parseLiteral(element); + // Variables. + if (curToken.getKind() == Token::variable) + return parseVariable(element, isTopLevel); + return emitError(curToken.getLoc(), + "expected directive, literal, or variable"); +} + +LogicalResult FormatParser::parseVariable(std::unique_ptr &element, + bool isTopLevel) { + Token varTok = curToken; + consumeToken(); + + StringRef name = varTok.getSpelling().drop_front(); + llvm::SMLoc loc = varTok.getLoc(); + + // Functor used to find an element within the given range that has the same + // name as 'name'. + auto findArg = [&](auto &&range) { + auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); + return it != range.end() ? &*it : nullptr; + }; + + // Check that the parsed argument is something actually registered on the op. + /// Attributes + if (const NamedAttribute *attr = findArg(op.getAttributes())) { + if (isTopLevel && !seenAttrs.insert(attr).second) + return emitError(loc, "attribute '" + name + "' is already bound"); + element = std::make_unique(attr); + return success(); + } + /// Operands + if (const NamedTypeConstraint *operand = findArg(op.getOperands())) { + if (isTopLevel) { + if (hasAllOperands || !seenOperands.insert(operand).second) + return emitError(loc, "operand '" + name + "' is already bound"); + } + element = std::make_unique(operand); + return success(); + } + /// Results. + if (const NamedTypeConstraint *result = findArg(op.getResults())) { + if (isTopLevel) + return emitError(loc, "results can not be used at the top level"); + element = std::make_unique(result); + return success(); + } + return emitError(loc, "expected variable to refer to a argument or result"); +} + +LogicalResult FormatParser::parseDirective(std::unique_ptr &element, + bool isTopLevel) { + Token dirTok = curToken; + consumeToken(); + + switch (dirTok.getKind()) { + case Token::kw_attr_dict: + return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel); + case Token::kw_functional_type: + return parseFunctionalTypeDirective(element, dirTok, isTopLevel); + case Token::kw_operands: + return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel); + case Token::kw_results: + return parseResultsDirective(element, dirTok.getLoc(), isTopLevel); + case Token::kw_type: + return parseTypeDirective(element, dirTok, isTopLevel); + + default: + llvm_unreachable("unknown directive token"); + } +} + +LogicalResult FormatParser::parseLiteral(std::unique_ptr &element) { + Token literalTok = curToken; + consumeToken(); + + // Check that the parsed literal is valid. + StringRef value = literalTok.getSpelling().drop_front().drop_back(); + if (!LiteralElement::isValidLiteral(value)) + return emitError(literalTok.getLoc(), "expected valid literal"); + + element = std::make_unique(value); + return success(); +} + +LogicalResult +FormatParser::parseAttrDictDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel) { + if (!isTopLevel) + return emitError(loc, "'attr-dict' directive can only be used as a " + "top-level directive"); + if (hasAttrDict) + return emitError(loc, "'attr-dict' directive has already been seen"); + + hasAttrDict = true; + element = std::make_unique(); + return success(); +} + +LogicalResult +FormatParser::parseFunctionalTypeDirective(std::unique_ptr &element, + Token tok, bool isTopLevel) { + llvm::SMLoc loc = tok.getLoc(); + if (!isTopLevel) + return emitError( + loc, "'functional-type' is only valid as a top-level directive"); + + // Parse the main operand. + std::unique_ptr inputs, results; + if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + failed(parseTypeDirectiveOperand(inputs)) || + failed(parseToken(Token::comma, "expected ',' after inputs argument")) || + failed(parseTypeDirectiveOperand(results)) || + failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + return failure(); + + // Get the proper directive kind and create it. + element = std::make_unique(std::move(inputs), + std::move(results)); + return success(); +} + +LogicalResult +FormatParser::parseOperandsDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel) { + if (isTopLevel && (hasAllOperands || !seenOperands.empty())) + return emitError(loc, "'operands' directive creates overlap in format"); + hasAllOperands = true; + element = std::make_unique(); + return success(); +} + +LogicalResult +FormatParser::parseResultsDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel) { + if (isTopLevel) + return emitError(loc, "'results' directive can not be used as a " + "top-level directive"); + element = std::make_unique(); + return success(); +} + +LogicalResult +FormatParser::parseTypeDirective(std::unique_ptr &element, Token tok, + bool isTopLevel) { + llvm::SMLoc loc = tok.getLoc(); + if (!isTopLevel) + return emitError(loc, "'type' is only valid as a top-level directive"); + + std::unique_ptr operand; + if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + failed(parseTypeDirectiveOperand(operand)) || + failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + return failure(); + element = std::make_unique(std::move(operand)); + return success(); +} + +LogicalResult +FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element) { + llvm::SMLoc loc = curToken.getLoc(); + if (failed(parseElement(element, /*isTopLevel=*/false))) + return failure(); + if (isa(element.get())) + return emitError( + loc, "'type' directive operand expects variable or directive operand"); + + if (auto *var = dyn_cast(element.get())) { + unsigned opIdx = var->getVar() - op.operand_begin(); + if (fmt.allOperandTypes || seenOperandTypes.test(opIdx)) + return emitError(loc, "'type' of '" + var->getVar()->name + + "' is already bound"); + seenOperandTypes.set(opIdx); + } else if (auto *var = dyn_cast(element.get())) { + unsigned resIdx = var->getVar() - op.result_begin(); + if (fmt.allResultTypes || seenResultTypes.test(resIdx)) + return emitError(loc, "'type' of '" + var->getVar()->name + + "' is already bound"); + seenResultTypes.set(resIdx); + } else if (isa(&*element)) { + if (fmt.allOperandTypes || seenOperandTypes.any()) + return emitError(loc, "'operands' 'type' is already bound"); + fmt.allOperandTypes = true; + } else if (isa(&*element)) { + if (fmt.allResultTypes || seenResultTypes.any()) + return emitError(loc, "'results' 'type' is already bound"); + fmt.allResultTypes = true; + } else { + return emitError(loc, "invalid argument to 'type' directive"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Interface +//===----------------------------------------------------------------------===// + +void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) { + // TODO(riverriddle) Operator doesn't expose all necessary functionality via + // the const interface. + Operator &op = const_cast(constOp); + + // Check if the operation specified the format field. + StringRef formatStr; + TypeSwitch(op.getDef().getValueInit("assemblyFormat")) + .Case( + [&](auto *init) { formatStr = init->getValue(); }); + if (formatStr.empty()) + return; + + // Parse the format description. + llvm::SourceMgr mgr; + mgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(formatStr), + llvm::SMLoc()); + OperationFormat format(op); + if (failed(FormatParser(mgr, format, op).parse())) + return; +}