diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2184,6 +2184,108 @@ converter.bindSymbol(sym, symThreadprivateExv); } +void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + std::vector symbols; + auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList) { + for (const auto &ompObject : objList.v) { + Fortran::common::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::Designator &designator) { + if (const Fortran::parser::Name *name = + getDesignatorNameIfDataRef(designator)) { + symbols.push_back(*name->symbol); + } + }, + [&](const Fortran::parser::Name &name) { + symbols.push_back(*name.symbol); + }}, + ompObject.u); + } + }; + + const auto &spec{std::get( + declareTargetConstruct.t)}; + auto mod = converter.getFirOpBuilder().getModule(); + + // The default capture type + auto deviceType = Fortran::parser::OmpDeviceTypeClause::Type::Any; + if (const auto *objectList{ + Fortran::parser::Unwrap(spec.u)}) { + // Case: declare target(func, var1, var2) + findFuncAndVarSyms(*objectList); + } else if (const auto *clauseList{ + Fortran::parser::Unwrap( + spec.u)}) { + if (clauseList->v.empty()) { + // Case: declare target, implicit capture of function + symbols.push_back(eval.getOwningProcedure()->getSubprogramSymbol()); + } + + for (const auto &clause : clauseList->v) { + if (const auto *toClause{ + std::get_if(&clause.u)}) { + // Case: declare target to(func, var1, var2)... + findFuncAndVarSyms(toClause->v); + } else if (const auto *linkClause{ + std::get_if( + &clause.u)}) { + // Case: declare target link(var1, var2)... + TODO(converter.getCurrentLocation(), + "the link clause is currently unsupported"); + } else if (const auto *deviceClause{ + std::get_if( + &clause.u)}) { + // Case: declare target ... device_type(any | host | nohost) + deviceType = deviceClause->v.v; + } + } + } + + for (auto sym : symbols) { + auto *op = mod.lookupSymbol(converter.mangleName(sym)); + + // TODO: Remove this cast and TODO assert when global data and link are + // supported + mlir::func::FuncOp fOp = mlir::dyn_cast(op); + if (!fOp) + TODO(converter.getCurrentLocation(), + "only subroutines and functions are currently supported"); + + // The function already has a declare target applied to it, very + // likely through implicit capture (usage in another declare target + // function/subroutine). It should be marked as any if it has been + // assigned both host and nohost, else we skip, as there is no + // change + if (mlir::omp::OpenMPDialect::isDeclareTarget(fOp)) { + auto currentDeclTar = + mlir::omp::OpenMPDialect::getDeclareTargetDeviceType(fOp); + if ((currentDeclTar == mlir::omp::DeclareTargetDeviceType::nohost && + deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Host) || + (currentDeclTar == mlir::omp::DeclareTargetDeviceType::host && + deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Nohost)) { + mlir::omp::OpenMPDialect::setDeclareTarget( + op, mlir::omp::DeclareTargetDeviceType::any); + } + + continue; + } + + if (deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Nohost) { + mlir::omp::OpenMPDialect::setDeclareTarget( + op, mlir::omp::DeclareTargetDeviceType::nohost); + } else if (deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Host) { + mlir::omp::OpenMPDialect::setDeclareTarget( + op, mlir::omp::DeclareTargetDeviceType::host); + } else if (deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Any) { + mlir::omp::OpenMPDialect::setDeclareTarget( + op, mlir::omp::DeclareTargetDeviceType::any); + } + } +} + void Fortran::lower::genOpenMPDeclarativeConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -2206,8 +2308,7 @@ }, [&](const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { - TODO(converter.getCurrentLocation(), - "OpenMPDeclareTargetConstruct"); + handleDeclareTarget(converter, eval, declareTargetConstruct); }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt --- a/flang/lib/Semantics/CMakeLists.txt +++ b/flang/lib/Semantics/CMakeLists.txt @@ -28,6 +28,7 @@ data-to-inits.cpp definable.cpp expression.cpp + finalize-omp.cpp mod-file.cpp pointer-assignment.cpp program-tree.cpp diff --git a/flang/lib/Semantics/finalize-omp.h b/flang/lib/Semantics/finalize-omp.h new file mode 100644 --- /dev/null +++ b/flang/lib/Semantics/finalize-omp.h @@ -0,0 +1,21 @@ +//===-- lib/Semantics/finalize-omp.h ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_SEMANTICS_FINALIZE_OMP_H_ +#define FORTRAN_SEMANTICS_FINALIZE_OMP_H_ + +namespace Fortran::parser { +struct Program; +} // namespace Fortran::parser + +namespace Fortran::semantics { +class SemanticsContext; +bool FinalizeOMP(SemanticsContext &context, parser::Program &program); +} // namespace Fortran::semantics + +#endif // FORTRAN_SEMANTICS_FINALIZE_OMP_H_ \ No newline at end of file diff --git a/flang/lib/Semantics/finalize-omp.cpp b/flang/lib/Semantics/finalize-omp.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Semantics/finalize-omp.cpp @@ -0,0 +1,141 @@ +#include "finalize-omp.h" +#include "flang/Parser/parse-tree-visitor.h" +#include "flang/Semantics/tools.h" + +#include +#include +#include + +namespace Fortran::semantics { + +using namespace parser::literals; + +class GatherCallRefs { +public: + GatherCallRefs() {} + + // Default action for a parse tree node is to visit children. + template bool Pre(T &) { return true; } + template void Post(T &) {} + + void Post(parser::Call &call) { + if (std::holds_alternative(std::get<0>(call.t).u)) + callNames.push_back(std::get(std::get<0>(call.t).u)); + } + + std::list callNames; +}; + +class ImplicitDeclareTargetCapture { +public: + template bool Pre(T &) { return true; } + template void Post(T &) {} + ImplicitDeclareTargetCapture(SemanticsContext &context) + : messages_{context.messages()} {} + + // Related to rewriting declare target specifiers to + // contain functions nested within the primary declare + // target function. + void Post(parser::OpenMPDeclareTargetConstruct &x) { + auto &spec{std::get(x.t)}; + if (parser::OmpObjectList * + objectList{parser::Unwrap(spec.u)}) { + markDeclTarForEachProgramInList(programUnits_, *objectList); + } else if (auto *clauseList{ + parser::Unwrap(spec.u)}) { + for (auto &clause : clauseList->v) { + if (auto *toClause{std::get_if(&clause.u)}) { + markDeclTarForEachProgramInList(programUnits_, toClause->v); + } else if (auto *linkClause{ + std::get_if(&clause.u)}) { + markDeclTarForEachProgramInList(programUnits_, linkClause->v); + } + } + + // The default "declare target" inside of a function case, we must + // create and generate an to extended-list, containing at minimum the + // current function + if (clauseList->v.empty()) { + if (auto *name = getNameFromProgramUnit(*currentProgramUnit_)) { + std::list list; + list.push_back(parser::OmpObject{ + parser::Designator{parser::DataRef{std::move(*name)}}}); + auto objList = parser::OmpObjectList{std::move(list)}; + markDeclTarForEachProgramInList(programUnits_, objList); + clauseList->v.push_back(parser::OmpClause::To{std::move(objList)}); + } + } + } + } + + bool Pre(parser::ProgramUnit &x) { + currentProgramUnit_ = &x; + if (auto *name = getNameFromProgramUnit(x)) + programUnits_[name->ToString()] = &x; + return true; + } + + parser::Name *getNameFromProgramUnit(parser::ProgramUnit &x) { + if (auto *func{parser::Unwrap(x.u)}) { + parser::FunctionStmt &Stmt = std::get<0>(func->t).statement; + return &std::get(Stmt.t); + } else if (auto *subr{parser::Unwrap(x.u)}) { + parser::SubroutineStmt &Stmt = std::get<0>(subr->t).statement; + return &std::get(Stmt.t); + } + return nullptr; + } + + void markDeclTarForEachProgramInList( + std::map programUnits, + parser::OmpObjectList &objList) { + auto existsInList = [](parser::OmpObjectList &objList, parser::Name name) { + for (auto &ompObject : objList.v) + if (auto *objName{parser::Unwrap(ompObject)}) + if (objName->ToString() == name.ToString()) + return true; + return false; + }; + + GatherCallRefs gatherer{}; + for (auto &ompObject : objList.v) { + if (auto *name{parser::Unwrap(ompObject)}) { + auto programUnit = programUnits.find(name->ToString()); + // something other than a subroutine or function, skip it + if (programUnit == programUnits.end()) + continue; + + parser::Walk(*programUnit->second, gatherer); + + // Currently using the Function Name rather than the CallRef name, + // unsure if these are interchangeable. However, ideally functions + // and subroutines should probably be parser::PorcedureDesignator's + // rather than parser::Designator's, but regular designators seem + // to be all that is utilised in the PFT definition for OmpObjects. + for (auto v : gatherer.callNames) { + if (!existsInList(objList, v)) { + objList.v.push_back(parser::OmpObject{parser::Designator{ + parser::DataRef{std::move(*getNameFromProgramUnit( + *programUnits.find(v.ToString())->second))}}}); + } + } + + gatherer.callNames.clear(); + } + } + } + +private: + std::map programUnits_; + parser::ProgramUnit *currentProgramUnit_ = nullptr; + + parser::Messages &messages_; +}; + +bool FinalizeOMP(SemanticsContext &context, parser::Program &program) { + ImplicitDeclareTargetCapture impCap{context}; + Walk(program, impCap); + return !context.AnyFatalError(); +} + +} // namespace Fortran::semantics diff --git a/flang/lib/Semantics/semantics.cpp b/flang/lib/Semantics/semantics.cpp --- a/flang/lib/Semantics/semantics.cpp +++ b/flang/lib/Semantics/semantics.cpp @@ -31,6 +31,7 @@ #include "check-select-type.h" #include "check-stop.h" #include "compute-offsets.h" +#include "finalize-omp.h" #include "mod-file.h" #include "resolve-labels.h" #include "resolve-names.h" @@ -170,6 +171,7 @@ ResolveNames(context, program, context.globalScope()); RewriteParseTree(context, program); ComputeOffsets(context, context.globalScope()); + FinalizeOMP(context, program); CheckDeclarations(context); StatementSemanticsPass1{context}.Walk(program); StatementSemanticsPass2 pass2{context}; diff --git a/flang/test/Lower/OpenMP/Todo/omp-declare-target.f90 b/flang/test/Lower/OpenMP/Todo/omp-declare-target.f90 deleted file mode 100644 --- a/flang/test/Lower/OpenMP/Todo/omp-declare-target.f90 +++ /dev/null @@ -1,12 +0,0 @@ -! This test checks lowering of OpenMP declare target Directive. - -// RUN: not flang-new -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s - -module mod1 -contains - subroutine sub() - integer :: x, y - // CHECK: not yet implemented: OpenMPDeclareTargetConstruct - !$omp declare target - end -end module diff --git a/flang/test/Lower/OpenMP/omp-declare-target.f90 b/flang/test/Lower/OpenMP/omp-declare-target.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-declare-target.f90 @@ -0,0 +1,262 @@ +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! Check specification valid forms of declare target with functions +! utilising device_type and to clauses as well as the default +! zero clause declare target + +! CHECK-LABEL: func.func @_QPfunc_t_device() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_T_DEVICE() RESULT(I) +!$omp declare target to(FUNC_T_DEVICE) device_type(nohost) + INTEGER :: I + I = 1 +END FUNCTION FUNC_T_DEVICE + +! CHECK-LABEL: func.func @_QPfunc_t_host() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_T_HOST() RESULT(I) +!$omp declare target to(FUNC_T_HOST) device_type(host) + INTEGER :: I + I = 1 +END FUNCTION FUNC_T_HOST + +! CHECK-LABEL: func.func @_QPfunc_t_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_T_ANY() RESULT(I) +!$omp declare target to(FUNC_T_ANY) device_type(any) + INTEGER :: I + I = 1 +END FUNCTION FUNC_T_ANY + +! CHECK-LABEL: func.func @_QPfunc_default_t_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_DEFAULT_T_ANY() RESULT(I) +!$omp declare target to(FUNC_DEFAULT_T_ANY) + INTEGER :: I + I = 1 +END FUNCTION FUNC_DEFAULT_T_ANY + +! CHECK-LABEL: func.func @_QPfunc_default_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_DEFAULT_ANY() RESULT(I) +!$omp declare target + INTEGER :: I + I = 1 +END FUNCTION FUNC_DEFAULT_ANY + +! CHECK-LABEL: func.func @_QPfunc_default_extendedlist() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_DEFAULT_EXTENDEDLIST() RESULT(I) +!$omp declare target(FUNC_DEFAULT_EXTENDEDLIST) + INTEGER :: I + I = 1 +END FUNCTION FUNC_DEFAULT_EXTENDEDLIST + +! CHECK-LABEL: func.func @_QPexist_on_both() +! CHECK-NOT: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION EXIST_ON_BOTH() RESULT(I) + INTEGER :: I + I = 1 +END FUNCTION EXIST_ON_BOTH + +!! ----- + +! Check specification valid forms of declare target with subroutines +! utilising device_type and to clauses as well as the default +! zero clause declare target + +! CHECK-LABEL: func.func @_QPsubr_t_device() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_T_DEVICE() +!$omp declare target to(SUBR_T_DEVICE) device_type(nohost) +END + +! CHECK-LABEL: func.func @_QPsubr_t_host() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_T_HOST() +!$omp declare target to(SUBR_T_HOST) device_type(host) +END + +! CHECK-LABEL: func.func @_QPsubr_t_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_T_ANY() +!$omp declare target to(SUBR_T_ANY) device_type(any) +END + +! CHECK-LABEL: func.func @_QPsubr_default_t_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_DEFAULT_T_ANY() +!$omp declare target to(SUBR_DEFAULT_T_ANY) +END + +! CHECK-LABEL: func.func @_QPsubr_default_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_DEFAULT_ANY() +!$omp declare target +END + +! CHECK-LABEL: func.func @_QPsubr_default_extendedlist() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_DEFAULT_EXTENDEDLIST() +!$omp declare target(SUBR_DEFAULT_EXTENDEDLIST) +END + +! CHECK-LABEL: func.func @_QPsubr_exist_on_both() +! CHECK-NOT: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_EXIST_ON_BOTH() +END + +!! ----- + +! Check declare target inconjunction with implicitly +! invoked functions, this tests the declare target +! implicit capture pass within Flang. Functions +! invoked within an explicitly declare target function +! are marked as declare target with the callers +! device_type clause + +! CHECK-LABEL: func.func @_QPimplicitly_captured +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED(TOGGLE) RESULT(K) + INTEGER :: I, J, K + LOGICAL :: TOGGLE + I = 10 + J = 5 + IF (TOGGLE) THEN + K = I + ELSE + K = J + END IF +END FUNCTION IMPLICITLY_CAPTURED + + +! CHECK-LABEL: func.func @_QPtarget_function +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION(TOGGLE) RESULT(I) +!$omp declare target + INTEGER :: I + LOGICAL :: TOGGLE + I = IMPLICITLY_CAPTURED(TOGGLE) +END FUNCTION TARGET_FUNCTION + +!! ----- + +! Check declare target inconjunction with implicitly +! invoked functions, this tests the declare target +! implicit capture pass within Flang. Functions +! invoked within an explicitly declare target function +! are marked as declare target with the callers +! device_type clause, however, if they are found with +! distinct device_type clauses i.e. host and nohost, +! then they should be marked as any + +! CHECK-LABEL: func.func @_QPimplicitly_captured_twice +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_TWICE() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_TWICE + +! CHECK-LABEL: func.func @_QPtarget_function_twice_host +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TWICE_HOST() RESULT(I) +!$omp declare target to(TARGET_FUNCTION_TWICE_HOST) device_type(host) + INTEGER :: I + I = IMPLICITLY_CAPTURED_TWICE() +END FUNCTION TARGET_FUNCTION_TWICE_HOST + +! CHECK-LABEL: func.func @_QPtarget_function_twice_device +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TWICE_DEVICE() RESULT(I) +!$omp declare target to(TARGET_FUNCTION_TWICE_DEVICE) device_type(nohost) + INTEGER :: I + I = IMPLICITLY_CAPTURED_TWICE() +END FUNCTION TARGET_FUNCTION_TWICE_DEVICE + +!! ----- + +! Check declare target inconjunction with implicitly +! invoked functions, this tests the declare target +! implicit capture pass within Flang. A slightly more +! complex test checking functions are marked implicitly +! appropriately. + +! CHECK-LABEL: func.func @_QPimplicitly_captured_nest +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_NEST() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_NEST + +! CHECK-LABEL: func.func @_QPimplicitly_captured_one +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_ONE() RESULT(K) + K = IMPLICITLY_CAPTURED_NEST() +END FUNCTION IMPLICITLY_CAPTURED_ONE + +! CHECK-LABEL: func.func @_QPimplicitly_captured_two +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_TWO() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_TWO + +! CHECK-LABEL: func.func @_QPtarget_function_test +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TEST() RESULT(J) +!$omp declare target to(TARGET_FUNCTION_TEST) device_type(nohost) + INTEGER :: I, J + I = IMPLICITLY_CAPTURED_ONE() + J = IMPLICITLY_CAPTURED_TWO() + I +END FUNCTION TARGET_FUNCTION_TEST + +!! ----- + +! Check declare target inconjunction with implicitly +! invoked functions, this tests the declare target +! implicit capture pass within Flang. A slightly more +! complex test checking functions are marked implicitly +! appropriately. + +! CHECK-LABEL: func.func @_QPimplicitly_captured_nest_twice +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_NEST_TWICE() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_NEST_TWICE + +! CHECK-LABEL: func.func @_QPimplicitly_captured_one_twice +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_ONE_TWICE() RESULT(K) + K = IMPLICITLY_CAPTURED_NEST_TWICE() +END FUNCTION IMPLICITLY_CAPTURED_ONE_TWICE + +! CHECK-LABEL: func.func @_QPimplicitly_captured_two_twice +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_TWO_TWICE() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_TWO_TWICE + +! CHECK-LABEL: func.func @_QPtarget_function_test_device +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TEST_DEVICE() RESULT(J) + !$omp declare target to(TARGET_FUNCTION_TEST_DEVICE) device_type(nohost) + INTEGER :: I, J + I = IMPLICITLY_CAPTURED_ONE_TWICE() + J = IMPLICITLY_CAPTURED_TWO_TWICE() + I +END FUNCTION TARGET_FUNCTION_TEST_DEVICE + +! CHECK-LABEL: func.func @_QPtarget_function_test_host +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TEST_HOST() RESULT(J) + !$omp declare target to(TARGET_FUNCTION_TEST_HOST) device_type(host) + INTEGER :: I, J + I = IMPLICITLY_CAPTURED_ONE_TWICE() + J = IMPLICITLY_CAPTURED_TWO_TWICE() + I +END FUNCTION TARGET_FUNCTION_TEST_HOST diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -21,6 +21,10 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +namespace mlir::omp { +enum class DeclareTargetDeviceType : uint32_t; +} // namespace mlir::omp + #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc" #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc" #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -28,6 +28,13 @@ let cppNamespace = "::mlir::omp"; let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; let useDefaultAttributePrinterParser = 1; + + let extraClassDeclaration = [{ + // Helper functions for assigning a DeclareTargetDeviceType Attribute to functions + static void setDeclareTarget(Operation *func, mlir::omp::DeclareTargetDeviceType deviceType); + static bool isDeclareTarget(Operation *func); + static mlir::omp::DeclareTargetDeviceType getDeclareTargetDeviceType(Operation *func); + }]; } // OmpCommon requires definition of OpenACC_Dialect. @@ -77,6 +84,27 @@ def OpenMP_PointerLikeType : TypeAlias; +//===----------------------------------------------------------------------===// +// 2.12.7 Declare Target Directive +//===----------------------------------------------------------------------===// + +def DeviceTypeAny : I32EnumAttrCase<"any", 0>; +def DeviceTypeHost : I32EnumAttrCase<"host", 1>; +def DeviceTypeNoHost : I32EnumAttrCase<"nohost", 2>; + +def DeclareTargetDeviceType : I32EnumAttr< + "DeclareTargetDeviceType", + "device_type clause", + [DeviceTypeAny, DeviceTypeHost, DeviceTypeNoHost]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::omp"; +} + +def DeclareTargetDeviceTypeAttr : EnumAttr { + let assemblyFormat = "`(` $value `)`"; +} + //===----------------------------------------------------------------------===// // 2.6 parallel Construct //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1421,6 +1421,30 @@ return success(); } +//===----------------------------------------------------------------------===// +// OpenMPDialect helper functions +//===----------------------------------------------------------------------===// + +void OpenMPDialect::setDeclareTarget( + Operation *func, mlir::omp::DeclareTargetDeviceType deviceType) { + func->setAttr("omp.declare_target", + mlir::omp::DeclareTargetDeviceTypeAttr::get(func->getContext(), + deviceType)); +} + +bool OpenMPDialect::isDeclareTarget(Operation *func) { + return func->hasAttr("omp.declare_target"); +} + +mlir::omp::DeclareTargetDeviceType +OpenMPDialect::getDeclareTargetDeviceType(Operation *func) { + if (mlir::Attribute declTar = func->getAttr("omp.declare_target")) { + if (declTar.isa()) + return declTar.cast().getValue(); + } + return {}; +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/attr.mlir b/mlir/test/Dialect/OpenMP/attr.mlir --- a/mlir/test/Dialect/OpenMP/attr.mlir +++ b/mlir/test/Dialect/OpenMP/attr.mlir @@ -29,3 +29,23 @@ // CHECK: module attributes {omp.flags = #omp.flags} { module attributes {omp.flags = #omp.flags} {} + +// ---- + +// CHECK-LABEL: func @omp_decl_tar_host +// CHECK-SAME: {{.*}} attributes {omp.declare_target = #omp} { +func.func @omp_decl_tar_host() -> () attributes {omp.declare_target = #omp} { + return +} + +// CHECK-LABEL: func @omp_decl_tar_nohost +// CHECK-SAME: {{.*}} attributes {omp.declare_target = #omp} { +func.func @omp_decl_tar_nohost() -> () attributes {omp.declare_target = #omp} { + return +} + +// CHECK-LABEL: func @omp_decl_tar_any +// CHECK-SAME: {{.*}} attributes {omp.declare_target = #omp} { +func.func @omp_decl_tar_any() -> () attributes {omp.declare_target = #omp} { + return +}