diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -64,6 +64,19 @@ assert(*curPtr == '<'); SmallVector nestedPunctuation; const char *codeCompleteLoc = state.lex.getCodeCompleteLoc(); + + // Functor used to emit an unbalanced punctuation error. + auto emitPunctError = [&] { + return emitError() << "unbalanced '" << nestedPunctuation.back() + << "' character in pretty dialect name"; + }; + // Functor used to check for unbalanced punctuation. + auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult { + if (nestedPunctuation.back() != expectedToken) + return emitPunctError(); + nestedPunctuation.pop_back(); + return success(); + }; do { // Handle code completions, which may appear in the middle of the symbol // body. @@ -77,10 +90,8 @@ switch (c) { case '\0': // This also handles the EOF case. - if (!nestedPunctuation.empty()) { - return emitError() << "unbalanced '" << nestedPunctuation.back() - << "' character in pretty dialect name"; - } + if (!nestedPunctuation.empty()) + return emitPunctError(); return emitError("unexpected nul or EOF in pretty dialect name"); case '<': case '[': @@ -96,20 +107,20 @@ continue; case '>': - if (nestedPunctuation.pop_back_val() != '<') - return emitError("unbalanced '>' character in pretty dialect name"); + if (failed(checkNestedPunctuation('<'))) + return failure(); break; case ']': - if (nestedPunctuation.pop_back_val() != '[') - return emitError("unbalanced ']' character in pretty dialect name"); + if (failed(checkNestedPunctuation('['))) + return failure(); break; case ')': - if (nestedPunctuation.pop_back_val() != '(') - return emitError("unbalanced ')' character in pretty dialect name"); + if (failed(checkNestedPunctuation('('))) + return failure(); break; case '}': - if (nestedPunctuation.pop_back_val() != '{') - return emitError("unbalanced '}' character in pretty dialect name"); + if (failed(checkNestedPunctuation('{'))) + return failure(); break; case '"': { // Dispatch to the lexer to lex past strings. diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -320,12 +320,12 @@ // ----- -// expected-error @+1 {{unbalanced ')' character in pretty dialect name}} +// expected-error @+1 {{unbalanced '[' character in pretty dialect name}} func.func private @struct_type_neg_offset(!spirv.struct<(f32 [0)>) -> () // ----- -// expected-error @+1 {{unbalanced ']' character in pretty dialect name}} +// expected-error @+1 {{unbalanced '(' character in pretty dialect name}} func.func private @struct_type_neg_offset(!spirv.struct<(f32 0])>) -> () // ----- @@ -497,7 +497,7 @@ // ----- -// expected-error @+1 {{unbalanced ')' character in pretty dialect name}} +// expected-error @+1 {{unbalanced '<' character in pretty dialect name}} func.func private @matrix_invalid_format(!spirv.matrix< 3 x vector<3xf32>) -> () // ----- diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -369,7 +369,7 @@ // ----- -func.func @dialect_type_missing_greater(!foo<) -> () { // expected-error {{unbalanced ')' character in pretty dialect name}} +func.func @dialect_type_missing_greater(!foo<) -> () { // expected-error {{unbalanced '<' character in pretty dialect name}} return // ----- @@ -414,7 +414,7 @@ // ----- -// expected-error @+1 {{unbalanced ']' character in pretty dialect name}} +// expected-error @+1 {{unbalanced '<' character in pretty dialect name}} func.func @invalid_unknown_type_dialect_name() -> !invalid.dialect // ----- @@ -582,7 +582,7 @@ // ----- -// expected-error @+1 {{unbalanced ')' character in pretty dialect name}} +// expected-error @+1 {{unbalanced '<' character in pretty dialect name}} func.func @bad_arrow(%arg : !unreg.ptr<(i32)->) // -----