diff --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h --- a/mlir/include/mlir/AsmParser/AsmParser.h +++ b/mlir/include/mlir/AsmParser/AsmParser.h @@ -43,38 +43,22 @@ AsmParserState *asmState = nullptr, AsmParserCodeCompleteContext *codeCompleteContext = nullptr); -/// This parses a single MLIR attribute to an MLIR context if it was valid. If -/// not, an error message is emitted through a new SourceMgrDiagnosticHandler -/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping -/// `attrStr`. If the passed `attrStr` has additional tokens that were not part -/// of the type, an error is emitted. -// TODO: Improve diagnostic reporting. -Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context); -Attribute parseAttribute(llvm::StringRef attrStr, Type type); - -/// This parses a single MLIR attribute to an MLIR context if it was valid. If -/// not, an error message is emitted through a new SourceMgrDiagnosticHandler -/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping -/// `attrStr`. The number of characters of `attrStr` parsed in the process is -/// returned in `numRead`. +/// This parses a single MLIR attribute to an MLIR context if it was valid. If +/// not, an error diagnostic is emitted to the context and a null value is +/// returned. +/// If `numRead` is provided, it is set to the number of consumed characters on +/// succesful parse. Otherwise, parsing fails if the entire string is not +/// consumed. Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, - size_t &numRead); -Attribute parseAttribute(llvm::StringRef attrStr, Type type, size_t &numRead); - -/// This parses a single MLIR type to an MLIR context if it was valid. If not, -/// an error message is emitted through a new SourceMgrDiagnosticHandler -/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping -/// `typeStr`. If the passed `typeStr` has additional tokens that were not part -/// of the type, an error is emitted. -// TODO: Improve diagnostic reporting. -Type parseType(llvm::StringRef typeStr, MLIRContext *context); + Type type = {}, size_t *numRead = nullptr); -/// This parses a single MLIR type to an MLIR context if it was valid. If not, -/// an error message is emitted through a new SourceMgrDiagnosticHandler -/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping -/// `typeStr`. The number of characters of `typeStr` parsed in the process is -/// returned in `numRead`. -Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t &numRead); +/// This parses a single MLIR type to an MLIR context if it was valid. If not, +/// an error diagnostic is emitted to the context. +/// If `numRead` is provided, it is set to the number of consumed characters on +/// succesful parse. Otherwise, parsing fails if the entire string is not +/// consumed. +Type parseType(llvm::StringRef typeStr, MLIRContext *context, + size_t *numRead = nullptr); /// This parses a single IntegerSet/AffineMap to an MLIR context if it was /// valid. If not, an error message is emitted through a new diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -309,12 +309,13 @@ /// parsing failed, nullptr is returned. The number of bytes read from the input /// string is returned in 'numRead'. template -static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, - ParserFn &&parserFn) { +static T parseSymbol(StringRef inputStr, MLIRContext *context, + size_t *numReadOut, ParserFn &&parserFn) { + // Set the buffer name to the string being parsed, so that it appears in error + // diagnostics. + auto memBuffer = MemoryBuffer::getMemBuffer(inputStr, /*BufferName=*/inputStr, + /*RequiresNullTerminator=*/true); SourceMgr sourceMgr; - auto memBuffer = MemoryBuffer::getMemBuffer( - inputStr, /*BufferName=*/"", - /*RequiresNullTerminator=*/false); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); SymbolState aliasState; ParserConfig config(context); @@ -322,9 +323,6 @@ /*codeCompleteContext=*/nullptr); Parser parser(state); - SourceMgrDiagnosticHandler handler( - const_cast(parser.getSourceMgr()), - parser.getContext()); Token startTok = parser.getToken(); T symbol = parserFn(parser); if (!symbol) @@ -332,38 +330,25 @@ // Provide the number of bytes that were read. Token endTok = parser.getToken(); - numRead = static_cast(endTok.getLoc().getPointer() - - startTok.getLoc().getPointer()); + size_t numRead = + endTok.getLoc().getPointer() - startTok.getLoc().getPointer(); + if (numReadOut) { + *numReadOut = numRead; + } else if (numRead != inputStr.size()) { + parser.emitError(endTok.getLoc()) << "found trailing characters: '" + << inputStr.drop_front(numRead) << "'"; + return T(); + } return symbol; } -Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { - size_t numRead = 0; - return parseAttribute(attrStr, context, numRead); -} -Attribute mlir::parseAttribute(StringRef attrStr, Type type) { - size_t numRead = 0; - return parseAttribute(attrStr, type, numRead); -} - Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, - size_t &numRead) { - return parseSymbol(attrStr, context, numRead, [](Parser &parser) { - return parser.parseAttribute(); - }); -} -Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { + Type type, size_t *numRead) { return parseSymbol( - attrStr, type.getContext(), numRead, + attrStr, context, numRead, [type](Parser &parser) { return parser.parseAttribute(type); }); } - -Type mlir::parseType(StringRef typeStr, MLIRContext *context) { - size_t numRead = 0; - return parseType(typeStr, context, numRead); -} - -Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { +Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead) { return parseSymbol(typeStr, context, numRead, [](Parser &parser) { return parser.parseType(); }); } diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -1031,9 +1031,9 @@ size_t numRead = 0; MLIRContext *context = fileLoc->getContext(); if constexpr (std::is_same_v) - result = ::parseType(asmStr, context, numRead); + result = ::parseType(asmStr, context, &numRead); else - result = ::parseAttribute(asmStr, context, numRead); + result = ::parseAttribute(asmStr, context, Type(), &numRead); if (!result) return failure(); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1692,14 +1692,15 @@ Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = attr.dyn_cast()) { - paddingValues.push_back( - parseAttribute(attr.cast(), elementType)); - if (!paddingValues.back()) { + auto parsedAttr = dyn_cast_if_present( + parseAttribute(stringAttr, getContext(), elementType)); + if (!parsedAttr || parsedAttr.getType() != elementType) { auto diag = this->emitOpError("expects a padding that parses to ") << elementType << ", got " << std::get<0>(it); diag.attachNote(target.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } + paddingValues.push_back(parsedAttr); continue; } // Otherwise, add the attribute directly. diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -117,9 +117,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation - // expected-error @below {{expects a padding that parses to 'f32', got "foo"}} + // expected-error @below {{expects a padding that parses to 'f32', got "{foo}"}} %1 = transform.structured.pad %0 { - padding_values=["foo", 0.0 : f32, 0.0 : f32], + padding_values=["{foo}", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0] } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -317,8 +317,10 @@ SerializedAffineMap &value) { assert(rawYamlContext); auto *yamlContext = static_cast(rawYamlContext); - if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext) - .dyn_cast_or_null()) + std::string nullTerminatedScalar(scalar); + if (auto attr = + mlir::parseAttribute(nullTerminatedScalar, yamlContext->mlirContext) + .dyn_cast_or_null()) value.affineMapAttr = attr; else if (!value.affineMapAttr || !value.affineMapAttr.isa()) return "could not parse as an affine map attribute"; diff --git a/mlir/unittests/Parser/ParserTest.cpp b/mlir/unittests/Parser/ParserTest.cpp --- a/mlir/unittests/Parser/ParserTest.cpp +++ b/mlir/unittests/Parser/ParserTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Parser/Parser.h" +#include "mlir/AsmParser/AsmParser.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Verifier.h" @@ -55,4 +56,44 @@ EXPECT_EQ(block.front().getName().getStringRef(), "test.first"); EXPECT_EQ(block.back().getName().getStringRef(), "test.second"); } + +TEST(MLIRParser, ParseAttr) { + using namespace testing; + MLIRContext context; + Builder b(&context); + { // Successful parse + StringLiteral attrAsm = "array"; + size_t numRead = 0; + Attribute attr = parseAttribute(attrAsm, &context, Type(), &numRead); + EXPECT_EQ(attr, b.getDenseI64ArrayAttr({1, 2, 3})); + EXPECT_EQ(numRead, attrAsm.size()); + } + { // Failed parse + std::vector diagnostics; + ScopedDiagnosticHandler handler(&context, [&](Diagnostic &d) { + llvm::raw_string_ostream(diagnostics.emplace_back()) + << d.getLocation() << ": " << d; + }); + size_t numRead = 0; + EXPECT_FALSE(parseAttribute("dense<>", &context, Type(), &numRead)); + EXPECT_THAT(diagnostics, ElementsAre("loc(\"dense<>\":1:7): expected ':'")); + EXPECT_EQ(numRead, size_t(0)); + } + { // Parse with trailing characters + std::vector diagnostics; + ScopedDiagnosticHandler handler(&context, [&](Diagnostic &d) { + llvm::raw_string_ostream(diagnostics.emplace_back()) + << d.getLocation() << ": " << d; + }); + EXPECT_FALSE(parseAttribute("10 foo", &context)); + EXPECT_THAT( + diagnostics, + ElementsAre("loc(\"10 foo\":1:5): found trailing characters: 'foo'")); + + size_t numRead = 0; + EXPECT_EQ(parseAttribute("10 foo", &context, Type(), &numRead), + b.getI64IntegerAttr(10)); + EXPECT_EQ(numRead, size_t(4)); // includes trailing whitespace + } +} } // namespace