diff --git a/llvm/lib/Analysis/TrainingLogger.cpp b/llvm/lib/Analysis/TrainingLogger.cpp --- a/llvm/lib/Analysis/TrainingLogger.cpp +++ b/llvm/lib/Analysis/TrainingLogger.cpp @@ -52,10 +52,29 @@ namespace llvm { class LoggerDataImpl { +protected: const std::vector LoggedFeatureSpecs; const TensorSpec RewardSpec; const bool IncludeReward; + LoggerDataImpl(const std::vector &LoggedSpecs, + const TensorSpec &RewardSpec, bool IncludeReward) + : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec), + IncludeReward(IncludeReward) {} + virtual void logRewardImpl(const char *Value, size_t Size) = 0; + +public: + // flush the logged info to a stream and clear the log contents. + virtual void flush(std::string *Str) = 0; + virtual char *addNewTensor(size_t FeatureID) = 0; + virtual size_t getNrRecords() const = 0; + virtual ~LoggerDataImpl() = default; + + template void logReward(T Value) { + logRewardImpl(reinterpret_cast(&Value), sizeof(T)); + } +}; +class TFSequenceExampleLoggerDataImpl : public LoggerDataImpl { std::vector FeatureLists; tensorflow::FeatureList Reward; @@ -94,13 +113,14 @@ } public: - LoggerDataImpl(const std::vector &LoggedSpecs, - const TensorSpec &RewardSpec, bool IncludeReward) - : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec), - IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {} + TFSequenceExampleLoggerDataImpl(const std::vector &LoggedSpecs, + const TensorSpec &RewardSpec, + bool IncludeReward) + : LoggerDataImpl(LoggedSpecs, RewardSpec, IncludeReward), + FeatureLists(LoggedFeatureSpecs.size()) {} // flush the logged info to a stream and clear the log contents. - void flush(std::string *Str) { + void flush(std::string *Str) override { size_t NrRecords = getNrRecords(); (void)NrRecords; tensorflow::SequenceExample SE; @@ -109,7 +129,7 @@ serialize(SE, Str); } - char *addNewTensor(size_t FeatureID) { + char *addNewTensor(size_t FeatureID) override { const auto &Spec = LoggedFeatureSpecs[FeatureID]; if (Spec.isElementType()) { auto *RF = FeatureLists[FeatureID] @@ -129,18 +149,22 @@ llvm_unreachable("Unsupported tensor type."); } - template void logReward(T Value) { + void logRewardImpl(const char *Value, size_t Size) override { assert(IncludeReward); if (RewardSpec.isElementType()) - Reward.add_feature()->mutable_float_list()->add_value(Value); - else if (RewardSpec.isElementType() || - RewardSpec.isElementType()) - Reward.add_feature()->mutable_int64_list()->add_value(Value); + Reward.add_feature()->mutable_float_list()->add_value( + *reinterpret_cast(Value)); + else if (RewardSpec.isElementType()) + Reward.add_feature()->mutable_int64_list()->add_value( + *reinterpret_cast(Value)); + else if (RewardSpec.isElementType()) + Reward.add_feature()->mutable_int64_list()->add_value( + *reinterpret_cast(Value)); else llvm_unreachable("Unsupported tensor type."); } - size_t getNrRecords() const { + size_t getNrRecords() const override { return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size(); } }; @@ -150,8 +174,8 @@ const TensorSpec &RewardSpec, bool IncludeReward) : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), IncludeReward(IncludeReward), - LoggerData(std::make_unique(FeatureSpecs, RewardSpec, - IncludeReward)) {} + LoggerData(std::make_unique( + FeatureSpecs, RewardSpec, IncludeReward)) {} Logger::~Logger() {}