diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/Utils/DirectiveToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/Utils/DirectiveToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/Utils/DirectiveToLLVMIRTranslation.h @@ -0,0 +1,32 @@ +//===- DirectiveToLLVMIRTranslation.h - Utils used by OpenMP and OpenACC -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities and common LLVMIR translation patterns for +// OpenMP and OpenACC operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_UTILS_COMMONOPENMPOPENACC_H +#define MLIR_DIALECT_UTILS_COMMONOPENMPOPENACC_H + +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +namespace mlir { + +LogicalResult convertDataOp(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation); + +template +LogicalResult +convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation); + +} // namespace mlir + +#endif // MLIR_DIALECT_UTILS_COMMONOPENMPOPENACC_H \ No newline at end of file diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -6,4 +6,5 @@ add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(ROCDL) +add_subdirectory(Utils) add_subdirectory(X86Vector) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt @@ -1,10 +1,11 @@ -add_mlir_translation_library(MLIROpenACCToLLVMIRTranslation +add_mlir_dialect_library(MLIROpenACCToLLVMIRTranslation OpenACCToLLVMIRTranslation.cpp LINK_COMPONENTS Core LINK_LIBS PUBLIC + MLIRDirectiveToLLVMIRTranslation MLIRIR MLIRLLVMDialect MLIROpenACCDialect diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Utils/DirectiveToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/ADT/TypeSwitch.h" @@ -28,456 +29,6 @@ using OpenACCIRBuilder = llvm::OpenMPIRBuilder; -//===----------------------------------------------------------------------===// -// Utility functions -//===----------------------------------------------------------------------===// - -/// Flag values are extracted from openmp/libomptarget/include/omptarget.h and -/// mapped to corresponding OpenACC flags. -static constexpr uint64_t kCreateFlag = 0x000; -static constexpr uint64_t kDeviceCopyinFlag = 0x001; -static constexpr uint64_t kHostCopyoutFlag = 0x002; -static constexpr uint64_t kCopyFlag = kDeviceCopyinFlag | kHostCopyoutFlag; -static constexpr uint64_t kPresentFlag = 0x1000; -static constexpr uint64_t kDeleteFlag = 0x008; -// Runtime extension to implement the OpenACC second reference counter. -static constexpr uint64_t kHoldFlag = 0x2000; - -/// Default value for the device id -static constexpr int64_t kDefaultDevice = -1; - -/// Create a constant string location from the MLIR Location information. -static llvm::Constant *createSourceLocStrFromLocation(Location loc, - OpenACCIRBuilder &builder, - StringRef name, - uint32_t &strLen) { - if (auto fileLoc = loc.dyn_cast()) { - StringRef fileName = fileLoc.getFilename(); - unsigned lineNo = fileLoc.getLine(); - unsigned colNo = fileLoc.getColumn(); - return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo, strLen); - } - std::string locStr; - llvm::raw_string_ostream locOS(locStr); - locOS << loc; - return builder.getOrCreateSrcLocStr(locOS.str(), strLen); -} - -/// Create the location struct from the operation location information. -static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder, - Operation *op) { - auto loc = op->getLoc(); - auto funcOp = op->getParentOfType(); - StringRef funcName = funcOp ? funcOp.getName() : "unknown"; - uint32_t strLen; - llvm::Constant *locStr = - createSourceLocStrFromLocation(loc, builder, funcName, strLen); - return builder.getOrCreateIdent(locStr, strLen); -} - -/// Create a constant string representing the mapping information extracted from -/// the MLIR location information. -static llvm::Constant *createMappingInformation(Location loc, - OpenACCIRBuilder &builder) { - uint32_t strLen; - if (auto nameLoc = loc.dyn_cast()) { - StringRef name = nameLoc.getName(); - return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name, - strLen); - } - return createSourceLocStrFromLocation(loc, builder, "unknown", strLen); -} - -/// Return the runtime function used to lower the given operation. -static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, - Operation *op) { - return llvm::TypeSwitch(op) - .Case([&](acc::EnterDataOp) { - return builder.getOrCreateRuntimeFunctionPtr( - llvm::omp::OMPRTL___tgt_target_data_begin_mapper); - }) - .Case([&](acc::ExitDataOp) { - return builder.getOrCreateRuntimeFunctionPtr( - llvm::omp::OMPRTL___tgt_target_data_end_mapper); - }) - .Case([&](acc::UpdateOp) { - return builder.getOrCreateRuntimeFunctionPtr( - llvm::omp::OMPRTL___tgt_target_data_update_mapper); - }); - llvm_unreachable("Unknown OpenACC operation"); -} - -/// Computes the size of type in bytes. -static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder, - llvm::Value *basePtr) { - llvm::LLVMContext &ctx = builder.getContext(); - llvm::Value *null = - llvm::Constant::getNullValue(basePtr->getType()->getPointerTo()); - llvm::Value *sizeGep = - builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1)); - llvm::Value *sizePtrToInt = - builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx)); - return sizePtrToInt; -} - -/// Extract pointer, size and mapping information from operands -/// to populate the future functions arguments. -static LogicalResult -processOperands(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, Operation *op, - ValueRange operands, unsigned totalNbOperand, - uint64_t operandFlag, SmallVector &flags, - SmallVectorImpl &names, unsigned &index, - struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { - OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); - llvm::LLVMContext &ctx = builder.getContext(); - auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); - auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); - auto *i64Ty = llvm::Type::getInt64Ty(ctx); - auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); - - for (Value data : operands) { - llvm::Value *dataValue = moduleTranslation.lookupValue(data); - - llvm::Value *dataPtrBase; - llvm::Value *dataPtr; - llvm::Value *dataSize; - - // Handle operands that were converted to DataDescriptor. - if (DataDescriptor::isValid(data)) { - dataPtrBase = - builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor); - dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor); - dataSize = - builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor); - } else if (data.getType().isa()) { - dataPtrBase = dataValue; - dataPtr = dataValue; - dataSize = getSizeInBytes(builder, dataValue); - } else { - return op->emitOpError() - << "Data operand must be legalized before translation." - << "Unsupported type: " << data.getType(); - } - - // Store base pointer extracted from operand into the i-th position of - // argBase. - llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP( - arrI8PtrTy, mapperAllocas.ArgsBase, - {builder.getInt32(0), builder.getInt32(index)}); - llvm::Value *ptrBaseCast = builder.CreateBitCast( - ptrBaseGEP, dataPtrBase->getType()->getPointerTo()); - builder.CreateStore(dataPtrBase, ptrBaseCast); - - // Store pointer extracted from operand into the i-th position of args. - llvm::Value *ptrGEP = builder.CreateInBoundsGEP( - arrI8PtrTy, mapperAllocas.Args, - {builder.getInt32(0), builder.getInt32(index)}); - llvm::Value *ptrCast = - builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo()); - builder.CreateStore(dataPtr, ptrCast); - - // Store size extracted from operand into the i-th position of argSizes. - llvm::Value *sizeGEP = builder.CreateInBoundsGEP( - arrI64Ty, mapperAllocas.ArgSizes, - {builder.getInt32(0), builder.getInt32(index)}); - builder.CreateStore(dataSize, sizeGEP); - - flags.push_back(operandFlag); - llvm::Constant *mapName = - createMappingInformation(data.getLoc(), *accBuilder); - names.push_back(mapName); - ++index; - } - return success(); -} - -/// Process data operands from acc::EnterDataOp -static LogicalResult -processDataOperands(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - acc::EnterDataOp op, SmallVector &flags, - SmallVectorImpl &names, - struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { - // TODO add `create_zero` and `attach` operands - - unsigned index = 0; - - // Create operands are handled as `alloc` call. - if (failed(processOperands(builder, moduleTranslation, op, - op.createOperands(), op.getNumDataOperands(), - kCreateFlag, flags, names, index, mapperAllocas))) - return failure(); - - // Copyin operands are handled as `to` call. - if (failed(processOperands(builder, moduleTranslation, op, - op.copyinOperands(), op.getNumDataOperands(), - kDeviceCopyinFlag, flags, names, index, - mapperAllocas))) - return failure(); - - return success(); -} - -/// Process data operands from acc::ExitDataOp -static LogicalResult -processDataOperands(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - acc::ExitDataOp op, SmallVector &flags, - SmallVectorImpl &names, - struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { - // TODO add `detach` operands - - unsigned index = 0; - - // Delete operands are handled as `delete` call. - if (failed(processOperands(builder, moduleTranslation, op, - op.deleteOperands(), op.getNumDataOperands(), - kDeleteFlag, flags, names, index, mapperAllocas))) - return failure(); - - // Copyout operands are handled as `from` call. - if (failed(processOperands(builder, moduleTranslation, op, - op.copyoutOperands(), op.getNumDataOperands(), - kHostCopyoutFlag, flags, names, index, - mapperAllocas))) - return failure(); - - return success(); -} - -/// Process data operands from acc::UpdateOp -static LogicalResult -processDataOperands(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - acc::UpdateOp op, SmallVector &flags, - SmallVectorImpl &names, - struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { - unsigned index = 0; - - // Host operands are handled as `from` call. - if (failed(processOperands(builder, moduleTranslation, op, op.hostOperands(), - op.getNumDataOperands(), kHostCopyoutFlag, flags, - names, index, mapperAllocas))) - return failure(); - - // Device operands are handled as `to` call. - if (failed(processOperands(builder, moduleTranslation, op, - op.deviceOperands(), op.getNumDataOperands(), - kDeviceCopyinFlag, flags, names, index, - mapperAllocas))) - return failure(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// Conversion functions -//===----------------------------------------------------------------------===// - -/// Converts an OpenACC data operation into LLVM IR. -static LogicalResult convertDataOp(acc::DataOp &op, - llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { - llvm::LLVMContext &ctx = builder.getContext(); - auto enclosingFuncOp = op.getOperation()->getParentOfType(); - llvm::Function *enclosingFunction = - moduleTranslation.lookupFunction(enclosingFuncOp.getName()); - - OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); - - llvm::Value *srcLocInfo = createSourceLocationInfo(*accBuilder, op); - - llvm::Function *beginMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr( - llvm::omp::OMPRTL___tgt_target_data_begin_mapper); - - llvm::Function *endMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr( - llvm::omp::OMPRTL___tgt_target_data_end_mapper); - - // Number of arguments in the data operation. - unsigned totalNbOperand = op.getNumDataOperands(); - - struct OpenACCIRBuilder::MapperAllocas mapperAllocas; - OpenACCIRBuilder::InsertPointTy allocaIP( - &enclosingFunction->getEntryBlock(), - enclosingFunction->getEntryBlock().getFirstInsertionPt()); - accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand, - mapperAllocas); - - SmallVector flags; - SmallVector names; - unsigned index = 0; - - // TODO handle no_create, deviceptr and attach operands. - - if (failed(processOperands(builder, moduleTranslation, op, op.copyOperands(), - totalNbOperand, kCopyFlag | kHoldFlag, flags, - names, index, mapperAllocas))) - return failure(); - - if (failed(processOperands( - builder, moduleTranslation, op, op.copyinOperands(), totalNbOperand, - kDeviceCopyinFlag | kHoldFlag, flags, names, index, mapperAllocas))) - return failure(); - - // TODO copyin readonly currenlty handled as copyin. Update when extension - // available. - if (failed(processOperands(builder, moduleTranslation, op, - op.copyinReadonlyOperands(), totalNbOperand, - kDeviceCopyinFlag | kHoldFlag, flags, names, index, - mapperAllocas))) - return failure(); - - if (failed(processOperands( - builder, moduleTranslation, op, op.copyoutOperands(), totalNbOperand, - kHostCopyoutFlag | kHoldFlag, flags, names, index, mapperAllocas))) - return failure(); - - // TODO copyout zero currenlty handled as copyout. Update when extension - // available. - if (failed(processOperands(builder, moduleTranslation, op, - op.copyoutZeroOperands(), totalNbOperand, - kHostCopyoutFlag | kHoldFlag, flags, names, index, - mapperAllocas))) - return failure(); - - if (failed(processOperands(builder, moduleTranslation, op, - op.createOperands(), totalNbOperand, - kCreateFlag | kHoldFlag, flags, names, index, - mapperAllocas))) - return failure(); - - // TODO create zero currenlty handled as create. Update when extension - // available. - if (failed(processOperands(builder, moduleTranslation, op, - op.createZeroOperands(), totalNbOperand, - kCreateFlag | kHoldFlag, flags, names, index, - mapperAllocas))) - return failure(); - - if (failed(processOperands(builder, moduleTranslation, op, - op.presentOperands(), totalNbOperand, - kPresentFlag | kHoldFlag, flags, names, index, - mapperAllocas))) - return failure(); - - llvm::GlobalVariable *maptypes = - accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); - llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( - llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), - maptypes, /*Idx0=*/0, /*Idx1=*/0); - - llvm::GlobalVariable *mapnames = - accBuilder->createOffloadMapnames(names, ".offload_mapnames"); - llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( - llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), - mapnames, /*Idx0=*/0, /*Idx1=*/0); - - // Create call to start the data region. - accBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo, - maptypesArg, mapnamesArg, mapperAllocas, - kDefaultDevice, totalNbOperand); - - // Convert the region. - llvm::BasicBlock *entryBlock = nullptr; - - for (Block &bb : op.region()) { - llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create( - ctx, "acc.data", builder.GetInsertBlock()->getParent()); - if (entryBlock == nullptr) - entryBlock = llvmBB; - moduleTranslation.mapBlock(&bb, llvmBB); - } - - auto afterDataRegion = builder.saveIP(); - - llvm::BranchInst *sourceTerminator = builder.CreateBr(entryBlock); - - builder.restoreIP(afterDataRegion); - llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create( - ctx, "acc.end_data", builder.GetInsertBlock()->getParent()); - - SetVector blocks = - LLVM::detail::getTopologicallySortedBlocks(op.region()); - for (Block *bb : blocks) { - llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb); - if (bb->isEntryBlock()) { - assert(sourceTerminator->getNumSuccessors() == 1 && - "provided entry block has multiple successors"); - sourceTerminator->setSuccessor(0, llvmBB); - } - - if (failed( - moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) { - return failure(); - } - - if (isa(bb->getTerminator())) - builder.CreateBr(endDataBlock); - } - - // Create call to end the data region. - builder.SetInsertPoint(endDataBlock); - accBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo, - maptypesArg, mapnamesArg, mapperAllocas, - kDefaultDevice, totalNbOperand); - - return success(); -} - -/// Converts an OpenACC standalone data operation into LLVM IR. -template -static LogicalResult -convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { - auto enclosingFuncOp = - op.getOperation()->template getParentOfType(); - llvm::Function *enclosingFunction = - moduleTranslation.lookupFunction(enclosingFuncOp.getName()); - - OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); - - auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op); - auto *mapperFunc = getAssociatedFunction(*accBuilder, op); - - // Number of arguments in the enter_data operation. - unsigned totalNbOperand = op.getNumDataOperands(); - - llvm::LLVMContext &ctx = builder.getContext(); - - struct OpenACCIRBuilder::MapperAllocas mapperAllocas; - OpenACCIRBuilder::InsertPointTy allocaIP( - &enclosingFunction->getEntryBlock(), - enclosingFunction->getEntryBlock().getFirstInsertionPt()); - accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand, - mapperAllocas); - - SmallVector flags; - SmallVector names; - - if (failed(processDataOperands(builder, moduleTranslation, op, flags, names, - mapperAllocas))) - return failure(); - - llvm::GlobalVariable *maptypes = - accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); - llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( - llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), - maptypes, /*Idx0=*/0, /*Idx1=*/0); - - llvm::GlobalVariable *mapnames = - accBuilder->createOffloadMapnames(names, ".offload_mapnames"); - llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( - llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), - mapnames, /*Idx0=*/0, /*Idx1=*/0); - - accBuilder->emitMapperCall(builder.saveIP(), mapperFunc, srcLocInfo, - maptypesArg, mapnamesArg, mapperAllocas, - kDefaultDevice, totalNbOperand); - - return success(); -} - namespace { /// Implementation of the dialect interface that converts operations belonging @@ -504,7 +55,7 @@ return llvm::TypeSwitch(op) .Case([&](acc::DataOp dataOp) { - return convertDataOp(dataOp, builder, moduleTranslation); + return convertDataOp(op, builder, moduleTranslation); }) .Case([&](acc::EnterDataOp enterDataOp) { return convertStandaloneDataOp(enterDataOp, builder, diff --git a/mlir/lib/Target/LLVMIR/Dialect/Utils/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/Utils/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(MLIRDirectiveToLLVMIRTranslation +DirectiveToLLVMIRTranslation.cpp + +ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR/Dialect/Utils/DirectiveToLLVMIRTranslation + +LINK_COMPONENTS +Core + +LINK_LIBS PUBLIC +MLIRIR +MLIRLLVMDialect +MLIROpenACCDialect +MLIROpenACCToLLVM +MLIROpenMPDialect +MLIRSupport +MLIRTargetLLVMIRExport +) \ No newline at end of file diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Utils/DirectiveToLLVMIRTranslation.cpp copy from mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp copy to mlir/lib/Target/LLVMIR/Dialect/Utils/DirectiveToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/Utils/DirectiveToLLVMIRTranslation.cpp @@ -1,23 +1,19 @@ -//===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===// +//===- DirectiveToLLVMIRTranslation.cpp - Utils used by OpenMP and OpenACC ==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This file implements a translation between the MLIR OpenACC dialect and LLVM -// IR. -// -//===----------------------------------------------------------------------===// -#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/Utils/DirectiveToLLVMIRTranslation.h" #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/ADT/TypeSwitch.h" @@ -26,12 +22,6 @@ using namespace mlir; -using OpenACCIRBuilder = llvm::OpenMPIRBuilder; - -//===----------------------------------------------------------------------===// -// Utility functions -//===----------------------------------------------------------------------===// - /// Flag values are extracted from openmp/libomptarget/include/omptarget.h and /// mapped to corresponding OpenACC flags. static constexpr uint64_t kCreateFlag = 0x000; @@ -47,10 +37,9 @@ static constexpr int64_t kDefaultDevice = -1; /// Create a constant string location from the MLIR Location information. -static llvm::Constant *createSourceLocStrFromLocation(Location loc, - OpenACCIRBuilder &builder, - StringRef name, - uint32_t &strLen) { +static llvm::Constant * +createSourceLocStrFromLocation(Location loc, llvm::OpenMPIRBuilder &builder, + StringRef name, uint32_t &strLen) { if (auto fileLoc = loc.dyn_cast()) { StringRef fileName = fileLoc.getFilename(); unsigned lineNo = fileLoc.getLine(); @@ -64,7 +53,7 @@ } /// Create the location struct from the operation location information. -static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder, +static llvm::Value *createSourceLocationInfo(llvm::OpenMPIRBuilder &builder, Operation *op) { auto loc = op->getLoc(); auto funcOp = op->getParentOfType(); @@ -77,8 +66,8 @@ /// Create a constant string representing the mapping information extracted from /// the MLIR location information. -static llvm::Constant *createMappingInformation(Location loc, - OpenACCIRBuilder &builder) { +static llvm::Constant * +createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder) { uint32_t strLen; if (auto nameLoc = loc.dyn_cast()) { StringRef name = nameLoc.getName(); @@ -89,7 +78,7 @@ } /// Return the runtime function used to lower the given operation. -static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, +static llvm::Function *getAssociatedFunction(llvm::OpenMPIRBuilder &builder, Operation *op) { return llvm::TypeSwitch(op) .Case([&](acc::EnterDataOp) { @@ -128,15 +117,15 @@ ValueRange operands, unsigned totalNbOperand, uint64_t operandFlag, SmallVector &flags, SmallVectorImpl &names, unsigned &index, - struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { - OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); + struct llvm::OpenMPIRBuilder::MapperAllocas &mapperAllocas) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); llvm::LLVMContext &ctx = builder.getContext(); auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); auto *i64Ty = llvm::Type::getInt64Ty(ctx); auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); - for (Value data : operands) { + for (mlir::Value data : operands) { llvm::Value *dataValue = moduleTranslation.lookupValue(data); llvm::Value *dataPtrBase; @@ -185,7 +174,7 @@ flags.push_back(operandFlag); llvm::Constant *mapName = - createMappingInformation(data.getLoc(), *accBuilder); + createMappingInformation(data.getLoc(), *ompBuilder); names.push_back(mapName); ++index; } @@ -193,12 +182,11 @@ } /// Process data operands from acc::EnterDataOp -static LogicalResult -processDataOperands(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - acc::EnterDataOp op, SmallVector &flags, - SmallVectorImpl &names, - struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { +static LogicalResult processDataOperands( + llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, + acc::EnterDataOp op, SmallVector &flags, + SmallVectorImpl &names, + struct llvm::OpenMPIRBuilder::MapperAllocas &mapperAllocas) { // TODO add `create_zero` and `attach` operands unsigned index = 0; @@ -220,12 +208,11 @@ } /// Process data operands from acc::ExitDataOp -static LogicalResult -processDataOperands(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - acc::ExitDataOp op, SmallVector &flags, - SmallVectorImpl &names, - struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { +static LogicalResult processDataOperands( + llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, + acc::ExitDataOp op, SmallVector &flags, + SmallVectorImpl &names, + struct llvm::OpenMPIRBuilder::MapperAllocas &mapperAllocas) { // TODO add `detach` operands unsigned index = 0; @@ -247,12 +234,11 @@ } /// Process data operands from acc::UpdateOp -static LogicalResult -processDataOperands(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - acc::UpdateOp op, SmallVector &flags, - SmallVectorImpl &names, - struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { +static LogicalResult processDataOperands( + llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, + acc::UpdateOp op, SmallVector &flags, + SmallVectorImpl &names, + struct llvm::OpenMPIRBuilder::MapperAllocas &mapperAllocas) { unsigned index = 0; // Host operands are handled as `from` call. @@ -276,32 +262,47 @@ //===----------------------------------------------------------------------===// /// Converts an OpenACC data operation into LLVM IR. -static LogicalResult convertDataOp(acc::DataOp &op, - llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +LogicalResult mlir::convertDataOp(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { llvm::LLVMContext &ctx = builder.getContext(); - auto enclosingFuncOp = op.getOperation()->getParentOfType(); + auto enclosingFuncOp = op->getParentOfType(); llvm::Function *enclosingFunction = moduleTranslation.lookupFunction(enclosingFuncOp.getName()); - OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - llvm::Value *srcLocInfo = createSourceLocationInfo(*accBuilder, op); + llvm::Value *srcLocInfo = createSourceLocationInfo(*ompBuilder, op); - llvm::Function *beginMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr( + llvm::Function *beginMapperFunc = ompBuilder->getOrCreateRuntimeFunctionPtr( llvm::omp::OMPRTL___tgt_target_data_begin_mapper); - llvm::Function *endMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr( + llvm::Function *endMapperFunc = ompBuilder->getOrCreateRuntimeFunctionPtr( llvm::omp::OMPRTL___tgt_target_data_end_mapper); // Number of arguments in the data operation. - unsigned totalNbOperand = op.getNumDataOperands(); + unsigned totalNbOperand = 0; + ValueRange copyOperands, copyinOperands, copyinReadonlyOperands; + ValueRange copyoutOperands, copyoutZeroOperands, createOperands; + ValueRange createZeroOperands, presentOperands; + + if(isa(op)){ + acc::DataOp dataOp = dyn_cast(op); + totalNbOperand = dataOp.getNumDataOperands(); + copyOperands = dataOp.copyOperands(); + copyinOperands = dataOp.copyinOperands(); + copyinReadonlyOperands = dataOp.copyinReadonlyOperands(); + copyoutOperands = dataOp.copyoutOperands(); + copyoutZeroOperands = dataOp.copyoutZeroOperands(); + createOperands = dataOp.createOperands(); + createZeroOperands = dataOp.createZeroOperands(); + presentOperands = dataOp.presentOperands(); + } - struct OpenACCIRBuilder::MapperAllocas mapperAllocas; - OpenACCIRBuilder::InsertPointTy allocaIP( + struct llvm::OpenMPIRBuilder::MapperAllocas mapperAllocas; + llvm::OpenMPIRBuilder::InsertPointTy allocaIP( &enclosingFunction->getEntryBlock(), enclosingFunction->getEntryBlock().getFirstInsertionPt()); - accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand, + ompBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand, mapperAllocas); SmallVector flags; @@ -310,78 +311,76 @@ // TODO handle no_create, deviceptr and attach operands. - if (failed(processOperands(builder, moduleTranslation, op, op.copyOperands(), + if (failed(processOperands(builder, moduleTranslation, op, copyOperands, totalNbOperand, kCopyFlag | kHoldFlag, flags, names, index, mapperAllocas))) return failure(); if (failed(processOperands( - builder, moduleTranslation, op, op.copyinOperands(), totalNbOperand, + builder, moduleTranslation, op, copyinOperands, totalNbOperand, kDeviceCopyinFlag | kHoldFlag, flags, names, index, mapperAllocas))) return failure(); // TODO copyin readonly currenlty handled as copyin. Update when extension // available. if (failed(processOperands(builder, moduleTranslation, op, - op.copyinReadonlyOperands(), totalNbOperand, + copyinReadonlyOperands, totalNbOperand, kDeviceCopyinFlag | kHoldFlag, flags, names, index, mapperAllocas))) return failure(); if (failed(processOperands( - builder, moduleTranslation, op, op.copyoutOperands(), totalNbOperand, + builder, moduleTranslation, op, copyoutOperands, totalNbOperand, kHostCopyoutFlag | kHoldFlag, flags, names, index, mapperAllocas))) return failure(); // TODO copyout zero currenlty handled as copyout. Update when extension // available. if (failed(processOperands(builder, moduleTranslation, op, - op.copyoutZeroOperands(), totalNbOperand, + copyoutZeroOperands, totalNbOperand, kHostCopyoutFlag | kHoldFlag, flags, names, index, mapperAllocas))) return failure(); - if (failed(processOperands(builder, moduleTranslation, op, - op.createOperands(), totalNbOperand, - kCreateFlag | kHoldFlag, flags, names, index, - mapperAllocas))) + if (failed(processOperands( + builder, moduleTranslation, op, createOperands, totalNbOperand, + kCreateFlag | kHoldFlag, flags, names, index, mapperAllocas))) return failure(); // TODO create zero currenlty handled as create. Update when extension // available. if (failed(processOperands(builder, moduleTranslation, op, - op.createZeroOperands(), totalNbOperand, + createZeroOperands, totalNbOperand, kCreateFlag | kHoldFlag, flags, names, index, mapperAllocas))) return failure(); - if (failed(processOperands(builder, moduleTranslation, op, - op.presentOperands(), totalNbOperand, - kPresentFlag | kHoldFlag, flags, names, index, - mapperAllocas))) + if (failed(processOperands( + builder, moduleTranslation, op, presentOperands, totalNbOperand, + kPresentFlag | kHoldFlag, flags, names, index, mapperAllocas))) return failure(); llvm::GlobalVariable *maptypes = - accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); + ompBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), maptypes, /*Idx0=*/0, /*Idx1=*/0); llvm::GlobalVariable *mapnames = - accBuilder->createOffloadMapnames(names, ".offload_mapnames"); + ompBuilder->createOffloadMapnames(names, ".offload_mapnames"); llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), mapnames, /*Idx0=*/0, /*Idx1=*/0); // Create call to start the data region. - accBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo, + ompBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo, maptypesArg, mapnamesArg, mapperAllocas, kDefaultDevice, totalNbOperand); // Convert the region. llvm::BasicBlock *entryBlock = nullptr; - for (Block &bb : op.region()) { + for (Block &bb : op->getRegion(0)) { llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create( ctx, "acc.data", builder.GetInsertBlock()->getParent()); if (entryBlock == nullptr) @@ -398,7 +397,7 @@ ctx, "acc.end_data", builder.GetInsertBlock()->getParent()); SetVector blocks = - LLVM::detail::getTopologicallySortedBlocks(op.region()); + LLVM::detail::getTopologicallySortedBlocks(op->getRegion(0)); for (Block *bb : blocks) { llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb); if (bb->isEntryBlock()) { @@ -418,16 +417,16 @@ // Create call to end the data region. builder.SetInsertPoint(endDataBlock); - accBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo, + ompBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo, maptypesArg, mapnamesArg, mapperAllocas, kDefaultDevice, totalNbOperand); return success(); } - +namespace mlir { /// Converts an OpenACC standalone data operation into LLVM IR. template -static LogicalResult +LogicalResult convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { auto enclosingFuncOp = @@ -435,7 +434,7 @@ llvm::Function *enclosingFunction = moduleTranslation.lookupFunction(enclosingFuncOp.getName()); - OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::OpenMPIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op); auto *mapperFunc = getAssociatedFunction(*accBuilder, op); @@ -445,8 +444,8 @@ llvm::LLVMContext &ctx = builder.getContext(); - struct OpenACCIRBuilder::MapperAllocas mapperAllocas; - OpenACCIRBuilder::InsertPointTy allocaIP( + struct llvm::OpenMPIRBuilder::MapperAllocas mapperAllocas; + llvm::OpenMPIRBuilder::InsertPointTy allocaIP( &enclosingFunction->getEntryBlock(), enclosingFunction->getEntryBlock().getFirstInsertionPt()); accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand, @@ -478,68 +477,14 @@ return success(); } -namespace { - -/// Implementation of the dialect interface that converts operations belonging -/// to the OpenACC dialect to LLVM IR. -class OpenACCDialectLLVMIRTranslationInterface - : public LLVMTranslationDialectInterface { -public: - using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; - - /// Translates the given operation to LLVM IR using the provided IR builder - /// and saving the state in `moduleTranslation`. - LogicalResult - convertOperation(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) const final; -}; - -} // namespace - -/// Given an OpenACC MLIR operation, create the corresponding LLVM IR -/// (including OpenACC runtime calls). -LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation( - Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) const { - - return llvm::TypeSwitch(op) - .Case([&](acc::DataOp dataOp) { - return convertDataOp(dataOp, builder, moduleTranslation); - }) - .Case([&](acc::EnterDataOp enterDataOp) { - return convertStandaloneDataOp(enterDataOp, builder, - moduleTranslation); - }) - .Case([&](acc::ExitDataOp exitDataOp) { - return convertStandaloneDataOp(exitDataOp, builder, - moduleTranslation); - }) - .Case([&](acc::UpdateOp updateOp) { - return convertStandaloneDataOp(updateOp, builder, - moduleTranslation); - }) - .Case([](auto op) { - // `yield` and `terminator` can be just omitted. The block structure was - // created in the function that handles their parent operation. - assert(op->getNumOperands() == 0 && - "unexpected OpenACC terminator with operands"); - return success(); - }) - .Default([&](Operation *op) { - return op->emitError("unsupported OpenACC operation: ") - << op->getName(); - }); -} - -void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) { - registry.insert(); - registry.addExtension(+[](MLIRContext *ctx, acc::OpenACCDialect *dialect) { - dialect->addInterfaces(); - }); -} - -void mlir::registerOpenACCDialectTranslation(MLIRContext &context) { - DialectRegistry registry; - registerOpenACCDialectTranslation(registry); - context.appendDialectRegistry(registry); -} +template LogicalResult convertStandaloneDataOp( + acc::EnterDataOp &op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation); +template LogicalResult convertStandaloneDataOp( + acc::ExitDataOp &op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation); +template LogicalResult convertStandaloneDataOp( + acc::UpdateOp &op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation); + +} // namespace mlir \ No newline at end of file