diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -497,12 +497,6 @@ OMP_REQ_DYNAMIC_ALLOCATORS = 0x010, LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS) }; - -enum OpenMPOffloadingReservedDeviceIDs { - /// Device ID if the device was not defined, runtime should get it - /// from environment variables in the spec. - OMP_DEVICEID_UNDEF = -1, -}; } // anonymous namespace /// Describes ident structure that describes a source location. diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h @@ -241,6 +241,12 @@ LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ OMP_MAP_MEMBER_OF) }; +enum OpenMPOffloadingReservedDeviceIDs { + /// Device ID if the device was not defined, runtime should get it + /// from environment variables in the spec. + OMP_DEVICEID_UNDEF = -1 +}; + enum class AddressSpace : unsigned { Generic = 0, Global = 1, diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1095,6 +1095,9 @@ /// variables. StringMap InternalVars; + /// Computes the size of type in bytes. + llvm::Value *getSizeInBytes(llvm::Value *basePtr); + /// Create the global variable holding the offload mappings information. GlobalVariable *createOffloadMaptypes(SmallVectorImpl &Mappings, std::string VarName); @@ -1550,6 +1553,32 @@ StringRef EntryFnIDName, int32_t NumTeams, int32_t NumThreads); + /// Generator for '#omp target data' + /// + /// \param Loc The location where the target data construct was encountered. + /// \param CodeGenIP The insertion point at which the target directive code + /// should be placed. + /// \param hasRegion True if the op has a region associated with it, false + /// otherwise + /// \param mapTypeFlags BitVector storing the mapType flags for the + /// mapOperands + /// \param mapNames Names for the mapOperands + /// \param mapperAllocas Pointers to the AllocInsts for the map clause + /// \param mapperFunc Mapper Fucntion to be called for the Target Data op + /// \param mapperFunc Stores the deviceId from the device clause + /// \param IfCond Value which corresponds to the if clause condition. + /// \param processMapOpCB Callback that will generate the mapOperand code. + /// \param BodyGenCB Callback that will generate the region code. + OpenMPIRBuilder::InsertPointTy + createTargetData(const LocationDescription &Loc, + llvm::OpenMPIRBuilder::InsertPointTy codeGenIP, + bool hasRegion, SmallVector &mapTypeFlags, + SmallVector &mapNames, + struct MapperAllocas &mapperAllocas, + llvm::Function *mapperFunc, int64_t deviceID, + llvm::Value *ifCond, BodyGenCallbackTy processMapOpCB, + BodyGenCallbackTy BodyGenCB); + /// Declarations for LLVM-IR types (simple, array, function and structure) are /// generated below. Their names are defined and used in OpenMPKinds.def. Here /// we provide the declarations, the initializeTypes function will provide the diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4038,6 +4038,77 @@ return OutlinedFnID; } +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData( + const LocationDescription &Loc, + llvm::OpenMPIRBuilder::InsertPointTy codeGenIP, bool hasRegion, + SmallVector &mapTypeFlags, + SmallVector &mapNames, + struct MapperAllocas &mapperAllocas, llvm::Function *mapperFunc, + int64_t deviceID, llvm::Value *ifCond, BodyGenCallbackTy processMapOpCB, + BodyGenCallbackTy BodyGenCB) { + if (!updateToLocation(Loc)) + return InsertPointTy(); + + Builder.restoreIP(codeGenIP); + + // LLVM utilities like blocks with terminators. + auto *UI = Builder.CreateUnreachable(); + + Instruction *ThenTI = UI, *ElseTI = nullptr; + if (ifCond) { + SplitBlockAndInsertIfThenElse(ifCond, UI, &ThenTI, &ElseTI); + ThenTI->getParent()->setName("omp_if.then"); + ElseTI->getParent()->setName("omp_if.else"); + } + Builder.SetInsertPoint(ThenTI); + + processMapOpCB(Builder.saveIP(), Builder.saveIP()); + + uint32_t SrcLocStrSize; + llvm::Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize); + llvm::Value *srcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize); + + llvm::GlobalVariable *mapTypesGV = + createOffloadMaptypes(mapTypeFlags, ".offload_maptypes"); + llvm::Value *mapTypesArg = Builder.CreateConstInBoundsGEP2_32( + llvm::ArrayType::get(Builder.getInt64Ty(), mapTypeFlags.size()), + mapTypesGV, + /*Idx0=*/0, /*Idx1=*/0); + + llvm::GlobalVariable *mapNamesGV = + createOffloadMapnames(mapNames, ".offload_mapnames"); + llvm::Value *mapNamesArg = Builder.CreateConstInBoundsGEP2_32( + llvm::ArrayType::get(Builder.getInt8PtrTy(), mapNames.size()), mapNamesGV, + /*Idx0=*/0, /*Idx1=*/0); + + if (hasRegion) { + llvm::Function *beginMapperFunc = getOrCreateRuntimeFunctionPtr( + llvm::omp::OMPRTL___tgt_target_data_begin_mapper); + llvm::Function *endMapperFunc = getOrCreateRuntimeFunctionPtr( + llvm::omp::OMPRTL___tgt_target_data_end_mapper); + + // Create call to start the data region. + emitMapperCall(Builder.saveIP(), beginMapperFunc, srcLocInfo, mapTypesArg, + mapNamesArg, mapperAllocas, deviceID, mapTypeFlags.size()); + + processMapOpCB(Builder.saveIP(), Builder.saveIP()); + + // Create call to end the data region. + emitMapperCall(Builder.saveIP(), endMapperFunc, srcLocInfo, mapTypesArg, + mapNamesArg, mapperAllocas, deviceID, mapTypeFlags.size()); + } else { + emitMapperCall(Builder.saveIP(), mapperFunc, srcLocInfo, mapTypesArg, + mapNamesArg, mapperAllocas, deviceID, mapTypeFlags.size()); + } + + // Update the insertion point and remove the terminator we introduced. + Builder.SetInsertPoint(UI->getParent()); + if (ifCond) + UI->getParent()->setName("omp_if.end"); + UI->eraseFromParent(); + return Builder.saveIP(); +} + std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef Parts, StringRef FirstSeparator, StringRef Separator) { @@ -4085,6 +4156,18 @@ return getOrCreateInternalVariable(KmpCriticalNameTy, Name); } +llvm::Value *OpenMPIRBuilder::getSizeInBytes(llvm::Value *basePtr) { + + llvm::LLVMContext &ctx = Builder.getContext(); + llvm::Value *null = + llvm::Constant::getNullValue(basePtr->getType()->getPointerTo()); + llvm::Value *sizeGep = + Builder.CreateGEP(basePtr->getType(), null, Builder.getInt32(1)); + llvm::Value *sizePtrToInt = + Builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx)); + return sizePtrToInt; +} + GlobalVariable * OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl &Mappings, std::string VarName) { diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4912,6 +4912,84 @@ EXPECT_TRUE(MapperCall->getOperand(8)->getType()->isPointerTy()); } +TEST_F(OpenMPIRBuilderTest, TargetEnterData) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + unsigned numDataOperands = 1; + int64_t deviceID = 2; + struct OpenMPIRBuilder::MapperAllocas mapperAllocas; + SmallVector mapTypeflags = {5}; + SmallVector mapNames; + auto *i8PtrTy = Builder.getInt8PtrTy(); + auto *arrI8PtrTy = ArrayType::get(i8PtrTy, numDataOperands); + auto *i64Ty = Builder.getInt64Ty(); + auto *arrI64Ty = ArrayType::get(i64Ty, numDataOperands); + + AllocaInst *Val1 = + Builder.CreateAlloca(Builder.getInt32Ty(), Builder.getInt64(1)); + ASSERT_NE(Val1, nullptr); + + IRBuilder<>::InsertPoint allocaIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + OMPBuilder.createMapperAllocas(Builder.saveIP(), allocaIP, numDataOperands, + mapperAllocas); + + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + auto processMapOpCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { + Value *dataValue = Val1; + Value *dataPtrBase; + Value *dataPtr; + dataPtrBase = dataValue; + dataPtr = dataValue; + Builder.restoreIP(codeGenIP); + + Value *null = Constant::getNullValue(dataValue->getType()->getPointerTo()); + Value *sizeGep = + Builder.CreateGEP(dataValue->getType(), null, Builder.getInt32(1)); + Value *sizePtrToInt = Builder.CreatePtrToInt(sizeGep, i64Ty); + + Value *ptrBaseGEP = + Builder.CreateInBoundsGEP(arrI8PtrTy, mapperAllocas.ArgsBase, + {Builder.getInt32(0), Builder.getInt32(0)}); + Value *ptrBaseCast = Builder.CreateBitCast( + ptrBaseGEP, dataPtrBase->getType()->getPointerTo()); + Builder.CreateStore(dataPtrBase, ptrBaseCast); + Value *ptrGEP = + Builder.CreateInBoundsGEP(arrI8PtrTy, mapperAllocas.Args, + {Builder.getInt32(0), Builder.getInt32(0)}); + Value *ptrCast = + Builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo()); + Builder.CreateStore(dataPtr, ptrCast); + Value *sizeGEP = + Builder.CreateInBoundsGEP(arrI64Ty, mapperAllocas.ArgSizes, + {Builder.getInt32(0), Builder.getInt32(0)}); + Builder.CreateStore(sizePtrToInt, sizeGEP); + }; + auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {}; + + Builder.restoreIP(OMPBuilder.createTargetData( + Loc, Builder.saveIP(), /*hasRegion=*/false, mapTypeflags, mapNames, + mapperAllocas, + OMPBuilder.getOrCreateRuntimeFunctionPtr( + OMPRTL___tgt_target_data_begin_mapper), + deviceID, /*ifCond=*/nullptr, processMapOpCB, bodyCB)); + + CallInst *targetDataCall = dyn_cast(&BB->back()); + BB->dump(); + EXPECT_NE(targetDataCall, nullptr); + EXPECT_EQ(targetDataCall->arg_size(), 9U); + EXPECT_EQ(targetDataCall->getCalledFunction()->getName(), + "__tgt_target_data_begin_mapper"); + EXPECT_TRUE(targetDataCall->getOperand(1)->getType()->isIntegerTy(64)); + EXPECT_TRUE(targetDataCall->getOperand(2)->getType()->isIntegerTy(32)); + EXPECT_TRUE(targetDataCall->getOperand(8)->getType()->isPointerTy()); +} + TEST_F(OpenMPIRBuilderTest, CreateTask) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/Utils.h b/mlir/include/mlir/Target/LLVMIR/Dialect/Utils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/Utils.h @@ -0,0 +1,58 @@ +//===- Utils.h - General Utilities for translating MLIR dialect to LLVM IR-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines general utilities for MLIR Dialect translations to LLVM IR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_UTILS_H +#define MLIR_DIALECT_UTILS_H + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/Frontend/OpenMP/OMPIRBuilder.h" +#include "llvm/IR/IRBuilder.h" + +namespace mlir { +namespace utils { + +/// Create a constant string location from the MLIR Location information. +static llvm::Constant * +createSourceLocStrFromLocation(Location loc, llvm::OpenMPIRBuilder &builder, + StringRef name, uint32_t &strLen) { + if (auto fileLoc = loc.dyn_cast()) { + StringRef fileName = fileLoc.getFilename(); + unsigned lineNo = fileLoc.getLine(); + unsigned colNo = fileLoc.getColumn(); + return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo, strLen); + } + std::string locStr; + llvm::raw_string_ostream locOS(locStr); + locOS << loc; + return builder.getOrCreateSrcLocStr(locOS.str(), strLen); +} + +/// Create a constant string representing the mapping information extracted from +/// the MLIR location information. +static llvm::Constant * +createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder) { + uint32_t strLen; + if (auto nameLoc = loc.dyn_cast()) { + StringRef name = nameLoc.getName(); + return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name, + strLen); + } + return createSourceLocStrFromLocation(loc, builder, "unknown", strLen); +} + +} // namespace utils +} // namespace mlir + +#endif // MLIR_DIALECT_UTILS_H \ No newline at end of file diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Utils.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/ADT/TypeSwitch.h" @@ -46,23 +47,6 @@ /// Default value for the device id static constexpr int64_t kDefaultDevice = -1; -/// Create a constant string location from the MLIR Location information. -static llvm::Constant *createSourceLocStrFromLocation(Location loc, - OpenACCIRBuilder &builder, - StringRef name, - uint32_t &strLen) { - if (auto fileLoc = loc.dyn_cast()) { - StringRef fileName = fileLoc.getFilename(); - unsigned lineNo = fileLoc.getLine(); - unsigned colNo = fileLoc.getColumn(); - return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo, strLen); - } - std::string locStr; - llvm::raw_string_ostream locOS(locStr); - locOS << loc; - return builder.getOrCreateSrcLocStr(locOS.str(), strLen); -} - /// Create the location struct from the operation location information. static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder, Operation *op) { @@ -70,24 +54,11 @@ auto funcOp = op->getParentOfType(); StringRef funcName = funcOp ? funcOp.getName() : "unknown"; uint32_t strLen; - llvm::Constant *locStr = - createSourceLocStrFromLocation(loc, builder, funcName, strLen); + llvm::Constant *locStr = mlir::utils::createSourceLocStrFromLocation( + loc, builder, funcName, strLen); return builder.getOrCreateIdent(locStr, strLen); } -/// Create a constant string representing the mapping information extracted from -/// the MLIR location information. -static llvm::Constant *createMappingInformation(Location loc, - OpenACCIRBuilder &builder) { - uint32_t strLen; - if (auto nameLoc = loc.dyn_cast()) { - StringRef name = nameLoc.getName(); - return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name, - strLen); - } - return createSourceLocStrFromLocation(loc, builder, "unknown", strLen); -} - /// Return the runtime function used to lower the given operation. static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, Operation *op) { @@ -107,19 +78,6 @@ llvm_unreachable("Unknown OpenACC operation"); } -/// Computes the size of type in bytes. -static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder, - llvm::Value *basePtr) { - llvm::LLVMContext &ctx = builder.getContext(); - llvm::Value *null = - llvm::Constant::getNullValue(basePtr->getType()->getPointerTo()); - llvm::Value *sizeGep = - builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1)); - llvm::Value *sizePtrToInt = - builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx)); - return sizePtrToInt; -} - /// Extract pointer, size and mapping information from operands /// to populate the future functions arguments. static LogicalResult @@ -153,7 +111,7 @@ } else if (data.getType().isa()) { dataPtrBase = dataValue; dataPtr = dataValue; - dataSize = getSizeInBytes(builder, dataValue); + dataSize = accBuilder->getSizeInBytes(dataValue); } else { return op->emitOpError() << "Data operand must be legalized before translation." @@ -185,7 +143,7 @@ flags.push_back(operandFlag); llvm::Constant *mapName = - createMappingInformation(data.getLoc(), *accBuilder); + mlir::utils::createMappingInformation(data.getLoc(), *accBuilder); names.push_back(mapName); ++index; } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -15,10 +15,12 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Utils.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IRBuilder.h" @@ -1338,6 +1340,187 @@ return success(); } +/// Process MapOperands for Target Data directives. +static LogicalResult processMapOperand( + llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, + const SmallVector &mapOperands, const ArrayAttr &mapTypes, + SmallVector &mapTypeFlags, + SmallVectorImpl &mapNames, + struct llvm::OpenMPIRBuilder::MapperAllocas &mapperAllocas) { + auto numMapOperands = mapOperands.size(); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::PointerType *i8PtrTy = builder.getInt8PtrTy(); + llvm::ArrayType *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, numMapOperands); + llvm::IntegerType *i64Ty = builder.getInt64Ty(); + llvm::ArrayType *arrI64Ty = llvm::ArrayType::get(i64Ty, numMapOperands); + + unsigned index = 0; + for (const auto &mapOp : mapOperands) { + const auto &mapTypeOp = mapTypes[index]; + + llvm::Value *mapOpValue = moduleTranslation.lookupValue(mapOp); + llvm::Value *mapOpPtrBase; + llvm::Value *mapOpPtr; + llvm::Value *mapOpSize; + + if (mapOp.getType().isa()) { + mapOpPtrBase = mapOpValue; + mapOpPtr = mapOpValue; + mapOpSize = ompBuilder->getSizeInBytes(mapOpValue); + } else { + return failure(); + } + + // Store base pointer extracted from operand into the i-th position of + // argBase. + llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP( + arrI8PtrTy, mapperAllocas.ArgsBase, + {builder.getInt32(0), builder.getInt32(index)}); + llvm::Value *ptrBaseCast = builder.CreateBitCast( + ptrBaseGEP, mapOpPtrBase->getType()->getPointerTo()); + builder.CreateStore(mapOpPtrBase, ptrBaseCast); + + // Store pointer extracted from operand into the i-th position of args. + llvm::Value *ptrGEP = builder.CreateInBoundsGEP( + arrI8PtrTy, mapperAllocas.Args, + {builder.getInt32(0), builder.getInt32(index)}); + llvm::Value *ptrCast = + builder.CreateBitCast(ptrGEP, mapOpPtr->getType()->getPointerTo()); + builder.CreateStore(mapOpPtr, ptrCast); + + // Store size extracted from operand into the i-th position of argSizes. + llvm::Value *sizeGEP = builder.CreateInBoundsGEP( + arrI64Ty, mapperAllocas.ArgSizes, + {builder.getInt32(0), builder.getInt32(index)}); + builder.CreateStore(mapOpSize, sizeGEP); + + mapTypeFlags.push_back(mapTypeOp.dyn_cast().getInt()); + llvm::Constant *mapName = + mlir::utils::createMappingInformation(mapOp.getLoc(), *ompBuilder); + mapNames.push_back(mapName); + ++index; + } + + return success(); +} + +static LogicalResult +convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + unsigned numMapOperands; + llvm::Function *mapperFunc; + llvm::Value *ifCond = nullptr; + int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF; + SmallVector mapOperands; + ArrayAttr mapTypes; + + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + + LogicalResult result = + llvm::TypeSwitch(op) + .Case([&](omp::DataOp dataOp) { + if (auto ifExprVar = dataOp.getIfExpr()) + ifCond = moduleTranslation.lookupValue(ifExprVar); + + if (auto devId = dataOp.getDevice()) + if (auto constOp = mlir::dyn_cast( + devId.getDefiningOp())) + if (auto intAttr = + constOp.getValue().dyn_cast()) + deviceID = intAttr.getInt(); + + numMapOperands = dataOp.getMapOperands().size(); + mapOperands = dataOp.getMapOperands(); + mapTypes = dataOp.getMapTypes(); + return success(); + }) + .Case([&](omp::EnterDataOp enterDataOp) { + if (enterDataOp.getNowait()) + return failure(); + + if (auto ifExprVar = enterDataOp.getIfExpr()) + ifCond = moduleTranslation.lookupValue(ifExprVar); + + if (auto devId = enterDataOp.getDevice()) + if (auto constOp = mlir::dyn_cast( + devId.getDefiningOp())) + if (auto intAttr = + constOp.getValue().dyn_cast()) + deviceID = intAttr.getInt(); + + numMapOperands = enterDataOp.getMapOperands().size(); + mapOperands = enterDataOp.getMapOperands(); + mapTypes = enterDataOp.getMapTypes(); + mapperFunc = ompBuilder->getOrCreateRuntimeFunctionPtr( + llvm::omp::OMPRTL___tgt_target_data_begin_mapper); + return success(); + }) + .Case([&](omp::ExitDataOp exitDataOp) { + if (exitDataOp.getNowait()) + return failure(); + + if (auto ifExprVar = exitDataOp.getIfExpr()) + ifCond = moduleTranslation.lookupValue(ifExprVar); + + if (auto devId = exitDataOp.getDevice()) + if (auto constOp = mlir::dyn_cast( + devId.getDefiningOp())) + if (auto intAttr = + constOp.getValue().dyn_cast()) + deviceID = intAttr.getInt(); + + numMapOperands = exitDataOp.getMapOperands().size(); + mapOperands = exitDataOp.getMapOperands(); + mapTypes = exitDataOp.getMapTypes(); + mapperFunc = ompBuilder->getOrCreateRuntimeFunctionPtr( + llvm::omp::OMPRTL___tgt_target_data_end_mapper); + return success(); + }) + .Default([&](Operation *op) { + return op->emitError("unsupported OpenMP operation: ") + << op->getName(); + }); + + if (failed(result)) + return failure(); + + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); + + struct llvm::OpenMPIRBuilder::MapperAllocas mapperAllocas; + SmallVector mapTypeFlags; + SmallVector mapNames; + ompBuilder->createMapperAllocas(builder.saveIP(), allocaIP, numMapOperands, + mapperAllocas); + + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + LogicalResult processMapOpStatus = success(); + auto processMapOpCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { + builder.restoreIP(codeGenIP); + processMapOpStatus = + processMapOperand(builder, moduleTranslation, mapOperands, mapTypes, + mapTypeFlags, mapNames, mapperAllocas); + }; + + LogicalResult bodyGenStatus = success(); + auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { + // DataOp has only one region associated with it. + auto ®ion = cast(op).getRegion(); + builder.restoreIP(codeGenIP); + convertOmpOpRegions(region, "omp.data.region", builder, moduleTranslation, + bodyGenStatus); + }; + + builder.restoreIP(ompBuilder->createTargetData( + ompLoc, builder.saveIP(), isa(op), mapTypeFlags, mapNames, + mapperAllocas, mapperFunc, deviceID, ifCond, processMapOpCB, bodyCB)); + + if (failed(processMapOpStatus)) + return processMapOpStatus; + return bodyGenStatus; +} + namespace { /// Implementation of the dialect interface that converts operations belonging @@ -1452,6 +1635,9 @@ .Case([&](omp::ThreadprivateOp) { return convertOmpThreadprivate(*op, builder, moduleTranslation); }) + .Case([&](auto op) { + return convertOmpTargetData(op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName();