diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -73,6 +73,8 @@ createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config); std::unique_ptr createPolymorphicOpConversionPass(); +std::unique_ptr createOMPFunctionFilteringPass(); + // declarative passes #define GEN_PASS_REGISTRATION #include "flang/Optimizer/Transforms/Passes.h.inc" diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -298,4 +298,13 @@ let dependentDialects = [ "fir::FIROpsDialect" ]; } +def OMPFunctionFiltering : Pass<"omp-function-filtering"> { + let summary = "Filters out functions intended for the host when compiling " + "for the device and vice versa."; + let constructor = "::fir::createOMPFunctionFilteringPass()"; + let dependentDialects = [ + "mlir::func::FuncDialect" + ]; +} + #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -23,6 +23,7 @@ #include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/Support/InitFIR.h" #include "flang/Optimizer/Support/Utils.h" +#include "flang/Optimizer/Transforms/Passes.h" #include "flang/Parser/dump-parse-tree.h" #include "flang/Parser/parsing.h" #include "flang/Parser/provenance.h" @@ -300,6 +301,11 @@ mlir::PassManager pm((*mlirModule)->getName(), mlir::OpPassManager::Nesting::Implicit); pm.enableVerifier(/*verifyPasses=*/true); + + if (ci.getInvocation().getFrontendOpts().features.IsEnabled( + Fortran::common::LanguageFeature::OpenMP)) + pm.addPass(fir::createOMPFunctionFilteringPass()); + pm.addPass(std::make_unique()); if (mlir::failed(pm.run(*mlirModule))) { @@ -689,9 +695,13 @@ // Set-up the MLIR pass manager mlir::PassManager pm((*mlirModule)->getName(), mlir::OpPassManager::Nesting::Implicit); + pm.enableVerifier(/*verifyPasses=*/true); + + if (ci.getInvocation().getFrontendOpts().features.IsEnabled( + Fortran::common::LanguageFeature::OpenMP)) + pm.addPass(fir::createOMPFunctionFilteringPass()); pm.addPass(std::make_unique()); - pm.enableVerifier(/*verifyPasses=*/true); // Create the pass pipeline fir::createMLIRToLLVMPassPipeline(pm, level, opts.StackArrays, diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ AddDebugFoundation.cpp PolymorphicOpConversion.cpp LoopVersioning.cpp + OMPFunctionFiltering.cpp DEPENDS FIRDialect diff --git a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp @@ -0,0 +1,70 @@ +//===- OMPFunctionFiltering.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to filter out functions intended for the host +// when compiling for the device and vice versa. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SmallVector.h" + +namespace fir { +#define GEN_PASS_DEF_OMPFUNCTIONFILTERING +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace fir; +using namespace mlir; + +namespace { +class OMPFunctionFilteringPass + : public fir::impl::OMPFunctionFilteringBase { +public: + OMPFunctionFilteringPass() = default; + + void runOnOperation() override { + auto op = dyn_cast(getOperation()); + if (!op) + return; + + bool isDeviceCompilation = op.getIsDevice(); + op->walk([&](func::FuncOp funcOp) { + // Do not filter functions with target regions inside, because they have + // to be available for both host and device so that regular and reverse + // offloading can be supported. + bool hasTarget = false; + funcOp->walk([&](omp::TargetOp) { hasTarget = true; }); + if (hasTarget) + return; + + omp::DeclareTargetDeviceType declareType = + omp::DeclareTargetDeviceType::host; + auto declareTargetOp = + dyn_cast(funcOp.getOperation()); + if (declareTargetOp && declareTargetOp.isDeclareTarget()) + declareType = declareTargetOp.getDeclareTargetDeviceType(); + + if ((isDeviceCompilation && + declareType == omp::DeclareTargetDeviceType::host) || + (!isDeviceCompilation && + declareType == omp::DeclareTargetDeviceType::nohost)) + funcOp->erase(); + }); + } +}; +} // namespace + +std::unique_ptr fir::createOMPFunctionFilteringPass() { + return std::make_unique(); +} diff --git a/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 b/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 --- a/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 +++ b/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 @@ -1,51 +1,52 @@ -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes ALL,HOST +!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL,DEVICE ! Check specification valid forms of declare target with functions ! utilising device_type and to clauses as well as the default ! zero clause declare target -! CHECK-LABEL: func.func @_QPfunc_t_device() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! DEVICE-LABEL: func.func @_QPfunc_t_device() +! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_T_DEVICE() RESULT(I) !$omp declare target to(FUNC_T_DEVICE) device_type(nohost) INTEGER :: I I = 1 END FUNCTION FUNC_T_DEVICE -! CHECK-LABEL: func.func @_QPfunc_t_host() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! HOST-LABEL: func.func @_QPfunc_t_host() +! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_T_HOST() RESULT(I) !$omp declare target to(FUNC_T_HOST) device_type(host) INTEGER :: I I = 1 END FUNCTION FUNC_T_HOST -! CHECK-LABEL: func.func @_QPfunc_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_T_ANY() RESULT(I) !$omp declare target to(FUNC_T_ANY) device_type(any) INTEGER :: I I = 1 END FUNCTION FUNC_T_ANY -! CHECK-LABEL: func.func @_QPfunc_default_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_default_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_DEFAULT_T_ANY() RESULT(I) !$omp declare target to(FUNC_DEFAULT_T_ANY) INTEGER :: I I = 1 END FUNCTION FUNC_DEFAULT_T_ANY -! CHECK-LABEL: func.func @_QPfunc_default_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_default_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_DEFAULT_ANY() RESULT(I) !$omp declare target INTEGER :: I I = 1 END FUNCTION FUNC_DEFAULT_ANY -! CHECK-LABEL: func.func @_QPfunc_default_extendedlist() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_default_extendedlist() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_DEFAULT_EXTENDEDLIST() RESULT(I) !$omp declare target(FUNC_DEFAULT_EXTENDEDLIST) INTEGER :: I @@ -58,46 +59,46 @@ ! utilising device_type and to clauses as well as the default ! zero clause declare target -! CHECK-LABEL: func.func @_QPsubr_t_device() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! DEVICE-LABEL: func.func @_QPsubr_t_device() +! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_T_DEVICE() !$omp declare target to(SUBR_T_DEVICE) device_type(nohost) END -! CHECK-LABEL: func.func @_QPsubr_t_host() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! HOST-LABEL: func.func @_QPsubr_t_host() +! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_T_HOST() !$omp declare target to(SUBR_T_HOST) device_type(host) END -! CHECK-LABEL: func.func @_QPsubr_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_T_ANY() !$omp declare target to(SUBR_T_ANY) device_type(any) END -! CHECK-LABEL: func.func @_QPsubr_default_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_default_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_DEFAULT_T_ANY() !$omp declare target to(SUBR_DEFAULT_T_ANY) END -! CHECK-LABEL: func.func @_QPsubr_default_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_default_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_DEFAULT_ANY() !$omp declare target END -! CHECK-LABEL: func.func @_QPsubr_default_extendedlist() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_default_extendedlist() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_DEFAULT_EXTENDEDLIST() !$omp declare target(SUBR_DEFAULT_EXTENDEDLIST) END !! ----- -! CHECK-LABEL: func.func @_QPrecursive_declare_target -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! DEVICE-LABEL: func.func @_QPrecursive_declare_target +! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} RECURSIVE FUNCTION RECURSIVE_DECLARE_TARGET(INCREMENT) RESULT(K) !$omp declare target to(RECURSIVE_DECLARE_TARGET) device_type(nohost) INTEGER :: INCREMENT, K diff --git a/flang/test/Transforms/omp-function-filtering.mlir b/flang/test/Transforms/omp-function-filtering.mlir new file mode 100644 --- /dev/null +++ b/flang/test/Transforms/omp-function-filtering.mlir @@ -0,0 +1,111 @@ +// RUN: fir-opt -split-input-file --omp-function-filtering %s | FileCheck %s + +// CHECK: func.func @any +// CHECK: func.func @nohost +// CHECK-NOT: func.func @host +// CHECK-NOT: func.func @none +// CHECK: func.func @nohost_target +// CHECK: func.func @host_target +// CHECK: func.func @none_target +module attributes {omp.is_device = true} { + func.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @none() -> () { + func.return + } + func.func @nohost_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @host_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @none_target() -> () { + omp.target {} + func.return + } +} + +// ----- + +// CHECK: func.func @any +// CHECK-NOT: func.func @nohost +// CHECK: func.func @host +// CHECK: func.func @none +// CHECK: func.func @nohost_target +// CHECK: func.func @host_target +// CHECK: func.func @none_target +module attributes {omp.is_device = false} { + func.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @none() -> () { + func.return + } + func.func @nohost_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @host_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @none_target() -> () { + omp.target {} + func.return + } +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/IR/IRMapping.h" @@ -1686,6 +1687,37 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const { + // Amend omp.declare_target by deleting the IR of the outlined functions + // created for target regions. They cannot be filtered out from MLIR earlier + // because the omp.target operation inside must be translated to LLVM, but the + // wrapper functions themselves must not remain at the end of the process. + // We know that functions where omp.declare_target does not match + // omp.is_device at this stage can only be wrapper functions because those + // that aren't are removed earlier as an MLIR transformation pass. + if (attribute.getName() == "omp.declare_target") { + if (FunctionOpInterface funcOp = dyn_cast(op)) { + if (auto offloadMod = dyn_cast( + op->getParentOfType().getOperation())) { + auto declareTarget = dyn_cast(op); + + bool isDeviceCompilation = offloadMod.getIsDevice(); + omp::DeclareTargetDeviceType declareType = + omp::DeclareTargetDeviceType::host; + if (declareTarget && declareTarget.isDeclareTarget()) + declareType = declareTarget.getDeclareTargetDeviceType(); + + if ((isDeviceCompilation && + declareType == omp::DeclareTargetDeviceType::host) || + (!isDeviceCompilation && + declareType == omp::DeclareTargetDeviceType::nohost)) { + llvm::Function *llvmFunc = + moduleTranslation.lookupFunction(funcOp.getName()); + llvmFunc->dropAllReferences(); + llvmFunc->eraseFromParent(); + } + } + } + } return llvm::TypeSwitch(attribute.getValue()) .Case([&](mlir::omp::FlagsAttr rtlAttr) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -930,6 +930,17 @@ // nodes to the results of preceding blocks. detail::connectPHINodes(func.getBody(), *this); + // Delete non-declare target functions from the OpenMP device compilation here + // because convertDialectAttributes() will skip them due to not having the + // omp.declare_target dialect attribute. + BoolAttr isDeviceAttr; + if (!func->hasAttr("omp.declare_target") && + (isDeviceAttr = mlirModule->getAttrOfType("omp.is_device")) && + isDeviceAttr.getValue()) { + llvmFunc->dropAllReferences(); + llvmFunc->eraseFromParent(); + } + // Finally, convert dialect attributes attached to the function. return convertDialectAttributes(func); } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2543,3 +2543,57 @@ // CHECK: @__omp_rtl_assume_no_thread_state = weak_odr hidden constant i32 1 // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0 module attributes {omp.flags = #omp.flags} {} + +// ----- + +module attributes {omp.is_device = false} { + // CHECK-NOT: @filter_host_nohost + llvm.func @filter_host_nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + + // CHECK: @filter_host_host + llvm.func @filter_host_host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + + // CHECK: @filter_host_none + llvm.func @filter_host_none() -> () { + llvm.return + } +} + +// ----- + +module attributes {omp.is_device = true} { + // CHECK: @filter_device_nohost + llvm.func @filter_device_nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + + // CHECK-NOT: @filter_device_host + llvm.func @filter_device_host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + + // CHECK-NOT: @filter_device_none + llvm.func @filter_device_none() -> () { + llvm.return + } +}