diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -33,6 +33,7 @@ #include "flang/Semantics/unparse-with-symbols.h" #include "flang/Tools/CrossToolHelpers.h" +#include "mlir/Dialect/OpenMP/OpenMPPasses.h" #include "mlir/IR/Dialect.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" @@ -301,6 +302,7 @@ mlir::OpPassManager::Nesting::Implicit); pm.enableVerifier(/*verifyPasses=*/true); pm.addPass(std::make_unique()); + pm.addPass(mlir::omp::createFilterDeviceHostFunctionsPass()); if (mlir::failed(pm.run(*mlirModule))) { unsigned diagID = ci.getDiagnostics().getCustomDiagID( @@ -691,6 +693,7 @@ mlir::OpPassManager::Nesting::Implicit); pm.addPass(std::make_unique()); + pm.addPass(mlir::omp::createFilterDeviceHostFunctionsPass()); pm.enableVerifier(/*verifyPasses=*/true); // Create the pass pipeline diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90 --- a/flang/test/Driver/mlir-debug-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90 @@ -25,6 +25,7 @@ ! ALL: Pass statistics report ! ALL: Fortran::lower::VerifierPass +! ALL-NEXT: FilterDeviceHostFunctions ! ALL-NEXT: 'func.func' Pipeline ! ALL-NEXT: InlineElementals ! ALL-NEXT: LowerHLFIROrderedAssignments diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -12,6 +12,7 @@ ! ALL: Pass statistics report ! ALL: Fortran::lower::VerifierPass +! ALL-NEXT: FilterDeviceHostFunctions ! O2-NEXT: Canonicalizer ! O2-NEXT: 'func.func' Pipeline ! O2-NEXT: SimplifyHLFIRIntrinsics diff --git a/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 b/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 --- a/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 +++ b/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 @@ -1,51 +1,52 @@ -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes ALL,HOST +!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL,DEVICE ! Check specification valid forms of declare target with functions ! utilising device_type and to clauses as well as the default ! zero clause declare target -! CHECK-LABEL: func.func @_QPfunc_t_device() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! DEVICE-LABEL: func.func @_QPfunc_t_device() +! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_T_DEVICE() RESULT(I) !$omp declare target to(FUNC_T_DEVICE) device_type(nohost) INTEGER :: I I = 1 END FUNCTION FUNC_T_DEVICE -! CHECK-LABEL: func.func @_QPfunc_t_host() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! HOST-LABEL: func.func @_QPfunc_t_host() +! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_T_HOST() RESULT(I) !$omp declare target to(FUNC_T_HOST) device_type(host) INTEGER :: I I = 1 END FUNCTION FUNC_T_HOST -! CHECK-LABEL: func.func @_QPfunc_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_T_ANY() RESULT(I) !$omp declare target to(FUNC_T_ANY) device_type(any) INTEGER :: I I = 1 END FUNCTION FUNC_T_ANY -! CHECK-LABEL: func.func @_QPfunc_default_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_default_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_DEFAULT_T_ANY() RESULT(I) !$omp declare target to(FUNC_DEFAULT_T_ANY) INTEGER :: I I = 1 END FUNCTION FUNC_DEFAULT_T_ANY -! CHECK-LABEL: func.func @_QPfunc_default_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_default_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_DEFAULT_ANY() RESULT(I) !$omp declare target INTEGER :: I I = 1 END FUNCTION FUNC_DEFAULT_ANY -! CHECK-LABEL: func.func @_QPfunc_default_extendedlist() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPfunc_default_extendedlist() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} FUNCTION FUNC_DEFAULT_EXTENDEDLIST() RESULT(I) !$omp declare target(FUNC_DEFAULT_EXTENDEDLIST) INTEGER :: I @@ -58,46 +59,46 @@ ! utilising device_type and to clauses as well as the default ! zero clause declare target -! CHECK-LABEL: func.func @_QPsubr_t_device() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! DEVICE-LABEL: func.func @_QPsubr_t_device() +! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_T_DEVICE() !$omp declare target to(SUBR_T_DEVICE) device_type(nohost) END -! CHECK-LABEL: func.func @_QPsubr_t_host() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! HOST-LABEL: func.func @_QPsubr_t_host() +! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_T_HOST() !$omp declare target to(SUBR_T_HOST) device_type(host) END -! CHECK-LABEL: func.func @_QPsubr_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_T_ANY() !$omp declare target to(SUBR_T_ANY) device_type(any) END -! CHECK-LABEL: func.func @_QPsubr_default_t_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_default_t_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_DEFAULT_T_ANY() !$omp declare target to(SUBR_DEFAULT_T_ANY) END -! CHECK-LABEL: func.func @_QPsubr_default_any() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_default_any() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_DEFAULT_ANY() !$omp declare target END -! CHECK-LABEL: func.func @_QPsubr_default_extendedlist() -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! ALL-LABEL: func.func @_QPsubr_default_extendedlist() +! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} SUBROUTINE SUBR_DEFAULT_EXTENDEDLIST() !$omp declare target(SUBR_DEFAULT_EXTENDEDLIST) END !! ----- -! CHECK-LABEL: func.func @_QPrecursive_declare_target -! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} +! DEVICE-LABEL: func.func @_QPrecursive_declare_target +! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget{{.*}} RECURSIVE FUNCTION RECURSIVE_DECLARE_TARGET(INCREMENT) RESULT(K) !$omp declare target to(RECURSIVE_DECLARE_TARGET) device_type(nohost) INTEGER :: INCREMENT, K diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt --- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt @@ -21,3 +21,10 @@ mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIROpenMPTypeInterfacesIncGen) add_dependencies(mlir-generic-headers MLIROpenMPTypeInterfacesIncGen) + +set(LLVM_TARGET_DEFINITIONS OpenMPPasses.td) +mlir_tablegen(OpenMPPasses.h.inc -gen-pass-decls -name OpenMP) +mlir_tablegen(OpenMPPasses.capi.h.inc -gen-pass-capi-header --prefix OpenMP) +mlir_tablegen(OpenMPPasses.capi.cpp.inc -gen-pass-capi-impl --prefix OpenMP) +add_public_tablegen_target(MLIROpenMPPassIncGen) +add_mlir_doc(Passes OpenMPPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPPasses.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPPasses.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPPasses.h @@ -0,0 +1,39 @@ +//===- OpenMPPasses.h - OpenMP pass entry points ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_OPENMP_OPENMPPASSES_H_ +#define MLIR_DIALECT_OPENMP_OPENMPPASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace omp { + +#define GEN_PASS_DECL +#include "mlir/Dialect/OpenMP/OpenMPPasses.h.inc" + +/// Create a pass to filter out functions intended for the host when compiling +/// for the device and vice versa. +std::unique_ptr createFilterDeviceHostFunctionsPass(); + +} // namespace omp + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/OpenMP/OpenMPPasses.h.inc" + +} // namespace mlir + +#endif // MLIR_DIALECT_OPENMP_OPENMPPASSES_H_ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPPasses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPPasses.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPPasses.td @@ -0,0 +1,22 @@ +//===-- OpenMPPasses.td - OpenMp pass definition file ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENMP_OPENMPPASSES_TD_ +#define MLIR_DIALECT_OPENMP_OPENMPPASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def FilterDeviceHostFunctions : Pass<"omp-filter-device-host-functions"> { + let summary = "Filters out functions intended for the host when compiling for the device and vice versa."; + let constructor = "mlir::omp::createFilterDeviceHostFunctionsPass()"; + let dependentDialects = [ + "func::FuncDialect" + ]; +} + +#endif // MLIR_DIALECT_OPENMP_OPENMPPASSES_TD_ diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -28,6 +28,7 @@ #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/NVGPU/Passes.h" +#include "mlir/Dialect/OpenMP/OpenMPPasses.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" @@ -73,6 +74,7 @@ LLVM::registerLLVMPasses(); math::registerMathPasses(); memref::registerMemRefPasses(); + registerOpenMPPasses(); registerSCFPasses(); registerShapePasses(); spirv::registerSPIRVPasses(); diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt --- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt @@ -14,3 +14,5 @@ MLIRLLVMDialect MLIRFuncDialect ) + +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIROpenMPTransforms + FilterDeviceHostFunctions.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP + + DEPENDS + MLIROpenMPPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRFuncDialect + MLIROpenMPDialect + MLIRPass + MLIRTransforms +) diff --git a/mlir/lib/Dialect/OpenMP/Transforms/FilterDeviceHostFunctions.cpp b/mlir/lib/Dialect/OpenMP/Transforms/FilterDeviceHostFunctions.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Transforms/FilterDeviceHostFunctions.cpp @@ -0,0 +1,72 @@ +//===- FilterDeviceHostFunctions.cpp - MLIR OpenMP pass implementation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to filter out functions intended for the host +// when compiling for the device and vice versa. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenMP/OpenMPPasses.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace omp { +#define GEN_PASS_DEF_FILTERDEVICEHOSTFUNCTIONS +#include "mlir/Dialect/OpenMP/OpenMPPasses.h.inc" +} // namespace omp +} // namespace mlir + +using namespace mlir; +using namespace mlir::omp; + +namespace { +class FilterDeviceHostFunctionsPass + : public omp::impl::FilterDeviceHostFunctionsBase< + FilterDeviceHostFunctionsPass> { +public: + FilterDeviceHostFunctionsPass() = default; + + void runOnOperation() override { + auto op = dyn_cast(getOperation()); + if (!op) + return; + + bool isDeviceCompilation = op.getIsDevice(); + op->walk([&](func::FuncOp funcOp) { + // Do not filter functions with target regions inside, because they have + // to be available for both host and device so that regular and reverse + // offloading can be supported. + bool hasTarget = false; + funcOp->walk([&](TargetOp) { hasTarget = true; }); + if (hasTarget) + return; + + DeclareTargetDeviceType declareType = DeclareTargetDeviceType::host; + auto declareTargetOp = + dyn_cast(funcOp.getOperation()); + if (declareTargetOp && declareTargetOp.isDeclareTarget()) + declareType = declareTargetOp.getDeclareTargetDeviceType(); + + if ((isDeviceCompilation && + declareType == DeclareTargetDeviceType::host) || + (!isDeviceCompilation && + declareType == DeclareTargetDeviceType::nohost)) + funcOp->erase(); + }); + } +}; +} // namespace + +std::unique_ptr mlir::omp::createFilterDeviceHostFunctionsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/OpenMP/filter-device-host-functions.mlir b/mlir/test/Dialect/OpenMP/filter-device-host-functions.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/OpenMP/filter-device-host-functions.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt %s -split-input-file --pass-pipeline='builtin.module(omp-filter-device-host-functions)' | FileCheck %s + +// CHECK: func.func @any +// CHECK: func.func @nohost +// CHECK-NOT: func.func @host +// CHECK-NOT: func.func @none +// CHECK: func.func @nohost_target +// CHECK: func.func @host_target +// CHECK: func.func @none_target +module attributes {omp.is_device = true} { + func.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @none() -> () { + func.return + } + func.func @nohost_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @host_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @none_target() -> () { + omp.target {} + func.return + } +} + +// ----- + +// CHECK: func.func @any +// CHECK-NOT: func.func @nohost +// CHECK: func.func @host +// CHECK: func.func @none +// CHECK: func.func @nohost_target +// CHECK: func.func @host_target +// CHECK: func.func @none_target +module attributes {omp.is_device = false} { + func.func @any() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @nohost() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @host() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + func.return + } + func.func @none() -> () { + func.return + } + func.func @nohost_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @host_target() -> () + attributes { + omp.declare_target = + #omp.declaretarget + } { + omp.target {} + func.return + } + func.func @none_target() -> () { + omp.target {} + func.return + } +}