diff --git a/llvm/include/llvm/Transforms/Scalar/SpeculativeExecution.h b/llvm/include/llvm/Transforms/Scalar/SpeculativeExecution.h --- a/llvm/include/llvm/Transforms/Scalar/SpeculativeExecution.h +++ b/llvm/include/llvm/Transforms/Scalar/SpeculativeExecution.h @@ -66,10 +66,9 @@ #include "llvm/IR/PassManager.h" namespace llvm { -class SpeculativeExecutionPass - : public PassInfoMixin { +class SpeculativeExecutionPassImpl { public: - SpeculativeExecutionPass(bool OnlyIfDivergentTarget = false); + SpeculativeExecutionPassImpl(bool OnlyIfDivergentTarget = false); PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); @@ -86,6 +85,22 @@ TargetTransformInfo *TTI = nullptr; }; + +class SpeculativeExecutionPass + : public SpeculativeExecutionPassImpl, + public PassInfoMixin { +public: + SpeculativeExecutionPass(bool OnlyIfDivergentTarget = false) + : SpeculativeExecutionPassImpl(OnlyIfDivergentTarget) {} +}; + +class SpeculativeExecutionIfHasBranchDivergencePass + : public SpeculativeExecutionPassImpl, + public PassInfoMixin { +public: + SpeculativeExecutionIfHasBranchDivergencePass() + : SpeculativeExecutionPassImpl(/*OnlyIfDivergentTarget=*/true) {} +}; } #endif //LLVM_TRANSFORMS_SCALAR_SPECULATIVEEXECUTION_H diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -440,7 +440,7 @@ // Speculative execution if the target has divergent branches; otherwise nop. if (Level.getSpeedupLevel() > 1) { - FPM.addPass(SpeculativeExecutionPass()); + FPM.addPass(SpeculativeExecutionPass(/*OnlyIfDivergentTarget=*/true)); // Optimize based on known information about branches, and cleanup afterward. FPM.addPass(JumpThreadingPass()); diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp --- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -123,7 +123,7 @@ // Variable preserved purely for correct name printing. const bool OnlyIfDivergentTarget; - SpeculativeExecutionPass Impl; + SpeculativeExecutionPassImpl Impl; }; } // namespace @@ -150,7 +150,8 @@ namespace llvm { -bool SpeculativeExecutionPass::runImpl(Function &F, TargetTransformInfo *TTI) { +bool SpeculativeExecutionPassImpl::runImpl(Function &F, + TargetTransformInfo *TTI) { if (OnlyIfDivergentTarget && !TTI->hasBranchDivergence()) { LLVM_DEBUG(dbgs() << "Not running SpeculativeExecution because " "TTI->hasBranchDivergence() is false.\n"); @@ -165,7 +166,7 @@ return Changed; } -bool SpeculativeExecutionPass::runOnBasicBlock(BasicBlock &B) { +bool SpeculativeExecutionPassImpl::runOnBasicBlock(BasicBlock &B) { BranchInst *BI = dyn_cast(B.getTerminator()); if (BI == nullptr) return false; @@ -251,8 +252,8 @@ } } -bool SpeculativeExecutionPass::considerHoistingFromTo( - BasicBlock &FromBlock, BasicBlock &ToBlock) { +bool SpeculativeExecutionPassImpl::considerHoistingFromTo(BasicBlock &FromBlock, + BasicBlock &ToBlock) { SmallPtrSet NotHoisted; const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](User *U) { for (Value* V : U->operand_values()) { @@ -299,12 +300,13 @@ return new SpeculativeExecutionLegacyPass(/* OnlyIfDivergentTarget = */ true); } -SpeculativeExecutionPass::SpeculativeExecutionPass(bool OnlyIfDivergentTarget) +SpeculativeExecutionPassImpl::SpeculativeExecutionPassImpl( + bool OnlyIfDivergentTarget) : OnlyIfDivergentTarget(OnlyIfDivergentTarget || SpecExecOnlyIfDivergentTarget) {} -PreservedAnalyses SpeculativeExecutionPass::run(Function &F, - FunctionAnalysisManager &AM) { +PreservedAnalyses +SpeculativeExecutionPassImpl::run(Function &F, FunctionAnalysisManager &AM) { auto *TTI = &AM.getResult(F); bool Changed = runImpl(F, TTI);