diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -348,6 +348,10 @@ /// A cache for the symbol tables constructed during symbols lookup. SymbolTableCollection symbolTableCollection; + + /// The set of functions that should be masked out from the output, due to not + /// corresponding to the target device being compiled for. + llvm::SmallVector maskedFunctions; }; namespace detail { diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -39,6 +39,7 @@ MLIRLLVMDialect MLIRLLVMIRTransforms MLIRTranslateLib + MLIROpenMPDialect ) add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -898,7 +898,28 @@ detail::connectPHINodes(func.getBody(), *this); // Finally, convert dialect attributes attached to the function. - return convertDialectAttributes(func); + LogicalResult result = convertDialectAttributes(func); + + // All functions are translated first to ensure target regions are always + // processed, so that they can be outlined. However, they must be deleted + // afterwards if the device they are intended for does not match the device we + // are currently generating code for. + if (auto offloadMod = + dyn_cast(mlirModule)) { + bool isDevicePass = offloadMod.getIsDevice(); + omp::DeclareTargetDeviceType declareType = + omp::DeclareTargetDeviceType::host; + + if (omp::OpenMPDialect::isDeclareTarget(func.getOperation())) + declareType = + omp::OpenMPDialect::getDeclareTargetDeviceType(func.getOperation()); + + if ((isDevicePass && declareType == omp::DeclareTargetDeviceType::host) || + (!isDevicePass && declareType == omp::DeclareTargetDeviceType::nohost)) + maskedFunctions.push_back(llvmFunc); + } + + return result; } LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) { @@ -1004,6 +1025,8 @@ } LogicalResult ModuleTranslation::convertFunctions() { + maskedFunctions.clear(); + // Convert functions. for (auto function : getModuleBody(mlirModule).getOps()) { // Ignore external functions. @@ -1014,6 +1037,11 @@ return failure(); } + // Delete translated functions that are intended for a device we are currently + // not generating code for. + for (auto *llvmFunction : llvm::reverse(maskedFunctions)) + llvmFunction->eraseFromParent(); + return success(); } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2537,3 +2537,41 @@ // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0 module attributes {omp.flags = #omp.flags, omp.is_device = #omp.isdevice} {} + +// ----- + +// CHECK: define void @any +// CHECK: define void @nohost +// CHECK-NOT: define void @host +module attributes {omp.is_device = #omp.isdevice} { + llvm.func @any() -> () attributes {omp.declare_target = #omp} { + llvm.return + } + llvm.func @nohost() -> () attributes {omp.declare_target = #omp} { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @host() -> () attributes {omp.declare_target = #omp} { + llvm.call @any() : () -> () + llvm.return + } +} + +// ----- + +// CHECK: define void @any +// CHECK-NOT: define void @nohost +// CHECK: define void @host +module attributes {omp.is_device = #omp.isdevice} { + llvm.func @any() -> () attributes {omp.declare_target = #omp} { + llvm.return + } + llvm.func @nohost() -> () attributes {omp.declare_target = #omp} { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @host() -> () attributes {omp.declare_target = #omp} { + llvm.call @any() : () -> () + llvm.return + } +}