diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -613,6 +613,18 @@ /// \param CriticalName Name of the critical region. /// Value *getOMPCriticalRegionLock(StringRef CriticalName); + + /// Capture the above-defined paraneters for the parallel regions. + /// + /// \param CaptureAllocaInsPoint Insertion point for the alloca-ed struct. + /// \param OuterFn The function containing the omp::Parallel. + /// \param Blocks The parallel region blocks. + /// \param TIDAddr The address of the TID value. + /// \param ZeroAddr The address of the Zero value. + void captureParallelRegionParameters( + const IRBuilder<>::InsertPoint &CaptureAllocaInsPoint, Function *OuterFn, + const SmallVectorImpl &Blocks, const Value *TIDAddr, + const Value *ZeroAddr); }; /// Class to represented the control flow structure of an OpenMP canonical loop. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -434,6 +434,12 @@ Value *Ident = getOrCreateIdent(SrcLocStr); Value *ThreadID = getOrCreateThreadID(Ident); + // Save the IP for the capture alloca struct generation and population. + // Store just before getting the thread number. (And before the end of the + // block). + IRBuilder<>::InsertPoint CaptureAllocaInsPoint(Builder.GetInsertBlock(), + --Builder.GetInsertPoint()); + if (NumThreads) { // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads) Value *Args[] = { @@ -687,6 +693,10 @@ FunctionCallee TIDRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num); + // Capture the outer parameters for the ParallelRegions. + captureParallelRegionParameters(CaptureAllocaInsPoint, OuterFn, Blocks, + TIDAddr, ZeroAddr); + auto PrivHelper = [&](Value &V) { if (&V == TIDAddr || &V == ZeroAddr) return; @@ -740,6 +750,88 @@ return AfterIP; } +void OpenMPIRBuilder::captureParallelRegionParameters( + const IRBuilder<>::InsertPoint &CaptureAllocaInsPoint, Function *OuterFn, + const SmallVectorImpl &Blocks, const Value *TIDAddr, + const Value *ZeroAddr) { + // Capture outside parameters. + SetVector CapturedValues; + SetVector BlockParents; + for (unsigned Counter = 0; Counter < Blocks.size(); Counter++) { + BasicBlock *ParallelRegionBlock = Blocks[Counter]; + BlockParents.insert(ParallelRegionBlock); + } + for (unsigned Counter = 0; Counter < Blocks.size(); Counter++) { + BasicBlock *ParallelRegionBlock = Blocks[Counter]; + for (auto I = ParallelRegionBlock->begin(), E = ParallelRegionBlock->end(); + I != E; ++I) { + for (Use &U : I->operands()) { + Value *V = U.get(); + if (V == TIDAddr || V == ZeroAddr) + continue; + + // Skip pointers. + if (V->getType()->isPointerTy()) + continue; + + Instruction *DefInst = dyn_cast(V); + if (!DefInst || !DefInst->getParent()) + continue; + + // If the parent of the def instruction is not in the parallel + // region block set, the definition of the operant is in an + // upper block. + if (!BlockParents.contains(DefInst->getParent())) + CapturedValues.insert(V); + } + } + } + + // If there are captured parameters to the parallel loop, + // allocate the captured struct on the stack, set the element values. + // Then, load the capture struct, extract the elements and replace the + // captured values with the extracted ones from the struct. + if (CapturedValues.size()) { + // Create the StructTy + std::vector StructTypes; + for (unsigned Counter = 0; Counter < CapturedValues.size(); Counter++) + StructTypes.push_back(CapturedValues[Counter]->getType()); + + Type *CaptureStructType = + StructType::create(StructTypes, "CapturedStructType"); + + AllocaInst *AllocaInst; + { + llvm::IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(CaptureAllocaInsPoint.getBlock(), + CaptureAllocaInsPoint.getPoint()); + // Allocate and populate the capture struct. + AllocaInst = Builder.CreateAlloca(CaptureStructType, nullptr, + "CaptureStructAlloca"); + Value *InsertValue = UndefValue::get(CaptureStructType); + for (auto SrcIdx : enumerate(CapturedValues)) + InsertValue = Builder.CreateInsertValue(InsertValue, SrcIdx.value(), + SrcIdx.index()); + Builder.CreateStore(InsertValue, AllocaInst); + } + + Value *LoadedAlloca = Builder.CreateLoad(AllocaInst); + for (auto SrcIdx : enumerate(CapturedValues)) { + Value *LoadedValue = + Builder.CreateExtractValue(LoadedAlloca, SrcIdx.index()); + + // Find the usages of the captured values and replace them in the parallel + // region blocks. + for (unsigned Counter = 0; Counter < Blocks.size(); Counter++) + for (auto I = Blocks[Counter]->begin(), E = Blocks[Counter]->end(); + I != E; ++I) + for (Use &U : I->operands()) + if (SrcIdx.value() == U.get()) + U.set(LoadedValue); + } + } +} + void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) { // Build call void __kmpc_flush(ident_t *loc) Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -1232,4 +1232,83 @@ EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1)); } +TEST_F(OpenMPIRBuilderTest, ParallelCaptureUpperDefinedParameters) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + + Type *I32Ty = Type::getInt32Ty(M->getContext()); + Type *I32PtrTy = Type::getInt32PtrTy(M->getContext()); + Type *StructTy = StructType::get(I32Ty, I32PtrTy); + Type *StructPtrTy = StructTy->getPointerTo(); + Type *VoidTy = Type::getVoidTy(M->getContext()); + FunctionCallee RetI32Func = M->getOrInsertFunction("ret_i32", I32Ty); + FunctionCallee TakeI32Func = + M->getOrInsertFunction("take_i32", VoidTy, I32Ty); + FunctionCallee RetI32PtrFunc = M->getOrInsertFunction("ret_i32ptr", I32PtrTy); + FunctionCallee TakeI32PtrFunc = + M->getOrInsertFunction("take_i32ptr", VoidTy, I32PtrTy); + FunctionCallee RetStructFunc = M->getOrInsertFunction("ret_struct", StructTy); + FunctionCallee TakeStructFunc = + M->getOrInsertFunction("take_struct", VoidTy, StructTy); + FunctionCallee RetStructPtrFunc = + M->getOrInsertFunction("ret_structptr", StructPtrTy); + FunctionCallee TakeStructPtrFunc = + M->getOrInsertFunction("take_structPtr", VoidTy, StructPtrTy); + Value *I32Val = Builder.CreateCall(RetI32Func); + Value *I32PtrVal = Builder.CreateCall(RetI32PtrFunc); + Value *StructVal = Builder.CreateCall(RetStructFunc); + Value *StructPtrVal = Builder.CreateCall(RetStructPtrFunc); + + Instruction *Internal; + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &ContinuationBB) { + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.restoreIP(CodeGenIP); + Internal = Builder.CreateCall(TakeI32Func, I32Val); + Builder.CreateCall(TakeI32PtrFunc, I32PtrVal); + Builder.CreateCall(TakeStructFunc, StructVal); + Builder.CreateCall(TakeStructPtrFunc, StructPtrVal); + }; + auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &V, + Value *&ReplacementValue) -> InsertPointTy { + ReplacementValue = &V; + return CodeGenIP; + }; + auto FiniCB = [](InsertPointTy) {}; + + IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + IRBuilder<>::InsertPoint AfterIP = + OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, + nullptr, nullptr, OMP_PROC_BIND_default, false); + + Builder.restoreIP(AfterIP); + Builder.CreateRetVoid(); + + OMPBuilder.finalize(); + + EXPECT_FALSE(verifyModule(*M, &errs())); + Function *OutlinedFn = Internal->getFunction(); + + Type *Arg2Type = OutlinedFn->getArg(2)->getType(); + EXPECT_TRUE(Arg2Type->isPointerTy()); + Type *StructElemTy = Arg2Type->getPointerElementType(); + EXPECT_STREQ(StructElemTy->getStructName().data(), "CapturedStructType"); + EXPECT_TRUE(StructElemTy->isStructTy()); + EXPECT_EQ(StructElemTy->getStructNumElements(), static_cast(2)); + StructType *StructTypeTy = reinterpret_cast(StructElemTy); + EXPECT_TRUE(StructTypeTy->getElementType(0)->isIntegerTy(32)); + EXPECT_TRUE(StructTypeTy->getElementType(1)->isStructTy()); + StructType *InnerStructType = + reinterpret_cast(StructTypeTy->getElementType(1)); + EXPECT_TRUE(InnerStructType->getElementType(0)->isIntegerTy(32)); + EXPECT_TRUE(InnerStructType->getElementType(1)->isPointerTy()); + EXPECT_TRUE( + InnerStructType->getElementType(1)->getPointerElementType()->isIntegerTy( + 32)); +} } // namespace diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -6,7 +6,7 @@ %end = constant 0 : index // CHECK: omp.parallel omp.parallel { - // CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64 + // CHECK: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64 br ^bb1(%start, %end : index, index) // CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: !llvm.i64, %[[ARG2:[0-9]+]]: !llvm.i64):{{.*}} ^bb1(%0: index, %1: index): diff --git a/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir b/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +module { + llvm.func @malloc(!llvm.i64) -> !llvm.ptr + llvm.func @main() { + %0 = llvm.mlir.constant(4 : index) : !llvm.i64 + %1 = llvm.mlir.constant(4 : index) : !llvm.i64 + %2 = llvm.mlir.null : !llvm.ptr + %3 = llvm.mlir.constant(1 : index) : !llvm.i64 + %4 = llvm.getelementptr %2[%3] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + %5 = llvm.ptrtoint %4 : !llvm.ptr to !llvm.i64 + %6 = llvm.mul %1, %5 : !llvm.i64 + %7 = llvm.call @malloc(%6) : (!llvm.i64) -> !llvm.ptr + %8 = llvm.bitcast %7 : !llvm.ptr to !llvm.ptr + %9 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %11 = llvm.insertvalue %8, %10[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %12 = llvm.mlir.constant(0 : index) : !llvm.i64 + %13 = llvm.insertvalue %12, %11[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.mlir.constant(1 : index) : !llvm.i64 + %15 = llvm.insertvalue %1, %13[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %16 = llvm.insertvalue %14, %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %17 = llvm.mlir.constant(4.200000e+01 : f32) : !llvm.float + // CHECK: %CaptureStructAlloca = alloca %CapturedStructType + // CHECK: %{{.*}} = insertvalue %CapturedStructType undef, {{.*}}, 0 + // CHECK: store %CapturedStructType %{{.*}}, %CapturedStructType* %CaptureStructAlloca + omp.parallel num_threads(%0 : !llvm.i64) { + // CHECK: %{{.*}} = load %CapturedStructType, %CapturedStructType* %CaptureStructAlloca + // CHECK: %{{.*}} = extractvalue %CapturedStructType %{{.*}}, 0 + %27 = llvm.mlir.constant(1 : i64) : !llvm.i64 + %28 = llvm.extractvalue %16[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %29 = llvm.mlir.constant(0 : index) : !llvm.i64 + %30 = llvm.mlir.constant(1 : index) : !llvm.i64 + %31 = llvm.mul %27, %30 : !llvm.i64 + %32 = llvm.add %29, %31 : !llvm.i64 + %33 = llvm.getelementptr %28[%32] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + llvm.store %17, %33 : !llvm.ptr + omp.terminator + } + llvm.return + } +}