diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -613,6 +613,18 @@ /// \param CriticalName Name of the critical region. /// Value *getOMPCriticalRegionLock(StringRef CriticalName); + + /// Capture the above-defined paraneters for the parallel regions. + /// + /// \param CaptureAllocaInsPoint Insertion point for the alloca-ed struct. + /// \param OuterFn The function containing the omp::Parallel. + /// \param Blocks The parallel region blocks. + /// \param TIDAddr The address of the TID value. + /// \param ZeroAddr The address of the Zero value. + void captureParallelRegionParameters( + const IRBuilder<>::InsertPoint &CaptureAllocaInsPoint, Function *OuterFn, + const SmallVector &Blocks, const Value *const TIDAddr, + const Value *const ZeroAddr); }; /// Class to represented the control flow structure of an OpenMP canonical loop. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -443,6 +443,12 @@ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args); } + Instruction *ThreadIDInst = dyn_cast(ThreadID); + + // Save the IP for the capture alloca struct generation and population. + IRBuilder<>::InsertPoint CaptureAllocaInsPoint(ThreadIDInst->getParent(), + ThreadIDInst->getIterator()); + if (ProcBind != OMP_PROC_BIND_default) { // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind) Value *Args[] = { @@ -687,6 +693,10 @@ FunctionCallee TIDRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num); + // Capture the outer parameters for the ParallelRegions. + captureParallelRegionParameters(CaptureAllocaInsPoint, OuterFn, Blocks, + TIDAddr, ZeroAddr); + auto PrivHelper = [&](Value &V) { if (&V == TIDAddr || &V == ZeroAddr) return; @@ -740,6 +750,99 @@ return AfterIP; } +void OpenMPIRBuilder::captureParallelRegionParameters( + const IRBuilder<>::InsertPoint &CaptureAllocaInsPoint, Function *OuterFn, + const SmallVector &Blocks, const Value *const TIDAddr, + const Value *const ZeroAddr) { + // Capture outside parameters. + SetVector capturedValues; + for (Function::iterator B = OuterFn->begin(), BE = OuterFn->end(); B != BE; + ++B) { + for (auto I = B->begin(), E = B->end(); I != E; ++I) { + for (Use &U : I->operands()) { + Value *V = U.get(); + if (V == TIDAddr || V == ZeroAddr) + continue; + + // Skip pointers. + if (V->getType()->isPointerTy()) + continue; + + Instruction *I = dyn_cast(V); + // This could be a constant (propagated constant) value, + // that has no defining instruction. + if (!I) + continue; + + // If the definition is in a parallel block, no need to capture. + bool isInParallelRegionBlock = false; + for (BasicBlock *regionBlock : Blocks) + if (regionBlock == I->getParent()) { + isInParallelRegionBlock = true; + break; + } + + if (isInParallelRegionBlock) + continue; + + // If the user is defined in the a parallel block, capture. + if (auto *UserI = dyn_cast(U.getUser())) { + bool isUserInParallelRegionBlock = false; + for (BasicBlock *regionBlock : Blocks) + if (regionBlock == UserI->getParent()) { + isUserInParallelRegionBlock = true; + break; + } + if (isUserInParallelRegionBlock) + capturedValues.insert(V); + } + } + } + } + + // If there are captured parameters to the parallel loop, + // allocate the captured struct on the stack, set the element values. + // Then, load the capture struct, extract the elements and replace the + // captured values with the extracted ones from the struct. + if (capturedValues.size()) { + // Create the StructTy + std::vector structTypes; + for (unsigned i = 0; i < capturedValues.size(); i++) + structTypes.push_back(capturedValues[i]->getType()); + + Type *captureStructType = + StructType::create(structTypes, "CapturedStructType"); + + InsertPointTy sIP = Builder.saveIP(); + + Builder.SetInsertPoint(CaptureAllocaInsPoint.getBlock(), + CaptureAllocaInsPoint.getPoint()); + // Allocate and populate the capture struct. + llvm::AllocaInst *allocaInst = + Builder.CreateAlloca(captureStructType, nullptr, "CaptureStructAlloca"); + llvm::Value *insertValue = llvm::UndefValue::get(captureStructType); + for (auto srcIdx : llvm::enumerate(capturedValues)) + insertValue = Builder.CreateInsertValue(insertValue, srcIdx.value(), + srcIdx.index()); + Builder.CreateStore(insertValue, allocaInst); + Builder.restoreIP(sIP); + + llvm::Value *loadedAlloca = Builder.CreateLoad(allocaInst); + for (auto srcIdx : llvm::enumerate(capturedValues)) { + llvm::Value *loadedValue = + Builder.CreateExtractValue(loadedAlloca, srcIdx.index()); + + // Find the usages of the captured values and replace them in the parallel + // region blocks. + for (BasicBlock *B : Blocks) + for (auto I = B->begin(), E = B->end(); I != E; ++I) + for (Use &U : I->operands()) + if (srcIdx.value() == U.get()) + U.set(loadedValue); + } + } +} + void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) { // Build call void __kmpc_flush(ident_t *loc) Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -1232,4 +1232,83 @@ EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1)); } +TEST_F(OpenMPIRBuilderTest, ParallelCaptureUpperDefinedParameters) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + + Type *I32Ty = Type::getInt32Ty(M->getContext()); + Type *I32PtrTy = Type::getInt32PtrTy(M->getContext()); + Type *StructTy = StructType::get(I32Ty, I32PtrTy); + Type *StructPtrTy = StructTy->getPointerTo(); + Type *VoidTy = Type::getVoidTy(M->getContext()); + FunctionCallee RetI32Func = M->getOrInsertFunction("ret_i32", I32Ty); + FunctionCallee TakeI32Func = + M->getOrInsertFunction("take_i32", VoidTy, I32Ty); + FunctionCallee RetI32PtrFunc = M->getOrInsertFunction("ret_i32ptr", I32PtrTy); + FunctionCallee TakeI32PtrFunc = + M->getOrInsertFunction("take_i32ptr", VoidTy, I32PtrTy); + FunctionCallee RetStructFunc = M->getOrInsertFunction("ret_struct", StructTy); + FunctionCallee TakeStructFunc = + M->getOrInsertFunction("take_struct", VoidTy, StructTy); + FunctionCallee RetStructPtrFunc = + M->getOrInsertFunction("ret_structptr", StructPtrTy); + FunctionCallee TakeStructPtrFunc = + M->getOrInsertFunction("take_structPtr", VoidTy, StructPtrTy); + Value *I32Val = Builder.CreateCall(RetI32Func); + Value *I32PtrVal = Builder.CreateCall(RetI32PtrFunc); + Value *StructVal = Builder.CreateCall(RetStructFunc); + Value *StructPtrVal = Builder.CreateCall(RetStructPtrFunc); + + Instruction *Internal; + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &ContinuationBB) { + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.restoreIP(CodeGenIP); + Internal = Builder.CreateCall(TakeI32Func, I32Val); + Builder.CreateCall(TakeI32PtrFunc, I32PtrVal); + Builder.CreateCall(TakeStructFunc, StructVal); + Builder.CreateCall(TakeStructPtrFunc, StructPtrVal); + }; + auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &V, + Value *&ReplacementValue) -> InsertPointTy { + ReplacementValue = &V; + return CodeGenIP; + }; + auto FiniCB = [](InsertPointTy) {}; + + IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + IRBuilder<>::InsertPoint AfterIP = + OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, + nullptr, nullptr, OMP_PROC_BIND_default, false); + + Builder.restoreIP(AfterIP); + Builder.CreateRetVoid(); + + OMPBuilder.finalize(); + + EXPECT_FALSE(verifyModule(*M, &errs())); + Function *OutlinedFn = Internal->getFunction(); + + Type *Arg2Type = OutlinedFn->getArg(2)->getType(); + EXPECT_TRUE(Arg2Type->isPointerTy()); + Type *structElemTy = Arg2Type->getPointerElementType(); + EXPECT_STREQ(structElemTy->getStructName().data(), "CapturedStructType"); + EXPECT_TRUE(structElemTy->isStructTy()); + EXPECT_EQ(structElemTy->getStructNumElements(), static_cast(2)); + StructType *structType = reinterpret_cast(structElemTy); + EXPECT_TRUE(structType->getElementType(0)->isIntegerTy(32)); + EXPECT_TRUE(structType->getElementType(1)->isStructTy()); + StructType *innerStructType = + reinterpret_cast(structType->getElementType(1)); + EXPECT_TRUE(innerStructType->getElementType(0)->isIntegerTy(32)); + EXPECT_TRUE(innerStructType->getElementType(1)->isPointerTy()); + EXPECT_TRUE( + innerStructType->getElementType(1)->getPointerElementType()->isIntegerTy( + 32)); +} } // namespace diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -12,6 +12,9 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Transforms/RegionUtils.h" + +#include "llvm/ADT/SetVector.h" using namespace mlir; diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -6,7 +6,7 @@ %end = constant 0 : index // CHECK: omp.parallel omp.parallel { - // CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64 + // CHECK: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64 br ^bb1(%start, %end : index, index) // CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: !llvm.i64, %[[ARG2:[0-9]+]]: !llvm.i64):{{.*}} ^bb1(%0: index, %1: index): diff --git a/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir b/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +module { + llvm.func @malloc(!llvm.i64) -> !llvm.ptr + llvm.func @main() { + %0 = llvm.mlir.constant(4 : index) : !llvm.i64 + %1 = llvm.mlir.constant(4 : index) : !llvm.i64 + %2 = llvm.mlir.null : !llvm.ptr + %3 = llvm.mlir.constant(1 : index) : !llvm.i64 + %4 = llvm.getelementptr %2[%3] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + %5 = llvm.ptrtoint %4 : !llvm.ptr to !llvm.i64 + %6 = llvm.mul %1, %5 : !llvm.i64 + %7 = llvm.call @malloc(%6) : (!llvm.i64) -> !llvm.ptr + %8 = llvm.bitcast %7 : !llvm.ptr to !llvm.ptr + %9 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %11 = llvm.insertvalue %8, %10[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %12 = llvm.mlir.constant(0 : index) : !llvm.i64 + %13 = llvm.insertvalue %12, %11[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.mlir.constant(1 : index) : !llvm.i64 + %15 = llvm.insertvalue %1, %13[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %16 = llvm.insertvalue %14, %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %17 = llvm.mlir.constant(4.200000e+01 : f32) : !llvm.float + // CHECK: %CaptureStructAlloca = alloca %CapturedStructType + // CHECK: %{{.*}} = insertvalue %CapturedStructType undef, {{.*}}, 0 + // CHECK: store %CapturedStructType %{{.*}}, %CapturedStructType* %CaptureStructAlloca + omp.parallel num_threads(%0 : !llvm.i64) { + // CHECK: %{{.*}} = load %CapturedStructType, %CapturedStructType* %CaptureStructAlloca + // CHECK: %{{.*}} = extractvalue %CapturedStructType %{{.*}}, 0 + %27 = llvm.mlir.constant(1 : i64) : !llvm.i64 + %28 = llvm.extractvalue %16[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %29 = llvm.mlir.constant(0 : index) : !llvm.i64 + %30 = llvm.mlir.constant(1 : index) : !llvm.i64 + %31 = llvm.mul %27, %30 : !llvm.i64 + %32 = llvm.add %29, %31 : !llvm.i64 + %33 = llvm.getelementptr %28[%32] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + llvm.store %17, %33 : !llvm.ptr + omp.terminator + } + llvm.return + } +}