diff --git a/llvm/include/llvm/BinaryFormat/MsgPackDocument.h b/llvm/include/llvm/BinaryFormat/MsgPackDocument.h --- a/llvm/include/llvm/BinaryFormat/MsgPackDocument.h +++ b/llvm/include/llvm/BinaryFormat/MsgPackDocument.h @@ -65,6 +65,7 @@ DocNode() : KindAndDoc(nullptr) {} // Type methods + bool isNil() const { return getKind() == Type::Nil; } bool isMap() const { return getKind() == Type::Map; } bool isArray() const { return getKind() == Type::Array; } bool isScalar() const { return !isMap() && !isArray(); } @@ -345,15 +346,30 @@ return N.getArray(); } - /// Read a MsgPack document from a binary MsgPack blob. - /// The blob data must remain valid for the lifetime of this Document (because - /// a string object in the document contains a StringRef into the original - /// blob). - /// If Multi, then this sets root to an array and adds top-level objects to - /// it. If !Multi, then it only reads a single top-level object, even if there - /// are more, and sets root to that. - /// Returns false if failed due to illegal format. - bool readFromBlob(StringRef Blob, bool Multi); + /// Read a document from a binary msgpack blob, merging into anything already + /// in the Document. The blob data must remain valid for the lifetime of this + /// Document (because a string object in the document contains a StringRef + /// into the original blob). If Multi, then this sets root to an array and + /// adds top-level objects to it. If !Multi, then it only reads a single + /// top-level object, even if there are more, and sets root to that. Returns + /// false if failed due to illegal format or merge error. + /// + /// The Merger arg is a callback function that is called when the merge has a + /// conflict, that is, it is trying to set an item that is already set. If the + /// conflict can be resolved, the callback function must set *DestNode to the + /// resolved node and return true, otherwise it returns false. If SrcNode is + /// an array or map, the resolution must be that *DestNode is an array or map + /// respectively, although it could be the array or map (respectively) that + /// was already there. MapKey is the key if *DestNode is a map entry, a nil + /// node otherwise. The default for Merger is to allow array and map merging, + /// and disallow any other conflict. + bool readFromBlob( + StringRef Blob, bool Multi, + function_ref + Merger = [](DocNode *DestNode, DocNode SrcNode, DocNode MapKey) { + return (DestNode->isMap() && SrcNode.isMap()) || + (DestNode->isArray() && SrcNode.isArray()); + }); /// Write a MsgPack document to a binary MsgPack blob. void writeToBlob(std::string &Blob); diff --git a/llvm/lib/BinaryFormat/MsgPackDocument.cpp b/llvm/lib/BinaryFormat/MsgPackDocument.cpp --- a/llvm/lib/BinaryFormat/MsgPackDocument.cpp +++ b/llvm/lib/BinaryFormat/MsgPackDocument.cpp @@ -60,26 +60,35 @@ // A level in the document reading stack. struct StackLevel { + StackLevel(DocNode Node, size_t Length, DocNode *MapEntry = nullptr) + : Node(Node), Length(Length), MapEntry(MapEntry), Count(0) {} DocNode Node; size_t Length; // Points to map entry when we have just processed a map key. DocNode *MapEntry; + DocNode MapKey; + size_t Count; }; -// Read a document from a binary msgpack blob. +// Read a document from a binary msgpack blob, merging into anything already in +// the Document. // The blob data must remain valid for the lifetime of this Document (because a // string object in the document contains a StringRef into the original blob). // If Multi, then this sets root to an array and adds top-level objects to it. // If !Multi, then it only reads a single top-level object, even if there are // more, and sets root to that. -// Returns false if failed due to illegal format. -bool Document::readFromBlob(StringRef Blob, bool Multi) { +// Returns false if failed due to illegal format or merge error. + +bool Document::readFromBlob( + StringRef Blob, bool Multi, + function_ref + Merger) { msgpack::Reader MPReader(Blob); SmallVector Stack; if (Multi) { // Create the array for multiple top-level objects. Root = getArrayNode(); - Stack.push_back(StackLevel({Root, (size_t)-1, nullptr})); + Stack.push_back(StackLevel(Root, (size_t)-1, nullptr)); } do { // On to next element (or key if doing a map key next). @@ -124,29 +133,47 @@ } // Store it. + DocNode *DestNode = nullptr; if (Stack.empty()) - Root = Node; + DestNode = &Root; else if (Stack.back().Node.getKind() == Type::Array) { // Reading an array entry. auto &Array = Stack.back().Node.getArray(); - Array.push_back(Node); + DestNode = &Array[Stack.back().Count++]; } else { auto &Map = Stack.back().Node.getMap(); if (!Stack.back().MapEntry) { // Reading a map key. + Stack.back().MapKey = Node; Stack.back().MapEntry = &Map[Node]; - } else { - // Reading the value for the map key read in the last iteration. - *Stack.back().MapEntry = Node; - Stack.back().MapEntry = nullptr; + continue; } + // Reading the value for the map key read in the last iteration. + DestNode = Stack.back().MapEntry; + Stack.back().MapEntry = nullptr; + if (DestNode->isEmpty()) + *DestNode = getNode(); + ++Stack.back().Count; } + if (!DestNode->isNil()) { + // In a merge, there is already a value at this position. Call the + // callback to attempt to resolve the conflict. The resolution must result + // in an array or map if Node is an array or map respectively. + DocNode MapKey = !Stack.empty() && !Stack.back().MapKey.isEmpty() + ? Stack.back().MapKey + : getNode(); + if (!Merger(DestNode, Node, MapKey)) + return false; // Merge conflict resolution failed + assert(!((Node.isMap() && !DestNode->isMap()) || + (Node.isArray() && !DestNode->isArray()))); + } else + *DestNode = Node; // See if we're starting a new array or map. - switch (Node.getKind()) { + switch (DestNode->getKind()) { case msgpack::Type::Array: case msgpack::Type::Map: - Stack.push_back(StackLevel({Node, Obj.Length, nullptr})); + Stack.push_back(StackLevel(*DestNode, Obj.Length, nullptr)); break; default: break; @@ -154,14 +181,10 @@ // Pop finished stack levels. while (!Stack.empty()) { - if (Stack.back().Node.getKind() == msgpack::Type::Array) { - if (Stack.back().Node.getArray().size() != Stack.back().Length) - break; - } else { - if (Stack.back().MapEntry || - Stack.back().Node.getMap().size() != Stack.back().Length) - break; - } + if (Stack.back().MapEntry) + break; + if (Stack.back().Count != Stack.back().Length) + break; Stack.pop_back(); } } while (!Stack.empty()); diff --git a/llvm/unittests/BinaryFormat/MsgPackDocumentTest.cpp b/llvm/unittests/BinaryFormat/MsgPackDocumentTest.cpp --- a/llvm/unittests/BinaryFormat/MsgPackDocumentTest.cpp +++ b/llvm/unittests/BinaryFormat/MsgPackDocumentTest.cpp @@ -35,7 +35,7 @@ ASSERT_EQ(SN.getKind(), Type::Nil); } -TEST(MsgPackDocument, TestReadMap) { +TEST(MsgPackDocument, TestReadMergeMap) { Document Doc; bool Ok = Doc.readFromBlob(StringRef("\x82\xa3" "foo" @@ -53,6 +53,65 @@ auto BarS = M["bar"]; ASSERT_EQ(BarS.getKind(), Type::Int); ASSERT_EQ(BarS.getInt(), 2); + + Ok = Doc.readFromBlob(StringRef("\x82\xa3" + "foz" + "\xd0\x03\xa3" + "baz" + "\xd0\x04"), + /*Multi=*/false); + ASSERT_TRUE(Ok); + ASSERT_EQ(M.size(), 4u); + FooS = M["foo"]; + ASSERT_EQ(FooS.getKind(), Type::Int); + ASSERT_EQ(FooS.getInt(), 1); + BarS = M["bar"]; + ASSERT_EQ(BarS.getKind(), Type::Int); + ASSERT_EQ(BarS.getInt(), 2); + auto FozS = M["foz"]; + ASSERT_EQ(FozS.getKind(), Type::Int); + ASSERT_EQ(FozS.getInt(), 3); + auto BazS = M["baz"]; + ASSERT_EQ(BazS.getKind(), Type::Int); + ASSERT_EQ(BazS.getInt(), 4); + + Ok = Doc.readFromBlob( + StringRef("\x82\xa3" + "foz" + "\xd0\x06\xa3" + "bay" + "\xd0\x08"), + /*Multi=*/false, [](DocNode *Dest, DocNode Src, DocNode MapKey) { + // Merger function that merges two ints by ORing their values, as long + // as the map key is "foz". + if (Src.isMap()) + return Dest->isMap(); + if (Src.isArray()) + return Dest->isArray(); + if (MapKey.isString() && MapKey.getString() == "foz" && + Dest->getKind() == Type::Int && Src.getKind() == Type::Int) { + *Dest = Src.getDocument()->getNode(Dest->getInt() | Src.getInt()); + return true; + } + return false; + }); + ASSERT_TRUE(Ok); + ASSERT_EQ(M.size(), 5u); + FooS = M["foo"]; + ASSERT_EQ(FooS.getKind(), Type::Int); + ASSERT_EQ(FooS.getInt(), 1); + BarS = M["bar"]; + ASSERT_EQ(BarS.getKind(), Type::Int); + ASSERT_EQ(BarS.getInt(), 2); + FozS = M["foz"]; + ASSERT_EQ(FozS.getKind(), Type::Int); + ASSERT_EQ(FozS.getInt(), 7); + BazS = M["baz"]; + ASSERT_EQ(BazS.getKind(), Type::Int); + ASSERT_EQ(BazS.getInt(), 4); + auto BayS = M["bay"]; + ASSERT_EQ(BayS.getKind(), Type::Int); + ASSERT_EQ(BayS.getInt(), 8); } TEST(MsgPackDocument, TestWriteInt) {