diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -345,6 +345,10 @@ /// A cache for the symbol tables constructed during symbols lookup. SymbolTableCollection symbolTableCollection; + + /// The set of functions that should be masked out from the output, due to not + /// corresponding to the target device being compiled for. + llvm::SmallVector maskedFunctions; }; namespace detail { diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt --- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt @@ -12,4 +12,5 @@ LINK_LIBS PUBLIC MLIRIR MLIRLLVMDialect + MLIRFuncDialect ) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -919,7 +919,29 @@ detail::connectPHINodes(func.getBody(), *this); // Finally, convert dialect attributes attached to the function. - return convertDialectAttributes(func); + LogicalResult result = convertDialectAttributes(func); + + // All functions are translated first to ensure target regions are always + // processed, so that they can be outlined. However, they must be deleted + // afterwards if the device they are intended for does not match the device we + // are currently generating code for. + if (auto offloadMod = + dyn_cast(mlirModule)) { + bool isDevicePass = offloadMod.getIsDevice(); + omp::DeclareTargetDeviceType declareType = + omp::DeclareTargetDeviceType::host; + + auto declareTargetOp = + cast(func.getOperation()); + if (declareTargetOp.isDeclareTarget()) + declareType = declareTargetOp.getDeclareTargetDeviceType(); + + if ((isDevicePass && declareType == omp::DeclareTargetDeviceType::host) || + (!isDevicePass && declareType == omp::DeclareTargetDeviceType::nohost)) + maskedFunctions.push_back(llvmFunc); + } + + return result; } LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) { @@ -1025,6 +1047,8 @@ } LogicalResult ModuleTranslation::convertFunctions() { + maskedFunctions.clear(); + // Convert functions. for (auto function : getModuleBody(mlirModule).getOps()) { // Ignore external functions. @@ -1035,6 +1059,11 @@ return failure(); } + // Delete translated functions that are intended for a device we are currently + // not generating code for. + for (auto *llvmFunction : llvm::reverse(maskedFunctions)) + llvmFunction->eraseFromParent(); + return success(); } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2493,3 +2493,75 @@ // CHECK: @__omp_rtl_assume_no_thread_state = weak_odr hidden constant i32 1 // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0 module attributes {omp.flags = #omp.flags} {} + +// ----- + +// CHECK: define void @any +// CHECK: define void @nohost +// CHECK-NOT: define void @host +// CHECK-NOT: define void @no_declare_target +module attributes {omp.is_device = #omp.isdevice} { + llvm.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + llvm.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @no_declare_target() -> () { + llvm.call @host() : () -> () + llvm.return + } +} + +// ----- + +// CHECK: define void @any +// CHECK-NOT: define void @nohost +// CHECK: define void @host +// CHECK: define void @no_declare_target +module attributes {omp.is_device = #omp.isdevice} { + llvm.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + llvm.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @no_declare_target() -> () { + llvm.call @host() : () -> () + llvm.return + } +}