diff --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h --- a/mlir/include/mlir/AsmParser/AsmParser.h +++ b/mlir/include/mlir/AsmParser/AsmParser.h @@ -49,16 +49,28 @@ /// If `numRead` is provided, it is set to the number of consumed characters on /// succesful parse. Otherwise, parsing fails if the entire string is not /// consumed. -Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, - Type type = {}, size_t *numRead = nullptr); +Attribute parseNullTerminatedAttribute(llvm::StringRef attrStr, + MLIRContext *context, Type type = {}, + size_t *numRead = nullptr); +inline Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, + Type type = {}, size_t *numRead = nullptr) { + // Ensure source is null-terminated by copying to std::string + return parseNullTerminatedAttribute(std::string(attrStr), context, type, + numRead); +} /// This parses a single MLIR type to an MLIR context if it was valid. If not, /// an error diagnostic is emitted to the context. /// If `numRead` is provided, it is set to the number of consumed characters on /// succesful parse. Otherwise, parsing fails if the entire string is not /// consumed. -Type parseType(llvm::StringRef typeStr, MLIRContext *context, - size_t *numRead = nullptr); +Type parseNullTerminatedType(llvm::StringRef typeStr, MLIRContext *context, + size_t *numRead = nullptr); +inline Type parseType(llvm::StringRef typeStr, MLIRContext *context, + size_t *numRead = nullptr) { + // Ensure source is null-terminated by copying to std::string + return parseNullTerminatedType(std::string(typeStr), context, numRead); +} /// This parses a single IntegerSet/AffineMap to an MLIR context if it was /// valid. If not, an error message is emitted through a new diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -342,13 +342,15 @@ return symbol; } -Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, - Type type, size_t *numRead) { +Attribute mlir::parseNullTerminatedAttribute(StringRef attrStr, + MLIRContext *context, Type type, + size_t *numRead) { return parseSymbol( attrStr, context, numRead, [type](Parser &parser) { return parser.parseAttribute(type); }); } -Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead) { +Type mlir::parseNullTerminatedType(StringRef typeStr, MLIRContext *context, + size_t *numRead) { return parseSymbol(typeStr, context, numRead, [](Parser &parser) { return parser.parseType(); }); } diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -1031,9 +1031,9 @@ size_t numRead = 0; MLIRContext *context = fileLoc->getContext(); if constexpr (std::is_same_v) - result = ::parseType(asmStr, context, &numRead); + result = ::parseNullTerminatedType(asmStr, context, &numRead); else - result = ::parseAttribute(asmStr, context, Type(), &numRead); + result = ::parseNullTerminatedAttribute(asmStr, context, Type(), &numRead); if (!result) return failure(); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1693,7 +1693,7 @@ // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = attr.dyn_cast()) { auto parsedAttr = dyn_cast_if_present( - parseAttribute(stringAttr, getContext(), elementType)); + parseNullTerminatedAttribute(stringAttr, getContext(), elementType)); if (!parsedAttr || parsedAttr.getType() != elementType) { auto diag = this->emitOpError("expects a padding that parses to ") << elementType << ", got " << std::get<0>(it); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -317,10 +317,8 @@ SerializedAffineMap &value) { assert(rawYamlContext); auto *yamlContext = static_cast(rawYamlContext); - std::string nullTerminatedScalar(scalar); - if (auto attr = - mlir::parseAttribute(nullTerminatedScalar, yamlContext->mlirContext) - .dyn_cast_or_null()) + if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext) + .dyn_cast_or_null()) value.affineMapAttr = attr; else if (!value.affineMapAttr || !value.affineMapAttr.isa()) return "could not parse as an affine map attribute"; diff --git a/mlir/unittests/Parser/ParserTest.cpp b/mlir/unittests/Parser/ParserTest.cpp --- a/mlir/unittests/Parser/ParserTest.cpp +++ b/mlir/unittests/Parser/ParserTest.cpp @@ -95,5 +95,10 @@ b.getI64IntegerAttr(10)); EXPECT_EQ(numRead, size_t(4)); // includes trailing whitespace } + { // Parse without null-terminator + StringRef attrAsm = StringRef("999", 1); + Attribute attr = parseAttribute(attrAsm, &context); + EXPECT_EQ(attr, b.getI64IntegerAttr(9)); + } } } // namespace