diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td --- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td @@ -28,7 +28,7 @@ /// Test format has invalid syntax. def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> { let parameters = (ins "int":$v0, "int":$v1); - // CHECK: expected literal, directive, or variable + // CHECK: expected literal, variable, directive, or optional group let assemblyFormat = "`<` $v0, $v1 `>`"; } diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -97,7 +97,7 @@ def DirectiveFunctionalTypeInvalidB : TestFormat_Op<[{ functional-type }]>; -// CHECK: error: expected directive, literal, variable, or optional group +// CHECK: error: expected literal, variable, directive, or optional group def DirectiveFunctionalTypeInvalidC : TestFormat_Op<[{ functional-type( }]>; @@ -105,7 +105,7 @@ def DirectiveFunctionalTypeInvalidD : TestFormat_Op<[{ functional-type(operands }]>; -// CHECK: error: expected directive, literal, variable, or optional group +// CHECK: error: expected literal, variable, directive, or optional group def DirectiveFunctionalTypeInvalidE : TestFormat_Op<[{ functional-type(operands, }]>; @@ -262,7 +262,7 @@ def DirectiveTypeInvalidA : TestFormat_Op<[{ type }]>; -// CHECK: error: expected directive, literal, variable, or optional group +// CHECK: error: expected literal, variable, directive, or optional group def DirectiveTypeInvalidB : TestFormat_Op<[{ type( }]>; @@ -278,7 +278,7 @@ //===----------------------------------------------------------------------===// // functional-type/type operands -// CHECK: error: literals may only be used in a top-level section of the format +// CHECK: error: literals may only be used in the top-level section of the format def DirectiveTypeZOperandInvalidA : TestFormat_Op<[{ type(`literal`) }]>; @@ -334,7 +334,7 @@ }]>; // CHECK: error: unexpected end of file in literal -// CHECK: error: expected directive, literal, variable, or optional group +// CHECK: error: expected literal, variable, directive, or optional group def LiteralInvalidD : TestFormat_Op<[{ ` }]>; @@ -352,15 +352,15 @@ def OptionalInvalidA : TestFormat_Op<[{ type(($attr^)?) attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; -// CHECK: error: expected directive, literal, variable, or optional group +// CHECK: error: expected literal, variable, directive, or optional group def OptionalInvalidB : TestFormat_Op<[{ () attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; -// CHECK: error: optional group specified no anchor element +// CHECK: error: optional group has no anchor element def OptionalInvalidC : TestFormat_Op<[{ ($attr)? attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; -// CHECK: error: first parsable element of an operand group must be an attribute, literal, operand, or region +// CHECK: error: first parsable element of an optional group must be a literal or variable def OptionalInvalidD : TestFormat_Op<[{ (type($operand) $operand^)? attr-dict }]>, Arguments<(ins Optional:$operand)>; @@ -370,15 +370,15 @@ }]>, Arguments<(ins OptionalAttr:$attr)>; // CHECK: error: only one element can be marked as the anchor of an optional group def OptionalInvalidF : TestFormat_Op<[{ - ($attr^ $attr2^) attr-dict + ($attr^ $attr2^)? attr-dict }]>, Arguments<(ins OptionalAttr:$attr, OptionalAttr:$attr2)>; // CHECK: error: only optional attributes can be used to anchor an optional group def OptionalInvalidG : TestFormat_Op<[{ - ($attr^) attr-dict + ($attr^)? attr-dict }]>, Arguments<(ins I64Attr:$attr)>; // CHECK: error: only variable length operands can be used within an optional group def OptionalInvalidH : TestFormat_Op<[{ - ($arg^) attr-dict + ($arg^)? attr-dict }]>, Arguments<(ins I64:$arg)>; // CHECK: error: only literals, types, and variables can be used within an optional group def OptionalInvalidI : TestFormat_Op<[{ @@ -386,7 +386,7 @@ }]>, Arguments<(ins Variadic:$arg)>; // CHECK: error: only literals, types, and variables can be used within an optional group def OptionalInvalidJ : TestFormat_Op<[{ - (attr-dict) + (attr-dict^)? }]>; // CHECK: error: expected '?' after optional group def OptionalInvalidK : TestFormat_Op<[{ @@ -404,7 +404,7 @@ def OptionalInvalidN : TestFormat_Op<[{ ($arg^): }]>, Arguments<(ins Variadic:$arg)>; -// CHECK: error: expected directive, literal, variable, or optional group +// CHECK: error: expected literal, variable, directive, or optional group def OptionalInvalidO : TestFormat_Op<[{ ($arg^):(`test` }]>, Arguments<(ins Variadic:$arg)>; diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -31,61 +31,12 @@ //===----------------------------------------------------------------------===// namespace { - -/// This class represents a single format element. -class Element { -public: - /// LLVM-style RTTI. - enum class Kind { - /// This element is a directive. - ParamsDirective, - StructDirective, - - /// This element is a literal. - Literal, - - /// This element is a variable. - Variable, - }; - Element(Kind kind) : kind(kind) {} - virtual ~Element() = default; - - /// Return the kind of this element. - Kind getKind() const { return kind; } - -private: - /// The kind of this element. - Kind kind; -}; - -/// This class represents an instance of a literal element. -class LiteralElement : public Element { -public: - LiteralElement(StringRef literal) - : Element(Kind::Literal), literal(literal) {} - - static bool classof(const Element *el) { - return el->getKind() == Kind::Literal; - } - - /// Get the literal spelling. - StringRef getSpelling() const { return literal; } - -private: - /// The spelling of the literal for this element. - StringRef literal; -}; - /// This class represents an instance of a variable element. A variable refers /// to an attribute or type parameter. -class VariableElement : public Element { +class ParameterElement + : public VariableElementBase { public: - VariableElement(AttrOrTypeParameter param) - : Element(Kind::Variable), param(param) {} - - static bool classof(const Element *el) { - return el->getKind() == Kind::Variable; - } + ParameterElement(AttrOrTypeParameter param) : param(param) {} /// Get the parameter in the element. const AttrOrTypeParameter &getParam() const { return param; } @@ -103,22 +54,18 @@ }; /// Base class for a directive that contains references to multiple variables. -template -class ParamsDirectiveBase : public Element { +template +class ParamsDirectiveBase : public DirectiveElementBase { public: - using Base = ParamsDirectiveBase; + using Base = ParamsDirectiveBase; - ParamsDirectiveBase(SmallVector> &¶ms) - : Element(ElementKind), params(std::move(params)) {} - - static bool classof(const Element *el) { - return el->getKind() == ElementKind; - } + ParamsDirectiveBase(std::vector &¶ms) + : params(std::move(params)) {} /// Get the parameters contained in this directive. auto getParams() const { - return llvm::map_range(params, [](auto &el) { - return cast(el.get())->getParam(); + return llvm::map_range(params, [](FormatElement *el) { + return cast(el)->getParam(); }); } @@ -126,13 +73,11 @@ unsigned getNumParams() const { return params.size(); } /// Take all of the parameters from this directive. - SmallVector> takeParams() { - return std::move(params); - } + std::vector takeParams() { return std::move(params); } private: /// The parameters captured by this directive. - SmallVector> params; + std::vector params; }; /// This class represents a `params` directive that refers to all parameters @@ -144,8 +89,7 @@ /// When used as an argument to another directive that accepts variables, /// `params` can be used in place of manually listing all parameters of an /// attribute or type. -class ParamsDirective - : public ParamsDirectiveBase { +class ParamsDirective : public ParamsDirectiveBase { public: using Base::Base; }; @@ -155,8 +99,7 @@ /// /// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}` /// -class StructDirective - : public ParamsDirectiveBase { +class StructDirective : public ParamsDirectiveBase { public: using Base::Base; }; @@ -237,7 +180,7 @@ class AttrOrTypeFormat { public: AttrOrTypeFormat(const AttrOrTypeDef &def, - std::vector> &&elements) + std::vector &&elements) : def(def), elements(std::move(elements)) {} /// Generate the attribute or type parser. @@ -247,7 +190,7 @@ private: /// Generate the parser code for a specific format element. - void genElementParser(Element *el, FmtContext &ctx, MethodBody &os); + void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a literal. void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a variable. @@ -259,7 +202,7 @@ void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a specific format element. - void genElementPrinter(Element *el, FmtContext &ctx, MethodBody &os); + void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a literal. void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a variable. @@ -275,7 +218,7 @@ const AttrOrTypeDef &def; /// The list of top-level format elements returned by the assembly format /// parser. - std::vector> elements; + std::vector elements; /// Flags for printing spaces. bool shouldEmitSpace = false; @@ -311,8 +254,8 @@ &ctx); /// Generate call to each parameter parser. - for (auto &el : elements) - genElementParser(el.get(), ctx, os); + for (FormatElement *el : elements) + genElementParser(el, ctx, os); /// Generate call to the attribute or type builder. Use the checked getter /// if one was generated. @@ -328,11 +271,11 @@ os << ");"; } -void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx, +void AttrOrTypeFormat::genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os) { if (auto *literal = dyn_cast(el)) return genLiteralParser(literal->getSpelling(), ctx, os); - if (auto *var = dyn_cast(el)) + if (auto *var = dyn_cast(el)) return genVariableParser(var->getParam(), ctx, os); if (auto *params = dyn_cast(el)) return genParamsParser(params, ctx, os); @@ -435,11 +378,11 @@ /// Generate printers. shouldEmitSpace = true; lastWasPunctuation = false; - for (auto &el : elements) - genElementPrinter(el.get(), ctx, os); + for (FormatElement *el : elements) + genElementPrinter(el, ctx, os); } -void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx, +void AttrOrTypeFormat::genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os) { if (auto *literal = dyn_cast(el)) return genLiteralPrinter(literal->getSpelling(), ctx, os); @@ -447,7 +390,7 @@ return genParamsPrinter(params, ctx, os); if (auto *strct = dyn_cast(el)) return genStructPrinter(strct, ctx, os); - if (auto *var = dyn_cast(el)) + if (auto *var = dyn_cast(el)) return genVariablePrinter(var->getParam(), ctx, os, var->shouldBeQualified()); @@ -492,7 +435,7 @@ llvm::interleave( el->getParams(), [&](auto param) { this->genVariablePrinter(param, ctx, os); }, - [&]() { this->genLiteralPrinter(",", ctx, os); }); + [&] { this->genLiteralPrinter(",", ctx, os); }); } void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, @@ -504,75 +447,54 @@ this->genLiteralPrinter("=", ctx, os); this->genVariablePrinter(param, ctx, os); }, - [&]() { this->genLiteralPrinter(",", ctx, os); }); + [&] { this->genLiteralPrinter(",", ctx, os); }); } //===----------------------------------------------------------------------===// -// FormatParser +// DefFormatParser //===----------------------------------------------------------------------===// namespace { -class FormatParser { +class DefFormatParser : public FormatParser { public: - FormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def) - : lexer(mgr, def.getLoc()[0]), curToken(lexer.lexToken()), def(def), + DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def) + : FormatParser(mgr, def.getLoc()[0]), def(def), seenParams(def.getNumParameters()) {} /// Parse the attribute or type format and create the format elements. FailureOr parse(); -private: - /// The current context of the parser when parsing an element. - enum ParserContext { - /// The element is being parsed in the default context - at the top of the - /// format - TopLevelContext, - /// The element is being parsed as a child to a `struct` directive. - StructDirective, - }; - - /// Emit an error. - LogicalResult emitError(const Twine &msg) { - lexer.emitError(curToken.getLoc(), msg); - return failure(); +protected: + /// Verify the parsed elements. + LogicalResult verify(SMLoc loc, ArrayRef elements) override; + /// Verify the elements of a custom directive. + LogicalResult + verifyCustomDirectiveArguments(SMLoc loc, + ArrayRef arguments) override { + return emitError(loc, "'custom' not supported (yet)"); } - - /// Parse an expected token. - LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) { - if (curToken.getKind() != kind) - return emitError(msg); - consumeToken(); - return success(); + /// Verify the elements of an optional group. + LogicalResult + verifyOptionalGroupElements(SMLoc loc, ArrayRef elements, + Optional anchorIndex) override { + return emitError(loc, "optional groups not (yet) supported"); } - /// Advance the lexer to the next token. - void consumeToken() { - assert(curToken.getKind() != FormatToken::eof && - curToken.getKind() != FormatToken::error && - "shouldn't advance past EOF or errors"); - curToken = lexer.lexToken(); - } + /// Parse an attribute or type variable. + FailureOr parseVariableImpl(SMLoc loc, StringRef name, + Context ctx) override; + /// Parse an attribute or type format directive. + FailureOr + parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override; - /// Parse any element. - FailureOr> parseElement(ParserContext ctx); - /// Parse a literal element. - FailureOr> parseLiteral(ParserContext ctx); - /// Parse a variable element. - FailureOr> parseVariable(ParserContext ctx); - /// Parse a directive. - FailureOr> parseDirective(ParserContext ctx); +private: /// Parse a `params` directive. - FailureOr> parseParamsDirective(); + FailureOr parseParamsDirective(SMLoc loc); /// Parse a `qualified` directive. - FailureOr> - parseQualifiedDirective(ParserContext ctx); + FailureOr parseQualifiedDirective(SMLoc loc, Context ctx); /// Parse a `struct` directive. - FailureOr> parseStructDirective(); + FailureOr parseStructDirective(SMLoc loc); - /// The current format lexer. - FormatLexer lexer; - /// The current token in the stream. - FormatToken curToken; /// Attribute or type tablegen def. const AttrOrTypeDef &def; @@ -581,170 +503,132 @@ }; } // namespace -FailureOr FormatParser::parse() { - std::vector> elements; - elements.reserve(16); - - /// Parse the format elements. - while (curToken.getKind() != FormatToken::eof) { - auto element = parseElement(TopLevelContext); - if (failed(element)) - return failure(); - - /// Add the format element and continue. - elements.push_back(std::move(*element)); - } - - /// Check that all parameters have been seen. +LogicalResult DefFormatParser::verify(SMLoc loc, + ArrayRef elements) { for (auto &it : llvm::enumerate(def.getParameters())) { if (!seenParams.test(it.index())) { - return emitError("format is missing reference to parameter: " + - it.value().getName()); + return emitError(loc, "format is missing reference to parameter: " + + it.value().getName()); } } - - return AttrOrTypeFormat(def, std::move(elements)); + return success(); } -FailureOr> -FormatParser::parseElement(ParserContext ctx) { - if (curToken.getKind() == FormatToken::literal) - return parseLiteral(ctx); - if (curToken.getKind() == FormatToken::variable) - return parseVariable(ctx); - if (curToken.isKeyword()) - return parseDirective(ctx); - - return emitError("expected literal, directive, or variable"); -} - -FailureOr> -FormatParser::parseLiteral(ParserContext ctx) { - if (ctx != TopLevelContext) { - return emitError( - "literals may only be used in the top-level section of the format"); - } - - /// Get the literal spelling without the surrounding "`". - auto value = curToken.getSpelling().drop_front().drop_back(); - if (!isValidLiteral(value, [&](Twine diag) { - (void)emitError("expected valid literal but got '" + value + - "': " + diag); - })) +FailureOr DefFormatParser::parse() { + FailureOr> elements = FormatParser::parse(); + if (failed(elements)) return failure(); - - consumeToken(); - return {std::make_unique(value)}; + return AttrOrTypeFormat(def, std::move(*elements)); } -FailureOr> -FormatParser::parseVariable(ParserContext ctx) { - /// Get the parameter name without the preceding "$". - auto name = curToken.getSpelling().drop_front(); - +FailureOr +DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { /// Lookup the parameter. ArrayRef params = def.getParameters(); auto *it = llvm::find_if( params, [&](auto ¶m) { return param.getName() == name; }); /// Check that the parameter reference is valid. - if (it == params.end()) - return emitError(def.getName() + " has no parameter named '" + name + "'"); + if (it == params.end()) { + return emitError(loc, + def.getName() + " has no parameter named '" + name + "'"); + } auto idx = std::distance(params.begin(), it); if (seenParams.test(idx)) - return emitError("duplicate parameter '" + name + "'"); + return emitError(loc, "duplicate parameter '" + name + "'"); seenParams.set(idx); - consumeToken(); - return {std::make_unique(*it)}; + return create(*it); } -FailureOr> -FormatParser::parseDirective(ParserContext ctx) { +FailureOr +DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, + Context ctx) { - switch (curToken.getKind()) { + switch (kind) { case FormatToken::kw_qualified: - return parseQualifiedDirective(ctx); + return parseQualifiedDirective(loc, ctx); case FormatToken::kw_params: - return parseParamsDirective(); + return parseParamsDirective(loc); case FormatToken::kw_struct: if (ctx != TopLevelContext) { return emitError( + loc, "`struct` may only be used in the top-level section of the format"); } - return parseStructDirective(); + return parseStructDirective(loc); + default: - return emitError("unknown directive in format: " + curToken.getSpelling()); + return emitError(loc, "unsupported directive kind"); } } -FailureOr> -FormatParser::parseQualifiedDirective(ParserContext ctx) { - consumeToken(); +FailureOr +DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) { if (failed(parseToken(FormatToken::l_paren, "expected '(' before argument list"))) return failure(); - FailureOr> var = parseElement(ctx); + FailureOr var = parseElement(ctx); if (failed(var)) return var; - if (!isa(*var)) - return emitError("`qualified` argument list expected a variable"); - cast(var->get())->setShouldBeQualified(); + if (!isa(*var)) + return emitError(loc, "`qualified` argument list expected a variable"); + cast(*var)->setShouldBeQualified(); if (failed( parseToken(FormatToken::r_paren, "expected ')' after argument list"))) return failure(); return var; } -FailureOr> FormatParser::parseParamsDirective() { - consumeToken(); +FailureOr DefFormatParser::parseParamsDirective(SMLoc loc) { /// Collect all of the attribute's or type's parameters. - SmallVector> vars; + std::vector vars; /// Ensure that none of the parameters have already been captured. for (const auto &it : llvm::enumerate(def.getParameters())) { if (seenParams.test(it.index())) { - return emitError("`params` captures duplicate parameter: " + - it.value().getName()); + return emitError(loc, "`params` captures duplicate parameter: " + + it.value().getName()); } seenParams.set(it.index()); - vars.push_back(std::make_unique(it.value())); + vars.push_back(create(it.value())); } - return {std::make_unique(std::move(vars))}; + return create(std::move(vars)); } -FailureOr> FormatParser::parseStructDirective() { - consumeToken(); +FailureOr DefFormatParser::parseStructDirective(SMLoc loc) { if (failed(parseToken(FormatToken::l_paren, "expected '(' before `struct` argument list"))) return failure(); /// Parse variables captured by `struct`. - SmallVector> vars; + std::vector vars; /// Parse first captured parameter or a `params` directive. - FailureOr> var = parseElement(StructDirective); - if (failed(var) || !isa(*var)) - return emitError("`struct` argument list expected a variable or directive"); + FailureOr var = parseElement(StructDirectiveContext); + if (failed(var) || !isa(*var)) { + return emitError(loc, + "`struct` argument list expected a variable or directive"); + } if (isa(*var)) { /// Parse any other parameters. vars.push_back(std::move(*var)); - while (curToken.getKind() == FormatToken::comma) { + while (peekToken().is(FormatToken::comma)) { consumeToken(); - var = parseElement(StructDirective); + var = parseElement(StructDirectiveContext); if (failed(var) || !isa(*var)) - return emitError("expected a variable in `struct` argument list"); + return emitError(loc, "expected a variable in `struct` argument list"); vars.push_back(std::move(*var)); } } else { /// `struct(params)` captures all parameters in the attribute or type. - vars = cast(var->get())->takeParams(); + vars = cast(*var)->takeParams(); } - if (curToken.getKind() != FormatToken::r_paren) - return emitError("expected ')' at the end of an argument list"); + if (failed(parseToken(FormatToken::r_paren, + "expected ')' at the end of an argument list"))) + return failure(); - consumeToken(); - return {std::make_unique<::StructDirective>(std::move(vars))}; + return create(std::move(vars)); } //===----------------------------------------------------------------------===// @@ -756,11 +640,10 @@ MethodBody &printer) { llvm::SourceMgr mgr; mgr.AddNewSourceBuffer( - llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), - SMLoc()); + llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc()); /// Parse the custom assembly format> - FormatParser fmtParser(mgr, def); + DefFormatParser fmtParser(mgr, def); FailureOr format = fmtParser.parse(); if (failed(format)) { if (formatErrorIsFatal) diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -15,9 +15,13 @@ #define MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_ #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/SMLoc.h" +#include namespace llvm { class SourceMgr; @@ -85,6 +89,9 @@ /// Return a location for this token. SMLoc getLoc() const; + /// Returns true if the token is of the given kind. + bool is(Kind kind) { return getKind() == kind; } + /// Return if this token is a keyword. bool isKeyword() const { return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end; @@ -115,8 +122,7 @@ FormatToken emitError(SMLoc loc, const Twine &msg); FormatToken emitError(const char *loc, const Twine &msg); - FormatToken emitErrorAndNote(SMLoc loc, const Twine &msg, - const Twine ¬e); + FormatToken emitErrorAndNote(SMLoc loc, const Twine &msg, const Twine ¬e); private: /// Return the next character in the stream. @@ -142,6 +148,362 @@ const char *curPtr; }; +//===----------------------------------------------------------------------===// +// FormatElement +//===----------------------------------------------------------------------===// + +/// This class represents a single format element. +/// +/// If you squint and take a close look, you can see the outline of a `Format` +/// dialect. +class FormatElement { +public: + /// The top-level kinds of format elements. + enum Kind { Literal, Variable, Whitespace, Directive, Optional }; + + /// Support LLVM-style RTTI. + static bool classof(const FormatElement *el) { return true; } + + /// Get the element kind. + Kind getKind() const { return kind; } + +protected: + /// Create a format element with the given kind. + FormatElement(Kind kind) : kind(kind) {} + +private: + /// The kind of the element. + Kind kind; +}; + +/// The base class for all format elements. This class implements common methods +/// for LLVM-style RTTI. +template +class FormatElementBase : public FormatElement { +public: + /// Support LLVM-style RTTI. + static bool classof(const FormatElement *el) { + return ElementKind == el->getKind(); + } + +protected: + /// Create a format element with the given kind. + FormatElementBase() : FormatElement(ElementKind) {} +}; + +/// This class represents a literal element. A literal is either one of the +/// supported punctuation characters (e.g. `(` or `,`) or a string literal (e.g. +/// `literal`). +class LiteralElement : public FormatElementBase { +public: + /// Create a literal element with the given spelling. + explicit LiteralElement(StringRef spelling) : spelling(spelling) {} + + /// Get the spelling of the literal. + StringRef getSpelling() const { return spelling; } + +private: + /// The spelling of the variable, i.e. the string contained within the + /// backticks. + StringRef spelling; +}; + +/// This class represents a variable element. A variable refers to some part of +/// the object being parsed, e.g. an attribute or operand on an operation or a +/// parameter on an attribute. +class VariableElement : public FormatElementBase { +public: + /// These are the kinds of variables. + enum Kind { Attribute, Operand, Region, Result, Successor, Parameter }; + + /// Get the kind of variable. + Kind getKind() const { return kind; } + +protected: + /// Create a variable with a kind. + VariableElement(Kind kind) : kind(kind) {} + +private: + /// The kind of variable. + Kind kind; +}; + +/// Base class for variable elements. This class implements common methods for +/// LLVM-style RTTI. +template +class VariableElementBase : public VariableElement { +public: + /// An element is of this class if it is a variable and has the same variable + /// type. + static bool classof(const FormatElement *el) { + if (auto *varEl = dyn_cast(el)) + return VariableKind == varEl->getKind(); + return false; + } + +protected: + /// Create a variable element with the given variable kind. + VariableElementBase() : VariableElement(VariableKind) {} +}; + +/// This class represents a whitespace element, e.g. a newline or space. It is a +/// literal that is printed but never parsed. When the value is empty, i.e. ``, +/// a space is elided where one would have been printed automatically. +class WhitespaceElement : public FormatElementBase { +public: + /// Create a whitespace element. + explicit WhitespaceElement(StringRef value) : value(value) {} + + /// Get the whitespace value. + StringRef getValue() const { return value; } + +private: + /// The value of the whitespace element. Can be empty. + StringRef value; +}; + +class DirectiveElement : public FormatElementBase { +public: + /// These are the kinds of directives. + enum Kind { + AttrDict, + Custom, + FunctionalType, + Operands, + Ref, + Regions, + Results, + Successors, + Type, + Params, + Struct + }; + + /// Get the directive kind. + Kind getKind() const { return kind; } + +protected: + /// Create a directive element with a kind. + DirectiveElement(Kind kind) : kind(kind) {} + +private: + /// The directive kind. + Kind kind; +}; + +/// Base class for directive elements. This class implements common methods for +/// LLVM-style RTTI. +template +class DirectiveElementBase : public DirectiveElement { +public: + /// Create a directive element with the specified kind. + DirectiveElementBase() : DirectiveElement(DirectiveKind) {} + + /// A format element is of this class if it is a directive element and has the + /// same kind. + static bool classof(const FormatElement *el) { + if (auto *directiveEl = dyn_cast(el)) + return DirectiveKind == directiveEl->getKind(); + return false; + } +}; + +/// This class represents a custom format directive that is implemented by the +/// user in C++. The directive accepts a list of arguments that is passed to the +/// C++ function. +class CustomDirective : public DirectiveElementBase { +public: + /// Create a custom directive with a name and list of arguments. + CustomDirective(StringRef name, std::vector &&arguments) + : name(name), arguments(std::move(arguments)) {} + + /// Get the custom directive name. + StringRef getName() const { return name; } + + /// Get the arguments to the custom directive. + ArrayRef getArguments() const { return arguments; } + +private: + /// The name of the custom directive. The name is used to call two C++ + /// methods: `parse{name}` and `print{name}` with the given arguments. + StringRef name; + /// The arguments with which to call the custom functions. These are either + /// variables (for which the functions are responsible for populating) or + /// references to variables. + std::vector arguments; +}; + +/// This class represents a group of elements that are optionally emitted based +/// on an optional variable "anchor" and a group of elements that are emitted +/// when the anchor element is not present. +class OptionalElement : public FormatElementBase { +public: + /// Create an optional group with the given child elements. + OptionalElement(std::vector &&thenElements, + std::vector &&elseElements, + unsigned anchorIndex, unsigned parseStart) + : thenElements(std::move(thenElements)), + elseElements(std::move(elseElements)), anchorIndex(anchorIndex), + parseStart(parseStart) {} + + /// Return the `then` elements of the optional group. + ArrayRef getThenElements() const { return thenElements; } + + /// Return the `else` elements of the optional group. + ArrayRef getElseElements() const { return elseElements; } + + /// Return the anchor of the optional group. + FormatElement *getAnchor() const { return thenElements[anchorIndex]; } + + /// Return the index of the first element to be parsed. + unsigned getParseStart() const { return parseStart; } + +private: + /// The child elements emitted when the anchor is present. + std::vector thenElements; + /// The child elements emitted when the anchor is not present. + std::vector elseElements; + /// The index of the anchor element of the optional group within + /// `thenElements`. + unsigned anchorIndex; + /// The index of the first element that is parsed in `thenElements`. That is, + /// the first non-whitespace element. + unsigned parseStart; +}; + +//===----------------------------------------------------------------------===// +// FormatParserBase +//===----------------------------------------------------------------------===// + +/// Base class for a parser that implements an assembly format. This class +/// defines a common assembly format syntax and the creation of format elements. +/// Subclasses will need to implement parsing for the format elements they +/// support. +class FormatParser { +public: + /// Vtable anchor. + virtual ~FormatParser(); + + /// Parse the assembly format. + FailureOr> parse(); + +protected: + /// The current context of the parser when parsing an element. + enum Context { + /// The element is being parsed in a "top-level" context, i.e. at the top of + /// the format or in an optional group. + TopLevelContext, + /// The element is being parsed as a custom directive child. + CustomDirectiveContext, + /// The element is being parsed as a type directive child. + TypeDirectiveContext, + /// The element is being parsed as a reference directive child. + RefDirectiveContext, + /// The element is being parsed as a struct directive child. + StructDirectiveContext + }; + + /// Create a format parser with the given source manager and a location. + explicit FormatParser(llvm::SourceMgr &mgr, llvm::SMLoc loc) + : lexer(mgr, loc), curToken(lexer.lexToken()) {} + + /// Allocate and construct a format element. + template + FormatElementT *create(Args &&...args) { + FormatElementT *ptr = allocator.Allocate(); + ::new (ptr) FormatElementT(std::forward(args)...); + return ptr; + } + + //===--------------------------------------------------------------------===// + // Element Parsing + + /// Parse a single element of any kind. + FailureOr parseElement(Context ctx); + /// Parse a literal. + FailureOr parseLiteral(Context ctx); + /// Parse a variable. + FailureOr parseVariable(Context ctx); + /// Parse a directive. + FailureOr parseDirective(Context ctx); + /// Parse an optional group. + FailureOr parseOptionalGroup(Context ctx); + + /// Parse a custom directive. + FailureOr parseCustomDirective(llvm::SMLoc loc, Context ctx); + + /// Parse a format-specific variable kind. + virtual FailureOr + parseVariableImpl(llvm::SMLoc loc, StringRef name, Context ctx) = 0; + /// Parse a format-specific directive kind. + virtual FailureOr + parseDirectiveImpl(llvm::SMLoc loc, FormatToken::Kind kind, Context ctx) = 0; + + //===--------------------------------------------------------------------===// + // Format Verification + + /// Verify that the format is well-formed. + virtual LogicalResult verify(llvm::SMLoc loc, + ArrayRef elements) = 0; + /// Verify the arguments to a custom directive. + virtual LogicalResult + verifyCustomDirectiveArguments(llvm::SMLoc loc, + ArrayRef arguments) = 0; + /// Verify the elements of an optional group. + virtual LogicalResult + verifyOptionalGroupElements(llvm::SMLoc loc, + ArrayRef elements, + Optional anchorIndex) = 0; + + //===--------------------------------------------------------------------===// + // Lexer Utilities + + /// Emit an error at the given location. + LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) { + lexer.emitError(loc, msg); + return failure(); + } + + /// Emit an error and a note at the given notation. + LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, + const Twine ¬e) { + lexer.emitErrorAndNote(loc, msg, note); + return failure(); + } + + /// Parse a single token of the expected kind. + FailureOr parseToken(FormatToken::Kind kind, const Twine &msg) { + if (!curToken.is(kind)) + return emitError(curToken.getLoc(), msg); + FormatToken tok = curToken; + consumeToken(); + return tok; + } + + /// Advance the lexer to the next token. + void consumeToken() { + assert(!curToken.is(FormatToken::eof) && !curToken.is(FormatToken::error) && + "shouldn't advance past EOF or errors"); + curToken = lexer.lexToken(); + } + + /// Get the current token. + FormatToken peekToken() { return curToken; } + +private: + /// The format parser retains ownership of the format elements in a bump + /// pointer allocator. + llvm::BumpPtrAllocator allocator; + /// The format lexer to use. + FormatLexer lexer; + /// The current token in the lexer. + FormatToken curToken; +}; + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + /// Whether a space needs to be emitted before a literal. E.g., two keywords /// back-to-back require a space separator, but a keyword followed by '<' does /// not require a space. diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -177,6 +177,201 @@ return FormatToken(kind, str); } +//===----------------------------------------------------------------------===// +// FormatParser +//===----------------------------------------------------------------------===// + +FormatParser::~FormatParser() = default; + +FailureOr> FormatParser::parse() { + SMLoc loc = curToken.getLoc(); + + // Parse each of the format elements into the main format. + std::vector elements; + while (curToken.getKind() != FormatToken::eof) { + FailureOr element = parseElement(TopLevelContext); + if (failed(element)) + return failure(); + elements.push_back(*element); + } + + // Verify the format. + if (failed(verify(loc, elements))) + return failure(); + return elements; +} + +//===----------------------------------------------------------------------===// +// Element Parsing + +FailureOr FormatParser::parseElement(Context ctx) { + if (curToken.is(FormatToken::literal)) + return parseLiteral(ctx); + if (curToken.is(FormatToken::variable)) + return parseVariable(ctx); + if (curToken.isKeyword()) + return parseDirective(ctx); + if (curToken.is(FormatToken::l_paren)) + return parseOptionalGroup(ctx); + return emitError(curToken.getLoc(), + "expected literal, variable, directive, or optional group"); +} + +FailureOr FormatParser::parseLiteral(Context ctx) { + FormatToken tok = curToken; + SMLoc loc = tok.getLoc(); + consumeToken(); + + if (ctx != TopLevelContext) { + return emitError( + loc, + "literals may only be used in the top-level section of the format"); + } + // Get the spelling without the surrounding backticks. + StringRef value = tok.getSpelling().drop_front().drop_back(); + + // The parsed literal is a space element (`` or ` `) or a newline. + if (value.empty() || value == " " || value == "\\n") + return create(value); + + // Check that the parsed literal is valid. + if (!isValidLiteral(value, [&](Twine msg) { + (void)emitError(loc, "expected valid literal but got '" + value + + "': " + msg); + })) + return failure(); + return create(value); +} + +FailureOr FormatParser::parseVariable(Context ctx) { + FormatToken tok = curToken; + SMLoc loc = tok.getLoc(); + consumeToken(); + + // Get the name of the variable without the leading `$`. + StringRef name = tok.getSpelling().drop_front(); + return parseVariableImpl(loc, name, ctx); +} + +FailureOr FormatParser::parseDirective(Context ctx) { + FormatToken tok = curToken; + SMLoc loc = tok.getLoc(); + consumeToken(); + + if (tok.is(FormatToken::kw_custom)) + return parseCustomDirective(loc, ctx); + return parseDirectiveImpl(loc, tok.getKind(), ctx); +} + +FailureOr FormatParser::parseOptionalGroup(Context ctx) { + SMLoc loc = curToken.getLoc(); + consumeToken(); + if (ctx != TopLevelContext) { + return emitError(loc, + "optional groups can only be used as top-level elements"); + } + + // Parse the child elements for this optional group. + std::vector thenElements, elseElements; + Optional anchorIndex; + do { + FailureOr element = parseElement(TopLevelContext); + if (failed(element)) + return failure(); + // Check for an anchor. + if (curToken.is(FormatToken::caret)) { + if (anchorIndex) + return emitError(curToken.getLoc(), "only one element can be marked as " + "the anchor of an optional group"); + anchorIndex = thenElements.size(); + consumeToken(); + } + thenElements.push_back(*element); + } while (!curToken.is(FormatToken::r_paren)); + consumeToken(); + + // Parse the `else` elements of this optional group. + if (curToken.is(FormatToken::colon)) { + consumeToken(); + if (failed( + parseToken(FormatToken::l_paren, + "expected '(' to start else branch of optional group"))) + return failure(); + do { + FailureOr element = parseElement(TopLevelContext); + if (failed(element)) + return failure(); + elseElements.push_back(*element); + } while (!curToken.is(FormatToken::r_paren)); + consumeToken(); + } + if (failed(parseToken(FormatToken::question, + "expected '?' after optional group"))) + return failure(); + + // The optional group is required to have an anchor. + if (!anchorIndex) + return emitError(loc, "optional group has no anchor element"); + + // Verify the child elements. + if (failed(verifyOptionalGroupElements(loc, thenElements, anchorIndex)) || + failed(verifyOptionalGroupElements(loc, elseElements, llvm::None))) + return failure(); + + // Get the first parsable element. It must be an element that can be + // optionally-parsed. + auto parseBegin = llvm::find_if_not(thenElements, [](FormatElement *element) { + return isa(element); + }); + if (!isa(*parseBegin)) { + return emitError(loc, "first parsable element of an optional group must be " + "a literal or variable"); + } + + unsigned parseStart = std::distance(thenElements.begin(), parseBegin); + return create(std::move(thenElements), + std::move(elseElements), *anchorIndex, + parseStart); +} + +FailureOr FormatParser::parseCustomDirective(SMLoc loc, + Context ctx) { + if (ctx != TopLevelContext) + return emitError(loc, "'custom' is only valid as a top-level directive"); + + FailureOr nameTok; + if (failed(parseToken(FormatToken::less, + "expected '<' before custom directive name")) || + failed(nameTok = + parseToken(FormatToken::identifier, + "expected custom directive name identifier")) || + failed(parseToken(FormatToken::greater, + "expected '>' after custom directive name")) || + failed(parseToken(FormatToken::l_paren, + "expected '(' before custom directive parameters"))) + return failure(); + + // Parse the arguments. + std::vector arguments; + while (true) { + FailureOr argument = parseElement(CustomDirectiveContext); + if (failed(argument)) + return failure(); + arguments.push_back(*argument); + if (!curToken.is(FormatToken::comma)) + break; + consumeToken(); + } + + if (failed(parseToken(FormatToken::r_paren, + "expected ')' after custom directive parameters"))) + return failure(); + + if (failed(verifyCustomDirectiveArguments(loc, arguments))) + return failure(); + return create(nameTok->getSpelling(), std::move(arguments)); +} + //===----------------------------------------------------------------------===// // Utility Functions //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -9,6 +9,7 @@ #include "OpFormatGen.h" #include "FormatGen.h" #include "OpClass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/Class.h" #include "mlir/TableGen/Format.h" @@ -33,84 +34,37 @@ using namespace mlir::tblgen; //===----------------------------------------------------------------------===// -// Element -//===----------------------------------------------------------------------===// +// VariableElement namespace { -/// This class represents a single format element. -class Element { +/// This class represents an instance of an op variable element. A variable +/// refers to something registered on the operation itself, e.g. an operand, +/// result, attribute, region, or successor. +template +class OpVariableElement : public VariableElementBase { public: - enum class Kind { - /// This element is a directive. - AttrDictDirective, - CustomDirective, - FunctionalTypeDirective, - OperandsDirective, - RefDirective, - RegionsDirective, - ResultsDirective, - SuccessorsDirective, - TypeDirective, - - /// This element is a literal. - Literal, - - /// This element is a whitespace. - Newline, - Space, - - /// This element is an variable value. - AttributeVariable, - OperandVariable, - RegionVariable, - ResultVariable, - SuccessorVariable, - - /// This element is an optional element. - Optional, - }; - Element(Kind kind) : kind(kind) {} - virtual ~Element() = default; + using Base = OpVariableElement; - /// Return the kind of this element. - Kind getKind() const { return kind; } - -private: - /// The kind of this element. - Kind kind; -}; -} // namespace - -//===----------------------------------------------------------------------===// -// VariableElement + /// Create an op variable element with the variable value. + OpVariableElement(const VarT *var) : var(var) {} -namespace { -/// This class represents an instance of an variable element. A variable refers -/// to something registered on the operation itself, e.g. an argument, result, -/// etc. -template -class VariableElement : public Element { -public: - VariableElement(const VarT *var) : Element(kindVal), var(var) {} - static bool classof(const Element *element) { - return element->getKind() == kindVal; - } + /// Get the variable. const VarT *getVar() { return var; } protected: + /// The op variable, e.g. a type or attribute constraint. const VarT *var; }; /// This class represents a variable that refers to an attribute argument. struct AttributeVariable - : public VariableElement { - using VariableElement::VariableElement; + : public OpVariableElement { + using Base::Base; /// Return the constant builder call for the type of this attribute, or None /// if it doesn't have one. - Optional getTypeBuilder() const { - Optional attrType = var->attr.getValueType(); + llvm::Optional getTypeBuilder() const { + llvm::Optional attrType = var->attr.getValueType(); return attrType ? attrType->getBuilderCall() : llvm::None; } @@ -132,54 +86,49 @@ /// This class represents a variable that refers to an operand argument. using OperandVariable = - VariableElement; - -/// This class represents a variable that refers to a region. -using RegionVariable = - VariableElement; + OpVariableElement; /// This class represents a variable that refers to a result. using ResultVariable = - VariableElement; + OpVariableElement; + +/// This class represents a variable that refers to a region. +using RegionVariable = OpVariableElement; /// This class represents a variable that refers to a successor. using SuccessorVariable = - VariableElement; + OpVariableElement; } // namespace //===----------------------------------------------------------------------===// // DirectiveElement namespace { -/// This class implements single kind directives. -template class DirectiveElement : public Element { -public: - DirectiveElement() : Element(type){}; - static bool classof(const Element *ele) { return ele->getKind() == type; } -}; /// This class represents the `operands` directive. This directive represents /// all of the operands of an operation. -using OperandsDirective = DirectiveElement; - -/// This class represents the `regions` directive. This directive represents -/// all of the regions of an operation. -using RegionsDirective = DirectiveElement; +using OperandsDirective = DirectiveElementBase; /// This class represents the `results` directive. This directive represents /// all of the results of an operation. -using ResultsDirective = DirectiveElement; +using ResultsDirective = DirectiveElementBase; + +/// This class represents the `regions` directive. This directive represents +/// all of the regions of an operation. +using RegionsDirective = DirectiveElementBase; /// This class represents the `successors` directive. This directive represents /// all of the successors of an operation. -using SuccessorsDirective = - DirectiveElement; +using SuccessorsDirective = DirectiveElementBase; /// This class represents the `attr-dict` directive. This directive represents /// the attribute dictionary of the operation. class AttrDictDirective - : public DirectiveElement { + : public DirectiveElementBase { public: explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {} + + /// Return whether the dictionary should be printed with the 'attributes' + /// keyword. bool isWithKeyword() const { return withKeyword; } private: @@ -187,66 +136,41 @@ bool withKeyword; }; -/// This class represents a custom format directive that is implemented by the -/// user in C++. -class CustomDirective : public Element { -public: - CustomDirective(StringRef name, - std::vector> &&arguments) - : Element{Kind::CustomDirective}, name(name), - arguments(std::move(arguments)) {} - - static bool classof(const Element *element) { - return element->getKind() == Kind::CustomDirective; - } - - /// Return the name of the custom directive. - StringRef getName() const { return name; } - - /// Return the arguments to the custom directive. - auto getArguments() const { return llvm::make_pointee_range(arguments); } - -private: - /// The user provided name of the directive. - StringRef name; - - /// The arguments to the custom directive. - std::vector> arguments; -}; - /// This class represents the `functional-type` directive. This directive takes /// two arguments and formats them, respectively, as the inputs and results of a /// FunctionType. class FunctionalTypeDirective - : public DirectiveElement { + : public DirectiveElementBase { public: - FunctionalTypeDirective(std::unique_ptr inputs, - std::unique_ptr results) - : inputs(std::move(inputs)), results(std::move(results)) {} - Element *getInputs() const { return inputs.get(); } - Element *getResults() const { return results.get(); } + FunctionalTypeDirective(FormatElement *inputs, FormatElement *results) + : inputs(inputs), results(results) {} + + FormatElement *getInputs() const { return inputs; } + FormatElement *getResults() const { return results; } private: /// The input and result arguments. - std::unique_ptr inputs, results; + FormatElement *inputs, *results; }; /// This class represents the `ref` directive. -class RefDirective : public DirectiveElement { +class RefDirective : public DirectiveElementBase { public: - RefDirective(std::unique_ptr arg) : operand(std::move(arg)) {} - Element *getOperand() const { return operand.get(); } + RefDirective(FormatElement *arg) : arg(arg) {} + + FormatElement *getArg() const { return arg; } private: - /// The operand that is used to format the directive. - std::unique_ptr operand; + /// The argument that is used to format the directive. + FormatElement *arg; }; /// This class represents the `type` directive. -class TypeDirective : public DirectiveElement { +class TypeDirective : public DirectiveElementBase { public: - TypeDirective(std::unique_ptr arg) : operand(std::move(arg)) {} - Element *getOperand() const { return operand.get(); } + TypeDirective(FormatElement *arg) : arg(arg) {} + + FormatElement *getArg() const { return arg; } /// Indicate if this type is printed "qualified" (that is it is /// prefixed with the `!dialect.mnemonic`). @@ -256,126 +180,13 @@ } private: - /// The operand that is used to format the directive. - std::unique_ptr operand; + /// The argument that is used to format the directive. + FormatElement *arg; bool shouldBeQualifiedFlag = false; }; } // namespace -//===----------------------------------------------------------------------===// -// LiteralElement - -namespace { -/// This class represents an instance of a literal element. -class LiteralElement : public Element { -public: - LiteralElement(StringRef literal) - : Element{Kind::Literal}, literal(literal) {} - static bool classof(const Element *element) { - return element->getKind() == Kind::Literal; - } - - /// Return the literal for this element. - StringRef getLiteral() const { return literal; } - -private: - /// The spelling of the literal for this element. - StringRef literal; -}; -} // namespace - -//===----------------------------------------------------------------------===// -// WhitespaceElement - -namespace { -/// This class represents a whitespace element, e.g. newline or space. It's a -/// literal that is printed but never parsed. -class WhitespaceElement : public Element { -public: - WhitespaceElement(Kind kind) : Element{kind} {} - static bool classof(const Element *element) { - Kind kind = element->getKind(); - return kind == Kind::Newline || kind == Kind::Space; - } -}; - -/// This class represents an instance of a newline element. It's a literal that -/// prints a newline. It is ignored by the parser. -class NewlineElement : public WhitespaceElement { -public: - NewlineElement() : WhitespaceElement(Kind::Newline) {} - static bool classof(const Element *element) { - return element->getKind() == Kind::Newline; - } -}; - -/// This class represents an instance of a space element. It's a literal that -/// prints or omits printing a space. It is ignored by the parser. -class SpaceElement : public WhitespaceElement { -public: - SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {} - static bool classof(const Element *element) { - return element->getKind() == Kind::Space; - } - - /// Returns true if this element should print as a space. Otherwise, the - /// element should omit printing a space between the surrounding elements. - bool getValue() const { return value; } - -private: - bool value; -}; -} // namespace - -//===----------------------------------------------------------------------===// -// OptionalElement - -namespace { -/// This class represents a group of elements that are optionally emitted based -/// upon an optional variable of the operation, and a group of elements that are -/// emotted when the anchor element is not present. -class OptionalElement : public Element { -public: - OptionalElement(std::vector> &&thenElements, - std::vector> &&elseElements, - unsigned anchor, unsigned parseStart) - : Element{Kind::Optional}, thenElements(std::move(thenElements)), - elseElements(std::move(elseElements)), anchor(anchor), - parseStart(parseStart) {} - static bool classof(const Element *element) { - return element->getKind() == Kind::Optional; - } - - /// Return the `then` elements of this grouping. - auto getThenElements() const { - return llvm::make_pointee_range(thenElements); - } - - /// Return the `else` elements of this grouping. - auto getElseElements() const { - return llvm::make_pointee_range(elseElements); - } - - /// Return the anchor of this optional group. - Element *getAnchor() const { return thenElements[anchor].get(); } - - /// Return the index of the first element that needs to be parsed. - unsigned getParseStart() const { return parseStart; } - -private: - /// The child elements of `then` branch of this optional. - std::vector> thenElements; - /// The child elements of `else` branch of this optional. - std::vector> elseElements; - /// The index of the element that acts as the anchor for the optional group. - unsigned anchor; - /// The index of the first element that is parsed (is not a - /// WhitespaceElement). - unsigned parseStart; -}; -} // namespace - //===----------------------------------------------------------------------===// // OperationFormat //===----------------------------------------------------------------------===// @@ -450,7 +261,7 @@ /// Generate the operation parser from this format. void genParser(Operator &op, OpClass &opClass); /// Generate the parser code for a specific format element. - void genElementParser(Element *element, MethodBody &body, + void genElementParser(FormatElement *element, MethodBody &body, FmtContext &attrTypeCtx, GenContext genCtx = GenContext::Normal); /// Generate the C++ to resolve the types of operands and results during @@ -471,11 +282,11 @@ void genPrinter(Operator &op, OpClass &opClass); /// Generate the printer code for a specific format element. - void genElementPrinter(Element *element, MethodBody &body, Operator &op, + void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op, bool &shouldEmitSpace, bool &lastWasPunctuation); /// The various elements in this format. - std::vector> elements; + std::vector elements; /// A flag indicating if all operand/result types were seen. If the format /// contains these, it can not contain individual type resolvers. @@ -848,7 +659,8 @@ /// Get the name used for the type list for the given type directive operand. /// 'lengthKind' to the corresponding kind for the given argument. -static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) { +static StringRef getTypeListName(FormatElement *arg, + ArgumentLengthKind &lengthKind) { if (auto *operand = dyn_cast(arg)) { lengthKind = getArgumentLengthKind(operand->getVar()); return operand->getVar()->name; @@ -891,26 +703,26 @@ } /// Generate the storage code required for parsing the given element. -static void genElementParserStorage(Element *element, const Operator &op, +static void genElementParserStorage(FormatElement *element, const Operator &op, MethodBody &body) { if (auto *optional = dyn_cast(element)) { - auto elements = optional->getThenElements(); + ArrayRef elements = optional->getThenElements(); // If the anchor is a unit attribute, it won't be parsed directly so elide // it. auto *anchor = dyn_cast(optional->getAnchor()); - Element *elidedAnchorElement = nullptr; - if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr()) + FormatElement *elidedAnchorElement = nullptr; + if (anchor && anchor != elements.front() && anchor->isUnitAttr()) elidedAnchorElement = anchor; - for (auto &childElement : elements) - if (&childElement != elidedAnchorElement) - genElementParserStorage(&childElement, op, body); - for (auto &childElement : optional->getElseElements()) - genElementParserStorage(&childElement, op, body); + for (FormatElement *childElement : elements) + if (childElement != elidedAnchorElement) + genElementParserStorage(childElement, op, body); + for (FormatElement *childElement : optional->getElseElements()) + genElementParserStorage(childElement, op, body); } else if (auto *custom = dyn_cast(element)) { - for (auto ¶mElement : custom->getArguments()) - genElementParserStorage(¶mElement, op, body); + for (FormatElement *paramElement : custom->getArguments()) + genElementParserStorage(paramElement, op, body); } else if (isa(element)) { body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " @@ -972,7 +784,7 @@ } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; - StringRef name = getTypeListName(dir->getOperand(), lengthKind); + StringRef name = getTypeListName(dir->getArg(), lengthKind); if (lengthKind != ArgumentLengthKind::Single) body << " ::mlir::SmallVector<::mlir::Type, 1> " << name << "Types;\n"; else @@ -990,12 +802,12 @@ } /// Generate the parser for a parameter to a custom directive. -static void genCustomParameterParser(Element ¶m, MethodBody &body) { - if (auto *attr = dyn_cast(¶m)) { +static void genCustomParameterParser(FormatElement *param, MethodBody &body) { + if (auto *attr = dyn_cast(param)) { body << attr->getVar()->name << "Attr"; - } else if (isa(¶m)) { + } else if (isa(param)) { body << "result.attributes"; - } else if (auto *operand = dyn_cast(¶m)) { + } else if (auto *operand = dyn_cast(param)) { StringRef name = operand->getVar()->name; ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) @@ -1007,26 +819,26 @@ else body << formatv("{0}RawOperands[0]", name); - } else if (auto *region = dyn_cast(¶m)) { + } else if (auto *region = dyn_cast(param)) { StringRef name = region->getVar()->name; if (region->getVar()->isVariadic()) body << llvm::formatv("{0}Regions", name); else body << llvm::formatv("*{0}Region", name); - } else if (auto *successor = dyn_cast(¶m)) { + } else if (auto *successor = dyn_cast(param)) { StringRef name = successor->getVar()->name; if (successor->getVar()->isVariadic()) body << llvm::formatv("{0}Successors", name); else body << llvm::formatv("{0}Successor", name); - } else if (auto *dir = dyn_cast(¶m)) { - genCustomParameterParser(*dir->getOperand(), body); + } else if (auto *dir = dyn_cast(param)) { + genCustomParameterParser(dir->getArg(), body); - } else if (auto *dir = dyn_cast(¶m)) { + } else if (auto *dir = dyn_cast(param)) { ArgumentLengthKind lengthKind; - StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + StringRef listName = getTypeListName(dir->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) body << llvm::formatv("{0}TypeGroups", listName); else if (lengthKind == ArgumentLengthKind::Variadic) @@ -1048,48 +860,48 @@ // * Add a local variable for optional operands and types. This provides a // better API to the user defined parser methods. // * Set the location of operand variables. - for (Element ¶m : dir->getArguments()) { - if (auto *operand = dyn_cast(¶m)) { + for (FormatElement *param : dir->getArguments()) { + if (auto *operand = dyn_cast(param)) { auto *var = operand->getVar(); body << " " << var->name << "OperandsLoc = parser.getCurrentLocation();\n"; if (var->isOptional()) { body << llvm::formatv( - " llvm::Optional<::mlir::OpAsmParser::OperandType> " + " ::llvm::Optional<::mlir::OpAsmParser::OperandType> " "{0}Operand;\n", var->name); } else if (var->isVariadicOfVariadic()) { body << llvm::formatv(" " - "llvm::SmallVector> " "{0}OperandGroups;\n", var->name); } - } else if (auto *dir = dyn_cast(¶m)) { + } else if (auto *dir = dyn_cast(param)) { ArgumentLengthKind lengthKind; - StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + StringRef listName = getTypeListName(dir->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::Optional) { body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName); } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { body << llvm::formatv( - " llvm::SmallVector> " + " ::llvm::SmallVector> " "{0}TypeGroups;\n", listName); } - } else if (auto *dir = dyn_cast(¶m)) { - Element *input = dir->getOperand(); + } else if (auto *dir = dyn_cast(param)) { + FormatElement *input = dir->getArg(); if (auto *operand = dyn_cast(input)) { if (!operand->getVar()->isOptional()) continue; body << llvm::formatv( " {0} {1}Operand = {1}Operands.empty() ? {0}() : " "{1}Operands[0];\n", - "llvm::Optional<::mlir::OpAsmParser::OperandType>", + "::llvm::Optional<::mlir::OpAsmParser::OperandType>", operand->getVar()->name); } else if (auto *type = dyn_cast(input)) { ArgumentLengthKind lengthKind; - StringRef listName = getTypeListName(type->getOperand(), lengthKind); + StringRef listName = getTypeListName(type->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::Optional) { body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? " "::mlir::Type() : {0}Types[0];\n", @@ -1100,7 +912,7 @@ } body << " if (parse" << dir->getName() << "(parser"; - for (Element ¶m : dir->getArguments()) { + for (FormatElement *param : dir->getArguments()) { body << ", "; genCustomParameterParser(param, body); } @@ -1109,15 +921,15 @@ << " return ::mlir::failure();\n"; // After parsing, add handling for any of the optional constructs. - for (Element ¶m : dir->getArguments()) { - if (auto *attr = dyn_cast(¶m)) { + for (FormatElement *param : dir->getArguments()) { + if (auto *attr = dyn_cast(param)) { const NamedAttribute *var = attr->getVar(); if (var->attr.isOptional()) body << llvm::formatv(" if ({0}Attr)\n ", var->name); body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", var->name); - } else if (auto *operand = dyn_cast(¶m)) { + } else if (auto *operand = dyn_cast(param)) { const NamedTypeConstraint *var = operand->getVar(); if (var->isOptional()) { body << llvm::formatv(" if ({0}Operand.hasValue())\n" @@ -1131,9 +943,9 @@ " }\n", var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr()); } - } else if (auto *dir = dyn_cast(¶m)) { + } else if (auto *dir = dyn_cast(param)) { ArgumentLengthKind lengthKind; - StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + StringRef listName = getTypeListName(dir->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::Optional) { body << llvm::formatv(" if ({0}Type)\n" " {0}Types.push_back({0}Type);\n", @@ -1205,16 +1017,16 @@ // Generate variables to store the operands and type within the format. This // allows for referencing these variables in the presence of optional // groupings. - for (auto &element : elements) - genElementParserStorage(&*element, op, body); + for (FormatElement *element : elements) + genElementParserStorage(element, op, body); // A format context used when parsing attributes with buildable types. FmtContext attrTypeCtx; attrTypeCtx.withBuilder("parser.getBuilder()"); // Generate parsers for each of the elements. - for (auto &element : elements) - genElementParser(element.get(), body, attrTypeCtx); + for (FormatElement *element : elements) + genElementParser(element, body, attrTypeCtx); // Generate the code to resolve the operand/result types and successors now // that they have been parsed. @@ -1226,23 +1038,23 @@ body << " return ::mlir::success();\n"; } -void OperationFormat::genElementParser(Element *element, MethodBody &body, +void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, FmtContext &attrTypeCtx, GenContext genCtx) { /// Optional Group. if (auto *optional = dyn_cast(element)) { - auto elements = llvm::drop_begin(optional->getThenElements(), - optional->getParseStart()); + ArrayRef elements = + optional->getThenElements().drop_front(optional->getParseStart()); // Generate a special optional parser for the first element to gate the // parsing of the rest of the elements. - Element *firstElement = &*elements.begin(); + FormatElement *firstElement = elements.front(); if (auto *attrVar = dyn_cast(firstElement)) { genElementParser(attrVar, body, attrTypeCtx); body << " if (" << attrVar->getVar()->name << "Attr) {\n"; } else if (auto *literal = dyn_cast(firstElement)) { body << " if (succeeded(parser.parseOptional"; - genLiteralParser(literal->getLiteral(), body); + genLiteralParser(literal->getSpelling(), body); body << ")) {\n"; } else if (auto *opVar = dyn_cast(firstElement)) { genElementParser(opVar, body, attrTypeCtx); @@ -1265,7 +1077,7 @@ // If the anchor is a unit attribute, we don't need to print it. When // parsing, we will add this attribute if this group is present. - Element *elidedAnchorElement = nullptr; + FormatElement *elidedAnchorElement = nullptr; auto *anchorAttr = dyn_cast(optional->getAnchor()); if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) { elidedAnchorElement = anchorAttr; @@ -1277,20 +1089,17 @@ // Generate the rest of the elements inside an optional group. Elements in // an optional group after the guard are parsed as required. - for (Element &childElement : llvm::drop_begin(elements, 1)) { - if (&childElement != elidedAnchorElement) { - genElementParser(&childElement, body, attrTypeCtx, - GenContext::Optional); - } - } + for (FormatElement *childElement : llvm::drop_begin(elements, 1)) + if (childElement != elidedAnchorElement) + genElementParser(childElement, body, attrTypeCtx, GenContext::Optional); body << " }"; // Generate the else elements. auto elseElements = optional->getElseElements(); if (!elseElements.empty()) { body << " else {\n"; - for (Element &childElement : elseElements) - genElementParser(&childElement, body, attrTypeCtx); + for (FormatElement *childElement : elseElements) + genElementParser(childElement, body, attrTypeCtx); body << " }"; } body << "\n"; @@ -1298,7 +1107,7 @@ /// Literals. } else if (LiteralElement *literal = dyn_cast(element)) { body << " if (parser.parse"; - genLiteralParser(literal->getLiteral(), body); + genLiteralParser(literal->getSpelling(), body); body << ")\n return ::mlir::failure();\n"; /// Whitespaces. @@ -1398,7 +1207,7 @@ } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; - StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + StringRef listName = getTypeListName(dir->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { body << llvm::formatv(variadicOfVariadicTypeParserCode, listName); } else if (lengthKind == ArgumentLengthKind::Variadic) { @@ -1408,7 +1217,7 @@ } else { const char *parserCode = dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode; - TypeSwitch(dir->getOperand()) + TypeSwitch(dir->getArg()) .Case([&](auto operand) { body << formatv(parserCode, operand->getVar()->constraint.getCPPClassName(), @@ -1610,7 +1419,7 @@ MethodBody &body) { // Check for the case where all regions were parsed. bool hasAllRegions = llvm::any_of( - elements, [](auto &elt) { return isa(elt.get()); }); + elements, [](FormatElement *elt) { return isa(elt); }); if (hasAllRegions) { body << " result.addRegions(fullRegions);\n"; return; @@ -1628,8 +1437,9 @@ void OperationFormat::genParserSuccessorResolution(Operator &op, MethodBody &body) { // Check for the case where all successors were parsed. - bool hasAllSuccessors = llvm::any_of( - elements, [](auto &elt) { return isa(elt.get()); }); + bool hasAllSuccessors = llvm::any_of(elements, [](FormatElement *elt) { + return isa(elt); + }); if (hasAllSuccessors) { body << " result.addSuccessors(fullSuccessors);\n"; return; @@ -1773,7 +1583,7 @@ } /// Generate the printer for a custom directive parameter. -static void genCustomDirectiveParameterPrinter(Element *element, +static void genCustomDirectiveParameterPrinter(FormatElement *element, const Operator &op, MethodBody &body) { if (auto *attr = dyn_cast(element)) { @@ -1792,10 +1602,10 @@ body << op.getGetterName(successor->getVar()->name) << "()"; } else if (auto *dir = dyn_cast(element)) { - genCustomDirectiveParameterPrinter(dir->getOperand(), op, body); + genCustomDirectiveParameterPrinter(dir->getArg(), op, body); } else if (auto *dir = dyn_cast(element)) { - auto *typeOperand = dir->getOperand(); + auto *typeOperand = dir->getArg(); auto *operand = dyn_cast(typeOperand); auto *var = operand ? operand->getVar() : cast(typeOperand)->getVar(); @@ -1815,9 +1625,9 @@ static void genCustomDirectivePrinter(CustomDirective *customDir, const Operator &op, MethodBody &body) { body << " print" << customDir->getName() << "(_odsPrinter, *this"; - for (Element ¶m : customDir->getArguments()) { + for (FormatElement *param : customDir->getArguments()) { body << ", "; - genCustomDirectiveParameterPrinter(¶m, op, body); + genCustomDirectiveParameterPrinter(param, op, body); } body << ");\n"; } @@ -1841,7 +1651,7 @@ } /// Generate the C++ for an operand to a (*-)type directive. -static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op, +static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op, MethodBody &body, bool useArrayRef = true) { if (isa(arg)) @@ -1945,9 +1755,10 @@ } /// Generate the check for the anchor of an optional group. -static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op, +static void genOptionalGroupPrinterAnchor(FormatElement *anchor, + const Operator &op, MethodBody &body) { - TypeSwitch(anchor) + TypeSwitch(anchor) .Case([&](auto *element) { const NamedTypeConstraint *var = element->getVar(); std::string name = op.getGetterName(var->name); @@ -1963,7 +1774,7 @@ body << " if (!" << name << "().empty()) {\n"; }) .Case([&](TypeDirective *element) { - genOptionalGroupPrinterAnchor(element->getOperand(), op, body); + genOptionalGroupPrinterAnchor(element->getArg(), op, body); }) .Case([&](FunctionalTypeDirective *element) { genOptionalGroupPrinterAnchor(element->getInputs(), op, body); @@ -1974,42 +1785,45 @@ }); } -void OperationFormat::genElementPrinter(Element *element, MethodBody &body, - Operator &op, bool &shouldEmitSpace, +void OperationFormat::genElementPrinter(FormatElement *element, + MethodBody &body, Operator &op, + bool &shouldEmitSpace, bool &lastWasPunctuation) { if (LiteralElement *literal = dyn_cast(element)) - return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, + return genLiteralPrinter(literal->getSpelling(), body, shouldEmitSpace, lastWasPunctuation); // Emit a whitespace element. - if (isa(element)) { - body << " _odsPrinter.printNewline();\n"; + if (auto *space = dyn_cast(element)) { + if (space->getValue() == "\\n") { + body << " _odsPrinter.printNewline();\n"; + } else { + genSpacePrinter(!space->getValue().empty(), body, shouldEmitSpace, + lastWasPunctuation); + } return; } - if (SpaceElement *space = dyn_cast(element)) - return genSpacePrinter(space->getValue(), body, shouldEmitSpace, - lastWasPunctuation); // Emit an optional group. if (OptionalElement *optional = dyn_cast(element)) { // Emit the check for the presence of the anchor element. - Element *anchor = optional->getAnchor(); + FormatElement *anchor = optional->getAnchor(); genOptionalGroupPrinterAnchor(anchor, op, body); // If the anchor is a unit attribute, we don't need to print it. When // parsing, we will add this attribute if this group is present. auto elements = optional->getThenElements(); - Element *elidedAnchorElement = nullptr; + FormatElement *elidedAnchorElement = nullptr; auto *anchorAttr = dyn_cast(anchor); - if (anchorAttr && anchorAttr != &*elements.begin() && + if (anchorAttr && anchorAttr != elements.front() && anchorAttr->isUnitAttr()) { elidedAnchorElement = anchorAttr; } // Emit each of the elements. - for (Element &childElement : elements) { - if (&childElement != elidedAnchorElement) { - genElementPrinter(&childElement, body, op, shouldEmitSpace, + for (FormatElement *childElement : elements) { + if (childElement != elidedAnchorElement) { + genElementPrinter(childElement, body, op, shouldEmitSpace, lastWasPunctuation); } } @@ -2019,8 +1833,8 @@ auto elseElements = optional->getElseElements(); if (!elseElements.empty()) { body << " else {\n"; - for (Element &childElement : elseElements) { - genElementPrinter(&childElement, body, op, shouldEmitSpace, + for (FormatElement *childElement : elseElements) { + genElementPrinter(childElement, body, op, shouldEmitSpace, lastWasPunctuation); } body << " }"; @@ -2111,7 +1925,7 @@ body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), " "_odsPrinter);\n"; } else if (auto *dir = dyn_cast(element)) { - if (auto *operand = dyn_cast(dir->getOperand())) { + if (auto *operand = dyn_cast(dir->getArg())) { if (operand->getVar()->isVariadicOfVariadic()) { body << llvm::formatv( " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, " @@ -2123,9 +1937,9 @@ } const NamedTypeConstraint *var = nullptr; { - if (auto *operand = dyn_cast(dir->getOperand())) + if (auto *operand = dyn_cast(dir->getArg())) var = operand->getVar(); - else if (auto *operand = dyn_cast(dir->getOperand())) + else if (auto *operand = dyn_cast(dir->getArg())) var = operand->getVar(); } if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && @@ -2147,7 +1961,7 @@ return; } body << " _odsPrinter << "; - genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false) + genTypeOperandPrinter(dir->getArg(), op, body, /*useArrayRef=*/false) << ";\n"; } else if (auto *dir = dyn_cast(element)) { body << " _odsPrinter.printFunctionalType("; @@ -2167,13 +1981,12 @@ // Flags for if we should emit a space, and if the last element was // punctuation. bool shouldEmitSpace = true, lastWasPunctuation = false; - for (auto &element : elements) - genElementPrinter(element.get(), body, op, shouldEmitSpace, - lastWasPunctuation); + for (FormatElement *element : elements) + genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation); } //===----------------------------------------------------------------------===// -// FormatParser +// OpFormatParser //===----------------------------------------------------------------------===// /// Function to find an element within the given range that has the same name as @@ -2186,30 +1999,35 @@ namespace { /// This class implements a parser for an instance of an operation assembly /// format. -class FormatParser { +class OpFormatParser : public FormatParser { public: - FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) - : lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format), - op(op), seenOperandTypes(op.getNumOperands()), + OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) + : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op), + seenOperandTypes(op.getNumOperands()), seenResultTypes(op.getNumResults()) {} - /// Parse the operation assembly format. - LogicalResult parse(); +protected: + /// Verify the format elements. + LogicalResult verify(SMLoc loc, ArrayRef elements) override; + /// Verify the arguments to a custom directive. + LogicalResult + verifyCustomDirectiveArguments(SMLoc loc, + ArrayRef arguments) override; + /// Verify the elements of an optional group. + LogicalResult + verifyOptionalGroupElements(SMLoc loc, ArrayRef elements, + Optional anchorIndex) override; + LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element, + bool isAnchor); + + /// Parse an operation variable. + FailureOr parseVariableImpl(SMLoc loc, StringRef name, + Context ctx) override; + /// Parse an operation format directive. + FailureOr + parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override; private: - /// The current context of the parser when parsing an element. - enum ParserContext { - /// The element is being parsed in a "top-level" context, i.e. at the top of - /// the format or in an optional group. - TopLevelContext, - /// The element is being parsed as a custom directive child. - CustomDirectiveContext, - /// The element is being parsed as a type directive child. - TypeDirectiveContext, - /// The element is being parsed as a reference directive child. - RefDirectiveContext - }; - /// This struct represents a type resolution instance. It includes a specific /// type as well as an optional transformer to apply to that type in order to /// properly resolve the type of a variable. @@ -2218,16 +2036,14 @@ Optional transformer; }; - /// An iterator over the elements of a format group. - using ElementsIterT = llvm::pointee_iterator< - std::vector>::const_iterator>; + using ElementsItT = ArrayRef::iterator; /// Verify the state of operation attributes within the format. - LogicalResult verifyAttributes(SMLoc loc); + LogicalResult verifyAttributes(SMLoc loc, ArrayRef elements); /// Verify the attribute elements at the back of the given stack of iterators. LogicalResult verifyAttributes( SMLoc loc, - SmallVectorImpl> &iteratorStack); + SmallVectorImpl> &iteratorStack); /// Verify the state of operation operands within the format. LogicalResult @@ -2266,85 +2082,28 @@ /// within the format. ConstArgument findSeenArg(StringRef name); - /// Parse a specific element. - LogicalResult parseElement(std::unique_ptr &element, - ParserContext context); - LogicalResult parseVariable(std::unique_ptr &element, - ParserContext context); - LogicalResult parseDirective(std::unique_ptr &element, - ParserContext context); - LogicalResult parseLiteral(std::unique_ptr &element, - ParserContext context); - LogicalResult parseOptional(std::unique_ptr &element, - ParserContext context); - LogicalResult parseOptionalChildElement( - std::vector> &childElements, - Optional &anchorIdx); - LogicalResult verifyOptionalChildElement(Element *element, - SMLoc childLoc, bool isAnchor); - /// Parse the various different directives. - LogicalResult parseAttrDictDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context, - bool withKeyword); - LogicalResult parseCustomDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context); - LogicalResult parseCustomDirectiveParameter( - std::vector> ¶meters); - LogicalResult parseFunctionalTypeDirective(std::unique_ptr &element, - FormatToken tok, - ParserContext context); - LogicalResult parseOperandsDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context); - LogicalResult parseQualifiedDirective(std::unique_ptr &element, - FormatToken tok, ParserContext context); - LogicalResult parseReferenceDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context); - LogicalResult parseRegionsDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context); - LogicalResult parseResultsDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context); - LogicalResult parseSuccessorsDirective(std::unique_ptr &element, - SMLoc loc, - ParserContext context); - LogicalResult parseTypeDirective(std::unique_ptr &element, - FormatToken tok, ParserContext context); - LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, - bool isRefChild = false); - - //===--------------------------------------------------------------------===// - // Lexer Utilities - //===--------------------------------------------------------------------===// - - /// Advance the current lexer onto the next token. - void consumeToken() { - assert(curToken.getKind() != FormatToken::eof && - curToken.getKind() != FormatToken::error && - "shouldn't advance past EOF or errors"); - curToken = lexer.lexToken(); - } - LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) { - if (curToken.getKind() != kind) - return emitError(curToken.getLoc(), msg); - consumeToken(); - return ::mlir::success(); - } - LogicalResult emitError(SMLoc loc, const Twine &msg) { - lexer.emitError(loc, msg); - return ::mlir::failure(); - } - LogicalResult emitErrorAndNote(SMLoc loc, const Twine &msg, - const Twine ¬e) { - lexer.emitErrorAndNote(loc, msg, note); - return ::mlir::failure(); - } + FailureOr parseAttrDictDirective(SMLoc loc, Context context, + bool withKeyword); + FailureOr parseFunctionalTypeDirective(SMLoc loc, + Context context); + FailureOr parseOperandsDirective(SMLoc loc, Context context); + FailureOr parseQualifiedDirective(SMLoc loc, + Context context); + FailureOr parseReferenceDirective(SMLoc loc, + Context context); + FailureOr parseRegionsDirective(SMLoc loc, Context context); + FailureOr parseResultsDirective(SMLoc loc, Context context); + FailureOr parseSuccessorsDirective(SMLoc loc, + Context context); + FailureOr parseTypeDirective(SMLoc loc, Context context); + FailureOr parseTypeDirectiveOperand(SMLoc loc, + bool isRefChild = false); //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// - FormatLexer lexer; - FormatToken curToken; OperationFormat &fmt; Operator &op; @@ -2361,17 +2120,8 @@ }; } // namespace -LogicalResult FormatParser::parse() { - SMLoc loc = curToken.getLoc(); - - // Parse each of the format elements into the main format. - while (curToken.getKind() != FormatToken::eof) { - std::unique_ptr element; - if (failed(parseElement(element, TopLevelContext))) - return ::mlir::failure(); - fmt.elements.push_back(std::move(element)); - } - +LogicalResult OpFormatParser::verify(SMLoc loc, + ArrayRef elements) { // Check that the attribute dictionary is in the format. if (!hasAttrDict) return emitError(loc, "'attr-dict' directive not found in " @@ -2404,29 +2154,29 @@ } // Verify the state of the various operation components. - if (failed(verifyAttributes(loc)) || + if (failed(verifyAttributes(loc, elements)) || failed(verifyResults(loc, variableTyResolver)) || failed(verifyOperands(loc, variableTyResolver)) || failed(verifyRegions(loc)) || failed(verifySuccessors(loc))) - return ::mlir::failure(); + return failure(); // Collect the set of used attributes in the format. fmt.usedAttributes = seenAttrs.takeVector(); - return ::mlir::success(); + return success(); } -LogicalResult FormatParser::verifyAttributes(SMLoc loc) { +LogicalResult +OpFormatParser::verifyAttributes(SMLoc loc, + ArrayRef elements) { // Check that there are no `:` literals after an attribute without a constant // type. The attribute grammar contains an optional trailing colon type, which // can lead to unexpected and generally unintended behavior. Given that, it is // better to just error out here instead. - using ElementsIterT = llvm::pointee_iterator< - std::vector>::const_iterator>; - SmallVector, 1> iteratorStack; - iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end()); + SmallVector, 1> iteratorStack; + iteratorStack.emplace_back(elements.begin(), elements.end()); while (!iteratorStack.empty()) if (failed(verifyAttributes(loc, iteratorStack))) - return ::mlir::failure(); + return ::failure(); // Check for VariadicOfVariadic variables. The segment attribute of those // variables will be infered. @@ -2437,16 +2187,16 @@ } } - return ::mlir::success(); + return success(); } /// Verify the attribute elements at the back of the given stack of iterators. -LogicalResult FormatParser::verifyAttributes( +LogicalResult OpFormatParser::verifyAttributes( SMLoc loc, - SmallVectorImpl> &iteratorStack) { + SmallVectorImpl> &iteratorStack) { auto &stackIt = iteratorStack.back(); - ElementsIterT &it = stackIt.first, e = stackIt.second; + ElementsItT &it = stackIt.first, e = stackIt.second; while (it != e) { - Element *element = &*(it++); + FormatElement *element = *(it++); // Traverse into optional groups. if (auto *optional = dyn_cast(element)) { @@ -2455,7 +2205,7 @@ auto elseElements = optional->getElseElements(); iteratorStack.emplace_back(elseElements.begin(), elseElements.end()); - return ::mlir::success(); + return success(); } // We are checking for an attribute element followed by a `:`, so there is @@ -2470,7 +2220,7 @@ // Check the next iterator within the stack for literal elements. for (auto &nextItPair : iteratorStack) { - ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second; + ElementsItT nextIt = nextItPair.first, nextE = nextItPair.second; for (; nextIt != nextE; ++nextIt) { // Skip any trailing whitespace, attribute dictionaries, or optional // groups. @@ -2479,8 +2229,8 @@ continue; // We are only interested in `:` literals. - auto *literal = dyn_cast(&*nextIt); - if (!literal || literal->getLiteral() != ":") + auto *literal = dyn_cast(*nextIt); + if (!literal || literal->getSpelling() != ":") break; // TODO: Use the location of the literal element itself. @@ -2493,12 +2243,11 @@ } } iteratorStack.pop_back(); - return ::mlir::success(); + return success(); } -LogicalResult FormatParser::verifyOperands( - SMLoc loc, - llvm::StringMap &variableTyResolver) { +LogicalResult OpFormatParser::verifyOperands( + SMLoc loc, llvm::StringMap &variableTyResolver) { // Check that all of the operands are within the format, and their types can // be inferred. auto &buildableTypes = fmt.buildableTypes; @@ -2541,13 +2290,13 @@ auto it = buildableTypes.insert({*builder, buildableTypes.size()}); fmt.operandTypes[i].setBuilderIdx(it.first->second); } - return ::mlir::success(); + return success(); } -LogicalResult FormatParser::verifyRegions(SMLoc loc) { +LogicalResult OpFormatParser::verifyRegions(SMLoc loc) { // Check that all of the regions are within the format. if (hasAllRegions) - return ::mlir::success(); + return success(); for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) { const NamedRegion ®ion = op.getRegion(i); @@ -2559,22 +2308,21 @@ "' directive to the custom assembly format"); } } - return ::mlir::success(); + return success(); } -LogicalResult FormatParser::verifyResults( - SMLoc loc, - llvm::StringMap &variableTyResolver) { +LogicalResult OpFormatParser::verifyResults( + SMLoc loc, llvm::StringMap &variableTyResolver) { // If we format all of the types together, there is nothing to check. if (fmt.allResultTypes) - return ::mlir::success(); + return success(); // If no result types are specified and we can infer them, infer all result // types if (op.getNumResults() > 0 && seenResultTypes.count() == 0 && canInferResultTypes) { fmt.infersResultTypes = true; - return ::mlir::success(); + return success(); } // Check that all of the result types can be inferred. @@ -2608,13 +2356,13 @@ auto it = buildableTypes.insert({*builder, buildableTypes.size()}); fmt.resultTypes[i].setBuilderIdx(it.first->second); } - return ::mlir::success(); + return success(); } -LogicalResult FormatParser::verifySuccessors(SMLoc loc) { +LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) { // Check that all of the successors are within the format. if (hasAllSuccessors) - return ::mlir::success(); + return success(); for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { const NamedSuccessor &successor = op.getSuccessor(i); @@ -2626,10 +2374,10 @@ "' directive to the custom assembly format"); } } - return ::mlir::success(); + return success(); } -void FormatParser::handleAllTypesMatchConstraint( +void OpFormatParser::handleAllTypesMatchConstraint( ArrayRef values, llvm::StringMap &variableTyResolver) { for (unsigned i = 0, e = values.size(); i != e; ++i) { @@ -2646,7 +2394,7 @@ } } -void FormatParser::handleSameTypesConstraint( +void OpFormatParser::handleSameTypesConstraint( llvm::StringMap &variableTyResolver, bool includeResults) { const NamedTypeConstraint *resolver = nullptr; @@ -2671,7 +2419,7 @@ } } -void FormatParser::handleTypesMatchConstraint( +void OpFormatParser::handleTypesMatchConstraint( llvm::StringMap &variableTyResolver, const llvm::Record &def) { StringRef lhsName = def.getValueAsString("lhs"); @@ -2681,7 +2429,7 @@ variableTyResolver[rhsName] = {arg, transformer}; } -ConstArgument FormatParser::findSeenArg(StringRef name) { +ConstArgument OpFormatParser::findSeenArg(StringRef name) { if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name)) return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; if (const NamedTypeConstraint *arg = findArg(op.getResults(), name)) @@ -2691,40 +2439,15 @@ return nullptr; } -LogicalResult FormatParser::parseElement(std::unique_ptr &element, - ParserContext context) { - // Directives. - if (curToken.isKeyword()) - return parseDirective(element, context); - // Literals. - if (curToken.getKind() == FormatToken::literal) - return parseLiteral(element, context); - // Optionals. - if (curToken.getKind() == FormatToken::l_paren) - return parseOptional(element, context); - // Variables. - if (curToken.getKind() == FormatToken::variable) - return parseVariable(element, context); - return emitError(curToken.getLoc(), - "expected directive, literal, variable, or optional group"); -} - -LogicalResult FormatParser::parseVariable(std::unique_ptr &element, - ParserContext context) { - FormatToken varTok = curToken; - consumeToken(); - - StringRef name = varTok.getSpelling().drop_front(); - SMLoc loc = varTok.getLoc(); - - // Check that the parsed argument is something actually registered on the - // op. - /// Attributes +FailureOr +OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { + // Check that the parsed argument is something actually registered on the op. + // Attributes if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) { - if (context == TypeDirectiveContext) + if (ctx == TypeDirectiveContext) return emitError( loc, "attributes cannot be used as children to a `type` directive"); - if (context == RefDirectiveContext) { + if (ctx == RefDirectiveContext) { if (!seenAttrs.count(attr)) return emitError(loc, "attribute '" + name + "' must be bound before it is referenced"); @@ -2732,280 +2455,92 @@ return emitError(loc, "attribute '" + name + "' is already bound"); } - element = std::make_unique(attr); - return ::mlir::success(); + return create(attr); } - /// Operands + // Operands if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) { - if (context == TopLevelContext || context == CustomDirectiveContext) { + if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { if (fmt.allOperands || !seenOperands.insert(operand).second) return emitError(loc, "operand '" + name + "' is already bound"); - } else if (context == RefDirectiveContext && !seenOperands.count(operand)) { + } else if (ctx == RefDirectiveContext && !seenOperands.count(operand)) { return emitError(loc, "operand '" + name + "' must be bound before it is referenced"); } - element = std::make_unique(operand); - return ::mlir::success(); + return create(operand); } - /// Regions + // Regions if (const NamedRegion *region = findArg(op.getRegions(), name)) { - if (context == TopLevelContext || context == CustomDirectiveContext) { + if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { if (hasAllRegions || !seenRegions.insert(region).second) return emitError(loc, "region '" + name + "' is already bound"); - } else if (context == RefDirectiveContext && !seenRegions.count(region)) { + } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) { return emitError(loc, "region '" + name + "' must be bound before it is referenced"); } else { return emitError(loc, "regions can only be used at the top level"); } - element = std::make_unique(region); - return ::mlir::success(); + return create(region); } - /// Results. + // Results. if (const auto *result = findArg(op.getResults(), name)) { - if (context != TypeDirectiveContext) + if (ctx != TypeDirectiveContext) return emitError(loc, "result variables can can only be used as a child " "to a 'type' directive"); - element = std::make_unique(result); - return ::mlir::success(); + return create(result); } - /// Successors. + // Successors. if (const auto *successor = findArg(op.getSuccessors(), name)) { - if (context == TopLevelContext || context == CustomDirectiveContext) { + if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { if (hasAllSuccessors || !seenSuccessors.insert(successor).second) return emitError(loc, "successor '" + name + "' is already bound"); - } else if (context == RefDirectiveContext && - !seenSuccessors.count(successor)) { + } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) { return emitError(loc, "successor '" + name + "' must be bound before it is referenced"); } else { return emitError(loc, "successors can only be used at the top level"); } - element = std::make_unique(successor); - return ::mlir::success(); + return create(successor); } return emitError(loc, "expected variable to refer to an argument, region, " "result, or successor"); } -LogicalResult FormatParser::parseDirective(std::unique_ptr &element, - ParserContext context) { - FormatToken dirTok = curToken; - consumeToken(); - - switch (dirTok.getKind()) { +FailureOr +OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, + Context ctx) { + switch (kind) { case FormatToken::kw_attr_dict: - return parseAttrDictDirective(element, dirTok.getLoc(), context, + return parseAttrDictDirective(loc, ctx, /*withKeyword=*/false); case FormatToken::kw_attr_dict_w_keyword: - return parseAttrDictDirective(element, dirTok.getLoc(), context, + return parseAttrDictDirective(loc, ctx, /*withKeyword=*/true); - case FormatToken::kw_custom: - return parseCustomDirective(element, dirTok.getLoc(), context); case FormatToken::kw_functional_type: - return parseFunctionalTypeDirective(element, dirTok, context); + return parseFunctionalTypeDirective(loc, ctx); case FormatToken::kw_operands: - return parseOperandsDirective(element, dirTok.getLoc(), context); + return parseOperandsDirective(loc, ctx); case FormatToken::kw_qualified: - return parseQualifiedDirective(element, dirTok, context); + return parseQualifiedDirective(loc, ctx); case FormatToken::kw_regions: - return parseRegionsDirective(element, dirTok.getLoc(), context); + return parseRegionsDirective(loc, ctx); case FormatToken::kw_results: - return parseResultsDirective(element, dirTok.getLoc(), context); + return parseResultsDirective(loc, ctx); case FormatToken::kw_successors: - return parseSuccessorsDirective(element, dirTok.getLoc(), context); + return parseSuccessorsDirective(loc, ctx); case FormatToken::kw_ref: - return parseReferenceDirective(element, dirTok.getLoc(), context); + return parseReferenceDirective(loc, ctx); case FormatToken::kw_type: - return parseTypeDirective(element, dirTok, context); + return parseTypeDirective(loc, ctx); default: - llvm_unreachable("unknown directive token"); + return emitError(loc, "unsupported directive kind"); } } -LogicalResult FormatParser::parseLiteral(std::unique_ptr &element, - ParserContext context) { - FormatToken literalTok = curToken; - if (context != TopLevelContext) { - return emitError( - literalTok.getLoc(), - "literals may only be used in a top-level section of the format"); - } - consumeToken(); - - StringRef value = literalTok.getSpelling().drop_front().drop_back(); - - // The parsed literal is a space element (`` or ` `). - if (value.empty() || (value.size() == 1 && value.front() == ' ')) { - element = std::make_unique(!value.empty()); - return ::mlir::success(); - } - // The parsed literal is a newline element. - if (value == "\\n") { - element = std::make_unique(); - return ::mlir::success(); - } - - // Check that the parsed literal is valid. - if (!isValidLiteral(value, [&](Twine diag) { - (void)emitError(literalTok.getLoc(), - "expected valid literal but got '" + value + - "': " + diag); - })) - return failure(); - element = std::make_unique(value); - return ::mlir::success(); -} - -LogicalResult FormatParser::parseOptional(std::unique_ptr &element, - ParserContext context) { - SMLoc curLoc = curToken.getLoc(); - if (context != TopLevelContext) - return emitError(curLoc, "optional groups can only be used as top-level " - "elements"); - consumeToken(); - - // Parse the child elements for this optional group. - std::vector> thenElements, elseElements; - Optional anchorIdx; - do { - if (failed(parseOptionalChildElement(thenElements, anchorIdx))) - return ::mlir::failure(); - } while (curToken.getKind() != FormatToken::r_paren); - consumeToken(); - - // Parse the `else` elements of this optional group. - if (curToken.getKind() == FormatToken::colon) { - consumeToken(); - if (failed(parseToken(FormatToken::l_paren, - "expected '(' to start else branch " - "of optional group"))) - return failure(); - do { - SMLoc childLoc = curToken.getLoc(); - elseElements.push_back({}); - if (failed(parseElement(elseElements.back(), TopLevelContext)) || - failed(verifyOptionalChildElement(elseElements.back().get(), childLoc, - /*isAnchor=*/false))) - return failure(); - } while (curToken.getKind() != FormatToken::r_paren); - consumeToken(); - } - - if (failed(parseToken(FormatToken::question, - "expected '?' after optional group"))) - return ::mlir::failure(); - - // The optional group is required to have an anchor. - if (!anchorIdx) - return emitError(curLoc, "optional group specified no anchor element"); - - // The first parsable element of the group must be able to be parsed in an - // optional fashion. - auto parseBegin = llvm::find_if_not(thenElements, [](auto &element) { - return isa(element.get()); - }); - Element *firstElement = parseBegin->get(); - if (!isa(firstElement) && - !isa(firstElement) && - !isa(firstElement) && !isa(firstElement)) - return emitError(curLoc, - "first parsable element of an operand group must be " - "an attribute, literal, operand, or region"); - - auto parseStart = parseBegin - thenElements.begin(); - element = std::make_unique( - std::move(thenElements), std::move(elseElements), *anchorIdx, parseStart); - return ::mlir::success(); -} - -LogicalResult FormatParser::parseOptionalChildElement( - std::vector> &childElements, - Optional &anchorIdx) { - SMLoc childLoc = curToken.getLoc(); - childElements.push_back({}); - if (failed(parseElement(childElements.back(), TopLevelContext))) - return ::mlir::failure(); - - // Check to see if this element is the anchor of the optional group. - bool isAnchor = curToken.getKind() == FormatToken::caret; - if (isAnchor) { - if (anchorIdx) - return emitError(childLoc, "only one element can be marked as the anchor " - "of an optional group"); - anchorIdx = childElements.size() - 1; - consumeToken(); - } - - return verifyOptionalChildElement(childElements.back().get(), childLoc, - isAnchor); -} - -LogicalResult FormatParser::verifyOptionalChildElement(Element *element, - SMLoc childLoc, - bool isAnchor) { - return TypeSwitch(element) - // All attributes can be within the optional group, but only optional - // attributes can be the anchor. - .Case([&](AttributeVariable *attrEle) { - if (isAnchor && !attrEle->getVar()->attr.isOptional()) - return emitError(childLoc, "only optional attributes can be used to " - "anchor an optional group"); - return ::mlir::success(); - }) - // Only optional-like(i.e. variadic) operands can be within an optional - // group. - .Case([&](OperandVariable *ele) { - if (!ele->getVar()->isVariableLength()) - return emitError(childLoc, "only variable length operands can be " - "used within an optional group"); - return ::mlir::success(); - }) - // Only optional-like(i.e. variadic) results can be within an optional - // group. - .Case([&](ResultVariable *ele) { - if (!ele->getVar()->isVariableLength()) - return emitError(childLoc, "only variable length results can be " - "used within an optional group"); - return ::mlir::success(); - }) - .Case([&](RegionVariable *) { - // TODO: When ODS has proper support for marking "optional" regions, add - // a check here. - return ::mlir::success(); - }) - .Case([&](TypeDirective *ele) { - return verifyOptionalChildElement(ele->getOperand(), childLoc, - /*isAnchor=*/false); - }) - .Case([&](FunctionalTypeDirective *ele) { - if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc, - /*isAnchor=*/false))) - return failure(); - return verifyOptionalChildElement(ele->getResults(), childLoc, - /*isAnchor=*/false); - }) - // Literals, whitespace, and custom directives may be used, but they can't - // anchor the group. - .Case([&](Element *) { - if (isAnchor) - return emitError(childLoc, "only variables and types can be used " - "to anchor an optional group"); - return ::mlir::success(); - }) - .Default([&](Element *) { - return emitError(childLoc, "only literals, types, and variables can be " - "used within an optional group"); - }); -} - -LogicalResult -FormatParser::parseAttrDictDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context, - bool withKeyword) { +FailureOr +OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context, + bool withKeyword) { if (context == TypeDirectiveContext) return emitError(loc, "'attr-dict' directive can only be used as a " "top-level directive"); @@ -3022,104 +2557,50 @@ hasAttrDict = true; } - element = std::make_unique(withKeyword); - return ::mlir::success(); + return create(withKeyword); } -LogicalResult -FormatParser::parseCustomDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context) { - SMLoc curLoc = curToken.getLoc(); - if (context != TopLevelContext) - return emitError(loc, "'custom' is only valid as a top-level directive"); - - // Parse the custom directive name. - if (failed(parseToken(FormatToken::less, - "expected '<' before custom directive name"))) - return ::mlir::failure(); - - FormatToken nameTok = curToken; - if (failed(parseToken(FormatToken::identifier, - "expected custom directive name identifier")) || - failed(parseToken(FormatToken::greater, - "expected '>' after custom directive name")) || - failed(parseToken(FormatToken::l_paren, - "expected '(' before custom directive parameters"))) - return ::mlir::failure(); - - // Parse the child elements for this optional group.= - std::vector> elements; - do { - if (failed(parseCustomDirectiveParameter(elements))) - return ::mlir::failure(); - if (curToken.getKind() != FormatToken::comma) - break; - consumeToken(); - } while (true); - - if (failed(parseToken(FormatToken::r_paren, - "expected ')' after custom directive parameters"))) - return ::mlir::failure(); - - // After parsing all of the elements, ensure that all type directives refer - // only to variables. - for (auto &ele : elements) { - if (auto *typeEle = dyn_cast(ele.get())) { - if (!isa(typeEle->getOperand())) { - return emitError(curLoc, "type directives within a custom directive " - "may only refer to variables"); +LogicalResult OpFormatParser::verifyCustomDirectiveArguments( + SMLoc loc, ArrayRef arguments) { + for (FormatElement *argument : arguments) { + if (!isa(argument)) { + // TODO: FormatElement should have location info attached. + return emitError(loc, "only variables and types may be used as " + "parameters to a custom directive"); + } + if (auto *type = dyn_cast(argument)) { + if (!isa(type->getArg())) { + return emitError(loc, "type directives within a custom directive may " + "only refer to variables"); } } } - - element = std::make_unique(nameTok.getSpelling(), - std::move(elements)); - return ::mlir::success(); -} - -LogicalResult FormatParser::parseCustomDirectiveParameter( - std::vector> ¶meters) { - SMLoc childLoc = curToken.getLoc(); - parameters.push_back({}); - if (failed(parseElement(parameters.back(), CustomDirectiveContext))) - return ::mlir::failure(); - - // Verify that the element can be placed within a custom directive. - if (!isa( - parameters.back().get())) { - return emitError(childLoc, "only variables and types may be used as " - "parameters to a custom directive"); - } - return ::mlir::success(); + return success(); } -LogicalResult FormatParser::parseFunctionalTypeDirective( - std::unique_ptr &element, FormatToken tok, ParserContext context) { - SMLoc loc = tok.getLoc(); +FailureOr +OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) { if (context != TopLevelContext) return emitError( loc, "'functional-type' is only valid as a top-level directive"); // Parse the main operand. - std::unique_ptr inputs, results; + FailureOr inputs, results; if (failed(parseToken(FormatToken::l_paren, "expected '(' before argument list")) || - failed(parseTypeDirectiveOperand(inputs)) || + failed(inputs = parseTypeDirectiveOperand(loc)) || failed(parseToken(FormatToken::comma, "expected ',' after inputs argument")) || - failed(parseTypeDirectiveOperand(results)) || + failed(results = parseTypeDirectiveOperand(loc)) || failed( parseToken(FormatToken::r_paren, "expected ')' after argument list"))) - return ::mlir::failure(); - element = std::make_unique(std::move(inputs), - std::move(results)); - return ::mlir::success(); + return failure(); + return create(*inputs, *results); } -LogicalResult -FormatParser::parseOperandsDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context) { +FailureOr +OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) { if (context == RefDirectiveContext) { if (!fmt.allOperands) return emitError(loc, "'ref' of 'operands' is not bound by a prior " @@ -3130,31 +2611,27 @@ return emitError(loc, "'operands' directive creates overlap in format"); fmt.allOperands = true; } - element = std::make_unique(); - return ::mlir::success(); + return create(); } -LogicalResult -FormatParser::parseReferenceDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context) { +FailureOr +OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) { if (context != CustomDirectiveContext) return emitError(loc, "'ref' is only valid within a `custom` directive"); - std::unique_ptr operand; + FailureOr arg; if (failed(parseToken(FormatToken::l_paren, "expected '(' before argument list")) || - failed(parseElement(operand, RefDirectiveContext)) || + failed(arg = parseElement(RefDirectiveContext)) || failed( parseToken(FormatToken::r_paren, "expected ')' after argument list"))) - return ::mlir::failure(); + return failure(); - element = std::make_unique(std::move(operand)); - return ::mlir::success(); + return create(*arg); } -LogicalResult -FormatParser::parseRegionsDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context) { +FailureOr +OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) { if (context == TypeDirectiveContext) return emitError(loc, "'regions' is only valid as a top-level directive"); if (context == RefDirectiveContext) { @@ -3168,23 +2645,19 @@ return emitError(loc, "'regions' directive creates overlap in format"); hasAllRegions = true; } - element = std::make_unique(); - return ::mlir::success(); + return create(); } -LogicalResult -FormatParser::parseResultsDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context) { +FailureOr +OpFormatParser::parseResultsDirective(SMLoc loc, Context context) { if (context != TypeDirectiveContext) return emitError(loc, "'results' directive can can only be used as a child " "to a 'type' directive"); - element = std::make_unique(); - return ::mlir::success(); + return create(); } -LogicalResult -FormatParser::parseSuccessorsDirective(std::unique_ptr &element, - SMLoc loc, ParserContext context) { +FailureOr +OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) { if (context == TypeDirectiveContext) return emitError(loc, "'successors' is only valid as a top-level directive"); @@ -3199,62 +2672,59 @@ return emitError(loc, "'successors' directive creates overlap in format"); hasAllSuccessors = true; } - element = std::make_unique(); - return ::mlir::success(); + return create(); } -LogicalResult -FormatParser::parseTypeDirective(std::unique_ptr &element, - FormatToken tok, ParserContext context) { - SMLoc loc = tok.getLoc(); +FailureOr OpFormatParser::parseTypeDirective(SMLoc loc, + Context context) { if (context == TypeDirectiveContext) return emitError(loc, "'type' cannot be used as a child of another `type`"); bool isRefChild = context == RefDirectiveContext; - std::unique_ptr operand; + FailureOr operand; if (failed(parseToken(FormatToken::l_paren, "expected '(' before argument list")) || - failed(parseTypeDirectiveOperand(operand, isRefChild)) || + failed(operand = parseTypeDirectiveOperand(loc, isRefChild)) || failed( parseToken(FormatToken::r_paren, "expected ')' after argument list"))) - return ::mlir::failure(); + return failure(); - element = std::make_unique(std::move(operand)); - return ::mlir::success(); + return create(*operand); } -LogicalResult -FormatParser::parseQualifiedDirective(std::unique_ptr &element, - FormatToken tok, ParserContext context) { +FailureOr +OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) { + FailureOr element; if (failed(parseToken(FormatToken::l_paren, "expected '(' before argument list")) || - failed(parseElement(element, context)) || + failed(element = parseElement(context)) || failed( parseToken(FormatToken::r_paren, "expected ')' after argument list"))) return failure(); - if (auto *attr = dyn_cast(element.get())) { - attr->setShouldBeQualified(); - } else if (auto *type = dyn_cast(element.get())) { - type->setShouldBeQualified(); - } else { - return emitError( - tok.getLoc(), - "'qualified' directive expects an attribute or a `type` directive"); - } - return success(); + return TypeSwitch>(*element) + .Case([](auto *element) { + element->setShouldBeQualified(); + return element; + }) + .Default([&](auto *element) { + return emitError( + loc, + "'qualified' directive expects an attribute or a `type` directive"); + }); } -LogicalResult -FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element, - bool isRefChild) { - SMLoc loc = curToken.getLoc(); - if (failed(parseElement(element, TypeDirectiveContext))) - return ::mlir::failure(); - if (isa(element.get())) +FailureOr +OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) { + FailureOr result = parseElement(TypeDirectiveContext); + if (failed(result)) + return failure(); + + FormatElement *element = *result; + if (isa(element)) return emitError( loc, "'type' directive operand expects variable or directive operand"); - if (auto *var = dyn_cast(element.get())) { + if (auto *var = dyn_cast(element)) { unsigned opIdx = var->getVar() - op.operand_begin(); if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx))) return emitError(loc, "'type' of '" + var->getVar()->name + @@ -3263,7 +2733,7 @@ return emitError(loc, "'ref' of 'type($" + var->getVar()->name + ")' is not bound by a prior 'type' directive"); seenOperandTypes.set(opIdx); - } else if (auto *var = dyn_cast(element.get())) { + } else if (auto *var = dyn_cast(element)) { unsigned resIdx = var->getVar() - op.result_begin(); if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx))) return emitError(loc, "'type' of '" + var->getVar()->name + @@ -3289,7 +2759,78 @@ } else { return emitError(loc, "invalid argument to 'type' directive"); } - return ::mlir::success(); + return element; +} + +LogicalResult +OpFormatParser::verifyOptionalGroupElements(SMLoc loc, + ArrayRef elements, + Optional anchorIndex) { + for (auto &it : llvm::enumerate(elements)) { + if (failed(verifyOptionalGroupElement( + loc, it.value(), anchorIndex && *anchorIndex == it.index()))) + return failure(); + } + return success(); +} + +LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc, + FormatElement *element, + bool isAnchor) { + return TypeSwitch(element) + // All attributes can be within the optional group, but only optional + // attributes can be the anchor. + .Case([&](AttributeVariable *attrEle) { + if (isAnchor && !attrEle->getVar()->attr.isOptional()) + return emitError(loc, "only optional attributes can be used to " + "anchor an optional group"); + return success(); + }) + // Only optional-like(i.e. variadic) operands can be within an optional + // group. + .Case([&](OperandVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(loc, "only variable length operands can be used " + "within an optional group"); + return success(); + }) + // Only optional-like(i.e. variadic) results can be within an optional + // group. + .Case([&](ResultVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(loc, "only variable length results can be used " + "within an optional group"); + return success(); + }) + .Case([&](RegionVariable *) { + // TODO: When ODS has proper support for marking "optional" regions, add + // a check here. + return success(); + }) + .Case([&](TypeDirective *ele) { + return verifyOptionalGroupElement(loc, ele->getArg(), + /*isAnchor=*/false); + }) + .Case([&](FunctionalTypeDirective *ele) { + if (failed(verifyOptionalGroupElement(loc, ele->getInputs(), + /*isAnchor=*/false))) + return failure(); + return verifyOptionalGroupElement(loc, ele->getResults(), + /*isAnchor=*/false); + }) + // Literals, whitespace, and custom directives may be used, but they can't + // anchor the group. + .Case([&](FormatElement *) { + if (isAnchor) + return emitError(loc, "only variables and types can be used " + "to anchor an optional group"); + return success(); + }) + .Default([&](FormatElement *) { + return emitError(loc, "only literals, types, and variables can be " + "used within an optional group"); + }); } //===----------------------------------------------------------------------===// @@ -3308,7 +2849,9 @@ mgr.AddNewSourceBuffer( llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), SMLoc()); OperationFormat format(op); - if (failed(FormatParser(mgr, format, op).parse())) { + OpFormatParser parser(mgr, format, op); + FailureOr> elements = parser.parse(); + if (failed(elements)) { // Exit the process if format errors are treated as fatal. if (formatErrorIsFatal) { // Invoke the interrupt handlers to run the file cleanup handlers. @@ -3317,6 +2860,7 @@ } return; } + format.elements = std::move(*elements); // Generate the printer and parser based on the parsed format. format.genParser(op, opClass);