Index: include/llvm/Support/YAMLTraits.h =================================================================== --- include/llvm/Support/YAMLTraits.h +++ include/llvm/Support/YAMLTraits.h @@ -472,6 +472,16 @@ virtual void endFlowSequence() = 0; virtual bool mapTag(StringRef Tag, bool Default=false) = 0; + /// Specify the type of the node. + /// + /// During output, the \p Tag is written if \p Use returns true. A lambda is + /// used for lazy evaluation so that the underlying structure can be used. + /// During parsing, the structure may not be allocated at this point. + /// + /// During input, this returns whether \p Tag was encounter. Also return true + /// if no tag was found but this is the default tag (i.e. \p Default is true). + virtual bool mapTag(StringRef Tag, std::function Use, + bool Default = false) = 0; virtual void beginMapping() = 0; virtual void endMapping() = 0; virtual bool preflightKey(const char*, bool, bool, bool &, void *&) = 0; @@ -982,6 +992,7 @@ private: bool outputting() override; bool mapTag(StringRef, bool) override; + bool mapTag(StringRef, std::function Use, bool Default) override; void beginMapping() override; void endMapping() override; bool preflightKey(const char *, bool, bool, bool &, void *&) override; @@ -1113,6 +1124,7 @@ bool outputting() override; bool mapTag(StringRef, bool) override; + bool mapTag(StringRef, std::function Use, bool Default) override; void beginMapping() override; void endMapping() override; bool preflightKey(const char *key, bool, bool, bool &, void *&) override; Index: lib/Support/YAMLTraits.cpp =================================================================== --- lib/Support/YAMLTraits.cpp +++ lib/Support/YAMLTraits.cpp @@ -101,6 +101,10 @@ return CurrentNode ? CurrentNode->_node : nullptr; } +bool Input::mapTag(StringRef Tag, std::function Use, bool Default) { + return mapTag(Tag, Default); +} + bool Input::mapTag(StringRef Tag, bool Default) { std::string foundTag = CurrentNode->_node->getVerbatimTag(); if (foundTag.empty()) { @@ -421,6 +425,10 @@ NeedsNewLine = true; } +bool Output::mapTag(StringRef Tag, std::function Use, bool Default) { + return mapTag(Tag, Use()); +} + bool Output::mapTag(StringRef Tag, bool Use) { if (Use) { // If this tag is being written inside a sequence we should write the start Index: unittests/Support/YAMLIOTest.cpp =================================================================== --- unittests/Support/YAMLIOTest.cpp +++ unittests/Support/YAMLIOTest.cpp @@ -1572,6 +1572,114 @@ } } +//===----------------------------------------------------------------------===// +// Test class hierarchy +//===----------------------------------------------------------------------===// + +struct DoubleBase { + enum Kind { decimal, fraction }; + Kind kind; + DoubleBase(Kind kind) : kind(kind) {} +}; + +struct DoubleDecimal : public DoubleBase { + double value; + DoubleDecimal() : DoubleBase(decimal) {} + DoubleDecimal(double value) : DoubleBase(decimal), value(value) {} +}; + +struct DoubleFraction : public DoubleBase { + int num, denom; + DoubleFraction() : DoubleBase(fraction) {} + DoubleFraction(int num, int denom) + : DoubleBase(fraction), num(num), denom(denom) {} +}; + +namespace llvm { +namespace yaml { +template <> struct ScalarEnumerationTraits { + static void enumeration(IO &io, DoubleBase::Kind &kind) { + io.enumCase(kind, "decimal", DoubleBase::decimal); + io.enumCase(kind, "fraction", DoubleBase::fraction); + } +}; + +template <> struct MappingTraits { + static void mapping(IO &io, DoubleBase *&d) { + if (io.mapTag("!decimal", [&]() { return d->kind == DoubleBase::decimal; }, + true)) { + if (!io.outputting()) + d = new DoubleDecimal; + mappingDecimal(io, static_cast(d)); + } else if (io.mapTag("!fraction", + [&]() { return d->kind == DoubleBase::fraction; })) { + if (!io.outputting()) + d = new DoubleFraction; + mappingFraction(io, static_cast(d)); + } + } + + static void mappingDecimal(IO &io, DoubleDecimal *dd) { + io.mapRequired("value", dd->value); + } + static void mappingFraction(IO &io, DoubleFraction *df) { + io.mapRequired("numerator", df->num); + io.mapRequired("denominator", df->denom); + } +}; +} +} + +LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(DoubleBase *) + +TEST(YAMLIO, TestClassHiearchy) { + std::vector docList; + Input yin("--- !decimal\nvalue: 3.0\n" + "--- !fraction\nnumerator: 9\ndenominator: 2\n...\n"); + yin >> docList; + EXPECT_FALSE(yin.error()); + EXPECT_EQ(docList.size(), 2UL); + + EXPECT_EQ(docList[0]->kind, DoubleBase::decimal); + auto *dd = static_cast(docList[0]); + EXPECT_EQ(dd->value, 3.0); + + EXPECT_EQ(docList[1]->kind, DoubleBase::fraction); + auto *df = static_cast(docList[1]); + EXPECT_EQ(df->num, 9); + EXPECT_EQ(df->denom, 2); +} + +TEST(YAMLIO, TestClassHiearchyWriteAndRead) { + std::string intermediate; + { + DoubleFraction a(1025, 100); + DoubleDecimal b(-3.75); + std::vector docList; + docList.push_back(&a); + docList.push_back(&b); + + llvm::raw_string_ostream ostr(intermediate); + Output yout(ostr); + yout << docList; + } + + { + Input yin(intermediate); + std::vector docList2; + yin >> docList2; + + EXPECT_FALSE(yin.error()); + EXPECT_EQ(docList2.size(), 2UL); + EXPECT_EQ(docList2[0]->kind, DoubleBase::fraction); + auto *df = static_cast(docList2[0]); + EXPECT_EQ(df->num, 1025); + EXPECT_EQ(df->denom, 100); + + EXPECT_EQ(docList2[1]->kind, DoubleBase::decimal); + EXPECT_EQ(static_cast(docList2[1])->value, -3.75); + } +} //===----------------------------------------------------------------------===// // Test mapping validation