diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -178,8 +178,14 @@ // Flush the content of the log to the stream, clearing the stored data in the // process. + void flush(std::string *Str); void flush(raw_ostream &OS); + // Flush a set of logs that are produced from the same module, e.g. + // per-function regalloc traces, as a google::protobuf::Struct message. + static void flushLogs(raw_ostream &OS, + const StringMap> &Loggers); + private: std::vector FeatureSpecs; TensorSpec RewardSpec; diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/Utils/TFUtils.h" +#include "llvm/Support/Base64.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/JSON.h" @@ -22,6 +23,7 @@ #include "llvm/Support/Path.h" #include "llvm/Support/raw_ostream.h" +#include "google/protobuf/struct.pb.h" #include "google/protobuf/text_format.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_experimental.h" @@ -72,6 +74,14 @@ TFSessionOptionsPtr createTFSessionOptions() { return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions); } + +void serialize(const Message &SE, std::string *OutStr) { + if (ProtobufTextMode) { + TextFormat::PrintToString(SE, OutStr); + } else { + *OutStr = SE.SerializeAsString(); + } +} } // namespace namespace llvm { @@ -307,19 +317,13 @@ IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {} // flush the logged info to a stream and clear the log contents. - void flush(raw_ostream &OS) { + void flush(std::string *Str) { size_t NrRecords = getNrRecords(); (void)NrRecords; tensorflow::SequenceExample SE; transferLog(SE); assert(isSelfConsistent(SE, NrRecords)); - std::string OutStr; - if (ProtobufTextMode) - google::protobuf::TextFormat::PrintToString(SE, &OutStr); - else - OutStr = SE.SerializeAsString(); - - OS << OutStr; + serialize(SE, Str); } char *addNewTensor(size_t FeatureID) { @@ -567,5 +571,31 @@ return reinterpret_cast(LoggerData->addNewTensor(FeatureID)); } -void Logger::flush(raw_ostream &OS) { LoggerData->flush(OS); } +void Logger::flush(std::string *Str) { LoggerData->flush(Str); } + +void Logger::flush(raw_ostream &OS) { + std::string Buff; + LoggerData->flush(&Buff); + OS << Buff; +} + +void Logger::flushLogs(raw_ostream &OS, + const StringMap> &Loggers) { + google::protobuf::Struct Msg; + for (const auto &NamedLogger : Loggers) { + tensorflow::SequenceExample SE; + const auto &Logger = NamedLogger.second; + std::string Unencoded; + if (Logger->LoggerData->getNrRecords() > 0) + Logger->flush(&Unencoded); + + (*Msg.mutable_fields())[NamedLogger.first().str()] + .mutable_string_value() + ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded)); + } + + std::string OutStr; + serialize(Msg, &OutStr); + OS << OutStr; +} #endif // defined(LLVM_HAVE_TF_API) diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp --- a/llvm/unittests/Analysis/TFUtilsTest.cpp +++ b/llvm/unittests/Analysis/TFUtilsTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/Utils/TFUtils.h" +#include "google/protobuf/struct.pb.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" #include "llvm/AsmParser/Parser.h" @@ -283,3 +284,32 @@ PROTO_CHECKER("reward", float_list, 1, Zero); PROTO_CHECKER("reward", float_list, 2, R); } + +TEST(TFUtilsTest, LoggerGroup) { + std::vector Features; + Features.push_back({TensorSpec::createSpec("the_float", {1}), None}); + Features.push_back({TensorSpec::createSpec("the_int", {1}), None}); + + auto Rewards = TensorSpec::createSpec("reward", {1}); + StringMap> Loggers; + std::vector Names{"a", "b"}; + size_t Bump = 0; + for (auto Name : Names) { + auto L = std::make_unique(Features, Rewards, true); + for (int64_t I = 0; I < 3; ++I) { + float F = static_cast(I) + Bump; + L->logFloatValue(0, &F); + L->logInt64Value(1, &I); + } + L->logFloatFinalReward(3.14 + Bump); + Loggers.insert(std::make_pair(Name, std::move(L))); + } + std::string Result; + raw_string_ostream OS(Result); + Logger::flushLogs(OS, Loggers); + google::protobuf::Struct Expected; + ASSERT_TRUE(Expected.ParseFromString(Result)); + EXPECT_EQ(Expected.fields_size(), 2); + EXPECT_TRUE(Expected.fields().contains("a")); + EXPECT_TRUE(Expected.fields().contains("b")); +}