diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -345,6 +345,10 @@ /// A cache for the symbol tables constructed during symbols lookup. SymbolTableCollection symbolTableCollection; + + /// The set of functions that should be masked out from the output, due to not + /// corresponding to the target device being compiled for. + llvm::SmallVector maskedFunctions; }; namespace detail { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -904,7 +904,28 @@ detail::connectPHINodes(func.getBody(), *this); // Finally, convert dialect attributes attached to the function. - return convertDialectAttributes(func); + LogicalResult result = convertDialectAttributes(func); + + // All functions are translated first to ensure target regions are always + // processed, so that they can be outlined. However, they must be deleted + // afterwards if the device they are intended for does not match the device we + // are currently generating code for. + if (auto offloadMod = + dyn_cast(mlirModule)) { + bool isDevicePass = offloadMod.getIsDevice(); + omp::DeclareTargetDeviceType declareType = + omp::DeclareTargetDeviceType::host; + + if (omp::OpenMPDialect::isDeclareTarget(func.getOperation())) + declareType = + omp::OpenMPDialect::getDeclareTargetDeviceType(func.getOperation()); + + if ((isDevicePass && declareType == omp::DeclareTargetDeviceType::host) || + (!isDevicePass && declareType == omp::DeclareTargetDeviceType::nohost)) + maskedFunctions.push_back(llvmFunc); + } + + return result; } LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) { @@ -1010,6 +1031,8 @@ } LogicalResult ModuleTranslation::convertFunctions() { + maskedFunctions.clear(); + // Convert functions. for (auto function : getModuleBody(mlirModule).getOps()) { // Ignore external functions. @@ -1020,6 +1043,11 @@ return failure(); } + // Delete translated functions that are intended for a device we are currently + // not generating code for. + for (auto *llvmFunction : llvm::reverse(maskedFunctions)) + llvmFunction->eraseFromParent(); + return success(); } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2493,3 +2493,75 @@ // CHECK: @__omp_rtl_assume_no_thread_state = weak_odr hidden constant i32 1 // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0 module attributes {omp.flags = #omp.flags} {} + +// ----- + +// CHECK: define void @any +// CHECK: define void @nohost +// CHECK-NOT: define void @host +// CHECK-NOT: define void @no_declare_target +module attributes {omp.is_device = #omp.isdevice} { + llvm.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + llvm.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @no_declare_target() -> () { + llvm.call @host() : () -> () + llvm.return + } +} + +// ----- + +// CHECK: define void @any +// CHECK-NOT: define void @nohost +// CHECK: define void @host +// CHECK: define void @no_declare_target +module attributes {omp.is_device = #omp.isdevice} { + llvm.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + llvm.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.call @any() : () -> () + llvm.return + } + llvm.func @no_declare_target() -> () { + llvm.call @host() : () -> () + llvm.return + } +}