diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -73,8 +73,11 @@ std::unique_ptr createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config); std::unique_ptr createPolymorphicOpConversionPass(); + std::unique_ptr> createOMPEarlyOutliningPass(); +std::unique_ptr createOMPFunctionFilteringPass(); + // declarative passes #define GEN_PASS_REGISTRATION #include "flang/Optimizer/Transforms/Passes.h.inc" diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -311,4 +311,13 @@ let dependentDialects = ["mlir::omp::OpenMPDialect"]; } +def OMPFunctionFiltering : Pass<"omp-function-filtering"> { + let summary = "Filters out functions intended for the host when compiling " + "for the device and vice versa."; + let constructor = "::fir::createOMPFunctionFilteringPass()"; + let dependentDialects = [ + "mlir::func::FuncDialect" + ]; +} + #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -312,6 +312,7 @@ if (isDevice) pm.addPass(fir::createOMPEarlyOutliningPass()); + pm.addPass(fir::createOMPFunctionFilteringPass()); } pm.enableVerifier(/*verifyPasses=*/true); diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ PolymorphicOpConversion.cpp LoopVersioning.cpp OMPEarlyOutlining.cpp + OMPFunctionFiltering.cpp DEPENDS FIRDialect diff --git a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp @@ -0,0 +1,73 @@ +//===- OMPFunctionFiltering.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to filter out functions intended for the host +// when compiling for the device and vice versa. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SmallVector.h" + +namespace fir { +#define GEN_PASS_DEF_OMPFUNCTIONFILTERING +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace fir; +using namespace mlir; + +namespace { +class OMPFunctionFilteringPass + : public fir::impl::OMPFunctionFilteringBase { +public: + OMPFunctionFilteringPass() = default; + + void runOnOperation() override { + auto op = dyn_cast(getOperation()); + if (!op) + return; + + bool isDeviceCompilation = op.getIsTargetDevice(); + op->walk([&](func::FuncOp funcOp) { + // Do not filter functions with target regions inside, because they have + // to be available for both host and device so that regular and reverse + // offloading can be supported. + bool hasTargetRegion = + funcOp + ->walk( + [&](omp::TargetOp) { return WalkResult::interrupt(); }) + .wasInterrupted(); + if (hasTargetRegion) + return; + + omp::DeclareTargetDeviceType declareType = + omp::DeclareTargetDeviceType::host; + auto declareTargetOp = + dyn_cast(funcOp.getOperation()); + if (declareTargetOp && declareTargetOp.isDeclareTarget()) + declareType = declareTargetOp.getDeclareTargetDeviceType(); + + if ((isDeviceCompilation && + declareType == omp::DeclareTargetDeviceType::host) || + (!isDeviceCompilation && + declareType == omp::DeclareTargetDeviceType::nohost)) + funcOp->erase(); + }); + } +}; +} // namespace + +std::unique_ptr fir::createOMPFunctionFilteringPass() { + return std::make_unique(); +} diff --git a/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 b/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 --- a/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 +++ b/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 @@ -1,51 +1,52 @@ -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes ALL,HOST +!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL,DEVICE ! Check specification valid forms of declare target with functions ! utilising device_type and to clauses as well as the default ! zero clause declare target -! CHECK-LABEL: func.func @_QPfunc_t_device() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! DEVICE-LABEL: func.func @_QPfunc_t_device() +! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_T_DEVICE() RESULT(I) !$omp declare target to(FUNC_T_DEVICE) device_type(nohost) INTEGER :: I I = 1 END FUNCTION FUNC_T_DEVICE -! CHECK-LABEL: func.func @_QPfunc_t_host() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! HOST-LABEL: func.func @_QPfunc_t_host() +! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_T_HOST() RESULT(I) !$omp declare target to(FUNC_T_HOST) device_type(host) INTEGER :: I I = 1 END FUNCTION FUNC_T_HOST -! CHECK-LABEL: func.func @_QPfunc_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_T_ANY() RESULT(I) !$omp declare target to(FUNC_T_ANY) device_type(any) INTEGER :: I I = 1 END FUNCTION FUNC_T_ANY -! CHECK-LABEL: func.func @_QPfunc_default_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_default_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_DEFAULT_T_ANY() RESULT(I) !$omp declare target to(FUNC_DEFAULT_T_ANY) INTEGER :: I I = 1 END FUNCTION FUNC_DEFAULT_T_ANY -! CHECK-LABEL: func.func @_QPfunc_default_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_default_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_DEFAULT_ANY() RESULT(I) !$omp declare target INTEGER :: I I = 1 END FUNCTION FUNC_DEFAULT_ANY -! CHECK-LABEL: func.func @_QPfunc_default_extendedlist() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_default_extendedlist() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_DEFAULT_EXTENDEDLIST() RESULT(I) !$omp declare target(FUNC_DEFAULT_EXTENDEDLIST) INTEGER :: I @@ -58,46 +59,46 @@ ! utilising device_type and to clauses as well as the default ! zero clause declare target -! CHECK-LABEL: func.func @_QPsubr_t_device() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! DEVICE-LABEL: func.func @_QPsubr_t_device() +! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_T_DEVICE() !$omp declare target to(SUBR_T_DEVICE) device_type(nohost) END -! CHECK-LABEL: func.func @_QPsubr_t_host() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! HOST-LABEL: func.func @_QPsubr_t_host() +! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_T_HOST() !$omp declare target to(SUBR_T_HOST) device_type(host) END -! CHECK-LABEL: func.func @_QPsubr_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_T_ANY() !$omp declare target to(SUBR_T_ANY) device_type(any) END -! CHECK-LABEL: func.func @_QPsubr_default_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_default_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_DEFAULT_T_ANY() !$omp declare target to(SUBR_DEFAULT_T_ANY) END -! CHECK-LABEL: func.func @_QPsubr_default_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_default_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_DEFAULT_ANY() !$omp declare target END -! CHECK-LABEL: func.func @_QPsubr_default_extendedlist() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_default_extendedlist() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_DEFAULT_EXTENDEDLIST() !$omp declare target(SUBR_DEFAULT_EXTENDEDLIST) END !! ----- -! CHECK-LABEL: func.func @_QPrecursive_declare_target -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! DEVICE-LABEL: func.func @_QPrecursive_declare_target +! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} RECURSIVE FUNCTION RECURSIVE_DECLARE_TARGET(INCREMENT) RESULT(K) !$omp declare target to(RECURSIVE_DECLARE_TARGET) device_type(nohost) INTEGER :: INCREMENT, K diff --git a/flang/test/Lower/OpenMP/omp-declare-target-program-var.f90 b/flang/test/Lower/OpenMP/omp-declare-target-program-var.f90 --- a/flang/test/Lower/OpenMP/omp-declare-target-program-var.f90 +++ b/flang/test/Lower/OpenMP/omp-declare-target-program-var.f90 @@ -1,12 +1,12 @@ -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s -!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes=HOST,ALL +!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefix=ALL PROGRAM main - ! CHECK-DAG: %0 = fir.alloca f32 {bindc_name = "i", uniq_name = "_QFEi"} + ! HOST-DAG: %0 = fir.alloca f32 {bindc_name = "i", uniq_name = "_QFEi"} REAL :: I - ! CHECK-DAG: fir.global internal @_QFEi {omp.declare_target = #omp.declaretarget} : f32 { - ! CHECK-DAG: %0 = fir.undefined f32 - ! CHECK-DAG: fir.has_value %0 : f32 - ! CHECK-DAG: } + ! ALL-DAG: fir.global internal @_QFEi {omp.declare_target = #omp.declaretarget} : f32 { + ! ALL-DAG: %0 = fir.undefined f32 + ! ALL-DAG: fir.has_value %0 : f32 + ! ALL-DAG: } !$omp declare target(I) END diff --git a/flang/test/Transforms/omp-function-filtering.mlir b/flang/test/Transforms/omp-function-filtering.mlir new file mode 100644 --- /dev/null +++ b/flang/test/Transforms/omp-function-filtering.mlir @@ -0,0 +1,111 @@ +// RUN: fir-opt -split-input-file --omp-function-filtering %s | FileCheck %s + +// CHECK: func.func @any +// CHECK: func.func @nohost +// CHECK-NOT: func.func @host +// CHECK-NOT: func.func @none +// CHECK: func.func @nohost_target +// CHECK: func.func @host_target +// CHECK: func.func @none_target +module attributes {omp.is_target_device = true} { + func.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @none() -> () { + func.return + } + func.func @nohost_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @host_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @none_target() -> () { + omp.target {} + func.return + } +} + +// ----- + +// CHECK: func.func @any +// CHECK-NOT: func.func @nohost +// CHECK: func.func @host +// CHECK: func.func @none +// CHECK: func.func @nohost_target +// CHECK: func.func @host_target +// CHECK: func.func @none_target +module attributes {omp.is_target_device = false} { + func.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @none() -> () { + func.return + } + func.func @nohost_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @host_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @none_target() -> () { + omp.target {} + func.return + } +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/IR/IRMapping.h" @@ -1667,6 +1668,39 @@ return bodyGenStatus; } +static LogicalResult +convertDeclareTargetAttr(Operation *op, + omp::DeclareTargetAttr declareTargetAttr, + LLVM::ModuleTranslation &moduleTranslation) { + // Amend omp.declare_target by deleting the IR of the outlined functions + // created for target regions. They cannot be filtered out from MLIR earlier + // because the omp.target operation inside must be translated to LLVM, but the + // wrapper functions themselves must not remain at the end of the process. + // We know that functions where omp.declare_target does not match + // omp.is_target_device at this stage can only be wrapper functions because + // those that aren't are removed earlier as an MLIR transformation pass. + if (FunctionOpInterface funcOp = dyn_cast(op)) { + if (auto offloadMod = dyn_cast( + op->getParentOfType().getOperation())) { + bool isDeviceCompilation = offloadMod.getIsTargetDevice(); + omp::DeclareTargetDeviceType declareType = + declareTargetAttr.getDeviceType().getValue(); + + if ((isDeviceCompilation && + declareType == omp::DeclareTargetDeviceType::host) || + (!isDeviceCompilation && + declareType == omp::DeclareTargetDeviceType::nohost)) { + llvm::Function *llvmFunc = + moduleTranslation.lookupFunction(funcOp.getName()); + llvmFunc->dropAllReferences(); + llvmFunc->eraseFromParent(); + } + return success(); + } + } + return failure(); +} + namespace { /// Implementation of the dialect interface that converts operations belonging @@ -1694,7 +1728,6 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const { - return llvm::TypeSwitch(attribute.getValue()) .Case([&](mlir::omp::FlagsAttr rtlAttr) { return convertFlagsAttr(op, rtlAttr, moduleTranslation); @@ -1706,6 +1739,10 @@ versionAttr.getVersion()); return success(); }) + .Case([&](mlir::omp::DeclareTargetAttr declareTargetAttr) { + return convertDeclareTargetAttr(op, declareTargetAttr, + moduleTranslation); + }) .Default([&](Attribute attr) { // fall through for omp attributes that do not require lowering and/or // have no concrete definition and thus no type to define a case on diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -931,6 +931,18 @@ // nodes to the results of preceding blocks. detail::connectPHINodes(func.getBody(), *this); + // Delete non-declare target functions from the OpenMP device compilation here + // because convertDialectAttributes() will skip them due to not having the + // omp.declare_target dialect attribute. + BoolAttr isDeviceAttr; + if (!func->hasAttr("omp.declare_target") && + (isDeviceAttr = + mlirModule->getAttrOfType("omp.is_target_device")) && + isDeviceAttr.getValue()) { + llvmFunc->dropAllReferences(); + llvmFunc->eraseFromParent(); + } + // Finally, convert dialect attributes attached to the function. return convertDialectAttributes(func); } diff --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir --- a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir @@ -2,7 +2,7 @@ // name stored in the omp.outline_parent_name attribute. // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s -module attributes {omp.is_device = true} { +module attributes {omp.is_target_device = true} { llvm.func @writeindex_omp_outline_0_(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {omp.outline_parent_name = "writeindex_"} { omp.target map((from -> %arg0 : !llvm.ptr), (implicit -> %arg1: !llvm.ptr)) { %0 = llvm.mlir.constant(20 : i32) : i32 diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2543,3 +2543,57 @@ // CHECK: @__omp_rtl_assume_no_thread_state = weak_odr hidden constant i32 1 // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0 module attributes {omp.flags = #omp.flags} {} + +// ----- + +module attributes {omp.is_target_device = false} { + // CHECK-NOT: @filter_host_nohost + llvm.func @filter_host_nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + + // CHECK: @filter_host_host + llvm.func @filter_host_host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + + // CHECK: @filter_host_none + llvm.func @filter_host_none() -> () { + llvm.return + } +} + +// ----- + +module attributes {omp.is_target_device = true} { + // CHECK: @filter_device_nohost + llvm.func @filter_device_nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + + // CHECK-NOT: @filter_device_host + llvm.func @filter_device_host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + llvm.return + } + + // CHECK-NOT: @filter_device_none + llvm.func @filter_device_none() -> () { + llvm.return + } +}