diff --git a/llvm/include/llvm/Analysis/Utils/TrainingLogger.h b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h --- a/llvm/include/llvm/Analysis/Utils/TrainingLogger.h +++ b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h @@ -121,8 +121,14 @@ const std::string ¤tContext() const { return CurrentContext; } + /// Check if there is at least an observation for `currentContext()`. bool hasObservationInProgress() const { - return ObservationIDs.find(CurrentContext) != ObservationIDs.end(); + return hasAnyObservationForContext(CurrentContext); + } + + /// Check if there is at least an observation for the context `Ctx`. + bool hasAnyObservationForContext(StringRef Ctx) const { + return ObservationIDs.find(Ctx) != ObservationIDs.end(); } template void logReward(T Value) { diff --git a/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp b/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp --- a/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp +++ b/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp @@ -474,7 +474,7 @@ void logRewardIfNeeded(const MachineFunction &MF, llvm::function_ref GetReward) override { - if (!Log) + if (!Log || !Log->hasAnyObservationForContext(MF.getName())) return; // The function pass manager would run all the function passes for a // function, so we assume the last context belongs to this function. If diff --git a/llvm/lib/CodeGen/MLRegallocPriorityAdvisor.cpp b/llvm/lib/CodeGen/MLRegallocPriorityAdvisor.cpp --- a/llvm/lib/CodeGen/MLRegallocPriorityAdvisor.cpp +++ b/llvm/lib/CodeGen/MLRegallocPriorityAdvisor.cpp @@ -197,7 +197,7 @@ void logRewardIfNeeded(const MachineFunction &MF, llvm::function_ref GetReward) override { - if (!Log) + if (!Log || !Log->hasAnyObservationForContext(MF.getName())) return; // The function pass manager would run all the function passes for a // function, so we assume the last context belongs to this function. If diff --git a/llvm/test/CodeGen/MLRegalloc/empty-log.ll b/llvm/test/CodeGen/MLRegalloc/empty-log.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/MLRegalloc/empty-log.ll @@ -0,0 +1,33 @@ +; REQUIRES: have_tflite +; REQUIRES: x86_64-linux +; +; Check that we can log more than 1 function. +; +; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=development \ +; RUN: -regalloc-training-log=%t1 < %s +; RUN: FileCheck --input-file %t1 %s + +; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-priority-advisor=development \ +; RUN: -regalloc-priority-training-log=%t2 < %s +; RUN: FileCheck --input-file %t2 %s + +declare void @f(); + +define void @f1(i64 %lhs, i64 %rhs, i64* %addr) { + ret void +} + +define void @f2(i64 %lhs, i64 %rhs, i64* %addr) { + %sum = add i64 %lhs, %rhs + call void @f(); + store i64 %sum, i64* %addr + ret void +} + +define void @f3(i64 %lhs, i64 %rhs, i64* %addr) { + ret void +} + +; CHECK-NOT: {"context":"f1"} +; CHECK: {"context":"f2"} +; CHECK-NOT: {"context":"f3"}