diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -991,18 +991,21 @@ return success(); } - /// Parse a 'x' separated dimension list. This populates the dimension list, - /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on - /// `?` otherwise. + /// Parse a dimension list of a tensor or memref type. This populates the + /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set + /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable. /// - /// dimension-list ::= (dimension `x`)* - /// dimension ::= `?` | integer + /// dimension-list ::= eps | dimension (`x` dimension)* + /// dimension-list-with-trailing-x ::= (dimension `x`)* + /// dimension ::= `?` | decimal-literal /// /// When `allowDynamic` is not set, this is used to parse: /// - /// static-dimension-list ::= (integer `x`)* + /// static-dimension-list ::= epx | decimal-literal (`x` decimal-literal)* + /// static-dimension-list-with-trailing-x ::= (dimension `x`)* virtual ParseResult parseDimensionList(SmallVectorImpl &dimensions, - bool allowDynamic = true) = 0; + bool allowDynamic = true, + bool withTrailingX = true) = 0; /// Parse an 'x' token in a dimension list, handling the case where the x is /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -491,8 +491,10 @@ } ParseResult parseDimensionList(SmallVectorImpl &dimensions, - bool allowDynamic) override { - return parser.parseDimensionListRanked(dimensions, allowDynamic); + bool allowDynamic, + bool withTrailingX) override { + return parser.parseDimensionListRanked(dimensions, allowDynamic, + withTrailingX); } ParseResult parseXInDimensionList() override { diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -203,7 +203,8 @@ ParseResult parseVectorDimensionList(SmallVectorImpl &dimensions, unsigned &numScalableDims); ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, - bool allowDynamic = true); + bool allowDynamic = true, + bool withTrailingX = true); ParseResult parseIntegerInDimensionList(int64_t &value); ParseResult parseXInDimensionList(); diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -523,18 +523,20 @@ /// Parse a dimension list of a tensor or memref type. This populates the /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and -/// errors out on `?` otherwise. +/// errors out on `?` otherwise. Parsing the trailing `x` is configurable. /// -/// dimension-list-ranked ::= (dimension `x`)* +/// dimension-list ::= eps | dimension (`x` dimension)* +/// dimension-list-with-trailing-x ::= (dimension `x`)* /// dimension ::= `?` | decimal-literal /// /// When `allowDynamic` is not set, this is used to parse: /// -/// static-dimension-list ::= (decimal-literal `x`)* +/// static-dimension-list ::= epx | decimal-literal (`x` decimal-literal)* +/// static-dimension-list-with-trailing-x ::= (dimension `x`)* ParseResult Parser::parseDimensionListRanked(SmallVectorImpl &dimensions, - bool allowDynamic) { - while (getToken().isAny(Token::integer, Token::question)) { + bool allowDynamic, bool withTrailingX) { + auto parse_dim = [&]() -> LogicalResult { auto loc = getToken().getLoc(); if (consumeIf(Token::question)) { if (!allowDynamic) @@ -542,15 +544,32 @@ dimensions.push_back(-1); } else { int64_t value; - if (parseIntegerInDimensionList(value)) + if (failed(parseIntegerInDimensionList(value))) return failure(); dimensions.push_back(value); } - // Make sure we have an 'x' or something like 'xbf32'. - if (parseXInDimensionList()) - return failure(); + return success(); + }; + if (withTrailingX) { + while (getToken().isAny(Token::integer, Token::question)) { + if (failed(parse_dim())) + return failure(); + if (failed(parseXInDimensionList())) + return failure(); + } + } else { + if (getToken().isAny(Token::integer, Token::question)) { + if (failed(parse_dim())) + return failure(); + while (getToken().is(Token::bare_identifier) && + getTokenSpelling()[0] == 'x') { + if (failed(parseXInDimensionList())) + return failure(); + if (failed(parse_dim())) + return failure(); + } + } } - return success(); }