Index: include/llvm/Support/YAMLTraits.h =================================================================== --- include/llvm/Support/YAMLTraits.h +++ include/llvm/Support/YAMLTraits.h @@ -39,6 +39,12 @@ namespace llvm { namespace yaml { +enum class NodeKind : uint8_t { + Scalar, + Map, + Sequence, +}; + struct EmptyContext {}; /// This class should be specialized by any type that needs to be converted @@ -145,14 +151,14 @@ // Must provide: // // Function to write the value as a string: - //static void output(const T &value, void *ctxt, llvm::raw_ostream &out); + // static void output(const T &value, void *ctxt, llvm::raw_ostream &out); // // Function to convert a string to a value. Returns the empty // StringRef on success or an error string if string is malformed: - //static StringRef input(StringRef scalar, void *ctxt, T &value); + // static StringRef input(StringRef scalar, void *ctxt, T &value); // // Function to determine if the value should be quoted. - //static QuotingType mustQuote(StringRef); + // static QuotingType mustQuote(StringRef); }; /// This class should be specialized by type that requires custom conversion @@ -181,6 +187,47 @@ // Function to convert a string to a value. Returns the empty // StringRef on success or an error string if string is malformed: // static StringRef input(StringRef Scalar, void *ctxt, T &Value); + // + // Optional: + // static StringRef inputTag(T &Val, std::string Tag) + // static void outputTag(const T &Val, raw_ostream &Out) +}; + +/// This class should be specialized by type that requires custom conversion +/// to/from a YAML scalar with optional tags. For example: +/// +/// template <> +/// struct TaggedScalarTraits { +/// static void output(const MyType &Value, void*, llvm::raw_ostream +/// &ScalarOut, llvm::raw_ostream &TagOut) +/// { +/// // stream out custom formatting including optional Tag +/// Out << Val; +/// } +/// static StringRef input(StringRef Scalar, StringRef Tag, void*, MyType +/// &Value) { +/// // parse scalar and set `value` +/// // return empty string on success, or error string +/// return StringRef(); +/// } +/// static QuotingType mustQuote(const MyType &Value, StringRef) { +// return QuotingType::Single; +// } +/// }; +template struct TaggedScalarTraits { + // Must provide: + // + // Function to write the value and tag as strings: + // static void output(const T &Value, void *ctx, llvm::raw_ostream &ScalarOut, + // llvm::raw_ostream &TagOut); + // + // Function to convert a string to a value. Returns the empty + // StringRef on success or an error string if string is malformed: + // static StringRef input(StringRef Scalar, StringRef Tag, void *ctxt, T + // &Value); + // + // Function to determine if the value should be quoted. + // static QuotingType mustQuote(const T &Value, StringRef Scalar); }; /// This class should be specialized by any type that needs to be converted @@ -234,6 +281,31 @@ // static void output(IO &io, T &elem); }; +/// This class should be specialized by any type that can be represented as +/// a scalar, map, or sequence, decided dynamically. For example: +/// +/// typedef std::unique_ptr MyPoly; +/// +/// template<> +/// struct PolymorphicTraits { +/// static NodeKind getKind(const MyPoly &poly) { +/// return poly->getKind(); +/// } +/// static MyScalar& getAsScalar(MyPoly &poly) { +/// if (!poly || !isa(poly)) +/// poly.reset(new MyScalar()); +/// return *cast(poly.get()); +/// } +/// // ... +/// }; +template struct PolymorphicTraits { + // Must provide: + // static NodeKind getKind(const T &poly); + // static scalar_type &getAsScalar(T &poly); + // static map_type &getAsMap(T &poly); + // static sequence_type &getAsSequence(T &poly); +}; + // Only used for better diagnostics of missing traits template struct MissingTrait; @@ -307,6 +379,24 @@ (sizeof(test>(nullptr, nullptr)) == 1); }; +// Test if TaggedScalarTraits is defined on type T. +template struct has_TaggedScalarTraits { + using Signature_input = StringRef (*)(StringRef, StringRef, void *, T &); + using Signature_output = void (*)(const T &, void *, raw_ostream &, + raw_ostream &); + using Signature_mustQuote = QuotingType (*)(const T &, StringRef); + + template + static char test(SameType *, + SameType *, + SameType *); + + template static double test(...); + + static bool const value = + (sizeof(test>(nullptr, nullptr, nullptr)) == 1); +}; + // Test if MappingContextTraits is defined on type T. template struct has_MappingTraits { using Signature_mapping = void (*)(class IO &, T &, Context &); @@ -438,6 +528,17 @@ static bool const value = (sizeof(test>(nullptr))==1); }; +template struct has_PolymorphicTraits { + using Signature_getKind = NodeKind (*)(const T &); + + template + static char test(SameType *); + + template static double test(...); + + static bool const value = (sizeof(test>(nullptr)) == 1); +}; + inline bool isNumeric(StringRef S) { const static auto skipDigits = [](StringRef Input) { return Input.drop_front( @@ -621,10 +722,12 @@ !has_ScalarBitSetTraits::value && !has_ScalarTraits::value && !has_BlockScalarTraits::value && + !has_TaggedScalarTraits::value && !has_MappingTraits::value && !has_SequenceTraits::value && !has_CustomMappingTraits::value && - !has_DocumentListTraits::value> {}; + !has_DocumentListTraits::value && + !has_PolymorphicTraits::value> {}; template struct validatedMappingTraits @@ -678,6 +781,9 @@ virtual void scalarString(StringRef &, QuotingType) = 0; virtual void blockScalarString(StringRef &) = 0; + virtual void scalarTag(std::string &) = 0; + + virtual NodeKind getNodeKind() = 0; virtual void setError(const Twine &) = 0; @@ -912,6 +1018,31 @@ } } +template +typename std::enable_if::value, void>::type +yamlize(IO &io, T &Val, bool, EmptyContext &Ctx) { + if (io.outputting()) { + std::string ScalarStorage, TagStorage; + raw_string_ostream ScalarBuffer(ScalarStorage), TagBuffer(TagStorage); + TaggedScalarTraits::output(Val, io.getContext(), ScalarBuffer, + TagBuffer); + io.scalarTag(TagBuffer.str()); + StringRef ScalarStr = ScalarBuffer.str(); + io.scalarString(ScalarStr, + TaggedScalarTraits::mustQuote(Val, ScalarStr)); + } else { + std::string Tag; + io.scalarTag(Tag); + StringRef Str; + io.scalarString(Str, QuotingType::None); + StringRef Result = + TaggedScalarTraits::input(Str, Tag, io.getContext(), Val); + if (!Result.empty()) { + io.setError(Twine(Result)); + } + } +} + template typename std::enable_if::value, void>::type yamlize(IO &io, T &Val, bool, Context &Ctx) { @@ -968,6 +1099,20 @@ } template +typename std::enable_if::value, void>::type +yamlize(IO &io, T &Val, bool, EmptyContext &Ctx) { + switch (io.outputting() ? PolymorphicTraits::getKind(Val) + : io.getNodeKind()) { + case NodeKind::Scalar: + return yamlize(io, PolymorphicTraits::getAsScalar(Val), true, Ctx); + case NodeKind::Map: + return yamlize(io, PolymorphicTraits::getAsMap(Val), true, Ctx); + case NodeKind::Sequence: + return yamlize(io, PolymorphicTraits::getAsSequence(Val), true, Ctx); + } +} + +template typename std::enable_if::value, void>::type yamlize(IO &io, T &Val, bool, EmptyContext &Ctx) { char missing_yaml_trait_for_type[sizeof(MissingTrait)]; @@ -1245,6 +1390,8 @@ void endBitSetScalar() override; void scalarString(StringRef &, QuotingType) override; void blockScalarString(StringRef &) override; + void scalarTag(std::string &) override; + NodeKind getNodeKind() override; void setError(const Twine &message) override; bool canElideEmptySequence() override; @@ -1390,6 +1537,8 @@ void endBitSetScalar() override; void scalarString(StringRef &, QuotingType) override; void blockScalarString(StringRef &) override; + void scalarTag(std::string &) override; + NodeKind getNodeKind() override; void setError(const Twine &message) override; bool canElideEmptySequence() override; @@ -1409,14 +1558,21 @@ void flowKey(StringRef Key); enum InState { - inSeq, - inFlowSeq, + inSeqFirstElement, + inSeqOtherElement, + inFlowSeqFirstElement, + inFlowSeqOtherElement, inMapFirstKey, inMapOtherKey, inFlowMapFirstKey, inFlowMapOtherKey }; + static bool inSeqAnyElement(InState State); + static bool inFlowSeqAnyElement(InState State); + static bool inMapAnyKey(InState State); + static bool inFlowMapAnyKey(InState State); + raw_ostream &Out; int WrapColumn; SmallVector StateStack; @@ -1552,6 +1708,16 @@ return In; } +// Define non-member operator>> so that Input can stream in a polymorphic type. +template +inline typename std::enable_if::value, Input &>::type +operator>>(Input &In, T &Val) { + EmptyContext Ctx; + if (In.setCurrentDocument()) + yamlize(In, Val, true, Ctx); + return In; +} + // Provide better error message about types missing a trait specialization template inline typename std::enable_if::value, @@ -1640,6 +1806,24 @@ return Out; } +// Define non-member operator<< so that Output can stream out a polymorphic +// type. +template +inline typename std::enable_if::value, Output &>::type +operator<<(Output &Out, T &Val) { + EmptyContext Ctx; + Out.beginDocuments(); + if (Out.preflightDocument(0)) { + // FIXME: The parser does not support explicit documents terminated with a + // plain scalar; the end-marker is included as part of the scalar token. + assert(PolymorphicTraits::getKind(Val) != NodeKind::Scalar && "plain scalar documents are not supported"); + yamlize(Out, Val, true, Ctx); + Out.postflightDocument(); + } + Out.endDocuments(); + return Out; +} + // Provide better error message about types missing a trait specialization template inline typename std::enable_if::value, Index: lib/Support/YAMLTraits.cpp =================================================================== --- lib/Support/YAMLTraits.cpp +++ lib/Support/YAMLTraits.cpp @@ -341,11 +341,25 @@ void Input::blockScalarString(StringRef &S) { scalarString(S, QuotingType::None); } +void Input::scalarTag(std::string &Tag) { + Tag = CurrentNode->_node->getVerbatimTag(); +} + void Input::setError(HNode *hnode, const Twine &message) { assert(hnode && "HNode must not be NULL"); setError(hnode->_node, message); } +NodeKind Input::getNodeKind() { + if (isa(CurrentNode)) + return NodeKind::Scalar; + else if (isa(CurrentNode)) + return NodeKind::Map; + else if (isa(CurrentNode)) + return NodeKind::Sequence; + llvm_unreachable("Unsupported node kind"); +} + void Input::setError(Node *node, const Twine &message) { Strm->printError(node, message); EC = make_error_code(errc::invalid_argument); @@ -436,9 +450,11 @@ // If this tag is being written inside a sequence we should write the start // of the sequence before writing the tag, otherwise the tag won't be // attached to the element in the sequence, but rather the sequence itself. - bool SequenceElement = - StateStack.size() > 1 && (StateStack[StateStack.size() - 2] == inSeq || - StateStack[StateStack.size() - 2] == inFlowSeq); + bool SequenceElement = false; + if (StateStack.size() > 1) { + auto &E = StateStack[StateStack.size() - 2]; + SequenceElement = inSeqAnyElement(E) || inFlowSeqAnyElement(E); + } if (SequenceElement && StateStack.back() == inMapFirstKey) { newLineCheck(); } else { @@ -461,6 +477,9 @@ } void Output::endMapping() { + // If we did not map anything, we should explicitly emit an empty map + if (StateStack.back() == inMapFirstKey) + output("{}"); StateStack.pop_back(); } @@ -524,12 +543,15 @@ } unsigned Output::beginSequence() { - StateStack.push_back(inSeq); + StateStack.push_back(inSeqFirstElement); NeedsNewLine = true; return 0; } void Output::endSequence() { + // If we did not emit anything, we should explicitly emit an empty sequence + if (StateStack.back() == inSeqFirstElement) + output("[]"); StateStack.pop_back(); } @@ -538,10 +560,17 @@ } void Output::postflightElement(void *) { + if (StateStack.back() == inSeqFirstElement) { + StateStack.pop_back(); + StateStack.push_back(inSeqOtherElement); + } else if (StateStack.back() == inFlowSeqFirstElement) { + StateStack.pop_back(); + StateStack.push_back(inFlowSeqOtherElement); + } } unsigned Output::beginFlowSequence() { - StateStack.push_back(inFlowSeq); + StateStack.push_back(inFlowSeqFirstElement); newLineCheck(); ColumnAtFlowStart = Column; output("[ "); @@ -680,6 +709,14 @@ } } +void Output::scalarTag(std::string &Tag) { + if (Tag.empty()) + return; + newLineCheck(); + output(Tag); + output(" "); +} + void Output::setError(const Twine &message) { } @@ -693,7 +730,7 @@ return true; if (StateStack.back() != inMapFirstKey) return true; - return (StateStack[StateStack.size()-2] != inSeq); + return !inSeqAnyElement(StateStack[StateStack.size() - 2]); } void Output::output(StringRef s) { @@ -703,9 +740,8 @@ void Output::outputUpToEndOfLine(StringRef s) { output(s); - if (StateStack.empty() || (StateStack.back() != inFlowSeq && - StateStack.back() != inFlowMapFirstKey && - StateStack.back() != inFlowMapOtherKey)) + if (StateStack.empty() || (!inFlowSeqAnyElement(StateStack.back()) && + !inFlowMapAnyKey(StateStack.back()))) NeedsNewLine = true; } @@ -725,16 +761,20 @@ outputNewLine(); - assert(StateStack.size() > 0); + if (StateStack.size() == 0) + return; + unsigned Indent = StateStack.size() - 1; bool OutputDash = false; - if (StateStack.back() == inSeq) { + if (StateStack.back() == inSeqFirstElement || + StateStack.back() == inSeqOtherElement) { OutputDash = true; - } else if ((StateStack.size() > 1) && ((StateStack.back() == inMapFirstKey) || - (StateStack.back() == inFlowSeq) || - (StateStack.back() == inFlowMapFirstKey)) && - (StateStack[StateStack.size() - 2] == inSeq)) { + } else if ((StateStack.size() > 1) && + ((StateStack.back() == inMapFirstKey) || + inFlowSeqAnyElement(StateStack.back()) || + (StateStack.back() == inFlowMapFirstKey)) && + inSeqAnyElement(StateStack[StateStack.size() - 2])) { --Indent; OutputDash = true; } @@ -772,6 +812,24 @@ output(": "); } +NodeKind Output::getNodeKind() { report_fatal_error("invalid call"); } + +bool Output::inSeqAnyElement(InState State) { + return State == inSeqFirstElement || State == inSeqOtherElement; +} + +bool Output::inFlowSeqAnyElement(InState State) { + return State == inFlowSeqFirstElement || State == inFlowSeqOtherElement; +} + +bool Output::inMapAnyKey(InState State) { + return State == inMapFirstKey || State == inMapOtherKey; +} + +bool Output::inFlowMapAnyKey(InState State) { + return State == inFlowMapFirstKey || State == inFlowMapOtherKey; +} + //===----------------------------------------------------------------------===// // traits for built-in types //===----------------------------------------------------------------------===// Index: unittests/Support/YAMLIOTest.cpp =================================================================== --- unittests/Support/YAMLIOTest.cpp +++ unittests/Support/YAMLIOTest.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" @@ -2640,3 +2641,235 @@ EXPECT_FALSE(isNumeric("-inf")); EXPECT_FALSE(isNumeric("1,230.15")); } + +//===----------------------------------------------------------------------===// +// Test PolymorphicTraits and TaggedScalarTraits +//===----------------------------------------------------------------------===// + +struct Poly { + enum NodeKind { + NK_Scalar, + NK_Seq, + NK_Map, + } Kind; + + Poly(NodeKind Kind) : Kind(Kind) {} + + virtual ~Poly() = default; + + NodeKind getKind() const { return Kind; } +}; + +struct Scalar : Poly { + enum ScalarKind { + SK_Unknown, + SK_Double, + SK_Bool, + } SKind; + + union { + double DoubleValue; + bool BoolValue; + }; + + Scalar() : Poly(NK_Scalar), SKind(SK_Unknown) {} + Scalar(double DoubleValue) + : Poly(NK_Scalar), SKind(SK_Double), DoubleValue(DoubleValue) {} + Scalar(bool BoolValue) + : Poly(NK_Scalar), SKind(SK_Bool), BoolValue(BoolValue) {} + + static bool classof(const Poly *N) { return N->getKind() == NK_Scalar; } +}; + +struct Seq : Poly, std::vector> { + Seq() : Poly(NK_Seq) {} + + static bool classof(const Poly *N) { return N->getKind() == NK_Seq; } +}; + +struct Map : Poly, llvm::StringMap> { + Map() : Poly(NK_Map) {} + + static bool classof(const Poly *N) { return N->getKind() == NK_Map; } +}; + +namespace llvm { +namespace yaml { + +template <> struct PolymorphicTraits> { + static NodeKind getKind(const std::unique_ptr &N) { + if (isa(*N)) + return NodeKind::Scalar; + if (isa(*N)) + return NodeKind::Sequence; + if (isa(*N)) + return NodeKind::Map; + llvm_unreachable("unsupported node type"); + } + + static Scalar &getAsScalar(std::unique_ptr &N) { + if (!N || !isa(*N)) + N = llvm::make_unique(); + return *cast(N.get()); + } + + static Seq &getAsSequence(std::unique_ptr &N) { + if (!N || !isa(*N)) + N = llvm::make_unique(); + return *cast(N.get()); + } + + static Map &getAsMap(std::unique_ptr &N) { + if (!N || !isa(*N)) + N = llvm::make_unique(); + return *cast(N.get()); + } +}; + +template <> struct TaggedScalarTraits { + static void output(const Scalar &S, void *Ctxt, raw_ostream &ScalarOS, + raw_ostream &TagOS) { + switch (S.SKind) { + case Scalar::SK_Unknown: + report_fatal_error("output unknown scalar"); + break; + case Scalar::SK_Double: + TagOS << "!double"; + ScalarTraits::output(S.DoubleValue, Ctxt, ScalarOS); + break; + case Scalar::SK_Bool: + TagOS << "!bool"; + ScalarTraits::output(S.BoolValue, Ctxt, ScalarOS); + break; + } + } + + static StringRef input(StringRef ScalarStr, StringRef Tag, void *Ctxt, + Scalar &S) { + S.SKind = StringSwitch(Tag) + .Case("!double", Scalar::SK_Double) + .Case("!bool", Scalar::SK_Bool) + .Default(Scalar::SK_Unknown); + switch (S.SKind) { + case Scalar::SK_Unknown: + return StringRef("unknown scalar tag"); + case Scalar::SK_Double: + return ScalarTraits::input(ScalarStr, Ctxt, S.DoubleValue); + case Scalar::SK_Bool: + return ScalarTraits::input(ScalarStr, Ctxt, S.BoolValue); + } + llvm_unreachable("unknown scalar kind"); + } + + static QuotingType mustQuote(const Scalar &S, StringRef Str) { + switch (S.SKind) { + case Scalar::SK_Unknown: + report_fatal_error("quote unknown scalar"); + case Scalar::SK_Double: + return ScalarTraits::mustQuote(Str); + case Scalar::SK_Bool: + return ScalarTraits::mustQuote(Str); + } + llvm_unreachable("unknown scalar kind"); + } +}; + +template <> struct CustomMappingTraits { + static void inputOne(IO &IO, StringRef Key, Map &M) { + IO.mapRequired(Key.str().c_str(), M[Key]); + } + + static void output(IO &IO, Map &M) { + for (auto &N : M) + IO.mapRequired(N.getKey().str().c_str(), N.getValue()); + } +}; + +template <> struct SequenceTraits { + static size_t size(IO &IO, Seq &A) { return A.size(); } + + static std::unique_ptr &element(IO &IO, Seq &A, size_t Index) { + if (Index >= A.size()) + A.resize(Index + 1); + return A[Index]; + } +}; + +} // namespace yaml +} // namespace llvm + +TEST(YAMLIO, TestReadWritePolymorphicScalar) { + std::string intermediate; + std::unique_ptr node = llvm::make_unique(true); + + llvm::raw_string_ostream ostr(intermediate); + Output yout(ostr); +#ifdef GTEST_HAS_DEATH_TEST +#ifndef NDEBUG + EXPECT_DEATH(yout << node, "plain scalar documents are not supported"); +#endif +#endif +} + +TEST(YAMLIO, TestReadWritePolymorphicSeq) { + std::string intermediate; + { + auto seq = llvm::make_unique(); + seq->push_back(llvm::make_unique(true)); + seq->push_back(llvm::make_unique(1.0)); + auto node = llvm::unique_dyn_cast(seq); + + llvm::raw_string_ostream ostr(intermediate); + Output yout(ostr); + yout << node; + } + { + Input yin(intermediate); + std::unique_ptr node; + yin >> node; + + EXPECT_FALSE(yin.error()); + auto seq = llvm::dyn_cast(node.get()); + ASSERT_TRUE(seq); + ASSERT_EQ(seq->size(), 2u); + auto first = llvm::dyn_cast((*seq)[0].get()); + ASSERT_TRUE(first); + EXPECT_EQ(first->SKind, Scalar::SK_Bool); + EXPECT_TRUE(first->BoolValue); + auto second = llvm::dyn_cast((*seq)[1].get()); + ASSERT_TRUE(second); + EXPECT_EQ(second->SKind, Scalar::SK_Double); + EXPECT_EQ(second->DoubleValue, 1.0); + } +} + +TEST(YAMLIO, TestReadWritePolymorphicMap) { + std::string intermediate; + { + auto map = llvm::make_unique(); + (*map)["foo"] = llvm::make_unique(false); + (*map)["bar"] = llvm::make_unique(2.0); + std::unique_ptr node = llvm::unique_dyn_cast(map); + + llvm::raw_string_ostream ostr(intermediate); + Output yout(ostr); + yout << node; + } + { + Input yin(intermediate); + std::unique_ptr node; + yin >> node; + + EXPECT_FALSE(yin.error()); + auto map = llvm::dyn_cast(node.get()); + ASSERT_TRUE(map); + auto foo = llvm::dyn_cast((*map)["foo"].get()); + ASSERT_TRUE(foo); + EXPECT_EQ(foo->SKind, Scalar::SK_Bool); + EXPECT_FALSE(foo->BoolValue); + auto bar = llvm::dyn_cast((*map)["bar"].get()); + ASSERT_TRUE(bar); + EXPECT_EQ(bar->SKind, Scalar::SK_Double); + EXPECT_EQ(bar->DoubleValue, 2.0); + } +}