diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -15,14 +15,9 @@ #define MLIR_IR_DIALECTIMPLEMENTATION_H #include "mlir/IR/OpImplementation.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/SMLoc.h" -#include "llvm/Support/raw_ostream.h" namespace mlir { -class Builder; - //===----------------------------------------------------------------------===// // DialectAsmPrinter //===----------------------------------------------------------------------===// @@ -30,360 +25,26 @@ /// This is a pure-virtual base class that exposes the asmprinter hooks /// necessary to implement a custom printAttribute/printType() method on a /// dialect. -class DialectAsmPrinter { +class DialectAsmPrinter : public AsmPrinter { public: - DialectAsmPrinter() {} - virtual ~DialectAsmPrinter(); - virtual raw_ostream &getStream() const = 0; - - /// Print the given attribute to the stream. - virtual void printAttribute(Attribute attr) = 0; - - /// Print the given attribute without its type. The corresponding parser must - /// provide a valid type for the attribute. - virtual void printAttributeWithoutType(Attribute attr) = 0; - - /// Print the given floating point value in a stabilized form that can be - /// roundtripped through the IR. This is the companion to the 'parseFloat' - /// hook on the DialectAsmParser. - virtual void printFloat(const APFloat &value) = 0; - - /// Print the given type to the stream. - virtual void printType(Type type) = 0; - -private: - DialectAsmPrinter(const DialectAsmPrinter &) = delete; - void operator=(const DialectAsmPrinter &) = delete; + using AsmPrinter::AsmPrinter; + ~DialectAsmPrinter() override; }; -// Make the implementations convenient to use. -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) { - p.printAttribute(attr); - return p; -} - -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, - const APFloat &value) { - p.printFloat(value); - return p; -} -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) { - return p << APFloat(value); -} -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) { - return p << APFloat(value); -} - -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) { - p.printType(type); - return p; -} - -// Support printing anything that isn't convertible to one of the above types, -// even if it isn't exactly one of them. For example, we want to print -// FunctionType with the Type version above, not have it match this. -template ::value && - !std::is_convertible::value && - !std::is_convertible::value && - !llvm::is_one_of::value, - T>::type * = nullptr> -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) { - p.getStream() << other; - return p; -} - //===----------------------------------------------------------------------===// // DialectAsmParser //===----------------------------------------------------------------------===// -/// The DialectAsmParser has methods for interacting with the asm parser: -/// parsing things from it, emitting errors etc. It has an intentionally -/// high-level API that is designed to reduce/constrain syntax innovation in -/// individual attributes or types. -class DialectAsmParser { +/// The DialectAsmParser has methods for interacting with the asm parser when +/// parsing attributes and types. +class DialectAsmParser : public AsmParser { public: - virtual ~DialectAsmParser(); - - /// Emit a diagnostic at the specified location and return failure. - virtual InFlightDiagnostic emitError(llvm::SMLoc loc, - const Twine &message = {}) = 0; - - /// Return a builder which provides useful access to MLIRContext, global - /// objects like types and attributes. - virtual Builder &getBuilder() const = 0; - - /// Get the location of the next token and store it into the argument. This - /// always succeeds. - virtual llvm::SMLoc getCurrentLocation() = 0; - ParseResult getCurrentLocation(llvm::SMLoc *loc) { - *loc = getCurrentLocation(); - return success(); - } - - /// Return the location of the original name token. - virtual llvm::SMLoc getNameLoc() const = 0; - - /// Re-encode the given source location as an MLIR location and return it. - /// Note: This method should only be used when a `Location` is necessary, as - /// the encoding process is not efficient. In other cases a more suitable - /// alternative should be used, such as the `getChecked` methods defined - /// below. - virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; + using AsmParser::AsmParser; + ~DialectAsmParser() override; /// Returns the full specification of the symbol being parsed. This allows for /// using a separate parser if necessary. virtual StringRef getFullSymbolSpec() const = 0; - - // These methods emit an error and return failure or success. This allows - // these to be chained together into a linear sequence of || expressions in - // many cases. - - /// Parse a floating point value from the stream. - virtual ParseResult parseFloat(double &result) = 0; - - /// Parse an integer value from the stream. - template - ParseResult parseInteger(IntT &result) { - auto loc = getCurrentLocation(); - OptionalParseResult parseResult = parseOptionalInteger(result); - if (!parseResult.hasValue()) - return emitError(loc, "expected integer value"); - return *parseResult; - } - - /// Parse an optional integer value from the stream. - virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; - - template - OptionalParseResult parseOptionalInteger(IntT &result) { - auto loc = getCurrentLocation(); - - // Parse the unsigned variant. - APInt uintResult; - OptionalParseResult parseResult = parseOptionalInteger(uintResult); - if (!parseResult.hasValue() || failed(*parseResult)) - return parseResult; - - // Try to convert to the provided integer type. sextOrTrunc is correct even - // for unsigned types because parseOptionalInteger ensures the sign bit is - // zero for non-negated integers. - result = - (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue(); - if (APInt(uintResult.getBitWidth(), result) != uintResult) - return emitError(loc, "integer value too large"); - return success(); - } - - /// Invoke the `getChecked` method of the given Attribute or Type class, using - /// the provided location to emit errors in the case of failure. Note that - /// unlike `OpBuilder::getType`, this method does not implicitly insert a - /// context parameter. - template - T getChecked(llvm::SMLoc loc, ParamsT &&... params) { - return T::getChecked([&] { return emitError(loc); }, - std::forward(params)...); - } - /// A variant of `getChecked` that uses the result of `getNameLoc` to emit - /// errors. - template - T getChecked(ParamsT &&... params) { - return T::getChecked([&] { return emitError(getNameLoc()); }, - std::forward(params)...); - } - - //===--------------------------------------------------------------------===// - // Token Parsing - //===--------------------------------------------------------------------===// - - /// Parse a '->' token. - virtual ParseResult parseArrow() = 0; - - /// Parse a '->' token if present - virtual ParseResult parseOptionalArrow() = 0; - - /// Parse a '{' token. - virtual ParseResult parseLBrace() = 0; - - /// Parse a '{' token if present - virtual ParseResult parseOptionalLBrace() = 0; - - /// Parse a `}` token. - virtual ParseResult parseRBrace() = 0; - - /// Parse a `}` token if present - virtual ParseResult parseOptionalRBrace() = 0; - - /// Parse a `:` token. - virtual ParseResult parseColon() = 0; - - /// Parse a `:` token if present. - virtual ParseResult parseOptionalColon() = 0; - - /// Parse a `,` token. - virtual ParseResult parseComma() = 0; - - /// Parse a `,` token if present. - virtual ParseResult parseOptionalComma() = 0; - - /// Parse a `=` token. - virtual ParseResult parseEqual() = 0; - - /// Parse a `=` token if present. - virtual ParseResult parseOptionalEqual() = 0; - - /// Parse a quoted string token. - ParseResult parseString(std::string *string) { - auto loc = getCurrentLocation(); - if (parseOptionalString(string)) - return emitError(loc, "expected string"); - return success(); - } - - /// Parse a quoted string token if present. - virtual ParseResult parseOptionalString(std::string *string) = 0; - - /// Parse a given keyword. - ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { - auto loc = getCurrentLocation(); - if (parseOptionalKeyword(keyword)) - return emitError(loc, "expected '") << keyword << "'" << msg; - return success(); - } - - /// Parse a keyword into 'keyword'. - ParseResult parseKeyword(StringRef *keyword) { - auto loc = getCurrentLocation(); - if (parseOptionalKeyword(keyword)) - return emitError(loc, "expected valid keyword"); - return success(); - } - - /// Parse the given keyword if present. - virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; - - /// Parse a keyword, if present, into 'keyword'. - virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; - - /// Parse a '<' token. - virtual ParseResult parseLess() = 0; - - /// Parse a `<` token if present. - virtual ParseResult parseOptionalLess() = 0; - - /// Parse a '>' token. - virtual ParseResult parseGreater() = 0; - - /// Parse a `>` token if present. - virtual ParseResult parseOptionalGreater() = 0; - - /// Parse a `(` token. - virtual ParseResult parseLParen() = 0; - - /// Parse a `(` token if present. - virtual ParseResult parseOptionalLParen() = 0; - - /// Parse a `)` token. - virtual ParseResult parseRParen() = 0; - - /// Parse a `)` token if present. - virtual ParseResult parseOptionalRParen() = 0; - - /// Parse a `[` token. - virtual ParseResult parseLSquare() = 0; - - /// Parse a `[` token if present. - virtual ParseResult parseOptionalLSquare() = 0; - - /// Parse a `]` token. - virtual ParseResult parseRSquare() = 0; - - /// Parse a `]` token if present. - virtual ParseResult parseOptionalRSquare() = 0; - - /// Parse a `...` token if present; - virtual ParseResult parseOptionalEllipsis() = 0; - - /// Parse a `?` token. - virtual ParseResult parseOptionalQuestion() = 0; - - /// Parse a `*` token. - virtual ParseResult parseOptionalStar() = 0; - - //===--------------------------------------------------------------------===// - // Attribute Parsing - //===--------------------------------------------------------------------===// - - /// Parse an arbitrary attribute and return it in result. - virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; - - /// Parse an attribute of a specific kind and type. - template - ParseResult parseAttribute(AttrType &result, Type type = {}) { - llvm::SMLoc loc = getCurrentLocation(); - - // Parse any kind of attribute. - Attribute attr; - if (parseAttribute(attr, type)) - return failure(); - - // Check for the right kind of attribute. - result = attr.dyn_cast(); - if (!result) - return emitError(loc, "invalid kind of attribute specified"); - return success(); - } - - /// Parse an affine map instance into 'map'. - virtual ParseResult parseAffineMap(AffineMap &map) = 0; - - /// Parse an integer set instance into 'set'. - virtual ParseResult printIntegerSet(IntegerSet &set) = 0; - - //===--------------------------------------------------------------------===// - // Type Parsing - //===--------------------------------------------------------------------===// - - /// Parse a type. - virtual ParseResult parseType(Type &result) = 0; - - /// Parse a type of a specific kind, e.g. a FunctionType. - template - ParseResult parseType(TypeType &result) { - llvm::SMLoc loc = getCurrentLocation(); - - // Parse any kind of type. - Type type; - if (parseType(type)) - return failure(); - - // Check for the right kind of attribute. - result = type.dyn_cast(); - if (!result) - return emitError(loc, "invalid kind of type specified"); - return success(); - } - - /// Parse a type if present. - virtual OptionalParseResult parseOptionalType(Type &result) = 0; - - /// Parse a 'x' separated dimension list. This populates the dimension list, - /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on - /// `?` otherwise. - /// - /// dimension-list ::= (dimension `x`)* - /// dimension ::= `?` | integer - /// - /// When `allowDynamic` is not set, this is used to parse: - /// - /// static-dimension-list ::= (integer `x`)* - virtual ParseResult parseDimensionList(SmallVectorImpl &dimensions, - bool allowDynamic = true) = 0; - - /// Parse an 'x' token in a dimension list, handling the case where the x is - /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the - /// next token. - virtual ParseResult parseXInDimensionList() = 0; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -24,17 +24,179 @@ class Builder; +//===----------------------------------------------------------------------===// +// AsmPrinter +//===----------------------------------------------------------------------===// + +/// This base class exposes generic asm printer hooks, usable across the various +/// derived printers. +class AsmPrinter { +public: + /// This class contains the internal default implementation of the base + /// printer methods. + class Impl; + + /// Initialize the printer with the given internal implementation. + AsmPrinter(Impl &impl) : impl(&impl) {} + virtual ~AsmPrinter(); + + /// Return the raw output stream used by this printer. + virtual raw_ostream &getStream() const; + + /// Print the given floating point value in a stabilized form that can be + /// roundtripped through the IR. This is the companion to the 'parseFloat' + /// hook on the AsmParser. + virtual void printFloat(const APFloat &value); + + virtual void printType(Type type); + virtual void printAttribute(Attribute attr); + + /// Print the given attribute without its type. The corresponding parser must + /// provide a valid type for the attribute. + virtual void printAttributeWithoutType(Attribute attr); + + /// Print the given string as a symbol reference, i.e. a form representable by + /// a SymbolRefAttr. A symbol reference is represented as a string prefixed + /// with '@'. The reference is surrounded with ""'s and escaped if it has any + /// special or non-printable characters in it. + virtual void printSymbolName(StringRef symbolRef); + + /// Print an optional arrow followed by a type list. + template + void printOptionalArrowTypeList(TypeRange &&types) { + if (types.begin() != types.end()) + printArrowTypeList(types); + } + template + void printArrowTypeList(TypeRange &&types) { + auto &os = getStream() << " -> "; + + bool wrapped = !llvm::hasSingleElement(types) || + (*types.begin()).template isa(); + if (wrapped) + os << '('; + llvm::interleaveComma(types, *this); + if (wrapped) + os << ')'; + } + + /// Print the two given type ranges in a functional form. + template + void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { + auto &os = getStream(); + os << '('; + llvm::interleaveComma(inputs, *this); + os << ')'; + printArrowTypeList(results); + } + +protected: + /// Initialize the printer with no internal implementation. In this case, all + /// virtual methods of this class must be overriden. + AsmPrinter() : impl(nullptr) {} + +private: + AsmPrinter(const AsmPrinter &) = delete; + void operator=(const AsmPrinter &) = delete; + + /// The internal implementation of the printer. + Impl *impl; +}; + +template +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, Type type) { + p.printType(type); + return p; +} + +template +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, Attribute attr) { + p.printAttribute(attr); + return p; +} + +template +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, const APFloat &value) { + p.printFloat(value); + return p; +} +template +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, float value) { + return p << APFloat(value); +} +template +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, double value) { + return p << APFloat(value); +} + +// Support printing anything that isn't convertible to one of the other +// streamable types, even if it isn't exactly one of them. For example, we want +// to print FunctionType with the Type version above, not have it match this. +template < + typename AsmPrinterT, typename T, + typename std::enable_if::value && + !std::is_convertible::value && + !std::is_convertible::value && + !std::is_convertible::value && + !std::is_convertible::value && + !llvm::is_one_of::value, + T>::type * = nullptr> +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, const T &other) { + p.getStream() << other; + return p; +} + +template +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, bool value) { + return p << (value ? StringRef("true") : "false"); +} + +template +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, const ValueTypeRange &types) { + llvm::interleaveComma(types, p); + return p; +} +template +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, const TypeRange &types) { + llvm::interleaveComma(types, p); + return p; +} +template +inline std::enable_if_t::value, + AsmPrinterT &> +operator<<(AsmPrinterT &p, ArrayRef types) { + llvm::interleaveComma(types, p); + return p; +} + //===----------------------------------------------------------------------===// // OpAsmPrinter //===----------------------------------------------------------------------===// /// This is a pure-virtual base class that exposes the asmprinter hooks /// necessary to implement a custom print() method. -class OpAsmPrinter { +class OpAsmPrinter : public AsmPrinter { public: - OpAsmPrinter() {} - virtual ~OpAsmPrinter(); - virtual raw_ostream &getStream() const = 0; + using AsmPrinter::AsmPrinter; + ~OpAsmPrinter() override; /// Print a newline and indent the printer to the start of the current /// operation. @@ -70,12 +232,6 @@ printOperand(*it); } } - virtual void printType(Type type) = 0; - virtual void printAttribute(Attribute attr) = 0; - - /// Print the given attribute without its type. The corresponding parser must - /// provide a valid type for the attribute. - virtual void printAttributeWithoutType(Attribute attr) = 0; /// Print the given successor. virtual void printSuccessor(Block *successor) = 0; @@ -131,47 +287,9 @@ virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands) = 0; - /// Print an optional arrow followed by a type list. - template - void printOptionalArrowTypeList(TypeRange &&types) { - if (types.begin() != types.end()) - printArrowTypeList(types); - } - template - void printArrowTypeList(TypeRange &&types) { - auto &os = getStream() << " -> "; - - bool wrapped = !llvm::hasSingleElement(types) || - (*types.begin()).template isa(); - if (wrapped) - os << '('; - llvm::interleaveComma(types, *this); - if (wrapped) - os << ')'; - } - /// Print the complete type of an operation in functional form. void printFunctionalType(Operation *op); - - /// Print the two given type ranges in a functional form. - template - void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { - auto &os = getStream(); - os << '('; - llvm::interleaveComma(inputs, *this); - os << ')'; - printArrowTypeList(results); - } - - /// Print the given string as a symbol reference, i.e. a form representable by - /// a SymbolRefAttr. A symbol reference is represented as a string prefixed - /// with '@'. The reference is surrounded with ""'s and escaped if it has any - /// special or non-printable characters in it. - virtual void printSymbolName(StringRef symbolRef) = 0; - -private: - OpAsmPrinter(const OpAsmPrinter &) = delete; - void operator=(const OpAsmPrinter &) = delete; + using AsmPrinter::printFunctionalType; }; // Make the implementations convenient to use. @@ -189,77 +307,28 @@ return p; } -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) { - p.printType(type); - return p; -} - -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) { - p.printAttribute(attr); - return p; -} - -// Support printing anything that isn't convertible to one of the above types, -// even if it isn't exactly one of them. For example, we want to print -// FunctionType with the Type version above, not have it match this. -template ::value && - !std::is_convertible::value && - !std::is_convertible::value && - !std::is_convertible::value && - !llvm::is_one_of::value, - T>::type * = nullptr> -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) { - p.getStream() << other; - return p; -} - -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) { - return p << (value ? StringRef("true") : "false"); -} - inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { p.printSuccessor(value); return p; } -template -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, - const ValueTypeRange &types) { - llvm::interleaveComma(types, p); - return p; -} -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) { - llvm::interleaveComma(types, p); - return p; -} -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef types) { - llvm::interleaveComma(types, p); - return p; -} - //===----------------------------------------------------------------------===// -// OpAsmParser +// AsmParser //===----------------------------------------------------------------------===// -/// The OpAsmParser has methods for interacting with the asm parser: parsing -/// things from it, emitting errors etc. It has an intentionally high-level API -/// that is designed to reduce/constrain syntax innovation in individual -/// operations. -/// -/// For example, consider an op like this: -/// -/// %x = load %p[%1, %2] : memref<...> -/// -/// The "%x = load" tokens are already parsed and therefore invisible to the -/// custom op parser. This can be supported by calling `parseOperandList` to -/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to -/// parse the indices, then calling `parseColonTypeList` to parse the result -/// type. -/// -class OpAsmParser { +/// This base class exposes generic asm parser hooks, usable across the various +/// derived parsers. +class AsmParser { public: - virtual ~OpAsmParser(); + AsmParser() = default; + virtual ~AsmParser(); + + /// Return the location of the original name token. + virtual llvm::SMLoc getNameLoc() const = 0; + + //===--------------------------------------------------------------------===// + // Utilities + //===--------------------------------------------------------------------===// /// Emit a diagnostic at the specified location and return failure. virtual InFlightDiagnostic emitError(llvm::SMLoc loc, @@ -277,44 +346,11 @@ return success(); } - /// Return the name of the specified result in the specified syntax, as well - /// as the sub-element in the name. It returns an empty string and ~0U for - /// invalid result numbers. For example, in this operation: - /// - /// %x, %y:2, %z = foo.op - /// - /// getResultName(0) == {"x", 0 } - /// getResultName(1) == {"y", 0 } - /// getResultName(2) == {"y", 1 } - /// getResultName(3) == {"z", 0 } - /// getResultName(4) == {"", ~0U } - virtual std::pair - getResultName(unsigned resultNo) const = 0; - - /// Return the number of declared SSA results. This returns 4 for the foo.op - /// example in the comment for `getResultName`. - virtual size_t getNumResults() const = 0; - - /// Return the location of the original name token. - virtual llvm::SMLoc getNameLoc() const = 0; - /// Re-encode the given source location as an MLIR location and return it. /// Note: This method should only be used when a `Location` is necessary, as /// the encoding process is not efficient. virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; - // These methods emit an error and return failure or success. This allows - // these to be chained together into a linear sequence of || expressions in - // many cases. - - /// Parse an operation in its generic form. - /// The parsed operation is parsed in the current context and inserted in the - /// provided block and insertion point. The results produced by this operation - /// aren't mapped to any named value in the parser. Returns nullptr on - /// failure. - virtual Operation *parseGenericOperation(Block *insertBlock, - Block::iterator insertPt) = 0; - //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// @@ -385,6 +421,17 @@ /// Parse a '*' token if present. virtual ParseResult parseOptionalStar() = 0; + /// Parse a quoted string token. + ParseResult parseString(std::string *string) { + auto loc = getCurrentLocation(); + if (parseOptionalString(string)) + return emitError(loc, "expected string"); + return success(); + } + + /// Parse a quoted string token if present. + virtual ParseResult parseOptionalString(std::string *string) = 0; + /// Parse a given keyword. ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { auto loc = getCurrentLocation(); @@ -440,6 +487,9 @@ /// Parse a `...` token if present; virtual ParseResult parseOptionalEllipsis() = 0; + /// Parse a floating point value from the stream. + virtual ParseResult parseFloat(double &result) = 0; + /// Parse an integer value from the stream. template ParseResult parseInteger(IntT &result) { @@ -514,6 +564,27 @@ return parseCommaSeparatedList(Delimiter::None, parseElementFn); } + //===--------------------------------------------------------------------===// + // Attribute/Type Parsing + //===--------------------------------------------------------------------===// + + /// Invoke the `getChecked` method of the given Attribute or Type class, using + /// the provided location to emit errors in the case of failure. Note that + /// unlike `OpBuilder::getType`, this method does not implicitly insert a + /// context parameter. + template + T getChecked(llvm::SMLoc loc, ParamsT &&... params) { + return T::getChecked([&] { return emitError(loc); }, + std::forward(params)...); + } + /// A variant of `getChecked` that uses the result of `getNameLoc` to emit + /// errors. + template + T getChecked(ParamsT &&... params) { + return T::getChecked([&] { return emitError(getNameLoc()); }, + std::forward(params)...); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// @@ -634,6 +705,180 @@ virtual ParseResult parseOptionalLocationSpecifier(Optional &result) = 0; + //===--------------------------------------------------------------------===// + // Type Parsing + //===--------------------------------------------------------------------===// + + /// Parse a type. + virtual ParseResult parseType(Type &result) = 0; + + /// Parse an optional type. + virtual OptionalParseResult parseOptionalType(Type &result) = 0; + + /// Parse a type of a specific type. + template + ParseResult parseType(TypeT &result) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of type. + Type type; + if (parseType(type)) + return failure(); + + // Check for the right kind of attribute. + result = type.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of type specified"); + + return success(); + } + + /// Parse a type list. + ParseResult parseTypeList(SmallVectorImpl &result) { + do { + Type type; + if (parseType(type)) + return failure(); + result.push_back(type); + } while (succeeded(parseOptionalComma())); + return success(); + } + + /// Parse an arrow followed by a type list. + virtual ParseResult parseArrowTypeList(SmallVectorImpl &result) = 0; + + /// Parse an optional arrow followed by a type list. + virtual ParseResult + parseOptionalArrowTypeList(SmallVectorImpl &result) = 0; + + /// Parse a colon followed by a type. + virtual ParseResult parseColonType(Type &result) = 0; + + /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. + template + ParseResult parseColonType(TypeType &result) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of type. + Type type; + if (parseColonType(type)) + return failure(); + + // Check for the right kind of attribute. + result = type.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of type specified"); + + return success(); + } + + /// Parse a colon followed by a type list, which must have at least one type. + virtual ParseResult parseColonTypeList(SmallVectorImpl &result) = 0; + + /// Parse an optional colon followed by a type list, which if present must + /// have at least one type. + virtual ParseResult + parseOptionalColonTypeList(SmallVectorImpl &result) = 0; + + /// Parse a keyword followed by a type. + ParseResult parseKeywordType(const char *keyword, Type &result) { + return failure(parseKeyword(keyword) || parseType(result)); + } + + /// Add the specified type to the end of the specified type list and return + /// success. This is a helper designed to allow parse methods to be simple + /// and chain through || operators. + ParseResult addTypeToList(Type type, SmallVectorImpl &result) { + result.push_back(type); + return success(); + } + + /// Add the specified types to the end of the specified type list and return + /// success. This is a helper designed to allow parse methods to be simple + /// and chain through || operators. + ParseResult addTypesToList(ArrayRef types, + SmallVectorImpl &result) { + result.append(types.begin(), types.end()); + return success(); + } + + /// Parse a 'x' separated dimension list. This populates the dimension list, + /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on + /// `?` otherwise. + /// + /// dimension-list ::= (dimension `x`)* + /// dimension ::= `?` | integer + /// + /// When `allowDynamic` is not set, this is used to parse: + /// + /// static-dimension-list ::= (integer `x`)* + virtual ParseResult parseDimensionList(SmallVectorImpl &dimensions, + bool allowDynamic = true) = 0; + + /// Parse an 'x' token in a dimension list, handling the case where the x is + /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the + /// next token. + virtual ParseResult parseXInDimensionList() = 0; + +private: + AsmParser(const AsmParser &) = delete; + void operator=(const AsmParser &) = delete; +}; + +//===----------------------------------------------------------------------===// +// OpAsmParser +//===----------------------------------------------------------------------===// + +/// The OpAsmParser has methods for interacting with the asm parser: parsing +/// things from it, emitting errors etc. It has an intentionally high-level API +/// that is designed to reduce/constrain syntax innovation in individual +/// operations. +/// +/// For example, consider an op like this: +/// +/// %x = load %p[%1, %2] : memref<...> +/// +/// The "%x = load" tokens are already parsed and therefore invisible to the +/// custom op parser. This can be supported by calling `parseOperandList` to +/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to +/// parse the indices, then calling `parseColonTypeList` to parse the result +/// type. +/// +class OpAsmParser : public AsmParser { +public: + using AsmParser::AsmParser; + ~OpAsmParser() override; + + /// Return the name of the specified result in the specified syntax, as well + /// as the sub-element in the name. It returns an empty string and ~0U for + /// invalid result numbers. For example, in this operation: + /// + /// %x, %y:2, %z = foo.op + /// + /// getResultName(0) == {"x", 0 } + /// getResultName(1) == {"y", 0 } + /// getResultName(2) == {"y", 1 } + /// getResultName(3) == {"z", 0 } + /// getResultName(4) == {"", ~0U } + virtual std::pair + getResultName(unsigned resultNo) const = 0; + + /// Return the number of declared SSA results. This returns 4 for the foo.op + /// example in the comment for `getResultName`. + virtual size_t getNumResults() const = 0; + + // These methods emit an error and return failure or success. This allows + // these to be chained together into a linear sequence of || expressions in + // many cases. + + /// Parse an operation in its generic form. + /// The parsed operation is parsed in the current context and inserted in the + /// provided block and insertion point. The results produced by this operation + /// aren't mapped to any named value in the parser. Returns nullptr on + /// failure. + virtual Operation *parseGenericOperation(Block *insertBlock, + Block::iterator insertPt) = 0; + //===--------------------------------------------------------------------===// // Operand Parsing //===--------------------------------------------------------------------===// @@ -813,77 +1058,6 @@ // Type Parsing //===--------------------------------------------------------------------===// - /// Parse a type. - virtual ParseResult parseType(Type &result) = 0; - - /// Parse an optional type. - virtual OptionalParseResult parseOptionalType(Type &result) = 0; - - /// Parse a type of a specific type. - template - ParseResult parseType(TypeT &result) { - llvm::SMLoc loc = getCurrentLocation(); - - // Parse any kind of type. - Type type; - if (parseType(type)) - return failure(); - - // Check for the right kind of attribute. - result = type.dyn_cast(); - if (!result) - return emitError(loc, "invalid kind of type specified"); - - return success(); - } - - /// Parse a type list. - ParseResult parseTypeList(SmallVectorImpl &result) { - do { - Type type; - if (parseType(type)) - return failure(); - result.push_back(type); - } while (succeeded(parseOptionalComma())); - return success(); - } - - /// Parse an arrow followed by a type list. - virtual ParseResult parseArrowTypeList(SmallVectorImpl &result) = 0; - - /// Parse an optional arrow followed by a type list. - virtual ParseResult - parseOptionalArrowTypeList(SmallVectorImpl &result) = 0; - - /// Parse a colon followed by a type. - virtual ParseResult parseColonType(Type &result) = 0; - - /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. - template - ParseResult parseColonType(TypeType &result) { - llvm::SMLoc loc = getCurrentLocation(); - - // Parse any kind of type. - Type type; - if (parseColonType(type)) - return failure(); - - // Check for the right kind of attribute. - result = type.dyn_cast(); - if (!result) - return emitError(loc, "invalid kind of type specified"); - - return success(); - } - - /// Parse a colon followed by a type list, which must have at least one type. - virtual ParseResult parseColonTypeList(SmallVectorImpl &result) = 0; - - /// Parse an optional colon followed by a type list, which if present must - /// have at least one type. - virtual ParseResult - parseOptionalColonTypeList(SmallVectorImpl &result) = 0; - /// Parse a list of assignments of the form /// (%x1 = %y1, %x2 = %y2, ...) ParseResult parseAssignmentList(SmallVectorImpl &lhs, @@ -914,27 +1088,6 @@ parseOptionalAssignmentListWithTypes(SmallVectorImpl &lhs, SmallVectorImpl &rhs, SmallVectorImpl &types) = 0; - /// Parse a keyword followed by a type. - ParseResult parseKeywordType(const char *keyword, Type &result) { - return failure(parseKeyword(keyword) || parseType(result)); - } - - /// Add the specified type to the end of the specified type list and return - /// success. This is a helper designed to allow parse methods to be simple - /// and chain through || operators. - ParseResult addTypeToList(Type type, SmallVectorImpl &result) { - result.push_back(type); - return success(); - } - - /// Add the specified types to the end of the specified type list and return - /// success. This is a helper designed to allow parse methods to be simple - /// and chain through || operators. - ParseResult addTypesToList(ArrayRef types, - SmallVectorImpl &result) { - result.append(types.begin(), types.end()); - return success(); - } private: /// Parse either an operand list or a region argument list depending on diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -52,6 +52,18 @@ void OperationName::dump() const { print(llvm::errs()); } +//===--------------------------------------------------------------------===// +// AsmParser +//===--------------------------------------------------------------------===// + +AsmParser::~AsmParser() {} +DialectAsmParser::~DialectAsmParser() {} +OpAsmParser::~OpAsmParser() {} + +//===--------------------------------------------------------------------===// +// DialectAsmPrinter +//===--------------------------------------------------------------------===// + DialectAsmPrinter::~DialectAsmPrinter() {} //===--------------------------------------------------------------------===// @@ -250,12 +262,12 @@ struct NewLineCounter { unsigned curLine = 1; }; -} // end anonymous namespace static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { ++newLine.curLine; return os << '\n'; } +} // end anonymous namespace //===----------------------------------------------------------------------===// // AliasInitializer @@ -492,6 +504,7 @@ /// The following are hooks of `OpAsmPrinter` that are not necessary for /// determining potential aliases. + void printFloat(const APFloat &value) override {} void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} void printNewline() override {} @@ -1202,18 +1215,17 @@ AsmState::~AsmState() {} //===----------------------------------------------------------------------===// -// ModulePrinter +// AsmPrinter::Impl //===----------------------------------------------------------------------===// -namespace { -class ModulePrinter { +namespace mlir { +class AsmPrinter::Impl { public: - ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, - AsmStateImpl *state = nullptr) + Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None, + AsmStateImpl *state = nullptr) : os(os), printerFlags(flags), state(state) {} - explicit ModulePrinter(ModulePrinter &printer) - : os(printer.os), printerFlags(printer.printerFlags), - state(printer.state) {} + explicit Impl(Impl &other) + : Impl(other.os, other.printerFlags, other.state) {} /// Returns the output stream of the printer. raw_ostream &getStream() { return os; } @@ -1298,9 +1310,9 @@ /// A tracker for the number of new lines emitted during printing. NewLineCounter newLine; }; -} // end anonymous namespace +} // namespace mlir -void ModulePrinter::printTrailingLocation(Location loc, bool allowAlias) { +void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { // Check to see if we are printing debug information. if (!printerFlags.shouldPrintDebugInfo()) return; @@ -1309,7 +1321,7 @@ printLocation(loc, /*allowAlias=*/allowAlias); } -void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { +void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) { TypeSwitch(loc) .Case([&](OpaqueLoc loc) { printLocationInternal(loc.getFallbackLocation(), pretty); @@ -1430,7 +1442,7 @@ os << str; } -void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) { +void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { if (printerFlags.shouldPrintDebugInfoPrettyForm()) return printLocationInternal(loc, /*pretty=*/true); @@ -1578,8 +1590,8 @@ os << R"(opaque<"_", "0xDEADBEEF">)"; } -void ModulePrinter::printAttribute(Attribute attr, - AttrTypeElision typeElision) { +void AsmPrinter::Impl::printAttribute(Attribute attr, + AttrTypeElision typeElision) { if (!attr) { os << "<>"; return; @@ -1780,8 +1792,8 @@ os << ']'; } -void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, - bool allowHex) { +void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, + bool allowHex) { if (auto stringAttr = attr.dyn_cast()) return printDenseStringElementsAttr(stringAttr); @@ -1789,8 +1801,8 @@ allowHex); } -void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, - bool allowHex) { +void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( + DenseIntOrFPElementsAttr attr, bool allowHex) { auto type = attr.getType(); auto elementType = type.getElementType(); @@ -1860,7 +1872,8 @@ } } -void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) { +void AsmPrinter::Impl::printDenseStringElementsAttr( + DenseStringElementsAttr attr) { ArrayRef data = attr.getRawStringData(); auto printFn = [&](unsigned index) { os << "\""; @@ -1870,7 +1883,7 @@ printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); } -void ModulePrinter::printType(Type type) { +void AsmPrinter::Impl::printType(Type type) { if (!type) { os << "<>"; return; @@ -1986,9 +1999,9 @@ .Default([&](Type type) { return printDialectType(type); }); } -void ModulePrinter::printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs, - bool withKeyword) { +void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef attrs, + ArrayRef elidedAttrs, + bool withKeyword) { // If there are no attributes, then there is nothing to be done. if (attrs.empty()) return; @@ -2020,7 +2033,7 @@ printFilteredAttributesFn(filteredAttrs); } -void ModulePrinter::printNamedAttribute(NamedAttribute attr) { +void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { if (isBareIdentifier(attr.first)) { os << attr.first; } else { @@ -2037,81 +2050,82 @@ printAttribute(attr.second); } -//===----------------------------------------------------------------------===// -// CustomDialectAsmPrinter -//===----------------------------------------------------------------------===// - -namespace { -/// This class provides the main specialization of the DialectAsmPrinter that is -/// used to provide support for print attributes and types. This hooks allows -/// for dialects to hook into the main ModulePrinter. -struct CustomDialectAsmPrinter : public DialectAsmPrinter { -public: - CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {} - ~CustomDialectAsmPrinter() override {} - - raw_ostream &getStream() const override { return printer.getStream(); } - - /// Print the given attribute to the stream. - void printAttribute(Attribute attr) override { printer.printAttribute(attr); } - - /// Print the given attribute without its type. The corresponding parser must - /// provide a valid type for the attribute. - void printAttributeWithoutType(Attribute attr) override { - printer.printAttribute(attr, ModulePrinter::AttrTypeElision::Must); - } - - /// Print the given floating point value in a stablized form. - void printFloat(const APFloat &value) override { - printFloatValue(value, getStream()); - } - - /// Print the given type to the stream. - void printType(Type type) override { printer.printType(type); } - - /// The main module printer. - ModulePrinter &printer; -}; -} // end anonymous namespace - -void ModulePrinter::printDialectAttribute(Attribute attr) { +void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { auto &dialect = attr.getDialect(); // Ask the dialect to serialize the attribute to a string. std::string attrName; { llvm::raw_string_ostream attrNameStr(attrName); - ModulePrinter subPrinter(attrNameStr, printerFlags, state); - CustomDialectAsmPrinter printer(subPrinter); + Impl subPrinter(attrNameStr, printerFlags, state); + DialectAsmPrinter printer(subPrinter); dialect.printAttribute(attr, printer); } printDialectSymbol(os, "#", dialect.getNamespace(), attrName); } -void ModulePrinter::printDialectType(Type type) { +void AsmPrinter::Impl::printDialectType(Type type) { auto &dialect = type.getDialect(); // Ask the dialect to serialize the type to a string. std::string typeName; { llvm::raw_string_ostream typeNameStr(typeName); - ModulePrinter subPrinter(typeNameStr, printerFlags, state); - CustomDialectAsmPrinter printer(subPrinter); + Impl subPrinter(typeNameStr, printerFlags, state); + DialectAsmPrinter printer(subPrinter); dialect.printType(type, printer); } printDialectSymbol(os, "!", dialect.getNamespace(), typeName); } +//===--------------------------------------------------------------------===// +// AsmPrinter +//===--------------------------------------------------------------------===// + +AsmPrinter::~AsmPrinter() {} + +raw_ostream &AsmPrinter::getStream() const { + assert(impl && "expected AsmPrinter::getStream to be overriden"); + return impl->getStream(); +} + +/// Print the given floating point value in a stablized form. +void AsmPrinter::printFloat(const APFloat &value) { + assert(impl && "expected AsmPrinter::printFloat to be overriden"); + printFloatValue(value, impl->getStream()); +} + +void AsmPrinter::printType(Type type) { + assert(impl && "expected AsmPrinter::printType to be overriden"); + impl->printType(type); +} + +void AsmPrinter::printAttribute(Attribute attr) { + assert(impl && "expected AsmPrinter::printAttribute to be overriden"); + impl->printAttribute(attr); +} + +void AsmPrinter::printAttributeWithoutType(Attribute attr) { + assert(impl && + "expected AsmPrinter::printAttributeWithoutType to be overriden"); + impl->printAttribute(attr, Impl::AttrTypeElision::Must); +} + +void AsmPrinter::printSymbolName(StringRef symbolRef) { + assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); + ::printSymbolReference(symbolRef, impl->getStream()); +} + //===----------------------------------------------------------------------===// // Affine expressions and maps //===----------------------------------------------------------------------===// -void ModulePrinter::printAffineExpr( +void AsmPrinter::Impl::printAffineExpr( AffineExpr expr, function_ref printValueName) { printAffineExprInternal(expr, BindingStrength::Weak, printValueName); } -void ModulePrinter::printAffineExprInternal( +void AsmPrinter::Impl::printAffineExprInternal( AffineExpr expr, BindingStrength enclosingTightness, function_ref printValueName) { const char *binopSpelling = nullptr; @@ -2244,12 +2258,12 @@ os << ')'; } -void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { +void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { printAffineExprInternal(expr, BindingStrength::Weak); isEq ? os << " == 0" : os << " >= 0"; } -void ModulePrinter::printAffineMap(AffineMap map) { +void AsmPrinter::Impl::printAffineMap(AffineMap map) { // Dimension identifiers. os << '('; for (int i = 0; i < (int)map.getNumDims() - 1; ++i) @@ -2275,7 +2289,7 @@ os << ')'; } -void ModulePrinter::printIntegerSet(IntegerSet set) { +void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { // Dimension identifiers. os << '('; for (unsigned i = 1; i < set.getNumDims(); ++i) @@ -2313,11 +2327,14 @@ namespace { /// This class contains the logic for printing operations, regions, and blocks. -class OperationPrinter : public ModulePrinter, private OpAsmPrinter { +class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { public: + using Impl = AsmPrinter::Impl; + using Impl::printType; + explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, AsmStateImpl &state) - : ModulePrinter(os, flags, &state) {} + : Impl(os, flags, &state), OpAsmPrinter(static_cast(*this)) {} /// Print the given top-level operation. void printTopLevelOperation(Operation *op); @@ -2346,9 +2363,6 @@ // OpAsmPrinter methods //===--------------------------------------------------------------------===// - /// Return the current stream of the printer. - raw_ostream &getStream() const override { return os; } - /// Print a newline and indent the printer to the start of the current /// operation. void printNewline() override { @@ -2356,20 +2370,6 @@ os.indent(currentIndent); } - /// Print the given type. - void printType(Type type) override { ModulePrinter::printType(type); } - - /// Print the given attribute. - void printAttribute(Attribute attr) override { - ModulePrinter::printAttribute(attr); - } - - /// Print the given attribute without its type. The corresponding parser must - /// provide a valid type for the attribute. - void printAttributeWithoutType(Attribute attr) override { - ModulePrinter::printAttribute(attr, AttrTypeElision::Must); - } - /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. @@ -2388,13 +2388,13 @@ /// Print an optional attribute dictionary with a given set of elided values. void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) override { - ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); + Impl::printOptionalAttrDict(attrs, elidedAttrs); } void printOptionalAttrDictWithKeyword( ArrayRef attrs, ArrayRef elidedAttrs = {}) override { - ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs, - /*withKeyword=*/true); + Impl::printOptionalAttrDict(attrs, elidedAttrs, + /*withKeyword=*/true); } /// Print the given successor. @@ -2427,11 +2427,6 @@ void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands) override; - /// Print the given string as a symbol reference. - void printSymbolName(StringRef symbolRef) override { - ::printSymbolReference(symbolRef, os); - } - private: // Contains the stack of default dialects to use when printing regions. // A new dialect is pushed to the stack before parsing regions nested under an @@ -2732,7 +2727,7 @@ //===----------------------------------------------------------------------===// void Attribute::print(raw_ostream &os) const { - ModulePrinter(os).printAttribute(*this); + AsmPrinter::Impl(os).printAttribute(*this); } void Attribute::dump() const { @@ -2740,7 +2735,9 @@ llvm::errs() << "\n"; } -void Type::print(raw_ostream &os) const { ModulePrinter(os).printType(*this); } +void Type::print(raw_ostream &os) const { + AsmPrinter::Impl(os).printType(*this); +} void Type::dump() const { print(llvm::errs()); } @@ -2759,7 +2756,7 @@ os << "<>"; return; } - ModulePrinter(os).printAffineExpr(*this); + AsmPrinter::Impl(os).printAffineExpr(*this); } void AffineExpr::dump() const { @@ -2772,11 +2769,11 @@ os << "<>"; return; } - ModulePrinter(os).printAffineMap(*this); + AsmPrinter::Impl(os).printAffineMap(*this); } void IntegerSet::print(raw_ostream &os) const { - ModulePrinter(os).printIntegerSet(*this); + AsmPrinter::Impl(os).printIntegerSet(*this); } void Value::print(raw_ostream &os) { diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -24,8 +24,6 @@ using namespace mlir; using namespace detail; -DialectAsmParser::~DialectAsmParser() {} - //===----------------------------------------------------------------------===// // DialectRegistry //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -19,8 +19,6 @@ using namespace mlir; -OpAsmParser::~OpAsmParser() {} - //===----------------------------------------------------------------------===// // OperationName //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -0,0 +1,501 @@ +//===- AsmParserImpl.h - MLIR AsmParserImpl Class ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_PARSER_ASMPARSERIMPL_H +#define MLIR_LIB_PARSER_ASMPARSERIMPL_H + +#include "Parser.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Parser/AsmParserState.h" + +namespace mlir { +namespace detail { +//===----------------------------------------------------------------------===// +// AsmParserImpl +//===----------------------------------------------------------------------===// + +/// This class provides the implementation of the generic parser methods within +/// AsmParser. +template +class AsmParserImpl : public BaseT { +public: + AsmParserImpl(llvm::SMLoc nameLoc, Parser &parser) + : nameLoc(nameLoc), parser(parser) {} + ~AsmParserImpl() override {} + + /// Return the location of the original name token. + llvm::SMLoc getNameLoc() const override { return nameLoc; } + + //===--------------------------------------------------------------------===// + // Utilities + //===--------------------------------------------------------------------===// + + /// Return if any errors were emitted during parsing. + bool didEmitError() const { return emittedError; } + + /// Emit a diagnostic at the specified location and return failure. + InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { + emittedError = true; + return parser.emitError(loc, message); + } + + /// Return a builder which provides useful access to MLIRContext, global + /// objects like types and attributes. + Builder &getBuilder() const override { return parser.builder; } + + /// Get the location of the next token and store it into the argument. This + /// always succeeds. + llvm::SMLoc getCurrentLocation() override { + return parser.getToken().getLoc(); + } + + /// Re-encode the given source location as an MLIR location and return it. + Location getEncodedSourceLoc(llvm::SMLoc loc) override { + return parser.getEncodedSourceLocation(loc); + } + + //===--------------------------------------------------------------------===// + // Token Parsing + //===--------------------------------------------------------------------===// + + using Delimiter = AsmParser::Delimiter; + + /// Parse a `->` token. + ParseResult parseArrow() override { + return parser.parseToken(Token::arrow, "expected '->'"); + } + + /// Parses a `->` if present. + ParseResult parseOptionalArrow() override { + return success(parser.consumeIf(Token::arrow)); + } + + /// Parse a '{' token. + ParseResult parseLBrace() override { + return parser.parseToken(Token::l_brace, "expected '{'"); + } + + /// Parse a '{' token if present + ParseResult parseOptionalLBrace() override { + return success(parser.consumeIf(Token::l_brace)); + } + + /// Parse a `}` token. + ParseResult parseRBrace() override { + return parser.parseToken(Token::r_brace, "expected '}'"); + } + + /// Parse a `}` token if present + ParseResult parseOptionalRBrace() override { + return success(parser.consumeIf(Token::r_brace)); + } + + /// Parse a `:` token. + ParseResult parseColon() override { + return parser.parseToken(Token::colon, "expected ':'"); + } + + /// Parse a `:` token if present. + ParseResult parseOptionalColon() override { + return success(parser.consumeIf(Token::colon)); + } + + /// Parse a `,` token. + ParseResult parseComma() override { + return parser.parseToken(Token::comma, "expected ','"); + } + + /// Parse a `,` token if present. + ParseResult parseOptionalComma() override { + return success(parser.consumeIf(Token::comma)); + } + + /// Parses a `...` if present. + ParseResult parseOptionalEllipsis() override { + return success(parser.consumeIf(Token::ellipsis)); + } + + /// Parse a `=` token. + ParseResult parseEqual() override { + return parser.parseToken(Token::equal, "expected '='"); + } + + /// Parse a `=` token if present. + ParseResult parseOptionalEqual() override { + return success(parser.consumeIf(Token::equal)); + } + + /// Parse a '<' token. + ParseResult parseLess() override { + return parser.parseToken(Token::less, "expected '<'"); + } + + /// Parse a `<` token if present. + ParseResult parseOptionalLess() override { + return success(parser.consumeIf(Token::less)); + } + + /// Parse a '>' token. + ParseResult parseGreater() override { + return parser.parseToken(Token::greater, "expected '>'"); + } + + /// Parse a `>` token if present. + ParseResult parseOptionalGreater() override { + return success(parser.consumeIf(Token::greater)); + } + + /// Parse a `(` token. + ParseResult parseLParen() override { + return parser.parseToken(Token::l_paren, "expected '('"); + } + + /// Parses a '(' if present. + ParseResult parseOptionalLParen() override { + return success(parser.consumeIf(Token::l_paren)); + } + + /// Parse a `)` token. + ParseResult parseRParen() override { + return parser.parseToken(Token::r_paren, "expected ')'"); + } + + /// Parses a ')' if present. + ParseResult parseOptionalRParen() override { + return success(parser.consumeIf(Token::r_paren)); + } + + /// Parse a `[` token. + ParseResult parseLSquare() override { + return parser.parseToken(Token::l_square, "expected '['"); + } + + /// Parses a '[' if present. + ParseResult parseOptionalLSquare() override { + return success(parser.consumeIf(Token::l_square)); + } + + /// Parse a `]` token. + ParseResult parseRSquare() override { + return parser.parseToken(Token::r_square, "expected ']'"); + } + + /// Parses a ']' if present. + ParseResult parseOptionalRSquare() override { + return success(parser.consumeIf(Token::r_square)); + } + + /// Parses a '?' token. + ParseResult parseQuestion() override { + return parser.parseToken(Token::question, "expected '?'"); + } + + /// Parses a '?' if present. + ParseResult parseOptionalQuestion() override { + return success(parser.consumeIf(Token::question)); + } + + /// Parses a '*' token. + ParseResult parseStar() override { + return parser.parseToken(Token::star, "expected '*'"); + } + + /// Parses a '*' if present. + ParseResult parseOptionalStar() override { + return success(parser.consumeIf(Token::star)); + } + + /// Parses a '+' token. + ParseResult parsePlus() override { + return parser.parseToken(Token::plus, "expected '+'"); + } + + /// Parses a '+' token if present. + ParseResult parseOptionalPlus() override { + return success(parser.consumeIf(Token::plus)); + } + + /// Parses a quoted string token if present. + ParseResult parseOptionalString(std::string *string) override { + if (!parser.getToken().is(Token::string)) + return failure(); + + if (string) + *string = parser.getToken().getStringValue(); + parser.consumeToken(); + return success(); + } + + /// Returns true if the current token corresponds to a keyword. + bool isCurrentTokenAKeyword() const { + return parser.getToken().isAny(Token::bare_identifier, Token::inttype) || + parser.getToken().isKeyword(); + } + + /// Parse the given keyword if present. + ParseResult parseOptionalKeyword(StringRef keyword) override { + // Check that the current token has the same spelling. + if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) + return failure(); + parser.consumeToken(); + return success(); + } + + /// Parse a keyword, if present, into 'keyword'. + ParseResult parseOptionalKeyword(StringRef *keyword) override { + // Check that the current token is a keyword. + if (!isCurrentTokenAKeyword()) + return failure(); + + *keyword = parser.getTokenSpelling(); + parser.consumeToken(); + return success(); + } + + /// Parse a keyword if it is one of the 'allowedKeywords'. + ParseResult + parseOptionalKeyword(StringRef *keyword, + ArrayRef allowedKeywords) override { + // Check that the current token is a keyword. + if (!isCurrentTokenAKeyword()) + return failure(); + + StringRef currentKeyword = parser.getTokenSpelling(); + if (llvm::is_contained(allowedKeywords, currentKeyword)) { + *keyword = currentKeyword; + parser.consumeToken(); + return success(); + } + + return failure(); + } + + /// Parse a floating point value from the stream. + ParseResult parseFloat(double &result) override { + bool isNegative = parser.consumeIf(Token::minus); + Token curTok = parser.getToken(); + llvm::SMLoc loc = curTok.getLoc(); + + // Check for a floating point value. + if (curTok.is(Token::floatliteral)) { + auto val = curTok.getFloatingPointValue(); + if (!val.hasValue()) + return emitError(loc, "floating point value too large"); + parser.consumeToken(Token::floatliteral); + result = isNegative ? -*val : *val; + return success(); + } + + // Check for a hexadecimal float value. + if (curTok.is(Token::integer)) { + Optional apResult; + if (failed(parser.parseFloatFromIntegerLiteral( + apResult, curTok, isNegative, APFloat::IEEEdouble(), + /*typeSizeInBits=*/64))) + return failure(); + + parser.consumeToken(Token::integer); + result = apResult->convertToDouble(); + return success(); + } + + return emitError(loc, "expected floating point literal"); + } + + /// Parse an optional integer value from the stream. + OptionalParseResult parseOptionalInteger(APInt &result) override { + return parser.parseOptionalInteger(result); + } + + /// Parse a list of comma-separated items with an optional delimiter. If a + /// delimiter is provided, then an empty list is allowed. If not, then at + /// least one element will be parsed. + ParseResult parseCommaSeparatedList(Delimiter delimiter, + function_ref parseElt, + StringRef contextMessage) override { + return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage); + } + + //===--------------------------------------------------------------------===// + // Attribute Parsing + //===--------------------------------------------------------------------===// + + /// Parse an arbitrary attribute and return it in result. + ParseResult parseAttribute(Attribute &result, Type type) override { + result = parser.parseAttribute(type); + return success(static_cast(result)); + } + + /// Parse an optional attribute. + template + OptionalParseResult + parseOptionalAttributeAndAddToList(AttrT &result, Type type, + StringRef attrName, NamedAttrList &attrs) { + OptionalParseResult parseResult = + parser.parseOptionalAttribute(result, type); + if (parseResult.hasValue() && succeeded(*parseResult)) + attrs.push_back(parser.builder.getNamedAttr(attrName, result)); + return parseResult; + } + OptionalParseResult parseOptionalAttribute(Attribute &result, Type type, + StringRef attrName, + NamedAttrList &attrs) override { + return parseOptionalAttributeAndAddToList(result, type, attrName, attrs); + } + OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type, + StringRef attrName, + NamedAttrList &attrs) override { + return parseOptionalAttributeAndAddToList(result, type, attrName, attrs); + } + OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type, + StringRef attrName, + NamedAttrList &attrs) override { + return parseOptionalAttributeAndAddToList(result, type, attrName, attrs); + } + + /// Parse a named dictionary into 'result' if it is present. + ParseResult parseOptionalAttrDict(NamedAttrList &result) override { + if (parser.getToken().isNot(Token::l_brace)) + return success(); + return parser.parseAttributeDict(result); + } + + /// Parse a named dictionary into 'result' if the `attributes` keyword is + /// present. + ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override { + if (failed(parseOptionalKeyword("attributes"))) + return success(); + return parser.parseAttributeDict(result); + } + + /// Parse an affine map instance into 'map'. + ParseResult parseAffineMap(AffineMap &map) override { + return parser.parseAffineMapReference(map); + } + + /// Parse an integer set instance into 'set'. + ParseResult printIntegerSet(IntegerSet &set) override { + return parser.parseIntegerSetReference(set); + } + + //===--------------------------------------------------------------------===// + // Identifier Parsing + //===--------------------------------------------------------------------===// + + /// Parse an optional @-identifier and store it (without the '@' symbol) in a + /// string attribute named 'attrName'. + ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, + NamedAttrList &attrs) override { + Token atToken = parser.getToken(); + if (atToken.isNot(Token::at_identifier)) + return failure(); + + result = getBuilder().getStringAttr(atToken.getSymbolReference()); + attrs.push_back(getBuilder().getNamedAttr(attrName, result)); + parser.consumeToken(); + + // If we are populating the assembly parser state, record this as a symbol + // reference. + if (parser.getState().asmState) { + parser.getState().asmState->addUses(SymbolRefAttr::get(result), + atToken.getLocRange()); + } + return success(); + } + + /// Parse a loc(...) specifier if present, filling in result if so. + ParseResult + parseOptionalLocationSpecifier(Optional &result) override { + // If there is a 'loc' we parse a trailing location. + if (!parser.consumeIf(Token::kw_loc)) + return success(); + LocationAttr directLoc; + if (parser.parseToken(Token::l_paren, "expected '(' in location") || + parser.parseLocationInstance(directLoc) || + parser.parseToken(Token::r_paren, "expected ')' in location")) + return failure(); + + result = directLoc; + return success(); + } + + //===--------------------------------------------------------------------===// + // Type Parsing + //===--------------------------------------------------------------------===// + + /// Parse a type. + ParseResult parseType(Type &result) override { + return failure(!(result = parser.parseType())); + } + + /// Parse an optional type. + OptionalParseResult parseOptionalType(Type &result) override { + return parser.parseOptionalType(result); + } + + /// Parse an arrow followed by a type list. + ParseResult parseArrowTypeList(SmallVectorImpl &result) override { + if (parseArrow() || parser.parseFunctionResultTypes(result)) + return failure(); + return success(); + } + + /// Parse an optional arrow followed by a type list. + ParseResult + parseOptionalArrowTypeList(SmallVectorImpl &result) override { + if (!parser.consumeIf(Token::arrow)) + return success(); + return parser.parseFunctionResultTypes(result); + } + + /// Parse a colon followed by a type. + ParseResult parseColonType(Type &result) override { + return failure(parser.parseToken(Token::colon, "expected ':'") || + !(result = parser.parseType())); + } + + /// Parse a colon followed by a type list, which must have at least one type. + ParseResult parseColonTypeList(SmallVectorImpl &result) override { + if (parser.parseToken(Token::colon, "expected ':'")) + return failure(); + return parser.parseTypeListNoParens(result); + } + + /// Parse an optional colon followed by a type list, which if present must + /// have at least one type. + ParseResult + parseOptionalColonTypeList(SmallVectorImpl &result) override { + if (!parser.consumeIf(Token::colon)) + return success(); + return parser.parseTypeListNoParens(result); + } + + ParseResult parseDimensionList(SmallVectorImpl &dimensions, + bool allowDynamic) override { + return parser.parseDimensionListRanked(dimensions, allowDynamic); + } + + ParseResult parseXInDimensionList() override { + return parser.parseXInDimensionList(); + } + +protected: + /// The source location of the dialect symbol. + llvm::SMLoc nameLoc; + + /// The main parser. + Parser &parser; + + /// A flag that indicates if any errors were emitted during parsing. + bool emittedError = false; +}; +} // namespace detail +} // end namespace mlir + +#endif // MLIR_LIB_PARSER_ASMPARSERIMPL_H diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -11,7 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" +#include "AsmParserImpl.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -27,304 +27,20 @@ /// This class provides the main implementation of the DialectAsmParser that /// allows for dialects to parse attributes and types. This allows for dialect /// hooking into the main MLIR parsing logic. -class CustomDialectAsmParser : public DialectAsmParser { +class CustomDialectAsmParser : public AsmParserImpl { public: CustomDialectAsmParser(StringRef fullSpec, Parser &parser) - : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), - parser(parser) {} + : AsmParserImpl(parser.getToken().getLoc(), parser), + fullSpec(fullSpec) {} ~CustomDialectAsmParser() override {} - /// Emit a diagnostic at the specified location and return failure. - InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { - return parser.emitError(loc, message); - } - - /// Return a builder which provides useful access to MLIRContext, global - /// objects like types and attributes. - Builder &getBuilder() const override { return parser.builder; } - - /// Get the location of the next token and store it into the argument. This - /// always succeeds. - llvm::SMLoc getCurrentLocation() override { - return parser.getToken().getLoc(); - } - - /// Return the location of the original name token. - llvm::SMLoc getNameLoc() const override { return nameLoc; } - - /// Re-encode the given source location as an MLIR location and return it. - Location getEncodedSourceLoc(llvm::SMLoc loc) override { - return parser.getEncodedSourceLocation(loc); - } - /// Returns the full specification of the symbol being parsed. This allows /// for using a separate parser if necessary. StringRef getFullSymbolSpec() const override { return fullSpec; } - /// Parse a floating point value from the stream. - ParseResult parseFloat(double &result) override { - bool isNegative = parser.consumeIf(Token::minus); - Token curTok = parser.getToken(); - llvm::SMLoc loc = curTok.getLoc(); - - // Check for a floating point value. - if (curTok.is(Token::floatliteral)) { - auto val = curTok.getFloatingPointValue(); - if (!val.hasValue()) - return emitError(loc, "floating point value too large"); - parser.consumeToken(Token::floatliteral); - result = isNegative ? -*val : *val; - return success(); - } - - // Check for a hexadecimal float value. - if (curTok.is(Token::integer)) { - Optional apResult; - if (failed(parser.parseFloatFromIntegerLiteral( - apResult, curTok, isNegative, APFloat::IEEEdouble(), - /*typeSizeInBits=*/64))) - return failure(); - - parser.consumeToken(Token::integer); - result = apResult->convertToDouble(); - return success(); - } - - return emitError(loc, "expected floating point literal"); - } - - /// Parse an optional integer value from the stream. - OptionalParseResult parseOptionalInteger(APInt &result) override { - return parser.parseOptionalInteger(result); - } - - //===--------------------------------------------------------------------===// - // Token Parsing - //===--------------------------------------------------------------------===// - - /// Parse a `->` token. - ParseResult parseArrow() override { - return parser.parseToken(Token::arrow, "expected '->'"); - } - - /// Parses a `->` if present. - ParseResult parseOptionalArrow() override { - return success(parser.consumeIf(Token::arrow)); - } - - /// Parse a '{' token. - ParseResult parseLBrace() override { - return parser.parseToken(Token::l_brace, "expected '{'"); - } - - /// Parse a '{' token if present - ParseResult parseOptionalLBrace() override { - return success(parser.consumeIf(Token::l_brace)); - } - - /// Parse a `}` token. - ParseResult parseRBrace() override { - return parser.parseToken(Token::r_brace, "expected '}'"); - } - - /// Parse a `}` token if present - ParseResult parseOptionalRBrace() override { - return success(parser.consumeIf(Token::r_brace)); - } - - /// Parse a `:` token. - ParseResult parseColon() override { - return parser.parseToken(Token::colon, "expected ':'"); - } - - /// Parse a `:` token if present. - ParseResult parseOptionalColon() override { - return success(parser.consumeIf(Token::colon)); - } - - /// Parse a `,` token. - ParseResult parseComma() override { - return parser.parseToken(Token::comma, "expected ','"); - } - - /// Parse a `,` token if present. - ParseResult parseOptionalComma() override { - return success(parser.consumeIf(Token::comma)); - } - - /// Parses a `...` if present. - ParseResult parseOptionalEllipsis() override { - return success(parser.consumeIf(Token::ellipsis)); - } - - /// Parse a `=` token. - ParseResult parseEqual() override { - return parser.parseToken(Token::equal, "expected '='"); - } - - /// Parse a `=` token if present. - ParseResult parseOptionalEqual() override { - return success(parser.consumeIf(Token::equal)); - } - - /// Parse a '<' token. - ParseResult parseLess() override { - return parser.parseToken(Token::less, "expected '<'"); - } - - /// Parse a `<` token if present. - ParseResult parseOptionalLess() override { - return success(parser.consumeIf(Token::less)); - } - - /// Parse a '>' token. - ParseResult parseGreater() override { - return parser.parseToken(Token::greater, "expected '>'"); - } - - /// Parse a `>` token if present. - ParseResult parseOptionalGreater() override { - return success(parser.consumeIf(Token::greater)); - } - - /// Parse a `(` token. - ParseResult parseLParen() override { - return parser.parseToken(Token::l_paren, "expected '('"); - } - - /// Parses a '(' if present. - ParseResult parseOptionalLParen() override { - return success(parser.consumeIf(Token::l_paren)); - } - - /// Parse a `)` token. - ParseResult parseRParen() override { - return parser.parseToken(Token::r_paren, "expected ')'"); - } - - /// Parses a ')' if present. - ParseResult parseOptionalRParen() override { - return success(parser.consumeIf(Token::r_paren)); - } - - /// Parse a `[` token. - ParseResult parseLSquare() override { - return parser.parseToken(Token::l_square, "expected '['"); - } - - /// Parses a '[' if present. - ParseResult parseOptionalLSquare() override { - return success(parser.consumeIf(Token::l_square)); - } - - /// Parse a `]` token. - ParseResult parseRSquare() override { - return parser.parseToken(Token::r_square, "expected ']'"); - } - - /// Parses a ']' if present. - ParseResult parseOptionalRSquare() override { - return success(parser.consumeIf(Token::r_square)); - } - - /// Parses a '?' if present. - ParseResult parseOptionalQuestion() override { - return success(parser.consumeIf(Token::question)); - } - - /// Parses a '*' if present. - ParseResult parseOptionalStar() override { - return success(parser.consumeIf(Token::star)); - } - - /// Parses a quoted string token if present. - ParseResult parseOptionalString(std::string *string) override { - if (!parser.getToken().is(Token::string)) - return failure(); - - if (string) - *string = parser.getToken().getStringValue(); - parser.consumeToken(); - return success(); - } - - /// Returns true if the current token corresponds to a keyword. - bool isCurrentTokenAKeyword() const { - return parser.getToken().isAny(Token::bare_identifier, Token::inttype) || - parser.getToken().isKeyword(); - } - - /// Parse the given keyword if present. - ParseResult parseOptionalKeyword(StringRef keyword) override { - // Check that the current token has the same spelling. - if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) - return failure(); - parser.consumeToken(); - return success(); - } - - /// Parse a keyword, if present, into 'keyword'. - ParseResult parseOptionalKeyword(StringRef *keyword) override { - // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) - return failure(); - - *keyword = parser.getTokenSpelling(); - parser.consumeToken(); - return success(); - } - - //===--------------------------------------------------------------------===// - // Attribute Parsing - //===--------------------------------------------------------------------===// - - /// Parse an arbitrary attribute and return it in result. - ParseResult parseAttribute(Attribute &result, Type type) override { - result = parser.parseAttribute(type); - return success(static_cast(result)); - } - - /// Parse an affine map instance into 'map'. - ParseResult parseAffineMap(AffineMap &map) override { - return parser.parseAffineMapReference(map); - } - - /// Parse an integer set instance into 'set'. - ParseResult printIntegerSet(IntegerSet &set) override { - return parser.parseIntegerSetReference(set); - } - - //===--------------------------------------------------------------------===// - // Type Parsing - //===--------------------------------------------------------------------===// - - ParseResult parseType(Type &result) override { - result = parser.parseType(); - return success(static_cast(result)); - } - - ParseResult parseDimensionList(SmallVectorImpl &dimensions, - bool allowDynamic) override { - return parser.parseDimensionListRanked(dimensions, allowDynamic); - } - - ParseResult parseXInDimensionList() override { - return parser.parseXInDimensionList(); - } - - OptionalParseResult parseOptionalType(Type &result) override { - return parser.parseOptionalType(result); - } - private: /// The full symbol specification. StringRef fullSpec; - - /// The source location of the dialect symbol. - SMLoc nameLoc; - - /// The main parser. - Parser &parser; }; } // namespace diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "Parser.h" +#include "AsmParserImpl.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" @@ -1093,15 +1094,15 @@ } namespace { -class CustomOpAsmParser : public OpAsmParser { +class CustomOpAsmParser : public AsmParserImpl { public: CustomOpAsmParser( SMLoc nameLoc, ArrayRef resultIDs, function_ref parseAssembly, bool isIsolatedFromAbove, StringRef opName, OperationParser &parser) - : nameLoc(nameLoc), resultIDs(resultIDs), parseAssembly(parseAssembly), - isIsolatedFromAbove(isIsolatedFromAbove), opName(opName), - parser(parser) { + : AsmParserImpl(nameLoc, parser), resultIDs(resultIDs), + parseAssembly(parseAssembly), isIsolatedFromAbove(isIsolatedFromAbove), + opName(opName), parser(parser) { (void)isIsolatedFromAbove; // Only used in assert, silence unused warning. } @@ -1131,21 +1132,6 @@ // Utilities //===--------------------------------------------------------------------===// - /// Return if any errors were emitted during parsing. - bool didEmitError() const { return emittedError; } - - /// Emit a diagnostic at the specified location and return failure. - InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { - emittedError = true; - return parser.emitError(loc, "custom op '" + opName + "' " + message); - } - - llvm::SMLoc getCurrentLocation() override { - return parser.getToken().getLoc(); - } - - Builder &getBuilder() const override { return parser.builder; } - /// Return the name of the specified result in the specified syntax, as well /// as the subelement in the name. For example, in this operation: /// @@ -1181,331 +1167,10 @@ return count; } - llvm::SMLoc getNameLoc() const override { return nameLoc; } - - /// Re-encode the given source location as an MLIR location and return it. - Location getEncodedSourceLoc(llvm::SMLoc loc) override { - return parser.getEncodedSourceLocation(loc); - } - - //===--------------------------------------------------------------------===// - // Token Parsing - //===--------------------------------------------------------------------===// - - /// Parse a `->` token. - ParseResult parseArrow() override { - return parser.parseToken(Token::arrow, "expected '->'"); - } - - /// Parses a `->` if present. - ParseResult parseOptionalArrow() override { - return success(parser.consumeIf(Token::arrow)); - } - - /// Parse a '{' token. - ParseResult parseLBrace() override { - return parser.parseToken(Token::l_brace, "expected '{'"); - } - - /// Parse a '{' token if present - ParseResult parseOptionalLBrace() override { - return success(parser.consumeIf(Token::l_brace)); - } - - /// Parse a `}` token. - ParseResult parseRBrace() override { - return parser.parseToken(Token::r_brace, "expected '}'"); - } - - /// Parse a `}` token if present - ParseResult parseOptionalRBrace() override { - return success(parser.consumeIf(Token::r_brace)); - } - - /// Parse a `:` token. - ParseResult parseColon() override { - return parser.parseToken(Token::colon, "expected ':'"); - } - - /// Parse a `:` token if present. - ParseResult parseOptionalColon() override { - return success(parser.consumeIf(Token::colon)); - } - - /// Parse a `,` token. - ParseResult parseComma() override { - return parser.parseToken(Token::comma, "expected ','"); - } - - /// Parse a `,` token if present. - ParseResult parseOptionalComma() override { - return success(parser.consumeIf(Token::comma)); - } - - /// Parses a `...` if present. - ParseResult parseOptionalEllipsis() override { - return success(parser.consumeIf(Token::ellipsis)); - } - - /// Parse a `=` token. - ParseResult parseEqual() override { - return parser.parseToken(Token::equal, "expected '='"); - } - - /// Parse a `=` token if present. - ParseResult parseOptionalEqual() override { - return success(parser.consumeIf(Token::equal)); - } - - /// Parse a '<' token. - ParseResult parseLess() override { - return parser.parseToken(Token::less, "expected '<'"); - } - - /// Parse a '<' token if present. - ParseResult parseOptionalLess() override { - return success(parser.consumeIf(Token::less)); - } - - /// Parse a '>' token. - ParseResult parseGreater() override { - return parser.parseToken(Token::greater, "expected '>'"); - } - - /// Parse a '>' token if present. - ParseResult parseOptionalGreater() override { - return success(parser.consumeIf(Token::greater)); - } - - /// Parse a `(` token. - ParseResult parseLParen() override { - return parser.parseToken(Token::l_paren, "expected '('"); - } - - /// Parses a '(' if present. - ParseResult parseOptionalLParen() override { - return success(parser.consumeIf(Token::l_paren)); - } - - /// Parse a `)` token. - ParseResult parseRParen() override { - return parser.parseToken(Token::r_paren, "expected ')'"); - } - - /// Parses a ')' if present. - ParseResult parseOptionalRParen() override { - return success(parser.consumeIf(Token::r_paren)); - } - - /// Parse a `[` token. - ParseResult parseLSquare() override { - return parser.parseToken(Token::l_square, "expected '['"); - } - - /// Parses a '[' if present. - ParseResult parseOptionalLSquare() override { - return success(parser.consumeIf(Token::l_square)); - } - - /// Parse a `]` token. - ParseResult parseRSquare() override { - return parser.parseToken(Token::r_square, "expected ']'"); - } - - /// Parses a ']' if present. - ParseResult parseOptionalRSquare() override { - return success(parser.consumeIf(Token::r_square)); - } - - /// Parses a '?' token. - ParseResult parseQuestion() override { - return parser.parseToken(Token::question, "expected '?'"); - } - - /// Parses a '?' token if present. - ParseResult parseOptionalQuestion() override { - return success(parser.consumeIf(Token::question)); - } - - /// Parses a '+' token. - ParseResult parsePlus() override { - return parser.parseToken(Token::plus, "expected '+'"); - } - - /// Parses a '+' token if present. - ParseResult parseOptionalPlus() override { - return success(parser.consumeIf(Token::plus)); - } - - /// Parses a '*' token. - ParseResult parseStar() override { - return parser.parseToken(Token::star, "expected '*'"); - } - - /// Parses a '*' token if present. - ParseResult parseOptionalStar() override { - return success(parser.consumeIf(Token::star)); - } - - /// Parse an optional integer value from the stream. - OptionalParseResult parseOptionalInteger(APInt &result) override { - return parser.parseOptionalInteger(result); - } - - /// Parse a list of comma-separated items with an optional delimiter. If a - /// delimiter is provided, then an empty list is allowed. If not, then at - /// least one element will be parsed. - ParseResult parseCommaSeparatedList(Delimiter delimiter, - function_ref parseElt, - StringRef contextMessage) override { - return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage); - } - - //===--------------------------------------------------------------------===// - // Attribute Parsing - //===--------------------------------------------------------------------===// - - /// Parse an arbitrary attribute of a given type and return it in result. - ParseResult parseAttribute(Attribute &result, Type type) override { - result = parser.parseAttribute(type); - return success(static_cast(result)); - } - - /// Parse an optional attribute. - template - OptionalParseResult - parseOptionalAttributeAndAddToList(AttrT &result, Type type, - StringRef attrName, NamedAttrList &attrs) { - OptionalParseResult parseResult = - parser.parseOptionalAttribute(result, type); - if (parseResult.hasValue() && succeeded(*parseResult)) - attrs.push_back(parser.builder.getNamedAttr(attrName, result)); - return parseResult; - } - OptionalParseResult parseOptionalAttribute(Attribute &result, Type type, - StringRef attrName, - NamedAttrList &attrs) override { - return parseOptionalAttributeAndAddToList(result, type, attrName, attrs); - } - OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type, - StringRef attrName, - NamedAttrList &attrs) override { - return parseOptionalAttributeAndAddToList(result, type, attrName, attrs); - } - OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type, - StringRef attrName, - NamedAttrList &attrs) override { - return parseOptionalAttributeAndAddToList(result, type, attrName, attrs); - } - - /// Parse a named dictionary into 'result' if it is present. - ParseResult parseOptionalAttrDict(NamedAttrList &result) override { - if (parser.getToken().isNot(Token::l_brace)) - return success(); - return parser.parseAttributeDict(result); - } - - /// Parse a named dictionary into 'result' if the `attributes` keyword is - /// present. - ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override { - if (failed(parseOptionalKeyword("attributes"))) - return success(); - return parser.parseAttributeDict(result); - } - - /// Parse an affine map instance into 'map'. - ParseResult parseAffineMap(AffineMap &map) override { - return parser.parseAffineMapReference(map); - } - - /// Parse an integer set instance into 'set'. - ParseResult printIntegerSet(IntegerSet &set) override { - return parser.parseIntegerSetReference(set); - } - - //===--------------------------------------------------------------------===// - // Identifier Parsing - //===--------------------------------------------------------------------===// - - /// Returns true if the current token corresponds to a keyword. - bool isCurrentTokenAKeyword() const { - return parser.getToken().is(Token::bare_identifier) || - parser.getToken().isKeyword(); - } - - /// Parse the given keyword if present. - ParseResult parseOptionalKeyword(StringRef keyword) override { - // Check that the current token has the same spelling. - if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) - return failure(); - parser.consumeToken(); - return success(); - } - - /// Parse a keyword, if present, into 'keyword'. - ParseResult parseOptionalKeyword(StringRef *keyword) override { - // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) - return failure(); - - *keyword = parser.getTokenSpelling(); - parser.consumeToken(); - return success(); - } - - /// Parse a keyword if it is one of the 'allowedKeywords'. - ParseResult - parseOptionalKeyword(StringRef *keyword, - ArrayRef allowedKeywords) override { - // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) - return failure(); - - StringRef currentKeyword = parser.getTokenSpelling(); - if (llvm::is_contained(allowedKeywords, currentKeyword)) { - *keyword = currentKeyword; - parser.consumeToken(); - return success(); - } - - return failure(); - } - - /// Parse an optional @-identifier and store it (without the '@' symbol) in a - /// string attribute named 'attrName'. - ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, - NamedAttrList &attrs) override { - Token atToken = parser.getToken(); - if (atToken.isNot(Token::at_identifier)) - return failure(); - - result = getBuilder().getStringAttr(atToken.getSymbolReference()); - attrs.push_back(getBuilder().getNamedAttr(attrName, result)); - parser.consumeToken(); - - // If we are populating the assembly parser state, record this as a symbol - // reference. - if (parser.getState().asmState) { - parser.getState().asmState->addUses(SymbolRefAttr::get(result), - atToken.getLocRange()); - } - return success(); - } - - /// Parse a loc(...) specifier if present, filling in result if so. - ParseResult - parseOptionalLocationSpecifier(Optional &result) override { - // If there is a 'loc' we parse a trailing location. - if (!parser.consumeIf(Token::kw_loc)) - return success(); - LocationAttr directLoc; - if (parser.parseToken(Token::l_paren, "expected '(' in location") || - parser.parseLocationInstance(directLoc) || - parser.parseToken(Token::r_paren, "expected ')' in location")) - return failure(); - - result = directLoc; - return success(); + /// Emit a diagnostic at the specified location and return failure. + InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { + return AsmParserImpl::emitError(loc, "custom op '" + opName + + "' " + message); } //===--------------------------------------------------------------------===// @@ -1779,53 +1444,6 @@ // Type Parsing //===--------------------------------------------------------------------===// - /// Parse a type. - ParseResult parseType(Type &result) override { - return failure(!(result = parser.parseType())); - } - - /// Parse an optional type. - OptionalParseResult parseOptionalType(Type &result) override { - return parser.parseOptionalType(result); - } - - /// Parse an arrow followed by a type list. - ParseResult parseArrowTypeList(SmallVectorImpl &result) override { - if (parseArrow() || parser.parseFunctionResultTypes(result)) - return failure(); - return success(); - } - - /// Parse an optional arrow followed by a type list. - ParseResult - parseOptionalArrowTypeList(SmallVectorImpl &result) override { - if (!parser.consumeIf(Token::arrow)) - return success(); - return parser.parseFunctionResultTypes(result); - } - - /// Parse a colon followed by a type. - ParseResult parseColonType(Type &result) override { - return failure(parser.parseToken(Token::colon, "expected ':'") || - !(result = parser.parseType())); - } - - /// Parse a colon followed by a type list, which must have at least one type. - ParseResult parseColonTypeList(SmallVectorImpl &result) override { - if (parser.parseToken(Token::colon, "expected ':'")) - return failure(); - return parser.parseTypeListNoParens(result); - } - - /// Parse an optional colon followed by a type list, which if present must - /// have at least one type. - ParseResult - parseOptionalColonTypeList(SmallVectorImpl &result) override { - if (!parser.consumeIf(Token::colon)) - return success(); - return parser.parseTypeListNoParens(result); - } - /// Parse a list of assignments of the form /// (%x1 = %y1, %x2 = %y2, ...). OptionalParseResult @@ -1870,9 +1488,6 @@ } private: - /// The source location of the operation name. - SMLoc nameLoc; - /// Information about the result name specifiers. ArrayRef resultIDs; @@ -1881,11 +1496,8 @@ bool isIsolatedFromAbove; StringRef opName; - /// The main operation parser. + /// The backing operation parser. OperationParser &parser; - - /// A flag that indicates if any errors were emitted during parsing. - bool emittedError = false; }; } // end anonymous namespace.