diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -104,6 +104,9 @@ struct LoggedFeatureSpec { TensorSpec Spec; Optional LoggingName; + const std::string &getLoggingName() const { + return LoggingName ? *LoggingName : Spec.name(); + } }; /// Load the output specs. If SpecFileOverride is not empty, that path is used. @@ -170,7 +173,9 @@ // we can consider using bytes. char *addEntryAndGetFloatOrInt64Buffer(size_t FeatureID); - void print(raw_ostream &OS); + // Flush the content of the log to the stream, clearing the stored data in the + // process. + void flush(raw_ostream &OS); private: std::vector FeatureSpecs; diff --git a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp --- a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp +++ b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp @@ -377,7 +377,7 @@ void TrainingLogger::print() { std::error_code EC; raw_fd_ostream OutFile(LogFileName, EC); - L->print(OutFile); + L->flush(OutFile); } DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor( diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -262,29 +262,57 @@ class LoggerDataImpl { const std::vector LoggedFeatureSpecs; const TensorSpec RewardSpec; + const bool IncludeReward; + + std::vector FeatureLists; + tensorflow::FeatureList Reward; + + bool isSelfConsistent(const tensorflow::SequenceExample &SE, + size_t NrRecords) const { + bool Ret = true; + for (const auto &TSpecs : LoggedFeatureSpecs) { + const auto &Name = TSpecs.getLoggingName(); + const auto &FL = SE.feature_lists().feature_list().at(Name).feature(); + if (NrRecords != static_cast(FL.size())) { + dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected " + << NrRecords << " got " << FL.size() << "\n"; + Ret = false; + } + } + if (IncludeReward && static_cast(SE.feature_lists() + .feature_list() + .at(RewardSpec.name()) + .feature() + .size()) != NrRecords) { + dbgs() << "[TF-UTILS]: reward is missing records.\n"; + Ret = false; + } + return Ret; + } - tensorflow::SequenceExample SE; - std::vector FeatureLists; - tensorflow::FeatureList *Reward = nullptr; - -public: - LoggerDataImpl(const std::vector &LoggedSpecs, - const TensorSpec &RewardSpec, bool IncludeReward) - : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec) { + void transferLog(tensorflow::SequenceExample &SE) { auto *FL = SE.mutable_feature_lists()->mutable_feature_list(); if (IncludeReward) - Reward = &(*FL)[RewardSpec.name()]; - // Allocate first the map entries, then capture their address. We will not - // mutate the set of features after this (i.e. the pointers won't dangle). - for (const auto &LFS : LoggedSpecs) { - (*FL)[LFS.LoggingName ? *LFS.LoggingName : LFS.Spec.name()] = {}; + (*FL)[RewardSpec.name()].Swap(&Reward); + assert(FeatureLists.size() == LoggedFeatureSpecs.size()); + for (size_t I = 0; I < FeatureLists.size(); ++I) { + const auto &LFS = LoggedFeatureSpecs[I]; + (*FL)[LFS.getLoggingName()].Swap(&FeatureLists[I]); } - for (const auto &LFS : LoggedSpecs) - FeatureLists.push_back( - &(*FL)[LFS.LoggingName ? *LFS.LoggingName : LFS.Spec.name()]); } - void print(raw_ostream &OS) { +public: + LoggerDataImpl(const std::vector &LoggedSpecs, + const TensorSpec &RewardSpec, bool IncludeReward) + : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec), + IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {} + + // flush the logged info to a stream and clear the log contents. + void flush(raw_ostream &OS) { + size_t NrRecords = getNrRecords(); + tensorflow::SequenceExample SE; + transferLog(SE); + assert(isSelfConsistent(SE, NrRecords)); std::string OutStr; if (ProtobufTextMode) google::protobuf::TextFormat::PrintToString(SE, &OutStr); @@ -298,14 +326,14 @@ const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec; if (Spec.isElementType()) { auto *RF = FeatureLists[FeatureID] - ->add_feature() + .add_feature() ->mutable_float_list() ->mutable_value(); RF->Resize(Spec.getElementCount(), 0.0); return reinterpret_cast(RF->mutable_data()); } else if (Spec.isElementType() || Spec.isElementType()) { auto *RF = FeatureLists[FeatureID] - ->add_feature() + .add_feature() ->mutable_int64_list() ->mutable_value(); RF->Resize(Spec.getElementCount(), 0); @@ -315,17 +343,18 @@ } template void logReward(T Value) { + assert(IncludeReward); if (RewardSpec.isElementType()) - Reward->add_feature()->mutable_float_list()->add_value(Value); + Reward.add_feature()->mutable_float_list()->add_value(Value); else if (RewardSpec.isElementType() || RewardSpec.isElementType()) - Reward->add_feature()->mutable_int64_list()->add_value(Value); + Reward.add_feature()->mutable_int64_list()->add_value(Value); else llvm_unreachable("Unsupported tensor type."); } size_t getNrRecords() const { - return FeatureLists.empty() ? 0 : FeatureLists[0]->feature().size(); + return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size(); } }; } // namespace llvm @@ -538,5 +567,5 @@ return reinterpret_cast(LoggerData->addNewTensor(FeatureID)); } -void Logger::print(raw_ostream &OS) { LoggerData->print(OS); } +void Logger::flush(raw_ostream &OS) { LoggerData->flush(OS); } #endif // defined(LLVM_HAVE_TF_API) diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp --- a/llvm/unittests/Analysis/TFUtilsTest.cpp +++ b/llvm/unittests/Analysis/TFUtilsTest.cpp @@ -179,7 +179,7 @@ L.logFloatReward(-3.0); std::string Result; raw_string_ostream OS(Result); - L.print(OS); + L.flush(OS); tensorflow::SequenceExample Expected; EXPECT_TRUE(Expected.ParseFromString(Result)); @@ -215,7 +215,7 @@ L.logInt32Reward(-3); std::string Result; raw_string_ostream OS(Result); - L.print(OS); + L.flush(OS); tensorflow::SequenceExample Expected; EXPECT_TRUE(Expected.ParseFromString(Result)); @@ -250,7 +250,7 @@ std::string Result; raw_string_ostream OS(Result); - L.print(OS); + L.flush(OS); tensorflow::SequenceExample Expected; EXPECT_TRUE(Expected.ParseFromString(Result)); PROTO_CHECKER("the_float", float_list, 0, F00); @@ -274,7 +274,7 @@ L.logFloatFinalReward(3.14); std::string Result; raw_string_ostream OS(Result); - L.print(OS); + L.flush(OS); const float Zero[]{0.0}; const float R[]{3.14}; tensorflow::SequenceExample Expected;