diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -95,6 +95,7 @@ Int64, Int64, Int32Arr3Ty, Int32Arr3Ty, Int32) __OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, false, Int8Ptr) __OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, false, SizeTy, SizeTy, Int8) +__OMP_STRUCT_TYPE(Task, kmp_task_ompbuilder_t, false, VoidPtr, VoidPtr, Int32, VoidPtr, VoidPtr) __OMP_STRUCT_TYPE(ConfigurationEnvironment, ConfigurationEnvironmentTy, false, Int8, Int8, Int8) __OMP_STRUCT_TYPE(DynamicEnvironment, DynamicEnvironmentTy, false, Int16) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1555,9 +1555,9 @@ "there must be a single user for the outlined function"); CallInst *StaleCI = cast(OutlinedFn.user_back()); - // HasTaskData is true if any variables are captured in the outlined region, + // HasShareds is true if any variables are captured in the outlined region, // false otherwise. - bool HasTaskData = StaleCI->arg_size() > 0; + bool HasShareds = StaleCI->arg_size() > 0; Builder.SetInsertPoint(StaleCI); // Gather the arguments for emitting the runtime call for @@ -1585,8 +1585,15 @@ // Argument - `sizeof_kmp_task_t` (TaskSize) // Tasksize refers to the size in bytes of kmp_task_t data structure // including private vars accessed in task. - Value *TaskSize = Builder.getInt64(0); - if (HasTaskData) { + // TODO: add kmp_task_t_with_privates (privates) + Value *TaskSize = Builder.getInt64( + divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8)); + + // Argument - `sizeof_shareds` (SharedsSize) + // SharedsSize refers to the shareds array size in the kmp_task_t data + // structure. + Value *SharedsSize = Builder.getInt64(0); + if (HasShareds) { AllocaInst *ArgStructAlloca = dyn_cast(StaleCI->getArgOperand(0)); assert(ArgStructAlloca && @@ -1596,19 +1603,17 @@ dyn_cast(ArgStructAlloca->getAllocatedType()); assert(ArgStructType && "Unable to find struct type corresponding to " "arguments for extracted function"); - TaskSize = + SharedsSize = Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType)); } - // TODO: Argument - sizeof_shareds - // Argument - task_entry (the wrapper function) - // If the outlined function has some captured variables (i.e. HasTaskData is + // If the outlined function has some captured variables (i.e. HasShareds is // true), then the wrapper function will have an additional argument (the // struct containing captured variables). Otherwise, no such argument will // be present. SmallVector WrapperArgTys{Builder.getInt32Ty()}; - if (HasTaskData) + if (HasShareds) WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType()); FunctionCallee WrapperFuncVal = M.getOrInsertFunction( (Twine(OutlinedFn.getName()) + ".wrapper").str(), @@ -1617,19 +1622,19 @@ // Emit the @__kmpc_omp_task_alloc runtime call // The runtime call returns a pointer to an area where the task captured - // variables must be copied before the task is run (NewTaskData) - CallInst *NewTaskData = Builder.CreateCall( - TaskAllocFn, - {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags, - /*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0), - /*task_func=*/WrapperFunc}); + // variables must be copied before the task is run (TaskData) + CallInst *TaskData = Builder.CreateCall( + TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags, + /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize, + /*task_func=*/WrapperFunc}); // Copy the arguments for outlined function - if (HasTaskData) { - Value *TaskData = StaleCI->getArgOperand(0); + if (HasShareds) { + Value *Shareds = StaleCI->getArgOperand(0); Align Alignment = TaskData->getPointerAlignment(M.getDataLayout()); - Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment, - TaskSize); + Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData); + Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment, + SharedsSize); } Value *DepArrayPtr = nullptr; @@ -1705,12 +1710,12 @@ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0); Function *TaskCompleteFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0); - Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, NewTaskData}); - if (HasTaskData) - Builder.CreateCall(WrapperFunc, {ThreadID, NewTaskData}); + Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData}); + if (HasShareds) + Builder.CreateCall(WrapperFunc, {ThreadID, TaskData}); else Builder.CreateCall(WrapperFunc, {ThreadID}); - Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData}); + Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData}); Builder.SetInsertPoint(ThenTI); } @@ -1719,14 +1724,14 @@ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps); Builder.CreateCall( TaskFn, - {Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()), + {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()), DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0), ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))}); } else { // Emit the @__kmpc_omp_task runtime call to spawn the task Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); - Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); + Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData}); } StaleCI->eraseFromParent(); @@ -1735,10 +1740,13 @@ BasicBlock *WrapperEntryBB = BasicBlock::Create(M.getContext(), "", WrapperFunc); Builder.SetInsertPoint(WrapperEntryBB); - if (HasTaskData) - Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)}); - else + if (HasShareds) { + llvm::Value *Shareds = + Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1)); + Builder.CreateCall(&OutlinedFn, {Shareds}); + } else { Builder.CreateCall(&OutlinedFn); + } Builder.CreateRet(Builder.getInt32(0)); }; diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5397,19 +5397,29 @@ ConstantInt *DataSize = dyn_cast(TaskAllocCall->getArgOperand(3)); ASSERT_NE(DataSize, nullptr); - EXPECT_EQ(DataSize->getSExtValue(), 24); // 64-bit pointer + 128-bit integer + EXPECT_EQ(DataSize->getSExtValue(), 40); - // TODO: Verify size of shared clause variables + ConstantInt *SharedsSize = + dyn_cast(TaskAllocCall->getOperand(4)); + EXPECT_EQ(SharedsSize->getSExtValue(), + 24); // 64-bit pointer + 128-bit integer // Verify Wrapper function Function *WrapperFunc = dyn_cast(TaskAllocCall->getArgOperand(5)->stripPointerCasts()); ASSERT_NE(WrapperFunc, nullptr); + + LoadInst *SharedsLoad = dyn_cast(WrapperFunc->begin()->begin()); + ASSERT_NE(SharedsLoad, nullptr); + EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1)); + EXPECT_FALSE(WrapperFunc->isDeclaration()); - CallInst *OutlinedFnCall = dyn_cast(WrapperFunc->begin()->begin()); + CallInst *OutlinedFnCall = + dyn_cast(++WrapperFunc->begin()->begin()); ASSERT_NE(OutlinedFnCall, nullptr); EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty()); - EXPECT_EQ(OutlinedFnCall->getArgOperand(0), WrapperFunc->getArg(1)); + EXPECT_EQ(OutlinedFnCall->getArgOperand(0), + WrapperFunc->getArg(1)->uses().begin()->getUser()); // Verify the presence of `trunc` and `icmp` instructions in Outlined function Function *OutlinedFn = OutlinedFnCall->getCalledFunction(); diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2208,7 +2208,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc - // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0, + // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]]) // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) omp.task { @@ -2258,7 +2258,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr) { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc - // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0, + // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]]) // CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}}) omp.task depend(taskdependin -> %zaddr : !llvm.ptr) { @@ -2303,9 +2303,10 @@ llvm.store %diff, %zaddr : !llvm.ptr // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc - // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 16, i64 0, + // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16, // CHECK-SAME: ptr @[[wrapper_fn:.+]]) - // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[task_data]], ptr {{.+}}, i64 16, i1 false) + // CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]] + // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false) // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) omp.task { %z = llvm.add %x, %y : i32 @@ -2334,7 +2335,8 @@ // CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) { -// CHECK: call void @[[outlined_fn]](ptr %[[task_data]]) +// CHECK: %[[shareds:.+]] = load ptr, ptr %1, align 8 +// CHECK: call void @[[outlined_fn]](ptr %[[shareds]]) // CHECK: ret i32 0 // CHECK: } @@ -2430,7 +2432,7 @@ // CHECK: br label %[[codeRepl:[^,]+]] // CHECK: [[codeRepl]]: // CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) -// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 0, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper) +// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper) // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]]) // CHECK: br label %[[task_exit:[^,]+]] // CHECK: [[task_exit]]: @@ -2443,8 +2445,9 @@ // CHECK: %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2 // CHECK: store ptr %[[zaddr]], ptr %[[gep3]], align 8 // CHECK: %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) -// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 16, i64 0, ptr @omp_taskgroup_task..omp_par.1.wrapper) -// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[t2_alloc]], ptr align 8 %[[structArg]], i64 16, i1 false) +// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper) +// CHECK: %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]] +// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[shareds]], ptr align 1 %[[structArg]], i64 16, i1 false) // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]]) // CHECK: br label %[[task_exit3:[^,]+]] // CHECK: [[task_exit3]]: @@ -2614,7 +2617,7 @@ // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) // CHECK: %[[final_flag:.+]] = select i1 %[[boolexpr]], i32 2, i32 0 // CHECK: %[[task_flags:.+]] = or i32 %[[final_flag]], 1 -// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 0, i64 0, ptr @omp_task_final..omp_par.wrapper) +// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @omp_task_final..omp_par.wrapper) // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) // CHECK: br label %[[task_exit:[^,]+]] // CHECK: [[task_exit]]: @@ -2645,7 +2648,7 @@ // CHECK: br label %[[codeRepl:[^,]+]] // CHECK: [[codeRepl]]: // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) -// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0, i64 0, ptr @omp_task_if..omp_par.wrapper) +// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @omp_task_if..omp_par.wrapper) // CHECK: br i1 %[[boolexpr]], label %[[true_label:[^,]+]], label %[[false_label:[^,]+]] // CHECK: [[true_label]]: // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])