diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -473,6 +473,47 @@ return success(); } + /// These are the supported delimiters around operand lists and region + /// argument lists, used by parseOperandList and parseRegionArgumentList. + enum class Delimiter { + /// Zero or more operands with no delimiters. + None, + /// Parens surrounding zero or more operands. + Paren, + /// Square brackets surrounding zero or more operands. + Square, + /// <> brackets surrounding zero or more operands. + LessGreater, + /// {} brackets surrounding zero or more operands. + Braces, + /// Parens supporting zero or more operands, or nothing. + OptionalParen, + /// Square brackets supporting zero or more ops, or nothing. + OptionalSquare, + /// <> brackets supporting zero or more ops, or nothing. + OptionalLessGreater, + /// {} brackets surrounding zero or more operands, or nothing. + OptionalBraces, + }; + + /// Parse a list of comma-separated items with an optional delimiter. If a + /// delimiter is provided, then an empty list is allowed. If not, then at + /// least one element will be parsed. + /// + /// contextMessage is an optional message appended to "expected '('" sorts of + /// diagnostics when parsing the delimeters. + virtual ParseResult + parseCommaSeparatedList(Delimiter delimiter, + function_ref parseElementFn, + StringRef contextMessage = StringRef()) = 0; + + /// Parse a comma separated list of elements that must have at least one entry + /// in it. + ParseResult + parseCommaSeparatedList(function_ref parseElementFn) { + return parseCommaSeparatedList(Delimiter::None, std::move(parseElementFn)); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// @@ -610,21 +651,6 @@ /// Parse a single operand if present. virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0; - /// These are the supported delimiters around operand lists and region - /// argument lists, used by parseOperandList and parseRegionArgumentList. - enum class Delimiter { - /// Zero or more operands with no delimiters. - None, - /// Parens surrounding zero or more operands. - Paren, - /// Square brackets surrounding zero or more operands. - Square, - /// Parens supporting zero or more operands, or nothing. - OptionalParen, - /// Square brackets supporting zero or more ops, or nothing. - OptionalSquare, - }; - /// Parse zero or more SSA comma-separated operand references with a specified /// surrounding delimiter, and an optional required operand count. virtual ParseResult diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -158,7 +158,6 @@ // Sizes of parsed variadic operands, will be updated below after parsing. int32_t numDependencies = 0; - int32_t numOperands = 0; auto tokenTy = TokenType::get(ctx); @@ -179,38 +178,27 @@ SmallVector valueTypes; SmallVector unwrappedTypes; - if (succeeded(parser.parseOptionalLParen())) { - auto argsLoc = parser.getCurrentLocation(); - - // Parse a single instance of `%value as %unwrapped : !async.value`. - auto parseAsyncValueArg = [&]() -> ParseResult { - if (parser.parseOperand(valueArgs.emplace_back()) || - parser.parseKeyword("as") || - parser.parseOperand(unwrappedArgs.emplace_back()) || - parser.parseColonType(valueTypes.emplace_back())) - return failure(); - - auto valueTy = valueTypes.back().dyn_cast(); - unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); - - return success(); - }; - - // If the next token is `)` skip async value arguments parsing. - if (failed(parser.parseOptionalRParen())) { - do { - if (parseAsyncValueArg()) - return failure(); - } while (succeeded(parser.parseOptionalComma())); - - if (parser.parseRParen() || - parser.resolveOperands(valueArgs, valueTypes, argsLoc, - result.operands)) - return failure(); - } + // Parse a single instance of `%value as %unwrapped : !async.value`. + auto parseAsyncValueArg = [&]() -> ParseResult { + if (parser.parseOperand(valueArgs.emplace_back()) || + parser.parseKeyword("as") || + parser.parseOperand(unwrappedArgs.emplace_back()) || + parser.parseColonType(valueTypes.emplace_back())) + return failure(); - numOperands = valueArgs.size(); - } + auto valueTy = valueTypes.back().dyn_cast(); + unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); + + return success(); + }; + + auto argsLoc = parser.getCurrentLocation(); + if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen, + parseAsyncValueArg) || + parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands)) + return failure(); + + int32_t numOperands = valueArgs.size(); // Add derived `operand_segment_sizes` attribute based on parsed operands. auto operandSegmentSizes = DenseIntElementsAttr::get( diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -77,22 +77,16 @@ parseOperandAndTypeList(OpAsmParser &parser, SmallVectorImpl &operands, SmallVectorImpl &types) { - if (parser.parseLParen()) - return failure(); - - do { - OpAsmParser::OperandType operand; - Type type; - if (parser.parseOperand(operand) || parser.parseColonType(type)) - return failure(); - operands.push_back(operand); - types.push_back(type); - } while (succeeded(parser.parseOptionalComma())); - - if (parser.parseRParen()) - return failure(); - - return success(); + return parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { + OpAsmParser::OperandType operand; + Type type; + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); + operands.push_back(operand); + types.push_back(type); + return success(); + }); } /// Parse an allocate clause with allocators and a list of operands with types. @@ -108,30 +102,24 @@ SmallVectorImpl &typesAllocate, SmallVectorImpl &operandsAllocator, SmallVectorImpl &typesAllocator) { - if (parser.parseLParen()) - return failure(); - do { - OpAsmParser::OperandType operand; - Type type; - - if (parser.parseOperand(operand) || parser.parseColonType(type)) - return failure(); - operandsAllocator.push_back(operand); - typesAllocator.push_back(type); - if (parser.parseArrow()) - return failure(); - if (parser.parseOperand(operand) || parser.parseColonType(type)) - return failure(); - - operandsAllocate.push_back(operand); - typesAllocate.push_back(type); - } while (succeeded(parser.parseOptionalComma())); - - if (parser.parseRParen()) - return failure(); + return parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { + OpAsmParser::OperandType operand; + Type type; + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); + operandsAllocator.push_back(operand); + typesAllocator.push_back(type); + if (parser.parseArrow()) + return failure(); + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); - return success(); + operandsAllocate.push_back(operand); + typesAllocate.push_back(type); + return success(); + }); } static LogicalResult verifyParallelOp(ParallelOp op) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1778,16 +1778,16 @@ if (!parser.parseOptionalComma()) { // Parse the interface variables - do { - // The name of the interface variable attribute isnt important - auto attrName = "var_symbol"; - FlatSymbolRefAttr var; - NamedAttrList attrs; - if (parser.parseAttribute(var, Type(), attrName, attrs)) { - return failure(); - } - interfaceVars.push_back(var); - } while (!parser.parseOptionalComma()); + if (parser.parseCommaSeparatedList([&]() -> ParseResult { + // The name of the interface variable attribute isnt important + FlatSymbolRefAttr var; + NamedAttrList attrs; + if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) + return failure(); + interfaceVars.push_back(var); + return success(); + })) + return failure(); } state.addAttribute(kInterfaceAttrName, parser.getBuilder().getArrayAttr(interfaceVars)); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2197,13 +2197,12 @@ SmallVectorImpl &caseDestinations, SmallVectorImpl> &caseOperands, SmallVectorImpl> &caseOperandTypes) { - if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) || - failed(parser.parseSuccessor(defaultDestination))) + if (parser.parseKeyword("default") || parser.parseColon() || + parser.parseSuccessor(defaultDestination)) return failure(); if (succeeded(parser.parseOptionalLParen())) { - if (failed(parser.parseRegionArgumentList(defaultOperands)) || - failed(parser.parseColonTypeList(defaultOperandTypes)) || - failed(parser.parseRParen())) + if (parser.parseRegionArgumentList(defaultOperands) || + parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) return failure(); } diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp --- a/mlir/lib/Parser/AffineParser.cpp +++ b/mlir/lib/Parser/AffineParser.cpp @@ -474,26 +474,22 @@ /// Parse the list of dimensional identifiers to an affine map. ParseResult AffineParser::parseDimIdList(unsigned &numDims) { - if (parseToken(Token::l_paren, - "expected '(' at start of dimensional identifiers list")) { - return failure(); - } - auto parseElt = [&]() -> ParseResult { auto dimension = getAffineDimExpr(numDims++, getContext()); return parseIdentifierDefinition(dimension); }; - return parseCommaSeparatedListUntil(Token::r_paren, parseElt); + return parseCommaSeparatedList(Delimiter::Paren, parseElt, + " in dimensional identifier list"); } /// Parse the list of symbolic identifiers to an affine map. ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { - consumeToken(Token::l_square); auto parseElt = [&]() -> ParseResult { auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); return parseIdentifierDefinition(symbol); }; - return parseCommaSeparatedListUntil(Token::r_square, parseElt); + return parseCommaSeparatedList(Delimiter::Square, parseElt, + " in symbol list"); } /// Parse the list of symbolic identifiers to an affine map. @@ -544,21 +540,6 @@ ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map, OpAsmParser::Delimiter delimiter) { - Token::Kind rightToken; - switch (delimiter) { - case OpAsmParser::Delimiter::Square: - if (parseToken(Token::l_square, "expected '['")) - return failure(); - rightToken = Token::r_square; - break; - case OpAsmParser::Delimiter::Paren: - if (parseToken(Token::l_paren, "expected '('")) - return failure(); - rightToken = Token::r_paren; - break; - default: - return emitError("unexpected delimiter"); - } SmallVector exprs; auto parseElt = [&]() -> ParseResult { @@ -571,9 +552,9 @@ // 1-d affine expressions); the list can be empty. Grammar: // multi-dim-affine-expr ::= `(` `)` // | `(` affine-expr (`,` affine-expr)* `)` - if (parseCommaSeparatedListUntil(rightToken, parseElt, - /*allowEmptyList=*/true)) + if (parseCommaSeparatedList(delimiter, parseElt, " in affine map")) return failure(); + // Parsed a valid affine map. map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, exprs, getContext()); @@ -594,8 +575,6 @@ /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` AffineMap AffineParser::parseAffineMapRange(unsigned numDims, unsigned numSymbols) { - parseToken(Token::l_paren, "expected '(' at start of affine map range"); - SmallVector exprs; auto parseElt = [&]() -> ParseResult { auto elt = parseAffineExpr(); @@ -608,7 +587,8 @@ // 1-d affine expressions). Grammar: // multi-dim-affine-expr ::= `(` `)` // | `(` affine-expr (`,` affine-expr)* `)` - if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) + if (parseCommaSeparatedList(Delimiter::Paren, parseElt, + " in affine map range")) return AffineMap(); // Parsed a valid affine map. @@ -662,10 +642,6 @@ /// IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols) { - if (parseToken(Token::l_paren, - "expected '(' at start of integer set constraint list")) - return IntegerSet(); - SmallVector constraints; SmallVector isEqs; auto parseElt = [&]() -> ParseResult { @@ -680,7 +656,8 @@ }; // Parse a list of affine constraints (comma-separated). - if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) + if (parseCommaSeparatedList(Delimiter::Paren, parseElt, + " in integer set constraint list")) return IntegerSet(); // If no constraints were parsed, then treat this as a degenerate 'true' case. diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -67,15 +67,13 @@ // Parse an array attribute. case Token::l_square: { - consumeToken(Token::l_square); - SmallVector elements; auto parseElt = [&]() -> ParseResult { elements.push_back(parseAttribute()); return elements.back() ? success() : failure(); }; - if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) + if (parseCommaSeparatedList(Delimiter::Square, parseElt)) return nullptr; return builder.getArrayAttr(elements); } @@ -262,9 +260,6 @@ /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value /// ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { - if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) - return failure(); - llvm::SmallDenseSet seenKeys; auto parseElt = [&]() -> ParseResult { // The name of an attribute can either be a bare identifier, or a string. @@ -300,7 +295,8 @@ return success(); }; - if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) + if (parseCommaSeparatedList(Delimiter::Braces, parseElt, + " in attribute dictionary")) return failure(); return success(); @@ -769,8 +765,6 @@ /// parseList([[1, 2], 3]) -> Failure /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure ParseResult TensorLiteralParser::parseList(SmallVectorImpl &dims) { - p.consumeToken(Token::l_square); - auto checkDims = [&](const SmallVectorImpl &prevDims, const SmallVectorImpl &newDims) -> ParseResult { if (prevDims == newDims) @@ -782,7 +776,7 @@ bool first = true; SmallVector newDims; unsigned size = 0; - auto parseCommaSeparatedList = [&]() -> ParseResult { + auto parseOneElement = [&]() -> ParseResult { SmallVector thisDims; if (p.getToken().getKind() == Token::l_square) { if (parseList(thisDims)) @@ -797,7 +791,7 @@ first = false; return success(); }; - if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) + if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement)) return failure(); // Return the sublists' dimensions with 'size' prepended. diff --git a/mlir/lib/Parser/LocationParser.cpp b/mlir/lib/Parser/LocationParser.cpp --- a/mlir/lib/Parser/LocationParser.cpp +++ b/mlir/lib/Parser/LocationParser.cpp @@ -82,9 +82,8 @@ return success(); }; - if (parseToken(Token::l_square, "expected '[' in fused location") || - parseCommaSeparatedList(parseElt) || - parseToken(Token::r_square, "expected ']' in fused location")) + if (parseCommaSeparatedList(Delimiter::Square, parseElt, + " in fused location")) return failure(); // Return the fused location. diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -24,6 +24,8 @@ /// include state. class Parser { public: + using Delimiter = OpAsmParser::Delimiter; + Builder builder; Parser(ParserState &state) : builder(state.context), state(state) {} @@ -39,9 +41,20 @@ function_ref parseElement, bool allowEmptyList = true); + /// Parse a list of comma-separated items with an optional delimiter. If a + /// delimiter is provided, then an empty list is allowed. If not, then at + /// least one element will be parsed. + ParseResult + parseCommaSeparatedList(Delimiter delimiter, + function_ref parseElementFn, + StringRef contextMessage = StringRef()); + /// Parse a comma separated list of elements that must have at least one entry /// in it. - ParseResult parseCommaSeparatedList(function_ref parseElement); + ParseResult + parseCommaSeparatedList(function_ref parseElementFn) { + return parseCommaSeparatedList(Delimiter::None, parseElementFn); + } ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); @@ -266,7 +279,7 @@ ParseResult parseAffineMapOfSSAIds(AffineMap &map, function_ref parseElement, - OpAsmParser::Delimiter delimiter); + Delimiter delimiter); /// Parse an AffineExpr where dim and symbol identifiers are SSA ids. ParseResult diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -35,20 +35,90 @@ // Parser //===----------------------------------------------------------------------===// -/// Parse a comma separated list of elements that must have at least one entry -/// in it. +/// Parse a list of comma-separated items with an optional delimiter. If a +/// delimiter is provided, then an empty list is allowed. If not, then at +/// least one element will be parsed. ParseResult -Parser::parseCommaSeparatedList(function_ref parseElement) { +Parser::parseCommaSeparatedList(Delimiter delimiter, + function_ref parseElementFn, + StringRef contextMessage) { + switch (delimiter) { + case Delimiter::None: + break; + case Delimiter::OptionalParen: + if (getToken().isNot(Token::l_paren)) + return success(); + LLVM_FALLTHROUGH; + case Delimiter::Paren: + if (parseToken(Token::l_paren, "expected '('" + contextMessage)) + return failure(); + // Check for empty list. + if (consumeIf(Token::r_paren)) + return success(); + break; + case Delimiter::OptionalLessGreater: + // Check for absent list. + if (getToken().isNot(Token::less)) + return success(); + LLVM_FALLTHROUGH; + case Delimiter::LessGreater: + if (parseToken(Token::less, "expected '<'" + contextMessage)) + return success(); + // Check for empty list. + if (consumeIf(Token::greater)) + return success(); + break; + case Delimiter::OptionalSquare: + if (getToken().isNot(Token::l_square)) + return success(); + LLVM_FALLTHROUGH; + case Delimiter::Square: + if (parseToken(Token::l_square, "expected '['" + contextMessage)) + return failure(); + // Check for empty list. + if (consumeIf(Token::r_square)) + return success(); + break; + case Delimiter::OptionalBraces: + if (getToken().isNot(Token::l_brace)) + return success(); + LLVM_FALLTHROUGH; + case Delimiter::Braces: + if (parseToken(Token::l_brace, "expected '{'" + contextMessage)) + return failure(); + // Check for empty list. + if (consumeIf(Token::r_brace)) + return success(); + break; + } + // Non-empty case starts with an element. - if (parseElement()) + if (parseElementFn()) return failure(); // Otherwise we have a list of comma separated elements. while (consumeIf(Token::comma)) { - if (parseElement()) + if (parseElementFn()) return failure(); } - return success(); + + switch (delimiter) { + case Delimiter::None: + return success(); + case Delimiter::OptionalParen: + case Delimiter::Paren: + return parseToken(Token::r_paren, "expected ')'" + contextMessage); + case Delimiter::OptionalLessGreater: + case Delimiter::LessGreater: + return parseToken(Token::greater, "expected '>'" + contextMessage); + case Delimiter::OptionalSquare: + case Delimiter::Square: + return parseToken(Token::r_square, "expected ']'" + contextMessage); + case Delimiter::OptionalBraces: + case Delimiter::Braces: + return parseToken(Token::r_brace, "expected '}'" + contextMessage); + } + llvm_unreachable("Unknown delimiter"); } /// Parse a comma-separated list of elements, terminated with an arbitrary @@ -1282,6 +1352,15 @@ return parser.parseOptionalInteger(result); } + /// Parse a list of comma-separated items with an optional delimiter. If a + /// delimiter is provided, then an empty list is allowed. If not, then at + /// least one element will be parsed. + ParseResult parseCommaSeparatedList(Delimiter delimiter, + function_ref parseElt, + StringRef contextMessage) override { + return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// @@ -1467,67 +1546,37 @@ Delimiter delimiter = Delimiter::None) { auto startLoc = parser.getToken().getLoc(); - // Handle delimiters. - switch (delimiter) { - case Delimiter::None: - // Don't check for the absence of a delimiter if the number of operands - // is unknown (and hence the operand list could be empty). - if (requiredOperandCount == -1) - break; - // Token already matches an identifier and so can't be a delimiter. - if (parser.getToken().is(Token::percent_identifier)) - break; - // Test against known delimiters. - if (parser.getToken().is(Token::l_paren) || - parser.getToken().is(Token::l_square)) - return emitError(startLoc, "unexpected delimiter"); - return emitError(startLoc, "invalid operand"); - case Delimiter::OptionalParen: - if (parser.getToken().isNot(Token::l_paren)) - return success(); - LLVM_FALLTHROUGH; - case Delimiter::Paren: - if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) - return failure(); - break; - case Delimiter::OptionalSquare: - if (parser.getToken().isNot(Token::l_square)) - return success(); - LLVM_FALLTHROUGH; - case Delimiter::Square: - if (parser.parseToken(Token::l_square, "expected '[' in operand list")) - return failure(); - break; - } - - // Check for zero operands. - if (parser.getToken().is(Token::percent_identifier)) { - do { - OperandType operandOrArg; - if (isOperandList ? parseOperand(operandOrArg) - : parseRegionArgument(operandOrArg)) - return failure(); - result.push_back(operandOrArg); - } while (parser.consumeIf(Token::comma)); + // The no-delimiter case has some special handling for better diagnostics. + if (delimiter == Delimiter::None) { + // parseCommaSeparatedList doesn't handle the missing case for "none", + // so we handle it custom here. + if (parser.getToken().isNot(Token::percent_identifier)) { + // If we didn't require any operands or required exactly zero (weird) + // then this is success. + if (requiredOperandCount == -1 || requiredOperandCount == 0) + return success(); + + // Otherwise, try to produce a nice error message. + if (parser.getToken().is(Token::l_paren) || + parser.getToken().is(Token::l_square)) + return emitError(startLoc, "unexpected delimiter"); + return emitError(startLoc, "invalid operand"); + } } - // Handle delimiters. If we reach here, the optional delimiters were - // present, so we need to parse their closing one. - switch (delimiter) { - case Delimiter::None: - break; - case Delimiter::OptionalParen: - case Delimiter::Paren: - if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) - return failure(); - break; - case Delimiter::OptionalSquare: - case Delimiter::Square: - if (parser.parseToken(Token::r_square, "expected ']' in operand list")) + auto parseOneOperand = [&]() -> ParseResult { + OperandType operandOrArg; + if (isOperandList ? parseOperand(operandOrArg) + : parseRegionArgument(operandOrArg)) return failure(); - break; - } + result.push_back(operandOrArg); + return success(); + }; + + if (parseCommaSeparatedList(delimiter, parseOneOperand, " in operand list")) + return failure(); + // Check that we got the expected # of elements. if (requiredOperandCount != -1 && result.size() != static_cast(requiredOperandCount)) return emitError(startLoc, "expected ") diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -165,15 +165,12 @@ return emitError("expected comma after offset value"); // Parse stride list. - if (!consumeIf(Token::kw_strides)) - return emitError("expected `strides` keyword after offset specification"); - if (!consumeIf(Token::colon)) - return emitError("expected colon after `strides` keyword"); - if (failed(parseStrideList(strides))) - return emitError("invalid braces-enclosed stride list"); - if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) - return emitError("invalid memref stride"); + if (parseToken(Token::kw_strides, + "expected `strides` keyword after offset specification") || + parseToken(Token::colon, "expected colon after `strides` keyword") || + parseStrideList(strides)) + return failure(); return success(); } @@ -563,31 +560,30 @@ // Parse a comma-separated list of dimensions, possibly empty: // stride-list ::= `[` (dimension (`,` dimension)*)? `]` ParseResult Parser::parseStrideList(SmallVectorImpl &dimensions) { - if (!consumeIf(Token::l_square)) - return failure(); - // Empty list early exit. - if (consumeIf(Token::r_square)) - return success(); - while (true) { - if (consumeIf(Token::question)) { - dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); - } else { - // This must be an integer value. - int64_t val; - if (getToken().getSpelling().getAsInteger(10, val)) - return emitError("invalid integer value: ") << getToken().getSpelling(); - // Make sure it is not the one value for `?`. - if (ShapedType::isDynamic(val)) - return emitError("invalid integer value: ") - << getToken().getSpelling() - << ", use `?` to specify a dynamic dimension"; - dimensions.push_back(val); - consumeToken(Token::integer); - } - if (!consumeIf(Token::comma)) - break; - } - if (!consumeIf(Token::r_square)) - return failure(); - return success(); + return parseCommaSeparatedList( + Delimiter::Square, + [&]() -> ParseResult { + if (consumeIf(Token::question)) { + dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); + } else { + // This must be an integer value. + int64_t val; + if (getToken().getSpelling().getAsInteger(10, val)) + return emitError("invalid integer value: ") + << getToken().getSpelling(); + // Make sure it is not the one value for `?`. + if (ShapedType::isDynamic(val)) + return emitError("invalid integer value: ") + << getToken().getSpelling() + << ", use `?` to specify a dynamic dimension"; + + if (val == 0) + return emitError("invalid memref stride"); + + dimensions.push_back(val); + consumeToken(Token::integer); + } + return success(); + }, + " in stride list"); } diff --git a/mlir/test/IR/invalid-affinemap.mlir b/mlir/test/IR/invalid-affinemap.mlir --- a/mlir/test/IR/invalid-affinemap.mlir +++ b/mlir/test/IR/invalid-affinemap.mlir @@ -29,7 +29,9 @@ #hello_world = affine_map<(i, j) [s0] -> (((s0 + (i + j) + 5), j)> // expected-error {{expected ')'}} // ----- -#hello_world = affine_map<(i, j) [s0] -> i + s0, j)> // expected-error {{expected '(' at start of affine map range}} + +// expected-error @+1 {{expected '(' in affine map range}} +#hello_world = affine_map<(i, j) [s0] -> i + s0, j)> // ----- #hello_world = affine_map<(i, j) [s0] -> (x)> // expected-error {{use of undeclared identifier}} @@ -47,6 +49,8 @@ #hello_world = affine_map<(i, j) [s0, s1] -> (+i, j)> // expected-error {{missing left operand of binary op}} // ----- + + #hello_world = affine_map<(i, j) [s0, s1] -> (i, *j)> // expected-error {{missing left operand of binary op}} // ----- @@ -91,7 +95,8 @@ #hello_world = affine_map<(i, j) [s0, s1] -> (i, i mod (2+i))> // expected-error {{non-affine expression: right operand of mod has to be either a constant or symbolic}} // ----- -#hello_world = affine_map<(i, j) [s0, s1] -> (-1*i j, j)> // expected-error {{expected ',' or ')'}} +// expected-error @+1 {{expected ')' in affine map range}} +#hello_world = affine_map<(i, j) [s0, s1] -> (-1*i j, j)> // ----- #hello_world = affine_map<(i, j) -> (i, 3*d0 + )> // expected-error {{use of undeclared identifier}} diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -92,7 +92,8 @@ // ----- -func @memref_stride_invalid_strides(memref<42x42xi8, offset: 0, strides: ()>) // expected-error {{invalid braces-enclosed stride list}} +// expected-error @+1 {{expected '['}} +func @memref_stride_invalid_strides(memref<42x42xi8, offset: 0, strides: ()>) // ----- @@ -633,7 +634,8 @@ // ----- -#set0 = affine_set<(i)[N, M] : )i >= 0)> // expected-error {{expected '(' at start of integer set constraint list}} +// expected-error @+1 {{expected '(' in integer set constraint list}} +#set0 = affine_set<(i)[N, M] : )i >= 0)> // ----- #set0 = affine_set<(i)[N] : (i >= 0, N - i >= 0)>