Index: llvm/include/llvm/Frontend/Directive/DirectiveBase.td =================================================================== --- llvm/include/llvm/Frontend/Directive/DirectiveBase.td +++ llvm/include/llvm/Frontend/Directive/DirectiveBase.td @@ -48,6 +48,21 @@ string clauseEnumSetClass = ""; } +// Information about values accepted by enum-like clauses +class ClauseVal { + // Name of the clause value. + string name = n; + + // Integer value of the clause. + int value = v; + + // Can user specify this value? + bit isUserValue = uv; + + // Set clause value used by default when unknown. + bit isDefault = 0; +} + // Information about a specific clause. class Clause { // Name of the clause. @@ -62,11 +77,14 @@ // Optional class holding value of the clause in flang AST. string flangClass = ?; + // List of allowed clause values + list allowedClauseValues = []; + // Is clause implicit? If clause is set as implicit, the default kind will // be return in getClauseKind instead of their own kind. bit isImplicit = 0; - // Set directive used by default when unknown. Function returning the kind + // Set clause used by default when unknown. Function returning the kind // of enumeration will use this clause as the default. bit isDefault = 0; } Index: llvm/include/llvm/Frontend/OpenMP/OMP.td =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMP.td +++ llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -60,8 +60,20 @@ def OMPC_CopyPrivate : Clause<"copyprivate"> { let clangClass = "OMPCopyprivateClause"; } +def OMP_PROC_BIND_master : ClauseVal<"master",2,1> {} +def OMP_PROC_BIND_close : ClauseVal<"close",3,1> {} +def OMP_PROC_BIND_spread : ClauseVal<"spread",4,1> {} +def OMP_PROC_BIND_default : ClauseVal<"default",5,0> {} +def OMP_PROC_BIND_unknown : ClauseVal<"unknown",6,0> { let isDefault = 1; } def OMPC_ProcBind : Clause<"proc_bind"> { let clangClass = "OMPProcBindClause"; + let allowedClauseValues = [ + OMP_PROC_BIND_master, + OMP_PROC_BIND_close, + OMP_PROC_BIND_spread, + OMP_PROC_BIND_default, + OMP_PROC_BIND_unknown + ]; } def OMPC_Schedule : Clause<"schedule"> { let clangClass = "OMPScheduleClause"; } def OMPC_Ordered : Clause<"ordered"> { let clangClass = "OMPOrderedClause"; } Index: llvm/include/llvm/Frontend/OpenMP/OMPConstants.h =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPConstants.h +++ llvm/include/llvm/Frontend/OpenMP/OMPConstants.h @@ -68,16 +68,6 @@ constexpr auto Enum = omp::DefaultKind::Enum; #include "llvm/Frontend/OpenMP/OMPKinds.def" -/// IDs for the different proc bind kinds. -enum class ProcBindKind { -#define OMP_PROC_BIND_KIND(Enum, Str, Value) Enum = Value, -#include "llvm/Frontend/OpenMP/OMPKinds.def" -}; - -#define OMP_PROC_BIND_KIND(Enum, ...) \ - constexpr auto Enum = omp::ProcBindKind::Enum; -#include "llvm/Frontend/OpenMP/OMPKinds.def" - /// IDs for all omp runtime library ident_t flag encodings (see /// their defintion in openmp/runtime/src/kmp.h). enum class IdentFlag { Index: llvm/utils/TableGen/DirectiveEmitter.cpp =================================================================== --- llvm/utils/TableGen/DirectiveEmitter.cpp +++ llvm/utils/TableGen/DirectiveEmitter.cpp @@ -197,6 +197,40 @@ } } +// Generate enums for values that clauses can take. +// Also generate function declarations for getName(StringRef Str). +void GenerateEnumClauseVal(const std::vector &Records, + raw_ostream &OS, StringRef Prefix, + StringRef CppNamespace, std::string &EnumHelperFuncs, + bool MakeEnumAvailableInNamespace) { + for (const auto &R : Records) { + const auto &ClauseVals = R->getValueAsListOfDefs("allowedClauseValues"); + if (ClauseVals.size() <= 0) + continue; + + OS << "\n"; + std::string enumName{(llvm::Twine(R->getName().drop_front(Prefix.size())) + + llvm::Twine("Kind")) + .str()}; + OS << "enum class " << enumName << " {\n"; + for (const auto &CV : ClauseVals) { + const auto &Value = CV->getValueAsInt("value"); + OS << " " << CV->getName() << "=" << Value << ",\n"; + } + OS << "};\n"; + + if (MakeEnumAvailableInNamespace) { + OS << "\n"; + for (const auto &CV : ClauseVals) { + OS << "constexpr auto " << CV->getName() << " = " + << "llvm::" << CppNamespace << "::" << enumName + << "::" << CV->getName() << ";\n"; + } + EnumHelperFuncs += enumName + " get" + enumName + "(StringRef);\n"; + } + } +} + // Generate the declaration section for the enumeration in the directive // language void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { @@ -239,6 +273,12 @@ const auto &Clauses = Records.getAllDerivedDefinitions("Clause"); GenerateEnumClass(Clauses, OS, "Clause", DirLang.getClausePrefix(), DirLang); + // Emit ClauseVal enumeration + std::string EnumHelperFuncs; + GenerateEnumClauseVal(Clauses, OS, DirLang.getClausePrefix(), + DirLang.getCppNamespace(), EnumHelperFuncs, + DirLang.hasMakeEnumAvailableInNamespace()); + // Generic function signatures OS << "\n"; OS << "// Enumeration helper functions\n"; @@ -258,6 +298,10 @@ OS << "bool isAllowedClauseForDirective(Directive D, " << "Clause C, unsigned Version);\n"; OS << "\n"; + if (EnumHelperFuncs.length() > 0) { + OS << EnumHelperFuncs; + OS << "\n"; + } // Closing namespaces for (auto Ns : llvm::reverse(Namespaces)) @@ -302,7 +346,7 @@ }); if (DefaultIt == Records.end()) { - PrintError("A least one " + Enum + " must be defined as default."); + PrintError("At least one " + Enum + " must be defined as default."); return; } @@ -327,6 +371,44 @@ OS << "}\n"; } +// Generate function implementation for getKind(StringRef Str) +void GenerateGetKindClauseVal(const std::vector &Records, + raw_ostream &OS, StringRef Prefix, + StringRef Namespace) { + + for (const auto &R : Records) { + const auto &ClauseVals = R->getValueAsListOfDefs("allowedClauseValues"); + if (ClauseVals.size() <= 0) + continue; + + auto DefaultIt = + std::find_if(ClauseVals.begin(), ClauseVals.end(), [](Record *CV) { + return CV->getValueAsBit("isDefault") == true; + }); + + if (DefaultIt == ClauseVals.end()) { + PrintError("At least one val in Clause " + R->getName() + + " must be defined as default."); + return; + } + + const auto DefaultName = (*DefaultIt)->getName(); + std::string EnumName{(llvm::Twine(R->getName().drop_front(Prefix.size())) + + llvm::Twine("Kind")) + .str()}; + OS << "\n"; + OS << EnumName << " llvm::" << Namespace << "::get" << EnumName + << "(llvm::StringRef Str) {\n"; + OS << " return llvm::StringSwitch<" << EnumName << ">(Str)\n"; + for (const auto &CV : ClauseVals) { + const auto Name = CV->getValueAsString("name"); + OS << " .Case(\"" << Name << "\"," << CV->getName() << ")\n"; + } + OS << " .Default(" << DefaultName << ");\n"; + OS << "}\n"; + } +} + void GenerateCaseForVersionedClauses(const std::vector &Clauses, raw_ostream &OS, StringRef DirectiveName, DirectiveLanguage &DirLang, @@ -563,6 +645,10 @@ // getClauseName(Clause Kind) GenerateGetName(Clauses, OS, "Clause", DirLang, DirLang.getClausePrefix()); + // getKind(StringRef Str) + GenerateGetKindClauseVal(Clauses, OS, DirLang.getClausePrefix(), + DirLang.getCppNamespace()); + // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) GenerateIsAllowedClause(Directives, OS, DirLang); } Index: mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt +++ mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenMP/OMP.td) +mlir_tablegen(OmpCommon.td --gen-directive-decl) +add_public_tablegen_target(omp_common_td) + set(LLVM_TARGET_DEFINITIONS OpenMPOps.td) mlir_tablegen(OpenMPOpsDialect.h.inc -gen-dialect-decls -dialect=omp) mlir_tablegen(OpenMPOps.h.inc -gen-op-decls) Index: mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td =================================================================== --- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -15,6 +15,7 @@ #define OPENMP_OPS include "mlir/IR/OpBase.td" +include "mlir/Dialect/OpenMP/OmpCommon.td" def OpenMP_Dialect : Dialect { let name = "omp"; @@ -42,18 +43,6 @@ let cppNamespace = "::mlir::omp"; } -// Possible values for the proc_bind clause -def ClauseProcMaster : StrEnumAttrCase<"master">; -def ClauseProcClose : StrEnumAttrCase<"close">; -def ClauseProcSpread : StrEnumAttrCase<"spread">; - -def ClauseProcBind : StrEnumAttr< - "ClauseProcBind", - "procbind clause", - [ClauseProcMaster, ClauseProcClose, ClauseProcSpread]> { - let cppNamespace = "::mlir::omp"; -} - def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> { let summary = "parallel construct"; let description = [{ Index: mlir/lib/Target/LLVMIR/ModuleTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -456,6 +456,9 @@ // Parallel operation is created with some default options for now. llvm::Value *ifCond = nullptr; llvm::Value *numThreads = nullptr; + llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default; + if (auto bind = cast(opInst).proc_bind_val()) + pbKind = llvm::omp::getProcBindKind(bind.getValue()); bool isCancellable = false; // TODO: Determine the actual alloca insertion point, e.g., the function // entry or the alloca insertion point as provided by the body callback @@ -463,7 +466,7 @@ llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP()); builder.restoreIP(ompBuilder->CreateParallel( builder, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, - llvm::omp::OMP_PROC_BIND_default, isCancellable)); + pbKind, isCancellable)); return success(); } Index: mlir/test/Target/openmp-llvm.mlir =================================================================== --- mlir/test/Target/openmp-llvm.mlir +++ mlir/test/Target/openmp-llvm.mlir @@ -78,3 +78,34 @@ // CHECK-LABEL: omp.par.region2: // CHECK: call void @body(i64 43) // CHECK: br label %omp.par.pre_finalize + +// CHECK-LABEL: define void @test_omp_parallel_3() +llvm.func @test_omp_parallel_3() -> () { + // CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) + // CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_1]], i32 2) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_1:.*]] to {{.*}} + omp.parallel proc_bind(master) { + omp.barrier + omp.terminator + } + // CHECK: [[OMP_THREAD_3_2:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) + // CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_2]], i32 3) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_2:.*]] to {{.*}} + omp.parallel proc_bind(close) { + omp.barrier + omp.terminator + } + // CHECK: [[OMP_THREAD_3_3:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) + // CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_3]], i32 4) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_3:.*]] to {{.*}} + omp.parallel proc_bind(spread) { + omp.barrier + omp.terminator + } + + llvm.return +} + +// CHECK: define internal void @[[OMP_OUTLINED_FN_3_3]] +// CHECK: define internal void @[[OMP_OUTLINED_FN_3_2]] +// CHECK: define internal void @[[OMP_OUTLINED_FN_3_1]] Index: mlir/tools/mlir-tblgen/CMakeLists.txt =================================================================== --- mlir/tools/mlir-tblgen/CMakeLists.txt +++ mlir/tools/mlir-tblgen/CMakeLists.txt @@ -14,6 +14,7 @@ OpDocGen.cpp OpFormatGen.cpp OpInterfacesGen.cpp + OpenMPCommonGen.cpp PassGen.cpp PassDocGen.cpp RewriterGen.cpp Index: mlir/tools/mlir-tblgen/OpenMPCommonGen.cpp =================================================================== --- /dev/null +++ mlir/tools/mlir-tblgen/OpenMPCommonGen.cpp @@ -0,0 +1,74 @@ +//===========- OpenMPCommonGen.cpp - OpenMP common info generator -===========// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// OpenMPCommonGen generates utility information from the single OpenMP source +// of truth in llvm/lib/Frontend/OpenMP. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/GenInfo.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Record.h" + +using llvm::raw_ostream; +using llvm::Record; +using llvm::RecordKeeper; +using llvm::StringRef; +using llvm::Twine; + +static bool emitDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { + const auto &clauses = recordKeeper.getAllDerivedDefinitions("Clause"); + const auto &directiveLanguages = + recordKeeper.getAllDerivedDefinitions("DirectiveLanguage"); + const auto &directiveLanguage = directiveLanguages[0]; + StringRef clausePrefix = directiveLanguage->getValueAsString("clausePrefix"); + + for (const auto &r : clauses) { + const auto &clauseVals = r->getValueAsListOfDefs("allowedClauseValues"); + if (clauseVals.size() <= 0) + continue; + + llvm::StringRef enumName{ + r->getName().drop_front(clausePrefix.size()).str()}; + std::vector cvDefs; + for (const auto &cv : clauseVals) { + if (!cv->getValueAsBit("isUserValue")) + continue; + const auto name = cv->getValueAsString("name"); + std::string cvDef{ + (llvm::Twine("Clause") + enumName + llvm::Twine(name)).str()}; + os << "def " << cvDef << " : StrEnumAttrCase<\"" << name << "\">;\n"; + cvDefs.push_back(cvDef); + } + + os << "def Clause" << enumName << ": StrEnumAttr<\n"; + os << " \"Clause" << enumName << "\",\n"; + os << " \"" << enumName << " Clause\",\n"; + os << " ["; + for (unsigned int i = 0; i < cvDefs.size(); i++) { + os << cvDefs[i]; + if (i != cvDefs.size() - 1) + os << ","; + } + os << "]> {\n"; + os << " let cppNamespace = \"::mlir::omp\";\n"; + os << "}\n"; + } + return false; +} + +// Registers the generator to mlir-tblgen. +static mlir::GenRegistration + genDirectiveDecls("gen-directive-decl", + "Generate declarations for directives (OpenMP etc.)", + [](const RecordKeeper &records, raw_ostream &os) { + return emitDecls(records, os); + });