diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -348,6 +348,10 @@ /// A cache for the symbol tables constructed during symbols lookup. SymbolTableCollection symbolTableCollection; + + /// The set of functions that should be masked out from the output, due to not + /// corresponding to the target device being compiled for. + llvm::SmallVector maskedFunctions; }; namespace detail { diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -39,6 +39,7 @@ MLIRLLVMDialect MLIRLLVMIRTransforms MLIRTranslateLib + MLIROpenMPDialect ) add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration @@ -56,6 +57,7 @@ MLIROpenACCToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation + MLIROpenMPDialect ) add_mlir_translation_library(MLIRTargetLLVMIRImport @@ -76,6 +78,7 @@ MLIRDLTIDialect MLIRLLVMDialect MLIRTranslateLib + MLIROpenMPDialect ) add_mlir_translation_library(MLIRFromLLVMIRTranslationRegistration diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -898,7 +898,28 @@ detail::connectPHINodes(func.getBody(), *this); // Finally, convert dialect attributes attached to the function. - return convertDialectAttributes(func); + LogicalResult result = convertDialectAttributes(func); + + // All functions are translated first to ensure target regions are always + // processed, so that they can be outlined. However, they must be deleted + // afterwards if the device they are intended for does not match the device we + // are currently generating code for. + if (auto offloadMod = + dyn_cast(mlirModule)) { + bool isDevicePass = offloadMod.getIsDevice(); + omp::DeclareTargetDeviceType declareType = + omp::DeclareTargetDeviceType::host; + + if (omp::OpenMPDialect::isDeclareTarget(func.getOperation())) + declareType = + omp::OpenMPDialect::getDeclareTargetDeviceType(func.getOperation()); + + if ((isDevicePass && declareType == omp::DeclareTargetDeviceType::host) || + (!isDevicePass && declareType == omp::DeclareTargetDeviceType::nohost)) + maskedFunctions.push_back(llvmFunc); + } + + return result; } LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) { @@ -1004,6 +1025,8 @@ } LogicalResult ModuleTranslation::convertFunctions() { + maskedFunctions.clear(); + // Convert functions. for (auto function : getModuleBody(mlirModule).getOps()) { // Ignore external functions. @@ -1014,6 +1037,11 @@ return failure(); } + // Delete translated functions that are intended for a device we are currently + // not generating code for. + for (auto *llvmFunction : llvm::reverse(maskedFunctions)) + llvmFunction->eraseFromParent(); + return success(); }