diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -23,6 +23,10 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" @@ -30,10 +34,8 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "llvm/Transforms/Utils/CodeExtractor.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/IntrinsicsAMDGPU.h" +using namespace llvm::PatternMatch; using namespace llvm; using namespace omp; @@ -75,6 +77,8 @@ "Number of OpenMP parallel regions replaced with ID in GPU state machines"); STATISTIC(NumOpenMPParallelRegionsMerged, "Number of OpenMP parallel regions merged"); +STATISTIC(NumBytesMovedToSharedMemory, + "Amount of memory pushed to shared memory"); #if !defined(NDEBUG) static constexpr auto TAG = "[" DEBUG_TYPE "]"; @@ -542,6 +546,7 @@ if (IsModulePass) { Changed |= runAttributor(); + Changed |= replaceGlobalization(); if (remarksEnabled()) analysisGlobalization(); } else { @@ -1022,6 +1027,80 @@ return Changed; } + /// Replace globalization calls in the device with global shared memory if it + /// is called by a single thread. + bool replaceGlobalization() { + auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; + auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared]; + bool Changed = false; + + auto ReplaceAllocCalls = [&](Use &U, Function &F) { + CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI); + if (!CB) + return false; + + auto *ED = A.lookupAAFor(IRPosition::function(F)); + if (!ED || !ED->isSingleThreadExecution(*CB)) + return false; + + ConstantInt *AllocSize = dyn_cast(CB->getArgOperand(0)); + if (!AllocSize) + return false; + + LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in " + << CB->getCaller()->getName() << " with " + << AllocSize->getZExtValue() + << " bytes of shared memory\n"); + + // Remove the free call + CallBase *FC = nullptr; + for (auto *U : CB->users()) { + CallBase *C = dyn_cast(U); + if (C && C->getCalledFunction() == FreeCall.Declaration) { + if (FC) + return false; + FC = C; + } + } + FC->eraseFromParent(); + + // Create a new shared memory buffer of the same size as the allocation + // and replace all the uses of the original allocation with it. + Type *Int8Ty = Type::getInt8Ty(M.getContext()); + Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue()); + auto *SharedMem = new GlobalVariable( + M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage, + UndefValue::get(Int8ArrTy), CB->getName(), nullptr, + GlobalValue::NotThreadLocal, + static_cast(AddressSpace::Shared)); + auto *NewBuffer = + ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo()); + + SharedMem->setAlignment(MaybeAlign(8)); + CB->replaceAllUsesWith(NewBuffer); + + auto Remark = [&](OptimizationRemark OR) { + return OR << "Replaced globalized variable with " + << ore::NV("SharedMemory", AllocSize->getZExtValue()) + << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ") + << "of shared memory"; + }; + emitRemark(CB, "OpenMPReplaceGlobalization", Remark); + + CB->eraseFromParent(); + + NumBytesMovedToSharedMemory += AllocSize->getZExtValue(); + Changed = true; + + return true; + }; + RFI.foreachUse(SCC, ReplaceAllocCalls); + + if (Changed) + OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_free_shared); + return Changed; + } + /// Try to delete parallel regions if possible. bool deleteParallelRegions() { const unsigned CallbackCalleeOperand = 2; @@ -1524,6 +1603,13 @@ /// Kernel (=GPU) optimizations and utility functions /// ///{{ + enum class AddressSpace : unsigned { + Generic = 0, + Global = 1, + Shared = 3, + Constant = 4, + Local = 5, + }; /// Check if \p F is a kernel, hence entry point for target offloading. bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } @@ -2363,6 +2449,21 @@ if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) return false; + // Temporarily match the pattern generated by clang for teams regions. + // TODO: Remove this once the new runtime is in place. + ConstantInt *One, *NegOne; + CmpInst::Predicate Pred; + auto &&m_ThreadID = m_Intrinsic(); + auto &&m_WarpSize = m_Intrinsic(); + auto &&m_BlockSize = m_Intrinsic(); + if (match(Cmp, m_Cmp(Pred, m_ThreadID, + m_And(m_Sub(m_BlockSize, m_ConstantInt(One)), + m_Xor(m_Sub(m_WarpSize, m_ConstantInt(One)), + m_ConstantInt(NegOne)))))) + if (One->isOne() && NegOne->isMinusOne() && + Pred == CmpInst::Predicate::ICMP_EQ) + return true; + ConstantInt *C = dyn_cast(Cmp->getOperand(1)); if (!C || !C->isZero()) return false; diff --git a/llvm/test/Transforms/OpenMP/replace_globalization.ll b/llvm/test/Transforms/OpenMP/replace_globalization.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/OpenMP/replace_globalization.ll @@ -0,0 +1,103 @@ +; RUN: opt -S -passes='openmp-opt' < %s | FileCheck %s +; RUN: opt -passes=openmp-opt -pass-remarks=openmp-opt -disable-output < %s 2>&1 | FileCheck %s -check-prefix=CHECK-REMARKS +target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64" +target triple = "nvptx64" + +; CHECK-REMARKS: remark: replace_globalization.c:5:7: Replaced globalized variable with 16 bytes of shared memory +; CHECK-REMARKS: remark: replace_globalization.c:5:14: Replaced globalized variable with 4 bytes of shared memory +; CHECK: [[SHARED_X:@.+]] = internal addrspace(3) global [16 x i8] undef +; CHECK: [[SHARED_Y:@.+]] = internal addrspace(3) global [4 x i8] undef + +; CHECK: %{{.*}} = call i8* @__kmpc_alloc_shared({{.*}}) +; CHECK: call void @__kmpc_free_shared({{.*}}) +define dso_local void @foo() { +entry: + %x = call i8* @__kmpc_alloc_shared(i64 4) + %x_on_stack = bitcast i8* %x to i32* + %0 = bitcast i32* %x_on_stack to i8* + call void @use(i8* %0) + call void @__kmpc_free_shared(i8* %x) + ret void +} + +define void @bar() { + call void @baz() + call void @qux() + ret void +} + +; CHECK: %{{.*}} = bitcast i8* addrspacecast (i8 addrspace(3)* getelementptr inbounds ([16 x i8], [16 x i8] addrspace(3)* [[SHARED_X]], i32 0, i32 0) to i8*) to [4 x i32]* +define internal void @baz() { +entry: + %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %cmp = icmp eq i32 %tid, 0 + br i1 %cmp, label %master, label %exit +master: + %x = call i8* @__kmpc_alloc_shared(i64 16), !dbg !9 + %x_on_stack = bitcast i8* %x to [4 x i32]* + %0 = bitcast [4 x i32]* %x_on_stack to i8* + call void @use(i8* %0) + call void @__kmpc_free_shared(i8* %x) + br label %exit +exit: + ret void +} + +; CHECK: %{{.*}} = bitcast i8* addrspacecast (i8 addrspace(3)* getelementptr inbounds ([4 x i8], [4 x i8] addrspace(3)* [[SHARED_Y]], i32 0, i32 0) to i8*) to [4 x i32]* +define internal void @qux() { +entry: + %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %ntid = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + %warpsize = call i32 @llvm.nvvm.read.ptx.sreg.warpsize() + %0 = sub nuw i32 %warpsize, 1 + %1 = sub nuw i32 %ntid, 1 + %2 = xor i32 %0, -1 + %master_tid = and i32 %1, %2 + %3 = icmp eq i32 %tid, %master_tid + br i1 %3, label %master, label %exit +master: + %y = call i8* @__kmpc_alloc_shared(i64 4), !dbg !10 + %y_on_stack = bitcast i8* %y to [4 x i32]* + %4 = bitcast [4 x i32]* %y_on_stack to i8* + call void @use(i8* %4) + call void @__kmpc_free_shared(i8* %y) + br label %exit +exit: + ret void +} + + +define void @use(i8* %x) { +entry: + %addr = alloca i8* + store i8* %x, i8** %addr + ret void +} + +declare i8* @__kmpc_alloc_shared(i64) + +declare void @__kmpc_free_shared(i8*) + +declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() + +declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + +declare i32 @llvm.nvvm.read.ptx.sreg.warpsize() + + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!3, !4} +!nvvm.annotations = !{!5, !6} + + +!0 = distinct !DICompileUnit(language: DW_LANG_C99, file: !1, producer: "clang version 12.0.0", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, enums: !2, splitDebugInlining: false, nameTableKind: None) +!1 = !DIFile(filename: "replace_globalization.c", directory: "/tmp/replace_globalization.c") +!2 = !{} +!3 = !{i32 2, !"Debug Info Version", i32 3} +!4 = !{i32 1, !"wchar_size", i32 4} +!5 = !{void ()* @foo, !"kernel", i32 1} +!6 = !{void ()* @bar, !"kernel", i32 1} +!7 = distinct !DISubprogram(name: "bar", scope: !1, file: !1, line: 1, type: !8, scopeLine: 1, flags: DIFlagPrototyped, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0, retainedNodes: !2) +!8 = !DISubroutineType(types: !2) +!9 = !DILocation(line: 5, column: 7, scope: !7) +!10 = !DILocation(line: 5, column: 14, scope: !7)