diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt --- a/llvm/lib/Target/NVPTX/CMakeLists.txt +++ b/llvm/lib/Target/NVPTX/CMakeLists.txt @@ -25,6 +25,7 @@ NVPTXLowerAggrCopies.cpp NVPTXLowerArgs.cpp NVPTXLowerAlloca.cpp + NVPTXLowerUnreachable.cpp NVPTXPeephole.cpp NVPTXMCExpr.cpp NVPTXPrologEpilogPass.cpp diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -47,6 +47,7 @@ FunctionPass *createNVPTXImageOptimizerPass(); FunctionPass *createNVPTXLowerArgsPass(); FunctionPass *createNVPTXLowerAllocaPass(); +FunctionPass *createNVPTXLowerUnreachablePass(); MachineFunctionPass *createNVPTXPeephole(); MachineFunctionPass *createNVPTXProxyRegErasurePass(); diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp @@ -0,0 +1,126 @@ +//===-- NVPTXLowerUnreachable.cpp - Lower unreachables to exit =====--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// PTX does not have a notion of `unreachable`, which results in emitted basic +// blocks having an edge to the next block: +// +// block1: +// call @does_not_return(); +// // unreachable +// block2: +// // ptxas will create a CFG edge from block1 to block2 +// +// This may result in significant changes to the control flow graph, e.g., when +// LLVM moves unreachable blocks to the end of the function. That's a problem +// in the context of divergent control flow, as `ptxas` uses the CFG to +// determine divergent regions, and some intructions may not be executed +// divergently. +// +// For example, `bar.sync` is not allowed to be executed divergently on Pascal +// or earlier. If we start with the following: +// +// entry: +// // start of divergent region +// @%p0 bra cont; +// @%p1 bra unlikely; +// ... +// bra.uni cont; +// unlikely: +// ... +// // unreachable +// cont: +// // end of divergent region +// bar.sync 0; +// bra.uni exit; +// exit: +// ret; +// +// it is transformed by the branch-folder and block-placement passes to: +// +// entry: +// // start of divergent region +// @%p0 bra cont; +// @%p1 bra unlikely; +// ... +// bra.uni cont; +// cont: +// bar.sync 0; +// bra.uni exit; +// unlikely: +// ... +// // unreachable +// exit: +// // end of divergent region +// ret; +// +// After moving the `unlikely` block to the end of the function, it has an edge +// to the `exit` block, which widens the divergent region and makes the +// `bar.sync` instruction happen divergently. +// +// To work around this, we add an `exit` instruction before every `unreachable`, +// as `ptxas` understands that exit terminates the CFG. Note that `trap` is not +// equivalent, and only future versions of `ptxas` will model it like `exit`. +// +//===----------------------------------------------------------------------===// + +#include "NVPTX.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" + +using namespace llvm; + +namespace llvm { +void initializeNVPTXLowerUnreachablePass(PassRegistry &); +} + +namespace { +class NVPTXLowerUnreachable : public FunctionPass { + bool runOnFunction(Function &F) override; + +public: + static char ID; // Pass identification, replacement for typeid + NVPTXLowerUnreachable() : FunctionPass(ID) {} + StringRef getPassName() const override { + return "add an exit instruction before every unreachable"; + } +}; +} // namespace + +char NVPTXLowerUnreachable::ID = 1; + +INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable", + "Lower Unreachable", false, false) + +// ============================================================================= +// Main function for this pass. +// ============================================================================= +bool NVPTXLowerUnreachable::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + LLVMContext &C = F.getContext(); + FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false); + InlineAsm *Exit = InlineAsm::get(ExitFTy, "exit;", "", true); + + bool Changed = false; + for (auto &BB : F) + for (auto &I : BB) { + if (auto unreachableInst = dyn_cast(&I)) { + Changed = true; + CallInst::Create(ExitFTy, Exit, "", unreachableInst); + } + } + return Changed; +} + +FunctionPass *llvm::createNVPTXLowerUnreachablePass() { + return new NVPTXLowerUnreachable(); +} diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp --- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -72,6 +72,7 @@ void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &); void initializeNVPTXLowerAggrCopiesPass(PassRegistry &); void initializeNVPTXLowerAllocaPass(PassRegistry &); +void initializeNVPTXLowerUnreachablePass(PassRegistry &); void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &); void initializeNVPTXLowerArgsPass(PassRegistry &); void initializeNVPTXProxyRegErasurePass(PassRegistry &); @@ -98,6 +99,7 @@ initializeNVPTXAtomicLowerPass(PR); initializeNVPTXLowerArgsPass(PR); initializeNVPTXLowerAllocaPass(PR); + initializeNVPTXLowerUnreachablePass(PR); initializeNVPTXCtorDtorLoweringLegacyPass(PR); initializeNVPTXLowerAggrCopiesPass(PR); initializeNVPTXProxyRegErasurePass(PR); @@ -400,6 +402,8 @@ addPass(createLoadStoreVectorizerPass()); addPass(createSROAPass()); } + + addPass(createNVPTXLowerUnreachablePass()); } bool NVPTXPassConfig::addInstSelector() { diff --git a/llvm/test/CodeGen/NVPTX/unreachable.ll b/llvm/test/CodeGen/NVPTX/unreachable.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/unreachable.ll @@ -0,0 +1,23 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s +; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s +; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %} +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %} + +; CHECK: .extern .func throw +declare void @throw() #0 + +; CHECK: .entry kernel_func +define void @kernel_func() { +; CHECK: call.uni +; CHECK: throw, + call void @throw() +; CHECK: exit + unreachable +} + +attributes #0 = { noreturn } + + +!nvvm.annotations = !{!1} + +!1 = !{ptr @kernel_func, !"kernel", i32 1}