Index: flang/lib/Lower/OpenMP.cpp =================================================================== --- flang/lib/Lower/OpenMP.cpp +++ flang/lib/Lower/OpenMP.cpp @@ -1122,6 +1122,12 @@ &opClauseList); } else if (blockDirective.v == llvm::omp::OMPD_target_data) { createTargetDataOp(converter, opClauseList, blockDirective.v, &eval); + } else if (blockDirective.v == llvm::omp::OMPD_target) { + auto targetOp = firOpBuilder.create( + currentLocation, /*if_clause*/ mlir::Value(), /*device*/ mlir::Value(), + /*thread_limit*/ mlir::Value(), + /*nowait*/ nullptr); + createBodyOfOp(targetOp, converter, currentLocation, eval, &opClauseList); } else { TODO(converter.getCurrentLocation(), "Unhandled block directive"); } Index: flang/test/Lower/OpenMP/target_region.f90 =================================================================== --- /dev/null +++ flang/test/Lower/OpenMP/target_region.f90 @@ -0,0 +1,59 @@ +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +!=============================================================================== +! Simple target region +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_region() { +subroutine omp_target_region + integer :: a + integer :: b + integer :: c + a = 10 + b = 20 + + !CHECK omp.target { + !CHECK %[[VAL_0:.*]] = fir.load %{{.*}} : !fir.ref + !CHECK %[[VAL_1:.*]] = fir.load %{{.*}} : !fir.ref + !CHECK %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : i32 + !CHECK fir.store %[[VAL_2]] to {{.*}} : !fir.ref + !CHECK omp.terminator + !$omp target + c = a + b + !$omp end target + +end subroutine omp_target_region + +!=============================================================================== +! Two targe regions +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_regions() { +subroutine omp_target_regions + integer :: a + integer :: b + integer :: c + a = 10 + b = 20 + + !CHECK omp.target { + !CHECK %[[VAL_0:.*]] = fir.load %[[ADDR_A]]:.*]] : !fir.ref + !CHECK %[[VAL_1:.*]] = fir.load %{{.*}} : !fir.ref + !CHECK %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : i32 + !CHECK fir.store %[[VAL_2]] to [[ADDR_C:.*]] : !fir.ref + !CHECK omp.terminator + !$omp target + c = a + b + !$omp end target + !CHECK %[[VAL_3:.*]] = fir.load %[[ADDR_C]] : !fir.ref + !CHECK %[[VAL_4:.*]] = fir.load %[[ADDR_A]] : !fir.ref + !CHECK %[[VAL_5:.*]] = arith.subi %[[VAL_3]], %[[VAL_4]] : i32 + !CHECK fir.store %[[VAL_5]] to [[ADDR_C:.*]] : !fir.ref + !$omp target + c = c - a + !$omp end target + +end subroutine omp_target_regions + + + Index: flang/test/Lower/OpenMP/target_region_to_llvmir.f90 =================================================================== --- /dev/null +++ flang/test/Lower/OpenMP/target_region_to_llvmir.f90 @@ -0,0 +1,25 @@ +!RUN: %flang_fc1 -emit-llvm -fopenmp %s -o - | FileCheck %s + +!=============================================================================== +! Simple target region +!=============================================================================== + +!CHECK-LABEL: @omp_target_region +subroutine omp_target_region + integer :: a + integer :: b + integer :: c + a = 10 + b = 20 + !CHECK: call void @__omp_offloading_[[DEV:.*]]_[[FIL:.*]]_omp_target_region__l[[LINE:.*]](ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}) + + !CHECK: define internal void @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_region__l[[LINE]](ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]]) + !CHECK: %[[VAL_A:.*]] = load i32, ptr %[[ADDR_A]], align 4 + !CHECK: %[[VAL_B:.*]] = load i32, ptr %[[ADDR_B]], align 4 + !CHECK: %[[SUM:.*]] = add i32 %[[VAL_A]], %[[VAL_B]] + !CHECK: store i32 %[[SUM]], ptr %[[ADDR_C]], align 4 + !$omp target + c = a + b + !$omp end target + +end subroutine omp_target_region Index: mlir/lib/Target/LLVMIR/CMakeLists.txt =================================================================== --- mlir/lib/Target/LLVMIR/CMakeLists.txt +++ mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -38,6 +38,7 @@ MLIRLLVMDialect MLIRLLVMIRTransforms MLIRTranslateLib + MLIROpenMPDialect ) add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration @@ -55,6 +56,7 @@ MLIROpenACCToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation + MLIROpenMPDialect ) add_mlir_translation_library(MLIRTargetLLVMIRImport @@ -81,4 +83,5 @@ LINK_LIBS PUBLIC MLIRLLVMIRToLLVMTranslation + MLIROpenMPDialect ) Index: mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -11,12 +11,15 @@ // //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -24,6 +27,7 @@ #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/Support/FileSystem.h" using namespace mlir; @@ -1542,6 +1546,148 @@ return bodyGenStatus; } +static llvm::TargetRegionEntryInfo +getTargetEntryUniqueInfo(omp::TargetOp &targetOp, + llvm::StringRef parentName = "") { + auto fileLoc = targetOp.getLoc()->findInstanceOf(); + + assert(fileLoc && "No file found from location"); + StringRef fileName = fileLoc.getFilename().getValue(); + + llvm::sys::fs::UniqueID id; + if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) { + targetOp.emitError("Unable to get unique ID for file"); + } + + uint64_t line = fileLoc.getLine(); + return llvm::TargetRegionEntryInfo(parentName, id.getDevice(), id.getFile(), + line); +} + +// Takes a TargetOp and creates an outlined function. +static llvm::Function * +createOutlinedfunction(omp::TargetOp targetOp, StringRef functionName, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::SmallVectorImpl &operands) { + Region &targetRegion = targetOp.getRegion(); + + llvm::SetVector operandSet; + getUsedValuesDefinedAbove(targetRegion, operandSet); + + // Convert the operand types to llvm::Type for constructing parameter types. + llvm::SmallVector parameterTypes; + for (Value operand : operandSet) { + parameterTypes.push_back(moduleTranslation.convertType(operand.getType())); + operands.push_back(operand); + } + + // Create function type and function. + auto functionType = + llvm::FunctionType::get(builder.getVoidTy(), parameterTypes, + /*isVarArg*/ false); + auto function = llvm::Function::Create( + functionType, llvm::GlobalValue::InternalLinkage, functionName, + builder.GetInsertBlock()->getModule()); + + // Save insert point. + auto oldInsertPoint = builder.GetInsertPoint(); + auto oldInsertBlock = builder.GetInsertBlock(); + + // Generate the region into the function. + llvm::BasicBlock *EntryBB = llvm::BasicBlock::Create( + moduleTranslation.getLLVMContext(), "entry", function); + builder.SetInsertPoint(EntryBB); + LogicalResult bodyGenStatus = success(); + llvm::BasicBlock *exitBlock = convertOmpOpRegions( + targetRegion, "omp.target", builder, moduleTranslation, bodyGenStatus); + + // Rewrite uses of input valus to parameters. + for (auto op_arg : zip(operands, function->args())) { + auto operand = std::get<0>(op_arg); + auto &arg = std::get<1>(op_arg); + auto oldValue = moduleTranslation.lookupValue(operand); + llvm::SmallVector users; + // Collect all the instructions + for (llvm::User *user : llvm::make_early_inc_range(oldValue->users())) { + if (auto instr = dyn_cast(user)) { + if (instr->getFunction() == function) { + instr->replaceUsesOfWith(oldValue, &arg); + } + } + } + } + + // Insert return instruction. + builder.SetInsertPoint(exitBlock); + builder.CreateRetVoid(); + + // Restore insert point. + builder.SetInsertPoint(oldInsertBlock, oldInsertPoint); + + return function; +} + +static void emitTargetOutlinedFunction( + omp::TargetOp &targetOp, StringRef parentName, llvm::Function *&outlinedFn, + llvm::Constant *&outlinedFnID, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::SmallVectorImpl &operands) { + + llvm::TargetRegionEntryInfo entryInfo = + getTargetEntryUniqueInfo(targetOp, parentName); + int32_t defaultValTeams = -1; + int32_t defaultValThreads = -1; + + llvm::OpenMPIRBuilder::FunctionGenCallback &&generateOutlinedFunction = + [&targetOp, &builder, &moduleTranslation, + &operands](llvm::StringRef entryFnName) { + return createOutlinedfunction(targetOp, entryFnName, builder, + moduleTranslation, operands); + }; + + moduleTranslation.getOpenMPBuilder()->emitTargetRegionFunction( + entryInfo, generateOutlinedFunction, defaultValTeams, defaultValThreads, + true, outlinedFn, outlinedFnID); +} + +static void emitTargetCall(llvm::IRBuilderBase &builder, + llvm::Function *outlinedFn, + llvm::Value *outlinedFnID, + LLVM::ModuleTranslation &moduleTranslation, + llvm::SmallVectorImpl &operands) { + // TODO: Add kernel launch call when device codegen is supported. + builder.CreateCall(outlinedFn, moduleTranslation.lookupValues(operands)); +} + +static LogicalResult +convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + bool isDevice = false; + if (auto offloadMod = dyn_cast( + opInst.getParentOfType().getOperation())) { + isDevice = offloadMod.getIsDevice(); + } + + if (isDevice) // TODO: Implement device codegen. + return success(); + + auto targetOp = cast(opInst); + + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + + StringRef parentName = opInst.getParentOfType().getName(); + llvm::Function *outlinedFn; + llvm::Constant *outlinedFnID; + llvm::SmallVector operands; + emitTargetOutlinedFunction(targetOp, parentName, outlinedFn, outlinedFnID, + builder, moduleTranslation, operands); + + emitTargetCall(builder, outlinedFn, outlinedFnID, moduleTranslation, + operands); + return success(); +} + namespace { /// Implementation of the dialect interface that converts operations belonging @@ -1659,6 +1805,9 @@ .Case([&](auto op) { return convertOmpTargetData(op, builder, moduleTranslation); }) + .Case([&](omp::TargetOp) { + return convertOmpTarget(*op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName(); Index: mlir/lib/Target/LLVMIR/ModuleTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -1311,6 +1312,15 @@ LLVM::ensureDistinctSuccessors(module); ModuleTranslation translator(module, std::move(llvmModule)); + + bool isDevice = false; + if (auto offloadMod = dyn_cast(module)) { + isDevice = offloadMod.getIsDevice(); + } + // TODO: set the flags when available + llvm::OpenMPIRBuilderConfig Config(isDevice, false, false, false); + translator.getOpenMPBuilder()->setConfig(Config); + if (failed(translator.convertFunctionSignatures())) return nullptr; if (failed(translator.convertGlobals()))