diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -917,9 +917,9 @@ verifyCustomDirectiveArguments(SMLoc loc, ArrayRef arguments) override; /// Verify the elements of an optional group. - LogicalResult - verifyOptionalGroupElements(SMLoc loc, ArrayRef elements, - Optional anchorIndex) override; + LogicalResult verifyOptionalGroupElements(SMLoc loc, + ArrayRef elements, + FormatElement *anchor) override; /// Parse an attribute or type variable. FailureOr parseVariableImpl(SMLoc loc, StringRef name, @@ -989,7 +989,7 @@ LogicalResult DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, ArrayRef elements, - Optional anchorIndex) { + FormatElement *anchor) { // `params` and `struct` directives are allowed only if all the contained // parameters are optional. for (FormatElement *el : elements) { @@ -1011,8 +1011,8 @@ } } // The anchor must be a parameter or one of the aforementioned directives. - if (anchorIndex && !isa( - elements[*anchorIndex])) { + if (anchor && + !isa(anchor)) { return emitError(loc, "optional group anchor must be a parameter or directive"); } diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -378,9 +378,9 @@ /// Create an optional group with the given child elements. OptionalElement(std::vector &&thenElements, std::vector &&elseElements, - unsigned anchorIndex, unsigned parseStart) + FormatElement *anchor, unsigned parseStart) : thenElements(std::move(thenElements)), - elseElements(std::move(elseElements)), anchorIndex(anchorIndex), + elseElements(std::move(elseElements)), anchor(anchor), parseStart(parseStart) {} /// Return the `then` elements of the optional group. @@ -390,7 +390,7 @@ ArrayRef getElseElements() const { return elseElements; } /// Return the anchor of the optional group. - FormatElement *getAnchor() const { return thenElements[anchorIndex]; } + FormatElement *getAnchor() const { return anchor; } /// Return the index of the first element to be parsed. unsigned getParseStart() const { return parseStart; } @@ -400,9 +400,8 @@ std::vector thenElements; /// The child elements emitted when the anchor is not present. std::vector elseElements; - /// The index of the anchor element of the optional group within - /// `thenElements`. - unsigned anchorIndex; + /// The anchor element of the optional group. + FormatElement *anchor; /// The index of the first element that is parsed in `thenElements`. That is, /// the first non-whitespace element. unsigned parseStart; @@ -496,7 +495,7 @@ virtual LogicalResult verifyOptionalGroupElements(llvm::SMLoc loc, ArrayRef elements, - Optional anchorIndex) = 0; + FormatElement *anchor) = 0; //===--------------------------------------------------------------------===// // Lexer Utilities diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -320,17 +320,17 @@ // Parse the child elements for this optional group. std::vector thenElements, elseElements; - Optional anchorIndex; + FormatElement *anchor = nullptr; do { FailureOr element = parseElement(TopLevelContext); if (failed(element)) return failure(); // Check for an anchor. if (curToken.is(FormatToken::caret)) { - if (anchorIndex) + if (anchor) return emitError(curToken.getLoc(), "only one element can be marked as " "the anchor of an optional group"); - anchorIndex = thenElements.size(); + anchor = *element; consumeToken(); } thenElements.push_back(*element); @@ -357,12 +357,12 @@ return failure(); // The optional group is required to have an anchor. - if (!anchorIndex) + if (!anchor) return emitError(loc, "optional group has no anchor element"); // Verify the child elements. - if (failed(verifyOptionalGroupElements(loc, thenElements, anchorIndex)) || - failed(verifyOptionalGroupElements(loc, elseElements, llvm::None))) + if (failed(verifyOptionalGroupElements(loc, thenElements, anchor)) || + failed(verifyOptionalGroupElements(loc, elseElements, nullptr))) return failure(); // Get the first parsable element. It must be an element that can be @@ -377,8 +377,7 @@ unsigned parseStart = std::distance(thenElements.begin(), parseBegin); return create(std::move(thenElements), - std::move(elseElements), *anchorIndex, - parseStart); + std::move(elseElements), anchor, parseStart); } FailureOr FormatParser::parseCustomDirective(SMLoc loc, diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2170,9 +2170,9 @@ verifyCustomDirectiveArguments(SMLoc loc, ArrayRef arguments) override; /// Verify the elements of an optional group. - LogicalResult - verifyOptionalGroupElements(SMLoc loc, ArrayRef elements, - Optional anchorIndex) override; + LogicalResult verifyOptionalGroupElements(SMLoc loc, + ArrayRef elements, + FormatElement *anchor) override; LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element, bool isAnchor); @@ -3150,13 +3150,10 @@ return element; } -LogicalResult -OpFormatParser::verifyOptionalGroupElements(SMLoc loc, - ArrayRef elements, - Optional anchorIndex) { - for (auto &it : llvm::enumerate(elements)) { - if (failed(verifyOptionalGroupElement( - loc, it.value(), anchorIndex && *anchorIndex == it.index()))) +LogicalResult OpFormatParser::verifyOptionalGroupElements( + SMLoc loc, ArrayRef elements, FormatElement *anchor) { + for (FormatElement *element : elements) { + if (failed(verifyOptionalGroupElement(loc, element, element == anchor))) return failure(); } return success();