diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Endian.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/SourceMgr.h" #include @@ -2482,6 +2483,13 @@ } llvm::support::ulittle32_t align; memcpy(&align, blobData->data(), sizeof(uint32_t)); + if (align && !llvm::isPowerOf2_32(align)) { + return p.emitError(value.getLoc(), + "expected hex string blob for key '" + key + + "' to encode alignment in first 4 bytes, but got " + "non-power-of-2 value: " + + Twine(align)); + } // Get the data portion of the blob. StringRef data = StringRef(*blobData).drop_front(sizeof(uint32_t)); diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" @@ -516,6 +517,7 @@ private: /// The table of dialect resources within the bytecode file. SmallVector dialectResources; + llvm::StringMap dialectResourceHandleRenamingMap; }; class ParsedResourceEntry : public AsmParsedResourceEntry { @@ -604,6 +606,7 @@ EncodingReader &offsetReader, EncodingReader &resourceReader, StringSectionReader &stringReader, T *handler, const std::shared_ptr &bufferOwnerRef, + function_ref remapKey = {}, function_ref processKeyFn = {}) { uint64_t numResources; if (failed(offsetReader.parseVarInt(numResources))) @@ -635,6 +638,7 @@ // Otherwise, parse the resource value. EncodingReader entryReader(data, fileLoc); + key = remapKey(key); ParsedResourceEntry entry(key, kind, entryReader, stringReader, bufferOwnerRef); if (failed(handler->parseResource(entry))) @@ -665,8 +669,16 @@ // provides most of the arguments. auto parseGroup = [&](auto *handler, bool allowEmpty = false, function_ref keyFn = {}) { + auto resolveKey = [&](StringRef key) -> StringRef { + auto it = dialectResourceHandleRenamingMap.find(key); + if (it == dialectResourceHandleRenamingMap.end()) + return ""; + return it->second; + }; + return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, - stringReader, handler, bufferOwnerRef, keyFn); + stringReader, handler, bufferOwnerRef, resolveKey, + keyFn); }; // Read the external resources from the bytecode. @@ -714,6 +726,7 @@ << "unknown 'resource' key '" << key << "' for dialect '" << dialect->name << "'"; } + dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); dialectResources.push_back(*handle); return success(); }; diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Bytecode/BytecodeTest.cpp @@ -0,0 +1,75 @@ +//===- AdaptorTest.cpp - Adaptor unit tests -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" + +#include "llvm/ADT/StringRef.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace mlir; + +using testing::ElementsAre; + +StringLiteral IRWithResources = R"( +module @TestDialectResources attributes { + bytecode.test = dense_resource : tensor<4xi32> +} {} +{-# + dialect_resources: { + builtin: { + resource: "0x1000000001000000020000000300000004000000" + } + } +#-} +)"; + +TEST(Bytecode, MultiModuleWithResource) { + MLIRContext context; + Builder builder(&context); + ParserConfig parseConfig(&context); + OwningOpRef module = + parseSourceString(IRWithResources, parseConfig); + ASSERT_TRUE(module); + + // Write the module to bytecode + std::string buffer; + llvm::raw_string_ostream ostream(buffer); + ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream))); + + // Parse it back + OwningOpRef roundTripModule = + parseSourceString(ostream.str(), parseConfig); + ASSERT_TRUE(roundTripModule); + + // Try to see if we have a valid resource in the parsed module. + auto checkResourceAttribute = [&](Operation *op) { + Attribute attr = roundTripModule->getAttr("bytecode.test"); + ASSERT_TRUE(attr); + auto denseResourceAttr = dyn_cast(attr); + ASSERT_TRUE(denseResourceAttr); + std::optional> attrData = + denseResourceAttr.tryGetAsArrayRef(); + ASSERT_TRUE(attrData.has_value()); + ASSERT_EQ(attrData->size(), static_cast(4)); + EXPECT_EQ((*attrData)[0], 1); + EXPECT_EQ((*attrData)[1], 2); + EXPECT_EQ((*attrData)[2], 3); + EXPECT_EQ((*attrData)[3], 4); + }; + + checkResourceAttribute(*module); + checkResourceAttribute(*roundTripModule); +} diff --git a/mlir/unittests/Bytecode/CMakeLists.txt b/mlir/unittests/Bytecode/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Bytecode/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_unittest(MLIRBytecodeTests + BytecodeTest.cpp +) +target_link_libraries(MLIRBytecodeTests + PRIVATE + MLIRBytecodeReader + MLIRBytecodeWriter + MLIRParser +) diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -5,6 +5,7 @@ endfunction() add_subdirectory(Analysis) +add_subdirectory(Bytecode) add_subdirectory(Conversion) add_subdirectory(Debug) add_subdirectory(Dialect)