diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1824,6 +1824,27 @@ Value *IfCond, BodyGenCallbackTy ProcessMapOpCB, BodyGenCallbackTy BodyGenCB = {}); + using TargetBodyGenCallbackTy = function_ref; + + /// Generator for '#omp target' + /// + /// \param Loc where the target data construct was encountered. + /// \param CodeGenIP The insertion point where the call to the outlined + /// function should be emitted. + /// \param EntryInfo The entry information about the function. + /// \param NumTeams Number of teams specified in the num_teams clause. + /// \param NumThreads Number of teams specified in the thread_limit clause. + /// \param Inputs The input values to the region that will be passed. + /// as arguments to the outlined function. + /// \param BodyGenCB Callback that will generate the region code. + InsertPointTy createTarget(const LocationDescription &Loc, + OpenMPIRBuilder::InsertPointTy CodeGenIP, + TargetRegionEntryInfo &EntryInfo, int32_t NumTeams, + int32_t NumThreads, + SmallVectorImpl &Inputs, + TargetBodyGenCallbackTy BodyGenCB); + /// Declarations for LLVM-IR types (simple, array, function and structure) are /// generated below. Their names are defined and used in OpenMPKinds.def. Here /// we provide the declarations, the initializeTypes function will provide the diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4111,6 +4111,88 @@ return Builder.saveIP(); } +static Function * +createOutlinedFunction(IRBuilderBase &Builder, StringRef FuncName, + SmallVectorImpl &Inputs, + OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { + SmallVector ParameterTypes; + for (auto &Arg : Inputs) + ParameterTypes.push_back(Arg->getType()); + + auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes, + /*isVarArg*/ false); + auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, + Builder.GetInsertBlock()->getModule()); + + // Save insert point. + auto OldInsertPoint = Builder.saveIP(); + + // Generate the region into the function. + BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func); + Builder.SetInsertPoint(EntryBB); + Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP())); + + // Insert return instruction. + Builder.CreateRetVoid(); + + // Rewrite uses of input valus to parameters. + for (auto InArg : zip(Inputs, Func->args())) { + Value *Input = std::get<0>(InArg); + Argument &Arg = std::get<1>(InArg); + + // Collect all the instructions + for (User *User : make_early_inc_range(Input->users())) + if (auto Instr = dyn_cast(User)) + if (Instr->getFunction() == Func) + Instr->replaceUsesOfWith(Input, &Arg); + } + + // Restore insert point. + Builder.restoreIP(OldInsertPoint); + + return Func; +} + +static void +emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + TargetRegionEntryInfo &EntryInfo, + Function *&OutlinedFn, int32_t NumTeams, + int32_t NumThreads, SmallVectorImpl &Inputs, + OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { + + OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction = + [&Builder, &Inputs, &CBFunc](StringRef EntryFnName) { + return createOutlinedFunction(Builder, EntryFnName, Inputs, CBFunc); + }; + + Constant *OutlinedFnID; + OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, + NumTeams, NumThreads, true, OutlinedFn, + OutlinedFnID); +} + +static void emitTargetCall(IRBuilderBase &Builder, Function *OutlinedFn, + SmallVectorImpl &Args) { + // TODO: Add kernel launch call when device codegen is supported. + Builder.CreateCall(OutlinedFn, Args); +} + +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( + const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy CodeGenIP, + TargetRegionEntryInfo &EntryInfo, int32_t NumTeams, int32_t NumThreads, + SmallVectorImpl &Args, TargetBodyGenCallbackTy CBFunc) { + if (!updateToLocation(Loc)) + return InsertPointTy(); + + Builder.restoreIP(CodeGenIP); + + Function *OutlinedFn; + emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, NumTeams, + NumThreads, Args, CBFunc); + emitTargetCall(Builder, OutlinedFn, Args); + return Builder.saveIP(); +} + std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef Parts, StringRef FirstSeparator, StringRef Separator) { diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5119,6 +5119,62 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, TargetRegion) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + OpenMPIRBuilderConfig Config(false, false, false, false); + OMPBuilder.setConfig(Config); + F->setName("func"); + IRBuilder<> Builder(BB); + auto Int32Ty = Builder.getInt32Ty(); + + AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr"); + AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr"); + AllocaInst *CPtr = Builder.CreateAlloca(Int32Ty, nullptr, "c_ptr"); + + Builder.CreateStore(Builder.getInt32(10), APtr); + Builder.CreateStore(Builder.getInt32(20), BPtr); + auto BodyGenCB = [&](InsertPointTy AllocaIP, + InsertPointTy CodeGenIP) -> InsertPointTy { + Builder.restoreIP(CodeGenIP); + LoadInst *AVal = Builder.CreateLoad(Int32Ty, APtr); + LoadInst *BVal = Builder.CreateLoad(Int32Ty, BPtr); + Value *Sum = Builder.CreateAdd(AVal, BVal); + Builder.CreateStore(Sum, CPtr); + return Builder.saveIP(); + }; + + llvm::SmallVector Inputs; + Inputs.push_back(APtr); + Inputs.push_back(BPtr); + Inputs.push_back(CPtr); + + TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); + OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); + Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(), EntryInfo, + -1, -1, Inputs, BodyGenCB)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check the outlined call + auto Iter = F->getEntryBlock().rbegin(); + CallInst *Call = dyn_cast(&*(++Iter)); + EXPECT_NE(Call, nullptr); + + // Check that the correct aguments are passed in + for (auto ArgInput : zip(Call->args(), Inputs)) { + EXPECT_EQ(std::get<0>(ArgInput), std::get<1>(ArgInput)); + } + + // Check that the outlined function exists with the expected prefix + Function *OutlinedFunc = Call->getCalledFunction(); + EXPECT_NE(OutlinedFunc, nullptr); + StringRef FunctionName = OutlinedFunc->getName(); + EXPECT_TRUE(FunctionName.startswith("__omp_offloading")); + EXPECT_FALSE(verifyModule(*M, &errs())); +} + TEST_F(OpenMPIRBuilderTest, CreateTask) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -162,13 +162,7 @@ /// Returns the OpenMP IR builder associated with the LLVM IR module being /// constructed. - llvm::OpenMPIRBuilder *getOpenMPBuilder() { - if (!ompBuilder) { - ompBuilder = std::make_unique(*llvmModule); - ompBuilder->initialize(); - } - return ompBuilder.get(); - } + llvm::OpenMPIRBuilder *getOpenMPBuilder(); /// Translates the given location. const llvm::DILocation *translateLoc(Location loc, llvm::DILocalScope *scope); diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -39,6 +39,7 @@ MLIRLLVMDialect MLIRLLVMIRTransforms MLIRTranslateLib + MLIROpenMPDialect ) add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -12,11 +12,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -24,6 +26,7 @@ #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/Support/FileSystem.h" using namespace mlir; @@ -1573,6 +1576,102 @@ return success(); } +static llvm::TargetRegionEntryInfo +getTargetEntryUniqueInfo(omp::TargetOp targetOp, + llvm::StringRef parentName = "") { + auto fileLoc = targetOp.getLoc()->findInstanceOf(); + + assert(fileLoc && "No file found from location"); + StringRef fileName = fileLoc.getFilename().getValue(); + + llvm::sys::fs::UniqueID id; + if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) { + targetOp.emitError("Unable to get unique ID for file"); + } + + uint64_t line = fileLoc.getLine(); + return llvm::TargetRegionEntryInfo(parentName, id.getDevice(), id.getFile(), + line); +} + +static bool targetOpSupported(Operation &opInst) { + auto targetOp = cast(opInst); + if (targetOp.getIfExpr()) { + opInst.emitError("If clause not yet supported"); + return false; + } + + if (targetOp.getDevice()) { + opInst.emitError("Device clause not yet supported"); + return false; + } + + if (targetOp.getThreadLimit()) { + opInst.emitError("Thread limit clause not yet supported"); + return false; + } + + if (targetOp.getNowait()) { + opInst.emitError("Nowait clause not yet supported"); + return false; + } + + return true; +} + +static LogicalResult +convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + + if (!targetOpSupported(opInst)) + return failure(); + + bool isDevice = false; + if (auto offloadMod = dyn_cast( + opInst.getParentOfType().getOperation())) { + isDevice = offloadMod.getIsDevice(); + } + + if (isDevice) // TODO: Implement device codegen. + return success(); + + auto targetOp = cast(opInst); + auto &targetRegion = targetOp.getRegion(); + + llvm::SetVector operandSet; + getUsedValuesDefinedAbove(targetRegion, operandSet); + + // Collect the input arguments. + llvm::SmallVector inputs; + for (Value operand : operandSet) + inputs.push_back(moduleTranslation.lookupValue(operand)); + + LogicalResult bodyGenStatus = success(); + + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + auto bodyCB = [&](InsertPointTy allocaIP, + InsertPointTy codeGenIP) -> InsertPointTy { + builder.restoreIP(codeGenIP); + llvm::BasicBlock *exitBlock = convertOmpOpRegions( + targetRegion, "omp.target", builder, moduleTranslation, bodyGenStatus); + builder.SetInsertPoint(exitBlock); + return builder.saveIP(); + }; + + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + StringRef parentName = opInst.getParentOfType().getName(); + llvm::TargetRegionEntryInfo entryInfo = + getTargetEntryUniqueInfo(targetOp, parentName); + int32_t defaultValTeams = -1; + int32_t defaultValThreads = -1; + + builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget( + ompLoc, builder.saveIP(), entryInfo, defaultValTeams, defaultValThreads, + inputs, bodyCB)); + + return bodyGenStatus; +} + namespace { /// Implementation of the dialect interface that converts operations belonging @@ -1713,6 +1812,9 @@ .Case([&](auto op) { return convertOmpTargetData(op, builder, moduleTranslation); }) + .Case([&](omp::TargetOp) { + return convertOmpTarget(*op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -449,6 +450,7 @@ assert(satisfiesLLVMModule(mlirModule) && "mlirModule should honor LLVM's module semantics."); } + ModuleTranslation::~ModuleTranslation() { if (ompBuilder) ompBuilder->finalize(); @@ -1250,6 +1252,26 @@ return remapped; } +llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() { + if (!ompBuilder) { + ompBuilder = std::make_unique(*llvmModule); + ompBuilder->initialize(); + + bool isDevice = false; + if (auto offloadMod = + dyn_cast(mlirModule)) + isDevice = offloadMod.getIsDevice(); + + // TODO: set the flags when available + llvm::OpenMPIRBuilderConfig Config( + isDevice, /* IsTargetCodegen */ false, + /* HasRequiresUnifiedSharedMemory */ false, + /* OpenMPOffloadMandatory */ false); + ompBuilder->setConfig(Config); + } + return ompBuilder.get(); +} + const llvm::DILocation * ModuleTranslation::translateLoc(Location loc, llvm::DILocalScope *scope) { return debugTranslation->translateLoc(loc, scope); diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir --- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir @@ -174,3 +174,89 @@ // CHECK: ret void // ----- + +module attributes {omp.is_device = #omp.isdevice} { + llvm.func @omp_target_region_() { + %0 = llvm.mlir.constant(20 : i32) : i32 + %1 = llvm.mlir.constant(10 : i32) : i32 + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.alloca %2 x i32 {bindc_name = "a", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEa"} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(1 : i64) : i64 + %5 = llvm.alloca %4 x i32 {bindc_name = "b", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEb"} : (i64) -> !llvm.ptr + %6 = llvm.mlir.constant(1 : i64) : i64 + %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr + llvm.store %1, %3 : !llvm.ptr + llvm.store %0, %5 : !llvm.ptr + omp.target { + %8 = llvm.load %3 : !llvm.ptr + %9 = llvm.load %5 : !llvm.ptr + %10 = llvm.add %8, %9 : i32 + llvm.store %10, %7 : !llvm.ptr + omp.terminator + } + llvm.return + } +} + +// CHECK: call void @__omp_offloading_[[DEV:.*]]_[[FIL:.*]]_omp_target_region__l[[LINE:.*]](ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}) +// CHECK: define internal void @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_region__l[[LINE]](ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]]) +// CHECK: %[[VAL_A:.*]] = load i32, ptr %[[ADDR_A]], align 4 +// CHECK: %[[VAL_B:.*]] = load i32, ptr %[[ADDR_B]], align 4 +// CHECK: %[[SUM:.*]] = add i32 %[[VAL_A]], %[[VAL_B]] +// CHECK: store i32 %[[SUM]], ptr %[[ADDR_C]], align 4 + +// ----- + +module attributes {omp.is_device = #omp.isdevice} { + llvm.func @omp_target_region_() { + %0 = llvm.mlir.constant(20 : i32) : i32 + %1 = llvm.mlir.constant(10 : i32) : i32 + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.alloca %2 x i32 {bindc_name = "a", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEa"} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(1 : i64) : i64 + %5 = llvm.alloca %4 x i32 {bindc_name = "b", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEb"} : (i64) -> !llvm.ptr + %6 = llvm.mlir.constant(1 : i64) : i64 + %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr + llvm.store %1, %3 : !llvm.ptr + llvm.store %0, %5 : !llvm.ptr + omp.target { + omp.parallel { + %8 = llvm.load %3 : !llvm.ptr + %9 = llvm.load %5 : !llvm.ptr + %10 = llvm.add %8, %9 : i32 + llvm.store %10, %7 : !llvm.ptr + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// CHECK: call void @__omp_offloading_[[DEV:.*]]_[[FIL:.*]]_omp_target_region__l[[LINE:.*]](ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}) +// CHECK: define internal void @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_region__l[[LINE]](ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]]) +// CHECK: %[[STRUCTARG:.*]] = alloca { ptr, ptr, ptr }, align 8 +// CHECK: %[[GEP1:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[STRUCTARG]], i32 0, i32 0 +// CHECK: store ptr %[[ADDR_A]], ptr %[[GEP1]], align 8 +// CHECK: %[[GEP2:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[STRUCTARG]], i32 0, i32 1 +// CHECK: store ptr %[[ADDR_B]], ptr %[[GEP2]], align 8 +// CHECK: %[[GEP3:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[STRUCTARG]], i32 0, i32 2 +// CHECK: store ptr %[[ADDR_C]], ptr %[[GEP3]], align 8 +// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_region__l[[LINE]]..omp_par, ptr %[[STRUCTARG]]) + + +// CHECK: define internal void @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_region__l[[LINE]]..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %[[STRUCTARG2:.*]]) #0 { +// CHECK: %[[GEP4:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[STRUCTARG2]], i32 0, i32 0 +// CHECK: %[[LOADGEP1:.*]] = load ptr, ptr %[[GEP4]], align 8 +// CHECK: %[[GEP5:.*]] = getelementptr { ptr, ptr, ptr }, ptr %0, i32 0, i32 1 +// CHECK: %[[LOADGEP2:.*]] = load ptr, ptr %[[GEP5]], align 8 +// CHECK: %[[GEP6:.*]] = getelementptr { ptr, ptr, ptr }, ptr %0, i32 0, i32 2 +// CHECK: %[[LOADGEP3:.*]] = load ptr, ptr %[[GEP6]], align 8 + +// CHECK: %[[VAL_A:.*]] = load i32, ptr %[[LOADGEP1]], align 4 +// CHECK: %[[VAL_B:.*]] = load i32, ptr %[[LOADGEP2]], align 4 +// CHECK: %[[SUM:.*]] = add i32 %[[VAL_A]], %[[VAL_B]] +// CHECK: store i32 %[[SUM]], ptr %[[LOADGEP3]], align 4 + +// ----- +