Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2089,6 +2089,12 @@ /// duplicating the body code. enum BodyGenTy { Priv, DupNoPriv, NoPriv }; + /// Callback type for creating the map infos for the kernel parameters. + /// \param CodeGenIP is the insertion point where code should be generated, + /// if any. + using GenMapInfoCallbackTy = + function_ref; + /// Generator for '#omp target data' /// /// \param Loc The location where the target data construct was encountered. @@ -2109,8 +2115,7 @@ OpenMPIRBuilder::InsertPointTy createTargetData( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond, - TargetDataInfo &Info, - function_ref GenMapInfoCB, + TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB, omp::RuntimeFunction *MapperFunc = nullptr, function_ref @@ -2134,10 +2139,12 @@ /// as arguments to the outlined function. /// \param BodyGenCB Callback that will generate the region code. InsertPointTy createTarget(const LocationDescription &Loc, + OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams, int32_t NumThreads, SmallVectorImpl &Inputs, + GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB); /// Declarations for LLVM-IR types (simple, array, function and structure) are Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp =================================================================== --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4161,8 +4161,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond, - TargetDataInfo &Info, - function_ref GenMapInfoCB, + TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB, omp::RuntimeFunction *MapperFunc, function_ref BodyGenCB, @@ -4347,42 +4346,94 @@ static void emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, TargetRegionEntryInfo &EntryInfo, - Function *&OutlinedFn, int32_t NumTeams, - int32_t NumThreads, SmallVectorImpl &Inputs, - OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { + Function *&OutlinedFn, Constant *&OutlinedFnID, + int32_t NumTeams, int32_t NumThreads, + SmallVectorImpl &Inputs, + OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc, + OpenMPIRBuilder::InsertPointTy AllocaIP) { OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction = - [&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) { + [&OMPBuilder, &Builder, &Inputs, &CBFunc, + &AllocaIP](StringRef EntryFnName) { return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs, CBFunc); }; - Constant *OutlinedFnID; OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, NumTeams, NumThreads, true, OutlinedFn, OutlinedFnID); } -static void emitTargetCall(IRBuilderBase &Builder, Function *OutlinedFn, - SmallVectorImpl &Args) { - // TODO: Add kernel launch call - Builder.CreateCall(OutlinedFn, Args); +static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + OpenMPIRBuilder::InsertPointTy AllocaIP, + Function *OutlinedFn, Constant *OutlinedFnID, + int32_t NumTeams, int32_t NumThreads, + SmallVectorImpl &Args, + OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB) { + + OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + auto MapInfo = GenMapInfoCB(Builder.saveIP()); + OMPBuilder.emitOffloadingArrays(AllocaIP, Builder.saveIP(), MapInfo, Info, + /*IsNonContiguous=*/true); + + OpenMPIRBuilder::TargetDataRTArgs RTArgs; + OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info); + + // emitKernelLaunch + auto &&emitTargetCallFallbackCB = + [&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy { + Builder.restoreIP(IP); + Builder.CreateCall(OutlinedFn, Args); + return Builder.saveIP(); + }; + + unsigned NumTargetItems = MapInfo.BasePointers.size(); + // TODO: Use correct device ID + Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); + Value *NumTeamsVal = Builder.getInt32(NumTeams); + Value *NumThreadsVal = Builder.getInt32(NumThreads); + uint32_t SrcLocStrSize; + Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); + Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, + llvm::omp::IdentFlag(0), 0); + // TODO: Use correct NumIterations + Value *NumIterations = Builder.getInt64(0); + // TODO: Use correct DynCGGroupMem + Value *DynCGGroupMem = Builder.getInt32(0); + + bool HasNoWait = false; + + OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations, + NumTeamsVal, NumThreadsVal, + DynCGGroupMem, HasNoWait); + + Builder.restoreIP(OMPBuilder.emitKernelLaunch( + Builder, OutlinedFn, OutlinedFnID, emitTargetCallFallbackCB, KArgs, + DeviceID, RTLoc, AllocaIP)); } OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( - const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy CodeGenIP, - TargetRegionEntryInfo &EntryInfo, int32_t NumTeams, int32_t NumThreads, - SmallVectorImpl &Args, TargetBodyGenCallbackTy CBFunc) { + const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, + int32_t NumTeams, int32_t NumThreads, SmallVectorImpl &Args, + GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy CBFunc) { if (!updateToLocation(Loc)) return InsertPointTy(); Builder.restoreIP(CodeGenIP); Function *OutlinedFn; - emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, NumTeams, - NumThreads, Args, CBFunc); + Constant *OutlinedFnID; + emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, + OutlinedFnID, NumTeams, NumThreads, Args, CBFunc, + AllocaIP); if (!Config.isTargetDevice()) - emitTargetCall(Builder, OutlinedFn, Args); + emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams, + NumThreads, Args, GenMapInfoCB); + return Builder.saveIP(); } Index: llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp =================================================================== --- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5056,6 +5056,35 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +namespace { +// Some basic handling of argument mapping for the moment +void CreateDefaultMapInfos(llvm::OpenMPIRBuilder &ompBuilder, + llvm::SmallVectorImpl &args, + llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo) { + for (auto arg : args) { + if (!arg->getType()->isPointerTy()) { + combinedInfo.BasePointers.clear(); + combinedInfo.Pointers.clear(); + combinedInfo.Sizes.clear(); + combinedInfo.Types.clear(); + combinedInfo.Names.clear(); + return; + } + combinedInfo.BasePointers.emplace_back(arg); + combinedInfo.Pointers.emplace_back(arg); + uint32_t SrcLocStrSize; + combinedInfo.Names.emplace_back(ompBuilder.getOrCreateSrcLocStr( + "Unknown loc - stub implementation", SrcLocStrSize)); + combinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM)); + combinedInfo.Sizes.emplace_back(ompBuilder.Builder.getInt64( + ompBuilder.M.getDataLayout().getTypeAllocSize(arg->getType()))); + } +} +} // namespace + TEST_F(OpenMPIRBuilderTest, TargetRegion) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); @@ -5087,28 +5116,53 @@ Inputs.push_back(BPtr); Inputs.push_back(CPtr); + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) + -> llvm::OpenMPIRBuilder::MapInfosTy & { + CreateDefaultMapInfos(OMPBuilder, Inputs, CombinedInfos); + return CombinedInfos; + }; + TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); - Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(), EntryInfo, - -1, -1, Inputs, BodyGenCB)); + Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(), + Builder.saveIP(), EntryInfo, -1, 0, + Inputs, GenMapInfoCB, BodyGenCB)); OMPBuilder.finalize(); Builder.CreateRetVoid(); - // Check the outlined call + // Check the kernel launch sequence auto Iter = F->getEntryBlock().rbegin(); - CallInst *Call = dyn_cast(&*(++Iter)); - EXPECT_NE(Call, nullptr); + F->dump(); + EXPECT_TRUE(isa(&*(Iter))); + BranchInst *Branch = dyn_cast(&*(Iter)); + EXPECT_TRUE(isa(&*(++Iter))); + EXPECT_TRUE(isa(&*(++Iter))); + CallInst *Call = dyn_cast(&*(Iter)); + + // Check that the kernel launch function is called + Function *KernelLaunchFunc = Call->getCalledFunction(); + EXPECT_NE(KernelLaunchFunc, nullptr); + StringRef FunctionName = KernelLaunchFunc->getName(); + EXPECT_TRUE(FunctionName.startswith("__tgt_target_kernel")); + + // Check the fallback call + BasicBlock *FallbackBlock = Branch->getSuccessor(0); + Iter = FallbackBlock->rbegin(); + CallInst *FCall = dyn_cast(&*(++Iter)); + EXPECT_NE(FCall, nullptr); // Check that the correct aguments are passed in - for (auto ArgInput : zip(Call->args(), Inputs)) { + for (auto ArgInput : zip(FCall->args(), Inputs)) { EXPECT_EQ(std::get<0>(ArgInput), std::get<1>(ArgInput)); } // Check that the outlined function exists with the expected prefix - Function *OutlinedFunc = Call->getCalledFunction(); + Function *OutlinedFunc = FCall->getCalledFunction(); EXPECT_NE(OutlinedFunc, nullptr); - StringRef FunctionName = OutlinedFunc->getName(); - EXPECT_TRUE(FunctionName.startswith("__omp_offloading")); + StringRef FunctionName2 = OutlinedFunc->getName(); + EXPECT_TRUE(FunctionName2.startswith("__omp_offloading")); + EXPECT_FALSE(verifyModule(*M, &errs())); } @@ -5126,6 +5180,13 @@ Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)), Constant::getNullValue(Type::getInt32PtrTy(Ctx))}; + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) + -> llvm::OpenMPIRBuilder::MapInfosTy & { + CreateDefaultMapInfos(OMPBuilder, CapturedArgs, CombinedInfos); + return CombinedInfos; + }; + auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -5139,9 +5200,10 @@ TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, /*Line=*/3, /*Count=*/0); - Builder.restoreIP( - OMPBuilder.createTarget(Loc, EntryIP, EntryInfo, /*NumTeams=*/-1, - /*NumThreads=*/-1, CapturedArgs, BodyGenCB)); + Builder.restoreIP(OMPBuilder.createTarget( + Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1, + /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, BodyGenCB)); + Builder.CreateRetVoid(); OMPBuilder.finalize(); Index: mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1378,7 +1378,8 @@ DataLayout &DL, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, const SmallVector &mapOperands, - const ArrayAttr &mapTypes) { + const ArrayAttr &mapTypes, + bool IsTargetParams = false) { // Get map clause information. llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); unsigned index = 0; @@ -1398,8 +1399,12 @@ combinedInfo.Pointers.emplace_back(mapOpValue); combinedInfo.Names.emplace_back( mlir::LLVM::createMappingInformation(mapOp.getLoc(), *ompBuilder)); - combinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags( - mapTypes[index].dyn_cast().getInt())); + combinedInfo.Types.emplace_back( + llvm::omp::OpenMPOffloadMappingFlags( + mapTypes[index].dyn_cast().getInt()) | + (IsTargetParams + ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM + : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE)); combinedInfo.Sizes.emplace_back( builder.getInt64(getSizeInBytes(DL, mapOp.getType()))); index++; @@ -1663,11 +1668,27 @@ return failure(); int32_t defaultValTeams = -1; - int32_t defaultValThreads = -1; + int32_t defaultValThreads = 0; + + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); + + DataLayout DL = DataLayout(opInst.getParentOfType()); + SmallVector mapOperands = targetOp.getMapOperands(); + ArrayAttr mapTypes = targetOp.getMapTypes().value(); + + llvm::OpenMPIRBuilder::MapInfosTy combinedInfos; + auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) + -> llvm::OpenMPIRBuilder::MapInfosTy & { + builder.restoreIP(codeGenIP); + genMapInfos(builder, moduleTranslation, DL, combinedInfos, mapOperands, + mapTypes, true); + return combinedInfos; + }; builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget( - ompLoc, builder.saveIP(), entryInfo, defaultValTeams, defaultValThreads, - inputs, bodyCB)); + ompLoc, allocaIP, builder.saveIP(), entryInfo, defaultValTeams, + defaultValThreads, inputs, genMapInfoCB, bodyCB)); return bodyGenStatus; } Index: mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir =================================================================== --- mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir +++ mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir @@ -12,7 +12,7 @@ %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr llvm.store %1, %3 : !llvm.ptr llvm.store %0, %5 : !llvm.ptr - omp.target { + omp.target map((to -> %3 : !llvm.ptr), (to -> %5 : !llvm.ptr), (from -> %7 : !llvm.ptr)) { %8 = llvm.load %3 : !llvm.ptr %9 = llvm.load %5 : !llvm.ptr %10 = llvm.add %8, %9 : i32 Index: mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir =================================================================== --- mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir +++ mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir @@ -12,7 +12,7 @@ %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr llvm.store %1, %3 : !llvm.ptr llvm.store %0, %5 : !llvm.ptr - omp.target { + omp.target map((to -> %3 : !llvm.ptr), (to -> %5 : !llvm.ptr), (from -> %7 : !llvm.ptr)) { %8 = llvm.load %3 : !llvm.ptr %9 = llvm.load %5 : !llvm.ptr %10 = llvm.add %8, %9 : i32 Index: mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir =================================================================== --- mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir +++ mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir @@ -12,7 +12,7 @@ %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr llvm.store %1, %3 : !llvm.ptr llvm.store %0, %5 : !llvm.ptr - omp.target { + omp.target map((to -> %3 : !llvm.ptr), (to -> %5 : !llvm.ptr), (from -> %7 : !llvm.ptr)) { omp.parallel { %8 = llvm.load %3 : !llvm.ptr %9 = llvm.load %5 : !llvm.ptr