diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -284,5 +284,5 @@ ]; } - + #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2184,6 +2184,108 @@ converter.bindSymbol(sym, symThreadprivateExv); } +void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + std::vector symbols; + auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList) { + for (const auto &ompObject : objList.v) { + Fortran::common::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::Designator &designator) { + if (const Fortran::parser::Name *name = + getDesignatorNameIfDataRef(designator)) { + symbols.push_back(*name->symbol); + } + }, + [&](const Fortran::parser::Name &name) { + symbols.push_back(*name.symbol); + }}, + ompObject.u); + } + }; + + const auto &spec{std::get( + declareTargetConstruct.t)}; + auto mod = converter.getFirOpBuilder().getModule(); + + // The default capture type + auto deviceType = Fortran::parser::OmpDeviceTypeClause::Type::Any; + if (const auto *objectList{ + Fortran::parser::Unwrap(spec.u)}) { + // Case: declare target(func, var1, var2) + findFuncAndVarSyms(*objectList); + } else if (const auto *clauseList{ + Fortran::parser::Unwrap( + spec.u)}) { + if (clauseList->v.empty()) { + // Case: declare target, implicit capture of function + symbols.push_back(eval.getOwningProcedure()->getSubprogramSymbol()); + } + + for (const auto &clause : clauseList->v) { + if (const auto *toClause{ + std::get_if(&clause.u)}) { + // Case: declare target to(func, var1, var2)... + findFuncAndVarSyms(toClause->v); + } else if (const auto *linkClause{ + std::get_if( + &clause.u)}) { + // Case: declare target link(var1, var2)... + TODO(converter.getCurrentLocation(), + "the link clause is currently unsupported"); + } else if (const auto *deviceClause{ + std::get_if( + &clause.u)}) { + // Case: declare target ... device_type(any | host | nohost) + deviceType = deviceClause->v.v; + } + } + } + + for (auto sym : symbols) { + auto *op = mod.lookupSymbol(converter.mangleName(sym)); + + // TODO: Remove this cast and TODO assert when global data and link are + // supported + mlir::func::FuncOp fOp = mlir::dyn_cast(op); + if (!fOp) + TODO(converter.getCurrentLocation(), + "only subroutines and functions are currently supported"); + + // The function already has a declare target applied to it, very + // likely through implicit capture (usage in another declare target + // function/subroutine). It should be marked as any if it has been + // assigned both host and nohost, else we skip, as there is no + // change + if (mlir::omp::OpenMPDialect::isDeclareTarget(fOp)) { + auto currentDeclTar = + mlir::omp::OpenMPDialect::getDeclareTargetDeviceType(fOp); + if ((currentDeclTar == mlir::omp::DeclareTargetDeviceType::nohost && + deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Host) || + (currentDeclTar == mlir::omp::DeclareTargetDeviceType::host && + deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Nohost)) { + mlir::omp::OpenMPDialect::setDeclareTarget( + op, mlir::omp::DeclareTargetDeviceType::any); + } + + continue; + } + + if (deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Nohost) { + mlir::omp::OpenMPDialect::setDeclareTarget( + op, mlir::omp::DeclareTargetDeviceType::nohost); + } else if (deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Host) { + mlir::omp::OpenMPDialect::setDeclareTarget( + op, mlir::omp::DeclareTargetDeviceType::host); + } else if (deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Any) { + mlir::omp::OpenMPDialect::setDeclareTarget( + op, mlir::omp::DeclareTargetDeviceType::any); + } + } +} + void Fortran::lower::genOpenMPDeclarativeConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -2206,8 +2308,7 @@ }, [&](const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { - TODO(converter.getCurrentLocation(), - "OpenMPDeclareTargetConstruct"); + handleDeclareTarget(converter, eval, declareTargetConstruct); }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp --- a/flang/lib/Semantics/rewrite-parse-tree.cpp +++ b/flang/lib/Semantics/rewrite-parse-tree.cpp @@ -16,11 +16,29 @@ #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include +#include +#include namespace Fortran::semantics { using namespace parser::literals; +class GatherCallRefs { +public: + GatherCallRefs() {} + + // Default action for a parse tree node is to visit children. + template bool Pre(T &) { return true; } + template void Post(T &) {} + + void Post(parser::Call &call) { + if (std::holds_alternative(std::get<0>(call.t).u)) + callNames.push_back(std::get(std::get<0>(call.t).u)); + } + + std::list callNames; +}; + /// Convert misidentified statement functions to array element assignments /// or pointer-valued function result assignments. /// Convert misidentified format expressions to namelist group names. @@ -44,6 +62,12 @@ void Post(parser::ReadStmt &); void Post(parser::WriteStmt &); + // Related to rewriting declare target specifiers to + // contain functions nested within the primary declare + // target function. + void Post(parser::OpenMPDeclareTargetConstruct &); + bool Pre(parser::ProgramUnit &); + // Name resolution yet implemented: // TODO: Can some/all of these now be enabled? bool Pre(parser::EquivalenceStmt &) { return false; } @@ -68,6 +92,9 @@ bool errorOnUnresolvedName_{true}; parser::Messages &messages_; std::list stmtFuncsToConvert_; + + std::map programUnits_; + parser::ProgramUnit *currentProgramUnit_ = nullptr; }; // Check that name has been resolved to a symbol @@ -172,6 +199,95 @@ FixMisparsedUntaggedNamelistName(x); } +parser::Name *getNameFromProgramUnit(parser::ProgramUnit &x) { + if (auto *func{parser::Unwrap(x.u)}) { + parser::FunctionStmt &Stmt = std::get<0>(func->t).statement; + return &std::get(Stmt.t); + } else if (auto *subr{parser::Unwrap(x.u)}) { + parser::SubroutineStmt &Stmt = std::get<0>(subr->t).statement; + return &std::get(Stmt.t); + } + + return nullptr; +} + +bool RewriteMutator::Pre(parser::ProgramUnit &x) { + currentProgramUnit_ = &x; + if (auto *name = getNameFromProgramUnit(x)) + programUnits_[name->ToString()] = &x; + return true; +} + +void markForEachProgramInList( + std::map programUnits, + parser::OmpObjectList &objList) { + auto existsInList = [](parser::OmpObjectList &objList, parser::Name name) { + for (auto &ompObject : objList.v) + if (auto *objName{parser::Unwrap(ompObject)}) + if (objName->ToString() == name.ToString()) + return true; + return false; + }; + + GatherCallRefs gatherer{}; + for (auto &ompObject : objList.v) { + if (auto *name{parser::Unwrap(ompObject)}) { + auto programUnit = programUnits.find(name->ToString()); + // something other than a subroutine or function, skip it + if (programUnit == programUnits.end()) + continue; + + parser::Walk(*programUnit->second, gatherer); + + // Currently using the Function Name rather than the CallRef name, + // unsure if these are interchangeable. However, ideally functions + // and subroutines should probably be parser::PorcedureDesignator's + // rather than parser::Designator's, but regular designators seem + // to be all that is utilised in the PFT definition for OmpObjects. + for (auto v : gatherer.callNames) { + if (!existsInList(objList, v)) { + objList.v.push_back(parser::OmpObject{parser::Designator{ + parser::DataRef{std::move(*getNameFromProgramUnit( + *programUnits.find(v.ToString())->second))}}}); + } + } + + gatherer.callNames.clear(); + } + } +} + +void RewriteMutator::Post(parser::OpenMPDeclareTargetConstruct &x) { + auto &spec{std::get(x.t)}; + if (parser::OmpObjectList * + objectList{parser::Unwrap(spec.u)}) { + markForEachProgramInList(programUnits_, *objectList); + } else if (auto *clauseList{parser::Unwrap(spec.u)}) { + for (auto &clause : clauseList->v) { + if (auto *toClause{std::get_if(&clause.u)}) { + markForEachProgramInList(programUnits_, toClause->v); + } else if (auto *linkClause{ + std::get_if(&clause.u)}) { + markForEachProgramInList(programUnits_, linkClause->v); + } + } + + // The default "declare target" inside of a function case, we must + // create and generate an to extended-list, containing at minimum the + // current function + if (clauseList->v.empty()) { + if (auto *name = getNameFromProgramUnit(*currentProgramUnit_)) { + std::list list; + list.push_back(parser::OmpObject{ + parser::Designator{parser::DataRef{std::move(*name)}}}); + auto objList = parser::OmpObjectList{std::move(list)}; + markForEachProgramInList(programUnits_, objList); + clauseList->v.push_back(parser::OmpClause::To{std::move(objList)}); + } + } + } +} + bool RewriteParseTree(SemanticsContext &context, parser::Program &program) { RewriteMutator mutator{context}; parser::Walk(program, mutator); diff --git a/flang/test/Lower/OpenMP/Todo/omp-declare-target.f90 b/flang/test/Lower/OpenMP/Todo/omp-declare-target.f90 deleted file mode 100644 --- a/flang/test/Lower/OpenMP/Todo/omp-declare-target.f90 +++ /dev/null @@ -1,12 +0,0 @@ -! This test checks lowering of OpenMP declare target Directive. - -// RUN: not flang-new -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s - -module mod1 -contains - subroutine sub() - integer :: x, y - // CHECK: not yet implemented: OpenMPDeclareTargetConstruct - !$omp declare target - end -end module diff --git a/flang/test/Lower/OpenMP/omp-declare-target.f90 b/flang/test/Lower/OpenMP/omp-declare-target.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-declare-target.f90 @@ -0,0 +1,262 @@ +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! Check specification valid forms of declare target with functions +! utilising device_type and to clauses as well as the default +! zero clause declare target + +! CHECK-LABEL: func.func @_QPfunc_t_device() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_T_DEVICE() RESULT(I) +!$omp declare target to(FUNC_T_DEVICE) device_type(nohost) + INTEGER :: I + I = 1 +END FUNCTION FUNC_T_DEVICE + +! CHECK-LABEL: func.func @_QPfunc_t_host() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_T_HOST() RESULT(I) +!$omp declare target to(FUNC_T_HOST) device_type(host) + INTEGER :: I + I = 1 +END FUNCTION FUNC_T_HOST + +! CHECK-LABEL: func.func @_QPfunc_t_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_T_ANY() RESULT(I) +!$omp declare target to(FUNC_T_ANY) device_type(any) + INTEGER :: I + I = 1 +END FUNCTION FUNC_T_ANY + +! CHECK-LABEL: func.func @_QPfunc_default_t_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_DEFAULT_T_ANY() RESULT(I) +!$omp declare target to(FUNC_DEFAULT_T_ANY) + INTEGER :: I + I = 1 +END FUNCTION FUNC_DEFAULT_T_ANY + +! CHECK-LABEL: func.func @_QPfunc_default_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_DEFAULT_ANY() RESULT(I) +!$omp declare target + INTEGER :: I + I = 1 +END FUNCTION FUNC_DEFAULT_ANY + +! CHECK-LABEL: func.func @_QPfunc_default_extendedlist() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION FUNC_DEFAULT_EXTENDEDLIST() RESULT(I) +!$omp declare target(FUNC_DEFAULT_EXTENDEDLIST) + INTEGER :: I + I = 1 +END FUNCTION FUNC_DEFAULT_EXTENDEDLIST + +! CHECK-LABEL: func.func @_QPexist_on_both() +! CHECK-NOT: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION EXIST_ON_BOTH() RESULT(I) + INTEGER :: I + I = 1 +END FUNCTION EXIST_ON_BOTH + +!! ----- + +! Check specification valid forms of declare target with subroutines +! utilising device_type and to clauses as well as the default +! zero clause declare target + +! CHECK-LABEL: func.func @_QPsubr_t_device() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_T_DEVICE() +!$omp declare target to(SUBR_T_DEVICE) device_type(nohost) +END + +! CHECK-LABEL: func.func @_QPsubr_t_host() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_T_HOST() +!$omp declare target to(SUBR_T_HOST) device_type(host) +END + +! CHECK-LABEL: func.func @_QPsubr_t_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_T_ANY() +!$omp declare target to(SUBR_T_ANY) device_type(any) +END + +! CHECK-LABEL: func.func @_QPsubr_default_t_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_DEFAULT_T_ANY() +!$omp declare target to(SUBR_DEFAULT_T_ANY) +END + +! CHECK-LABEL: func.func @_QPsubr_default_any() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_DEFAULT_ANY() +!$omp declare target +END + +! CHECK-LABEL: func.func @_QPsubr_default_extendedlist() +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_DEFAULT_EXTENDEDLIST() +!$omp declare target(SUBR_DEFAULT_EXTENDEDLIST) +END + +! CHECK-LABEL: func.func @_QPsubr_exist_on_both() +! CHECK-NOT: {{.*}}attributes {omp.declare_target = #omp{{.*}} +SUBROUTINE SUBR_EXIST_ON_BOTH() +END + +!! ----- + +! Check declare target inconjunction with implicitly +! invoked functions, this tests the declare target +! implicit capture pass within Flang. Functions +! invoked within an explicitly declare target function +! are marked as declare target with the callers +! device_type clause + +! CHECK-LABEL: func.func @_QPimplicitly_captured +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED(TOGGLE) RESULT(K) + INTEGER :: I, J, K + LOGICAL :: TOGGLE + I = 10 + J = 5 + IF (TOGGLE) THEN + K = I + ELSE + K = J + END IF +END FUNCTION IMPLICITLY_CAPTURED + + +! CHECK-LABEL: func.func @_QPtarget_function +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION(TOGGLE) RESULT(I) +!$omp declare target + INTEGER :: I + LOGICAL :: TOGGLE + I = IMPLICITLY_CAPTURED(TOGGLE) +END FUNCTION TARGET_FUNCTION + +!! ----- + +! Check declare target inconjunction with implicitly +! invoked functions, this tests the declare target +! implicit capture pass within Flang. Functions +! invoked within an explicitly declare target function +! are marked as declare target with the callers +! device_type clause, however, if they are found with +! distinct device_type clauses i.e. host and nohost, +! then they should be marked as any + +! CHECK-LABEL: func.func @_QPimplicitly_captured_twice +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_TWICE() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_TWICE + +! CHECK-LABEL: func.func @_QPtarget_function_twice_host +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TWICE_HOST() RESULT(I) +!$omp declare target to(TARGET_FUNCTION_TWICE_HOST) device_type(host) + INTEGER :: I + I = IMPLICITLY_CAPTURED_TWICE() +END FUNCTION TARGET_FUNCTION_TWICE_HOST + +! CHECK-LABEL: func.func @_QPtarget_function_twice_device +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TWICE_DEVICE() RESULT(I) +!$omp declare target to(TARGET_FUNCTION_TWICE_DEVICE) device_type(nohost) + INTEGER :: I + I = IMPLICITLY_CAPTURED_TWICE() +END FUNCTION TARGET_FUNCTION_TWICE_DEVICE + +!! ----- + +! Check declare target inconjunction with implicitly +! invoked functions, this tests the declare target +! implicit capture pass within Flang. A slightly more +! complex test checking functions are marked implicitly +! appropriately. + +! CHECK-LABEL: func.func @_QPimplicitly_captured_nest +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_NEST() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_NEST + +! CHECK-LABEL: func.func @_QPimplicitly_captured_one +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_ONE() RESULT(K) + K = IMPLICITLY_CAPTURED_NEST() +END FUNCTION IMPLICITLY_CAPTURED_ONE + +! CHECK-LABEL: func.func @_QPimplicitly_captured_two +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_TWO() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_TWO + +! CHECK-LABEL: func.func @_QPtarget_function_test +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TEST() RESULT(J) +!$omp declare target to(TARGET_FUNCTION_TEST) device_type(nohost) + INTEGER :: I, J + I = IMPLICITLY_CAPTURED_ONE() + J = IMPLICITLY_CAPTURED_TWO() + I +END FUNCTION TARGET_FUNCTION_TEST + +!! ----- + +! Check declare target inconjunction with implicitly +! invoked functions, this tests the declare target +! implicit capture pass within Flang. A slightly more +! complex test checking functions are marked implicitly +! appropriately. + +! CHECK-LABEL: func.func @_QPimplicitly_captured_nest_twice +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_NEST_TWICE() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_NEST_TWICE + +! CHECK-LABEL: func.func @_QPimplicitly_captured_one_twice +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_ONE_TWICE() RESULT(K) + K = IMPLICITLY_CAPTURED_NEST_TWICE() +END FUNCTION IMPLICITLY_CAPTURED_ONE_TWICE + +! CHECK-LABEL: func.func @_QPimplicitly_captured_two_twice +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION IMPLICITLY_CAPTURED_TWO_TWICE() RESULT(K) + INTEGER :: I + I = 10 + K = I +END FUNCTION IMPLICITLY_CAPTURED_TWO_TWICE + +! CHECK-LABEL: func.func @_QPtarget_function_test_device +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TEST_DEVICE() RESULT(J) + !$omp declare target to(TARGET_FUNCTION_TEST_DEVICE) device_type(nohost) + INTEGER :: I, J + I = IMPLICITLY_CAPTURED_ONE_TWICE() + J = IMPLICITLY_CAPTURED_TWO_TWICE() + I +END FUNCTION TARGET_FUNCTION_TEST_DEVICE + +! CHECK-LABEL: func.func @_QPtarget_function_test_host +! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp{{.*}} +FUNCTION TARGET_FUNCTION_TEST_HOST() RESULT(J) + !$omp declare target to(TARGET_FUNCTION_TEST_HOST) device_type(host) + INTEGER :: I, J + I = IMPLICITLY_CAPTURED_ONE_TWICE() + J = IMPLICITLY_CAPTURED_TWO_TWICE() + I +END FUNCTION TARGET_FUNCTION_TEST_HOST diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -21,6 +21,10 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +namespace mlir::omp { +enum class DeclareTargetDeviceType : uint32_t; +} // namespace mlir::omp + #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc" #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc" #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -28,6 +28,13 @@ let cppNamespace = "::mlir::omp"; let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; let useDefaultAttributePrinterParser = 1; + + let extraClassDeclaration = [{ + // Helper functions for assigning a DeclareTargetDeviceType Attribute to functions + static void setDeclareTarget(Operation *func, mlir::omp::DeclareTargetDeviceType deviceType); + static bool isDeclareTarget(Operation *func); + static mlir::omp::DeclareTargetDeviceType getDeclareTargetDeviceType(Operation *func); + }]; } // OmpCommon requires definition of OpenACC_Dialect. @@ -77,6 +84,27 @@ def OpenMP_PointerLikeType : TypeAlias; +//===----------------------------------------------------------------------===// +// 2.12.7 Declare Target Directive +//===----------------------------------------------------------------------===// + +def DeviceTypeAny : I32EnumAttrCase<"any", 0>; +def DeviceTypeHost : I32EnumAttrCase<"host", 1>; +def DeviceTypeNoHost : I32EnumAttrCase<"nohost", 2>; + +def DeclareTargetDeviceType : I32EnumAttr< + "DeclareTargetDeviceType", + "device_type clause", + [DeviceTypeAny, DeviceTypeHost, DeviceTypeNoHost]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::omp"; +} + +def DeclareTargetDeviceTypeAttr : EnumAttr { + let assemblyFormat = "`(` $value `)`"; +} + //===----------------------------------------------------------------------===// // 2.6 parallel Construct //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1421,6 +1421,30 @@ return success(); } +//===----------------------------------------------------------------------===// +// OpenMPDialect helper functions +//===----------------------------------------------------------------------===// + +void OpenMPDialect::setDeclareTarget( + Operation *func, mlir::omp::DeclareTargetDeviceType deviceType) { + func->setAttr("omp.declare_target", + mlir::omp::DeclareTargetDeviceTypeAttr::get(func->getContext(), + deviceType)); +} + +bool OpenMPDialect::isDeclareTarget(Operation *func) { + return func->hasAttr("omp.declare_target"); +} + +mlir::omp::DeclareTargetDeviceType +OpenMPDialect::getDeclareTargetDeviceType(Operation *func) { + if (mlir::Attribute declTar = func->getAttr("omp.declare_target")) { + if (declTar.isa()) + return declTar.cast().getValue(); + } + return {}; +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/attr.mlir b/mlir/test/Dialect/OpenMP/attr.mlir --- a/mlir/test/Dialect/OpenMP/attr.mlir +++ b/mlir/test/Dialect/OpenMP/attr.mlir @@ -29,3 +29,23 @@ // CHECK: module attributes {omp.flags = #omp.flags} { module attributes {omp.flags = #omp.flags} {} + +// ---- + +// CHECK-LABEL: func @omp_decl_tar_host +// CHECK-SAME: {{.*}} attributes {omp.declare_target = #omp} { +func.func @omp_decl_tar_host() -> () attributes {omp.declare_target = #omp} { + return +} + +// CHECK-LABEL: func @omp_decl_tar_nohost +// CHECK-SAME: {{.*}} attributes {omp.declare_target = #omp} { +func.func @omp_decl_tar_nohost() -> () attributes {omp.declare_target = #omp} { + return +} + +// CHECK-LABEL: func @omp_decl_tar_any +// CHECK-SAME: {{.*}} attributes {omp.declare_target = #omp} { +func.func @omp_decl_tar_any() -> () attributes {omp.declare_target = #omp} { + return +}