diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -19,6 +19,7 @@ #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" @@ -34,6 +35,7 @@ registerArmSVEDialectTranslation(registry); registerLLVMDialectTranslation(registry); registerNVVMDialectTranslation(registry); + registerOpenACCDialectTranslation(registry); registerOpenMPDialectTranslation(registry); registerROCDLDialectTranslation(registry); registerX86VectorDialectTranslation(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//===- OpenACCToLLVMIRTranslation.h - OpenACC Dialect to LLVM IR -- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for OpenACC dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_OPENACC_OPENACCTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_OPENACC_OPENACCTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the OpenACC dialect and the translation to the LLVM IR in +/// the given registry; +void registerOpenACCDialectTranslation(DialectRegistry ®istry); + +/// Register the OpenACC dialect and the translation in the registry +/// associated with the given context. +void registerOpenACCDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_OPENACC_OPENACCTOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -27,6 +27,7 @@ LINK_LIBS PUBLIC MLIRLLVMIR + MLIROpenACC MLIROpenMP MLIRLLVMIRTransforms MLIRTranslation @@ -42,6 +43,7 @@ MLIRX86VectorToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation + MLIROpenACCToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(AMX) add_subdirectory(LLVMIR) add_subdirectory(NVVM) +add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(ROCDL) add_subdirectory(X86Vector) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_translation_library(MLIROpenACCToLLVMIRTranslation + OpenACCToLLVMIRTranslation.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIROpenACC + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -0,0 +1,287 @@ +//===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR OpenACC dialect and LLVM +// IR. +// +//===----------------------------------------------------------------------===// +#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +using OpenACCIRBuilder = llvm::OpenMPIRBuilder; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// 0 = alloc/create +static constexpr uint64_t createFlag = 0; +/// 1 = to/copyin +static constexpr uint64_t copyinFlag = 1; +/// Default value for the device id +static constexpr int64_t defaultDevice = -1; + +/// Create a constant string location from the MLIR Location information. +static llvm::Constant *createSourceLocStrFromLocation(Location loc, + OpenACCIRBuilder &builder, + StringRef name) { + if (auto fileLoc = loc.dyn_cast()) { + StringRef fileName = fileLoc.getFilename(); + unsigned lineNo = fileLoc.getLine(); + unsigned colNo = fileLoc.getColumn(); + return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo); + } else { + std::string locStr; + llvm::raw_string_ostream locOS(locStr); + locOS << loc; + return builder.getOrCreateSrcLocStr(locOS.str()); + } +} + +/// Create the location struct from the operation location information. +static llvm::Value *createSourceLocationInfo(acc::EnterDataOp &op, + OpenACCIRBuilder &builder) { + auto loc = op.getLoc(); + auto funcOp = op.getOperation()->getParentOfType(); + StringRef funcName = funcOp ? funcOp.getName() : "unknown"; + llvm::Constant *locStr = + createSourceLocStrFromLocation(loc, builder, funcName); + return builder.getOrCreateIdent(locStr); +} + +/// Create a constant string representing the mapping information extracted from +/// the MLIR location information. +static llvm::Constant *createMappingInformation(Location loc, + OpenACCIRBuilder &builder) { + if (auto nameLoc = loc.dyn_cast()) { + StringRef name = nameLoc.getName(); + return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name); + } else { + return createSourceLocStrFromLocation(loc, builder, "unknown"); + } +} + +/// Return the runtime function used to lower the given operation. +static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, + Operation &op) { + if (isa(op)) + return builder.getOrCreateRuntimeFunctionPtr( + llvm::omp::OMPRTL___tgt_target_data_begin_mapper); + llvm_unreachable("Unknown OpenACC operation"); +} + +/// Computes the size of type in bytes. +static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder, + llvm::Value *basePtr) { + llvm::LLVMContext &ctx = builder.getContext(); + llvm::Value *null = + llvm::Constant::getNullValue(basePtr->getType()->getPointerTo()); + llvm::Value *sizeGep = + builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1)); + llvm::Value *sizePtrToInt = + builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx)); + return sizePtrToInt; +} + +/// Extract pointer, size and mapping information from operands +/// to populate the future functions arguments. +static LogicalResult +processOperands(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, Operation &op, + ValueRange operands, unsigned totalNbOperand, + uint64_t operandFlag, SmallVector &flags, + SmallVector &names, unsigned &index, + llvm::AllocaInst *arg, llvm::AllocaInst *argBase, + llvm::AllocaInst *argSizes) { + OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::LLVMContext &ctx = builder.getContext(); + auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); + auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); + auto *i64Ty = llvm::Type::getInt64Ty(ctx); + auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); + + for (Value data : operands) { + // Handle operands that were converted to MemRefDescriptors. + if (data.getType().isa()) { + MemRefDescriptor descriptor(data); + llvm::Value *memRefValue = moduleTranslation.lookupValue(data); + llvm::Value *allocatedPtr = builder.CreateExtractValue(memRefValue, 0); + llvm::Value *sizePtrToInt = getSizeInBytes(builder, allocatedPtr); + llvm::Value *sizesPtr = builder.CreateInBoundsGEP( + arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(index)}); + builder.CreateStore(sizePtrToInt, sizesPtr); + + llvm::Value *basePtr = builder.CreateInBoundsGEP( + arrI8PtrTy, argBase, {builder.getInt32(0), builder.getInt32(index)}); + llvm::Value *basePtrCast = builder.CreateBitCast( + basePtr, allocatedPtr->getType()->getPointerTo()); + builder.CreateStore(allocatedPtr, basePtrCast); + + llvm::Value *ptr = builder.CreateInBoundsGEP( + arrI8PtrTy, arg, {builder.getInt32(0), builder.getInt32(index)}); + llvm::Value *ptrCast = + builder.CreateBitCast(ptr, allocatedPtr->getType()->getPointerTo()); + builder.CreateStore(allocatedPtr, ptrCast); + + flags.push_back(operandFlag); + llvm::Constant *mapName = + createMappingInformation(data.getLoc(), *accBuilder); + names.push_back(mapName); + } else { + // Not supported types for now. + op.emitOpError("Unsupported type ") << data.getType(); + return failure(); + } + ++index; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Conversion functions +//===----------------------------------------------------------------------===// + +/// Converts an OpenACC enter_data operartion into LLVM IR. +static LogicalResult +convertEnterDataOp(Operation &op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto enterDataOp = cast(op); + auto enclosingFuncOp = op.getParentOfType(); + llvm::Function *enclosingFunction = + moduleTranslation.lookupFunction(enclosingFuncOp.getName()); + + OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); + + auto *srcLocInfo = createSourceLocationInfo(enterDataOp, *accBuilder); + auto *mapperFunc = getAssociatedFunction(*accBuilder, op); + + // Number of arguments in the enter_data operation. + // TODO include create_zero and attach operands. + unsigned totalNbOperand = + enterDataOp.createOperands().size() + enterDataOp.copyinOperands().size(); + + // TODO could be moved to OpenXXIRBuilder? + llvm::LLVMContext &ctx = builder.getContext(); + auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); + auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); + auto *i64Ty = llvm::Type::getInt64Ty(ctx); + auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); + llvm::IRBuilder<>::InsertPoint allocaIP( + &enclosingFunction->getEntryBlock(), + enclosingFunction->getEntryBlock().getFirstInsertionPt()); + llvm::IRBuilder<>::InsertPoint currentIP = builder.saveIP(); + builder.restoreIP(allocaIP); + llvm::AllocaInst *arg = builder.CreateAlloca(arrI8PtrTy); + llvm::AllocaInst *argBase = builder.CreateAlloca(arrI8PtrTy); + llvm::AllocaInst *argSizes = builder.CreateAlloca(arrI64Ty); + builder.restoreIP(currentIP); + + SmallVector flags; + SmallVector names; + unsigned index = 0; + + // create operands are handled as `alloc` call. + if (failed(processOperands(builder, moduleTranslation, op, + enterDataOp.createOperands(), totalNbOperand, + createFlag, flags, names, index, arg, argBase, + argSizes))) + return failure(); + + // copyin operands are handled as `to` call. + if (failed(processOperands(builder, moduleTranslation, op, + enterDataOp.copyinOperands(), totalNbOperand, + copyinFlag, flags, names, index, arg, argBase, + argSizes))) + return failure(); + + llvm::GlobalVariable *maptypes = + accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); + llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( + llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), + maptypes, /*Idx0=*/0, /*Idx1=*/0); + + llvm::GlobalVariable *mapnames = + accBuilder->createOffloadMapnames(names, ".offload_mapnames"); + llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( + llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), + mapnames, /*Idx0=*/0, /*Idx1=*/0); + + llvm::Value *argBaseGEP = builder.CreateInBoundsGEP( + arrI8PtrTy, argBase, {builder.getInt32(0), builder.getInt32(0)}); + llvm::Value *argGEP = builder.CreateInBoundsGEP( + arrI8PtrTy, arg, {builder.getInt32(0), builder.getInt32(0)}); + llvm::Value *argSizesGEP = builder.CreateInBoundsGEP( + arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(0)}); + llvm::Value *nullPtr = llvm::Constant::getNullValue( + llvm::Type::getInt8PtrTy(ctx)->getPointerTo()); + + builder.CreateCall(mapperFunc, + {srcLocInfo, builder.getInt64(defaultDevice), + builder.getInt32(totalNbOperand), argBaseGEP, argGEP, + argSizesGEP, maptypesArg, mapnamesArg, nullPtr}); + + return success(); +} + +namespace { + +/// Implementation of the dialect interface that converts operations belonging +/// to the OpenACC dialect to LLVM IR. +class OpenACCDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final; +}; + +} // end namespace + +/// Given an OpenACC MLIR operation, create the corresponding LLVM IR +/// (including OpenACC runtime calls). +LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation( + Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + + return llvm::TypeSwitch(op) + .Case([&](acc::EnterDataOp) { + return convertEnterDataOp(*op, builder, moduleTranslation); + }) + .Default([&](Operation *op) { + return op->emitError("unsupported OpenACC operation: ") + << op->getName(); + }); +} + +void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); +} + +void mlir::registerOpenACCDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerOpenACCDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/test/Target/LLVMIR/openacc-llvm.mlir b/mlir/test/Target/LLVMIR/openacc-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openacc-llvm.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -convert-std-to-llvm -split-input-file %s | mlir-translate -mlir-to-llvmir | FileCheck %s + +func @testenterdataop(%a: memref<10xf32>, %b: memref<10xf32>) -> () { + acc.enter_data copyin(%b : memref<10xf32>) create(%a : memref<10xf32>) + return +} + +// CHECK: %struct.ident_t = type { i32, i32, i32, i32, i8* } + +// CHECK: [[LOCSTR:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};testenterdataop;{{[0-9]*}};{{[0-9]*}};;\00", align 1 +// CHECK: [[LOCGLOBAL:@.*]] = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([{{[0-9]*}} x i8], [{{[0-9]*}} x i8]* [[LOCSTR]], i32 0, i32 0) }, align 8 +// CHECK: [[MAPNAME1:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};unknown;8;10;;\00", align 1 +// CHECK: [[MAPNAME2:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};unknown;14;11;;\00", align 1 +// CHECK: [[MAPTYPES:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i64] [i64 0, i64 1] +// CHECK: [[MAPNAMES:@.*]] = private constant [{{[0-9]*}} x i8*] [i8* getelementptr inbounds ([{{[0-9]*}} x i8], [{{[0-9]*}} x i8]* [[MAPNAME1]], i32 0, i32 0), i8* getelementptr inbounds ([{{[0-9]*}} x i8], [{{[0-9]*}} x i8]* [[MAPNAME2]], i32 0, i32 0)] + + +// CHECK: define void @testenterdataop +// CHECK: %{{.*}} = alloca [{{[0-9]*}} x i8*], align 8 +// CHECK: %{{.*}} = alloca [{{[0-9]*}} x i8*], align 8 +// CHECK: %{{.*}} = alloca [{{[0-9]*}} x i64], align 8 + +// CHECK: call void @__tgt_target_data_begin_mapper(%struct.ident_t* [[LOCGLOBAL]], i64 -1, i32 2, i8** %{{.*}}, i8** %{{.*}}, i64* %{{.*}}, i64* getelementptr inbounds ([{{[0-9]*}} x i64], [{{[0-9]*}} x i64]* [[MAPTYPES]], i32 0, i32 0), i8** getelementptr inbounds ([{{[0-9]*}} x i8*], [{{[0-9]*}} x i8*]* [[MAPNAMES]], i32 0, i32 0), i8** null) + +// CHECK: declare void @__tgt_target_data_begin_mapper(%struct.ident_t*, i64, i32, i8**, i8**, i64*, i64*, i8**, i8**) #0